# Import Libraries

In [None]:
import torch
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [19]:
%matplotlib inline
import torch
import torch.nn as nn
import torch.nn.functional as F
import itertools
import pandas as pd
import numpy as np
import os
from torch.utils import data
from collections import OrderedDict
from torch.utils.data import Dataset
import matplotlib.pyplot as plt
import datetime
import random
from torch.autograd import Variable
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Paramters

In [2]:
epochs = 100 
num_pairs = 200
num_residual_blocks = 5

# Helper Functions

In [21]:
def conv_block(in_channel, out_channel, activation='relu', *args, **kwargs):
    activations = nn.ModuleDict([['lrelu', nn.LeakyReLU(negative_slope=0.2, inplace=True)], ['relu', nn.ReLU()]])
    return nn.Sequential(nn.Conv2d(in_channel, out_channel, *args, **kwargs), nn.BatchNorm2d(out_channel), activations[activation])
    

class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, downsample=None):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size = 3, stride =stride,padding =1,bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, in_channels, kernel_size = 3, stride =stride,padding =1,bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.downsample = downsample

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        if self.downsample:
            residual = self.downsample(x)
        out += residual
        out = self.relu(out)
        return out

def print_images(iml, dir, epoch, save_mode_on=True):
    titles = ['Real-A', 'Fake-B (A->B)', 'Recon-A (A->B->A)', 'Identity-A (A->A)', 'Real-B', 'Fake-A (B->A)', 'Recon-B (B->A->B)', 'Identity-B (B->B)']
    idx = 0
    fig, arr = plt.subplots(2,4, figsize=(12, 6))
    for i in range(2):
        for j in range(4):
            im = iml[idx].squeeze().T
            im = (im + 1) / 2.0
            arr[i, j].axis('off')
            arr[i, j].imshow(im.detach().cpu(), vmin=0, vmax=1)
            arr[i, j].set_title(titles[idx], fontweight="bold")
            idx = idx + 1
    fig.tight_layout()

    if save_mode_on:
        plt.savefig(os.path.join(dir, 'epoch-{}.jpg'.format(epoch)))
        plt.close()
    else:
        plt.show()

def assign_model_id(modeldir):
    return 'ep_' + str(epochs) + '-pairs_' + str(num_pairs) + '-resblocks_' + str(num_residual_blocks)

def to_var(x):
    if torch.cuda.is_available():
        x = x.cuda()
    return Variable(x)

class ImagePool():
    def __init__(self, pool_size):
        self.pool_size = pool_size
        if self.pool_size > 0:  
            self.num_imgs = 0
            self.images = []
    def query(self, images):
        if self.pool_size == 0:  
            return images
        return_images = []
        for image in images:
            image = torch.unsqueeze(image.data, 0)
            if self.num_imgs < self.pool_size:  
                self.num_imgs = self.num_imgs + 1
                self.images.append(image)
                return_images.append(image)
            else:
                p = random.uniform(0, 1)
                if p > 0.5:  
                    random_id = random.randint(0, self.pool_size - 1)  
                    tmp = self.images[random_id].clone()
                    self.images[random_id] = image
                    return_images.append(tmp)
                else:       
                    return_images.append(image)
        return_images = torch.cat(return_images, 0)  
        return return_images

In [4]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [5]:
%cd /content/drive/MyDrive
!pwd

/content/drive/MyDrive
/content/drive/MyDrive


# Generator

In [6]:
def create_generator():
    model = nn.Sequential(OrderedDict([]))

    encoder = nn.Sequential(OrderedDict([
        ('conv1', nn.Conv2d(in_channels=3, out_channels=64, kernel_size=(7,7), stride=1, bias=False, padding=3)),
        ('bnorm1', nn.BatchNorm2d(num_features=64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)),
        ('relu1', nn.ReLU()),
        ('conv2', nn.Conv2d(in_channels=64, out_channels=128, kernel_size=(3,3), stride=2, bias=False, padding=1)),
        ('bnorm2', nn.BatchNorm2d(num_features=128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)),
        ('relu2', nn.ReLU()),
        ('conv3', nn.Conv2d(in_channels=128, out_channels=256, kernel_size=(3,3), stride=2, bias=False, padding=1)),
        ('bnorm3', nn.BatchNorm2d(num_features=256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)),
        ('relu3', nn.ReLU())]))
    model.add_module(name='encoder', module=encoder)

    for i in range(num_residual_blocks):
        model.add_module(name='res{}'.format(i+1), module=ResidualBlock(256,256))

    decoder = nn.Sequential(OrderedDict([
        ('deconv1', nn.ConvTranspose2d(256,64, kernel_size=(3,3), stride=2, padding=1, output_padding=1)),
        ('bnorm4', nn.BatchNorm2d(num_features=64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)),
        ('relu4', nn.ReLU()),
        ('deconv2', nn.ConvTranspose2d(64,32, kernel_size=(3,3), stride=2, padding=1, output_padding=1)),
        ('bnorm5', nn.BatchNorm2d(num_features=32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)),
        ('relu5', nn.ReLU()),
        ('reflectpad', nn.ReflectionPad2d(3)),
        ('conv4', nn.Conv2d(in_channels=32, out_channels=3, kernel_size=(7,7), stride=1, bias=True)),
        ('tanh', nn.Tanh())]))
    model.add_module(name='decoder', module=decoder)

    return model

# Loss Functions

In [7]:
def gan_loss(p, real):
    if real:
        return F.mse_loss(p, torch.ones(p.shape).to(device))
    else:
        return F.mse_loss(p, torch.zeros(p.shape).to(device))

def cycle_loss(reconstructed, real):
    return F.l1_loss(reconstructed, real)

def identity_loss(identity, real):
    return F.l1_loss(identity, real)

# Discriminator:


In [8]:
def create_discriminator():
    discriminator = nn.Sequential(OrderedDict([
        ('conv', nn.Conv2d(in_channels=3, out_channels=64, kernel_size=(4,4), stride=2, padding=1)),
        ('lrelu1', nn.LeakyReLU(0.2, True)),
        ('convblock1', conv_block(in_channel=64, out_channel=128, activation='lrelu', kernel_size=(4,4), stride=2, padding=1, bias=False)),
        ('convblock2', conv_block(in_channel=128, out_channel=256, activation='lrelu', kernel_size=(4,4), stride=2, padding=1, bias=False)),
        ('convblock3', conv_block(in_channel=256, out_channel=512, activation='lrelu', kernel_size=(4,4), stride=1, padding=1, bias=False)),
        ('patch', nn.Conv2d(in_channels=512, out_channels=1, kernel_size=(4,4), stride=1, padding=1))]))  
    return discriminator

# GAN Model

In [22]:
class cycleGAN(nn.Module):
    def __init__(self, learning_rate=2e-4):
        nn.Module.__init__(self)

        self.learning_rate = learning_rate
        self.LAMBDA_CYCLE = 10.0
        self.LAMBDA_ID = 0.5
        pool = 50

        self.is_training = True
        self.save_losses = False

        self.fake_A_pool = ImagePool(pool)
        self.fake_B_pool = ImagePool(pool)

        self.tr_gen_loss_dict = {
            'a2b': [],
            'b2a': [],
            'id_a2b': [],
            'id_b2a': [],
            'a2b2a': [],
            'b2a2b': [],
            'loss_gen_total': []
        }
        self.tr_dis_loss_dict = {
            'loss_dis_b': [],
            'loss_dis_a': [],
            'loss_dis_total': []
        }
        self.val_gen_loss_dict = {
            'a2b': [],
            'b2a': [],
            'id_a2b': [],
            'id_b2a': [],
            'a2b2a': [],
            'b2a2b': [],
            'loss_gen_total': []
        }
        self.val_dis_loss_dict = {
            'loss_dis_b': [],
            'loss_dis_a': [],
            'loss_dis_total': []
        }

        self.im_list = []

        self.generator_a2b = create_generator()
        self.generator_b2a = create_generator()

        self.disc_a = create_discriminator()
        self.disc_b = create_discriminator()

        self.opt_G = torch.optim.Adam(itertools.chain(self.generator_a2b.parameters(), self.generator_b2a.parameters()), lr=self.learning_rate)
        self.opt_D = torch.optim.Adam(itertools.chain(self.disc_a.parameters(), self.disc_b.parameters()), lr=self.learning_rate)

    def forward(self, real_a, real_b):
        fake_a2b = self.generator_a2b(real_a)
        recon_b2a = self.generator_b2a(fake_a2b)
        fake_b2a = self.generator_b2a(real_b)
        recon_a2b = self.generator_a2b(fake_b2a)
        identity_a2b = self.generator_a2b(real_b)
        identity_b2a = self.generator_b2a(real_a)
        self.im_list = [real_a, fake_a2b, recon_b2a, identity_b2a, real_b, fake_b2a, recon_a2b, identity_a2b]

        return fake_a2b, recon_b2a, fake_b2a, recon_a2b, identity_a2b, identity_b2a

    def backward_G(self, real_a, real_b, fake_a2b, recon_b2a, fake_b2a, recon_a2b, identity_a2b, identity_b2a):
        if self.is_training:
            self.set_requires_grad([self.disc_a, self.disc_b], False)
            self.opt_G.zero_grad()

        loss_identity_a2b = identity_loss(identity_a2b, real_b)
        loss_identity_b2a = identity_loss(identity_b2a, real_a)

        loss_gan_gen_a2b = gan_loss(self.disc_b(fake_a2b), True)
        loss_gan_gen_b2a = gan_loss(self.disc_a(fake_b2a), True)
        a2b2a = cycle_loss(recon_b2a, real_a)
        b2a2b = cycle_loss(recon_a2b, real_b)

        loss_gen_total = loss_gan_gen_a2b + loss_gan_gen_b2a \
            + (a2b2a + b2a2b) * self.LAMBDA_CYCLE \
            + (loss_identity_a2b + loss_identity_b2a) * self.LAMBDA_ID

        if self.is_training:
            loss_gen_total.backward()
            self.opt_G.step()

        if self.save_losses:
            if self.is_training:
                self.tr_gen_loss_dict['a2b'].append(loss_gan_gen_a2b.item())
                self.tr_gen_loss_dict['b2a'].append(loss_gan_gen_b2a.item())
                self.tr_gen_loss_dict['id_a2b'].append(loss_identity_a2b.item())
                self.tr_gen_loss_dict['id_b2a'].append(loss_identity_b2a.item())
                self.tr_gen_loss_dict['a2b2a'].append(a2b2a.item())
                self.tr_gen_loss_dict['b2a2b'].append(b2a2b.item())
                self.tr_gen_loss_dict['loss_gen_total'].append(loss_gen_total.item())
            else:
                self.val_gen_loss_dict['a2b'].append(loss_gan_gen_a2b.item())
                self.val_gen_loss_dict['b2a'].append(loss_gan_gen_b2a.item())
                self.val_gen_loss_dict['id_a2b'].append(loss_identity_a2b.item())
                self.val_gen_loss_dict['id_b2a'].append(loss_identity_b2a.item())
                self.val_gen_loss_dict['a2b2a'].append(a2b2a.item())
                self.val_gen_loss_dict['b2a2b'].append(b2a2b.item())
                self.val_gen_loss_dict['loss_gen_total'].append(loss_gen_total.item())


    def backward_D(self, real_a, real_b, fake_a2b, fake_b2a):
        fake_a2b = self.fake_B_pool.query(fake_a2b)
        fake_b2a = self.fake_A_pool.query(fake_b2a)

        if self.is_training:
            self.set_requires_grad([self.disc_a, self.disc_b], True)
            self.opt_D.zero_grad()   

        loss_gan_dis_a_real = gan_loss(self.disc_a(real_a), True)
        loss_gan_dis_a_fake = gan_loss(self.disc_a(fake_b2a.detach()), False)

        loss_gan_dis_b_real = gan_loss(self.disc_b(real_b), True)
        loss_gan_dis_b_fake = gan_loss(self.disc_b(fake_a2b.detach()), False) 

        loss_dis_a = (loss_gan_dis_a_real + loss_gan_dis_a_fake) * 0.5
        loss_dis_b = (loss_gan_dis_b_real + loss_gan_dis_b_fake) * 0.5

        loss_dis_total = loss_dis_a + loss_dis_b

        if self.is_training:
            loss_dis_total.backward()
            self.opt_D.step()

        if self.save_losses:
            if self.is_training:
                self.tr_dis_loss_dict['loss_dis_b'].append(loss_dis_b.item())
                self.tr_dis_loss_dict['loss_dis_a'].append(loss_dis_a.item())
                self.tr_dis_loss_dict['loss_dis_total'].append(loss_dis_total.item())
            else:
                self.val_dis_loss_dict['loss_dis_b'].append(loss_dis_b.item())
                self.val_dis_loss_dict['loss_dis_a'].append(loss_dis_a.item())
                self.val_dis_loss_dict['loss_dis_total'].append(loss_dis_total.item())

    def set_requires_grad(self, nets, requires_grad=False):
        if not isinstance(nets, list):
            nets = [nets]
        for net in nets:
            if net is not None:
                for param in net.parameters():
                    param.requires_grad = requires_grad

    def optimize_parameters(self, real_a, real_b):
        fake_a2b, recon_b2a, fake_b2a, recon_a2b, identity_a2b, identity_b2a = self.forward(real_a, real_b)  
        self.backward_G(real_a, real_b, fake_a2b, recon_b2a, fake_b2a, recon_a2b, identity_a2b, identity_b2a)
        self.backward_D(real_a, real_b, fake_a2b, fake_b2a)

# Directories

In [10]:
def manage_folders():
    timee = datetime.datetime.now().strftime("%Y_%m_%d-%H:%M")

    cur = os.getcwd()

    if not os.path.isdir(os.path.join(cur, 'Output')):
        os.mkdir(os.path.join(cur, 'Output'))

    out = os.path.join(cur, 'Output')
    out = os.path.join(out, timee)
    os.mkdir(out)

    graphdir = os.path.join(out, 'loss-graphs')
    if not os.path.isdir(graphdir):
        os.mkdir(graphdir)

    imdir = os.path.join(out, 'generated-images')
    if not os.path.isdir(imdir):
        os.mkdir(imdir)

    trdir = os.path.join(imdir, 'train')
    if not os.path.isdir(trdir):
        os.mkdir(trdir)

    valdir = os.path.join(imdir, 'val')
    if not os.path.isdir(valdir):
        os.mkdir(valdir)

    modeldir = os.path.join(out, 'saved-models')
    if not os.path.isdir(modeldir):
        os.mkdir(modeldir)

    return trdir, valdir, graphdir, modeldir

# Dataset

In [11]:
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader

def pic_loader(pic_type):
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    data_dir = '/content/drive/MyDrive/selfie2anime'

    train_path = os.path.join(data_dir, 'train{}'.format(pic_type))
    test_path = os.path.join(data_dir, 'test{}'.format(pic_type))

    train_dataset = datasets.ImageFolder(train_path, transform)
    test_dataset = datasets.ImageFolder(test_path, transform)

    train_dloader = DataLoader(dataset=train_dataset, batch_size=1, shuffle=True, num_workers=2)
    test_dloader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False, num_workers=2)

    return train_dloader, test_dloader

In [12]:
data_dir = '/content/drive/MyDrive/selfie2anime'

dataloader_A, test_dataloader_A = pic_loader('A')
dataloader_B, test_dataloader_B = pic_loader('B')

# Train

In [13]:
def train(train_dataset_SS, validation_dataset_SS, train_dataset_AA, validation_dataset_AA, epochs, device):

    # X_s: Real selfie 
    # X_a: Real anime 
    # y_s: Selfie label = 1
    # y_a: Anime label = 0

    model = cycleGAN().to(device)

    for epoch in range(epochs):
        train_dataset_S = iter(train_dataset_SS)
        validation_dataset_S = iter(validation_dataset_SS)
        train_dataset_A = iter(train_dataset_AA)
        validation_dataset_A = iter(validation_dataset_AA)

        print('Epoch', epoch+1, '------------------')

        # Training
        temp = 1
        model.is_training = True
        for X_s, y_s in train_dataset_S:
            X_a, y_a = train_dataset_A.next()

            X_s, y_s = to_var(X_s), to_var(y_s).long().squeeze()
            X_a, y_a = to_var(X_a), to_var(y_a).long().squeeze()

            X_s, X_a = X_s.to(device), X_a.to(device)

            if temp == train_dataset_S.__len__():
                model.save_losses = True

            model.optimize_parameters(X_s, X_a)

            temp = temp+1

        print('Tr - Generator Loss:', np.round(model.tr_gen_loss_dict['loss_gen_total'][-1], decimals=4))
        print('Tr - Dicriminator Loss:', np.round(model.tr_dis_loss_dict['loss_dis_total'][-1], decimals=4))

        model.save_losses = False
        print_images(model.im_list, trdir, str(epoch), save_mode_on=True)

        # Validation
        with torch.set_grad_enabled(False):
            temp = 1
            model.is_training = False
            for X_s, y_s in validation_dataset_S:
                X_a, y_a = validation_dataset_A.next()
                X_s, y_s = to_var(X_s), to_var(y_s).long().squeeze()
                X_a, y_a = to_var(X_a), to_var(y_a).long().squeeze()

                X_s, X_a = X_s.to(device), X_a.to(device)

                if temp == validation_dataset_S.__len__():
                    model.save_losses = True

                model.optimize_parameters(X_s, X_a)

                temp = temp+1

            print('Val - Generator Loss:', np.round(model.val_gen_loss_dict['loss_gen_total'][-1], decimals=4))
            print('Val - Dicriminator Loss:', np.round(model.val_dis_loss_dict['loss_dis_total'][-1], decimals=4))

            model.save_losses = False
            print_images(model.im_list, valdir, str(epoch), save_mode_on=True)

    #save everthing
    df = pd.DataFrame.from_dict(model.tr_gen_loss_dict)
    df.to_csv(os.path.join(graphdir, 'tr_gen_losses.csv'), index=False)
    df = pd.DataFrame.from_dict(model.tr_dis_loss_dict)
    df.to_csv(os.path.join(graphdir, 'tr_dis_losses.csv'), index=False)
    df = pd.DataFrame.from_dict(model.val_gen_loss_dict)
    df.to_csv(os.path.join(graphdir, 'val_gen_losses.csv'), index=False)
    df = pd.DataFrame.from_dict(model.val_dis_loss_dict)
    df.to_csv(os.path.join(graphdir, 'val_dis_losses.csv'), index=False)
    torch.save(model.state_dict(), os.path.join(modeldir, assign_model_id(modeldir)) + '.pth')

Run training

In [23]:
trdir, valdir, graphdir, modeldir = manage_folders()
train((dataloader_A), (test_dataloader_A), (dataloader_B), (test_dataloader_B), epochs, device)

Epoch 1 ------------------


KeyboardInterrupt: ignored