In [1]:
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
import torch.utils.data as data
import random
from torchvision import transforms
import torchvision.models.vgg as vgg
import torch.utils.model_zoo as model_zoo
from collections import namedtuple
import torch
from PIL import Image
from torch.nn import functional as F
import os , itertools
from glob import glob
import matplotlib.pyplot as plt
import tqdm
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


Parameter

In [2]:
#model params
params = {
    'batch_size':120,
    'input_size':224,
    'resize_scale':224,
    'crop_size':224,
    'fliplr':False,
    'num_epochs':100,
    'decay_epoch':50,
    'ngf':32,   #number of generator filters
    'ndf':64,   #number of discriminator filters
    'num_resnet':6, #number of resnet blocks
    'lrG':0.0002,    #learning rate for generator
    'lrD':0.0002,    #learning rate for discriminator
    'beta1':0.5 ,    #beta1 for Adam optimizer
    'beta2':0.999 ,  #beta2 for Adam optimizer
    'lambdaA':10 ,   #lambdaA for cycle loss
    'lambdaB':10  ,  #lambdaB for cycle loss
    'img_form':'jpg'
}
model_dir='../../model/3Division_cycleGan/5x/'
data_dir = '../../data/OriginalTile/'

In [3]:
def to_np(x):
    return x.data.cpu().numpy()
def plot_train_result(real_image, gen_image, recon_image, epoch, save=False,  show=True, fig_size=(15, 15)):
    fig, axes = plt.subplots(2, 3, figsize=fig_size)
    imgs = [to_np(real_image[0]), to_np(gen_image[0]), to_np(recon_image[0]),
            to_np(real_image[1]), to_np(gen_image[1]), to_np(recon_image[1])]
    for ax, img in zip(axes.flatten(), imgs):
        ax.axis('off')
        #ax.set_adjustable('box-forced')
        # Scale to 0-255
        img = img.squeeze()
        img = (((img - img.min()) * 255) / (img.max() - img.min())).transpose(1, 2, 0).astype(np.uint8)
        ax.imshow(img, cmap=None, aspect='equal')
    plt.subplots_adjust(wspace=0, hspace=0)

    title = 'Epoch {0}'.format(epoch + 1)
    fig.text(0.5, 0.04, title, ha='center')

    # save figure
    if save:
        save_fn = 'Result_epoch_{:d}'.format(epoch+1) + '.png'
        plt.savefig(save_fn)

    if show:
        plt.show()
    else:
        plt.close()
def model_load(G_A1, G_B1, D_A1, D_B1, epoch,device,model_dir):
    G_A1.load_state_dict(torch.load(model_dir+'G_A_'+str(epoch)+'.pth',map_location=device))
    G_B1.load_state_dict(torch.load(model_dir+'G_B_'+str(epoch)+'.pth',map_location=device))
    D_A1.load_state_dict(torch.load(model_dir+'D_A_'+str(epoch)+'.pth',map_location=device))
    D_B1.load_state_dict(torch.load(model_dir+'D_B_'+str(epoch)+'.pth',map_location=device))
    G_A1.eval()
    G_B1.eval()
    D_A1.eval()
    D_B1.eval()
    return G_A1, G_B1, D_A1, D_B1

dataLoader

In [4]:
class ImagePool():
    def __init__(self, pool_size):
        self.pool_size = pool_size
        if self.pool_size > 0:
            self.num_imgs = 0
            self.images = []

    def query(self, images):
        if self.pool_size == 0:
            return images
        return_images = []
        for image in images.data:
            image = torch.unsqueeze(image, 0)
            if self.num_imgs < self.pool_size:
                self.num_imgs = self.num_imgs + 1
                self.images.append(image)
                return_images.append(image)
            else:
                p = random.uniform(0, 1)
                if p > 0.5:
                    random_id = random.randint(0, self.pool_size-1)
                    tmp = self.images[random_id].clone()
                    self.images[random_id] = image
                    return_images.append(tmp)
                else:
                    return_images.append(image)
        return_images = Variable(torch.cat(return_images, 0))
        return return_images
        
class DatasetFromFolder(data.Dataset):
    def __init__(self, image_dir, subfolder='train', transform=None, resize_scale=None, crop_size=None, fliplr=False):
        super(DatasetFromFolder, self).__init__()
        self.input_path = os.path.join(image_dir, subfolder)
        self.image_filenames = [x for x in sorted(glob(self.input_path+'/*.'+params['img_form']))]
        self.image_filenames = [f.replace(self.input_path+'/', '') for f in self.image_filenames]
        print(self.image_filenames)
        self.transform = transform
        
        self.resize_scale = resize_scale
        self.crop_size = crop_size
        self.fliplr = fliplr

    def __getitem__(self, index):
        # Load Image
        img_fn = os.path.join(self.input_path, self.image_filenames[index])
        img = Image.open(img_fn).convert('RGB')

        # preprocessing
        if self.resize_scale:
            img = img.resize((self.resize_scale, self.resize_scale), Image.BILINEAR)

        if self.crop_size:
            x = random.randint(0, self.resize_scale - self.crop_size + 1)
            y = random.randint(0, self.resize_scale - self.crop_size + 1)
            img = img.crop((x, y, x + self.crop_size, y + self.crop_size))
        if self.fliplr:
            if random.random() < 0.5:
                img = img.transpose(Image.FLIP_LEFT_RIGHT)

        if self.transform is not None:
            img = self.transform(img)

        return img

    def __len__(self):
        return len(self.image_filenames)

CycleGAN Architecture

In [5]:
class ResidualBlock(nn.Module):
    def __init__(self, in_features):
        super(ResidualBlock, self).__init__()

        conv_block = [  nn.ReflectionPad2d(1),
                        nn.Conv2d(in_features, in_features, 3),
                        nn.InstanceNorm2d(in_features),
                        nn.ReLU(inplace=True),
                        nn.ReflectionPad2d(1),
                        nn.Conv2d(in_features, in_features, 3),
                        nn.InstanceNorm2d(in_features)  ]

        self.conv_block = nn.Sequential(*conv_block)

    def forward(self, x):
        return x + self.conv_block(x)

class Generator(nn.Module):
    def __init__(self, input_nc, output_nc, n_residual_blocks=9):
        super(Generator, self).__init__()

        # Initial convolution block       
        model = [   nn.ReflectionPad2d(3),
                    nn.Conv2d(input_nc, 64, 7),
                    nn.InstanceNorm2d(64),
                    nn.ReLU(inplace=True) ]

        # Downsampling
        in_features = 64
        out_features = in_features*2
        for _ in range(2):
            model += [  nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
                        nn.InstanceNorm2d(out_features),
                        nn.ReLU(inplace=True) ]
            in_features = out_features
            out_features = in_features*2

        # Residual blocks
        for _ in range(n_residual_blocks):
            model += [ResidualBlock(in_features)]

        # Upsampling
        out_features = in_features//2
        for _ in range(2):
            model += [  nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
                        nn.InstanceNorm2d(out_features),
                        nn.ReLU(inplace=True) ]
            in_features = out_features
            out_features = in_features//2

        # Output layer
        model += [  nn.ReflectionPad2d(3),
                    nn.Conv2d(64, output_nc, 7),
                    nn.Tanh() ]

        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)

class Discriminator(nn.Module):
    def __init__(self, input_nc):
        super(Discriminator, self).__init__()

        # A bunch of convolutions one after another
        model = [   nn.Conv2d(input_nc, 64, 4, stride=2, padding=1),
                    nn.LeakyReLU(0.2, inplace=True) ]

        model += [  nn.Conv2d(64, 128, 4, stride=2, padding=1),
                    nn.InstanceNorm2d(128), 
                    nn.LeakyReLU(0.2, inplace=True) ]

        model += [  nn.Conv2d(128, 256, 4, stride=2, padding=1),
                    nn.InstanceNorm2d(256), 
                    nn.LeakyReLU(0.2, inplace=True) ]

        model += [  nn.Conv2d(256, 512, 4, padding=1),
                    nn.InstanceNorm2d(512), 
                    nn.LeakyReLU(0.2, inplace=True) ]

        # FCN classification layer
        model += [nn.Conv2d(512, 1, 4, padding=1)]

        self.model = nn.Sequential(*model)

    def forward(self, x):
        x =  self.model(x)
        # Average pooling and flatten
        return F.avg_pool2d(x, x.size()[2:]).view(x.size()[0], -1)
def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)

    elif classname.find("BatchNorm") != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)

Load dataset

In [6]:
transform = transforms.Compose([
    transforms.Resize(size=params['input_size']),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])
#Subfolders - day & night
train_data_A = DatasetFromFolder(data_dir, subfolder='5x', transform=transform,
                                resize_scale=None, crop_size=None, fliplr=params['fliplr'])
train_data_loader_A = torch.utils.data.DataLoader(dataset=train_data_A, batch_size=params['batch_size'], shuffle=False)
 

['S10-00098-13_0_0.jpg', 'S10-00098-13_0_1.jpg', 'S10-00098-13_0_2.jpg', 'S10-00098-13_0_3.jpg', 'S10-00098-13_1_0.jpg', 'S10-00098-13_1_1.jpg', 'S10-00098-13_1_2.jpg', 'S10-00098-13_1_3.jpg', 'S10-00098-13_2_0.jpg', 'S10-00098-13_2_1.jpg', 'S10-00098-13_2_2.jpg', 'S10-00098-13_2_3.jpg', 'S10-00098-13_3_0.jpg', 'S10-00098-13_3_1.jpg', 'S10-00098-13_3_2.jpg', 'S10-00098-13_3_3.jpg', 'S10-00098-13_4_0.jpg', 'S10-00098-13_4_1.jpg', 'S10-00098-13_4_2.jpg', 'S10-00098-13_4_3.jpg', 'S10-00098-4_0_0.jpg', 'S10-00098-4_0_1.jpg', 'S10-00098-4_0_2.jpg', 'S10-00098-4_0_3.jpg', 'S10-00098-4_10_0.jpg', 'S10-00098-4_10_1.jpg', 'S10-00098-4_10_2.jpg', 'S10-00098-4_10_3.jpg', 'S10-00098-4_11_0.jpg', 'S10-00098-4_11_1.jpg', 'S10-00098-4_11_2.jpg', 'S10-00098-4_11_3.jpg', 'S10-00098-4_12_0.jpg', 'S10-00098-4_12_1.jpg', 'S10-00098-4_12_2.jpg', 'S10-00098-4_12_3.jpg', 'S10-00098-4_13_0.jpg', 'S10-00098-4_13_1.jpg', 'S10-00098-4_13_2.jpg', 'S10-00098-4_13_3.jpg', 'S10-00098-4_14_0.jpg', 'S10-00098-4_14_1.j

Build model

In [7]:

#Build Model 
#G_A - Day->Night ; G_B - Night -> Day
G_A = Generator(3,3) 
G_B = Generator(3,3)

#two Discriminators
D_A = Discriminator(3)
D_B = Discriminator(3)

G_A,G_B,D_A,D_B=model_load(G_A,G_B,D_A,D_B,99,device,model_dir)

G_A = G_A.to(device)
G_B = G_B.to(device)

#two Discriminators
D_A = D_A.to(device)
D_B = D_B.to(device)
G_optimizer = torch.optim.Adam(itertools.chain(G_A.parameters(), G_B.parameters()), lr=params['lrG'], betas=(params['beta1'], params['beta2']))
D_A_optimizer = torch.optim.Adam(D_A.parameters(), lr=params['lrD'], betas=(params['beta1'], params['beta2']))
D_B_optimizer = torch.optim.Adam(D_B.parameters(), lr=params['lrD'], betas=(params['beta1'], params['beta2']))

TestModel

In [8]:
count = 0
filename_list=train_data_loader_A.dataset.image_filenames

for i, real_A in enumerate(train_data_loader_A):
    real_A = real_A.to(device)
    fake_A = G_B(real_A)
    fake_A=to_np(fake_A)
    fake_A = fake_A.squeeze()
    fake_A = (((fake_A - fake_A.min()) * 255) / (fake_A.max() - fake_A.min())).transpose(0, 2, 3, 1).astype(np.uint8)
    for j in range(len(fake_A)):
        img=Image.fromarray(fake_A[j])
        img.save('../../data/CycleGANData/3divisionTile/5x_standard/'+filename_list[i*120+j])


In [9]:
len(fake_A)

80