## 0. Imports

In [1]:
# PyTorch imports
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, WeightedRandomSampler, Dataset
from torchvision import transforms
from tqdm import trange
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Torch device:", device) # Quick check to see if we're using GPU or CPU.

import optuna
from skimage.metrics import structural_similarity as ssim
from sklearn.model_selection import train_test_split
import numpy as np
import cv2
from pathlib import Path


# Custom imports
import dataset.download_and_preprocess as dl
from dataset.dataloader import KTHDataset
from autoencoder.autoencoder import AutoencoderModel


from collections import defaultdict
import matplotlib.pyplot as plt
from IPython.display import clear_output

# for reproducibility
np.random.seed(42)

Torch device: cuda


## 1. Handling the dataset
### A. Download the  KTH dataset and pre-process the dataset into torch tensors
Due to limited computing power we limit our study to the running videos of the KTH dataset only.

In [2]:
dl.download_and_extract(overwrite=False)
dl.extract_and_save_frames()

### B. Define function for uniformaly sampling the data based on frame/video count

In [3]:
def uniform_sampler(num_samples:int = 512) -> tuple[np.ndarray, np.ndarray]:
    """
    Uniformly samples frames from the KTH running dataset, ensuring that the number of samples per video
    is proportional to the video's frame count, for a total of `num_samples` frames.

    Parameters:
        num_samples (int): Total number of frames to sample across all videos (default: 512).

    Returns:
        samples (np.ndarray): Array of sampled frame file paths (Path objects).
        sampled_video_label (np.ndarray): Array of video labels corresponding to each sampled frame.
    """
    folder = Path("dataset") / "KTH_data_running"
    all_files = list(folder.glob("*.pt"))

    total_samples = 0
    unique_files = {}
    # Extract all the video names and count.
    for f in all_files:
        name = f.name.split("_frame_")[0]
        total_samples+=1
        try:
            unique_files[name] += 1
        except KeyError:
            unique_files[name] = 1


    # From all the videos, sample uniformely a count based on the proporiotional occurance in the dataset.
    samples = []
    sampled_video_label = []
    for video, sample_count in unique_files.items():
        # Note: this rounding may cause that we have a few files too much or little but meh.. we have enough
        target_samples = round((sample_count/total_samples) * num_samples) 

        video_frames = list(folder.glob(f"{video}_*.pt"))
        sampled_files = np.random.choice(video_frames, target_samples, replace=False)
        
        samples.extend(sampled_files)
        sampled_video_label.extend([video] * len(sampled_files))
    

    samples = np.array(samples)
    sampled_video_label = np.array(sampled_video_label)

    return samples, sampled_video_label


### C. Define a function that splits the uniformly sampled dataset in to train, validation, and test sets.
Our splitting of data was 70% training, 15% validation & testing.

In [None]:
def dataset_splitter(samples, sampled_video_label, valtest_size:float=0.3) -> tuple[KTHDataset, KTHDataset, KTHDataset]:
    """
    Splits the provided samples and labels into training, validation, and test sets,
    and returns KTHDataset objects for each split with appropriate transforms.

    Parameters:
        samples (np.ndarray): Array of sampled frame file paths (Path objects).
        sampled_video_label (np.ndarray): Array of video labels corresponding to each sampled frame.
        valtest_size (float): Proportion of the dataset to include in the validation and test splits (default: 0.3).

    Returns:
        train_dataset (KTHDataset): Training dataset with augmentation transforms.
        val_dataset (KTHDataset): Validation dataset with basic transforms.
        test_dataset (KTHDataset): Test dataset with basic transforms.
    """
        
    X_train, X_temp, y_train, y_temp = train_test_split(
    samples, sampled_video_label, 
    stratify=sampled_video_label,
    test_size=valtest_size,
    random_state=42
    )

    X_val, X_test, y_val, y_test = train_test_split(
        X_temp, y_temp,
        test_size=0.5,
        random_state=42
    )



    # Data augmentation for training
    train_transform = transforms.Compose([
        transforms.Lambda(lambda x: torch.from_numpy(x) if isinstance(x, np.ndarray) else x),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomVerticalFlip(p=0.5)
    ])

    val_test_transform = transforms.Compose([
        transforms.Lambda(lambda x: torch.from_numpy(x) if isinstance(x, np.ndarray) else x)
    ])


    train_dataset = KTHDataset(X_train, y_train, transform=train_transform)
    val_dataset = KTHDataset(X_val, y_val, transform=val_test_transform)
    test_dataset = KTHDataset(X_test, y_test, transform=val_test_transform)

    return train_dataset, val_dataset, test_dataset


## 2. Training the model

## A. Plotter functions
These 2 functions are merely for scoring metrics/plotting the performance of the model.

In [None]:
def ssim_accuracy_percent(output, target) -> float:
    """
    Computes SSIM-based accuracy as a percentage.

    Parameters:
        output (torch.Tensor): Reconstructed images (B, C, H, W), values in [0, 1]
        target (torch.Tensor): Ground truth images (B, C, H, W), values in [0, 1]

    Returns:
        float: SSIM-based accuracy in [0, 100]%
    """

    output_np = output.detach().cpu().numpy()
    target_np = target.detach().cpu().numpy()

    ssim_scores = []

    for i in range(output_np.shape[0]):
        out_img = output_np[i, 0]  # Extract 2D image
        tgt_img = target_np[i, 0]
        score = ssim(tgt_img, out_img, data_range=1.0)
        ssim_scores.append(score)

    return 100 * np.mean(ssim_scores)


def plot_model_metrics(train_losses, val_losses, train_accuracies, val_accuracies, epochs:int, save_path=None, title='fill') -> None:
    """
    Plots training and validation losses and accuracies.

    Parameters:
        train_losses (list): List of training losses.
        val_losses (list): List of validation losses.
        train_accuracies (list): List of training accuracies.
        val_accuracies (list): List of validation accuracies.
        epochs (int): Number of epochs the model trained on (so can be less than 200 due to early-stopping).
        save_path (str): Path to save the plot. If None, the plot will be shown.
        title (str): Title of the plot.
    """
    # Create a figure with two y-axes
    fig, ax1 = plt.subplots(figsize=(8, 5))
    x = np.arange(1, epochs + 1) # x-axis values (1 to epochs)
    ax2 = ax1.twinx()

    # plot accuracies (solid, left y-axis)
    ax2.plot(x, train_losses, label='Train Loss', color='tab:blue', linestyle='--')
    ax2.plot(x, val_losses, label='Val Loss', color='tab:orange', linestyle='--')
    ax2.set_ylabel('Loss')
    ax2.yaxis.set_label_position("left")
    ax2.yaxis.tick_left()
    
    # plot accuracies (solid, right y-axis)
    ax1.plot(x, train_accuracies, label='Train Accuracy', color='tab:blue')
    ax1.plot(x, val_accuracies, label='Val Accuracy', color='tab:orange')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Accuracy')
    ax1.set_xlim(1, epochs+1)

    ax1.set_ylim(30, 100)  # Set y-axis limits for accuracy
    ax1.set_yticks(np.arange(30, 101, 10))  # Set y-ticks for accuracy
    ax1.yaxis.set_label_position("right")
    ax1.yaxis.tick_right()

    # Legends and title
    lines_1, labels_1 = ax1.get_legend_handles_labels()
    lines_2, labels_2 = ax2.get_legend_handles_labels()
    ax1.legend(lines_1 + lines_2, labels_1 + labels_2, loc='upper right')
    ax1.set_title(title)

    # Show or save the plot
    if save_path is None:
        plt.show()
    else:
        plt.savefig(save_path)
        plt.close(fig)

## B. The actual training of the autoencoder

In [None]:
def train_autoencoder(model:AutoencoderModel, train_loader, val_loader, patience:int = 10, learning_rate:float = 1e-3):
    """
    Training loop for the autoencoder model, which is trained using adam optimizer and MSE loss.
    The training stops if the validation loss does not improve for a specified number of epochs (patience).

    Parameters:
        model (autoencoder.AutoencoderModel): The autoencoder model to be trained. (it has all parameters such as learning rate and epochs)
        train_loader (torch.utils.data.DataLoader): DataLoader for the training dataset.
        val_loader (torch.utils.data.DataLoader): DataLoader for the validation dataset.
        patience (int): Number of epochs with no improvement after which training will be stopped.

    Returns:
        tuple: A tuple containing:
            - train_losses (list): List of training losses.
            - val_losses (list): List of validation losses.
            - train_accuracies (list): List of training accuracies.
            - val_accuracies (list): List of validation accuracies.
    """
    best_val_loss = float("inf")
    epochs_no_improve = 0

    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    loss_fn = nn.MSELoss()

    train_accuracies = []
    train_losses = []

    val_accuracies = []
    val_losses = []

    #for epoch in range(model.epochs):
    for epoch in trange(model.epochs, desc="Epochs"):
        model.trained_epochs += 1
        # Training
        model.train()
        running_loss = 0.0
        running_accuracy = 0.0
        for images, _ in train_loader:
            images = images.unsqueeze(1).to(device)
            outputs = model(images)
            
            loss = loss_fn(outputs, images)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * images.size(0)
            running_accuracy += ssim_accuracy_percent(outputs, images) * images.size(0)


        epoch_accuracy = running_accuracy / len(train_loader.dataset)
        epoch_loss = running_loss / len(train_loader.dataset)
        
        train_accuracies.append(epoch_accuracy)
        train_losses.append(epoch_loss)

        # Validation
        model.eval()
        val_loss = 0.0
        val_accuracy = 0.0
        with torch.no_grad():
            for images, _ in val_loader:
                images = images.unsqueeze(1).to(device)
                outputs = model(images)
                loss = loss_fn(outputs, images)
                val_loss += loss.item() * images.size(0)
                val_accuracy += ssim_accuracy_percent(outputs, images) * images.size(0)
        
        val_loss /= len(val_loader.dataset)
        val_losses.append(val_loss)

        val_accuracy = val_accuracy / len(val_loader.dataset)
        val_accuracies.append(val_accuracy)


        # Early stopping
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            epochs_no_improve = 0
            model.save(Path("autoencoder")/ "models", filename="Current_best")
        else:
            epochs_no_improve += 1

        if epochs_no_improve >= patience:
            break

        clear_output(True)
        print(f"Loss: {val_loss:5f}\tAccuracy: {val_accuracy:4.1f}")
                
    return train_losses, val_losses, train_accuracies, val_accuracies

## 3. Optuna hyperparameter search for finding the best activation function

In [6]:
def optuna_optimization(trial:optuna.Trial):
    """
    Objective function for Optuna optimization.

    Parameters:
        trial (optuna.Trial): Optuna trial object.

    Returns:
        float: Validation loss.
    """
    
    # Define the hyperparameters to optimize
    activation_function = trial.suggest_categorical("activation_function", ["tanh", "relu", "sigmoid", "leaky_relu"])
    

    string_activation_function = {
        "tanh": nn.Tanh(),
        "relu": nn.ReLU(),
        "sigmoid": nn.Sigmoid(),
        "leaky_relu": nn.LeakyReLU()
    }

    af = string_activation_function[activation_function]

    # Initialize the model
    model = AutoencoderModel(epochs=20, a_function=af).to(device)


    # Fetch the files and split
    samples, labels = uniform_sampler(512)
    train, val, __ = dataset_splitter(samples, labels)


    # Data loaders
    train_loader = DataLoader(train, batch_size=256)
    val_loader = DataLoader(val, batch_size=256, shuffle=False)
    
    
    # Train the model
    train_losses, val_losses, train_accuracies, val_accuracies = train_autoencoder(model, train_loader, val_loader, patience=10, learning_rate=1e-3)

    # Determine best metrics for this model
    best_loss = min(val_losses)
    best_accuracy = max(val_accuracies)

    # Save the model
    model.save(filename=f"model_trial_{trial.number}")

    # Save the plot
    save_path = Path("plots") / "autoencoder_training"
    save_path.mkdir(parents=True, exist_ok=True)  # Create directory if it doesn't exist
    save_path = save_path / f"trial_{trial.number}.png"
    plot_model_metrics(train_losses, val_losses, train_accuracies, val_accuracies, model.trained_epochs, save_path=save_path, title=f"Trial {trial.number} - Loss: {best_loss:.4f} - Acc: {best_accuracy:00.1f}")
    

    return best_loss  # Return the minimum validation loss

### A. Conduct the activation function search on a subset of our dataset.

In [None]:
n_trials = 4
study = optuna.create_study(direction="minimize", study_name="Autoencoder Optimization", storage="sqlite:///autoencoder_activationstudy.db", load_if_exists=True)
studydf = study.trials_dataframe()

if studydf.empty or studydf[studydf["state"] != "FAIL"].empty:
		n_trials_to_complete = n_trials
else:
    trialdf = studydf[studydf["state"] != "FAIL"]
    n_trials_to_complete = n_trials-trialdf.shape[0]
    print(f"finishing study by running {n_trials_to_complete} trials")


study.optimize(optuna_optimization, n_trials=n_trials_to_complete, n_jobs=1)
print("Best parameters:", study.best_trial.params)
studydf

Epochs:   5%|▌         | 1/20 [00:28<08:57, 28.27s/it]

Loss: 0.058740	Accuracy: 21.8


## 4. Training and/or loading our 100-latent dimenional model

In [None]:
model_path = Path("autoencoder") / "models" / "model.pt"


batch_size = 256
learning_rate = 1e-3

# Fetch the files and split
samples, labels = uniform_sampler(2560)
train, val, __ = dataset_splitter(samples, labels)


# Data loaders
train_loader = DataLoader(train, batch_size=batch_size)
val_loader = DataLoader(val, batch_size=batch_size, shuffle=False)


if not model_path.exists():
    model = AutoencoderModel(latent_dim=100, epochs=200, a_function=nn.LeakyReLU()).to(device)

    # Train the model
    train_losses, val_losses, train_accuracies, val_accuracies = train_autoencoder(model, train_loader, val_loader, patience=10, learning_rate=learning_rate)

    # Determine best metrics for this model
    best_loss = min(val_losses)
    best_accuracy = max(val_accuracies)

    # Save the mode
    model.save(filename=f"model")

    # Save the plot
    save_path = Path("plots") / "autoencoder_training"
    save_path.mkdir(parents=True, exist_ok=True)  # Create directory if it doesn't exist
    save_path = save_path / f"model.png"
    plot_model_metrics(train_losses, val_losses, train_accuracies, val_accuracies, model.trained_epochs, save_path=save_path, title=f"Custom trial - Loss: {best_loss:.4f} - Acc: {best_accuracy:00.1f}")

else:
    model = AutoencoderModel(latent_dim=100, epochs=200, a_function=nn.LeakyReLU()).to(device)
    model.load(model_path)

## 5. Visual inspection of the networks output

In [None]:
import random

model.eval()
with torch.no_grad():
    # Grab 8 random indices from the test set
    indices = random.sample(range(len(test_dataset)), 8)
    sample_imgs = torch.stack([test_dataset[i][0] for i in indices])
    sample_imgs = sample_imgs.unsqueeze(1).to(device)
    reconstructions = model(sample_imgs)

fig, axs = plt.subplots(2, 8, figsize=(15, 4))
for i in range(8):
    axs[0, i].imshow(sample_imgs[i, 0].cpu(), cmap='gray')
    axs[1, i].imshow(reconstructions[i, 0].cpu(), cmap='gray')
    axs[0, i].axis('off')
    axs[1, i].axis('off')
axs[0, 0].set_title("Originals")
axs[1, 0].set_title("Reconstructions")
plt.show()


### Reconstruct a video

In [None]:
query_video = "person01_running_d1"
video_location = Path("dataset") / "KTH_data" / "running"
video_frames = len(list(video_location.glob(f"{query_video}_frame_*.pt")))

original_frames = []
processed_frames = []

model.eval()

for frame in range(1, video_frames):
    file = torch.load(video_location / f"{query_video}_frame_{frame}.pt")
    original_np = file.cpu().numpy()

    # Run through model
    with torch.no_grad():
        reconstructed = model(file.unsqueeze(0).unsqueeze(0).to(device))
        reconstructed_np = reconstructed.squeeze().cpu().numpy()

    # Append both frames
    original_frames.append(original_np)
    processed_frames.append(reconstructed_np)

# Convert lists to arrays
original_frames = np.array(original_frames)
processed_frames = np.array(processed_frames)

# Normalize if not in [0,1]
original_frames = np.clip(original_frames, 0, 1)
processed_frames = np.clip(processed_frames, 0, 1)

# Add 5-pixel separator
separator = np.zeros((len(original_frames), 120, 5), dtype=np.float32)
combined_frames = np.concatenate((original_frames, separator, processed_frames), axis=2)

# Convert to uint8 for video
combined_uint8 = (combined_frames * 255).astype(np.uint8)

# Convert grayscale to 3-channel
combined_bgr = np.stack([combined_uint8]*3, axis=-1)  # Shape: (N, H, W, 3)

# Setup video writer
output_path = f"{query_video}_combined_reconstruction.mp4"
height, width = combined_bgr.shape[1:3]
fps = 25
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
video_writer = cv2.VideoWriter(output_path, fourcc, fps, (width, height), isColor=True)

for frame in combined_bgr:
    video_writer.write(frame)

video_writer.release()

### Lastly, we encode and save the full dataset for further training PredRNN