In [3]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


Imports

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import math
import random
import os
import os.path
import numpy as np
import logging
import argparse
from argparse import ZERO_OR_MORE
from torch.nn.modules.module import T
from torch.utils.tensorboard import SummaryWriter as writer
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
import pandas as pd
from sklearn.metrics import mean_squared_error, mean_absolute_error
from skimage.metrics import structural_similarity as ssim
import time

Parameters

In [5]:
epochs = 150
batch_size = 32 # reduced from 250 to 32
# n_steps = 16 # timestep
in_channels = 1
lr = 0.001 # changed from 0.001 to 0.0001
# n_class = 10
latent_dim = 35 # changed from 128 to 25 (continue with 25)
input_size = 40
k = 20 # multiplier of channel
# scheduled = True # whether to apply scheduled sampling
# checkpoint_filename = 'vae_checkpoint_32_16_250_full_dim25.pth'
save_folder = '/content/drive/MyDrive/TU060/Dissertation/Models/V3/VAE/z_dim=30/full'
# Check if the directory exists, if not, create it
if not os.path.exists(save_folder):
    os.makedirs(save_folder)

Load data

In [6]:
# Load data
raw = np.load('/content/drive/MyDrive/TU060/Dissertation/Dataset/sampled_maps_10000_01.npy')
# raw = np.load('/content/drive/MyDrive/TU060/Dissertation/Dataset/sampled_maps_20000_01.npy')
# raw = np.load('/content/drive/MyDrive/TU060/Dissertation/Dataset/sampled_maps_40000_01.npy')
# raw = np.load('/content/drive/MyDrive/TU060/Dissertation/Dataset/sampled_maps_80000_01.npy')
# raw = np.load('/content/drive/MyDrive/TU060/Dissertation/Dataset/sampled_maps_160000_01.npy')
# raw = np.load('/content/drive/MyDrive/TU060/Dissertation/Dataset/TopoMaps_s01.npy')


# Normalise data between -1 and 1
# data_norm = (raw - raw.min()) / (raw.max() - raw.min()) * 2 - 1
#Normalise the data
# data_norm = (raw - raw.min()) / (raw.max() - raw.min())

# Initialize normalized data array
data_norm = np.zeros_like(raw)

# Normalize each image individually
for i in range(len(raw)):
    min_val = raw[i].min()
    max_val = raw[i].max()
    data_norm[i] = (raw[i] - min_val) / (max_val - min_val)

# Reshape data to have a single channel
data_norm_reshape = data_norm.reshape(-1, 1, raw.shape[1], raw.shape[2]).astype(np.float32)


# Split data into training, validation, and testing
data_train, data_test = train_test_split(data_norm_reshape, test_size=0.3, random_state=42, shuffle=True)
data_test, data_val = train_test_split(data_test, test_size=0.5, random_state=42, shuffle=True)

train_loader = DataLoader(data_train, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(data_val, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(data_test, batch_size=batch_size, shuffle=False)

# ANN-VAE

In [7]:
class VanillaVAE(nn.Module):
    def __init__(self, in_channels, latent_dim) -> None:
        super().__init__()
        self.in_channels = in_channels
        self.latent_dim = latent_dim

        modules = []
        hidden_dims = [32, 64, 128]
        self.hidden_dims = hidden_dims.copy()

        # Build Encoder
        for h_dim in hidden_dims:
            modules.append(
                nn.Sequential(
                    nn.Conv2d(in_channels, out_channels=h_dim,
                              kernel_size= 3, stride= 2, padding  = 1),
                    nn.BatchNorm2d(h_dim),
                    nn.LeakyReLU())
            )
            in_channels = h_dim

        self.encoder = nn.Sequential(*modules)
        self.fc_mu = nn.Linear(hidden_dims[-1]*25, latent_dim) # changed from 4 to 25
        self.fc_var = nn.Linear(hidden_dims[-1]*25, latent_dim) # changed from 4 to 25

        # # Print the dimensions of fc_mu and fc_var
        # print(f"Dimensions of fc_mu: {self.fc_mu.weight.shape}")
        # print(f"Dimensions of fc_var: {self.fc_var.weight.shape}")



        # Build Decoder
        modules = []

        self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 25) # changed from 4 to 25

        hidden_dims.reverse()

        for i in range(len(hidden_dims) - 1):
            modules.append(
                nn.Sequential(
                    nn.ConvTranspose2d(hidden_dims[i],
                                       hidden_dims[i + 1],
                                       kernel_size=3,
                                       stride = 2,
                                       padding=1,
                                       output_padding=1),
                    nn.BatchNorm2d(hidden_dims[i + 1]),
                    nn.LeakyReLU())
            )

        self.decoder = nn.Sequential(*modules)

        self.final_layer = nn.Sequential(
                            nn.ConvTranspose2d(hidden_dims[-1],
                                               hidden_dims[-1],
                                               kernel_size=3,
                                               stride=2,
                                               padding=1,
                                               output_padding=1),
                            nn.BatchNorm2d(hidden_dims[-1]),
                            nn.LeakyReLU(),
                            nn.ConvTranspose2d(hidden_dims[-1], out_channels= self.in_channels, # deconvにしてみる
                                      kernel_size= 3, padding= 1),
                            nn.Sigmoid()) # Changed from nn.Tanh() to nn.Sigmoid()


    def encode(self, input):
        result = self.encoder(input)
        result = torch.flatten(result, start_dim=1)

        # Print the size of the intermediate tensor
        # print(f"Size of result tensor in encode method: {result.size()}")

        # Split the result into mu and var components
        # of the latent Gaussian distribution
        mu = self.fc_mu(result)
        log_var = self.fc_var(result)

        # Print the sizes of mu and log_var tensors
        # print(f"Size of mu tensor: {mu.size()}")
        # print(f"Size of log_var tensor: {log_var.size()}")

        return [mu, log_var]

    def decode(self, z):
        """
        Maps the given latent codes
        onto the image space.
        :param z: (Tensor) [B x D]
        :return: (Tensor) [B x C x H x W]
        """

        result = self.decoder_input(z)
        result = result.view(-1, self.hidden_dims[-1], 5, 5) # changed from 2 to 5

        # Print the size of the intermediate tensor
        # print(f"Size of result tensor in decode method: {result.size()}")

        result = self.decoder(result)
        result = self.final_layer(result)
        return result


    def reparameterize(self, mu, logvar):
        """
        Reparameterization trick to sample from N(mu, var) from
        N(0,1).
        :param mu: (Tensor) Mean of the latent Gaussian [B x D]
        :param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D]
        :return: (Tensor) [B x D]
        """
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return eps * std + mu

    def forward(self, input):
        mu, log_var = self.encode(input)
        z = self.reparameterize(mu, log_var)
        return  self.decode(z), mu, log_var

    def loss_function(self,
                      recons_img,
                      input_img,
                      mu, log_var, kld_weight) -> dict:
        """
        Computes the VAE loss function.
        KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2}
        :param args:
        :param kwargs:
        :return:
        """

        recons_loss =F.mse_loss(recons_img, input_img)


        kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)
        loss = recons_loss + kld_weight * kld_loss
        """
        kld_loss = -0.5 * torch.mean(1 + log_var - mu ** 2 - log_var.exp())
        loss = recons_loss + kld_loss
        """
        return {'loss': loss, 'Reconstruction_Loss':recons_loss, 'KLD':kld_loss}

    def sample(self,
               num_samples:int,
               current_device: int, **kwargs):
        """
        Samples from the latent space and return the corresponding
        image space map.
        :param num_samples: (Int) Number of samples
        :param current_device: (Int) Device to run the model
        :return: (Tensor)
        """
        z = torch.randn(num_samples,
                        self.latent_dim)

        z = z.to(current_device)

        samples = self.decode(z)
        return samples

# Main VAE

In [8]:
# Define the directory to save the model checkpoints
checkpoint_dir = '/content/drive/MyDrive/TU060/Dissertation/Checkpoints/V3'

# Check if the checkpoint directory exists, if not, create it
if not os.path.exists(checkpoint_dir):
    os.makedirs(checkpoint_dir)

# Define a function to save the model checkpoint after every epoch
def save_checkpoint(epoch, model, optimizer, loss, filename):
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss,
    }
    torch.save(checkpoint, os.path.join(checkpoint_dir, filename))

In [9]:
# Instantiate the VAE class
network = VanillaVAE(in_channels, latent_dim)

max_accuracy = 0
min_loss = 1000

optimizer = torch.optim.AdamW(network.parameters(), lr=lr, betas=(0.9, 0.999), weight_decay=0.001)

# Initialize variables for early stopping
best_val_loss = float('inf')
patience = 0
patience_limit = 10  # Number of epochs without improvement before stopping

# Check if there is a checkpoint file available
start_epoch = 0
# if os.path.exists('/content/drive/MyDrive/TU060/Dissertation/Checkpoints/vae_checkpoint_32_16_250_1000.pth'):
#     checkpoint = torch.load('/content/drive/MyDrive/TU060/Dissertation/Checkpoints/svae_checkpoint_32_16_250_1000.pth')
#     start_epoch = checkpoint['epoch']
#     best_val_loss = checkpoint['loss']
#     network.load_state_dict(checkpoint['model_state_dict'])
#     optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
#     print(f"Resuming from epoch {start_epoch + 1}")

# Initialize lists to store training and validation losses
train_losses_list = []
val_losses_list = []

# Initialize lists to store evaluation metrics
train_metrics_list = []
val_metrics_list = []

total_start_time = time.time()

# Training loop
for epoch in range(start_epoch, epochs):
    # Training phase
    network.train()
    train_loss = 0.0
    start_time = time.time()
    train_reconstruction = []
    train_data = []

    for batch_index, data_batch in enumerate(train_loader):
        optimizer.zero_grad()
        # data_batch = data_batch.to(device)
        recons, mu, log_var = network(data_batch)
        losses = network.loss_function(recons, data_batch, mu, log_var, 1/len(train_loader))
        losses['loss'].backward()
        optimizer.step()
        train_loss += losses['loss'].item() * data_batch.size(0)

        # print(f"Epoch [{epoch + 1}/{epochs}], Batch [{batch_index}/{len(train_loader)}], Train Loss: {losses['loss'].item():.4f}")

        train_reconstruction.append(recons.detach().numpy())
        train_data.append(data_batch.detach().numpy())

    train_loss /= len(train_loader.dataset)
    train_losses_list.append(train_loss)

    # Concatenate training data
    train_data = np.concatenate(train_data, axis=0)
    train_reconstruction = np.concatenate(train_reconstruction, axis=0)
    train_reconstruction = np.squeeze(train_reconstruction)

    train_data_flat = train_data.reshape(-1, train_data.shape[2] * train_data.shape[3])  # Flatten to 2D array
    train_reconstruction_flat = train_reconstruction.reshape(-1, train_reconstruction.shape[1] * train_reconstruction.shape[2])  # Flatten to 2D array

    # Compute training metrics
    train_metrics = {
        'MSE': mean_squared_error(train_data_flat, train_reconstruction_flat),
        'MAE': mean_absolute_error(train_data_flat, train_reconstruction_flat),
        'SSIM': ssim(train_data_flat, train_reconstruction_flat, data_range=train_reconstruction_flat.max() - train_reconstruction_flat.min())
        # 'MAPE': mean_absolute_percentage_error(train_data_flat, train_reconstruction_flat)
    }
    train_metrics_list.append(train_metrics)

    # Validation phase
    network.eval()
    val_loss = 0.0
    val_reconstruction = []
    val_data = []

    with torch.no_grad():
        for batch_idx, data_batch in enumerate(val_loader):
            # data_batch = data_batch.to(device)
            recons, mu, log_var = network(data_batch)
            val_losses = network.loss_function(recons, data_batch, mu, log_var, 1/len(val_loader))
            val_loss += val_losses['loss'].item() * data_batch.size(0)

            val_reconstruction.append(recons.detach().numpy())
            val_data.append(data_batch.detach().numpy())

    val_loss /= len(val_loader.dataset)
    val_losses_list.append(val_loss)

    end_time = time.time()
    epoch_time = end_time - start_time
    print(f"Epoch [{epoch + 1}/{epochs}], Time: {epoch_time:.2f} seconds")

    # print(f"Epoch [{epoch+1}/{epochs}], Train Loss: {train_loss}, Val Loss: {val_loss}")

    # Concatenate validation data
    val_data = np.concatenate(val_data, axis=0)
    val_reconstruction = np.concatenate(val_reconstruction, axis=0)
    val_reconstruction = np.squeeze(val_reconstruction)

    val_data_flat = val_data.reshape(-1, val_data.shape[2] * val_data.shape[3])  # Flatten to 2D array
    val_reconstruction_flat = val_reconstruction.reshape(-1, val_reconstruction.shape[1] * val_reconstruction.shape[2])  # Flatten to 2D array

    # Compute validation metrics
    val_metrics = {
        'MSE': mean_squared_error(val_data_flat, val_reconstruction_flat),
        'MAE': mean_absolute_error(val_data_flat, val_reconstruction_flat),
        'SSIM': ssim(val_data_flat, val_reconstruction_flat, data_range=val_reconstruction_flat.max() - val_reconstruction_flat.min())
        # 'MAPE': mean_absolute_percentage_error(val_data_flat, val_reconstruction_flat)
    }
    val_metrics_list.append(val_metrics)

    # Print training and validation loss for each epoch
    print(f"Epoch [{epoch+1}/{epochs}], Train Loss: {train_loss}, Val Loss: {val_loss}")
    print(f"Epoch [{epoch+1}/{epochs}], Train MSE: {train_metrics['MSE']}, Train MAE: {train_metrics['MAE']}, Train SSIM: {train_metrics['SSIM']}")
    print(f"Epoch [{epoch+1}/{epochs}], Val MSE: {val_metrics['MSE']}, Val MAE: {val_metrics['MAE']}, Val SSIM: {val_metrics['SSIM']}")


    # Save model checkpoint after every epoch
    # save_checkpoint(epoch, network, optimizer, best_val_loss, f'vae_checkpoint_32_16_25_full_dim25_{epoch}.pth')

    # Check for improvement in validation loss
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        patience = 0
        # Save the model to the specified folder
        # model_path = os.path.join(save_folder, 'best_model.pth')
        # torch.save(network.state_dict(), model_path)
    else:
        patience += 1

    # Early stopping
    if patience > patience_limit:
        print("Early stopping! No improvement in validation loss.")
        break

# Calculate time
total_end_time = time.time()
epoch_time = total_end_time - total_start_time
print(f"Time: {epoch_time:.2f} seconds")

# # Save losses as CSV
# losses_df = pd.DataFrame({
#     'Epoch': range(start_epoch, epoch +1),
#     'Train Loss': train_losses_list,
#     'Validation Loss': val_losses_list
# })
# losses_df.to_csv('vae2_training_losses_full.csv', index=False)

# Save losses as CSV
# csv_path = os.path.join(save_folder, 'vae3_30_full.csv')
# losses_df = pd.DataFrame({
#     'Epoch': range(start_epoch, epoch +1),
#     'Train Loss': train_losses_list,
#     'Validation Loss': val_losses_list,
#     'Train MSE': [metrics['MSE'] for metrics in train_metrics_list],
#     'Train MAE': [metrics['MAE'] for metrics in train_metrics_list],
#     'Train SSIM': [metrics['SSIM'] for metrics in train_metrics_list],
#     # 'Train MAPE': [metrics['MAPE'] for metrics in train_metrics_list],
#     'Validation MSE': [metrics['MSE'] for metrics in val_metrics_list],
#     'Validation MAE': [metrics['MAE'] for metrics in val_metrics_list],
#     'Validation SSIM': [metrics['SSIM'] for metrics in val_metrics_list],
#     # 'Validation MAPE': [metrics['MAPE'] for metrics in val_metrics_list]
# })
# losses_df.to_csv(csv_path, index=False)

Epoch [1/150], Time: 9.12 seconds
Epoch [1/150], Train Loss: 0.04465967149393899, Val Loss: 0.06143754307428996
Epoch [1/150], Train MSE: 0.036594949662685394, Train MAE: 0.14550326764583588, Train SSIM: 0.3179758627017082
Epoch [1/150], Val MSE: 0.031192278489470482, Val MAE: 0.13512784242630005, Val SSIM: 0.4288103322237335
Epoch [2/150], Time: 8.62 seconds
Epoch [2/150], Train Loss: 0.03520117292233876, Val Loss: 0.06637706057230632
Epoch [2/150], Train MSE: 0.028038527816534042, Train MAE: 0.12653112411499023, Train SSIM: 0.45733325781822853
Epoch [2/150], Val MSE: 0.02478119172155857, Val MAE: 0.11797056347131729, Val SSIM: 0.5172676851787544
Epoch [3/150], Time: 8.80 seconds
Epoch [3/150], Train Loss: 0.033728627852031165, Val Loss: 0.06110414771238963
Epoch [3/150], Train MSE: 0.02628406509757042, Train MAE: 0.12175945937633514, Train SSIM: 0.5070330187589432
Epoch [3/150], Val MSE: 0.02506137080490589, Val MAE: 0.11764025688171387, Val SSIM: 0.5290189527828564
Epoch [4/150], Ti

In [11]:
# Load the model
# network = VanillaVAE(in_channels, latent_dim)
# checkpoint = torch.load(model_path)
# network.load_state_dict(checkpoint)
# network.eval()

# Evaluation
reconstruction = []
data_test = []

network.eval()
with torch.no_grad():
    for batch_idx, data_batch in enumerate(test_loader):
        # data_batch = data_batch.to(device)
        recons, mu, log_var = network(data_batch)
        reconstruction.append(recons.numpy())
        data_test.append(data_batch.numpy())

# # Save data_test
# data_test = []
# with torch.no_grad():
#     for batch_idx, data_batch in enumerate(test_loader):
#         # real_img = real_img.to(device)
#         data_test.append(data_batch.numpy())
data_test = np.concatenate(data_test, axis=0)

# Calculate evaluation metrics
reconstruction = np.concatenate(reconstruction, axis=0)
reconstruction = np.squeeze(reconstruction)

data_test_flat = data_test.reshape(-1, data_test.shape[2] * data_test.shape[3])  # Flatten to 2D array
reconstruction_flat = reconstruction.reshape(-1, reconstruction.shape[1] * reconstruction.shape[2])  # Flatten to 2D array

mse = mean_squared_error(data_test_flat, reconstruction_flat)
mae = mean_absolute_error(data_test_flat, reconstruction_flat)
ssim_score = ssim(data_test_flat, reconstruction_flat,
                  data_range=reconstruction_flat.max() - reconstruction_flat.min())
# mape = mean_absolute_percentage_error(data_test_flat, reconstruction_flat)

print(f"Mean Squared Error: {mse}")
print(f"Mean Absolute Error: {mae}")
print(f"SSIM: {ssim_score}")
# print(f"Mean Absolute Percentage Error: {mape}")

Mean Squared Error: 0.038744326680898666
Mean Absolute Error: 0.1544773429632187
SSIM: 0.0006281379864081619


In [None]:
# Visualize actual and predicted images with sample indices
n_samples = 5  # Number of samples to visualize

# Select random samples
available_indices = list(range(len(data_test)))
indices = np.random.choice(available_indices, size=n_samples, replace=False)

# Plot actual and predicted images
fig, axes = plt.subplots(nrows=2, ncols=n_samples, figsize=(15, 5))

for i, idx in enumerate(indices):
    # Remove selected index from available indices
    available_indices.remove(idx)

    # Plot actual image with index
    axes[0, i].imshow(data_test[idx].squeeze(), cmap='jet')
    axes[0, i].set_title(f"Actual ({idx})")
    axes[0, i].axis('off')

    # Plot predicted image with index
    reconstructed_img = reconstruction[idx]
    axes[1, i].imshow(reconstructed_img, cmap='jet')
    axes[1, i].set_title(f"Predicted ({idx})")
    axes[1, i].axis('off')

plt.tight_layout()
# img_path = os.path.join(save_folder, 'vae3_30_full_2.png')
# plt.savefig(img_path)
plt.show()