In [1]:
# Pix2Pix.py
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
import torchvision.transforms as T

import matplotlib.image as mpimg
import matplotlib.pyplot as plt

from math import sqrt, log2
import time

# Data
- Url  
https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/

In [2]:
class ResnetBlock(nn.Module):
    def __init__(self, channel, kernel, stride, reflect_padding):
        super(ResnetBlock, self).__init__()
        self.features = nn.Sequential(
            nn.ReflectionPad2d(reflect_padding),
            nn.Conv2d(channel, channel, kernel, stride, 0),
            nn.InstanceNorm2d(channel),
            nn.ReflectionPad2d(reflect_padding),
            nn.Conv2d(channel, channel, kernel, stride, 0),
            nn.InstanceNorm2d(channel))
    
    def forward(self, x):
        x_res = self.features(x)
        return x + x_res

In [3]:
def resnet_init_weight(m, mean, std):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(mean, std)

In [4]:
class Generator(nn.Module):
    def __init__(self, input_nc, output_nc, ngf, nb, dtype):
        super(Generator, self).__init__()
        self.input_nc = input_nc
        self.output_nc = output_nc
        self.ngf = ngf
        self.nb = nb
        
        self.encoder = nn.Sequential(
            # Conv 1
            nn.ReflectionPad2d(3),
            nn.Conv2d(input_nc, ngf, 7, 1, 0),
            nn.InstanceNorm2d(ngf),
            nn.ReLU(True),
            # Conv 2
            nn.Conv2d(ngf, ngf*2, 3, 2, 1),
            nn.InstanceNorm2d(ngf*2),
            nn.ReLU(True),
            # Conv 3
            nn.Conv2d(ngf*2, ngf*4, 3, 2, 1),
            nn.InstanceNorm2d(ngf*4),
            nn.ReLU(True)).type(dtype)

        # 256x256 -> 9 / 128x128 -> 6
        self.resnet_blocks = []
        for i in range(nb):
            self.resnet_blocks += [ResnetBlock(ngf * 4, 3, 1, 1)]
            self.resnet_blocks[i].apply(lambda x :resnet_init_weight(x, 0, 0.2))

        self.resnet_blocks = nn.Sequential(*self.resnet_blocks).type(dtype)
    
        self.decocder = nn.Sequential(
            # Deconv1
            nn.ConvTranspose2d(ngf*4, ngf*2, 3, 2, 1, 1),
            nn.InstanceNorm2d(ngf*2),
            nn.ReLU(True),
            # Deconv2
            nn.ConvTranspose2d(ngf*2, ngf, 3, 2, 1, 1),
            nn.InstanceNorm2d(ngf),
            nn.ReLU(True),
            nn.ReflectionPad2d(3),
            nn.Conv2d(ngf, output_nc, 7, 1, 0),
            nn.Tanh()).type(dtype)

    def forward(self, x):
        x = self.encoder(x)
        x = self.resnet_blocks(x)
        x = self.decocder(x)
        return x

In [5]:
class Discriminator(nn.Module):
    def __init__(self, input_nc, output_nc, ndf, dtype):
        super(Discriminator, self).__init__()
        self.features = nn.Sequential(
            # Conv1
            nn.Conv2d(input_nc, ndf, 4, 2, 1),
            nn.LeakyReLU(0.2, True),
            # Conv2
            nn.Conv2d(ndf, ndf*2, 4, 2, 1),
            nn.InstanceNorm2d(ndf*2),
            nn.LeakyReLU(0.2, True),
            # Conv3
            nn.Conv2d(ndf*2, ndf*4, 4, 2, 1),
            nn.InstanceNorm2d(ndf*4),
            nn.LeakyReLU(0.2, True),
            # Conv4
            nn.Conv2d(ndf*4, ndf*8, 4, 1, 1),
            nn.InstanceNorm2d(ndf*8),
            nn.LeakyReLU(0.2, True),
            # Conv5
            nn.Conv2d(ndf*8, output_nc, 4, 1, 1)).type(dtype)
        
    def forward(self, x):
        return self.features(x)

In [6]:
dtype = torch.cuda.FloatTensor

In [88]:
G_A = Generator(3, 3, 32, 6, dtype)
D_A = Discriminator(3, 1, 64, dtype)

In [8]:
import torch.nn.functional as F
import os
from torch.utils.data import Dataset, DataLoader

In [None]:
DataLoader()

In [9]:
class CycleDataset(Dataset):
    def __init__(self, file_dir, size):
        self.file_dir = file_dir
        self.file_list = os.listdir(file_dir)
        self.size = size
        self.transform = T.Compose([
            T.ToTensor(),
            T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

    def resize2d(self, img):
        with torch.no_grad():
            img = (F.adaptive_avg_pool2d(Variable(img), (self.size, self.size))).data
        return img

    def __getitem__(self, index):
        x = mpimg.imread(self.file_dir + self.file_list[index])
        y = x[:,256:,:]
        x = x[:,:256,:]
        x = self.transform(x)
        y = self.transform(y)
        if self.size != 256:
            x = self.resize2d(x)
            y = self.resize2d(y)
        return x, y
 
    def __len__(self):
        return len(self.file_list)

In [11]:
a_dataset = CycleDataset('data/horse2zebra/trainA/', 128)
b_dataset = CycleDataset('data/horse2zebra/trainA/', 128)

In [12]:
a_dataloader = DataLoader(a_dataset, shuffle=True)
b_dataloader = DataLoader(b_dataset, shuffle=True)

In [14]:
input_ngc = 3
output_ngc = 3
input_ndc = 3
output_ndc = 1
ngf = 32
ndf = 64
nb = 6 #128x128 -> 6 # 256x256 -> 9

G_A = Generator(input_ngc, output_ngc, ngf, nb, dtype)
G_B = Generator(input_ngc, output_ngc, ngf, nb, dtype)
D_A = Discriminator(input_ndc, output_ndc, ndf, dtype)
D_B = Discriminator(input_ndc, output_ndc, ndf, dtype)

In [15]:
import itertools

In [16]:
BCE_loss = nn.BCELoss().type(dtype)
MSE_loss = nn.MSELoss().type(dtype)
L1_loss = nn.L1Loss().type(dtype)

# Adam optimizer
G_optimizer = optim.Adam(itertools.chain(G_A.parameters(), G_B.parameters()), lr=0.0002, betas=(0.5, 0.999))
D_A_optimizer = optim.Adam(D_A.parameters(), lr=0.0002, betas=(0.5, 0.999))
D_B_optimizer = optim.Adam(D_B.parameters(), lr=0.0002, betas=(0.5, 0.999))

In [None]:
class ImageBuffer

In [None]:
class image_store():
    def __init__(self, store_size=50):
        self.store_size = store_size
        self.num_img = 0
        self.images = []

    def query(self, image):
        select_imgs = []
        for i in range(image.size()[0]):
            if self.num_img < self.store_size:
                self.images.append(image)
                select_imgs.append(image)
                self.num_img += 1
            else:
                prob = np.random.uniform(0, 1)
                if prob > 0.5:
                    ind = np.random.randint(0, self.store_size - 1)
                    select_imgs.append(self.images[ind])
                    self.images[ind] = image
                else:
                    select_imgs.append(image)

        return Variable(torch.cat(select_imgs, 0))

class ImagePool():
    def __init__(self, pool_size):
        self.pool_size = pool_size
        if self.pool_size > 0:
            self.num_imgs = 0
            self.images = []

    def query(self, images):
        if self.pool_size == 0:
            return images
        return_images = []
        for image in images.data:
            image = torch.unsqueeze(image, 0)
            if self.num_imgs < self.pool_size:
                self.num_imgs = self.num_imgs + 1
                self.images.append(image)
                return_images.append(image)
            else:
                p = random.uniform(0, 1)
                if p > 0.5:
                    random_id = random.randint(0, self.pool_size-1)
                    tmp = self.images[random_id].clone()
                    self.images[random_id] = image
                    return_images.append(tmp)
                else:
                    return_images.append(image)
        return_images = Variable(torch.cat(return_images, 0))
        return return_images