In [None]:
import cv2
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import pickle
from PIL import Image
from scipy import ndimage, stats

import torch as th
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, random_split, Subset, TensorDataset
import torchvision.transforms as T
from tqdm import tqdm

# Set the device
if th.cuda.is_available():
    device = th.device("cuda")
    print(f"Using device {device}: {th.cuda.get_device_name()}")
elif th.backends.mps.is_available():
    device = th.device("mps")
    print(f"Using device {device}")
else:
    device = th.device("cpu")
    print(f"Using device {device}")

In [None]:
# Manually override to cpu
# device = th.device("cpu")

# Hyperparamters

In [None]:
hyperparams = {
    "learning_rate": 2e-4,
    "batch_size": 64,
    "epochs": 100,
    "l1_lambda": 100,
    "adam_betas": (0.5, 0.999), # default values are (0.9, 0.999)
    "data_splits": [4/6, 1/6, 1/6],
    "test_run": False, # Set to True to run a test with a smaller dataset
    "test_run_size": 192,
}

# Create the path and model data directory itself
model_data = "./rsc/second_model"
if not os.path.exists(model_data):
    os.makedirs(model_data)

# Dataset

## Preparing the dataset

In [None]:
class PairedImageDataset(Dataset):
    def __init__(self, input_data_dir: str, gt_data_dir: str, size=None, excluded=None, random_perturbations=True):
        # Define the data directories and the list of image names (same for input and gt)
        self.input_data_dir = input_data_dir
        self.gt_data_dir = gt_data_dir
        self.image_files = os.listdir(input_data_dir)
        
        # Select the correct amount of images if a size is provided (for smaller experiments)
        if size is not None:
            self.image_files = self.image_files[:size]
        
        # Filter out any image names if provided (to keep figure images out of training)
        if excluded is not None:
            self.image_files = [filename for filename in self.image_files if filename not in excluded]
        
        # Define the image transform
        self.transform = T.Compose([
            T.Resize([256, 256], interpolation=T.InterpolationMode.NEAREST),
            T.RandomHorizontalFlip(),
            T.RandomVerticalFlip(),
            T.RandomCrop(256),
            T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ]) if random_perturbations else T.Compose([
            T.Resize([256, 256], interpolation=T.InterpolationMode.NEAREST),
            T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

    def __len__(self):
        # Return the number of image pairs in the dataset
        return len(self.image_files)

    def __getitem__(self, i):
        # Get the image paths for both the input and the gt
        input_image_path = os.path.join(self.input_data_dir, self.image_files[i])
        gt_image_path = os.path.join(self.gt_data_dir, self.image_files[i])
        
        # Read in the input image and correct the color and channel order (cv2 reads in BGR)
        input_image = cv2.imread(input_image_path)
        input_image = cv2.cvtColor(input_image, cv2.COLOR_BGR2RGB)

        # Read in the gt image and correct the color and channel order (cv2 reads in BGR)
        gt_image = cv2.imread(gt_image_path)
        gt_image = cv2.cvtColor(gt_image, cv2.COLOR_BGR2RGB)
        
        # Convert numpy arrays to torch tensors via PIL images
        input_image = T.ToTensor()(Image.fromarray(input_image))
        gt_image = T.ToTensor()(Image.fromarray(gt_image))
        
        # Stack images along the batch dimension and transform them together (for equal augmentation)
        stacked_images = th.stack([input_image, gt_image])
        transformed_images = self.transform(stacked_images)

        # Return the transformed images ([input_image, gt_image])
        return transformed_images[0], transformed_images[1]

## Running dataset construction code

In [None]:
# Get the dataset paths
figure_input_data_dir = "./FigureImages/input"
figure_gt_data_dir = "./FigureImages/ground_truth"
input_data_dir = "./UnderwaterImages/input"
gt_data_dir = "./UnderwaterImages/ground_truth"
#input_data_dir = "./Paired/underwater_scenes/trainA"
#gt_data_dir = "./Paired/underwater_scenes/trainB"

# Create the main (split up) and figure dataset
excluded = os.listdir(figure_input_data_dir)
if hyperparams["test_run"]:
    dataset = PairedImageDataset(
                                input_data_dir, 
                                gt_data_dir, 
                                size=hyperparams["test_run_size"],
                                excluded=excluded
                                )
else:
    dataset = PairedImageDataset(
                                input_data_dir, 
                                gt_data_dir, 
                                excluded=excluded
                                )
train_set, validation_set, test_set = random_split(dataset, hyperparams["data_splits"])
figure_dataset = PairedImageDataset(
                                    figure_input_data_dir, 
                                    figure_gt_data_dir, 
                                    random_perturbations=False
                                    )

# Create the dataloaders
train_set = DataLoader(dataset=train_set, batch_size=hyperparams["batch_size"], shuffle=True)
validation_set = DataLoader(dataset=validation_set, batch_size=hyperparams["batch_size"], shuffle=True)
test_set = DataLoader(dataset=test_set, batch_size=hyperparams["batch_size"], shuffle=True)
figure_set = DataLoader(dataset=figure_dataset, batch_size=len(figure_dataset))

# Models

## Discriminator

In [None]:
class DownModule(nn.Module):
    def __init__(self, in_channels, out_channels, leaky_relu_slope=0.2):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1)
        self.bn = nn.BatchNorm2d(out_channels)
        self.lrelu = nn.LeakyReLU(leaky_relu_slope)
        
    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.lrelu(x)
        return x      
    
class ZeroPadModule(nn.Module):
    def __init__(self):
        super().__init__()
        
    def forward(self, x):
        x = F.pad(x, (1, 1, 1, 1), mode='constant', value=0)
        return x

class Discriminator(nn.Module):
    def __init__(self, DEBUG=False):
        super().__init__()
        self.DEBUG = DEBUG
        
        self.DownLayers = nn.Sequential(
            DownModule(6, 64),
            DownModule(64, 128),
            DownModule(128, 256),
            ZeroPadModule(),
            nn.Conv2d(256, 512, kernel_size=4, stride=1, padding=0),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2),
            ZeroPadModule(),
            nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=0),
            nn.Sigmoid() #NOTE: Not actually in the paper, but required for the BCELoss (limits values to [0,1]). This produces a valid probability tensor.
        )
        
    def forward(self, x: th.Tensor, y: th.Tensor) -> th.Tensor:
        """Forward pass of the discriminator

        Args:
            x (th.Tensor): Raw underwater image
            y (th.Tensor): Enhanced underwater image

        Returns:
            th.Tensor: Output tensor measuring the realness of the input images
        """
        
        z = th.concatenate((x, y), dim=1)
        
        # Input tensor shape
        if self.DEBUG:
            print("Input tensor shape:")
            print(z.shape)
        
        for layer in self.DownLayers:
            z = layer(z)
            if self.DEBUG:
                print(z.shape)
        
        return z
    
#discriminator = Discriminator(DEBUG=True).to(device)

#sample = th.rand(1, 3, 256, 256, device=device)
#clone = sample.clone()
#output = discriminator(sample, clone)

## Generator / Autoencoder

In [None]:
class EncoderModule(nn.Module):
    def __init__(self, in_channels, out_channels, leaky_relu_slope=0.2):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1)
        self.bn = nn.BatchNorm2d(out_channels)
        self.lrelu = nn.LeakyReLU(leaky_relu_slope)
        
    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.lrelu(x)
        return x      

class FeatureMapModule(nn.Module):
    def __init__(self, in_channels, out_channels, leaky_relu_slope=0.2):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1)
        self.lrelu = nn.LeakyReLU(leaky_relu_slope)
        
    def forward(self, x):
        x = self.conv(x)
        x = self.lrelu(x)
        return x      

class DecoderModule(nn.Module):
    def __init__(self, in_channels, out_channels, dropout_prob=0.5):
        super().__init__()
        self.deconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1)
        self.bn = nn.BatchNorm2d(out_channels)
        self.dropout = nn.Dropout(dropout_prob)
        self.relu = nn.ReLU()
        
    def forward(self, x):
        x = self.deconv(x)
        x = self.bn(x)
        x = self.dropout(x)
        x = self.relu(x)
        return x
    
class OutputModule(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.deconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1)
        
    def forward(self, x):
        x = self.deconv(x)
        return x

class Autoencoder(nn.Module):
    """
    Autoencoder model for image generation

    A residual autoencoder model for image generation. 
    The final model will be an image-to-image translation model
    that enhances underwater images.
    """
    def __init__(self, DEBUG=False):
        super().__init__()
        self.DEBUG = DEBUG

        self.EncoderLayers = nn.ModuleList([
            EncoderModule(3, 64),
            EncoderModule(64, 128),
            EncoderModule(128, 256),
            EncoderModule(256, 512),
            EncoderModule(512, 512),
            EncoderModule(512, 512),
            EncoderModule(512, 512),
            FeatureMapModule(512, 512),
        ])
        
        self.DecoderLayers = nn.ModuleList([
            DecoderModule(512, 512),
            DecoderModule(1024, 512),
            DecoderModule(1024, 512),
            DecoderModule(1024, 512, dropout_prob=0.0),
            DecoderModule(1024, 256, dropout_prob=0.0),
            DecoderModule(512, 128, dropout_prob=0.0),
            DecoderModule(256, 64, dropout_prob=0.0),
        ])
        
        self.OutputLayer = OutputModule(128, 3)
        self.tanh = nn.Tanh() #NOTE: Not actually in the paper, but required to limit values to [0,1]. This produces a valid (float) image tensor.

        
    def forward(self, x):
        """Forward pass for the autoencoder model.

        Args:
            x (th.Tensor): Input image tensor

        Returns:
            th.Tensor: Output image tensor
        """
        # Store the activations of the encoder layers for skip connections
        layer_outputs = []
        
        if self.DEBUG:
            print("Starting forward pass")
            print(x.shape)
        
        # Encoder pass
        for i in range(len(self.EncoderLayers)):
            x = self.EncoderLayers[i](x)
            if i < len(self.EncoderLayers) - 1:
                layer_outputs.append(x)
            if self.DEBUG:
                print(x.shape)
        
        if self.DEBUG:
            print("Encoding complete")
            print(x.shape)
        
        # Checking the shapes of the stored activations
        #[print("Stored activations: ",x.shape) for x in layer_outputs]
        
        # Decoder pass      
        for i in range(len(self.DecoderLayers)):
            
            if i != 0:
                # Get the appropriate encoder activation
                s = layer_outputs.pop()
                
                # If the shapes match, concatenate the activations
                if x.shape == s.shape:
                    x = th.cat((x, s), 1)
                    
                else:
                    print("Error, shapes do not match")
                    print("X:", x.shape)
                    print("S:", s.shape)
                    return th.tensor([])

            # Pass the concatenated activations through the decoder layer
            x = self.DecoderLayers[i](x)
            if self.DEBUG:
                print(x.shape)
                 
        if self.DEBUG:
            print("Decoding complete")
        
        # Perform the final deconvolution
        x = th.cat((x, layer_outputs.pop()), 1)
        x = self.OutputLayer(x)
        x = self.tanh(x)
        
        if self.DEBUG:
            print("Is layer_outputs empty:", len(layer_outputs) == 0)
            print(x.shape)
            print("Output complete")
            
        return x

# Training Loop

## Defining the loop

In [None]:
from IPython.display import clear_output

def train_loop(dataloader, discriminator, generator, d_optimizer, g_optimizer, loss_stats, device, 
               epochs=150, performed_epochs=0, l1_lambda=100):
    # Define the loss functions
    d_real_loss = nn.BCELoss()
    d_gan_loss = nn.BCELoss()
    g_gan_loss = nn.BCELoss()
    g_l1_loss = nn.L1Loss()
    
    # Get the global paths for the generator, discriminator, and loss statistics
    global generator_path, discriminator_path, loss_stats_path
    
    
    ## Pre-training the generator
    #TODO: Unclear if we should have this here or not
    """
    pre_epochs = 10
    
    for pre_epoch in range(pre_epochs):
        
        for batch, (x, y) in tqdm(enumerate(dataloader)):
            x, y = x.to(device), y.to(device)
            #====================#
            # Generator training #
            #====================#
            
            # Zero the parameter gradients
            g_optimizer.zero_grad()
            
            # Forward pass
            z = generator(x)
            
            # Compute the loss
            gl1 = g_l1_loss(z, y)
            
            g_loss = l1_lambda * gl1
            
            # Backward pass
            g_loss.backward()
            
            #===================#
            # Update weights    #
            #===================#
            
            # Update weights
            g_optimizer.step()
            print("This ran")
    """
    
    for epoch in tqdm(range(performed_epochs, epochs)):
        # The loss stats of an epoch to average later
        loss_epoch = pd.DataFrame(columns=["d_loss", "g_loss", "g_GAN_loss", "g_L1_loss"])
    
        for batch, (x, y) in tqdm(enumerate(dataloader)):
            x, y = x.to(device), y.to(device)
            #========================#
            # Discriminator training #
            #========================#
            
            # Zero the parameter gradients
            d_optimizer.zero_grad()
            
            # Forward pass
            z = generator(x).detach()
            d_real = discriminator(x, y)
            d_fake = discriminator(x, z)
            
            # Compute the loss
            drl = d_real_loss(d_real, th.ones_like(d_real))
            dgl = d_gan_loss(d_fake, th.zeros_like(d_fake))
            
            d_loss = drl + dgl
            
            # Backward pass
            d_loss.backward()
        
            
            #====================#
            # Generator training #
            #====================#
            
            # Zero the parameter gradients
            g_optimizer.zero_grad()
            
            # Forward pass
            z = generator(x)
            d_fake = discriminator(x, z).detach()
            
            # Compute the loss
            ggl = g_gan_loss(d_fake, th.ones_like(d_fake))
            gl1 = g_l1_loss(z, y)
            
            g_loss = ggl + l1_lambda * gl1
            
            # Backward pass
            g_loss.backward()
            
            #===================#
            # Update weights    #
            #===================#
            
            # Update weights
            d_optimizer.step()
            
            # Update weights
            g_optimizer.step()
            
            #===================#
            # Cleanup memory   #
            #===================#
            #TODO: Verify that this is necessary and that it works as intended
            #del x, y, z, d_real, d_fake
            
            # Store the batch statistics
            loss_epoch = pd.concat([loss_epoch, pd.DataFrame({
                "d_loss": d_loss.item(), 
                "g_loss": g_loss.item(), 
                "g_GAN_loss": ggl.item(), 
                "g_L1_loss": gl1.item()
                }, index=[0])])
            
        
        # Clear the output and get, print, and save the epoch average loss statistics
        clear_output(wait=True)
        loss_epoch_mean = {**loss_epoch.mean().to_dict(), "Epoch": epoch}
        lem_df = pd.DataFrame(loss_epoch_mean, index=[0])
        display(lem_df)
        loss_stats = pd.concat([loss_stats, pd.DataFrame(loss_epoch_mean, index=[0])])
        
        # Save the current generator, discriminator, and epochs
        th.save(generator.state_dict(), generator_path)
        th.save(discriminator.state_dict(), discriminator_path)
        loss_stats.to_csv(loss_stats_path, index=False)
        
        # Plot the loss statistics on different subplots
        fig, axs = plt.subplots(2, 2, figsize=(15, 10))
        axs[0, 0].plot(loss_stats["Epoch"], loss_stats["d_loss"])
        axs[0, 0].set_title("Discriminator loss")
        axs[0, 1].plot(loss_stats["Epoch"], loss_stats["g_loss"])
        axs[0, 1].set_title("Generator loss")
        axs[1, 0].plot(loss_stats["Epoch"], loss_stats["g_GAN_loss"])
        axs[1, 0].set_title("Generator GAN loss")
        axs[1, 1].plot(loss_stats["Epoch"], loss_stats["g_L1_loss"])
        axs[1, 1].set_title("Generator L1 loss")
        plt.show()

        
    return loss_stats

# Save/load progress

To easily continue training the same model in separate runs we save all the data and can load from that state

## Dataset saving/loading

In [None]:
# Get the path where to save to or load from for the split indices
split_indices_path = os.path.join(model_data, "split_indices.pt")

# Save or load based on whether a split is already present
if os.path.exists(split_indices_path):
    # Load the indices
    split_indices = th.load(split_indices_path)
    train_indices = split_indices["train_indices"]
    validation_indices = split_indices["validation_indices"]
    test_indices = split_indices["test_indices"]
    
    # Replace the current dataloaders
    train_set = DataLoader(dataset=Subset(dataset, train_indices), 
                           batch_size=hyperparams["batch_size"], shuffle=True)
    validation_set = DataLoader(dataset=Subset(dataset, validation_indices), 
                                batch_size=hyperparams["batch_size"], shuffle=True)
    test_set = DataLoader(dataset=Subset(dataset, test_indices), 
                          batch_size=hyperparams["batch_size"], shuffle=True)
else:
    # Save the indices
    th.save({
        "train_indices": train_set.dataset.indices,
        "validation_indices": validation_set.dataset.indices,
        "test_indices": test_set.dataset.indices
    }, split_indices_path)

## Model and loss statistics loading/creation

In [None]:
# Load or create the generator
generator_path = os.path.join(model_data, "generator.pth")
generator = Autoencoder().to(device)
if os.path.exists(generator_path):
    generator.load_state_dict(th.load(generator_path))

# Load or create the discriminator
discriminator_path = os.path.join(model_data, "discriminator.pth")
discriminator = Discriminator().to(device)
if os.path.exists(discriminator_path):
    discriminator.load_state_dict(th.load(discriminator_path))

# The loss stats to keep track of
loss_stats_path = os.path.join(model_data, "loss_stats.csv")
performed_epochs = 0
if os.path.exists(loss_stats_path):
    loss_stats = pd.read_csv(loss_stats_path)
    performed_epochs = loss_stats.shape[0]
else:
    loss_stats = pd.DataFrame(columns=["Epoch","d_loss", "g_loss", "g_GAN_loss", "g_L1_loss"])

## Running the loop

Description of various GAN training problems: 
- https://developers.google.com/machine-learning/gan/problems
- https://arxiv.org/pdf/2005.00065.pdf

In [None]:
d_optimizer = optim.Adam(
    discriminator.parameters(), 
    lr=hyperparams["learning_rate"],
    betas=hyperparams["adam_betas"])

g_optimizer = optim.Adam(
    generator.parameters(), 
    lr=hyperparams["learning_rate"],
    betas=hyperparams["adam_betas"])

loss_stats = train_loop(
    train_set, 
    discriminator, 
    generator, 
    d_optimizer, 
    g_optimizer, 
    loss_stats,
    device, 
    epochs=hyperparams["epochs"], 
    performed_epochs=performed_epochs,
    l1_lambda=hyperparams["l1_lambda"])

display(loss_stats)

### Plotting traing results

In [None]:
# Plot the loss statistics on different subplots
fig, axs = plt.subplots(2, 2, figsize=(15, 10))
axs[0, 0].plot(loss_stats["Epoch"], loss_stats["d_loss"])
axs[0, 0].set_title("Discriminator loss")
axs[0, 1].plot(loss_stats["Epoch"], loss_stats["g_loss"])
axs[0, 1].set_title("Generator loss")
axs[1, 0].plot(loss_stats["Epoch"], loss_stats["g_GAN_loss"])
axs[1, 0].set_title("Generator GAN loss")
axs[1, 1].plot(loss_stats["Epoch"], loss_stats["g_L1_loss"])
axs[1, 1].set_title("Generator L1 loss")
plt.show()


### Testing image generation

In [None]:
# Disable gradient computation
with th.no_grad():
    # Get dataset images   
    batch_input, batch_gt = next(iter(train_set))

    # Generate a batch of enhanced images
    enhanced_images = generator(batch_input.to(device)).cpu()
    
    # Print the input, output, and groud truth for all images in the test batch
    for i in range(len(enhanced_images)):
        input_img = batch_input[i]
        output_img = enhanced_images[i]
        gt_image = batch_gt[i]
        
        # If the images are normalized to range [-1, 1], adjust the values for visualization
        input_img = (input_img + 1) / 2
        output_img = (output_img + 1) / 2
        gt_image = (gt_image + 1) / 2
        
        plt.figure(figsize=(12, 4))
        plt.subplot(1, 3, 1)
        plt.imshow(input_img.permute(1, 2, 0))
        plt.title("Input Image")
        plt.axis("off")
        
        plt.subplot(1, 3, 2)
        plt.imshow(output_img.permute(1, 2, 0))
        plt.title("Output Image")
        plt.axis("off")

        plt.subplot(1, 3, 3)
        plt.imshow(gt_image.permute(1, 2, 0))
        plt.title("Ground Truth Image")
        plt.axis("off")

        plt.show()

### Testing the predefined 5 images

In [None]:
# Disable gradient computation
with th.no_grad():
    # Get a batch (all) of the predefined images
    batch_input, batch_gt = next(iter(figure_set))

    # Generate the enhanced images and show them
    enhanced_images = generator(batch_input.to(device)).cpu()
    for i in range(len(enhanced_images)):
        input_img = batch_input[i]
        output_img = enhanced_images[i]
        gt_image = batch_gt[i]
        
        # If the images are normalized to range [-1, 1], adjust the values for visualization
        input_img = (input_img + 1) / 2
        output_img = (output_img + 1) / 2
        gt_image = (gt_image + 1) / 2
        
        plt.figure(figsize=(12, 4))
        plt.subplot(1, 3, 1)
        plt.imshow(input_img.permute(1, 2, 0))
        plt.title("Input Image")
        plt.axis("off")
        
        plt.subplot(1, 3, 2)
        plt.imshow(output_img.permute(1, 2, 0))
        plt.title("Output Image")
        plt.axis("off")

        plt.subplot(1, 3, 3)
        plt.imshow(gt_image.permute(1, 2, 0))
        plt.title("Ground Truth Image")
        plt.axis("off")

        plt.show()

# Model evaluation

## Setting model to evaluation mode

In [None]:
# Set the model to evaluation mode
generator.eval()
discriminator.eval()

## Defining the metrics

All metrics as implemented support multibatch processing. That is to say they support tensor input on the form:

[n, c=3, w, h] where n i the number of samples per batch. c is the number of colour channels. w and h are width/height respectively.

This may make code readability somewhat reduced.

### The UIQM rabbithole

UIQM is surprisingly complex. The different submetrics are appropriately broken down into separate functions.

In [None]:
def PSNR(x, y):
    """Peak Signal-to-Noise Ratio (PSNR)

    Args:
        x (np.ndarray): Image tensor, generated
        y (np.ndarray): Image tensor, ground truth

    Returns:
        float: PSNR value
    """
    # Maximum possible pixel value
    MAX = 1.0
    
    n, c, w, h = x.shape
    
    # Flattening each image, while retaining batch axis
    x = np.reshape(x, (n, c * w * h))
    y = np.reshape(y, (n, c * w * h))
    
    # Take the mean of the x-y difference along the batch axis
    mean = np.mean((x - y) ** 2, axis=1)
    
    # Compute the PSNR
    psnr = 10 * np.log10(MAX / mean)
    
    return psnr

def SSIM(x, y):
    """Structural Similarity Index Measure (SSIM)
    
    Args:
        x (np.ndarray): Image tensor, generated
        y (np.ndarray): Image tensor, ground truth

    Returns:
        float: SSIM value
    """
    # Maximum possible pixel value
    L = 1.0
    
    # Constants
    C1 = (0.01 * L) ** 2
    C2 = (0.03 * L) ** 2
    
    # Compute the mean of x and y along the channel, width, and height axes
    mu_x = np.mean(x, axis=(1,2,3))
    mu_y = np.mean(y, axis=(1,2,3))
    
    # Create a n, c, w, h tensor of the mean of x and y
    mu_x_b = np.broadcast_to(mu_x[:, np.newaxis, np.newaxis, np.newaxis], x.shape)
    mu_y_b = np.broadcast_to(mu_y[:, np.newaxis, np.newaxis, np.newaxis], y.shape)
    
    # Compute the variance and covariance of x and y
    sigma_x = np.mean((x - mu_x_b) ** 2, axis=(1, 2, 3))
    sigma_y = np.mean((y - mu_y_b) ** 2, axis=(1, 2, 3))
    
    sigma_xy = np.mean((x - mu_x_b) * (y - mu_y_b), axis=(1, 2, 3))
    
    # Compute the numerator and denominator of the SSIM
    num = (2 * mu_x * mu_y + C1) * (2 * sigma_xy + C2)
    den = (mu_x ** 2 + mu_y ** 2 + C1) * (sigma_x + sigma_y + C2)

    # Take the elementwise ratio of the numerator and denominator to obtain the SSIM
    ssim = num / den

    return ssim

def UIQM(x):
    """Underwater Image Quality Measure (UIQM)
    NOTE: Works for multibatch inputs [n, 3, w, h]

    Args:
        x (np.ndarray): Image tensor

    Returns:
        float: UIQM value
    """
    
    # Constants, as per the Sea-Pix-GAN paper
    c1 = 0.0282
    c2 = 0.2953
    c3 = 3.5753
    
    # Compute the UICM, UISM, and UIConM
    uicm = UICM(x)
    uism = UISM(x)
    uiconm = UIConM(x)
    
    return c1 * uicm + c2 * uism + c3 * uiconm

def UICM(x):
    """Underwater Image Colorfulness Measure (UICM)
    NOTE: Works for multibatch inputs [n, 3, w, h]

    Args:
        x (np.ndarray): Image tensor

    Returns:
        float: UICM value
    """
    # Extract the R, G, and B channels of all images in batch
    R = x[:,0, :, :]
    G = x[:,1, :, :]
    B = x[:,2, :, :]
    
    # Flatten the channels for batch
    n, w, h = R.shape
    R = np.reshape(R, (n, w * h))
    G = np.reshape(G, (n, w * h))
    B = np.reshape(B, (n, w * h))
    
    # Set alpha
    alpha = 0.1
    
    # Create the RG and YB channels
    RG = R - G
    YB = 0.5 * (R + G) - B
    
    # Compute the alpha trimmed distribution along each image
    aRG = stats.trimboth(RG, alpha, axis=1)
    aYB = stats.trimboth(YB, alpha, axis=1)
    
    # Compute the mean along each image
    mean_aRG = np.mean(aRG, axis=1)
    mean_aYB = np.mean(aYB, axis=1)
        
    # Compute the alpha trimmed standard deviation along each image
    std_aRG = np.sum((RG - mean_aRG[:, np.newaxis]) ** 2, axis=1) / (w * h - 1)
    std_aYB = np.sum((YB - mean_aYB[:, np.newaxis]) ** 2, axis=1) / (w * h - 1)
    
    std_aRG = np.sqrt(std_aRG)
    std_aYB = np.sqrt(std_aYB)
        
    # Constants, as per HVSIUIQM paper
    c1 = -0.0282
    c2 = 0.1586
    
    # Compute the 2 terms of the UICM
    term_1 = np.sqrt(mean_aRG ** 2 + mean_aYB ** 2) 
    term_2 = np.sqrt(std_aRG ** 2 + std_aYB ** 2) #TODO: can std be converted to variance to save computation?
    
    return c1 * term_1 + c2 * term_2

def UISM(x):
    """Underwater Image Sharpness Measure (UISM)
    NOTE: Works for multibatch inputs [n, 3, w, h]

    Args:
        x (np.ndarray): Image tensor

    Returns:
        float: UISM value
    """
    #=======================#
    # Edge detection filter #
    #=======================#
    
    # Run the image through a Sobel filter, along the x and y axes respectively. 
    x_grad = ndimage.sobel(x, axis=0)
    y_grad = ndimage.sobel(x, axis=1)
    
    # Then compute the magnitude of the gradients (hypotenuse of the two gradients)
    filtered_x = np.hypot(x_grad, y_grad)
    
    #=========================#
    # Enhancement measurement #
    #=========================#
    
    # Compute Enhancement Measure Estimation (EME) for each channel using the filtered image
    EME_R = EME(filtered_x[:,0, :, :])
    EME_G = EME(filtered_x[:,1, :, :])
    EME_B = EME(filtered_x[:,2, :, :])
    
    # Compute the weighted EME, using standard RGB channel weights
    lambda_R = 0.299
    lambda_G = 0.587
    lambda_B = 0.114
    
    return lambda_R * EME_R + lambda_G * EME_G + lambda_B * EME_B

def EME(x):
    """Enhancement Measure Estimation (EME)

    Args:
        x (np.ndarray): Image tensor, sobel filtered
        
    Returns:
        float: EME value
    """
    #TODO: What should the block size be? Unclear given paper (the one with the looong name)
    # Set the k1 and k2 constants as width and height of the image. n is the number of images in the batch
    n, k1, k2 = x.shape

    norm_const = 2 / (k1 * k2)
    
    EME = 0
    
    for i in range(k1):
        for j in range(k2):
            
            # Find max and min values in the block
            max_val = np.max(x[:, i:i+k1, j:j+k2])
            min_val = np.min(x[:, i:i+k1, j:j+k2])
            
            # Compute the local contrast
            local_contrast = np.log(max_val / min_val)
            
            EME += local_contrast
            
    return norm_const * EME

def UIConM(x):
    """Underwater Image Contrast Measure (UIConM)
    NOTE: Works for multibatch inputs [n, 3, w, h]

    Args:
        x (np.ndarray): Image tensor

    Returns:
        float: UIConM value
    """
    # RGB channel weights
    lambda_R = 0.299
    lambda_G = 0.587
    lambda_B = 0.114
    
    # Exract the intensity of the image, using standard RGB channel weights
    I = lambda_R * x[:,0, :, :] + lambda_G * x[:,1, :, :] + lambda_B * x[:,2, :, :]
    
    # Compute the logAMEE(Intensity)
    logamee = logAMEE(I)
    
    return logamee

def logAMEE(x):
    """Logarithmic Agaian Measure of Enhancement by Engropy (logAMEE)

    Args:
        x (np.ndarray): Image tensor, weighted intensity

    Returns:
        float: logAMEE value
    """
    # Define the PLIP operators #TODO: Double-check that these are in fact the correct operations
    add = lambda x,y: x + y - x * y         # circle-plus operator
    mul = lambda x,y: x * y                 # circle-multiply operator
    diff = lambda x,y: (x - y) / (1 - y)    # Theta operator
    
    # Define k1 and k2 as the width and height of the image
    n, k1, k2 = x.shape #TODO: Once again, what should the block size actually be? Check the paper again.
    
    
    # Define the constants
    norm_const = 1 / (k1 * k2)
    
    logamee = 0
    
    for i in range(k1):
        for j in range(k2):
            
            # Extract the block
            block = x[:, i:i+k1, j:j+k2]
            
            # Find the maximum and minimum values in the block
            max_val = np.max(block)
            min_val = np.min(block)
            
            I_diff = diff(max_val, min_val)
            I_add = add(max_val, min_val)
            
            ratio = I_diff / I_add
            
            logamee += ratio * np.log(ratio)
            
    return norm_const * logamee
    
def get_metrics(test_set, generator, DEBUG=False):
    # Set the frame size according to the debug mode    
    frame_size = 500 if not DEBUG else 10
    
    # Initialize the metrics dataframe
    metrics = pd.DataFrame(columns=["PSNR", "SSIM", "UIQM"])
    
    # While the length of metrics is < 500, keep sampling 
    # batches from the test data and computing the metrics
    while len(metrics) < frame_size:
        
        print(f"Metric samples computed: {len(metrics)}")
        
        # Disable gradient computation
        with th.no_grad():
            # TODO: Check if doesn't grab the same batch each time
            x, y = next(iter(test_set))
            
            x = x.to(device)
            y = y.detach().cpu().numpy()
            
            enhanced = generator(x).detach().cpu().numpy()
            
            psnr = PSNR(y, enhanced)
            ssim = SSIM(y, enhanced)
            uiqm = UIQM(enhanced)
            
            batch_metrics = pd.DataFrame({
                "PSNR": psnr, 
                "SSIM": ssim, 
                "UIQM": uiqm
                })
            
            metrics = pd.concat([metrics, batch_metrics])
            
    # Discard any rows beyond the 500th
    metrics = metrics.head(frame_size)
        
    return metrics

## Getting the test metrics

In [None]:
raw_metrics = get_metrics(test_set, generator, DEBUG=False)
summary = pd.DataFrame(columns=["Mean", "Std"])

summary.loc["PSNR"] = [raw_metrics["PSNR"].mean(), raw_metrics["PSNR"].std()]
summary.loc["SSIM"] = [raw_metrics["SSIM"].mean(), raw_metrics["SSIM"].std()]
summary.loc["UIQM"] = [raw_metrics["UIQM"].mean(), raw_metrics["UIQM"].std()]

display(raw_metrics.head())
display(summary)

### Saving the test metrics

In [None]:
raw_metrics.to_csv(os.path.join(model_data, "raw_metrics.csv"), index=False)
summary.to_csv(os.path.join(model_data, "summary.csv"), index=False)

### Performing the weakest paired T-test to ever have graced the earth

We perform a paired T-test (on a staggering 3 datapoints) to determine whether or not our model is statistically significantly different from the model in the original paper.

In [None]:
original_stats = pd.DataFrame(columns=["Mean", "Std"])
original_stats.loc["PSNR"] = [23.30, 1.68]
original_stats.loc["SSIM"] = [0.79, 0.09]
original_stats.loc["UIQM"] = [2.84, 0.20]

display(original_stats)

# Perform a paired t-test on the PSNR, SSIM, and UIQM metrics
t_statistic, p_value = stats.ttest_rel(summary['Mean'], original_stats['Mean'])

print(f"T-statistic: {t_statistic:.2f}")
print(f"P-value: {p_value:.2f}")

if p_value < 0.05:
    print("The difference in means is statistically significant")
else:
    print("The difference in means is not statistically significant")

The above results mean absolutely nothing, but I declare that it means our results are close enough to say we've reproduced the results.

#### Proving the paired T-test means nothing

Here's some of the data from the other models. Given the data we have available, they are also statistically indistinguishable from our model depsite being verifiably generated by other types of models.

In [None]:
funie_stats = pd.DataFrame(columns=["Mean", "Std"])
funie_stats.loc["PSNR"] = [20.49, 2.33]
funie_stats.loc["SSIM"] = [0.70, 0.09]
funie_stats.loc["UIQM"] = [2.78, 0.23]

display(funie_stats)

# Perform a paired t-test on the PSNR, SSIM, and UIQM metrics
t_statistic, p_value = stats.ttest_rel(summary['Mean'], funie_stats['Mean'])

print(f"T-statistic: {t_statistic:.2f}")
print(f"P-value: {p_value:.2f}")

if p_value < 0.05:
    print("The difference in means is statistically significant")
else:
    print("The difference in means is not statistically significant")

In [None]:
low_DCP_stats = pd.DataFrame(columns=["Mean", "Std"])
low_DCP_stats.loc["PSNR"] = [14.00, 4.02]
low_DCP_stats.loc["SSIM"] = [0.48, 0.15]
low_DCP_stats.loc["UIQM"] = [1.51, 0.49]

display(low_DCP_stats)

# Perform a paired t-test on the PSNR, SSIM, and UIQM metrics
t_statistic, p_value = stats.ttest_rel(summary['Mean'], low_DCP_stats['Mean'])

print(f"T-statistic: {t_statistic:.2f}")
print(f"P-value: {p_value:.2f}")

if p_value < 0.05:
    print("The difference in means is statistically significant")
else:
    print("The difference in means is not statistically significant")