In [None]:
import os
import torch
from torchvision.datasets import CocoDetection
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader, random_split
from torchvision.transforms import functional as F
from pycocotools.coco import COCO
import torchvision.transforms.v2 as T  # PyTorch >= 2.0 preferred

import torch
import torch.nn as nn
from torchvision.models import resnet18

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Data

In [None]:
# Paths
COCO_PATH = "/path/to/COCO"  # change this
IMG_DIR_TRAIN = os.path.join(COCO_PATH, "images/train2017")
IMG_DIR_VAL = os.path.join(COCO_PATH, "images/val2017")
ANN_FILE_TRAIN = os.path.join(COCO_PATH, "annotations/person_keypoints_train2017.json")
ANN_FILE_VAL = os.path.join(COCO_PATH, "annotations/person_keypoints_val2017.json")

# Basic transform (you can expand this later)
class ToTensorTransform:
    def __call__(self, image, target):
        return F.to_tensor(image), target

# Custom dataset wrapper to use keypoints
class CocoKeypointsDataset(CocoDetection):
    def __init__(self, img_folder, ann_file, transforms=None):
        super().__init__(img_folder, ann_file)
        self.coco = COCO(ann_file)
        self._transforms = transforms

    def __getitem__(self, idx):
        img, target = super().__getitem__(idx)
        ann_ids = self.coco.getAnnIds(imgIds=self.ids[idx], iscrowd=None)
        anns = self.coco.loadAnns(ann_ids)
        # Filter for annotations with keypoints
        anns = [ann for ann in anns if 'keypoints' in ann]
        target = {
            "image_id": self.ids[idx],
            "annotations": anns
        }
        if self._transforms:
            img, target = self._transforms(img, target)
        return img, target

# Initialize datasets
train_dataset = CocoKeypointsDataset(IMG_DIR_TRAIN, ANN_FILE_TRAIN, transforms=ToTensorTransform())
val_dataset = CocoKeypointsDataset(IMG_DIR_VAL, ANN_FILE_VAL, transforms=ToTensorTransform())

# Optional: split val set for a test set
val_size = int(0.5 * len(val_dataset))
test_size = len(val_dataset) - val_size
val_dataset, test_dataset = random_split(val_dataset, [val_size, test_size])

# DataLoaders
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=4, collate_fn=lambda x: tuple(zip(*x)))
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False, num_workers=4, collate_fn=lambda x: tuple(zip(*x)))
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False, num_workers=4, collate_fn=lambda x: tuple(zip(*x)))

# Model

In [None]:
NUM_KEYPOINTS = 17

class KeypointModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.backbone = resnet18(pretrained=True)
        self.backbone.fc = nn.Linear(self.backbone.fc.in_features, NUM_KEYPOINTS * 2)

    def forward(self, x):
        return self.backbone(x)


In [None]:
model = KeypointModel().to(device)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)


In [None]:
def train_one_epoch(model, dataloader, optimizer, criterion):
    model.train()
    total_loss = 0
    for imgs, targets in dataloader:
        imgs = imgs.to(device)
        keypoints = targets['keypoints'].reshape(imgs.size(0), -1).to(device)

        preds = model(imgs)
        loss = criterion(preds, keypoints)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    return total_loss / len(dataloader)


In [None]:
def evaluate(model, dataloader):
    model.eval()
    preds_list, gt_list = [], []
    with torch.no_grad():
        for imgs, targets in dataloader:
            imgs = imgs.to(device)
            keypoints = targets['keypoints'].reshape(imgs.size(0), -1).to(device)
            preds = model(imgs)

            preds_list.append(preds.cpu())
            gt_list.append(keypoints.cpu())
    return preds_list, gt_list


In [None]:
import matplotlib.pyplot as plt

def visualize_prediction(img_tensor, keypoints, pred_keypoints=None):
    img = img_tensor.permute(1, 2, 0).cpu().numpy()
    keypoints = keypoints.view(-1, 2).cpu().numpy()

    plt.imshow(img)
    plt.scatter(keypoints[:, 0], keypoints[:, 1], c='g', label='Ground Truth')
    if pred_keypoints is not None:
        pred_keypoints = pred_keypoints.view(-1, 2).cpu().numpy()
        plt.scatter(pred_keypoints[:, 0], pred_keypoints[:, 1], c='r', marker='x', label='Predicted')
    plt.legend()
    plt.show()
