In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import zipfile
import os

zip_filename = '/content/drive/MyDrive/Datasets/tennis_court_det_dataset.zip'

with zipfile.ZipFile(zip_filename, 'r') as zip_ref:
    zip_ref.extractall('./data')

In [None]:
import json
import os
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms

class TennisCourtDataset(Dataset):
    def __init__(self, json_file, img_dir, transform=None):
        self.annotations = json.load(open(json_file))
        self.img_dir = img_dir
        self.transform = transform

    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, idx):
        img_id = self.annotations[idx]['id']
        img_path = os.path.join(self.img_dir, img_id + '.png')
        image = Image.open(img_path).convert("RGB")
        keypoints = self.annotations[idx]['kps']

        if self.transform:
            image = self.transform(image)

        keypoints = torch.tensor(keypoints, dtype=torch.float32)
        return image, keypoints

transform = transforms.Compose([
    transforms.ToTensor(),
])

train_json = '/content/data/data/data_train.json'
val_json = '/content/data/data/data_val.json'
img_dir = '/content/data/data/images/'

train_dataset = TennisCourtDataset(json_file=train_json, img_dir=img_dir, transform=transform)
val_dataset = TennisCourtDataset(json_file=val_json, img_dir=img_dir, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=2, shuffle=False)

In [None]:
import torch.nn as nn
import torch

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, pad=1, stride=1, bias=True):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=pad, bias=bias),
            nn.ReLU(),
            nn.BatchNorm2d(out_channels)
        )

    def forward(self, x):
        return self.block(x)

class UpConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, pad=1, stride=1, bias=True):
        super().__init__()
        self.block = nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=pad, bias=bias),
            nn.ReLU(),
            nn.BatchNorm2d(out_channels)
        )

    def forward(self, x):
        return self.block(x)

class TrackNet(nn.Module):
    def __init__(self, input_channels=3, out_channels=14):
        super().__init__()
        self.out_channels = out_channels
        self.input_channels = input_channels

        self.conv1 = ConvBlock(in_channels=self.input_channels, out_channels=64)
        self.conv2 = ConvBlock(in_channels=64, out_channels=64)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv3 = ConvBlock(in_channels=64, out_channels=128)
        self.conv4 = ConvBlock(in_channels=128, out_channels=128)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv5 = ConvBlock(in_channels=128, out_channels=256)
        self.conv6 = ConvBlock(in_channels=256, out_channels=256)
        self.conv7 = ConvBlock(in_channels=256, out_channels=256)
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv8 = ConvBlock(in_channels=256, out_channels=512)
        self.conv9 = ConvBlock(in_channels=512, out_channels=512)
        self.conv10 = ConvBlock(in_channels=512, out_channels=512)

        self.upconv1 = UpConvBlock(in_channels=512, out_channels=256, kernel_size=2, stride=2, pad=0)
        self.conv11 = ConvBlock(in_channels=512, out_channels=256)
        self.conv12 = ConvBlock(in_channels=256, out_channels=256)
        self.conv13 = ConvBlock(in_channels=256, out_channels=256)

        self.upconv2 = UpConvBlock(in_channels=256, out_channels=128, kernel_size=2, stride=2, pad=0)
        self.conv14 = ConvBlock(in_channels=256, out_channels=128)
        self.conv15 = ConvBlock(in_channels=128, out_channels=128)

        self.upconv3 = UpConvBlock(in_channels=128, out_channels=64, kernel_size=2, stride=2, pad=0)
        self.conv16 = ConvBlock(in_channels=128, out_channels=64)
        self.conv17 = ConvBlock(in_channels=64, out_channels=64)
        self.conv18 = nn.Conv2d(in_channels=64, out_channels=self.out_channels, kernel_size=1, stride=1, padding=0)

        self.adaptive_pool = nn.AdaptiveAvgPool2d((2, 2))
        self.flatten = nn.Flatten()
        self.fc = nn.Linear(self.out_channels * 2 * 2, 2 * 14)
        self._init_weights()

    def forward(self, x):
        c1 = self.conv1(x)
        c2 = self.conv2(c1)
        p1 = self.pool1(c2)

        c3 = self.conv3(p1)
        c4 = self.conv4(c3)
        p2 = self.pool2(c4)

        c5 = self.conv5(p2)
        c6 = self.conv6(c5)
        c7 = self.conv7(c6)
        p3 = self.pool3(c7)

        c8 = self.conv8(p3)
        c9 = self.conv9(c8)
        c10 = self.conv10(c9)

        u1 = self.upconv1(c10)
        cat1 = torch.cat([u1, c7], dim=1)  # Skip connection
        c11 = self.conv11(cat1)
        c12 = self.conv12(c11)
        c13 = self.conv13(c12)

        u2 = self.upconv2(c13)
        cat2 = torch.cat([u2, c4], dim=1)  # Skip connection
        c14 = self.conv14(cat2)
        c15 = self.conv15(c14)

        u3 = self.upconv3(c15)
        cat3 = torch.cat([u3, c2], dim=1)  # Skip connection
        c16 = self.conv16(cat3)
        c17 = self.conv17(c16)
        c18 = self.conv18(c17)

        pooled = self.adaptive_pool(c18)
        flattened = self.flatten(pooled)
        out = self.fc(flattened)
        out = out.view(-1, self.out_channels, 2)

        return out

    def _init_weights(self):
        for module in self.modules():
            if isinstance(module, nn.Conv2d) or isinstance(module, nn.ConvTranspose2d):
                nn.init.uniform_(module.weight, -0.05, 0.05)
                if module.bias is not None:
                    nn.init.constant_(module.bias, 0)

            elif isinstance(module, nn.BatchNorm2d):
                nn.init.constant_(module.weight, 1)
                nn.init.constant_(module.bias, 0)

In [None]:
import torch.optim as optim

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = TrackNet().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [None]:

# --- Training --- #
num_epochs = 20
best_val_loss = float('inf')  # Inizializza con un valore molto grande

for epoch in range(num_epochs):
    print(f'Epoch: {epoch}')
    model.train()
    running_loss = 0.0
    for images, keypoints in train_loader:
        images = images.to(device)
        keypoints = keypoints.to(device)

        optimizer.zero_grad()
        outputs = model(images)

        loss = criterion(outputs, keypoints)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * images.size(0)

    epoch_loss = running_loss / len(train_loader.dataset)
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}')

    # --- Validation --- #
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for images, keypoints in val_loader:
            images = images.to(device)
            keypoints = keypoints.to(device)

            outputs = model(images)
            loss = criterion(outputs, keypoints)
            val_loss += loss.item() * images.size(0)

    val_loss /= len(val_loader.dataset)
    print(f'Validation Loss: {val_loss:.4f}')

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), 'best_model.pth')
        print(f'Modello salvato con Validation Loss migliorato: {best_val_loss:.4f}')