In [1]:

# Importing the required libraries

import cv2
import json
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms

In [2]:
!wget -O tennis_court_det_dataset.zip 'https://www.dropbox.com/scl/fi/bpy2vvf3qmbkz41s59at0/tennis_court_det_dataset.zip?rlkey=xraopgs3x70mpy78podsht5f7&st=l8v8gf20&dl=0'

--2024-05-29 20:33:26--  https://www.dropbox.com/scl/fi/bpy2vvf3qmbkz41s59at0/tennis_court_det_dataset.zip?rlkey=xraopgs3x70mpy78podsht5f7&st=l8v8gf20&dl=0
Resolving www.dropbox.com (www.dropbox.com)... 162.125.5.18, 2620:100:601d:18::a27d:512
Connecting to www.dropbox.com (www.dropbox.com)|162.125.5.18|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://uc56375a3ad0d9439151afb4227f.dl.dropboxusercontent.com/cd/0/inline/CT0NqUzvbdO3HF7gzdkXGvDopkM5ej9PJCCWdVizVR89W13L9QSK6k0qXmjGc3vz55B24yANTIm00Kfj5lXfKjTA4KMtTErtZ-RGyiB7PjtA-cAecufMDtsk-MdnMxf-9lzLYu9PIJq5tnP5oSKLa-be/file# [following]
--2024-05-29 20:33:27--  https://uc56375a3ad0d9439151afb4227f.dl.dropboxusercontent.com/cd/0/inline/CT0NqUzvbdO3HF7gzdkXGvDopkM5ej9PJCCWdVizVR89W13L9QSK6k0qXmjGc3vz55B24yANTIm00Kfj5lXfKjTA4KMtTErtZ-RGyiB7PjtA-cAecufMDtsk-MdnMxf-9lzLYu9PIJq5tnP5oSKLa-be/file
Resolving uc56375a3ad0d9439151afb4227f.dl.dropboxusercontent.com (uc56375a3ad0d9439151afb4227f.dl.dropboxusercon

In [3]:
!unzip tennis_court_det_dataset.zip -d /data

[1;30;43mSe han truncado las últimas 5000 líneas del flujo de salida.[0m
  inflating: /data/data/images/JNKp7sCdQlY_2200.png  
  inflating: /data/data/images/JNKp7sCdQlY_2250.png  
  inflating: /data/data/images/JNKp7sCdQlY_250.png  
  inflating: /data/data/images/JNKp7sCdQlY_300.png  
  inflating: /data/data/images/JNKp7sCdQlY_350.png  
  inflating: /data/data/images/JNKp7sCdQlY_400.png  
  inflating: /data/data/images/JNKp7sCdQlY_450.png  
  inflating: /data/data/images/JNKp7sCdQlY_50.png  
  inflating: /data/data/images/JNKp7sCdQlY_500.png  
  inflating: /data/data/images/JNKp7sCdQlY_550.png  
  inflating: /data/data/images/JNKp7sCdQlY_600.png  
  inflating: /data/data/images/JNKp7sCdQlY_650.png  
  inflating: /data/data/images/JNKp7sCdQlY_700.png  
  inflating: /data/data/images/JNKp7sCdQlY_750.png  
  inflating: /data/data/images/juXbdW7z0WA_100.png  
  inflating: /data/data/images/juXbdW7z0WA_1050.png  
  inflating: /data/data/images/juXbdW7z0WA_1100.png  
  inflating: /data/da

In [4]:
class Dataset_Preparation(Dataset):
    def __init__(self, images_path, data_file):
        self.image_path = images_path
        # Loading the json file
        with open(data_file, 'r') as f: self.data = json.load(f)

        # Preparing data into a suitable format for the model
        self.transforms = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

    def __len__(self): return len(self.data)

    def __getitem__(self, index):
        item = self.data[index]
        path = f'{self.image_path}/{item["id"]}.png'
        img = cv2.imread(path)
        height, width = img.shape[:2]

        # Converting the image to RGB
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = self.transforms(img) # Applying the transformations
        
        # Reshaping the keypoints
        keypoints_data = np.array(item['kps']).flatten()
        keypoints_data = keypoints_data.astype(np.float32)

        # Normalizing the keypoints
        keypoints_data[::2] *= 224.0/width
        keypoints_data[1::2] *= 224.0/height

        return img, keypoints_data


In [14]:
train_dataset = Dataset_Preparation("/data/data/images", "/data/data/data_train.json")
val_dataset = Dataset_Preparation("/data/data/images", "/data/data/data_val.json")

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False)

In [15]:
model = models.resnet50(pretrained=True)
model.fc = torch.nn.Linear(model.fc.in_features, 14*2)

Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [00:00<00:00, 171MB/s]


In [16]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

In [17]:
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

In [18]:
epochs = 20
for epoch in range(epochs):
    for i, (images, keypoints) in enumerate(train_loader):
        images = images.to(device)
        keypoints = keypoints.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, keypoints)
        loss.backward()
        optimizer.step()

        if i % 10 == 0: print(f'Epoch {epoch}, iter {i}, loss: {loss.item()}')


Epoch 0, iter 0, loss: 14039.2880859375
Epoch 0, iter 10, loss: 14223.013671875
Epoch 0, iter 20, loss: 14293.54296875
Epoch 0, iter 30, loss: 13852.2158203125
Epoch 0, iter 40, loss: 12935.94921875
Epoch 0, iter 50, loss: 13168.74609375
Epoch 0, iter 60, loss: 13902.0576171875
Epoch 0, iter 70, loss: 12392.361328125
Epoch 0, iter 80, loss: 11619.197265625
Epoch 0, iter 90, loss: 11633.0380859375
Epoch 0, iter 100, loss: 10853.9267578125
Epoch 0, iter 110, loss: 10690.7216796875
Epoch 0, iter 120, loss: 10197.9736328125
Epoch 0, iter 130, loss: 10378.74609375
Epoch 0, iter 140, loss: 9244.2060546875
Epoch 0, iter 150, loss: 9145.134765625
Epoch 0, iter 160, loss: 8700.44140625
Epoch 0, iter 170, loss: 8670.34765625
Epoch 0, iter 180, loss: 8164.2314453125
Epoch 0, iter 190, loss: 7904.232421875
Epoch 0, iter 200, loss: 7510.5517578125
Epoch 0, iter 210, loss: 7418.07958984375
Epoch 0, iter 220, loss: 7290.61328125
Epoch 0, iter 230, loss: 7096.06494140625
Epoch 0, iter 240, loss: 6840.

In [19]:
torch.save(model.state_dict(), 'keypoints_model.pth')