# Tennis Stroke Classification: Model Development

### 0. Import PyTorch and Set Up Device Agnostic Code

In [None]:
import torch
from torch import nn

device = 'mps' if torch.backends.mps.is_available() else 'cpu'
device

In [None]:
torch.mps.empty_cache()

### 1. Data Preprocessing and Data Exploration

In [None]:
import os
def walkthrough_data(dir_path):
    """Walk through dir_path returning its contents"""
    for dir_path, dirnames, filenames in os.walk(dir_path):
        print(f"{len(dirnames)} directories and {len(filenames)} images in {dir_path}")

walkthrough_data('dataset')

In [None]:
#Training and test path
from pathlib import Path

image_path = Path("dataset")

train_dir = image_path / "train_set"
test_dir = image_path / "test_set"

train_dir, test_dir

### 2. Visualize Images using Matplotlib

In [None]:
import random
from PIL import Image

# Set seed
random.seed(73)

image_path_list = list(image_path.glob("*/*/*.jpeg"))

random_image_path = random.choice(image_path_list)

image_class = random_image_path.parent.stem

img = Image.open(random_image_path)

print(f'Image Class: {image_class}, Image Height: {img.height}, Image Width: {img.width}')

img

In [None]:
import numpy as np
import matplotlib.pyplot as plt

# Image to array
img_as_array = np.asarray(img)

#Plot
plt.figure(figsize = (10, 7))
plt.imshow(img_as_array)
plt.title(f"Image Class: {image_class}, Image Shape: {img_as_array.shape} -> [height, width, color_channels]")
plt.axis(False)

### 3. Transforming Data into Tensors

In [None]:
def center_crop_square(img: Image.Image) -> Image.Image:
    """Crops the center square from a PIL image."""
    width, height = img.size
    min_dim = min(width, height)
    left = (width - min_dim) // 2
    top = (height - min_dim) // 2
    right = left + min_dim
    bottom = top + min_dim
    return img.crop((left, top, right, bottom))

In [None]:
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

train_transform = transforms.Compose([
    transforms.Lambda(center_crop_square),
    transforms.Resize((320, 320)),
    transforms.RandomHorizontalFlip(p = 0.5),
    transforms.RandomRotation(degrees = 15),
    transforms.ColorJitter(brightness = 0.1, contrast = 0.1, saturation = 0.1),
    transforms.ToTensor(),
])

test_transform = transforms.Compose([
    transforms.Lambda(center_crop_square),     
    transforms.Resize((320, 320)),            
    transforms.ToTensor(),
])

In [None]:
def plot_transformed_images(image_paths: list, transform, seed = 73, n = 3):
    random.seed(seed)
    random_image_paths = random.sample(image_paths, k = n)
    for image_path in random_image_paths:
        with Image.open(image_path) as f:
            fig, ax = plt.subplots(nrows = 1, ncols = 2)
            ax[0].imshow(f)
            ax[0].set_title(f"Original\nSize: {f.size}")
            ax[0].axis(False)

            # Transform and plot image
            transformed_image = transform(f).permute(1, 2, 0)
            ax[1].imshow(transformed_image)
            ax[1].set_title(f"Transformed Image\nSize: {transformed_image.shape}")
            ax[1].axis(False)

            fig.suptitle(f"Class Name: {image_path.parent.stem}", fontsize = 16)

plot_transformed_images(image_paths = image_path_list, transform = train_transform)

### 4. Loading image data using `ImageFolder`

In [None]:
from torchvision import datasets
train_data = datasets.ImageFolder(
    root = train_dir,
    transform = train_transform,
    target_transform = None
)

test_data = datasets.ImageFolder(
    root = test_dir,
    transform = test_transform,
)

train_data, test_data

In [None]:
class_names = train_data.classes
class_names

In [None]:
class_dict = train_data.class_to_idx
class_dict

In [None]:
img, label = train_data[0][0], train_data[0][1]
img, label

### 5. Turn loaded images into `DataLoader`'s

In [None]:
#Turn train and test data into DataLoaders
from torch.utils.data import DataLoader
BATCH_SIZE = 32

train_dataloader = DataLoader(
    dataset = train_data,
    batch_size = BATCH_SIZE,
    num_workers = os.cpu_count(),
    shuffle = True,
    drop_last = True
)

test_dataloader = DataLoader(
    dataset = test_data,
    batch_size = BATCH_SIZE,
    num_workers = os.cpu_count(),
    shuffle = False,
    drop_last = True
)

In [None]:
len(train_dataloader), len(test_dataloader)

### 6. Creating a CNN Model

In [None]:
# Model Class Definition
import torch
from torch import nn

class TennisStrokeClassification(nn.Module):
    '''
    Tennis Stroke Multiclass Classification Convolutional Neural Network (CNN) (Based on VGG Model).
    Takes `num_classes` as input. Takes RGB inputs, thus 3 in_channels on first convolutional layer. 
    Each block of `self.features` applies `nn.Conv2d()` filters, normalizes batches using `nn.BatchNorm2d()`,
    applies a rectified linear unit activation layer `nn.ReLU()`, and performs max pooling at the end using `nn.MaxPool2d()`.
    Finally, `self.classifier` flattens the image and applies a linear layers with 512 * 4 * 4 inputs to `num_classes` outputs.
    '''
    def __init__(self, num_classes: int = 4):
        super().__init__()
        self.features = nn.Sequential(
            # [32, 3, 224, 224]
            nn.Conv2d(3, 64, kernel_size = 3, padding = 1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size = 3, padding = 1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size = 2, stride = 2),  

            # [32, 64, 112, 112]
            nn.Conv2d(64, 128, kernel_size = 3, padding = 1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, 128, kernel_size = 3, padding = 1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size = 2, stride = 2),  
            nn.Dropout(0.2),

            # [32, 128, 56, 56]
            nn.Conv2d(128, 256, kernel_size = 3, padding = 1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(256, 256, kernel_size = 3, padding = 1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(256, 256, kernel_size = 3, padding = 1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size = 2, stride = 2),
            # [32, 256, 28, 28]
        )
        self.adaptive_pool = nn.AdaptiveAvgPool2d((4, 4))
        # [32, 256, 4, 4]

        # Classifier
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Dropout(p = 0.4),
            nn.Linear(256 * 4 * 4, 1024),
            nn.ReLU(),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        """Defines the forward pass of the CNN model."""
        x = self.features(x)
        if x.device.type == "mps":
            x = x.cpu() # move to CPU for adaptive pooling
            x = self.adaptive_pool(x)  
            x = x.to("mps") # move back to MPS
        else:
            x = self.adaptive_pool(x) # apply adaptive avg pool normally
        x = self.classifier(x)    
        return x

In [None]:
tennis_stroke_model = TennisStrokeClassification()
tennis_stroke_model = tennis_stroke_model.to(device)

### 7. Creating Training and Testing Loop Functions

In [None]:
from torch.optim.lr_scheduler import StepLR
adam_optimizer = torch.optim.Adam(params = tennis_stroke_model.parameters(), lr = 0.0001, weight_decay = 1e-4)

scheduler = StepLR(
    adam_optimizer,  
    step_size = 5,  
    gamma = 0.5    
)

loss_fn = nn.CrossEntropyLoss()

def accuracy_fn(pred, true):
    correct = torch.eq(pred, true).sum().item()
    return correct / len(pred) * 100

In [None]:
#train_step() takes the training dataloader, test_step() takes the testing dataloader
def train_step(
        model: torch.nn.Module,
        dataloader: torch.utils.data.DataLoader,
        seed: int, 
        loss_function: torch.nn.Module,
        optimization_function: torch.optim.Optimizer,
        accuracy_function,
):
    torch.manual_seed(seed)
    train_loss_total, train_acc_total = 0, 0
    for X_batch, y_batch in dataloader:
        # Move to best device
        X_batch, y_batch = X_batch.to(device), y_batch.to(device)
        model.train()
        # Forward pass
        y_train_preds_logits = model(X_batch)
        # Loss
        loss = loss_function(y_train_preds_logits, y_batch)
        train_loss_total += loss.item()
        # Backpropagation
        optimization_function.zero_grad()
        loss.backward()
        # Gradient Descent
        optimization_function.step()
        # Accuracy
        accuracy = accuracy_function(y_train_preds_logits.argmax(dim = 1), y_batch)
        train_acc_total += accuracy
    train_acc = train_acc_total / len(dataloader)
    train_loss = train_loss_total / len(dataloader)
    print(f"Train Loss: {train_loss} | Train Accuracy: {train_acc}")

def test_step(
        model: torch.nn.Module,
        loss_function: torch.nn.Module,
        seed: int,
        accuracy_function,
        dataloader: torch.utils.data.DataLoader
):
    torch.manual_seed(seed)
    test_loss_total, test_accuracy_total = 0, 0
    # Set to evaluation mode
    model.eval()
    with torch.inference_mode():
        for X_batch, y_batch in dataloader:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            # Forward pass
            y_test_preds_logits = model(X_batch)
            # Loss
            loss = loss_function(y_test_preds_logits, y_batch)
            test_loss_total += loss.item()
            # Accuracy
            accuracy = accuracy_function(y_test_preds_logits.argmax(dim = 1), y_batch)
            test_accuracy_total += accuracy
        test_acc = test_accuracy_total / len(dataloader)
        test_loss = test_loss_total / len(dataloader)
        print(f"Test Loss: {test_loss} | Test Accuracy: {test_acc}")

In [None]:
# Wrap the test loop inside a function
def train_test_loop(
        model: torch.nn.Module,
        epochs: int,
        device,
        optimizer: torch.optim.Optimizer,
        scheduling_function: torch.optim.lr_scheduler,
        loss_function: torch.nn.Module
):
    model = model.to(device)
    for epoch in range(epochs):
        print(f"Epoch: {epoch} ==============================")
        train_step(
            model = model,
            dataloader = train_dataloader,
            seed = 73,
            loss_function = loss_function,
            optimization_function = optimizer,
            accuracy_function = accuracy_fn,
        )
        test_step(
            model = model,
            loss_function = loss_function,
            seed = 73,
            accuracy_function = accuracy_fn,
            dataloader = test_dataloader
        )
        scheduling_function.step()

train_test_loop(model = tennis_stroke_model, epochs = 5, device = device, optimizer = adam_optimizer, scheduling_function = scheduler, loss_function = loss_fn)