In [None]:
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import os

import torch 
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
import torch.optim as optim
import torchvision
from torchvision import transforms
import albumentations as A
from albumentations.pytorch import ToTensorV2

from tqdm import tqdm

from PIL import Image
import matplotlib.pyplot as plt

In [None]:
# Showing images 

def show_images(img, nrow=5, title=""):
    img = img.detach().cpu() * 0.5 + 0.5
    img_grid = torchvision.utils.make_grid(img, nrow=nrow).permute(1, 2, 0)
    plt.figure(figsize=(8, 5))
    plt.imshow(img_grid)
    plt.axis("off")
    plt.title(title)
    plt.show()

In [None]:
# Creating dataset class

class MonetPhotoDataset(Dataset):
    def __init__(self, folder_path_monets, folder_path_photos, transform=None):
        self.folder_path_monets = folder_path_monets
        self.folder_path_photos = folder_path_photos
        self.transform = transform
        self.monets = [os.path.join(folder_path_monets, f) for f in os.listdir(folder_path_monets) if f.endswith(('.jpg', '.jpeg', '.png'))]
        self.photos = [os.path.join(folder_path_photos, f) for f in os.listdir(folder_path_photos) if f.endswith(('.jpg', '.jpeg', '.png'))]
        
    def __getitem__(self, idx):
        monet_path = self.monets[idx % len(self.monets)]
        photo_path = self.photos[idx % len(self.photos)]
        
        monet = np.array(Image.open(monet_path).convert('RGB')).astype('float32')
        photo = np.array(Image.open(photo_path).convert('RGB')).astype('float32')
        
                
        if self.transform:
            augmentations = transform(image=monet, image0=photo)
            monet = augmentations["image"]
            photo = augmentations["image0"]
            
        return monet, photo
    
    def __len__(self):
        return max(len(self.monets), len(self.photos))

In [None]:
# Device configuration

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Hyperparameters

num_epochs = 35
lr_disc = 2e-4
lr_gen = 2e-4
num_workers = 2
batch_size = 8
image_size = 128
channels_img = 3

LAMBDA_CYCLE = 10.0
LAMBDA_IDENTITY = 5.0
SAVE_MODEL = True

CHECKPOINT_DISC_M = "discM.pth.tar"
CHECKPOINT_DISC_P = "discP.pth.tar"
CHECKPOINT_GEN_M = "genM.pth.tar"
CHECKPOINT_GEN_P = "genP.pth.tar"


# Transformations

transform = A.Compose(
    [
        A.Resize(image_size, image_size),
        A.HorizontalFlip(),
        A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255.0),
        ToTensorV2(),
    ], additional_targets={"image0": "image"},
)

train_dataset = MonetPhotoDataset("/kaggle/input/gan-getting-started/monet_jpg", "/kaggle/input/gan-getting-started/photo_jpg", transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)

In [None]:
print(f"Total monet images: {len(train_dataset)}")

In [None]:
monet_img, photo_img = next(iter(train_loader))

show_images(monet_img[:5], nrow=5, title="Monet Images")
show_images(photo_img[:5], nrow=5, title="Normal Images")

In [None]:
# Discriminator Model

class BlockDisc(nn.Module):
    def __init__(self, in_channels, out_channels, stride=2):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 4, stride, 1, bias=True, padding_mode="reflect"),
            nn.InstanceNorm2d(out_channels),
            nn.LeakyReLU(0.2),
        )

    def forward(self, x):
        return self.conv(x)


class Discriminator(nn.Module):
    def __init__(self, in_channels, features=[64, 128, 256, 512]):
        super().__init__()

        self.initial = nn.Sequential(
            nn.Conv2d(in_channels, features[0], kernel_size=4, stride=2, padding=1, padding_mode="reflect"),
            # paper authors says that we don't use InstanceNorm for the first layer.
            nn.LeakyReLU(0.2),
        )

        in_channel = features[0]
        layers = []

        for feature in features[1:]:

            layers.append(
                BlockDisc(in_channel, feature, stride=1 if feature == features[-1] else 2),
            )
            in_channel = feature

        layers.append(
            nn.Conv2d(in_channel, 1, kernel_size=4, stride=1, padding=1, padding_mode="reflect")
        )

        self.model = nn.Sequential(*layers)

    def forward(self, x):
        x = self.initial(x)
        return torch.sigmoid(self.model(x))

In [None]:
# Generator Model

class BlockGen(nn.Module):
    def __init__(self, in_channels, out_channels, encoder=True, **kwargs): # **kwargs include kernel_size, stride etc.
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, padding_mode="reflect", **kwargs) if encoder
            else nn.ConvTranspose2d(in_channels, out_channels, **kwargs),
            nn.InstanceNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.conv(x)


class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(channels, channels, kernel_size=3, padding=1, stride=1, padding_mode="reflect"),
            nn.InstanceNorm2d(channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels, channels, kernel_size=3, padding=1, stride=1, padding_mode="reflect"),
            nn.InstanceNorm2d(channels),
        )

    def forward(self, x):
        return x + self.block(x) # !!!


class Generator(nn.Module):
    def __init__(self, in_channels, num_features=64, num_residuals=6):
        super().__init__()
        self.initial = nn.Sequential(
            nn.Conv2d(in_channels, num_features, kernel_size=7, stride=1, padding=3, padding_mode="reflect"),
            nn.InstanceNorm2d(num_features),
            nn.ReLU(inplace=True),
        )

        self.enc_blocks = nn.ModuleList([
            BlockGen(num_features, num_features*2, encoder=True, kernel_size=3, stride=2, padding=1),
            BlockGen(num_features*2, num_features*4, encoder=True, kernel_size=3, stride=2, padding=1),
        ])

        self.residual_blocks = nn.Sequential(
            *[ResidualBlock(num_features*4) for _ in range(num_residuals)]
        )

        self.dec_blocks = nn.ModuleList([
            BlockGen(num_features*4, num_features*2, encoder=False, kernel_size=3, stride=2, padding=1, output_padding=1), # if you don't use output padding test shape is gonna be 253x253 but we want this shape to be 256x256 which same as the input shape.
            BlockGen(num_features*2, num_features, encoder=False, kernel_size=3, stride=2, padding=1, output_padding=1),
        ])

        self.last = nn.Sequential(
            nn.Conv2d(num_features, in_channels, kernel_size=7, stride=1, padding=3, padding_mode="reflect"),
            nn.Tanh(),
        )

    def forward(self, x):
        x = self.initial(x)

        for layer in self.enc_blocks:
            x = layer(x)

        x = self.residual_blocks(x)

        for layer in self.dec_blocks:
            x = layer(x)

        return self.last(x)

def gen_test():
    x = torch.randn((10, 3, 128, 128))
    gen = Generator(in_channels=3, num_residuals=6)
    img = gen(x)
    print(img.shape)
    
gen_test()

In [None]:
# Plot for losses
def plot(g_loss, d_loss):
    plt.figure(figsize=(10, 5))
    plt.title("Generator and Discriminator Loss During Training")
    plt.plot(g_loss, label="Generator")
    plt.plot(d_loss, label="Discriminator")
    plt.xlabel("Iterations")
    plt.ylabel("Loss")
    plt.legend()
    plt.show()
    
# Saving the models
def save_checkpoint(model, optimizer, file_name = "my_checkpoint.pth.tar"):
    print("SAVING CHECKPOINT !")

    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict(),
    }

    torch.save(checkpoint, file_name)


In [None]:
import itertools

print(f"Device: {DEVICE}")

torch.backends.cudnn.benchmark = True

disc_P = Discriminator(in_channels=3).to(DEVICE)
disc_M = Discriminator(in_channels=3).to(DEVICE)
gen_P = Generator(in_channels=3, num_residuals=9).to(DEVICE)
gen_M = Generator(in_channels=3, num_residuals=9).to(DEVICE)

opt_disc = optim.Adam(
    list(disc_P.parameters()) + list(disc_M.parameters()) ,
    lr=lr_disc,
    betas=(0.5, 0.999)
)
opt_gen = optim.Adam(
    itertools.chain(gen_P.parameters(), gen_M.parameters()),
    lr=lr_gen,
    betas=(0.5, 0.999)
)

L1_loss = nn.L1Loss()
mse_loss = nn.MSELoss()

In [None]:
decay_epoch = 20

lambda_func = lambda epoch: 1 - max(0, epoch-decay_epoch)/(num_epochs-decay_epoch)

lr_scheduler_D = torch.optim.lr_scheduler.LambdaLR(opt_disc, lr_lambda=lambda_func)
lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(opt_gen, lr_lambda=lambda_func)

In [None]:
# Training

d_scaler = torch.cuda.amp.GradScaler()
g_scaler = torch.cuda.amp.GradScaler()


D_LOSS = []
G_LOSS = []
for epoch in range(num_epochs):
    loop = tqdm(train_loader, leave=True)
    # --------------------------------------------------------------------------------------------------
    print(f"EPOCH {epoch+1}")
    for idx, batch in enumerate(loop):
        m, p = batch
        m, p = m.to(DEVICE), p.to(DEVICE)

        # >>> DISCRIMINATOR

        with torch.cuda.amp.autocast():
            # FIRST DISCRIMINATOR
            fake_m = gen_M(p)
            d_m_real = disc_M(m)
            d_m_fake = disc_M(fake_m.detach())

            d_m_real_loss = mse_loss(d_m_real, torch.ones_like(d_m_real))
            d_m_fake_loss = mse_loss(d_m_fake, torch.zeros_like(d_m_fake))
            d_m_loss = d_m_real_loss + d_m_fake_loss

            # SECOND DISCRIMINATOR
            fake_p = gen_P(m)
            d_p_real = disc_P(p)
            d_p_fake = disc_P(fake_p.detach())

            d_p_real_loss = mse_loss(d_p_real, torch.ones_like(d_p_real))
            d_p_fake_loss = mse_loss(d_p_fake, torch.zeros_like(d_p_fake))
            d_p_loss = d_p_real_loss + d_p_fake_loss

            # PUT IT TOGETHER
            d_loss = (d_m_loss + d_p_loss) / 2

        opt_disc.zero_grad()
        d_scaler.scale(d_loss).backward()
        d_scaler.step(opt_disc)
        d_scaler.update()

        # >>> GENERATOR

        with torch.cuda.amp.autocast():
            # Adversarial Loss
            d_m_fake = disc_M(fake_m)
            d_p_fake = disc_P(fake_p)
            g_m_loss = mse_loss(d_m_fake, torch.ones_like(d_m_fake))
            g_p_loss = mse_loss(d_p_fake, torch.ones_like(d_p_fake))

            # Cycle Loss
            cycle_m = gen_P(fake_m)
            cycle_p = gen_M(fake_p)

            cycle_m_loss = L1_loss(m, cycle_m)
            cycle_p_loss = L1_loss(p, cycle_p)

            # Identity Loss
            identity_m = gen_M(m)
            identity_p = gen_P(p)

            identity_m_loss = L1_loss(m, identity_m)
            identity_p_loss = L1_loss(p, identity_p)

            # PUT ALL TOGETHER
            g_loss = (
                    g_m_loss
                    + g_p_loss
                    + cycle_m_loss * LAMBDA_CYCLE
                    + cycle_p_loss * LAMBDA_CYCLE
                    + identity_m_loss * LAMBDA_IDENTITY
                    + identity_p_loss * LAMBDA_IDENTITY
            )

        opt_gen.zero_grad()
        g_scaler.scale(g_loss).backward()
        g_scaler.step(opt_gen)
        g_scaler.update()

        D_LOSS.append(d_loss.item())
        G_LOSS.append(g_loss.item())

        if idx % 200 == 0:
            loop.set_postfix(
                d_loss=d_loss.item(),
                g_loss=g_loss.item(),
            )

    with torch.no_grad():
        show_images(img=p[:5], nrow=5, title="Real Photos")
        fake_m=fake_m.type(torch.float64)
        show_images(img=fake_m[:5], nrow=5, title="Monets")
    
    lr_scheduler_D.step()
    lr_scheduler_G.step()
    
    # --------------------------------------------------------------------------------------------------
    if SAVE_MODEL and epoch % 5 == 0:
        save_checkpoint(disc_M, opt_disc, file_name=CHECKPOINT_DISC_M)
        save_checkpoint(disc_P, opt_disc, file_name=CHECKPOINT_DISC_P)
        save_checkpoint(gen_M, opt_gen, file_name=CHECKPOINT_GEN_M)
        save_checkpoint(gen_P, opt_gen, file_name=CHECKPOINT_GEN_P)

In [None]:
plot(G_LOSS, D_LOSS)

In [None]:
save_dir = '../images'
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

In [None]:
generate_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

to_image = transforms.ToPILImage()

gen_M.eval()

data_dir = "/kaggle/input/gan-getting-started/photo_jpg"

files = [os.path.join(data_dir, name) for name in os.listdir(data_dir)]

Tensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.Tensor

for i in range(0, len(files), batch_size):
    # read images
    imgs = []
    for j in range(i, min(len(files), i+batch_size)):
        img = Image.open(files[j])
        img = generate_transforms(img)
        imgs.append(img)
    imgs = torch.stack(imgs, 0).type(Tensor)
    
    # generate
    fake_imgs = gen_M(imgs).detach().cpu()
    
    # save
    for j in range(fake_imgs.size(0)):
        img = fake_imgs[j].squeeze().permute(1, 2, 0)
        img_arr = img.numpy()
        img_arr = (img_arr - np.min(img_arr)) * 255 / (np.max(img_arr) - np.min(img_arr))
        img_arr = img_arr.astype(np.uint8)
        
        img = to_image(img_arr)
        _, name = os.path.split(files[i+j])
        img.save(os.path.join(save_dir, name))

In [None]:
import shutil
shutil.make_archive("/kaggle/working/images", 'zip', "/kaggle/images")