Modified from:https://github.com/Lornatang/CycleGAN-PyTorch

A more complicted PyTorch implementation can be found at:
 https://github.com/aitorzip/PyTorch-CycleGAN

In [15]:
import random
import torch
import torch.nn as nn
import torch.nn.functional as F

In [12]:
if(torch.cuda.is_available()):
  device = torch.device("cuda")
  print(device, torch.cuda.get_device_name(0))
else:
  device= torch.device("cpu")
  print(device)

cuda Tesla K80


Connect to Google drive to generate data loader. 

If you train using your own PC with Anaconda
1. do not run drive.mount ("/content/gdrive", force_remount=True)
2. train_dataset = datasets.ImageFolder(root = "C:/Users/ADMIN/Google 雲端硬碟/Image folders/train", transform = transformer) 

In [5]:
from google.colab import drive
drive.mount("/content/gdrive", force_remount=True)

Mounted at /content/gdrive


In [75]:
image_size = 256 
batch_size = 4

In [76]:
transformer = transforms.Compose([
  transforms.Resize((image_size, image_size)),       
  transforms.ToTensor(),                     
  transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5] )])

In [77]:
train_dataset = datasets.ImageFolder(root = "/content/gdrive/MyDrive/CycleGAN Img folder", transform = transformer)

In [78]:
classes = train_dataset.classes
classes_index = train_dataset.class_to_idx
print(classes)
print(classes_index)

['A', 'B']
{'A': 0, 'B': 1}


In [79]:
dataloader = Data.DataLoader(
    dataset=train_dataset,
    batch_size=batch_size,
    shuffle=True)

# Unilities

In [13]:
class ReplayBuffer:
  def __init__(self, max_size=50):
    assert (max_size > 0), "Empty buffer or trying to create a black hole. Be careful."
    self.max_size = max_size
    self.data = []

  def push_and_pop(self, data):
    to_return = []
    for element in data.data:
      element = torch.unsqueeze(element, 0)
      if len(self.data) < self.max_size:
        self.data.append(element)
        to_return.append(element)
      else:
        if random.uniform(0, 1) > 0.5:
          i = random.randint(0, self.max_size - 1)
          to_return.append(self.data[i].clone())
          self.data[i] = element
        else:
          to_return.append(element)
    return torch.cat(to_return)

In [14]:
# custom weights initialization called on netG and netD
def weights_init(m):
  classname = m.__class__.__name__
  if classname.find("Conv") != -1:
      torch.nn.init.normal_(m.weight, 0.0, 0.02)
  elif classname.find("BatchNorm") != -1:
      torch.nn.init.normal_(m.weight, 1.0, 0.02)
      torch.nn.init.zeros_(m.bias)

# Define CycleGAN NN

In [16]:
class Discriminator(nn.Module):
  def __init__(self):
    super(Discriminator, self).__init__()
    self.main = nn.Sequential(
      nn.Conv2d(3, 64, 4, stride=2, padding=1),
      nn.LeakyReLU(0.2, inplace=True),

      nn.Conv2d(64, 128, 4, stride=2, padding=1),
      nn.InstanceNorm2d(128),
      nn.LeakyReLU(0.2, inplace=True),

      nn.Conv2d(128, 256, 4, stride=2, padding=1),
      nn.InstanceNorm2d(256),
      nn.LeakyReLU(0.2, inplace=True),

      nn.Conv2d(256, 512, 4, padding=1),
      nn.InstanceNorm2d(512),
      nn.LeakyReLU(0.2, inplace=True),

      nn.Conv2d(512, 1, 4, padding=1),
    )

  def forward(self, x):
    x = self.main(x)
    x = F.avg_pool2d(x, x.size()[2:])
    x = torch.flatten(x, 1)
    return x

In [17]:
class Generator(nn.Module):
  def __init__(self):
    super(Generator, self).__init__()
    self.main = nn.Sequential(
      # Initial convolution block
      nn.ReflectionPad2d(3),
      nn.Conv2d(3, 64, 7),
      nn.InstanceNorm2d(64),
      nn.ReLU(inplace=True),

      # Downsampling
      nn.Conv2d(64, 128, 3, stride=2, padding=1),
      nn.InstanceNorm2d(128),
      nn.ReLU(inplace=True),
      nn.Conv2d(128, 256, 3, stride=2, padding=1),
      nn.InstanceNorm2d(256),
      nn.ReLU(inplace=True),

      # Residual blocks
      ResidualBlock(256),
      ResidualBlock(256),
      ResidualBlock(256),
      ResidualBlock(256),
      ResidualBlock(256),
      ResidualBlock(256),
      ResidualBlock(256),
      ResidualBlock(256),
      ResidualBlock(256),

      # Upsampling
      nn.ConvTranspose2d(256, 128, 3, stride=2, padding=1, output_padding=1),
      nn.InstanceNorm2d(128),
      nn.ReLU(inplace=True),
      nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1),
      nn.InstanceNorm2d(64),
      nn.ReLU(inplace=True),

      # Output layer
      nn.ReflectionPad2d(3),
      nn.Conv2d(64, 3, 7),
      nn.Tanh()
    )

  def forward(self, x):
    return self.main(x)

In [18]:
class ResidualBlock(nn.Module):
  def __init__(self, in_channels):
    super(ResidualBlock, self).__init__()

    self.res = nn.Sequential(nn.ReflectionPad2d(1),
          nn.Conv2d(in_channels, in_channels, 3),
          nn.InstanceNorm2d(in_channels),
          nn.ReLU(inplace=True),
          nn.ReflectionPad2d(1),
          nn.Conv2d(in_channels, in_channels, 3),
          nn.InstanceNorm2d(in_channels))

  def forward(self, x):
      return x + self.res(x)

#Optimizor

In [27]:
class DecayLR:
  def __init__(self, epochs, offset, decay_epochs):
    epoch_flag = epochs - decay_epochs
    assert (epoch_flag > 0), "Decay must start before the training session ends!"
    self.epochs = epochs
    self.offset = offset
    self.decay_epochs = decay_epochs

  def step(self, epoch):
    return 1.0 - max(0, epoch + self.offset - self.decay_epochs) / (
            self.epochs - self.decay_epochs)

# The main training loop

In [20]:
import argparse
import itertools
import os
import random

import torch.backends.cudnn as cudnn
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.utils as vutils
from PIL import Image
from tqdm import tqdm

In [21]:
# create model
netG_A2B = Generator().to(device)
netG_B2A = Generator().to(device)
netD_A = Discriminator().to(device)
netD_B = Discriminator().to(device)

In [22]:
netG_A2B.apply(weights_init)
netG_B2A.apply(weights_init)
netD_A.apply(weights_init)
netD_B.apply(weights_init)

Discriminator(
  (main): Sequential(
    (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (3): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (6): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
    (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1))
    (9): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (10): LeakyReLU(negative_slope=0.2, inplace=True)
    (11): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1))
  )
)

In [23]:
# define loss function (adversarial_loss) and optimizer
cycle_loss = torch.nn.L1Loss().to(device)
identity_loss = torch.nn.L1Loss().to(device)
adversarial_loss = torch.nn.MSELoss().to(device)

In [34]:
lr = 0.0001
epochs = 5
decay_epochs = 2

In [31]:
# Optimizers
optimizer_G = torch.optim.Adam(itertools.chain(netG_A2B.parameters(), netG_B2A.parameters()),lr=lr, betas=(0.5, 0.999))
optimizer_D_A = torch.optim.Adam(netD_A.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_D_B = torch.optim.Adam(netD_B.parameters(), lr=lr, betas=(0.5, 0.999))

lr_lambda = DecayLR(epochs, 0, decay_epochs).step
lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(optimizer_G, lr_lambda=lr_lambda)
lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(optimizer_D_A, lr_lambda=lr_lambda)
lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(optimizer_D_B, lr_lambda=lr_lambda)

In [32]:
g_losses = []
d_losses = []

identity_losses = []
gan_losses = []
cycle_losses = []

In [33]:
fake_A_buffer = ReplayBuffer()
fake_B_buffer = ReplayBuffer()

step by step send on batch to one training loop

In [80]:
for i, data in enumerate(dataloader):
  break;

In [81]:
print(len(data), data[0].shape)

2 torch.Size([4, 3, 256, 256])


In [82]:
# get batch size data
real_image_A = data[0].to(device)
real_image_B = data[1].to(device)
batch_size = real_image_A.size(0)

In [83]:
# real data label is 1, fake data label is 0.
real_label = torch.full((batch_size, 1), 1, device=device, dtype=torch.float32)
fake_label = torch.full((batch_size, 1), 0, device=device, dtype=torch.float32)

In [84]:
##############################################
# (1) Update G network: Generators A2B and B2A
##############################################

# Set G_A and G_B's gradients to zero
optimizer_G.zero_grad()
# Identity loss
# G_B2A(A) should equal A if real A is fed
identity_image_A = netG_B2A(real_image_A)
loss_identity_A = identity_loss(identity_image_A, real_image_A) * 5.0
# G_A2B(B) should equal B if real B is fed
identity_image_B = netG_A2B(real_image_B)
loss_identity_B = identity_loss(identity_image_B, real_image_B) * 5.0

RuntimeError: ignored

In [None]:
# save last check pointing
torch.save(netG_A2B.state_dict(), f"weights/{args.dataset}/netG_A2B.pth")
torch.save(netG_B2A.state_dict(), f"weights/{args.dataset}/netG_B2A.pth")
torch.save(netD_A.state_dict(), f"weights/{args.dataset}/netD_A.pth")
torch.save(netD_B.state_dict(), f"weights/{args.dataset}/netD_B.pth")