In [59]:
import torch
import sys
import os
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision.utils import save_image
from torchvision import datasets
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import numpy as np
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm

In [60]:
class DiscBlock(nn.Module):
  def __init__(self, in_channels, out_channels, stride):
    super(DiscBlock, self).__init__()
    self.conv = nn.Sequential(
        nn.Conv2d(in_channels, out_channels, 4, stride, 1, bias=True, padding_mode='reflect'),
        nn.InstanceNorm2d(out_channels),
        nn.LeakyReLU(0.2)
    )

  def forward(self, x):
    return self.conv(x)

In [61]:
class Discrimenator(nn.Module):
  def __init__(self, in_channels, features=[64, 128, 256, 512]):
    super(Discrimenator, self).__init__()
    self.initial = nn.Sequential(
        nn.Conv2d(in_channels, features[0], 4, 2, 1, padding_mode='reflect'),
        nn.LeakyReLU(0.2)
    )

    layers = []
    in_channels = features[0]
    for feature in features[1:]:
      layers.append(DiscBlock(in_channels, feature, 2 if feature != features[-1] else 1 ))
      in_channels = feature

    layers.append(nn.Conv2d(in_channels, 1, 4, 1, 1, padding_mode='reflect'))

    self.model = nn.Sequential(*layers)

  def forward(self, x):
    x = self.initial(x)
    return torch.sigmoid(self.model(x))

In [62]:
def test_disc():
  x = torch.rand(5, 3, 256, 256)
  model = Discrimenator(in_channels=3)
  preds = model(x)
  print(preds.shape)

In [63]:
test_disc()

torch.Size([5, 1, 30, 30])


In [95]:
class GenBlock(nn.Module):
  def __init__(self, in_channels, out_channels, down=True, act=True, **kwargs):
    super(GenBlock, self).__init__()
    self.conv = nn.Sequential(
        nn.Conv2d(in_channels, out_channels, padding_mode='reflect', **kwargs) if down else \
        nn.ConvTranspose2d(in_channels, out_channels, **kwargs),
        nn.InstanceNorm2d(out_channels),
        nn.ReLU(inplace=True) if act else nn.Identity()
    )

  def forward(self, x):
    return self.conv(x)

In [96]:
class ResBlock(nn.Module):
  def __init__(self, channels):
    super(ResBlock, self).__init__()
    self.block = nn.Sequential(
        GenBlock(channels, channels, kernel_size=3, padding=1),
        GenBlock(channels, channels, act=False, kernel_size=3, padding=1)
    )

  def forward(self, x):
    return self.block(x) + x

In [97]:
class Generator(nn.Module):
  def __init__(self, img_channels, num_residuals=9, features=64):
    super(Generator, self).__init__()
    self.initial = nn.Sequential(
        nn.Conv2d(img_channels, features, kernel_size=7, stride=1, padding=3, padding_mode='reflect'),
        nn.ReLU(inplace=True)
    )

    self.down_blocks = nn.ModuleList([
        GenBlock(features, features*2, kernel_size=4, stride=2, padding=1),
        GenBlock(features*2, features*4, kernel_size=4, stride=2, padding=1)
    ])

    self.residual_blocks = nn.Sequential(
        *[ResBlock(features*4) for _ in range(num_residuals)]
    )

    self.up_blocks = nn.ModuleList([
        GenBlock(features*4, features*2, kernel_size=3, down=False, stride=2, padding=1, output_padding=1),
        GenBlock(features*2, features, kernel_size=3, down=False, stride=2, padding=1, output_padding=1)
    ])

    self.final = nn.Conv2d(features, img_channels, kernel_size=7, stride=1, padding=3, padding_mode='reflect')

  def forward(self, x):
    x = self.initial(x)

    for layer in self.down_blocks:
      x = layer(x)

    x = self.residual_blocks(x)

    for layer in self.up_blocks:
      x = layer(x)

    return torch.tanh(self.final(x))

In [67]:
def test_gen():
  x = torch.rand(5, 3, 256, 256)
  gen = Generator(img_channels=3)

  preds = gen(x)
  print(preds.shape)

In [68]:
test_gen()

torch.Size([5, 3, 256, 256])


In [69]:
class HorseZebraDataset(Dataset):
  def __init__(self, horse_root, zebra_root, transforms=None):
    self.horse_root = horse_root
    self.zebra_root = zebra_root
    self.transforms = transforms

    self.horse_images = os.listdir(self.horse_root)
    self.zebra_images = os.listdir(self.zebra_root)

  def __len__(self):
    return len(self.horse_images)

  def __getitem__(self, index):
    horse_img = self.horse_images[index]
    zebra_img = self.zebra_images[index]

    horse_path = os.path.join(self.horse_root, horse_img)
    zebra_path = os.path.join(self.zebra_root, zebra_img)

    horse_img = np.array(Image.open(horse_path).convert('RGB'))
    zebra_img = np.array(Image.open(zebra_path).convert('RGB'))

    if self.transforms:
      augumentations = self.transforms(image=zebra_img, image0=horse_img)
      zebra_img = augumentations['image']
      horse_img = augumentations['image0']

    return zebra_img, horse_img

In [108]:
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
TRAIN_DIR = 'data/CycleGAN/horse2zebra/train'
VAL_DIR = 'data/CycleGAN/horse2zebra/val'
LEARNING_RATE = 2e-4
BATCH_SIZE = 1
IMG_CHANNELS = 3
IMG_SIZE = 256
NUM_WORKERS = 4
LAMBDA_IDENTITY = 0
LAMBDA_CYCLE = 10
EPOCHS = 10
SAVE_MODEL = True
LOAD_MODEL = False
CHECKPOINT_GEN_Z = 'checkpoints/CycleGAN/GEN_Z/gen_z.pth.tar'
CHECKPOINT_GEN_H = 'checkpoints/CycleGAN/GEN_H/gen_h.pth.tar'
CHECKPOINT_DISC_Z = 'checkpoints/CycleGAN/DISC_Z/disc_z.pth.tar'
CHECKPOINT_DISC_H = 'checkpoints/CycleGAN/DISC_H/disc_h.pth.tar'


In [90]:
transforms = A.Compose([
    A.Resize(256, 256),
    A.HorizontalFlip(p=0.5),
    A.Normalize([0.5 for _ in range(IMG_CHANNELS)], [0.5 for _ in range(IMG_CHANNELS)]),
    ToTensorV2()
], additional_targets={'image0': 'image'})

In [125]:
def save_checkpoint(model, optimizer, filename='my_checkpoint.pth.tar'):
  print(f'=> Saving Checkpoint At {filename}')
  checkpoint = {
      'state_dict': model.state_dict(),
      'optimizer': optimizer.state_dict()
  }

  torch.save(checkpoint, filename)

def load_checkpoint(filename, model, optimizer, lr):
  print(f'=> Loading Checkpoint From {filename}')
  checkpoint = torch.load(filename, map_location=DEVICE)
  model.load_state_dict(checkpoint['state_dict'])
  optimizer.load_state_dict(checkpoint['optimizer'])

  for param_group in optimizer.param_groups:
    param_group['lr'] = lr

In [122]:
train_dataset = HorseZebraDataset(TRAIN_DIR+'/horses', TRAIN_DIR+'/zebras', transforms=transforms)
val_dataset = HorseZebraDataset(VAL_DIR+'/horses', VAL_DIR+'/zebras', transforms=transforms)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [127]:
def train_step(gen_Z, gen_H, disc_Z, disc_H, gen_opt, disc_opt, g_scalar, d_scalar, L1, MSE, loader):
  for batch_idx, (zebra, horse) in enumerate(tqdm(loader, leave=True)):
    zebra = zebra.to(DEVICE)
    horse = horse.to(DEVICE)

    with torch.cuda.amp.autocast():
      fake_horse = gen_H(zebra)
      fake_zebra = gen_Z(horse)

      D_H_real = disc_H(horse)
      D_Z_real = disc_Z(zebra)
      D_H_fake = disc_H(fake_horse.detach())
      D_Z_fake = disc_Z(fake_zebra.detach())

      D_H_real_loss = MSE(D_H_real, torch.ones_like(D_H_real))
      D_H_fake_loss = MSE(D_H_real, torch.ones_like(D_H_fake))
      D_H_loss = D_H_fake_loss + D_H_real_loss

      D_Z_real_loss = MSE(D_Z_real, torch.ones_like(D_Z_real))
      D_Z_fake_loss = MSE(D_Z_real, torch.ones_like(D_Z_fake))
      D_Z_loss = D_Z_real_loss + D_Z_fake_loss

      D_loss = (D_H_loss + D_Z_loss) / 2

    disc_opt.zero_grad()
    d_scalar.scale(D_loss).backward()
    d_scalar.step(disc_opt)
    d_scalar.update()

    with torch.cuda.amp.autocast():
      D_Z_fake = disc_Z(fake_zebra)
      D_H_fake = disc_H(fake_horse)
      G_H_loss = MSE(D_H_fake, torch.ones_like(D_H_fake))
      G_Z_loss = MSE(D_Z_fake, torch.ones_like(D_Z_fake))

      cycle_zebra = gen_Z(fake_horse)
      cycle_horse = gen_H(fake_zebra)
      cycle_H_loss = L1(horse, cycle_horse)
      cycle_Z_loss = L1(zebra, cycle_zebra)

      # identity_horse = gen_H(horse)
      # identity_zebra = gen_Z(zebra)
      # identity_H_loss = L1(horse, identity_horse)
      # identity_Z_loss = L1(horse, identity_zebra)

      G_loss = (
          G_H_loss +
          G_Z_loss +
          (cycle_H_loss * LAMBDA_CYCLE)+
          (cycle_Z_loss * LAMBDA_CYCLE)
          # (identity_H_loss * LAMBDA_IDENTITY)+
          # (identity_Z_loss * LAMBDA_IDENTITY)
      )

    gen_opt.zero_grad()
    g_scalar.scale(G_loss).backward()
    g_scalar.step(gen_opt)
    g_scalar.update()


In [None]:
disc_H = Discrimenator(in_channels=3).to(DEVICE)
disc_Z = Discrimenator(in_channels=3).to(DEVICE)
gen_H = Generator(img_channels=3, num_residuals=9).to(DEVICE)
gen_Z = Generator(img_channels=3, num_residuals=9).to(DEVICE)

disc_opt = optim.Adam(
    list(disc_H.parameters()) + list(disc_Z.parameters()),
    lr=LEARNING_RATE,
    betas=(0.5, 0.999)
)

gen_opt = optim.Adam(
    list(gen_H.parameters()) + list(gen_Z.parameters()),
    lr=LEARNING_RATE,
    betas=(0.5, 0.999)
)

l1 = nn.L1Loss()
mse = nn.MSELoss()

if LOAD_MODEL:
  load_checkpoint(CHECKPOINT_DISC_H, disc_H, disc_opt, LEARNING_RATE)
  load_checkpoint(CHECKPOINT_DISC_Z, disc_Z, disc_opt, LEARNING_RATE)
  load_checkpoint(CHECKPOINT_GEN_H, gen_H, gen_opt, LEARNING_RATE)
  load_checkpoint(CHECKPOINT_GEN_Z, gen_Z, gen_opt, LEARNING_RATE)

g_scalar = torch.cuda.amp.GradScaler()
d_scalar = torch.cuda.amp.GradScaler()

for epoch in range(EPOCHS):
  train_step(gen_Z, gen_H, disc_Z, disc_H, gen_opt, disc_opt, g_scalar, d_scalar, l1, mse, train_loader)

  if SAVE_MODEL:
    save_checkpoint(disc_H, disc_opt, CHECKPOINT_DISC_H)
    save_checkpoint(disc_Z, disc_opt, CHECKPOINT_DISC_Z)
    save_checkpoint(gen_H, gen_opt, CHECKPOINT_GEN_H)
    save_checkpoint(gen_Z, gen_opt, CHECKPOINT_GEN_Z)
