Modified from:https://github.com/Lornatang/CycleGAN-PyTorch

I change the code from a set of py files to one Ipynb file that can be run at Colab. 

Please refer to my ppt for how to organize training images on your Google drive.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torchvision import datasets
import torchvision.utils as vutils
import torch.utils.data as Data
from torch.utils.data import Dataset

from PIL import Image
import random
import itertools
import glob
import os

In [None]:
if(torch.cuda.is_available()):
  device = torch.device("cuda")
  print(device, torch.cuda.get_device_name(0))
else:
  device= torch.device("cpu")
  print(device)

cuda Tesla T4


Connect to Google drive to generate data loader. 

If you train using your own PC with Anaconda
1. do not run drive.mount ("/content/gdrive", force_remount=True)
2. train_dataset = datasets.ImageFolder(root = "C:/Users/ADMIN/Google 雲端硬碟/Image folders/train", transform = transformer) 

In [None]:
from google.colab import drive
drive.mount("/content/gdrive", force_remount=True)

Mounted at /content/gdrive


In [None]:
image_size = 256 
batch_size = 4

In [None]:
class ImageDataset(Dataset):
    def __init__(self, root, transform=None, unaligned=False, mode="train"):
        self.transform = transform
        self.unaligned = unaligned

        self.files_A = sorted(glob.glob(os.path.join(root, f"{mode}/A") + "/*.*"))
        self.files_B = sorted(glob.glob(os.path.join(root, f"{mode}/B") + "/*.*"))

    def __getitem__(self, index):
        item_A = self.transform(Image.open(self.files_A[index % len(self.files_A)]))

        if self.unaligned:
            item_B = self.transform(Image.open(self.files_B[random.randint(0, len(self.files_B) - 1)]))
        else:
            item_B = self.transform(Image.open(self.files_B[index % len(self.files_B)]))

        return {"A": item_A, "B": item_B}

    def __len__(self):
        return max(len(self.files_A), len(self.files_B))

In [None]:
transformer = transforms.Compose([
  transforms.Resize(int(image_size*1.12), Image.BICUBIC),
  transforms.RandomCrop(image_size),
  transforms.RandomHorizontalFlip(),     
  transforms.ToTensor(),                     
  transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5] )])

  "Argument interpolation should be of type InterpolationMode instead of int. "


In [None]:
dataset = ImageDataset("/content/gdrive/MyDrive/CycleGAN Img folder", transform = transformer, unaligned=True)

In [None]:
dataloader = Data.DataLoader(dataset, batch_size=batch_size, shuffle=True, pin_memory=True)

test: load a batch of images from dataset

In [None]:
for i, data in enumerate(dataloader):
  break;

In [None]:
print(data["A"].shape,data["B"].shape)

torch.Size([4, 3, 256, 256]) torch.Size([4, 3, 256, 256])


Unilitiy functions

In [None]:
class ReplayBuffer:
  def __init__(self, max_size=50):
    assert (max_size > 0), "Empty buffer or trying to create a black hole. Be careful."
    self.max_size = max_size
    self.data = []

  def push_and_pop(self, data):
    to_return = []
    for element in data.data:
      element = torch.unsqueeze(element, 0)
      if len(self.data) < self.max_size:
        self.data.append(element)
        to_return.append(element)
      else:
        if random.uniform(0, 1) > 0.5:
          i = random.randint(0, self.max_size - 1)
          to_return.append(self.data[i].clone())
          self.data[i] = element
        else:
          to_return.append(element)
    return torch.cat(to_return)

In [None]:
# custom weights initialization called on netG and netD
def weights_init(m):
  classname = m.__class__.__name__
  if classname.find("Conv") != -1:
      torch.nn.init.normal_(m.weight, 0.0, 0.02)
  elif classname.find("BatchNorm") != -1:
      torch.nn.init.normal_(m.weight, 1.0, 0.02)
      torch.nn.init.zeros_(m.bias)

Define CycleGAN NN

In [None]:
class Discriminator(nn.Module):
  def __init__(self):
    super(Discriminator, self).__init__()
    self.main = nn.Sequential(
      nn.Conv2d(3, 64, 4, stride=2, padding=1),
      nn.LeakyReLU(0.2, inplace=True),

      nn.Conv2d(64, 128, 4, stride=2, padding=1),
      nn.InstanceNorm2d(128),
      nn.LeakyReLU(0.2, inplace=True),

      nn.Conv2d(128, 256, 4, stride=2, padding=1),
      nn.InstanceNorm2d(256),
      nn.LeakyReLU(0.2, inplace=True),

      nn.Conv2d(256, 512, 4, padding=1),
      nn.InstanceNorm2d(512),
      nn.LeakyReLU(0.2, inplace=True),

      nn.Conv2d(512, 1, 4, padding=1),
    )

  def forward(self, x):
    x = self.main(x)
    x = F.avg_pool2d(x, x.size()[2:])
    x = torch.flatten(x, 1)
    return x

In [None]:
class Generator(nn.Module):
  def __init__(self):
    super(Generator, self).__init__()
    self.main = nn.Sequential(
      # Initial convolution block
      nn.ReflectionPad2d(3),
      nn.Conv2d(3, 64, 7),
      nn.InstanceNorm2d(64),
      nn.ReLU(inplace=True),

      # Downsampling
      nn.Conv2d(64, 128, 3, stride=2, padding=1),
      nn.InstanceNorm2d(128),
      nn.ReLU(inplace=True),
      nn.Conv2d(128, 256, 3, stride=2, padding=1),
      nn.InstanceNorm2d(256),
      nn.ReLU(inplace=True),

      # Residual blocks
      ResidualBlock(256),
      ResidualBlock(256),
      ResidualBlock(256),
      ResidualBlock(256),
      ResidualBlock(256),
      ResidualBlock(256),
      ResidualBlock(256),
      ResidualBlock(256),
      ResidualBlock(256),

      # Upsampling
      nn.ConvTranspose2d(256, 128, 3, stride=2, padding=1, output_padding=1),
      nn.InstanceNorm2d(128),
      nn.ReLU(inplace=True),
      nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1),
      nn.InstanceNorm2d(64),
      nn.ReLU(inplace=True),

      # Output layer
      nn.ReflectionPad2d(3),
      nn.Conv2d(64, 3, 7),
      nn.Tanh()
    )

  def forward(self, x):
    return self.main(x)

In [None]:
class ResidualBlock(nn.Module):
  def __init__(self, in_channels):
    super(ResidualBlock, self).__init__()

    self.res = nn.Sequential(nn.ReflectionPad2d(1),
          nn.Conv2d(in_channels, in_channels, 3),
          nn.InstanceNorm2d(in_channels),
          nn.ReLU(inplace=True),
          nn.ReflectionPad2d(1),
          nn.Conv2d(in_channels, in_channels, 3),
          nn.InstanceNorm2d(in_channels))

  def forward(self, x):
      return x + self.res(x)

Optimizor

In [None]:
class DecayLR:
  def __init__(self, epochs, offset, decay_epochs):
    epoch_flag = epochs - decay_epochs
    assert (epoch_flag > 0), "Decay must start before the training session ends!"
    self.epochs = epochs
    self.offset = offset
    self.decay_epochs = decay_epochs

  def step(self, epoch):
    return 1.0 - max(0, epoch + self.offset - self.decay_epochs) / (
            self.epochs - self.decay_epochs)

prepare for taining

In [None]:
# create model
netG_A2B = Generator().to(device)
netG_B2A = Generator().to(device)
netD_A = Discriminator().to(device)
netD_B = Discriminator().to(device)

In [None]:
from torchsummary import summary
summary(netG_A2B,(3,256,256),batch_size=4, device='cuda')

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
   ReflectionPad2d-1           [4, 3, 262, 262]               0
            Conv2d-2          [4, 64, 256, 256]           9,472
    InstanceNorm2d-3          [4, 64, 256, 256]               0
              ReLU-4          [4, 64, 256, 256]               0
            Conv2d-5         [4, 128, 128, 128]          73,856
    InstanceNorm2d-6         [4, 128, 128, 128]               0
              ReLU-7         [4, 128, 128, 128]               0
            Conv2d-8           [4, 256, 64, 64]         295,168
    InstanceNorm2d-9           [4, 256, 64, 64]               0
             ReLU-10           [4, 256, 64, 64]               0
  ReflectionPad2d-11           [4, 256, 66, 66]               0
           Conv2d-12           [4, 256, 64, 64]         590,080
   InstanceNorm2d-13           [4, 256, 64, 64]               0
             ReLU-14           [4, 256,

In [None]:
netG_A2B.apply(weights_init)
netG_B2A.apply(weights_init)
netD_A.apply(weights_init)
netD_B.apply(weights_init)
print()




In [None]:
# define loss function (adversarial_loss) and optimizer
cycle_loss = torch.nn.L1Loss().to(device)
identity_loss = torch.nn.L1Loss().to(device)
adversarial_loss = torch.nn.MSELoss().to(device)

In [None]:
lr = 0.0001
epochs = 5 # change to larger number, 50000, for real training
decay_epochs = 2 #change to 100

In [None]:
# Optimizers
optimizer_G = torch.optim.Adam(itertools.chain(netG_A2B.parameters(), netG_B2A.parameters()),lr=lr, betas=(0.5, 0.999))
optimizer_D_A = torch.optim.Adam(netD_A.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_D_B = torch.optim.Adam(netD_B.parameters(), lr=lr, betas=(0.5, 0.999))

lr_lambda = DecayLR(epochs, 0, decay_epochs).step
lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(optimizer_G, lr_lambda=lr_lambda)
lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(optimizer_D_A, lr_lambda=lr_lambda)
lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(optimizer_D_B, lr_lambda=lr_lambda)

In [None]:
g_losses = []
d_losses = []

identity_losses = []
gan_losses = []
cycle_losses = []

In [None]:
fake_A_buffer = ReplayBuffer()
fake_B_buffer = ReplayBuffer()

In [None]:
outf = "out"
dataset = "TomJerry"

In [None]:
try:
  os.makedirs(os.path.join(outf, dataset, "A"))
  os.makedirs(os.path.join(outf, dataset, "B"))
except OSError:
  pass

In [None]:
try:
  os.makedirs(os.path.join("weights", dataset))
except OSError:
  pass

# Step-by-step training of one batch
You can skip this section and jump to Main training loop

In [None]:
for i, data in enumerate(dataloader):
  break;

In [None]:
# get batch size data
real_image_A = data['A'].to(device)
real_image_B = data['B'].to(device)
batch_size = real_image_A.size(0)

In [None]:
# real data label is 1, fake data label is 0.
real_label = torch.full((batch_size, 1), 1, device=device, dtype=torch.float32)
fake_label = torch.full((batch_size, 1), 0, device=device, dtype=torch.float32)

(1) Update G network: Generators A2B and B2A

In [None]:
# Set G_A and G_B's gradients to zero
optimizer_G.zero_grad()

In [None]:
# Identity loss
# G_B2A(A) should equal A if real A is fed
identity_image_A = netG_B2A(real_image_A)
loss_identity_A = identity_loss(identity_image_A, real_image_A) * 5.0
print(real_image_A.shape, identity_image_A.shape, loss_identity_A)

torch.Size([4, 3, 256, 256]) torch.Size([4, 3, 256, 256]) tensor(3.7666, device='cuda:0', grad_fn=<MulBackward0>)


In [None]:
# G_A2B(B) should equal B if real B is fed
identity_image_B = netG_A2B(real_image_B)
loss_identity_B = identity_loss(identity_image_B, real_image_B) * 5.0
print(real_image_B.shape, identity_image_B.shape, loss_identity_B)

torch.Size([4, 3, 256, 256]) torch.Size([4, 3, 256, 256]) tensor(2.5936, device='cuda:0', grad_fn=<MulBackward0>)


In [None]:
# GAN loss
# GAN loss D_A(G_A(A))
fake_image_A = netG_B2A(real_image_B)
fake_output_A = netD_A(fake_image_A)
loss_GAN_B2A = adversarial_loss(fake_output_A, real_label)
# GAN loss D_B(G_B(B))
fake_image_B = netG_A2B(real_image_A)
fake_output_B = netD_B(fake_image_B)
loss_GAN_A2B = adversarial_loss(fake_output_B, real_label)

In [None]:
# Cycle loss
recovered_image_A = netG_B2A(fake_image_B)
loss_cycle_ABA = cycle_loss(recovered_image_A, real_image_A) * 10.0

recovered_image_B = netG_A2B(fake_image_A)
loss_cycle_BAB = cycle_loss(recovered_image_B, real_image_B) * 10.0

In [None]:
# Combined loss and calculate gradients
errG = loss_identity_A + loss_identity_B + loss_GAN_A2B + loss_GAN_B2A + loss_cycle_ABA + loss_cycle_BAB

# Calculate gradients for G_A and G_B
errG.backward()
# Update G_A and G_B's weights
optimizer_G.step()

(2) Update D network: Discriminator A

In [None]:
# Set D_A gradients to zero
optimizer_D_A.zero_grad()

In [None]:
# Real A image loss
real_output_A = netD_A(real_image_A)
errD_real_A = adversarial_loss(real_output_A, real_label)

In [None]:
# Fake A image loss
fake_image_A = fake_A_buffer.push_and_pop(fake_image_A)
fake_output_A = netD_A(fake_image_A.detach())
errD_fake_A = adversarial_loss(fake_output_A, fake_label)

In [None]:
# Combined loss and calculate gradients
errD_A = (errD_real_A + errD_fake_A) / 2

# Calculate gradients for D_A
errD_A.backward()
# Update D_A weights
optimizer_D_A.step()

(3) Update D network: Discriminator B

In [None]:
# Set D_B gradients to zero
optimizer_D_B.zero_grad()

# Real B image loss
real_output_B = netD_B(real_image_B)
errD_real_B = adversarial_loss(real_output_B, real_label)

# Fake B image loss
fake_image_B = fake_B_buffer.push_and_pop(fake_image_B)
fake_output_B = netD_B(fake_image_B.detach())
errD_fake_B = adversarial_loss(fake_output_B, fake_label)

# Combined loss and calculate gradients
errD_B = (errD_real_B + errD_fake_B) / 2

# Calculate gradients for D_B
errD_B.backward()
# Update D_B weights
optimizer_D_B.step()

In [None]:
vutils.save_image(real_image_A, f"{outf}/{dataset}/A/real_samples.png",normalize=True)

In [None]:
vutils.save_image(real_image_B, f"{outf}/{dataset}/B/real_samples.png",normalize=True)

In [None]:
epoch=0

In [None]:
fake_image_A = 0.5 * (netG_B2A(real_image_B).data + 1.0)
fake_image_B = 0.5 * (netG_A2B(real_image_A).data + 1.0)
vutils.save_image(fake_image_A.detach(), f"{outf}/{dataset}/A/fake_samples_epoch_{epoch}.png",normalize=True)
vutils.save_image(fake_image_B.detach(),f"{outf}/{dataset}/B/fake_samples_epoch_{epoch}.png",normalize=True)

In [None]:
# do check pointing every epoch
torch.save(netG_A2B.state_dict(), f"weights/{dataset}/netG_A2B_epoch_{epoch}.pth")
torch.save(netG_B2A.state_dict(), f"weights/{dataset}/netG_B2A_epoch_{epoch}.pth")
torch.save(netD_A.state_dict(), f"weights/{dataset}/netD_A_epoch_{epoch}.pth")
torch.save(netD_B.state_dict(), f"weights/{dataset}/netD_B_epoch_{epoch}.pth")

#The main training loop

In [None]:
##############################################
# (1) Update G network: Generators A2B and B2A
##############################################
def UpdateG (real_image_A,real_image_B,real_label,fake_label):
  # Set G_A and G_B's gradients to zero
  optimizer_G.zero_grad()

  # Identity loss
  # G_B2A(A) should equal A if real A is fed
  identity_image_A = netG_B2A(real_image_A)
  loss_identity_A = identity_loss(identity_image_A, real_image_A) * 5.0
  # G_A2B(B) should equal B if real B is fed
  identity_image_B = netG_A2B(real_image_B)
  loss_identity_B = identity_loss(identity_image_B, real_image_B) * 5.0

  # GAN loss
  # GAN loss D_A(G_A(A))
  fake_image_A = netG_B2A(real_image_B)
  fake_output_A = netD_A(fake_image_A)
  loss_GAN_B2A = adversarial_loss(fake_output_A, real_label)
  # GAN loss D_B(G_B(B))
  fake_image_B = netG_A2B(real_image_A)
  fake_output_B = netD_B(fake_image_B)
  loss_GAN_A2B = adversarial_loss(fake_output_B, real_label)

  # Cycle loss
  recovered_image_A = netG_B2A(fake_image_B)
  loss_cycle_ABA = cycle_loss(recovered_image_A, real_image_A) * 10.0

  recovered_image_B = netG_A2B(fake_image_A)
  loss_cycle_BAB = cycle_loss(recovered_image_B, real_image_B) * 10.0

  # Combined loss and calculate gradients
  errG = loss_identity_A + loss_identity_B + loss_GAN_A2B + loss_GAN_B2A + loss_cycle_ABA + loss_cycle_BAB

  # Calculate gradients for G_A and G_B
  errG.backward()
  # Update G_A and G_B's weights
  optimizer_G.step()

  return fake_image_A, fake_image_B

In [None]:
##############################################
# (2) Update D network: Discriminator A
##############################################
def UpdateD_A (fake_image_A):
  # Set D_A gradients to zero
  optimizer_D_A.zero_grad()

  # Real A image loss
  real_output_A = netD_A(real_image_A)
  errD_real_A = adversarial_loss(real_output_A, real_label)

  # Fake A image loss
  fake_image_A = fake_A_buffer.push_and_pop(fake_image_A)
  fake_output_A = netD_A(fake_image_A.detach())
  errD_fake_A = adversarial_loss(fake_output_A, fake_label)

  # Combined loss and calculate gradients
  errD_A = (errD_real_A + errD_fake_A) / 2

  # Calculate gradients for D_A
  errD_A.backward()
  # Update D_A weights
  optimizer_D_A.step()

In [None]:
##############################################
# (3) Update D network: Discriminator B
##############################################
def UpdateD_B(fake_image_B):
  # Set D_B gradients to zero
  optimizer_D_B.zero_grad()

  # Real B image loss
  real_output_B = netD_B(real_image_B)
  errD_real_B = adversarial_loss(real_output_B, real_label)

  # Fake B image loss
  fake_image_B = fake_B_buffer.push_and_pop(fake_image_B)
  fake_output_B = netD_B(fake_image_B.detach())
  errD_fake_B = adversarial_loss(fake_output_B, fake_label)

  # Combined loss and calculate gradients
  errD_B = (errD_real_B + errD_fake_B) / 2

  # Calculate gradients for D_B
  errD_B.backward()
  # Update D_B weights
  optimizer_D_B.step()

In [None]:
print_freq = 1 

In [None]:
for epoch in range(0, epochs):
  for i, data in enumerate(dataloader):
    # get batch size data
    real_image_A = data["A"].to(device)
    real_image_B = data["B"].to(device)
    #batch_size = real_image_A.size(0)

    # real data label is 1, fake data label is 0.
    real_label = torch.full((batch_size, 1), 1, device=device, dtype=torch.float32)
    fake_label = torch.full((batch_size, 1), 0, device=device, dtype=torch.float32)

    # (1) Update G network: Generators A2B and B2A
    fake_image_A, fake_image_B = UpdateG(real_image_A,real_image_B,real_label,fake_label)

    #(2) Update D network: Discriminator A
    UpdateD_A(fake_image_A)

    #(3) Update D network: Discriminator B
    UpdateD_B(fake_image_B)

    if(i%print_freq ==0):
      vutils.save_image(real_image_A, f"{outf}/{dataset}/A/real_samples.png",normalize=True)
      vutils.save_image(real_image_B, f"{outf}/{dataset}/B/real_samples.png",normalize=True)
      fake_image_A = 0.5 * (netG_B2A(real_image_B).data + 1.0)
      fake_image_B = 0.5 * (netG_A2B(real_image_A).data + 1.0)
      vutils.save_image(fake_image_A.detach(), f"{outf}/{dataset}/A/fake_samples_epoch_{epoch}.png",normalize=True)
      vutils.save_image(fake_image_B.detach(),f"{outf}/{dataset}/B/fake_samples_epoch_{epoch}.png",normalize=True)

  # do check pointing every epoch
  torch.save(netG_A2B.state_dict(), f"weights/{dataset}/netG_A2B_epoch_{epoch}.pth")
  torch.save(netG_B2A.state_dict(), f"weights/{dataset}/netG_B2A_epoch_{epoch}.pth")
  torch.save(netD_A.state_dict(), f"weights/{dataset}/netD_A_epoch_{epoch}.pth")
  torch.save(netD_B.state_dict(), f"weights/{dataset}/netD_B_epoch_{epoch}.pth")

  # Update learning rates
  lr_scheduler_G.step()
  lr_scheduler_D_A.step()
  lr_scheduler_D_B.step()

# save last check pointing
torch.save(netG_A2B.state_dict(), f"weights/{dataset}/netG_A2B.pth")
torch.save(netG_B2A.state_dict(), f"weights/{dataset}/netG_B2A.pth")
torch.save(netD_A.state_dict(), f"weights/{dataset}/netD_A.pth")
torch.save(netD_B.state_dict(), f"weights/{dataset}/netD_B.pth")

  return F.mse_loss(input, target, reduction=self.reduction)
