 ## Setup

In [10]:
import os
import torch
import torchvision
import torch.nn as nn
import torchvision.transforms.v2 as T
from torchvision.datasets import CocoDetection
from pycocotools.coco import COCO
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import numpy as np
import random
import wandb

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


 ## Paths

In [11]:
COCO_PATH = "../../data/coco/"  # change this
IMG_DIR_TRAIN = os.path.join(COCO_PATH, "images/train2017")
IMG_DIR_VAL = os.path.join(COCO_PATH, "images/val2017")
ANN_FILE_TRAIN = os.path.join(COCO_PATH, "annotations/person_keypoints_train2017.json")
ANN_FILE_VAL = os.path.join(COCO_PATH, "annotations/person_keypoints_val2017.json")



In [12]:
class YOLOKeypointDataset(CocoDetection):
    def __init__(self, img_folder, ann_file, transform=None, grid_size=32, num_keypoints=17):
        super().__init__(img_folder, ann_file)
        self.coco = self.coco
        self.transform = transform
        self.grid_size = grid_size
        self.num_keypoints = num_keypoints
        self.image_size = 256  # resize all images

    def __getitem__(self, index):
        img, target = super().__getitem__(index)

        if self.transform:
            img = self.transform(img)

        target_tensor = torch.zeros((self.grid_size, self.grid_size, self.num_keypoints * 3))  # x, y, vis for each kp
        for ann in target:
            if ann["num_keypoints"] == 0:
                continue
            kps = torch.tensor(ann["keypoints"]).view(-1, 3)
            for i, (x, y, v) in enumerate(kps):
                if v > 0:
                    grid_x = int(x * self.grid_size / self.image_size)
                    grid_y = int(y * self.grid_size / self.image_size)
                    if 0 <= grid_x < self.grid_size and 0 <= grid_y < self.grid_size:
                        target_tensor[grid_y, grid_x, i*3 + 0] = x / self.image_size
                        target_tensor[grid_y, grid_x, i*3 + 1] = y / self.image_size
                        target_tensor[grid_y, grid_x, i*3 + 2] = 1.0  # visible
        return img, target_tensor

transform = T.Compose([
    T.Resize((256, 256)),
    T.ToTensor(),
    T.ConvertImageDtype(torch.float32)
])

train_dataset = YOLOKeypointDataset(IMG_DIR_TRAIN, ANN_FILE_TRAIN, transform)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=4)



loading annotations into memory...
Done (t=10.49s)
creating index...
index created!


In [13]:
class YOLOKeypointNet(nn.Module):
    def __init__(self, num_keypoints=17, grid_size=32):
        super().__init__()
        self.backbone = torchvision.models.resnet18(weights="DEFAULT")
        self.backbone.fc = nn.Identity()

        self.head = nn.Sequential(
            nn.Conv2d(512, 256, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(256, num_keypoints * 3, kernel_size=1)  # (x, y, vis) per keypoint
        )
        self.grid_size = grid_size
        self.num_keypoints = num_keypoints

    def forward(self, x):
        x = self.backbone(x)
        x = x.view(x.size(0), 512, 1, 1)
        x = torch.nn.functional.interpolate(x, size=(self.grid_size, self.grid_size), mode="bilinear")
        x = self.head(x)
        return x.permute(0, 2, 3, 1)  # [B, G, G, K*3]

In [14]:
def keypoint_loss(pred, target):
    mask = target[..., 2::3] > 0  # visibility mask
    pos_loss = ((pred[..., 0::3] - target[..., 0::3])**2 +
                (pred[..., 1::3] - target[..., 1::3])**2)
    conf_loss = ((pred[..., 2::3] - target[..., 2::3])**2)
    total_loss = (pos_loss * mask).sum() + conf_loss.sum()
    return total_loss / target.size(0)

In [None]:
model = YOLOKeypointNet().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

wandb.init(project="yolo-keypoint", name="resnet18-grid32", config={"epochs": 10, "grid_size": 32})

for epoch in range(10):
    model.train()
    total_loss = 0
    loop = tqdm(train_loader, desc=f"Epoch {epoch+1}")
    for imgs, targets in loop:
        imgs = imgs.to(device)
        targets = targets.to(device)

        preds = model(imgs)
        loss = keypoint_loss(preds, targets)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        loop.set_postfix(loss=loss.item())
        wandb.log({"loss": loss.item()})

    print(f"Epoch {epoch+1}, avg loss: {total_loss / len(train_loader):.4f}")

[34m[1mwandb[0m: Currently logged in as: [33mfejowo5522[0m ([33mfejowo5522-[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Epoch 1:   0%|          | 0/7393 [00:00<?, ?it/s]

In [None]:
def visualize_predictions(img, pred, threshold=0.5):
    img = img.permute(1, 2, 0).cpu().numpy()
    pred = pred.cpu().detach().numpy()

    plt.imshow(img)
    G = pred.shape[0]
    for y in range(G):
        for x in range(G):
            for i in range(17):
                conf = pred[y, x, i * 3 + 2]
                if conf > threshold:
                    px = pred[y, x, i * 3 + 0] * 256
                    py = pred[y, x, i * 3 + 1] * 256
                    plt.scatter([px], [py], c='r', s=10)
    plt.show()
