In [None]:
import os
import torch
from torchvision.datasets import CocoDetection
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader, random_split
from torchvision.transforms import functional as F
from pycocotools.coco import COCO
import torchvision.transforms.v2 as T

import torch
import torch.nn as nn
from torchvision.models import resnet18
from torchvision import transforms

import matplotlib.pyplot as plt

from tqdm.notebook import tqdm


In [37]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

cuda


In [38]:
COCO_PATH = "./data/coco/"  # change this
IMG_DIR_TRAIN = os.path.join(COCO_PATH, "images/train2017")
IMG_DIR_VAL = os.path.join(COCO_PATH, "images/val2017")
ANN_FILE_TRAIN = os.path.join(COCO_PATH, "annotations/person_keypoints_train2017.json")
ANN_FILE_VAL = os.path.join(COCO_PATH, "annotations/person_keypoints_val2017.json")

NUM_KEYPOINTS = 17
EPOCHS = 10


# Data

In [39]:
class CustomTransform:
    def __init__(self, size=(256, 256)):
        self.transform = transforms.Compose([
            transforms.Resize(size),
            transforms.ToTensor()
        ])
    
    def __call__(self, image, target):
        return self.transform(image), target
    

In [40]:
class CocoKeypointsDataset(CocoDetection):
    def __init__(self, img_folder, ann_file, transforms=None):
        super().__init__(img_folder, ann_file)
        self.coco = COCO(ann_file)
        self._transforms = transforms

    def __getitem__(self, idx):
        img, target = super().__getitem__(idx)
        ann_ids = self.coco.getAnnIds(imgIds=self.ids[idx], iscrowd=None)
        anns = self.coco.loadAnns(ann_ids)
        # Filter for annotations with keypoints
        anns = [ann for ann in anns if 'keypoints' in ann]
        target = {
            "image_id": self.ids[idx],
            "annotations": anns
        }
        if self._transforms:
            img, target = self._transforms(img, target)
        return img, target
    

In [41]:
def coco_collate_fn(batch):
    imgs, targets = zip(*batch)
    imgs = torch.stack([img for img in imgs])
    keypoints = []
    for t in targets:
        anns = t['annotations']
        if not anns:
            keypoints.append(torch.zeros(17, 2))
            continue
        best = max(anns, key=lambda a: sum([v > 0 for v in a['keypoints'][2::3]]))
        kp = torch.tensor(best['keypoints']).view(17, 3)[:, :2]
        keypoints.append(kp)
    keypoints = torch.stack(keypoints).reshape(len(batch), -1)
    return imgs, keypoints


In [42]:
train_dataset = CocoKeypointsDataset(IMG_DIR_TRAIN, ANN_FILE_TRAIN, transforms=CustomTransform())
val_dataset = CocoKeypointsDataset(IMG_DIR_VAL, ANN_FILE_VAL, transforms=CustomTransform())

val_size = int(0.5 * len(val_dataset))
test_size = len(val_dataset) - val_size
val_dataset, test_dataset = random_split(val_dataset, [val_size, test_size])

loading annotations into memory...
Done (t=5.75s)
creating index...
index created!
loading annotations into memory...
Done (t=5.73s)
creating index...
index created!
loading annotations into memory...
Done (t=0.25s)
creating index...
index created!
loading annotations into memory...
Done (t=0.23s)
creating index...
index created!


In [43]:
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, collate_fn=coco_collate_fn)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False, collate_fn=coco_collate_fn)
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False, collate_fn=coco_collate_fn)

# Model

In [44]:

# class KeypointModel(nn.Module):
#     def __init__(self):
#         super().__init__()
#         self.backbone = resnet18(pretrained=True)
#         self.backbone.fc = nn.Linear(self.backbone.fc.in_features, NUM_KEYPOINTS * 2)

#     def forward(self, x):
#         return self.backbone(x)

class KeypointModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(64, 192, kernel_size=5, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(192, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
        )
        self.avgpool = nn.AdaptiveAvgPool2d((6, 6))
        self.classifier = nn.Sequential(
            nn.Dropout(),
            nn.Linear(256 * 6 * 6, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, NUM_KEYPOINTS * 2),
        )

    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x


In [45]:
model = KeypointModel().to(device)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)


In [46]:
def train_one_epoch(model, dataloader, optimizer, criterion):
    model.train()
    total_loss = 0
    for imgs, targets in tqdm(dataloader):
        imgs = imgs.to(device)

        keypoints = torch.stack([
            torch.tensor(t).float()
            for t in targets
        ]).to(device)

        preds = model(imgs)
        loss = criterion(preds, keypoints)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    return total_loss / len(dataloader)


In [47]:
def evaluate(model, dataloader):
    model.eval()
    preds_list, gt_list = [], []
    with torch.no_grad():
        for imgs, targets in dataloader:
            imgs = imgs.to(device)
            keypoints = targets.to(device)
            preds = model(imgs)

            preds_list.append(preds.cpu())
            gt_list.append(keypoints.cpu())
    return preds_list, gt_list


In [48]:
train_losses = []
val_losses = []

for epoch in range(EPOCHS):
    train_loss = train_one_epoch(model, train_loader, optimizer, criterion)
    train_losses.append(train_loss)

    preds_list, gt_list = evaluate(model, val_loader)
    val_loss = sum(criterion(pred, gt).item() for pred, gt in zip(preds_list, gt_list)) / len(gt_list)
    val_losses.append(val_loss)

    print(f"Epoch {epoch+1}/{EPOCHS} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")


  0%|          | 0/14786 [00:00<?, ?it/s]

  torch.tensor(t).float()


KeyboardInterrupt: 

In [None]:

def visualize_prediction(img_tensor, keypoints, pred_keypoints=None):
    img = img_tensor.permute(1, 2, 0).cpu().numpy()
    keypoints = keypoints.view(-1, 2).cpu().numpy()

    plt.imshow(img)
    plt.scatter(keypoints[:, 0], keypoints[:, 1], c='g', label='Ground Truth')
    if pred_keypoints is not None:
        pred_keypoints = pred_keypoints.view(-1, 2).cpu().numpy()
        plt.scatter(pred_keypoints[:, 0], pred_keypoints[:, 1], c='r', marker='x', label='Predicted')
    plt.legend()
    plt.show()


In [None]:
model.eval()
imgs, targets = next(iter(test_loader))
imgs = imgs.to(device)
with torch.no_grad():
    preds = model(imgs)

for i in range(5):
    img = imgs[i].cpu()
    gt_kpts = targets['keypoints'][i]
    pred_kpts = preds[i].reshape(-1, 2).cpu()

    visualize_prediction(img, gt_kpts, pred_kpts)


In [None]:
plt.plot(train_losses, label='Train Loss')
plt.plot(val_losses, label='Val Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.title("Training & Validation Loss")
plt.show()
