In [1]:
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
import torch.utils.data as data
import random
from torchvision import transforms
import torchvision.models.vgg as vgg
import torch.utils.model_zoo as model_zoo
from collections import namedtuple
import torch
from PIL import Image
from torch.nn import functional as F
import os , itertools
from glob import glob
import matplotlib.pyplot as plt
from tqdm import tqdm
device = torch.device("cuda:4" if torch.cuda.is_available() else "cpu")
device



device(type='cuda', index=4)

Parameter

In [2]:
#model params
params = {
    'batch_size':1,
    'input_size':1024,
    'resize_scale':1024,
    'crop_size':1024,
    'fliplr':False,
    'num_epochs':100,
    'decay_epoch':50,
    'ngf':32,   #number of generator filters
    'ndf':64,   #number of discriminator filters
    'num_resnet':6, #number of resnet blocks
    'lrG':0.0002,    #learning rate for generator
    'lrD':0.0002,    #learning rate for discriminator
    'beta1':0.5 ,    #beta1 for Adam optimizer
    'beta2':0.999 ,  #beta2 for Adam optimizer
    'lambdaA':10 ,   #lambdaA for cycle loss
    'lambdaB':10,  #lambdaB for cycle loss
    'img_form':'jpeg'
}

data_a_dir = '../../data/normalization_type/Notstandard/'
save_dir = '../../data/normalization_type/Notstandard_gan/'

In [3]:
def to_np(x):
    return x.data.cpu().numpy()
def plot_train_result(real_image, gen_image, recon_image, epoch, save=False,  show=True, fig_size=(15, 15)):
    fig, axes = plt.subplots(2, 3, figsize=fig_size)
    imgs = [to_np(real_image[0]), to_np(gen_image[0]), to_np(recon_image[0]),
            to_np(real_image[1]), to_np(gen_image[1]), to_np(recon_image[1])]
    for ax, img in zip(axes.flatten(), imgs):
        ax.axis('off')
        #ax.set_adjustable('box-forced')
        # Scale to 0-255
        img = img.squeeze()
        img = (((img - img.min()) * 255) / (img.max() - img.min())).transpose(1, 2, 0).astype(np.uint8)
        ax.imshow(img, cmap=None, aspect='equal')
    plt.subplots_adjust(wspace=0, hspace=0)

    title = 'Epoch {0}'.format(epoch + 1)
    fig.text(0.5, 0.04, title, ha='center')

    # save figure
    if save:
        save_fn = '../../result/cyclegan/Result_epoch_{:d}'.format(epoch+1) + '.png'
        plt.savefig(save_fn)

    if show:
        plt.show()
    else:
        plt.close()

dataLoader

In [4]:
class ImagePool():
    def __init__(self, pool_size):
        self.pool_size = pool_size
        if self.pool_size > 0:
            self.num_imgs = 0
            self.images = []

    def query(self, images):
        if self.pool_size == 0:
            return images
        return_images = []
        with torch.no_grad():  # No need to track gradients
            for image in images:
                image = torch.unsqueeze(image, 0)
                if self.num_imgs < self.pool_size:
                    self.num_imgs += 1
                    self.images.append(image)
                    return_images.append(image)
                else:
                    p = random.uniform(0, 1)
                    if p > 0.5:
                        random_id = random.randint(0, self.pool_size - 1)
                        tmp = self.images[random_id].clone()
                        self.images[random_id] = image
                        return_images.append(tmp)
                    else:
                        return_images.append(image)
            return_images = torch.cat(return_images, 0)
        return return_images.detach()  
        
class DatasetFromFolder(data.Dataset):
    def __init__(self, image_list, transform=None, resize_scale=None, crop_size=None, fliplr=False):
        super(DatasetFromFolder, self).__init__()
        self.image_list =image_list
        self.transform = transform
        
        self.resize_scale = resize_scale
        self.crop_size = crop_size
        self.fliplr = fliplr

    def __getitem__(self, index):
        # Load Image
        img = Image.open(self.image_list[index]).convert('RGB')

        # preprocessing
        if self.resize_scale:
            img = img.resize((self.resize_scale, self.resize_scale), Image.BILINEAR)

        if self.crop_size:
            x = random.randint(0, self.resize_scale - self.crop_size + 1)
            y = random.randint(0, self.resize_scale - self.crop_size + 1)
            img = img.crop((x, y, x + self.crop_size, y + self.crop_size))
        if self.fliplr:
            if random.random() < 0.5:
                img = img.transpose(Image.FLIP_LEFT_RIGHT)

        if self.transform is not None:
            img = self.transform(img)

        return img,self.image_list[index]

    def __len__(self):
        return len(self.image_list)

CycleGAN Architecture

In [5]:
class ResidualBlock(nn.Module):
    def __init__(self, in_features):
        super(ResidualBlock, self).__init__()

        conv_block = [  nn.ReflectionPad2d(1),
                        nn.Conv2d(in_features, in_features, 3),
                        nn.InstanceNorm2d(in_features),
                        nn.ReLU(inplace=True),
                        nn.ReflectionPad2d(1),
                        nn.Conv2d(in_features, in_features, 3),
                        nn.InstanceNorm2d(in_features)  ]

        self.conv_block = nn.Sequential(*conv_block)

    def forward(self, x):
        return x + self.conv_block(x)

class Generator(nn.Module):
    def __init__(self, input_nc, output_nc, n_residual_blocks=9):
        super(Generator, self).__init__()

        # Initial convolution block       
        model = [   nn.ReflectionPad2d(3),
                    nn.Conv2d(input_nc, 64, 7),
                    nn.InstanceNorm2d(64),
                    nn.ReLU(inplace=True) ]

        # Downsampling
        in_features = 64
        out_features = in_features*2
        for _ in range(2):
            model += [  nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
                        nn.InstanceNorm2d(out_features),
                        nn.ReLU(inplace=True) ]
            in_features = out_features
            out_features = in_features*2

        # Residual blocks
        for _ in range(n_residual_blocks):
            model += [ResidualBlock(in_features)]

        # Upsampling
        out_features = in_features//2
        for _ in range(2):
            model += [  nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
                        nn.InstanceNorm2d(out_features),
                        nn.ReLU(inplace=True) ]
            in_features = out_features
            out_features = in_features//2

        # Output layer
        model += [  nn.ReflectionPad2d(3),
                    nn.Conv2d(64, output_nc, 7),
                    nn.Tanh() ]

        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)

def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)

    elif classname.find("BatchNorm") != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)

Load dataset

In [6]:
transform = transforms.Compose([
    transforms.Resize(size=params['input_size']),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])
#Subfolders - day & night
data_list=glob(data_a_dir+'*.'+params['img_form'])
dataset = DatasetFromFolder(data_list, transform=transform,
                                resize_scale=params['resize_scale'], crop_size=params['crop_size'], fliplr=params['fliplr'])
data_loader = torch.utils.data.DataLoader(dataset=dataset , batch_size=params['batch_size'], shuffle=True)

 

Build model

In [7]:
#Build Model 
#G_A - Day->Night ; G_B - Night -> Day
Generator = Generator(3,3).to(device) 
Generator.load_state_dict(torch.load('../../model/cyclegan/G_B_14.pth', map_location=device))

  Generator.load_state_dict(torch.load('../../model/cyclegan/G_B_14.pth', map_location=device))


<All keys matched successfully>

predict Model

In [8]:
data_loader_tqdm=tqdm(data_loader)

for real_A,path in data_loader_tqdm:
    real_A = real_A.to(device)
    fake_B = Generator(real_A)
    fake_B = fake_B.to('cpu')
    fake_B = fake_B.detach().numpy()
    fake_B = np.squeeze(fake_B)
    fake_B = (((fake_B - fake_B.min()) * 255) / (fake_B.max() - fake_B.min())).transpose(1, 2, 0).astype(np.uint8)
    fake_B = Image.fromarray(fake_B)
    
    fake_B.save(save_dir+os.path.basename(path[0]))

100%|██████████| 4478/4478 [09:41<00:00,  7.70it/s]
