In [1]:
import itertools
import numpy as np
import matplotlib.pyplot as plt
import argparse
import glob
import os

import torch
import tqdm
from PIL import Image
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F

### Preparing hiper-parameters

In [2]:
parser = argparse.ArgumentParser()
parser.add_argument("--n_epoch", type=int, default=200)
parser.add_argument("--batch_size", type=int, default=1)
parser.add_argument("--lr", type=float, default=2e-4)
parser.add_argument("--decay_start", type=int, default=100)
parser.add_argument("--weight_identity", type=float, default=5.0)
parser.add_argument("--weight_cycle", type=float, default=10.0)
parser.add_argument("--image_size", type=int, default=256)
parser.add_argument("--beta1", type=float, default=0.5)
opt = parser.parse_args(args=[])
print(opt)

Namespace(n_epoch=200, batch_size=1, lr=0.0002, decay_start=100, weight_identity=5.0, weight_cycle=10.0, image_size=256, beta1=0.5)


### Preparing Datasets

In [3]:
class ImageDataset(Dataset):
  def __init__(self, transform=None):
    super().__init__()
    self.files_A = glob.glob("./drive/MyDrive/data/horse2zebra/trainA/*.jpg")
    self.files_B = glob.glob("./drive/MyDrive/data/horse2zebra/trainB/*.jpg")
    self.transform = transform

  def __getitem__(self, index):
    imgA = self.transform(Image.open(self.files_A[index % len(self.files_A)]))
    while True:
      random_index = np.random.randint(0, len(self.files_B)-1)
      imgB = self.transform(Image.open(self.files_B[random_index % len(self.files_B)]))
      C, H, W = imgB.size()
      if C == 3:
        break
    return {"A": imgA, "B":imgB}

  def __len__(self):
    return max(len(self.files_A, len(self.files_B)))

class DecayLR(object):
  def __init__(self, n_epoch, offset, decay_start_epoch):
    self.n_epochs = n_epochs
    self.offset = offset
    self.decay_start_epoch = decay_start_epoch

  def step(self, epoch):
    return 1.0 - max(0, epoch + self.offset - self.decay_start_epoch) / (self.n_epochs - self.decay_start_epoch)

In [4]:
class ReplayBuffer(object):
  def __init__(self, max_size=50):
    self.max_size = max_size
    self.data = []

  def push_and_pop(self, data):
    to_return = []
    for element in data.data:
      element = torch.unsqueeze(element, 0)
      if len(self.data) < self.max_size:
        self.data.append(element)
        to_return.append(element)
      else:
        if np.random.rand() > 0.5:
          i = np.random.randint(0, self.max_size -1)
          to_return.append(self.data[i].clone())
          self.data[i] = element
        else:
          to_return.append(element)
    return torch.cat(to_return)

### Design of Generator

In [None]:
class ResidualBlock(nn.Module):
  def __init__(self, in_channels):
    super().__init__()
    self.conv_layers = nn.Sequential(
        nn.ReflectionPad2d(1),
        nn.Conv2d(in_channels, in_channels, 3),
        nn.InstanceNorm2d(in_channels),
        nn.ReLU(inplace=True),
        nn.ReflectionPad2d(1),
        nn.Conv2d(in_channels, in_channels, 3),
        nn.InstanceNorm2d(in_channels)
    )

  def forward(self, x):
    out = self.conv_layers(x)
    out += x
    return out

In [5]:
class Generator(nn.Module):
  def __init__(self, res_block, in_channels=3):
    super().__init__()
    self.encoder = nn.Sequential(
        nn.ReflectionPad2d(3),
        nn.Conv2d(in_channels, 64, 7),
        nn.InstanceNorm2d(64),
        nn.ReLU(inplace=True),

        nn.Conv2d(64, 128, 3, stride=2, padding=1),
        nn.InstanceNorm2d(128),
        nn.ReLU(inplace=True),

        nn.Conv2d(128, 256, 3, stride=2, padding=1),
        nn.InstanceNorm2d(256),
        nn.ReLU(inplace=True)
    )
    #
    ### transformer
    #
    self.res_block = res_block(256)
    self.transformer = nn.ModuleList([
        res_block(256),
        res_block(256),
        res_block(256),
        res_block(256),
        res_block(256),
        res_block(256),
        res_block(256),
        res_block(256),
        res_block(256)
    ])

    #
    ### Decoder
    #
    self.decoder = nn.Sequential(
        nn.ConvTranspose2d(256, 128, 3, stride=2, padding=1, output_padding=1),
        nn.InstanceNorm2d(128),
        nn.ReLU(inplace=True),

        nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1),
        nn.InstanceNorm2d(64),
        nn.ReLU(inplace=True),

        nn.ReflectionPad2d(3),
        nn.Conv2d(64, 3, 7),
        nn.Tanh()
    )

  def forward(self, x):
    out = self.encoder(x)
    for func in self.transformer:
      out = func(out)
    out = self.decoder(out)
    return out

### Design of discriminator