CycleGAN \\
Kasra Borazjani - 810196662 \\
Hamid Salemi - 810196479

# Part 1. Dataloader

In [40]:
import os
import cv2
import numpy as np
from abc import ABC, abstractmethod
import torch
from collections import OrderedDict
from torch.optim import lr_scheduler
import itertools
import torch.nn as nn
from torch.nn import init
import functools
import time

In [66]:
IMG_EXTENSIONS = [
    '.jpg', '.JPG', '.jpeg', '.JPEG',
    '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
    '.tif', '.TIF', '.tiff', '.TIFF',
]


def is_image_file(filename):
    return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)

In [41]:
class ImagePool():
    
    

    def __init__(self, pool_size):
        
        
        self.pool_size = pool_size
        if self.pool_size > 0:  # create an empty pool
            self.num_imgs = 0
            self.images = []

    def query(self, images):
        
        
        if self.pool_size == 0:  # if the buffer size is 0, do nothing
            return images
        return_images = []
        for image in images:
            image = torch.unsqueeze(image.data, 0)
            if self.num_imgs < self.pool_size:   # if the buffer is not full; keep inserting current images to the buffer
                self.num_imgs = self.num_imgs + 1
                self.images.append(image)
                return_images.append(image)
            else:
                p = random.uniform(0, 1)
                if p > 0.5:  # by 50% chance, the buffer will return a previously stored image, and insert the current image into the buffer
                    random_id = np.random.randint(0, self.pool_size - 1)  # randint is inclusive
                    tmp = self.images[random_id].clone()
                    self.images[random_id] = image
                    return_images.append(tmp)
                else:       # by another 50% chance, the buffer will return the current image
                    return_images.append(image)
        return_images = torch.cat(return_images, 0)   # collect all the images and return
        return return_images

In [68]:
def make_dataset(path, max_size):
    images = []
    for root, _, fnames in sorted(os.walk(path)):
        for fname in fnames:
            if is_image_file(fname):
                dir = os.path.join(root, fname)
                images.append(dir)
    return images[:min(max_size, len(images))]

In [43]:
class customDataSet():
  def __init__(self, dataroot, phase, max_size, direction, in_channels, out_channels, serial_batches):
    self.dir_A = os.path.join(dataroot, phase + 'A')  
    self.dir_B = os.path.join(dataroot, phase + 'B')  
    self.paths_A = sorted(make_dataset(self.dir_A, max_size))
    self.paths_B = sorted(make_dataset(self.dir_B, max_size))
    self.A_size = len(self.paths_A)
    self.B_size = len(self.paths_B)
    btoA = direction == 'BtoA'
    input_nc = self.out_channels if btoA else self.in_channels
    output_nc = self.in_channels if btoA else self.out_channels
    self.serial_batches = serial_batches

  def __getitem__(self, index):
    path_A =  self.paths_A[index % self.A_size]
    
    if self.serial_batches:
      index_B = index%self.B_size
    else:
      index_B = np.random.randint

    B_path = self.B_paths[index_B]
    A_img = Image.open(A_path).convert('RGB')
    B_img = Image.open(B_path).convert('RGB')

    return {'A': A, 'B': B, 'A_paths': A_path, 'B_paths': B_path}

  def __len__(self):
    return max(self.A_size, self.B_size)


In [64]:
class customDataLoader():

  def __init__(self, dataroot, phase, max_size, direction, in_channels, out_channels, serial_batches, batch_size, num_threads):
    self.dataset = customDataSet(dataroot, phase, max_size, direction, in_channels, out_channels, serial_batches)
    self.dataloader = torch.utils.data.DataLoader(self.dataset, batch_size = batch_size, shuffle=not serial_batches, num_workers=int(num_threads))
    self.max_size = max_size
    self.batch_size = batch_size
  
  def load_Data(self):
    return self
  
  def __len__(self):
    return(min(len(self.dataset), self.max_size))

  def __iter__(self):
    for i, data in enumerate(self.dataloader):
      if i * self.batch_size >= self.max_size:
        break
      yield data

In [45]:
def create_dataset(dataroot, phase, max_size, direction, in_channels, out_channels, serial_batches, batch_size, num_threads):
    
    
    data_loader = customDataLoader(dataroot, phase, max_size, direction, in_channels, out_channels, serial_batches, batch_size, num_threads)
    dataset = data_loader.load_data()
    return dataset

# Part 2. Model

## 2.1 Generator

### 2.1.1. Residual Blocks

In [46]:
class ResnetBlock(nn.Module):
  

    def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
        
        
        super(ResnetBlock, self).__init__()
        self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias)

    def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias):
        
        
        conv_block = []
        p = 0
        if padding_type == 'reflect':
            conv_block += [nn.ReflectionPad2d(1)]
        elif padding_type == 'replicate':
            conv_block += [nn.ReplicationPad2d(1)]
        elif padding_type == 'zero':
            p = 1
        else:
            raise NotImplementedError('padding [%s] is not implemented' % padding_type)

        conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim), nn.ReLU(True)]
        if use_dropout:
            conv_block += [nn.Dropout(0.5)]

        p = 0
        if padding_type == 'reflect':
            conv_block += [nn.ReflectionPad2d(1)]
        elif padding_type == 'replicate':
            conv_block += [nn.ReplicationPad2d(1)]
        elif padding_type == 'zero':
            p = 1
        else:
            raise NotImplementedError('padding [%s] is not implemented' % padding_type)
        conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim)]

        return nn.Sequential(*conv_block)

    def forward(self, x):
      
        out = x + self.conv_block(x)  # add skip connections
        return out

### 2.1.2. Resnet-Based Generator

In [47]:
class ResnetGenerator(nn.Module):
    

    def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect'):
        
        assert(n_blocks >= 0)
        super(ResnetGenerator, self).__init__()
        if type(norm_layer) == functools.partial:
            use_bias = norm_layer.func == nn.InstanceNorm2d
        else:
            use_bias = norm_layer == nn.InstanceNorm2d

        model = [nn.ReflectionPad2d(3),
                 nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
                 norm_layer(ngf),
                 nn.ReLU(True)]

        
        # add encoder layers
        n_downsampling = 2
        for i in range(n_downsampling):
            mult = 2 ** i
            model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),
                      norm_layer(ngf * mult * 2),
                      nn.ReLU(True)]

        mult = 2 ** n_downsampling

        # add residual blocks
        for i in range(n_blocks):

            model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]

        # add decoder layers
        for i in range(n_downsampling):
            mult = 2 ** (n_downsampling - i)
            model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
                                         kernel_size=3, stride=2,
                                         padding=1, output_padding=1,
                                         bias=use_bias),
                      norm_layer(int(ngf * mult / 2)),
                      nn.ReLU(True)]
        model += [nn.ReflectionPad2d(3)]
        model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
        model += [nn.Tanh()]

        self.model = nn.Sequential(*model)

    def forward(self, input):
        """Standard forward"""
        return self.model(input)

initializing the weights on cuda

In [48]:
def init_weights(net, init_type='normal', init_gain=0.02):
    
    def init_func(m):  # define the initialization function
        classname = m.__class__.__name__
        if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
            init.normal_(m.weight.data, 0.0, init_gain)
            if hasattr(m, 'bias') and m.bias is not None:
                init.constant_(m.bias.data, 0.0)
        elif classname.find('BatchNorm2d') != -1:
            init.normal_(m.weight.data, 1.0, init_gain)
            init.constant_(m.bias.data, 0.0)

    print('initialize network with %s' % init_type)
    net.apply(init_func)  # apply the initialization function <init_func>

initializing any network on cuda:

In [49]:
def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]):
    
    if len(gpu_ids) > 0:
        assert(torch.cuda.is_available())
        net.to(gpu_ids[0])
        net = torch.nn.DataParallel(net, gpu_ids)  # multi-GPUs
    init_weights(net, init_type, init_gain=init_gain)
    return net

batch normalization

In [50]:
def get_norm_layer(norm_type='batch'):
    
    if norm_type == 'batch':
        norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True)
    else:
        raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
    return norm_layer

total generator definition:

In [51]:
def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, init_type='normal', init_gain=0.02, gpu_ids=[]):
    
    net = None
    norm_layer = get_norm_layer(norm_type=norm)

    if netG == 'resnet_6blocks':
        net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=6)
    else:
        raise NotImplementedError('Generator model name [%s] is not recognized' % netG)
    return init_net(net, init_type, init_gain, gpu_ids)

## 2.2. Discriminator

### 2.2.1 PatchGAN

PatchGAN definition

In [52]:
class NLayerDiscriminator(nn.Module):
    
    super(NLayerDiscriminator, self).__init__()
        if type(norm_layer) == functools.partial:  # no need to use bias as BatchNorm2d has affine parameters
            use_bias = norm_layer.func == nn.InstanceNorm2d
        else:
            use_bias = norm_layer == nn.InstanceNorm2d

        kw = 4
        padw = 1
        sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
        nf_mult = 1
        nf_mult_prev = 1
        for n in range(1, n_layers):  # gradually increase the number of filters
            nf_mult_prev = nf_mult
            nf_mult = min(2 ** n, 8)
            sequence += [
                nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
                norm_layer(ndf * nf_mult),
                nn.LeakyReLU(0.2, True)
            ]

        nf_mult_prev = nf_mult
        nf_mult = min(2 ** n_layers, 8)
        sequence += [
            nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
            norm_layer(ndf * nf_mult),
            nn.LeakyReLU(0.2, True)
        ]

        sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]  # output 1 channel prediction map
        self.model = nn.Sequential(*sequence)

    def forward(self, input):
        """Standard forward."""
        return self.model(input)

ParchGAN discriminator initialization

In [53]:
def define_D(input_nc, ndf, netD, n_layers_D=3, norm='batch', init_type='normal', init_gain=0.02, gpu_ids=[]):
    
    net = None
    norm_layer = get_norm_layer(norm_type=norm)

    if netD == 'basic':  # default PatchGAN classifier
        net = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer)
    else:
        raise NotImplementedError('Discriminator model name [%s] is not recognized' % netD)
    return init_net(net, init_type, init_gain, gpu_ids)

## 2.3. GAN Loss

In [54]:
class GANLoss(nn.Module):

    def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0):
        
        super(GANLoss, self).__init__()
        self.register_buffer('real_label', torch.tensor(target_real_label))
        self.register_buffer('fake_label', torch.tensor(target_fake_label))
        self.gan_mode = gan_mode
        if gan_mode == 'lsgan':
            self.loss = nn.MSELoss()
        else:
            raise NotImplementedError('gan mode %s not implemented' % gan_mode)

    def get_target_tensor(self, prediction, target_is_real):
        
        if target_is_real:
            target_tensor = self.real_label
        else:
            target_tensor = self.fake_label
        return target_tensor.expand_as(prediction)

    def __call__(self, prediction, target_is_real):
      
      if self.gan_mode in ['lsgan']:
            target_tensor = self.get_target_tensor(prediction, target_is_real)
            loss = self.loss(prediction, target_tensor)
        return loss

## 2.4. Model Definition

define scheduler:

In [55]:
def get_scheduler(optimizer, lr_policy, epoch_count, n_epochs, n_epochs_decay, lr_decay_iters):
    
    if lr_policy == 'linear':
        def lambda_rule(epoch):
            lr_l = 1.0 - max(0, epoch + epoch_count - n_epochs) / float(n_epochs_decay + 1)
            return lr_l
        scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
    else:
        return NotImplementedError('learning rate policy [%s] is not implemented', lr_policy)
    return scheduler

base model placeholder:

In [56]:
class BaseModel(ABC):
  def __init__(self, gpu_ids, isTrain, checkpoints_dir, name, preprocess, lr_policy, epoch_count, n_epochs, n_epochs_decay, lr_decay_iters, continue_train, load_iter, epoch, verbose):
        
        self.gpu_ids = gpu_ids
        self.isTrain = isTrain
        self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu')  # get device name: CPU or GPU
        self.save_dir = os.path.join(checkpoints_dir, name)  # save all the checkpoints to save_dir
        if preprocess != 'scale_width':  # with [scale_width], input images might have different sizes, which hurts the performance of cudnn.benchmark.
            torch.backends.cudnn.benchmark = True
        self.loss_names = []
        self.model_names = []
        self.visual_names = []
        self.optimizers = []
        self.image_paths = []
        self.metric = 0  # used for learning rate policy 'plateau'
        self.lr_policy = lr_policy
        self.epoch_count = epoch_count
        self.n_epochs = n_epochs
        self.n_epochs_decay = n_epochs_decay
        self.lr_decay_iters = lr_decay_iters
        self.continue_iter = continue_iter
        self.load_iter = load_iter
        self.epoch = epoch
        self.verbose = verbose

    @abstractmethod
    def set_input(self, input):
        
        pass

    @abstractmethod
    def forward(self):
        
        pass

    @abstractmethod
    def optimize_parameters(self):
        
        pass

    def setup(self):
        
        if self.isTrain:
            self.schedulers = [get_scheduler(optimizer, self.lr_policy, self.epoch_count, self.n_epochs, self.n_epochs_decay, self.lr_decay_iters) for optimizer in self.optimizers]
        if not self.isTrain or self.continue_train:
            load_suffix = 'iter_%d' % self.load_iter if self.load_iter > 0 else self.epoch
            self.load_networks(load_suffix)
        self.print_networks(self.verbose)

    def eval(self):
        
        for name in self.model_names:
            if isinstance(name, str):
                net = getattr(self, 'net' + name)
                net.eval()

    def test(self):
        
        with torch.no_grad():
            self.forward()
            self.compute_visuals()

    def compute_visuals(self):
        
        pass

    def get_image_paths(self):
        return self.image_paths

    def update_learning_rate(self):
        old_lr = self.optimizers[0].param_groups[0]['lr']
        for scheduler in self.schedulers:
            if self.lr_policy == 'plateau':
                scheduler.step(self.metric)
            else:
                scheduler.step()

        lr = self.optimizers[0].param_groups[0]['lr']
        print('learning rate %.7f -> %.7f' % (old_lr, lr))

    def get_current_visuals(self):
        visual_ret = OrderedDict()
        for name in self.visual_names:
            if isinstance(name, str):
                visual_ret[name] = getattr(self, name)
        return visual_ret

    def get_current_losses(self):
        errors_ret = OrderedDict()
        for name in self.loss_names:
            if isinstance(name, str):
                errors_ret[name] = float(getattr(self, 'loss_' + name))  # float(...) works for both scalar tensor and float number
        return errors_ret

    def save_networks(self, epoch):
        for name in self.model_names:
            if isinstance(name, str):
                save_filename = '%s_net_%s.pth' % (epoch, name)
                save_path = os.path.join(self.save_dir, save_filename)
                net = getattr(self, 'net' + name)

                if len(self.gpu_ids) > 0 and torch.cuda.is_available():
                    torch.save(net.module.cpu().state_dict(), save_path)
                    net.cuda(self.gpu_ids[0])
                else:
                    torch.save(net.cpu().state_dict(), save_path)

    def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0):
        key = keys[i]
        if i + 1 == len(keys):  # at the end, pointing to a parameter/buffer
            if module.__class__.__name__.startswith('InstanceNorm') and \
                    (key == 'running_mean' or key == 'running_var'):
                if getattr(module, key) is None:
                    state_dict.pop('.'.join(keys))
            if module.__class__.__name__.startswith('InstanceNorm') and \
               (key == 'num_batches_tracked'):
                state_dict.pop('.'.join(keys))
        else:
            self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1)

    def load_networks(self, epoch):
        for name in self.model_names:
            if isinstance(name, str):
                load_filename = '%s_net_%s.pth' % (epoch, name)
                load_path = os.path.join(self.save_dir, load_filename)
                net = getattr(self, 'net' + name)
                if isinstance(net, torch.nn.DataParallel):
                    net = net.module
                print('loading the model from %s' % load_path)
                # if you are using PyTorch newer than 0.4 (e.g., built from
                # GitHub source), you can remove str() on self.device
                state_dict = torch.load(load_path, map_location=str(self.device))
                if hasattr(state_dict, '_metadata'):
                    del state_dict._metadata

                # patch InstanceNorm checkpoints prior to 0.4
                for key in list(state_dict.keys()):  # need to copy keys here because we mutate in loop
                    self.__patch_instance_norm_state_dict(state_dict, net, key.split('.'))
                net.load_state_dict(state_dict)

    def print_networks(self, verbose):
        print('---------- Networks initialized -------------')
        for name in self.model_names:
            if isinstance(name, str):
                net = getattr(self, 'net' + name)
                num_params = 0
                for param in net.parameters():
                    num_params += param.numel()
                if verbose:
                    print(net)
                print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6))
        print('-----------------------------------------------')

    def set_requires_grad(self, nets, requires_grad=False):
        if not isinstance(nets, list):
            nets = [nets]
        for net in nets:
            if net is not None:
                for param in net.parameters():
                    param.requires_grad = requires_grad


CycleGAN model class, which is a subclass of the base model class:

In [57]:
class CycleGANModel(BaseModel):

    def __init__(self, gpu_ids, isTrain, checkpoints_dir, name, preprocess, lr_policy, epoch_count, n_epochs, n_epochs_decay,
                 lr_decay_iters, continue_train, load_iter, epoch, verbose, input_nc, output_nc, ngf, netG, norm, no_dropout,
                 init_type, init_gain, ndf, netD, n_layers_D, pool_size, gan_mode, lr, beta1, direction):
      
        BaseModel.__init__(self, gpu_ids, isTrain, checkpoints_dir, name, preprocess, lr_policy, epoch_count, n_epochs, n_epochs_decay, lr_decay_iters, continue_train, load_iter, epoch, verbose, lambda_A = 10.0,
                           lambda_B = 10.0, lambda_identity = 0.5)
        # specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses>
        self.loss_names = ['D_A', 'G_A', 'cycle_A', 'idt_A', 'D_B', 'G_B', 'cycle_B', 'idt_B']
        self.lambda_A = lambda_A
        self.lambda_B = lambda_B
        self.lambda_identity = lambda_identity
        self.direction = direction
        # specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals>
        visual_names_A = ['real_A', 'fake_B', 'rec_A']
        visual_names_B = ['real_B', 'fake_A', 'rec_B']
        if self.isTrain and self.lambda_identity > 0.0:  # if identity loss is used, we also visualize idt_B=G_A(B) ad idt_A=G_A(B)
            visual_names_A.append('idt_B')
            visual_names_B.append('idt_A')

        self.visual_names = visual_names_A + visual_names_B  # combine visualizations for A and B
        # specify the models you want to save to the disk. The training/test scripts will call <BaseModel.save_networks> and <BaseModel.load_networks>.
        if self.isTrain:
            self.model_names = ['G_A', 'G_B', 'D_A', 'D_B']
        else:  # during test time, only load Gs
            self.model_names = ['G_A', 'G_B']

        # define networks (both Generators and discriminators)
        # The naming is different from those used in the paper.
        # Code (vs. paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)
        self.netG_A = define_G(input_nc, output_nc, ngf, netG, norm,
                                        not no_dropout, init_type, init_gain, self.gpu_ids)
        self.netG_B = define_G(output_nc, input_nc, ngf, netG, norm,
                                        not no_dropout, init_type, init_gain, self.gpu_ids)

        if self.isTrain:  # define discriminators
            self.netD_A = define_D(output_nc, ndf, netD,
                                            n_layers_D, norm, init_type, init_gain, self.gpu_ids)
            self.netD_B = define_D(input_nc, ndf, netD,
                                            n_layers_D, norm, init_type, init_gain, self.gpu_ids)

        if self.isTrain:
            if self.lambda_identity > 0.0:  # only works when input and output images have the same number of channels
                assert(input_nc == output_nc)
            self.fake_A_pool = ImagePool(pool_size)  # create image buffer to store previously generated images
            self.fake_B_pool = ImagePool(pool_size)  # create image buffer to store previously generated images
            # define loss functions
            self.criterionGAN = GANLoss(gan_mode).to(self.device)  # define GAN loss.
            self.criterionCycle = torch.nn.L1Loss()
            self.criterionIdt = torch.nn.L1Loss()
            # initialize optimizers; schedulers will be automatically created by function <BaseModel.setup>.
            self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()), lr=lr, betas=(beta1, 0.999))
            self.optimizer_D = torch.optim.Adam(itertools.chain(self.netD_A.parameters(), self.netD_B.parameters()), lr=lr, betas=(beta1, 0.999))
            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_D)

    def set_input(self, input):
        AtoB = self.direction == 'AtoB'
        self.real_A = input['A' if AtoB else 'B'].to(self.device)
        self.real_B = input['B' if AtoB else 'A'].to(self.device)
        self.image_paths = input['A_paths' if AtoB else 'B_paths']

    def forward(self):
        self.fake_B = self.netG_A(self.real_A)  # G_A(A)
        self.rec_A = self.netG_B(self.fake_B)   # G_B(G_A(A))
        self.fake_A = self.netG_B(self.real_B)  # G_B(B)
        self.rec_B = self.netG_A(self.fake_A)   # G_A(G_B(B))

    def backward_D_basic(self, netD, real, fake):
        # Real
        pred_real = netD(real)
        loss_D_real = self.criterionGAN(pred_real, True)
        # Fake
        pred_fake = netD(fake.detach())
        loss_D_fake = self.criterionGAN(pred_fake, False)
        # Combined loss and calculate gradients
        loss_D = (loss_D_real + loss_D_fake) * 0.5
        loss_D.backward()
        return loss_D

    def backward_D_A(self):
        fake_B = self.fake_B_pool.query(self.fake_B)
        self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B)

    def backward_D_B(self):
        fake_A = self.fake_A_pool.query(self.fake_A)
        self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A)

    def backward_G(self):
        lambda_idt = self.lambda_identity
        lambda_A = self.lambda_A
        lambda_B = self.lambda_B
        # Identity loss
        if lambda_idt > 0:
            # G_A should be identity if real_B is fed: ||G_A(B) - B||
            self.idt_A = self.netG_A(self.real_B)
            self.loss_idt_A = self.criterionIdt(self.idt_A, self.real_B) * lambda_B * lambda_idt
            # G_B should be identity if real_A is fed: ||G_B(A) - A||
            self.idt_B = self.netG_B(self.real_A)
            self.loss_idt_B = self.criterionIdt(self.idt_B, self.real_A) * lambda_A * lambda_idt
        else:
            self.loss_idt_A = 0
            self.loss_idt_B = 0

        # GAN loss D_A(G_A(A))
        self.loss_G_A = self.criterionGAN(self.netD_A(self.fake_B), True)
        # GAN loss D_B(G_B(B))
        self.loss_G_B = self.criterionGAN(self.netD_B(self.fake_A), True)
        # Forward cycle loss || G_B(G_A(A)) - A||
        self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A
        # Backward cycle loss || G_A(G_B(B)) - B||
        self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B
        # combined loss and calculate gradients
        self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B
        self.loss_G.backward()

    def optimize_parameters(self):
        # forward
        self.forward()      # compute fake images and reconstruction images.
        # G_A and G_B
        self.set_requires_grad([self.netD_A, self.netD_B], False)  # Ds require no gradients when optimizing Gs
        self.optimizer_G.zero_grad()  # set G_A and G_B's gradients to zero
        self.backward_G()             # calculate gradients for G_A and G_B
        self.optimizer_G.step()       # update G_A and G_B's weights
        # D_A and D_B
        self.set_requires_grad([self.netD_A, self.netD_B], True)
        self.optimizer_D.zero_grad()   # set D_A and D_B's gradients to zero
        self.backward_D_A()      # calculate gradients for D_A
        self.backward_D_B()      # calculate graidents for D_B
        self.optimizer_D.step()  # update D_A and D_B's weights


In [58]:
def create_model(self, gpu_ids, isTrain, checkpoints_dir, name, preprocess, lr_policy, epoch_count, n_epochs, n_epochs_decay,
                 lr_decay_iters, continue_train, load_iter, epoch, verbose, input_nc, output_nc, ngf, netG, norm, no_dropout,
                 init_type, init_gain, ndf, netD, n_layers_D, pool_size, gan_mode, lr, beta1, direction):
  
    instance = CycleGANModel(gpu_ids, isTrain, checkpoints_dir, name, preprocess, lr_policy, epoch_count, n_epochs, n_epochs_decay,
                 lr_decay_iters, continue_train, load_iter, epoch, verbose, input_nc, output_nc, ngf, netG, norm, no_dropout,
                 init_type, init_gain, ndf, netD, n_layers_D, pool_size, gan_mode, lr, beta1, direction)
    print("model was created")
    return instance

In [62]:
def train_model(
    dataroot,
    name,
    model,
    gpu_ids,
    checkpoints_dir,
    input_nc,
    output_nc,
    ngf,
    ndf,
    netD,
    netG,
    n_layers_D,
    norm,
    init_type,
    init_gain,
    no_dropout,
    direction,
    serial_batches,
    num_threads,
    batch_size,
    max_size,
    epoch,
    load_iter,
    verbose,
    display_freq,
    display_ncols,
    display_id,
    phase,
    epoch_count,
    continue_train,
    save_by_iter,
    save_epoch_freq,
    save_latest_freq,
    n_epochs,
    n_epochs_decay, 
    beta1,
    lr,
    gan_mode,
    pool_size,
    lr_policy,
    lr_decay_iters
):
       # get training options
    dataset = create_dataset(dataroot, phase, max_size, direction, input_nc, output_nc, serial_batches, batch_size, num_threads)  # create a dataset given opt.dataset_mode and other options
    dataset_size = len(dataset)    # get the number of images in the dataset.
    print('The number of training images = %d' % dataset_size)

    model = create_model(gpu_ids, isTrain, checkpoints_dir, name, preprocess, lr_policy, epoch_count, n_epochs, n_epochs_decay,
                 lr_decay_iters, continue_train, load_iter, epoch, verbose, input_nc, output_nc, ngf, netG, norm, no_dropout,
                 init_type, init_gain, ndf, netD, n_layers_D, pool_size, gan_mode, lr, beta1, direction)      # create a model given opt.model and other options
    model.setup()               # regular setup: load and print networks; create schedulers
    total_iters = 0                # the total number of training iterations

    for epoch in range(opt.epoch_count, opt.n_epochs + opt.n_epochs_decay + 1):    # outer loop for different epochs; we save the model by <epoch_count>, <epoch_count>+<save_latest_freq>
        epoch_start_time = time.time()  # timer for entire epoch
        iter_data_time = time.time()    # timer for data loading per iteration
        epoch_iter = 0                  # the number of training iterations in current epoch, reset to 0 every epoch
        model.update_learning_rate()    # update learning rates in the beginning of every epoch.
        for i, data in enumerate(dataset):  # inner loop within one epoch
            iter_start_time = time.time()  # timer for computation per iteration
            if total_iters % opt.print_freq == 0:
                t_data = iter_start_time - iter_data_time

            total_iters += opt.batch_size
            epoch_iter += opt.batch_size
            model.set_input(data)         # unpack data from dataset and apply preprocessing
            model.optimize_parameters()   # calculate loss functions, get gradients, update network weights

            if total_iters % opt.display_freq == 0:   # display images on visdom and save images to a HTML file
                save_result = total_iters % opt.update_html_freq == 0
                model.compute_visuals()

            if total_iters % opt.print_freq == 0:    # print training losses and save logging information to the disk
                losses = model.get_current_losses()
                t_comp = (time.time() - iter_start_time) / opt.batch_size
                visualizer.print_current_losses(epoch, epoch_iter, losses, t_comp, t_data)

            if total_iters % opt.save_latest_freq == 0:   # cache our latest model every <save_latest_freq> iterations
                print('saving the latest model (epoch %d, total_iters %d)' % (epoch, total_iters))
                save_suffix = 'iter_%d' % total_iters if opt.save_by_iter else 'latest'
                model.save_networks(save_suffix)

            iter_data_time = time.time()
        if epoch % opt.save_epoch_freq == 0:              # cache our model every <save_epoch_freq> epochs
            print('saving the model at the end of epoch %d, iters %d' % (epoch, total_iters))
            model.save_networks('latest')
            model.save_networks(epoch)

        print('End of epoch %d / %d \t Time Taken: %d sec' % (epoch, opt.n_epochs + opt.n_epochs_decay, time.time() - epoch_start_time))

# Part 3. Training

## 3.1. Downloading the dataset

In [None]:
from google.colab import drive

drive.mount('/content/gdrive')

os.chdir('gdrive/My Drive/NNDL - Spring 00/Mini Project 3')


In [33]:
!wget -N 'https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/horse2zebra.zip' 


--2021-07-25 18:52:56--  https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/horse2zebra.zip
Resolving people.eecs.berkeley.edu (people.eecs.berkeley.edu)... 128.32.244.190
Connecting to people.eecs.berkeley.edu (people.eecs.berkeley.edu)|128.32.244.190|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 116867962 (111M) [application/zip]
Saving to: ‘horse2zebra.zip’


2021-07-25 18:53:35 (2.85 MB/s) - ‘horse2zebra.zip’ saved [116867962/116867962]



In [None]:
!mkdir ./datasets
!mkdir ./datasets/horse2zebra
!unzip './datasets/horse2zebra.zip' -d ./datasets/
!rm './datasets/horse2zebra.zip'

## 3.2. Specifying the constants needed to run the code

In [60]:
dataroot = './datasets/horse2zebra'
name = 'horse2zebra'
model = 'cycle_gan'
gpu_ids = '0'
checkpoints_dir = './checkpoints'
input_nc = 3
output_nc = 3
ngf = 65
ndf = 64
netD = 'basic'
netG = 'resnet_6blocks'
n_layers_D = 3
norm = 'batch'
init_type= 'normal'
init_gain = 0.02
no_dropout = True
direction = 'AtoB'
serial_batches = True
num_threads = 4
batch_size = 4
max_size = 100
epoch = 'latest'
load_iter = 0
verbose = True
display_freq = 50
display_ncols = 4
display_id = 1
phase = 'train'
epoch_count= 1
continue_train = True
save_by_iter = True
save_epoch_freq = 5
save_latest_freq = 5000
n_epochs = 3
n_epochs_decay = 2 
beta1 = 0.5
lr = 0.0002
gan_mode = 'lsgan'
pool_size = 50
lr_policy = 'linear'
lr_decay_iters = 2

In [None]:
train_model(
    dataroot,
    name,
    model,
    gpu_ids,
    checkpoints_dir,
    input_nc,
    output_nc,
    ngf,
    ndf,
    netD,
    netG,
    n_layers_D,
    norm,
    init_type,
    init_gain,
    no_dropout,
    direction,
    serial_batches,
    num_threads,
    batch_size,
    max_size,
    epoch,
    load_iter,
    verbose,
    display_freq,
    display_ncols,
    display_id,
    phase,
    epoch_count,
    continue_train,
    save_by_iter,
    save_epoch_freq,
    save_latest_freq,
    n_epochs,
    n_epochs_decay, 
    beta1,
    lr,
    gan_mode,
    pool_size,
    lr_policy,
    lr_decay_iters
)