<a href="https://colab.research.google.com/github/FarrelAD/Hology-8-2025-Data-Mining-PRIVATE/blob/dev%2Fvidi/notebooks/vidi/SFCN/notebook.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Crowd-counting with method SFCN

A new method that actually has better results than the other models we’ve tried so far.

# 1. Project Setup

## Import Libraries

In [None]:
import csv
import re
import json
import math
import random
from glob import glob
from typing import List, Tuple
import numpy as np
import cv2
from scipy.spatial import KDTree
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models, transforms
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

## Dataset Download and Setup

In [None]:
# @title Setup Kaggle secret key
!pip install -q kaggle

from google.colab import files

uploaded = files.upload()

for fn in uploaded.keys():
  print('User uploaded file "{name}" with length {length} bytes'.format(
      name=fn, length=len(uploaded[fn])))

# Then move kaggle.json into the folder where the API expects to find it.
!mkdir -p ~/.kaggle/ && mv kaggle.json ~/.kaggle/ && chmod 600 ~/.kaggle/kaggle.json

In [None]:
# @title Setup dataset in Colab
import zipfile
import os
from google.colab import drive

drive.mount('/content/drive')

# Paths
zip_path = "/content/penyisihan-hology-8-0-2025-data-mining.zip"
drive_extract_path = "/content/drive/MyDrive/PROJECTS/Cognivio/Percobaan Hology 8 2025/dataset"
local_dataset_path = "/content/dataset"  # for current session

# ---------------------------
# Step 1: Download zip (if not exists in /content)
# ---------------------------
if not os.path.exists(zip_path):
    print("Dataset not found locally, downloading...")
    !kaggle competitions download -c penyisihan-hology-8-0-2025-data-mining -p /content
else:
    print("Dataset already exists, skipping download.")

# ---------------------------
# Step 2: Extract to Google Drive (for backup)
# ---------------------------
os.makedirs(drive_extract_path, exist_ok=True)

if not os.listdir(drive_extract_path):  # Check if folder is empty
    print("Extracting dataset to Google Drive...")
    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
        zip_ref.extractall(drive_extract_path)
    print("Dataset extracted to:", drive_extract_path)
else:
    print("Dataset already extracted at:", drive_extract_path)

# ---------------------------
# Step 3: Copy dataset to local /content (faster training)
# ---------------------------
if not os.path.exists(local_dataset_path):
    print("Copying dataset to Colab local storage (/content)...")
    !cp -r "$drive_extract_path" "$local_dataset_path"
else:
    print("Dataset already available in Colab local storage.")

## Some Configuration

In [None]:
# ImageNet mean and std for normalisation
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]

# Dataset path
TRAIN_IMG_DIR = os.path.join(local_dataset_path, "train", "images")
TRAIN_LABEL_DIR = os.path.join(local_dataset_path, "train", "labels")
TEST_IMG_DIR  = os.path.join(local_dataset_path, "test", "images")


# Seed for better reproducibility
SEED = 1337 # or 31
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)



# 
EPOCHS = 120
ACCUM_STEPS = 2
EARLY_STOP_PATIENCE = 15
CUDA_AMP = True         # CUDA Automatic Mixed Precision (AMP).

BASE_SIZE = 768
DOWN = 8
PATCH_SIZE = 384
PATCHES_PER_IMAGE = 4
AVOID_EMPTY_PATCHES = True
SIGMA_MODE = "adaptive"

BATCH = 4
NUM_WORKERS = 4

## Some Helper Functions

In [None]:
def derive_json_path(
    lbl_dir: str, 
    img_path: str
) -> str:
    """Derive the corresponding JSON label path for a given image path.

    Tries to match the image basename to a JSON file in the label directory.
    Falls back to matching digits if an exact name isn't found.
    """
    name = os.path.splitext(os.path.basename(img_path))[0]
    cand = os.path.join(lbl_dir, name + ".json")
    if os.path.exists(cand):
        return cand
    # Try matching trailing digits
    m = re.findall(r"\d+", name)
    if m:
        alt = os.path.join(lbl_dir, f"{m[-1]}.json")
        if os.path.exists(alt):
            return alt
    # Fallback: any file starting with the same name
    lst = glob(os.path.join(lbl_dir, f"{name}*.json"))
    if lst:
        return lst[0]
    raise FileNotFoundError(f"JSON label not found for {img_path}")

def parse_points_from_json(
    path: str
) -> Tuple[np.ndarray, int]:
    """Parse annotated points from a JSON file.

    Supports several common crowd counting annotation formats. Returns an
    array of shape (N, 2) containing [x, y] coordinates and, if present
    in the JSON, the declared number of people. If no `human_num` or
    `num_human` field is present the count is returned as None.
    """
    with open(path, "r", encoding="utf-8") as f:
        obj = json.load(f)
    pts: List[List[float]]
    num = None
    if isinstance(obj, dict) and "points" in obj:
        pts = obj["points"]
        # unify possible keys for ground-truth count
        num = obj.get("human_num", obj.get("num_human", None))
        # If points are dicts, extract x/y fields
        if len(pts) > 0 and isinstance(pts[0], dict):
            pts = [[p["x"], p["y"]] for p in pts if "x" in p and "y" in p]
    elif isinstance(obj, dict) and "annotations" in obj:
        pts = [[a["x"], a["y"]] for a in obj["annotations"] if "x" in a and "y" in a]
        num = obj.get("human_num", obj.get("num_human", None))
    elif isinstance(obj, list):
        # direct list of [x,y]
        pts = obj
    else:
        raise ValueError(f"Unknown JSON schema: {path}")
    pts_arr = np.array(pts, dtype=np.float32) if len(pts) > 0 else np.zeros((0, 2), np.float32)
    return pts_arr, num

def letterbox(
    img: np.ndarray, 
    target: int = 512
) -> Tuple[np.ndarray, float, int, int]:
    """Resize and pad an image to a square canvas without distortion.

    Returns the padded image, the scale factor used, and the left/top
    padding applied. The output size is (target, target).
    """
    h, w = img.shape[:2]
    scale = min(target / h, target / w)
    nh, nw = int(round(h * scale)), int(round(w * scale))
    img_rs = cv2.resize(img, (nw, nh), interpolation=cv2.INTER_LINEAR)
    top = (target - nh) // 2
    left = (target - nw) // 2
    canvas = np.zeros((target, target, 3), dtype=img_rs.dtype)
    canvas[top:top + nh, left:left + nw] = img_rs
    return canvas, scale, left, top


def make_density_map(
    points_xy: np.ndarray,
    grid_size: int,
    down: int = 8,
    sigma_mode: str = "adaptive",
    knn: int = 3,
    beta: float = 0.3,
    const_sigma: float = 2.0,
) -> np.ndarray:
    """Generate a density map on a grid given annotated points.

    The density map is of shape (grid_size//down, grid_size//down). Each
    point is represented by a Gaussian whose sigma is either constant or
    computed from the k-nearest neighbours. The integral of the density map
    approximates the number of points.
    """
    target = grid_size
    dh, dw = target // down, target // down
    den = np.zeros((dh, dw), dtype=np.float32)
    if len(points_xy) == 0:
        return den
    # Scale points to the density map resolution
    pts = points_xy.copy()
    pts[:, 0] = pts[:, 0] * (dw / target)
    pts[:, 1] = pts[:, 1] * (dh / target)
    tree = KDTree(pts) if len(pts) > 1 else None
    for (x, y) in pts:
        # Determine sigma
        if sigma_mode == "adaptive" and tree is not None and len(pts) > 3:
            dists, _ = tree.query([x, y], k=min(knn + 1, len(pts)))
            sigma = max(1.0, float(np.mean(dists[1:])) * beta)
        else:
            sigma = const_sigma
        cx, cy = float(x), float(y)
        rad = int(max(1, math.ceil(3 * sigma)))
        x0, x1 = max(0, int(math.floor(cx - rad))), min(dw, int(math.ceil(cx + rad + 1)))
        y0, y1 = max(0, int(math.floor(cy - rad))), min(dh, int(math.ceil(cy + rad + 1)))
        if x1 <= x0 or y1 <= y0:
            continue
        xs = np.arange(x0, x1) - cx
        ys = np.arange(y0, y1) - cy
        xx, yy = np.meshgrid(xs, ys)
        g = np.exp(-(xx**2 + yy**2) / (2 * sigma * sigma))
        s = g.sum()
        if s > 0:
            den[y0:y1, x0:x1] += (g / s).astype(np.float32)
    return den

# 2. Data Preprocessing

In [None]:
# Build list of images and randomly shuffle before splitting
all_imgs = sorted(glob(os.path.join(TRAIN_IMG_DIR, "*.*")))

random.shuffle(all_imgs)
n_images = len(all_imgs)

if n_images < 2:
    raise ValueError("Need at least 2 images for training and validation")

n_val = max(1, int(0.1 * n_images))
n_train = n_images - n_val

train_imgs = all_imgs[:n_train]
val_imgs = all_imgs[n_train:]

# 3. Dataset Loading

In [None]:
class CrowdDataset(Dataset):
    """Custom dataset for crowd counting.

    Supports optional random patch cropping on training data. When
    `avoid_empty_patches` is set, the dataset will attempt to choose a
    patch containing at least one point, falling back to a random crop if
    it fails after a number of retries.
    """

    def __init__(
        self,
        img_dir: str,
        label_dir: str,
        base_size: int = 768,
        down: int = 8,
        aug: bool = True,
        mode: str = "train",
        patch_size: int = 0,
        patches_per_image: int = 1,
        sigma_mode: str = "adaptive",
        avoid_empty_patches: bool = False,
    ) -> None:
        super().__init__()
        self.img_paths = sorted(glob(os.path.join(img_dir, "*.*")))
        if len(self.img_paths) == 0:
            raise ValueError(f"No images found in {img_dir}")
        self.label_dir = label_dir
        self.base_size = base_size
        self.down = down
        self.aug = aug
        self.mode = mode
        self.patch_size = patch_size
        self.patches_per_image = max(1, int(patches_per_image))
        self.sigma_mode = sigma_mode
        self.avoid_empty_patches = avoid_empty_patches

        self.to_tensor = transforms.ToTensor()
        self.normalize = transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
        self.color_jit = transforms.ColorJitter(0.1, 0.1, 0.1, 0.05)

        # Compute effective length for patch training
        if self.mode == "train" and self.patch_size > 0 and self.patches_per_image > 1:
            self.effective_len = len(self.img_paths) * self.patches_per_image
        else:
            self.effective_len = len(self.img_paths)

    def __len__(self) -> int:
        return self.effective_len

    def _load_img_pts(
        self, 
        idx_base: int
    ) -> Tuple[np.ndarray, np.ndarray]:
        """Load an image and its transformed annotation points."""
        pimg = self.img_paths[idx_base]
        img = cv2.imread(pimg, cv2.IMREAD_COLOR)
        h, w = img.shape[:2]
        plbl = derive_json_path(self.label_dir, pimg)
        pts, _ = parse_points_from_json(plbl)
        # Augment: horizontal flip
        if self.mode == "train" and self.aug and random.random() < 0.5:
            img = img[:, ::-1, :].copy()
            if len(pts) > 0:
                pts = pts.copy()
                pts[:, 0] = (w - 1) - pts[:, 0]
        # Augment: colour jitter
        if self.mode == "train" and self.aug and random.random() < 0.5:
            pil = transforms.ToPILImage()(img)
            pil = self.color_jit(pil)
            img = np.array(pil)
        # Letterbox to base size
        canvas, scale, left, top = letterbox(img, target=self.base_size)
        if len(pts) > 0:
            pts_tr = pts.copy()
            pts_tr[:, 0] = pts_tr[:, 0] * scale + left
            pts_tr[:, 1] = pts_tr[:, 1] * scale + top
            # Clamp to canvas bounds
            m = (
                (pts_tr[:, 0] >= 0)
                & (pts_tr[:, 0] < self.base_size)
                & (pts_tr[:, 1] >= 0)
                & (pts_tr[:, 1] < self.base_size)
            )
            pts_tr = pts_tr[m]
        else:
            pts_tr = np.zeros((0, 2), np.float32)
        return canvas, pts_tr

    def __getitem__(
        self, 
        index: int
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        # Map global index to an image index (for patch training)
        if self.mode == "train" and self.patch_size > 0 and self.patches_per_image > 1:
            idx_base = index // self.patches_per_image
        else:
            idx_base = index
        idx_base %= len(self.img_paths)

        img_lb, pts_tr = self._load_img_pts(idx_base)

        # Optional patch cropping
        if self.mode == "train" and self.patch_size > 0:
            ps = self.patch_size
            if ps > self.base_size:
                raise ValueError("patch_size must be <= base_size")
            max_off = self.base_size - ps
            pts_out = np.zeros((0, 2), np.float32)
            # Attempt to find a non-empty patch when requested
            for attempt in range(10) if self.avoid_empty_patches else [0]:
                ox = 0 if max_off <= 0 else random.randint(0, max_off)
                oy = 0 if max_off <= 0 else random.randint(0, max_off)
                crop = img_lb[oy : oy + ps, ox : ox + ps, :]
                if len(pts_tr) > 0:
                    pts_c = pts_tr.copy()
                    pts_c[:, 0] -= ox
                    pts_c[:, 1] -= oy
                    m = (
                        (pts_c[:, 0] >= 0)
                        & (pts_c[:, 0] < ps)
                        & (pts_c[:, 1] >= 0)
                        & (pts_c[:, 1] < ps)
                    )
                    pts_c = pts_c[m]
                else:
                    pts_c = np.zeros((0, 2), np.float32)
                # If avoid_empty_patches is False, we break immediately (no retries)
                if not self.avoid_empty_patches or len(pts_c) > 0 or attempt == 9:
                    pts_out = pts_c
                    img_out = crop
                    break
            grid = ps
        else:
            img_out = img_lb
            pts_out = pts_tr
            grid = self.base_size

        # Build density map
        den = make_density_map(
            pts_out,
            grid_size=grid,
            down=self.down,
            sigma_mode=self.sigma_mode,
        )

        # Optional check: ensure density integrates to the number of points
        if self.mode != "train" or random.random() < 0.02:
            cnt = float(len(pts_out))
            s = float(den.sum())
            if abs(s - cnt) > 0.05:
                print(f"[warn] density sum {s:.3f} != count {cnt:.3f}")

        # Convert to tensors and normalise
        t = self.to_tensor(img_out)
        t = self.normalize(t)
        d = torch.from_numpy(den).unsqueeze(0)
        c = torch.tensor([float(len(pts_out))], dtype=torch.float32)
        return t, d, c

In [None]:
# Instantiate datasets
train_ds = CrowdDataset(
    img_dir=TRAIN_IMG_DIR,
    label_dir=TRAIN_LABEL_DIR,
    base_size=BASE_SIZE,
    down=DOWN,
    aug=True,
    mode="train",
    patch_size=PATCH_SIZE,
    patches_per_image=PATCHES_PER_IMAGE,
    sigma_mode=SIGMA_MODE,
    avoid_empty_patches=AVOID_EMPTY_PATCHES,
)
val_ds = CrowdDataset(
    img_dir=TRAIN_IMG_DIR,
    label_dir=TRAIN_LABEL_DIR,
    base_size=BASE_SIZE,
    down=DOWN,
    aug=False,
    mode="val",
    patch_size=0,
    patches_per_image=1,
    sigma_mode=SIGMA_MODE,
    avoid_empty_patches=False,
)


# Override image paths after shuffling
train_ds.img_paths = train_imgs
val_ds.img_paths = val_imgs

# Recompute effective lengths for patch training
if train_ds.mode == "train" and train_ds.patch_size > 0 and train_ds.patches_per_image > 1:
    train_ds.effective_len = len(train_ds.img_paths) * train_ds.patches_per_image
else:
    train_ds.effective_len = len(train_ds.img_paths)
val_ds.effective_len = len(val_ds.img_paths)


# Data loaders
train_loader = DataLoader(
    train_ds,
    batch_size=BATCH,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=True,
    drop_last=True,
)
val_loader = DataLoader(
    val_ds,
    batch_size= BATCH,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True,
)

# 4. Model Architecture

In [None]:
class SpatialEncoder(nn.Module):
    """Simple spatial encoder that propagates information in four directions."""

    def __init__(self, channels: int, k: int = 9) -> None:
        super().__init__()
        p = k // 2
        self.h1 = nn.Conv2d(channels, channels, (1, k), padding=(0, p), groups=channels, bias=False)
        self.h2 = nn.Conv2d(channels, channels, (1, k), padding=(0, p), groups=channels, bias=False)
        self.v1 = nn.Conv2d(channels, channels, (k, 1), padding=(p, 0), groups=channels, bias=False)
        self.v2 = nn.Conv2d(channels, channels, (k, 1), padding=(p, 0), groups=channels, bias=False)
        self.proj = nn.Conv2d(channels * 4, channels, 1, bias=False)
        self.act = nn.ReLU(inplace=True)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        y = torch.cat([self.h1(x), self.h2(x), self.v1(x), self.v2(x)], dim=1)
        return self.act(self.proj(y))


class SFCN_VGG(nn.Module):
    """Simplified SFCN with VGG-16 backbone and spatial encoder."""

    def __init__(self, pretrained: bool = True) -> None:
        super().__init__()
        vgg = models.vgg16_bn(
            weights=models.VGG16_BN_Weights.IMAGENET1K_V1 if pretrained else None
        )
        # Use features up to conv4_3 (stride 8)
        self.frontend = nn.Sequential(*list(vgg.features.children())[:33])
        self.senc = SpatialEncoder(512, k=9)
        self.head = nn.Sequential(
            nn.Conv2d(512, 256, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 128, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 64, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 1, 1),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.frontend(x)
        x = self.senc(x)
        x = self.head(x)
        return torch.nn.functional.softplus(x)
    

# Model, optimizer, scaler, and scheduler
model = SFCN_VGG(pretrained=True).to(device)
opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
scaler = torch.cuda.amp.GradScaler() if (CUDA_AMP and device.type == "cuda") else None
sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=EPOCHS)

# 5. Model Training

In [None]:
def train_epoch(
    model: nn.Module,
    loader: DataLoader,
    device: torch.device,
    optimizer: torch.optim.Optimizer,
    scaler: torch.cuda.amp.GradScaler = None,
    accum_steps: int = 1,
    criterion: str = "mse",
    count_loss_alpha: float = 0.0,
) -> float:
    """Train the model for one epoch and return the mean absolute error.

    Supports optional gradient accumulation and auxiliary count loss. The
    auxiliary loss encourages the sum of the predicted density map to match
    the ground-truth count.
    """
    model.train()
    if criterion == "mse":
        crit = nn.MSELoss()
    elif criterion == "huber":
        crit = nn.SmoothL1Loss()
    else:
        raise ValueError("criterion must be 'mse' or 'huber'")
    running_mae, nimg = 0.0, 0
    total_pred_count, total_gt_count = 0.0, 0.0
    optimizer.zero_grad(set_to_none=True)
    for step, (imgs, dens, _) in enumerate(tqdm(loader, desc="Train", leave=False), 1):
        imgs = imgs.to(device)
        dens = dens.to(device)
        if scaler is not None:
            with torch.autocast(device_type="cuda", dtype=torch.float16):
                preds = model(imgs)
                map_loss = crit(preds, dens)
                # Count loss: MSE between predicted and ground-truth counts
                if count_loss_alpha > 0.0:
                    pred_cnt = preds.sum(dim=(1, 2, 3))
                    gt_cnt = dens.sum(dim=(1, 2, 3))
                    cnt_loss = F.mse_loss(pred_cnt, gt_cnt)
                    total_loss = (map_loss + count_loss_alpha * cnt_loss) / accum_steps
                else:
                    total_loss = map_loss / accum_steps
            scaler.scale(total_loss).backward()
            if step % accum_steps == 0:
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad(set_to_none=True)
        else:
            preds = model(imgs)
            map_loss = crit(preds, dens)
            if count_loss_alpha > 0.0:
                pred_cnt = preds.sum(dim=(1, 2, 3))
                gt_cnt = dens.sum(dim=(1, 2, 3))
                cnt_loss = F.mse_loss(pred_cnt, gt_cnt)
                total_loss = (map_loss + count_loss_alpha * cnt_loss) / accum_steps
            else:
                total_loss = map_loss / accum_steps
            total_loss.backward()
            if step % accum_steps == 0:
                optimizer.step()
                optimizer.zero_grad(set_to_none=True)
        with torch.no_grad():
            pc = preds.sum(dim=(1, 2, 3)).detach().cpu().numpy()
            gc = dens.sum(dim=(1, 2, 3)).detach().cpu().numpy()
            running_mae += np.abs(pc - gc).sum()
            total_pred_count += pc.sum()
            total_gt_count += gc.sum()
            nimg += imgs.size(0)
            # Warn if the difference in count is large for any sample
            if step % 10 == 0:
                for i in range(len(pc)):
                    if abs(pc[i] - gc[i]) > 1.0:
                        print(
                            f"[warn] batch {step} sample {i}: density sum {gc[i]:.1f} != pred {pc[i]:.1f}"
                        )
    avg_pred = total_pred_count / max(1, nimg)
    avg_gt = total_gt_count / max(1, nimg)
    print(f"Avg pred count: {avg_pred:.1f} vs GT {avg_gt:.1f}")
    return running_mae / max(1, nimg)

In [None]:
@torch.no_grad()
def evaluate(
    model: nn.Module,
    loader: DataLoader,
    device: torch.device,
) -> Tuple[float, float]:
    """Evaluate the model on a validation set and compute MAE and RMSE."""
    model.eval()
    mae, mse, nimg = 0.0, 0.0, 0
    for imgs, dens, _ in tqdm(loader, desc="Val", leave=False):
        imgs = imgs.to(device)
        dens = dens.to(device)
        pred = model(imgs)
        diff = (pred.sum(dim=(1, 2, 3)) - dens.sum(dim=(1, 2, 3))).detach().cpu().numpy()
        mae += np.abs(diff).sum()
        mse += (diff ** 2).sum()
        nimg += imgs.size(0)
    import math

    return mae / max(1, nimg), math.sqrt(mse / max(1, nimg))


best_mae = float("inf")
patience = EARLY_STOP_PATIENCE
bad_epochs = 0
for epoch in range(1, EPOCHS + 1):
    print(f"\nEpoch {epoch}/{EPOCHS}")
    tr_mae = train_epoch(
        model,
        train_loader,
        device,
        opt,
        scaler,
        accum_steps=ACCUM_STEPS,
        criterion=.criterion,
        count_loss_alpha=.count_loss_alpha,
    )
    va_mae, va_rmse = evaluate(model, val_loader, device)
    sched.step()
    print(
        f"Train MAE: {tr_mae:.3f} | Val MAE: {va_mae:.3f} | Val RMSE: {va_rmse:.3f}"
    )
    if va_mae + 1e-6 < best_mae:
        best_mae = va_mae
        bad_epochs = 0
        torch.save(
            {
                "model": model.state_dict(),
                "epoch": epoch,
                "val_mae": va_mae,
                "": vars(),
            },
            save,
        )
        print(f"✅ New best: {save} (Val MAE={va_mae:.3f})")
    else:
        bad_epochs += 1
        print(f"No improvement for {bad_epochs} epoch(s).")
        if bad_epochs >= patience:
            print(
                f"Early stopping at epoch {epoch}. Best Val MAE: {best_mae:.3f}"
            )
            break
print("Training finished. Best Val MAE:", best_mae)

# 6. Evaluation

# 7. Test Set Prediction and Submit Result

In [None]:
class TestDataset1(Dataset):
    """Dataset for test images based on provided sample CSV order."""
    def __init__(
        self, 
        image_dir: str, 
        sample_csv: str, 
        transform=None
    ):
        self.image_dir = image_dir
        self.transform = transform

        # Load image_ids from sample CSV
        with open(sample_csv, "r", newline="", encoding="utf-8") as f:
            reader = csv.reader(f)
            header = next(reader)
            if header != ["image_id", "predicted_count"]:
                raise ValueError(
                    f"Sample CSV header expected ['image_id','predicted_count'], got {header}"
                )
            self.image_ids = [row[0] for row in reader]

    def __len__(self):
        return len(self.image_ids)

    def __getitem__(self, idx):
        img_id = self.image_ids[idx]
        img_path = os.path.join(self.image_dir, img_id + ".jpg")

        if not os.path.exists(img_path):
            # If image not found, return None signal
            return None, img_id

        image = cv2.imread(img_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        if self.transform:
            augmented = self.transform(image=image)
            image = augmented["image"]

        return image, img_id

class TestDataset(Dataset):
    """Dataset for test images."""
    def __init__(
        self, 
        image_dir: str, 
        transform=None
    ):
        self.image_dir = image_dir
        self.transform = transform
        self.image_files = sorted([f for f in os.listdir(image_dir) if f.endswith('.jpg')],
                                 key=lambda x: int(os.path.splitext(x)[0]))

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_name = self.image_files[idx]
        img_path = os.path.join(self.image_dir, img_name)

        image = cv2.imread(img_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        if self.transform:
            augmented = self.transform(image=image)
            image = augmented['image']

        return image, img_name

# --- Prediction Function ---
def generate_sfcn_predictions(
    model,
    test_dir: str,
    output_file: str = "sfcn_submission.csv",
    batch_size: int = 16,
    transform=None,
) -> pd.DataFrame:
    """Generate test predictions using the SFCN model."""

    print("Generating test predictions with SFCN model...")

    # Dataset + DataLoader
    test_dataset = TestDataset(test_dir, transform)
    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=2,
    )

    model.eval()
    predictions = []
    image_names = []
    
    with torch.no_grad():
        for images, names in test_loader:
            images = images.to(device)
            pred_density, pred_count = model(images)

            # Use count predictions for submission
            batch_preds = pred_count.cpu().numpy()
            predictions.extend([max(0, int(round(pred))) for pred in batch_preds])
            image_names.extend(names)

    # Create submission DataFrame
    submission_df = pd.DataFrame({
        'image_id': image_names,
        'predicted_count': predictions
    })

    # Sort by image_id
    submission_df['sort_key'] = submission_df['image_id'].apply(lambda x: int(os.path.splitext(x)[0]))
    submission_df = submission_df.sort_values('sort_key').drop('sort_key', axis=1).reset_index(drop=True)

    # Save submission
    submission_df.to_csv(output_file, index=False)

    print(f"Submission saved to {output_file}")
    print(f"Predictions for {len(submission_df)} test images")

    # Statistics
    pred_counts = submission_df['predicted_count'].values
    print(f"\nTest Predictions Statistics:")
    print(f"Min: {pred_counts.min()}")
    print(f"Max: {pred_counts.max()}")
    print(f"Mean: {pred_counts.mean():.2f}")
    print(f"Median: {np.median(pred_counts):.2f}")

    return submission_df

# Run predictions
submission_df = generate_sfcn_predictions(
    model,
    test_dir=TEST_IMG_DIR,
    batch_size=16,
    # transform=test_transforms
)
submission_df.head(50)