In [2]:
from ultralytics import YOLO
import numpy as np
import cv2
import torch
from pathlib import Path
import albumentations as A
from albumentations.pytorch import ToTensorV2
from torch import nn
import torch
from torch.utils.data import DataLoader, Dataset, random_split
from sklearn.model_selection import train_test_split
import torch.nn.functional as F
import time
import matplotlib.pyplot as plt
import timm
import torch.optim as optim

In [3]:
NUM_LANDMARKS = 9

class ViTLandmark(nn.Module):
    def __init__(self, pretrained=True):
        super().__init__()
        self.backbone = timm.create_model('vit_small_patch16_224', pretrained=pretrained)

        in_features = self.backbone.head.in_features
        self.backbone.head = nn.Identity()

        self.head = nn.Sequential(
            nn.Linear(in_features, 2048),
            nn.LayerNorm(2048),
            nn.ReLU(),
            nn.Linear(2048, 1024),
            nn.LayerNorm(1024),
            nn.ReLU(),
            nn.Linear(1024, NUM_LANDMARKS * 2)
        )

    def forward(self, x):
        #print(x.shape)
        #print(x.dtype, x.min().item(), x.max().item())
        #print(torch.isnan(x).any(), torch.isinf(x).any())
        features = self.backbone(x)
        coords = torch.sigmoid(self.head(features))
        return coords

class ViTDataset(Dataset):
    def __init__(self, folder):
        self.folder = Path(folder)
        self.files = sorted(self.folder.glob("*.pt"))

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        data = torch.load(self.files[idx])
        img = data["image"].float()
        keypoints = data["keypoints"].float()
        keypoints = keypoints / 224
    
        return img, keypoints


In [4]:
device = 'cuda'

model = ViTLandmark(pretrained=True).to(device)
b_params = model.backbone.parameters()
h_params = model.head.parameters()
optimizer = optim.SGD([
    {'params': b_params, 'lr': 1e-5},
    {'params': h_params, 'lr': 1e-3},
], momentum=0.9)

criterion = nn.SmoothL1Loss()
dataset = ViTDataset(f"../data/training_data/vit")
val_fraction = 0.1
val_len = int(len(dataset)*val_fraction)
train_len = len(dataset) - val_len
train_dataset, val_dataset = random_split(dataset, [train_len, val_len])
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=0, pin_memory=True)

val_loader = DataLoader(val_dataset, batch_size=16, shuffle=True, num_workers=0, pin_memory=True)

NVIDIA GeForce RTX 5080 with CUDA capability sm_120 is not compatible with the current PyTorch installation.
The current PyTorch install supports CUDA capabilities sm_37 sm_50 sm_60 sm_61 sm_70 sm_75 sm_80 sm_86 sm_90 compute_37.
If you want to use the NVIDIA GeForce RTX 5080 GPU with PyTorch, please check the instructions at https://pytorch.org/get-started/locally/



In [None]:
for p in model.backbone.parameters():
    p.requires_grad = False

best_pixel_error = float('inf')
for epoch in range(10):
    model.train()
    total_loss = 0
    bc = 1
    for images, keypoints in train_loader:
        images = images.to(device)
        keypoints = keypoints.to(device)
        k_flat = keypoints.view(keypoints.size(0), -1)
        optimizer.zero_grad()
        preds = model(images)
        loss = criterion(preds, k_flat)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        total_loss += loss.item()
        print(f"\rBatch {bc}/{len(train_loader)}, Loss {loss}", end='')
        bc += 1
        
    avg_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch}/10 - Loss: {avg_loss:.6f}")

    model.eval()
    val_loss = 0
    total_pixel_error = 0
    val_batches = 0
    with torch.no_grad():
        for images, keypoints in val_loader:
            images = images.to(device)
            keypoints = keypoints.to(device)
            k_flat = keypoints.view(keypoints.size(0), -1)

            preds = model(images)
            loss = criterion(preds, k_flat)
            val_loss += loss.item()

            # pixel error for interpretability
            preds_px = preds * 224
            k_px = k_flat * 224
            pixel_error = torch.mean(torch.abs(preds_px - k_px)).item()
            total_pixel_error += pixel_error
            val_batches += 1

    avg_val_loss = val_loss / val_batches
    avg_pixel_error = total_pixel_error / val_batches
    print(f"Epoch {epoch}/10 - Val Loss: {avg_val_loss:.6f}, Avg Pixel Error: {avg_pixel_error:.2f}")

    # ----------------------------
    # Checkpoint best model
    # ----------------------------
    if avg_pixel_error < best_pixel_error:
        best_pixel_error = avg_pixel_error
        torch.save(model.state_dict(), "best_vit_landmark.pth")
        print(f"New best model saved with pixel error {best_pixel_error:.2f}")

Batch 7/92, Loss 0.024677449837327003

In [5]:
small_dataset, _ = random_split(train_dataset, [16, len(train_dataset)-16])
small_loader = DataLoader(small_dataset, batch_size=4, shuffle=True)

for p in model.backbone.parameters():
    p.requires_grad = False

for epoch in range(20):
    for images, keypoints in small_loader:
        images, keypoints = images.to(device), keypoints.to(device)
        optimizer.zero_grad()

        # -------------------
        # Forward pass
        # -------------------
        with torch.autograd.detect_anomaly():
            preds = model(images)
    
            # -------------------
            # Diagnostics BEFORE loss.backward()
            # -------------------
            with torch.no_grad():
                print(f"\n=== Diagnostics epoch {epoch} ===")
                print("Images: min %.3f, max %.3f, mean %.3f" % (images.min().item(), images.max().item(), images.mean().item()))
                print("Keypoints: min %.3f, max %.3f, mean %.3f" % (keypoints.min().item(), keypoints.max().item(), keypoints.mean().item()))
                print("Preds: min %.3f, max %.3f, mean %.3f" % (preds.min().item(), preds.max().item(), preds.mean().item()))
    
            # -------------------
            # Compute loss
            # -------------------
            loss = criterion(preds, keypoints.view(keypoints.size(0), -1))
    
            # -------------------
            # Diagnostics AFTER backward() (grad check)
            # -------------------
            loss.backward()
            total_grad = sum(p.grad.abs().sum().item() for p in model.parameters() if p.grad is not None)
            print("Loss:", loss.item(), "| Total grad sum:", total_grad)

        # optimizer step
        optimizer.step()

  with torch.autograd.detect_anomaly():



=== Diagnostics epoch 0 ===
Images: min -2.118, max 2.082, mean -0.021
Keypoints: min 0.057, max 0.955, mean 0.533
Preds: min 0.285, max 0.627, mean 0.496
Loss: 0.03823692351579666 | Total grad sum: 119.8822019174695

=== Diagnostics epoch 0 ===
Images: min -2.118, max 1.856, mean -0.706
Keypoints: min 0.031, max 0.980, mean 0.527
Preds: min 0.285, max 0.627, mean 0.496
Loss: 0.034270405769348145 | Total grad sum: 109.39575120806694

=== Diagnostics epoch 0 ===
Images: min -2.118, max 2.047, mean -0.378
Keypoints: min 0.046, max 0.939, mean 0.516
Preds: min 0.286, max 0.627, mean 0.496
Loss: 0.03586568683385849 | Total grad sum: 130.356221601367

=== Diagnostics epoch 0 ===
Images: min -2.118, max 1.995, mean -0.256
Keypoints: min 0.062, max 0.952, mean 0.520
Preds: min 0.288, max 0.626, mean 0.496
Loss: 0.038210418075323105 | Total grad sum: 147.0216919630766

=== Diagnostics epoch 1 ===
Images: min -2.118, max 2.013, mean -0.488
Keypoints: min 0.079, max 0.939, mean 0.510
Preds: min

In [6]:
IMAGENET_MEAN = np.array([0.485, 0.456, 0.406])
IMAGENET_STD  = np.array([0.229, 0.224, 0.225])
def inspect_data(img_tensor, preds, gt=None):
    img = img_tensor.permute(1,2,0).cpu().numpy()
    img = img * IMAGENET_STD + IMAGENET_MEAN   # denormalize
    img = (img * 255).clip(0,255).astype(np.uint8)
    img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)

    # Draw predictions
    for i, (x, y) in enumerate(preds):
        x, y = int(x), int(y)
        cv2.circle(img, (x,y), 4, (0,0,255), -1)   # Red = prediction
        cv2.putText(img, str(i), (x+5, y-5),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0,255,0), 1, cv2.LINE_AA)

    # Draw ground truth if provided
    if gt is not None:
        for i, (x, y) in enumerate(gt):
            x, y = int(x), int(y)
            cv2.circle(img, (x,y), 4, (255,0,0), 1)   # Blue = ground truth

    cv2.imshow("Debug", img)
    cv2.waitKey(1)
    key = cv2.waitKey(0)
    if key == ord('q'):
        cv2.destroyAllWindows()
        return False
    return True

model.eval()
with torch.no_grad():
    for images, keypoints in val_loader:
        images = images.to(device)
        keypoints = keypoints.to(device)
        preds = model(images)
        preds_px = preds.view(-1, NUM_LANDMARKS, 2) * 224
        keypoints_px = keypoints.view(-1, NUM_LANDMARKS, 2) * 224

        inspect_data(images[0].cpu(), preds_px[0].cpu().numpy(), keypoints_px[0].cpu().numpy())
        break  # only visualize first batch/image