<a href="https://colab.research.google.com/github/KoniHD/hw2/blob/main/notebooks/hw2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Setup

## Clone Project

In [None]:
import os

if not os.path.exists("pyproject.toml"):
    print("Repo doesn't exist yet. Cloning from github ...")
    !git clone --quiet --depth 1 https://github.com/KoniHD/hw2.git
    os.chdir("hw2")

!uv pip install -r --quiet requirements.txt --system
os.chdir("..")
os.kill(os.getpid(), 9)  # Restart kernel to make modules available

## Download Dataset

In [None]:
# Fetch data
!mkdir -p data
!wget -q -P data/ https://s3.amazonaws.com/video.udacity-data.com/topher/2018/May/5aea1b91_train-test-data/train-test-data.zip
!unzip -q -n data/train-test-data.zip -d data

## Imports libraries

In [None]:
import os
import matplotlib.pyplot as plt
import torch
import pandas as pd
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader

from lightning.pytorch import Trainer, seed_everything
from lightning.pytorch.loggers import TensorBoardLogger, CSVLogger
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping

from data.custom_transforms import (
    Rescale,
    RandomCrop,
    Normalize,
    ToTensor,
)

from data.facial_keypoints_dataset import FacialKeypointsDataset

from models.simple_cnn import Simple_CNN
from keypoint_task import KeypointDetection

## Set Hyperparameter

In [None]:
config = {
    # Data
    "batch_size": 16,
    "img_size": 224,
    # Model
    "out_dim": 136,
    "activation": "relu",
    "dropout_rate": 0.3,
    "batch_norm": True,
    # Training
    "lr": 4e-3,
    "max_epochs": 30,
    "criterion": "mse",
    "random_seed": 42,
    "patience": 5,
    "optimizer": "adam",
}

## Load Data and visualize

In [None]:
seed_everything(
    config["random_seed"], workers=True
)  # Try to create deterministic results

# defining the data_transform using transforms.Compose([all tx's, . , .])
# order matters! i.e. rescaling should come before a smaller crop
data_transform = transforms.Compose(
    [Rescale(250), RandomCrop(config["img_size"]), Normalize(), ToTensor()]
)

training_keypoints_csv_path = os.path.join("data", "training_frames_keypoints.csv")
training_data_dir = os.path.join("data", "training")
test_keypoints_csv_path = os.path.join("data", "test_frames_keypoints.csv")
test_data_dir = os.path.join("data", "test")


# create the transformed dataset
transformed_dataset = FacialKeypointsDataset(
    csv_file=training_keypoints_csv_path,
    root_dir=training_data_dir,
    transform=data_transform,
)

# load training data in batches
train_loader = DataLoader(
    transformed_dataset, batch_size=config["batch_size"], shuffle=True, num_workers=4
)

# creating the test dataset
test_dataset = FacialKeypointsDataset(
    csv_file=test_keypoints_csv_path, root_dir=test_data_dir, transform=data_transform
)

# loading test data in batches
test_loader = DataLoader(
    test_dataset, batch_size=config["batch_size"], shuffle=True, num_workers=4
)

for i, data in enumerate(test_loader):
    sample = data
    image = sample["image"][0]
    keypoints = sample["keypoints"][0]
    _, h, w = image.shape
    # plot the image black and white
    plt.imshow(image.numpy().transpose(1, 2, 0), cmap="gray")
    plt.scatter(
        keypoints[:, 0] * (w / 2) + (w / 2),
        keypoints[:, 1] * (h / 2) + (h / 2),
        c="r",
        s=20,
    )
    plt.show()
    print(f"Image min/max:   {image.min():.4f} / {image.max():.4f}")
    break

# Data Exploration & Sanity Checks

Observe basic dataset characteristics and sanity check via **model overfitting**.

In [None]:
print(f"===Metrics of first batch===")
batch = next(iter(train_loader))
images, keypoints = batch["image"], batch["keypoints"]

print(f"Image shape:\t\t{images.shape}")
print(
    f"Image min/max:\t\t{images.min():.4f} / {images.max():.4f}\t\twithin [0, 1]: {(-0 <= images.min().round(decimals=1) and images.max().round(decimals=1) <= 1)}"
)
print(
    f"Keypoints min/max:\t{keypoints.min():.4f} / {keypoints.max():.4f}\twithin [-1, 1]: {(-1 <= keypoints.min().round(decimals=1) and keypoints.max().round(decimals=1) <= 1)}"
)

## Part 1: Direct Coordinate Regression

### Overfitting

In [None]:
default_exp_dir = "exp/simple_cnn/"

# Model
simple_cnn = Simple_CNN(
    out_dim=config["out_dim"],
    activation=config["activation"],
    dropout=0.0,
    batch_norm=False,
)

# Lightning Wrapper
keypoint_task = KeypointDetection(
    model=simple_cnn,
    lr=config["lr"],
    criterion=config["criterion"],
    patience=config["patience"],
    optimizer=config["optimizer"],
    activation=config["activation"],
    drouput=0.0,
    batch_norm=False,
)

trainer = Trainer(
    max_epochs=200,
    accelerator="auto",
    deterministic="warn",
    logger=False,
    default_root_dir=default_exp_dir,
    detect_anomaly=True,
    overfit_batches=1,
    enable_autolog_hparams=False,
    enable_checkpointing=False,
)
trainer.fit(keypoint_task, train_dataloaders=train_loader)

metrics = trainer.callback_metrics
print(f"\n\n=============\nFinal train loss: {metrics['train_loss']:.4f}")

Visualize overfitting results

In [None]:
keypoint_task.eval()
with torch.inference_mode():
    outputs = keypoint_task.forward(images)

outputs = outputs.view(-1, 68, 2).cpu()
images_cpu = images.cpu()

fig, axes = plt.subplots(2, 4, figsize=(16, 8))
for i, ax in enumerate(axes.flat):
    _, h, w = images_cpu[i].shape
    ax.imshow(images_cpu[i].numpy().transpose(1, 2, 0), cmap="gray")
    ax.scatter(
        outputs[i, :, 0] * (w / 2) + (w / 2),
        outputs[i, :, 1] * (h / 2) + (h / 2),
        c="r",
        s=10,
    )
    ax.scatter(
        keypoints[i, :, 0].cpu() * (w / 2) + (w / 2),
        keypoints[i, :, 1].cpu() * (h / 2) + (h / 2),
        c="g",
        s=10,
    )
    ax.axis("off")
plt.suptitle("Red=Predicted, Green=Ground Truth")

## Real training loop

In [None]:
version = 0

simple_cnn = Simple_CNN(
    out_dim=config["out_dim"],
    activation=config["activation"],
    dropout=config["dropout_rate"],
    batch_norm=config["batch_norm"],
)

simple_cnn = torch.compile(simple_cnn, mode="max-autotune")

keypoint_task = KeypointDetection(
    model=simple_cnn,
    lr=config["lr"],
    criterion=config["criterion"],
    patience=config["patience"],
    optimizer=config["optimizer"],
    activation=config["activation"],
    drouput=config["dropout_rate"],
    batch_norm=config["batch_norm"],
)

checkpoint_callback = ModelCheckpoint(
    dirpath=default_exp_dir + f"version_{version}",
    filename="simple-cnn",
    monitor="val_loss",
    mode="min",
    save_top_k=1,
    save_last=True,
    save_weights_only=True,
    enable_version_counter=True,
)

earlystopping_callback = EarlyStopping(
    monitor="val_loss", patience=config["patience"], mode="min", verbose=True
)

trainer = Trainer(
    accelerator="auto",
    logger=[
        TensorBoardLogger(
            default_exp_dir,
            name="",
            version=f"version_{version}",
            log_graph=True,
            default_hp_metric=False,
        ),
        CSVLogger(default_exp_dir, name="", version=f"version_{version}"),
    ],
    max_epochs=config["max_epochs"],
    callbacks=[checkpoint_callback, earlystopping_callback],
    deterministic="warn",
    default_root_dir=default_exp_dir,
    num_sanity_val_steps=0,
    enable_checkpointing=True,
    enable_autolog_hparams=False,
)

trainer.fit(keypoint_task, train_dataloaders=train_loader, val_dataloaders=test_loader)

keypoint_task = KeypointDetection.load_from_checkpoint(
    checkpoint_callback.best_model_path, weights_only=True, model=simple_cnn
)

### Visualize Loss Curve

In [None]:
metrics = pd.read_csv(default_exp_dir + f"version_{version}/metrics.csv")

fig, ax = plt.subplots(figsize=(8, 4))
metrics[["epoch", "train_loss"]].dropna().plot(x="epoch", ax=ax, label="Train Loss")
metrics[["epoch", "val_loss"]].dropna().plot(x="epoch", ax=ax, label="Val Loss")
ax.set_xlabel("Epoch")
ax.set_ylabel("MSE Loss")
ax.set_title(f"Part 1: Simple CNN Training Curve Version {version}")
ax.set_xlim(left=0.0)
ax.set_ylim(bottom=0.0)
ax.legend()
plt.tight_layout()
plt.show()

### Visualize using Tensorboard

In [None]:
%reload_ext tensorboard
%tensorboard --logdir {default_exp_dir}

**Optional:** Save model weights to huggingface for reproducibility.

In [None]:
from google.colab import userdata

hf_token = userdata.get("HF_TOKEN")

model_to_save = getattr(simple_cnn, "_orig_mod", simple_cnn)
model_to_save.push_to_hub("username/simple-cnn-keypoints")

model_to_save.push_to_hub(
    "KoniHD/Simple_CNN",
    config=config,
    commit_message=f"Training run version: {version}",
    private=True,
    token=hf_token,
)

In [None]:
# TODO: Training a simple CNN

In [None]:
# TODO: Visualization of results

## Part 2: Transfer Learning for Keypoint Detection

In [None]:
# TODO: Pretrained ResNet backbone

In [None]:
# TODO: Advanced pretrained models (DINO, MAE, ...)

## Part 3: Heatmap-based Keypoint Detection

In [None]:
# TODO: Heatmap synthesis and training

In [None]:
# TODO: Visualization of heatmap prediction