# Import

In [0]:
import torch
from torch import nn
from torchsummary import summary
import torch.nn.functional as F
import os

In [0]:
def conv(n_in, n_filters, kernel_size, stride, bias=False):
    return nn.Conv2d(n_in, n_filters, kernel_size=kernel_size, stride=stride, padding=kernel_size // 2, bias=bias)

In [0]:
def inv_conv(n_in, n_filters, kernel_size, stride, bias=False):
    return nn.ConvTranspose2d(n_in, n_filters, kernel_size=kernel_size, stride=stride, padding=1, output_padding=1, bias=bias)

In [0]:
def requires_grad(module, status):
  for p in list(module.parameters()):
    p.requires_grad = status

In [0]:
device = torch.device('cuda')

# Discriminator

In [0]:
def init_cnn(m):
    if getattr(m, 'bias', None) is not None: nn.init.constant_(m.bias, 0)
    if isinstance(m, (nn.Conv2d, nn.Linear)): nn.init.kaiming_normal_(m.weight)
    for l in m.children(): init_cnn(l)

In [0]:
def activation(act):
  if act == 'leakyrelu':
    return nn.LeakyReLU(negative_slope=0.2)
  elif act == 'tanh':
    return nn.Tanh()
  else:
    return nn.ReLU()

In [0]:
def conv_layer(n_in, n_filters, kernel_size, stride, do_norm=True, do_act=True, act='relu'):
    layers = [conv(n_in, n_filters, kernel_size, stride)]
    if do_norm: layers.append(nn.InstanceNorm2d(n_filters))
    if do_act: layers.append(activation(act))
    return nn.Sequential(*layers)

In [0]:
def inv_conv_layer(n_in, n_filters, kernel_size, stride, do_norm=True, do_act=True, act='relu'):
  layers = [inv_conv(n_in, n_filters, kernel_size, stride)]
  if do_norm: layers.append(nn.InstanceNorm2d(n_filters))
  if do_act: layers.append(activation(act))
  return nn.Sequential(*layers)

In [0]:
class Discriminator(nn.Module):
  def __init__(self, input_shape):
    super().__init__()
    n_in = input_shape[0]
    self.loss = nn.MSELoss()
    self.model = nn.Sequential(
        conv_layer(n_in, 64, 3, 2, act='leakyrelu', do_norm=False),
        conv_layer(64, 128, 3, 2, act='leakyrelu'),
        conv_layer(128, 256, 3, 2, act='leakyrelu'),
        conv_layer(256, 512, 3, 2, act='leakyrelu'),
        conv_layer(512, 512, 3, 1, act='leakyrelu'),
        conv(512, 1, 3, 1)
    )
    init_cnn(self.model)

  def forward(self, x):
    return self.model(x)

In [16]:
input_shape = (3, 256, 256)
disc = Discriminator(input_shape)
disc.to(device=device)
summary(disc, input_shape)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 128, 128]           1,728
         LeakyReLU-2         [-1, 64, 128, 128]               0
            Conv2d-3          [-1, 128, 64, 64]          73,728
    InstanceNorm2d-4          [-1, 128, 64, 64]               0
         LeakyReLU-5          [-1, 128, 64, 64]               0
            Conv2d-6          [-1, 256, 32, 32]         294,912
    InstanceNorm2d-7          [-1, 256, 32, 32]               0
         LeakyReLU-8          [-1, 256, 32, 32]               0
            Conv2d-9          [-1, 512, 16, 16]       1,179,648
   InstanceNorm2d-10          [-1, 512, 16, 16]               0
        LeakyReLU-11          [-1, 512, 16, 16]               0
           Conv2d-12          [-1, 512, 16, 16]       2,359,296
   InstanceNorm2d-13          [-1, 512, 16, 16]               0
        LeakyReLU-14          [-1, 512,

# Generator

In [0]:
class ResBlock(nn.Module):
  def __init__(self, n_in, n_filters):
    super().__init__()
    self.convs = nn.Sequential(
        conv_layer(n_in, n_filters, 3, 1, act='relu'),
        conv_layer(n_filters, n_filters, 3, 1, do_act=False)
    )
    self.downsample = None
    self.relu = nn.ReLU()
    if n_in != n_filters:
      self.downsample = conv_layer(n_in, n_filters, 3, 1, do_act=False)
      init_cnn(self.downsample)

    init_cnn(self.convs)

  def forward(self, x):
    residual = x
    y = self.convs(x)
    if self.downsample:
      residual = self.downsample(x)
    return self.relu(y + residual)

In [19]:
res_block = ResBlock(3, 64)
res_block.to(device=device)
summary(res_block, (input_shape))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 256, 256]           1,728
    InstanceNorm2d-2         [-1, 64, 256, 256]               0
              ReLU-3         [-1, 64, 256, 256]               0
            Conv2d-4         [-1, 64, 256, 256]          36,864
    InstanceNorm2d-5         [-1, 64, 256, 256]               0
            Conv2d-6         [-1, 64, 256, 256]           1,728
    InstanceNorm2d-7         [-1, 64, 256, 256]               0
              ReLU-8         [-1, 64, 256, 256]               0
Total params: 40,320
Trainable params: 40,320
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.75
Forward/backward pass size (MB): 256.00
Params size (MB): 0.15
Estimated Total Size (MB): 256.90
----------------------------------------------------------------


In [0]:
class Generator(nn.Module):
  def __init__(self, input_shape, n_resblock=9):
    super().__init__()
    n_in = input_shape[0]
    self.loss = nn.MSELoss()
    res_blocks = [ResBlock(256, 256) for _ in range(n_resblock)]
    self.model = nn.Sequential(
        conv_layer(n_in, 64, 7, 1, act='relu'),
        conv_layer(64, 128, 3, 2, act='relu'),
        conv_layer(128, 256, 3, 2, act='relu'),

        *res_blocks,

        inv_conv_layer(256, 128, 3, 2, act='relu'),
        inv_conv_layer(128, 64, 3, 2, act='relu'),
        conv_layer(64, 3, 7, 1, act='tanh')
    )
    init_cnn(self.model)

  def forward(self, x):
    return self.model(x)

In [21]:
gen = Generator(input_shape)
gen.to(device=device)
summary(gen, (input_shape))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 256, 256]           9,408
    InstanceNorm2d-2         [-1, 64, 256, 256]               0
              ReLU-3         [-1, 64, 256, 256]               0
            Conv2d-4        [-1, 128, 128, 128]          73,728
    InstanceNorm2d-5        [-1, 128, 128, 128]               0
              ReLU-6        [-1, 128, 128, 128]               0
            Conv2d-7          [-1, 256, 64, 64]         294,912
    InstanceNorm2d-8          [-1, 256, 64, 64]               0
              ReLU-9          [-1, 256, 64, 64]               0
           Conv2d-10          [-1, 256, 64, 64]         589,824
   InstanceNorm2d-11          [-1, 256, 64, 64]               0
             ReLU-12          [-1, 256, 64, 64]               0
           Conv2d-13          [-1, 256, 64, 64]         589,824
   InstanceNorm2d-14          [-1, 256,

# Cycle-GAN

In [0]:
def real_mse_loss(d_out): # how close is the produced output from being real?
  return torch.mean((d_out - 1) ** 2)

def fake_mse_loss(d_out): # how close is the produced output from being fake?
  return torch.mean(d_out ** 2)

def cycle_consistency_loss(real_im, reconstructed_im, lambda_weight): # calculate reconstruction loss and return weighted loss
  loss = torch.mean(torch.abs(real_im - reconstructed_im))
  return loss * lambda_weight

# Trainer

In [0]:
from tqdm.notebook import tqdm

class GANTrainer:
  def __init__(self, dataloader_a, dataloader_b, input_shape):
    # Dataloaders
    self.dataloader_a, self.dataloader_b = dataloader_a, dataloader_b
    
    # Models
    self.gen_a, self.gen_b = Generator(input_shape), Generator(input_shape)
    self.disc_a, self.disc_b = Discriminator(input_shape), Discriminator(input_shape)
    self.send_to_gpu()

    # Optimizer
    gen_params = list(self.gen_a.parameters()) + list(self.gen_b.parameters())
    self.gen_opt = torch.optim.Adam(gen_params, betas=(0.5, 0.99))
    self.disc_a_opt = torch.optim.Adam(self.disc_a.parameters(), betas=(0.5, 0.99))
    self.disc_b_opt = torch.optim.Adam(self.disc_b.parameters(), betas=(0.5, 0.99))

  def send_to_gpu(self):
    self.gen_a.to(device=device)
    self.gen_b.to(device=device)
    self.disc_a.to(device=device)
    self.disc_b.to(device=device)

  def set_trainable(self, disc_a=False, disc_b=False):
    gen = (not disc_a) and (not disc_b)
    requires_grad(self.gen_a, gen)
    requires_grad(self.gen_b, gen)
    requires_grad(self.disc_a, disc_a)
    requires_grad(self.disc_b, disc_b)

  def fit(self, nb_epochs=1):
    for epoch in range(nb_epochs):
      print("EPOCH {}".format(epoch + 1))
      self.one_epoch()

  def one_epoch(self):
    all_loss_disc_a, all_loss_disc_b, all_loss_cycle, count = 0.0, 0.0, 0.0, 0.0
    for real_a, real_b in tqdm(zip(self.dataloader_a, self.dataloader_b), total=len(self.dataloader_a)):
      real_a, real_b = real_a.cuda(), real_b.cuda()
      real_a, real_b = real_a.to(device=device, dtype=torch.float), real_b.to(device=device, dtype=torch.float)

      # Discriminators training
      ## disc_a
      self.disc_a_opt.zero_grad()
      real_disc_a_loss = real_mse_loss(self.disc_a(real_a))
      fake_disc_a_loss = fake_mse_loss(self.disc_a(self.gen_a(real_b)))
      disc_a_loss = real_disc_a_loss + fake_disc_a_loss
      disc_a_loss.backward()
      self.disc_a_opt.step()
      all_loss_disc_a += disc_a_loss

      ## disc_b
      self.disc_b_opt.zero_grad()
      real_disc_b_loss = real_mse_loss(self.disc_b(real_b))
      fake_disc_b_loss = fake_mse_loss(self.disc_b(self.gen_b(real_a)))
      disc_b_loss = real_disc_b_loss + fake_disc_b_loss
      disc_b_loss.backward()
      self.disc_b_opt.step()
      all_loss_disc_b += disc_b_loss

      # Generators training
      self.gen_opt.zero_grad()
      out_1 = self.gen_a(real_b)
      loss_1 = real_mse_loss(self.disc_a(out_1))
      out_2 = self.gen_b(out_1)
      loss_2 = cycle_consistency_loss(real_im=real_b, reconstructed_im=out_2, lambda_weight=10.0)

      out_3 = self.gen_b(real_a)
      loss_3 = real_mse_loss(self.disc_b(out_3))
      out_4 = self.gen_a(out_3)
      loss_4 = cycle_consistency_loss(real_im=real_a, reconstructed_im=out_4, lambda_weight=10.0)

      gen_total_loss = loss_1 + loss_2 + loss_3 + loss_4
      gen_total_loss.backward()
      self.gen_opt.step()
      all_loss_cycle += gen_total_loss

      count += real_a.shape[0]

    print("Loss:")
    print("Discriminator A --> {:.4f}".format(all_loss_disc_a / count))
    print("Discriminator B --> {:.4f}".format(all_loss_disc_b / count))
    print("Cycle           --> {:.4f} \n".format(all_loss_cycle / count))

# Data

In [15]:
from google.colab import drive
drive.mount('/content/drive')

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&response_type=code&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive.photos.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fpeopleapi.readonly

Enter your authorization code:
··········
Mounted at /content/drive


In [0]:
root_path = os.path.join(os.getcwd(), 'drive', 'My Drive', 'Datasets')

In [0]:
zip_path = os.path.join(root_path, 'monet2photo.zip')

In [0]:
import zipfile
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
    zip_ref.extractall(root_path)

In [0]:
data_path = os.path.join(root_path, 'monet2photo', 'monet2photo')
train_a_path = os.path.join(data_path, 'trainA')
train_b_path = os.path.join(data_path, 'trainB')

In [0]:
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
from skimage import io, transform
import random

In [0]:
class DomainDataset(Dataset):
  def __init__(self, train_path, transform_ops=None):
    self.train_path, self.transform_ops = train_path, transform_ops
    self.images = os.listdir(self.train_path)
    random.shuffle(self.images)
    self.length = len(self.images)

  def __len__(self):
    return self.length

  def __getitem__(self, idx):
    img_path = os.path.join(self.train_path, self.images[idx])
    img = io.imread(img_path)

    if self.transform_ops:
      img = self.transform_ops(img)

    return img

  def set_length(self, length):
    self.length = length

In [0]:
class ToTensor(object):
  def __call__(self, sample):
    image = sample.transpose((2, 0, 1)) / 255.0
    return torch.from_numpy(image)

In [0]:
tfms = transforms.Compose([ToTensor()])
ds_a = DomainDataset(train_a_path, transform_ops=tfms)
ds_b = DomainDataset(train_b_path, transform_ops=tfms)
ds_a.set_length(min(len(ds_a), len(ds_b)))
ds_b.set_length(min(len(ds_a), len(ds_b)))

In [0]:
batch_size = 4
train_a_loader = DataLoader(dataset=ds_a, batch_size=batch_size, shuffle=True)
train_b_loader = DataLoader(dataset=ds_b, batch_size=batch_size, shuffle=True)

# Putting everything together

In [0]:
input_shape = (3, 256, 256)
trainer = GANTrainer(train_a_loader, train_b_loader, input_shape)

In [0]:
trainer.fit(1)

EPOCH 1


HBox(children=(IntProgress(value=0, max=268), HTML(value='')))