In [None]:
#### pacakages importing
import os
import torch
import argparse
import numpy as np
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import torch.nn.functional as F
import torchvision.models as models

from random import shuffle
from PIL import Image
from tqdm import tqdm

try:
    from torch.hub import load_state_dict_from_url
except ImportError:
    from torch.utils.model_zoo import load_url as load_state_dict_from_url

In [None]:
#### definition of class Instance-Normalization
class InstanceNormalization(nn.Module):
    def __init__(self, dim, eps=1e-9):
        super(InstanceNormalization, self).__init__()
        self.scale = nn.Parameter(torch.FloatTensor(dim))   
        self.shift = nn.Parameter(torch.FloatTensor(dim))
        # x' = sclae* * x + shift
        self.eps = eps
        self._reset_parameters()

    def _reset_parameters(self):
        # initialization or reseting
        self.scale.data.uniform_()
        self.shift.data.zero_()
        
    def __call__(self, x):
        '''
        x --[N, C, H, W]
        mean, var, scale_broadcast, shift_broadcast --[N, C, H, W]
        '''
        n = x.size(2)*x.size(3)
        t = x.view(x.size(0), x.size(1), n)
        # [N, C, H, W] -> [N, C, HW]
        '''unsqueeze(i) where i means we will expand the dim as the i-th dim'''
        
        mean = torch.mean(t, 2).unsqueeze(2).unsqueeze(3).expand_as(x)
        var = torch.var(t, 2).unsqueeze(2).unsqueeze(3).expand_as(x)
        # [N, C] -> [H, C, 1] -> [N, C, 1, 1] -> [N, C, H, W]
        # the avg is based on each channel crossing the 'H-W' plane
        
        scale_broadcast = self.scale.unsqueeze(1).unsqueeze(1).unsqueeze(0)
        scale_broadcast = scale_broadcast.expand_as(x)
        shift_broadcast = self.shift.unsqueeze(1).unsqueeze(1).unsqueeze(0)
        shift_broadcast = shift_broadcast.expand_as(x)
        # [dim] -> [dim, 1] -> [dim, 1, 1] -> [1, dim, 1, 1] -> [N, dim, H, W]
        # for each channel
            
        out = (x - mean) / torch.sqrt(var + self.eps)
        out = out * scale_broadcast + shift_broadcast
        return out

In [None]:
#### definition of Generator
class CartoonGenerator(nn.Module):
    def __init__(self):
        super(CartoonGenerator, self).__init__()
        
        """Down Convolution"""
        self.refpad0_1_1 = nn.ReflectionPad2d(3)
        self.conv0_1_1 = nn.Conv2d(3, 64, 7)
        self.in0_1_1 = InstanceNormalization(64)
        # relu
        # [H, W]
        
        self.conv0_2_1 = nn.Conv2d(64, 128, 3, 2, 1)
        self.conv0_2_2 = nn.Conv2d(128, 128, 3, 1, 1)
        self.in0_2_1 = InstanceNormalization(128)
        # relu
        # [H/2, W/2]
    
        self.conv0_3_1 = nn.Conv2d(128, 256, 3, 2, 1)
        self.conv0_3_2 = nn.Conv2d(256, 256, 3, 1, 1)
        self.in0_3_1 = InstanceNormalization(256)
        # relu
        # [H/4, W/4]
        
        """Residual Blocks"""
        self.refpad0_4_1 = nn.ReflectionPad2d(1)
        self.conv0_4_1 = nn.Conv2d(256, 256, 3)
        self.in0_4_1 = InstanceNormalization(256)
        # relu
        self.refpad0_4_2 = nn.ReflectionPad2d(1)
        self.conv0_4_2 = nn.Conv2d(256, 256, 3)
        self.in0_4_2 = InstanceNormalization(256)
        # + input
        
        self.refpad0_5_1 = nn.ReflectionPad2d(1)
        self.conv0_5_1 = nn.Conv2d(256, 256, 3)
        self.in0_5_1 = InstanceNormalization(256)
        # relu
        self.refpad0_5_2 = nn.ReflectionPad2d(1)
        self.conv0_5_2 = nn.Conv2d(256, 256, 3)
        self.in0_5_2 = InstanceNormalization(256)
        # + input
        
        self.refpad0_6_1 = nn.ReflectionPad2d(1)
        self.conv0_6_1 = nn.Conv2d(256, 256, 3)
        self.in0_6_1 = InstanceNormalization(256)
        # relu
        self.refpad0_6_2 = nn.ReflectionPad2d(1)
        self.conv0_6_2 = nn.Conv2d(256, 256, 3)
        self.in0_6_2 = InstanceNormalization(256)
        # + input
        
        self.refpad0_7_1 = nn.ReflectionPad2d(1)
        self.conv0_7_1 = nn.Conv2d(256, 256, 3)
        self.in0_7_1 = InstanceNormalization(256)
        # relu
        self.refpad0_7_2 = nn.ReflectionPad2d(1)
        self.conv0_7_2 = nn.Conv2d(256, 256, 3)
        self.in0_7_2 = InstanceNormalization(256)
        # + input
        
        self.refpad0_8_1 = nn.ReflectionPad2d(1)
        self.conv0_8_1 = nn.Conv2d(256, 256, 3)
        self.in0_8_1 = InstanceNormalization(256)
        # relu
        self.refpad0_8_2 = nn.ReflectionPad2d(1)
        self.conv0_8_2 = nn.Conv2d(256, 256, 3)
        self.in0_8_2 = InstanceNormalization(256)
        # + input
        
        self.refpad0_9_1 = nn.ReflectionPad2d(1)
        self.conv0_9_1 = nn.Conv2d(256, 256, 3)
        self.in0_9_1 = InstanceNormalization(256)
        # relu
        self.refpad0_9_2 = nn.ReflectionPad2d(1)
        self.conv0_9_2 = nn.Conv2d(256, 256, 3)
        self.in0_9_2 = InstanceNormalization(256)
        # + input
        
        self.refpad0_10_1 = nn.ReflectionPad2d(1)
        self.conv0_10_1 = nn.Conv2d(256, 256, 3)
        self.in0_10_1 = InstanceNormalization(256)
        # relu
        self.refpad0_10_2 = nn.ReflectionPad2d(1)
        self.conv0_10_2 = nn.Conv2d(256, 256, 3)
        self.in0_10_2 = InstanceNormalization(256)
        # + input
        
        self.refpad0_11_1 = nn.ReflectionPad2d(1)
        self.conv0_11_1 = nn.Conv2d(256, 256, 3)
        self.in0_11_1 = InstanceNormalization(256)
        # relu
        self.refpad0_11_2 = nn.ReflectionPad2d(1)
        self.conv0_11_2 = nn.Conv2d(256, 256, 3)
        self.in0_11_2 = InstanceNormalization(256)
        # + input
        
        """UP Deconvolution"""
        self.deconv0_12_1 = nn.ConvTranspose2d(256, 128, 3, 2, 1, 1)
        self.deconv0_12_2 = nn.ConvTranspose2d(128, 128, 3, 1, 1)
        self.in0_12_1 = InstanceNormalization(128)
        # relu
        
        self.deconv0_13_1 = nn.ConvTranspose2d(128, 64, 3, 2, 1, 1)
        self.deconv0_13_2 = nn.ConvTranspose2d(64, 64, 3, 1, 1)
        self.in0_13_1 = InstanceNormalization(64)
        # relu
        
        self.refpad0_14_1 = nn.ReflectionPad2d(3)
        self.deconv0_14_1 = nn.Conv2d(64, 3, 7)
        # tanh
        
    def forward(self, x):
        y = F.relu(self.in0_1_1(self.conv0_1_1(self.refpad0_1_1(x))))
        y = F.relu(self.in0_2_1(self.conv0_2_2(self.conv0_2_1(y))))
        t04 = F.relu(self.in0_3_1(self.conv0_3_2(self.conv0_3_1(y))))
        
        """"""
        y = F.relu(self.in0_4_1(self.conv0_4_1(self.refpad0_4_1(t04))))
        t05 = self.in0_4_2(self.conv0_4_2(self.refpad0_4_2(y))) + t04
        
        y = F.relu(self.in0_5_1(self.conv0_5_1(self.refpad0_5_1(t05))))
        t06 = self.in0_5_2(self.conv0_5_2(self.refpad0_5_2(y))) + t05
        
        y = F.relu(self.in0_6_1(self.conv0_6_1(self.refpad0_6_1(t06))))
        t07 = self.in0_6_2(self.conv0_6_2(self.refpad0_6_2(y))) + t06
        
        y = F.relu(self.in0_7_1(self.conv0_7_1(self.refpad0_7_1(t07))))
        t08 = self.in0_7_2(self.conv0_7_2(self.refpad0_7_2(y))) + t07
        
        y = F.relu(self.in0_8_1(self.conv0_8_1(self.refpad0_8_1(t08))))
        t09 = self.in0_8_2(self.conv0_8_2(self.refpad0_8_2(y))) + t08
        
        y = F.relu(self.in0_9_1(self.conv0_9_1(self.refpad0_9_1(t09))))
        t10 = self.in0_9_2(self.conv0_9_2(self.refpad0_9_2(y))) + t09
        
        y = F.relu(self.in0_10_1(self.conv0_10_1(self.refpad0_10_1(t10))))
        t11 = self.in0_10_2(self.conv0_10_2(self.refpad0_10_2(y))) + t10
        
        y = F.relu(self.in0_11_1(self.conv0_11_1(self.refpad0_11_1(t11))))
        y = self.in0_11_2(self.conv0_11_2(self.refpad0_11_2(y))) + t11
        """"""
        
        y = F.relu(self.in0_12_1(self.deconv0_12_2(self.deconv0_12_1(y))))
        y = F.relu(self.in0_13_1(self.deconv0_13_2(self.deconv0_13_1(y))))
        y = torch.tanh(self.deconv0_14_1(self.refpad0_14_1(y)))

        return y

In [None]:
#### definition of Discriminator
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.conv_1 = nn.Conv2d(3, 32, 3, 1)
        # leak_relu
        
        self.conv_2_1 = nn.Conv2d(32, 64, 3, 2, 1)
        # leak_relu
        self.conv_2_2 = nn.Conv2d(64, 128, 3, 1, 1)
        self.in_2 = InstanceNormalization(128)
        # leak_relu
        
        self.conv_3_1 = nn.Conv2d(128, 128, 3, 2, 1)
        # leak_relu
        self.conv_3_2 = nn.Conv2d(128, 256, 3, 1, 1)
        self.in_3 = InstanceNormalization(256)
        # leak_relu
        
        self.conv_4 = nn.Conv2d(256, 256, 3, 1, 1)
        self.in_4 = InstanceNormalization(256)
        # leak_relu
        
        self.conv5 = nn.Conv2d(256, 1, 3, 1, 1)
        
    def forward(self, x):
        x = F.leaky_relu(self.conv_1(x), negative_slope=0.2)
        
        x = F.leaky_relu(self.conv_2_1(x), negative_slope=0.2)
        x = F.leaky_relu(self.in_2(self.conv_2_2(x)), negative_slope=0.2)
        
        x = F.leaky_relu(self.conv_3_1(x), negative_slope=0.2)
        x = F.leaky_relu(self.in_3(self.conv_3_2(x)), negative_slope=0.2)
        
        x = F.leaky_relu(self.in_4(self.conv_4(x)), negative_slope=0.2)
        
        x = self.conv5(x)
        
        return x

In [None]:
#### definition of losses
class ContentLoss(nn.Module):
    def __init__(self):
        super(ContentLoss, self).__init__()
        self.partial_loss = nn.L1Loss()
        
    def __call__(self, F1, F2):
        loss = 0
        L1 = len(F1)
        L2 = len(F2)
        if L1 != L2:
            raise Exception("Unmatch input features")
        for i in range(L1):
            loss += self.partial_loss(F1[i], F2[i])
        return loss
    
    
class AdversialLoss(nn.Module):
    def __init__(self):
        super(AdversialLoss, self).__init__()
        self.register_buffer('true_label', torch.tensor(1.0))
        self.register_buffer('false_label', torch.tensor(0.0))
        
        self.loss = nn.MSELoss()
    
    def get_target_tensor(self, prediction, target_is_real):
        if target_is_real is True:
            target_tensor = self.true_label
        else:
            target_tensor = self.false_label
        return target_tensor.expand_as(prediction)
    
    def __call__(self, prediction, target_is_real):
        target_tensor = self.get_target_tensor(prediction, target_is_real)
        loss = self.loss(prediction, target_tensor)
        return loss

In [None]:
#### self-implementation of VGG16_bn for feature encoding
class VGG(nn.Module):
    def __init__(self):
        super(VGG, self).__init__()
    
        self.conv1_1 = nn.Sequential(
                nn.Conv2d(3, 64, kernel_size=3, padding=1),
                nn.BatchNorm2d(64),
                nn.ReLU(inplace=True)
                )
        self.conv1_2 = nn.Sequential(
                nn.Conv2d(64, 64, kernel_size=3, padding=1),
                nn.BatchNorm2d(64),
                nn.ReLU(inplace=True)
                )
        self.conv2_1 = nn.Sequential(
                nn.MaxPool2d(kernel_size=2, stride=2),
                nn.Conv2d(64, 128, kernel_size=3, padding=1),
                nn.BatchNorm2d(128),
                nn.ReLU(inplace=True)
                )
        self.conv2_2 = nn.Sequential(
                nn.Conv2d(128, 128, kernel_size=3, padding=1),
                nn.BatchNorm2d(128),
                nn.ReLU(inplace=True)
                )
        self.conv3_1 = nn.Sequential(
                nn.MaxPool2d(kernel_size=2, stride=2),
                nn.Conv2d(128, 256, kernel_size=3, padding=1),
                nn.BatchNorm2d(256),
                nn.ReLU(inplace=True)
                )
        self.conv3_2 = nn.Sequential(
                nn.Conv2d(256, 256, kernel_size=3, padding=1),
                nn.BatchNorm2d(256),
                nn.ReLU(inplace=True)
                )
        self.conv3_3 = nn.Sequential(
                nn.Conv2d(256, 256, kernel_size=3, padding=1),
                nn.BatchNorm2d(256),
                nn.ReLU(inplace=True)
                )
        self.conv4_1 = nn.Sequential(
                nn.MaxPool2d(kernel_size=2, stride=2),
                nn.Conv2d(256, 512, kernel_size=3, padding=1),
                nn.BatchNorm2d(512),
                nn.ReLU(inplace=True)
                )
        self.conv4_2 = nn.Sequential(
                nn.Conv2d(512, 512, kernel_size=3, padding=1),
                nn.BatchNorm2d(512),
                nn.ReLU(inplace=True)
                )
        self.conv4_3 = nn.Sequential(
                nn.Conv2d(512, 512, kernel_size=3, padding=1),
                nn.BatchNorm2d(512),
                nn.ReLU(inplace=True)
                )
        self.conv5_1 = nn.Sequential(
                nn.MaxPool2d(kernel_size=2, stride=2),
                nn.Conv2d(512, 512, kernel_size=3, padding=1),
                nn.BatchNorm2d(512),
                nn.ReLU(inplace=True)
                )
        self.conv5_2 = nn.Sequential(
                nn.Conv2d(512, 512, kernel_size=3, padding=1),
                nn.BatchNorm2d(512),
                nn.ReLU(inplace=True)
                )
        self.conv5_3 = nn.Sequential(
                nn.Conv2d(512, 512, kernel_size=3, padding=1),
                nn.BatchNorm2d(512),
                nn.ReLU(inplace=True),
                nn.MaxPool2d(kernel_size=2, stride=2)
                )
        
    def forward(self, x):
        y11 = self.conv1_1(x)
        y12 = self.conv1_2(y11)
        
        y21 = self.conv2_1(y12)
        y22 = self.conv2_2(y21)
        
        y31 = self.conv3_1(y22)
        y32 = self.conv3_2(y31)
        y33 = self.conv3_3(y32)
        
        y41 = self.conv4_1(y33)
        y42 = self.conv4_2(y41)
        y43 = self.conv4_3(y42)
        
        y51 = self.conv5_1(y43)
        y52 = self.conv5_2(y51)
        y53 = self.conv5_3(y52)
        
        # return y12, y22, y33, y43, y53
        return y11, y12, y21, y22, y31, y32, y33, y41, y42, y43, y51, y52, y53

In [None]:
#### parameters configuration
parser = argparse.ArgumentParser()
parser.add_argument('--lr', type=int, default=1e-4, help='learning rate, default=0.0001')
parser.add_argument('--batch_size', type=int, default=4, help='batch size during training, default=8')

opt = parser.parse_args()
print(opt)

IMG_SIZE = 256 # make sure that, the data you prepared satisfies that H == W >= IMG_SIZE

In [None]:
#### two metrics for classification
def binary_cross_entropy_metric(y_pred, y_true):
    eps = 1e-12
    y_pred = np.clip(y_pred, eps, 1.0-eps)
    return -(y_true * np.log(y_pred+eps) + (1-y_true) * np.log(1-y_pred+eps)).mean()

def binary_accuracy_metric(y_pred, y_true):
    y_pred = np.where(y_pred > 0.5, 1, 0)
    y_pred = y_pred.astype(np.int32)
    return (y_true == y_pred).mean()

In [None]:
#### dirs and files for training recording
check_pth = os.path.join(os.getcwd(), 'checkpoints')
if os.path.exists(check_pth) is not True:
    os.mkdir(check_pth)
    
## the training loss will be recorded
txt_pth = os.path.join(check_pth, 'records.txt')                     
txt_handle = open(txt_pth, 'w')

In [None]:
#### data preparing

## make sure that, you have training data in "./datasets", the structure of which is as like as follows
"""
-- datasets
  |-- real 
  |-- comic

where, data of photorealistic images are in subdir "./datasets/real", in form of ['0.jpg', '1.jpg', '2.jpg', ...]
       data of cartoon-like images are in subdir  "./datasets/comic", in form of ['0.jpg', '1.jpg', '2.jpg', ...]
"""

## now we will try to generate the third kind of data "./datasets/comic_blurred"
if os.path.exists(os.path.join(os.getcwd(), 'datasets', 'comic_blurred')) is not True:
    os.mkdir(os.path.join(os.getcwd(), 'datasets', 'comic_blurred'))
    
'''Data paths'''
datasets_pth = {
    'trainA': os.path.join(os.getcwd(), 'datasets', 'real'),
    'trainB': os.path.join(os.getcwd(), 'datasets', 'comic'),
    'trainC': os.path.join(os.getcwd(), 'datasets', 'comic_blurred')
}

# the following blurring function are from https://github.com/nijuyr/comixGAN

def smooth_image_edges(img, plot=False):
    # Get edges
    edges = cv2.Canny(img,30,60)

    # Dilate edges with kernel (5,5) with 15 iterations
    dilated_edges = cv2.dilate(edges,(7,7), iterations=25)
    
    dilated_edges_to_compare = dilated_edges.copy()
    dilated_edges_to_compare[dilated_edges == 0] = -1

    # Copy image twice
    img_no_dilated_edges, img_only_dilated_edges = img.copy(), img.copy()

    # Prepare images with region of only edges and no edges
    img_no_dilated_edges[dilated_edges_to_compare != -1] = 0
    img_only_dilated_edges[dilated_edges_to_compare == -1] = 0
    
    # Gaussian blur of the image with region of only edges
    blurred_edges = cv2.GaussianBlur(img_only_dilated_edges,(9,9),0)
    
    # Clip to take only region of edges (without values blurred on the remaining parts of the image)
    blurred_edges[dilated_edges_to_compare == -1] = 0

    # Final Gaussian blur of sum of images with and without edges
    result = blurred_edges + img_no_dilated_edges
    result = cv2.GaussianBlur(result,(9,9),0)
    
    if plot:
        plt.figure(figsize=(16,10))
        plt.subplot(221),plt.imshow(img[:,:,[2,1,0]])
        plt.title('Original Image'), plt.xticks([]), plt.yticks([])
        plt.subplot(222),plt.imshow(edges, cmap = 'gray')
        plt.title('Edges'), plt.xticks([]), plt.yticks([])
        plt.subplot(224),plt.imshow(dilated_edges, cmap = 'gray')
        plt.title('Dilated edges'), plt.xticks([]), plt.yticks([])
        plt.subplot(223),plt.imshow(result[:,:,[2,1,0]])
        plt.title('Blurred Image'), plt.xticks([]), plt.yticks([])
        plt.show()        
    return result

# try an example
samples = os.listdir(datasets_pth['trainB'])
random.shuffle(samples)

sample = cv2.imread(os.path.join(datasets_pth['trainB'], samples[0]))
result = smooth_image_edges(sample, plot=True)

# convert comic data in "./datasets/comic" to blurred data in "./datasets/comic_blurred"
#### to generate the blurred comic image
imgs = os.listdir(datasets_pth['trainB'])
for img_name in tqdm(imgs):
    image = cv2.imread(os.path.join(datasets_pth['trainB'], img_name))
    result = smooth_image_edges(image, plot=False)
    cv2.imwrite(os.path.join(datasets_pth['trainC'], img_name), result)

In [None]:
#### Global Variables
EPOCH = 101
gpu = 0
omega = 10  # loss = lossAdv + omega * lossCon
real_label = 1.0
fake_label = 0.0

In [None]:
#### preparing data
trainA_files = os.listdir(datasets_pth['trainA'])
trainB_files = os.listdir(datasets_pth['trainB'])
trainC_files = os.listdir(datasets_pth['trainC'])
l_A = len(trainA_files)
l_B = len(trainB_files)
l_C = len(trainC_files)
l_L = max([l_A, l_B, l_C])
    
for i in range(l_A):
    trainA_files[i] = os.path.join(datasets_pth['trainA'], trainA_files[i])
for i in range(l_B):
    trainB_files[i] = os.path.join(datasets_pth['trainB'], trainB_files[i])
for i in range(l_C):
    trainC_files[i] = os.path.join(datasets_pth['trainC'], trainC_files[i])
print("Finish loading the datasets, and there are: <<{}>> human-faces, and <<{}>> manga-faces!".format(l_A, l_B))

In [None]:
#### initialize models
G = CartoonGenerator()
D = Discriminator()
    
#### loading in pretrained vgg16
vgg = VGG()
vgg.load_state_dict(torch.load('vgg16.pth'))
for param in vgg.parameters():
    param.requires_grad = False
    
print("Finish initializing the models, and they are: Cartoon-Generator, Cartoon-Discriminator, and VGG19")


'''Loss functions'''
criterionCon = ContentLoss()
criterionDis = nn.MSELoss()
criterionAdv = AdversialLoss()

'''Optimizers'''
optim_G = optim.Adam(G.parameters(), lr=opt.lr, betas=(0.5, 0.999))
optim_D = optim.Adam(D.parameters(), lr=opt.lr, betas=(0.5, 0.999))

'''Use GPU'''
if gpu >= 0:
    G = G.to(gpu)
    D = D.to(gpu)
    vgg = vgg.to(gpu)

    criterionCon = criterionCon.to(gpu)
    criterionAdv = criterionAdv.to(gpu)

In [None]:
#### pretrain the generator
for epoch in range(50):
    '''initialization phase'''
    shuffle(trainA_files)
    shuffle(trainB_files)
    shuffle(trainC_files)
        
    for i_l in tqdm(range(0, l_L, opt.batch_size)):
        X = np.zeros((opt.batch_size, 3, IMG_SIZE, IMG_SIZE))
        for c in range(opt.batch_size):
            img = Image.open(trainA_files[(i_l * opt.batch_size+c) % l_A])
            img.thumbnail((IMG_SIZE, IMG_SIZE))
            X[c, :, :, :] = np.array(img).transpose(2,0,1)
            
        X = torch.from_numpy(X.astype(np.float32))
        X = 2*(X/255)-1
        # X = Variable(X)
                
        if gpu >= 0:
            X = X.to(gpu)
                
        Out = G(X)
        # print(out.data.size())
                
        F1 = vgg(Out)
        F2 = vgg(X)
        loss = criterionCon(F1, F2)
                
        optim_G.zero_grad()
        loss.backward()
        optim_G.step()
            
        log = "This is the end of the {}-th epoch, the content loss for initialization is: {}.".format(epoch, loss)
        # print(log)
        if i_l % 10 == 0:
            txt_handle.write(log+'\n')
                
        
    # save image
    if gpu >= 0:
        X = X.cpu()
        Out = Out.cpu()
    x = np.array(X.data[0,:,:,:]).transpose(1,2,0)
    o = np.array(Out.data[0,:,:,:]).transpose(1,2,0)
            
    img_x = Image.fromarray(((x*0.5+0.5)*255).astype(np.uint8))
    img_o = Image.fromarray(((o*0.5+0.5)*255).astype(np.uint8))
            
    img_x.save(os.path.join(os.getcwd(), check_pth, 'x-{}.jpg'.format(epoch)))
    img_o.save(os.path.join(os.getcwd(), check_pth, 'o-{}.jpg'.format(epoch)))

## save the models
if gpu >= 0:
    G = G.cpu()
torch.save(G.state_dict(), "./checkpoints/G_pretrained.pth")
if gpu >= 0:
    G = G.cuda()

In [None]:
#### pretrain the discriminator
PRETRAIN_DISCRIMINATOR_BATCH_SIZE = 16
opt.batch_size = PRETRAIN_DISCRIMINATOR_BATCH_SIZE
n_A = int(PRETRAIN_DISCRIMINATOR_BATCH_SIZE / 4)
n_B = int(PRETRAIN_DISCRIMINATOR_BATCH_SIZE / 2)
n_C = int(PRETRAIN_DISCRIMINATOR_BATCH_SIZE / 4)
max_iterations = max(l_A // n_A, l_B // n_B, l_C // n_C)
    
for epoch in range(50):
    shuffle(trainA_files)
    shuffle(trainB_files)
    shuffle(trainC_files)
        
    i_A, i_B, i_C = 0, 0, 0
    ## get input images as well as corresponding labels
    for i in tqdm(range(max_iterations)):
        inputs = []
        labels = []
        n_a, n_b, n_c = 0, 0, 0
        # get real
        while n_a < n_A:
            img = Image.open(trainA_files[i_A % l_A])
            i_A += 1
            H, W = img.size[0], img.size[1]
            if H<IMG_SIZE or W<IMG_SIZE:
                continue
            img.thumbnail((IMG_SIZE, IMG_SIZE))
            #print(img.size)
            arr = np.array(img).transpose(2,0,1)
            arr = arr/255
            arr = torch.from_numpy(arr.astype(np.float32))
                
            label = np.zeros((1, IMG_SIZE//4, IMG_SIZE//4))
            label = torch.from_numpy(label.astype(np.float32))
                
            inputs.append(arr.unsqueeze(0))
            labels.append((label).unsqueeze(0))
            n_a += 1
                
        # get comic
        while n_b < n_B:
            img = Image.open(trainB_files[i_B % l_B])
            i_B += 1
            H, W = img.size[0], img.size[1]
            if H<IMG_SIZE or W<IMG_SIZE:
                continue
            img.thumbnail((IMG_SIZE, IMG_SIZE))
            #print(img.size)
            arr = np.array(img).transpose(2,0,1)
            arr = arr/255
            arr = torch.from_numpy(arr.astype(np.float32))
                
            label = np.ones((1, IMG_SIZE//4, IMG_SIZE//4))
            label = torch.from_numpy(label.astype(np.float32))
                
            inputs.append(arr.unsqueeze(0))
            labels.append((label).unsqueeze(0))
            n_b += 1
            
        # get comic_blurred
        while n_c < n_C:
            img = Image.open(trainC_files[i_C % l_C])
            i_C += 1
            H, W = img.size[0], img.size[1]
            if H<IMG_SIZE or W<IMG_SIZE:
                continue
            img.thumbnail((IMG_SIZE, IMG_SIZE))
            #print(img.size)
            arr = np.array(img).transpose(2,0,1)
            arr = arr/255
            arr = torch.from_numpy(arr.astype(np.float32))
                
            label = np.zeros((1, IMG_SIZE//4, IMG_SIZE//4))
            label = torch.from_numpy(label.astype(np.float32))
                
            inputs.append(arr.unsqueeze(0))
            labels.append((label).unsqueeze(0))
            n_c += 1
            
            
        ## concatenate
        inputs = torch.cat(inputs, 0)
        labels = torch.cat(labels, 0)
            
        ## randomize the order
        randomize = np.arange(PRETRAIN_DISCRIMINATOR_BATCH_SIZE)
        np.random.shuffle(randomize)
        inputs = inputs[randomize]
        labels = labels[randomize]

        if gpu >= 0:
            inputs = inputs.cuda()
            labels = labels.cuda()
            
        optim_D.zero_grad()
        preds = D(inputs)

        preds[preds < 0.0] = 0.0
        preds[preds > 1.0] = 1.0

        # print(preds.size(), labels.size())
        loss = criterionDis(preds, labels)
        # print(loss)
        loss.backward()
        optim_D.step()
            
        log = ("The loss is: {}".format(loss))
        if i % 10 == 0:
            txt_handle.write(log+'\n')
            
    ## validation
    print(log)
    if gpu >= 0:
        preds = preds.cpu()
        labels = labels.cpu()
        
    acc = binary_accuracy_metric(np.array(preds.data), np.array(labels.data))
    log = ("The loss is: {}, The accuracy is: {}".format(loss, acc))
    print(log)
    
## save the models
if gpu >= 0:
    D = D.cpu()
torch.save(D.state_dict(), "./checkpoints/D_pretrained.pth")
if gpu >= 0:
    D = D.cuda()   

In [None]:
#### deversial training the entire framework
opt.batch_size = 4
max_iterations = max(l_A // opt.batch_size, l_B // opt.batch_size, l_C // opt.batch_size)
    
for epoch in range(EPOCH):
    shuffle(trainA_files)
    shuffle(trainB_files)
    shuffle(trainC_files)
    '''training'''
    i_A = i_B = i_C = 0
    ii = 0
    for i_L in tqdm(range(0, max_iterations, opt.batch_size)):
        ii += 1
            
        ## inputs of G
        X = np.zeros((opt.batch_size, 3, IMG_SIZE, IMG_SIZE))
        c = 0
            
        while c < opt.batch_size:
            img = Image.open(trainA_files[i_A % l_A])
            i_A += 1
            H, W = img.size[0], img.size[1]
            if H < IMG_SIZE or W < IMG_SIZE:
                continue
            img.thumbnail((IMG_SIZE, IMG_SIZE))
            X[c, :, :, :] = np.array(img).transpose(2,0,1)
            c += 1
        X = X / 255
        X = torch.from_numpy(X.astype(np.float32))
        if gpu >= 0:
            X = X.to(gpu)
                
        if ii % 3 == 0:
            ## train D
            loss_D = 0
            optim_D.zero_grad()
            # 1A: train D on real
            # >> prepare the inouts
            real_X = np.zeros((opt.batch_size, 3, IMG_SIZE, IMG_SIZE))
            c = 0
            while c < opt.batch_size:
                img = Image.open(trainB_files[i_B % l_B])
                i_B += 1
                H, W = img.size[0], img.size[1]
                if H < IMG_SIZE or W < IMG_SIZE:
                    continue
                img.thumbnail((IMG_SIZE, IMG_SIZE))
                real_X[c, :, :, :] = np.array(img).transpose(2,0,1)
                c += 1
            real_X = real_X / 255
            real_X = torch.from_numpy(real_X.astype(np.float32))
            if gpu >= 0:
                real_X = real_X.to(gpu)
            # >> forward
            real_decision = D(real_X)
            d_real_error = criterionAdv(real_decision, True)   # torch.ones([opt.batch_size, 3, 256, 256]))
            # d_real_error.backward()
            loss_D = d_real_error
                
            # 1B: train D on blur   
            # >> prepare the inputs 
            blur_X = np.zeros((opt.batch_size, 3, IMG_SIZE, IMG_SIZE))
            c = 0
            # for c in range(opt.batch_size):
            while c < opt.batch_size:
                img = Image.open(trainC_files[i_C % l_C])
                i_C += 1
                H, W = img.size[0], img.size[1]
                if H < IMG_SIZE or W < IMG_SIZE:
                    continue
                img.thumbnail((IMG_SIZE, IMG_SIZE))
                blur_X[c, :, :, :] = np.array(img).transpose(2,0,1)
                c += 1
            blur_X = blur_X / 255
            blur_X = torch.from_numpy(blur_X.astype(np.float32))
            if gpu >= 0:
                blur_X = blur_X.to(gpu)
            # >> forward
            blur_decision = D(blur_X)
            d_blur_error = criterionAdv(blur_decision, False)  # torch.zeros([opt.batch_size, 3, 256, 256]))
            # d_blur_error.backward()
            loss_D += d_blur_error
                
            # 1C: train D on fake
            fake_X = G(X)
            d_fake_decision = D(fake_X)
            d_fake_error = criterionAdv(d_fake_decision, False)# torch.zeros([opt.batch_size, 3, 256, 256]))
            # d_fake_error.backward()
            loss_D += d_fake_error
                
            # 1D: update
            loss_D = loss_D / 3
            loss_D.backward(retain_graph=True)
            optim_D.step()
            
            
        ## 2: train G
        loss_G = 0
        optim_G.zero_grad()
        Out = G(X)
        F1 = vgg(Out)
        F2 = vgg(X)
        loss_con = criterionCon(F1, F2)# x10
            
            
        g_fake_decision= D(Out)
        loss_adv = criterionAdv(g_fake_decision, True)
            
        loss_G = loss_adv + loss_con * 0.5
        loss_G.backward()
        optim_G.step()
            
        if i_L % 100 == 0:
            log = "Epoch-{}, iteration-{} >> Content loss of G: {}, Adversarial loss of G: {}, Adversarial loss of D: {}".format(epoch, i_L, loss_con, loss_adv, loss_D)
            txt_handle.write(log+'\n')
            print(log)
                
    ## save image
    if gpu >= 0:
        X = X.cpu()
        Out = Out.cpu()

    x = (np.array(X.data[0,:,:,:])*255).astype(np.uint8).transpose(1,2,0)
    o = (np.array(Out.data[0,:,:,:])*255).astype(np.uint8).transpose(1,2,0)
        
    plt.subplot(121)
    plt.title("Real image")
    plt.imshow(x)
        
    plt.subplot(122)
    plt.title("Cartoonalized image")
    plt.imshow(o)
    
    plt.savefig(os.path.join(os.getcwd(), check_pth, '%04d-%06d.jpg'%(epoch, i_L)))
            
## save models
if gpu >= 0:
    G = G.cpu()
torch.save(G.state_dict(), "./checkpoints/G_adv.pth")
if gpu >= 0:
    D = D.cpu()
torch.save(D.state_dict(), "./checkpoints/D_adv.pth")
        
txt_handle.close()