## Libraries

In [2]:
# Cell 1: Imports
import os
import json
import cv2
import numpy as np
import matplotlib.pyplot as plt
from glob import glob
from tqdm.notebook import tqdm
from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn as nn
import torchvision.transforms as T
from torchvision.models import resnet18
from PIL import Image
import fiftyone as fo
import fiftyone.zoo as foz
import fiftyone.utils.annotations as foua
from fiftyone import ViewField as F


## Setup

In [3]:
# Cell 2: Paths
data_root = "./data/coco"
img_dir = os.path.join(data_root, "train2017")
ann_file = os.path.join(data_root, "annotations", "person_keypoints_train2017.json")

assert os.path.exists(ann_file), "Annotation file not found"


## Data

In [4]:
# Cell 3: Load Annotations
with open(ann_file) as f:
    coco = json.load(f)

print(coco[:10])

# Build image_id → file_name mapping
id_to_img = {idx: img['image'] for idx, img in enumerate(coco)}


[{'joints_vis': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'joints': [[620.0, 394.0], [616.0, 269.0], [573.0, 185.0], [647.0, 188.0], [661.0, 221.0], [656.0, 231.0], [610.0, 187.0], [647.0, 176.0], [637.0201, 189.8183], [695.9799, 108.1817], [606.0, 217.0], [553.0, 161.0], [601.0, 167.0], [692.0, 185.0], [693.0, 240.0], [688.0, 313.0]], 'image': '015601864.jpg', 'scale': 3.021046, 'center': [594.0, 257.0]}, {'joints_vis': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'joints': [[895.0, 293.0], [910.0, 279.0], [945.0, 223.0], [1012.0, 218.0], [961.0, 315.0], [960.0, 403.0], [979.0, 221.0], [906.0, 190.0], [912.4915, 190.6586], [830.5085, 182.3414], [871.0, 304.0], [883.0, 229.0], [888.0, 174.0], [924.0, 206.0], [1013.0, 203.0], [955.0, 263.0]], 'image': '015601864.jpg', 'scale': 2.472117, 'center': [952.0, 222.0]}, {'joints_vis': [0, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'joints': [[-1.0, -1.0], [-1.0, -1.0], [806.0, 543.0], [720.0, 593.0], [-1.0, -1.0], [-1.0, -1.0],

In [None]:
# Cell 4: Prepare Dataset List (COCO-style format)
data = []

for ann in coco:
    if sum(ann['joints_vis']) < 5:
        continue
    if len(ann['image']) < 12+4:
        ann['image'] = (12+4 - len(ann['image'])) * '0' + ann['image']

    img_path = os.path.join(img_dir, "train2017", ann['image'])
    if not os.path.exists(img_path):
        continue
    sample = {
        'joints': ann['joints'],
        'joints_vis': ann['joints_vis'],
        'image': ann['image'],
        'scale': ann['scale'],
        'center': ann['center']
    }
    data.append(sample)

print(f"Loaded {len(data)} annotated person keypoint samples in COCO-style format.")


./data/coco\train2017\train2017\000015601864.jpg
./data/coco\train2017\train2017\000015601864.jpg
./data/coco\train2017\train2017\000015599452.jpg
./data/coco\train2017\train2017\000015599452.jpg
./data/coco\train2017\train2017\000015599452.jpg
./data/coco\train2017\train2017\000086617615.jpg
./data/coco\train2017\train2017\000086617615.jpg
./data/coco\train2017\train2017\000060111501.jpg
./data/coco\train2017\train2017\000070807258.jpg
./data/coco\train2017\train2017\000070807258.jpg
./data/coco\train2017\train2017\000002058449.jpg
./data/coco\train2017\train2017\000021233911.jpg
./data/coco\train2017\train2017\000021233911.jpg
./data/coco\train2017\train2017\000018182497.jpg
./data/coco\train2017\train2017\000018182497.jpg
./data/coco\train2017\train2017\000018340451.jpg
./data/coco\train2017\train2017\000018340451.jpg
./data/coco\train2017\train2017\000030424224.jpg
./data/coco\train2017\train2017\000030424224.jpg
./data/coco\train2017\train2017\000043194502.jpg
./data/coco\train201

In [None]:
# Cell 5: Dataset
class CocoKeypointsDataset(Dataset):
    def __init__(self, data, transform=None):
        self.data = data
        self.transform = transform or T.Compose([
            T.Resize((256, 256)),
            T.ToTensor()
        ])

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]
        img_path = os.path.join(img_dir, sample['image'])
        image = Image.open(img_path).convert("RGB")
        w, h = image.size
        scale = np.array([256/w, 256/h])
        keypoints_scaled = np.array(sample['joints']) * scale
        keypoints_flat = keypoints_scaled.flatten().astype(np.float32)

        image = self.transform(image)
        return image, torch.tensor(keypoints_flat)


In [None]:
# Cell 6: Dataloaders
train_data = data[:5000]
val_data = data[5000:5200]

train_ds = CocoKeypointsDataset(train_data)
val_ds = CocoKeypointsDataset(val_data)

train_dl = DataLoader(train_ds, batch_size=32, shuffle=True)
val_dl = DataLoader(val_ds, batch_size=32)


## Model

In [None]:
# Cell 7: Model
class KeypointNet(nn.Module):
    def __init__(self):
        super().__init__()
        base = resnet18(weights=None)
        base.fc = nn.Linear(base.fc.in_features, 34)  # 17 keypoints * 2
        self.backbone = base

    def forward(self, x):
        return self.backbone(x)


## Train

In [None]:
# Cell 8: Training Setup
model = KeypointNet().cuda()
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)


In [None]:
# Cell 9: Training Loop
def train_epoch(dl):
    model.train()
    total_loss = 0
    for imgs, targets in tqdm(dl):
        imgs, targets = imgs.cuda(), targets.cuda()
        preds = model(imgs)
        loss = loss_fn(preds, targets)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(dl)

def eval_epoch(dl):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for imgs, targets in tqdm(dl):
            imgs, targets = imgs.cuda(), targets.cuda()
            preds = model(imgs)
            loss = loss_fn(preds, targets)
            total_loss += loss.item()
    return total_loss / len(dl)

for epoch in range(5):
    train_loss = train_epoch(train_dl)
    val_loss = eval_epoch(val_dl)
    print(f"Epoch {epoch}: Train Loss = {train_loss:.4f}, Val Loss = {val_loss:.4f}")


## Visualize

In [None]:
# Cell 10: Visualize Prediction
def visualize(img, pred_keypoints):
    img = img.permute(1, 2, 0).cpu().numpy()
    kp = pred_keypoints.cpu().reshape(-1, 2)
    img = (img * 255).astype(np.uint8)
    for x, y in kp:
        cv2.circle(img, (int(x), int(y)), 3, (0, 255, 0), -1)
    plt.imshow(img)
    plt.axis('off')
    plt.show()

img, _ = val_ds[0]
model.eval()
with torch.no_grad():
    pred = model(img.unsqueeze(0).cuda()).squeeze()
visualize(img, pred)
