# Implementation of Cycle-GAN in Jupyter Notebook

To run this notebook on Google COLAB: <br>
1. Change runtime type to GPU<br>
2. Execute all cells

In [None]:
"""
A pytorch implementation of the Cycle-GAN architecture used for generating art in aivie.
The original paper by Zhu et al. can be found at: https://arxiv.org/pdf/1703.10593.pdf
This implementation is based on their more complete and performant official pytorch implementation, which can be found at: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix
This implementation was inspired by https://github.com/aitorzip/PyTorch-CycleGAN. However, this implementation is not a verbatim copy of the aforementioned implementation, as it was heavily modified and completely rewritten for practical and educational purposes. 
"""

In [None]:
!unzip /content/data_img_A.zip -d /content/data_img_A
!unzip /content/data_img_B.zip -d /content/data_img_B

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.nn.init as init
import torch.utils.data as data
import torch.autograd as autograd
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import itertools
import numpy as np
import os
from tqdm import tqdm

In [None]:
class ResNetBlock(nn.Module):
    """
    A residual block used in the generator of this Cycle-GAN.
    """
    def __init__(self, in_channels: int, r_padding: int = 1, kernel_size : int = 3):
        """
        Creates a ResNetBlock instance
        
        Args:
            in_channels (int) : number of channels in input image
            r_padding (int) : size of padding for left, right, top and bottom for torch.nn.ReflectionPad2d. Default is 1.
            kernel_size (int) : height and width for the 2D convolutional window in torch.nn.Conv2d layer. Default is 3.
        """
        super().__init__()

        #set in_channels and out_channels to be the same for convolutional layers
        self._out_channels = in_channels 
        
        #build model
        self.model = nn.Sequential(
            nn.ReflectionPad2d(padding = r_padding),
            nn.Conv2d(in_channels = in_channels, out_channels = self._out_channels, kernel_size=kernel_size),
            nn.InstanceNorm2d(num_features=in_channels),
            nn.ReLU(inplace=True),
            nn.ReflectionPad2d(padding = r_padding),
            nn.Conv2d(in_channels = in_channels, out_channels = self._out_channels, kernel_size=kernel_size),
            nn.InstanceNorm2d(num_features=in_channels)
        )
    
    def forward(self, x):
        """Concatenates tensors in forward"""
        return x + self.model(x)

In [None]:
class Generator(nn.Module):
    """Generator for the Cycle-GAN"""

    def __init__(self, start_in: int = 3, start_out: int = 64, end_out: int = 3, ends_kernel_size: int = 7, mid_kernel_size: int = 3, r_padding: int = 3, padding: int = 1, stride: int = 2, n_resnet = 9) -> None:
        """
        Creates a Generator instance. Will have c7s1-64, d128, d256, R256, R256, R256, R256, R256,R256, R256, R256, R256, u256, u128, c7s1-3 architecture. 

        Args:
            start_in (int): number of channels in input tensor. Default is 3.
            start_out (int): number of channels produced by first torch.nn.Conv2d layer. Default is 64.
            end_out (int): number of channels in final tensor. Default is 3.
            ends_kernel_size (int): height and width for the 2D convolutional window in first and last torch.nn.Conv2d layer. Default is 7.
            mid_kernel_size (int): height and width for the 2D convolutional window in middle torch.nn.Conv2d layers. Default is 3.
            r_padding (int) : size of padding for left, right, top and bottom for torch.nn.ReflectionPad2d. Default is 3.
            padding (int): size of zero-padding in torch.nn.Conv2d layer. Default is 1.
            stride (int) : stride argument for filter in torch.nn.Conv2d layer. Default is 2.
            n_resnet (int) : determines the number of resnet blocks in model. Default is 9.
        """
        super().__init__()
        
        # define constants
        self.NUM_DOWNSAMPLE = 2
        self.NUM_UPSAMPLE = 2

        # c7s1-64 block
        self._arg_model = [
            nn.ReflectionPad2d(padding=r_padding),
            nn.Conv2d(in_channels = start_in, out_channels = start_out, kernel_size = ends_kernel_size),
            nn.InstanceNorm2d(num_features=start_out),
            nn.ReLU(inplace=True)
        ]

        self._in_channels = start_in
        self._out_channels = start_out

        #d128 & d256 block
        for _ in range(self.NUM_DOWNSAMPLE):
            self._in_channels = self._out_channels
            self._out_channels *= 2
            self._arg_model += self._downsample(in_channels=self._in_channels, out_channels=self._out_channels, padding = padding, kernel_size=mid_kernel_size, stride=stride)

        #R256 blocks
        for _ in range(n_resnet):
            self._arg_model.append(ResNetBlock(in_channels=self._out_channels))
        
        
        #u128 & u64 blocks
        for _ in range(self.NUM_UPSAMPLE):
            self._in_channels = self._out_channels
            self._out_channels = self._in_channels // 2
            self._arg_model += self._upsample(in_channels=self._in_channels, out_channels=self._out_channels, kernel_size=mid_kernel_size, padding=padding, output_padding=padding, stride = stride)

        #output layer
        self._arg_model += [
            nn.ReflectionPad2d(padding=r_padding),
            nn.Conv2d(in_channels = self._out_channels, out_channels=end_out, kernel_size=ends_kernel_size),
            nn.Tanh()
        ]

        #build model
        self.model = nn.Sequential(*self._arg_model)


    def forward(self, x):
        """Standard forward"""
        return self.model(x)

    
    def _downsample(self, in_channels: int, out_channels: int, padding: int, kernel_size: int, stride: int) -> list:
        """
        Creates downsampling block for generator

        Args:
            in_channels (int): number of channels in input tensor
            out_channels (int): number of channels produced by torch.nn.Conv2d layer
            padding (int): size of zero-padding in torch.nn.Conv2d layer.
            kernel_size (int): height and width for the 2D convolutional window in torch.nn.Conv2d layer.
            stride (int) : stride argument for filter in torch.nn.Conv2d layer.
        Returns:
            list: list with torch.nn.Conv2d, torch.nn.InstanceNorm2d, nn.ReLU
        """
        cur = [
            nn.Conv2d(in_channels= in_channels, out_channels=out_channels, kernel_size=kernel_size, padding = padding, stride=stride),
            nn.InstanceNorm2d(num_features=out_channels),
            nn.ReLU(inplace=True)
        ]
        return cur


    def _upsample(self, in_channels: int, out_channels: int, kernel_size: int, padding: int, output_padding: int, stride: int) -> list:
        """
        Creates upsampling block for generator

        Args:
            in_channels (int): number of channels in input tensor
            out_channels (int): number of channels produced by torch.nn.Conv2d layer
            padding (int): size of zero-padding in torch.nn.Conv2d layer.
            output_padding (int) : controls additional size added to output of nn.ConvTranspose2d
            stride (int) : stride argument for filter in torch.nn.Conv2d layer.
            kernel_size (int): height and width for the 2D convolutional window in torch.nn.ConvTranspose2d layer.
        Returns:
            list: list with torch.nn.ConvTranspose2d, torch.nn.InstanceNorm2d, nn.ReLU
        """
        cur = [
            nn.ConvTranspose2d(in_channels= in_channels, out_channels=out_channels, kernel_size=kernel_size, padding = padding, output_padding=output_padding, stride=stride),
            nn.InstanceNorm2d(num_features=out_channels),
            nn.ReLU(inplace=True)
        ]
        return cur

In [None]:
class Discriminator(nn.Module):
    """Discriminator for the Cycle-GAN"""

    def __init__(self, start_in: int = 3, start_out: int = 64, kernel_size: int = 4, padding: int = 1, stride: int = 2, negative_slope: int = 0.2, num_groups: int = 4) -> None:
        """
        Creates Discriminator instance

        Args:
            start_in (int): number of channels in input tensor. Default is 3.
            start_out (int): number of channels produced by first torch.nn.Conv2d layer. Default is 64.
            kernel_size (int): height and width for the 2D convolutional window in torch.nn.Conv2d layer. Default is 4.
            padding (int): size of zero-padding in torch.nn.Conv2d layer. Default is 1.
            stride (int): stride argument for filter in torch.nn.Conv2d layer. Default is 2.
            negative_slope (int): determines the negative slope of the torch.nn.LeakyReLU layer. Default is 0.2
            num_groups (int): number of convolutional groups in model. Default is 4.
        """
        super().__init__()
        
        #creates a list to be passed into torch.nn.Sequential
        self._arg_model = self._build_conv_groups(in_channels=start_in, out_channels=start_out, kernel_size = kernel_size, padding = padding, stride = stride, negative_slope = negative_slope, normalization=False)
        
        self._in_channels = start_in
        self._out_channels = start_out

        #add groups with normalization to model
        for _ in range(num_groups-2):
            self._in_channels = self._out_channels
            self._out_channels *= 2
            self._arg_model += self._build_conv_groups(in_channels=self._in_channels, out_channels=self._out_channels, kernel_size = kernel_size, padding = padding, stride = stride, negative_slope = negative_slope)
        
        #add final conv group to model
        self._in_channels = self._out_channels
        self._out_channels *= 2
        self._arg_model += self._build_conv_groups(in_channels=self._in_channels, out_channels=self._out_channels, kernel_size = kernel_size, padding = padding, stride = stride, negative_slope = negative_slope, has_stride=False)

        #add dense classification layer
        self._arg_model.append(nn.Conv2d(in_channels = self._out_channels, out_channels = 1, kernel_size = kernel_size, padding = padding))

        self.model = nn.Sequential(*self._arg_model)

    def forward(self, x):
        """Forward and flatten"""
        x = self.model(x)
        return F.avg_pool2d(x, x.size()[2:]).view(x.size()[0], -1)

    def _build_conv_groups(self, in_channels: int, out_channels: int, kernel_size: int, padding: int, stride: int, negative_slope: int, normalization: bool  = True, has_stride: bool = True) -> list:
        """
        Builds convolutional 'group' consisting of torch.nn.Conv2d, (torch.nn.InstanceNorm2d), and torch.nn.LeakyReLU

        Args:
            in_channels (int): number of channels in input tensor
            out_channels (int): number of channels produced by torch.nn.Conv2d layer.
            kernel_size (int): height and width for the 2D convolutional window in torch.nn.Conv2d layer. 
            padding (int): size of zero-padding in torch.nn.Conv2d layer. 
            stride (int): stride argument for filter in torch.nn.Conv2d layer. 
            normalization (bool): determines whether or not the group will have normalization. Normalization if True, else no normalization.
            negative_slope (int): determines the negative slope of the torch.nn.LeakyReLU layer.
            has_stride (bool): determines whether the torch.nn.Conv2d layers has a stride. Default is True.

        Returns:
            list : list with a torch.nn.Conv2d, (torch.nn.InstanceNorm2d), and torch.nn.LeakyReLU to be a component of full Discriminator
        """
        cur = []

        #appends convolutional layer
        if has_stride:
            cur.append(nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride = stride, padding = padding))
        else:
            cur.append(nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, padding = padding))

        #appends normalization layer
        if normalization:
            cur.append(nn.InstanceNorm2d(num_features = out_channels))

        #appends LeakyReLU activation layer
        cur.append(nn.LeakyReLU(negative_slope = negative_slope, inplace = True))

        return cur

In [None]:
class DatasetAB(data.Dataset):
    """Dataset used for images in class A and class B"""
    
    def __init__(self, *datasets) -> None:
        """Creates DatasetAB instance"""
        self.datasets = datasets
    
    def __len__(self) -> int:
        """Returns length of max dataset"""
        length = [len(d) for d in self.datasets]

        return max(length)

    def __getitem__(self, index:int) -> tuple:
        """Returns an element from dataset A and B"""
        d = [dataset for dataset in self.datasets]
        assert len(d) == 2

        a = d[0][index%len(d[0])]
        b = d[1][index%len(d[1])]
        
        return a[0], b[0]

In [None]:
class CycleGAN:
    """complete Cycle-GAN class with two generators and two discriminators"""

    def __init__(self, lr: int = 0.0002, trainable: bool = False, lambda_a: float = 10.0, lambda_b: float = 10.0, lambda_identity: float = 0.5) -> None:
        """
        Creates Cycle-GAN instance
        
        Args:
            lr (int): learning rate for torch.optim.Adam in CycleGAN. Default is 0.0002.
            trainable (bool): determines whether or not the model can be trained. True for training, False for generating. Default is False
            lambda_a (float):
            lambda_b (float):
            lambda_identity (float): 
        """

        self.lr = lr
        self.trainable = trainable
        self.lambda_a = lambda_a
        self.lambda_b = lambda_b
        self.lambda_identity = lambda_identity

        # set up torch device if GPU is available
        self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

        # define generator and discriminator
        # generator and discriminator can be adjusted by overriding default params.
        self.generator_A2B = Generator().to(self.device)
        self.discriminator_A = Discriminator().to(self.device)
        self.generator_B2A = Generator().to(self.device)
        self.discriminator_B = Discriminator().to(self.device)

        #define loss functions for generators and discriminators
        self.criterion_idt = nn.L1Loss()
        self.criterion_cycle = nn.L1Loss()
        self.criterion_gan = nn.MSELoss()

        #optimizers for discriminators and generators
        self.optim_discriminator_A = optim.Adam(self.discriminator_A.parameters(), lr = self.lr, betas=(0.5, 0.999))
        self.optim_discriminator_B = optim.Adam(self.discriminator_B.parameters(), lr = self.lr, betas=(0.5, 0.999))
        self.optim_generator = optim.Adam(itertools.chain(self.generator_A2B.parameters(), self.generator_B2A.parameters()), lr = self.lr, betas=(0.5, 0.999))

        #data handling
        self.data_loader = None
    
    def load_data(self, path_A: str, path_B: str, batch_size: int = 32) -> None:
        """
        Loads data into the model

        Args:
            path_A (str): path to the directory containing the data for A
            path_B (str): path to the directory containing the data for B
            batch_size (int): batch size for the data. Default is 32.
        """

        #transformations to be applied to data. Modify as needed
        transformations = transforms.Compose(transforms=[
            transforms.Resize((256, 256)),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean = (0.5,0.5,0.5), std=(0.5,0.5,0.5))
        ])

        #apply transformations to image in given directory.
        images = DatasetAB(
            datasets.ImageFolder(path_A, transform=transformations),
            datasets.ImageFolder(path_B, transform=transformations)       
        )
        #load dataset into dataloader for use in training
        self.data_loader = data.DataLoader(
            dataset=images,
            batch_size=batch_size,
            shuffle=True,
            pin_memory=True
        )

        #ground truth for fake and real images
        self.target_reals = torch.Tensor(np.ones(batch_size)).to(self.device)
        self.target_fakes = torch.Tensor(np.zeros(batch_size)).to(self.device)

    def train(self, epochs: int = 50) -> None:
        """
        Trains Cycle-GAN model

        Args:
            epochs (int): number of epochs to train Cycle-GAN model for. Default is 50.
        """
        #raise necessary errors
        if not self.trainable:
            raise RuntimeError('Cannot train model when trainable is set to False')
        if self.data_loader is None:
            raise RuntimeError('No data loaded into the model for training')
        
        # iterate through epochs
        for epoch in range(epochs):
            
            t = tqdm(iter(self.data_loader), leave = False, total=len(self.data_loader))
            #iterate through batches
            for _, batch in enumerate(t):

                #current batch             
                batch_real_a = autograd.Variable(batch[0]).to(self.device)
                batch_real_b = autograd.Variable(batch[1]).to(self.device)
                
                #generator training
                self.optim_generator.zero_grad()

                #GAN losses
                fake_b = self.generator_A2B(batch_real_a)
                pred_fake = self.discriminator_B(fake_b)
                generator_A2B_loss = self.criterion_gan(pred_fake, self.target_reals)

                fake_a = self.generator_B2A(batch_real_b)
                pred_fake = self.discriminator_A(fake_a)
                generator_B2A_loss = self.criterion_gan(pred_fake, self.target_reals)

                #forward and backward cycle losses
                reconstructed_a = self.generator_B2A(fake_b)
                cycle_ABA_loss = self.criterion_cycle(reconstructed_a, batch_real_a) * self.lambda_a

                reconstructed_b = self.generator_A2B(fake_a)
                cycle_BAB_loss = self.criterion_cycle(reconstructed_b, batch_real_b) * self.lambda_b

                #identity losses
                identity_a = self.generator_B2A(batch_real_a)
                identity_a_loss = self.criterion_idt(identity_a, batch_real_a) * self.lambda_a * self.lambda_identity

                identity_b = self.generator_A2B(batch_real_b)
                identity_b_loss = self.criterion_idt(identity_b, batch_real_b) * self.lambda_b * self.lambda_identity

                generator_losses = generator_A2B_loss + generator_B2A_loss + cycle_ABA_loss + cycle_BAB_loss + identity_a_loss + identity_b_loss
                generator_losses.backward()

                self.optim_generator.step()

                #discriminator training
                self.optim_discriminator_A.zero_grad()

                pred_real = self.discriminator_A(batch_real_a)
                loss_d_real = self.criterion_gan(pred_real, self.target_reals)

                pred_fake = self.discriminator_A(fake_a.detach())
                loss_d_fake = self.criterion_gan(pred_fake, self.target_fakes)

                loss_discriminator_a = 0.5 * (loss_d_real + loss_d_fake)
                loss_discriminator_a.backward()

                self.optim_discriminator_A.step()

                self.optim_discriminator_B.zero_grad()

                pred_real = self.discriminator_A(batch_real_a)
                loss_d_real = self.criterion_gan(pred_real, self.target_reals)

                pred_fake = self.discriminator_B(fake_b.detach())
                loss_d_fake = self.criterion_gan(pred_fake, self.target_fakes)

                loss_discriminator_b = 0.5 * (loss_d_real + loss_d_fake)
                loss_discriminator_b.backward()

                self.optim_discriminator_B.step()
                
            torch.save(self.generator_A2B.state_dict(), '/content/generator_A2B.pth')
            torch.save(self.generator_B2A.state_dict(), '/content/generator_B2A.pth')
            torch.save(self.discriminator_A.state_dict(), '/content/discriminator_A.pth')
            torch.save(self.discriminator_B.state_dict(), '/content/discriminator_B.pth')

    def generate(self, img_dir: str, out_dir: str, direction: str) -> None:
        """
        Generates images using trained Cycle-GAN model.
        img_dir (str): directory from which images will be generated
        out_dir (str): directory where generated images will be outputted. 
        direction (str): direction of image generation. Valid arguments are 'A2B' and 'B2A'
        """
        #A2B
        if direction.lower() == 'a2b':
            temp_generator = Generator()

            #load in model from same .pth file path as in train() method
            temp_generator.load_state_dict(torch.load('/content/generator_A2B.pth'))

            #load dataset into dataloader for use in generating
            temp_data_loader = data.DataLoader(
                dataset=datasets.ImageFolder(img_dir, transform=self.transformations),
                batch_size=1,
                shuffle=True,
                pin_memory=True
            )
            with torch.no_grad():
                for index, batch in enumerate(temp_data_loader):
                    #generate image.
                    img = (temp_generator(batch[0]).data + 1)/2.0

                    #save image in directory
                    utils.save_image(img, os.path.join(out_dir, '{}.png'.format(index)))

        #B2A
        elif direction.lower() == 'b2a':
            temp_generator = Generator()

            #load in model from same .pth file path as in train() method
            temp_generator.load_state_dict(torch.load('/content/generator_B2A.pth'))
            
            #load dataset into dataloader for use in generating
            temp_data_loader = data.DataLoader(
                dataset=datasets.ImageFolder(img_dir, transform=self.transformations),
                batch_size=1,
                shuffle=True,
                pin_memory=True
            )
            with torch.no_grad():
                for index, batch in enumerate(temp_data_loader):
                    #generate image.
                    img = (temp_generator(batch[0]).data + 1)/2.0

                    #save image in directory
                    utils.save_image(img, os.path.join(out_dir, '{}.png'.format(index)))

        #raise error if invalid direction is passed
        else:
            raise ValueError('{} is not a valid direction'.format(direction))

In [None]:
c = CycleGAN()
c.trainable = True
c.load_data(path_A='/content/data_img_A', path_B='/content/data_img_B', batch_size=8)
c.train(epochs=75)
print('done')