In [None]:

from __future__ import print_function
from PIL import Image
from torch.utils.data import Dataset
import glob
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
import itertools
from pathlib import Path
import torch.nn.functional as F
%matplotlib inline

# CycleGan implementation reference: https://towardsdatascience.com/overview-of-cyclegan-architecture-and-training-afee31612a2f


class ResidualBlock(nn.Module):
    def __init__(self):
        super(ResidualBlock, self).__init__()
        self.conv_block = nn.Sequential(
            # [N, 256, 64, 64] -> [N, 256, 64, 64]
            nn.ReflectionPad2d(1),
            nn.Conv2d(in_channels=256, out_channels=256,
                      kernel_size=3, stride=1),
            nn.InstanceNorm2d(num_features=256),
            nn.ReLU(inplace=True),
            # [N, 256, 64, 64] -> [N, 256, 64, 64]
            nn.ReflectionPad2d(1),
            nn.Conv2d(in_channels=256, out_channels=256,
                      kernel_size=3, stride=1),
            nn.InstanceNorm2d(num_features=256),
        )

    def forward(self, x):
        # input has size [N, C, H, W] = [N, 256, 64, 64]
        return x + self.conv_block(x)


class Generator(nn.Module):

    def __init__(self, n_residuals):
        super(Generator, self).__init__()

        # [Batch, C, H, W]
        modules = []

        # [Batch, 3, 256, 256] -> [Batch, 64, 256, 256]
        modules += [nn.ReflectionPad2d(3),
                    nn.Conv2d(in_channels=3, out_channels=64,
                              kernel_size=7, stride=1),
                    nn.InstanceNorm2d(num_features=64),
                    nn.ReLU(inplace=True)]

        # [Batch, 64, 256, 256] -> [Batch, 128, 128, 128]
        modules += [nn.Conv2d(in_channels=64, out_channels=128,
                              kernel_size=3, stride=2, padding=1),
                    nn.InstanceNorm2d(num_features=128),
                    nn.ReLU(inplace=True)]

        # [Batch, 128, 128, 128] -> [Batch, 256, 64, 64]
        modules += [nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=2, padding=1),
                    nn.InstanceNorm2d(num_features=256),
                    nn.ReLU(inplace=True)]

        for _ in range(n_residuals):
            modules += [ResidualBlock()]

        # [Batch, 256, 64, 64] -> [Batch, 128, 128, 128]
        modules += [nn.ConvTranspose2d(in_channels=256, out_channels=128,
                                       kernel_size=3, stride=2, padding=1, output_padding=1),
                    nn.InstanceNorm2d(num_features=128),
                    nn.ReLU(inplace=True)]

        # [Batch, 128, 128, 128] -> [Batch, 64, 256, 256]
        modules += [nn.ConvTranspose2d(in_channels=128, out_channels=64,
                                       kernel_size=3, stride=2, padding=1, output_padding=1),
                    nn.InstanceNorm2d(num_features=64),
                    nn.ReLU(inplace=True)]

        # [Batch, 64, 256, 256]-> [Batch, 3, 256, 256]
        modules += [nn.ReflectionPad2d(3),
                    nn.Conv2d(in_channels=64, out_channels=3,
                              kernel_size=7, stride=1),
                    nn.Tanh()
                    ]

        self.stack = nn.Sequential(*modules)

    def forward(self, input):
        # input image: [Batch, 3, 256, 256] -> output image: [Batch, 3, 256, 256]
        return self.stack(input)


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.stack = nn.Sequential(

            #  image input  [Batch, 3, 256, 256]  ->  [Batch, 64, 128, 128]
            nn.Conv2d(in_channels=3, out_channels=64, kernel_size=4,
                      stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),

            #  [Batch, 64, 128, 128] -> [Batch, 128, 64, 64]
            nn.Conv2d(in_channels=64, out_channels=128,
                      kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(num_features=128),
            nn.LeakyReLU(0.2, inplace=True),

            #  [Batch, 128, 64, 64] -> [Batch, 256, 32, 32]
            nn.Conv2d(in_channels=128, out_channels=256,
                      kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(num_features=256),
            nn.LeakyReLU(0.2, inplace=True),

            # [Batch, 256, 32, 32] -> [Batch, 512, 31, 31]
            nn.Conv2d(in_channels=256, out_channels=512,
                      kernel_size=4, stride=1, padding=1),
            nn.InstanceNorm2d(num_features=512),
            nn.LeakyReLU(0.2, inplace=True),

            # [Batch, 512, 31, 31] -> [Batch, 1, 30, 30]
            # Each value of the output tensor 30x30 holds the classification result for a 70x70 area of the input image.
            nn.Conv2d(in_channels=512, out_channels=1, kernel_size=4,
                      stride=1, padding=1)
        )

    def forward(self, input):
        #  image input  [Batch, 3, 256, 256] -> [Batch, 1, 30, 30]
        x = self.stack(input)
        avg_filter_size = x.size()[2:]  # [30,30]
        x = F.avg_pool2d(x, kernel_size=avg_filter_size)
        x = x.view(-1)  # [Batch,]
        return x


def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)


n_epochs = 50
batch_size = 1
dataroot = 'data/maps/'
lr = 0.0002
image_size = 256

device = torch.device("cuda:0" if (
    torch.cuda.is_available()) else "cpu")

# Create the Discriminator
netD_A = Discriminator().to(device)
netD_B = Discriminator().to(device)

# Create the generator
n_residuals = 9
netG_AB = Generator(n_residuals).to(device)
netG_BA = Generator(n_residuals).to(device)

# Apply the weights_init function to randomly initialize all weights
netD_A.apply(weights_init)
netD_B.apply(weights_init)
netG_AB.apply(weights_init)
netG_BA.apply(weights_init)


class ImageDataset(Dataset):
    def __init__(self, dataroot, transform, mode):
        self.filesA = glob.glob('%s%s*A\*' % (dataroot, mode))
        self.filesB = glob.glob('%s%s*B\*' % (dataroot, mode))
        self.transform = transform

    def __len__(self):
        return max(len(self.filesA), len(self.filesB))

    def __getitem__(self, idx):
        a = Image.open(self.filesA[idx % len(self.filesA)]).convert('RGB')
        # avoid fixed paired images
        b = Image.open(self.filesB[random.randint(
            0, len(self.filesB)-1)]).convert('RGB')
        return {'A': self.transform(a), 'B': self.transform(b)}


# A-horse 1067  B-zebra 1334

transform = transforms.Compose([
    transforms.Resize(int(image_size * 1.12)),
    transforms.RandomCrop(image_size),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

train_dataset = ImageDataset('data\\maps\\', transform,  'train')
test_dataset = ImageDataset('data\\maps\\', transform,  'test')

dataloader = torch.utils.data.DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True)

optimizerG = torch.optim.Adam(
    itertools.chain(netG_AB.parameters(), netG_BA.parameters()), lr=lr, betas=(0.5, 0.999)
)
optimizerD_A = optim.Adam(netD_A.parameters(), lr=lr, betas=(0.5, 0.999))
optimizerD_B = optim.Adam(netD_B.parameters(), lr=lr, betas=(0.5, 0.999))

criterion_identity = nn.L1Loss()
criterion_cycle = nn.L1Loss()
criterion_gan = nn.MSELoss()

real_labels = torch.full((batch_size,), 1, dtype=torch.float,
                         device=device, requires_grad=False)
fake_labels = torch.full((batch_size,), 0, dtype=torch.float,
                         device=device, requires_grad=False)


def plot_AB_pair(pair):
    # {A:[1,3,256,256],B:[1,3,256,256]}
    img_list = [pair['A'].squeeze(), pair['B'].squeeze()]
    plt.figure(figsize=(8, 8))
    plt.axis("off")
    plt.imshow(np.transpose(vutils.make_grid(
        img_list, padding=2, normalize=True), (1, 2, 0)))
    plt.show()


real_batch = next(iter(dataloader))
plot_AB_pair(real_batch)


for epoch in range(n_epochs):

    # For each batch in the dataloader
    for i, batch in enumerate(dataloader):

        # train Generator
        optimizerG.zero_grad()
        real_A = batch['A'].to(device)  # torch.Size([Batch, 3, 256, 256]))
        real_B = batch['B'].to(device)  # torch.Size([Batch, 3, 256, 256]))

        # cycle loss
        fake_B = netG_AB(real_A)
        fake_A = netG_BA(real_B)
        rec_A = netG_BA(fake_B)
        rec_B = netG_AB(fake_A)

        cycle_loss_A = criterion_cycle(rec_A, real_A)
        cycle_loss_B = criterion_cycle(rec_B, real_B)

        # identity loss
        identity_loss_A = criterion_identity(netG_BA(real_A), real_A)
        identity_loss_B = criterion_identity(netG_AB(real_B), real_B)

        # gan loss
        gan_loss_A = criterion_gan(netD_A(fake_A), real_labels)
        gan_loss_B = criterion_gan(netD_B(fake_B), real_labels)
        lossG = (cycle_loss_A+cycle_loss_B)*10.0 + (identity_loss_A +
                                                    identity_loss_B)*5.0+(gan_loss_A+gan_loss_B)
        lossG.backward()
        optimizerG.step()

        # train Discriminator A
        optimizerD_A.zero_grad()
        lossD_A = (criterion_gan(netD_A(real_A), real_labels) +
                   criterion_gan(netD_A(fake_A.detach()), fake_labels))*0.5
        lossD_A.backward()
        optimizerD_A.step()

        # train Discriminator B
        optimizerD_B.zero_grad()
        lossD_B = (criterion_gan(netD_B(real_B), real_labels) +
                   criterion_gan(netD_B(fake_B.detach()), fake_labels))*0.5
        lossD_B.backward()
        optimizerD_B.step()

        if i % 200 == 0:
#             print('Epoch: %d/%d, batch: [%d/%d], lossG: %.2f, lossD_A: %.2f, lossD_B: %.2f' %
#                   (epoch, n_epochs, i, len(dataloader), lossG, lossD_A, lossD_B))
            with torch.no_grad():
                a = test_dataset[10]['A'].unsqueeze(0).to(device)
                fake_b = netG_AB(a)
                plot_AB_pair({'A': a.to('cpu'), 'B': fake_b.to('cpu')})

    # save after every 50 epoch
    if (epoch+1) % 50 == 0:
        path = "model/cycleGan/maps/"
        torch.save(netG_AB.state_dict(), path + "netG_AB_%d.pt" % (epoch+1))
        torch.save(netG_BA.state_dict(), path + "netG_BA_%d.pt" % (epoch+1))
        torch.save(netD_A.state_dict(), path + "netD_A_%d.pt" % (epoch+1))
        torch.save(netD_B.state_dict(), path + "netD_B_%d.pt" % (epoch+1))
