In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms
import json
import cv2
import numpy as np
import torch.nn as nn

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
class KeypointsDataset(Dataset):
    def __init__(self, img_dir, data_file):
        self.img_dir = img_dir
        with open(data_file, "r") as f:
            self.data = json.load(f)
        
        self.transforms = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        img = cv2.imread(f"{self.img_dir}/{item['id']}.png")
        h,w = img.shape[:2]

        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = self.transforms(img)
        kps = np.array(item['kps']).flatten()
        kps = kps.astype(np.float32)

        kps[::2] *= 224.0 / w
        kps[1::2] *= 224.0 / h

        return img, kps


In [None]:
train_dataset = KeypointsDataset("/kaggle/input/court-keypoints/data/images","/kaggle/input/court-keypoints/data/data_train.json")
val_dataset = KeypointsDataset("/kaggle/input/court-keypoints/data/images","/kaggle/input/court-keypoints/data/data_val.json")

train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False)

In [None]:
class ChannelAttention(nn.Module):
    def __init__(self, channel, ratio=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
 
        self.shared_MLP = nn.Sequential(
            nn.Conv2d(channel, channel // ratio, 1, bias=False),
            nn.ReLU(),
            nn.Conv2d(channel // ratio, channel, 1, bias=False)
        )
        self.sigmoid = nn.Sigmoid()
 
    def forward(self, x):
        avgout = self.shared_MLP(self.avg_pool(x))
        maxout = self.shared_MLP(self.max_pool(x))
        return self.sigmoid(avgout + maxout)
 
class SpatialAttention(nn.Module):
    def __init__(self):
        super(SpatialAttention, self).__init__()
        self.conv2d = nn.Conv2d(in_channels=2, out_channels=1, kernel_size=7, stride=1, padding=3)
        self.sigmoid = nn.Sigmoid()
 
    def forward(self, x):
        avgout = torch.mean(x, dim=1, keepdim=True)
        maxout, _ = torch.max(x, dim=1, keepdim=True)
        out = torch.cat([avgout, maxout], dim=1)
        out = self.sigmoid(self.conv2d(out))
        return out

In [None]:
class KeypointResNet50(nn.Module):
    def __init__(self):
        super(KeypointResNet50, self).__init__()
        self.backbone = models.resnet50(pretrained=True)
        self.backbone.fc = nn.Identity()  

        self.channel_attention = ChannelAttention(2048)
        self.spatial_attention = SpatialAttention()

        self.fc = nn.Linear(models.resnet50(pretrained=True).fc.in_features, 14 * 2)

    def forward(self, x):
        x = self.backbone(x)
        #x = self.channel_attention(x) * x
        #x = self.spatial_attention(x) * x

        #x = nn.functional.adaptive_avg_pool2d(x, (1, 1))
        #x = torch.flatten(x, 1)
        x = self.fc(x)
        return x

In [None]:
model = KeypointResNet50()
model = model.to(device)

In [None]:
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

In [None]:
epochs = 20
for epoch in range(epochs):
    model.train()
    for i, (imgs, kps) in enumerate(train_loader):
        imgs = imgs.to(device)
        kps = kps.to(device)

        optimizer.zero_grad()
        outputs = model(imgs)
        loss = criterion(outputs, kps)
        loss.backward()
        optimizer.step()

        if i % 10 == 0:
            print(f"Epoch {epoch}, iter {i}, loss: {loss.item()}")

    model.eval()
    with torch.no_grad():
        val_loss = 0
        for imgs, kps in val_loader:
            imgs = imgs.to(device)
            kps = kps.to(device)

            outputs = model(imgs)
            loss = criterion(outputs, kps)
            val_loss += loss.item()

        val_loss /= len(val_loader)
        print(f"Epoch {epoch}, val loss: {val_loss}")

In [None]:
torch.save(model.stat_dict(), "keypoints_model.pt")