#**CycleGAN applying to monet2photo**

In this notebook we are going to implement the CycleGAN architecture for the monet2photo dataset.

## Setup

In [0]:
##############################
#  SETUP PYTHON ENVIRONMENT  #  
##############################

!pip install torch torchvision
conda update pytorch torchvision

In [0]:
!wget -c https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/monet2photo.zip -O monet2photo.zip
!unzip monet2photo.zip

[1;30;43mLe flux de sortie a été tronqué et ne contient que les 5000 dernières lignes.[0m
  inflating: monet2photo/trainB/2014-01-10 12:46:39.jpg  
  inflating: monet2photo/trainB/2015-02-23 11:45:00.jpg  
  inflating: monet2photo/trainB/2015-05-23 09:38:08.jpg  
  inflating: monet2photo/trainB/2016-01-15 00:48:49.jpg  
  inflating: monet2photo/trainB/2016-12-27 18:15:26.jpg  
  inflating: monet2photo/trainB/2017-01-02 11:16:05.jpg  
  inflating: monet2photo/trainB/2015-08-06 11:08:32.jpg  
  inflating: monet2photo/trainB/2015-12-16 07:43:36.jpg  
  inflating: monet2photo/trainB/2016-03-25 12:57:28.jpg  
  inflating: monet2photo/trainB/2016-05-10 11:44:44.jpg  
  inflating: monet2photo/trainB/2015-03-11 08:01:52.jpg  
  inflating: monet2photo/trainB/2015-08-30 05:26:05.jpg  
  inflating: monet2photo/trainB/2015-10-03 19:00:56.jpg  
  inflating: monet2photo/trainB/2016-06-08 02:40:47.jpg  
  inflating: monet2photo/trainB/2016-02-28 23:11:19.jpg  
  inflating: monet2photo/trainB/2016-0

## Helper functions

In [0]:
import glob
import os
from PIL import Image

import torch
import torchvision.transforms as transforms
from torch.utils.data import Dataset

def init_normal_weights(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
        if hasattr(m, "bias") and m.bias is not None:
            torch.nn.init.constant_(m.bias.data, 0.0)
    elif classname.find("BatchNorm2d") != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)

def print_models(G_XtoY, G_YtoX, D_X, D_Y):
    
    """ Prints model information for the generators and discriminators. """
    print("                 G_XtoY                ")
    print("---------------------------------------")
    print(G_XtoY)
    print("---------------------------------------")

    print("                 G_YtoX                ")
    print("---------------------------------------")
    print(G_YtoX)
    print("---------------------------------------")

    print("                  D_X                  ")
    print("---------------------------------------")
    print(D_X)
    print("---------------------------------------")

    print("                  D_Y                  ")
    print("---------------------------------------")
    print(D_Y)
    print("---------------------------------------")
      
class LambdaLR:
    
    """ LambdaLR or LambdaLearningRate
        Allow us to decrease the learning rate from a specific epoch ('decay_start_epoch')
        This accelerates learning and become and allows to be more precise for larger epochs.
    """
    
    def __init__(self, n_epochs, offset, decay_start_epoch):
        self.n_epochs = n_epochs
        self.offset = offset
        self.decay_start_epoch = decay_start_epoch
        
    def step(self, epoch):
        return 1.0 - max(0, epoch + self.offset - self.decay_start_epoch) / (self.n_epochs - self.decay_start_epoch)

class ImageDataset(Dataset):
    
    """ ImageDataSet
        Allow us to preprocess/transform data into a desired format (resize, crop, normalize, rgb, tensor)
    """
    
    def __init__(self, root, transforms_, mode):
        self.transform = transforms.Compose(transforms_)
        self.files_X = sorted(glob.glob(os.path.join(root, "%sA" % mode) + "/*.*"))
        self.files_Y = sorted(glob.glob(os.path.join(root, "%sB" % mode) + "/*.*"))
        
    def __to_rgb(self, image):
        rgb_image = Image.new("RGB", image.size)
        rgb_image.paste(image)
        return rgb_image

    def __getitem__(self, index):
        image_X = Image.open(self.files_X[index % len(self.files_X)])
        image_Y = Image.open(self.files_Y[index % len(self.files_Y)])
        if image_X.mode != "RGB":
            image_X = self.__to_rgb(image_X)
        if image_Y.mode != "RGB":
            image_Y = self.__to_rgb(image_Y)
        item_X = self.transform(image_X)
        item_Y = self.transform(image_Y)
        return {"X": item_X, "Y": item_Y}
    
    def __len__(self):
        return max(len(self.files_X), len(self.files_Y))

## Architectures

In [0]:
import torch.nn as nn

class ResnetBlock(nn.Module):
    
    """Defines the architecture of a ResidualBlock
       Note: We decided to choose the number of ResidualBlock we would like to put in both generators
    """
    
    def __init__(self, in_features):
        super(ResnetBlock, self).__init__()

        self.block = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(in_features, in_features, 3),
            nn.InstanceNorm2d(in_features),
            nn.ReLU(inplace=True),
            nn.ReflectionPad2d(1),
            nn.Conv2d(in_features, in_features, 3),
            nn.InstanceNorm2d(in_features),
        )

    def forward(self, x):
        return x + self.block(x)

class Generator(nn.Module):

    """Defines the architecture of the generator network.
       Note: Both generators G_XtoY and G_YtoX have the same architecture in this assignment.
    """

    def __init__(self, opt):
        super(Generator, self).__init__()

        # Initial convolution block (first convolutional layer)
        out_features = 64
        model = [
            # Pads the input tensor using the reflection of the input boundary
            nn.ReflectionPad2d(opt.channels),
            nn.Conv2d(opt.channels, out_features, 7),
            # Applies Instance Normalization over a 4D input
            nn.InstanceNorm2d(out_features),
            nn.ReLU(inplace=True),
        ]
        in_features = out_features

        # Two convolution blocks for downsampling
        for _ in range(2):
            out_features *= 2
            model += [
                nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(inplace=True),
            ]
            in_features = out_features

        # Residual blocks
        for _ in range(opt.n_residual_blocks):
            model += [ResnetBlock(out_features)]

        # Two convolution blocks for upsampling
        for _ in range(2):
            out_features //= 2
            model += [
                # Upsamples a given multi-channel data. Double the spatial space.
                nn.Upsample(scale_factor=2),
                nn.Conv2d(in_features, out_features, 3, stride=1, padding=1),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(inplace=True),
            ]
            in_features = out_features

        # Output layer
        model += [nn.ReflectionPad2d(opt.channels), nn.Conv2d(out_features, opt.channels, 7), nn.Tanh()]

        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)

class Discriminator(nn.Module):
    
    """Defines the architecture of the discriminator network.
       Note: Both discriminators D_X and D_Y have the same architecture in this assignment.
    """
    
    def __init__(self, opt):
        super(Discriminator, self).__init__()

        # Calculate output shape of image discriminator
        self.output_shape = (1, opt.img_size // 2 ** 4, opt.img_size // 2 ** 4)

        def discriminator_block(in_filters, out_filters, normalize = True):
            """ Returns downsampling layers of each discriminator block """
            layers = [nn.Conv2d(in_filters, out_filters, 4, stride = 2, padding = 1)]
            if normalize:
                layers.append(nn.InstanceNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2, inplace = True))
            return layers

        self.model = nn.Sequential(
            *discriminator_block(opt.channels, 64, normalize = False),
            *discriminator_block(64, 128),
            *discriminator_block(128, 256),
            *discriminator_block(256, 512),
            nn.ZeroPad2d((1, 0, 1, 0)),
            nn.Conv2d(512, 1, 4, padding = 1)
        )

    def forward(self, x):
        return self.model(x)

## CycleGAN

In [0]:
##############################
#    MODEL INITIALIZATION    # 
##############################

def create_model(opt):
    
    """ Builds the generators and discriminators. """
    
    # Initialize generator and discriminator
    G_XtoY = Generator(opt)
    G_YtoX = Generator(opt)
    D_X = Discriminator(opt)
    D_Y = Discriminator(opt)
    
    print_models(G_XtoY, G_YtoX, D_X, D_Y)
    
    if opt.cuda:
        G_XtoY = G_XtoY.cuda()
        G_YtoX = G_YtoX.cuda()
        D_X = D_X.cuda()
        D_Y = D_Y.cuda()
        
    if opt.epoch != 0:
        # Load pretrained models
        G_XtoY.load_state_dict(torch.load("saved_models/G_XtoY_%d.pth" % opt.epoch))
        G_YtoX.load_state_dict(torch.load("saved_models/G_YtoX_%d.pth" % opt.epoch))
        D_X.load_state_dict(torch.load("saved_models/D_X_%d.pth" % opt.epoch))
        D_Y.load_state_dict(torch.load("saved_models/D_Y_%d.pth" % opt.epoch))
    else:
        # Initialize weights
        G_XtoY.apply(init_normal_weights)
        G_YtoX.apply(init_normal_weights)
        D_X.apply(init_normal_weights)
        D_Y.apply(init_normal_weights)
        
    return G_XtoY, G_YtoX, D_X, D_Y

##############################
#       MODEL TRAINING       # 
##############################
    
def training_loop(train_dataloader, test_dataloader, opt):
    
    """ Runs the training loop.
        * Saves checkpoint every opt.checkpoint_interval iterations
        * Saves generated samples every opt.sample_interval iterations
    """
    
    # Create generators and discriminators
    G_XtoY, G_YtoX, D_X, D_Y = create_model(opt)
    
    # Losses
    loss_GAN = torch.nn.MSELoss()
    loss_cycle = torch.nn.L1Loss()
    loss_identity = torch.nn.L1Loss()

    # Optimizers
    optimizer_G = torch.optim.Adam(itertools.chain(G_XtoY.parameters(), G_YtoX.parameters()), lr = opt.lr, betas = (opt.b1, opt.b2))
    optimizer_D_X = torch.optim.Adam(D_X.parameters(), lr = opt.lr, betas = (opt.b1, opt.b2))
    optimizer_D_Y = torch.optim.Adam(D_Y.parameters(), lr = opt.lr, betas = (opt.b1, opt.b2))

    # Learning rate update schedulers
    LambdaLR_schedular_G = torch.optim.lr_scheduler.LambdaLR(optimizer_G, lr_lambda = LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step)
    LambdaLR_scheduler_D_X = torch.optim.lr_scheduler.LambdaLR(optimizer_D_X, lr_lambda = LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step)
    LambdaLR_scheduler_D_Y = torch.optim.lr_scheduler.LambdaLR(optimizer_D_Y, lr_lambda = LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step)
    
    Tensor = torch.Tensor
    if opt.cuda:  
        loss_GAN.cuda()
        loss_cycle.cuda()
        loss_identity.cuda()
        Tensor = torch.cuda.FloatTensor

    def sample_images(batches_done):
        """Saves a generated sample from the test set"""
        imgs = next(iter(test_dataloader))
        G_XtoY.eval()
        G_YtoX.eval()
        real_X = Variable(imgs["X"].type(Tensor))
        fake_Y = G_XtoY(real_X)
        real_Y = Variable(imgs["Y"].type(Tensor))
        fake_X = G_YtoX(real_Y)
        # Arange images along x-axis
        real_X = make_grid(real_X, nrow = 5, normalize = True)
        real_Y = make_grid(real_Y, nrow = 5, normalize = True)
        fake_X = make_grid(fake_X, nrow = 5, normalize = True)
        fake_Y = make_grid(fake_Y, nrow = 5, normalize = True)
        # Arange images along y-axis
        image_grid = torch.cat((real_X, fake_Y, real_Y, fake_X), 1)
        save_image(image_grid, "images/%s.png" % batches_done, normalize = False)
    
    losses_models_G = pd.DataFrame(np.zeros((len(train_dataloader), opt.n_epochs - opt.epoch + 1)))
    losses_models_D = pd.DataFrame(np.zeros((len(train_dataloader), opt.n_epochs - opt.epoch + 1)))
    losses_models_G.columns = range(opt.epoch, opt.n_epochs+1)
    losses_models_D.columns = range(opt.epoch, opt.n_epochs+1)
    # Training
    for epoch in range(opt.epoch, opt.n_epochs):
        for i, batch in enumerate(train_dataloader):
            
            # Set model input
            real_X = Variable(batch["X"].type(Tensor))
            real_Y = Variable(batch["Y"].type(Tensor))
            
            # Adversarial ground truths
            valid = Variable(Tensor(np.ones((real_X.size(0), *D_X.output_shape))), requires_grad = False)
            fake = Variable(Tensor(np.zeros((real_X.size(0), *D_X.output_shape))), requires_grad = False)
                        
            # -----------------------
            #  Train Discriminator X
            # -----------------------
            
            optimizer_D_X.zero_grad()
        
            # Real loss
            loss_real_D_X = loss_GAN(D_X(real_X), valid)
            
            # Fake loss
            fake_X = G_YtoX(real_Y)
            loss_fake_D_X = loss_GAN(D_X(fake_X), fake)
            
            
            # Total loss
            loss_D_X = (loss_real_D_X + loss_fake_D_X) / 2
        
            loss_D_X.backward()
            optimizer_D_X.step()
            
            # -----------------------
            #  Train Discriminator Y
            # -----------------------
        
            optimizer_D_Y.zero_grad()
            
            # Real loss
            loss_real_D_Y = loss_GAN(D_Y(real_Y), valid)
            
            # Fake loss
            fake_Y = G_XtoY(real_X)
            loss_fake_D_Y = loss_GAN(D_Y(fake_Y), fake)
            
            # Total loss
            loss_D_Y = (loss_real_D_Y + loss_fake_D_Y) / 2
        
            loss_D_Y.backward()
            optimizer_D_Y.step()
        
            loss_D = (loss_D_X + loss_D_Y) / 2
            losses_models_D[epoch][i] = loss_D
            
            # ------------------------------
            #  Train Generators XtoY and YtoX
            # -------------------------------
            
            G_XtoY.train()
            G_YtoX.train()
            
            optimizer_G.zero_grad()
            
            # GAN loss
            fake_Y = G_XtoY(real_X)
            loss_GAN_XtoY = loss_GAN(D_Y(fake_Y), valid)
            fake_X = G_YtoX(real_Y)
            loss_GAN_YtoX = loss_GAN(D_X(fake_X), valid)
            loss_GAN_G = (loss_GAN_XtoY + loss_GAN_YtoX) / 2
            
            # Cycle loss
            recov_X = G_YtoX(fake_Y)
            loss_cycle_X = loss_cycle(recov_X, real_X)
            recov_Y = G_XtoY(fake_X)
            loss_cycle_Y = loss_cycle(recov_Y, real_Y)
            loss_cycle_G = (loss_cycle_X + loss_cycle_Y) / 2
            
            # Identity loss
            loss_identity_X = loss_identity(G_YtoX(real_X), real_X)
            loss_identity_Y = loss_identity(G_XtoY(real_Y), real_Y)
            loss_identity_G = (loss_identity_X + loss_identity_Y) / 2
            
            # Total loss
            loss_G = loss_GAN_G + opt.lambda_cyc * loss_cycle_G + opt.lambda_id * loss_identity_G
            
            loss_G.backward()
            optimizer_G.step()
            losses_models_G[epoch][i] = loss_G
        
            # --------------
            #  Log Progress
            # --------------
        
            batches_done = epoch * len(train_dataloader) + i
        
            # Print log
            sys.stdout.write(
                "\r[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
                % (
                    epoch,
                    opt.n_epochs,
                    i,
                    len(train_dataloader),
                    loss_D.item(),
                    loss_G.item()
                )
            )
        
            # Save sample image at interval
            if batches_done % opt.sample_interval == 0:
                sample_images(batches_done)
        
        # Update learning rates
        LambdaLR_schedular_G.step()
        LambdaLR_scheduler_D_X.step()
        LambdaLR_scheduler_D_Y.step()
        
        # Save discriminators and generators lossses at each epoch
        losses_models_G.to_pickle("losses_models/losses_models_G_%d" % epoch)
        losses_models_D.to_pickle("losses_models/losses_models_D_%d" % epoch)
            
        # Save model at checkpoints
        if opt.checkpoint_interval != -1 and epoch % opt.checkpoint_interval == 0:
            torch.save(G_XtoY.state_dict(), "saved_models/G_XtoY_%d.pth" % epoch)
            torch.save(G_YtoX.state_dict(), "saved_models/G_YtoX_%d.pth" % epoch)
            torch.save(D_X.state_dict(), "saved_models/D_X_%d.pth" % epoch)
            torch.save(D_Y.state_dict(), "saved_models/D_Y_%d.pth" % epoch)

Models moved to GPU.


In [0]:
training_loop(train_dataloader, test_dataloader, opt)

                 G_XtoY                
---------------------------------------
Generator(
  (model): Sequential(
    (0): ReflectionPad2d((3, 3, 3, 3))
    (1): Conv2d(3, 64, kernel_size=(7, 7), stride=(1, 1))
    (2): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (3): ReLU(inplace=True)
    (4): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (5): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (8): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (9): ReLU(inplace=True)
    (10): ResnetBlock(
      (block): Sequential(
        (0): ReflectionPad2d((1, 1, 1, 1))
        (1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
        (2): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
        (3)

KeyboardInterrupt: ignored