# Import libraries

In [None]:
import torch
from torch.autograd import Variable
from torchvision import transforms
import argparse
import os, itertools
import numpy as np
from PIL import Image
import torch.utils.data as data
import os
import random
import matplotlib.pyplot as plt
import imageio

# Dataset preparation and Parameter Setting

We use the argparse module to define and parse command line arguments. It sets parameters for the data set, model, and learning. It also defines directories for loading data and saving results. The values of the arguments are printed and stored in the params variable.

In [None]:
parser = argparse.ArgumentParser()

#Data Set Parameter
parser.add_argument('--dataset', required=False, default='summer2winter', help='input dataset')
parser.add_argument('--batch_size', type=int, default=1, help='train batch size')
parser.add_argument('--input_size', type=int, default=256, help='input size')
parser.add_argument('--resize_scale', type=int, default=286, help='resize scale (0 is false)')
parser.add_argument('--crop_size', type=int, default=256, help='crop size (0 is false)')
parser.add_argument('--fliplr', type=bool, default=True, help='random fliplr True of False')

#Model Parameters 
parser.add_argument('--ngf', type=int, default=32) # number of generator filters
parser.add_argument('--ndf', type=int, default=64) # number of discriminator filters
parser.add_argument('--num_resnet', type=int, default=6, help='number of resnet blocks in generator')

#Learning Parameters
parser.add_argument('--num_epochs', type=int, default=70, help='number of train epochs')
parser.add_argument('--decay_epoch', type=int, default=100, help='start decaying learning rate after this number')
parser.add_argument('--lrG', type=float, default=0.0002, help='learning rate for generator, default=0.0002')
parser.add_argument('--lrD', type=float, default=0.0002, help='learning rate for discriminator, default=0.0002')
parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for Adam optimizer')
parser.add_argument('--beta2', type=float, default=0.999, help='beta2 for Adam optimizer')
parser.add_argument('--lambdaA', type=float, default=10, help='lambdaA for cycle loss')
parser.add_argument('--lambdaB', type=float, default=10, help='lambdaB for cycle loss')
params = parser.parse_args([])
print(params)

# Directories for loading data and saving results
data_dir = '/content/data'
save_dir = '/content/results/'
plot_gif_dir = '/content/results/plot_gif/'
test_res_dir = '/content/test_results/'

Namespace(dataset='summer2winter', batch_size=1, input_size=256, resize_scale=286, crop_size=256, fliplr=True, ngf=32, ndf=64, num_resnet=6, num_epochs=70, decay_epoch=100, lrG=0.0001, lrD=0.0001, beta1=0.5, beta2=0.999, lambdaA=10, lambdaB=10)


Ensuring that the directories exist and can be used for storing and loading data and results.

In [None]:
if not os.path.exists(data_dir):
    os.makedirs(data_dir)

if not os.path.exists(save_dir):
    os.makedirs(save_dir) 

if not os.path.exists(test_res_dir):
    os.makedirs(test_res_dir)   

if not os.path.exists(plot_gif_dir):
    os.makedirs(plot_gif_dir) 

In [None]:
os.chdir(data_dir)
!pip install kaggle --upgrade
os.environ['KAGGLE_USERNAME'] = ''
os.environ['KAGGLE_KEY'] = ''

!kaggle datasets download -d balraj98/summer2winter-yosemite
!unzip summer2winter-yosemite.zip

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Downloading summer2winter-yosemite.zip to /content/data
 89% 113M/126M [00:01<00:00, 69.7MB/s] 
100% 126M/126M [00:01<00:00, 73.2MB/s]
Archive:  summer2winter-yosemite.zip
  inflating: metadata.csv            
  inflating: testA/2010-09-07 12_23_20.jpg  
  inflating: testA/2010-10-05 13_45_11.jpg  
  inflating: testA/2010-10-05 19_08_31.jpg  
  inflating: testA/2011-05-23 17_46_40.jpg  
  inflating: testA/2011-05-26 15_06_01.jpg  
  inflating: testA/2011-05-28 15_13_21.jpg  
  inflating: testA/2011-05-29 10_20_21.jpg  
  inflating: testA/2011-05-29 13_29_21.jpg  
  inflating: testA/2011-06-03 03_36_41.jpg  
  inflating: testA/2011-06-03 15_29_50.jpg  
  inflating: testA/2011-06-03 21_27_20.jpg  
  inflating: testA/2011-06-04 19_38_11.jpg  
  inflating: testA/2011-06-09 12_02_20.jpg  
  inflating: testA/2011-06-14 23_29_30.jpg  
  inflating: testA/2011-06-20 08_47_21.jpg  
  inflating: tes

# Data Transformation

Here we are defining a data augmentation pipeline using transforms.Compose(). The pipeline resizes the input image to (params.input_size, params.input_size), converts it to a tensor, and normalizes the pixel values to have a mean of (0.5, 0.5, 0.5) and a standard deviation of (0.5, 0.5, 0.5). 

In [None]:
transform = transforms.Compose([
    transforms.Resize((params.input_size,params.input_size)),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])

# Defining Auxiliary Classes

In [None]:
class DatasetFromFolder(data.Dataset):
    """
    A PyTorch dataset class for loading image data from a folder.

    Args:
        image_dir (str): Path to the folder containing image files.
        subfolder (str, optional): Name of the subfolder within image_dir to use. Default is 'train'.
        transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed version.
        resize_scale (int, optional): Size to resize the image to. Default is None (no resizing).
        crop_size (int, optional): Size to crop the image to. Default is None (no cropping).
        fliplr (bool, optional): Whether or not to randomly flip the image horizontally. Default is False (no flipping).
    """

    def __init__(self, image_dir, subfolder='train', transform=None, resize_scale=None, crop_size=None, fliplr=False):
        """
        Initialize the dataset.

        Args:
            image_dir (str): Path to the folder containing image files.
            subfolder (str, optional): Name of the subfolder within image_dir to use. Default is 'train'.
            transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed version.
            resize_scale (int, optional): Size to resize the image to. Default is None (no resizing).
            crop_size (int, optional): Size to crop the image to. Default is None (no cropping).
            fliplr (bool, optional): Whether or not to randomly flip the image horizontally. Default is False (no flipping).
        """

        super(DatasetFromFolder, self).__init__()
        self.input_path = os.path.join(image_dir, subfolder)
        self.image_filenames = [x for x in sorted(os.listdir(self.input_path))]
        self.transform = transform
        
        self.resize_scale = resize_scale
        self.crop_size = crop_size
        self.fliplr = fliplr

    def __getitem__(self, index):
        """
        Load and preprocess an image from the dataset.

        Args:
            index (int): Index of the image to load.

        Returns:
            img (PIL Image): The loaded and preprocessed image.
        """

        # Load Image
        img_fn = os.path.join(self.input_path, self.image_filenames[index])
        img = Image.open(img_fn).convert('RGB')

        # preprocessing
        if self.resize_scale:
            img = img.resize((self.resize_scale, self.resize_scale), Image.BILINEAR)

        if self.crop_size:
            x = random.randint(0, self.resize_scale - self.crop_size + 1)
            y = random.randint(0, self.resize_scale - self.crop_size + 1)
            img = img.crop((x, y, x + self.crop_size, y + self.crop_size))
        if self.fliplr:
            if random.random() < 0.5:
                img = img.transpose(Image.FLIP_LEFT_RIGHT)

        if self.transform is not None:
            img = self.transform(img)

        return img

    def __len__(self):
        """
        Get the length of the dataset.

        Returns:
            length (int): The number of images in the dataset.
        """
        return len(self.image_filenames)

We initialize two DatasetFromFolder objects train_data_A and train_data_B which read the image files from the directories trainA and trainB respectively. The images are then preprocessed using the transform pipeline defined earlier, with additional options for resizing, cropping and horizontal flipping. The resulting preprocessed images are then loaded into DataLoader objects train_data_loader_A and train_data_loader_B respectively, which will be used for iterating over the training data during the training process. The batch_size parameter determines how many images are loaded into memory at once, and the shuffle parameter shuffles the order of the images to ensure the model sees a different order of images during each epoch of training.

In [None]:
train_data_A = DatasetFromFolder(data_dir, subfolder='trainA', transform=transform,
                                resize_scale=params.resize_scale, crop_size=params.crop_size, fliplr=params.fliplr)

train_data_loader_A = torch.utils.data.DataLoader(dataset=train_data_A, batch_size=params.batch_size, shuffle=True)

train_data_B = DatasetFromFolder(data_dir, subfolder='trainB', transform=transform,
                                resize_scale=params.resize_scale, crop_size=params.crop_size, fliplr=params.fliplr)

train_data_loader_B = torch.utils.data.DataLoader(dataset=train_data_B, batch_size=params.batch_size, shuffle=True)

Here we define the data loaders for the test dataset, test_data_A_loader and test_data_B_loader, which are used to load images for testing the trained model. The specific test images, test_real_A_data and test_real_B_data, are obtained by calling the __getitem__ method on the train data loaders for train_data_A and train_data_B respectively, and then unsqueezing them to create 4D tensors.

In [None]:
test_data_A = DatasetFromFolder(data_dir, subfolder='testA', transform=transform)

test_data_loader_A = torch.utils.data.DataLoader(dataset=test_data_A, batch_size=params.batch_size, shuffle=False)

test_data_B = DatasetFromFolder(data_dir, subfolder='testB', transform=transform)

test_data_loader_B = torch.utils.data.DataLoader(dataset=test_data_B, batch_size=params.batch_size, shuffle=False)


# Get specific test images
test_real_A_data = train_data_A.__getitem__(11).unsqueeze(0) # Convert to 4d tensor (BxNxHxW)
test_real_B_data = train_data_B.__getitem__(91).unsqueeze(0)
print(test_real_A_data)

tensor([[[[-0.4431, -0.4431, -0.4431,  ..., -0.3882, -0.3804, -0.3804],
          [-0.3725, -0.4353, -0.4431,  ..., -0.3882, -0.3804, -0.3804],
          [-0.3490, -0.3490, -0.3098,  ..., -0.3804, -0.3804, -0.3804],
          ...,
          [-0.3020, -0.3020, -0.3176,  ..., -0.2627, -0.0824, -0.0588],
          [-0.2863, -0.2706, -0.2784,  ..., -0.4902, -0.1608, -0.0902],
          [-0.3176, -0.2627, -0.2941,  ..., -0.3725, -0.1216, -0.1294]],

         [[-0.3961, -0.3961, -0.4039,  ..., -0.1294, -0.1216, -0.1216],
          [-0.3412, -0.3882, -0.4118,  ..., -0.1294, -0.1216, -0.1216],
          [-0.3412, -0.3412, -0.3098,  ..., -0.1216, -0.1216, -0.1216],
          ...,
          [-0.2314, -0.2314, -0.2549,  ..., -0.2235, -0.0588, -0.0431],
          [-0.1922, -0.1765, -0.1765,  ..., -0.4510, -0.1765, -0.0510],
          [-0.2235, -0.1686, -0.1922,  ..., -0.3098, -0.0902, -0.0980]],

         [[-0.3882, -0.4118, -0.4353,  ...,  0.2706,  0.2784,  0.2784],
          [-0.3176, -0.3804, -

# Defining Model

In [None]:
class ConvBlock(torch.nn.Module):
    """
    A convolutional block that consists of a convolutional layer and an optional batch normalization layer
    followed by an activation function.

    Args:
        input_size (int): The number of input channels.
        output_size (int): The number of output channels.
        kernel_size (int, optional): The size of the kernel. Default is 3.
        stride (int, optional): The stride of the convolution. Default is 2.
        padding (int, optional): The padding added to the input. Default is 1.
        activation (str, optional): The activation function to be applied. Can be 'relu', 'lrelu', 'tanh', or 'no_act'.
            Default is 'relu'.
        batch_norm (bool, optional): Whether to apply batch normalization. Default is True.
    """
    def __init__(self, input_size, output_size, kernel_size=3, stride=2, padding=1, activation='relu', batch_norm=True):
        super(ConvBlock, self).__init__()
        self.conv = torch.nn.Conv2d(input_size, output_size, kernel_size, stride, padding)
        self.batch_norm = batch_norm
        self.bn = torch.nn.InstanceNorm2d(output_size)
        self.activation = activation
        self.relu = torch.nn.ReLU(True)
        self.lrelu = torch.nn.LeakyReLU(0.2, True)
        self.tanh = torch.nn.Tanh()

    def forward(self, x):
        """
        Applies the convolutional block to the input.

        Args:
            x (torch.Tensor): The input tensor.

        Returns:
            torch.Tensor: The output tensor after applying the convolutional block.
        """
        if self.batch_norm:
            out = self.bn(self.conv(x))
        else:
            out = self.conv(x)

        if self.activation == 'relu':
            return self.relu(out)
        elif self.activation == 'lrelu':
            return self.lrelu(out)
        elif self.activation == 'tanh':
            return self.tanh(out)
        elif self.activation == 'no_act':
            return out


class DeconvBlock(torch.nn.Module):
    """
    A deconvolutional block that consists of a deconvolutional layer and an optional batch normalization layer
    followed by an activation function.

    Args:
        input_size (int): The number of input channels.
        output_size (int): The number of output channels.
        kernel_size (int, optional): The size of the kernel. Default is 3.
        stride (int, optional): The stride of the deconvolution. Default is 2.
        padding (int, optional): The padding added to the input. Default is 1.
        output_padding (int, optional): The additional size added to one side of the output shape. Default is 1.
        activation (str, optional): The activation function to be applied. Default is 'relu'.
        batch_norm (bool, optional): Whether to apply batch normalization. Default is True.
    """
    def __init__(self, input_size, output_size, kernel_size=3, stride=2, padding=1, output_padding=1, activation='relu', batch_norm=True):
        super(DeconvBlock, self).__init__()
        self.deconv = torch.nn.ConvTranspose2d(input_size, output_size, kernel_size, stride, padding, output_padding)
        self.batch_norm = batch_norm
        self.bn = torch.nn.InstanceNorm2d(output_size)
        self.activation = activation
        self.relu = torch.nn.ReLU(True)

    def forward(self, x):
    """
    Performs a forward pass of the DeconvBlock.

    Args:
        x (torch.Tensor): The input tensor.

    Returns:
        torch.Tensor: The output tensor.
    """
        if self.batch_norm:
            out = self.bn(self.deconv(x))
        else:
            out = self.deconv(x)

        if self.activation == 'relu':
            return self.relu(out)
        elif self.activation == 'lrelu':
            return self.lrelu(out)
        elif self.activation == 'tanh':
            return self.tanh(out)
        elif self.activation == 'no_act':
            return out


class ResnetBlock(torch.nn.Module):
    """
    A residual block that consists of two convolutional layers, each followed by batch normalization and ReLU activation,
    and an additional reflection padding layer. It adds the input tensor to the output tensor of the block to create a
    residual connection.
    
    Args:
        num_filter (int): The number of filters in the convolutional layers.
        kernel_size (int, optional): The size of the kernel in the convolutional layers. Default is 3.
        stride (int, optional): The stride of the convolutional layers. Default is 1.
        padding (int, optional): The padding added to the input by the reflection padding layer. Default is 0.
    """
    def __init__(self, num_filter, kernel_size=3, stride=1, padding=0):
        super(ResnetBlock, self).__init__()
        conv1 = torch.nn.Conv2d(num_filter, num_filter, kernel_size, stride, padding)
        conv2 = torch.nn.Conv2d(num_filter, num_filter, kernel_size, stride, padding)
        bn = torch.nn.InstanceNorm2d(num_filter)
        relu = torch.nn.ReLU(True)
        pad = torch.nn.ReflectionPad2d(1)

        self.resnet_block = torch.nn.Sequential(
            pad,
            conv1,
            bn,
            relu,
            pad,
            conv2,
            bn
        )

    def forward(self, x):
        """
        Apply forward pass of the residual block on the input tensor x and return the output tensor.
        
        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, num_filter, height, width).
        
        Returns:
            out (torch.Tensor): Output tensor of shape (batch_size, num_filter, height, width).
        """
        out = self.resnet_block(x)
        return out


class Generator(torch.nn.Module):
    """
    A generator neural network model for image-to-image translation tasks.

    Args:
        input_dim (int): The number of channels in the input image.
        num_filter (int): The number of filters in the first convolutional layer of the encoder.
        output_dim (int): The number of channels in the output image.
        num_resnet (int): The number of residual blocks in the generator.

    Attributes:
        pad (torch.nn.ReflectionPad2d): The reflection padding layer.
        conv1 (ConvBlock): The first convolutional block of the encoder.
        conv2 (ConvBlock): The second convolutional block of the encoder.
        conv3 (ConvBlock): The third convolutional block of the encoder.
        resnet_blocks (torch.nn.Sequential): The sequence of residual blocks in the generator.
        deconv1 (DeconvBlock): The first deconvolutional block of the decoder.
        deconv2 (DeconvBlock): The second deconvolutional block of the decoder.
        deconv3 (ConvBlock): The third convolutional block of the decoder.

    Methods:
        forward(x): Performs a forward pass through the generator.
        normal_weight_init(mean, std): Initializes the weights of the generator with normally distributed random values.

    """
    def __init__(self, input_dim, num_filter, output_dim, num_resnet):
        super(Generator, self).__init__()

        # Reflection padding
        self.pad = torch.nn.ReflectionPad2d(3)
        # Encoder
        self.conv1 = ConvBlock(input_dim, num_filter, kernel_size=7, stride=1, padding=0)
        self.conv2 = ConvBlock(num_filter, num_filter * 2)
        self.conv3 = ConvBlock(num_filter * 2, num_filter * 4)
        # Resnet blocks
        self.resnet_blocks = []
        for i in range(num_resnet):
            self.resnet_blocks.append(ResnetBlock(num_filter * 4))
        self.resnet_blocks = torch.nn.Sequential(*self.resnet_blocks)
        # Decoder
        self.deconv1 = DeconvBlock(num_filter * 4, num_filter * 2)
        self.deconv2 = DeconvBlock(num_filter * 2, num_filter)
        self.deconv3 = ConvBlock(num_filter, output_dim, kernel_size=7, stride=1, padding=0, activation='tanh', batch_norm=False)

    def forward(self, x):
        """
        Performs a forward pass through the generator.

        Args:
            x (torch.Tensor): The input image tensor.

        Returns:
            The output image tensor generated by the generator.
        """
        # Encoder
        enc1 = self.conv1(self.pad(x))
        enc2 = self.conv2(enc1)
        enc3 = self.conv3(enc2)
        # Resnet blocks
        res = self.resnet_blocks(enc3)
        # Decoder
        dec1 = self.deconv1(res)
        dec2 = self.deconv2(dec1)
        out = self.deconv3(self.pad(dec2))
        return out

    def normal_weight_init(self, mean=0.0, std=0.02):
        """
        Initializes the weights of the generator with normally distributed random values.

        Args:
            mean (float): The mean of the normal distribution.
            std (float): The standard deviation of the normal distribution.
        """
        for m in self.children():
            if isinstance(m, ConvBlock):
                torch.nn.init.normal(m.conv.weight, mean, std)
            if isinstance(m, DeconvBlock):
                torch.nn.init.normal(m.deconv.weight, mean, std)
            if isinstance(m, ResnetBlock):
                torch.nn.init.normal(m.conv.weight, mean, std)
                torch.nn.init.constant(m.conv.bias, 0)


class Discriminator(torch.nn.Module):
    """
    A convolutional neural network that is used as a discriminator in a Generative Adversarial Network (GAN).

    Args:
        input_dim (int): the number of input channels for the first convolutional layer.
        num_filter (int): the number of filters in the first convolutional layer.
        output_dim (int): the number of output channels for the last convolutional layer.
    """
    def __init__(self, input_dim, num_filter, output_dim):
        super(Discriminator, self).__init__()

        conv1 = ConvBlock(input_dim, num_filter, kernel_size=4, stride=2, padding=1, activation='lrelu', batch_norm=False)
        conv2 = ConvBlock(num_filter, num_filter * 2, kernel_size=4, stride=2, padding=1, activation='lrelu')
        conv3 = ConvBlock(num_filter * 2, num_filter * 4, kernel_size=4, stride=2, padding=1, activation='lrelu')
        conv4 = ConvBlock(num_filter * 4, num_filter * 8, kernel_size=4, stride=1, padding=1, activation='lrelu')
        conv5 = ConvBlock(num_filter * 8, output_dim, kernel_size=4, stride=1, padding=1, activation='no_act', batch_norm=False)

        self.conv_blocks = torch.nn.Sequential(
            conv1,
            conv2,
            conv3,
            conv4,
            conv5
        )

    def forward(self, x):
        """
        Feeds the input tensor through the discriminator network.

        Args:
            x (torch.Tensor): the input tensor.

        Returns:
            The output tensor after it has passed through the discriminator network.
        """
        out = self.conv_blocks(x)
        return out

    def normal_weight_init(self, mean=0.0, std=0.02):
        """
        Initializes the weights of the convolutional layers using a normal distribution with the given mean and standard deviation.

        Args:
            mean (float): the mean of the normal distribution (default=0.0).
            std (float): the standard deviation of the normal distribution (default=0.02).
        """
        for m in self.children():
            if isinstance(m, ConvBlock):
                torch.nn.init.normal(m.conv.weight, mean, std)

The generators G_A and G_B are defined with 3 input channels, 'params.ngf' number of filters in the first layer, 3 output channels, and 'params.num_resnet' number of residual blocks. The discriminators D_A and D_B are defined with 3 input channels, 'params.ndf' number of filters in the first layer, and 1 output channel. The normal_weight_init method is called on each of the generators and discriminators to initialize their weights. Lastly, the models are moved to the GPU by calling the 'cuda()' method on each of them.

In [None]:
G_A = Generator(3, params.ngf, 3, params.num_resnet) # input_dim, num_filter, output_dim, num_resnet
G_B = Generator(3, params.ngf, 3, params.num_resnet)

D_A = Discriminator(3, params.ndf, 1) # input_dim, num_filter, output_dim
D_B = Discriminator(3, params.ndf, 1)

G_A.normal_weight_init(mean=0.0, std=0.02)
G_B.normal_weight_init(mean=0.0, std=0.02)
D_A.normal_weight_init(mean=0.0, std=0.02)
D_B.normal_weight_init(mean=0.0, std=0.02)

print(G_A.cuda())
print(G_B.cuda())
print(D_A.cuda())
print(D_B.cuda())


  torch.nn.init.normal(m.conv.weight, mean, std)
  torch.nn.init.normal(m.deconv.weight, mean, std)


Generator(
  (pad): ReflectionPad2d((3, 3, 3, 3))
  (conv1): ConvBlock(
    (conv): Conv2d(3, 32, kernel_size=(7, 7), stride=(1, 1))
    (bn): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (relu): ReLU(inplace=True)
    (lrelu): LeakyReLU(negative_slope=0.2, inplace=True)
    (tanh): Tanh()
  )
  (conv2): ConvBlock(
    (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (bn): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (relu): ReLU(inplace=True)
    (lrelu): LeakyReLU(negative_slope=0.2, inplace=True)
    (tanh): Tanh()
  )
  (conv3): ConvBlock(
    (conv): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (bn): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (relu): ReLU(inplace=True)
    (lrelu): LeakyReLU(negative_slope=0.2, inplace=True)
    (tanh): Tanh()
  )
  (resnet_blocks): Sequential(
    (0): ResnetB

Initializing the three optimizers:

1. G_optimizer for optimizing the generators G_A and G_B with the Adam optimizer
2. D_A_optimizer for optimizing the discriminator D_A with the Adam optimizer
3. D_B_optimizer for optimizing the discriminator D_B with the Adam optimizer

In [None]:
G_optimizer = torch.optim.Adam(itertools.chain(G_A.parameters(), G_B.parameters()), lr=params.lrG, betas=(params.beta1, params.beta2))
D_A_optimizer = torch.optim.Adam(D_A.parameters(), lr=params.lrD, betas=(params.beta1, params.beta2))
D_B_optimizer = torch.optim.Adam(D_B.parameters(), lr=params.lrD, betas=(params.beta1, params.beta2))

# Defining Auxiliary Functions

In [None]:
def to_np(x):
    """
    Converts a PyTorch tensor to a NumPy array on the CPU.

    Args:
    x: A PyTorch tensor.

    Returns:
    A NumPy array on the CPU.
    """
    return x.data.cpu().numpy()


def to_var(x):
    """
    Converts a tensor to a PyTorch Variable and moves it to the GPU if CUDA is available.

    Args:
    x: The tensor to be converted.

    Returns:
    A PyTorch Variable containing the input tensor, moved to the GPU if available.
    """
    if torch.cuda.is_available():
        x = x.cuda()
    return Variable(x)


# De-normalization
def denorm(x):
    """
    De-normalizes the input tensor by scaling it from the range [-1, 1] to [0, 1].

    Args:
    x (torch.Tensor): Input tensor to be de-normalized.

    Returns:
    torch.Tensor: De-normalized tensor.
    """
    out = (x + 1) / 2
    return out.clamp(0, 1)


# Plot losses
def plot_loss(avg_losses, num_epochs, save=False, save_dir='results/', show=False):
    """
    Plots the losses of a GAN model.

    Args:
        avg_losses (list): A list of average loss values for each model (D_A, D_B, G_A, G_B, cycle_A, cycle_B) at different epochs.
        num_epochs (int): The total number of epochs.
        save (bool, optional): If True, saves the plot to a file. Defaults to False.
        save_dir (str, optional): The directory where the plot should be saved. Defaults to 'results/'.
        show (bool, optional): If True, displays the plot. Defaults to False.

    Returns:
        None
    """
    fig, ax = plt.subplots()
    ax.set_xlim(0, num_epochs)
    temp = 0.0
    for i in range(len(avg_losses)):
        temp = max(np.max(avg_losses[i]), temp)
    ax.set_ylim(0, temp*1.1)
    plt.xlabel('# of Epochs')
    plt.ylabel('Loss values')

    plt.plot(avg_losses[0], label='D_A')
    plt.plot(avg_losses[1], label='D_B')
    plt.plot(avg_losses[2], label='G_A')
    plt.plot(avg_losses[3], label='G_B')
    plt.plot(avg_losses[4], label='cycle_A')
    plt.plot(avg_losses[5], label='cycle_B')
    plt.legend()

    # save figure
    if save:
        if not os.path.exists(save_dir):
            os.mkdir(save_dir)
        save_fn = save_dir + 'Loss_values_epoch_{:d}'.format(num_epochs) + '.png'
        plt.savefig(save_fn)

    if show:
        plt.show()
    else:
        plt.close()


def plot_train_result(real_image, gen_image, recon_image, epoch, save=False, save_dir='results/', show=False, fig_size=(5, 5)):
    """
    Plots a grid of real images, generated images, and reconstructed images produced by a GAN model at a given epoch.

    Args:
        real_image (torch.Tensor): A tensor of real images.
        gen_image (torch.Tensor): A tensor of generated images.
        recon_image (torch.Tensor): A tensor of reconstructed images.
        epoch (int): The epoch number.
        save (bool, optional): If True, saves the plot to a file. Defaults to False.
        save_dir (str, optional): The directory where the plot should be saved. Defaults to 'results/'.
        show (bool, optional): If True, displays the plot. Defaults to False.
        fig_size (tuple, optional): The size of the figure. Defaults to (5, 5).

    Returns:
        None
    """
    fig, axes = plt.subplots(2, 3, figsize=fig_size)

    imgs = [to_np(real_image[0]), to_np(gen_image[0]), to_np(recon_image[0]),
            to_np(real_image[1]), to_np(gen_image[1]), to_np(recon_image[1])]
    for ax, img in zip(axes.flatten(), imgs):
        ax.axis('off')
        ax.set_adjustable('box')
        # Scale to 0-255
        img = img.squeeze()
        img = (((img - img.min()) * 255) / (img.max() - img.min())).transpose(1, 2, 0).astype(np.uint8)
        ax.imshow(img, cmap=None, aspect='equal')
    plt.subplots_adjust(wspace=0, hspace=0)

    title = 'Epoch {0}'.format(epoch + 1)
    fig.text(0.5, 0.04, title, ha='center')

    # save figure
    if save:
        if not os.path.exists(save_dir):
            os.mkdir(save_dir)

        save_fn = save_dir + 'Result_epoch_{:d}'.format(epoch+1) + '.png'
        plt.savefig(save_fn)

    if show:
        plt.show()
    else:
        plt.close()


def plot_test_result(real_image, gen_image, recon_image, index, save=False, save_dir='results/', show=False):
    """
    Plots the real image, generated image, and reconstructed image for a single test sample.

    Args:
    - real_image: torch.Tensor of shape (batch_size, channels, height, width) representing the real image.
    - gen_image: torch.Tensor of shape (batch_size, channels, height, width) representing the generated image.
    - recon_image: torch.Tensor of shape (batch_size, channels, height, width) representing the reconstructed image.
    - index: int representing the index of the test sample.
    - save: bool flag indicating whether to save the plot to a file.
    - save_dir: str representing the directory where the plot will be saved.
    - show: bool flag indicating whether to display the plot.
    
    Returns: None.
    """
    fig_size = (real_image.size(2) * 3 / 100, real_image.size(3) / 100)
    fig, axes = plt.subplots(1, 3, figsize=fig_size)

    imgs = [to_np(real_image), to_np(gen_image), to_np(recon_image)]
    for ax, img in zip(axes.flatten(), imgs):
        ax.axis('off')
        ax.set_adjustable('box')
        # Scale to 0-255
        img = img.squeeze()
        img = (((img - img.min()) * 255) / (img.max() - img.min())).transpose(1, 2, 0).astype(np.uint8)
        ax.imshow(img, cmap=None, aspect='equal')
    plt.subplots_adjust(wspace=0, hspace=0)

    # save figure
    if save:
        if not os.path.exists(save_dir):
            os.mkdir(save_dir)

        save_fn = save_dir + 'Test_result_{:d}'.format(index + 1) + '.png'
        fig.subplots_adjust(bottom=0)
        fig.subplots_adjust(top=1)
        fig.subplots_adjust(right=1)
        fig.subplots_adjust(left=0)
        plt.savefig(save_fn)

    if show:
        plt.show()
    else:
        plt.close()

# Make gif
def make_gif(dataset, num_epochs, save_dir='results/', source_dir='results/'):
    """
    Create a GIF by combining all the image plots saved during the training of a CycleGAN model.

    Args:
        dataset (str): Name of the dataset.
        num_epochs (int): Number of epochs for which the image plots have been saved.
        save_dir (str, optional): Directory to save the generated GIF. Defaults to 'results/'.
        source_dir (str, optional): Directory where the image plots are saved. Defaults to 'results/'.
    """
    gen_image_plots = []
    for epoch in range(num_epochs):
        # plot for generating gif
        save_fn = source_dir + 'Result_epoch_{:d}'.format(epoch + 1) + '.png'
        gen_image_plots.append(imageio.imread(save_fn))

    imageio.mimsave(save_dir + dataset + '_CycleGAN_epochs_{:d}'.format(num_epochs) + '.gif', gen_image_plots, fps=5)


class ImagePool():
    """
    Class for implementing an image pool for CycleGAN training.

    Args:
        pool_size (int): The maximum number of images to store in the pool.

    Attributes:
        pool_size (int): The maximum number of images to store in the pool.
        num_imgs (int): The current number of images in the pool.
        images (list): A list of images currently in the pool.
    """
    def __init__(self, pool_size):
        self.pool_size = pool_size
        if self.pool_size > 0:
            self.num_imgs = 0
            self.images = []

    def query(self, images):
        """
        Query the image pool to retrieve a set of images.

        If the pool is not full, the input images are added to the pool and returned
        without modification. Otherwise, each input image is either added to the pool
        with a probability of 0.5, or a random image from the pool is returned instead.

        Args:
            images (torch.Tensor): A tensor of input images to retrieve from the pool.

        Returns:
            torch.Tensor: A tensor of output images, either the original input images or
            images retrieved from the pool.
        """
        if self.pool_size == 0:
            return images
        return_images = []
        for image in images.data:
            image = torch.unsqueeze(image, 0)
            if self.num_imgs < self.pool_size:
                self.num_imgs = self.num_imgs + 1
                self.images.append(image)
                return_images.append(image)
            else:
                p = random.uniform(0, 1)
                if p > 0.5:
                    random_id = random.randint(0, self.pool_size-1)
                    tmp = self.images[random_id].clone()
                    self.images[random_id] = image
                    return_images.append(tmp)
                else:
                    return_images.append(image)
        return_images = Variable(torch.cat(return_images, 0))
        return return_images

Set up the loss functions to be used during training and initializes some lists to store the average losses during training. Also initialize image pools to store generated images for use in training the generators.

In [None]:
MSE_Loss = torch.nn.MSELoss().cuda()
L1_Loss = torch.nn.L1Loss().cuda()

# # Training GAN
D_A_avg_losses = []
D_B_avg_losses = []
G_A_avg_losses = []
G_B_avg_losses = []
cycle_A_avg_losses = []
cycle_B_avg_losses = []

# Generated image pool
num_pool = 50
fake_A_pool = ImagePool(num_pool)
fake_B_pool = ImagePool(num_pool)

# Model Training

In [None]:
step = 0
for epoch in range(params.num_epochs):
    D_A_losses = []
    D_B_losses = []
    G_A_losses = []
    G_B_losses = []
    cycle_A_losses = []
    cycle_B_losses = []
    
    # Learing rate decay
    if(epoch + 1) > params.decay_epoch:
        D_A_optimizer.param_groups[0]['lr'] -= params.lrD / (params.num_epochs - params.decay_epoch)
        D_B_optimizer.param_groups[0]['lr'] -= params.lrD / (params.num_epochs - params.decay_epoch)
        G_optimizer.param_groups[0]['lr'] -= params.lrG / (params.num_epochs - params.decay_epoch)
        
    
    # training
    for i, (real_A, real_B) in enumerate(zip(train_data_loader_A, train_data_loader_B)):
        
        # input image data
        real_A = Variable(real_A.cuda())
        real_B = Variable(real_B.cuda())
        
        # -------------------------- train generator G --------------------------
        # A --> B
        fake_B = G_A(real_A)
        D_B_fake_decision = D_B(fake_B)
        G_A_loss = MSE_Loss(D_B_fake_decision, Variable(torch.ones(D_B_fake_decision.size()).cuda()))
        
        # forward cycle loss
        recon_A = G_B(fake_B)
        cycle_A_loss = L1_Loss(recon_A, real_A) * params.lambdaA
        
        # B --> A
        fake_A = G_B(real_B)
        D_A_fake_decision = D_A(fake_A)
        G_B_loss = MSE_Loss(D_A_fake_decision, Variable(torch.ones(D_A_fake_decision.size()).cuda()))
        
        # backward cycle loss
        recon_B = G_A(fake_A)
        cycle_B_loss = L1_Loss(recon_B, real_B) * params.lambdaB
        
        # Back propagation
        G_loss = G_A_loss + G_B_loss + cycle_A_loss + cycle_B_loss
        G_optimizer.zero_grad()
        G_loss.backward()
        G_optimizer.step()
        
        
        # -------------------------- train discriminator D_A --------------------------
        D_A_real_decision = D_A(real_A)
        D_A_real_loss = MSE_Loss(D_A_real_decision, Variable(torch.ones(D_A_real_decision.size()).cuda()))
        
        fake_A = fake_A_pool.query(fake_A)
        
        D_A_fake_decision = D_A(fake_A)
        D_A_fake_loss = MSE_Loss(D_A_fake_decision, Variable(torch.zeros(D_A_fake_decision.size()).cuda()))
        
        # Back propagation
        D_A_loss = (D_A_real_loss + D_A_fake_loss) * 0.5
        D_A_optimizer.zero_grad()
        D_A_loss.backward()
        D_A_optimizer.step()
        
        # -------------------------- train discriminator D_B --------------------------
        D_B_real_decision = D_B(real_B)
        D_B_real_loss = MSE_Loss(D_B_real_decision, Variable(torch.ones(D_B_fake_decision.size()).cuda()))
        
        fake_B = fake_B_pool.query(fake_B)
        
        D_B_fake_decision = D_B(fake_B)
        D_B_fake_loss = MSE_Loss(D_B_fake_decision, Variable(torch.zeros(D_B_fake_decision.size()).cuda()))
        
        # Back propagation
        D_B_loss = (D_B_real_loss + D_B_fake_loss) * 0.5
        D_B_optimizer.zero_grad()
        D_B_loss.backward()
        D_B_optimizer.step()
        
        # ------------------------ Print -----------------------------
        # loss values
        D_A_losses.append(D_A_loss.data)
        D_B_losses.append(D_B_loss.data)
        G_A_losses.append(G_A_loss.data)
        G_B_losses.append(G_B_loss.data)
        cycle_A_losses.append(cycle_A_loss.data)
        cycle_B_losses.append(cycle_B_loss.data)

        if i%10 == 0:
            print('Epoch [%d/%d], Step [%d/%d], D_A_loss: %.4f, D_B_loss: %.4f, G_A_loss: %.4f, G_B_loss: %.4f'
                  % (epoch+1, params.num_epochs, i+1, len(train_data_loader_A), D_A_loss.data, D_B_loss.data, G_A_loss.data, G_B_loss.data))
        step += 1
        
    D_A_avg_loss = torch.mean(torch.FloatTensor(D_A_losses))
    D_B_avg_loss = torch.mean(torch.FloatTensor(D_B_losses))
    G_A_avg_loss = torch.mean(torch.FloatTensor(G_A_losses))
    G_B_avg_loss = torch.mean(torch.FloatTensor(G_B_losses))
    cycle_A_avg_loss = torch.mean(torch.FloatTensor(cycle_A_losses))
    cycle_B_avg_loss = torch.mean(torch.FloatTensor(cycle_B_losses))

    # avg loss values for plot
    D_A_avg_losses.append(D_A_avg_loss)
    D_B_avg_losses.append(D_B_avg_loss)
    G_A_avg_losses.append(G_A_avg_loss)
    G_B_avg_losses.append(G_B_avg_loss)
    cycle_A_avg_losses.append(cycle_A_avg_loss)
    cycle_B_avg_losses.append(cycle_B_avg_loss)

    # Show result for test image
    test_real_A = Variable(test_real_A_data.cuda())
    test_fake_B = G_A(test_real_A)
    test_recon_A = G_B(test_fake_B)

    test_real_B = Variable(test_real_B_data.cuda())
    test_fake_A = G_B(test_real_B)
    test_recon_B = G_A(test_fake_A)

    plot_train_result([test_real_A, test_real_B], [test_fake_B, test_fake_A], [test_recon_A, test_recon_B],
                            epoch, save=True, save_dir=save_dir)

    # log the images
    result_AtoB = np.concatenate((to_np(test_real_A), to_np(test_fake_B), to_np(test_recon_A)), axis=3)
    result_BtoA = np.concatenate((to_np(test_real_B), to_np(test_fake_A), to_np(test_recon_B)), axis=3)

    info = { 'result_AtoB': result_AtoB.transpose(0, 2, 3, 1),  # convert to BxHxWxC
             'result_BtoA': result_BtoA.transpose(0, 2, 3, 1) }

Epoch [1/70], Step [1/1231], D_A_loss: 0.9004, D_B_loss: 0.4739, G_A_loss: 0.7661, G_B_loss: 1.6461
Epoch [1/70], Step [11/1231], D_A_loss: 0.3759, D_B_loss: 0.2603, G_A_loss: 0.3970, G_B_loss: 0.4130
Epoch [1/70], Step [21/1231], D_A_loss: 0.2525, D_B_loss: 0.2814, G_A_loss: 0.3459, G_B_loss: 0.3882
Epoch [1/70], Step [31/1231], D_A_loss: 0.1928, D_B_loss: 0.2898, G_A_loss: 0.3215, G_B_loss: 0.3438
Epoch [1/70], Step [41/1231], D_A_loss: 0.2498, D_B_loss: 0.2496, G_A_loss: 0.3088, G_B_loss: 0.4098
Epoch [1/70], Step [51/1231], D_A_loss: 0.2251, D_B_loss: 0.1930, G_A_loss: 0.4088, G_B_loss: 0.3963
Epoch [1/70], Step [61/1231], D_A_loss: 0.2801, D_B_loss: 0.1386, G_A_loss: 0.4704, G_B_loss: 0.3859
Epoch [1/70], Step [71/1231], D_A_loss: 0.2683, D_B_loss: 0.2249, G_A_loss: 0.4105, G_B_loss: 0.2886
Epoch [1/70], Step [81/1231], D_A_loss: 0.1987, D_B_loss: 0.1831, G_A_loss: 0.3772, G_B_loss: 0.4262
Epoch [1/70], Step [91/1231], D_A_loss: 0.1454, D_B_loss: 0.1863, G_A_loss: 0.3858, G_B_loss

Plotting and saving the average losses over the epochs and also creating a GIF of the generated images.

In [None]:
# Plot average losses
avg_losses = []
avg_losses.append(D_A_avg_losses)
avg_losses.append(D_B_avg_losses)
avg_losses.append(G_A_avg_losses)
avg_losses.append(G_B_avg_losses)
avg_losses.append(cycle_A_avg_losses)
avg_losses.append(cycle_B_avg_losses)
plot_loss(avg_losses, params.num_epochs, save=True, save_dir=plot_gif_dir)

# Make gif
make_gif(params.dataset, params.num_epochs, save_dir=plot_gif_dir, source_dir=save_dir)

  gen_image_plots.append(imageio.imread(save_fn))


# Testing the Model

Generating test results for both the A-to-B and B-to-A directions. For each direction, we loop through the test data loader and generate fake images using the corresponding generator

In [None]:
for i, real_A in enumerate(test_data_loader_A):
    # input image data
    real_A = Variable(real_A.cuda())
    
    # A --> B --> A
    fake_B = G_A(real_A)
    recon_A = G_B(fake_B)
    
    # Show result for test data
    plot_test_result(real_A, fake_B, recon_A, i, save=True, save_dir=test_res_dir + 'AtoB/')

    print('%d images are generated.' % (i + 1))

for i, real_B in enumerate(test_data_loader_B):

    # input image data
    real_B = Variable(real_B.cuda())

    # B -> A -> B
    fake_A = G_B(real_B)
    recon_B = G_A(fake_A)

    # Show result for test data
    plot_test_result(real_B, fake_A, recon_B, i, save=True, save_dir=test_res_dir + 'BtoA/')

    print('%d images are generated.' % (i + 1))

1 images are generated.
2 images are generated.
3 images are generated.
4 images are generated.
5 images are generated.
6 images are generated.
7 images are generated.
8 images are generated.
9 images are generated.
10 images are generated.
11 images are generated.
12 images are generated.
13 images are generated.
14 images are generated.
15 images are generated.
16 images are generated.
17 images are generated.
18 images are generated.
19 images are generated.
20 images are generated.
21 images are generated.
22 images are generated.
23 images are generated.
24 images are generated.
25 images are generated.
26 images are generated.
27 images are generated.
28 images are generated.
29 images are generated.
30 images are generated.
31 images are generated.
32 images are generated.
33 images are generated.
34 images are generated.
35 images are generated.
36 images are generated.
37 images are generated.
38 images are generated.
39 images are generated.
40 images are generated.
41 images

# Saving the results

In [None]:
!zip -r /content/results.zip /content/results
!zip -r /content/test_results.zip /content/test_results

from google.colab import files
files.download("/content/results.zip")
files.download("/content/test_results.zip")

  adding: content/results/ (stored 0%)
  adding: content/results/plot_gifLoss_values_epoch_1.png (deflated 18%)
  adding: content/results/summer2winter_CycleGAN_epochs_1.gif (deflated 1%)
  adding: content/results/Loss_values_epoch_1.png (deflated 18%)
  adding: content/results/plot_gif/ (stored 0%)
  adding: content/results/plot_gif/summer2winter_CycleGAN_epochs_1.gif (deflated 1%)
  adding: content/results/plot_gif/Loss_values_epoch_1.png (deflated 18%)
  adding: content/results/Result_epoch_1.png (deflated 1%)
  adding: content/test_results/ (stored 0%)
  adding: content/test_results/BtoA/ (stored 0%)
  adding: content/test_results/BtoA/Test_result_57.png (deflated 0%)
  adding: content/test_results/BtoA/Test_result_216.png (deflated 0%)
  adding: content/test_results/BtoA/Test_result_78.png (deflated 0%)
  adding: content/test_results/BtoA/Test_result_169.png (deflated 0%)
  adding: content/test_results/BtoA/Test_result_81.png (deflated 0%)
  adding: content/test_results/BtoA/Test_

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>