In [1]:
import cv2
import torch
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, models
import numpy as np
import json

In [1]:
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
device

device(type='cuda')

In [3]:
class KeypointDataset(Dataset):
    def __init__(self, img_path, json_path):
        self.img_path = img_path
        with open(json_path, "r") as fh:
            self.data = json.load(fh)
        self.transforms = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize((224,224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        item = self.data[index]
        img = cv2.imread(f"{self.img_path}/{item['id']}.png")
        h,w = img.shape[:2]
        
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = self.transforms(img)

        kps = np.array(item["kps"]).flatten()
        kps = kps.astype(np.float32)

        kps[::2] *= 224.0 / w
        kps[1::2] *= 224.0 / h

        return img, kps


In [4]:
training_dataset = KeypointDataset("tennis_court_det_dataset/data/images", "tennis_court_det_dataset/data/data_train.json")
val_dataset = KeypointDataset("tennis_court_det_dataset/data/images", "tennis_court_det_dataset/data/data_val.json")

train_loader = DataLoader(training_dataset, batch_size = 16, shuffle = True)
val_loader = DataLoader(val_dataset, batch_size = 16, shuffle = True)

# Create the model

In [5]:
model = models.resnet50(pretrained=True)
model.fc = torch.nn.Linear(model.fc.in_features, 14*2) # Replace the last layer



In [6]:
model.to(device)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

# Train the model

In [7]:
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

In [8]:
epochs = 20
for epoch in range(epochs):
    for i, (imgs, kps) in enumerate(train_loader):
        imgs = imgs.to(device)
        kps = kps.to(device)

        optimizer.zero_grad()

        outputs = model(imgs)
        loss = criterion(outputs, kps)

        loss.backward()
        optimizer.step()

        if i % 10 == 0:
            print(f"Epoch {epoch}, iter {i}, loss: {loss.item()}")

Epoch 0, iter 0, loss: 14937.2294921875
Epoch 0, iter 10, loss: 14701.458984375
Epoch 0, iter 20, loss: 14067.0869140625
Epoch 0, iter 30, loss: 13720.052734375
Epoch 0, iter 40, loss: 13659.21875
Epoch 0, iter 50, loss: 13226.248046875
Epoch 0, iter 60, loss: 12777.7939453125
Epoch 0, iter 70, loss: 12242.228515625
Epoch 0, iter 80, loss: 11415.677734375
Epoch 0, iter 90, loss: 11058.9521484375
Epoch 0, iter 100, loss: 10978.4599609375
Epoch 0, iter 110, loss: 10592.1826171875
Epoch 0, iter 120, loss: 10305.845703125
Epoch 0, iter 130, loss: 10293.701171875
Epoch 0, iter 140, loss: 9873.791015625
Epoch 0, iter 150, loss: 9748.3740234375
Epoch 0, iter 160, loss: 8943.244140625
Epoch 0, iter 170, loss: 9006.0107421875
Epoch 0, iter 180, loss: 8504.189453125
Epoch 0, iter 190, loss: 8013.3134765625
Epoch 0, iter 200, loss: 8289.111328125
Epoch 0, iter 210, loss: 7444.2880859375
Epoch 0, iter 220, loss: 7221.99609375
Epoch 0, iter 230, loss: 7140.18408203125
Epoch 0, iter 240, loss: 6603.

In [11]:
torch.save(model.state_dict(), "keypoints_model.pt")