# VAE Model Architecture, Utils and Training Scripts






# Imports and drive connection

In [1]:
import re
import os
import sys
import json
import glob
import torch
import imageio
import inspect
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torch.nn import LayerNorm
from torch.utils.data import Dataset, DataLoader, random_split

from torch.optim.lr_scheduler import ReduceLROnPlateau
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


# Parameters

In [2]:
version = "v10.02"
save_dir = f"/content/drive/MyDrive/Herts - BSc /3rd Year/FYP/trained_models/vae_model_{version}"
os.makedirs(save_dir, exist_ok=True)

In [10]:
%%writefile "{save_dir}/model_config.json"
{
  "latent_dim": 512,
  "batch_size": 8,
  "num_heads": 4,
  "num_epochs": 50,
  "learning_rate": 0.00005,
  "weight_decay": 1e-5,
  "lr_scheduler": "cosine",
  "early_stop_patience": 5,
  "beta_start": 1.0,
  "beta_end": 0.1,
  "warmup_epochs": 20,
  "anneal_rate": -0.08,
  "train_ratio": 0.7,
  "val_ratio": 0.2,
  "test_ratio": 0.1,
  "hidden_dims": [32, 64, 128, 256, 512]
}

Overwriting /content/drive/MyDrive/Herts - BSc /3rd Year/FYP/trained_models/vae_model_v10.02/model_config.json


# VideoDataset Class


A helper class that designed to handle video data stored as a NumPy array. PyTorch expects float32 format and organised in (C, T, H, W) order. This class converts the dataset into that format, to make sure its ready for PyTortch.

*   **T** - Number of frames in the video (time).

*   **H** - Height of each frame (pixels).

*   **W** - Width of each frame (pixels).

*   **C** - Number of color channels (3 = RGB).


In [4]:
%%writefile "{save_dir}/VideoDataset.py"
import torch
from torch.utils.data import Dataset

class VideoDataset(Dataset):
    def __init__(self, video_array):
        super().__init__()
        self.data = video_array

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # Get a single video in NumPy form
        vid_np = self.data[idx]
        # Convert from NumPy to float32 PyTorch tensor
        vid_tensor = torch.from_numpy(vid_np).float()
        # Rearrange dimensions from (T, H, W, C) to (C, T, H, W)
        vid_tensor = vid_tensor.permute(3, 0, 1, 2)
        return vid_tensor

Writing /content/drive/MyDrive/Herts - BSc /3rd Year/FYP/trained_models/vae_model_v10.02/VideoDataset.py


# VideoVAE Model Architecture

The VideoVAE class implements a 3D Convolutional Variational Autoencoder (VAE) for video data. The input is a short video clip with shape (batch_size, 3, 9, 128, 128) representing RGB videos with 9 frames and spatial resolution of 128x128.

The encoder has five 3D convolutional blocks with increasing channel dimension: [32, 64, 128, 256, 512], and gradually reduces the spatial dimensions while preserving the temporal dimension. It flattens the encoded features and uses two fully connected layers to produce the mean (mu) and log-variance (log_var) vectors of the latent space.

The decoder mirrors the encoder using 3D transposed convolutions to reconstruct the input video.

* 5 Encoder Blocks + FC layers for latent distribution

* 5 Decoder Blocks + Final Reconstruction Layer

* Latent sampling with the reparameterisation trick

Four skip connections are kept between the Encoder and Decoder blocks, but they are turned OFF!

```
reconstruction = self.decode(z, skips=None)
```



In [5]:
%%writefile "{save_dir}/model_architecture.py"
import torch
from torch import nn, Tensor
from torch.nn import functional as F
from typing import List, Any

class VideoVAE(nn.Module):
    def __init__(self, in_channels: int, latent_dim: int, hidden_dims: List = None, **kwargs) -> None:
        super(VideoVAE, self).__init__()
        self.latent_dim = latent_dim
        if hidden_dims is None:
            hidden_dims = [32, 64, 128, 256, 512]
        self.hidden_dims = hidden_dims

        # Encoder Blocks
        self.encoder_block1 = nn.Sequential(
            nn.Conv3d(in_channels, hidden_dims[0], kernel_size=(3,3,3), stride=(1,2,2), padding=(1,1,1)),
            # Hidden dimension: 32, so the channel dimension from 3 (RGB) is increased to 32 (Fetaures like, shapes, lines, etc.)
            # The time dimension is kept to 9, but the spatial dimensions are divided by two (stride=(1,2,2))
            # The Padding is necessary to keep the expected output shape by compensating for the kernel size.
            # Input: (batch_size, 3, 9, 128, 128) -> Output: (batch_size, 32, 9, 64, 64)
            nn.BatchNorm3d(hidden_dims[0]),
            nn.LeakyReLU()
        )
        self.encoder_block2 = nn.Sequential(
            nn.Conv3d(hidden_dims[0], hidden_dims[1], kernel_size=(3,3,3), stride=(1,2,2), padding=(1,1,1)),
            # Current hidden dimension: 64
            # Input: (batch_size, 32, 9, 64, 64) -> Output: (batch_size, 64, 9, 32, 32)
            nn.BatchNorm3d(hidden_dims[1]),
            nn.LeakyReLU()
        )
        self.encoder_block3 = nn.Sequential(
            nn.Conv3d(hidden_dims[1], hidden_dims[2], kernel_size=(3,3,3), stride=(1,2,2), padding=(1,1,1)),
            # Current hidden dimension: 128
            # Input: (batch_size, 64, 9, 32, 32) -> Output: (batch_size, 128, 9, 16, 16)
            nn.BatchNorm3d(hidden_dims[2]),
            nn.LeakyReLU()
        )
        self.encoder_block4 = nn.Sequential(
            nn.Conv3d(hidden_dims[2], hidden_dims[3], kernel_size=(3,3,3), stride=(1,2,2), padding=(1,1,1)),
            # Current hidden dimension: 256
            # Input: (batch_size, 128, 9, 16, 16) -> Output: (batch_size, 256, 9, 8, 8)
            nn.BatchNorm3d(hidden_dims[3]),
            nn.LeakyReLU()
        )
        self.encoder_block5 = nn.Sequential(
            nn.Conv3d(hidden_dims[3], hidden_dims[4], kernel_size=(3,3,3), stride=(1,2,2), padding=(1,1,1)),
            # Current hidden dimension: 512
            # Input: (batch_size, 256, 9, 8, 8) -> Output: (batch_size, 512, 9, 4, 4)
            nn.BatchNorm3d(hidden_dims[4]),
            nn.LeakyReLU()
        )
        # After the Last layer: (batch_size, 512, 9, 4, 4)
        self.encoder_out_shape = (hidden_dims[-1], 9, 4, 4)
        # Flatten dimension is: 512x9x4x4 = 73728
        self.flatten_dim = hidden_dims[-1] * 9 * 4 * 4

        #
        self.fc_mu = nn.Linear(self.flatten_dim, latent_dim)
        self.fc_var = nn.Linear(self.flatten_dim, latent_dim)

        # Decoder: Fully connected layer to reshape latent vector
        self.decoder_input = nn.Linear(latent_dim, self.flatten_dim)
        # Input: (batch_size, latent_dim) -> Output: (batch_size, 73728)
        # Expands the latent vector back to a flattened shape, then reshaped to (batch_size, 512, 9, 4, 4) for decoding.

        # Decoder Blocks
        hidden_dims_rev = hidden_dims[::-1]  # Revers to: [512, 256, 128, 64, 32]
        self.decoder_block1 = nn.Sequential(
            nn.ConvTranspose3d(hidden_dims_rev[0],
                               hidden_dims_rev[1],
                               kernel_size=(3,3,3),
                               stride=(1,2,2),
                               padding=(1,1,1),
                               output_padding=(0,1,1)),
            # Current hidden dimension: 256, reduces channels from 512 to 256 while upsampling spatially.
            # Stride=(1,2,2) keeps time at 9 but doubles spatial dimensions; output_padding adjusts for exact size.
            # Input: (batch_size, 512, 9, 4, 4) -> Output: (batch_size, 256, 9, 8, 8)
            nn.BatchNorm3d(hidden_dims_rev[1]),
            nn.LeakyReLU()
        )
        self.decoder_block2 = nn.Sequential(
            nn.ConvTranspose3d(hidden_dims_rev[1],
                               hidden_dims_rev[2],
                               kernel_size=(3,3,3),
                               stride=(1,2,2),
                               padding=(1,1,1),
                               output_padding=(0,1,1)),
            # Current hidden dimension: 128
            # Input: (batch_size, 256, 9, 8, 8) -> Output: (batch_size, 128, 9, 16, 16)
            nn.BatchNorm3d(hidden_dims_rev[2]),
            nn.LeakyReLU()
        )
        self.decoder_block3 = nn.Sequential(
            nn.ConvTranspose3d(hidden_dims_rev[2],
                               hidden_dims_rev[3],
                               kernel_size=(3,3,3),
                               stride=(1,2,2),
                               padding=(1,1,1),
                               output_padding=(0,1,1)),
            # Current hidden dimension: 64
            # Input: (batch_size, 128, 9, 16, 16) -> Output: (batch_size, 64, 9, 32, 32)
            nn.BatchNorm3d(hidden_dims_rev[3]),
            nn.LeakyReLU()
        )
        self.decoder_block4 = nn.Sequential(
            nn.ConvTranspose3d(hidden_dims_rev[3],
                               hidden_dims_rev[4],
                               kernel_size=(3,3,3),
                               stride=(1,2,2),
                               padding=(1,1,1),
                               output_padding=(0,1,1)),
            # Current hidden dimension: 32
            # Input: (batch_size, 64, 9, 32, 32) -> Output: (batch_size, 32, 9, 64, 64)
            nn.BatchNorm3d(hidden_dims_rev[4]),
            nn.LeakyReLU()
        )
        self.final_layer = nn.Sequential(
            nn.ConvTranspose3d(hidden_dims_rev[-1],
                               hidden_dims_rev[-1],
                               kernel_size=(3,3,3),
                               stride=(1,2,2),
                               padding=(1,1,1),
                               output_padding=(0,1,1)),
            # Current hidden dimension: 32
            # Input: (batch_size, 32, 9, 64, 64) -> Output: (batch_size, 32, 9, 128, 128)
            nn.BatchNorm3d(hidden_dims_rev[-1]),
            nn.LeakyReLU(),
            nn.Conv3d(hidden_dims_rev[-1], out_channels=3,
                      kernel_size=(3,3,3), padding=(1,1,1)),
            # Reduces the channels from 32 to 3
            # Input: (batch_size, 32, 9, 128, 128) -> Output: (batch_size, 3, 9, 128, 128)
            nn.Sigmoid()
            # Normalises the output pixels to [0,1] (RGB)
        )

    def encode(self, x: torch.Tensor) -> List[Tensor]:
        # Encodes the input video into a latent representation and collects skip connections.
        # Input: (batch_size, 3, 9, 128, 128)
        e1 = self.encoder_block1(x)   # Shape: (batch_size, 32, 9, 64, 64)
        e2 = self.encoder_block2(e1)  # Shape: (batch_size, 64, 9, 32, 32)
        e3 = self.encoder_block3(e2)  # Shape: (batch_size, 128, 9, 16, 16)
        e4 = self.encoder_block4(e3)  # Shape: (batch_size, 256, 9, 8, 8)
        e5 = self.encoder_block5(e4)  # Shape: (batch_size, 512, 9, 4, 4)
        x_flat = torch.flatten(e5, start_dim=1)  # Shape: (batch_size, 73728)
        # Compute latent distribution parameters
        mu = self.fc_mu(x_flat)   # Shape: (batch_size, latent_dim)
        log_var = self.fc_var(x_flat)  # Shape: (batch_size, latent_dim)
        # Return mean, log variance, and skip connections for decoding
        return [mu, log_var, (e1, e2, e3, e4)]

    def decode(self, z: torch.Tensor, skips: tuple = None) -> torch.Tensor:
        # Decodes the latent vector back to a video, using skip connections from the encoder.
        # Input z: (batch_size, latent_dim), skips: tuple of 4 tensors from encoder
        x = self.decoder_input(z)  # Shape: (batch_size, 73728)
        x = x.view(-1, *self.encoder_out_shape)  # Reshape to: (batch_size, 512, 9, 4, 4)
        # Decoder block 1: upsample and add skip connection from encoder_block4
        d1 = self.decoder_block1(x)  # Shape: (batch_size, 256, 9, 8, 8)
        if skips is not None:
          d1 = d1 + skips[3]  # Add skip: (batch_size, 256, 9, 8, 8)
        # Decoder block 2: upsample and add skip from encoder_block3
        d2 = self.decoder_block2(d1)  # Shape: (batch_size, 128, 9, 16, 16)
        if skips is not None:
          d2 = d2 + skips[2]  # Add skip: (batch_size, 128, 9, 16, 16)
        # Decoder block 3: upsample and add skip from encoder_block2
        d3 = self.decoder_block3(d2)  # Shape: (batch_size, 64, 9, 32, 32)
        if skips is not None:
          d3 = d3 + skips[1]  # Add skip: (batch_size, 64, 9, 32, 32)
        # Decoder block 4: upsample and add skip from encoder_block1
        d4 = self.decoder_block4(d3)  # Shape: (batch_size, 32, 9, 64, 64)
        if skips is not None:
          d4 = d4 + skips[0]  # Add skip: (batch_size, 32, 9, 64, 64)
        # Reconstruction to original video shape
        x_recon = self.final_layer(d4)  # Shape: (batch_size, 3, 9, 128, 128)
        return x_recon

    def reparameterise(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
        # The Reparameterisation trick for sampling from the latent distribution by adding noise (eps)
        # Input: mu: (batch_size, latent_dim), logvar: (batch_size, latent_dim)
        std = torch.exp(0.5 * logvar)  # Compute standard deviation
        eps = torch.randn_like(std)  # noise
        return eps * std + mu

    def forward(self, x: torch.Tensor, **kwargs) -> List[Tensor]:
        # Full forward pass: encode, reparameterize, and decode.
        # Input: (batch_size, 3, 9, 128, 128)
        mu, log_var, skips = self.encode(x) # encode to latent space and get skips
        z = self.reparameterise(mu, log_var) # sample from latent distribution using reparameterisation
        reconstruction = self.decode(z, skips=None) # Skips Turned Off for reconstruction
        return [reconstruction, mu, log_var]

    def sample(self, num_samples: int, device: int, **kwargs) -> torch.Tensor:
        # Generate new video samples from the latent space without skip connections.
        z = torch.randn(num_samples, self.latent_dim).to(device)  # Shape: (num_samples, latent_dim)
        samples = self.decode(z)  # Decode without skips
        return samples

    def generate(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
        # Generate reconstructions from input data (essentially a forward pass returning only the reconstruction).
        # Input: (batch_size, 3, 9, 128, 128)
        return self.forward(x)[0]  # Return only the reconstructed video: (batch_size, 3, 9, 128, 128)

Writing /content/drive/MyDrive/Herts - BSc /3rd Year/FYP/trained_models/vae_model_v10.02/model_architecture.py


# Utility functions for Traning and Evaluation

* **get_data_loaders:**
Splits the dataset into training, validation, and test sets, then returns PyTorch DataLoader objects for each split

* **save_reconstruction_plots**
Visualises and saves a comparison of original and reconstructed frames

* **save_sample_plots**
Generates and saves samples from random latent vectors, showing the middle frame of each generated video

* **save_latent_distribution**
Extracts latent means from the model, applies PCA for dimensionality reduction, clusters them using KMeans, and saves a scatter plot to visualise the latent space

* **visualise_cluster_samples**
Groups encoded videos into 2 clusters.

* **plot_loss_curves**
Plots and saves training and validation loss curves over epochs.

* **save_tsne_latent_visualisation**
Applies t-SNE to project high-dimensional latent vectors into 2D and visualises clusters.

* **natural_sort_key**
Provides a key for naturally sorting filenames that contain both letters and numbers.

* **create_gif_from_folder**
Combines a sequence of image files from a folder into an animated GIF.

* **log_print**
Prints and logs messages to a file simultaneously.

* **compute_loss**
Calculates the total VAE loss, including reconstruction loss (MSE) and KL divergence loss, with a beta scaling.



In [6]:
%%writefile "{save_dir}/utils.py"
import os
import re
import glob
import torch
import imageio
import matplotlib.pyplot as plt
import torch.nn.functional as F
from sklearn.manifold import TSNE
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA
from torch.utils.data import random_split, DataLoader

from VideoDataset import VideoDataset

def get_data_loaders(video_array, train_ratio=0.7, val_ratio=0.2, test_ratio=0.1, batch_size=8, shuffle=True, num_workers=4, pin_memory=True):
    """Splits video array into train, validation, and test sets. Returns DataLoaders."""
    assert abs(train_ratio + val_ratio + test_ratio - 1.0) < 1e-6, "Ratios must sum to 1.0"

    dataset = VideoDataset(video_array)
    dataset_size = len(dataset)
    train_size = int(train_ratio * dataset_size)
    val_size = int(val_ratio * dataset_size)
    test_size = dataset_size - train_size - val_size

    generator = torch.Generator().manual_seed(42)
    train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size], generator=generator)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=pin_memory)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=pin_memory)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=pin_memory)

    return train_loader, val_loader, test_loader

def save_reconstruction_plots(model, data_loader, epoch, device, save_dir):
    """Saves plot comparing original and reconstructed frames from a validation video."""
    model.eval()
    with torch.no_grad():
        batch = next(iter(data_loader))
        batch = batch.to(device)
        recon_batch, _, _ = model(batch)
        sample = batch[0:1]  # Shape: (1, 3, T, H, W)
        recon_sample = recon_batch[0:1]

        T = sample.shape[2]
        if T >= 3:
            indices = [0, T // 2, T - 1]
        else:
            indices = list(range(T))

        n_frames = len(indices)
        fig, axes = plt.subplots(2, n_frames, figsize=(4 * n_frames, 8))

        for i, idx in enumerate(indices):
            # Convert frame from (C, H, W) to (H, W, C)
            orig_frame = sample[0, :, idx, :, :].cpu().permute(1, 2, 0).numpy()
            recon_frame = recon_sample[0, :, idx, :, :].cpu().permute(1, 2, 0).numpy()

            # Scale images from [-1, 1] to [0, 1] if needed (uncomment below)
            # orig_frame = (orig_frame + 1) / 2
            # recon_frame = (recon_frame + 1) / 2

            axes[0, i].imshow(orig_frame)
            axes[0, i].set_title(f"Original\nFrame {idx}")
            axes[0, i].axis("off")

            axes[1, i].imshow(recon_frame)
            axes[1, i].set_title(f"Reconstruction\nFrame {idx}")
            axes[1, i].axis("off")

        plt.suptitle(f"Reconstruction Comparison - Epoch {epoch}")
        plt.tight_layout(rect=[0, 0, 1, 0.95])
        recon_plot_path = os.path.join(save_dir, f"recon_epoch_{epoch}.png")
        plt.savefig(recon_plot_path)
        plt.close()

def save_sample_plots(model, epoch, device, save_dir, num_samples=4):
    """Saves plot of middle frames from videos generated by random latent codes."""
    model.eval()
    with torch.no_grad():
        samples = model.sample(num_samples=num_samples, device=device)  # Shape: (num_samples, 3, T, H, W)
        T = samples.shape[2]
        frame_idx = T // 2  # Select middle frame

        fig, axes = plt.subplots(1, num_samples, figsize=(4 * num_samples, 4))
        for i in range(num_samples):
            sample_frame = samples[i, :, frame_idx, :, :].cpu().permute(1, 2, 0).numpy()

            # Scale image from [-1, 1] to [0, 1] if needed (uncomment below)
            # sample_frame = (sample_frame + 1) / 2

            axes[i].imshow(sample_frame)
            axes[i].set_title(f"Sample {i}")
            axes[i].axis("off")

        plt.suptitle(f"Random Samples - Epoch {epoch}")
        plt.tight_layout(rect=[0, 0, 1, 0.95])
        sample_plot_path = os.path.join(save_dir, f"samples_epoch_{epoch}.png")
        plt.savefig(sample_plot_path)
        plt.close()

def save_latent_distribution(model, data_loader, epoch, device, save_dir, num_samples=100, num_clusters=5):
    """Saves plot of latent distribution using PCA and KMeans clustering."""
    model.eval()
    latent_means = []
    with torch.no_grad():
        for batch in data_loader:
            batch = batch.to(device)
            _, mu, _ = model(batch)
            latent_means.append(mu.cpu())
            if len(latent_means) * batch.size(0) >= num_samples:
                break
        latent_means = torch.cat(latent_means, dim=0)[:num_samples].numpy()

        # Cluster latent means
        kmeans = KMeans(n_clusters=num_clusters, random_state=0).fit(latent_means)
        cluster_labels = kmeans.labels_

        # Reduce to 2D with PCA
        pca = PCA(n_components=2)
        latent_2d = pca.fit_transform(latent_means)

        # Plot with cluster colors
        plt.figure(figsize=(6, 6))
        scatter = plt.scatter(latent_2d[:, 0], latent_2d[:, 1], c=cluster_labels, alpha=0.7, cmap='viridis')
        plt.colorbar(scatter, label='Cluster Label')
        plt.xlabel("PC1")
        plt.ylabel("PC2")
        plt.title(f"Latent Distribution (PCA) - Epoch {epoch}")
        latent_plot_path = os.path.join(save_dir, f"latent_distribution_epoch_{epoch}.png")
        plt.savefig(latent_plot_path)
        plt.close()

def visualise_cluster_samples(model, data_loader, device, save_dir, epoch, num_samples=100, num_samples_per_cluster=3):
    """Saves plot of middle frames from samples clustered into 2 groups."""
    model.eval()
    latent_codes = []
    videos = []

    with torch.no_grad():
        for batch in data_loader:
            batch = batch.to(device)
            _, mu, _ = model(batch)
            latent_codes.append(mu.cpu())
            videos.append(batch.cpu())
            if len(torch.cat(latent_codes, dim=0)) >= num_samples:
                break

    latent_codes = torch.cat(latent_codes, dim=0)[:num_samples].numpy()
    videos = torch.cat(videos, dim=0)[:num_samples]  # Shape: (num_samples, C, T, H, W)

    # Cluster into 2 groups
    kmeans = KMeans(n_clusters=2, random_state=42)
    cluster_labels = kmeans.fit_predict(latent_codes)

    # Collect sample indices per cluster
    clusters = {0: [], 1: []}
    for idx, label in enumerate(cluster_labels):
        if len(clusters[label]) < num_samples_per_cluster:
            clusters[label].append(idx)

    # Plot middle frame for each sample
    fig, axes = plt.subplots(2, num_samples_per_cluster, figsize=(4 * num_samples_per_cluster, 8))
    for cluster_idx, indices in clusters.items():
        for i, sample_idx in enumerate(indices):
            video = videos[sample_idx]
            T = video.shape[1]
            mid_frame = video[:, T // 2, :, :].permute(1, 2, 0).numpy()
            axes[cluster_idx, i].imshow(mid_frame)
            axes[cluster_idx, i].set_title(f"Cluster {cluster_idx} - Sample {i+1}")
            axes[cluster_idx, i].axis("off")

    plt.suptitle("Representative Samples from Each Cluster")
    plt.tight_layout(rect=[0, 0, 1, 0.95])
    os.makedirs(save_dir, exist_ok=True)
    save_path = os.path.join(save_dir, f"clusters_{epoch}.png")
    plt.savefig(save_path)
    plt.close()

def plot_loss_curves(train_losses, val_losses, train_rec_losses, val_rec_losses, train_kl_losses, val_kl_losses, save_dir):
    """Saves plots of total, reconstruction, and KL losses for train and validation."""
    epochs = range(1, len(train_losses) + 1)

    # Create figure with increased width for wider plots
    fig, axs = plt.subplots(1, 3, figsize=(24, 5))

    # Total loss plot
    axs[0].plot(epochs, train_losses, label="Train", marker='o')
    axs[0].plot(epochs, val_losses, label="Val", marker='o')
    axs[0].set_title("Total Loss")
    axs[0].set_xlabel("Epoch")
    axs[0].set_ylabel("Loss")
    axs[0].legend()
    axs[0].grid(True)

    # Reconstruction loss plot
    axs[1].plot(epochs, train_rec_losses, label="Train Rec", marker='o')
    axs[1].plot(epochs, val_rec_losses, label="Val Rec", marker='o')
    axs[1].set_title("Reconstruction Loss")
    axs[1].set_xlabel("Epoch")
    axs[1].set_ylabel("Loss")
    axs[1].legend()
    axs[1].grid(True)

    # KL loss plot
    axs[2].plot(epochs, train_kl_losses, label="Train KL", marker='o')
    axs[2].plot(epochs, val_kl_losses, label="Val KL", marker='o')
    axs[2].set_title("KL Loss")
    axs[2].set_xlabel("Epoch")
    axs[2].set_ylabel("Loss")
    axs[2].legend()
    axs[2].grid(True)

    # Adjust layout to prevent overlap
    plt.tight_layout()
    plot_path = os.path.join(save_dir, "loss_curves.png")
    plt.savefig(plot_path)
    plt.close()

def save_tsne_latent_visualisation(model, data_loader, save_path, num_samples=100):
    """Saves t-SNE plot of latent codes clustered into 2 groups."""
    model.eval()
    latent_codes = []
    count = 0

    with torch.no_grad():
        for batch in data_loader:
            batch = batch.to(next(model.parameters()).device)
            _, mu, _ = model(batch)
            latent_codes.append(mu.cpu())
            count += mu.size(0)
            if count >= num_samples:
                break

    latent_codes = torch.cat(latent_codes, dim=0)[:num_samples].numpy()

    # Cluster into 2 groups
    kmeans = KMeans(n_clusters=2, random_state=42)
    cluster_labels = kmeans.fit_predict(latent_codes)

    # Reduce to 2D with t-SNE
    tsne = TSNE(n_components=2, random_state=42)
    tsne_results = tsne.fit_transform(latent_codes)

    # Plot with cluster colors
    plt.figure(figsize=(8, 8))
    scatter = plt.scatter(tsne_results[:, 0], tsne_results[:, 1], c=cluster_labels, cmap='tab10', alpha=0.7)
    plt.colorbar(scatter, label='Cluster Label')
    plt.title("t-SNE of Video Latent Representations")
    plt.xlabel("Dimension 1")
    plt.ylabel("Dimension 2")
    plt.savefig(save_path)
    plt.close()

# Sorts file names naturally by splitting numbers and text
def natural_sort_key(s):
    return [int(text) if text.isdigit() else text.lower() for text in re.split('(\d+)', s)]

def create_gif_from_folder(folder_path, gif_path, pattern="*.png", duration=0.5):
    """Creates GIF from PNG images in a folder, sorted naturally."""
    images = []
    file_names = sorted(glob.glob(os.path.join(folder_path, pattern)), key=natural_sort_key)
    for filename in file_names:
        images.append(imageio.imread(filename))
    imageio.mimsave(gif_path, images, duration=duration)

def log_print(msg, file_obj):
    print(msg)
    file_obj.write(msg + "\n")
    file_obj.flush()

def compute_loss(recon_batch, batch, mu, logvar, beta):
    rec_loss = F.mse_loss(recon_batch, batch, reduction='sum') / batch.size(0)
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) / batch.size(0)
    loss = rec_loss + beta * kl_loss
    return loss, rec_loss, kl_loss

Writing /content/drive/MyDrive/Herts - BSc /3rd Year/FYP/trained_models/vae_model_v10.02/utils.py


# **VAE Training Pipeline**
This script defines the full training loop for the VideoVAE model. It includes functionality for training, validation, testing, visualisation, and logging.

### Training Phase
Optimises the model using reconstruction and KL divergence losses with a configurable beta parameter.

### Validation Phase
Evaluates the model after each epoch and implements early stopping based on validation loss improvements.

### Testing Phase
Assesses final model performance using the test set after training concludes.


### Logging and Visualisation

* Reconstruction comparisons
* Random sample generations
* Latent space PCA plots
* t-SNE embeddings
* Cluster visualisation

### After Training
* Loss curves
* GIFs for visual progress

### Model Checkpoint
Automatically saves the best-performing model based on validation loss.

In [7]:
%%writefile "{save_dir}/train.py"
import os
import torch
from torch.optim.lr_scheduler import CosineAnnealingLR, ReduceLROnPlateau, StepLR

from utils import (
    save_reconstruction_plots,
    save_sample_plots,
    save_latent_distribution,
    plot_loss_curves,
    create_gif_from_folder,
    save_tsne_latent_visualisation,
    visualise_cluster_samples,
    log_print,
    compute_loss
)

def train(model, train_loader, val_loader, test_loader, optimiser, scheduler, config, device, eval_dir="eval_plots"):
    recon_dir = os.path.join(eval_dir, "recon")
    samples_dir = os.path.join(eval_dir, "samples")
    latent_dir = os.path.join(eval_dir, "latent")
    best_model_dir = os.path.join(eval_dir, "best_model")
    loss_dir = os.path.join(eval_dir, "loss")
    tsne_dir = os.path.join(eval_dir, "tsne")
    log_dir = os.path.join(eval_dir, "logs")
    latent_sample_visualisation_dir = os.path.join(eval_dir, "latent_sample_visualisation")
    for d in [recon_dir, samples_dir, latent_dir, best_model_dir, loss_dir, tsne_dir, latent_sample_visualisation_dir, log_dir]:
        os.makedirs(d, exist_ok=True)

    log_file_path = os.path.join(log_dir, "training_log.txt")
    log_file = open(log_file_path, "a")

    model.to(device)

    num_epochs = config["num_epochs"]
    patience = config["early_stop_patience"]
    beta_start = config["beta_start"]
    beta_end = config["beta_end"]
    warmup_epochs = config["warmup_epochs"]
    anneal_rate = config["anneal_rate"]
    num_clusters = 2

    train_losses, val_losses, test_losses = [], [], []
    train_rec_losses, val_rec_losses, test_rec_losses = [], [], []
    train_kl_losses, val_kl_losses, test_kl_losses = [], [], []

    best_val_loss = float('inf')
    epoch_no_improvement = 0

    for epoch in range(1, num_epochs + 1):
        # ---- Training Phase ----
        model.train()
        beta = 1.0
        train_loss_epoch = train_rec_loss_epoch = train_kl_loss_epoch = 0.0
        for batch in train_loader:
            batch = batch.to(device)
            optimiser.zero_grad()
            recon_batch, mu, logvar = model(batch)
            loss, rec_loss, kl_loss = compute_loss(recon_batch, batch, mu, logvar, beta)
            loss.backward()
            optimiser.step()
            train_loss_epoch += loss.item() * batch.size(0)
            train_rec_loss_epoch += rec_loss.item() * batch.size(0)
            train_kl_loss_epoch += kl_loss.item() * batch.size(0)

        train_loss_epoch /= len(train_loader.dataset)
        train_rec_loss_epoch /= len(train_loader.dataset)
        train_kl_loss_epoch /= len(train_loader.dataset)
        train_losses.append(train_loss_epoch)
        train_rec_losses.append(train_rec_loss_epoch)
        train_kl_losses.append(train_kl_loss_epoch)

        # ---- Validation Phase ----
        model.eval()
        val_loss_epoch = val_rec_loss_epoch = val_kl_loss_epoch = 0.0
        with torch.no_grad():
            for batch in val_loader:
                batch = batch.to(device)
                recon_batch, mu, logvar = model(batch)
                loss, rec_loss, kl_loss = compute_loss(recon_batch, batch, mu, logvar, beta)
                val_loss_epoch += loss.item() * batch.size(0)
                val_rec_loss_epoch += rec_loss.item() * batch.size(0)
                val_kl_loss_epoch += kl_loss.item() * batch.size(0)

        val_loss_epoch /= len(val_loader.dataset)
        val_rec_loss_epoch /= len(val_loader.dataset)
        val_kl_loss_epoch /= len(val_loader.dataset)
        val_losses.append(val_loss_epoch)
        val_rec_losses.append(val_rec_loss_epoch)
        val_kl_losses.append(val_kl_loss_epoch)

        # ---- Early Stopping ----
        if val_loss_epoch < best_val_loss:
            best_val_loss = val_loss_epoch
            epoch_no_improvement = 0
            torch.save(model.state_dict(), os.path.join(best_model_dir, "best_video_vae.pth"))
        else:
            epoch_no_improvement += 1
            log_print(f"Epoch {epoch} - Early stopping counter: {epoch_no_improvement}/{patience}", log_file)


        if epoch_no_improvement >= patience:
            log_print(f"Early stopping triggered at epoch {epoch}", log_file)
            break

        # ---- Learning Rate Scheduler Step ----
        if scheduler is not None:
            if isinstance(scheduler, ReduceLROnPlateau):
                scheduler.step(val_loss_epoch)
            else:
                scheduler.step()

        # ---- Evaluation Plots ----
        save_reconstruction_plots(model, val_loader, epoch, device, recon_dir)
        save_sample_plots(model, epoch, device, samples_dir, num_samples=4)
        save_latent_distribution(model, val_loader, epoch, device, latent_dir, num_samples=100, num_clusters=num_clusters)
        tsne_save_path = os.path.join(tsne_dir, f"tsne_epoch_{epoch}.png")
        save_tsne_latent_visualisation(model, val_loader, tsne_save_path)
        visualise_cluster_samples(model, val_loader, device, latent_sample_visualisation_dir, epoch, num_samples=100, num_samples_per_cluster=3)

        log_print(f"Epoch {epoch}/{num_epochs} - Beta: {beta:.4f} - "
              f"Train Loss: {train_loss_epoch:.4f} (Rec: {train_rec_loss_epoch:.4f}, KL: {train_kl_loss_epoch:.4f}) - "
              f"Val Loss: {val_loss_epoch:.4f} (Rec: {val_rec_loss_epoch:.4f}, KL: {val_kl_loss_epoch:.4f})", log_file)

    # ---- Test Phase ----
    model.eval()
    test_loss_epoch = test_rec_loss_epoch = test_kl_loss_epoch = 0.0
    with torch.no_grad():
        for batch in test_loader:
            batch = batch.to(device)
            recon_batch, mu, logvar = model(batch)
            loss, rec_loss, kl_loss = compute_loss(recon_batch, batch, mu, logvar, beta)
            test_loss_epoch += loss.item() * batch.size(0)
            test_rec_loss_epoch += rec_loss.item() * batch.size(0)
            test_kl_loss_epoch += kl_loss.item() * batch.size(0)

    test_loss_epoch /= len(test_loader.dataset)
    test_rec_loss_epoch /= len(test_loader.dataset)
    test_kl_loss_epoch /= len(test_loader.dataset)
    test_losses.append(test_loss_epoch)
    test_rec_losses.append(test_rec_loss_epoch)
    test_kl_losses.append(test_kl_loss_epoch)


    log_print(f"Test Results - Loss: {test_loss_epoch:.4f} (Rec: {test_rec_loss_epoch:.4f}, KL: {test_kl_loss_epoch:.4f})", log_file)
    log_file.close()
    # ---- Save Loss Curves ----
    plot_loss_curves(train_losses, val_losses, train_rec_losses, val_rec_losses, train_kl_losses, val_kl_losses, loss_dir)

    # ---- Create GIFs ----
    create_gif_from_folder(recon_dir, os.path.join(recon_dir, "reconstruction.gif"))
    create_gif_from_folder(samples_dir, os.path.join(samples_dir, "samples.gif"))
    create_gif_from_folder(latent_dir, os.path.join(latent_dir, "latent.gif"))

Writing /content/drive/MyDrive/Herts - BSc /3rd Year/FYP/trained_models/vae_model_v10.02/train.py


In [8]:
# Load Video Data
video_data_path = '/content/drive/MyDrive/Herts - BSc /3rd Year/FYP/processed_data/video_data-9frame-v2.0.npy'
video_data = np.load(video_data_path)

print("Shape of video_data:", video_data.shape)

Shape of video_data: (4588, 9, 128, 128, 3)


# Execution
This script loads configuration settings and starts the complete training process for the VideoVAE model. It handles setup, execution, and saving of both the model and the current notebook.

### Configuration Loading
Reads training parameters from a model_config.json file.

### Data Preparation
Calls get_data_loaders() to split the dataset into training, validation, and test sets and creates DataLoaders.

### Training
Calls the train() function with all necessary args. This function handles training, validation, testing, logging, visualisation, and saving the best model.

### After training:

Saves the final model weights (video_vae_final.pth)

Copies the current Jupyter notebook into the output directory (correct version)


In [11]:
import os
import shutil
os.chdir(f"{save_dir}")

import json
import torch
import torch.optim as optim
import numpy as np
from torch.optim.lr_scheduler import CosineAnnealingLR, ReduceLROnPlateau, StepLR

from model_architecture import VideoVAE
from train import train
from utils import get_data_loaders

if __name__ == '__main__':
    config_path = os.path.join(save_dir, "model_config.json")
    with open(config_path, "r") as f:
        config = json.load(f)

    latent_dim = config["latent_dim"]
    batch_size = config["batch_size"]
    # num_heads = config["num_heads"]
    num_epochs = config["num_epochs"]
    learning_rate = config["learning_rate"]
    weight_decay = config["weight_decay"]
    lr_scheduler_type = config["lr_scheduler"]
    patience = config["early_stop_patience"]
    train_ratio = config["train_ratio"]
    val_ratio = config["val_ratio"]
    test_ratio = config["test_ratio"]
    hidden_dims = config["hidden_dims"]

    print(f"Config:\n latent_dim={latent_dim}\n learning_rate={learning_rate}\n lr_scheduler={lr_scheduler_type}\n batch_size={batch_size}")

    # Create DataLoaders
    train_loader, val_loader, test_loader = get_data_loaders(
        video_data, train_ratio=train_ratio, val_ratio=val_ratio, test_ratio=test_ratio, batch_size=batch_size
    )

    # Set Device, Model, Optimiser, and Scheduler
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = VideoVAE(3, latent_dim=latent_dim, hidden_dims=hidden_dims)
    optimiser = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

    if lr_scheduler_type == "cosine":
        scheduler = CosineAnnealingLR(optimiser, T_max=num_epochs)
    elif lr_scheduler_type == "step":
        scheduler = StepLR(optimiser, step_size=10, gamma=0.1)
    elif lr_scheduler_type == "plateau":
        scheduler = ReduceLROnPlateau(optimiser, mode='min', factor=0.1, patience=5)
    else:
        scheduler = None

    # Training
    train(model, train_loader, val_loader, test_loader, optimiser, scheduler, config, device, eval_dir=save_dir)

    # Save Final Model
    torch.save(model.state_dict(), os.path.join(save_dir, 'video_vae_final.pth'))
    print("Training complete. Final model saved as 'video_vae_final.pth'.")
    # Save the ipynb file
    current_ipynb_path = '/content/drive/MyDrive/Herts - BSc /3rd Year/FYP/Workflow/Final-WorkFlow/Architecture-Training.ipynb'
    destination_ipynb_path = os.path.join(save_dir, "Architecture-Training.ipynb")
    shutil.copy(current_ipynb_path, destination_ipynb_path)

    print(f"Copied {current_ipynb_path} to {destination_ipynb_path}")

Config:
 latent_dim=512
 learning_rate=5e-05
 lr_scheduler=cosine
 batch_size=8
Epoch 1/50 - Beta: 1.0000 - Train Loss: 9293.5215 (Rec: 8841.2979, KL: 452.2236) - Val Loss: 6217.2064 (Rec: 5786.7006, KL: 430.5058)
Epoch 2/50 - Beta: 1.0000 - Train Loss: 5237.3034 (Rec: 4856.9591, KL: 380.3443) - Val Loss: 4625.8502 (Rec: 4287.2043, KL: 338.6458)
Epoch 3/50 - Beta: 1.0000 - Train Loss: 4216.9676 (Rec: 3852.6121, KL: 364.3555) - Val Loss: 3857.4084 (Rec: 3488.2833, KL: 369.1250)
Epoch 4/50 - Beta: 1.0000 - Train Loss: 3679.4867 (Rec: 3328.5791, KL: 350.9076) - Val Loss: 3477.0154 (Rec: 3098.4047, KL: 378.6107)
Epoch 5/50 - Beta: 1.0000 - Train Loss: 3359.2341 (Rec: 3015.8065, KL: 343.4276) - Val Loss: 3342.4569 (Rec: 2973.4528, KL: 369.0041)
Epoch 6/50 - Beta: 1.0000 - Train Loss: 3117.2753 (Rec: 2781.2105, KL: 336.0647) - Val Loss: 3067.8173 (Rec: 2705.9906, KL: 361.8267)
Epoch 7/50 - Beta: 1.0000 - Train Loss: 2949.3128 (Rec: 2617.4237, KL: 331.8891) - Val Loss: 2916.5371 (Rec: 2590.26