In [None]:
!ls -R /kaggle/input/csrnet-dataset


In [1]:
import sys
import time
import math
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import models
from pathlib import Path
from tqdm import tqdm


In [2]:
# Add dataset.py from Kaggle input
sys.path.append("/kaggle/input/csrnet-dataset")

from dataset import CrowdDataset

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Device:", device)

# Correct Part A path
ROOT = Path("/kaggle/input/csrnet-dataset")
TRAIN_A = ROOT / "part_A_final" / "part_A_final" / "train_data"

dsA = CrowdDataset(str(TRAIN_A))
train_loader = DataLoader(dsA, batch_size=1, shuffle=True)

print("Part A images:", len(dsA))


Device: cuda
Part A images: 300


In [3]:
class CSRNet(nn.Module):
    def __init__(self):
        super().__init__()
        vgg = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1)

        self.frontend = nn.Sequential(*list(vgg.features.children())[:23])
        self.backend = nn.Sequential(
            nn.Conv2d(512,256,3,padding=1), nn.ReLU(True),
            nn.Conv2d(256,128,3,padding=1), nn.ReLU(True),
            nn.Conv2d(128,64,3,padding=1), nn.ReLU(True),
            nn.Conv2d(64,32,3,padding=1), nn.ReLU(True),
            nn.Conv2d(32,1,1)
        )

    def forward(self, x):
        x = self.frontend(x)
        x = self.backend(x)
        return x

model = CSRNet().to(device)


Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /root/.cache/torch/hub/checkpoints/vgg16-397923af.pth
100%|██████████| 528M/528M [00:02<00:00, 227MB/s] 


In [4]:
criterion = nn.MSELoss(reduction='sum')
optimizer = optim.Adam(model.parameters(), lr=1e-5)

batch_losses = []
epoch_losses = []


In [None]:
EPOCHS = 100

for epoch in range(1, EPOCHS + 1):
    model.train()
    running_epoch_loss = 0.0
    start = time.time()

    pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{EPOCHS}", leave=False)
    for imgs, dens in pbar:
        imgs, dens = imgs.to(device), dens.to(device)

        optimizer.zero_grad()
        out = model(imgs)
        
        if out.shape != dens.shape:
            out = F.interpolate(out, size=dens.shape[2:], mode='bilinear', align_corners=False)

        loss = criterion(out, dens)
        loss.backward()
        optimizer.step()

        batch_losses.append(loss.item())
        running_epoch_loss += loss.item()

        pbar.set_postfix({"batch_loss": loss.item()})

    epoch_loss = running_epoch_loss / len(train_loader)
    epoch_losses.append(epoch_loss)

    print(f"Epoch {epoch}/{EPOCHS} | Loss: {epoch_loss:.4f} | Time: {time.time()-start:.2f}s")


In [None]:
EPOCHS = 150

for epoch in range(1, EPOCHS + 1):
    model.train()
    running_epoch_loss = 0.0
    start = time.time()

    pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{EPOCHS}", leave=False)
    for imgs, dens in pbar:
        imgs, dens = imgs.to(device), dens.to(device)

        optimizer.zero_grad()
        out = model(imgs)

        if out.shape != dens.shape:
            out = F.interpolate(out, size=dens.shape[2:], mode='bilinear', align_corners=False)

        loss = criterion(out, dens)
        loss.backward()
        optimizer.step()

        batch_losses.append(loss.item())
        running_epoch_loss += loss.item()

        pbar.set_postfix({"batch_loss": loss.item()})

    epoch_loss = running_epoch_loss / len(train_loader)
    epoch_losses.append(epoch_loss)

    print(f"Epoch {epoch}/{EPOCHS} | Loss: {epoch_loss:.4f} | Time: {time.time()-start:.2f}s")


In [None]:
torch.save(model.state_dict(), "/kaggle/working/csrnet_weights.pth")
print("Final model saved!")


In [None]:
from torchvision import transforms

transform = transforms.Compose([
    transforms.ToTensor(),
])


In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])


In [None]:
# Paths
test_img_path = "/kaggle/input/csrnet-dataset/part_A_final/part_A_final/test_data/images"
test_den_path = "/kaggle/input/csrnet-dataset/part_A_final/part_A_final/test_data/ground_truth"


In [None]:
from torch.utils.data import DataLoader

# Correct test_data root folder
test_root = "/kaggle/input/csrnet-dataset/part_A_final/part_A_final/test_data"

# Create test dataset
test_dataset = CrowdDataset(test_root)

# DataLoader
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)


In [None]:
model = CSRNet().to(device)
model.load_state_dict(torch.load("/kaggle/working/csrnet_weights.pth", map_location=device))
model.eval()


In [None]:
train_root = "/kaggle/input/csrnet-dataset/part_A_final/part_A_final/train_data"

train_eval_dataset = CrowdDataset(train_root)
train_eval_loader = DataLoader(train_eval_dataset, batch_size=1, shuffle=False)


In [None]:
import torch
import numpy as np
from tqdm import tqdm
import torch.nn.functional as F

# Load trained model
model = CSRNet().to(device)
model.load_state_dict(torch.load("/kaggle/working/csrnet_weights.pth", map_location=device))
model.eval()

mae = 0.0
mse = 0.0

with torch.no_grad():
    for i, (imgs, dens) in enumerate(tqdm(train_eval_loader)):
        imgs = imgs.to(device)
        dens = dens.to(device)

        out = model(imgs)

        if out.shape != dens.shape:
            out = F.interpolate(out, size=dens.shape[2:], mode="bilinear", align_corners=False)

        pred_count = out.sum().item()
        gt_count = dens.sum().item()

        error = abs(pred_count - gt_count)

        mae += error
        mse += error ** 2

# Final results
mae /= len(train_eval_loader)
mse /= len(train_eval_loader)
rmse = np.sqrt(mse)

print("Train Data Evaluation:")
print(f"MAE  = {mae:.3f}")
print(f"MSE  = {mse:.3f}")
print(f"RMSE = {rmse:.3f}")


In [None]:
with torch.no_grad():
    for i, (imgs, dens) in enumerate(train_eval_loader):
        imgs = imgs.to(device)
        dens = dens.to(device)

        out = model(imgs)

        if out.shape != dens.shape:
            out = F.interpolate(out, size=dens.shape[2:], mode="bilinear")

        pred = out.sum().item()
        gt = dens.sum().item()

        print(f"Image {i+1} | GT: {gt:.1f} | Pred: {pred:.1f} | Error: {abs(gt-pred):.1f}")


In [13]:
model = CSRNet().to(device)
model.load_state_dict(torch.load(
    "/kaggle/input/csrnet-1/pytorch/trained-part-a/1/csrnet_weights.pth",
    map_location=device
))
print("Model loaded successfully!")


Model loaded successfully!


In [14]:
start_epoch = 151
end_epoch = 300

for epoch in range(start_epoch, end_epoch + 1):
    model.train()
    running_epoch_loss = 0.0
    start = time.time()

    pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{end_epoch}", leave=False)
    for imgs, dens in pbar:
        imgs, dens = imgs.to(device), dens.to(device)

        optimizer.zero_grad()
        out = model(imgs)

        if out.shape != dens.shape:
            out = F.interpolate(out, size=dens.shape[2:], mode='bilinear', align_corners=False)

        loss = criterion(out, dens)
        loss.backward()
        optimizer.step()

        batch_losses.append(loss.item())
        running_epoch_loss += loss.item()

        pbar.set_postfix({"batch_loss": loss.item()})

    epoch_loss = running_epoch_loss / len(train_loader)
    epoch_losses.append(epoch_loss)

    print(f"Epoch {epoch}/{end_epoch} | Loss: {epoch_loss:.4f} | Time: {time.time()-start:.2f}s")

    if epoch % 10 == 0:
        torch.save(model.state_dict(), f"/kaggle/working/csrnet_epoch_{epoch}.pth")
        print(f"Checkpoint saved for epoch {epoch}")


                                                                                    

Epoch 151/300 | Loss: 0.0149 | Time: 46.68s


                                                                                    

Epoch 152/300 | Loss: 0.0149 | Time: 39.31s


                                                                                    

Epoch 153/300 | Loss: 0.0149 | Time: 39.31s


                                                                                    

Epoch 154/300 | Loss: 0.0149 | Time: 39.30s


                                                                                    

Epoch 155/300 | Loss: 0.0149 | Time: 39.32s


                                                                                    

Epoch 156/300 | Loss: 0.0149 | Time: 39.41s


                                                                                    

Epoch 157/300 | Loss: 0.0149 | Time: 39.36s


                                                                                    

Epoch 158/300 | Loss: 0.0149 | Time: 39.43s


                                                                                    

Epoch 159/300 | Loss: 0.0149 | Time: 39.56s


                                                                                    

Epoch 160/300 | Loss: 0.0149 | Time: 39.49s
Checkpoint saved for epoch 160


                                                                                    

Epoch 161/300 | Loss: 0.0149 | Time: 39.52s


                                                                                    

Epoch 162/300 | Loss: 0.0149 | Time: 39.47s


                                                                                    

Epoch 163/300 | Loss: 0.0149 | Time: 39.52s


                                                                                    

Epoch 164/300 | Loss: 0.0149 | Time: 39.43s


                                                                                    

Epoch 165/300 | Loss: 0.0149 | Time: 39.43s


                                                                                    

Epoch 166/300 | Loss: 0.0149 | Time: 39.44s


                                                                                    

Epoch 167/300 | Loss: 0.0149 | Time: 39.51s


                                                                                    

Epoch 168/300 | Loss: 0.0149 | Time: 39.48s


                                                                                    

Epoch 169/300 | Loss: 0.0149 | Time: 39.49s


                                                                                    

Epoch 170/300 | Loss: 0.0149 | Time: 39.49s
Checkpoint saved for epoch 170


                                                                                    

Epoch 171/300 | Loss: 0.0149 | Time: 39.44s


                                                                                    

Epoch 172/300 | Loss: 0.0149 | Time: 39.28s


                                                                                    

Epoch 173/300 | Loss: 0.0149 | Time: 39.45s


                                                                                    

Epoch 174/300 | Loss: 0.0149 | Time: 39.36s


                                                                                    

Epoch 175/300 | Loss: 0.0149 | Time: 39.54s


                                                                                    

Epoch 176/300 | Loss: 0.0149 | Time: 39.60s


                                                                                    

Epoch 177/300 | Loss: 0.0149 | Time: 39.52s


                                                                                    

Epoch 178/300 | Loss: 0.0149 | Time: 39.47s


                                                                                    

Epoch 179/300 | Loss: 0.0149 | Time: 39.43s


                                                                                    

Epoch 180/300 | Loss: 0.0149 | Time: 39.36s
Checkpoint saved for epoch 180


                                                                                    

Epoch 181/300 | Loss: 0.0149 | Time: 39.39s


                                                                                    

Epoch 182/300 | Loss: 0.0149 | Time: 39.37s


                                                                                    

Epoch 183/300 | Loss: 0.0149 | Time: 39.58s


                                                                                    

Epoch 184/300 | Loss: 0.0149 | Time: 39.56s


                                                                                    

Epoch 185/300 | Loss: 0.0149 | Time: 39.58s


                                                                                    

Epoch 186/300 | Loss: 0.0149 | Time: 39.46s


                                                                                    

Epoch 187/300 | Loss: 0.0149 | Time: 39.46s


                                                                                    

Epoch 188/300 | Loss: 0.0149 | Time: 39.41s


                                                                                    

Epoch 189/300 | Loss: 0.0149 | Time: 39.51s


                                                                                    

Epoch 190/300 | Loss: 0.0149 | Time: 39.58s
Checkpoint saved for epoch 190


                                                                                    

Epoch 191/300 | Loss: 0.0149 | Time: 39.44s


                                                                                    

Epoch 192/300 | Loss: 0.0149 | Time: 39.38s


                                                                                    

Epoch 193/300 | Loss: 0.0149 | Time: 39.31s


                                                                                    

Epoch 194/300 | Loss: 0.0149 | Time: 39.31s


                                                                                    

Epoch 195/300 | Loss: 0.0149 | Time: 39.39s


                                                                                    

Epoch 196/300 | Loss: 0.0149 | Time: 39.46s


                                                                                    

Epoch 197/300 | Loss: 0.0149 | Time: 39.34s


                                                                                    

Epoch 198/300 | Loss: 0.0149 | Time: 39.42s


                                                                                    

Epoch 199/300 | Loss: 0.0149 | Time: 39.34s


                                                                                    

Epoch 200/300 | Loss: 0.0149 | Time: 39.50s
Checkpoint saved for epoch 200


                                                                                    

Epoch 201/300 | Loss: 0.0149 | Time: 39.56s


                                                                                    

Epoch 202/300 | Loss: 0.0149 | Time: 39.53s


                                                                                    

Epoch 203/300 | Loss: 0.0149 | Time: 39.54s


                                                                                    

Epoch 204/300 | Loss: 0.0149 | Time: 39.53s


                                                                                    

Epoch 205/300 | Loss: 0.0149 | Time: 39.51s


                                                                                    

Epoch 206/300 | Loss: 0.0149 | Time: 39.38s


                                                                                    

Epoch 207/300 | Loss: 0.0149 | Time: 39.48s


                                                                                    

Epoch 208/300 | Loss: 0.0149 | Time: 39.52s


                                                                                    

Epoch 209/300 | Loss: 0.0149 | Time: 39.45s


                                                                                    

Epoch 210/300 | Loss: 0.0149 | Time: 39.54s
Checkpoint saved for epoch 210


                                                                                    

Epoch 211/300 | Loss: 0.0149 | Time: 39.63s


                                                                                    

Epoch 212/300 | Loss: 0.0149 | Time: 39.53s


                                                                                    

Epoch 213/300 | Loss: 0.0149 | Time: 39.58s


                                                                                    

Epoch 214/300 | Loss: 0.0149 | Time: 39.50s


                                                                                    

Epoch 215/300 | Loss: 0.0149 | Time: 39.40s


                                                                                    

Epoch 216/300 | Loss: 0.0149 | Time: 39.41s


                                                                                    

Epoch 217/300 | Loss: 0.0149 | Time: 39.49s


                                                                                    

Epoch 218/300 | Loss: 0.0149 | Time: 39.41s


                                                                                    

Epoch 219/300 | Loss: 0.0149 | Time: 39.45s


                                                                                    

Epoch 220/300 | Loss: 0.0149 | Time: 39.47s
Checkpoint saved for epoch 220


                                                                                    

Epoch 221/300 | Loss: 0.0149 | Time: 39.57s


                                                                                    

Epoch 222/300 | Loss: 0.0149 | Time: 39.46s


                                                                                    

Epoch 223/300 | Loss: 0.0149 | Time: 39.45s


                                                                                    

Epoch 224/300 | Loss: 0.0149 | Time: 39.39s


                                                                                    

Epoch 225/300 | Loss: 0.0149 | Time: 39.50s


                                                                                    

Epoch 226/300 | Loss: 0.0149 | Time: 39.38s


                                                                                    

Epoch 227/300 | Loss: 0.0149 | Time: 39.41s


                                                                                    

Epoch 228/300 | Loss: 0.0149 | Time: 39.32s


                                                                                    

Epoch 229/300 | Loss: 0.0149 | Time: 39.44s


                                                                                    

Epoch 230/300 | Loss: 0.0149 | Time: 39.56s
Checkpoint saved for epoch 230


                                                                                    

Epoch 231/300 | Loss: 0.0149 | Time: 39.46s


                                                                                    

Epoch 232/300 | Loss: 0.0149 | Time: 39.36s


                                                                                    

Epoch 233/300 | Loss: 0.0149 | Time: 39.35s


                                                                                    

Epoch 234/300 | Loss: 0.0149 | Time: 39.32s


                                                                                    

Epoch 235/300 | Loss: 0.0149 | Time: 39.43s


                                                                                    

Epoch 236/300 | Loss: 0.0149 | Time: 39.44s


                                                                                    

Epoch 237/300 | Loss: 0.0149 | Time: 39.55s


                                                                                    

Epoch 238/300 | Loss: 0.0149 | Time: 39.69s


                                                                                    

Epoch 239/300 | Loss: 0.0149 | Time: 39.44s


                                                                                    

Epoch 240/300 | Loss: 0.0149 | Time: 39.29s
Checkpoint saved for epoch 240


                                                                                    

Epoch 241/300 | Loss: 0.0149 | Time: 39.26s


                                                                                    

Epoch 242/300 | Loss: 0.0149 | Time: 39.22s


                                                                                    

Epoch 243/300 | Loss: 0.0149 | Time: 39.35s


                                                                                    

Epoch 244/300 | Loss: 0.0149 | Time: 39.40s


                                                                                    

Epoch 245/300 | Loss: 0.0149 | Time: 39.64s


                                                                                    

Epoch 246/300 | Loss: 0.0149 | Time: 39.50s


                                                                                    

Epoch 247/300 | Loss: 0.0149 | Time: 39.44s


                                                                                    

Epoch 248/300 | Loss: 0.0149 | Time: 39.29s


                                                                                    

Epoch 249/300 | Loss: 0.0149 | Time: 39.42s


                                                                                    

Epoch 250/300 | Loss: 0.0149 | Time: 39.31s
Checkpoint saved for epoch 250


                                                                                    

Epoch 251/300 | Loss: 0.0149 | Time: 39.40s


                                                                                    

Epoch 252/300 | Loss: 0.0149 | Time: 39.39s


                                                                                    

Epoch 253/300 | Loss: 0.0149 | Time: 39.36s


                                                                                    

Epoch 254/300 | Loss: 0.0149 | Time: 39.37s


                                                                                    

Epoch 255/300 | Loss: 0.0149 | Time: 39.51s


                                                                                    

Epoch 256/300 | Loss: 0.0149 | Time: 39.40s


                                                                                    

Epoch 257/300 | Loss: 0.0149 | Time: 39.40s


                                                                                    

Epoch 258/300 | Loss: 0.0149 | Time: 39.36s


                                                                                    

Epoch 259/300 | Loss: 0.0149 | Time: 39.38s


                                                                                    

Epoch 260/300 | Loss: 0.0149 | Time: 39.35s
Checkpoint saved for epoch 260


                                                                                    

Epoch 261/300 | Loss: 0.0149 | Time: 39.42s


                                                                                    

Epoch 262/300 | Loss: 0.0149 | Time: 39.42s


                                                                                    

Epoch 263/300 | Loss: 0.0149 | Time: 39.36s


                                                                                    

Epoch 264/300 | Loss: 0.0149 | Time: 39.55s


                                                                                    

Epoch 265/300 | Loss: 0.0149 | Time: 39.47s


                                                                                    

Epoch 266/300 | Loss: 0.0149 | Time: 39.58s


                                                                                    

Epoch 267/300 | Loss: 0.0149 | Time: 39.45s


                                                                                    

Epoch 268/300 | Loss: 0.0149 | Time: 39.34s


                                                                                    

Epoch 269/300 | Loss: 0.0149 | Time: 39.38s


                                                                                    

Epoch 270/300 | Loss: 0.0149 | Time: 39.34s
Checkpoint saved for epoch 270


                                                                                    

Epoch 271/300 | Loss: 0.0149 | Time: 39.29s


                                                                                    

Epoch 272/300 | Loss: 0.0149 | Time: 39.37s


                                                                                    

Epoch 273/300 | Loss: 0.0149 | Time: 39.53s


                                                                                    

Epoch 274/300 | Loss: 0.0149 | Time: 39.39s


                                                                                    

Epoch 275/300 | Loss: 0.0149 | Time: 39.47s


                                                                                    

Epoch 276/300 | Loss: 0.0149 | Time: 39.48s


                                                                                    

Epoch 277/300 | Loss: 0.0149 | Time: 39.50s


                                                                                    

Epoch 278/300 | Loss: 0.0149 | Time: 39.42s


                                                                                    

Epoch 279/300 | Loss: 0.0149 | Time: 39.36s


                                                                                    

Epoch 280/300 | Loss: 0.0149 | Time: 39.40s
Checkpoint saved for epoch 280


                                                                                    

Epoch 281/300 | Loss: 0.0149 | Time: 39.53s


                                                                                    

Epoch 282/300 | Loss: 0.0149 | Time: 39.50s


                                                                                    

Epoch 283/300 | Loss: 0.0149 | Time: 39.37s


                                                                                    

Epoch 284/300 | Loss: 0.0149 | Time: 39.61s


                                                                                    

Epoch 285/300 | Loss: 0.0149 | Time: 39.34s


                                                                                    

Epoch 286/300 | Loss: 0.0149 | Time: 39.55s


                                                                                    

Epoch 287/300 | Loss: 0.0149 | Time: 39.59s


                                                                                    

Epoch 288/300 | Loss: 0.0149 | Time: 39.34s


                                                                                    

Epoch 289/300 | Loss: 0.0149 | Time: 39.19s


                                                                                    

Epoch 290/300 | Loss: 0.0149 | Time: 39.38s
Checkpoint saved for epoch 290


                                                                                    

Epoch 291/300 | Loss: 0.0149 | Time: 39.28s


                                                                                    

Epoch 292/300 | Loss: 0.0149 | Time: 39.29s


                                                                                    

Epoch 293/300 | Loss: 0.0149 | Time: 39.28s


                                                                                    

Epoch 294/300 | Loss: 0.0149 | Time: 39.34s


                                                                                    

Epoch 295/300 | Loss: 0.0149 | Time: 39.35s


                                                                                    

Epoch 296/300 | Loss: 0.0149 | Time: 39.43s


                                                                                    

Epoch 297/300 | Loss: 0.0149 | Time: 39.47s


                                                                                    

Epoch 298/300 | Loss: 0.0149 | Time: 39.34s


                                                                                    

Epoch 299/300 | Loss: 0.0149 | Time: 39.24s


                                                                                    

Epoch 300/300 | Loss: 0.0149 | Time: 39.36s
Checkpoint saved for epoch 300




In [17]:
torch.save(model.state_dict(), "/kaggle/working/csrnet_weights.pth")
print("Final Model Saved!")


Final Model Saved!


In [18]:
train_root = "/kaggle/input/csrnet-dataset/part_A_final/part_A_final/train_data"
train_eval_dataset = CrowdDataset(train_root)
train_eval_loader = DataLoader(train_eval_dataset, batch_size=1, shuffle=False)

import numpy as np
import torch.nn.functional as F
from tqdm import tqdm

mae = 0.0
mse = 0.0

with torch.no_grad():
    for imgs, dens in tqdm(train_eval_loader):
        imgs, dens = imgs.to(device), dens.to(device)
        out = model(imgs)

        if out.shape != dens.shape:
            out = F.interpolate(out, size=dens.shape[2:], mode='bilinear')

        pred = out.sum().item()
        gt = dens.sum().item()

        err = abs(pred - gt)
        mae += err
        mse += err**2

mae /= len(train_eval_loader)
rmse = np.sqrt(mse / len(train_eval_loader))

print("TRAIN DATA RESULTS:")
print(f"MAE  = {mae:.3f}")
print(f"RMSE = {rmse:.3f}")


100%|██████████| 300/300 [00:17<00:00, 17.20it/s]

TRAIN DATA RESULTS:
MAE  = 25.880
RMSE = 32.325





In [19]:
model = CSRNet().to(device)
model.load_state_dict(torch.load("/kaggle/working/csrnet_epoch_300.pth", map_location=device))
model.eval()

print("Loaded FINAL trained model (epoch 300)")


Loaded FINAL trained model (epoch 300)


In [22]:
model = CSRNet().to(device)
model.load_state_dict(torch.load(
    "/kaggle/working/csrnet_epoch_300.pth",
    map_location=device
))
model.eval()

print("Loaded FINAL model (epoch 300)")


Loaded FINAL model (epoch 300)


In [27]:
print(train_loader.dataset.img_dir)


/kaggle/input/csrnet-dataset/part_A_final/part_A_final/train_data/images


In [28]:
train_root = "/kaggle/input/csrnet-dataset/part_A_final/part_A_final/train_data"
train_eval_dataset = CrowdDataset(train_root)
train_eval_loader = DataLoader(train_eval_dataset, batch_size=1, shuffle=False)


In [29]:
model = CSRNet().to(device)
model.load_state_dict(torch.load(
    "/kaggle/working/csrnet_epoch_300.pth",
    map_location=device
))
model.eval()

print("Loaded FINAL model (epoch 300)")


Loaded FINAL model (epoch 300)


In [31]:
test_root = "/kaggle/input/csrnet-dataset/part_A_final/part_A_final/test_data"
test_dataset = CrowdDataset(test_root)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

print("Test dataset size:", len(test_dataset))


Test dataset size: 182


In [32]:
model = CSRNet().to(device)
model.load_state_dict(torch.load(
    "/kaggle/working/csrnet_epoch_300.pth",
    map_location=device
))
model.eval()

print("Loaded FINAL model (epoch 300)")


Loaded FINAL model (epoch 300)


In [34]:
import os

save_root = "/kaggle/working/partA_test_preprocessed"
os.makedirs(save_root, exist_ok=True)
os.makedirs(os.path.join(save_root, "images"), exist_ok=True)
os.makedirs(os.path.join(save_root, "density_maps"), exist_ok=True)

print("Folders created successfully!")


Folders created successfully!


In [38]:
import scipy.io as sio

gt_file = "/kaggle/input/csrnet-dataset/part_A_final/part_A_final/test_data/ground_truth/GT_IMG_1.mat"
data = sio.loadmat(gt_file)

print(data.keys())


dict_keys(['__header__', '__version__', '__globals__', 'image_info'])


In [39]:
import scipy.io as sio

def generate_density_map(img_path, gt_path):
    img = Image.open(img_path)
    img = np.array(img)

    # Load MATLAB .mat file (works for CSRNet)
    mat = sio.loadmat(gt_path)
    pts = mat["image_info"][0][0][0][0][0]  # shape Nx2

    density = np.zeros((img.shape[0], img.shape[1]), dtype=np.float32)

    for x, y in pts:
        if 0 <= int(y) < density.shape[0] and 0 <= int(x) < density.shape[1]:
            density[int(y), int(x)] = 1

    density = gaussian_filter(density, sigma=15)

    return density


In [7]:
import os
import numpy as np
import scipy.io as sio
from PIL import Image
from scipy.ndimage import gaussian_filter
from tqdm import tqdm

test_img_path = "/kaggle/input/csrnet-dataset/part_A_final/part_A_final/test_data/images"
test_gt_path  = "/kaggle/input/csrnet-dataset/part_A_final/part_A_final/test_data/ground_truth"

save_root = "/kaggle/working/partA_test_preprocessed"
os.makedirs(os.path.join(save_root, "images"), exist_ok=True)
os.makedirs(os.path.join(save_root, "density_maps"), exist_ok=True)

def generate_density_map(img_path, gt_path):
    img = Image.open(img_path)
    img = np.array(img)

    # Load .mat file using scipy
    mat = sio.loadmat(gt_path)
    pts = mat["image_info"][0][0][0][0][0]

    density = np.zeros((img.shape[0], img.shape[1]), dtype=np.float32)

    for x, y in pts:
        x, y = int(x), int(y)
        if 0 <= y < density.shape[0] and 0 <= x < density.shape[1]:
            density[y, x] = 1

    density = gaussian_filter(density, sigma=15)

    return density


img_files = sorted(os.listdir(test_img_path))

for img_name in tqdm(img_files):
    if img_name.endswith(".jpg"):
        img_path = os.path.join(test_img_path, img_name)

        img_number = img_name.replace("IMG_", "").replace(".jpg", "")
        gt_name = f"GT_IMG_{img_number}.mat"
        gt_path = os.path.join(test_gt_path, gt_name)

        density_map = generate_density_map(img_path, gt_path)

        # Save output
        Image.open(img_path).save(os.path.join(save_root, "images", img_name))
        np.save(os.path.join(save_root, "density_maps", img_name.replace(".jpg", ".npy")), density_map)

print("✔ Test density maps generated successfully!")


100%|██████████| 182/182 [00:18<00:00,  9.75it/s]

✔ Test density maps generated successfully!





In [10]:
# ---------------------------
# 1) IMPORTS
# ---------------------------
import sys
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
from torchvision import models
from torch.utils.data import DataLoader
from pathlib import Path

# ---------------------------
# 2) Load dataset.py
# ---------------------------
sys.path.append("/kaggle/input/csrnet-dataset")
from dataset import CrowdDataset    # NOW CrowdDataset is defined

# ---------------------------
# 3) Define CSRNet model
# ---------------------------
class CSRNet(nn.Module):
    def __init__(self):
        super().__init__()
        vgg = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1)

        self.frontend = nn.Sequential(*list(vgg.features.children())[:23])
        self.backend = nn.Sequential(
            nn.Conv2d(512,256,3,padding=1), nn.ReLU(True),
            nn.Conv2d(256,128,3,padding=1), nn.ReLU(True),
            nn.Conv2d(128,64,3,padding=1), nn.ReLU(True),
            nn.Conv2d(64,32,3,padding=1), nn.ReLU(True),
            nn.Conv2d(32,1,1)
        )

    def forward(self, x):
        x = self.frontend(x)
        x = self.backend(x)
        return x

# ---------------------------
# 4) Device
# ---------------------------
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Device:", device)


Device: cuda


In [11]:
test_dataset = CrowdDataset("/kaggle/working/partA_test_preprocessed")
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

print("Test images:", len(test_dataset))


Test images: 182


In [17]:
import torch

MODEL_PATH = "/kaggle/input/csrnet-1/pytorch/trained-part-a/1/csrnet_weights.pth"

model = CSRNet().to(device)
model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
model.eval()

print("Model loaded successfully!")


Model loaded successfully!


In [18]:
import numpy as np
import torch.nn.functional as F
from tqdm import tqdm

mae = 0
mse = 0

model.eval()

with torch.no_grad():
    for imgs, dens in tqdm(test_loader):
        imgs = imgs.to(device)
        dens = dens.to(device)

        out = model(imgs)

        # Fix size mismatch
        if out.shape != dens.shape:
            out = F.interpolate(out, size=dens.shape[2:], mode='bilinear', align_corners=False)

        pred = out.sum().item()
        gt = dens.sum().item()

        error = abs(pred - gt)
        mae += error
        mse += error**2

mae /= len(test_loader)
rmse = np.sqrt(mse / len(test_loader))

print("\nTEST DATA RESULTS:")
print(f"MAE  = {mae:.2f}")
print(f"RMSE = {rmse:.2f}")


100%|██████████| 182/182 [00:10<00:00, 17.03it/s]


TEST DATA RESULTS:
MAE  = 73.24
RMSE = 118.58



