train_csrnet_colab.py
Colab instructions:
1) Upload this file to Colab or paste into a notebook cell.
2) Mount Google Drive if you want to read/write there:
   from google.colab import drive
   drive.mount('/content/drive')
3) Adjust `DATA_ROOT` and `CHECKPOINT_DIR`.

In [1]:
# --- Mount Google Drive ---
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

# --- Check if GPU is available ---
import torch
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", DEVICE)


Mounted at /content/drive
Device: cuda


In [None]:
# --- Imports ---
import os, glob, time
from pathlib import Path

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision.models as models

# --- Config paths (EDIT HERE) ---
DATA_ROOT = "/content/drive/MyDrive/Deepvision_Project"

IMG_GLOB = os.path.join(DATA_ROOT, "torch_images_trainA", "*.pt")
GT_GLOB  = os.path.join(DATA_ROOT, "torch_density_trainA", "*.pt")

CHECKPOINT_DIR = "/content/drive/MyDrive/Deepvision_Project/checkpoints"
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

# --- Hyperparameters ---
BATCH_SIZE = 8
NUM_EPOCHS = 100
LR = 1e-5
NUM_WORKERS = 2 if torch.cuda.is_available() else 0
SAVE_EVERY = 5
PRINT_FREQ = 10


In [None]:
# --- Dataset class that matches filename stems ---
class PtDataset(Dataset):
    def __init__(self, img_glob, gt_glob, transform=None):
        img_paths = sorted(glob.glob(img_glob))
        gt_paths  = sorted(glob.glob(gt_glob))

        img_map = {Path(p).stem: p for p in img_paths}
        gt_map  = {Path(p).stem: p for p in gt_paths}

        common = sorted(set(img_map.keys()) & set(gt_map.keys()))

        if len(common) == 0:
            raise ValueError(
                f"No matching .pt pairs found.\n"
                f"Images: {len(img_paths)}, GT: {len(gt_paths)}.\n"
                f"Check filenames and DATA_ROOT path."
            )

        self.pairs = [(img_map[s], gt_map[s]) for s in common]
        self.transform = transform

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        img_path, gt_path = self.pairs[idx]
        img = torch.load(img_path).float()
        gt  = torch.load(gt_path).float()

        if gt.ndim == 2:
            gt = gt.unsqueeze(0)

        if self.transform:
            img = self.transform(img)

        return img, gt


In [None]:
# --- CSRNet model (VGG frontend + Dilated backend) ---
class CSRNet(nn.Module):
    def __init__(self, load_weights=True):
        super(CSRNet, self).__init__()

        vgg = models.vgg16_bn(pretrained=load_weights)
        features = list(vgg.features.children())

        # frontend (same as CSRNet paper)
        self.frontend = nn.Sequential(*features[:34])

        # backend: dilated convolutions
        self.backend = nn.Sequential(
            nn.Conv2d(512, 512, 3, padding=2, dilation=2),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, 3, padding=2, dilation=2),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, 3, padding=2, dilation=2),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 256, 3, padding=2, dilation=2),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 128, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 64, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 1, 1)
        )

        self._initialize_weights()

    def forward(self, x):
        x = self.frontend(x)
        x = self.backend(x)
        return x

    def _initialize_weights(self):
        for m in self.backend.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.normal_(m.weight, std=0.01)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)


In [None]:
# --- Train one epoch (resize GT to match pred size with count preservation) ---
def train_one_epoch(model, dataloader, criterion, optimizer, device, epoch):
    model.train()
    running_loss = 0.0
    t0 = time.time()

    for i, (img, gt) in enumerate(dataloader, 1):
        img = img.to(device)
        gt  = gt.to(device)

        pred = model(img)  # (B,1,Hp,Wp)

        # GT resize block
        if pred.shape != gt.shape:
            _, _, Hp, Wp = pred.shape
            _, _, Hg, Wg = gt.shape

            if Hg % Hp == 0 and Wg % Wp == 0:
                scale = Hg // Hp
                gt_resized = F.interpolate(gt, size=(Hp, Wp), mode="area") * (scale * scale)
            else:
                gt_resized = F.interpolate(gt, size=(Hp, Wp), mode="bilinear", align_corners=False)

            loss = criterion(pred, gt_resized)
        else:
            loss = criterion(pred, gt)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if i % PRINT_FREQ == 0:
            print(f"[Epoch {epoch}] Batch {i}/{len(dataloader)}  Loss: {running_loss/i:.6f}")

    epoch_loss = running_loss / len(dataloader)
    print(f"Epoch {epoch} completed. Loss: {epoch_loss:.6f}")
    return epoch_loss


# --- Validation function ---
def validate(model, dataloader, criterion, device):
    model.eval()
    running_loss = 0.0

    with torch.no_grad():
        for img, gt in dataloader:
            img = img.to(device)
            gt  = gt.to(device)

            pred = model(img)

            if pred.shape != gt.shape:
                _, _, Hp, Wp = pred.shape
                _, _, Hg, Wg = gt.shape
                if Hg % Hp == 0:
                    scale = Hg // Hp
                    gt_resized = F.interpolate(gt, size=(Hp, Wp), mode="area") * (scale * scale)
                else:
                    gt_resized = F.interpolate(gt, size=(Hp, Wp), mode="bilinear", align_corners=False)

                loss = criterion(pred, gt_resized)
            else:
                loss = criterion(pred, gt)

            running_loss += loss.item()

    return running_loss / len(dataloader)


In [None]:
# --- MAIN TRAINING LOOP ---
def main():
    # Dataset
    dataset = PtDataset(IMG_GLOB, GT_GLOB, transform=None)
    print("Total matched samples:", len(dataset))

    # Split train/val
    val_split = 0.1
    n_val = max(1, int(len(dataset) * val_split))
    n_train = len(dataset) - n_val
    train_set, val_set = random_split(dataset, [n_train, n_val])

    train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True,
                              num_workers=NUM_WORKERS, pin_memory=torch.cuda.is_available())
    val_loader   = DataLoader(val_set, batch_size=BATCH_SIZE, shuffle=False,
                              num_workers=NUM_WORKERS, pin_memory=torch.cuda.is_available())

    # Model
    model = CSRNet(load_weights=True).to(DEVICE)
    criterion = nn.MSELoss().to(DEVICE)
    optimizer = optim.Adam(model.parameters(), lr=LR)

    best_val = float("inf")

    for epoch in range(1, NUM_EPOCHS + 1):
        train_loss = train_one_epoch(model, train_loader, criterion, optimizer, DEVICE, epoch)
        val_loss   = validate(model, val_loader, criterion, DEVICE)

        print(f"Epoch {epoch}: Train={train_loss:.6f}  Val={val_loss:.6f}")

        # Save best
        if val_loss < best_val:
            best_val = val_loss
            torch.save(model.state_dict(), f"{CHECKPOINT_DIR}/csrnet_best.pth")
            print("Saved BEST model.")

        # Save periodic checkpoints
        if epoch % SAVE_EVERY == 0:
            torch.save(model.state_dict(), f"{CHECKPOINT_DIR}/csrnet_epoch{epoch}.pth")
            print("Saved checkpoint.")

    # Final save
    torch.save(model.state_dict(), f"{CHECKPOINT_DIR}/csrnet_final.pth")
    print("Training complete. Final model saved.")

# Run
main()


Total matched samples: 300
[Epoch 1] Batch 10/34  Loss: 1.430713
[Epoch 1] Batch 20/34  Loss: 2.087058
[Epoch 1] Batch 30/34  Loss: 1.894538
Epoch 1 completed. Loss: 1.876620
Epoch 1: Train=1.876620  Val=1.826885
Saved BEST model.
[Epoch 2] Batch 10/34  Loss: 1.639744
[Epoch 2] Batch 20/34  Loss: 1.563399
[Epoch 2] Batch 30/34  Loss: 1.709218
Epoch 2 completed. Loss: 1.585354
Epoch 2: Train=1.585354  Val=1.256913
Saved BEST model.
[Epoch 3] Batch 10/34  Loss: 1.292880
[Epoch 3] Batch 20/34  Loss: 1.250453
[Epoch 3] Batch 30/34  Loss: 1.252015
Epoch 3 completed. Loss: 1.167400
Epoch 3: Train=1.167400  Val=1.169355
Saved BEST model.
[Epoch 4] Batch 10/34  Loss: 1.246922
[Epoch 4] Batch 20/34  Loss: 0.951809
[Epoch 4] Batch 30/34  Loss: 1.020140
Epoch 4 completed. Loss: 1.026025
Epoch 4: Train=1.026025  Val=1.144753
Saved BEST model.
[Epoch 5] Batch 10/34  Loss: 0.775790
[Epoch 5] Batch 20/34  Loss: 0.872756
[Epoch 5] Batch 30/34  Loss: 1.023157
Epoch 5 completed. Loss: 0.988547
Epoch 5: 