In [39]:
###CONFIG####
import torch
import albumentations as A
from albumentations.pytorch import ToTensorV2

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
TRAIN_DIR = "data/train"
VAL_DIR = "data/val"
BATCH_SIZE = 1
LEARNING_RATE = 1e-5
LAMBDA_IDENTITY = 0.2 # (remove for efficiency if you set lambda_identity=0)
LAMBDA_CYCLE = 10
NUM_WORKERS = 2
NUM_EPOCHS = 10
LOAD_MODEL = False
SAVE_MODEL = False
CHECKPOINT_GEN_X = "genX.pth.tar"
CHECKPOINT_GEN_Y = "genY.pth.tar"
CHECKPOINT_CRITIC_X = "criticX.pth.tar"
CHECKPOINT_CRITIC_Y = "criticY.pth.tar"

transforms = A.Compose(
    [
        A.Resize(width=256, height=256),
        A.HorizontalFlip(p=0.5),
        A.ColorJitter(p = 0.1),
        A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255),
        ToTensorV2(),
    ],
    additional_targets={"image0": "image"},
)

In [40]:
###UTILS###
import random, os, numpy as np
import torch.nn as nn
import copy

def save_checkpoint(model, optimizer, filename="my_checkpoint.pth.tar"):
    print("=> Saving checkpoint")
    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict(),
    }
    torch.save(checkpoint, filename)


def load_checkpoint(checkpoint_file, model, optimizer, lr):
    print("=> Loading checkpoint")
    checkpoint = torch.load(checkpoint_file, map_location=DEVICE)
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])

    # If we don't do this then it will just have learning rate of old checkpoint
    # and it will lead to many hours of debugging \:
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr


def seed_everything(seed=42): # why we need this?
    os.environ["PYTHONHASHSEED"] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [41]:
###DATASET###
from PIL import Image
from torch.utils.data import Dataset

class XYDataset(Dataset):
    def __init__(self, root_X, root_Y, transform=None):
        self.root_X = root_X
        self.root_Y = root_Y
        self.transform = transform

        self.X_images = os.listdir(root_X)
        self.Y_images = os.listdir(root_Y)
        self.length_dataset = max(len(self.X_images), len(self.Y_images)) # 1000, 1500
        self.X_len = len(self.X_images)
        self.Y_len = len(self.Y_images)

    def __len__(self):
        return self.length_dataset

    def __getitem__(self, index):
        X_img = self.X_images[index % self.X_len]
        Y_img = self.Y_images[index % self.Y_len]

        X_path = os.path.join(self.root_X, X_img)
        Y_path = os.path.join(self.root_Y, Y_img)

        X_img = np.array(Image.open(X_path).convert("RGB"))
        Y_img = np.array(Image.open(Y_path).convert("RGB"))

        if self.transform:
            augmentations = self.transform(image=X_img, image0=Y_img)
            X_img = augmentations["image"]
            Y_img = augmentations["image0"]

        return X_img, Y_img

In [42]:
class BlockDisc(nn.Module):
  def __init__(self, in_channels, out_channels, stride = 2): #!
    super().__init__()
    self.conv = nn.Sequential(
        nn.Conv2d(in_channels, out_channels,4,stride,1, bias = True, padding_mode = "reflect"), #!
        nn.InstanceNorm2d(out_channels), #! why Instancenorm and not batch?
        nn.LeakyReLU(0.2,inplace = True) # why inplace?
    )

  def forward(self,x):
    return self.conv(x)

In [43]:
class Discriminator(nn.Module):
  def __init__(self,in_channels = 3, features = [64,128,256,512]): #256 -> 30x30
    super().__init__()
    self.initial = nn.Sequential(
        nn.Conv2d(in_channels,features[0],kernel_size=4,stride=2,padding=1,padding_mode = "reflect"),
        nn.LeakyReLU(0.2,inplace = True),
    )

    layers = []
    in_channels = features[0]
    for feature in features[1:]:
      layers.append(
          BlockDisc(in_channels,feature,stride = 1 if feature == features[-1] else 2),
      )
      in_channels = feature
    layers.append( # !!
        nn.Conv2d(in_channels,1,kernel_size=4,stride=1,padding=1,padding_mode = "reflect"),
    )
    self.model = nn.Sequential(*layers)

  def forward(self,x): # why not cat and why not y in arguments?
    #x  = torch.cat([x,y], dim = 1)
    x  = self.initial(x)
    return torch.sigmoid(self.model(x)) #!

In [44]:
def test():
  x = torch.randn((5,3,256,256))
  model = Discriminator(in_channels=3)
  preds = model(x)
  print(preds.shape)

In [45]:
test()

torch.Size([5, 1, 30, 30])


In [46]:
class ConvBlock(nn.Module):
  def __init__(self,in_channels,out_channels,down = True,act = True, **kwargs): #!
    super().__init__()
    self.conv = nn.Sequential( #!
        nn.Conv2d(in_channels,out_channels, **kwargs ,bias = False, padding_mode="reflect")
        if down
        else nn.ConvTranspose2d(in_channels,out_channels, **kwargs,bias = False),
        nn.InstanceNorm2d(out_channels),
        nn.ReLU(inplace=True) if act else nn.Identity(),
    )

  def forward(self,x):
    return self.conv(x)

In [47]:
class ResidualBlock(nn.Module):
  def __init__(self,channels):
    super().__init__()
    self.resid = nn.Sequential(
        ConvBlock(channels,channels,kernel_size = 3, padding = 1),
        ConvBlock(channels,channels,act = False,kernel_size = 3, padding = 1) ,
    )

  def forward(self,x):
    return x + self.resid(x) # why x + ?

In [48]:
class Generator(nn.Module):
  def __init__(self,img_channels,num_features = 64, num_residuals = 9):
    super().__init__()
    self.initial = nn.Sequential(
        nn.Conv2d(img_channels,num_features,kernel_size = 7,stride = 1, padding = 3 , padding_mode = "reflect"),
        nn.ReLU(inplace = True),
    )
    self.encoder = nn.ModuleList(
        [
            ConvBlock(num_features,num_features * 2, down = True,kernel_size = 3,stride = 2,padding = 1),
            ConvBlock(num_features*2,num_features * 4, down = True,kernel_size = 3,stride = 2,padding = 1)
        ]
    )

    self.residuals = nn.Sequential(
        *[
            ResidualBlock(num_features*4) for _ in range(num_residuals)
        ]
    )

    self.decoder = nn.ModuleList(
        [
            ConvBlock(num_features*4,num_features * 2, down = False,kernel_size = 3,stride = 2,padding = 1,output_padding = 1),
            ConvBlock(num_features*2,num_features, down = False,kernel_size = 3,stride = 2,padding = 1,output_padding = 1)
        ]
    )

    self.last = nn.Conv2d(num_features, img_channels,kernel_size=7,stride = 1, padding=3,padding_mode = "reflect")

  def forward(self,x):
    x = self.initial(x)
    for layer in self.encoder:
      x = layer(x)
    x = self.residuals(x)
    for layer in self.decoder:
      x = layer(x)
    return torch.tanh(self.last(x))

In [49]:
def test():
    x = torch.randn((2, 3, 256, 256))
    gen = Generator(3, 9)
    print(gen(x).shape)

In [50]:
test()

torch.Size([2, 3, 256, 256])


In [53]:

import sys
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
from torchvision.utils import save_image

def train_fn(disc_X,disc_Y,gen_Y,gen_X,loader,opt_disc,opt_gen,L1,mse,d_scaler,g_scaler):
  loop = tqdm(loader,leave = True)

  for idx, (Y,X) in enumerate(loop):
    Y = Y.to(DEVICE)
    X = X.to(DEVICE)

    with torch.cuda.amp.autocast():
      fake_X = gen_X(Y)
      D_X_real = disc_X(X)
      D_X_fake = disc_X(fake_X.detach())
      D_X_real_loss = mse(D_X_real, torch.ones_like(D_X_real))
      D_X_fake_loss = mse(D_X_fake, torch.zeros_like(D_X_fake))
      D_X_loss = D_X_real_loss + D_X_fake_loss

      fake_Y = gen_Y(X)
      D_Y_real = disc_Y(Y)
      D_Y_fake = disc_Y(fake_Y.detach())
      D_Y_real_loss = mse(D_Y_real, torch.ones_like(D_Y_real))
      D_Y_fake_loss = mse(D_Y_fake, torch.zeros_like(D_Y_fake))
      D_Y_loss = D_Y_real_loss + D_Y_fake_loss

      D_loss = (D_X_loss + D_Y_loss) /2

    opt_disc.zero_grad()
    d_scaler.scale(D_loss).backward()
    d_scaler.step(opt_disc)
    d_scaler.update()


    with torch.cuda.amp.autocast():
      # adversarial loss for both generators
      D_X_fake = disc_X(fake_X)
      D_Y_fake = disc_Y(fake_Y)
      loss_G_X = mse(D_X_fake,torch.ones_like(D_X_fake))
      loss_G_Y = mse(D_Y_fake,torch.ones_like(D_Y_fake))

      # cycle loss
      cycle_Y = gen_Y(fake_X)
      cycle_X = gen_X(fake_Y)
      cycle_Y_loss = L1(Y, cycle_Y)
      cycle_X_loss = L1(X, cycle_X)

      # identity loss
      identity_Y = gen_Y(Y)
      identity_X = gen_X(X)
      identity_Y_loss = L1(Y, identity_Y)
      identity_X_loss = L1(X, identity_X)


      G_loss = (loss_G_X + loss_G_Y + cycle_Y_loss * LAMBDA_CYCLE + cycle_X_loss * LAMBDA_CYCLE + identity_Y_loss * LAMBDA_IDENTITY + identity_X_loss * LAMBDA_IDENTITY)

    opt_gen.zero_grad()
    g_scaler.scale(G_loss).backward()
    g_scaler.step(opt_gen)
    g_scaler.update()

    if idx % 200 == 0:
      save_image(fake_X * 0.5 + 0.5, f"saved_images/X_{idx}.png")
      save_image(fake_Y * 0.5 + 0.5, f"saved_images/Y_{idx}.png")
def main():
    disc_X = Discriminator(in_channels=3).to(DEVICE)
    disc_Y = Discriminator(in_channels=3).to(DEVICE)
    gen_Y = Generator(img_channels=3, num_residuals=9).to(DEVICE) #X->Y
    gen_X = Generator(img_channels=3, num_residuals=9).to(DEVICE) #Y->X

    opt_disc = optim.Adam(list(disc_X.parameters()) + list(disc_Y.parameters()),lr=LEARNING_RATE,betas=(0.5, 0.999),)
    opt_gen = optim.Adam(list(gen_Y.parameters()) + list(gen_X.parameters()),lr=LEARNING_RATE,betas=(0.5, 0.999),)

    L1 = nn.L1Loss()
    mse = nn.MSELoss()

    if LOAD_MODEL:
        load_checkpoint(CHECKPOINT_GEN_Y,gen_X,opt_gen,LEARNING_RATE,)
        load_checkpoint(CHECKPOINT_GEN_Y,gen_Y,opt_gen,LEARNING_RATE,)
        load_checkpoint(CHECKPOINT_CRITIC_X,disc_X,opt_disc,LEARNING_RATE,)
        load_checkpoint(CHECKPOINT_CRITIC_Y,disc_Y,opt_disc,LEARNING_RATE,)

    dataset = XYDataset(root_X=TRAIN_DIR + "/X",root_Y=TRAIN_DIR + "/Y",transform=transforms,)
    val_dataset = XYDataset(root_X="cyclegan_test/X1",root_Y="cyclegan_test/Y1",transform=transforms,)
    val_loader = DataLoader(val_dataset,batch_size=1,shuffle=False,pin_memory=True,)
    loader = DataLoader(dataset,batch_size=BATCH_SIZE,shuffle=True,num_workers=NUM_WORKERS,pin_memory=True,)
    g_scaler = torch.cuda.amp.GradScaler()
    d_scaler = torch.cuda.amp.GradScaler()

    for epoch in range(NUM_EPOCHS):
        train_fn(disc_X,disc_Y,gen_Y,gen_X,loader,opt_disc,opt_gen,L1,mse,d_scaler,g_scaler,)

        if SAVE_MODEL:
            save_checkpoint(gen_X, opt_gen, filename=CHECKPOINT_GEN_Y)
            save_checkpoint(gen_Y, opt_gen, filename=CHECKPOINT_GEN_X)
            save_checkpoint(disc_X, opt_disc, filename=CHECKPOINT_CRITIC_X)
            save_checkpoint(disc_Y, opt_disc, filename=CHECKPOINT_CRITIC_Y)


if __name__ == "__main__":
  main()

100%|██████████| 1334/1334 [05:30<00:00,  4.04it/s]
100%|██████████| 1334/1334 [05:29<00:00,  4.05it/s]
100%|██████████| 1334/1334 [05:29<00:00,  4.05it/s]
100%|██████████| 1334/1334 [05:29<00:00,  4.05it/s]
100%|██████████| 1334/1334 [05:28<00:00,  4.06it/s]
100%|██████████| 1334/1334 [05:28<00:00,  4.06it/s]
100%|██████████| 1334/1334 [05:29<00:00,  4.05it/s]
100%|██████████| 1334/1334 [05:28<00:00,  4.06it/s]
100%|██████████| 1334/1334 [05:28<00:00,  4.06it/s]
100%|██████████| 1334/1334 [05:29<00:00,  4.05it/s]


In [52]:
!rmdir data/train/Y/.ipynb_checkpoints