In [1]:
%load_ext autoreload
%autoreload 2

import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as td
import torchvision as tv
import matplotlib.pyplot as plt
import math
import itertools
import datetime
import time
import sys
from torch.autograd import Variable
from torchvision.utils import save_image
from torch.utils.data import Dataset
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import glob
import random
import os

sys.path.insert(1, './code/')
from models import Create_nets
from optimizer import Get_loss_func, Get_optimizers
from utils import ReplayBuffer, LambdaLR, sample_images

In [2]:
from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

In [3]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

cuda


# Define Training Options for this script And Create Experiement Record

In [4]:
class Options:
    def __init__(self, **opts):
        self.__dict__.update(opts)

## pool_size = 1

In [5]:
opts = {'exp_name': 'assets/Exp1',
        'epoch_start': 0,
        'epoch_num': 20,
        'data_root': './datasets/',
        'dataset_name': 'photo2vangogh',
        'batch_size': 1,
        'lr': 0.0002,
        'b1': 0.5,
        'b2': 0.999,
        'decay_epoch': 8,
        'n_cpu': 4,
        'img_height': 256,
        'img_width': 256,
        'input_nc_A': 3,
        'input_nc_B': 3,
        'sample_interval': 200,
        'checkpoint_interval': 1,
        'n_residual_blocks': 9,
        'n_D_layers': 4,
        'lambda_cyc': 10,
        'lambda_id': 0.5,
        'pool_size': 1,
        'img_result_dir': 'results',
        'model_result_dir': 'saved_models'}

In [6]:
args = Options(**opts)
os.makedirs('%s-%s/%s' % (args.exp_name, args.dataset_name, args.img_result_dir), exist_ok=True)
os.makedirs('%s-%s/%s' % (args.exp_name, args.dataset_name, args.model_result_dir), exist_ok=True)

# Data Processing

In [7]:
class ImageDataset(Dataset):
    def __init__(self, args, root, transforms_=None, unaligned=False, mode='train'):
        self.transform = transforms.Compose(transforms_)
        self.args = args
        self.unaligned = unaligned
        self.files_X = sorted(glob.glob(os.path.join(root, '%sA' % mode) + '/*.*'))
        self.files_Y = sorted(glob.glob(os.path.join(root, '%sB' % mode) + '/*.*'))


    def __getitem__(self, index):

        img_X = Image.open(self.files_X[index % len(self.files_X)])
        if self.unaligned:
            img_Y = Image.open(self.files_Y[random.randint(0, len(self.files_Y)-1)])
        else:
            img_Y = Image.open(self.files_Y[index % len(self.files_Y)] )

        img_X = self.transform(img_X)
        img_Y = self.transform(img_Y)

        if self.args.input_nc_A == 1:
            img_X = img_X.convert('L')

        if self.args.input_nc_B == 1:
            img_Y = img_Y.convert('L')

        return {'X': img_X, 'Y': img_Y}

    def __len__(self):
        return max(len(self.files_X), len(self.files_Y))

In [8]:
transforms_ = [ transforms.Resize(int(args.img_height*1.12), Image.BICUBIC),
                transforms.RandomCrop((args.img_height, args.img_width)),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)) ]

train_dataloader = DataLoader(ImageDataset(args, "%s/%s" % (args.data_root, args.dataset_name), transforms_=transforms_,unaligned=True,mode='train'),
                    batch_size=args.batch_size, shuffle=True, num_workers=args.n_cpu//2, drop_last=True)

test_dataloader = DataLoader(ImageDataset(args, "%s/%s" % (args.data_root, args.dataset_name), transforms_=transforms_, unaligned=True,mode='test'),
                        batch_size=4, shuffle=True, num_workers=1, drop_last=True)

# Train

In [9]:
patch = (1, args.img_height//(2**args.n_D_layers) - 2 , args.img_width//(2**args.n_D_layers) - 2)

# Initialize generator and discriminator
G__AB, D__B, G__BA, D__A = Create_nets(args)

# Loss functions
criterion_GAN, criterion_cycle, criterion_identity = Get_loss_func(args)
# Optimizers
optimizer_G, optimizer_D_B, optimizer_D_A = Get_optimizers(args, G__AB, G__BA, D__B, D__A )
# Learning rate update schedulers
lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(optimizer_G, lr_lambda=LambdaLR(args.epoch_num, args.epoch_start, args.decay_epoch).step)
lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(optimizer_D_B, lr_lambda=LambdaLR(args.epoch_num, args.epoch_start, args.decay_epoch).step)
lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(optimizer_D_A, lr_lambda=LambdaLR(args.epoch_num, args.epoch_start, args.decay_epoch).step)

# Buffers of previously generated samples
fake_Y_A_buffer = ReplayBuffer()
fake_X_B_buffer = ReplayBuffer()

In [None]:
prev_time = time.time()
for epoch in range(args.epoch_start, args.epoch_num):
    epoch_statistic = ''
    for i, batch in enumerate(train_dataloader):

        # Set model input
        real_X_A = Variable(batch['X'].type(torch.FloatTensor).cuda())
        real_Y_B = Variable(batch['Y'].type(torch.FloatTensor).cuda())

         # Adversarial ground truths
        valid = Variable(torch.FloatTensor(np.ones((real_X_A.size(0), *patch))).cuda(), requires_grad=False)
        fake = Variable(torch.FloatTensor(np.zeros((real_X_A.size(0), *patch))).cuda(), requires_grad=False)

        #  Train G_A and G_B

        optimizer_G.zero_grad()
        # Identity loss
        loss_id_A = criterion_identity(G__BA(real_X_A), real_X_A)
        loss_id_B = criterion_identity(G__AB(real_Y_B), real_Y_B)

        loss_identity = (loss_id_A + loss_id_B) / 2

        # GAN loss
        fake_X_B = G__AB(real_X_A)
        pred_fake = D__B(fake_X_B)
        #print(pred_fake.shape,valid.shape)
        loss_GAN_AB = criterion_GAN(pred_fake, valid)

        fake_Y_A = G__BA(real_Y_B)
        pred_fake = D__A(fake_Y_A)
        loss_GAN_BA = criterion_GAN(pred_fake, valid)

        loss_GAN = (loss_GAN_AB + loss_GAN_BA) / 2

        # Cycle loss
        recov_X_A = G__BA(fake_X_B)
        loss_cycle_A = criterion_cycle(recov_X_A, real_X_A)
        recov_Y_B = G__AB(fake_Y_A)
        loss_cycle_B = criterion_cycle(recov_Y_B, real_Y_B)

        loss_cycle = (loss_cycle_A + loss_cycle_B) / 2

        # Total loss
        loss_G =    loss_GAN + \
                    args.lambda_cyc * loss_cycle + \
                    args.lambda_id * loss_identity

        loss_G.backward()
        optimizer_G.step()

        #  Train D_A
        
        optimizer_D_A.zero_grad()

        # Real loss
        pred_real = D__A(real_X_A)
        loss_real = criterion_GAN(pred_real, valid)
        # Fake loss (on batch of previously generated samples)
        fake_Y_A_ = fake_Y_A_buffer.push_and_pop(fake_Y_A)
        pred_fake = D__A(fake_Y_A_.detach())
        loss_fake = criterion_GAN(pred_fake, fake)
        # Total loss
        loss_D_A = (loss_real + loss_fake) / 2

        loss_D_A.backward()
        optimizer_D_A.step()

        #  Train D_B
        
        optimizer_D_B.zero_grad()

        # Real loss
        pred_real = D__B(real_Y_B)
        loss_real = criterion_GAN(pred_real, valid)
        # Fake loss (on batch of previously generated samples)
        fake_X_B_ = fake_X_B_buffer.push_and_pop(fake_X_B)
        pred_fake = D__B(fake_X_B_.detach())
        loss_fake = criterion_GAN(pred_fake, fake)
        # Total loss
        loss_D_B = (loss_real + loss_fake) / 2

        loss_D_B.backward()
        optimizer_D_B.step()

        loss_D = (loss_D_A + loss_D_B) / 2


        # Determine approximate time left
        batches_done = epoch * len(train_dataloader) + i
        batches_left = args.epoch_num * len(train_dataloader) - batches_done
        time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time))
        prev_time = time.time()

        # Print log
        sys.stdout.write("\r[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f, adv: %f, cycle: %f, identity: %f] ETA: %s" %
                                                        (epoch+1, args.epoch_num,
                                                        i, len(train_dataloader),
                                                        loss_D.data.cpu(), loss_G.data.cpu(),
                                                        loss_GAN.data.cpu(), loss_cycle.data.cpu(),
                                                        loss_identity.data.cpu(), time_left))
        
        # Save training statistics
        epoch_statistic += 'Batch:%d D_loss:%f G_loss:%f adv:%f cycle:%f identity:%f\n' % (i,
                                                        loss_D.data.cpu(), loss_G.data.cpu(),
                                                        loss_GAN.data.cpu(), loss_cycle.data.cpu(),
                                                        loss_identity.data.cpu())

        # If at sample interval save image
        if batches_done % args.sample_interval == 0:
            sample_images(args,G__AB,G__BA, test_dataloader, epoch, batches_done)

    stat_path = './%s-%s/statistics/' % (args.exp_name, args.dataset_name)
    if not os.path.isdir(stat_path):
        os.mkdir(stat_path)
    f = open(stat_path + str(epoch), 'w')
    f.write(epoch_statistic)
    f.close()
        
    
    # Update learning rates
    lr_scheduler_G.step(epoch)
    lr_scheduler_D_B.step(epoch)
    lr_scheduler_D_A.step(epoch)

    if args.checkpoint_interval != -1 and epoch % args.checkpoint_interval == 0:
        # Save model checkpoints
        torch.save(G__AB.state_dict(), './%s-%s/%s/G__AB_%d.pth' % (args.exp_name, args.dataset_name, args.model_result_dir, epoch))
        torch.save(G__BA.state_dict(), './%s-%s/%s/G__BA_%d.pth' % (args.exp_name, args.dataset_name, args.model_result_dir, epoch))
        torch.save(D__A.state_dict(), './%s-%s/%s/D__A_%d.pth' % (args.exp_name, args.dataset_name, args.model_result_dir, epoch))
        torch.save(D__B.state_dict(), './%s-%s/%s/D__B_%d.pth' % (args.exp_name, args.dataset_name, args.model_result_dir, epoch))


[Epoch 1/10] [Batch 1690/2832] [D loss: 0.352787] [G loss: 3.055143, adv: 0.512677, cycle: 0.240764, identity: 0.269649] ETA: 2:11:14.2580340