In [1]:
import time
import os
import sys
from options.train_options import TrainOptions
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from PIL import Image
import glob
from data import create_dataset
from models import create_model
from util.visualizer import Visualizer
from pytorch_fid import fid_score
from options.test_options import TestOptions
from util.visualizer import save_images
from itertools import islice
from util import html
from options.base_options import BaseOptions


In [2]:
USE_GPU = True

dtype = torch.float32

if USE_GPU and torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

print('using device:', device)

using device: cuda


# Dataset Preparation

In [20]:
sys.argv = ['ipykernel_launcher.py', '--dataroot', './datasets/maps']

In [21]:
# argument for training
class TrainOptions(BaseOptions):
    def initialize(self, parser):
        BaseOptions.initialize(self, parser)
        parser.add_argument('--display_freq', type=int, default=400, help='frequency of showing training results on screen')
        parser.add_argument('--display_ncols', type=int, default=4, help='if positive, display all images in a single visdom web panel with certain number of images per row.')
        parser.add_argument('--display_id', type=int, default=1, help='window id of the web display')
        parser.add_argument('--display_port', type=int, default=8097, help='visdom display port')
        parser.add_argument('--display_env', type=str, default='main', help='visdom display environment name (default is "main")')
        parser.add_argument('--display_server', type=str, default="http://localhost", help='visdom server of the web display')
        parser.add_argument('--update_html_freq', type=int, default=4000, help='frequency of saving training results to html')
        parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console')
        parser.add_argument('--save_latest_freq', type=int, default=10000, help='frequency of saving the latest results')
        parser.add_argument('--save_epoch_freq', type=int, default=5, help='frequency of saving checkpoints at the end of epochs')
        parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc')
        parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/')
        parser.add_argument('--gan_mode', type=str, default='lsgan', help='the type of GAN objective. [vanilla | lsgan ｜ wgangp]. vanilla GAN loss is the cross-entropy objective used in the original GAN paper.')
        # training parameters
        parser.add_argument('--niter', type=int, default=20, help='# of iter at starting learning rate')
        parser.add_argument('--niter_decay', type=int, default=20, help='# of iter to linearly decay learning rate to zero')
        parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model')
        parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by <epoch_count>, <epoch_count>+<save_latest_freq>, ...')
        parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam')
        parser.add_argument('--lr_policy', type=str, default='linear', help='learning rate policy: linear | step | plateau | cosine')
        parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam')
        parser.add_argument('--lr_decay_iters', type=int, default=100, help='multiply by a gamma every lr_decay_iters iterations')
        # lambda parameters
        parser.add_argument('--lambda_L1', type=float, default=10.0, help='weight for |B-G(A, E(B))|')
        parser.add_argument('--lambda_GAN', type=float, default=1.0, help='weight on D loss. D(G(A, E(B)))')
        parser.add_argument('--lambda_GAN2', type=float, default=1.0, help='weight on D2 loss, D(G(A, random_z))')
        parser.add_argument('--lambda_z', type=float, default=0.5, help='weight for ||E(G(random_z)) - random_z||')
        parser.add_argument('--lambda_kl', type=float, default=0.01, help='weight for KL loss')
        parser.add_argument('--use_same_D', action='store_true', help='if two Ds share the weights or not')
        self.isTrain = True
        return parser


In [22]:
sys.argv = ['ipykernel_launcher.py', '--dataroot', './datasets/maps','--model', 'bicycle_gan'] # pix2pix
opt = TrainOptions().parse()   

----------------- Options ---------------
               batch_size: 2                             
                    beta1: 0.5                           
              center_crop: False                         
          checkpoints_dir: ./checkpoints                 
            conditional_D: False                         
           continue_train: False                         
                crop_size: 256                           
                 dataroot: ./datasets/maps               	[default: None]
             dataset_mode: aligned                       
                direction: AtoB                          
              display_env: main                          
             display_freq: 400                           
               display_id: 1                             
            display_ncols: 4                             
             display_port: 8097                          
           display_server: http://localhost              
          disp

In [23]:
import importlib
import torch

def create_dataset(opt):
    """Create and return a dataset loader."""
    # Import the dataset module based on the dataset_name provided in opt
    dataset_filename = "data." + opt.dataset_mode + "_dataset"
    datasetlib = importlib.import_module(dataset_filename)

    # Find the dataset class in the module, which should be a subclass of BaseDataset
    target_dataset_name = opt.dataset_mode.replace('_', '') + 'dataset'
    dataset_class = None
    for name, cls in datasetlib.__dict__.items():
        if name.lower() == target_dataset_name.lower():
            dataset_class = cls
            break
            
    # Create an instance of the dataset
    dataset = dataset_class(opt)
    print("Dataset [%s] was created" % type(dataset).__name__)

    # Create a multi-threaded data loader
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=opt.batch_size,
        shuffle=not opt.serial_batches,
        num_workers=int(opt.num_threads))

    # Encapsulate the dataset and dataloader in an object that can be iterated
    class DatasetLoader:
        def __init__(self, dataset, dataloader):
            self.dataset = dataset
            self.dataloader = dataloader

        def __len__(self):
            return min(len(self.dataset), opt.max_dataset_size)

        def __iter__(self):
            for i, data in enumerate(self.dataloader):
                if i * opt.batch_size >= opt.max_dataset_size:
                    break
                yield data

    return DatasetLoader(dataset, dataloader)


In [24]:
dataset = create_dataset(opt)  
dataset_size = len(dataset)    
print('The number of training images = %d' % dataset_size)

Dataset [AlignedDataset] was created
The number of training images = 1096


## BicycleGAN

In [32]:
import torch
from models.base_model import BaseModel
from models import networks

class BiCycleGANModel(BaseModel):
    @staticmethod
    def modify_commandline_options(parser, is_train=True):
        return parser

    def __init__(self, opt):
        if opt.isTrain:
            assert opt.batch_size % 2 == 0  # load two images at one time.

        BaseModel.__init__(self, opt)
        # specify the training losses you want to print out. The program will call base_model.get_current_losses
        self.loss_names = ['G_GAN', 'D', 'G_GAN2', 'D2', 'G_L1', 'z_L1', 'kl']
        # specify the images you want to save/display. The program will call base_model.get_current_visuals
        self.visual_names = ['real_A_encoded', 'real_B_encoded', 'fake_B_random', 'fake_B_encoded']
        # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks
        use_D = opt.isTrain and opt.lambda_GAN > 0.0
        use_D2 = opt.isTrain and opt.lambda_GAN2 > 0.0 and not opt.use_same_D
        use_E = opt.isTrain or not opt.no_encode
        use_vae = True
        self.model_names = ['G']
        self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.nz, opt.ngf, netG=opt.netG,
                                      norm=opt.norm, nl=opt.nl, use_dropout=opt.use_dropout, init_type=opt.init_type, init_gain=opt.init_gain,
                                      gpu_ids=self.gpu_ids, where_add=opt.where_add, upsample=opt.upsample)
        D_output_nc = opt.input_nc + opt.output_nc if opt.conditional_D else opt.output_nc
        if use_D:
            self.model_names += ['D']
            self.netD = networks.define_D(D_output_nc, opt.ndf, netD=opt.netD, norm=opt.norm, nl=opt.nl,
                                          init_type=opt.init_type, init_gain=opt.init_gain, num_Ds=opt.num_Ds, gpu_ids=self.gpu_ids)
        if use_D2:
            self.model_names += ['D2']
            self.netD2 = networks.define_D(D_output_nc, opt.ndf, netD=opt.netD2, norm=opt.norm, nl=opt.nl,
                                           init_type=opt.init_type, init_gain=opt.init_gain, num_Ds=opt.num_Ds, gpu_ids=self.gpu_ids)
        else:
            self.netD2 = None
        if use_E:
            self.model_names += ['E']
            self.netE = networks.define_E(opt.output_nc, opt.nz, opt.nef, netE=opt.netE, norm=opt.norm, nl=opt.nl,
                                          init_type=opt.init_type, init_gain=opt.init_gain, gpu_ids=self.gpu_ids, vaeLike=use_vae)

        if opt.isTrain:
            self.criterionGAN = networks.GANLoss(gan_mode=opt.gan_mode).to(self.device)
            self.criterionL1 = torch.nn.L1Loss()
            self.criterionZ = torch.nn.L1Loss()
            # initialize optimizers
            self.optimizers = []
            self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizers.append(self.optimizer_G)
            if use_E:
                self.optimizer_E = torch.optim.Adam(self.netE.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
                self.optimizers.append(self.optimizer_E)

            if use_D:
                self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
                self.optimizers.append(self.optimizer_D)
            if use_D2:
                self.optimizer_D2 = torch.optim.Adam(self.netD2.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
                self.optimizers.append(self.optimizer_D2)

    def is_train(self):
        """check if the current batch is good for training."""
        return self.opt.isTrain and self.real_A.size(0) == self.opt.batch_size

    def set_input(self, input):
        AtoB = self.opt.direction == 'AtoB'
        self.real_A = input['A' if AtoB else 'B'].to(self.device)
        self.real_B = input['B' if AtoB else 'A'].to(self.device)
        self.image_paths = input['A_paths' if AtoB else 'B_paths']

    def get_z_random(self, batch_size, nz, random_type='gauss'):
        if random_type == 'uni':
            z = torch.rand(batch_size, nz) * 2.0 - 1.0
        elif random_type == 'gauss':
            z = torch.randn(batch_size, nz)
        return z.detach().to(self.device)

    def encode(self, input_image):
        mu, logvar = self.netE.forward(input_image)
        std = logvar.mul(0.5).exp_()
        eps = self.get_z_random(std.size(0), std.size(1))
        z = eps.mul(std).add_(mu)
        return z, mu, logvar

    def test(self, z0=None, encode=False):
        with torch.no_grad():
            if encode:  # use encoded z
                z0, _ = self.netE(self.real_B)
            if z0 is None:
                z0 = self.get_z_random(self.real_A.size(0), self.opt.nz)
            self.fake_B = self.netG(self.real_A, z0)
            return self.real_A, self.fake_B, self.real_B

    def forward(self):
        # get real images
        half_size = self.opt.batch_size // 2
        # A1, B1 for encoded; A2, B2 for random
        self.real_A_encoded = self.real_A[0:half_size]
        self.real_B_encoded = self.real_B[0:half_size]
        self.real_A_random = self.real_A[half_size:]
        self.real_B_random = self.real_B[half_size:]
        # get encoded z
        self.z_encoded, self.mu, self.logvar = self.encode(self.real_B_encoded)
        # get random z
        self.z_random = self.get_z_random(self.real_A_encoded.size(0), self.opt.nz)
        # generate fake_B_encoded
        self.fake_B_encoded = self.netG(self.real_A_encoded, self.z_encoded)
        # generate fake_B_random
        self.fake_B_random = self.netG(self.real_A_encoded, self.z_random)
        if self.opt.conditional_D:   # tedious conditoinal data
            self.fake_data_encoded = torch.cat([self.real_A_encoded, self.fake_B_encoded], 1)
            self.real_data_encoded = torch.cat([self.real_A_encoded, self.real_B_encoded], 1)
            self.fake_data_random = torch.cat([self.real_A_encoded, self.fake_B_random], 1)
            self.real_data_random = torch.cat([self.real_A_random, self.real_B_random], 1)
        else:
            self.fake_data_encoded = self.fake_B_encoded
            self.fake_data_random = self.fake_B_random
            self.real_data_encoded = self.real_B_encoded
            self.real_data_random = self.real_B_random

        # compute z_predict
        if self.opt.lambda_z > 0.0:
            self.mu2, logvar2 = self.netE(self.fake_B_random)  # mu2 is a point estimate

    def backward_D(self, netD, real, fake):
        # Fake, stop backprop to the generator by detaching fake_B
        pred_fake = netD(fake.detach())
        # real
        pred_real = netD(real)
        loss_D_fake, _ = self.criterionGAN(pred_fake, False)
        loss_D_real, _ = self.criterionGAN(pred_real, True)
        # Combined loss
        loss_D = loss_D_fake + loss_D_real
        loss_D.backward()
        return loss_D, [loss_D_fake, loss_D_real]

    def backward_G_GAN(self, fake, netD=None, ll=0.0):
        if ll > 0.0:
            pred_fake = netD(fake)
            loss_G_GAN, _ = self.criterionGAN(pred_fake, True)
        else:
            loss_G_GAN = 0
        return loss_G_GAN * ll

    def backward_EG(self):
        # 1, G(A) should fool D
        self.loss_G_GAN = self.backward_G_GAN(self.fake_data_encoded, self.netD, self.opt.lambda_GAN)
        if self.opt.use_same_D:
            self.loss_G_GAN2 = self.backward_G_GAN(self.fake_data_random, self.netD, self.opt.lambda_GAN2)
        else:
            self.loss_G_GAN2 = self.backward_G_GAN(self.fake_data_random, self.netD2, self.opt.lambda_GAN2)
        # 2. KL loss
        if self.opt.lambda_kl > 0.0:
            self.loss_kl = torch.sum(1 + self.logvar - self.mu.pow(2) - self.logvar.exp()) * (-0.5 * self.opt.lambda_kl)
        else:
            self.loss_kl = 0
        # 3, reconstruction |fake_B-real_B|
        if self.opt.lambda_L1 > 0.0:
            self.loss_G_L1 = self.criterionL1(self.fake_B_encoded, self.real_B_encoded) * self.opt.lambda_L1
        else:
            self.loss_G_L1 = 0.0

        self.loss_G = self.loss_G_GAN + self.loss_G_GAN2 + self.loss_G_L1 + self.loss_kl
        self.loss_G.backward(retain_graph=True)

    def update_D(self):
        self.set_requires_grad([self.netD, self.netD2], True)
        # update D1
        if self.opt.lambda_GAN > 0.0:
            self.optimizer_D.zero_grad()
            self.loss_D, self.losses_D = self.backward_D(self.netD, self.real_data_encoded, self.fake_data_encoded)
            if self.opt.use_same_D:
                self.loss_D2, self.losses_D2 = self.backward_D(self.netD, self.real_data_random, self.fake_data_random)
            self.optimizer_D.step()

        if self.opt.lambda_GAN2 > 0.0 and not self.opt.use_same_D:
            self.optimizer_D2.zero_grad()
            self.loss_D2, self.losses_D2 = self.backward_D(self.netD2, self.real_data_random, self.fake_data_random)
            self.optimizer_D2.step()

    def backward_G_alone(self):
        # 3, reconstruction |(E(G(A, z_random)))-z_random|
        if self.opt.lambda_z > 0.0:
            self.loss_z_L1 = self.criterionZ(self.mu2, self.z_random) * self.opt.lambda_z
            self.loss_z_L1.backward()
        else:
            self.loss_z_L1 = 0.0

    def update_G_and_E(self):
        # update G and E
        self.set_requires_grad([self.netD, self.netD2], False)
        self.optimizer_E.zero_grad()
        self.optimizer_G.zero_grad()
        self.backward_EG()

        # update G alone
        if self.opt.lambda_z > 0.0:
            self.set_requires_grad([self.netE], False)
            self.backward_G_alone()
            self.set_requires_grad([self.netE], True)

        self.optimizer_E.step()
        self.optimizer_G.step()

    def optimize_parameters(self):
        self.forward()
        self.update_G_and_E()
        self.update_D()

# Training

In [6]:
model = create_model(opt)      
model.setup(opt)              
visualizer = Visualizer(opt) 
total_iters = 0 
losses_data = {}
iter_data = []

for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1):    
    epoch_start_time = time.time() 
    iter_data_time = time.time()   
    epoch_iter = 0                  
    
    for i, data in enumerate(dataset): 
        iter_start_time = time.time()  
        if total_iters % opt.print_freq == 0:
            t_data = iter_start_time - iter_data_time
        visualizer.reset()
        total_iters += opt.batch_size
        epoch_iter += opt.batch_size
        model.set_input(data)        
        if not model.is_train():     
            print('skip this batch')
            continue
        model.optimize_parameters()   

        if total_iters % opt.display_freq == 0: 
            save_result = total_iters % opt.update_html_freq == 0
            model.compute_visuals()
            visualizer.display_current_results(model.get_current_visuals(), epoch, save_result)

        if total_iters % opt.print_freq == 0:   
            losses = model.get_current_losses()
            t_comp = (time.time() - iter_start_time) / opt.batch_size
            visualizer.print_current_losses(epoch, epoch_iter, losses, t_comp, t_data)
            if opt.display_id > 0:
                visualizer.plot_current_losses(epoch, float(epoch_iter) / dataset_size, losses)

        if total_iters % opt.save_latest_freq == 0:   
            print('saving the latest model (epoch %d, total_iters %d)' % (epoch, total_iters))
            model.save_networks('latest')

        iter_data_time = time.time()
    if epoch % opt.save_epoch_freq == 0:              
        print('saving the model at the end of epoch %d, iters %d' % (epoch, total_iters))
        model.save_networks('latest')
        model.save_networks(epoch)

    print('End of epoch %d / %d \t Time Taken: %d sec' % (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time))
    model.update_learning_rate()                     


Setting up a new session...


initialize network with xavier
initialize network with xavier
initialize network with xavier
initialize network with xavier
model [BiCycleGANModel] was created
---------- Networks initialized -------------
[Network G] Total number of parameters : 54.795 M
[Network D] Total number of parameters : 3.459 M
[Network D2] Total number of parameters : 3.459 M
[Network E] Total number of parameters : 2.590 M
-----------------------------------------------
create web directory ./checkpoints/web...
(epoch: 1, iters: 100, time: 0.049, data: 0.793) G_GAN: 0.561 D: 0.700 G_GAN2: 0.732 D2: 1.198 G_L1: 1.111 z_L1: 1.074 kl: 0.101 
(epoch: 1, iters: 200, time: 0.042, data: 0.004) G_GAN: 0.759 D: 0.444 G_GAN2: 0.514 D2: 1.041 G_L1: 2.191 z_L1: 0.721 kl: 0.055 
(epoch: 1, iters: 300, time: 0.047, data: 0.002) G_GAN: 0.692 D: 1.450 G_GAN2: 0.534 D2: 0.696 G_L1: 2.535 z_L1: 0.516 kl: 0.014 
(epoch: 1, iters: 400, time: 0.199, data: 0.005) G_GAN: 0.941 D: 0.848 G_GAN2: 0.872 D2: 1.159 G_L1: 1.294 z_L1: 0.5

(epoch: 6, iters: 320, time: 0.050, data: 0.003) G_GAN: 0.676 D: 0.494 G_GAN2: 0.684 D2: 0.567 G_L1: 2.043 z_L1: 0.363 kl: 0.030 
(epoch: 6, iters: 420, time: 0.048, data: 0.003) G_GAN: 0.868 D: 0.746 G_GAN2: 0.502 D2: 0.963 G_L1: 0.781 z_L1: 0.294 kl: 0.025 
(epoch: 6, iters: 520, time: 0.178, data: 0.003) G_GAN: 1.187 D: 1.188 G_GAN2: 0.941 D2: 1.848 G_L1: 0.522 z_L1: 0.312 kl: 0.027 
(epoch: 6, iters: 620, time: 0.043, data: 0.004) G_GAN: 1.097 D: 0.875 G_GAN2: 1.060 D2: 1.065 G_L1: 1.546 z_L1: 0.401 kl: 0.117 
(epoch: 6, iters: 720, time: 0.044, data: 0.002) G_GAN: 1.078 D: 0.377 G_GAN2: 0.737 D2: 0.498 G_L1: 1.692 z_L1: 0.353 kl: 0.033 
(epoch: 6, iters: 820, time: 0.044, data: 0.003) G_GAN: 1.048 D: 0.944 G_GAN2: 0.608 D2: 0.907 G_L1: 0.935 z_L1: 0.380 kl: 0.024 
(epoch: 6, iters: 920, time: 0.204, data: 0.003) G_GAN: 1.030 D: 0.636 G_GAN2: 0.498 D2: 0.722 G_L1: 0.915 z_L1: 0.792 kl: 0.050 
(epoch: 6, iters: 1020, time: 0.049, data: 0.004) G_GAN: 0.731 D: 0.763 G_GAN2: 0.608 D2: 

(epoch: 11, iters: 840, time: 0.038, data: 0.003) G_GAN: 1.118 D: 0.848 G_GAN2: 1.004 D2: 0.327 G_L1: 0.882 z_L1: 0.338 kl: 0.037 
(epoch: 11, iters: 940, time: 0.053, data: 0.002) G_GAN: 1.204 D: 0.303 G_GAN2: 1.305 D2: 1.354 G_L1: 0.758 z_L1: 0.470 kl: 0.026 
(epoch: 11, iters: 1040, time: 0.191, data: 0.003) G_GAN: 1.103 D: 0.348 G_GAN2: 1.388 D2: 0.539 G_L1: 2.433 z_L1: 0.570 kl: 0.004 
End of epoch 11 / 40 	 Time Taken: 53 sec
learning rate = 0.0002000
(epoch: 12, iters: 44, time: 0.050, data: 0.004) G_GAN: 1.293 D: 0.388 G_GAN2: 1.982 D2: 0.937 G_L1: 2.682 z_L1: 0.653 kl: 0.228 
(epoch: 12, iters: 144, time: 0.057, data: 0.004) G_GAN: 0.827 D: 0.482 G_GAN2: 0.367 D2: 0.894 G_L1: 1.102 z_L1: 0.719 kl: 0.044 
(epoch: 12, iters: 244, time: 0.046, data: 0.004) G_GAN: 0.493 D: 1.059 G_GAN2: 1.988 D2: 0.298 G_L1: 1.252 z_L1: 0.375 kl: 0.025 
(epoch: 12, iters: 344, time: 0.200, data: 0.003) G_GAN: 1.038 D: 0.514 G_GAN2: 2.080 D2: 0.241 G_L1: 1.317 z_L1: 0.373 kl: 0.014 
(epoch: 12, ite

(epoch: 17, iters: 264, time: 0.057, data: 0.002) G_GAN: 1.611 D: 0.415 G_GAN2: 0.894 D2: 0.945 G_L1: 1.393 z_L1: 0.365 kl: 0.027 
(epoch: 17, iters: 364, time: 0.049, data: 0.003) G_GAN: 0.972 D: 0.454 G_GAN2: 1.058 D2: 0.377 G_L1: 0.951 z_L1: 0.424 kl: 0.025 
(epoch: 17, iters: 464, time: 0.177, data: 0.002) G_GAN: 0.863 D: 0.834 G_GAN2: 1.140 D2: 1.155 G_L1: 1.715 z_L1: 0.680 kl: 0.019 
(epoch: 17, iters: 564, time: 0.044, data: 0.003) G_GAN: 1.296 D: 0.463 G_GAN2: 1.162 D2: 0.266 G_L1: 2.083 z_L1: 0.568 kl: 0.018 
(epoch: 17, iters: 664, time: 0.055, data: 0.004) G_GAN: 1.313 D: 0.193 G_GAN2: 1.183 D2: 0.958 G_L1: 1.375 z_L1: 0.312 kl: 0.041 
(epoch: 17, iters: 764, time: 0.051, data: 0.004) G_GAN: 1.283 D: 0.457 G_GAN2: 1.816 D2: 0.204 G_L1: 1.066 z_L1: 0.601 kl: 0.101 
(epoch: 17, iters: 864, time: 0.217, data: 0.004) G_GAN: 1.690 D: 0.235 G_GAN2: 1.197 D2: 0.601 G_L1: 0.838 z_L1: 0.776 kl: 0.035 
(epoch: 17, iters: 964, time: 0.044, data: 0.005) G_GAN: 1.171 D: 0.224 G_GAN2: 0.6

(epoch: 22, iters: 784, time: 0.044, data: 0.002) G_GAN: 1.090 D: 0.766 G_GAN2: 0.486 D2: 2.201 G_L1: 1.614 z_L1: 0.293 kl: 0.068 
(epoch: 22, iters: 884, time: 0.050, data: 0.003) G_GAN: 1.290 D: 0.349 G_GAN2: 0.904 D2: 1.388 G_L1: 2.390 z_L1: 0.426 kl: 0.054 
(epoch: 22, iters: 984, time: 0.207, data: 0.002) G_GAN: 0.995 D: 0.336 G_GAN2: 0.869 D2: 0.361 G_L1: 1.318 z_L1: 0.465 kl: 0.070 
(epoch: 22, iters: 1084, time: 0.041, data: 0.003) G_GAN: 1.463 D: 0.189 G_GAN2: 0.973 D2: 0.964 G_L1: 1.200 z_L1: 0.593 kl: 0.058 
End of epoch 22 / 40 	 Time Taken: 55 sec
learning rate = 0.0001714
(epoch: 23, iters: 88, time: 0.050, data: 0.003) G_GAN: 1.765 D: 0.799 G_GAN2: 0.850 D2: 1.131 G_L1: 1.018 z_L1: 0.673 kl: 0.197 
(epoch: 23, iters: 188, time: 0.051, data: 0.002) G_GAN: 1.161 D: 0.272 G_GAN2: 1.095 D2: 0.202 G_L1: 0.990 z_L1: 0.447 kl: 0.066 
(epoch: 23, iters: 288, time: 0.199, data: 0.004) G_GAN: 1.222 D: 0.188 G_GAN2: 1.165 D2: 0.168 G_L1: 1.356 z_L1: 0.514 kl: 0.057 
(epoch: 23, ite

(epoch: 28, iters: 208, time: 0.051, data: 0.005) G_GAN: 1.418 D: 0.124 G_GAN2: 0.977 D2: 0.449 G_L1: 0.879 z_L1: 0.452 kl: 0.010 
(epoch: 28, iters: 308, time: 0.060, data: 0.004) G_GAN: 1.167 D: 0.299 G_GAN2: 1.559 D2: 0.098 G_L1: 1.240 z_L1: 0.483 kl: 0.033 
(epoch: 28, iters: 408, time: 0.223, data: 0.004) G_GAN: 0.859 D: 0.297 G_GAN2: 1.262 D2: 0.192 G_L1: 1.372 z_L1: 0.336 kl: 0.025 
saving the latest model (epoch 28, total_iters 30000)
(epoch: 28, iters: 508, time: 0.045, data: 0.003) G_GAN: 0.798 D: 0.420 G_GAN2: 0.897 D2: 0.522 G_L1: 1.035 z_L1: 0.520 kl: 0.027 
(epoch: 28, iters: 608, time: 0.050, data: 0.003) G_GAN: 1.245 D: 0.573 G_GAN2: 0.879 D2: 1.028 G_L1: 0.835 z_L1: 0.492 kl: 0.025 
(epoch: 28, iters: 708, time: 0.059, data: 0.003) G_GAN: 0.575 D: 0.682 G_GAN2: 0.831 D2: 0.334 G_L1: 1.478 z_L1: 0.508 kl: 0.028 
(epoch: 28, iters: 808, time: 0.214, data: 0.004) G_GAN: 1.451 D: 0.390 G_GAN2: 1.518 D2: 0.341 G_L1: 1.252 z_L1: 0.463 kl: 0.018 
(epoch: 28, iters: 908, time:

(epoch: 33, iters: 728, time: 0.041, data: 0.002) G_GAN: 1.641 D: 0.514 G_GAN2: 1.813 D2: 0.104 G_L1: 1.027 z_L1: 0.478 kl: 0.073 
(epoch: 33, iters: 828, time: 0.049, data: 0.002) G_GAN: 1.597 D: 0.147 G_GAN2: 1.304 D2: 0.718 G_L1: 1.622 z_L1: 0.563 kl: 0.070 
(epoch: 33, iters: 928, time: 0.228, data: 0.002) G_GAN: 1.391 D: 0.460 G_GAN2: 1.098 D2: 0.299 G_L1: 1.341 z_L1: 0.399 kl: 0.059 
(epoch: 33, iters: 1028, time: 0.043, data: 0.004) G_GAN: 1.053 D: 0.304 G_GAN2: 0.969 D2: 0.841 G_L1: 0.836 z_L1: 0.743 kl: 0.074 
End of epoch 33 / 40 	 Time Taken: 53 sec
learning rate = 0.0000667
(epoch: 34, iters: 32, time: 0.048, data: 0.002) G_GAN: 1.314 D: 1.070 G_GAN2: 1.325 D2: 0.830 G_L1: 0.941 z_L1: 0.409 kl: 0.076 
(epoch: 34, iters: 132, time: 0.052, data: 0.002) G_GAN: 1.171 D: 0.178 G_GAN2: 0.864 D2: 1.749 G_L1: 1.334 z_L1: 0.449 kl: 0.064 
(epoch: 34, iters: 232, time: 0.245, data: 0.002) G_GAN: 1.044 D: 0.374 G_GAN2: 0.974 D2: 0.265 G_L1: 1.231 z_L1: 0.506 kl: 0.092 
(epoch: 34, ite

(epoch: 39, iters: 52, time: 0.043, data: 0.003) G_GAN: 1.744 D: 0.097 G_GAN2: 1.307 D2: 0.110 G_L1: 2.375 z_L1: 0.678 kl: 0.079 
(epoch: 39, iters: 152, time: 0.047, data: 0.002) G_GAN: 1.237 D: 0.409 G_GAN2: 1.225 D2: 0.120 G_L1: 0.777 z_L1: 0.647 kl: 0.061 
(epoch: 39, iters: 252, time: 0.055, data: 0.003) G_GAN: 1.267 D: 1.657 G_GAN2: 1.072 D2: 1.065 G_L1: 0.835 z_L1: 0.818 kl: 0.065 
(epoch: 39, iters: 352, time: 0.240, data: 0.004) G_GAN: 1.194 D: 0.291 G_GAN2: 1.372 D2: 0.473 G_L1: 1.051 z_L1: 0.424 kl: 0.068 
(epoch: 39, iters: 452, time: 0.045, data: 0.003) G_GAN: 1.025 D: 0.231 G_GAN2: 1.401 D2: 0.080 G_L1: 1.330 z_L1: 0.896 kl: 0.191 
(epoch: 39, iters: 552, time: 0.051, data: 0.002) G_GAN: 1.258 D: 0.393 G_GAN2: 0.811 D2: 0.387 G_L1: 0.909 z_L1: 0.426 kl: 0.074 
(epoch: 39, iters: 652, time: 0.049, data: 0.003) G_GAN: 1.262 D: 0.214 G_GAN2: 1.210 D2: 0.151 G_L1: 1.382 z_L1: 0.402 kl: 0.062 
(epoch: 39, iters: 752, time: 0.200, data: 0.003) G_GAN: 1.377 D: 0.202 G_GAN2: 1.30

# Test

In [12]:
class TestOptions(BaseOptions):
    def initialize(self, parser):
        BaseOptions.initialize(self, parser)
        parser.add_argument('--results_dir', type=str, default='../results/', help='saves results here.')
        parser.add_argument('--phase', type=str, default='val', help='train, val, test, etc')
        parser.add_argument('--num_test', type=int, default=50, help='how many test images to run')
        parser.add_argument('--n_samples', type=int, default=5, help='#samples')
        parser.add_argument('--no_encode', action='store_true', help='do not produce encoded image')
        parser.add_argument('--sync', action='store_true', help='use the same latent code for different input images')
        parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio for the results')
        parser.add_argument('--eval', action='store_true', help='use eval mode during test time.')

        self.isTrain = False
        return parser

In [13]:
real_images_dir = "real_images"
fake_images_dir = "fake_images"

os.makedirs(real_images_dir, exist_ok=True)
os.makedirs(fake_images_dir, exist_ok=True)
def tensor_to_pil(tensor):
    return transforms.ToPILImage()(tensor.cpu().detach())

In [14]:
# options
opt = TestOptions().parse()
opt.num_threads = 1  
opt.batch_size = 1  
opt.serial_batches = True

# create dataset
dataset = create_dataset(opt)
model = create_model(opt)
model.setup(opt)
model.eval()
print('Loading model %s' % opt.model)

web_dir = os.path.join(opt.results_dir, opt.phase + '_sync' if opt.sync else opt.phase)
webpage = html.HTML(web_dir, 'Training = %s, Phase = %s, Class =%s' % (opt.name, opt.phase, opt.name))

# sample random z
if opt.sync:
    z_samples = model.get_z_random(opt.n_samples + 1, opt.nz)

# test stage
for i, data in enumerate(islice(dataset, opt.num_test)):
    model.set_input(data)
    print('process input image %3.3d/%3.3d' % (i, opt.num_test))
    if not opt.sync:
        z_samples = model.get_z_random(opt.n_samples + 1, opt.nz)
    for nn in range(opt.n_samples + 1):
        encode = nn == 0 and not opt.no_encode
        real_A, fake_B, real_B = model.test(z_samples[[nn]], encode=encode)
        real_image_path = os.path.join(real_images_dir, f'real_image_{i}.png')
        fake_image_path = os.path.join(fake_images_dir, f'fake_image_{i}_{nn}.png')
        tensor_to_pil(real_A.squeeze(0)).save(real_image_path)
        tensor_to_pil(fake_B.squeeze(0)).save(fake_image_path)
        
        if nn == 0:
            images = [real_A, real_B, fake_B]
            names = ['input', 'ground truth', 'encoded']
            
        else:
            images.append(fake_B)
            names.append('random_sample%2.2d' % nn)

    img_path = 'input_%3.3d' % i
    save_images(webpage, images, names, img_path, aspect_ratio=opt.aspect_ratio, width=opt.crop_size)

webpage.save()


----------------- Options ---------------
             aspect_ratio: 1.0                           
               batch_size: 2                             
              center_crop: False                         
          checkpoints_dir: ./checkpoints                 
            conditional_D: False                         
                crop_size: 256                           
                 dataroot: ./datasets/maps               	[default: None]
             dataset_mode: aligned                       
                direction: AtoB                          
          display_winsize: 256                           
                    epoch: latest                        
                     eval: False                         
                  gpu_ids: 0                             
                init_gain: 0.02                          
                init_type: xavier                        
                 input_nc: 3                             
              

## Evaluations

In [12]:
fid_value = fid_score.calculate_fid_given_paths([real_images_dir, fake_images_dir],
                                                batch_size=opt.batch_size,
                                                device='cuda',
                                                dims=64)
print('FID score: ', fid_value)

100%|██████████| 50/50 [00:00<00:00, 106.62it/s]
100%|██████████| 300/300 [00:01<00:00, 197.92it/s]

FID score:  8.63273997616474



