# TODO
- GT heatmaps are not correctly calculated 

# Setup

In [61]:
import os

import torch
from torchvision.datasets import CocoDetection
from torch.utils.data import DataLoader, random_split
# from torchvision.transforms import functional as F
from torch.nn import functional as F
import torchvision.transforms.functional as VF
from pycocotools.coco import COCO
import torchvision.transforms.v2 as T
import torch.nn as nn
from torchvision.models import resnet18
from torchvision import transforms

import matplotlib.pyplot as plt
from itertools import cycle

from tqdm.notebook import tqdm

import random
import numpy as np

import wandb

import math

In [62]:
wandb.login()

True

In [63]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

# reduce cpu contention
torch.set_num_threads(1)
NUM_WORKERS = 6  # adjust based on CPU cores

cuda


In [64]:
COCO_PATH = "../../data/coco/"  # change this
IMG_DIR_TRAIN = os.path.join(COCO_PATH, "images/train2017")
IMG_DIR_VAL = os.path.join(COCO_PATH, "images/val2017")
ANN_FILE_TRAIN = os.path.join(COCO_PATH, "annotations/person_keypoints_train2017.json")
ANN_FILE_VAL = os.path.join(COCO_PATH, "annotations/person_keypoints_val2017.json")

REMOVE_IMAGES_WITHOUT_KEYPOINTS = True
VAL_SPLIT = 0.5
TEST_VAL_TRAIN_PERCENT = (0.1, 0.1, 0.1)
BATCH_SIZE = 128
DATA_AUGMENTATION = False

NUM_KEYPOINTS = 17

KEYPOINT_NAMES = [
    'nose',
    'left_eye',
    'right_eye',
    'left_ear',
    'right_ear',
    'left_shoulder',
    'right_shoulder',
    'left_elbow',
    'right_elbow',
    'left_wrist',
    'right_wrist',
    'left_hip',
    'right_hip',
    'left_knee',
    'right_knee',
    'left_ankle',
    'right_ankle'
]


HEATMAP_OUTPUT_STRIDE = 4
HEATMAP_SIZE = (256 // HEATMAP_OUTPUT_STRIDE, 256 // HEATMAP_OUTPUT_STRIDE)  # (64, 64)
SIGMA = 2  # Gaussian spread


# Data

In [65]:
class CustomTransform:
    def __init__(self, size=(256, 256), augmentation=False):
        self.size = size
        if augmentation:
            self.transform = T.Compose([
                T.RandomHorizontalFlip(p=0.5),
                T.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3),
                T.Resize(size),
                T.ToTensor(),
                T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ])
        else:
            self.transform = T.Compose([
                T.Resize(size),
                T.ToTensor(),
                # T.ToDtype(torch.float32, scale=True),
                T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ])


    def __call__(self, image, target):
        orig_w, orig_h = image.size
        image = self.transform(image)
        
        # Scale keypoints to new image size
        annotations = []
        for ann in target["annotations"]:
            kps = np.array(ann['keypoints']).reshape(-1, 3)
            kps[:, 0] = kps[:, 0] * (self.size[0] / orig_w)
            kps[:, 1] = kps[:, 1] * (self.size[1] / orig_h)
            ann['keypoints'] = kps.ravel().tolist()
            annotations.append(ann)
        
        return image, {
            "image_id": target["image_id"],
            "annotations": annotations,
            "orig_size": target["orig_size"]
        }


In [None]:
class CocoKeypointsDataset(CocoDetection):
    def __init__(self, img_folder, ann_file, transforms=None):
        super().__init__(img_folder, ann_file)
        self.coco = COCO(ann_file)
        self._transforms = transforms
        self.filter_without_keypoints = REMOVE_IMAGES_WITHOUT_KEYPOINTS
        self.heatmap_size = HEATMAP_SIZE
        self.output_stride = HEATMAP_OUTPUT_STRIDE
        self.sigma = SIGMA

        if self.filter_without_keypoints:
            original_ids = list(self.ids)
            self.ids = []
            for img_id in original_ids:
                ann_ids = self.coco.getAnnIds(imgIds=img_id, iscrowd=None)
                anns = self.coco.loadAnns(ann_ids)
                anns_with_kp = [ann for ann in anns if 'keypoints' in ann and np.any(np.array(ann['keypoints']) != 0)]
                if len(anns_with_kp) > 0:
                    self.ids.append(img_id)

        # cache annotations for each image
        self.anns_per_image = {}
        for img_id in self.ids:
            ann_ids = self.coco.getAnnIds(imgIds=img_id)
            anns = self.coco.loadAnns(ann_ids)
            self.anns_per_image[img_id] = anns

    def draw_gaussian(self, heatmap, center_x, center_y):
        """Draw a 2D Gaussian on heatmap channel (vectorized, correct region and shape)"""
        height, width = heatmap.shape
        sigma = self.sigma
        
        # bounds
        x0 = int(max(0, center_x - 3 * sigma))
        y0 = int(max(0, center_y - 3 * sigma))
        x1 = int(min(width, center_x + 3 * sigma + 1))
        y1 = int(min(height, center_y + 3 * sigma + 1))
        if x1 <= x0 or y1 <= y0:
            return
        xs = np.arange(x0, x1)
        ys = np.arange(y0, y1)
        xx, yy = np.meshgrid(xs, ys, indexing='xy')  # xx,yy shape: (y1-y0, x1-x0)
        d2 = (xx - center_x) ** 2 + (yy - center_y) ** 2
        exponent = d2 / (2 * sigma ** 2)
        mask = exponent <= 4.6052
        g = np.exp(-exponent) * mask
        patch = heatmap[y0:y1, x0:x1]
        np.maximum(patch, g, out=patch)
        heatmap[y0:y1, x0:x1] = patch

    def __getitem__(self, idx):
        img, _ = super().__getitem__(idx)
        orig_w, orig_h = img.size
        anns = self.anns_per_image[self.ids[idx]]

        # Apply transforms first to get scaled keypoints
        if self._transforms:
            img, target = self._transforms(img, {
                "image_id": self.ids[idx],
                "annotations": anns,
                "orig_size": (orig_w, orig_h)
            })
            anns = target["annotations"]

        # Ensure correct order: (width, height)
        if hasattr(self._transforms, 'size'):
            resized_w, resized_h = self._transforms.size
        else:
            resized_w, resized_h = (256, 256)
        heatmap_h, heatmap_w = self.heatmap_size
        
        # print(f"resized_w={resized_w}, resized_h={resized_h}, heatmap_w={heatmap_w}, heatmap_h={heatmap_h}")
        
        # Create heatmap tensor: [NUM_KEYPOINTS, H, W]
        heatmap = np.zeros((NUM_KEYPOINTS, heatmap_h, heatmap_w), dtype=np.float32)
        for ann in anns:
            kps = np.array(ann['keypoints']).reshape(-1, 3)
            for kp_idx, (x, y, v) in enumerate(kps):
                if v > 0:  # Only visible keypoints
                    # Correct scaling: x to width, y to height
                    x_hm = x * (heatmap_w / resized_w)
                    y_hm = y * (heatmap_h / resized_h)
                    self.draw_gaussian(heatmap[kp_idx], x_hm, y_hm)
        return img, torch.from_numpy(heatmap)

In [67]:
train_dataset = CocoKeypointsDataset(
    IMG_DIR_TRAIN, 
    ANN_FILE_TRAIN, 
    transforms=CustomTransform(augmentation=DATA_AUGMENTATION), 
)
val_dataset = CocoKeypointsDataset(
    IMG_DIR_VAL, 
    ANN_FILE_VAL, 
    transforms=CustomTransform(augmentation=False), 
)

val_size = int(VAL_SPLIT * len(val_dataset))
test_size = len(val_dataset) - val_size
val_dataset, test_dataset = random_split(val_dataset, [val_size, test_size])


loading annotations into memory...
Done (t=8.08s)
creating index...
index created!
loading annotations into memory...
Done (t=7.02s)
creating index...
index created!
loading annotations into memory...
Done (t=0.25s)
creating index...
index created!
loading annotations into memory...
Done (t=0.24s)
creating index...
index created!


In [68]:
print("train dataset size:", len(train_dataset))
print("val dataset size:", len(val_dataset))
print("test dataset size:", len(test_dataset))

train dataset size: 56599
val dataset size: 1173
test dataset size: 1173


In [69]:
subset_len_train = int(TEST_VAL_TRAIN_PERCENT[0] * len(train_dataset))
subset_len_val = int(TEST_VAL_TRAIN_PERCENT[1] * len(val_dataset))
subset_len_test = int(TEST_VAL_TRAIN_PERCENT[2] * len(test_dataset))

train_dataset = torch.utils.data.Subset(train_dataset, range(subset_len_train))
val_dataset = torch.utils.data.Subset(val_dataset, range(subset_len_val))
test_dataset = torch.utils.data.Subset(test_dataset, range(subset_len_test))

In [70]:
train_loader = DataLoader(
    train_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=True, 
    pin_memory=True,            # faster GPU transfer
    # persistent_workers=True     # maintain worker pool
)
val_loader = DataLoader(
    val_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=False, 
    pin_memory=True,
    # persistent_workers=True
)
test_loader = DataLoader(
    test_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=False, 
    pin_memory=True,
    # persistent_workers=True
)

In [71]:
print("train dataset size:", len(train_dataset))
print("val dataset size:", len(val_dataset))
print("test dataset size:", len(test_dataset))

train dataset size: 5659
val dataset size: 117
test dataset size: 117


## Visualize data

In [72]:
def unnormalize(img_tensor):
    mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
    return img_tensor * std + mean

def visualize_heatmaps(image, gt_heatmap, pred_heatmap=None, keypoint_idx=0):
    """Visualize input image with ground truth and predicted heatmaps"""
    fig, axes = plt.subplots(1, 3 if pred_heatmap is not None else 2, figsize=(15, 5))
    
    # Original image
    img = unnormalize(image).permute(1, 2, 0).cpu().numpy()
    axes[0].imshow(img)
    axes[0].set_title('Input Image')
    axes[0].axis('off')
    
    # Ground truth heatmap
    gt_hm = gt_heatmap[keypoint_idx].cpu().numpy()
    axes[1].imshow(gt_hm, cmap='jet', alpha=0.5)
    axes[1].set_title(f'GT Heatmap (Keypoint {KEYPOINT_NAMES[keypoint_idx]})')
    axes[1].axis('off')
    
    # Predicted heatmap (if available)
    if pred_heatmap is not None:
        pred_hm = pred_heatmap[keypoint_idx].detach().cpu().numpy()
        axes[2].imshow(pred_hm, cmap='jet', alpha=0.5)
        axes[2].set_title(f'Predicted Heatmap (Keypoint {KEYPOINT_NAMES[keypoint_idx]})')
        axes[2].axis('off')
    
    plt.tight_layout()
    plt.show()


In [None]:
images, gt_heatmaps = next(iter(val_loader))
images = images.to(device)
gt_heatmaps = gt_heatmaps.to(device)

idx = random.randint(0, images.size(0)-1)

for kpidx in range(0, NUM_KEYPOINTS):
    visualize_heatmaps(
        images[idx].cpu(),
        gt_heatmaps[idx].cpu(),
        keypoint_idx = kpidx
    )

# Model

In [74]:
class HeatmapModel(nn.Module):
    def __init__(self):
        super().__init__()
        resnet = resnet18(pretrained=True)
        
        # Remove the last two layers (avgpool and fc)
        self.backbone = nn.Sequential(*list(resnet.children())[:-2])
        
        # Upsampling layers to increase resolution
        self.upsample = nn.Sequential(
            nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, NUM_KEYPOINTS, kernel_size=1)
        )
        
        # Initialize last layer
        nn.init.normal_(self.upsample[-1].weight, std=0.001)
        nn.init.constant_(self.upsample[-1].bias, 0)

    def forward(self, x):
        x = self.backbone(x)  # [B, 512, 8, 8]
        x = self.upsample(x)  # [B, 17, 64, 64]
        return x


# Model training functions

## Loss

In [76]:
def heatmap_loss(pred_heatmaps, gt_heatmaps):
    """MSE loss with emphasis on positive pixels"""
    # Basic MSE loss
    loss = F.mse_loss(pred_heatmaps, gt_heatmaps, reduction='none')
    
    # Increase weight for positive pixels
    pos_mask = gt_heatmaps > 0.1
    loss[pos_mask] *= 3.0
    
    return loss.mean()


# Train

In [79]:
EPOCHS = 100
LEARNING_RATE = 1e-4
WEIGHT_DECAY = 1e-4

In [80]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = HeatmapModel().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)


In [None]:
wandb_config = {
    "epochs": EPOCHS,
    "learning_rate": LEARNING_RATE,
    "weight_decay": WEIGHT_DECAY,
    "batch_size": BATCH_SIZE,
    "train_size": subset_len_train,
    "val_size": subset_len_val,
    "test_size": subset_len_test,
    "model_name": "HeatmapResNet",
    "criterion": "mse",
    "optimizer": "Adam",
    "num_keypoints": NUM_KEYPOINTS,
    "remove_images_without_keypoints": REMOVE_IMAGES_WITHOUT_KEYPOINTS,
    "val_split": VAL_SPLIT,
    "test_val_train_percent": TEST_VAL_TRAIN_PERCENT,
    "device": device,
    "data_augmentation": DATA_AUGMENTATION,
    "heatmap_stride": HEATMAP_OUTPUT_STRIDE,
    "heatmap_size": HEATMAP_SIZE,
    "heatmap_sigma": SIGMA,
}

wandb.init(
    entity="fejowo5522-",
    project="NN_Project",
    config=wandb_config,
    group="KeypointDetectionHeatmap"
)

In [None]:
early_stopping = True
patience = 10
best_val_loss = float('inf')
epochs_no_improve = 0

train_losses = []
val_losses = []


for epoch in tqdm(range(EPOCHS)):
    model.train()
    total_loss = 0.0
    
    for images, heatmaps in tqdm(train_loader, desc=f'Epoch {epoch+1}/{EPOCHS}'):
        images = images.to(device)
        heatmaps = heatmaps.to(device)
        
        # Forward pass
        pred_heatmaps = model(images)
        loss = heatmap_loss(pred_heatmaps, heatmaps)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()

    # Validation loss calculation
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for val_images, val_heatmaps in val_loader:
            val_images = val_images.to(device)
            val_heatmaps = val_heatmaps.to(device)
            val_pred_heatmaps = model(val_images)
            loss = heatmap_loss(val_pred_heatmaps, val_heatmaps)
            val_loss += loss.item()
    
    avg_val_loss = val_loss / len(val_loader)
    avg_train_loss = total_loss / len(train_loader)

    wandb.log({
        'epoch': epoch + 1,
        'train_loss': avg_train_loss,
        'val_loss': avg_val_loss
    })

    print(f'Epoch {epoch+1} - Train Loss: {avg_train_loss:.4f} Val Loss: {avg_val_loss:.4f}')


    if early_stopping:
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            epochs_no_improve = 0
            torch.save(model.state_dict(), "temp_best_model.pth")
        else:
            epochs_no_improve += 1
        if epochs_no_improve >= patience:
            print(f"Early stopping at epoch {epoch+1}")
            break

In [None]:
model.eval()

# Load best model if early stopping was used
if early_stopping and os.path.exists("temp_best_model.pth"):
    model.load_state_dict(torch.load("temp_best_model.pth"))

# Run model on all test data and collect predictions and ground truths
model.eval()
preds_list = []
gt_list = []
with torch.no_grad():
    for images, gt_heatmaps in test_loader:
        images = images.to(device)
        gt_heatmaps = gt_heatmaps.to(device)
        pred_heatmaps = model(images)
        preds_list.append(pred_heatmaps.cpu())
        gt_list.append(gt_heatmaps.cpu())
preds_all = torch.cat(preds_list, dim=0)
gt_all = torch.cat(gt_list, dim=0)
test_loss = heatmap_loss(preds_all, gt_all).item()

wandb.log({
    'test_loss': test_loss
})

wandb.finish()

0,1
epoch,▁▁▁▂▂▂▂▂▂▂▃▃▃▄▄▄▄▄▄▄▅▅▅▅▅▅▅▆▆▆▆▆▇▇▇█████
test_loss,▁
train_loss,█▇▄▄▃▃▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_loss,██▇▅▅▄▃▃▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,73.0
test_loss,2418.70361
train_loss,651.28759
val_loss,3612.5603


In [83]:
torch.save(model.state_dict(), "keypoint_model.pth")
print("Model saved to keypoint_model.pth")

Model saved to keypoint_model.pth


# Visualize

In [None]:
# After training, visualize predictions
model.eval()
with torch.no_grad():
    images, gt_heatmaps = next(iter(val_loader))
    images = images.to(device)
    gt_heatmaps = gt_heatmaps.to(device)
    pred_heatmaps = model(images)
    
    # Select random sample from batch
    idx = random.randint(0, images.size(0)-1)
    for kpidx in range(0, NUM_KEYPOINTS):
        visualize_heatmaps(
            images[idx].cpu(),
            gt_heatmaps[idx].cpu(),
            pred_heatmaps[idx].cpu(),
            keypoint_idx = kpidx
        )