In [1]:
from google.colab import drive
drive.mount('/content/drive')


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
!unzip /content/drive/MyDrive/tennis_court_det_dataset.zip

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
  inflating: data/images/JNKp7sCdQlY_2200.png  
  inflating: data/images/JNKp7sCdQlY_2250.png  
  inflating: data/images/JNKp7sCdQlY_250.png  
  inflating: data/images/JNKp7sCdQlY_300.png  
  inflating: data/images/JNKp7sCdQlY_350.png  
  inflating: data/images/JNKp7sCdQlY_400.png  
  inflating: data/images/JNKp7sCdQlY_450.png  
  inflating: data/images/JNKp7sCdQlY_50.png  
  inflating: data/images/JNKp7sCdQlY_500.png  
  inflating: data/images/JNKp7sCdQlY_550.png  
  inflating: data/images/JNKp7sCdQlY_600.png  
  inflating: data/images/JNKp7sCdQlY_650.png  
  inflating: data/images/JNKp7sCdQlY_700.png  
  inflating: data/images/JNKp7sCdQlY_750.png  
  inflating: data/images/juXbdW7z0WA_100.png  
  inflating: data/images/juXbdW7z0WA_1050.png  
  inflating: data/images/juXbdW7z0WA_1100.png  
  inflating: data/images/juXbdW7z0WA_200.png  
  inflating: data/images/juXbdW7z0WA_350.png  
  inflating: data/images/juXbdW7z0WA_40

In [3]:
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms

device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')

In [5]:
import json
import cv2
import numpy as np

In [8]:
class KeypointsDataset(Dataset):
  def __init__(self, img_dir, data_file):
    self.img_dir = img_dir

    with open(data_file, 'r') as f:
      self.data = json.load(f)

    self.transforms = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

  def __len__(self):
    return len(self.data)

  def __getitem__(self, idx):
    item = self.data[idx]
    img = cv2.imread(f"{self.img_dir}/{item['id']}.png")
    h, w = img.shape[:2]

    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = self.transforms(img)

    kps = np.array(item['kps']).flatten()
    kps = kps.astype(np.float32)
    kps[::2] *= 224.0 / w
    kps[1::2] *= 224.0 / h

    return img, kps


In [9]:
train_dataset = KeypointsDataset("data/images", "data/data_train.json")
val_dataset = KeypointsDataset("data/images", "data/data_val.json")

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=True)

In [10]:
model = models.resnet50(pretrained=True)
model.fc = torch.nn.Linear(model.fc.in_features, 28)
model = model.to(device)

Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [00:02<00:00, 51.0MB/s]


In [11]:
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

In [18]:
epochs = 10

for epoch in range(epochs):
  for i, (imgs, kps) in enumerate(train_loader):
    imgs = imgs.to(device)
    kps = kps.to(device)

    optimizer.zero_grad()
    outputs = model(imgs)
    loss = criterion(outputs, kps)
    loss.backward()
    optimizer.step()

    if i % 10 == 0:
            print(f"Epoch {epoch}, iter {i}, loss: {loss.item()}")

Epoch 0, iter 0, loss: 2.723433017730713
Epoch 0, iter 10, loss: 2.2327935695648193
Epoch 0, iter 20, loss: 3.3140528202056885
Epoch 0, iter 30, loss: 13.37511920928955
Epoch 0, iter 40, loss: 4.364597320556641
Epoch 0, iter 50, loss: 8.257667541503906
Epoch 0, iter 60, loss: 9.35802936553955
Epoch 0, iter 70, loss: 4.097846508026123
Epoch 0, iter 80, loss: 4.396847248077393
Epoch 0, iter 90, loss: 1.5489442348480225
Epoch 0, iter 100, loss: 6.8327531814575195
Epoch 0, iter 110, loss: 1.9263641834259033
Epoch 0, iter 120, loss: 4.01621150970459
Epoch 0, iter 130, loss: 1.2753175497055054
Epoch 0, iter 140, loss: 3.0497806072235107
Epoch 0, iter 150, loss: 7.98350715637207
Epoch 0, iter 160, loss: 1.83021080493927
Epoch 0, iter 170, loss: 9.600652694702148
Epoch 0, iter 180, loss: 12.742444038391113
Epoch 0, iter 190, loss: 2.0154855251312256
Epoch 0, iter 200, loss: 32.174163818359375
Epoch 0, iter 210, loss: 3.312422037124634
Epoch 0, iter 220, loss: 15.806962013244629
Epoch 0, iter 2

In [19]:
torch.save(model.state_dict(), "keypoints_model.pth")