## Imports

In [44]:
# PyTorch imports
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, WeightedRandomSampler, Dataset
from torchvision import transforms
from tqdm import trange
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Torch device:", device) # Quick check to see if we're using GPU or CPU.

import optuna
from skimage.metrics import structural_similarity as ssim
from sklearn.model_selection import train_test_split
import numpy as np
import pandas as pd
from pathlib import Path

# Custom imports
import dataset.download_and_preprocess as dl
from dataset.dataloader import KTHDataset
from autoencoder.autoencoder import AutoencoderModel



import matplotlib.pyplot as plt
from IPython.display import clear_output


Torch device: cuda


## Download the dataset and pre-process the dataset

In [45]:
action_space = ['walking', 'running', 'jogging', 'boxing', 'handwaving', 'handclapping']
for action in action_space:
    should_extract = dl.download_and_extract(action, overwrite=False)
    dl.extract_and_save_frames(action, extraction=should_extract)

## Split the data into train, test, validate

In [46]:
from random import randint
image_file_names = []
labels = []

# Create a list of all the image file names and their corresponding action labels
for action in action_space:
    folder = Path("dataset") / "KTH_data" / action
    num = randint(8000, 12000)
    i= 0
    for f in folder.glob("*.pt"):
        if i > num:
            break
        i +=1
        image_file_names.append(f)
        labels.append(action)
        
paths = np.array(image_file_names)
labels = np.array(labels)

In [47]:
# Split the dataset into training, validation, and test sets
# 70% training, 15% validation, 15% test
X_train, X_temp, y_train, y_temp = train_test_split(
    paths, labels, 
    stratify=labels, 
    test_size=0.3, 
    random_state=42
)
# Split the temporary set into validation and test sets
# 50% of the temporary set for validation and 50% for testing (again, 15% each overall)
X_val, X_test, y_val, y_test = train_test_split(
    X_temp, y_temp, 
    stratify=y_temp, 
    test_size=0.5, 
    random_state=42
)


# Data augmentation for training
train_transform = transforms.Compose([
    transforms.Lambda(lambda x: torch.from_numpy(x) if isinstance(x, np.ndarray) else x),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.5)
])

val_test_transform = transforms.Compose([
    transforms.Lambda(lambda x: torch.from_numpy(x) if isinstance(x, np.ndarray) else x)
])


train_dataset = KTHDataset(X_train, y_train, transform=train_transform)
val_dataset = KTHDataset(X_val, y_val, transform=val_test_transform)
test_dataset = KTHDataset(X_test, y_test, transform=val_test_transform)

# Create sample weights based on training label frequency

# Map class names to integer indices
class_to_idx = {cls_name: idx for idx, cls_name in enumerate(action_space)}
y_train_indices = np.array([class_to_idx[label] for label in y_train])

class_sample_count = np.bincount(y_train_indices)
class_weights = 1. / class_sample_count
sample_weights = class_weights[y_train_indices]
sampler = WeightedRandomSampler(sample_weights, len(sample_weights), replacement=True)

## Training loop helper functions

In [48]:
def ssim_accuracy_percent(output, target):
    """
    Computes SSIM-based accuracy as a percentage.

    Parameters:
        output (torch.Tensor): Reconstructed images (B, C, H, W), values in [0, 1]
        target (torch.Tensor): Ground truth images (B, C, H, W), values in [0, 1]

    Returns:
        float: SSIM-based accuracy in [0, 100]%
    """

    output_np = output.detach().cpu().numpy()
    target_np = target.detach().cpu().numpy()

    ssim_scores = []

    for i in range(output_np.shape[0]):
        out_img = output_np[i, 0]  # Extract 2D image
        tgt_img = target_np[i, 0]
        score = ssim(tgt_img, out_img, data_range=1.0)
        ssim_scores.append(score)

    return 100 * np.mean(ssim_scores)


def plot_model_metrics(train_losses, val_losses, train_accuracies, val_accuracies, epochs:int):
    """
    Plots training and validation losses and accuracies.

    Parameters:
        train_losses (list): List of training losses.
        val_losses (list): List of validation losses.
        train_accuracies (list): List of training accuracies.
        val_accuracies (list): List of validation accuracies.
    """
    
    clear_output(wait=True)  # 🧹 Clear previous plot output
    

    fig, ax1 = plt.subplots(figsize=(8, 5))
    ax1.set_xlim(0, epochs)
    
    # Plot losses (dotted, right y-axis)
    ax2 = ax1.twinx()
    ax2.plot(train_losses, label='Train Loss', color='tab:blue', linestyle='--')
    ax2.plot(val_losses, label='Val Loss', color='tab:orange', linestyle='--')
    ax2.set_ylabel('Loss')
    ax2.yaxis.set_label_position("left")
    ax2.yaxis.tick_left()

    # Plot SSIM accuracy (left y-axis)
    ax1.plot(train_accuracies, label='Train Accuracy', color='tab:blue')
    ax1.plot(val_accuracies, label='Val Accuracy', color='tab:orange')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Accuracy')
    # ax1.set_ylim(0, 100)  # Set y-axis limits for accuracy
    ax1.set_yticks(np.arange(50, 101, 10))  # Set y-ticks for accuracy
    ax1.yaxis.set_label_position("right")
    ax1.yaxis.tick_right()

    # Legends
    lines_1, labels_1 = ax1.get_legend_handles_labels()
    lines_2, labels_2 = ax2.get_legend_handles_labels()
    ax1.legend(lines_1 + lines_2, labels_1 + labels_2, loc='upper right')

    ax1.set_title('Autoencoder Training: SSIM Accuracy & Loss')
    plt.show()

def train_autoencoder(model, train_loader, val_loader):
    optimizer = torch.optim.Adam(model.parameters(), lr=model.learning_rate)
    loss_fn = nn.MSELoss()

    train_accuracies = []
    train_losses = []

    val_accuracies = []
    val_losses = []

    #for epoch in range(model.epochs):
    for epoch in trange(model.epochs, desc="Epochs"):
        model.trained_epochs += 1
        # Training
        model.train()
        running_loss = 0.0
        running_accuracy = 0.0
        for images, _ in train_loader:
            images = images.unsqueeze(1).to(device)
            outputs = model(images)
            
            loss = loss_fn(outputs, images)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * images.size(0)
            running_accuracy += ssim_accuracy_percent(outputs, images) * images.size(0)


        epoch_accuracy = running_accuracy / len(train_loader.dataset)
        epoch_loss = running_loss / len(train_loader.dataset)
        
        train_accuracies.append(epoch_accuracy)
        train_losses.append(epoch_loss)

        # Validation
        model.eval()
        val_loss = 0.0
        val_accuracy = 0.0
        with torch.no_grad():
            for images, _ in val_loader:
                images = images.unsqueeze(1).to(device)
                outputs = model(images)
                loss = loss_fn(outputs, images)
                val_loss += loss.item() * images.size(0)
                val_accuracy += ssim_accuracy_percent(outputs, images) * images.size(0)
        
        val_loss /= len(val_loader.dataset)
        val_losses.append(val_loss)

        val_accuracy = val_accuracy / len(val_loader.dataset)
        val_accuracies.append(val_accuracy)
        

        # (basic) early stopping implementation
        if len(val_losses) > 10:
            # If validation loss hasn't improved in the last 10 epochs, stop training
            if min(val_losses[-10:]) > min(val_losses[:-10]):
                print(f"Early stopping at epoch {epoch+1}")
                break
        clear_output(wait=True)
        plot_model_metrics(train_losses, val_losses, train_accuracies, val_accuracies, model.epochs)
        
    return train_losses, val_losses, train_accuracies, val_accuracies

In [49]:
def optuna_optimization(trial:optuna.Trial):
    """
    Objective function for Optuna optimization.

    Parameters:
        trial (optuna.Trial): Optuna trial object.

    Returns:
        float: Validation loss.
    """
    
    latent_dim = trial.suggest_int("latent_dim", 100, 5000, step = 100)
    batch_size = 512
    learning_rate = trial.suggest_loguniform("learning_rate", 1e-5, 1e-2)
    
    model = AutoencoderModel(latent_dim=latent_dim, epochs=250, batch_size=batch_size, learning_rate=learning_rate).to(device)

    # Data loaders
    train_loader = DataLoader(train_dataset, batch_size=model.batch_size, sampler=sampler)
    val_loader = DataLoader(val_dataset, batch_size=model.batch_size, shuffle=False)

    _, val_losses, _, _ = train_autoencoder(model, train_loader, val_loader)
    
    best_loss = min(val_losses)
    
    model.save(filename=f"model_trial_{trial.number}_loss_{best_loss:.4f}.pt")

    return min(best_loss)  # Return the minimum validation loss

In [None]:
study = optuna.create_study(direction="minimize")
study.optimize(optuna_optimization, n_trials=10)
best_trial = study.best_trial
best_trial_num = best_trial.number
best_loss = best_trial.value

# File path matches the pattern we saved earlier
best_model_path = f"model_trial_{best_trial_num}_loss_{best_loss:.4f}.pt"

print("Best trial params:", best_trial.params)
print("Best trial loss:", best_loss)
print("Saved model path:", best_model_path)



[I 2025-05-19 11:44:12,538] A new study created in memory with name: no-name-fe5fb683-d04b-4ff9-a515-9cfe3f80f28f
  learning_rate = trial.suggest_loguniform("learning_rate", 1e-5, 1e-2)
Epochs:   0%|          | 0/250 [00:00<?, ?it/s]

In [None]:
import random

model.eval()
with torch.no_grad():
    # Grab 8 random indices from the test set
    indices = random.sample(range(len(test_dataset)), 8)
    sample_imgs = torch.stack([test_dataset[i][0] for i in indices])
    sample_imgs = sample_imgs.unsqueeze(1).to(device)
    reconstructions = model(sample_imgs)

fig, axs = plt.subplots(2, 8, figsize=(15, 4))
for i in range(8):
    axs[0, i].imshow(sample_imgs[i, 0].cpu(), cmap='gray')
    axs[1, i].imshow(reconstructions[i, 0].cpu(), cmap='gray')
    axs[0, i].axis('off')
    axs[1, i].axis('off')
axs[0, 0].set_title("Originals")
axs[1, 0].set_title("Reconstructions")
plt.show()
