# Re-Implementation of "Unpaired Image-to-Image Translation via Neural Schrödinger Bridge"

#### This notebook presents a re-implementation of the methods proposed in the paper "[Unpaired Image-to-Image Translation via Neural Schrödinger Bridge](https://arxiv.org/abs/2305.15086)". The original paper introduces a novel approach to address the limitations of traditional diffusion models in unpaired image-to-image (I2I) translation tasks. Diffusion models, which are generative models simulating stochastic differential equations (SDEs), often rely on a Gaussian prior assumption, which may not be suitable for all types of data distributions.

#### The concept of Schrödinger Bridge (SB) offers a promising solution by learning an SDE that facilitates translation between two arbitrary distributions. This paper proposes the Unpaired Neural Schrödinger Bridge (UNSB), which treats the SB problem through a series of adversarial learning challenges. This adaptation enables the use of advanced discriminators and regularization techniques, improving the model's ability to learn and translate between unpaired data sets effectively.

#### The re-implementation in this notebook aims to explore the scalability and efficiency of UNSB, demonstrating its capability to perform various unpaired I2I translation tasks, particularly focusing on high-resolution images where previous SB models have faced challenges.

## All necessary Imports & Set-Up
#### As a first step of our Project, we proceed to Import all the necessary Libraries we will using throughout the whole project and to Set-Up the necessary environment for the project to work.

In [None]:
import numpy as np 
import os 
import cv2
import torch 
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import torch.nn as nn
from torchvision import transforms
import torch.nn.functional as F
import functools
from PIL import Image
import torchvision 
from torchvision.utils import save_image
import torch.optim as optim
import math
from torch.nn import init
import torchvision.utils as vutils
import time
import torchvision.models as models
from scipy.linalg import sqrtm
from tqdm import tqdm
import sys 
from sklearn.metrics.pairwise import polynomial_kernel
from torch.nn.functional import adaptive_avg_pool2d
import pathlib
import torchvision.transforms as TF
from scipy import linalg
import matplotlib.image as mpimg

#### In the following section of the code, we configure the processing device for our neural network operations. 
#### Using CUDA, PyTorch can leverage the GPU’s capabilities to significantly speed up the computations necessary for neural network training. If a GPU is not available, the code defaults to using the CPU.
#### The model was trained on "Kaggle", using GPU P100.

In [None]:
# Set up processing device; use GPU via CUDA if available, otherwise fallback to CPU.
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## Loading the Dataset

#### For our image translation experiments, we have selected the horse2zebra dataset. This dataset is widely used in image-to-image translation tasks and provides a diverse set of images of horses and zebras, making it an ideal candidate for testing the capabilities of the Unpaired Neural Schrödinger Bridge (UNSB) model in performing complex translation tasks between these two distinct classes of images.

#### Due to limited computational resources, we opted to train the model on a smaller subset of the horse2zebra dataset. 
#### Specifically:

#### - **Training Datasets**: each of the horse and zebra categories in the training set contains 200 images, totaling 400 images for model training. This allows sufficient variability to challenge the model's ability to learn meaningful translations without overwhelming our available resources.
### - **Testing Datasets**: for testing, each category—horse and zebra—includes 120 images. This selection ensures that we can evaluate the model's performance on unseen data, providing a robust measure of its generalization capabilities across different inputs.

#### The selection of these subsets was aimed at maintaining a diverse representation of images to ensure that the model can still learn to generalize well across different types of input. By focusing on a manageable subset of images, we can efficiently test and iterate our model without the need for extensive computational power, which is often a constraint in academic or small-scale research settings. This approach enables us to gain valuable insights into the model's performance and scalability under restricted conditions, while still ensuring comprehensive exposure to the variability present in real-world scenarios.

In [None]:
# Paths to the training datasets for both domains:
path_trainA = '/kaggle/input/horse2zebra-new/horse2zebra/small_trainA'  # Directory path for 'small_trainA' containing a subset of horse images.
path_trainB = '/kaggle/input/horse2zebra-new/horse2zebra/small_trainB' # Directory path for 'small_trainB' containing a subset of zebra images.

# Paths to the testing datasets for both domains:
path_testA = '/kaggle/input/horse2zebra-new/horse2zebra/testA'  # Directory path for 'testA' containing horse images for testing.
path_testB = '/kaggle/input/horse2zebra-new/horse2zebra/testB'  # Directory path for 'testB' containing zebra images for testing.

## Creating a Directory for Generated Images
#### In the following section of the code, we prepare for storing the output of our image translation model by setting up a directory where generated images will be saved. This is an essential step in organizing the outputs for evaluation and visualization purposes.

#### The code checks if a directory named generated_images exists in the /kaggle/working directory. If it does not exist, the directory is created. This setup ensures that we have a designated place to store our model's outputs without any interruptions during runtime.

In [None]:
# Create output images directories
generated_images = '/kaggle/working/generated_images'
if not os.path.exists(generated_images):
    os.makedirs('/kaggle/working/generated_images')

## Data Preparation

#### In this section, we define a custom dataset class, ImageDataset, which inherits from PyTorch's Dataset class. This class is specifically tailored to handle image data for neural network training. Here’s how it works:

#### - Initialization (__init__): The constructor takes the directory path (img_dir) where the images are stored. It then traverses the directory to collect paths to all images with specified extensions (.png, .jpg, .jpeg), ensuring that only image files are included. Additionally, a transformation pipeline is defined to process the images, which includes converting them to PIL format, resizing, tensor conversion, and normalization.

#### - Length (__len__): This method returns the total number of images in the dataset, allowing PyTorch to calculate the number of samples.

#### - Get Item (__getitem__): This method retrieves an image by index, reads it from the disk, processes it through the specified transformations, and returns the final tensor ready for model input.

#### This structure facilitates the handling and preprocessing of image data, ensuring that the images are in the correct format and dimensions required by the neural network.

In [None]:
class ImageDataset(Dataset):
    def __init__(self, img_dir):
        self.img_dir = img_dir  # Store the directory path to images
        self.image_paths = []  # List to hold the paths of the images
        
        # Walk through the directory to collect all image file paths
        for root, dirs, files in os.walk(img_dir):
            for file in files:
                # Add paths for files with appropriate image extensions
                if file.endswith(('.png', '.jpg', '.jpeg')):  
                    self.image_paths.append(os.path.join(root, file))

        # Define transformations for image preprocessing
        self.transform = transforms.Compose([
            transforms.ToPILImage(),  # Convert numpy arrays to PIL images for further processing
            transforms.Resize((256, 256)),  # Resize images to uniform size for the model
            transforms.ToTensor(),  # Convert images to PyTorch tensors
            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])  # Normalize pixel values to [-1, 1]
        ])

    def __len__(self):
        # Return the total number of images available in the dataset
        return len(self.image_paths)

    def __getitem__(self, idx):
        # Fetch the path and load the image
        image_path = self.image_paths[idx]
        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)  # Convert color from BGR to RGB

        # Apply the predefined transformations to the image
        if self.transform:
            image = self.transform(image)
        
        return image  # Return the transformed image

## Dataset and DataLoader Configuration
#### In this section, we initialize our datasets and dataloaders for both the training and testing phases. The horse2zebra dataset has been split into distinct sets for horses (TrainA and TestA) and zebras (TrainB and TestB).

#### Due to our limited computational resources, we set the batch size to 1. While this is not optimal for general training scenarios as it might lead to higher variance during the training process, it allows for deep model testing and demonstrations on machines with limited memory.

In [None]:
# Set the batch size for training the model to 1, reflecting our resource limitations.
BATCH_SIZE = 1  # Smaller batch size is not ideal for performance but necessary for our resource-constrained environment.

#### We use the custom ImageDataset class, previously defined, to handle the loading and preprocessing of images from designated directory paths. This class ensures that images are correctly resized, normalized, and transformed into tensors, making them suitable for input into our neural network.

In [None]:
# Initialize datasets for training and testing using the predefined ImageDataset class.
train_datasetA = ImageDataset(img_dir=path_trainA)  # Dataset of horse images for training.
train_datasetB = ImageDataset(img_dir=path_trainB)  # Dataset of zebra images for training.
test_datasetA = ImageDataset(img_dir=path_testA)    # Dataset of horse images for testing.
test_datasetB = ImageDataset(img_dir=path_testB)    # Dataset of zebra images for testing.

#### DataLoaders are now used to efficiently load data in batches, shuffle the data for better generalization, and enable multi-threading to expedite the data loading process. Given our small batch size, shuffling helps in reducing sample correlation and prevents the model from memorizing the sequence of training examples.

In [None]:
# Create DataLoaders for the datasets to manage loading and batching of data during training and testing.
train_dataloaderA = DataLoader(train_datasetA, batch_size=BATCH_SIZE, shuffle=True)  # DataLoader for horse images in training.
train_dataloaderB = DataLoader(train_datasetB, batch_size=BATCH_SIZE, shuffle=True)  # DataLoader for zebra images in training.
test_dataloaderA = DataLoader(test_datasetA, batch_size=BATCH_SIZE, shuffle=True)     # DataLoader for horse images in testing.
test_dataloaderB = DataLoader(test_datasetB, batch_size=BATCH_SIZE, shuffle=True)     # DataLoader for zebra images in testing.

## Helper Functions for Model Construction

#### This section details the utility and helper functions that are integral to building, initializing, and visualizing outcomes from our model. These functions handle various tasks such as padding, normalization, weight initialization, and image visualization. Each function is designed to modularize the codebase, making it easier to maintain and understand.

#### - **Padding Function**: *get_pad_layer* - It returns a padding layer based on the specified type (reflect, replicate, or zero). It is used to add padding to the inputs of convolutional layers, ensuring that boundary effects are handled correctly.

#### - **Pixel Normalization**: *PixelNorm* - A normalization technique used in generator models to stabilize the training process. It normalizes the feature vectors to a unit length, which helps in controlling the scale of gradients during backpropagation.

#### - **Timestep Embedding**: Functions and classes (*get_timestep_embedding* and *TimestepEmbedding*) related to embedding timesteps into inputs. This is particularly useful in models that need to consider the sequence of data, like time-series analysis or specific generative models where the sequence of generation is crucial.

#### - **Weight Initialization**: *init_weights* and *init_net* - These functions initialize the weights of the network in a way that aims to improve convergence during training. Standard practices like initializing weights from a Gaussian distribution and setting biases to zero are followed.

#### - **Normalization and Visualization**: *Normalize* - Applies a power normalization to the tensor. *denormalize* and *visualize_images* are utility functions used to convert tensor values to image format and display them. These are particularly useful during testing and when evaluating the model's performance visually.

In [None]:
# Define padding layer based on type for use in convolutional layers
def get_pad_layer(pad_type):
    if pad_type in ['reflect', 'refl']:
        return nn.ReflectionPad2d
    elif pad_type in ['replicate', 'repl']:
        return nn.ReplicationPad2d
    elif pad_type == 'zero':
        return nn.ZeroPad2d
    else:
        raise NotImplementedError(f'Padding type {pad_type} not recognized')

# Module to normalize pixel values in images for stable training
class PixelNorm(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, input):
        return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8)
        
# Generate embeddings for timesteps in models that incorporate time dynamics
def get_timestep_embedding(timesteps, embedding_dim, max_positions=10000):
    assert len(timesteps.shape) == 1
    half_dim = embedding_dim // 2
    emb = math.log(max_positions) / (half_dim - 1)
    emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb)
    emb = timesteps.float()[:, None] * emb[None, :]
    emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
    if embedding_dim % 2 == 1:
        emb = F.pad(emb, (0, 1), mode='constant')
    return emb
                                  
# Class to embed timestep information into network inputs
class TimestepEmbedding(nn.Module):
    def __init__(self, embedding_dim, hidden_dim, output_dim, act=nn.ReLU()):
        super().__init__()
        self.fc1 = nn.Linear(embedding_dim, hidden_dim)  # First layer: embedding to hidden dimension
        self.act = act  # Activation function
        self.fc2 = nn.Linear(hidden_dim, output_dim)  # Second layer: hidden dimension to output

    def forward(self, t):
        if t.dim() == 1:
            t = t.unsqueeze(-1)
        t = t.float()  # Ensure t is of type float
        t = self.fc1(t)
        t = self.act(t)
        t = self.fc2(t)
        return t
    
# Initialize network weights using a specific strategy for better training performance
def init_weights(net, init_gain=0.02, debug=False):
    def init_func(m):  # define the initialization function
        classname = m.__class__.__name__
        if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
            if debug:
                print(classname)  # Print class name during debugging
            init_gain = 0.02
            init.normal_(m.weight.data, 0.0, init_gain)
            if hasattr(m, 'bias') and m.bias is not None:
                init.constant_(m.bias.data, 0.0)
        elif classname.find('LayerNorm') != -1:
            init.normal_(m.weight.data, 1.0, init_gain)
            init.constant_(m.bias.data, 0.0)

    net.apply(init_func)  # apply the initialization function <init_func>

# Set up network for use, optionally initialize weights, and set GPU configuration if available
def init_net(net, init_gain=0.02, gpu_ids=[], debug=False, initialize_weights=True):
    if initialize_weights:
        init_weights(net, init_gain=init_gain, debug=debug)
    return net

# Module to normalize tensors based on a power rule, useful for data and feature normalization
class Normalize(nn.Module):
    def __init__(self, power=2):
        super(Normalize, self).__init__()
        self.power = power

    def forward(self, x):
        norm = x.pow(self.power).sum(1, keepdim=True).pow(1. / self.power)
        out = x.div(norm)
        return out
    
# Function to convert normalized image data back to standard image format
def denormalize(tensor):
    return tensor.mul(0.5).add(0.5)  # Converts from [-1, 1] to [0, 1]

# Visualize a batch of images using a grid layout
def visualize_images(images, title="Generated Images"):
    images = images.cpu()  # Move images to CPU for visualization
    images = denormalize(images)  # Denormalize images to bring them to displayable format
    grid = vutils.make_grid(images, padding=2, normalize=True)  # Create a grid of images
    plt.figure(figsize=(8, 8))
    plt.axis("off")
    plt.title(title)
    plt.imshow(np.transpose(grid, (1, 2, 0)))  # Display images
    plt.show()

## Define the Generator

#### The generator used in the diffusion model is designed with a focus on conditional image synthesis, where conditions can be temporal (time) or style-based (z). The architecture is structured around ResNet blocks that have been adapted to conditionally modify their processing based on external inputs. 

#### Below is an overview of the helper components of the generator and of their roles:

#### - **Adaptive Layer**: This module adjusts the input feature maps using a style vector z. It computes scaling and bias terms (gamma and beta) from the style vector, which modulates the feature maps to produce style-specific effects.

#### - **Conditional ResNet Block**: Each ResNet block in the generator is designed to accept additional inputs—temporal embeddings and style vectors—that influence the block's operation, making it sensitive to both the progression of the diffusion process and the desired output characteristics.

#### - **Temporal and Style Embeddings**: These components transform raw time steps and style vectors into formats suitable for integration into the ResNet blocks, ensuring that the generator's output is appropriately varied based on the input conditions.

#### This architecture allows the model to generate images that are not only high in quality but also specifically tailored to the conditions provided, making it highly versatile for tasks that require dynamic and context-sensitive image synthesis.

In [None]:
# Generator's helper functions

class AdaptiveLayer(nn.Module):
    # Initializer for the adaptive layer which applies learned affine transformations.
    def __init__(self, in_channel, style_dim):
        super().__init__()
        self.style = nn.Linear(style_dim, in_channel * 2)  # Creates a linear transformation for style codes
        # Initialize the affine transform parameters gamma to 1 (scale) and beta to 0 (shift)
        self.style.bias.data[:in_channel] = 1
        self.style.bias.data[in_channel:] = 0

    # Forward pass which applies the affine transformation to the input features
    def forward(self, input, style):
        gamma, beta = self.style(style).chunk(2, 1)  # Split style into gamma and beta components
        gamma, beta = gamma.unsqueeze(2).unsqueeze(3), beta.unsqueeze(2).unsqueeze(3)  # Adjust dimensions for feature map
        return gamma * input + beta  # Apply the affine transformation



class ResnetBlockCond(nn.Module):
    # Initializer for a conditional ResNet block which integrates time and style-based conditioning.
    def __init__(self, dim, norm_layer, temb_dim, z_dim):
        super(ResnetBlockCond, self).__init__()
        self.conv_block = nn.Sequential(
            nn.ReflectionPad2d(1),  # Padding for maintaining spatial dimensions after convolution
            nn.Conv2d(dim, dim, kernel_size=3, padding=0),  # Standard convolution layer
            norm_layer(dim),  # Normalization layer
            nn.ReLU(inplace=False)  # Activation function
        ) 
        
        self.adaptive = AdaptiveLayer(dim, z_dim)  # Style-based adaptive layer
        
        self.conv_fin = nn.Sequential(
            nn.ReflectionPad2d(1),  # Additional padding for final convolution
            nn.Conv2d(dim, dim, kernel_size=3, padding=0),  # Final convolution layer
            norm_layer(dim)  # Final normalization layer
        )
        self.dense_time = nn.Linear(temb_dim, dim)  # Linear layer for transforming time conditioning
        nn.init.zeros_(self.dense_time.bias)  # Initialize the bias for the time conditioning layer to zero
        self.style = nn.Linear(z_dim, dim * 2)  # Style transformation similar to the adaptive layer
        # Initialize gamma to 1 and beta to 0 for the style conditioning
        self.style.bias.data[:dim] = 1
        self.style.bias.data[dim:] = 0

    # Forward pass through the ResNet block with conditional inputs
    def forward(self, x, time_cond, z):
        time_input = self.dense_time(time_cond)  # Apply linear transformation to the time conditioning
        out = self.conv_block(x)  # Pass input through the convolutional block
        out = out + time_input[:, :, None, None]  # Add time conditioning to the features
        out = self.adaptive(out, z)  # Apply adaptive styling
        out = self.conv_fin(out)  # Final convolution to refine features
        out = x + out  # Add skip connections for better gradient flow
        return out

#### The ResNet Generator with Conditional Blocks is meticulously designed to perform image synthesis, adapting dynamically to temporal and style-based conditioning factors. Here’s a breakdown of the architecture and its core components.

#### The generator commences with an **initial convolution layer** that employs a 7x7 kernel to expand the input image's channel dimensions. This layer is equipped with reflection padding to avoid border effects and maintain the integrity of image features.

#### Following the initial expansion, the generator employs a series of **convolutional layers** that progressively decrease the spatial dimensions while simultaneously increasing the depth of feature maps. This **downsampling** is achieved through stride-2 convolutions, which effectively reduce the image size by half with each layer, thereby deepening the network’s ability to abstract higher-level features from the input. This mechanism enhances the model's capability to capture complex and abstract representations that are crucial for effective image synthesis.

#### The core of the generator is **a series of ResNet blocks**, specifically tailored to accommodate external conditioning. Each block is designed to perform two key functions: process the incoming features to refine and transform them, and integrate external conditional inputs that influence these transformations.
#### Within each ResNet block, external conditions are integrated in two forms. First, **temporal embeddings**, which provide cues related to the progression or stages of image generation, are incorporated. These embeddings influence the block operations, tailoring the transformations to specific times in the generative process. Secondly, **style vectors** are utilized to modulate the activations directly, allowing the block to adjust its behavior based on desired style attributes, such as texture or overall appearance.
#### Each ResNet block modifies the feature maps by first applying standard convolutional operations followed by the injection of the conditioned biases from both temporal and style-based inputs. This is facilitated by an **adaptive mechanism** that scales and shifts the feature maps accordingly, ensuring that each block’s output is a direct function of both the input features and the external conditions.

#### To construct the final image, the generator reverses the dimensionality reduction from the downsampling stages through a series of transposed convolutional layers. These layers gradually restore the spatial dimensions of the feature maps, refining and **upscaling** them back to the original input size. This stage is critical as it reconstructs detailed and high-resolution output from the abstracted feature representations formed in the deeper layers of the network.
#### The last part of the upsampling pathway employs a final convolution with a 7x7 kernel, again using reflection padding to enhance image quality. This layer is followed by a Tanh activation function, which normalizes the output pixel values to a standard range, typically between -1 and 1, making the output ready for display or further processing.

In [None]:
class ResnetGenerator_cond(nn.Module):
    # Initialization of the conditional ResNet generator
    def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, n_blocks=9):
        super(ResnetGenerator_cond, self).__init__()
        
        # Ensuring the number of blocks is non-negative
        assert(n_blocks >= 0)
        # Determine if bias is needed based on the type of normalization layer
        if type(norm_layer) == functools.partial:
            use_bias = norm_layer.func == nn.InstanceNorm2d
        else:
            use_bias = norm_layer == nn.InstanceNorm2d
            
        # Initial convolution module to process input image
        self.model = nn.Sequential(
            nn.ReflectionPad2d(3),  # Padding before initial convolution
            nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),  # Initial convolution to transform input channel
            norm_layer(ngf),  # Normalization layer
            nn.ReLU(inplace=False)  # Activation function
        )
        
        self.ngf = ngf  # Number of generator filters
        
        # List of residual blocks with conditional inputs
        self.model_res = nn.ModuleList([])
        # Downsampling part of the model
        self.model_downsample = nn.Sequential(
            nn.Conv2d(ngf, ngf * 4, kernel_size=3, stride=2, padding=1, bias=use_bias),
            norm_layer(ngf * 4),
            nn.ReLU(inplace=False)
        )
        
        # Add multiple ResnetBlockCond instances for intermediate processing
        for i in range(n_blocks):
            self.model_res += [ResnetBlockCond(ngf * 4, norm_layer, temb_dim=4 * ngf, z_dim=4 * ngf)]
       
        # Upsampling part of the model to restore original image size
        self.model_upsample = nn.Sequential(
            nn.ConvTranspose2d(ngf * 4, ngf, kernel_size=3, stride=2, padding=1, output_padding=1, bias=use_bias),
            norm_layer(ngf),
            nn.ReLU(inplace=False),
            nn.ReflectionPad2d(3),
            nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0),
            nn.Tanh()  # Output activation to ensure output values are between -1 and 1
        )
        
        # Define a transformation for the latent vector z
        mapping_layers = [PixelNorm(),
                          nn.Linear(self.ngf * 4, self.ngf * 4),
                          nn.LeakyReLU(0.2)]
        self.z_transform = nn.Sequential(*mapping_layers)
        
        # Time embedding layers
        modules_emb = [nn.Linear(self.ngf, self.ngf * 4)]
        nn.init.zeros_(modules_emb[-1].bias)  # Initialize the bias to zero for stability
        modules_emb += [nn.LeakyReLU(0.2), nn.Linear(self.ngf * 4, self.ngf * 4)]
        nn.init.zeros_(modules_emb[-1].bias)  # Again, initialize the bias to zero
        modules_emb += [nn.LeakyReLU(0.2)]
        self.time_embed = nn.Sequential(*modules_emb)
                                
    # Define the forward pass with conditional inputs time_cond and z
    def forward(self, x, time_cond, z):
        z_embed = self.z_transform(z)  # Transform z before feeding it to the ResNet blocks
        temb = get_timestep_embedding(time_cond, self.ngf)  # Embedding the time steps
        time_embed = self.time_embed(temb)  # Applying the time embedding
        out = self.model(x)  # Initial processing of input
        out = self.model_downsample(out)  # Apply downsampling
        for layer in self.model_res:  # Apply each ResNet block sequentially
            out = layer(out, time_embed, z_embed)
        out = self.model_upsample(out)  # Final upsampling and output layer
        return out


In [None]:
# Initialize the Generator
gen = ResnetGenerator_cond(input_nc=3, output_nc=3, ngf=64, n_blocks=9, norm_layer=nn.InstanceNorm2d).to(device)

## Discriminator Architecture
#### The discriminator in a diffusion model plays a crucial role in distinguishing generated images from real ones, thus guiding the generator towards producing more realistic images. The discriminator typically employs a series of convolutional layers that progressively downsample the input image, extracting increasingly abstract features that are crucial for making the distinction between real and fake images, making it adapt at handling dynamic scenarios where the state or style of the images changes over time.

#### The Key Components of our Discriminator are:

#### 1. **Conditional Convolution Block** (ConvBlock_cond): Each block in the discriminator is designed to process input features conditionally based on the temporal embedding. This block integrates convolution, normalization, activation, and optional downsampling while adjusting its processing based on the embedded temporal information.

#### 2. **Temporal Embedding Transformation**: Before being fed into the convolution blocks, the temporal embeddings are transformed through a dedicated embedding module (TimestepEmbedding). This module adapts the raw embeddings to have a suitable dimensionality and format for integration into the convolution blocks.

#### Here's an overview of the complete architecture:

#### - **Initial Layer**: The first layer expands the input channel dimensions while incorporating the first level of conditional processing.
#### - **Intermediate Layers**: A series of convolution blocks increase the depth of feature maps progressively, each conditioned on the transformed temporal embeddings. These layers refine the discriminator's capability to extract relevant features for authenticity determination.
#### - **Final Layer**: The last convolution block outputs a single channel feature map without downsampling, culminating in a final convolution that determines the real or fake classification.

#### This architecture ensures that the discriminator not only effectively discriminates between real and fake images but also adapts its behavior based on the additional context provided by the temporal embeddings, enhancing its relevance and effectiveness in scenarios involving temporal dynamics.

In [None]:
# DISCRIMINATOR
class ConvBlock_cond(nn.Module):
    """Conditional convolution block with embedding integration for discriminator."""
    def __init__(self, in_channels, out_channels, embedding_dim, kernel_size=3, stride=1, padding=1, use_bias=True, norm_layer=nn.BatchNorm2d, downsample=True):
        super(ConvBlock_cond, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=use_bias) # Standard convolution layer
        self.norm = norm_layer(out_channels) # Normalization layer specified by norm_layer argument
        self.act = nn.LeakyReLU(0.2, inplace=True) # Activation function set to LeakyReLU for stable gradients
        self.downsample = downsample # Option to downsample the feature map for reducing spatial dimensions
        self.dense = nn.Linear(embedding_dim, out_channels) # Linear layer to transform the embedding dimension to match output channels

    def forward(self, x, t_emb):
        out = self.conv(x) # Apply convolution to the input
        out = out + self.dense(t_emb)[..., None, None] # Add transformed timestep embedding to the convolution output
        out = self.norm(out) # Normalize the output
        out = self.act(out) # Apply the activation function
        # Conditionally apply downsampling
        if self.downsample:
            out = nn.functional.avg_pool2d(out, kernel_size=2, stride=2)
        return out

class NLayerDiscriminator_ncsn_new(nn.Module):
    """Discriminator that uses conditional convolution blocks to process input images."""
    def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d):
        """Initialize the discriminator with conditional convolution blocks."""
        super(NLayerDiscriminator_ncsn_new, self).__init__()
        # Determine if bias should be used based on the type of normalization layer
        use_bias = norm_layer == nn.InstanceNorm2d

        # List of modules that make up the main discriminator model
        self.model_main = nn.ModuleList()
        
        # First convolution block that processes the initial input layer
        self.model_main.append(
            ConvBlock_cond(input_nc, ndf, 4 * ndf, kernel_size=4, stride=1, padding=1, use_bias=use_bias))

        # Dynamically add intermediate convolution blocks with increasing feature depth
        nf_mult = 1
        for n in range(1, n_layers):
            nf_mult_prev = nf_mult
            nf_mult = min(2 ** n, 8)
            self.model_main.append(
                ConvBlock_cond(ndf * nf_mult_prev, ndf * nf_mult, 4 * ndf, kernel_size=4, stride=1, padding=1, use_bias=use_bias, norm_layer=norm_layer)
            )

        # Add the last convolution block without downsampling to maintain spatial dimensions
        nf_mult_prev = nf_mult
        nf_mult = min(2 ** n_layers, 8)
        self.model_main.append(
            ConvBlock_cond(ndf * nf_mult_prev, ndf * nf_mult, 4 * ndf, kernel_size=4, stride=1, padding=1, use_bias=use_bias, norm_layer=norm_layer, downsample=False)
        )
        
        # Final convolution layer that outputs a single channel for discrimination
        self.final_conv = nn.Conv2d(ndf * nf_mult, 1, kernel_size=4, stride=1, padding=1)
        # Time embedding layer that prepares the timestep embedding for integration into convolution blocks
        self.t_embed = TimestepEmbedding(
            embedding_dim=1,
            hidden_dim=4 * ndf,
            output_dim=4 * ndf,
            act=nn.LeakyReLU(0.2)
        )

    def forward(self, input, t_emb, input2=None):
        """Forward pass through the discriminator with optional dual inputs and timestep embedding."""
        t_emb = t_emb.float()  # Convert timestep embedding to float for processing
        t_emb = self.t_embed(t_emb)  # Apply embedding transformation
        # If a second input is provided, concatenate it with the first input
        out = torch.cat([input, input2], dim=1) if input2 is not None else input
        
        # Process each convolution block with the current output and timestep embedding
        for layer in self.model_main:
            out = layer(out, t_emb) if isinstance(layer, ConvBlock_cond) else layer(out)
        
        return self.final_conv(out)  # Apply the final convolution layer to produce the discriminator's output

In [None]:
# Initialize the discriminator
disc = NLayerDiscriminator_ncsn_new(input_nc=3, ndf=64, n_layers=3, norm_layer=nn.InstanceNorm2d).to(device)

## Define netE and netF

#### Our architecture is composed of other two networks, respectively called netF and netE. 
#### - **NetF** : It is an instance of the PatchSampleF class, whose aim is to extract and processes patches from the feature maps generated by other parts of the model. The extracted patches are normalized to have a unit norm, which helps in stabilizing the training and ensuring consistent scaling. If specified, the extracted patches are projected into a lower-dimensional space through an MLP, which can potentially highlight important feature relationships. We will see later that it is used to compute the Noise Contrastive Estimation (NCE) Loss, which helps in aligning features from the source and target domains, improving the quality and consistency of generated images.
#### - **NetE** : On the other side, netE is another discriminator. Its input consists of concatenated images (noisy input and generated output). As we will see later, it is fundamental in computing the loss associated with bridging the gap between two given distributions, aligning distributions, regularizing entropy, and guiding the optimization process towards finding the optimal joint distribution.

In [None]:
# PatchSampleF aims to extract and processes patches from the feature maps generated by other parts of the model

class PatchSampleF(nn.Module):
    """ PatchsampleF is a class designed to sample and normalize patches from feature maps. 
    It can optionally use multi-layer perceptrons (MLPs) for further processing. During the forward pass, 
    it selects a specified number of patches from each feature map, optionally processes them through MLPs, 
    and applies L2 normalization. It also supports specifying custom patch indices or randomly selecting them if not provided """
    def __init__(self, use_mlp=False, init_type='normal', init_gain=0.02, nc=256, gpu_ids=[]):
        super(PatchSampleF, self).__init__()
        self.l2norm = Normalize(2)  
        self.use_mlp = use_mlp
        self.nc = nc
        self.mlp_init = False
        self.init_type = init_type
        self.init_gain = init_gain
        self.gpu_ids = gpu_ids

    def create_mlp(self, feats):
        """ Creation of MLP for further processing of features and patches """
        for mlp_id, feat in enumerate(feats):
            input_nc = feat.shape[0]
            mlp = nn.Sequential(
                nn.Linear(input_nc, self.nc),
                nn.LeakyReLU(0.2),
                nn.Linear(self.nc, self.nc)
            )
            mlp.cuda()
            setattr(self, f'mlp_{mlp_id}', mlp)
        self.mlp_init = True

    def forward(self, feats, num_patches=64, patch_ids=None):
        """ The forward method returns the sampled (and possibly processed) patches and their indices. 
        If no patches are sampled, it reshapes the feature maps back to their original dimensions """
        if self.use_mlp and not self.mlp_init:
            self.create_mlp(feats)

        return_feats = []
        return_ids = []

        for feat_id, feat in enumerate(feats):
            # Add batch dimension if missing
            if len(feat.shape) == 3:
                feat = feat.unsqueeze(0)

            B, C, H, W = feat.shape
            feat_reshape = feat.permute(0, 2, 3, 1).reshape(B, -1, C)  # Reshape to [B, H*W, C]

            if num_patches > 0:
                if patch_ids is not None and len(patch_ids) > feat_id:
                    current_patch_ids = patch_ids[feat_id]
                else:
                    # Generate random patch indices if none provided
                    current_patch_ids = [torch.randperm(feat_reshape.shape[1])[:num_patches].to(feat.device) for _ in range(B)]
                current_patch_ids = [torch.tensor(pid, dtype=torch.long, device=feat.device) for pid in current_patch_ids]
                # Sampling patches
                x_sample = torch.cat([feat_reshape[b, pid, :] for b, pid in enumerate(current_patch_ids)], dim=0)
                return_ids.append(current_patch_ids)
            else:
                x_sample = feat_reshape.reshape(-1, C)
                current_patch_ids = [torch.tensor([], dtype=torch.long, device=feat.device) for _ in range(B)]
                return_ids.append(current_patch_ids)

            if self.use_mlp:
                mlp = getattr(self, f'mlp_{feat_id}')
                x_sample = mlp(x_sample)

            x_sample = self.l2norm(x_sample)

            return_feats.append(x_sample)

        # Since we add patches for each batch, we must handle the concatenation properly
        if num_patches == 0:
            return_feats = [f.view(B, H, W, -1).permute(0, 3, 1, 2) for f in return_feats]

        return return_feats, return_ids 

In [None]:
# define netF and netE
netF = PatchSampleF(use_mlp=True, init_type='normal', init_gain=0.02, nc=256).to(device)
netE = NLayerDiscriminator_ncsn_new(input_nc=3*4, ndf=64, n_layers=3, norm_layer=nn.InstanceNorm2d).to(device)

## Loss Criterions 

#### Before delving deeply inside the model, we define two loss criterions
* #### **Cross Entropy Loss (NCE)** : Used for contrastive learning tasks, calculating loss for feature differences across network layers.
* #### **Mean Squared Error Loss (Gan Loss)** : Applied in adversarial training, measuring the difference between real and generated samples. 

In [None]:
# Define Loss Criterions
def criterionNCE(nce_layers):
    criterionNCE = []
    for nce_layer in nce_layers:
        criterionNCE.append(nn.CrossEntropyLoss(reduction='none').to(device))
    return criterionNCE

def criterionGAN():
    return nn.MSELoss().to(device)


class GANLoss(nn.Module):
    """Define Least Squares GAN loss."""

    def __init__(self, target_real_label=1.0, target_fake_label=0.0):
        """ Initialize the GANLoss class, whose parameters are float representing labels for real and fake images """
        
        super(GANLoss, self).__init__()
        self.register_buffer('real_label', torch.tensor(target_real_label))
        self.register_buffer('fake_label', torch.tensor(target_fake_label))
        self.loss = nn.MSELoss()

    def get_target_tensor(self, prediction, target_is_real):
        """ Creates label tensors filled with the ground truth label, and with the size of the input """
        if target_is_real:
            target_tensor = self.real_label
        else:
            target_tensor = self.fake_label
        return target_tensor.expand_as(prediction)

    def __call__(self, prediction, target_is_real):
        """ Calculate loss given Discriminator's prediction and ground truth labels """
        target_tensor = self.get_target_tensor(prediction, target_is_real)
        return self.loss(prediction, target_tensor)

## SB Model

#### At this point, we have described all the architectures needed for our Diffusion Model. Let's have some fun! :)

#### Backbone Structure (**forward function**) : 
* #### The diffusion process gradually adds noise to the images, which is crucial for generating diverse and realistic intermediate states that the generator network refines
* #### By interpolating between previous states and adding controlled noise, we ensure smooth transitions and avoid abrupt changes in the image states 
* #### Using the generator network at each timestep allows us to iteratively improve the quality of the noisy images, aligning them closer to the desired output distribution 


#### The losses computed are : 

1. #### ***D_loss*** : Discriminator D loss is computed through the compute_D_loss method, which calculates the adversarial loss for the discriminator. It computes separate losses for real and fake images and then combines them to obtain the total discriminator loss. The latter is scaled by 0.5 to ensure equal contribution from both fake and real losses. This loss guides the training of the discriminator to better distinguish between real and generated images.
2. #### ***E_loss*** : Discriminator E loss is computed through the compute_E_loss method, whose primary goal is to guide the training of netE towards learning meaningful representations of transition distributions between noisy and generated image pairs. By minimizing the loss and incorporating regularization techniques, the network aims to align these distributions effectively, facilitating the generation of realistic and coherent images by the generator network G
3. #### ***G_loss*** : Generator G loss is computed through the compute_G_loss method, that evaluates the overall loss incurred by, encompassing multiple loss components : 
    * #### G_GAN Loss : Encourages the generator to produce realistic images 
    * #### Schrödinger Bridge loss : Ensures temporal consistency and distribution alignment 
    * #### Noise Contrastive Estimation : Promotes feature alignment between real and generated images 

  #### This combination enables the generator to learn effective image generation strategies that produce high-quality images consistent with the target distribution.
4. #### ***NCE_loss*** : The idea behind the compute_NCE_loss method is to compute a form of contrastive loss. It aims to learn representations by contrasting similarities and differences between features extracted from the source and target images at multiple layers of the network. After applying a weighting factor, it returns the average loss across all layers

In [None]:
class SBModel(nn.Module):
    
    def __init__(self):
        """ Initializes the SBModel class, setting up parameters, loss names, model names, visual names, optimizers, and other necessary configurations """
        # Note that parameters have been taken directly from the paper, except for the number of epochs, due to our hardware computation limits 
        super(SBModel,self).__init__()
        self.loss_names = ['G_GAN', 'D_real', 'D_fake', 'G', 'NCE','SB']
        self.model_names = ['G','F','D','E']
        self.visual_names = ['real_A','real_A_noisy', 'fake_B', 'real_B']
        self.optimizers = []
        self.tau = 0.1
        self.device = device   
        self.lambda_NCE = 1.0 
        self.nce_idt = True
        self.nce_layers = [0,4,8,12,16]  
        self.num_patches = 256
        self.netG = gen
        self.netD = disc
        self.netE = netE
        self.netF = netF 
        self.ngf = 64
        self.criterionNCE = criterionNCE(self.nce_layers)
        self.criterionGAN = GANLoss().to(device)
        self.lr = 0.00001
        self.beta1 = 0.5
        self.beta2 = 0.999 
        
        # Defining Optimizers
        self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=self.lr, betas=(self.beta1, self.beta2))
        self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=self.lr, betas=(self.beta1, self.beta2))
        self.optimizer_E = torch.optim.Adam(self.netE.parameters(),lr=self.lr, betas=(self.beta1, self.beta2))
        
    def data_dependent_initialize(self, dataA,dataB, dataA2, dataB2): 
        """ Prepares the model for training, computing fake images using the generator and initial losses for the generator 'G' 
        and the discriminators 'D' and 'E'. It is conditioned on whether the loss function involving the NCE term is active. 
        If so, it initializes an optimizer for netF """
        bs = 1
        self.set_input(dataA,dataB, dataA2, dataB2)
        self.real_A = self.real_A[:bs]
        self.real_B = self.real_B[:bs]
        self.real_A2 = self.real_A2[:bs]
        self.real_B2 = self.real_B2[:bs]
        self.forward()  
        self.compute_G_loss().backward()
        self.compute_D_loss().backward()
        self.compute_E_loss().backward()  
        if self.lambda_NCE > 0.0:
            self.optimizer_F = torch.optim.Adam(self.netF.parameters(), lr=self.lr, betas=(self.beta1, self.beta2))
            self.optimizers.append(self.optimizer_F)
        
    
    def set_input(self, dataA, dataB, dataA2, dataB2):
        """ Responsible for unpacking input data from the dataloader and performing any necessary preprocessing steps """
        self.real_A = dataA.to(device)
        self.real_B = dataB.to(device)
        self.real_A2 = dataA2.to(device)
        self.real_B2 = dataB2.to(device)
        
    def set_requires_grad(self, nets, requires_grad=True):
        """ Toggles the requirement for gradient computation for the parameters in the provided networks. 
        It is s helpful for controlling which parts of the model are trainable during different training phases """
        if not isinstance(nets, list):
            nets = [nets]
        for net in nets:
            if net is not None:
                for param in net.parameters():
                    param.requires_grad = requires_grad
    
    def get_current_losses(self):
        """ Retrieves the current training losses/errors from the model and returns them as a dictionary """
        errors_ret = {}
        for name in self.loss_names:
            if isinstance(name, str):
                errors_ret[name] = float(getattr(self, 'loss_' + name))  # float(...) works for both scalar tensor and float number
        return errors_ret
        
    def optimize_parameters(self):
        """ It ensures that the gradients are properly calculated and used to update the network weights during training """
        
        # Forward pass
        self.forward()
        
        # Set models to training mode
        self.netG.train()
        self.netE.train()
        self.netD.train()
        self.netF.train()
        
        # Update Discriminator D 
        self.set_requires_grad(self.netD, True) # Enables gradient calculation for D
        self.optimizer_D.zero_grad()  # Zeros the gradients of D_optimizer
        self.loss_D = self.compute_D_loss()  # Computes D loss 
        self.loss_D.backward()  # Backpropagates the gradient 
        torch.nn.utils.clip_grad_norm_(self.netD.parameters(), max_norm=1)   # Clip gradients to avoid exploiding gradients 
        self.optimizer_D.step()  # Updates the parameters of D 
        
        # Update Discriminator E
        self.set_requires_grad(self.netE, True)  # Enables gradient calculation for E
        self.optimizer_E.zero_grad()  # Zeros the gradients of E_optimizer
        self.loss_E = self.compute_E_loss()  # Computes E loss
        self.loss_E.backward()   # Backpropagates the gradient 
        torch.nn.utils.clip_grad_norm_(self.netE.parameters(), max_norm=1)  # Clip gradients to avoid exploiding gradients 
        self.optimizer_E.step()  # Updates the parameters of E
    
        # Update Generator G
        self.set_requires_grad(self.netD, False)  # Disables gradient calculation for discriminator D since it is not being updated in this step  
        self.set_requires_grad(self.netE, False)  # Disables gradient calculation for discriminator E since it is not being updated in this step  
        
        self.optimizer_G.zero_grad()  # Zeros the gradient of G_optimizer
        self.optimizer_F.zero_grad()  # Zeros the gradient of F_optimizer 
        
        self.loss_G = self.compute_G_loss()   # Compute G loss 
        self.loss_G.backward()  # Backpropagates the gradient
        
        torch.nn.utils.clip_grad_norm_(self.netG.parameters(), max_norm=1)  # Clip gradients to avoid exploiding gradients 
        self.optimizer_G.step()  # Updates the parameters of G
    
        torch.nn.utils.clip_grad_norm_(self.netF.parameters(), max_norm=1)  #  Clip gradients to avoid exploiding gradients  
        self.optimizer_F.step()  # Updates the parameters of F

    def forward(self):
        """ Diffusion Process, described above """
        tau = 0.01  # Entropy parameter 
        T = 5  # Number of time steps 
        incs = np.array([0] + [1/(i+1) for i in range(T-1)])  # Array of incremental values used to define time steps 
        times = np.cumsum(incs)
        times = times / times[-1]
        times = 0.5 * times[-1] + 0.5 * times
        times = torch.tensor(times).float().cuda()   # Array of normalized time steps, scaled and shifted 
        self.times = times
        bs =  self.real_A.size(0)
        time_idx = (torch.randint(T, size=[1]).cuda() * torch.ones(size=[1]).cuda()).long()  # Randomly selected time step index 
        self.time_idx = time_idx  
        self.timestep     = times[time_idx]  # Actual time value corresponsing to 'time_idx'
        
        with torch.no_grad():
            self.netG.eval()
            for t in range(0, self.time_idx.int().item() + 1):  # Iteration over each time step up to the current index 
                if t > 0:
                    # Interpolation factors based on the current and previous time steps for temporal interpolation -> Paper Fig. 3 
                    delta = times[t] - times[t-1]   
                    denom = times[-1] - times[t-1]  
                    inter = (delta / denom)         
                    scale = (delta * (1 - delta / denom))  
                    
                
                """ Handling Input 1 """
                Xt       = self.real_A if (t == 0) else (1-inter)* Xt + inter * Xt_1.detach() + (scale * tau).sqrt() * torch.randn_like(Xt).to(self.real_A.to(device))
                # Xt is updated using its previous state, the output from the previous timestep Xt_1, and added Gaussian noise 
                time_idx = (t * torch.ones(size=[self.real_A.shape[0]]).to(self.real_A.to(device))).long()
                time     = times[time_idx]
                z        = torch.randn(size=[self.real_A.shape[0],4*self.ngf]).to(self.real_A.to(device))
                Xt_1     = self.netG(Xt, time_idx, z) # Xt_1 is the output of the generator given the noisy input Xt, current time_idx, and latent vector z. 
        
                """ Handling input 2 """
                # We consider another input to help stabilize the training process. This ensures that the model learns consistent features across different instances. It can be considered as a sort of data augmentation 
                Xt2 = self.real_A2 if (t == 0) else (1-inter)*Xt2 + inter*Xt_12.detach() + (scale*tau).sqrt() * torch.randn_like(Xt2).to(self.real_A2.to(device))
                # Xt2 is updated using its previous state, the output from the previous timestep Xt_12, and added Gaussian noise 
                time_idx = (t * torch.ones(size=[self.real_A.shape[0]]).to(self.real_A.to(device))).long()
                time     = times[time_idx]
                z        = torch.randn(size=[self.real_A.shape[0], 4 * self.ngf]).to(self.real_A.to(device))
                Xt_12    = self.netG(Xt2, time_idx, z)  # Xt_12 is the output of the generator given the noisy input Xt2, current time_idx, and latent vector z.
                
                if self.nce_idt:
                    XtB = self.real_B if (t == 0) else (1-inter) * XtB + inter * Xt_1B.detach() + (scale * tau).sqrt() * torch.randn_like(XtB).to(self.real_A.to(device))
                    # XtB is updated using its previous state, the output from the previous timestep Xt_1B, and added Gaussian noise 
                    time_idx = (t * torch.ones(size=[self.real_A.shape[0]]).to(self.real_A.to(device))).long()
                    time     = times[time_idx]
                    z        = torch.randn(size=[self.real_A.shape[0],4*self.ngf]).to(self.real_A.to(device))
                    Xt_1B = self.netG(XtB, time_idx, z)  # Xt_1B is the output of the generator given the noisy input XtB, current time_idx, and latent vector z.
                    
            if self.nce_idt:
                self.XtB = XtB.detach()
                
            self.real_A_noisy = Xt.detach()
            self.real_A_noisy2 = Xt2.detach()          
        
        z_in    = torch.randn(size=[2*bs,4*self.ngf]).to(self.real_A.to(device))  # Random noise for generator inputs 
        z_in2    = torch.randn(size=[bs,4*self.ngf]).to(self.real_A.to(device))   # # Random noise for generator inputs 
        
        self.real = torch.cat((self.real_A, self.real_B), dim=0) if self.nce_idt else self.real_A   # Concatenate real horse and real zebra image if identity loss is enabled 
        self.realt = torch.cat((self.real_A_noisy, self.XtB), dim=0) if self.nce_idt else self.real_A_noisy  # Concatenate noisy horse and noisy zebra if identity loss is enabled 
        self.fake = self.netG(self.realt,self.time_idx,z_in)   # Apply the generator to the concatenated first step of  noisy images 
        self.fake_B2 =  self.netG(self.real_A_noisy2,self.time_idx,z_in2)  # Apply the generator to the second set of noisy image 
        self.fake_B = self.fake[:self.real_A.size(0)]  # Extract "generated zebra" (horse with zebra's features) from self.fake   

    def compute_D_loss(self):
        """ Computation of Discriminator D loss, combining losses for real and fake images """
        bs = self.real_A.size(0)
        fake = self.fake_B.detach()   # Obtained Fake Images 
        std = torch.rand(size=[1]).item()
        pred_fake = self.netD(fake,self.time_idx)    # Discriminator D's predictions for fake images 
        self.loss_D_fake = self.criterionGAN(pred_fake, False).mean() #Computes Adversarial Loss for fake images, setting target label to 'False' to denote that the images are fake
        self.pred_real = self.netD(self.real_B,self.time_idx)  # Discriminator D's predictions for real images 
        loss_D_real = self.criterionGAN(self.pred_real, True)  # Computes Adversarial Loss for real images, setting target label to 'True' to denote that the images are real
        self.loss_D_real = loss_D_real.mean()
        self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5  # Total Discriminator D loss 
        return self.loss_D
    
    def compute_E_loss(self):
        """ Computation of Discriminator E loss, aiming to bridge the gap between distributions """
        bs =  self.real_A.size(0)
        
        XtXt_1 = torch.cat([self.real_A_noisy,self.fake_B.detach()], dim=1)   # Concatenation of noisy input image with the corresponding generated image 
        XtXt_2 = torch.cat([self.real_A_noisy2,self.fake_B2.detach()], dim=1) # Concatenation of noisy input image 2 with the corresponding generated image 2 
        temp = torch.logsumexp(self.netE(XtXt_1, self.time_idx, XtXt_2).reshape(-1), dim=0).mean()  # Entropy term which includes a log-sum-exp term for stability
        # This operation helps to approximate the log of the integral of the transition probabilities, providing a more stable and robust computation 
        self.loss_E = -self.netE(XtXt_1, self.time_idx, XtXt_1).mean() +temp + temp**2  # Total E loss is computed, including terms related to negative LL and regularization 
        
        return self.loss_E
    
    def compute_G_loss(self):
        """ Compute Generator G Loss, given by the combination of G_GAN, SB and NCE loss """
        bs = 1
        tau = 0.01
        lambda_GAN = 1.0
        lambda_SB = 1.0
        lambda_NCE = 1.0
        
        fake = self.fake_B
        std = torch.rand(size=[1]).item() 
        
        if lambda_GAN > 0:
            pred_fake = self.netD(fake,self.time_idx)  # Discriminator D predictions on generated images 
            self.loss_G_GAN = self.criterionGAN(pred_fake, True).mean() # Compares fake predictions to the target label 'True', indicating these should be real 
        else:
            self.loss_G_GAN = 0

        self.loss_SB = 0
        if lambda_SB > 0:
            XtXt_1 = torch.cat([self.real_A_noisy, self.fake_B], dim=1)
            XtXt_2 = torch.cat([self.real_A_noisy2, self.fake_B2], dim=1)
            bs = 1
            ET_XY    = self.netE(XtXt_1, self.time_idx, XtXt_1).mean() - torch.logsumexp(self.netE(XtXt_1, self.time_idx, XtXt_2).reshape(-1), dim=0)  # Helps in aligning the distributions 
            self.loss_SB = -(self.timestep - self.time_idx[0])/self.timestep*tau*ET_XY
            self.loss_SB += self.tau*torch.mean((self.real_A_noisy-self.fake_B)**2)
        
        if lambda_NCE > 0:
            self.loss_NCE = self.calculate_NCE_loss(self.real_A, fake, lambda_NCE) # NCE loss helps in aligning the feature distributions of the real and generated images 
        else: 
            self.loss_NCE, self.loss_NCE_bd = 0.0, 0.0

        self.loss_G = lambda_GAN * self.loss_G_GAN + lambda_SB * self.loss_SB + lambda_NCE * self.loss_NCE # Total Generator loss 
        return self.loss_G
        
    def calculate_NCE_loss(self, src, tgt, lambda_NCE):
        """ Computation of Noise Contrastive Estimation Loss, measuring similarity between patches extracted from source and target images"""
        num_patches = 256
        nce_layers = [0,4,8,12,16]
        num_layers = len(nce_layers)
        z = torch.randn(size=[self.real_A.size(0),4*self.ngf]).to(self.real_A.to(device))
        feat_q = self.netG(tgt, self.time_idx, z)  # Feature Map obtained from the generator for target images 
        feat_k = self.netG(src, self.time_idx,z)   # Feature Map obtained from the generator for source images 
        feat_k_pool, sample_ids = self.netF(feat_k, num_patches, None)  # Through netF, we extract patches from the feature maps for target images 
        feat_q_pool, _ = self.netF(feat_q, num_patches, sample_ids)     # Through netF, we extract patches from the feature maps for source images 

        total_nce_loss = 0.0
        for f_q, f_k, crit, nce_layer in zip(feat_q_pool, feat_k_pool, self.criterionNCE, nce_layers):
            loss = crit(f_q, f_k) * lambda_NCE   # Cross entropy is used to measure the similarity between the patches 
            total_nce_loss += loss.mean()
        return total_nce_loss / num_layers  # Total loss is averaged across all sampled patches and layers  

## Introduction to Inception V3

#### Inception v3 is a convolutional neural network pre-trained on a large dataset of images (such as ImageNet) and is used to extract meaningful features from images. 

#### Using a network like Inception v3 allows for obtaining a high-level representation of images that captures relevant information to evaluate the quality and similarity of generated images compared to real ones. 

#### Below, we have provided the model that will be used for calculating FID and KID values: 

In [None]:
# Inception V3
# URL for the pretrained model weights file
FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth'

try:
    from torchvision.models.utils import load_state_dict_from_url
except ImportError:
    from torch.utils.model_zoo import load_url as load_state_dict_from_url

class InceptionV3(nn.Module):
    # Index of the default Inception block to return,
    # corresponds to the output of the final average pooling
    DEFAULT_BLOCK_INDEX = 3

    # Maps feature dimensionality to their respective output block indices
    BLOCK_INDEX_BY_DIM = {
        64: 0,   # First max pooling features
        192: 1,  # Second max pooling features
        768: 2,  # Pre-auxiliary classifier features
        2048: 3  # Final average pooling features
    }

    def __init__(self,
                 output_blocks=[DEFAULT_BLOCK_INDEX],
                 resize_input=True,
                 normalize_input=True,
                 requires_grad=False,
                 use_fid_inception=True):
        '''Initialization function'''
        super(InceptionV3, self).__init__()

        self.resize_input = resize_input
        self.normalize_input = normalize_input
        self.output_blocks = sorted(output_blocks)
        self.last_needed_block = max(output_blocks)

        # Ensure that the last needed block is not greater than 3
        assert self.last_needed_block <= 3, \
            'Last possible output block index is 3'

        self.blocks = nn.ModuleList()

        # Load the correct pretrained Inception model
        if use_fid_inception:
            inception = fid_inception_v3()
        else:
            inception = _inception_v3(pretrained=True)

        # Block 0: input to the first max pooling
        block0 = [
            inception.Conv2d_1a_3x3,
            inception.Conv2d_2a_3x3,
            inception.Conv2d_2b_3x3,
            nn.MaxPool2d(kernel_size=3, stride=2)
        ]
        self.blocks.append(nn.Sequential(*block0))

        # Block 1: first max pooling to the second max pooling
        if self.last_needed_block >= 1:
            block1 = [
                inception.Conv2d_3b_1x1,
                inception.Conv2d_4a_3x3,
                nn.MaxPool2d(kernel_size=3, stride=2)
            ]
            self.blocks.append(nn.Sequential(*block1))

        # Block 2: second max pooling to the auxiliary classifier
        if self.last_needed_block >= 2:
            block2 = [
                inception.Mixed_5b,
                inception.Mixed_5c,
                inception.Mixed_5d,
                inception.Mixed_6a,
                inception.Mixed_6b,
                inception.Mixed_6c,
                inception.Mixed_6d,
                inception.Mixed_6e,
            ]
            self.blocks.append(nn.Sequential(*block2))

        # Block 3: auxiliary classifier to the final average pooling
        if self.last_needed_block >= 3:
            block3 = [
                inception.Mixed_7a,
                inception.Mixed_7b,
                inception.Mixed_7c,
                nn.AdaptiveAvgPool2d(output_size=(1, 1))
            ]
            self.blocks.append(nn.Sequential(*block3))

        # Set requires_grad for all parameters based on the requires_grad flag
        for param in self.parameters():
            param.requires_grad = requires_grad

    def forward(self, inp):
        '''Forward function'''
        outp = []
        x = inp

        # Resize input if the flag is set
        if self.resize_input:
            x = F.interpolate(x,
                              size=(299, 299),
                              mode='bilinear',
                              align_corners=False)

        # Normalize input if the flag is set
        if self.normalize_input:
            x = 2 * x - 1  # Scale from range (0, 1) to range (-1, 1)

        # Pass through the blocks sequentially
        for idx, block in enumerate(self.blocks):
            x = block(x)
            if idx in self.output_blocks:
                outp.append(x)

            # Break if the last needed block is reached
            if idx == self.last_needed_block:
                break

        return outp


def _inception_v3(*args, **kwargs):
    try:
        version = tuple(map(int, torchvision.__version__.split('.')[:2]))
    except ValueError:
        # Just a caution against weird version strings
        version = (0,)

    if version >= (0, 6):
        kwargs['init_weights'] = False

    return torchvision.models.inception_v3(*args, **kwargs)


def fid_inception_v3():
    '''The Inception model for FID computation uses a different set of weights and has a slightly different structure than torchvision's Inception.
        This method first constructs torchvision's Inception and then patches the necessary parts that are different in the FID Inception model.'''
    inception = _inception_v3(num_classes=1008,
                              aux_logits=False,
                              pretrained=False)
    inception.Mixed_5b = FIDInceptionA(192, pool_features=32)
    inception.Mixed_5c = FIDInceptionA(256, pool_features=64)
    inception.Mixed_5d = FIDInceptionA(288, pool_features=64)
    inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128)
    inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160)
    inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160)
    inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192)
    inception.Mixed_7b = FIDInceptionE_1(1280)
    inception.Mixed_7c = FIDInceptionE_2(2048)

    # Load the state dictionary for the FID Inception model
    state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True)
    inception.load_state_dict(state_dict)
    return inception

class FIDInceptionA(torchvision.models.inception.InceptionA):
    '''InceptionA block patched for FID computation'''
    def __init__(self, in_channels, pool_features):
        super(FIDInceptionA, self).__init__(in_channels, pool_features)

    def forward(self, x):
        branch1x1 = self.branch1x1(x)

        branch5x5 = self.branch5x5_1(x)
        branch5x5 = self.branch5x5_2(branch5x5)

        branch3x3dbl = self.branch3x3dbl_1(x)
        branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
        branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)

        # Patch: Tensorflow's average pool does not use the padded zero's in
        # its average calculation
        branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
                                   count_include_pad=False)
        branch_pool = self.branch_pool(branch_pool)

        outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
        return torch.cat(outputs, 1)

class FIDInceptionC(torchvision.models.inception.InceptionC):
    '''InceptionC block patched for FID computation'''
    def __init__(self, in_channels, channels_7x7):
        super(FIDInceptionC, self).__init__(in_channels, channels_7x7)

    def forward(self, x):
        branch1x1 = self.branch1x1(x)

        branch7x7 = self.branch7x7_1(x)
        branch7x7 = self.branch7x7_2(branch7x7)
        branch7x7 = self.branch7x7_3(branch7x7)

        branch7x7dbl = self.branch7x7dbl_1(x)
        branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
        branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
        branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
        branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)

        # Patch: Tensorflow's average pool does not use the padded zero's in
        # its average calculation
        branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
                                   count_include_pad=False)
        branch_pool = self.branch_pool(branch_pool)

        outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
        return torch.cat(outputs, 1)

class FIDInceptionE_1(torchvision.models.inception.InceptionE):
    '''First InceptionE block patched for FID computation'''
    def __init__(self, in_channels):
        super(FIDInceptionE_1, self).__init__(in_channels)

    def forward(self, x):
        branch1x1 = self.branch1x1(x)

        branch3x3 = self.branch3x3_1(x)
        branch3x3 = [
            self.branch3x3_2a(branch3x3),
            self.branch3x3_2b(branch3x3),
        ]
        branch3x3 = torch.cat(branch3x3, 1)

        branch3x3dbl = self.branch3x3dbl_1(x)
        branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
        branch3x3dbl = [
            self.branch3x3dbl_3a(branch3x3dbl),
            self.branch3x3dbl_3b(branch3x3dbl),
        ]
        branch3x3dbl = torch.cat(branch3x3dbl, 1)

        # Patch: Tensorflow's average pool does not use the padded zero's in
        # its average calculation
        branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
                                   count_include_pad=False)
        branch_pool = self.branch_pool(branch_pool)

        outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
        return torch.cat(outputs, 1)

class FIDInceptionE_2(torchvision.models.inception.InceptionE):
    '''Second InceptionE block patched for FID computation'''
    def __init__(self, in_channels):
        super(FIDInceptionE_2, self).__init__(in_channels)

    def forward(self, x):
        branch1x1 = self.branch1x1(x)

        branch3x3 = self.branch3x3_1(x)
        branch3x3 = [
            self.branch3x3_2a(branch3x3),
            self.branch3x3_2b(branch3x3),
        ]
        branch3x3 = torch.cat(branch3x3, 1)

        branch3x3dbl = self.branch3x3dbl_1(x)
        branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
        branch3x3dbl = [
            self.branch3x3dbl_3a(branch3x3dbl),
            self.branch3x3dbl_3b(branch3x3dbl),
        ]
        branch3x3dbl = torch.cat(branch3x3dbl, 1)

        # Patch: The FID Inception model uses max pooling instead of average
        # pooling.
        branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)
        branch_pool = self.branch_pool(branch_pool)

        outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
        return torch.cat(outputs, 1)

# Define the block index by the feature dimension
block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[2048]
# Create the InceptionV3 model instance with the specified block index
model = InceptionV3([block_idx])
# Move the model to the GPU
inception_v3 = model.cuda()

## Fréchet Inception Distance (FID)
#### FID measures the distance between two Gaussian distributions approximated from the features of generated and real images. To calculate FID, the mean and covariance of the features extracted from real and generated images are computed. FID is defined as:

\begin{equation}
        FID = ||\mu_r - \mu_g||^2 + Tr(\sigma_r + \sigma_g - 2(\sigma_r\sigma_g)^{1/2})
\end{equation}

#### where:
* #### $\mu_r$ and $\mu_g$ are the mean vectors of the features of real and generated images, respectively.
* #### $\sigma_r$ and $\sigma_g$ are the covariance matrices of the features of real and generated images,   respectively.

#### As follows, the implementation of FID for each epoch:


In [None]:
def epoch_calculate_activation_statistics(images, model, batch_size=128, dims=2048, cuda=False):
    # Set the model to evaluation mode
    model.eval()  
    
    # Select device
    if cuda:
        model = model.cuda()  
        images = images.cuda()
    else:
        model = model.cpu()  
        images = images.cpu()  

    act = np.empty((len(images), dims))
    
    # No need to track gradients for this operation
    with torch.no_grad():  
        pred = model(images)
        pred = pred[0]
        
        # Check if the output is 4D (batch, channels, height, width)
        if pred.dim() == 4:  
            if pred.size(2) != 1 or pred.size(3) != 1:
                pred = adaptive_avg_pool2d(pred, output_size=(1, 1))
            pred = pred.view(pred.size(0), -1)
        
        # Check if the output is 2D (batch, features)
        elif pred.dim() == 2:  
            pred = pred
        else:
            raise RuntimeError("Unexpected output dimensions from the model.")

        act = pred.cpu().numpy()

    mu = np.mean(act, axis=0)
    sigma = np.cov(act, rowvar=False)
    return mu, sigma

def epoch_calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
    mu1 = np.atleast_1d(mu1)
    mu2 = np.atleast_1d(mu2)

    sigma1 = np.atleast_2d(sigma1)
    sigma2 = np.atleast_2d(sigma2)

    assert mu1.shape == mu2.shape, \
        'Training and test mean vectors have different lengths'
    assert sigma1.shape == sigma2.shape, \
        'Training and test covariances have different dimensions'

    diff = mu1 - mu2

    covmean, _ = sqrtm(sigma1.dot(sigma2), disp=False)
    if not np.isfinite(covmean).all():
        msg = ('FID calculation produces singular product; '
               'adding %s to diagonal of cov estimates') % eps
        print(msg)
        offset = np.eye(sigma1.shape[0]) * eps
        covmean = sqrtm((sigma1 + offset).dot(sigma2 + offset))
   
    if np.iscomplexobj(covmean):
        if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
            m = np.max(np.abs(covmean.imag))
            raise ValueError('Imaginary component {}'.format(m))
        covmean = covmean.real

    tr_covmean = np.trace(covmean)

    return (diff.dot(diff) + np.trace(sigma1) +
            np.trace(sigma2) - 2 * tr_covmean)


def epoch_calculate_fretchet(images_real, images_fake, model):
    '''Calculate final value '''
    mu_1, std_1 = epoch_calculate_activation_statistics(images_real, model, cuda=True)   
    mu_2, std_2 = epoch_calculate_activation_statistics(images_fake, model, cuda=True)
    
    # Get Frechet distance
    fid_value = epoch_calculate_frechet_distance(mu_1, std_1, mu_2, std_2)
    return fid_value

#### Computing FID on the entire dataset after training : 

In [None]:
# FID from official "https://github.com/mseitzer/pytorch-fid/blob/master/src/pytorch_fid/fid_score.py", with a slight modification due to imaginary components of Covariance Matrix 
  
# Supported image formats
IMAGE_EXTENSIONS = {"bmp", "jpg", "jpeg", "pgm", "png", "ppm", "tif", "tiff", "webp"}  

class ImagePathDataset(torch.utils.data.Dataset):
    '''ImagePathDataset is a class for handle the images'''
    def __init__(self, files, transforms=None):
        self.files = files  # List of image file paths
        self.transforms = transforms  # Optional transforms to apply to images

    def __len__(self):
        '''Return the number of images'''
        return len(self.files)  

    def __getitem__(self, i):
        '''Return the item at position i'''
        path = self.files[i]  # Get the image path
        img = Image.open(path).convert("RGB")  # Open image and convert to RGB
        if self.transforms is not None:
            img = self.transforms(img)  # Apply transforms if any
        return img  

def get_activations(files, model, batch_size=50, dims=2048, device="cpu", num_workers=1):
    model.eval()  # Set model to evaluation mode

    if batch_size > len(files):
        print("Warning: batch size is bigger than the data size. Setting batch size to data size")
        batch_size = len(files)  # Adjust batch size if it's larger than the number of files

    dataset = ImagePathDataset(files, transforms=TF.ToTensor())  # Create dataset from image files
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=False,
        drop_last=False,
        num_workers=num_workers,
    )  # Create dataloader for batching

    pred_arr = np.empty((len(files), dims))  # Initialize array to hold activations

    start_idx = 0  # Start index for filling pred_arr

    for batch in tqdm(dataloader):
        batch = batch.to(device)  # Move batch to device

        with torch.no_grad():
            pred = model(batch)[0]  # Get model predictions

        if pred.size(2) != 1 or pred.size(3) != 1:
            pred = adaptive_avg_pool2d(pred, output_size=(1, 1))  # Apply adaptive average pooling

        pred = pred.squeeze(3).squeeze(2).cpu().numpy()  # Convert predictions to numpy array

        pred_arr[start_idx : start_idx + pred.shape[0]] = pred  # Store predictions in pred_arr

        start_idx = start_idx + pred.shape[0]  # Update start index

    return pred_arr  # Return the array of activations

def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
    '''Compute FID value'''
    mu1 = np.atleast_1d(mu1)
    mu2 = np.atleast_1d(mu2)

    sigma1 = np.atleast_2d(sigma1)
    sigma2 = np.atleast_2d(sigma2)

    # Ensure that sigma1 and sigma2 have the same shape
    max_shape = max(sigma1.shape[0], sigma2.shape[0])
    sigma1 = np.pad(sigma1, ((0, max_shape - sigma1.shape[0]), (0, max_shape - sigma1.shape[1])), mode='constant')
    sigma2 = np.pad(sigma2, ((0, max_shape - sigma2.shape[0]), (0, max_shape - sigma2.shape[1])), mode='constant')

    assert mu1.shape == mu2.shape, "Training and test mean vectors have different lengths"
    assert sigma1.shape == sigma2.shape, "Training and test covariances have different dimensions"

    diff = mu1 - mu2

    # Add a small value to the diagonal to improve numerical stability
    offset1 = np.eye(sigma1.shape[0]) * eps
    offset2 = np.eye(sigma2.shape[0]) * eps

    covmean, _ = linalg.sqrtm((sigma1 + offset1).dot(sigma2 + offset2), disp=False)
    if not np.isfinite(covmean).all():
        msg = "fid calculation produces singular product; adding %s to diagonal of cov estimates" % eps
        print(msg)
        covmean = linalg.sqrtm((sigma1 + offset1).dot(sigma2 + offset2))

    if np.iscomplexobj(covmean):
        if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
            m = np.max(np.abs(covmean.imag))
            raise ValueError("Imaginary component {}".format(m))
        covmean = covmean.real

    tr_covmean = np.trace(covmean)

    return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean

def calculate_activation_statistics(files, model, batch_size=50, dims=2048, device="cpu", num_workers=1):
    act = get_activations(files, model, batch_size, dims, device, num_workers)  # Get activations for the files
    mu = np.mean(act, axis=0)  # Calculate mean of activations
    sigma = np.cov(act, rowvar=False)  # Calculate covariance of activations
    return mu, sigma  

def compute_statistics_of_path(path, model, batch_size, dims, device, num_workers=1):
    if path.endswith(".npz"):
        with np.load(path) as f:
            m, s = f["mu"][:], f["sigma"][:]  # Load precomputed statistics
    else:
        path = pathlib.Path(path)
        files = sorted([file for ext in IMAGE_EXTENSIONS for file in path.glob("*.{}".format(ext))])  # Gather image files
        m, s = calculate_activation_statistics(files, model, batch_size, dims, device, num_workers)  # Compute statistics

    return m, s  # Return mean and covariance

def calculate_fid_given_paths(paths, batch_size, device, dims, num_workers=1):
    for p in paths:
        if not os.path.exists(p):
            raise RuntimeError("Invalid path: %s" % p)  # Check if paths exist

    block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]  # Get block index for InceptionV3

    model = InceptionV3([block_idx]).to(device)  # Initialize InceptionV3 model

    m1, s1 = compute_statistics_of_path(paths[0], model, batch_size, dims, device, num_workers)  # Compute statistics for first path
    m2, s2 = compute_statistics_of_path(paths[1], model, batch_size, dims, device, num_workers)  # Compute statistics for second path
    fid_value = calculate_frechet_distance(m1, s1, m2, s2)  # Calculate FID

    return fid_value  # Return FID value

def save_fid_stats(paths, batch_size, device, dims, num_workers=1):
    if not os.path.exists(paths[0]):
        raise RuntimeError("Invalid path: %s" % paths[0])  # Check if input path exists

    if os.path.exists(paths[1]):
        raise RuntimeError("Existing output file: %s" % paths[1])  # Check if output file already exists

    block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]  # Get block index for InceptionV3

    model = InceptionV3([block_idx]).to(device)  # Initialize InceptionV3 model

    print(f"Saving statistics for {paths[0]}")

    m1, s1 = compute_statistics_of_path(paths[0], model, batch_size, dims, device, num_workers)  # Compute statistics for input path

    np.savez_compressed(paths[1], mu=m1, sigma=s1)  # Save statistics to file

## Kernel Inception Distance (KID)

#### KID measures the distance between the distributions of features from generated images and real images using the Maximum Mean Discrepancy (MMD) with a specific kernel, typically the Gaussian kernel.

#### MMD is a metric that evaluates how different two sets of samples (in this case, the features of the images) are. 

#### As follows, the implementation of KID for each epoch:


In [None]:
def epoch_polynomial_mmd(codes_g, codes_r, degree=3, gamma=None, coef0=1, var_at_m=None, ret_var=True):
    '''Compute the polynomial MMD (Maximum Mean Discrepancy)'''
    # Compute the polynomial kernel for generated vs generated, real vs real, and generated vs real
    K_XX = polynomial_kernel(codes_g, codes_g, degree=degree, gamma=gamma, coef0=coef0)
    K_YY = polynomial_kernel(codes_r, codes_r, degree=degree, gamma=gamma, coef0=coef0)
    K_XY = polynomial_kernel(codes_g, codes_r, degree=degree, gamma=gamma, coef0=coef0)

    # Calculate the MMD value
    m = K_XX.shape[0]
    mmd2 = (np.mean(K_XX) + np.mean(K_YY) - 2 * np.mean(K_XY))

    if not ret_var:
        return mmd2

    return mmd2, None

def epoch_polynomial_mmd_averages(codes_g, codes_r, n_subsets=10, subset_size=1000, ret_var=True, output=sys.stdout, **kernel_args):
    '''Compute the polynomial MMD averages over multiple subsets'''
    # Adjust subset size if it's larger than the number of available codes
    actual_subset_size = min(subset_size, len(codes_g), len(codes_r))

    m = min(len(codes_g), len(codes_r))
    mmds = np.zeros(n_subsets)
    vars = np.zeros(n_subsets) if ret_var else None
    choice = np.random.choice

    # Iterate over the number of subsets and compute MMD for each subset
    with tqdm(range(n_subsets), desc='MMD', file=output) as bar:
        for i in bar:
            if actual_subset_size < subset_size:
                # If the actual subset size is smaller than desired, allow replacement
                g = codes_g[choice(len(codes_g), actual_subset_size, replace=True)]
                r = codes_r[choice(len(codes_r), actual_subset_size, replace=True)]
            else:
                g = codes_g[choice(len(codes_g), actual_subset_size, replace=False)]
                r = codes_r[choice(len(codes_r), actual_subset_size, replace=False)]
            o = epoch_polynomial_mmd(g, r, **kernel_args, var_at_m=m, ret_var=ret_var)
            if ret_var:
                mmds[i], vars[i] = o
            else:
                mmds[i] = o
            bar.set_postfix({'mean': mmds[:i+1].mean()})
    return (mmds, vars) if ret_var else mmds

def epoch_calculate_kid_given_activations(activations_real, activations_fake):
    '''Compute KID (Kernel Inception Distance) given activations'''
    return epoch_polynomial_mmd_averages(activations_real, activations_fake, n_subsets=10)

def epoch_calculate_activations(images, model, cuda=False):
    '''Compute activations of images using the model'''
    model.eval()
    batch_size = images.size(0)
    if cuda:
        images = images.cuda()
        model.cuda()
    with torch.no_grad():
        pred = model(images)
        pred = pred[0]
        # Check if the output is 4D (batch, channels, height, width)
        if pred.dim() == 4:  
            pred = adaptive_avg_pool2d(pred, output_size=(1, 1))
        pred = pred.view(batch_size, -1)
    # Return as numpy array for consistency with KID function
    return pred.cpu().numpy()  

def epoch_compute_mmd_simple(codes_g, codes_r, degree=3, gamma=None, coef0=1):
    '''Compute MMD (Maximum Mean Discrepancy) using a simple polynomial kernel'''
    if gamma is None:
        # Default gamma is 1/number of features
        gamma = 1.0 / codes_g.shape[1]  

    # Compute the polynomial kernel for generated vs generated, real vs real, and generated vs real
    K_XX = polynomial_kernel(codes_g, codes_g, degree=degree, gamma=gamma, coef0=coef0)
    K_YY = polynomial_kernel(codes_r, codes_r, degree=degree, gamma=gamma, coef0=coef0)
    K_XY = polynomial_kernel(codes_g, codes_r, degree=degree, gamma=gamma, coef0=coef0)

    # Calculate the MMD value
    mmd2 = np.mean(K_XX) + np.mean(K_YY) - 2 * np.mean(K_XY)
    return mmd2

#### Computing KID on the entire dataset after training : 

In [None]:
""" Computing KID on the entire dataset """

def get_activations(files, model, batch_size=1, dims=2048,
                    cuda=False, verbose=False):
    """Calculates the activations of the pool_3 layer for all images."""
    # Set the model to evaluation mode
    model.eval()
    # Determine if the input is numpy arrays
    is_numpy = True if type(files[0]) == np.ndarray else False

    # Check if the number of images is a multiple of the batch size
    if len(files) % batch_size != 0:
        print(('Warning: number of images is not a multiple of the '
               'batch size. Some samples are going to be ignored.'))
    # Adjust batch size if it is larger than the number of images
    if batch_size > len(files):
        print(('Warning: batch size is bigger than the data size. '
               'Setting batch size to data size'))
        batch_size = len(files)

    # Calculate the number of batches and the number of images used
    n_batches = len(files) // batch_size
    n_used_imgs = n_batches * batch_size

    # Initialize an array to store the activations
    pred_arr = np.empty((n_used_imgs, dims))

    # Loop through each batch
    for i in tqdm(range(n_batches)):
        if verbose:
            print('\rPropagating batch %d/%d' % (i + 1, n_batches), end='', flush=True)
        start = i * batch_size
        end = start + batch_size

        # Preprocess images based on their format
        if is_numpy:
            images = np.copy(files[start:end]) + 1
            images /= 2.
        else:
            images = [np.array(Image.open(str(f))) for f in files[start:end]]
            images = np.stack(images).astype(np.float32) / 255.
            images = torch.from_numpy(images)
            if len(images.shape) == 3:
                images = torch.unsqueeze(images, dim=-1).expand(-1, -1, -1, 3)
            images = images.permute((0, 3, 1, 2))

        batch = images.float()
        if cuda:
            batch = batch.cuda()

        # Get the model predictions
        pred = model(batch)[0]

        # Apply adaptive average pooling if the output dimensions are not 1x1
        if pred.shape[2] != 1 or pred.shape[3] != 1:
            pred = adaptive_avg_pool2d(pred, output_size=(1, 1))

        # Store the activations in the array
        pred_arr[start:end] = pred.cpu().data.numpy().reshape(batch_size, -1)

    if verbose:
        print('done', torch.min(images))

    return pred_arr

def extract_lenet_features(imgs, net):
    """Extract features using a LeNet model."""
    net.eval()
    feats = []
    imgs = imgs.reshape([-1, 100] + list(imgs.shape[1:]))
    if imgs[0].min() < -0.001:
        imgs = (imgs + 1) / 2.0
    print(imgs.shape, imgs.min(), imgs.max())
    imgs = torch.from_numpy(imgs)
    for i, images in enumerate(imgs):
        feats.append(net.extract_features(images).detach().cpu().numpy())
    feats = np.vstack(feats)
    return feats

def _compute_activations(path, model, batch_size, dims, cuda, model_type):
    """Compute activations for the given path using the specified model."""
    if not type(path) == np.ndarray:
        import glob
        jpg = os.path.join(path, '*.jpg')
        png = os.path.join(path, '*.png')
        path = glob.glob(jpg) + glob.glob(png)
        if len(path) > 50000:
            import random
            random.shuffle(path)
            path = path[:50000]
    if model_type == 'inception':
        act = get_activations(path, model, batch_size, dims, cuda)
    elif model_type == 'lenet':
        act = extract_lenet_features(path, model)
    return act

def calculate_kid_given_paths(paths, batch_size, cuda, dims, model_type='inception'):
    """Calculates the KID of two paths"""
    pths = []
    for p in paths:
        if not os.path.exists(p):
            raise RuntimeError('Invalid path: %s' % p)
        if os.path.isdir(p):
            pths.append(p)
        elif p.endswith('.npy'):
            np_imgs = np.load(p)
            if np_imgs.shape[0] > 50000: np_imgs = np_imgs[np.random.permutation(np.arange(np_imgs.shape[0]))][:50000]
            pths.append(np_imgs)

    if model_type == 'inception':
        block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
        model = InceptionV3([block_idx])
    elif model_type == 'lenet':
        model = LeNet5()
        model.load_state_dict(torch.load('./models/lenet.pth'))
    if cuda:
        model.cuda()

    act_true = _compute_activations(pths[0], model, batch_size, dims, cuda, model_type)
    pths = pths[1:]
    results = []
    for j, pth in enumerate(pths):
        print(paths[j + 1])
        actj = _compute_activations(pth, model, batch_size, dims, cuda, model_type)
        kid_values = polynomial_mmd_averages(act_true, actj, n_subsets=100)
        results.append((paths[j + 1], kid_values[0].mean(), kid_values[0].std()))
    return results

def _sqn(arr):
    """Square norm of the flattened array."""
    flat = np.ravel(arr)
    return flat.dot(flat)

def polynomial_mmd_averages(codes_g, codes_r, n_subsets=50, subset_size=1000,
                            ret_var=True, output=sys.stdout, **kernel_args):
    """Compute MMD averages over several subsets."""
    m = min(codes_g.shape[0], codes_r.shape[0])
    subset_size = min(subset_size, m)  # Ensure subset_size is not larger than available samples
    mmds = np.zeros(n_subsets)
    if ret_var:
        vars = np.zeros(n_subsets)
    choice = np.random.choice

    with tqdm(range(n_subsets), desc='MMD', file=output) as bar:
        for i in bar:
            g = codes_g[choice(len(codes_g), subset_size, replace=False)]
            r = codes_r[choice(len(codes_r), subset_size, replace=False)]
            o = polynomial_mmd(g, r, **kernel_args, var_at_m=m, ret_var=ret_var)
            if ret_var:
                mmds[i], vars[i] = o
            else:
                mmds[i] = o
            bar.set_postfix({'mean': mmds[:i+1].mean()})
    return (mmds, vars) if ret_var else mmds

def polynomial_mmd(codes_g, codes_r, degree=3, gamma=None, coef0=1,
                   var_at_m=None, ret_var=True):
    """Compute polynomial MMD between two sets of codes."""
    X = codes_g
    Y = codes_r

    K_XX = polynomial_kernel(X, degree=degree, gamma=gamma, coef0=coef0)
    K_YY = polynomial_kernel(Y, degree=degree, gamma=gamma, coef0=coef0)
    K_XY = polynomial_kernel(X, Y, degree=degree, gamma=gamma, coef0=coef0)

    return _mmd2_and_variance(K_XX, K_XY, K_YY,
                              var_at_m=var_at_m, ret_var=ret_var)

def _mmd2_and_variance(K_XX, K_XY, K_YY, unit_diagonal=False,
                       mmd_est='unbiased', block_size=1024,
                       var_at_m=None, ret_var=True):
    """Calculate the MMD2 (Maximum Mean Discrepancy) and its variance."""
    
    m = K_XX.shape[0]  # Number of samples

    # Ensure that the kernel matrices have the correct shapes
    assert K_XX.shape == (m, m)
    assert K_XY.shape == (m, m)
    assert K_YY.shape == (m, m)

    # If var_at_m is not provided, set it to m
    if var_at_m is None:
        var_at_m = m

    # Initialize diagonal and sum variables based on whether unit diagonal is used
    if unit_diagonal:
        diag_X = diag_Y = 1
        sum_diag_X = sum_diag_Y = m
        sum_diag2_X = sum_diag2_Y = m
    else:
        diag_X = np.diagonal(K_XX)
        diag_Y = np.diagonal(K_YY)
        sum_diag_X = diag_X.sum()
        sum_diag_Y = diag_Y.sum()
        sum_diag2_X = _sqn(diag_X)
        sum_diag2_Y = _sqn(diag_Y)

    # Calculate the sum of the kernel matrices excluding the diagonal
    Kt_XX_sums = K_XX.sum(axis=1) - diag_X
    Kt_YY_sums = K_YY.sum(axis=1) - diag_Y
    K_XY_sums_0 = K_XY.sum(axis=0)
    K_XY_sums_1 = K_XY.sum(axis=1)

    Kt_XX_sum = Kt_XX_sums.sum()
    Kt_YY_sum = Kt_YY_sums.sum()
    K_XY_sum = K_XY_sums_0.sum()

    # Calculate MMD2 based on the chosen estimation method
    if mmd_est == 'biased':
        mmd2 = ((Kt_XX_sum + sum_diag_X) / (m * m)
                + (Kt_YY_sum + sum_diag_Y) / (m * m)
                - 2 * K_XY_sum / (m * m))
    else:
        assert mmd_est in {'unbiased', 'u-statistic'}
        mmd2 = (Kt_XX_sum + Kt_YY_sum) / (m * (m - 1))
        if mmd_est == 'unbiased':
            mmd2 -= 2 * K_XY_sum / (m * m)
        else:
            mmd2 -= 2 * (K_XY_sum - np.trace(K_XY)) / (m * (m - 1))

    # Return MMD2 if variance is not required
    if not ret_var:
        return mmd2

    # Calculate terms needed for variance estimation
    Kt_XX_2_sum = _sqn(K_XX) - sum_diag2_X
    Kt_YY_2_sum = _sqn(K_YY) - sum_diag2_Y
    K_XY_2_sum = _sqn(K_XY)
    dot_XX_XY = Kt_XX_sums.dot(K_XY_sums_1)
    dot_YY_YX = Kt_YY_sums.dot(K_XY_sums_0)
    
    if m <= 2:
        # Return zero variance if m is less than or equal to 2
        return mmd2, 0
    
    m1 = m - 1
    m2 = m - 2

    # Estimate zeta1 and zeta2 for variance calculation
    zeta1_est = (
        1 / (m * m1 * m2) * (
            _sqn(Kt_XX_sums) - Kt_XX_2_sum + _sqn(Kt_YY_sums) - Kt_YY_2_sum)
        - 1 / (m * m1) ** 2 * (Kt_XX_sum ** 2 + Kt_YY_sum ** 2)
        + 1 / (m * m * m1) * (
            _sqn(K_XY_sums_1) + _sqn(K_XY_sums_0) - 2 * K_XY_2_sum)
        - 2 / m ** 4 * K_XY_sum ** 2
        - 2 / (m * m * m1) * (dot_XX_XY + dot_YY_YX)
        + 2 / (m ** 3 * m1) * (Kt_XX_sum + Kt_YY_sum) * K_XY_sum
    )
    zeta2_est = (
        1 / (m * m1) * (Kt_XX_2_sum + Kt_YY_2_sum)
        - 1 / (m * m1) ** 2 * (Kt_XX_sum ** 2 + Kt_YY_sum ** 2)
        + 2 / (m * m) * K_XY_2_sum
        - 2 / m ** 4 * K_XY_sum ** 2
        - 4 / (m * m * m1) * (dot_XX_XY + dot_YY_YX)
        + 4 / (m ** 3 * m1) * (Kt_XX_sum + Kt_YY_sum) * K_XY_sum
    )
    var_est = (4 * (var_at_m - 2) / (var_at_m * (var_at_m - 1)) * zeta1_est
               + 2 / (var_at_m * (var_at_m - 1)) * zeta2_est)

    return mmd2, var_est  # Return MMD2 and its variance

## Training SB Model 

#### The training loop is defined as follows : 

In [None]:
if __name__ == '__main__':

    # Model 
    sb_model_compl = SBModel().to(device)
    
    total_iters = 0      
    optimize_time = 0.1
    epoch_count = 1
    n_epochs = 90
    n_epochs_decay = 90
    print_freq = 100 
    gpu_ids = [1]
    output_dir = "/kaggle/working/generated_images"
    
    # Lists for Losses, FID and KID metrics 
    losses_list = []
    fid_list = []
    kid_list = []
    
    sb_model_compl.train()
    
    # Training 
    times = []
    for epoch in range(epoch_count, n_epochs + n_epochs_decay +1):    
        epoch_start_time = time.time()  
        iter_data_time = time.time()    
        epoch_iter = 0     # the number of training iterations in current epoch, reset to 0 every epoch
         
        
        for i, ((dataA, dataB), (dataA2, dataB2)) in enumerate(zip(zip(train_dataloaderA, train_dataloaderB), zip(train_dataloaderA, train_dataloaderB))):  
            dataA = dataA.to(device)
            dataB = dataB.to(device)
            dataA2 = dataA2.to(device)
            dataB2 = dataB2.to(device)
            
            iter_start_time = time.time()  
            if total_iters % print_freq == 0:
                t_data = iter_start_time - iter_data_time

            batch_size = 1
            total_iters += batch_size
            epoch_iter += batch_size
            if len(gpu_ids) > 0:
                torch.cuda.synchronize()
            optimize_start_time = time.time()
            if epoch == epoch_count and i == 0:
                sb_model_compl.data_dependent_initialize(dataA,dataB, dataA2, dataB2)
            sb_model_compl.set_input(dataA,dataB, dataA2, dataB2)  # unpack data from dataset and apply preprocessing
            sb_model_compl.optimize_parameters()   # calculate loss functions, get gradients, update network weights
            
            torch.cuda.synchronize()
            optimize_time = (time.time() - optimize_start_time) / batch_size * 0.005 + 0.995 * optimize_time
            
            if total_iters % print_freq == 0:    # print training losses and save logging information 
                losses = sb_model_compl.get_current_losses()
                print(losses)
                
                

            iter_data_time = time.time()

        
        print('End of epoch %d / %d \t Time Taken: %d sec' % (epoch, n_epochs + n_epochs_decay, time.time() - epoch_start_time))
         
        # Visualize and save the generated fake images at the end of each epoch
        sb_model_compl.forward()  
        fake_B_images = sb_model_compl.fake_B
        fake_B2_images = sb_model_compl.fake_B2
        real_A_images = sb_model_compl.real_A
        real_B_images = sb_model_compl.real_B
        visualize_images(fake_B_images, title=f"Generated Zebras at Epoch {epoch}")
        visualize_images(fake_B2_images, title=f"Generated Zebras 2 at Epoch {epoch}")
        
        # Save the generated images
        save_image(fake_B_images, os.path.join(generated_images, f"generated_zebras_epoch{epoch}.png"))
        save_image(fake_B2_images, os.path.join(generated_images, f"generated_zebras_2_epoch{epoch}.png"))
        
        # Compute FID 
        fretchet_dist= epoch_calculate_fretchet(dataB,sb_model_compl.fake_B.to(device),inception_v3)
        print(f'Epoch {epoch}: FID:', fretchet_dist)
        
        # Compute activations for KID per epoch 
        activations_real = epoch_calculate_activations(dataB, inception_v3, cuda=True)
        activations_fake = epoch_calculate_activations(sb_model_compl.fake_B.to(device), inception_v3, cuda=True)

        # Calculate KID
        kid_value = epoch_compute_mmd_simple(activations_real, activations_fake)
        print(f'Epoch {epoch}: KID: {kid_value}')
        
        losses_list.append(losses)
        fid_list.append(fretchet_dist)
        kid_list.append(kid_value)

#### Computing FID and KID on the entire dataset after training : 

In [None]:
DIMS = 2048
PATHS = ["/kaggle/input/horse2zebra-new/horse2zebra/small_trainB", "/kaggle/working/generated_images"]  
NUM_WORKERS = 1
MODEL_TYPE = 'inception'

fid_value = calculate_fid_given_paths(PATHS, BATCH_SIZE, device, DIMS, NUM_WORKERS)
print("FID: ", fid_value)

results = calculate_kid_given_paths(PATHS, BATCH_SIZE, device, DIMS, model_type= MODEL_TYPE)
for p, m, s in results:
    print('KID (%s): %.3f (%.3f)' % (p, m, s))

#### Plotting losses, KID and FID for training: 

In [None]:
# Define the range of epochs based on the length of losses_list
epochs = range(1, len(losses_list) + 1)

# Initialize dictionaries to store the metrics across epochs
# Initialize lists based on keys in the first dictionary of losses_list
metrics = {key: [] for key in losses_list[0]} 

# Populate the lists with values for each epoch
for loss_dict in losses_list:
    for key, value in loss_dict.items():
        metrics[key].append(value)

# Create a plot figure
plt.figure(figsize=(12, 8))

# Plot each metric stored in the metrics dictionary
for metric, values in metrics.items():
    plt.plot(values, label=metric)

plt.title('Training Loss Metrics over Epochs')
plt.xlabel('Epochs')
plt.ylabel('Loss Value')
plt.legend()
plt.grid(True)
plt.show()

# Plot FID
plt.figure(figsize=(10, 5))
plt.plot(epochs, fid_list, label='FID')
plt.xlabel('Epochs')
plt.ylabel('FID Value')
plt.title('Fréchet Inception Distance (FID)')
plt.legend()
plt.grid(True)
plt.show()

# Plot KID
plt.figure(figsize=(10, 5))
plt.plot(epochs, kid_list, label='KID')
plt.xlabel('Epochs')
plt.ylabel('KID Value')
plt.title('Kernel Inception Distance (KID)')
plt.legend()
plt.grid(True)
plt.show()

#### Saving the model : 

In [None]:
# saving the model
torch.save(sb_model_compl.state_dict(), "/kaggle/working/our_pretrained_sb_model.pth")

## SB Test 

#### For the testing phase, we defined the same model, adapting the functions:

In [None]:
class SBModel_test(nn.Module):
    """ Initializes the SBModel class, setting up parameters, loss names, model names, visual names, optimizers, and other necessary configurations """
    def __init__(self):
        super(SBModel_test,self).__init__()
        self.visual_names = ['real']
        self.T = 5
        for NFE in range(self.T):
                fake_name = 'fake_' + str(NFE+1)
                self.visual_names.append(fake_name)
                
        self.tau = 0.1 
        self.device = device
        self.netG = gen
        self.ngf = 64
        self.lr = 0.00001
        self.beta1 = 0.5
        self.beta2 = 0.999
        
        
    def data_dependent_initialize(self, dataA,dataB): 
        """ Initializes the model using input data. It sets the input, trims the batch size, and performs a forward pass 
        to prepare the model for evaluation """
        
        bs = 1
        self.set_input(dataA,dataB)
        self.real_A = self.real_A[:bs]
        self.real_B = self.real_B[:bs]
        self.forward()   
    
    def set_input(self, dataA, dataB):
        """ Responsible for unpacking input data from the dataloader and performing any necessary preprocessing steps """
        self.real_A = dataA.to(device)
        self.real_B = dataB.to(device)
    
    
    def forward(self):
        """ It processes the input tensor through a series of timesteps, refining it iteratively by 
        blending it with noise and the generator's output at each step. The time steps are predefined and normalized, 
        and a random time index is selected to determine which time step to use for each batch element. The generator 
        is used in evaluation mode to produce outputs based on the current input tensor, the time step, and some random noise. """
        tau = 0.01
        T = 5
        incs = np.array([0] + [1/(i+1) for i in range(T-1)])
        times = np.cumsum(incs)
        times = times / times[-1]
        times = 0.5 * times[-1] + 0.5 * times
        times = torch.tensor(times).float().cuda()
        self.times = times
        bs =  1
        time_idx = (torch.randint(T, size=[1]).cuda() * torch.ones(size=[bs]).cuda()).long()
        self.time_idx = time_idx
        self.timestep     = times[time_idx]
        with torch.no_grad():
            self.netG.eval()
            for t in range(T):
                if t > 0:
                    delta = times[t] - times[t-1]
                    denom = times[-1] - times[t-1]
                    inter = (delta / denom)
                    scale = (delta * (1 - delta / denom))
                Xt       = self.real_A if (t == 0) else (1-inter) * Xt + inter * Xt_1.detach() + (scale * tau).sqrt() * torch.randn_like(Xt).to(self.real_A.device)
                time_idx = (t * torch.ones(size=[self.real_A.shape[0]]).to(self.real_A.device)).long()
                time     = times[time_idx]
                z        = torch.randn(size=[self.real_A.shape[0],4*self.ngf]).to(self.real_A.device)
                Xt_1     = self.netG(Xt, time_idx, z)

                self.Xt_1 = Xt_1


#### Testing the model : 

In [None]:
# Create output images directories
results_dir = '/kaggle/working/results_dir'
if not os.path.exists(results_dir):
    os.makedirs('/kaggle/working/results_dir')

In [None]:
if __name__ == '__main__':
    # Initialize test parameters
    aspect_ratio = 1.0
    
    # Hard-code some parameters for the test
    num_threads = 0   # Test code only supports num_threads = 1
    batch_size = 1    # Test code only supports batch_size = 1
    serial_batches = True  # Disable data shuffling
    no_flip = True    # No flip
    
    sb_model_test = SBModel_test().to(device)
    
    pretrained_dict = torch.load("/kaggle/working/our_pretrained_sb_model.pth")
    model_dict = sb_model_test.state_dict()
    
    fid_list_test = []
    kid_list_test = []

    # Filter out unnecessary keys
    pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict and model_dict[k].size() == v.size()}
    # Update the new model's dict with the pretrained dict
    model_dict.update(pretrained_dict)  
    sb_model_test.load_state_dict(model_dict)
    sb_model_test.eval()
    
    for i, (dataA, dataB) in enumerate(zip(test_dataloaderA, test_dataloaderB)):
        dataA = dataA.to(device)
        dataB = dataB.to(device)
        if i == 0:
            sb_model_test.data_dependent_initialize(dataA, dataB)
            sb_model_test.eval()
            
        # Unpack data from data loader
        sb_model_test.set_input(dataA, dataB)  
        # Run inference
        sb_model_test.forward()  
        
        fake_B_images = sb_model_test.Xt_1
        visualize_images(fake_B_images.to(device), title="Generated Zebras")
        
        # Save the generated images if needed
        save_image(fake_B_images, os.path.join(results_dir, f"generated_zebras_{i}.png"))
        
        #Compute FID 
        fretchet_dist = epoch_calculate_fretchet(dataB, fake_B_images.to(device), inception_v3)
        print('FID:', fretchet_dist)
        
        # Compute activations
        activations_real = epoch_calculate_activations(dataB, inception_v3, cuda=True)
        activations_fake = epoch_calculate_activations(fake_B_images.to(device), inception_v3, cuda=True)
        
        # Calculate KID
        kid_value = epoch_compute_mmd_simple(activations_real, activations_fake)
        print('KID:', kid_value)
        
        fid_list_test.append(fretchet_dist)
        kid_list_test.append(kid_value)
        

#### Plotting FID and KID for testing:

In [None]:
# Plotting FID and KID
plt.figure(figsize=(12, 6))

# Plot FID
plt.subplot(1, 2, 1)
plt.plot(fid_list_test, label='FID')
plt.xlabel('Test Sample')
plt.ylabel('FID Score')
plt.title('FID over Test Samples')
plt.legend()

# Plot KID
plt.subplot(1, 2, 2)
plt.plot(kid_list_test, label='KID')
plt.xlabel('Test Sample')
plt.ylabel('KID Score')
plt.title('KID over Test Samples')
plt.legend()

# Show plots
plt.tight_layout()
plt.show()

#### To conclude, we calculate the FID and KID values on the entire dataset after the testing phase 

In [None]:
# Define parameters here
PATHS = ["/kaggle/input/horse2zebra-new/horse2zebra/testB", "/kaggle/working/results_dir"]  # Replace with your actual paths

fid_value = calculate_fid_given_paths(PATHS, BATCH_SIZE, device, DIMS, NUM_WORKERS)
print("FID Test: ", fid_value)

kid_values = calculate_kid_given_paths(PATHS, BATCH_SIZE, device, DIMS, model_type = MODEL_TYPE)
for p, m, s in kid_values:
    print('KID Test (%s): %.3f (%.3f)' % (p, m, s))

## Extras : Further Experiments 

### We conducted several experiments to determine which of the various implemented architectures produced the best results (see **experiments** folder in our **github repository** for major details)

### 1. Feature Injection

#### During the re-implementation of the model, we tried to develop a generator that used Feature Injection.
#### This technique aims to enhance the network's flexibility for applications like image generation or style transfer by conditioning the transformations within the network on additional, contextually relevant data.
#### This process dynamically modifies the network’s behavior, which theoretically allows for richer and more context-sensitive outputs. 
#### However, a major issue was encountered with this approach, as indicated by unusually high loss values during training (G_GAN Loss: 3.2). This could be a result of the added complexity and potential instabilities introduced by direct feature modification, which might not be well-controlled by the current training regimen or data normalization strategies.
#### Visually, the generated images were unrealistic, as:

<div style="text-align:center">
    <img src="images/feature_injection_1.png" alt="Immagine 1" width="300" style="margin-right: 20px">
    <img src="images/feature_injection_2.png" alt="Immagine 2" width="300" style="margin-left: 20px">
</div>

### 2. CycleGAN Training 
#### Furthermore, before creating the final diffusion model, we tested our definitive generator and discriminator using CycleGAN training to better understand their functionality. Below are some images obtained  

<div style="display: flex; justify-content: center; align-items: center;">
  <figure style="margin: 0 20px; text-align: center;">
    <img src="images/cyclegan_1.png" width="300">
    <figcaption><strong>From horse to zebra</strong></figcaption>
  </figure>
  <figure style="margin: 0 20px; text-align: center;">
    <img src="images/cyclegan_2.png" width="300">
    <figcaption><strong>From zebra to horse</strong></figcaption>
  </figure>
</div>


### 3. Epochs and Reduced Dataset 

#### Last but not least, we want to mention that we are not training our model on the full horse2zebra dataset. As specified at the beginning, the model was trained on "Kaggle", using GPU P100. 
#### Due to limited computational resources, we opted to train the model on a smaller subset of the original dataset. Specifically, we run several experiments using  

- #### **Training Datasets**: at the beginning, 200 images for each of the horse and zebra categories in the training set. Then, we tested with 400 images for both trainA annd trainB, totaling 800 images for training 
- #### **Testing Datasets**: for testing, each category—horse and zebra—includes 120 images. 

#### We observed that increasing the number of images in the dataset also increases the computational time per epoch. With 200 images for trainA and trainB, each epoch took approximately 220 seconds. Instead, with 400 images for trainA and trainB, each epoch took about 440 seconds. Since we have only 12 hours available to run a notebook on Kaggle, we decided to limit the number of epochs accordingly.

#### Although a larger dataset allows for faster convergence, we achieved good results despite running fewer epochs than the 400 epochs used in the original paper, which utilized the entire dataset.

#### Following, we show the obtained results: 

- #### Training with 200 images for trainA and trainB for 180 epochs   

<div style="text-align:center">
    <img src="images/zebra1_results.png" alt="Zebra 1" width="30%" />
    <img src="images/zebra2_results.png" alt="Zebra 2" width="30%" />
    <img src="images/zebra1_200.png" alt="Zebra 3" width="30%" />
</div>

<div style="margin-left: 40px;">

#### Quantitative Results:

|          | FID    | KID    |
|----------|--------|--------|
| Train    | 195.77 | 0.135  |
| Test     | 120.06 | 0.089  |




#### Plotting FID, KID, and losses for training:

<div style="text-align:center">
    <img src="images/FID_train_200.png" alt="FID" width="40%" />
    <img src="images/KID_train_200.png" alt="KID" width="40%" />
    <img src="images/losses_200.png" alt="losses" width="40%" />
</div>

#### Plotting FID and KID for testing:

<div style="text-align:center">
    <img src="images/FID_test_200.png" alt="FID" width="30%" style="margin-right: 20px;" />
    <img src="images/KID_test_200.png" alt="KID" width="30%" style="margin-left: 20px;" />
</div>

- #### Training dataset with 400 images for trainA and trainB for 80 epochs 

<div style="text-align:center">
    <img src="images/zebra1_400.png" alt="Zebra 1" width="30%" />
    <img src="images/zebra2_400.png" alt="Zebra 2" width="30%" />
    <img src="images/zebra3_results.jpg" alt="Zebra 3" width="30%" />
</div>

<div style="margin-left: 40px;">

#### Results on 400 epochs:
|          | FID                | KID                |
|------------|--------------------|--------------------|
| Train | 197.76 | 0.131 |
| Test  | 101.70 | 0.082 |

#### Plotting FID, KID an losses for training:
<div style="text-align:center">
    <img src="images/FID__train_400.png" alt="FID" width="40%" />
    <img src="images/KID_train_400.png" alt="KID" width="40%" />
    <img src="images/losses_400.png" alt="losses" width="40%" />
</div>

#### Plotting FID and KID for testing:
<div style="text-align:center">
    <img src="images/FID_test_400.png" alt="FID" width="30%" style="margin-right: 20px;" />
    <img src="images/KID_test_400.png" alt="KID" width="30%" style="margin-left: 20px;" />
</div>
