# Run all the cells

In [None]:
# ! pip install albumentations
from PIL import Image
import os
import sys
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
from tqdm import tqdm_notebook
from torchvision.utils import save_image
import albumentations as A
from albumentations.pytorch import ToTensorV2

# Change Batch and Num_epochs according to the compute power

In [None]:
data_dir = '/kaggle/input/gan-getting-started'
Device = "cuda" if torch.cuda.is_available() else "cpu"
Batch = 10
Learning_rate = 1e-5
Lambda_identity = 5
Lambda_cycle = 10
Num_workers = 4
Num_epochs = 6
Load_model = True
Save_model = True
transforms = A.Compose([A.Resize(width = 256, height = 256),
                       A.HorizontalFlip(p=0.5),
                       A.Normalize(mean=[0.5,0.5,0.5], std = [0.5,0.5,0.5], max_pixel_value = 255),
                       ToTensorV2()], additional_targets = {'image0':'image'})

# Cutom Dataset

In [None]:
class monet_dataset(Dataset):
    def __init__(self, root_img, root_monet, transform=None):
        self.root_img = root_img
        self.root_monet = root_monet
        self.transform = transform
        
        self.img_images = os.listdir(root_img)
        self.monet_images = os.listdir(root_monet)
        self.length_dataset = max(len(self.img_images), len(self.monet_images))
        self.img_len = len(self.img_images)
        self.monet_len = len(self.monet_images)
        
    def __len__(self):
        return self.length_dataset
    
    def __getitem__(self, index):
        img_image = self.img_images[index%self.img_len]
        monet_image = self.monet_images[index%self.monet_len]
        
        img_path = os.path.join(self.root_img,img_image)
        monet_path = os.path.join(self.root_monet, monet_image)
        
        img_image = np.array(Image.open(img_path).convert("RGB"))
        monet_image = np.array(Image.open(monet_path).convert("RGB"))
        
        if self.transform:
            aug = self.transform(image=img_image, image0=monet_image)
            img_image = aug["image"]
            monet_image = aug["image0"]
            
        return img_image, monet_image

# Generator model for cycle Gan

In [None]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, down=True, use_act=True, **kwargs):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, padding_mode="reflect", **kwargs)
            if down
            else nn.ConvTranspose2d(in_channels, out_channels, **kwargs),
            nn.InstanceNorm2d(out_channels),
            nn.ReLU(inplace=True) if use_act else nn.Identity()
        )

    def forward(self, x):
        return self.conv(x)

class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.block = nn.Sequential(
            ConvBlock(channels, channels, kernel_size=3, padding=1),
            ConvBlock(channels, channels, use_act=False, kernel_size=3, padding=1),
        )

    def forward(self, x):
        return x + self.block(x)

class Generator(nn.Module):
    def __init__(self, img_channels, num_features = 64, num_residuals=9):
        super().__init__()
        self.initial = nn.Sequential(
            nn.Conv2d(img_channels, num_features, kernel_size=7, stride=1, padding=3, padding_mode="reflect"),
            nn.InstanceNorm2d(num_features),
            nn.ReLU(inplace=True),
        )
        self.down_blocks = nn.ModuleList(
            [
                ConvBlock(num_features, num_features*2, kernel_size=3, stride=2, padding=1),
                ConvBlock(num_features*2, num_features*4, kernel_size=3, stride=2, padding=1),
            ]
        )
        self.res_blocks = nn.Sequential(
            *[ResidualBlock(num_features*4) for _ in range(num_residuals)]
        )
        self.up_blocks = nn.ModuleList(
            [
                ConvBlock(num_features*4, num_features*2, down=False, kernel_size=3, stride=2, padding=1, output_padding=1),
                ConvBlock(num_features*2, num_features*1, down=False, kernel_size=3, stride=2, padding=1, output_padding=1),
            ]
        )

        self.last = nn.Conv2d(num_features*1, img_channels, kernel_size=7, stride=1, padding=3, padding_mode="reflect")

    def forward(self, x):
        x = self.initial(x)
        for layer in self.down_blocks:
            x = layer(x)
        x = self.res_blocks(x)
        for layer in self.up_blocks:
            x = layer(x)
        return torch.tanh(self.last(x))

# Discriminator for Cycle Gan

In [None]:
class Block(nn.Module):
    def __init__(self, in_channels, out_channels, stride):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 4, stride, 1, bias=True, padding_mode="reflect"),
            nn.InstanceNorm2d(out_channels),
            nn.LeakyReLU(0.2, inplace=True),
        )

    def forward(self, x):
        return self.conv(x)


class Discriminator(nn.Module):
    def __init__(self, in_channels=3, features=[64, 128, 256, 512]):
        super().__init__()
        self.initial = nn.Sequential(
            nn.Conv2d(
                in_channels,
                features[0],
                kernel_size=4,
                stride=2,
                padding=1,
                padding_mode="reflect",
            ),
            nn.LeakyReLU(0.2, inplace=True),
        )

        layers = []
        in_channels = features[0]
        for feature in features[1:]:
            layers.append(Block(in_channels, feature, stride=1 if feature==features[-1] else 2))
            in_channels = feature
        layers.append(nn.Conv2d(in_channels, 1, kernel_size=4, stride=1, padding=1, padding_mode="reflect"))
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        x = self.initial(x)
        return torch.sigmoid(self.model(x))

# Function to train the models once !

In [None]:
def train_fn(disc_i, disc_m, gen_m, gen_i,loader, opt_disc, opt_gen, l1, mse, d_scaler, g_scaler,epoch):
    m_reals = 0
    m_fakes = 0
    loop = tqdm_notebook(loader, leave=True)
    
    for idx, (img,monet) in enumerate(loop):
        img = img.to(Device)
        monet = monet.to(Device)
        
        
        with torch.cuda.amp.autocast():
            fake_monet = gen_m(img)
            D_M_real = disc_m(monet)
            D_M_fake = disc_m(fake_monet.detach())
            m_reals += D_M_real.mean().item()
            m_fakes +=D_M_fake.mean().item()
            D_M_real_loss = mse(D_M_real, torch.ones_like(D_M_real))
            D_M_fake_loss = mse(D_M_fake, torch.ones_like(D_M_fake))
            D_M_loss = D_M_fake_loss+D_M_real_loss
            
            fake_img = gen_i(monet)
            D_I_real = disc_i(img)
            D_I_fake = disc_i(fake_img.detach())
            D_I_fake_loss = mse(D_I_fake, torch.ones_like(D_I_fake))
            D_I_real_loss = mse(D_I_real, torch.ones_like(D_I_real))
            D_I_loss = D_I_real_loss+D_I_fake_loss
            
            D_loss = (D_I_loss+D_M_loss)/2
            
        opt_disc.zero_grad()
        d_scaler.scale(D_loss).backward()
        d_scaler.step(opt_disc)
        d_scaler.update()
        
        with torch.cuda.amp.autocast():
            D_M_fake = disc_i(fake_monet)
            D_I_fake = disc_m(fake_img)
            loss_G_M = mse(D_M_fake, torch.ones_like(D_M_fake))
            loss_G_I = mse(D_I_fake, torch.ones_like(D_I_fake))
            
            cycle_monet = gen_m(fake_img)
            cycle_img = gen_i(fake_monet)
            
            cycle_monet_loss = l1(monet, cycle_monet)
            cycle_img_loss = l1(img, cycle_img)
            
            identity_monet = gen_m(monet)
            identity_img = gen_i(img)
            identity_monet_loss = l1(monet, identity_monet)
            identity_img_loss = l1(img, identity_img)
            
            G_loss = (
            loss_G_I+loss_G_M+
            cycle_img_loss*Lambda_cycle+
            cycle_monet_loss*Lambda_cycle+
            identity_img_loss*Lambda_identity+
            identity_monet_loss*Lambda_identity)
            
        opt_gen.zero_grad()
        g_scaler.scale(G_loss).backward()
        g_scaler.step(opt_gen)
        g_scaler.update()
        
        if idx % 100 == 0:
            save_image(fake_img*0.5+0.5, f"saved_images/photo_{epoch}_{idx}.png")
            save_image(fake_monet*0.5+0.5, f"saved_images/monet_{epoch}_{idx}.png")
        if idx%500==0:
            torch.save(gen_i,f'Generator_img_{epoch}.pth')
            torch.save(gen_m,f'Generator_monet_{epoch}.pth')
            torch.save(disc_i,f'Discriminator_img_{epoch}.pth')
            torch.save(disc_m,f'Discriminator_monet_{epoch}.pth')

        loop.set_postfix(m_real=m_reals/(idx+1), m_fake=m_fakes/(idx+1))

# Making instance of generators and discriminators outside the main function
# Main function loads the dataset and trains the models for Num_epochs

In [None]:
disc_I = Discriminator(in_channels=3).to(Device)
disc_M = Discriminator(in_channels=3).to(Device)
gen_M = Generator(img_channels=3).to(Device)
gen_I = Generator(img_channels=3).to(Device)

def main(disc_I, disc_M, gen_M, gen_I):
    disc_I = disc_I
    disc_M = disc_M
    gen_M = gen_M
    gen_I = gen_I
    opt_disc = optim.Adam(
        list(disc_I.parameters()) + list(disc_M.parameters()),
        lr=Learning_rate,
        betas=(0.5, 0.999),
    )

    opt_gen = optim.Adam(
        list(gen_I.parameters()) + list(gen_M.parameters()),
        lr=Learning_rate,
        betas=(0.5, 0.999),
    )
    
    L1 = nn.L1Loss()
    mse = nn.MSELoss()
    
    dataset = monet_dataset(root_img=os.path.join(data_dir,"photo_jpg"),
                            root_monet=os.path.join(data_dir,"monet_jpg"),
                           transform = transforms)
    
    loader = DataLoader(dataset, batch_size=Batch, shuffle=True, pin_memory = True) # num_workers=Num_workers,
    g_scaler = torch.cuda.amp.GradScaler()
    d_scaler = torch.cuda.amp.GradScaler()
    
    for epoch in range(Num_epochs):
        train_fn(disc_I, disc_M, gen_M, gen_I, loader, opt_disc, opt_gen, L1, mse, d_scaler, g_scaler,epoch)

# Creating directory to save images and running the main function

In [None]:
if not os.path.exists('saved_images'):
    os.makedirs('saved_images')
main(disc_I, disc_M, gen_M, gen_I)

# Note : I was able to run it for 15 epochs so I load the weight of 15th epoch (in this case 14)

In [None]:
gen_M = Generator(img_channels=3).to(Device)
gen_M = torch.load('saved_models/Generator_monet_14.pth')

# Checking number of files in photo_jpg directory

In [None]:
photo_dir = os.path.join(data_dir,"photo_jpg")
files = [os.path.join(photo_dir, name) for name in os.listdir(photo_dir)]
print("Number of files found",len(files))

# Creating a dataloder with batchsize = 1 to save the results

In [None]:
transforms = A.Compose([A.Normalize(mean=[0.5,0.5,0.5], std = [0.5,0.5,0.5], max_pixel_value = 255),
                       ToTensorV2()], additional_targets = {'image0':'image'})
dataset = monet_dataset(root_img=os.path.join(data_dir,"photo_jpg"),
                            root_monet=os.path.join(data_dir,"monet_jpg"),
                           transform = transforms)
    
loader = DataLoader(dataset, batch_size=1, shuffle=True, pin_memory = True)

# Generating the monet images from photos and saving them individually !

In [None]:
loop = tqdm_notebook(loader, leave=True)
for idx, (img,_) in enumerate(loop):
    img = img.to(Device)
    monet = gen_M(img).detach().cpu()
    monet = monet*0.5+0.5
    save_image(monet, f"generated_images/photo_to_monet_{idx}.png")

# Zipping the contents

In [None]:
import shutil
shutil.make_archive('generated_images', 'zip','generated_images')