In [None]:
# #########################################################################################
# Simple Implementation of CycleGAN
# -----------------------------------------------------------------------------------------
# For Project: CycleGAN for Winter in Singapore
# -----------------------------------------------------------------------------------------
# Author: GongjieZhang@ntu.edu.sg (Nanyang Technological University, Singapore)
# -----------------------------------------------------------------------------------------
# Prerequisites:
#     Ubuntu18.04   python3.7   cuda10.1   numpy1.17.2   matplotlib3.1.1   
#     pytorch1.3   torchvision0.4.1   tqdm4.36.1    scikit-image0.16.1
# #########################################################################################


In [None]:
# Import necessary packages

import os
import tqdm 
import random
import itertools
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torch.utils.tensorboard import SummaryWriter


In [None]:
# Set fixed random seeds for reproducibility

random.seed(0)
np.random.seed(0)
torch.manual_seed(0)
torch.backends.cudnn.deterministic = True


In [None]:
# Training Configs

summer_dir = '/home/gongjie/GJ_Research/Summer_Winter_DATASET/summer/'
winter_dir = '/home/gongjie/GJ_Research/Summer_Winter_DATASET/winter/'

img_size = 512  # training resolution @ 512 x 512

batchsize = 1
lr_init = 2e-4

total_num_iteration = 150000
save_every_iter = 10000


In [None]:
# Dataset & Sampler for Training

class ImageDataset(torch.utils.data.Dataset):
    
    def __init__(self, summer_dir, winter_dir, img_size):
        self.S_dir = summer_dir
        self.W_dir = winter_dir
        self.S_imgs = [filename for filename in os.listdir(summer_dir) if os.path.splitext(filename)[-1] in ('.jpg', '.png') ]
        self.W_imgs = [filename for filename in os.listdir(winter_dir) if os.path.splitext(filename)[-1] in ('.jpg', '.png') ]
        self.transform = [ torchvision.transforms.Resize(int(img_size*1.15), Image.BICUBIC), 
                           torchvision.transforms.RandomCrop(img_size), 
                           torchvision.transforms.RandomHorizontalFlip(),
                           torchvision.transforms.ToTensor(),  #  [0 - 255] --> [0 - 1.0]
                           torchvision.transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)) ]
        self.transform = torchvision.transforms.Compose(self.transform)
        
    def __getitem__(self, index):
        summer_img = self.transform(Image.open(os.path.join(self.S_dir ,
                                                            self.S_imgs[index % len(self.S_imgs)])).convert('RGB'))
        winter_img = self.transform(Image.open(os.path.join(self.W_dir, 
                                                            self.W_imgs[random.randint(0, len(self.W_imgs) - 1)])).convert('RGB'))
        return summer_img, winter_img
    
    def __len__(self):
        return max(len(self.S_dir), len(self.W_dir))


class IterationBasedBatchSampler(torch.utils.data.sampler.BatchSampler):
    """
    Wraps a BatchSampler, re-sampling from it until [num_iterations] iterations have been sampled
    """

    def __init__(self, batch_sampler, num_iterations, start_iter=0):
        self.batch_sampler = batch_sampler
        self.num_iterations = num_iterations
        self.start_iter = start_iter

    def __iter__(self):
        iteration = self.start_iter
        while iteration <= self.num_iterations:
            # if the underlying sampler has a set_epoch method, like
            # DistributedSampler, used for making each process see
            # a different split of the dataset, then set it
            if hasattr(self.batch_sampler.sampler, "set_epoch"):
                self.batch_sampler.sampler.set_epoch(iteration)
            for batch in self.batch_sampler:
                iteration += 1
                if iteration > self.num_iterations:
                    break
                yield batch

    def __len__(self):
        return self.num_iterations
    

In [None]:
# helper function to show an image

def img_show(img):
    img = img / 2.0 + 0.5     # unnormalize
    npimg = img.detach().cpu().numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    
    
def transform_show(img1, img2):
    fig = plt.figure(figsize=(16, 34))
    ax = fig.add_subplot(1, 2, 1, xticks=[], yticks=[])
    img_show(img1)
    ax = fig.add_subplot(1, 2, 2, xticks=[], yticks=[])
    img_show(img2)
    return fig


In [None]:
# Definition of Network Architectures

class ResBlock(nn.Module):

    def __init__(self, in_features):
        super(ResBlock, self).__init__()

        conv_block = [  nn.ReflectionPad2d(1),
                        nn.Conv2d(in_features, in_features, 3),
                        nn.InstanceNorm2d(in_features),
                        nn.ReLU(inplace=True),
                        nn.ReflectionPad2d(1),
                        nn.Conv2d(in_features, in_features, 3),
                        nn.InstanceNorm2d(in_features)  ]

        self.conv_block = nn.Sequential(*conv_block)

    def forward(self, x):
        return x + self.conv_block(x)


class Gen(nn.Module):
    
    def __init__(self, input_nc=3, output_nc=3, n_residual_blocks=9):
        super(Gen, self).__init__()
        
        # Initial convolution block       
        model = [   nn.ReflectionPad2d(3),
                    nn.Conv2d(input_nc, 64, 7),
                    nn.InstanceNorm2d(64),
                    nn.ReLU(inplace=True) ]

        # Downsampling
        in_features = 64
        out_features = in_features*2
        for _ in range(2):
            model += [  nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
                        nn.InstanceNorm2d(out_features),
                        nn.ReLU(inplace=True) ]
            in_features = out_features
            out_features = in_features*2

        # Residual blocks
        for _ in range(n_residual_blocks):
            model += [ResBlock(in_features)]

        # Upsampling
        out_features = in_features//2
        for _ in range(2):
            model += [  nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
                        nn.InstanceNorm2d(out_features),
                        nn.ReLU(inplace=True) ]
            in_features = out_features
            out_features = in_features//2

        # Output layer
        model += [  nn.ReflectionPad2d(3),
                    nn.Conv2d(64, output_nc, 7),
                    nn.Tanh() ]

        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)
    
    def load(self, model):
        self.load_state_dict(torch.load(model, map_location=lambda storage, loc: storage))

    def save(self, model_path):
        torch.save(self.state_dict(), model_path)


class Dis(nn.Module):
    
    def __init__(self, input_nc=3):
        super(Dis, self).__init__()

        # A bunch of convolutions one after another
        model = [   nn.Conv2d(input_nc, 64, 4, stride=2, padding=1),
                    nn.LeakyReLU(0.2, inplace=True) ]

        model += [  nn.Conv2d(64, 128, 4, stride=2, padding=1),
                    nn.InstanceNorm2d(128), 
                    nn.LeakyReLU(0.2, inplace=True) ]

        model += [  nn.Conv2d(128, 256, 4, stride=2, padding=1),
                    nn.InstanceNorm2d(256), 
                    nn.LeakyReLU(0.2, inplace=True) ]

        model += [  nn.Conv2d(256, 512, 4, padding=1),
                    nn.InstanceNorm2d(512), 
                    nn.LeakyReLU(0.2, inplace=True) ]

        # FCN classification layer
        model += [nn.Conv2d(512, 1, 4, padding=1)]

        self.model = nn.Sequential(*model)

    def forward(self, x):
        x =  self.model(x)
        return F.avg_pool2d(x, x.size()[2:]).view(x.size()[0], -1)
    
    def load(self, model):
        self.load_state_dict(torch.load(model, map_location=lambda storage, loc: storage))

    def save(self, model_path):
        torch.save(self.state_dict(), model_path)


def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)


In [None]:
# Definition for Loss Functions for Network Optimizations

MSELoss = torch.nn.MSELoss()
L1Loss = torch.nn.L1Loss()

def MSErealTargetLoss(x):
    target = torch.cuda.FloatTensor(x.shape[0], 1).fill_(1.0)
    return MSELoss(x, target)
    

def MSEfakeTargetLoss(x):
    target = torch.cuda.FloatTensor(x.shape[0], 1).fill_(0.0)
    return MSELoss(x, target)


def cycleLoss(a, a_):
    return L1Loss(a, a_)


In [None]:
# LambdaLR Scheduler Definition

class LambdaLR():
    def __init__(self, n_total, decay_start):
        assert ((n_total - decay_start) > 0), "Decay must start before the training session ends!"
        self.n_total = n_total
        self.decay_start = decay_start

    def step(self, epoch):
        return 1.0 - max(0, epoch - self.decay_start)/(self.n_total - self.decay_start)
    

In [None]:
# Replay Buffer -- A trick for CycleGAN Optimization

class ReplayBuffer():
    def __init__(self, max_size=50):
        assert (max_size > 0)
        self.max_size = max_size
        self.data = []

    def push_and_pop(self, data):
        to_return = []
        for element in data.data:
            element = torch.unsqueeze(element, 0)
            if len(self.data) < self.max_size:
                self.data.append(element)
                to_return.append(element)
            else:
                if random.uniform(0,1) > 0.5:
                    i = random.randint(0, self.max_size-1)
                    to_return.append(self.data[i].clone())
                    self.data[i] = element
                else:
                    to_return.append(element)
        return torch.cat(to_return)
    

In [None]:
# Network Training Scheme 
# (Can be skipped if model has already been trained.)

# Initialize generators and discriminators
genS2W = Gen() 
genW2S = Gen()
disS = Dis()
disW = Dis()

genS2W.apply(weights_init)
genW2S.apply(weights_init)
disS.apply(weights_init)
disW.apply(weights_init)

genS2W.cuda()
genW2S.cuda()
disS.cuda()
disW.cuda()

# Optimizers & LR schedulers
optG = torch.optim.Adam(itertools.chain(genS2W.parameters(), genW2S.parameters()), lr=lr_init, betas=(0.5, 0.999))
optD_S = torch.optim.Adam(disS.parameters(), lr=lr_init, betas=(0.5, 0.999))
optD_W = torch.optim.Adam(disW.parameters(), lr=lr_init, betas=(0.5, 0.999))
lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(optG, lr_lambda=LambdaLR(total_num_iteration, total_num_iteration//1.5).step)
lr_scheduler_D_S = torch.optim.lr_scheduler.LambdaLR(optD_S, lr_lambda=LambdaLR(total_num_iteration, total_num_iteration//1.5).step)
lr_scheduler_D_W = torch.optim.lr_scheduler.LambdaLR(optD_W, lr_lambda=LambdaLR(total_num_iteration, total_num_iteration//1.5).step)

dataset = ImageDataset(summer_dir, winter_dir, img_size)
sampler = torch.utils.data.RandomSampler(dataset)
batch_sampler = torch.utils.data.sampler.BatchSampler(sampler=sampler, batch_size=batchsize, drop_last=True)
batch_sampler = IterationBasedBatchSampler(batch_sampler, num_iterations=total_num_iteration)
train_loader = torch.utils.data.DataLoader(dataset, num_workers=6, batch_sampler=batch_sampler, pin_memory=True)

fake_S_buffer = ReplayBuffer()
fake_W_buffer = ReplayBuffer()

writer = SummaryWriter('outputs')  # specify directory to store visualization outputs

for iteration, (S_imgs, W_imgs) in tqdm.tqdm(enumerate(train_loader)):

    S_imgs, W_imgs = S_imgs.cuda(), W_imgs.cuda()

    optG.zero_grad()
    
    # Summer to Winter, then Winter to Summer
    identity_W = genS2W(W_imgs)
    faked_W = genS2W(S_imgs) 
    restored_S = genW2S(faked_W)
    
    # Winter to Summer, then Summer to Winter
    identity_S = genW2S(S_imgs)
    faked_S = genW2S(W_imgs) 
    restored_W = genS2W(faked_S)

    # compute Adv and cyclic losses, and their updates
    AdvLossS = MSErealTargetLoss(disS(faked_S))
    AdvLossW = MSErealTargetLoss(disW(faked_W))
    CycleLoss1 = cycleLoss(S_imgs, restored_S) * 10.0
    CycleLoss2 = cycleLoss(W_imgs, restored_W) * 10.0
    loss_G = AdvLossS + AdvLossW + CycleLoss1 + CycleLoss2
    loss_G.backward()
    optG.step()

    # Dis losses and their updates
    optD_S.zero_grad()
    DisLossS = (MSEfakeTargetLoss(disS(fake_S_buffer.push_and_pop(faked_S).detach())) + MSErealTargetLoss(disS(S_imgs))) * 0.5
    DisLossS.backward()
    optD_S.step()

    optD_W.zero_grad()
    DisLossW = (MSEfakeTargetLoss(disW(fake_W_buffer.push_and_pop(faked_W).detach())) + MSErealTargetLoss(disW(W_imgs))) * 0.5
    DisLossW.backward()
    optD_W.step()
    
    cntAdvLoss = AdvLossS.item() + AdvLossW.item()
    cntCycleLoss = CycleLoss1.item() + CycleLoss2.item()
    cntGenLoss = loss_G.item()
    
    cntDisLossS = DisLossS.item()
    cntDisLossW = DisLossW.item()
    cntDisLoss = cntDisLossS + cntDisLossW
    
    # Log training procedure
    writer.add_scalar('AdvLoss', cntAdvLoss, iteration+1)
    writer.add_scalar('CycleLoss', cntCycleLoss, iteration+1)
    writer.add_scalar('Loss_G', cntGenLoss, iteration+1)
    
    writer.add_scalar('DisLossS', cntDisLossS, iteration+1)
    writer.add_scalar('DisLossW', cntDisLossW, iteration+1)
    writer.add_scalar('Loss_D', cntDisLoss, iteration+1)
    
    # Save model and outputs
    if (iteration+1) % save_every_iter == 0:
        genS2W.save('./s2w_' + "%06d" % (iteration+1) + '.pth')
        genW2S.save('./w2s_' + "%06d" % (iteration+1) + '.pth')
    
    lr_scheduler_G.step()
    lr_scheduler_D_S.step()
    lr_scheduler_D_W.step()
    
print('Training Completed. \n')

In [None]:
# Dataset for Inference

class ImageDatasetForInference(torch.utils.data.Dataset):
    
    def __init__(self, summer_dir, img_size):
        self.S_dir = summer_dir
        self.S_imgs = [filename for filename in os.listdir(summer_dir) if os.path.splitext(filename)[-1] in ('.jpg', '.png')]
        self.transform = [ torchvision.transforms.Resize(int(img_size), Image.BICUBIC), 
                           torchvision.transforms.ToTensor(),  #  [0 - 255] --> [0 - 1.0] and To Tensor
                           torchvision.transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)) ]
        self.transform = torchvision.transforms.Compose(self.transform)
        
    def __getitem__(self, index):
        return self.transform(Image.open(os.path.join(self.S_dir ,
                                                      self.S_imgs[(index) % len(self.S_imgs)])).convert('RGB'))
    
    def __len__(self):
        return len(self.S_dir)


In [None]:
# Inference Configs

img_size = 1024  # inference at resolution of 1024 x 1024
img_num = 1335
model_dir = 's2w_150000.pth'
img_dir = '/home/gongjie/GJ_Research/Summer_Winter_DATASET/sg/'
store_dir = './generated_size' + str(img_size) + '_model' + str(model_dir)


In [None]:
# Inference Scheme

dataset = ImageDatasetForInference(img_dir, img_size)
model = Gen().cuda()
model.load(model_dir)
model.eval()
if not os.path.exists(store_dir):
    os.mkdir(store_dir)
     
with torch.no_grad():
    for i in tqdm.tqdm(range(img_num)):
        img = dataset.__getitem__(i).cuda().unsqueeze(0)
        gen_img = model(img).squeeze(0)
        fig = transform_show(img.squeeze(0), gen_img)
        plt.savefig(os.path.join(store_dir, "%04d" % (i+1) + '.jpg'))
        plt.close("all")