<a href="https://colab.research.google.com/github/antonrufino/colab-notebooks/blob/main/CycleGAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

An implementation of CycleGAN from [Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks](https://arxiv.org/abs/1703.10593). The models have been modified to fit the constraints of Google Colab's GPUs

In [None]:
import os

import torch
from torch import nn
import torch.nn.functional as F
from torch import optim
from torch.utils import data
from torch.utils.tensorboard import SummaryWriter

from torchvision import transforms
from torchvision.utils import make_grid, save_image

from PIL import Image
from IPython.display import display

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [None]:
%load_ext tensorboard

# Model

## Residual Block

In [None]:
class ResidualBlock(nn.Module):
  def __init__(self):
    super(ResidualBlock, self).__init__()
  
    self.conv1 = nn.Conv2d(256, 256, 3, padding=1)
    self.bn1 = nn.InstanceNorm2d(256)
    self.conv2 = nn.Conv2d(256, 256, 3, padding=1)
    self.bn2 = nn.InstanceNorm2d(256)
  
  def forward(self, x):
    tmp = self.conv1(x)
    tmp = self.bn1(tmp)
    tmp = F.relu(tmp)

    tmp = self.conv2(tmp)
    tmp = self.bn2(tmp)
    y = x + tmp
    y = F.relu(y)

    return y


## Generator

In [None]:
class Generator(nn.Module):
  def __init__(self):
    super(Generator, self).__init__()

    self.down = nn.Sequential(
        nn.ReflectionPad2d(3),
        nn.Conv2d(3, 64, 7),
        nn.InstanceNorm2d(64),
        nn.ReLU(),

        nn.Conv2d(64, 128, 3, 2, padding=1),
        nn.InstanceNorm2d(128),
        nn.ReLU(),
        
        nn.Conv2d(128, 256, 3, 2, padding=1),
        nn.InstanceNorm2d(256),
        nn.ReLU(),   
    )

    self.body = nn.Sequential(*[ResidualBlock() for _ in range(6)])

    self.up = nn.Sequential(
        nn.ConvTranspose2d(256, 128, 3, 2, padding=1, output_padding=1),
        nn.InstanceNorm2d(128),
        nn.ReLU(),

        nn.ConvTranspose2d(128, 64, 3, 2, padding=1, output_padding=1),
        nn.InstanceNorm2d(64),
        nn.ReLU(),

        nn.ReflectionPad2d(3),
        nn.Conv2d(64, 3, 7),
        nn.InstanceNorm2d(3),
        # nn.ReLU(),
    )

  def forward(self, x):
    x = self.down(x)
    x = self.body(x)
    x = self.up(x)
    x = torch.tanh(x)
    
    x = (x + 1) / 2 # Limit output to [0,1]

    return x

## Discriminator

In [None]:
class Discriminator(nn.Module):
  def __init__(self):
    super(Discriminator, self).__init__()

    self.network = nn.Sequential(
        nn.Conv2d(3, 64, 4, 2),
        nn.LeakyReLU(0.2),

        nn.Conv2d(64, 128, 4, 2),
        nn.BatchNorm2d(128),
        nn.LeakyReLU(0.2),

        nn.Conv2d(128, 256, 4, 2),
        nn.BatchNorm2d(256),
        nn.LeakyReLU(0.2),

        nn.Conv2d(256, 512, 4, 2),
        nn.BatchNorm2d(512),
        nn.LeakyReLU(0.2),

        nn.Conv2d(512, 1, 4)
    )

  def forward(self, x):
    x = self.network(x)
    x = torch.sigmoid(x)

    return x

# Losses

In [None]:
class DiscriminatorLoss(nn.Module):
  def __init__(self):
    super(DiscriminatorLoss, self).__init__()

  def forward(self, d, x_fake, x_real):
    return (torch.mean((d(x_real) - 1) ** 2) + torch.mean(d(x_fake) ** 2)) / 2

class GeneratorLoss(nn.Module):
  def __init__(self):
    super(GeneratorLoss, self).__init__()

  def forward(self, g, d, x):
    return torch.mean((d(g(x)) - 1) ** 2)

class CycleLoss(nn.Module):
  def __init__(self):
    super(CycleLoss, self).__init__()

  def forward(self, g, f, x, y):
    total = torch.mean((f(g(x)) - x) ** 2) 
    total += torch.mean((g(f(y)) - y) ** 2)
    return total

# Style Transfer

## Data

In [None]:
!wget https://raw.githubusercontent.com/junyanz/pytorch-CycleGAN-and-pix2pix/master/datasets/download_cyclegan_dataset.sh
!mkdir datasets
!bash ./download_cyclegan_dataset.sh vangogh2photo

In [None]:
class ImageDataset(data.Dataset):
  def __init__(self, img_dir, height, width):
    self.img_dir = img_dir
    self.file_list = os.listdir(img_dir)
    self.pipeline = transforms.Compose([
      transforms.Resize((height, width)),
      transforms.ToTensor()
    ])

  def __len__(self):
    return len(self.file_list)

  def __getitem__(self, idx):
    img = Image.open(os.path.join(self.img_dir, self.file_list[idx]))
    img = self.pipeline(img)
    return img.to(device)

x_dataset = ImageDataset(
    img_dir='./datasets/vangogh2photo/trainB',
    height=128,
    width=128)

y_dataset = ImageDataset(
    img_dir='./datasets/vangogh2photo/trainA',
    height=128,
    width=128)

x_dataloader = data.DataLoader(x_dataset, batch_size=1, shuffle=True)
y_dataloader = data.DataLoader(y_dataset, batch_size=1, shuffle=True)

print(x_dataset[0])

# Training

In [None]:
EPOCHS = 20
LAMBDA = 10
IMG_HEIGHT = 128
IMG_WIDTH = 128

TRAIN_MODEL = True
TENSORBOARD_DIR = 'logs/test1'
PRETRAINED_G_MODEL_PATH = 'pretrained-g.pt'
PRETRAINED_F_MODEL_PATH = 'pretrained-f.pt'
PRETRAINED_DX_MODEL_PATH = 'pretrained-dx.pt'
PRETRAINED_DY_MODEL_PATH = 'pretrained-dy.pt'


In [None]:
%tensorboard --logdir logs

In [None]:
# !kill 1205

In [None]:
if TRAIN_MODEL:
  g = Generator().to(device)
  f = Generator().to(device)
  d_x = Discriminator().to(device)
  d_y = Discriminator().to(device)

  print(g(torch.randn(2, 3, 100, 100).to(device)).shape)
  print(f(torch.randn(2, 3, 100, 100).to(device)).shape)
  print(d_x(torch.randn(2, 3, 100, 100).to(device)).shape)
  print(d_y(torch.randn(2, 3, 100, 100).to(device)).shape)

  loss_gen = GeneratorLoss()
  loss_dis = DiscriminatorLoss()
  loss_cyc = CycleLoss()

  adam_g = optim.Adam(g.parameters(), 2e-4)
  adam_f = optim.Adam(f.parameters(), 2e-4)
  adam_d_x = optim.Adam(d_x.parameters(), 2e-4)
  adam_d_y = optim.Adam(d_y.parameters(), 2e-4)

  g_history = []
  f_history = []

  writer = SummaryWriter('logs/test1')

  running_loss_dis = 0.0
  running_loss_gen = 0.0

  test_img = Image.open('/content/datasets/vangogh2photo/testB/2014-08-01 17:41:55.jpg')
  pipeline = transforms.Compose([
    transforms.Resize(128),
    transforms.ToTensor()
  ])

  test_tensor = pipeline(test_img).unsqueeze(0).to(device)

  count = 0
  for i in range(EPOCHS):
    x_iter = iter(x_dataloader)
    y_iter = iter(y_dataloader)

    while True:
      try:
        x_sample = next(x_iter)
        y_sample = next(y_iter)
      except StopIteration:
        break

      count += 1

      adam_d_x.zero_grad()
      adam_d_y.zero_grad()

      g_history.append(g(x_sample))
      f_history.append(f(y_sample))

      if len(g_history) == 51:
        g_history = g_history[1:]

      if len(f_history) == 51:
        f_history = f_history[1:]

      y_stacked = torch.cat(g_history, 0).detach()
      x_stacked = torch.cat(f_history, 0).detach()

      loss = loss_dis(d_y, y_stacked, y_sample)
      loss += loss_dis(d_x, x_stacked, x_sample)

      loss.backward()

      running_loss_dis += loss.item()

      adam_d_x.step()
      adam_d_y.step()

      adam_g.zero_grad()
      adam_f.zero_grad()

      loss = loss_gen(g, d_y, x_sample)
      loss += loss_gen(f, d_x, y_sample)
      loss += LAMBDA * loss_cyc(g, f, x_sample, y_sample)

      loss.backward()

      running_loss_gen += loss.item()

      adam_g.step()
      adam_f.step()

      if count % 100 == 0:
        writer.add_scalar('discriminator loss', running_loss_dis / 100, count)
        writer.add_scalar('generator loss', running_loss_gen / 100, count)

        running_loss_dis = 0.0
        running_loss_gen = 0.0

        pic_dis = d_x(test_tensor)[0]
        writer.add_scalar('discriminator output on test image', torch.mean(pic_dis), count)

        torch.save(g.state_dict(), PRETRAINED_G_MODEL_PATH)
        torch.save(f.state_dict(), PRETRAINED_F_MODEL_PATH)
        torch.save(d_x.state_dict(), PRETRAINED_DX_MODEL_PATH)
        torch.save(d_y.state_dict(), PRETRAINED_DY_MODEL_PATH)

      if count % 100 == 0:
        pic =  g(test_tensor)
        grid = make_grid([test_tensor[0], pic[0], f(pic)[0]])

        writer.add_image('generated sample', grid, global_step=count)

  torch.save(g.state_dict(), PRETRAINED_G_MODEL_PATH)
  torch.save(f.state_dict(), PRETRAINED_F_MODEL_PATH)
  torch.save(d_x.state_dict(), PRETRAINED_DX_MODEL_PATH)
  torch.save(d_y.state_dict(), PRETRAINED_DY_MODEL_PATH)

## Run pretrained model

In [None]:
# INPUT_IMG_PATH = '/content/datasets/vangogh2photo/testA/00001.jpg'

# g = Generator().to(device)
# g.load_state_dict(torch.load(PRETRAINED_G_MODEL_PATH))

# input_img = Image.open(INPUT_IMG_PATH)
# input_img = transforms.ToTensor()(input_img).to(device).unsqueeze(0)
# output_img = g(input_img)[0]
# save_image(output_img, './output.png')