## In this code, we consider GAN loss, L1 loss, and Style loss. In addition, we just use the current lineart frame to input the generator.

In [None]:
!pip uninstall scipy
!pip install scipy==1.1.0

Uninstalling scipy-1.1.0:
  Would remove:
    /usr/local/lib/python3.6/dist-packages/scipy-1.1.0.dist-info/*
    /usr/local/lib/python3.6/dist-packages/scipy/*
Proceed (y/n)? y
  Successfully uninstalled scipy-1.1.0
Collecting scipy==1.1.0
  Using cached https://files.pythonhosted.org/packages/a8/0b/f163da98d3a01b3e0ef1cab8dd2123c34aee2bafbb1c5bffa354cc8a1730/scipy-1.1.0-cp36-cp36m-manylinux1_x86_64.whl
[31mERROR: plotnine 0.6.0 has requirement scipy>=1.2.0, but you'll have scipy 1.1.0 which is incompatible.[0m
[31mERROR: albumentations 0.1.12 has requirement imgaug<0.2.7,>=0.2.5, but you'll have imgaug 0.2.9 which is incompatible.[0m
Installing collected packages: scipy
Successfully installed scipy-1.1.0


In [None]:
import sys
import os
import numpy as np
from math import log10
from os.path import join
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import DataLoader
import torch.backends.cudnn as cudnn
import torchvision.transforms as transforms

In [None]:
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


In [None]:
sys.path.append('/content/gdrive/MyDrive/src_second')

In [None]:
from models import define_G, define_D, print_network
from data import get_training_set, get_test_set, create_iterator
from dataset import DatasetFromFolder
from loss import AdversarialLoss, StyleLoss, PerceptualLoss
from util import Progbar, stitch_images, postprocess, load

In [None]:
root = '/content/gdrive/MyDrive'
dataset = 'dataset'
logfile = 'trainlogs.dat'
checkpoint_path_G = False
checkpoint_path_D = False
batchSize = 16
testBatchSize = 1
nEpochs = 80
input_nc = 1
output_nc = 3
lr = 0.0001
beta1 = 0
cuda = True
threads = 0
seed = 123
L1lamb = 10
Stylelamb = 1000
Contentlamb = 0
Adversariallamb = 0.1
ngf = 2
ndf = 2

In [None]:
if cuda and not torch.cuda.is_available():
    raise Exception("No GPU found, please run without --cuda")

cudnn.benchmark = True

torch.manual_seed(seed)
if cuda:
    torch.cuda.manual_seed(seed)

In [None]:
print('===> Loading datasets')
root_path = root
train_set = get_training_set(join(root_path , dataset))
test_set = get_test_set(join(root_path , dataset))

===> Loading datasets


In [None]:
training_data_loader = DataLoader(dataset=train_set, num_workers=threads, batch_size=batchSize, shuffle=True)
testing_data_loader = DataLoader(dataset=test_set, num_workers=threads, batch_size=testBatchSize, shuffle=False)

In [None]:
sample_iterator = create_iterator(4, test_set)

In [None]:
print('===> Building model')
netG = define_G(input_nc, output_nc, ngf, False, [0])
netD = define_D(input_nc + output_nc, ndf, False, [0])

===> Building model


In [None]:
if checkpoint_path_G and checkpoint_path_D:
    load(checkpoint_path_G, checkpoint_path_D, netG, netD)

In [None]:
criterionGAN = AdversarialLoss()
criterionSTYLE = StyleLoss()
criterionCONTENT = PerceptualLoss()
criterionL1 = nn.L1Loss()
criterionMSE = nn.MSELoss()

In [None]:
# setup optimizer
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerD = optim.Adam(netD.parameters(), lr=lr * 0.1, betas=(beta1, 0.999))

In [None]:
print('---------- Networks initialized -------------')
print_network(netG)
print_network(netD)
print('-----------------------------------------------')

---------- Networks initialized -------------
InpaintGenerator(
  (encoder): Sequential(
    (0): ReflectionPad2d((3, 3, 3, 3))
    (1): Conv2d(1, 64, kernel_size=(7, 7), stride=(1, 1))
    (2): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (3): ReLU(inplace=True)
    (4): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (5): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (8): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (9): ReLU(inplace=True)
  )
  (middle): Sequential(
    (0): ResnetBlock(
      (conv_block): Sequential(
        (0): ReflectionPad2d((2, 2, 2, 2))
        (1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), dilation=(2, 2))
        (2): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_

In [None]:
real_a = torch.FloatTensor(batchSize, input_nc, 256, 256)
real_b = torch.FloatTensor(batchSize, output_nc, 256, 256)

In [None]:
if cuda:
    netD = netD.cuda()
    netG = netG.cuda()
    criterionGAN = criterionGAN.cuda()
    criterionL1 = criterionL1.cuda()
    critertionSTYLE = criterionSTYLE.cuda()
    criterionCONTENT = criterionCONTENT.cuda()
    criterionMSE = criterionMSE.cuda()
    real_a = real_a.cuda()
    real_b = real_b.cuda()

real_a = Variable(real_a)
real_b = Variable(real_b)

In [None]:
def train(epoch):

    for iteration, batch in enumerate(training_data_loader, 1):
        # forward
        real_a_cpu, real_b_cpu = batch[0], batch[1]
        with torch.no_grad():
          real_a.resize_(real_a_cpu.size()).copy_(real_a_cpu)
          real_b.resize_(real_b_cpu.size()).copy_(real_b_cpu)
        
        input_joined = real_a

        fake_b = netG(input_joined)

        ############################
        # (1) Update D network: maximize log(D(x,y)) + log(1 - D(x,G(x)))
        ###########################

        optimizerD.zero_grad()

        # train with fake
        fake_ab = torch.cat((real_a, fake_b), 1)
        pred_fake = netD.forward(fake_ab.detach())
        loss_d_fake = criterionGAN(pred_fake,False,True)

        # train with real
        real_ab = torch.cat((real_a, real_b), 1)
        pred_real = netD.forward(real_ab)
        loss_d_real = criterionGAN(pred_real, True, True) 


        # Combined loss
        loss_d = (loss_d_fake + loss_d_real) * 0.5

        loss_d.backward()

        #Discriminator parameters update every 12 iterations 
        if (iteration == 1 or iteration % 12 == 0):
            optimizerD.step()

        ############################
        # (2) Update G network: maximize log(D(x,G(x))) + L1(y,G(x))
        ##########################
        optimizerG.zero_grad()

        # First, G(A) should fake the discriminator
        fake_ab = torch.cat((real_a, fake_b), 1)
        pred_fake = netD.forward(fake_ab)
        loss_g_gan = criterionGAN(pred_fake, True, False)

        # Second, G(A) = B
        loss_g_l1 = criterionL1(fake_b, real_b) * L1lamb
        loss_g = loss_g_gan + loss_g_l1

        loss_g_style = criterionSTYLE(fake_b,real_b) * Stylelamb
        loss_g = loss_g + loss_g_style

        loss_g_content = criterionCONTENT(fake_b,real_b) * Contentlamb
        loss_g = loss_g + loss_g_content

        loss_g.backward()

        optimizerG.step()

        if (iteration % 7 == 0):
            logs = [("epoc", epoch),("iter", iteration),("Loss_G", loss_g.item()),("Loss_D", loss_d.item()), ("Loss_G_adv",loss_g_gan.item()),("Loss_G_L1",loss_g_l1.item()),("Loss_G_style",loss_g_style.item()),("Loss_G_content",loss_g_content.item()),("Loss_D_Real",loss_d_real.item()),("Loss_D_Fake",loss_d_fake.item())]
            log_train_data(logs)

        if (iteration % 7 == 0):
            sample(iteration)


        print("===> Epoch[{}]({}/{}): Loss_D: {:.4f} Loss_G: {:.4f} LossD_Fake: {:.4f} LossD_Real: {:.4f}  LossG_Adv: {:.4f} LossG_L1: {:.4f} LossG_Style {:.4f} LossG_Content {:.4f}".format(
           epoch, iteration, len(training_data_loader), loss_d, loss_g, loss_d_fake, loss_d_real, loss_g_gan, loss_g_l1, loss_g_style, loss_g_content))
        

In [None]:
def sample(iteration):
    with torch.no_grad():

        input,target,prev_frame = next(sample_iterator)
        
        if cuda:
            input = input.cuda()
            target = target.cuda()

        pred_input = input
        prediction = netG(pred_input)
        prediction = postprocess(prediction)
        input = postprocess(input)
        target = postprocess(target)

    img = stitch_images(input, target, prediction)
    samples_dir = root_path + "/samples_second"

    if not os.path.exists(samples_dir):
        os.makedirs(samples_dir)

    sample = dataset + "_" + str(epoch) + "_" + str(iteration).zfill(2) + ".jpg"
    print('\nsaving sample ' + sample + ' - learning rate: ' + str(lr))
    img.save(os.path.join(samples_dir, sample))

In [None]:
def log_train_data(loginfo):
    log_dir = root_path + "/logs_second"
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
    log_file = log_dir + "/" + logfile
    with open(log_file, 'a') as f:
        f.write('%s\n' % ' '.join([str(item[1]) for item in loginfo]))

In [None]:
def checkpoint(epoch):
    checkpoint_dir = root_path + '/checkpoint_second'
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)

    net_g_model_out_path = checkpoint_dir + "/netG_weights_epoch_{}.pth".format(epoch)
    net_d_model_out_path = checkpoint_dir + "/netD_weights_epoch_{}.pth".format(epoch)

    torch.save({'generator': netG.state_dict()}, net_g_model_out_path)
    torch.save({'discriminator': netD.state_dict()}, net_d_model_out_path)
    
    print("Checkpoint saved to {}".format("checkpoint" + dataset))

In [None]:
for epoch in range(1, nEpochs + 1):
    train(epoch)
    checkpoint(epoch)

===> Epoch[1](1/7): Loss_D: 0.6832 Loss_G: 10.0857 LossD_Fake: 0.6447 LossD_Real: 0.7217  LossG_Adv: 0.7350 LossG_L1: 3.6607 LossG_Style 5.6899 LossG_Content 0.0000
===> Epoch[1](2/7): Loss_D: 0.6893 Loss_G: 9.2321 LossD_Fake: 0.6586 LossD_Real: 0.7200  LossG_Adv: 0.7250 LossG_L1: 3.5780 LossG_Style 4.9291 LossG_Content 0.0000
===> Epoch[1](3/7): Loss_D: 0.6916 Loss_G: 9.9341 LossD_Fake: 0.6648 LossD_Real: 0.7183  LossG_Adv: 0.7210 LossG_L1: 3.6031 LossG_Style 5.6099 LossG_Content 0.0000
===> Epoch[1](4/7): Loss_D: 0.6903 Loss_G: 9.4275 LossD_Fake: 0.6659 LossD_Real: 0.7146  LossG_Adv: 0.7206 LossG_L1: 3.6140 LossG_Style 5.0929 LossG_Content 0.0000
===> Epoch[1](5/7): Loss_D: 0.6920 Loss_G: 9.3819 LossD_Fake: 0.6699 LossD_Real: 0.7142  LossG_Adv: 0.7166 LossG_L1: 3.4342 LossG_Style 5.2311 LossG_Content 0.0000
===> Epoch[1](6/7): Loss_D: 0.6936 Loss_G: 9.5503 LossD_Fake: 0.6718 LossD_Real: 0.7154  LossG_Adv: 0.7148 LossG_L1: 3.5401 LossG_Style 5.2955 LossG_Content 0.0000

saving sample 

In [None]:
def run():
    torch.multiprocessing.freeze_support()
    print('loop')

In [None]:
if __name__ == '__main__':
    run()

loop
