## 1. Import the dependencies

In [None]:
# !pip install torch_fidelity

In [None]:
import os
import torch
import torch.nn as nn
import numpy as np
import cv2
import matplotlib.pyplot as plt
import itertools
import random
from torchvision.utils import make_grid
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import torchvision.transforms as transforms
import torch.nn.functional as F
import pandas as pd
import torch_fidelity
from tkinter import Variable


In [None]:
SEED = 42
print("Random Seed:", SEED)
random.seed(SEED)
np.random.seed(SEED)
torch_seed = torch.initial_seed()
torch.manual_seed(torch_seed)
print("Torch seed:" , torch_seed)

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

cuda = torch.cuda.is_available()
Tensor = torch.cuda.FloatTensor if cuda else torch.Tensor

## 2. Define your model

In [None]:
class ChannelAttention(nn.Module):
    def __init__(self, in_channels, reduction_ratio=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        
        self.fc = nn.Sequential(
            nn.Conv2d(in_channels, in_channels // reduction_ratio, 1, bias=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels // reduction_ratio, in_channels, 1, bias=False)
        )
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, x):
        avg_out = self.fc(self.avg_pool(x))
        max_out = self.fc(self.max_pool(x))
        out = avg_out + max_out
        return self.sigmoid(out)

class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()
        self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False)
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x_cat = torch.cat([avg_out, max_out], dim=1)
        out = self.conv(x_cat)
        return self.sigmoid(out)

class CBAM(nn.Module):
    def __init__(self, channels, reduction_ratio=16, kernel_size=7):
        super(CBAM, self).__init__()
        self.channel_att = ChannelAttention(channels, reduction_ratio)
        self.spatial_att = SpatialAttention(kernel_size)
        
    def forward(self, x):
        x = x * self.channel_att(x)
        x = x * self.spatial_att(x)
        return x

class ImprovedResidualBlock(nn.Module):
    def __init__(self, in_features, use_cbam=True):
        super(ImprovedResidualBlock, self).__init__()
        
        self.block = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(in_features, in_features, kernel_size=3, padding=0),
            nn.InstanceNorm2d(in_features),
            nn.LeakyReLU(0.2, inplace=True),
            nn.ReflectionPad2d(1),
            nn.Conv2d(in_features, in_features, kernel_size=3, padding=0),
            nn.InstanceNorm2d(in_features)
        )
        
        # Add both SE and CBAM attention
        self.se = SqueezeExcitation(in_features)
        self.use_cbam = use_cbam
        if use_cbam:
            self.cbam = CBAM(in_features)
    
    def forward(self, x):
        out = self.block(x)
        out = self.se(out)  # Apply squeeze-excitation
        if self.use_cbam:
            out = self.cbam(out)  # Apply CBAM
        return x + out  # Skip connection

class SqueezeExcitation(nn.Module):
    def __init__(self, channel, reduction=16):
        super(SqueezeExcitation, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y.expand_as(x)

class EnhancedSelfAttention(nn.Module):
    def __init__(self, in_dim):
        super(EnhancedSelfAttention, self).__init__()
        # Standard self-attention components
        self.query_conv = nn.Conv2d(in_dim, in_dim//8, kernel_size=1)
        self.key_conv = nn.Conv2d(in_dim, in_dim//8, kernel_size=1)
        self.value_conv = nn.Conv2d(in_dim, in_dim, kernel_size=1)
        
        # Positional bias for facial regions (center-weighted attention)
        self.pos_bias = nn.Parameter(torch.zeros(1, 1, 64, 64))  # For 256x256 inputs after downsampling
        nn.init.normal_(self.pos_bias, 0, 0.02)
        
        self.gamma = nn.Parameter(torch.zeros(1))
        self.softmax = nn.Softmax(dim=-1)
        
    def forward(self, x):
        batch_size, C, width, height = x.size()
        
        # Create position bias (emphasize center where face usually is)
        if width != self.pos_bias.shape[2] or height != self.pos_bias.shape[3]:
            pos_bias = F.interpolate(self.pos_bias, size=(width, height), mode='bilinear')
        else:
            pos_bias = self.pos_bias
            
        # Project to get query, key, value
        query = self.query_conv(x).view(batch_size, -1, width*height).permute(0, 2, 1)
        key = self.key_conv(x).view(batch_size, -1, width*height)
        value = self.value_conv(x).view(batch_size, -1, width*height)
        
        # Calculate attention map with positional bias
        attention = torch.bmm(query, key)
        # Add positional bias to attention scores (reshape pos_bias to match attention dimensions)
        attention = attention + pos_bias.view(1, width*height, width*height)
        attention = self.softmax(attention)
        
        # Apply attention to value
        out = torch.bmm(value, attention.permute(0, 2, 1))
        out = out.view(batch_size, C, width, height)
        
        # Apply gamma parameter and add to input
        out = self.gamma * out + x
        return out
    
class AdaptiveInstanceNorm(nn.Module):
    def __init__(self, in_channel):
        super().__init__()
        self.norm = nn.InstanceNorm2d(in_channel, affine=False)
        
    def forward(self, x, style_scale, style_bias):
        out = self.norm(x)
        # Apply learned scale and bias from style information
        out = style_scale * out + style_bias
        return out

class StyleModulatedResBlock(nn.Module):
    def __init__(self, in_features):
        super(StyleModulatedResBlock, self).__init__()
        
        # Style modulation parameters
        self.style_scale = nn.Parameter(torch.ones(1, in_features, 1, 1))
        self.style_bias = nn.Parameter(torch.zeros(1, in_features, 1, 1))
        
        # Main conv blocks
        self.conv1 = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(in_features, in_features, kernel_size=3, padding=0),
        )
        self.adain1 = AdaptiveInstanceNorm(in_features)
        self.act1 = nn.LeakyReLU(0.2, inplace=True)
        
        self.conv2 = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(in_features, in_features, kernel_size=3, padding=0),
        )
        self.adain2 = AdaptiveInstanceNorm(in_features)
        
        # Add SE attention
        self.se = SqueezeExcitation(in_features)
        
    def forward(self, x):
        residual = x
        
        out = self.conv1(x)
        out = self.adain1(out, self.style_scale, self.style_bias)
        out = self.act1(out)
        
        out = self.conv2(out)
        out = self.adain2(out, self.style_scale, self.style_bias)
        
        # Apply SE attention
        out = self.se(out)
        
        return residual + out

In [None]:
class Discriminator(nn.Module):
    def __init__(self, in_channels=3, features_d=64):
        super(Discriminator, self).__init__()
        
        self.scale_factor = 16
        
        # Feature extraction layers - shared between patch and global discriminators
        self.feature_extraction = nn.Sequential(
            nn.utils.spectral_norm(nn.Conv2d(in_channels, features_d, kernel_size=4, stride=2, padding=1)),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.utils.spectral_norm(nn.Conv2d(features_d, features_d*2, kernel_size=4, stride=2, padding=1)),
            nn.InstanceNorm2d(features_d*2),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.utils.spectral_norm(nn.Conv2d(features_d*2, features_d*4, kernel_size=4, stride=2, padding=1)),
            nn.InstanceNorm2d(features_d*4),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.utils.spectral_norm(nn.Conv2d(features_d*4, features_d*8, kernel_size=4, stride=2, padding=1)),
            nn.InstanceNorm2d(features_d*8),
            nn.LeakyReLU(0.2, inplace=True)
        )
        
        # PatchGAN output
        self.patch_output = nn.Sequential(
            nn.utils.spectral_norm(nn.Conv2d(features_d*8, 1, kernel_size=4, stride=1, padding=1))
        )
        
        # Global discriminator output
        self.global_output = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(features_d*8, 1)
        )
    
    def forward(self, x, get_features=False):
        features = self.feature_extraction(x)
        patch_out = self.patch_output(features)
        global_out = self.global_output(features)
        
        if get_features:
            # Return intermediate features for feature matching loss
            return patch_out, global_out, [features]
        return patch_out, global_out

In [None]:
class Generator(nn.Module):
    def __init__(self, in_channels=3, num_residual_blocks=9):
        super(Generator, self).__init__()
        
        # Initial convolution block
        model = [
            nn.ReflectionPad2d(3),
            nn.Conv2d(in_channels, 64, kernel_size=7, padding=0),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True)
        ]
        
        # Enhanced downsampling
        in_features = 64
        out_features = in_features * 2
        for _ in range(2):
            model += [
                nn.Conv2d(in_features, out_features, kernel_size=3, stride=2, padding=1),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(inplace=True)
            ]
            in_features = out_features
            out_features = in_features * 2
        
        # Use a mix of improved blocks
        for i in range(num_residual_blocks):
            if i % 3 == 0:  # Every 3rd block uses style modulation
                model += [StyleModulatedResBlock(in_features)]
            else:
                use_cbam = (i >= 3 and i <= 6)
                model += [ImprovedResidualBlock(in_features, use_cbam=use_cbam)]
        
        # Add enhanced self-attention after residual blocks
        model += [EnhancedSelfAttention(in_features)]
        
        # Enhanced upsampling pathway
        out_features = in_features // 2
        for _ in range(2):
            model += [
                nn.ConvTranspose2d(in_features, out_features, kernel_size=3, stride=2, padding=1, output_padding=1),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(inplace=True)
            ]
            in_features = out_features
            out_features = in_features // 2
        
        # Output layer
        model += [
            nn.ReflectionPad2d(3),
            nn.Conv2d(64, in_channels, kernel_size=7, padding=0),
            nn.Tanh()
        ]
        
        self.model = nn.Sequential(*model)
    
    def forward(self, x):
        return self.model(x)

In [None]:
def relativistic_loss(real_pred, fake_pred, criterion):
    # Standard GAN uses D(real) and D(fake) independently
    # Relativistic GAN compares real to fake samples
    
    # For patch-level predictions
    real_patch, real_global = real_pred
    fake_patch, fake_global = fake_pred
    
    # Relativistic average for patch outputs
    real_fake_diff_patch = real_patch - fake_patch.mean(0, keepdim=True)
    fake_real_diff_patch = fake_patch - real_patch.mean(0, keepdim=True)
    
    # Relativistic average for global outputs
    real_fake_diff_global = real_global - fake_global.mean(0, keepdim=True)
    fake_real_diff_global = fake_global - real_global.mean(0, keepdim=True)
    
    # Combine losses
    batch_size = real_global.size(0)
    ones = torch.ones(batch_size, 1).cuda() if torch.cuda.is_available() else torch.ones(batch_size, 1)
    zeros = torch.zeros(batch_size, 1).cuda() if torch.cuda.is_available() else torch.zeros(batch_size, 1)
    
    # Patch level loss
    loss_real_patch = criterion(real_fake_diff_patch, torch.ones_like(real_fake_diff_patch))
    loss_fake_patch = criterion(fake_real_diff_patch, torch.zeros_like(fake_real_diff_patch))
    
    # Global level loss
    loss_real_global = nn.BCEWithLogitsLoss()(real_fake_diff_global, ones)
    loss_fake_global = nn.BCEWithLogitsLoss()(fake_real_diff_global, zeros)
    
    # Combine patch and global losses
    loss_real = 0.7 * loss_real_patch + 0.3 * loss_real_global
    loss_fake = 0.7 * loss_fake_patch + 0.3 * loss_fake_global
    
    return (loss_real + loss_fake) / 2

def compute_gradient_penalty(discriminator, real_samples, fake_samples):
    # Random interpolation of real and fake samples
    alpha = torch.rand((real_samples.size(0), 1, 1, 1)).type(Tensor)
    interpolates = (alpha * real_samples + (1 - alpha) * fake_samples).requires_grad_(True)
    
    # Get discriminator output for interpolated images
    patch_interpolates, global_interpolates = discriminator(interpolates)
    
    # Set fake gradients for patch outputs
    fake_patch_grad = torch.ones_like(patch_interpolates).type(Tensor)
    patch_gradients = torch.autograd.grad(
        outputs=patch_interpolates,
        inputs=interpolates,
        grad_outputs=fake_patch_grad,
        create_graph=True,
        retain_graph=True,
        only_inputs=True,
    )[0]
    
    # Flatten gradients and compute penalty
    patch_gradients = patch_gradients.reshape(real_samples.size(0), -1)
    patch_gradient_penalty = ((patch_gradients.norm(2, dim=1) - 1) ** 2).mean()
    
    # Set fake gradients for global outputs
    fake_global_grad = torch.ones_like(global_interpolates).type(Tensor)
    global_gradients = torch.autograd.grad(
        outputs=global_interpolates,
        inputs=interpolates,
        grad_outputs=fake_global_grad,
        create_graph=True,
        retain_graph=True,
        only_inputs=True,
    )[0]
    
    # Flatten gradients and compute penalty
    global_gradients = global_gradients.reshape(real_samples.size(0), -1)
    global_gradient_penalty = ((global_gradients.norm(2, dim=1) - 1) ** 2).mean()
    
    # Combined penalty
    return 0.7 * patch_gradient_penalty + 0.3 * global_gradient_penalty

def feature_matching_loss(real_features, fake_features):
    """
    Calculate feature matching loss by comparing features extracted from real and fake images.
    
    Args:
        real_features: List of feature tensors from real images
        fake_features: List of feature tensors from fake images
        
    Returns:
        The mean absolute difference between real and fake features
    """
    # Initialize loss
    loss = 0.0
    
    # Both should be lists of the same length, containing feature maps
    assert len(real_features) == len(fake_features), "Feature lists must have same length"
    
    # Calculate L1 loss between corresponding feature maps
    for real_feat, fake_feat in zip(real_features, fake_features):
        # Make sure feature maps have same shape
        assert real_feat.shape == fake_feat.shape, f"Feature shapes don't match: {real_feat.shape} vs {fake_feat.shape}"
        # Calculate mean absolute error between feature maps
        loss += torch.mean(torch.abs(real_feat - fake_feat))
    
    # Average over number of feature maps
    return loss / len(real_features)

def fft_loss(real_image, fake_image, alpha=1.0):
    """
    Calculate loss in frequency domain to better preserve structure and style
    """
    # Get FFT of images
    real_fft = torch.fft.fft2(real_image, dim=(-2, -1))
    fake_fft = torch.fft.fft2(fake_image, dim=(-2, -1))
    
    # Calculate magnitude and phase
    real_magnitude = torch.abs(real_fft)
    fake_magnitude = torch.abs(fake_fft)
    
    # Focus on low frequencies (structure) by creating a mask
    batch, channels, height, width = real_image.shape
    mask = torch.ones((batch, channels, height, width), device=real_image.device)
    
    # Create circular mask emphasizing lower frequencies
    y, x = torch.meshgrid(torch.arange(height), torch.arange(width))
    center_y, center_x = height // 2, width // 2
    # Create distance matrix from center
    dist = ((y - center_y) ** 2 + (x - center_x) ** 2).sqrt()
    # Normalize distances to [0, 1]
    dist = dist / dist.max()
    # Create mask that emphasizes low frequencies (center of FFT)
    freq_mask = (1 - dist).unsqueeze(0).unsqueeze(0).to(real_image.device)
    
    # Apply frequency mask (emphasize low-frequency differences)
    masked_diff = (real_magnitude - fake_magnitude) * freq_mask
    
    # L1 loss on magnitude
    loss = alpha * torch.mean(torch.abs(masked_diff))
    return loss

def edge_loss(real_images, fake_images, weight=1.0):
    """
    Edge preservation loss using Sobel filters
    """
    def sobel_filters(x):
        sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], 
                              dtype=torch.float32, device=x.device).reshape(1, 1, 3, 3).repeat(3, 1, 1, 1)
        sobel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], 
                              dtype=torch.float32, device=x.device).reshape(1, 1, 3, 3).repeat(3, 1, 1, 1)
        
        grad_x = F.conv2d(x, sobel_x, padding=1, groups=3)
        grad_y = F.conv2d(x, sobel_y, padding=1, groups=3)
        
        return torch.sqrt(grad_x**2 + grad_y**2 + 1e-8)  # Add small epsilon for numerical stability
    
    # Get edges
    real_edges = sobel_filters(real_images)
    fake_edges = sobel_filters(fake_images)
    
    # L1 loss on edges
    return weight * F.l1_loss(real_edges, fake_edges)

def color_histogram_loss(real_images, fake_images, nbins=64):
    """
    Ensure consistent color distribution between domains
    """
    def histogram(x, nbins):
        batch_size, channels, height, width = x.shape
        
        # Scale to [0, 1]
        x = (x + 1) / 2
        
        # Compute histogram for each channel
        hist_list = []
        for c in range(channels):
            channel_data = x[:, c].reshape(batch_size, -1)  # Flatten spatial dimensions
            # Create histogram bins
            hist = torch.zeros(batch_size, nbins, device=x.device)
            
            for b in range(batch_size):
                # Compute histogram using binning
                for i in range(nbins):
                    bin_start = i / nbins
                    bin_end = (i + 1) / nbins
                    # Count pixels in this bin
                    mask = (channel_data[b] >= bin_start) & (channel_data[b] < bin_end)
                    hist[b, i] = mask.float().sum() / (height * width)
            hist_list.append(hist)
        
        # Concatenate histograms from all channels
        return torch.cat(hist_list, dim=1)
    
    # Compute color histograms
    real_hist = histogram(real_images, nbins)
    fake_hist = histogram(fake_images, nbins)
    
    # Earth Mover's Distance approximation
    diff = torch.abs(torch.cumsum(real_hist, dim=1) - torch.cumsum(fake_hist, dim=1))
    return torch.mean(diff)

# Create LR scheduler for generators and discriminators
def get_lr_schedulers(optimizer_G, optimizer_D_A, optimizer_D_B, num_epochs):
    # Linear decay learning rate scheduler
    def lambda_rule(epoch):
        # Linearly decrease learning rate to 0 over num_epochs
        return 1.0 - max(0, epoch - num_epochs // 2) / float(num_epochs // 2)
    
    scheduler_G = torch.optim.lr_scheduler.LambdaLR(optimizer_G, lr_lambda=lambda_rule)
    scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(optimizer_D_A, lr_lambda=lambda_rule)
    scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(optimizer_D_B, lr_lambda=lambda_rule)
    
    return scheduler_G, scheduler_D_A, scheduler_D_B

class AdaptiveLossWeights:
    def __init__(self, initial_weights):
        """
        Adaptively adjust loss weights during training
        
        Args:
            initial_weights: dict of initial weights for different loss components
        """
        self.weights = initial_weights
        self.loss_history = {k: [] for k in initial_weights}
        self.window_size = 50  # Moving average window
        
    def update(self, current_losses):
        """Update weights based on recent loss trends"""
        # Add current losses to history
        for loss_name, loss_value in current_losses.items():
            if loss_name in self.loss_history:
                self.loss_history[loss_name].append(loss_value)
                # Keep only recent history
                if len(self.loss_history[loss_name]) > self.window_size:
                    self.loss_history[loss_name].pop(0)
        
        # Only update if we have enough history
        if all(len(hist) >= self.window_size for hist in self.loss_history.values()):
            # Get average of recent losses
            avg_losses = {k: sum(v[-self.window_size:]) / self.window_size 
                         for k, v in self.loss_history.items()}
            
            # Calculate relative magnitudes
            total_loss = sum(avg_losses.values())
            loss_ratios = {k: v / total_loss for k, v in avg_losses.items()}
            
            # Adjust weights inversely to loss ratios (larger loss -> smaller weight)
            target_ratio = 1.0 / len(self.weights)
            for loss_name in self.weights:
                if loss_ratios[loss_name] > target_ratio * 1.5:  # Loss is too dominant
                    self.weights[loss_name] *= 0.9  # Decrease weight
                elif loss_ratios[loss_name] < target_ratio * 0.5:  # Loss is too small
                    self.weights[loss_name] *= 1.1  # Increase weight
                    
        return self.weights

In [None]:
class ImagePool:
    def __init__(self, pool_size=50):
        """Initialize the ImagePool with a specified pool size."""
        self.pool_size = pool_size
        self.num_imgs = 0
        self.images = []

    def query(self, images):
        """
        Returns images from the pool.
        If the pool is not full, add the incoming images.
        Otherwise, randomly choose to return an old image or the current one.
        """
        if self.pool_size == 0:
            return images
        return_images = []
        for image in images:
            image = torch.unsqueeze(image.data, 0)  # add a batch dimension if needed
            if self.num_imgs < self.pool_size:
                self.images.append(image)
                self.num_imgs += 1
                return_images.append(image)
            else:
                # With probability 0.5, use a previously stored image
                if random.uniform(0, 1) > 0.5:
                    idx = random.randint(0, self.pool_size - 1)
                    tmp = self.images[idx].clone()
                    self.images[idx] = image
                    return_images.append(tmp)
                else:
                    return_images.append(image)
        return torch.cat(return_images, 0)

In [None]:
def print_epoch_losses(epoch, gen_losses, gen_AB_losses, gen_BA_losses, id_losses, gan_losses, cycle_losses, fm_losses, discA_losses, discB_losses):
    # Calculate average losses for the epoch
    avg_gen_loss = sum(gen_losses) / len(gen_losses)
    avg_gen_AB_loss = sum(gen_AB_losses) / len(gen_AB_losses)
    avg_gen_BA_loss = sum(gen_BA_losses) / len(gen_BA_losses)
    avg_id_loss = sum(id_losses) / len(id_losses)
    avg_gan_loss = sum(gan_losses) / len(gan_losses)
    avg_cycle_loss = sum(cycle_losses) / len(cycle_losses)
    avg_fm_loss = sum(fm_losses) / len(fm_losses)
    avg_discA_loss = sum(discA_losses) / len(discA_losses)
    avg_discB_loss = sum(discB_losses) / len(discB_losses)
    
    print(f"Epoch {epoch}:")
    print(f"  Generator Total Loss: {avg_gen_loss:.4f}")
    print(f"    G_AB (A->B) Loss:  {avg_gen_AB_loss:.4f}")
    print(f"    G_BA (B->A) Loss:  {avg_gen_BA_loss:.4f}")
    print(f"    Identity Loss:     {avg_id_loss:.4f}")
    print(f"    GAN Loss:          {avg_gan_loss:.4f}")
    print(f"    Cycle Loss:        {avg_cycle_loss:.4f}")
    print(f"    Feature Match:     {avg_fm_loss:.4f}")
    print(f"  Discriminator A Loss: {avg_discA_loss:.4f}")
    print(f"  Discriminator B Loss: {avg_discB_loss:.4f}")

def plot_losses(gen_loss_history, gen_AB_loss_history, gen_BA_loss_history, 
                id_loss_history, gan_loss_history, cycle_loss_history, fm_loss_history,
                discA_loss_history, discB_loss_history):
    epochs = range(1, len(gen_loss_history) + 1)
    
    plt.figure(figsize=(15, 10))
    
    # Plot Generator Losses
    plt.subplot(2, 2, 1)
    plt.plot(epochs, gen_loss_history, label="Total Generator Loss", color="blue")
    plt.plot(epochs, gen_AB_loss_history, label="G_AB Loss", color="cyan", linestyle="dashed")
    plt.plot(epochs, gen_BA_loss_history, label="G_BA Loss", color="magenta", linestyle="dashed")
    plt.xlabel("Epochs")
    plt.ylabel("Loss")
    plt.title("Generator Total Losses")
    plt.legend()
    plt.grid()
    
    # Plot Component Losses
    plt.subplot(2, 2, 2)
    plt.plot(epochs, id_loss_history, label="Identity Loss", color="green")
    plt.plot(epochs, gan_loss_history, label="GAN Loss", color="red")
    plt.plot(epochs, cycle_loss_history, label="Cycle Loss", color="orange")
    plt.plot(epochs, fm_loss_history, label="Feature Match Loss", color="purple")
    plt.xlabel("Epochs")
    plt.ylabel("Loss")
    plt.title("Loss Components")
    plt.legend()
    plt.grid()
    
    # Plot Discriminator Losses
    plt.subplot(2, 2, 3)
    plt.plot(epochs, discA_loss_history, label="Discriminator A Loss", color="teal")
    plt.plot(epochs, discB_loss_history, label="Discriminator B Loss", color="brown")
    plt.xlabel("Epochs")
    plt.ylabel("Loss")
    plt.title("Discriminator Losses")
    plt.legend()
    plt.grid()
    
    # Plot combined important losses for overall view
    plt.subplot(2, 2, 4)
    plt.plot(epochs, gen_loss_history, label="Generator Loss", color="blue")
    plt.plot(epochs, cycle_loss_history, label="Cycle Loss", color="orange")
    plt.plot(epochs, gan_loss_history, label="GAN Loss", color="red")
    plt.plot(epochs, [(a+b)/2 for a, b in zip(discA_loss_history, discB_loss_history)], 
             label="Avg Discriminator Loss", color="purple")
    plt.xlabel("Epochs")
    plt.ylabel("Loss")
    plt.title("Key Losses")
    plt.legend()
    plt.grid()
    
    plt.tight_layout()
    plt.show()

def show_images(real_A, fake_B, recov_A, real_B, fake_A, recov_B, size=(256, 256), num_samples=5):
    """Display three rows of images: real A, fake B, reconstructed A, real B, fake A, reconstructed B"""
    # Denormalize images from [-1, 1] to [0, 1]
    def denorm(tensor):
        return (tensor * 0.5 + 0.5).clamp(0, 1)
    
    # Create a figure with 3 rows of images
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    
    # Only use a subset of images
    real_A = denorm(real_A[:num_samples])
    fake_B = denorm(fake_B[:num_samples])
    recov_A = denorm(recov_A[:num_samples])
    real_B = denorm(real_B[:num_samples])
    fake_A = denorm(fake_A[:num_samples])
    recov_B = denorm(recov_B[:num_samples])
    
    # Create grids from batches of images
    real_A_grid = make_grid(real_A, nrow=num_samples).permute(1, 2, 0).cpu().numpy()
    fake_B_grid = make_grid(fake_B, nrow=num_samples).permute(1, 2, 0).cpu().numpy()
    recov_A_grid = make_grid(recov_A, nrow=num_samples).permute(1, 2, 0).cpu().numpy()
    real_B_grid = make_grid(real_B, nrow=num_samples).permute(1, 2, 0).cpu().numpy()
    fake_A_grid = make_grid(fake_A, nrow=num_samples).permute(1, 2, 0).cpu().numpy()
    recov_B_grid = make_grid(recov_B, nrow=num_samples).permute(1, 2, 0).cpu().numpy()
    
    # Display the images
    axes[0, 0].imshow(real_A_grid)
    axes[0, 0].set_title("Real A (Domain A)")
    axes[0, 0].axis("off")
    
    axes[0, 1].imshow(fake_B_grid)
    axes[0, 1].set_title("Fake B (Generated from A)")
    axes[0, 1].axis("off")
    
    axes[0, 2].imshow(recov_A_grid)
    axes[0, 2].set_title("Reconstructed A")
    axes[0, 2].axis("off")
    
    axes[1, 0].imshow(real_B_grid)
    axes[1, 0].set_title("Real B (Domain B)")
    axes[1, 0].axis("off")
    
    axes[1, 1].imshow(fake_A_grid)
    axes[1, 1].set_title("Fake A (Generated from B)")
    axes[1, 1].axis("off")
    
    axes[1, 2].imshow(recov_B_grid)
    axes[1, 2].set_title("Reconstructed B")
    axes[1, 2].axis("off")
    
    plt.tight_layout()
    plt.show()

In [None]:
image_size = (256, 256)
batch_size = 4
n_epochs = 100
lr = 2e-4
betas = (0.5, 0.999)
lambda_cyc = 10.0
lambda_identity = 5.0
lambda_gan = 1.0
lambda_fm = 5.0
lambda_gp = 10.0
pool_size_A = 50
pool_size_B = 50
num_residual_blocks = 9
use_attention = True
features_d = 64
file_name = "EnhancedCycleGAN-256x256"

In [None]:
def get_face_specific_transforms(img_size=256):
    """
    Create transforms specifically for face-to-cartoon translation
    """
    return transforms.Compose([
        # Fundamental transformations
        transforms.Resize((img_size + 30, img_size + 30)),  # Resize larger for crop
        transforms.RandomCrop((img_size, img_size)),  # Random crop for position variety
        transforms.RandomHorizontalFlip(),  # Flip faces (common augmentation)
        
        # Face-specific augmentations (mild to preserve facial structure)
        transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.05),
        transforms.RandomAffine(
            degrees=5,              # Mild rotation (faces are orientation-sensitive)
            translate=(0.05, 0.05), # Slight translation
            scale=(0.95, 1.05),     # Mild scaling
            fill=0                  # Fill with black
        ),
        
        # Prepare for model
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

In [None]:

"""
Step 3. Define Loss
"""
criterion_GAN = nn.MSELoss()  # For adversarial loss
criterion_cycle = nn.L1Loss()  # For cycle consistency loss
criterion_identity = nn.L1Loss()  # For identity loss

"""
Step 4. Initalize G and D¶
"""
G_AB = Generator(in_channels=3, num_residual_blocks=num_residual_blocks)
G_BA = Generator(in_channels=3, num_residual_blocks=num_residual_blocks)
D_A = Discriminator(in_channels=3, features_d=features_d)
D_B = Discriminator(in_channels=3, features_d=features_d)

## Total parameters in CycleGAN should be less than 60MB
total_params = sum(p.numel() for p in G_AB.parameters()) + \
               sum(p.numel() for p in G_BA.parameters()) + \
               sum(p.numel() for p in D_A.parameters()) + \
               sum(p.numel() for p in D_B.parameters())


"""
# modification of parameters computation is forbidden
"""
total_params_million = total_params / (1024 * 1024)
print(f'Total parameters in CycleGAN model: {total_params_million:.2f} million')

cuda = torch.cuda.is_available()
print(f'cuda: {cuda}')
if cuda:
    G_AB = G_AB.cuda()
    D_B = D_B.cuda()
    G_BA = G_BA.cuda()
    D_A = D_A.cuda()

criterion_GAN = criterion_GAN.cuda()
criterion_cycle = criterion_cycle.cuda()
criterion_identity = criterion_identity.cuda()

"""
Step 5. Configure Optimizers
"""


Total parameters in CycleGAN model: 27.30 million
cuda: False


In [None]:
# -------------------------
# Save Model
# -------------------------
def save_checkpoint(model, model_name, torch_seed, checkpoint_dir='models'):
    # Ensure the directory exists
    os.makedirs(checkpoint_dir, exist_ok=True)

    checkpoint = {
        'model': model.state_dict(),
        'torch_seed': torch_seed
    }

    filename = os.path.join(checkpoint_dir, f"{file_name}-{model_name}-best.pth")

    # Save the checkpoint
    torch.save(checkpoint, filename)
    print(f"Model saved to {filename}")

## 3. Load the data

In [None]:
"""
Step 6. DataLoader
"""
class ImageDataset(Dataset):
    def __init__(self, data_dir, mode='train', transforms=None, split_ratio=0.99):
        A_dir = os.path.join(data_dir, 'VAE_generation/train') # modification forbidden
        B_dir = os.path.join(data_dir, 'VAE_generation_Cartoon/train')  # modification forbidden

        files_A = sorted(os.listdir(A_dir))
        files_B = sorted(os.listdir(B_dir))

        split_idx_A = int(len(files_A) * split_ratio)
        split_idx_B = int(len(files_B) * split_ratio)

        if mode == 'train':
            self.files_A = [os.path.join(A_dir, name) for name in files_A[:split_idx_A]]
            self.files_B = [os.path.join(B_dir, name) for name in files_B[:split_idx_B]]
        elif mode == 'valid':
            self.files_A = [os.path.join(A_dir, name) for name in files_A[split_idx_A:]]
            self.files_B = [os.path.join(B_dir, name) for name in files_B[split_idx_B:]]

        self.transforms = transforms

    def __len__(self):
        return len(self.files_A)

    def __getitem__(self, index):
        file_A = self.files_A[index]
        file_B = self.files_B[index]

        img_A = Image.open(file_A)
        img_B = Image.open(file_B)

        if self.transforms is not None:
            img_A = self.transforms(img_A)
            img_B = self.transforms(img_B)

        return img_A, img_B

data_dir = './image_image_translation'
# data_dir = '/kaggle/input/group-project/image_image_translation'

image_size = (256, 256)
data_transforms = transforms.Compose([
    transforms.Resize(image_size),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

face_transforms = get_face_specific_transforms(image_size)

batch_size = 5

trainloader = DataLoader(
    ImageDataset(data_dir, mode='train', transforms=face_transforms),
    batch_size = batch_size,
    shuffle = True,
    num_workers = 3
)

validloader = DataLoader(
    ImageDataset(data_dir, mode='valid', transforms=data_transforms),
    batch_size = batch_size,
    shuffle = False,
    num_workers = 3
)

FileNotFoundError: [WinError 3] The system cannot find the path specified: '/kaggle/input/group-project/image_image_translation\\VAE_generation/train'

In [None]:
def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
        if hasattr(m, "bias") and m.bias is not None:
            torch.nn.init.constant_(m.bias.data, 0.0)
    elif classname.find("BatchNorm2d") != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)

# Apply weight initialization
G_AB.apply(weights_init_normal)
G_BA.apply(weights_init_normal)
D_A.apply(weights_init_normal)
D_B.apply(weights_init_normal)

## 4. Train your **model**

In [None]:

def train(G_AB, G_BA, D_A, D_B, train_loader, validloader, num_epochs=200, batch_size=4):
    # Setup optimizers
    optimizer_G = torch.optim.Adam(
        itertools.chain(G_AB.parameters(), G_BA.parameters()),
        lr=lr, betas=betas
    )
    optimizer_D_A = torch.optim.Adam(D_A.parameters(), lr=lr, betas=betas)
    optimizer_D_B = torch.optim.Adam(D_B.parameters(), lr=lr, betas=betas)
    
    # Setup LR schedulers
    scheduler_G, scheduler_D_A, scheduler_D_B = get_lr_schedulers(
        optimizer_G, optimizer_D_A, optimizer_D_B, num_epochs
    )
    
    # Create image pools
    fake_A_pool = ImagePool(pool_size_A)
    fake_B_pool = ImagePool(pool_size_B)
    
    # Initialize adaptive loss weights
    loss_weights = AdaptiveLossWeights({
        'cycle': 10.0,
        'identity': 5.0,
        'gan': 1.0,
        'frequency': 2.0,
        'edge': 5.0,
        'color': 1.0,
        'feature_matching': 10.0
    })
    
    # Loss histories for plotting
    gen_loss_history = []
    gen_AB_loss_history = []
    gen_BA_loss_history = []
    id_loss_history = []
    gan_loss_history = []
    cycle_loss_history = []
    fm_loss_history = []
    discA_loss_history = []
    discB_loss_history = []
    
    # Training loop
    for epoch in range(num_epochs):
        # Initialize epoch losses
        epoch_gen_losses = []
        epoch_gen_AB_losses = []
        epoch_gen_BA_losses = []
        epoch_id_losses = []
        epoch_gan_losses = []
        epoch_cycle_losses = []
        epoch_fm_losses = []
        epoch_discA_losses = []
        epoch_discB_losses = []
        
        for i, batch in enumerate(train_loader):
            # Get batch data
            real_A = batch['A'].type(Tensor)
            real_B = batch['B'].type(Tensor)
            
            # Set model input
            real_A = Variable(real_A)
            real_B = Variable(real_B)
            
            # Adversarial ground truths
            valid = Variable(Tensor(np.ones((real_A.size(0), 1))), requires_grad=False)
            fake = Variable(Tensor(np.zeros((real_A.size(0), 1))), requires_grad=False)
            
            #-------------------------------
            # Train Generators
            #-------------------------------
            optimizer_G.zero_grad()
            
            # Identity loss
            identity_A = G_BA(real_A)
            identity_B = G_AB(real_B)
            loss_id_A = F.l1_loss(identity_A, real_A)
            loss_id_B = F.l1_loss(identity_B, real_B)
            loss_identity = (loss_id_A + loss_id_B) / 2
            
            # GAN loss
            fake_B = G_AB(real_A)
            fake_A = G_BA(real_B)
            
            # Get features for feature matching loss
            real_A_pred, _, real_A_features = D_A(real_A, get_features=True)
            fake_A_pred, _, fake_A_features = D_A(fake_A, get_features=True)
            real_B_pred, _, real_B_features = D_B(real_B, get_features=True)
            fake_B_pred, _, fake_B_features = D_B(fake_B, get_features=True)
            
            # Calculate GAN loss using relativistic loss
            loss_GAN_AB = relativistic_loss((real_B_pred, _), (fake_B_pred, _), nn.MSELoss())
            loss_GAN_BA = relativistic_loss((real_A_pred, _), (fake_A_pred, _), nn.MSELoss())
            loss_GAN = (loss_GAN_AB + loss_GAN_BA) / 2
            
            # Cycle loss
            recov_A = G_BA(fake_B)
            recov_B = G_AB(fake_A)
            loss_cycle_A = F.l1_loss(recov_A, real_A)
            loss_cycle_B = F.l1_loss(recov_B, real_B)
            loss_cycle = (loss_cycle_A + loss_cycle_B) / 2
            
            # Feature matching loss
            loss_fm_A = feature_matching_loss(real_A_features, fake_A_features)
            loss_fm_B = feature_matching_loss(real_B_features, fake_B_features)
            loss_fm = (loss_fm_A + loss_fm_B) / 2
            
            # New losses
            loss_freq = (fft_loss(real_A, recov_A) + fft_loss(real_B, recov_B)) / 2
            loss_edge = (edge_loss(real_A, fake_A) + edge_loss(real_B, fake_B)) / 2
            loss_color = color_histogram_loss(real_A, fake_A) + color_histogram_loss(real_B, fake_B)
            
            # Get current loss weights
            weights = loss_weights.update({
                'cycle': loss_cycle.item(),
                'identity': loss_identity.item(),
                'gan': loss_GAN.item(),
                'frequency': loss_freq.item(),
                'edge': loss_edge.item(),
                'color': loss_color.item(),
                'feature_matching': loss_fm.item()
            })
            
            # Total generator loss with weighted components
            loss_G = (
                weights['gan'] * loss_GAN + 
                weights['cycle'] * loss_cycle +
                weights['identity'] * loss_identity +
                weights['frequency'] * loss_freq +
                weights['edge'] * loss_edge +
                weights['color'] * loss_color +
                weights['feature_matching'] * loss_fm
            )
            
            # Individual generator losses
            loss_G_AB = (
                weights['gan'] * loss_GAN_AB + 
                weights['cycle'] * loss_cycle_B +
                weights['identity'] * loss_id_B +
                weights['frequency'] * fft_loss(real_B, recov_B) +
                weights['edge'] * edge_loss(real_B, fake_B) +
                weights['color'] * color_histogram_loss(real_B, fake_B) +
                weights['feature_matching'] * loss_fm_B
            )
            
            loss_G_BA = (
                weights['gan'] * loss_GAN_BA + 
                weights['cycle'] * loss_cycle_A +
                weights['identity'] * loss_id_A +
                weights['frequency'] * fft_loss(real_A, recov_A) +
                weights['edge'] * edge_loss(real_A, fake_A) +
                weights['color'] * color_histogram_loss(real_A, fake_A) +
                weights['feature_matching'] * loss_fm_A
            )
            
            loss_G.backward()
            optimizer_G.step()
            
            #-------------------------------
            # Train Discriminator A
            #-------------------------------
            optimizer_D_A.zero_grad()
            
            # Use image pool for discriminator training
            fake_A_ = fake_A_pool.query(fake_A)
            
            # Real and fake discriminator outputs
            real_A_pred, _ = D_A(real_A)
            fake_A_pred, _ = D_A(fake_A_.detach())
            
            # Relativistic discriminator loss
            loss_D_A = relativistic_loss((real_A_pred, _), (fake_A_pred, _), nn.MSELoss())
            
            # Add gradient penalty
            loss_D_A_gp = compute_gradient_penalty(D_A, real_A, fake_A_.detach())
            loss_D_A += 10.0 * loss_D_A_gp
            
            loss_D_A.backward()
            optimizer_D_A.step()
            
            #-------------------------------
            # Train Discriminator B
            #-------------------------------
            optimizer_D_B.zero_grad()
            
            # Use image pool for discriminator training
            fake_B_ = fake_B_pool.query(fake_B)
            
            # Real and fake discriminator outputs
            real_B_pred, _ = D_B(real_B)
            fake_B_pred, _ = D_B(fake_B_.detach())
            
            # Relativistic discriminator loss
            loss_D_B = relativistic_loss((real_B_pred, _), (fake_B_pred, _), nn.MSELoss())
            
            # Add gradient penalty
            loss_D_B_gp = compute_gradient_penalty(D_B, real_B, fake_B_.detach())
            loss_D_B += 10.0 * loss_D_B_gp
            
            loss_D_B.backward()
            optimizer_D_B.step()
            
            # Store batch losses
            epoch_gen_losses.append(loss_G.item())
            epoch_gen_AB_losses.append(loss_G_AB.item())
            epoch_gen_BA_losses.append(loss_G_BA.item())
            epoch_id_losses.append(loss_identity.item())
            epoch_gan_losses.append(loss_GAN.item())
            epoch_cycle_losses.append(loss_cycle.item())
            epoch_fm_losses.append(loss_fm.item())
            epoch_discA_losses.append(loss_D_A.item())
            epoch_discB_losses.append(loss_D_B.item())
            
            # Print progress
            if i % 50 == 0:
                print(f"[Epoch {epoch}/{num_epochs}] [Batch {i}/{len(train_loader)}] "
                      f"[D_A loss: {loss_D_A.item():.4f}] [D_B loss: {loss_D_B.item():.4f}] "
                      f"[G loss: {loss_G.item():.4f}, adv: {loss_GAN.item():.4f}, cycle: {loss_cycle.item():.4f}, "
                      f"id: {loss_identity.item():.4f}, fm: {loss_fm.item():.4f}]")
                
        # Update learning rates
        scheduler_G.step()
        scheduler_D_A.step()
        scheduler_D_B.step()
        
        # Calculate average epoch losses
        gen_loss_history.append(sum(epoch_gen_losses) / len(epoch_gen_losses))
        gen_AB_loss_history.append(sum(epoch_gen_AB_losses) / len(epoch_gen_AB_losses))
        gen_BA_loss_history.append(sum(epoch_gen_BA_losses) / len(epoch_gen_BA_losses))
        id_loss_history.append(sum(epoch_id_losses) / len(epoch_id_losses))
        gan_loss_history.append(sum(epoch_gan_losses) / len(epoch_gan_losses))
        cycle_loss_history.append(sum(epoch_cycle_losses) / len(epoch_cycle_losses))
        fm_loss_history.append(sum(epoch_fm_losses) / len(epoch_fm_losses))
        discA_loss_history.append(sum(epoch_discA_losses) / len(epoch_discA_losses))
        discB_loss_history.append(sum(epoch_discB_losses) / len(epoch_discB_losses))
        
        # Print epoch losses
        print_epoch_losses(epoch, epoch_gen_losses, epoch_gen_AB_losses, epoch_gen_BA_losses,
                          epoch_id_losses, epoch_gan_losses, epoch_cycle_losses, epoch_fm_losses,
                          epoch_discA_losses, epoch_discB_losses)
        
        if (epoch + 1) % 10 == 0:
            # Set models to evaluation mode
            G_AB.eval()
            G_BA.eval()
            
            # Get validation samples
            valid_iter = iter(validloader)
            valid_real_A, valid_real_B = next(valid_iter)
            valid_real_A, valid_real_B = valid_real_A.type(Tensor), valid_real_B.type(Tensor)
            
            # Generate validation results
            with torch.no_grad():
                # Generate fake samples
                valid_fake_B = G_AB(valid_real_A)
                valid_fake_A = G_BA(valid_real_B)
                
                # Generate reconstructed samples
                valid_recov_A = G_BA(valid_fake_B)
                valid_recov_B = G_AB(valid_fake_A)
            
            # Visualize validation results
            show_images(valid_real_A, valid_fake_B, valid_recov_A,
                    valid_real_B, valid_fake_A, valid_recov_B)
    
    plot_losses(gen_loss_history, gen_AB_loss_history, gen_BA_loss_history,
           id_loss_history, gan_loss_history, cycle_loss_history, fm_loss_history,
           discA_loss_history, discB_loss_history)
    
    save_checkpoint(G_AB, "G_AB", torch_seed)
    save_checkpoint(G_BA, "G_BA", torch_seed)
    save_checkpoint(D_A, "D_A", torch_seed)
    save_checkpoint(D_B, "D_B", torch_seed)

    return G_AB, G_BA, D_A, D_B, gen_loss_history

In [None]:
G_AB, G_BA, D_A, D_B, loss_history = train(
    G_AB, G_BA, D_A, D_B, 
    train_loader=trainloader, 
    validloader=validloader,
    num_epochs=n_epochs, 
    batch_size=batch_size
)

In [None]:
def load_model(model, model_name):

    # Load the Checkpoing
    checkpoint = torch.load(f"models/{file_name}-{model_name}-best.pth")

    # Load the state dict into the model
    model.load_state_dict(checkpoint['model'])
    torch_seed = checkpoint['torch_seed']

    model.cuda()

    return model, torch_seed

In [None]:
G_AB, torch_seed = load_model(G_AB, "G_AB")
G_BA, torch_seed = load_model(G_BA, "G_BA")
D_A, torch_seed = load_model(D_A, "D_A")
D_B, torch_seed = load_model(D_B, "D_B")

## 5. Evaluation

In [None]:
"""
Step 8. Generate Images
"""
## Translation 1: Raw Image --> Cartoon Image
test_dir = os.path.join(data_dir, 'VAE_generation/test') # modification forbidden
files = [os.path.join(test_dir, name) for name in os.listdir(test_dir)]

save_dir = f'./generated_cartoon_images/{file_name}'
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

to_image = transforms.ToPILImage()

G_AB.eval()
for i in range(0, len(files), batch_size):
    # read images
    imgs = []
    for j in range(i, min(len(files), i+batch_size)):
        img = Image.open(files[j])
        img = data_transforms(img)
        imgs.append(img)
    imgs = torch.stack(imgs, 0).type(Tensor)

    # generate
    fake_imgs = G_AB(imgs).detach().cpu()

    # save
    for j in range(fake_imgs.size(0)):
        img = fake_imgs[j].squeeze().permute(1, 2, 0)
        img_arr = img.numpy()
        img_arr = (img_arr - np.min(img_arr)) * 255 / (np.max(img_arr) - np.min(img_arr))
        img_arr = img_arr.astype(np.uint8)

        img = to_image(img_arr)
        _, name = os.path.split(files[i+j])
        img.save(os.path.join(save_dir, name))

gt_dir = os.path.join(data_dir, 'VAE_generation_Cartoon/test')
metrics = torch_fidelity.calculate_metrics(
    input1=save_dir,
    input2=gt_dir,
    cuda=True,
    fid=True,
    isc=True
)

fid_score = metrics["frechet_inception_distance"]
is_score = metrics["inception_score_mean"]

if is_score > 0:
    s_value_1 = np.sqrt(fid_score / is_score)
    print("Geometric Mean Score:", s_value_1)
else:
    print("IS is 0, GMS cannot be computed!")

In [None]:
## Translation 2: Cartoon Image --> Raw Image
test_dir = os.path.join(data_dir, 'VAE_generation_Cartoon/test') # modification forbidden
files = [os.path.join(test_dir, name) for name in os.listdir(test_dir)]

save_dir = f'./generated_raw_images/{file_name}'
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

G_BA.eval()
for i in range(0, len(files), batch_size):
    # read images
    imgs = []
    for j in range(i, min(len(files), i+batch_size)):
        img = Image.open(files[j])
        img = data_transforms(img)
        imgs.append(img)
    imgs = torch.stack(imgs, 0).type(Tensor)

    # generate
    fake_imgs = G_BA(imgs).detach().cpu()

    # save
    for j in range(fake_imgs.size(0)):
        img = fake_imgs[j].squeeze().permute(1, 2, 0)
        img_arr = img.numpy()
        img_arr = (img_arr - np.min(img_arr)) * 255 / (np.max(img_arr) - np.min(img_arr))
        img_arr = img_arr.astype(np.uint8)

        img = to_image(img_arr)
        _, name = os.path.split(files[i+j])
        img.save(os.path.join(save_dir, name))

gt_dir = os.path.join(data_dir, 'VAE_generation/test')

metrics = torch_fidelity.calculate_metrics(
    input1 = save_dir,
    input2 = gt_dir,
    cuda = True,
    fid = True,
    isc = True
)

fid_score = metrics["frechet_inception_distance"]
is_score = metrics["inception_score_mean"]

if is_score > 0:
    s_value_2 = np.sqrt(fid_score / is_score)
    print("Geometric Mean Score:", s_value_2)
else:
    print("IS is 0, GMS cannot be computed!")

In [None]:
s_value = np.round((s_value_1+s_value_2)/2, 5)
print(s_value)
df = pd.DataFrame({'id': [1], 'label': [s_value]})
csv_path = "zhirong.csv"
df.to_csv(csv_path, index=False)

print(f"CSV saved to {csv_path}")