In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms
import json
import cv2
import numpy as np
import torch.nn as nn
import torch.nn.functional as F

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
class KeypointsDataset(Dataset):
    def __init__(self, img_dir, data_file):
        self.img_dir = img_dir
        with open(data_file, "r") as f:
            self.data = json.load(f)
        
        self.transforms = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        img = cv2.imread(f"{self.img_dir}/{item['id']}.png")
        h,w = img.shape[:2]

        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = self.transforms(img)
        kps = np.array(item['kps']).flatten()
        kps = kps.astype(np.float32)

        kps[::2] *= 224.0 / w
        kps[1::2] *= 224.0 / h

        return img, kps


In [4]:
train_dataset = KeypointsDataset("/kaggle/input/court-keypoints/data/images","/kaggle/input/court-keypoints/data/data_train.json")
val_dataset = KeypointsDataset("/kaggle/input/court-keypoints/data/images","/kaggle/input/court-keypoints/data/data_val.json")

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=True)

In [9]:
class ChannelAttention(nn.Module):
    def __init__(self, in_planes, ratio=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)

        self.fc1 = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))
        max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
        out = avg_out + max_out
        return self.sigmoid(out)

class SpatialAttention(nn.Module):
    def __init__(self):
        super(SpatialAttention, self).__init__()
        self.conv1 = nn.Conv2d(2, 1, kernel_size=7, padding=3)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x = torch.cat([avg_out, max_out], dim=1)
        x = self.conv1(x)
        return self.sigmoid(x)

class KeypointResNet50(nn.Module):
    def __init__(self):
        super(KeypointResNet50, self).__init__()
        self.backbone = models.resnet50(pretrained=True)
        self.backbone.fc = nn.Identity()

        self.conv = nn.Conv2d(2048, 2048, kernel_size=1, stride=1, padding=0, bias=False)

        self.channel_attention = ChannelAttention(2048)
        self.spatial_attention = SpatialAttention()

        self.fc = nn.Linear(2048, 14 * 2)

    def forward(self, x):
        x = self.backbone(x)
        x = x.unsqueeze(2).unsqueeze(3)
        x = self.conv(x)
        x = self.channel_attention(x) * x
        x = self.spatial_attention(x) * x

        x = nn.functional.adaptive_avg_pool2d(x, (1, 1))
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x

In [10]:
model = KeypointResNet50()
model = model.to(device)

In [11]:
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

In [12]:
epochs = 20
for epoch in range(epochs):
    model.train()
    for i, (imgs, kps) in enumerate(train_loader):
        imgs = imgs.to(device)
        kps = kps.to(device)

        optimizer.zero_grad()
        outputs = model(imgs)
        loss = criterion(outputs, kps)
        loss.backward()
        optimizer.step()

        if i % 10 == 0:
            print(f"Epoch {epoch}, iter {i}, loss: {loss.item()}")

    model.eval()
    with torch.no_grad():
        val_loss = 0
        for imgs, kps in val_loader:
            imgs = imgs.to(device)
            kps = kps.to(device)

            outputs = model(imgs)
            loss = criterion(outputs, kps)
            val_loss += loss.item()

        val_loss /= len(val_loader)
        print(f"Epoch {epoch}, val loss: {val_loss}")

Epoch 0, iter 0, loss: 14259.5859375
Epoch 0, iter 10, loss: 14713.0869140625
Epoch 0, iter 20, loss: 13434.08984375
Epoch 0, iter 30, loss: 10942.5849609375
Epoch 0, iter 40, loss: 6660.29052734375
Epoch 0, iter 50, loss: 2749.047607421875
Epoch 0, iter 60, loss: 957.94970703125
Epoch 0, iter 70, loss: 257.4103088378906
Epoch 0, iter 80, loss: 124.37671661376953
Epoch 0, iter 90, loss: 51.96255111694336
Epoch 0, iter 100, loss: 91.2490005493164
Epoch 0, iter 110, loss: 62.70871353149414
Epoch 0, iter 120, loss: 37.75197219848633
Epoch 0, iter 130, loss: 46.289615631103516
Epoch 0, iter 140, loss: 110.04586791992188
Epoch 0, iter 150, loss: 44.438941955566406
Epoch 0, iter 160, loss: 36.4776725769043
Epoch 0, iter 170, loss: 85.3827896118164
Epoch 0, iter 180, loss: 42.72153091430664
Epoch 0, iter 190, loss: 58.415977478027344
Epoch 0, iter 200, loss: 100.58003997802734
Epoch 0, iter 210, loss: 37.96141815185547
Epoch 0, iter 220, loss: 31.943378448486328
Epoch 0, iter 230, loss: 29.70

In [16]:
torch.save(model.state_dict(), "keypoints_model.pt")

In [18]:
%ls

keypoints_model.pt


In [19]:
%pwd

'/kaggle/working'

In [20]:
from IPython.display import FileLink
FileLink(r'keypoints_model.pt')