In [None]:
import torch
import sys
import torch.utils.data as data
from torch.utils.data import DataLoader, Dataset
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
import torchvision.transforms as transforms
from PIL import Image
import os
import numpy as np
import random
import copy
from torch.autograd import Variable
import itertools
from collections import OrderedDict
import time
import functools
from torch.nn import init
from abc import ABC, abstractmethod

In [None]:
BATCH_SIZE = 1
GPU_IDS = [i for i in range(torch.cuda.device_count())]
DEVICE = torch.device('cuda:{}'.format(GPU_IDS[0])) if GPU_IDS else torch.device('cpu')
LEARNING_RATE = 1e-5
LAMBDA_IDENTITY = 0.5
LAMBDA_CYCLE = 10
NUM_WORKERS = 2
NUM_EPOCHS = 10
LOAD_MODEL = False
SAVE_MODEL = True
CHECKPOINT_GEN_H = "genh.pth.tar"
CHECKPOINT_GEN_Z = "genz.pth.tar"
CHECKPOINT_CRITIC_H = "critich.pth.tar"
CHECKPOINT_CRITIC_Z = "criticz.pth.tar"

IMG_EXTENSIONS = [
    '.jpg', '.JPG', '.jpeg', '.JPEG',
    '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
]

if (len(GPU_IDS) > 0):
  torch.cuda.set_device(f'cuda:{GPU_IDS[0]}')

In [None]:
def get_transform():
  transform_list = []
  zoom = 1 + 0.1 * random.randint(0, 4)
  osize = [int(400 * zoom), int(600 * zoom)]
  transform_list.append(transforms.Resize(osize, transforms.functional.InterpolationMode.BICUBIC))
  transform_list.append(transforms.RandomCrop(256))
  transform_list.append(transforms.RandomHorizontalFlip())
  transform_list += [transforms.ToTensor(),
                       transforms.Normalize((0.5, 0.5, 0.5),
                                            (0.5, 0.5, 0.5))]
  return transforms.Compose(transform_list)

In [None]:
def is_image_file(filename):
    return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)

In [None]:
def store_dataset(dir):
  images = []
  all_path = []
  assert os.path.isdir(dir), '%s is not a valid directory' %dir

  for root, _, fnames in sorted(os.walk(dir)):
    for fname in fnames: 
      if is_image_file(fname):
        path = os.path.join(root, fname)
        img = Image.open(path).convert('RGB')
        images.append(img)
        all_path.append(path)
  return images, all_path

In [None]:
def CreateDataLoader():
  data_loader = CustomDatasetDataLoader()
  print(data_loader.name())
  data_loader.initialize()
  return data_loader

In [None]:
gpu_ids = "0"
str_ids = gpu_ids.split(',')
gpu_ids = []
for str_id in str_ids:
  id = int(str_id)
  if id >= 0:
    gpu_ids.append(id)
if(len(gpu_ids)) > 0:
  torch.cuda.set_device(gpu_ids[0])

In [None]:
class CustomDatasetDataLoader():
    def name(self):
        return 'CustomDatasetDataLoader'

    def initialize(self):
        self.dataset = UnalignedDataset("/content/dataset/trainA", "/content/dataset/trainB")
        self.dataloader = torch.utils.data.DataLoader(
            self.dataset,
            batch_size=BATCH_SIZE,
            shuffle=not True,
            num_workers=NUM_WORKERS
        )

    def load_data(self):
        return self.dataloader

    def __len__(self):
        return min(len(self.dataset), float("inf"))

In [None]:
class UnalignedDataset(torch.utils.data.Dataset):
  def __init__(self, dir_A, dir_B):
        self.dir_A = dir_A
        self.dir_B = dir_B
        self.transform = get_transform()

        self.A_imgs, self.A_paths = store_dataset(self.dir_A)
        self.B_imgs, self.B_paths = store_dataset(self.dir_B)
        self.length_dataset = max(len(self.A_imgs), len(self.B_imgs))
        self.A_size = len(self.A_paths)
        self.B_size = len(self.B_paths)

  def __len__(self):
      return self.length_dataset

  def __getitem__(self, index):
      A_img = self.A_imgs[index % self.A_size]
      B_img = self.B_imgs[index % self.B_size]
      A_path = self.A_paths[index % self.A_size]
      B_path = self.B_paths[index % self.B_size]

      A_img = self.transform(A_img)
      B_img = self.transform(B_img)

      input_img = A_img
      B_img = (B_img + 1)/2.
      B_img = (B_img - torch.min(B_img))/(torch.max(B_img) - torch.min(B_img))
      B_img = B_img * 2. - 1
      r, g, b = input_img[0] + 1, input_img[1] + 1, input_img[2] + 1
      A_gray = 1. - (0.299*r+0.587*g+0.114*b)/2.
      A_gray = torch.unsqueeze(A_gray, 0)

      return {'A': A_img, 'B': B_img, 'A_gray': A_gray, 'input_img': input_img,
                'A_paths': A_path, 'B_paths': B_path}

In [None]:
def get_norm_layer(norm_type='instance'):
    if norm_type == 'batch':
        norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True)
    elif norm_type == 'instance':
        norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
    else:
        raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
    return norm_layer

In [None]:
def init_weights(net, init_type='normal', init_gain=0.02):
    def init_func(m): 
        classname = m.__class__.__name__
        if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
            if init_type == 'normal':
                init.normal_(m.weight.data, 0.0, init_gain)
            elif init_type == 'xavier':
                init.xavier_normal_(m.weight.data, gain=init_gain)
            elif init_type == 'kaiming':
                init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
            elif init_type == 'orthogonal':
                init.orthogonal_(m.weight.data, gain=init_gain)
            else:
                raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
            if hasattr(m, 'bias') and m.bias is not None:
                init.constant_(m.bias.data, 0.0)
        elif classname.find('BatchNorm2d') != -1:  
            init.normal_(m.weight.data, 1.0, init_gain)
            init.constant_(m.bias.data, 0.0)

    print('initialize network with %s' % init_type)
    net.apply(init_func)  

In [None]:
def init_net(net, init_type='normal', init_gain=0.02):
    if len(GPU_IDS) > 0:
        assert(torch.cuda.is_available())
        net.to(GPU_IDS[0])
        net = torch.nn.DataParallel(net, device_ids=[0])  
    init_weights(net, init_type, init_gain=init_gain)
    return net

In [None]:
def define_G(input_nc, output_nc, ngf, norm='batch', use_dropout=False, init_type='normal', init_gain=0.02):
  net = None
  norm_layer = get_norm_layer(norm_type=norm)
  net = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
  return init_net(net, init_type, init_gain)

In [None]:
def define_D(input_nc, ndf, n_layers_D=3, norm='batch', init_type='normal', init_gain=0.02):
  net = None
  norm_layer = get_norm_layer(norm_type=norm)
  net = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer)
  return init_net(net, init_type, init_gain)

In [None]:
class UnetSkipConnectionBlock(nn.Module):
    def __init__(self, outer_nc, inner_nc, input_nc=None,
                 submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
        super(UnetSkipConnectionBlock, self).__init__()
        self.outermost = outermost
        if type(norm_layer) == functools.partial:
            use_bias = norm_layer.func == nn.InstanceNorm2d
        else:
            use_bias = norm_layer == nn.InstanceNorm2d
        if input_nc is None:
            input_nc = outer_nc
        downconv = nn.Conv2d(outer_nc, inner_nc, kernel_size=4,
                             stride=2, padding=1, bias=use_bias)
        downrelu = nn.LeakyReLU(0.2, True)
        downnorm = norm_layer(inner_nc)
        uprelu = nn.ReLU(True)
        upnorm = norm_layer(outer_nc)

        if outermost:
            upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
                                        kernel_size=4, stride=2,
                                        padding=1)
            down = [downconv]
            up = [uprelu, upconv, nn.Tanh()]
            model = down + [submodule] + up
        elif innermost:
            upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
                                        kernel_size=4, stride=2,
                                        padding=1, bias=use_bias)
            down = [downrelu, downconv]
            up = [uprelu, upconv, upnorm]
            model = down + up
        else:
            upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
                                        kernel_size=4, stride=2,
                                        padding=1, bias=use_bias)
            down = [downrelu, downconv, downnorm]
            up = [uprelu, upconv, upnorm]

            if use_dropout:
                model = down + [submodule] + up + [nn.Dropout(0.5)]
            else:
                model = down + [submodule] + up

        self.model = nn.Sequential(*model)

    def forward(self, x):
        if self.outermost:
            return self.model(x)
        else:
            return torch.cat([x, self.model(x)], 1)

In [None]:
class SkipModule(nn.Module):
    def __init__(self, submodule):
        super(SkipModule, self).__init__()
        self.submodule = submodule

    def forward(self, x):
        latent = self.submodule(x)
        return 0.8*x + latent, latent

In [None]:
class UnetGenerator(nn.Module):
    def __init__(self, input_nc, output_nc, num_downs, ngf=64,
                 norm_layer=nn.BatchNorm2d, use_dropout=False):
        super(UnetGenerator, self).__init__()

        unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True)
        for i in range(num_downs - 5):
            unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
        unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
        unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
        unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
        self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer)

    def forward(self, input):
      return self.model(input)

In [None]:
class NLayerDiscriminator(nn.Module):
  def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d):
    super(NLayerDiscriminator, self).__init__()
    if type(norm_layer) == functools.partial:
      use_bias = norm_layer.func == nn.InstanceNorm2d
    else:
      use_bias = norm_layer == nn.InstanceNorm2d

    kw = 4
    padw = int(np.ceil((kw-1)/2))
    sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]

    nf_mult = 1
    nf_mult_prev = 1
    for n in range(1, n_layers):
      nf_mult_prev = nf_mult
      nf_mult = min(2**n, 8)
      sequence += [
          nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
          norm_layer(ndf * nf_mult),
          nn.LeakyReLU(0.2, True)
      ]

    nf_mult_prev = nf_mult
    nf_mult = min(2**n_layers, 8)
    sequence += [
        nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
        norm_layer(ndf * nf_mult),
        nn.LeakyReLU(0.2, True)
    ]

    sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]
    
    self.model = nn.Sequential(*sequence)

  def forward(self, x):
    return self.model(x)

In [None]:
class GANLoss(nn.Module):
    def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0):
        super(GANLoss, self).__init__()
        self.register_buffer('real_label', torch.tensor(target_real_label))
        self.register_buffer('fake_label', torch.tensor(target_fake_label))
        if use_lsgan:
            self.loss = nn.MSELoss()
        else:
            self.loss = nn.BCEWithLogitsLoss()

    def get_target_tensor(self, prediction, target_is_real):
        if target_is_real:
            target_tensor = self.real_label
        else:
            target_tensor = self.fake_label
        return target_tensor.expand_as(prediction)

    def __call__(self, prediction, target_is_real):
        target_tensor = self.get_target_tensor(prediction, target_is_real)
        return self.loss(prediction, target_tensor)

In [None]:
class ImagePool():
    def __init__(self, pool_size):
        self.pool_size = pool_size
        if self.pool_size > 0:
            self.num_imgs = 0
            self.images = []

    def query(self, images):
        if self.pool_size == 0:
            return images
        return_images = []
        for image in images.data:
            image = torch.unsqueeze(image, 0)
            if self.num_imgs < self.pool_size:
                self.num_imgs = self.num_imgs + 1
                self.images.append(image)
                return_images.append(image)
            else:
                p = random.uniform(0, 1)
                if p > 0.5:
                    random_id = random.randint(0, self.pool_size-1)
                    tmp = self.images[random_id].clone()
                    self.images[random_id] = image
                    return_images.append(tmp)
                else:
                    return_images.append(image)
        return_images = Variable(torch.cat(return_images, 0))
        return return_images

In [None]:
def print_network(net):
    num_params = 0
    for param in net.parameters():
        num_params += param.numel()
    print(net)
    print('Total number of parameters: %d' % num_params)

In [None]:
def tensor2im(image_tensor, imtype=np.uint8):
    if not isinstance(image_tensor, np.ndarray):
        if isinstance(image_tensor, torch.Tensor):  
            image_tensor = image_tensor.data
        else:
            return image_tensor
        image_numpy = image_tensor[0].cpu().float().numpy()  
        if image_numpy.shape[0] == 1:  
            image_numpy = np.tile(image_numpy, (3, 1, 1))
        image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0  
    else:  
        image_numpy = image_tensor
    return image_numpy.astype(imtype)

In [None]:
data_loader = CreateDataLoader()
dataset = data_loader.load_data()
dataset_size = len(data_loader)

CustomDatasetDataLoader


In [None]:
print('#training images = %d' % dataset_size)

#training images = 1016


In [None]:
class CycleGANModel():
  def __init__(self, batchSize = 1, fineSize = 256, input_nc=3, output_nc=3):
    nb = batchSize
    size = fineSize
    self.netG_A = define_G(input_nc = 3, output_nc = 3, ngf = 64, norm = 'instance', use_dropout = not True, init_type = 'normal', init_gain = 0.02)
    self.netG_B = define_G(input_nc = 3, output_nc = 3, ngf = 64, norm = 'instance', use_dropout = not True, init_type = 'normal', init_gain = 0.02)
   
    self.netD_A = define_D(input_nc =3, ndf=64, n_layers_D=3, norm='instance', init_type='normal', init_gain=0.2)
    self.netD_B = define_D(input_nc =3, ndf=64, n_layers_D=3, norm='instance', init_type='normal', init_gain=0.2)
    self.fake_A_pool = ImagePool(50)
    self.fake_B_pool = ImagePool(50)

    self.criterionGAN = GANLoss(use_lsgan = True).to(DEVICE)
    self.criterionCycle = torch.nn.L1Loss()
    self.criterionL1 = torch.nn.L1Loss()
    self.criterionIdt = torch.nn.L1Loss()

    self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()), 
                                        lr = LEARNING_RATE, betas = (0.5, 0.999))
    self.optimizer_D = torch.optim.Adam(itertools.chain(self.netD_A.parameters(), self.netD_B.parameters()), 
                                        lr=LEARNING_RATE, betas=(0.5, 0.999))

    
    print('---------- Networks initialized -------------')
    print_network(self.netG_A)
    print_network(self.netG_B)
    print_network(self.netD_A)
    print_network(self.netD_B)
    print('-----------------------------------------------')

  def set_input(self, input):
    self.real_A = input['A'].to(DEVICE)
    self.real_B = input['B'].to(DEVICE)
    self.image_paths = input['A_paths']
    

  def forward(self):
    self.fake_B = self.netG_A(self.real_A)
    self.rec_A = self.netG_B(self.fake_B)
    self.fake_A = self.netG_B(self.real_B)
    self.rec_B = self.netG_A(self.fake_A)

  def predict(self):
    self.real_A = Variable(self.input_A, volatile=True)
    if self.opt.skip == 1:
        self.fake_B, self.latent_real_A = self.netG_A.forward(self.real_A)
    else:
        self.fake_B = self.netG_A.forward(self.real_A)
    self.rec_A = self.netG_B.forward(self.fake_B)

    real_A = tensor2im(self.real_A.data)
    fake_B = tensor2im(self.fake_B.data)
    rec_A = tensor2im(self.rec_A.data)
    return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ("rec_A", rec_A)])

  def get_image_paths(self):
        return self.image_paths
  
  def backward_D_basic(self, netD, real, fake):
    
    pred_real = netD.forward(real)
    loss_D_real = self.criterionGAN(pred_real, True)
    
    pred_fake = netD.forward(fake.detach())
    loss_D_fake = self.criterionGAN(pred_fake, False)
    
    loss_D = (loss_D_real + loss_D_fake) * 0.5

    loss_D.backward()
    return loss_D

  def backward_D_A(self):
      fake_B = self.fake_B_pool.query(self.fake_B)
      self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B)

  def backward_D_B(self):
      fake_A = self.fake_A_pool.query(self.fake_A)
      self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A)

  def backward_G(self, epoch):
    lambda_idt = 0.0
    lambda_A = 10.0
    lambda_B = 10.0
    self.loss_idt_A = 0
    self.loss_idt_B = 0

    self.fake_B = self.netG_A.forward(self.real_A) 
    pred_fake = self.netD_A.forward(self.fake_B) 
    self.loss_G_A = self.criterionGAN(pred_fake, True) 
    self.L1_AB = self.criterionL1(self.fake_B, self.real_B) * 10.0

    self.fake_A = self.netG_B.forward(self.real_B)
    pred_fake = self.netD_B.forward(self.fake_A)

    self.L1_BA = self.criterionL1(self.fake_A, self.real_A) * 10.0
    self.loss_G_B = self.criterionGAN(pred_fake, True)

    self.rec_A = self.netG_B.forward(self.fake_B)
    self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A

    self.rec_B = self.netG_A.forward(self.fake_A)
    self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B

    self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B 
    
    self.loss_G.backward()

  def optimize_parameters(self, epoch):
    
    self.forward()

    self.optimizer_G.zero_grad()
    self.backward_G(epoch)
    self.optimizer_G.step()
    
    self.optimizer_D_A.zero_grad()
    self.backward_D_A()
    self.optimizer_D_A.step()
   
    self.optimizer_D_B.zero_grad()
    self.backward_D_B()
    self.optimizer_D_B.step()

  def get_current_errors(self):
    D_A = self.loss_D_A.data[0]
    G_A = self.loss_G_A.data[0]
    Cyc_A = self.loss_cycle_A.data[0]
    D_B = self.loss_D_B.data[0]
    G_B = self.loss_G_B.data[0]
    Cyc_B = self.loss_cycle_B.data[0]
    if self.lambda_A > 0.0:
        return OrderedDict([('D_A', D_A), ('G_A', G_A), ('Cyc_A', Cyc_A),
                            ('D_B', D_B), ('G_B', G_B), ('Cyc_B', Cyc_B)])
    else:
        return OrderedDict([('D_A', D_A), ('G_A', G_A), 
                            ('D_B', D_B), ('G_B', G_B)])
        
  def get_current_visuals(self):
    real_A = tensor2im(self.real_A.data)
    fake_B = tensor2im(self.fake_B.data)
    latent_real_A = tensor2im(self.latent_real_A.data)
    real_B = tensor2im(self.real_B.data)
    fake_A = tensor2im(self.fake_A.data)

    rec_A = tensor2im(self.rec_A.data)
    rec_B = tensor2im(self.rec_B.data)
    latent_fake_A = tensor2im(self.latent_fake_A.data)
    return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('latent_real_A', latent_real_A), ('rec_A', rec_A), 
                        ('real_B', real_B), ('fake_A', fake_A), ('rec_B', rec_B), ('latent_fake_A', latent_fake_A)])
    
  def load_networks(self, epoch): 
    for name in ['G_A', 'G_B']:
      if isinstance(name, str):
          load_filename = '%s_net_%s.pth' % (epoch, name)
          load_path = os.path.join('/content/checkpoints', load_filename)
          net = getattr(self, 'net' + name)
          if isinstance(net, torch.nn.DataParallel):
              net = net.module
          print('loading the model from %s' % load_path)
          state_dict = torch.load(load_path, map_location=str(DEVICE))
          if hasattr(state_dict, '_metadata'):
              del state_dict._metadata
          net.load_state_dict(state_dict)
    
  def setup(self):
    load_suffix = '5'
    self.load_networks(load_suffix)
  
  def eval(self):
    for name in self.model_names:
      if isinstance(name, str):
        net = getattr(self, 'net' + name)
        net.eval()

  def compute_visuals(self):
    """Calculate additional output images for visdom and HTML visualization"""
    pass

  def get_current_visuals(self):
    """Return visualization images. train.py will display these images with visdom, and save the images to a HTML"""
    visual_ret = OrderedDict()
    for name in ['real_A', 'fake_B', 'rec_A', 'real_B', 'fake_A', 'rec_B']:
        if isinstance(name, str):
            visual_ret[name] = getattr(self, name)
    return visual_ret
    
  def test(self):
    with torch.no_grad():
      self.forward()
      self.compute_visuals()
    
  def save_networks(self, epoch):
    model_names = ['G_A', 'G_B', 'D_A', 'D_B']
    for name in model_names:
      if isinstance(name , str):
        save_filename = '%s_net_%s.pth' % (epoch, name)
        save_path = os.path.join('/content/checkpoints', save_filename)
        net = getattr(self, 'net' + name)

        if len(GPU_IDS) > 0 and torch.cuda.is_available():
          torch.save(net.module.cpu().state_dict(), save_path)
          net.cuda(GPU_IDS[0])
        else:
          torch.save(net.cpu().state_dict(), save_path)

  def update_learning_rate(self):
    lrd = 0.0001 / 100
    lr = self.old_lr - lrd
    for param_group in self.optimizer_D_A.param_groups:
        param_group['lr'] = lr
    for param_group in self.optimizer_D_B.param_groups:
        param_group['lr'] = lr
    for param_group in self.optimizer_G.param_groups:
        param_group['lr'] = lr

    print('update learning rate: %f -> %f' % (self.old_lr, lr))
    self.old_lr = lr

In [None]:
model = CycleGANModel()

In [None]:
total_steps = 0 
for epoch in range(NUM_EPOCHS):
  epoch_start_time = time.time()
  loop = tqdm(dataset, leave=True)
  for i, data in enumerate(loop):
    iter_start_time = time.time()
    total_steps += BATCH_SIZE
    epoch_iter = total_steps - dataset_size * (epoch - 1)
    model.set_input(data)
    model.optimize_parameters(epoch)

    if total_steps % 5000 == 0:
      print('saving the latest model (epoch %d, total_steps %d)' % (epoch, total_steps))
      model.save_networks('latest')
    
  if epoch % 5 == 0:
    print('saving the model at the end of epoch %d, iters %d' % (epoch, total_steps))
    model.save_networks('latest')
    model.save_networks(epoch)

  print('End of epoch %d / %d \t Time Taken: %d sec' % (epoch, NUM_EPOCHS, time.time() - epoch_start_time))

  if epoch == NUM_EPOCHS:
    model.update_learning_rate()
  elif epoch == (NUM_EPOCHS + 20):
    model.update_learning_rate()
  elif epoch == (NUM_EPOCHS + 70):
    model.update_learning_rate()
  elif epoch == (NUM_EPOCHS + 90):
    model.update_learning_rate()
    model.update_learning_rate()
    model.update_learning_rate()
    model.update_learning_rate()


100%|██████████| 1016/1016 [03:38<00:00,  4.64it/s]


saving the model at the end of epoch 0, iters 1016
End of epoch 0 / 10 	 Time Taken: 222 sec


100%|██████████| 1016/1016 [03:38<00:00,  4.64it/s]


End of epoch 1 / 10 	 Time Taken: 218 sec


 87%|████████▋ | 885/1016 [03:10<00:28,  4.64it/s]


KeyboardInterrupt: ignored