# Cycle GAN Implementation

## Dependencies

In [1]:
import os
import random
import numpy as np
from tqdm import tqdm
from PIL import Image
import albumentations as A
from albumentations.pytorch import ToTensorV2

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision.utils import save_image



## Utilities

In [2]:
def save_checkpoint(filename, model, optimizer):
    print("=> Saving checkpoint")
    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict(),
    }
    torch.save(checkpoint, filename)


def load_checkpoint(filename, model, optimizer, lr, device):
    print("=> Loading checkpoint")
    checkpoint = torch.load(filename, map_location=device)
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])

    for param_group in optimizer.param_groups:
        param_group["lr"] = lr


def seed_everything(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    os.environ["PYTHONHASHSEED"] = str(seed)

## Dataset

In [3]:
class CustomDataset(Dataset):
    def __init__(self, root1, root2, trans=None):
        self.root1 = root1
        self.root2 = root2
        self.trans = trans

        self.images1 = os.listdir(root1)
        self.images2 = os.listdir(root2)

        self.len1 = len(self.images1)
        self.len2 = len(self.images2)

    def __len__(self):
        return max(self.len1, self.len2)

    def __getitem__(self, index):
        img1 = self.images1[index % self.len1]
        img2 = self.images2[index % self.len2]

        path1 = os.path.join(self.root1, img1)
        path2 = os.path.join(self.root2, img2)

        img1 = np.array(Image.open(path1).convert('RGB'))
        img2 = np.array(Image.open(path2).convert('RGB'))

        if self.trans:
            augments = self.trans(image=img1, image0=img2)
            img1 = augments["image"]
            img2 = augments["image0"]

        return img1, img2

## Discriminator

In [4]:
class CNNBlockDisc(nn.Module):
    def __init__(self, in_channels, out_channels, stride):
        super().__init__()

        self.conv = nn.Sequential(
            nn.Conv2d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=4,
                stride=stride,
                padding=1,
                bias=True,
                padding_mode="reflect",
            ),
            nn.InstanceNorm2d(out_channels),
            nn.LeakyReLU(0.2),
        )

    def forward(self, x):
        return self.conv(x)


class Discriminator(nn.Module):
    def __init__(self, in_channels, features):
        super().__init__()

        self.initial = nn.Sequential(
            nn.Conv2d(
                in_channels,
                features[0],
                4, 2, 1,
                padding_mode="reflect",
            ), nn.LeakyReLU(0.2),
        )

        layers = []
        for i in range(1, len(features)):
            layers.append(CNNBlockDisc(
                in_channels=features[i-1],
                out_channels=features[i],
                stride=(1 if i == len(features) - 1 else 2),
            ))

        layers.append(nn.Conv2d(
            features[-1], 1, 4, 1, 1,
            padding_mode="reflect",
        ))

        self.model = nn.Sequential(*layers)

    def forward(self, x):
        x = self.initial(x)
        x = self.model(x)
        return torch.sigmoid(x)


def test():
    N = 32
    x = torch.randn(N, 3, 256, 256)
    model = Discriminator(3, [64, 128, 256, 512])
    print(model(x).shape)


# test()

## Generator

In [5]:
class CNNBlockGen(nn.Module):
    def __init__(self, in_channels, out_channels, down=True, use_act=True,
                 **kwargs):
        super().__init__()

        self.block = nn.Sequential(
            nn.Conv2d(
                in_channels,
                out_channels,
                padding_mode="reflect",
                **kwargs,
            ) if down else nn.ConvTranspose2d(
                in_channels,
                out_channels,
                **kwargs,
            ),
            nn.InstanceNorm2d(out_channels),
            nn.ReLU(inplace=True) if use_act else nn.Identity(),
        )

    def forward(self, x):
        return self.block(x)


class ResBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()

        self.block = nn.Sequential(
            CNNBlockGen(channels, channels,
                     kernel_size=3, padding=1),
            CNNBlockGen(channels, channels, use_act=False,
                     kernel_size=3, padding=1),
        )

    def forward(self, x):
        return x + self.block(x)


class Generator(nn.Module):
    def __init__(self, image_channels, num_features, num_residuals):
        super().__init__()

        self.initial = nn.Sequential(
            nn.Conv2d(
                image_channels,
                num_features,
                7, 1, 3,
                padding_mode="reflect",
            ), nn.ReLU(inplace=True),
        )

        self.down_blocks = nn.ModuleList(
            [
                CNNBlockGen(num_features*1, num_features*2,
                         kernel_size=3, stride=2, padding=1),
                CNNBlockGen(num_features*2, num_features*4,
                         kernel_size=3, stride=2, padding=1),
            ]
        )

        self.residual_blocks = nn.Sequential(
            *[ResBlock(num_features*4) for _ in range(num_residuals)]
        )

        self.up_blocks = nn.ModuleList(
            [
                CNNBlockGen(num_features*4, num_features*2, down=False,
                         kernel_size=3, stride=2, padding=1, output_padding=1),
                CNNBlockGen(num_features*2, num_features*1, down=False,
                         kernel_size=3, stride=2, padding=1, output_padding=1),
            ]
        )

        self.last = nn.Sequential(
            nn.Conv2d(
                num_features,
                image_channels,
                7, 1, 3,
                padding_mode="reflect",
            ), nn.Tanh(),
        )

    def forward(self, x):
        x = self.initial(x)
        for layer in self.down_blocks:
            x = layer(x)
        x = self.residual_blocks(x)
        for layer in self.up_blocks:
            x = layer(x)
        return self.last(x)


def test():
    N = 32
    x = torch.randn(N, 3, 256, 256)
    model = Generator(3, 64, 9)
    print(model(x).shape)


# test()

## Cycle GAN

In [6]:
class CycleGAN:
    def __init__(self, save=True, load=False, epochs=10):
        self.DEVICE = torch.device("cuda") if torch.cuda.is_available() \
            else torch.device("cpu")

        self.SAVE = save
        self.LOAD = load
        self.EPOCHS = epochs

        self.NUM_EPOCHS = 10
        self.LEARNING_RATE = 1e-5
        self.LAMBDA_CYCLE = 10.0
        self.LAMBDA_IDENTITY = 0.0

        self.CHECKPOINT_DSC_X = "checkpoints/dscx.pth.tar"
        self.CHECKPOINT_DSC_Y = "checkpoints/dscy.pth.tar"
        self.CHECKPOINT_GEN_X = "checkpoints/genx.pth.tar"
        self.CHECKPOINT_GEN_Y = "checkpoints/geny.pth.tar"
        self.IMG_SAVE_PATH = "saved_images"
        
        try:
            os.mkdir(self.IMG_SAVE_PATH)
            print("'Saved Images' Directory created.")
        except FileExistsError:
            print("'Saved Images' Directory already exists.")

        IMAGE_CHANNELS = 3
        DSC_FEATURES = [64, 128, 256, 512]
        GEN_FEATURES = 64
        NUM_RESIDUALS = 9

        BATCH_SIZE = 1
        NUM_WORKERS = 4
        ROOT_PATH_1 = "../input/gan-getting-started/photo_jpg"
        ROOT_PATH_2 = "../input/gan-getting-started/monet_jpg"

        transforms = A.Compose(
            [
                A.Resize(width=256, height=256),
                A.HorizontalFlip(p=0.5),
                A.Normalize(
                    mean=[0.5 for _ in range(IMAGE_CHANNELS)],
                    std=[0.5 for _ in range(IMAGE_CHANNELS)],
                    max_pixel_value=255,
                ),
                ToTensorV2(),
            ],
            additional_targets={"image0": "image"},
        )

        self.dscX = Discriminator(IMAGE_CHANNELS, DSC_FEATURES).to(self.DEVICE)
        self.dscY = Discriminator(IMAGE_CHANNELS, DSC_FEATURES).to(self.DEVICE)
        self.genX = Generator(IMAGE_CHANNELS, GEN_FEATURES,
                              NUM_RESIDUALS).to(self.DEVICE)
        self.genY = Generator(IMAGE_CHANNELS, GEN_FEATURES,
                              NUM_RESIDUALS).to(self.DEVICE)

        self.opt_dsc = optim.Adam(
            list(self.dscX.parameters()) + list(self.dscY.parameters()),
            lr=self.LEARNING_RATE,
            betas=(0.5, 0.999),
        )

        self.opt_gen = optim.Adam(
            list(self.genX.parameters()) + list(self.genY.parameters()),
            lr=self.LEARNING_RATE,
            betas=(0.5, 0.999),
        )

        self.l1 = nn.L1Loss()
        self.mse = nn.MSELoss()

        dataset = CustomDataset(
            root1=ROOT_PATH_1,
            root2=ROOT_PATH_2,
            trans=transforms
        )

        self.dataloader = DataLoader(
            shuffle=True,
            pin_memory=True,
            dataset=dataset,
            batch_size=BATCH_SIZE,
            num_workers=NUM_WORKERS,
        )

        self.gen_scaler = torch.cuda.amp.GradScaler()
        self.dsc_scaler = torch.cuda.amp.GradScaler()

    def train(self):
        if self.LOAD:
            self.load_model()

        for epoch in range(self.EPOCHS):
            self.train_model()

            if self.SAVE:
                self.save_model()

    def train_model(self):
        seed_everything()
        Y_real_log, Y_fake_log = 0, 0
        looper = tqdm(self.dataloader)

        for idx, (X, Y) in enumerate(looper):
            real_X, real_Y = X.to(self.DEVICE), Y.to(self.DEVICE)
            
            # Discriminator Training
            with torch.cuda.amp.autocast():
                fake_X = self.genX(real_Y)
                dscX_real = self.dscX(real_X)
                dscX_fake = self.dscX(fake_X)
                dscX_real_loss = self.mse(
                    dscX_real, torch.ones_like(dscX_real))
                dscX_fake_loss = self.mse(
                    dscX_fake, torch.zeros_like(dscX_fake))
                dscX_loss = dscX_real_loss + dscX_fake_loss

                fake_Y = self.genY(real_X)
                dscY_real = self.dscY(real_Y)
                dscY_fake = self.dscY(fake_Y)
                dscY_real_loss = self.mse(
                    dscY_real, torch.ones_like(dscY_real))
                dscY_fake_loss = self.mse(
                    dscY_fake, torch.zeros_like(dscY_fake))
                dscY_loss = dscY_real_loss + dscY_fake_loss

                dsc_loss = dscX_loss + dscY_loss

                # For logging purposes
                Y_real_log += dscY_real.mean().item()
                Y_fake_log += dscY_fake.mean().item()

            self.opt_dsc.zero_grad()
            self.dsc_scaler.scale(dsc_loss).backward(retain_graph=True)
            self.dsc_scaler.step(self.opt_dsc)
            self.dsc_scaler.update()

            # Generator Training
            with torch.cuda.amp.autocast():
                # Adversarial Loss
                dscX_fake = self.dscX(fake_X)
                dscY_fake = self.dscY(fake_Y)
                genX_loss = self.mse(dscX_fake, torch.ones_like(dscX_fake))
                genY_loss = self.mse(dscY_fake, torch.ones_like(dscY_fake))

                # Cycle Loss
                cycleX = self.genX(fake_Y)
                cycleY = self.genY(fake_X)
                cycleX_loss = self.l1(real_X, cycleX)
                cycleY_loss = self.l1(real_Y, cycleY)

                # Identity Loss (currently set to zero)
                # identityX = self.genX(real_X)
                # identityY = self.genY(real_Y)
                # identityX_loss = self.l1(real_X, identityX)
                # identityY_loss = self.l1(real_Y, identityY)

                gen_loss = (
                    genX_loss + genY_loss +
                    (cycleX_loss + cycleY_loss) * self.LAMBDA_CYCLE
                    # (identityX_loss + identityY_loss) * self.LAMBDA_IDENTITY
                )

            self.opt_gen.zero_grad()
            self.gen_scaler.scale(gen_loss).backward()
            self.gen_scaler.step(self.opt_gen)
            self.gen_scaler.update()

            # Saving images and logging
            if idx % 100 == 0:
                save_image(fake_X * 0.5 + 0.5,
                           f"{self.IMG_SAVE_PATH}/X{idx}.jpg")
                save_image(fake_Y * 0.5 + 0.5,
                           f"{self.IMG_SAVE_PATH}/Y{idx}.jpg")

            looper.set_postfix(
                Y_real=Y_real_log / (idx + 1),
                Y_fake=Y_fake_log / (idx + 1),
            )

    def load_model(self):
        load_checkpoint(self.CHECKPOINT_DSC_X, self.dscX,
                        self.opt_dsc, self.LEARNING_RATE, self.DEVICE)
        load_checkpoint(self.CHECKPOINT_DSC_Y, self.dscY,
                        self.opt_dsc, self.LEARNING_RATE, self.DEVICE)
        load_checkpoint(self.CHECKPOINT_GEN_X, self.genX,
                        self.opt_gen, self.LEARNING_RATE, self.DEVICE)
        load_checkpoint(self.CHECKPOINT_GEN_Y, self.genY,
                        self.opt_gen, self.LEARNING_RATE, self.DEVICE)

    def save_model(self):
        save_checkpoint(self.CHECKPOINT_DSC_X, self.dscX, self.opt_dsc)
        save_checkpoint(self.CHECKPOINT_DSC_Y, self.dscY, self.opt_dsc)
        save_checkpoint(self.CHECKPOINT_GEN_X, self.genX, self.opt_gen)
        save_checkpoint(self.CHECKPOINT_GEN_Y, self.genY, self.opt_gen)

## Execution

In [7]:
model = CycleGAN()
model.train()

'Saved Images' Directory already exists.


  1%|          | 82/7038 [01:17<1:49:04,  1.06it/s, Y_fake=nan, Y_real=0.503]


RuntimeError: CUDA error: unspecified launch failure
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
