In [1]:
import os
import json
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
from pycocotools.coco import COCO
import random
from tqdm import tqdm

# Set paths (update these according to your setup)
COCO_PATH = "../../data/coco/"
IMG_DIR_TRAIN = os.path.join(COCO_PATH, "images/train2017")
IMG_DIR_VAL = os.path.join(COCO_PATH, "images/val2017")
ANN_FILE_TRAIN = os.path.join(COCO_PATH, "annotations/person_keypoints_train2017.json")
ANN_FILE_VAL = os.path.join(COCO_PATH, "annotations/person_keypoints_val2017.json")

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
class COCOKeypointsDataset(Dataset):
    def __init__(self, img_dir, ann_file, img_size=416, grid_size=13, transform=None):
        self.img_dir = img_dir
        self.coco = COCO(ann_file)
        self.img_ids = list(self.coco.imgs.keys())
        self.img_size = img_size
        self.grid_size = grid_size
        self.transform = transform
        self.cell_size = img_size / grid_size
        
        # Filter images with at least one person
        valid_ids = []
        for img_id in self.img_ids:
            ann_ids = self.coco.getAnnIds(imgIds=img_id, catIds=[1])  # Category 1: person
            anns = self.coco.loadAnns(ann_ids)
            if len(anns) > 0:
                valid_ids.append(img_id)
        self.img_ids = valid_ids

    def __len__(self):
        return len(self.img_ids)

    def __getitem__(self, idx):
        img_id = self.img_ids[idx]
        img_info = self.coco.loadImgs(img_id)[0]
        img_path = os.path.join(self.img_dir, img_info['file_name'])
        
        # Load image
        img = Image.open(img_path).convert('RGB')
        orig_w, orig_h = img.size
        
        # Resize and transform image
        if self.transform:
            img = self.transform(img)
        
        # Initialize target grid tensor
        # Format: [obj, x, y, w, h, kp_x1, kp_y1, kp_v1, ...] for 17 keypoints
        target = torch.zeros((self.grid_size, self.grid_size, 1 + 4 + 17 * 3))
        
        # Get annotations
        ann_ids = self.coco.getAnnIds(imgIds=img_id, catIds=[1])
        anns = self.coco.loadAnns(ann_ids)
        
        for ann in anns:
            # Skip crowd annotations
            if ann.get('iscrowd', 0):
                continue
                
            # Get bounding box and keypoints
            x, y, w, h = ann['bbox']
            keypoints = np.array(ann['keypoints']).reshape(-1, 3)
            
            # Scale to image_size
            scale_x = self.img_size / orig_w
            scale_y = self.img_size / orig_h
            
            # Scale bounding box
            x = x * scale_x
            y = y * scale_y
            w = w * scale_x
            h = h * scale_y
            
            # Calculate grid cell
            cx = x + w / 2
            cy = y + h / 2
            grid_x = int(cx / self.cell_size)
            grid_y = int(cy / self.cell_size)
            
            # Skip if out of bounds
            if grid_x >= self.grid_size or grid_y >= self.grid_size:
                continue
                
            # Objectness score
            target[grid_y, grid_x, 0] = 1
            
            # Bounding box (relative to cell)
            target[grid_y, grid_x, 1] = (cx - grid_x * self.cell_size) / self.cell_size
            target[grid_y, grid_x, 2] = (cy - grid_y * self.cell_size) / self.cell_size
            target[grid_y, grid_x, 3] = w / self.img_size
            target[grid_y, grid_x, 4] = h / self.img_size
            
            # Keypoints
            for k, (x_k, y_k, v_k) in enumerate(keypoints):
                offset = 5 + k * 3
                x_k = x_k * scale_x
                y_k = y_k * scale_y
                
                # Keypoint coordinates relative to cell
                target[grid_y, grid_x, offset] = (x_k - grid_x * self.cell_size) / self.cell_size
                target[grid_y, grid_x, offset + 1] = (y_k - grid_y * self.cell_size) / self.cell_size
                target[grid_y, grid_x, offset + 2] = 1 if v_k == 2 else 0  # 2=visible
        
        return img, target

# Transformations
transform = transforms.Compose([
    transforms.Resize((416, 416)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Create datasets
train_dataset = COCOKeypointsDataset(
    IMG_DIR_TRAIN, ANN_FILE_TRAIN, transform=transform
)
val_dataset = COCOKeypointsDataset(
    IMG_DIR_VAL, ANN_FILE_VAL, transform=transform
)

# Data loaders
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False, num_workers=4)

loading annotations into memory...
Done (t=11.15s)
creating index...
index created!
loading annotations into memory...
Done (t=0.43s)
creating index...
index created!


In [3]:
class KeypointYOLO(nn.Module):
    def __init__(self):
        super(KeypointYOLO, self).__init__()
        self.grid_size = 13
        self.num_keypoints = 17
        
        # Feature extractor (simplified Darknet)
        self.features = nn.Sequential(
            # Input: 3x416x416
            nn.Conv2d(3, 32, 3, padding=1),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.1),
            nn.MaxPool2d(2, 2),  # 208x208
            
            nn.Conv2d(32, 64, 3, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.1),
            nn.MaxPool2d(2, 2),  # 104x104
            
            nn.Conv2d(64, 128, 3, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.1),
            nn.MaxPool2d(2, 2),  # 52x52
            
            nn.Conv2d(128, 256, 3, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.1),
            nn.MaxPool2d(2, 2),  # 26x26
            
            nn.Conv2d(256, 512, 3, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.1),
            nn.MaxPool2d(2, 2),  # 13x13
        )
        
        # Detection head
        self.detector = nn.Sequential(
            nn.Conv2d(512, 1024, 3, padding=1),
            nn.BatchNorm2d(1024),
            nn.LeakyReLU(0.1),
            nn.Conv2d(1024, (1 + 4 + 3 * self.num_keypoints), 1)
        )

    def forward(self, x):
        x = self.features(x)
        x = self.detector(x)
        # Reshape: (batch, channels, grid, grid) -> (batch, grid, grid, channels)
        x = x.permute(0, 2, 3, 1)
        return x

In [4]:
class KeypointLoss(nn.Module):
    def __init__(self, lambda_coord=5, lambda_noobj=0.5, lambda_kp=5):
        super().__init__()
        self.mse = nn.MSELoss(reduction='sum')
        self.bce = nn.BCEWithLogitsLoss(reduction='sum')
        self.lambda_coord = lambda_coord
        self.lambda_noobj = lambda_noobj
        self.lambda_kp = lambda_kp

    def forward(self, preds, targets):
        # Shapes: (batch, grid, grid, 1+4+51)
        obj_mask = targets[..., 0] == 1  # Cells with objects
        noobj_mask = targets[..., 0] == 0  # Cells without objects
        
        # Objectness loss
        obj_loss = self.bce(preds[..., 0][obj_mask], targets[..., 0][obj_mask])
        noobj_loss = self.bce(preds[..., 0][noobj_mask], targets[..., 0][noobj_mask])
        
        # Bounding box losses (only for cells with objects)
        box_preds = preds[obj_mask][..., 1:5]
        box_targets = targets[obj_mask][..., 1:5]
        
        # Center coordinates
        center_loss = self.mse(box_preds[..., :2], box_targets[..., :2])
        
        # Width/height
        wh_loss = self.mse(torch.sqrt(box_preds[..., 2:4] + 1e-6), 
                          torch.sqrt(box_targets[..., 2:4] + 1e-6))
        
        # Keypoint losses
        kp_loss = 0
        for k in range(17):
            # Get predictions and targets for this keypoint
            kp_preds = preds[obj_mask][..., 5 + k*3: 5 + k*3 + 2]
            kp_targets = targets[obj_mask][..., 5 + k*3: 5 + k*3 + 2]
            vis_targets = targets[obj_mask][..., 5 + k*3 + 2]
            
            # Only calculate loss for visible keypoints
            vis_mask = vis_targets == 1
            if vis_mask.sum() > 0:
                kp_loss += self.mse(kp_preds[vis_mask], kp_targets[vis_mask])
            
            # Visibility classification
            vis_loss = self.bce(preds[obj_mask][..., 5 + k*3 + 2], vis_targets)
            kp_loss += vis_loss
        
        total_loss = (
            obj_loss + 
            self.lambda_noobj * noobj_loss +
            self.lambda_coord * (center_loss + wh_loss) +
            self.lambda_kp * kp_loss
        )
        return total_loss

In [None]:
def train_model(model, train_loader, val_loader, num_epochs=10, lr=1e-4):
    model = model.to(device)
    criterion = KeypointLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)
    
    for epoch in range(num_epochs):
        # Training
        model.train()
        train_loss = 0
        for images, targets in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
            images = images.to(device)
            targets = targets.to(device)
            
            # Forward pass
            preds = model(images)
            loss = criterion(preds, targets)
            
            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
        
        # Validation
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for images, targets in val_loader:
                images = images.to(device)
                targets = targets.to(device)
                preds = model(images)
                loss = criterion(preds, targets)
                val_loss += loss.item()
        
        print(f"Epoch {epoch+1}/{num_epochs} | "
              f"Train Loss: {train_loss/len(train_loader):.4f} | "
              f"Val Loss: {val_loss/len(val_loader):.4f}")
    
    return model

# Initialize and train model
model = KeypointYOLO()
trained_model = train_model(
    model,
    train_loader,
    val_loader,
    num_epochs=10,
    lr=1e-4
)

# Save model
torch.save(trained_model.state_dict(), "keypoint_yolo.pth")

Epoch 1/10:   0%|          | 0/8015 [00:00<?, ?it/s]

In [None]:
def visualize_predictions(image, preds, threshold=0.5):
    # Convert image to numpy
    img = image.permute(1, 2, 0).cpu().numpy()
    img = img * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
    img = np.clip(img, 0, 1)
    
    fig, ax = plt.subplots(1)
    ax.imshow(img)
    
    # Process predictions
    grid_size = preds.shape[0]
    cell_size = 416 / grid_size
    
    for gy in range(grid_size):
        for gx in range(grid_size):
            if preds[gy, gx, 0] > threshold:  # Object detected
                # Get bounding box
                bx = (gx + preds[gy, gx, 1]) * cell_size
                by = (gy + preds[gy, gx, 2]) * cell_size
                bw = preds[gy, gx, 3] * 416
                bh = preds[gy, gx, 4] * 416
                
                # Draw rectangle
                rect = plt.Rectangle(
                    (bx - bw/2, by - bh/2), bw, bh,
                    fill=False, edgecolor='red', linewidth=1
                )
                ax.add_patch(rect)
                
                # Draw keypoints
                for k in range(17):
                    if preds[gy, gx, 5 + k*3 + 2] > 0.5:  # Visible keypoint
                        kx = (gx + preds[gy, gx, 5 + k*3]) * cell_size
                        ky = (gy + preds[gy, gx, 5 + k*3 + 1]) * cell_size
                        ax.scatter(kx, ky, s=20, c='blue')
    
    plt.show()

# Load a sample image
model.eval()
sample_img, _ = next(iter(val_loader))
with torch.no_grad():
    preds = model(sample_img[0:1].to(device))[0].cpu().numpy()

# Visualize
visualize_predictions(sample_img[0], preds)