### Imports

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchvision

import numpy as np
from PIL import Image
import albumentations as A
from albumentations.pytorch import ToTensorV2
import glob
from collections import namedtuple
from tqdm.notebook import tqdm
import random

### Convolution Blocks

#### Standard Convolution Block

In [None]:
class ConvBlock(nn.Module):
    def __init__(self, c_in: int, c_out: int, **kwargs):
        """
        Convolution Block with LeakyReLU activation
        """
        super().__init__()

        self.conv_block = nn.Sequential(
            nn.Conv2d(c_in, c_out, **kwargs),
            nn.InstanceNorm2d(c_out),
            nn.LeakyReLU(0.2)
        )

    def forward(self, x: torch.Tensor):
        return self.conv_block(x)

#### Upsample Convolution Block

In [None]:
class UpsampleConvBlock(nn.Module):
    def __init__(self, c_in: int, c_out: int, **kwargs):
        """
        Transpose Convolution Block with LeakyReLU activation
        """
        super().__init__()

        self.conv_block = nn.Sequential(
            nn.Conv2d(c_in, c_out, **kwargs),
            nn.InstanceNorm2d(c_out),
            nn.LeakyReLU(0.2)
        )

    def forward(self, x: torch.Tensor):
        x = F.interpolate(x, scale_factor=2, mode="bilinear")
        return self.conv_block(x)

#### Gated Convolution Block

In [None]:
class ConcatSiLU(nn.Module):
    """
    Custom activation function that concatenates SiLU(x) and SiLU(-x) along channel dimension
    """
    def forward(self, x: torch.Tensor):
        return torch.cat([F.silu(x), F.silu(-x)], dim=1)

class GatedResidualConvBlock(nn.Module):
    def __init__(self, c_in: int, c_out: int, c_hidden: int, **kwargs):
        """
        Gated Convolution Block with residual connection
        """
        super().__init__()

        self.conv_block = nn.Sequential(
            nn.Conv2d(c_in, c_hidden, bias=False, padding_mode="reflect", **kwargs),
            nn.InstanceNorm2d(c_hidden),
            ConcatSiLU(),
            nn.Conv2d(2*c_hidden, 2*c_out, padding_mode="reflect", **kwargs)
        )

    def forward(self, x: torch.Tensor):
        out = self.conv_block(x)

        value, gate = out.chunk(2, dim=1)

        return x + torch.tanh(value) * torch.sigmoid(gate)

### Discriminator

In [None]:
class Discriminator(nn.Module):
    def __init__(self, c_in: int = 3, features: list = [64, 128, 256, 512]):
        """
        Discriminator neural network for CycleGAN Architecture
        """
        super().__init__()

        self.c_in = c_in
        self.features = features

        self.discriminator = self._build_architecture()

    def _build_architecture(self):
        layers = []

        layers.append(
            nn.Sequential(
                nn.Conv2d(self.c_in, self.features[0], 4, 2, 1, padding_mode="reflect"),
                nn.LeakyReLU(0.2)
            )
        )

        for c_in, c_out in zip(self.features[:-1], self.features[1:]):
            layers.append(
                ConvBlock(c_in, c_out, kernel_size=4, stride=2 if c_out!=self.features[-1] else 1, padding=1, padding_mode="reflect")
            )
        
        layers.append(nn.Conv2d(self.features[-1], 1, kernel_size=4, stride=1, padding=1, padding_mode="reflect"))

        return nn.Sequential(*layers)


    def forward(self, x: torch.Tensor):
        return torch.sigmoid(self.discriminator(x))

### Generator

In [None]:
class Generator(nn.Module):
    def __init__(self, image_channels: int = 3, num_features:int = 64, num_residuals: int = 9):
        """
        Generator neural network for CycleGAN Architecture
        """
        super().__init__()

        dilation_residual = [1, 2, 4, 1, 2, 4, 1, 2, 4]
        padding_residual = [dilation_residual[i]*(3-1)//2 for i in range(len(dilation_residual))]

        self.initial = nn.Sequential(
            nn.Conv2d(image_channels, num_features, kernel_size=7, stride=1, padding=3, padding_mode="reflect"),
            nn.SiLU()
        )

        self.down_block = nn.Sequential(
            ConvBlock(num_features, num_features*2, kernel_size=3, stride=2, padding=1),
            ConvBlock(num_features*2, num_features*4, kernel_size=3, stride=2, padding=1)
        )

        self.gated_residual_blocks = nn.Sequential(
            *[GatedResidualConvBlock(num_features*4, num_features*4, num_features*4, kernel_size=3, stride=1, padding=padding_residual[i], dilation=dilation_residual[i]) for i in range(num_residuals)]
        )

        self.up_block = nn.Sequential(
            UpsampleConvBlock(num_features*4, num_features*2, kernel_size=3, stride=1, padding=1),
            UpsampleConvBlock(num_features*2, num_features, kernel_size=3, stride=1, padding=1)
        )

        self.last = nn.Conv2d(num_features, image_channels, kernel_size=7, stride=1, padding=3, padding_mode="reflect")

    def forward(self, x: torch.Tensor):
        x = self.initial(x)
        x = self.down_block(x)
        x = self.gated_residual_blocks(x)
        x = self.up_block(x)
        x = self.last(x)

        return torch.tanh(x)

### Dataset

In [None]:
class Photo2MonetDataset(Dataset):
    def __init__(self, photos: list, monet: list, transform = None):
        
        self.photos = photos
        self.monet = monet
        self.transform = transform

        self.photos_len = len(photos)
        self.monet_len = len(monet)

    def __len__(self):
        return self.monet_len

    def __getitem__(self, index):
        # index_photo = random.randint(0, self.photos_len-1)
        photo = np.array(Image.open(self.photos[index]))
        monet = np.array(Image.open(self.monet[index]))

        if self.transform:
            transformed = self.transform(image=photo, image0=monet)
            photo = transformed["image"]
            monet = transformed["image0"]
        
        return photo, monet

### Training Loop

#### Hyperparameters

In [None]:
Hyperparameters = namedtuple("Hyperparameters", ["batch_size", "lr", "lambda_cycle", "lambda_identity", "device", "save_path", "load_model", "visualize", "train_model"])

#### Save Model

In [None]:
def save_model(filepath: str, model: nn.Module, optimizer: optim.Optimizer):
    torch.save(
        {
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict()
        },
        filepath
    )

#### Vizualize grid

In [None]:
def visualize_grid(images: torch.Tensor, generator_p2m: nn.Module, filename: str):
    monet = generator_p2m(images)
    norm_monet = 0.5*monet+0.5

    grid = torchvision.utils.make_grid(norm_monet, nrow=4)

    torchvision.utils.save_image(grid.cpu(), filename)

#### Training Loop

In [None]:
def training_step(
    discriminator_m: nn.Module, generator_p2m: nn.Module, discriminator_p: nn.Module, generator_m2p: nn.Module,
    optimizer_discriminator: optim.Optimizer, optimizer_generator: optim.Optimizer, loader: DataLoader,
    d_scaler: torch.cuda.amp.grad_scaler.GradScaler, g_scaler: torch.cuda.amp.grad_scaler.GradScaler, params: Hyperparameters
    ):
    
    for photo, monet in tqdm(loader, leave=True):
        # print(torch.cuda.memory_allocated()/1e9)

        photo = photo.to(params.device)
        monet = monet.to(params.device)

        # Training discriminators
        with torch.cuda.amp.autocast_mode.autocast():
            # photo-to-monet pass
            fake_monet = generator_p2m(photo)
            disc_monet_real = discriminator_m(monet)
            disc_monet_fake = discriminator_m(fake_monet.detach())
            disc_monet_real_loss = F.mse_loss(disc_monet_real, torch.ones_like(disc_monet_real))
            disc_monet_fake_loss = F.mse_loss(disc_monet_fake, torch.zeros_like(disc_monet_real))
            disc_monet_loss = disc_monet_real_loss + disc_monet_fake_loss

            # monet-to-photo pass
            fake_photo = generator_m2p(monet)
            disc_photo_real = discriminator_p(photo)
            disc_photo_fake = discriminator_p(fake_photo.detach())
            disc_photo_real_loss = F.mse_loss(disc_photo_real, torch.ones_like(disc_monet_real))
            disc_photo_fake_loss = F.mse_loss(disc_photo_fake, torch.zeros_like(disc_monet_real))
            disc_photo_loss = disc_photo_real_loss + disc_photo_fake_loss

            # total discriminator loss
            disc_total_loss = disc_monet_loss + disc_photo_loss
        
        optimizer_discriminator.zero_grad()
        d_scaler.scale(disc_total_loss).backward()
        d_scaler.step(optimizer_discriminator)
        d_scaler.update()


        # Training generators
        with torch.cuda.amp.autocast_mode.autocast():
            # adversarial loss
            disc_monet_fake = discriminator_m(fake_monet)
            disc_photo_fake = discriminator_p(fake_photo)

            gen_adversarial_loss = F.mse_loss(disc_monet_fake, torch.ones_like(disc_monet_fake)) + F.mse_loss(disc_photo_fake, torch.ones_like(disc_photo_fake))
        
            # cycle consistency loss
            cycle_monet = generator_p2m(fake_photo)
            cycle_photo = generator_m2p(fake_monet)

            gen_cycle_loss = F.l1_loss(cycle_monet, monet) + F.l1_loss(cycle_photo, photo)

            # identity loss
            identity_monet = generator_p2m(monet)
            identity_photo = generator_m2p(photo)

            gen_identity_loss = F.l1_loss(identity_monet, monet) + F.l1_loss(identity_photo, photo)

            gen_total_loss = gen_adversarial_loss + params.lambda_cycle * gen_cycle_loss + params.lambda_identity * gen_identity_loss
        
        optimizer_generator.zero_grad()
        g_scaler.scale(gen_total_loss).backward()
        g_scaler.step(optimizer_generator)
        g_scaler.update()

In [None]:
def training_loop(discriminator_m: nn.Module, generator_p2m: nn.Module, discriminator_p: nn.Module, generator_m2p: nn.Module,
    loader: DataLoader, epochs: int, params: Hyperparameters):

    optimizer_discriminator = optim.Adam(
        list(discriminator_m.parameters()) + list(discriminator_p.parameters()),
        lr=params.lr,
        betas=(0.5, 0.999)
    )
    optimizer_generator = optim.Adam(
        list(generator_p2m.parameters()) + list(generator_m2p.parameters()),
        lr=params.lr,
        betas=(0.5, 0.999)
    )

    d_scaler = torch.cuda.amp.grad_scaler.GradScaler()
    g_scaler = torch.cuda.amp.grad_scaler.GradScaler()

    # visualization
    if params.visualize:
        images = next(iter(loader))[0].to(params.device)

    for epoch in range(epochs):
        print(f"Epoch: {epoch}")
        
        training_step(
            discriminator_m, generator_p2m, discriminator_p, generator_m2p, 
            optimizer_discriminator, optimizer_generator, loader, d_scaler, g_scaler, params
        )

        # vizualization
        if params.visualize:
            visualize_grid(images, generator_p2m, f"visualization/grid{epoch}.jpg")

        # save model
        save_model(f"{params.save_path}/discriminator_m.pt", discriminator_m, optimizer_discriminator)
        save_model(f"{params.save_path}/generator_p2m.pt", generator_p2m, optimizer_generator)
        save_model(f"{params.save_path}/discriminator_p.pt", discriminator_p, optimizer_discriminator)
        save_model(f"{params.save_path}/generator_m2p.pt", generator_m2p, optimizer_generator)

### Setup

In [None]:
params = Hyperparameters(
    batch_size=4, lr=3e-4, lambda_cycle=10.0, lambda_identity=10.0, 
    device=torch.device("cuda" if torch.cuda.is_available() else "cpu"), save_path="model", load_model=False, visualize=True, train_model=True
    )

In [None]:
transform = A.Compose(
    [
        A.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), max_pixel_value=255),
        A.HorizontalFlip(p=0.3),
        A.RandomResizedCrop(256, 256, p=0.2),
        ToTensorV2()
    ],
    additional_targets={"image0": "image"}
)

In [None]:
photos = glob.glob("photo_jpg/*.jpg")
monet = glob.glob("monet_jpg/*.jpg")

In [None]:
dataset = Photo2MonetDataset(photos, monet, transform)
loader = DataLoader(dataset, batch_size=params.batch_size, shuffle=True)

In [None]:
def load_checkpoint(checkpoint_path):
    return torch.load(checkpoint_path)["model_state_dict"]

In [None]:
discriminator_m = Discriminator().to(params.device)
generator_p2m = Generator().to(params.device)
discriminator_p = Discriminator().to(params.device)
generator_m2p = Generator().to(params.device)

In [None]:
if params.load_model:
    discriminator_m.load_state_dict(load_checkpoint("model/discriminator_m.pt"))
    generator_p2m.load_state_dict(load_checkpoint("model/generator_p2m.pt"))
    discriminator_p.load_state_dict(load_checkpoint("model/discriminator_p.pt"))
    generator_m2p.load_state_dict(load_checkpoint("model/generator_m2p.pt"))

### Training

In [None]:
if params.train_model:
    training_loop(discriminator_m, generator_p2m, discriminator_p, generator_m2p, loader, epochs=40, params=params)

Epoch: 0


0it [00:00, ?it/s]

Epoch: 1


0it [00:00, ?it/s]

Epoch: 2


0it [00:00, ?it/s]

Epoch: 3


0it [00:00, ?it/s]

Epoch: 4


0it [00:00, ?it/s]

Epoch: 5


0it [00:00, ?it/s]

Epoch: 6


0it [00:00, ?it/s]

Epoch: 7


0it [00:00, ?it/s]

Epoch: 8


0it [00:00, ?it/s]

Epoch: 9


0it [00:00, ?it/s]

Epoch: 10


0it [00:00, ?it/s]

Epoch: 11


0it [00:00, ?it/s]

Epoch: 12


0it [00:00, ?it/s]

Epoch: 13


0it [00:00, ?it/s]

Epoch: 14


0it [00:00, ?it/s]

Epoch: 15


0it [00:00, ?it/s]

Epoch: 16


0it [00:00, ?it/s]

Epoch: 17


0it [00:00, ?it/s]

### Sampling

In [None]:
class PhotoDataset(Dataset):
    def __init__(self, photos: list, transform = None):
        
        self.photos = photos
        self.monet = monet
        self.transform = transform

    def __len__(self):
        return len(self.photos)

    def __getitem__(self, index):
        photo = np.array(Image.open(self.photos[index]))

        if self.transform:
            transformed = self.transform(image=photo)
            photo = transformed["image"]
        
        return photo, self.photos[index].split("/")[-1].split(".")[0]

In [None]:
transform_photo = A.Compose(
    [
        A.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), max_pixel_value=255),
        ToTensorV2()
    ]
)

In [None]:
photos_full = glob.glob("photo_jpg/*.jpg")

In [None]:
photos_data = PhotoDataset(photos_full, transform_photo)
photos_loader = DataLoader(photos_data, batch_size=params.batch_size, shuffle=False)

In [None]:
model = torch.load("model/generator_p2m.pt")
generator_p2m = Generator().to(params.device)
generator_p2m.load_state_dict(model["model_state_dict"])

In [None]:
@torch.no_grad()
def sample(generator_p2m: nn.Module, loader: DataLoader, params: Hyperparameters):
    for photo, name in tqdm(loader):
        photo = photo.to(params.device)
        fake_monet = generator_p2m(photo)
        
        # denormalization [-1, 1] -> [0, 1]
        fake_monet = 0.5*fake_monet.cpu() + 0.5

        for i in range(fake_monet.shape[0]):
            torchvision.utils.save_image(fake_monet[i, :, :, :], f"images/{name[i]}_monet.jpg")