# Set Up

In [None]:
# Connect to Google Drive
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Import necessary libraries
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import random
import os
import os.path
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, mean_absolute_error
from skimage.metrics import structural_similarity as ssim
import time

In [None]:
# Create save folder if it doesn't exist
save_folder = 'location_to_save'
if not os.path.exists(save_folder):
    os.makedirs(save_folder)

# Set Hyperparameters

In [None]:
# Hyperparameters
epochs = 150
batch_size = 32
in_channels = 1
lr = 0.001
latent_dim = 35
input_size = 40
k = 20

# Load Data

In [None]:
# Load and normalise data
raw = np.load('location_of_npy_file')
data_norm = np.zeros_like(raw)

for i in range(len(raw)):
    min_val = raw[i].min()
    max_val = raw[i].max()
    data_norm[i] = (raw[i] - min_val) / (max_val - min_val)

data_norm_reshape = data_norm.reshape(-1, 1, raw.shape[1], raw.shape[2]).astype(np.float32)

# Split data
data_train, data_test = train_test_split(data_norm_reshape, test_size=0.3, random_state=42, shuffle=True)
data_test, data_val = train_test_split(data_test, test_size=0.5, random_state=42, shuffle=True)

train_loader = DataLoader(data_train, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(data_val, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(data_test, batch_size=batch_size, shuffle=False)

# ANN-based VAE Function Definitions

In [None]:
class VanillaVAE(nn.Module):
    def __init__(self, in_channels, latent_dim) -> None:
        super().__init__()
        self.in_channels = in_channels
        self.latent_dim = latent_dim

        modules = []
        hidden_dims = [32, 64, 128]
        self.hidden_dims = hidden_dims.copy()

        # Build Encoder
        for h_dim in hidden_dims:
            modules.append(
                nn.Sequential(
                    nn.Conv2d(in_channels, out_channels=h_dim, kernel_size=3, stride=2, padding=1),
                    nn.BatchNorm2d(h_dim),
                    nn.LeakyReLU())
            )
            in_channels = h_dim

        self.encoder = nn.Sequential(*modules)
        self.fc_mu = nn.Linear(hidden_dims[-1]*25, latent_dim)
        self.fc_var = nn.Linear(hidden_dims[-1]*25, latent_dim)

        # Build Decoder
        modules = []
        self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 25)
        hidden_dims.reverse()

        for i in range(len(hidden_dims) - 1):
            modules.append(
                nn.Sequential(
                    nn.ConvTranspose2d(hidden_dims[i], hidden_dims[i + 1], kernel_size=3, stride=2, padding=1, output_padding=1),
                    nn.BatchNorm2d(hidden_dims[i + 1]),
                    nn.LeakyReLU())
            )

        self.decoder = nn.Sequential(*modules)

        self.final_layer = nn.Sequential(
            nn.ConvTranspose2d(hidden_dims[-1], hidden_dims[-1], kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(hidden_dims[-1]),
            nn.LeakyReLU(),
            nn.ConvTranspose2d(hidden_dims[-1], out_channels=self.in_channels, kernel_size=3, padding=1),
            nn.Sigmoid()
        )

    def encode(self, input):
        result = self.encoder(input)
        result = torch.flatten(result, start_dim=1)
        mu = self.fc_mu(result)
        log_var = self.fc_var(result)
        return [mu, log_var]

    def decode(self, z):
        result = self.decoder_input(z)
        result = result.view(-1, self.hidden_dims[-1], 5, 5)
        result = self.decoder(result)
        result = self.final_layer(result)
        return result

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return eps * std + mu

    def forward(self, input):
        mu, log_var = self.encode(input)
        z = self.reparameterize(mu, log_var)
        return self.decode(z), mu, log_var

    def loss_function(self, recons_img, input_img, mu, log_var, kld_weight) -> dict:
        recons_loss = F.mse_loss(recons_img, input_img)
        kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim=1), dim=0)
        loss = recons_loss + kld_weight * kld_loss
        return {'loss': loss, 'Reconstruction_Loss': recons_loss, 'KLD': kld_loss}

    def sample(self, num_samples: int, current_device: int, **kwargs):
        z = torch.randn(num_samples, self.latent_dim)
        z = z.to(current_device)
        samples = self.decode(z)
        return samples

# Training and Validation

In [None]:
# Instantiate the VAE class
network = VanillaVAE(in_channels, latent_dim)
optimizer = torch.optim.AdamW(network.parameters(), lr=lr, betas=(0.9, 0.999), weight_decay=0.001)

# Initialise variables for early stopping
best_val_loss = float('inf')
patience = 0
patience_limit = 10

# Initialise lists to store training and validation losses
train_losses_list = []
val_losses_list = []

# Initialise lists to store evaluation metrics
train_metrics_list = []
val_metrics_list = []

total_start_time = time.time()

# Training loop
for epoch in range(start_epoch, epochs):
    # Training phase
    network.train()
    train_loss = 0.0
    start_time = time.time()
    train_reconstruction = []
    train_data = []

    for data_batch in train_loader:
        optimizer.zero_grad()
        recons, mu, log_var = network(data_batch)
        losses = network.loss_function(recons, data_batch, mu, log_var, 1/len(train_loader))
        losses['loss'].backward()
        optimizer.step()
        train_loss += losses['loss'].item() * data_batch.size(0)
        train_reconstruction.append(recons.detach().numpy())
        train_data.append(data_batch.detach().numpy())


    train_loss /= len(train_loader.dataset)
    train_losses_list.append(train_loss)
    train_data = np.concatenate(train_data, axis=0)
    train_reconstruction = np.concatenate(train_reconstruction, axis=0)
    train_reconstruction = np.squeeze(train_reconstruction)

    train_data_flat = train_data.reshape(-1, train_data.shape[2] * train_data.shape[3])
    train_reconstruction_flat = train_reconstruction.reshape(-1, train_reconstruction.shape[1] * train_reconstruction.shape[2])

    train_metrics = {
        'MSE': mean_squared_error(train_data_flat, train_reconstruction_flat),
        'MAE': mean_absolute_error(train_data_flat, train_reconstruction_flat),
        'SSIM': ssim(train_data_flat, train_reconstruction_flat, data_range=train_reconstruction_flat.max() - train_reconstruction_flat.min())
    }
    train_metrics_list.append(train_metrics)

    # Validation phase
    network.eval()
    val_loss = 0.0
    val_reconstruction = []
    val_data = []

    with torch.no_grad():
        for data_batch in val_loader:
            recons, mu, log_var = network(data_batch)
            val_losses = network.loss_function(recons, data_batch, mu, log_var, 1/len(val_loader))
            val_loss += val_losses['loss'].item() * data_batch.size(0)
            val_reconstruction.append(recons.detach().numpy())
            val_data.append(data_batch.detach().numpy())

    val_loss /= len(val_loader.dataset)
    val_losses_list.append(val_loss)

    end_time = time.time()
    epoch_time = end_time - start_time
    print(f"Epoch [{epoch + 1}/{epochs}], Time: {epoch_time:.2f} seconds")
    val_data = np.concatenate(val_data, axis=0)
    val_reconstruction = np.concatenate(val_reconstruction, axis=0)
    val_reconstruction = np.squeeze(val_reconstruction)

    val_data_flat = val_data.reshape(-1, val_data.shape[2] * val_data.shape[3])
    val_reconstruction_flat = val_reconstruction.reshape(-1, val_reconstruction.shape[1] * val_reconstruction.shape[2])

    val_metrics = {
        'MSE': mean_squared_error(val_data_flat, val_reconstruction_flat),
        'MAE': mean_absolute_error(val_data_flat, val_reconstruction_flat),
        'SSIM': ssim(val_data_flat, val_reconstruction_flat, data_range=val_reconstruction_flat.max() - val_reconstruction_flat.min())
    }
    val_metrics_list.append(val_metrics)

    print(f"Epoch [{epoch+1}/{epochs}], Train Loss: {train_loss}, Val Loss: {val_loss}")
    print(f"Epoch [{epoch+1}/{epochs}], Train MSE: {train_metrics['MSE']}, Train MAE: {train_metrics['MAE']}, Train SSIM: {train_metrics['SSIM']}")
    print(f"Epoch [{epoch+1}/{epochs}], Val MSE: {val_metrics['MSE']}, Val MAE: {val_metrics['MAE']}, Val SSIM: {val_metrics['SSIM']}")

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(network.state_dict(), os.path.join(save_folder, 'best_model.pth'))
        patience = 0
    else:
        patience += 1

    if patience >= patience_limit:
        print("Early stopping due to no improvement in validation loss.")
        break

total_end_time = time.time()
total_training_time = total_end_time - total_start_time
print(f"Total training time: {total_training_time:.2f} seconds")

# Evaluation

In [None]:
# Evaluation
reconstruction = []
data_test = []

network.eval()
with torch.no_grad():
    for batch_idx, data_batch in enumerate(test_loader):
        # data_batch = data_batch.to(device)
        recons, mu, log_var = network(data_batch)
        reconstruction.append(recons.numpy())
        data_test.append(data_batch.numpy())

data_test = np.concatenate(data_test, axis=0)

# Calculate evaluation metrics
reconstruction = np.concatenate(reconstruction, axis=0)
reconstruction = np.squeeze(reconstruction)

data_test_flat = data_test.reshape(-1, data_test.shape[2] * data_test.shape[3])  # Flatten to 2D array
reconstruction_flat = reconstruction.reshape(-1, reconstruction.shape[1] * reconstruction.shape[2])  # Flatten to 2D array

mse = mean_squared_error(data_test_flat, reconstruction_flat)
mae = mean_absolute_error(data_test_flat, reconstruction_flat)
ssim_score = ssim(data_test_flat, reconstruction_flat,
                  data_range=reconstruction_flat.max() - reconstruction_flat.min())

print(f"Mean Squared Error: {mse}")
print(f"Mean Absolute Error: {mae}")
print(f"SSIM: {ssim_score}")

# Visualisation

In [None]:
# Visualise actual and predicted images with sample indices
n_samples = 5

# Select random samples
available_indices = list(range(len(data_test)))
indices = np.random.choice(available_indices, size=n_samples, replace=False)

# Plot actual and predicted images
fig, axes = plt.subplots(nrows=2, ncols=n_samples, figsize=(15, 5))

for i, idx in enumerate(indices):
    # Remove selected index from available indices
    available_indices.remove(idx)

    # Plot actual image with index
    axes[0, i].imshow(data_test[idx].squeeze(), cmap='jet')
    axes[0, i].set_title(f"Actual ({idx})")
    axes[0, i].axis('off')

    # Plot predicted image with index
    reconstructed_img = reconstruction[idx]
    axes[1, i].imshow(reconstructed_img, cmap='jet')
    axes[1, i].set_title(f"Predicted ({idx})")
    axes[1, i].axis('off')

plt.tight_layout()
plt.show()