In [1]:
#importing useful modeuls
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import torch
import os
from PIL import Image
import torch.nn as nn
import torch.nn.functional as F
import random
from torch.utils.data import Dataset, DataLoader

In [2]:
root = "stage1_train"
folders = [d for d in os.listdir(root) if os.path.isdir(os.path.join(root, d))]

In [None]:
class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False),
            nn.GroupNorm(8, out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False),
            nn.GroupNorm(8, out_ch),
            nn.ReLU(inplace=True),
        )
    def forward(self, x):
        return self.net(x)

class Down(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.pool = nn.MaxPool2d(2)
        self.conv = DoubleConv(in_ch, out_ch)
    def forward(self, x):
        return self.conv(self.pool(x))

class Up(nn.Module):
    def __init__(self, in_ch, skip_ch, out_ch):
        super().__init__()
        self.up = nn.Upsample(scale_factor=2, mode="bilinear")
        self.reduce = nn.Conv2d(in_ch, in_ch // 2, kernel_size=1, bias=False)
        up_ch = in_ch // 2

        self.conv = DoubleConv(up_ch + skip_ch, out_ch)

    def forward(self, x, skip):
        x = self.up(x)
        x = self.reduce(x)
        x = torch.cat([skip, x], dim=1)
        return self.conv(x)

class UNetLiteDensity(nn.Module):
    def __init__(self, in_channels=1, base=16, out_activation="relu", patch_size=256):
        super().__init__()
        c1, c2, c3, c4, c5 = base, base*2, base*4, base*8, base*16

        self.inc = DoubleConv(in_channels, c1)
        self.d1  = Down(c1, c2)
        self.d2  = Down(c2, c3)
        self.d3  = Down(c3, c4)
        self.d4  = Down(c4, c5)

        self.u1 = Up(c5, c4, c4)
        self.u2 = Up(c4, c3, c3)
        self.u3 = Up(c3, c2, c2)
        self.u4 = Up(c2, c1, c1)

        self.outc1 = nn.Conv2d(c1, 1, kernel_size=1)
        self.outc2 = nn.Conv2d(1, 1, kernel_size=patch_size)

        if out_activation == "relu":
            self.act = nn.ReLU(inplace=True)
        elif out_activation == "softplus":
            self.act = nn.Softplus()
        else:
            print("WRONG THING")
    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.d1(x1)
        x3 = self.d2(x2)
        x4 = self.d3(x3)
        x5 = self.d4(x4)

        x = self.u1(x5, x4)
        x = self.u2(x,  x3)
        x = self.u3(x,  x2)
        x = self.u4(x,  x1)

        x = self.outc1(x)
        x = self.outc2(x)
        return self.act(x)


In [10]:
def build_patch_index_tiled(root, folder_ids, P=256):
    patch_index = []
    for sample_id in folder_ids:
        sample_dir = os.path.join(root, sample_id)
        x_full = np.load(os.path.join(sample_dir, "image.npy"), mmap_mode="r")
        H, W = x_full.shape

        n_h = (H + P - 1) // P
        n_w = (W + P - 1) // P

        for i in range(n_h):
            top = i * P
            for j in range(n_w):
                left = j * P
                patch_index.append((sample_id, top, left))

    return patch_index


In [11]:
class PatchDataset(Dataset):
    def __init__(self, root, patch_indices, P=256, invert=True):
        self.root = root
        self.patch_indices = patch_indices
        self.P = P
        self.invert = invert

    def __len__(self):
        n = len(self.patch_indices)
        return 2 * n if self.invert else n
    
    def __getitem__(self, idx):
        n = len(self.patch_indices)
        inverted = self.invert and (idx >= n)
        if inverted:
            idx = idx - n

        sample_id, top, left = self.patch_indices[idx]
        sample_dir = os.path.join(self.root, sample_id)

        x_full = np.load(os.path.join(sample_dir, "image.npy"), mmap_mode="r")
        y_full = np.load(os.path.join(sample_dir, "density_varfactor0.2.npy"), mmap_mode="r")

        x_shape, y_shape = x_full.shape, y_full.shape

        assert x_shape == y_shape
        assert x_full.ndim == 2 and y_full.ndim == 2

        H, W = x_shape

        H_pad = ((H + self.P - 1) // self.P) * self.P
        W_pad = ((W + self.P - 1) // self.P) * self.P

        pad_h = H_pad - H
        pad_w = W_pad - W

        if pad_h > 0 or pad_w > 0:
            x_full = np.pad(x_full, ((0, pad_h), (0, pad_w)), mode="constant", constant_values=0)
            y_full = np.pad(y_full, ((0, pad_h), (0, pad_w)), mode="constant", constant_values=0)

        x = x_full[top:top+self.P, left:left+self.P].astype(np.float32, copy=False)
        y = y_full[top:top+self.P, left:left+self.P].astype(np.float32, copy=False)

        assert x.shape == (self.P, self.P) and y.shape == (self.P, self.P)

        if inverted:
            x = 1.0 - x

        x = torch.from_numpy(x).unsqueeze(0)
        y = torch.from_numpy(y).unsqueeze(0)
        return x, y

In [None]:
P = 256

def split_train_val(folder_ids, val_frac=0.2, seed=123):

    folder_ids = list(folder_ids)
    rng = random.Random(seed)
    rng.shuffle(folder_ids)
    n_val = int(round(len(folder_ids) * val_frac))
    val_ids = sorted(folder_ids[:n_val])
    train_ids = sorted(folder_ids[n_val:])
    
    return train_ids, val_ids

train_ids, val_ids = split_train_val(folders, val_frac=0.2, seed=123)
train_patch_indices = build_patch_index_tiled(root, train_ids, P=P)
val_patch_indices   = build_patch_index_tiled(root, val_ids, P=P)

train_dataset = PatchDataset(root, train_patch_indices, P=P, invert=True)
val_dataset = PatchDataset(root, val_patch_indices, P=P, invert=True)

In [13]:
train_loader = DataLoader(train_dataset,
                          batch_size=4,
                          shuffle=True,
                          drop_last=False,
                          num_workers=0)

val_loader = DataLoader(val_dataset,
                        batch_size=4,
                        shuffle=False,
                        drop_last=False,
                        num_workers=0)

In [14]:
def count_huber(y_pred, y_true):
    pred_c = y_pred.sum(dim=(1,2,3))
    true_c = y_true.sum(dim=(1,2,3))
    return F.smooth_l1_loss(pred_c, true_c)

def total_loss(y_pred, y_true):
    cnt = count_huber(y_pred, y_true)
    return cnt

In [None]:
model = UNetLiteDensity(in_channels=1, base=16, out_activation="relu", patch_size=P)

loss_fn = total_loss
stride = 256
lr = 0.0001
weight_decay = 0.0001
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode="min", factor=0.3, patience=2, min_lr=1e-6
)

def count_mae(pred, y):
    pred_c = pred.sum(dim=(1,2,3))
    true_c = y.sum(dim=(1,2,3))
    return (pred_c - true_c).abs().mean()

num_epochs = 40

print("Train dataset len:", len(train_dataset), flush=True)
print("Val dataset len:", len(val_dataset), flush=True)
print("Train batches per epoch:", len(train_loader), flush=True)
print("Val batches per epoch:", len(val_loader), flush=True)

train_losses = []
val_losses = []
full_val_losses = []

train_maes = []
val_maes = []
full_val_maes = []

train_epoch_nums = []
val_epoch_nums = []

hyperparameters = {"lr": lr, 
                   "weight_decay": weight_decay}
print(hyperparameters)

max_norm = 5.0
clip_count = 0
for epoch in range(num_epochs):

    
    print(f"\nEpoch {epoch} START", flush=True)
    model.train()
    train_loss_acc = 0.0
    train_mae_acc = 0.0
    nb = 0

    for step, (x, y) in enumerate(train_loader):
        if step == 0:
            print("  First train batch loaded:", x.shape, y.shape, flush=True)

        optimizer.zero_grad(set_to_none=True)

        pred = model(x)
        loss = loss_fn(pred, y)
    
        loss.backward()
        grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=max_norm)

        if grad_norm.item() > max_norm:
            clip_count += 1

        optimizer.step()

        train_loss = loss.item()
        
        train_loss_acc += train_loss
        

        with torch.no_grad():
            train_avg_mae = count_mae(pred, y).item()
        train_mae_acc += train_avg_mae
        
        if step % 50 == 0:
            with torch.no_grad():
                
                cnt = count_huber(pred, y).item()

            print(
                f"  step {step}/{len(train_loader)} "
                f"MAE={train_avg_mae:.4f} totloss={train_loss}",
                flush=True
            )

        nb += 1
        

    train_loss = train_loss_acc / nb
    train_losses.append(train_loss)

    train_mae = train_mae_acc / nb
    train_maes.append(train_mae)

    train_epoch_nums.append(epoch)

    model.eval()
    val_loss_acc = 0.0
    val_mae_acc = 0.0
    nb = 0
    
    with torch.no_grad():
        for step, (x, y) in enumerate(val_loader):

            pred = model(x)
            loss = loss_fn(pred, y)
            
            val_loss = loss.item()
            val_loss_acc += val_loss

            val_avg_mae = count_mae(pred, y).item()
            val_mae_acc += val_avg_mae

            nb += 1

    val_loss = val_loss_acc / nb
    val_losses.append(val_loss)

    val_mae = val_mae_acc / nb
    val_maes.append(val_mae)

    val_epoch_nums.append(epoch)

    full_val_loss_acc = 0.0
    full_val_mae_acc = 0.0
    nb_full = 0

    with torch.no_grad():
        for val_id in val_ids:
            x_full = np.load(os.path.join(root, val_id, "image.npy"), mmap_mode="r").astype(np.float32)
            y_full = np.load(os.path.join(root, val_id, "density_varfactor0.2.npy")).astype(np.float32)

            H, W = x_full.shape
            H_pad = ((H + P - 1) // P) * P
            W_pad = ((W + P - 1) // P) * P

            x_pad_full = np.zeros((H_pad, W_pad), dtype=np.float32)
            x_pad_full[:H, :W] = x_full

            pred_count_total = 0.0
            for top in range(0, H_pad, P):
                for left in range(0, W_pad, P):
                    x_patch = x_pad_full[top:top+P, left:left+P]
                    xt = torch.from_numpy(x_patch)[None, None, :, :]
                    pred_patch = model(xt).item()
                    pred_count_total += pred_patch
            
            y_full_4d = torch.from_numpy(y_full).unsqueeze(0).unsqueeze(0)
            pred_full_4d = torch.tensor(pred_count_total, dtype=torch.float32).view(1,1,1,1)

            loss = loss_fn(pred_full_4d, y_full_4d)
            
            full_val_loss_acc += loss.item()

            full_val_avg_mae = count_mae(pred_full_4d, y_full_4d).item()
            full_val_mae_acc += full_val_avg_mae
            

            nb_full += 1

    full_val_loss = full_val_loss_acc / nb_full
    full_val_losses.append(full_val_loss)

    full_val_mae = full_val_mae_acc / nb_full
    full_val_maes.append(full_val_mae)

    print(f"Epoch {epoch}: train_loss={train_loss:.4f}, train_MAE={train_mae:.3f} | val_loss={val_loss:.4f}, val_MAE={val_mae:.3f} | full_val_loss={full_val_loss:.4f}, full_val_MAE={full_val_mae:.3f}")
    scheduler.step(full_val_mae)

    current_lr = optimizer.param_groups[0]["lr"]
    print(f"  lr={current_lr:.2e}", flush=True)

print("epoch clip_count:", clip_count, flush=True)

Train dataset len: 3590
Val dataset len: 858
Train batches per epoch: 898
Val batches per epoch: 215
{'lr': 0.0001, 'weight_decay': 0.0001}

Epoch 0 START
  First train batch loaded: torch.Size([4, 1, 256, 256]) torch.Size([4, 1, 256, 256])


  y = torch.from_numpy(y).unsqueeze(0)


  step 0/898 MAE=5.3762 totloss=4.934874534606934
  step 50/898 MAE=3.1027 totloss=2.619500160217285
  step 100/898 MAE=12.8667 totloss=12.468442916870117
  step 150/898 MAE=5.1214 totloss=4.674820423126221
  step 200/898 MAE=13.2633 totloss=12.816695213317871
  step 250/898 MAE=3.4584 totloss=3.083408832550049
  step 300/898 MAE=15.0349 totloss=14.55339527130127
  step 350/898 MAE=3.7759 totloss=3.275911808013916
  step 400/898 MAE=6.4775 totloss=5.977530479431152
  step 450/898 MAE=4.0772 totloss=3.57716703414917
  step 500/898 MAE=6.1240 totloss=5.689952373504639
  step 550/898 MAE=4.2866 totloss=3.911609172821045
  step 600/898 MAE=3.0067 totloss=2.507296562194824
  step 650/898 MAE=8.8160 totloss=8.316025733947754
  step 700/898 MAE=4.7462 totloss=4.271746635437012
  step 750/898 MAE=1.7543 totloss=1.3442132472991943
  step 800/898 MAE=3.3577 totloss=2.8580543994903564
  step 850/898 MAE=7.0475 totloss=6.656890869140625
Epoch 0: train_loss=6.6390, train_MAE=7.091 | val_loss=4.9380