## Data Processing

In [5]:
import torch
from torch import nn
from pathlib import Path
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import os
from PIL import Image
from torch.optim.lr_scheduler import StepLR

device = 'mps' if torch.backends.mps.is_available() else 'cpu'

image_path = Path("dataset")

train_dir = image_path / "train_set"
test_dir = image_path / "test_set"

def center_crop_square(img: Image.Image) -> Image.Image:
    """Crops the center square from a PIL image."""
    width, height = img.size
    min_dim = min(width, height)
    left = (width - min_dim) // 2
    top = (height - min_dim) // 2
    right = left + min_dim
    bottom = top + min_dim
    return img.crop((left, top, right, bottom))

train_transform = transforms.Compose([
    transforms.Lambda(center_crop_square),
    transforms.Resize((320, 320)),
    transforms.RandomHorizontalFlip(p = 0.5),
    transforms.RandomRotation(degrees = 15),
    transforms.ToTensor(),
])

test_transform = transforms.Compose([
    transforms.Lambda(center_crop_square),     
    transforms.Resize((320, 320)),            
    transforms.ToTensor(),
])

train_data = datasets.ImageFolder(
    root = train_dir,
    transform = train_transform,
    target_transform = None
)

test_data = datasets.ImageFolder(
    root = test_dir,
    transform = test_transform,
)

BATCH_SIZE = 32

train_dataloader = DataLoader(
    dataset = train_data,
    batch_size = BATCH_SIZE,
    num_workers = 0,
    shuffle = True,
    drop_last = True
)

test_dataloader = DataLoader(
    dataset = test_data,
    batch_size = BATCH_SIZE,
    num_workers = 0,
    shuffle = False,
    drop_last = True
)

## Model

In [6]:
class ConvResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride = 1):
        super().__init__()
        
        # Convolutional path
        self.conv_layers = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size = 3, stride = stride, padding = 1, bias = False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace = True),
            nn.Conv2d(out_channels, out_channels, kernel_size = 3, stride = 1, padding = 1, bias = False),
            nn.BatchNorm2d(out_channels)
        )
        
        # Shortcut path (identity/nothing or 1x1 conv to match shapes)
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size = 1, stride = stride, bias = False),
                nn.BatchNorm2d(out_channels)
            )
        
        self.relu = nn.ReLU(inplace = True)
    
    def forward(self, x):
        out = self.conv_layers(x)
        out += self.shortcut(x)   # skip connection
        out = self.relu(out)      # apply ReLU() after the shortcut
        return out
    
class TennisStrokeClassifier(nn.Module):
    def __init__(self, num_classes = 4):
        super().__init__()
        self.layer1 = ConvResidualBlock(in_channels = 3, out_channels = 64)
        self.layer2 = ConvResidualBlock(in_channels = 64, out_channels = 128, stride = 2)
        self.layer3 = ConvResidualBlock(in_channels = 128, out_channels = 256, stride = 2)
        self.avgpool = nn.AdaptiveAvgPool2d((1,1))
        self.fc = nn.Linear(256, num_classes)
    
    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x


In [7]:
tennis_stroke_model = TennisStrokeClassifier()
tennis_stroke_model = tennis_stroke_model.to(device)

sgd_optimizer = torch.optim.SGD(params = tennis_stroke_model.parameters(), lr = 0.01, momentum = 0.9, weight_decay = 1e-4)

sgd_scheduler = StepLR(
    sgd_optimizer,  
    step_size = 5,  
    gamma = 0.5    
)

loss_fn = nn.CrossEntropyLoss()

def accuracy_fn(pred, true):
    correct = torch.eq(pred, true).sum().item()
    return correct / len(pred) * 100

## Loop

In [8]:
def train_step(
        model: torch.nn.Module,
        dataloader: torch.utils.data.DataLoader,
        seed: int, 
        loss_function: torch.nn.Module,
        optimization_function: torch.optim.Optimizer,
        accuracy_function,
):
    torch.manual_seed(seed)
    train_loss_total, train_acc_total = 0, 0
    for X_batch, y_batch in dataloader:
        # Move to best device
        X_batch, y_batch = X_batch.to(device), y_batch.to(device)
        model.train()
        # Forward pass
        y_train_preds_logits = model(X_batch)
        # Loss
        loss = loss_function(y_train_preds_logits, y_batch)
        train_loss_total += loss.item()
        # Backpropagation
        optimization_function.zero_grad()
        loss.backward()
        # Gradient Descent
        optimization_function.step()
        # Accuracy
        accuracy = accuracy_function(y_train_preds_logits.argmax(dim = 1), y_batch)
        train_acc_total += accuracy
    train_acc = train_acc_total / len(dataloader)
    train_loss = train_loss_total / len(dataloader)
    print(f"Train Loss: {train_loss} | Train Accuracy: {train_acc}")

def test_step(
        model: torch.nn.Module,
        loss_function: torch.nn.Module,
        seed: int,
        accuracy_function,
        dataloader: torch.utils.data.DataLoader
):
    torch.manual_seed(seed)
    test_loss_total, test_accuracy_total = 0, 0
    # Set to evaluation mode
    model.eval()
    with torch.inference_mode():
        for X_batch, y_batch in dataloader:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            # Forward pass
            y_test_preds_logits = model(X_batch)
            # Loss
            loss = loss_function(y_test_preds_logits, y_batch)
            test_loss_total += loss.item()
            # Accuracy
            accuracy = accuracy_function(y_test_preds_logits.argmax(dim = 1), y_batch)
            test_accuracy_total += accuracy
        test_acc = test_accuracy_total / len(dataloader)
        test_loss = test_loss_total / len(dataloader)
        print(f"Test Loss: {test_loss} | Test Accuracy: {test_acc}")

def train_test_loop(
        model: torch.nn.Module,
        epochs: int,
        device,
        optimizer: torch.optim.Optimizer,
        scheduling_function: torch.optim.lr_scheduler,
        loss_function: torch.nn.Module
):
    model = model.to(device)
    for epoch in range(epochs):
        print(f"Epoch: {epoch} ==============================")
        train_step(
            model = model,
            dataloader = train_dataloader,
            seed = 73,
            loss_function = loss_function,
            optimization_function = optimizer,
            accuracy_function = accuracy_fn,
        )
        test_step(
            model = model,
            loss_function = loss_function,
            seed = 73,
            accuracy_function = accuracy_fn,
            dataloader = test_dataloader
        )
        scheduling_function.step()

## Training

In [9]:
train_test_loop(model = tennis_stroke_model, epochs = 50, device = device, optimizer = sgd_optimizer, scheduling_function = sgd_scheduler, loss_function = loss_fn)

Train Loss: 1.2852159686710523 | Train Accuracy: 37.97554347826087
Test Loss: 1.235659846663475 | Test Accuracy: 35.416666666666664
Train Loss: 1.2196418681870336 | Train Accuracy: 41.71195652173913
Test Loss: 1.2449156753718853 | Test Accuracy: 41.25
Train Loss: 1.1709506265495135 | Train Accuracy: 44.22554347826087
Test Loss: 1.1747179809957742 | Test Accuracy: 36.875
Train Loss: 1.1529144048690796 | Train Accuracy: 44.49728260869565
Test Loss: 1.440933978309234 | Test Accuracy: 39.583333333333336
Train Loss: 1.1368577972702358 | Train Accuracy: 45.85597826086956
Test Loss: 1.159144600853324 | Test Accuracy: 39.375
Train Loss: 1.0930090948291447 | Train Accuracy: 48.77717391304348
Test Loss: 1.0799746641889214 | Test Accuracy: 48.75
Train Loss: 1.0782781722752943 | Train Accuracy: 49.18478260869565
Test Loss: 1.0908863328397274 | Test Accuracy: 40.208333333333336
Train Loss: 1.0738156360128652 | Train Accuracy: 48.91304347826087
Test Loss: 1.0989935209353765 | Test Accuracy: 38.75
Tr

KeyboardInterrupt: 

In [13]:
import gc, torch
torch.mps.empty_cache()
gc.collect()

import multiprocessing as mp
mp.active_children() 
for p in mp.active_children():
    p.terminate()