In [393]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
import torch.optim as optim
from torchsummary import summary
import torchvision.models as models
import torchvision.transforms.functional as ff
import torchvision.datasets as dsets
from PIL import Image
import itertools
%matplotlib inline

In [346]:
def conv_general(input_dim, output_dim, kernel_size, stride, padding=0,
                 norm=nn.InstanceNorm2d, normalize=True, activate=True, relu_factor=0):
    ops = list()
    ops.append(nn.Conv2d(input_dim, output_dim, kernel_size, stride, padding, bias=False))
    
    if normalize:
        ops.append(norm(output_dim))
    
    if activate:
        if relu_factor:
            relu = nn.LeakyReLU(relu_factor)
        else:
            relu = nn.ReLU()
        ops.append(relu)
        
    return nn.Sequential(*ops)

In [252]:
def deconv_general(input_dim, output_dim, kernel_size, stride, padding=0, output_padding=0,
                   norm=nn.InstanceNorm2d, normalize=True, activate=True, relu_factor=0):
    ops = list()
    ops.append(nn.ConvTranspose2d(input_dim, output_dim, kernel_size, stride,
                                  padding, output_padding, bias=False))
    
    if normalize:
        ops.append(norm(output_dim))
    
    if activate:
        if relu_factor:
            relu = nn.LeakyReLU(relu_factor)
        else:
            relu = nn.ReLU()
        ops.append(relu)
        
    return nn.Sequential(*ops)

In [253]:
class ResidualBlock(nn.Module):

    def __init__(self, input_dim, output_dim):
        super(ResidualBlock, self).__init__()

        self.refl_pad = nn.ReflectionPad2d(1)
        self.conv_general = conv_general(input_dim, output_dim, 3, 1)
        self.conv = nn.Conv2d(output_dim, output_dim, 3, 1)
        self.instance_norm = nn.InstanceNorm2d(output_dim)

    def forward(self, x):
        o = self.refl_pad(x)
        o = self.conv_general(x)
        o = self.refl_pad(x)
        o = self.conv(x)
        o = self.instance_norm(x)
        
        return x + o

In [347]:
class Generator(nn.Module):

    def __init__(self, channels=64, residual_blocks=9):
        super(Generator, self).__init__()
        # 3 input image channels, 2566 output channels, 7*7 square convolution
        # kernel
        self.residual_blocks = residual_blocks
        self.refl_pad = nn.ReflectionPad2d(3)
        
        self.conv_general1 = conv_general(3, channels, 7, 1)
        self.conv_general2 = conv_general(channels, channels * 2, 3, 2, 1)
        self.conv_general3 = conv_general(channels * 2, channels * 4, 3, 2, 1)
        
        self.res_block = ResidualBlock(channels * 4, channels * 4)
        
        self.deconv_general1 = deconv_general(channels * 4, channels * 2, 3, 2, 1, 1)
        self.deconv_general2 = deconv_general(channels * 2, channels, 3, 2, 1, 1)
        
        self.conv = nn.Conv2d(channels, 3, 7, 1)
        self.tanh = nn.Tanh()
        
    def forward(self, x):
        # encoder
        x = self.refl_pad(x)
        x = self.conv_general1(x)
        x = self.conv_general2(x)
        x = self.conv_general3(x)
        
        # transformer
        for i in range(self.residual_blocks):
            x = self.res_block(x)
        
        # decoder
        x = self.deconv_general1(x)
        x = self.deconv_general2(x)        
        x = self.refl_pad(x)
        x = self.conv(x)
        x = self.tanh(x)
        
        return x

    def num_flat_features(self, x):
        size = x.size()[1:]  # all dimensions except the batch dimension
        num_features = 1
        for s in size:
            num_features *= s
        return num_features

In [357]:
class Discriminator(nn.Module):

    def __init__(self, channels=64):
        super(Discriminator, self).__init__()
        # 3 input image channels, 2566 output channels, 7*7 square convolution
        # kernel
        
        self.conv_general1 = conv_general(3, channels, 4, 2, 1, normalize=False, relu_factor=0.02)
        self.conv_general2 = conv_general(channels, channels * 2, 4, 2, 1)
        self.conv_general3 = conv_general(channels * 2, channels * 4, 4, 2, 1)
        self.conv_general4 = conv_general(channels * 4, channels * 8, 4, 1, 1)       
        self.conv = nn.Conv2d(channels * 8, 1, 4, 1, 1)
        
    def forward(self, x):
        x = self.conv_general1(x)
        print(x.shape)
        x = self.conv_general2(x)
        print(x.shape)
        x = self.conv_general3(x)
        print(x.shape)
        x = self.conv_general4(x)
        print(x.shape)
        x = self.conv(x)
        print(x.shape)
        
        return x

In [380]:
a = torch.arange(3 * 256 * 256).reshape(1, 3, 256, 256).float()

In [381]:
gen = Generator()
o = gen(a)
dis = Discriminator()
o = dis(o)

torch.Size([1, 64, 128, 128])
torch.Size([1, 128, 64, 64])
torch.Size([1, 256, 32, 32])
torch.Size([1, 512, 31, 31])
torch.Size([1, 1, 30, 30])


In [None]:
gen_x_to_y = Generator()
gen_y_to_x = Generator()
disc_x = Discriminator()
disc_y = Discriminator()

In [339]:
a.shape

torch.Size([1, 1, 3, 3])

In [492]:
import copy
import os
import shutil

import numpy as np
import torch


def mkdir(paths):
    if not isinstance(paths, (list, tuple)):
        paths = [paths]
    for path in paths:
        if not os.path.isdir(path):
            os.makedirs(path)


def cuda_devices(gpu_ids):
    gpu_ids = [str(i) for i in gpu_ids]
    os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(gpu_ids)


def cuda(xs):
    if torch.cuda.is_available():
        if not isinstance(xs, (list, tuple)):
            return xs.cuda()
        else:
            return [x.cuda() for x in xs]
    else:
        return xs


def save_checkpoint(state, save_path, is_best=False, max_keep=None):
    # save checkpoint
    torch.save(state, save_path)

    # deal with max_keep
    save_dir = os.path.dirname(save_path)
    list_path = os.path.join(save_dir, 'latest_checkpoint')

    save_path = os.path.basename(save_path)
    if os.path.exists(list_path):
        with open(list_path) as f:
            ckpt_list = f.readlines()
            ckpt_list = [save_path + '\n'] + ckpt_list
    else:
        ckpt_list = [save_path + '\n']

    if max_keep is not None:
        for ckpt in ckpt_list[max_keep:]:
            ckpt = os.path.join(save_dir, ckpt[:-1])
            if os.path.exists(ckpt):
                os.remove(ckpt)
        ckpt_list[max_keep:] = []

    with open(list_path, 'w') as f:
        f.writelines(ckpt_list)

    # copy best
    if is_best:
        shutil.copyfile(save_path, os.path.join(save_dir, 'best_model.ckpt'))


def load_checkpoint(ckpt_dir_or_file, map_location=None, load_best=False):
    if os.path.isdir(ckpt_dir_or_file):
        if load_best:
            ckpt_path = os.path.join(ckpt_dir_or_file, 'best_model.ckpt')
        else:
            with open(os.path.join(ckpt_dir_or_file, 'latest_checkpoint')) as f:
                ckpt_path = os.path.join(ckpt_dir_or_file, f.readline()[:-1])
    else:
        ckpt_path = ckpt_dir_or_file
    ckpt = torch.load(ckpt_path, map_location=map_location)
    print(' [*] Loading checkpoint from %s succeed!' % ckpt_path)
    return ckpt


def reorganize(dataset_dir):
    dirs = {}
    dirs['trainA'] = os.path.join(dataset_dir, 'link_trainA')
    dirs['trainB'] = os.path.join(dataset_dir, 'link_trainB')
    dirs['testA'] = os.path.join(dataset_dir, 'link_testA')
    dirs['testB'] = os.path.join(dataset_dir, 'link_testB')
    mkdir(list(dirs.values()))

    for key in dirs:
        try:
            os.remove(os.path.join(dirs[key], '0'))
        except:
            pass
        os.symlink(os.path.abspath(os.path.join(dataset_dir, key)),
                   os.path.join(dirs[key], '0'))

    return dirs

In [447]:
class ItemPool(object):

    def __init__(self, max_num=50):
        self.max_num = max_num
        self.num = 0
        self.items = []

    def __call__(self, in_items):
        """`in_items` is a list of item."""
        if self.max_num <= 0:
            return in_items
        return_items = []
        for in_item in in_items:
            if self.num < self.max_num:
                self.items.append(in_item)
                self.num = self.num + 1
                return_items.append(in_item)
            else:
                if np.random.ranf() > 0.5:
                    idx = np.random.randint(0, self.max_num)
                    tmp = copy.copy(self.items[idx])
                    self.items[idx] = in_item
                    return_items.append(tmp)
                else:
                    return_items.append(in_item)
        return return_items

In [502]:
epochs = 200
start_epoch = 0
batch_size = 1
lr = 0.0002
dataset_dir = 'datasets/horse2zebra'

load_size = 286
crop_size = 256

transform = transforms.Compose(
    [transforms.RandomHorizontalFlip(),
     transforms.Resize(load_size),
     transforms.RandomCrop(crop_size),
     transforms.ToTensor(),
     transforms.Normalize(mean=[0.5] * 3, std=[0.5] * 3)])

dataset_dirs = reorganize(dataset_dir)
a_train_data = dsets.ImageFolder(dataset_dirs['trainA'], transform=transform)
b_train_data = dsets.ImageFolder(dataset_dirs['trainB'], transform=transform)
a_test_data = dsets.ImageFolder(dataset_dirs['testA'], transform=transform)
b_test_data = dsets.ImageFolder(dataset_dirs['testB'], transform=transform)
a_train_loader = torch.utils.data.DataLoader(a_train_data, batch_size=batch_size, shuffle=True, num_workers=4)
b_train_loader = torch.utils.data.DataLoader(b_train_data, batch_size=batch_size, shuffle=True, num_workers=4)
a_test_loader = torch.utils.data.DataLoader(a_test_data, batch_size=3, shuffle=True, num_workers=4)
b_test_loader = torch.utils.data.DataLoader(b_test_data, batch_size=3, shuffle=True, num_workers=4)

In [522]:
# model

disc_a = Discriminator()
disc_b = Discriminator()
gen_a = Generator()
gen_b = Generator()
MSE = nn.MSELoss()
L1 = nn.L1Loss()
cuda([disc_a, disc_b, gen_a, gen_b])

disc_a_optimizer = torch.optim.Adam(disc_a.parameters(), lr=lr, betas=(0.5, 0.999))
disc_b_optimizer = torch.optim.Adam(disc_b.parameters(), lr=lr, betas=(0.5, 0.999))
gen_a_optimizer = torch.optim.Adam(gen_a.parameters(), lr=lr, betas=(0.5, 0.999))
gen_b_optimizer = torch.optim.Adam(gen_b.parameters(), lr=lr, betas=(0.5, 0.999))

a_fake_pool = ItemPool()
b_fake_pool = ItemPool()

In [523]:
with torch.no_grad():
    a_test_real = torch.autograd.Variable(iter(a_test_loader).next()[0])
    b_test_real = torch.autograd.Variable(iter(b_test_loader).next()[0])
a_test_real, b_test_real = cuda([a_test_real, b_test_real])

In [527]:
# train
for epoch in range(start_epoch, 2):
    for i, ((a_train_real, _), (b_train_real, _)) in enumerate(zip(a_train_loader, b_train_loader)):
        step = epoch * min(len(a_loader), len(b_loader)) + i + 1
        
        gen_a.train()
        gen_b.train()
        
#         a_train_real.requires_grad = True
#         b_train_real.requires_grad = True
        a_train_real, b_train_real = cuda([a_train_real, b_train_real])
    
        # generate fake images
        a_train_fake = gen_a(b_train_real)
        b_train_fake = gen_b(a_train_real)
        
        a_train_cycle = gen_a(b_train_fake)
        b_train_cycle = gen_b(a_train_fake)
        
        
        a_train_fake_disc = disc_a(a_train_fake)
        b_train_fake_disc = disc_b(b_train_fake)
        
        # generator loss
        real_label = cuda(torch.ones(a_train_fake_disc.size()))
        a_train_loss_gen = MSE(a_train_fake_disc, real_label)
        b_train_loss_gen = MSE(b_train_fake_disc, real_label)
        
        # cyclic loss
        a_train_loss_cycle = L1(a_train_cycle, a_train_real)
        b_train_loss_cycle = L1(b_train_cycle, b_train_real)
        
        train_loss_gen = a_train_loss_gen + b_train_loss_gen + 10.0 * (a_train_loss_cycle + b_train_loss_cycle)
        
        # generator backprop
        gen_a.zero_grad()
        gen_b.zero_grad()
        train_loss_gen.backward()
        gen_a_optimizer.step()
        gen_b_optimizer.step()
        
        a_train_fake = torch.Tensor(a_fake_pool([a_train_fake.detach().numpy()])[0])
        b_train_fake = torch.Tensor(b_fake_pool([b_train_fake.detach().numpy()])[0])
        a_train_fake, b_train_fake = cuda([a_train_fake, b_train_fake])
        
        
        # train discriminators
        a_train_real_disc = disc_a(a_train_real)
        a_train_fake_disc = disc_a(a_train_fake)
        b_train_real_disc = disc_b(b_train_real)
        b_train_fake_disc = disc_b(b_train_fake)
        real_label = cuda(torch.ones(a_train_fake_disc.size()))
        fake_label = cuda(torch.zeros(a_train_fake_disc.size()))
        
        # discriminator loss
        a_train_real_loss_disc = MSE(a_train_real_disc, real_label)
        a_train_fake_loss_disc = MSE(a_train_fake_disc, fake_label)
        b_train_real_loss_disc = MSE(b_train_real_disc, real_label)
        b_train_fake_loss_disc = MSE(b_train_fake_disc, fake_label)

        a_train_loss_disc = a_train_real_loss_disc + a_train_fake_loss_disc
        b_train_loss_disc = b_train_real_loss_disc + b_train_fake_loss_disc
        
        # discriminator backprop
        disc_a.zero_grad()
        disc_b.zero_grad()
        a_train_loss_disc.backward()
        b_train_loss_disc.backward()
        disc_a_optimizer.step()
        disc_b_optimizer.step()
        
        if (i + 1) % 1 == 0:
            print("Epoch: (%3d) (%5d/%5d)" % (epoch, i + 1, min(len(a_loader), len(b_loader))))

        if (i + 1) % 100 == 0:
            gen_a.eval()
            gen_b.eval()

            # train G
            a_test_fake = gen_a(b_test_real)
            b_test_fake = gen_b(a_test_real)

            a_test_cycle = gen_a(b_test_fake)
            b_test_cycle = gen_b(a_test_fake)

            pic = torch.cat([a_test_real, b_test_fake, a_test_cycle, 
                              b_test_real, a_test_fake, b_test_cycle], 
                             dim=0).data / 2.0 + 0.5

            save_dir = './sample_images_while_training'
            utils.mkdir(save_dir)
            torchvision.utils.save_image(pic, '%s/Epoch_(%d)_(%dof%d).jpg' % (save_dir, epoch, i + 1, min(len(a_loader), len(b_loader))), nrow=3)

        break  

torch.Size([1, 64, 128, 128])
torch.Size([1, 128, 64, 64])
torch.Size([1, 256, 32, 32])
torch.Size([1, 512, 31, 31])
torch.Size([1, 1, 30, 30])
torch.Size([1, 64, 128, 128])
torch.Size([1, 128, 64, 64])
torch.Size([1, 256, 32, 32])
torch.Size([1, 512, 31, 31])
torch.Size([1, 1, 30, 30])
torch.Size([1, 64, 128, 128])
torch.Size([1, 128, 64, 64])
torch.Size([1, 256, 32, 32])
torch.Size([1, 512, 31, 31])
torch.Size([1, 1, 30, 30])
torch.Size([1, 64, 128, 128])
torch.Size([1, 128, 64, 64])
torch.Size([1, 256, 32, 32])
torch.Size([1, 512, 31, 31])
torch.Size([1, 1, 30, 30])
torch.Size([1, 64, 128, 128])
torch.Size([1, 128, 64, 64])
torch.Size([1, 256, 32, 32])
torch.Size([1, 512, 31, 31])
torch.Size([1, 1, 30, 30])
torch.Size([1, 64, 128, 128])
torch.Size([1, 128, 64, 64])
torch.Size([1, 256, 32, 32])
torch.Size([1, 512, 31, 31])
torch.Size([1, 1, 30, 30])
Epoch: (  0) (    1/ 1067)
torch.Size([1, 64, 128, 128])
torch.Size([1, 128, 64, 64])
torch.Size([1, 256, 32, 32])
torch.Size([1, 512, 3

tensor([[[[ 3.2638e-01, -3.6851e-01,  2.3288e-01,  ...,  4.0371e-01,
            4.6333e-01,  2.3077e-01],
          [ 3.0675e-02,  4.4485e-01,  2.8081e-01,  ...,  2.0270e-01,
            1.4395e-01,  4.9138e-01],
          [-2.4351e-01,  5.0338e-01,  1.5834e-01,  ...,  2.8667e-01,
            1.5118e-01,  5.1702e-01],
          ...,
          [-3.1655e-01,  2.7835e-01, -1.8281e-02,  ...,  2.7307e-01,
            1.8649e-01,  5.5607e-03],
          [ 9.4958e-02,  3.2726e-01, -5.6268e-02,  ...,  1.7509e-01,
           -1.1852e-02,  1.3652e-01],
          [-1.0770e-01,  1.9923e-01,  4.0159e-02,  ...,  9.6832e-02,
           -5.9244e-02, -1.0726e-01]],

         [[-3.1499e-01,  5.8623e-01,  6.2428e-02,  ...,  1.8101e-01,
            1.1826e-01,  2.8713e-02],
          [-2.2162e-01, -2.0546e-01,  8.8168e-02,  ...,  3.2158e-01,
            2.6489e-01, -5.0689e-01],
          [-1.9654e-01,  2.4779e-01, -1.5417e-01,  ..., -4.3926e-02,
            3.8946e-01, -2.6373e-02],
          ...,
     

In [466]:
a_fake_pool = ItemPool()
for i, ((a_real, a_label), (b_real, b_label)) in enumerate(zip(a_loader, b_loader)):
    a_real = torch.autograd.Variable(a_real)
    b_real = torch.autograd.Variable(b_real)
    print(b_real.shape)
    a_fake = gen(b_real)
    x = a_fake_pool([a_fake])
    a_fake = torch.autograd.Variable(torch.Tensor(x[0]))
    print(a_fake.shape)
    break

torch.Size([2, 3, 256, 256])
torch.Size([2, 3, 256, 256])
torch.Size([2, 3, 256, 256])


In [435]:
a_fake.shape

torch.Size([1, 3, 256, 256])

In [475]:
x = torch.rand(1, 3, 256, 256).float()
y = torch.rand(1, 3, 256, 256).float()
torch.cat([x, y, x, y], dim=0).shape

torch.Size([4, 3, 256, 256])

In [476]:
torchvision.utils.save_image(torch.cat([x, y, x, y], dim=0), 'gheu.jpg', nrow=2)