# Importing libraries

In [None]:
import torch.nn as nn
from torch.nn import functional as F
import torch
import os
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torch import ones, zeros

from tqdm import tqdm

from torchvision import transforms
from torchvision.utils import save_image

# Dataset

In [None]:
def get_abs_file_paths(dir_name):
    for dir_path, _, filenames in os.walk(dir_name):
        for f in filenames:
            yield os.path.abspath(os.path.join(dir_path, f))


class NSTDataset(Dataset):

    def __init__(self, root_dir, train=True, transform=None):
        self.transform = transform

        mode = 'train' if train else 'test'
        self.sampleA = [file for file in get_abs_file_paths(os.path.join(root_dir, f"{mode}A"))]
        self.sampleB = [file for file in get_abs_file_paths(os.path.join(root_dir, f"{mode}B"))]

    def __len__(self):
        return 1000
        # return max(len(self.sampleA), len(self.sampleB))

    def __getitem__(self, idx):
        sampleA_len = len(self.sampleA)
        sampleB_len = len(self.sampleB)

        imageA = self.transform(Image.open(self.sampleA[idx % sampleA_len]))
        imageB = self.transform(Image.open(self.sampleB[idx % sampleB_len]))

        return {'imageA': imageA, 'imageB': imageB}

# Generator and discriminator

In [None]:
class ResnetGenerator(nn.Module):

    def __init__(self, n_input_channels=3, n_filters=64, n_output_channels=3, use_dropout=False):
        super(ResnetGenerator, self).__init__()

        # down_sampling

        self.initial_conv = nn.Sequential(
            nn.ReflectionPad2d(3),
            nn.Conv2d(n_input_channels, n_filters, kernel_size=7),
            nn.InstanceNorm2d(n_filters),
            nn.ReLU()
        )
        self.down_sampling1 = nn.Sequential(
            nn.Conv2d(n_filters, n_filters * 2, kernel_size=3, padding=1, stride=2),
            nn.InstanceNorm2d(n_filters * 2),
            nn.ReLU()
        )
        self.down_sampling2 = nn.Sequential(
            nn.Conv2d(n_filters * 2, n_filters * 4, kernel_size=3, padding=1, stride=2),
            nn.InstanceNorm2d(n_filters * 4),
            nn.ReLU()
        )

        # residual blocks

        self.residual_blks = []
        for _ in range(9):
            self.residual_blks += [Residual(n_filters * 4, use_dropout)]
        self.residual_blks = nn.Sequential(*self.residual_blks)

        # up_sampling

        self.up_sampling1 = nn.Sequential(
            nn.ConvTranspose2d(n_filters * 4, n_filters * 2, kernel_size=3, padding=1, stride=2, output_padding=1),
            nn.InstanceNorm2d(n_filters * 2),
            nn.ReLU()
        )
        self.up_sampling2 = nn.Sequential(
            nn.ConvTranspose2d(n_filters * 2, n_filters, kernel_size=3, padding=1, stride=2, output_padding=1),
            nn.InstanceNorm2d(n_filters),
            nn.ReLU()
        )
        self.final_conv = nn.Sequential(
            nn.ReflectionPad2d(3),
            nn.Conv2d(n_filters, n_output_channels, kernel_size=7),
            nn.InstanceNorm2d(n_output_channels),
            nn.Tanh()
        )

    def forward(self, X):
        X = self.initial_conv(X)
        X = self.down_sampling1(X)
        X = self.down_sampling2(X)

        X = self.residual_blks(X)

        X = self.up_sampling1(X)
        X = self.up_sampling2(X)
        X = self.final_conv(X)

        return X


class Residual(nn.Module):

    def __init__(self, n_channels, use_dropout=False):
        super(Residual, self).__init__()
        self.conv1 = nn.Conv2d(n_channels, n_channels,
                               kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(n_channels, n_channels,
                               kernel_size=3, padding=1)
        if use_dropout:
            self.dropout = nn.Dropout(0.5)
        else:
            self.dropout = None

        self.bn = nn.InstanceNorm2d(n_channels)

    def forward(self, X):
        Y = F.relu(self.bn(self.conv1(X)))
        if self.dropout:
            Y = self.dropout(Y)
        Y = self.bn(self.conv2(Y))
        Y += X
        return F.relu(Y)


class PatchGanDiscriminator(nn.Module):

    def __init__(self, n_input_channels=3, n_filters=64):
        super(PatchGanDiscriminator, self).__init__()
        self.conv1 = nn.Conv2d(n_input_channels, n_filters, kernel_size=4, stride=2, padding=1)
        self.bn = nn.InstanceNorm2d(n_filters)
        self.main = []
        for _ in range(3):
            prev_channels = n_filters
            n_filters *= 2
            self.main += [nn.Conv2d(prev_channels, n_filters, kernel_size=4, stride=2, padding=1),
                          nn.InstanceNorm2d(n_filters),
                          nn.LeakyReLU(0.2)]

        self.main += [nn.Conv2d(n_filters, n_filters, kernel_size=4, stride=2, padding=1),
                      nn.InstanceNorm2d(n_filters),
                      nn.LeakyReLU(0.2)]
        self.main = nn.Sequential(*self.main)

        self.conv2 = nn.Conv2d(n_filters, 1, kernel_size=4, padding=1)
        self.adap_pooling = nn.AdaptiveAvgPool2d((1, 1))

    def forward(self, X):
        X = F.leaky_relu_(self.bn(self.conv1(X)))
        X = self.main(X)
        X = self.conv2(X)
        X = self.adap_pooling(X)

        return torch.flatten(X)

# Losses

In [None]:
class CycleGanLoss(nn.Module):

    def __init__(self, discriminator_A, discriminator_B, gen_A2B, gen_B2A, device):
        super(CycleGanLoss, self).__init__()

        self.discriminator_A = discriminator_A
        self.discriminator_B = discriminator_B
        self.gen_A2B = gen_A2B
        self.gen_B2A = gen_B2A

        self.mse_loss = nn.MSELoss().to(device)  # losses for generator and discriminator
        self.mae_loss = nn.L1Loss().to(device)  # identity and cycle losses

    def calc(self, real_A, real_B):
        real_labels, fake_labels = ones(real_A.size(0)), zeros(real_A.size(0))

        def discriminator_loss(disc_real_results, disc_gen_results):
            real_loss = self.mse_loss(disc_real_results, real_labels)
            generated_loss = self.mse_loss(disc_gen_results, fake_labels)

            return (real_loss + generated_loss) * 0.5

        generated_B = self.gen_A2B(real_A)
        generated_A = self.gen_B2A(real_B)

        cycled_A = self.gen_B2A(generated_B)
        cycled_B = self.gen_A2B(generated_A)

        identical_A = self.gen_B2A(real_A)
        identical_B = self.gen_A2B(real_B)

        disc_real_A_results = self.discriminator_A(real_A)
        disc_real_B_results = self.discriminator_B(real_B)

        disc_gen_A_results = self.discriminator_A(generated_A)
        disc_gen_B_results = self.discriminator_B(generated_B)

        gen_A2B_loss = self.mse_loss(disc_gen_A_results, real_labels)
        gen_B2A_loss = self.mse_loss(disc_gen_B_results, real_labels)

        cycle_A_loss = self.mae_loss(cycled_A, real_A) * 10
        cycled_B_loss = self.mae_loss(cycled_B, real_B) * 10

        total_cycle_loss = cycle_A_loss + cycled_B_loss

        identity_A_loss = self.mae_loss(real_A, identical_A) * 0.5
        identity_B_loss = self.mae_loss(real_B, identical_B) * 0.5

        # Total generator loss = adversarial loss + cycle loss
        total_gen_loss = gen_A2B_loss + gen_B2A_loss + total_cycle_loss \
                         + identity_A_loss + identity_B_loss

        disc_x_loss = discriminator_loss(disc_real_A_results, disc_gen_A_results)
        disc_y_loss = discriminator_loss(disc_real_B_results, disc_gen_B_results)

        total_loss = total_gen_loss + disc_x_loss + disc_y_loss

        return total_loss


# Training

In [None]:
import gdown

dataset_name = "monet2photo"

url = f"https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/{dataset_name}.zip"
output = "dataset.zip"
gdown.download(url, output, False)

Downloading...
From: https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/monet2photo.zip
To: /content/dataset.zip
100%|██████████| 305M/305M [02:35<00:00, 1.96MB/s]


'dataset.zip'

In [None]:
%%capture

!unzip "dataset.zip"
!rm "dataset.zip"

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

torch.manual_seed(17)
image_size = 256
epoch_n = 40
freq_n = 1
start_epoch = 0

model_path = ""

transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

models_dir = f"/content/drive/MyDrive/NST-CycleGan/models"
dataset_dir = dataset_name
images_storage_dir = f"/content/drive/MyDrive/NST-CycleGan/images_output"

dataset = NSTDataset(root_dir=dataset_dir, train=True, transform=transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True)


gen_A2B = ResnetGenerator().to(device)
gen_B2A = ResnetGenerator().to(device)
discriminator_A = PatchGanDiscriminator().to(device)
discriminator_B = PatchGanDiscriminator().to(device)

optimizer_gen_A2B = torch.optim.Adam(gen_A2B.parameters())
optimizer_gen_B2A = torch.optim.Adam(gen_B2A.parameters())
optimizer_disc_A = torch.optim.Adam(discriminator_A.parameters())
optimizer_disc_B = torch.optim.Adam(discriminator_B.parameters())

if model_path is not "":
  checkpoint = torch.load(model_path)
  gen_A2B.load_state_dict(checkpoint['gen_A2B'])
  optimizer_gen_A2B.load_state_dict(checkpoint['optimizer_gen_A2B'])
  gen_B2A.load_state_dict(checkpoint['gen_B2A'])
  optimizer_gen_B2A.load_state_dict(checkpoint['optimizer_gen_B2A'])
  discriminator_A.load_state_dict(checkpoint['discriminator_A'])
  optimizer_disc_A.load_state_dict(checkpoint['optimizer_disc_A'])
  discriminator_B.load_state_dict(checkpoint['discriminator_B'])
  optimizer_disc_B.load_state_dict(checkpoint['optimizer_disc_B'])
  start_epoch = checkpoint['epoch']


cycleGanLoss = CycleGanLoss(discriminator_A=discriminator_A, discriminator_B=discriminator_B,
                            gen_A2B=gen_A2B, gen_B2A=gen_B2A, device=device)


for epoch in range(start_epoch, epoch_n):
    p_bar = tqdm(enumerate(dataloader), total=len(dataloader))
    for idx, batch in p_bar:

        optimizer_gen_A2B.zero_grad()
        optimizer_gen_B2A.zero_grad()
        optimizer_disc_A.zero_grad()
        optimizer_disc_B.zero_grad()

        image_A = batch['imageA']
        image_B = batch['imageB']

        cur_loss = cycleGanLoss.calc(image_A, image_B)
        cur_loss.backward()

        optimizer_gen_A2B.step()
        optimizer_gen_B2A.step()
        optimizer_disc_A.step()
        optimizer_disc_B.step()

        p_bar.set_description(
            f"[{epoch}/{epoch_n - 1}][{idx}/{len(dataloader) - 1}] "
            f"total_loss: {cur_loss.item():.4f} ")

    if epoch % freq_n == 0:
      torch.save({
      'epoch': epoch,
      'gen_A2B': gen_A2B.state_dict(),
      'optimizer_gen_A2B': optimizer_gen_A2B.state_dict(),
      'gen_B2A': gen_B2A.state_dict(),
      'optimizer_gen_B2A': optimizer_gen_B2A.state_dict(),
      'discriminator_A': discriminator_A.state_dict(),
      'optimizer_disc_A': optimizer_disc_A.state_dict(),
      'discriminator_B': discriminator_B.state_dict(),
      'optimizer_disc_B': optimizer_disc_B.state_dict(),
      }, f"{models_dir}/epoch_{epoch}_model.pth")

      save_image(image_A, f"{images_storage_dir}/epoch_{epoch}_real_A.png")
      save_image(image_B, f"{images_storage_dir}/epoch_{epoch}_real_B.png")

      image_gen_A = (gen_A2B(image_A).data * 0.5) + 0.5
      image_gen_B = (gen_B2A(image_B).data * 0.5) + 0.5

      save_image(image_gen_A, f"{images_storage_dir}/epoch_{epoch}_gen_A.png")
      save_image(image_gen_B, f"{images_storage_dir}/epoch_{epoch}_gen_B.png")

