In [1]:
import os
import numpy as np
import cv2
import json
import math
import matplotlib.pyplot as plt

import torch
from torch.utils.data import Dataset, DataLoader, random_split
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [3]:
def make_gaussian_box(box_size=5, sigma=1.0, mu=0.0):
    """
    Creates a 2D Gaussian kernel (square) of given size.

    Args:
        box_size (int): width/height of the square
        sigma (float): standard deviation of the Gaussian
        mu (float): mean of the Gaussian (center)

    Returns:
        np.ndarray: 2D array of shape (box_size, box_size)
    """
    # Create a normalized coordinate grid from -1 to 1
    x = np.linspace(-1, 1, box_size)
    y = np.linspace(-1, 1, box_size)
    xx, yy = np.meshgrid(x, y)
    
    # Compute distance from center
    d = np.sqrt(xx**2 + yy**2)
    
    # Compute Gaussian values
    gaussian = np.exp(-((d - mu)**2) / (2 * sigma**2))
    return gaussian

In [4]:
class HeatmapDataset(Dataset):
    def __init__(self, path, box_size=5, img_size=(64,48), gaussian_box=None):
        """
        Args:
            path (str): Path to dataset folder containing 'img' and 'ann' subfolders
            box_size (int): Size of the Gaussian box for each keypoint
            img_size (tuple): Resize images to this size (height, width)
            g (np.array): Precomputed Gaussian kernel of shape (box_size, box_size)
        """
        self.path = path
        self.box_size = box_size
        self.img_size = img_size
        self.gaussian_box = gaussian_box if gaussian_box is not None else np.ones((box_size, box_size))
        self.images, self.heatmaps = self._load_data()

    def _load_data(self):
        ann_dir = os.path.join(self.path, 'ann')
        img_dir = os.path.join(self.path, 'img')
        ann_files = os.listdir(ann_dir)

        images, heatmaps = [], []

        for name in ann_files:
            # Load JSON annotation
            with open(os.path.join(ann_dir, name)) as f:
                data = json.load(f)

            points = data.get('objects', [])
            if len(points) != 7:
                continue  # skip malformed annotations

            # Create empty heatmap
            h, w = self.img_size
            heatmap = np.zeros((7, h, w), dtype=np.float32)

            for i, pt in enumerate(points):
                coords = np.array(pt['points']['exterior'][0], dtype=int)
                cy, cx = coords[1], coords[0]

                # Compute box boundaries
                ymin = max(0, cy - self.box_size//2)
                ymax = min(h-1, cy + self.box_size//2)
                xmin = max(0, cx - self.box_size//2)
                xmax = min(w-1, cx + self.box_size//2)

                # Compute Gaussian slice indices
                gymin = ymin - (cy - self.box_size//2)
                gymax = gymin + (ymax - ymin + 1)
                gxmin = xmin - (cx - self.box_size//2)
                gxmax = gxmin + (xmax - xmin + 1)

                # Add Gaussian patch and normalize
                patch = self.gaussian_box[gymin:gymax, gxmin:gxmax]
                patch = patch / patch.sum() if patch.sum() > 0 else patch
                heatmap[i, ymin:ymax+1, xmin:xmax+1] = patch

            heatmaps.append(torch.from_numpy(heatmap))

            # Load image, resize, convert BGR->RGB, normalize
            img_path = os.path.join(img_dir, name[:-5])
            img = cv2.imread(img_path)
            if img is None:
                continue
            img = cv2.resize(img, (w,h))
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            img_t = torch.from_numpy(img.transpose(2,0,1)).float() / 255.0
            images.append(img_t)

        return images, heatmaps

    def __getitem__(self, idx):
        return self.images[idx], self.heatmaps[idx]

    def __len__(self):
        return len(self.images)

In [5]:
class ResNetBlock(nn.Module):
    def __init__(self, in_channels, out_channels, use_bn=True):
        super(ResNetBlock, self).__init__()
        self.use_bn = use_bn

        # Main branch
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=not use_bn)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=not use_bn)
        
        if use_bn:
            self.bn1 = nn.BatchNorm2d(out_channels)
            self.bn2 = nn.BatchNorm2d(out_channels)
            # Shortcut batch norm only if in/out channels differ
            self.bn_shortcut = nn.BatchNorm2d(out_channels) if in_channels != out_channels else nn.Identity()
        else:
            self.bn1 = self.bn2 = self.bn_shortcut = nn.Identity()
        
        # Shortcut path
        self.shortcut_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=not use_bn)

        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        # Main branch
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        
        out = self.conv2(out)
        out = self.bn2(out)

        # Shortcut
        shortcut = self.shortcut_conv(x)
        shortcut = self.bn_shortcut(shortcut)

        # Residual addition + ReLU
        out = self.relu(out + shortcut)
        return out


class ResNet(nn.Module):
    def __init__(self, base_size=8, num_classes=7, use_bn=True, use_sigmoid=False, use_softmax=False):
        super(ResNet, self).__init__()
        self.use_bn = use_bn
        self.use_sigmoid = use_sigmoid
        self.use_softmax = use_softmax

        self.conv = nn.Conv2d(3, base_size, kernel_size=7, padding=3, bias=not use_bn)
        self.bn0 = nn.BatchNorm2d(base_size) if use_bn else nn.Identity()
        self.relu = nn.ReLU(inplace=True)

        self.res1 = ResNetBlock(base_size, base_size*2, use_bn=use_bn)
        self.res2 = ResNetBlock(base_size*2, base_size*4, use_bn=use_bn)

        self.out = nn.Conv2d(base_size*4, num_classes, kernel_size=1)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn0(x)
        x = self.relu(x)

        x = self.res1(x)
        x = self.res2(x)
        x = self.out(x)

        if self.use_sigmoid:
            x = torch.sigmoid(x)
        elif self.use_softmax:
            n, c, h, w = x.shape
            x = x.view(n, c, -1)         # flatten spatial dims
            x = torch.softmax(x, dim=2)  # apply softmax
            x = x.view(n, c, h, w)       # reshape back
        return x

In [6]:
class UNet(nn.Module):
    def __init__(self, base_size=8, num_classes=7, use_bn=False, use_sigmoid=True, use_softmax=False):
        super(UNet, self).__init__()
        self.use_bn = use_bn
        self.use_sigmoid = use_sigmoid
        self.use_softmax = use_softmax
        self.relu = nn.ReLU(inplace=True)
        self.max_pool = nn.MaxPool2d(2)
        self.upsample = nn.Upsample(scale_factor=2, mode='nearest')

        # ---------- Encoder ----------
        self.conv_d0 = nn.Conv2d(3, base_size, kernel_size=3, padding=1, bias=not use_bn)
        self.conv_d1 = nn.Conv2d(base_size, base_size*2, kernel_size=3, padding=1, bias=not use_bn)
        self.conv_d2 = nn.Conv2d(base_size*2, base_size*4, kernel_size=3, padding=1, bias=not use_bn)

        self.bn_d0 = nn.BatchNorm2d(base_size) if use_bn else nn.Identity()
        self.bn_d1 = nn.BatchNorm2d(base_size*2) if use_bn else nn.Identity()
        self.bn_d2 = nn.BatchNorm2d(base_size*4) if use_bn else nn.Identity()

        # ---------- Decoder ----------
        self.upconv1 = nn.Conv2d(base_size*4, base_size*2, kernel_size=3, padding=1, bias=not use_bn)
        self.upconv0 = nn.Conv2d(base_size*2, base_size, kernel_size=3, padding=1, bias=not use_bn)

        self.bn_up1 = nn.BatchNorm2d(base_size*2) if use_bn else nn.Identity()
        self.bn_up0 = nn.BatchNorm2d(base_size) if use_bn else nn.Identity()

        self.conv_u1 = nn.Conv2d(base_size*4, base_size*2, kernel_size=3, padding=1, bias=not use_bn)
        self.conv_u0 = nn.Conv2d(base_size*2, base_size, kernel_size=3, padding=1, bias=not use_bn)

        self.bn_u1 = nn.BatchNorm2d(base_size*2) if use_bn else nn.Identity()
        self.bn_u0 = nn.BatchNorm2d(base_size) if use_bn else nn.Identity()

        # ---------- Output ----------
        self.out = nn.Conv2d(base_size, num_classes, kernel_size=1)

    def forward(self, x):
        # Encoder
        c0 = self.relu(self.bn_d0(self.conv_d0(x)))
        c1 = self.relu(self.bn_d1(self.conv_d1(self.max_pool(c0))))
        c2 = self.relu(self.bn_d2(self.conv_d2(self.max_pool(c1))))

        # Decoder
        u1 = self.relu(self.bn_up1(self.upconv1(self.upsample(c2))))
        u1 = self.relu(self.bn_u1(self.conv_u1(torch.cat([c1, u1], dim=1))))

        u0 = self.relu(self.bn_up0(self.upconv0(self.upsample(u1))))
        u0 = self.relu(self.bn_u0(self.conv_u0(torch.cat([c0, u0], dim=1))))

        # Output
        out = self.out(u0)

        if self.use_sigmoid:
            out = torch.sigmoid(out)
        elif self.use_softmax:
            n, c, h, w = out.shape
            out = out.view(n, c, -1)          # flatten spatial dims
            out = torch.softmax(out, dim=2)   # apply softmax over spatial
            out = out.view(n, c, h, w)        # reshape back

        return out


In [7]:
def model_train(model, train_loader, val_loader, epochs=50):
    model = model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    loss_fn = nn.MSELoss()

    prev_best = float("inf")

    train_losses_per_epoch = []
    val_losses_per_epoch = []

    for epoch in range(epochs):
        # ---------- TRAIN ----------
        model.train()
        train_loss = 0
        for X_batch, y_batch in train_loader:
            X_batch = X_batch.float().to(device)
            y_batch = y_batch.float().to(device)

            optimizer.zero_grad()
            outputs = model(X_batch)
            batch_loss = loss_fn(outputs, y_batch)
            batch_loss.backward()
            optimizer.step()

            train_loss += batch_loss.item()

        avg_train_loss = train_loss / len(train_loader)
        train_losses_per_epoch.append(avg_train_loss)

        # ---------- VALIDATION ----------
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for X_batch, y_batch in val_loader:
                X_batch = X_batch.float().to(device)
                y_batch = y_batch.float().to(device)
                outputs = model(X_batch)
                batch_loss = loss_fn(outputs, y_batch)
                val_loss += batch_loss.item()

        avg_val_loss = val_loss / len(val_loader)
        val_losses_per_epoch.append(avg_val_loss)

    return model, train_losses_per_epoch, val_losses_per_epoch


In [8]:
# create the dataset
box_size = 5
gaussian_box = make_gaussian_box(box_size=box_size, sigma=1.0, mu=0.0)
dataset = HeatmapDataset(path='cones_dataset', box_size=box_size, img_size=(64,48), gaussian_box=gaussian_box)
# select rows from the dataset
train, test = random_split(dataset, [int(0.9*len(dataset)), len(dataset)-int(0.9*len(dataset))])
# create a data loader for train and test sets
train_loader = DataLoader(train, batch_size=64, shuffle=True)
val_loader = DataLoader(test, batch_size=64, shuffle=True)

In [9]:
ResNet_base = ResNet(base_size=8, num_classes=7, use_bn=False, use_sigmoid=False, use_softmax=False)
UNet_base = UNet(base_size=8, num_classes=7, use_bn=False, use_sigmoid=False, use_softmax=False)
models_to_test = [
    ("ResNet_base", ResNet_base),
    ("UNet_base", UNet_base),
]

In [None]:
# Store losses
all_train_losses = {}
all_val_losses = {}

for name, model in models_to_test:
    print(f"Training {name}...")
    trained_model, train_losses, val_losses = model_train(model, train_loader, val_loader, epochs=100)
    all_train_losses[name] = train_losses
    all_val_losses[name] = val_losses


Training ResNet_base...


In [None]:
# Plot (excluding the first epoch)
plt.figure(figsize=(10,6))

for name, _ in models_to_test:
    plt.plot(all_val_losses[name][:], label=f"{name} val loss")
    plt.plot(all_train_losses[name][:], '--', label=f"{name} train loss")

plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Training and Validation Loss Comparison")
plt.legend()
plt.grid(True)
plt.show()