# Paint Like Monet? AI Art Challenge with GANs

In this contest, our aim is to harness the power of Generative Adversarial Networks (GANs) to replicate Claude Monet's distinctive artistic style in digital imagery. GANs are comprised of a generator and a discriminator: the generator crafts images in the style of Monet, while the discriminator discerns between authentic and synthesized images. Our objective is to produce between 7,000 to 10,000 Monet-esque images, which will be assessed using the **MiFID (Memorization-informed Fréchet Inception Distance)** metric.

GANs, in particular, represent a robust category of generative models featuring two neural networks: a generator and a discriminator. Through adversarial training, the generator aims to create realistic data samples, while the discriminator strives to differentiate between real and fake ones. This ongoing back-and-forth fosters the generator's ability to generate high-quality, lifelike data.

## DATA

Dataset is organized into four directories: monet_tfrec, photo_tfrec, monet_jpg, and photo_jpg. Both monet_tfrec and monet_jpg hold identical painting images, while photo_tfrec and photo_jpg contain matching sets of photos.

We suggest utilizing TFRecords, as this competition offers an opportunity to familiarize oneself with this data format. However, JPEG images are also provided for convenience.

Focus your model training on the images within the monet directories, which comprise Monet paintings. These will serve as the basis for teaching your model.

The photo directories contain various photos that need to be imbued with Monet-style characteristics. Your task is to infuse these images with Monet's artistic style and submit the resulting JPEGs in a zip file. Ensure that your submission contains no more than 10,000 images.

It's worth noting that Monet-style art can be crafted from scratch using alternative GAN architectures such as DCGAN. Therefore, the submitted image files do not necessarily have to be transformed photos from the provided dataset.

https://www.kaggle.com/competitions/gan-getting-started/data

In [None]:
import os
import math
import time
import shutil

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import cv2
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from torch.utils.data import Dataset, random_split, DataLoader

import torchvision.models as models
import torchvision.transforms as transforms

from tqdm.notebook import tqdm
import itertools

In [None]:
BASE_INPUT_PATH = '/kaggle/input/gan-getting-started/'


MONET_PATH = os.path.join(BASE_INPUT_PATH, "monet_jpg")
PHOTO_PATH = os.path.join(BASE_INPUT_PATH, "photo_jpg")
OUTPUT_PATH = os.path.join('/kaggle/images')

print(f"Monet paintings count : {len(os.listdir(MONET_PATH))}")

## Exploratory Data Analysis (EDA)

* Monets, Photos sized 256x256 in JPEG format

In [None]:
def display_images_grid(directory_path, num_samples=9):
    """
    Helper method display images in GRID
    """
    image_files = os.listdir(directory_path)[:num_samples]
    
    num_cols = int(math.sqrt(num_samples))
    num_rows = math.ceil(num_samples / num_cols)
    
    fig, axes = plt.subplots(num_rows, num_cols, figsize=(12, 12))
    
    for i, image_name in enumerate(image_files):
        img = cv2.imread(os.path.join(directory_path, image_name))
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        
        if num_rows == 1:
            ax = axes[i % num_cols]
        else:
            ax = axes[i // num_cols, i % num_cols]
        
        ax.imshow(img)
        ax.axis('off')
    
    plt.tight_layout()
    plt.show()

### Sample Monet Images

In [None]:
display_images_grid(MONET_PATH)

### Sample Photos

In [None]:
display_images_grid(PHOTO_PATH)

## Model Building

As we are familiar with, a Generative Adversarial Network (GAN) is a machine learning model comprising a generator and a discriminator engaged in a competitive process to generate realistic data. In this project, we will employ CycleGAN, a specific variant of GAN designed for image-to-image translation tasks lacking paired data. CycleGAN utilizes two generators and two discriminators to facilitate the translation of images from one domain to another while maintaining consistency in reversibility. 

This approach finds widespread application in tasks such as style transfer and image transformation.

### Dataset

We establish a dataset for training our generative model. It is designed to accommodate the two directories containing our Monet paintings and photos. This class is responsible for loading and preprocessing the images, which includes resizing and optionally normalization, to ensure uniform input for our neural network. Through the __getitem__ method, we randomly select a Monet-style image and its corresponding photo, process them, and return them as tensors. Additionally, the __len__ method ensures that the dataset size is constrained by the number of available pairs, returning the length of the smaller set of images.

In [None]:
class ImageDataset(Dataset):
    def __init__(self, path_monet, path_photo, size=(256, 256), normalize=True):
        super().__init__()
        self.monet_dir = path_monet
        self.photo_dir = path_photo
        self.monet_files = os.listdir(self.monet_dir)
        self.photo_files = os.listdir(self.photo_dir)
        
        if normalize:
            self.transform = transforms.Compose([
                transforms.Resize(size),
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))                                
            ])
        else:
            self.transform = transforms.Compose([
                transforms.Resize(size),
                transforms.ToTensor()])
        
        self.monet_idx = {i: fl for i, fl in enumerate(os.listdir(self.monet_dir))}
        self.photo_idx = {i: fl for i, fl in enumerate(os.listdir(self.photo_dir))}
            
    def __getitem__(self, idx):
        """
        Randomly select pair of photo and monet
        """
        rand_idx = np.random.randint(0, len(self.monet_files))
        photo_path = os.path.join(self.photo_dir, self.photo_files[idx % len(self.photo_files)])
        monet_path = os.path.join(self.monet_dir, self.monet_files[rand_idx])
        
        photo_img = Image.open(photo_path)
        photo_img = self.transform(photo_img)
        
        monet_img = Image.open(monet_path)
        monet_img = self.transform(monet_img)
        
        return photo_img, monet_img

    def __len__(self):
        return min(len(self.monet_files), len(self.photo_files))


In [None]:
image_ds = ImageDataset(MONET_PATH, PHOTO_PATH)
print(f"Length: {image_ds.__len__()}")

image_dl = DataLoader(image_ds, batch_size=1, pin_memory=True)

# Using GPU
device = 'cuda' if torch.cuda.is_available() else 'cpu'

### Cycle GAN

**CycleGAN** is a type of neural network used for image-to-image translation tasks. It can transform images from one domain to another without needing paired examples for training. Instead of pairs, it uses cycle-consistency, where images translated back and forth between domains should resemble the originals. This approach makes it useful for various tasks like style transfer or changing seasons in images.

Reverse normalization operation to restore the original image from its normalized form

In [None]:
def reverse_normalize(img, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]):
    for channel, mu, sigma in zip(img, mean, std):
        channel.mul_(sigma).add_(mu)
        
    return img

In [None]:
def Upsample(in_ch, out_ch, use_dropout=True, dropout_ratio=0.5):
    if use_dropout:
        return nn.Sequential(
            nn.ConvTranspose2d(in_ch, out_ch, 3, stride=2, padding=1, output_padding=1),
            nn.InstanceNorm2d(out_ch),
            nn.Dropout(dropout_ratio),
            nn.GELU())
    else:
        return nn.Sequential(
            nn.ConvTranspose2d(in_ch, out_ch, 3, stride=2, padding=1, output_padding=1),
            nn.InstanceNorm2d(out_ch),
            nn.GELU())

In [None]:
def Convlayer(in_ch, out_ch, kernel_size=3, stride=2, use_leaky=True, use_inst_norm=True, use_pad=True):
    if use_pad:
        conv = nn.Conv2d(in_ch, out_ch, kernel_size, stride, 1, bias=True)
    else:
        conv = nn.Conv2d(in_ch, out_ch, kernel_size, stride, 0, bias=True)

    if use_leaky:
        actv = nn.LeakyReLU(negative_slope=0.2, inplace=True)
    else:
        actv = nn.GELU()

    if use_inst_norm:
        norm = nn.InstanceNorm2d(out_ch)
    else:
        norm = nn.BatchNorm2d(out_ch)

    return nn.Sequential(conv, norm, actv)

In [None]:
class Resblock(nn.Module):
    def __init__(self, in_features, use_dropout=True, dropout_ratio=0.5):
        super().__init__()
        layers = list()
        layers.append(nn.ReflectionPad2d(1))
        layers.append(Convlayer(in_features, in_features, 3, 1, False, use_pad=False))
        layers.append(nn.Dropout(dropout_ratio))
        layers.append(nn.ReflectionPad2d(1))
        layers.append(nn.Conv2d(in_features, in_features, 3, 1, padding=0, bias=True))
        layers.append(nn.InstanceNorm2d(in_features))
        self.res = nn.Sequential(*layers)

    def forward(self, x):
        return x + self.res(x)

In [None]:
class Generator(nn.Module):
    def __init__(self, in_ch, out_ch, num_res_blocks=6):
        super().__init__()
        model = list()
        model.append(nn.ReflectionPad2d(3))
        model.append(Convlayer(in_ch, 64, 7, 1, False, True, False))
        model.append(Convlayer(64, 128, 3, 2, False))
        model.append(Convlayer(128, 256, 3, 2, False))
        for _ in range(num_res_blocks):
            model.append(Resblock(256))
        model.append(Upsample(256, 128))
        model.append(Upsample(128, 64))
        model.append(nn.ReflectionPad2d(3))
        model.append(nn.Conv2d(64, out_ch, kernel_size=7, padding=0))
        model.append(nn.Tanh())

        self.gen = nn.Sequential(*model)

    def forward(self, x):
        return self.gen(x)

In [None]:
class Discriminator(nn.Module):
    def __init__(self, in_ch, num_layers=4):
        super().__init__()
        model = list()
        model.append(nn.Conv2d(in_ch, 64, 4, stride=2, padding=1))
        model.append(nn.LeakyReLU(0.2, inplace=True))
        for i in range(1, num_layers):
            in_chs = 64 * 2**(i-1)
            out_chs = in_chs * 2
            if i == num_layers -1:
                model.append(Convlayer(in_chs, out_chs, 4, 1))
            else:
                model.append(Convlayer(in_chs, out_chs, 4, 2))
        model.append(nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=1))
        self.disc = nn.Sequential(*model)

    def forward(self, x):
        return self.disc(x)

In [None]:
def init_weights(net, init_type='normal', std=0.02):
    def init_func(m):
        classname = m.__class__.__name__
        if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
            init.normal_(m.weight.data, 0.0, std)
            if hasattr(m, 'bias') and m.bias is not None:
                init.constant_(m.bias.data, 0.0)
        elif classname.find('BatchNorm2d') != -1:
            init.normal_(m.weight.data, 1.0, std)
            init.constant_(m.bias.data, 0.0)
    net.apply(init_func)

The "sample fake" mechanism in CycleGAN serves to enhance training stability by selecting a subset of fabricated images and utilizing them as inputs for the discriminator. This method mitigates input variability, curbing substantial fluctuations in discriminator outputs across training iterations. By managing the number of fake images per iteration and introducing randomness in their selection, this mechanism fosters stability and facilitates smoother convergence in CycleGAN training.

Class designed to store 50 fabricated images and efficiently sample from them for input into the discriminator.

In [None]:
class sample_fake(object):
    def __init__(self, max_imgs=50):
        self.max_imgs = max_imgs
        self.cur_img = 0
        self.imgs = list()

    def __call__(self, imgs):
        ret = list()
        for img in imgs:
            if self.cur_img < self.max_imgs:
                self.imgs.append(img)
                ret.append(img)
                self.cur_img += 1
            else:
                if np.random.ranf() > 0.5:
                    idx = np.random.randint(0, self.max_imgs)
                    ret.append(self.imgs[idx])
                    self.imgs[idx] = img
                else:
                    ret.append(img)
        return ret

In [None]:

def update_req_grad(models, requires_grad=True):
    """
    Model parameters should be trainable or not
    """
    for model in models:
        for param in model.parameters():
            param.requires_grad = requires_grad

Learning rate scheduling dynamically adjusts the learning rate during training to optimize neural network convergence. Typically, it involves gradually reducing the learning rate over time, aiding smoother convergence and reaching a better minimum for the loss function. Employing a linear decay strategy, the learning rate diminishes gradually as training advances beyond a certain threshold, enhancing training stability effectively.

In [None]:
class learning_rate_scheduling():
    def __init__(self, decay_epochs=100, total_epochs=200):
        self.decay_epochs = decay_epochs
        self.total_epochs = total_epochs

    def step(self, epoch_num):
        if epoch_num <= self.decay_epochs:
            return 1.0
        else:
            fract = (epoch_num - self.decay_epochs)  / (self.total_epochs - self.decay_epochs)
            return 1.0 - fract

Class (**AvgStats**) is crafted to accumulate and monitor key metrics, such as loss values and iteration counts, throughout the training process. These metrics serve diverse purposes, such as tracking training progress, visualizing loss trends, and informing decisions regarding model training strategies.

In [None]:
class AvgStats(object):
    """
    Class to save training metrics
    """
    def __init__(self):
        self.reset()
        
    def reset(self):
        self.losses = []
        self.iterations = []
        
    def append(self, loss, iteration):
        self.losses.append(loss)
        self.iterations.append(iteration)

The subsequent class encapsulates the **Cycle GAN framework** along with its associated training loop. Its objective is to train the model in translating images between two domains, namely photos and Monet paintings. The class encompasses the training logic for both generators and discriminators, incorporates learning rate scheduling, and facilitates tracking of training statistics. Throughout the training process, the model's losses are backpropagated, and parameters are updated accordingly to optimize performance.

In [None]:
class CycleGAN(object):
    def __init__(self, in_ch, out_ch, epochs, start_lr=2e-4, lmbda=10, idt_coef=0.5, decay_epoch=0):
        self.epochs = epochs
        self.decay_epoch = decay_epoch if decay_epoch > 0 else int(self.epochs/2)
        self.lmbda = lmbda
        self.idt_coef = idt_coef
        self.device = torch.device(device)
        self.gen_mtp = Generator(in_ch, out_ch)
        self.gen_ptm = Generator(in_ch, out_ch)
        self.desc_m = Discriminator(in_ch)
        self.desc_p = Discriminator(in_ch)
        self.init_models()
        self.mse_loss = nn.MSELoss()
        self.l1_loss = nn.L1Loss()
        self.adam_gen = torch.optim.Adam(itertools.chain(self.gen_mtp.parameters(), self.gen_ptm.parameters()),
                                         lr = start_lr, betas=(0.5, 0.999))
        self.adam_desc = torch.optim.Adam(itertools.chain(self.desc_m.parameters(), self.desc_p.parameters()),
                                          lr=start_lr, betas=(0.5, 0.999))
        self.sample_monet = sample_fake()
        self.sample_photo = sample_fake()
        gen_lr = learning_rate_scheduling(self.decay_epoch, self.epochs)
        desc_lr = learning_rate_scheduling(self.decay_epoch, self.epochs)
        self.gen_lr_sched = torch.optim.lr_scheduler.LambdaLR(self.adam_gen, gen_lr.step)
        self.desc_lr_sched = torch.optim.lr_scheduler.LambdaLR(self.adam_desc, desc_lr.step)
        self.gen_stats = AvgStats()
        self.desc_stats = AvgStats()
        
    def init_models(self):
        init_weights(self.gen_mtp)
        init_weights(self.gen_ptm)
        init_weights(self.desc_m)
        init_weights(self.desc_p)
        self.gen_mtp = self.gen_mtp.to(self.device)
        self.gen_ptm = self.gen_ptm.to(self.device)
        self.desc_m = self.desc_m.to(self.device)
        self.desc_p = self.desc_p.to(self.device)
        
    def train(self, photo_dl):
        for epoch in range(self.epochs):
            start_time = time.time()
            avg_gen_loss = 0.0
            avg_desc_loss = 0.0
            t = tqdm(photo_dl, leave=False, total=photo_dl.__len__())
            for i, (photo_real, monet_real) in enumerate(t):
                photo_img, monet_img = photo_real.to(self.device), monet_real.to(self.device)
                update_req_grad([self.desc_m, self.desc_p], False)
                self.adam_gen.zero_grad()

                # forward pass through generator
                fake_photo = self.gen_mtp(monet_img)
                fake_monet = self.gen_ptm(photo_img)

                cycl_monet = self.gen_ptm(fake_photo)
                cycl_photo = self.gen_mtp(fake_monet)

                id_monet = self.gen_ptm(monet_img)
                id_photo = self.gen_mtp(photo_img)

                # generator losses
                idt_loss_monet = self.l1_loss(id_monet, monet_img) * self.lmbda * self.idt_coef
                idt_loss_photo = self.l1_loss(id_photo, photo_img) * self.lmbda * self.idt_coef

                cycle_loss_monet = self.l1_loss(cycl_monet, monet_img) * self.lmbda
                cycle_loss_photo = self.l1_loss(cycl_photo, photo_img) * self.lmbda

                monet_desc = self.desc_m(fake_monet)
                photo_desc = self.desc_p(fake_photo)

                real = torch.ones(monet_desc.size()).to(self.device)

                adv_loss_monet = self.mse_loss(monet_desc, real)
                adv_loss_photo = self.mse_loss(photo_desc, real)

                # total generator loss
                total_gen_loss = cycle_loss_monet + adv_loss_monet\
                              + cycle_loss_photo + adv_loss_photo\
                              + idt_loss_monet + idt_loss_photo
                
                avg_gen_loss += total_gen_loss.item()

                # backward pass
                total_gen_loss.backward()
                self.adam_gen.step()

                # forward pass through Descriminator
                update_req_grad([self.desc_m, self.desc_p], True)
                self.adam_desc.zero_grad()

                fake_monet = self.sample_monet([fake_monet.cpu().data.numpy()])[0]
                fake_photo = self.sample_photo([fake_photo.cpu().data.numpy()])[0]
                fake_monet = torch.tensor(fake_monet).to(self.device)
                fake_photo = torch.tensor(fake_photo).to(self.device)

                monet_desc_real = self.desc_m(monet_img)
                monet_desc_fake = self.desc_m(fake_monet)
                photo_desc_real = self.desc_p(photo_img)
                photo_desc_fake = self.desc_p(fake_photo)

                real = torch.ones(monet_desc_real.size()).to(self.device)
                fake = torch.zeros(monet_desc_fake.size()).to(self.device)

                # descriminator losses
                monet_desc_real_loss = self.mse_loss(monet_desc_real, real)
                monet_desc_fake_loss = self.mse_loss(monet_desc_fake, fake)
                photo_desc_real_loss = self.mse_loss(photo_desc_real, real)
                photo_desc_fake_loss = self.mse_loss(photo_desc_fake, fake)

                monet_desc_loss = (monet_desc_real_loss + monet_desc_fake_loss) / 2
                photo_desc_loss = (photo_desc_real_loss + photo_desc_fake_loss) / 2
                total_desc_loss = monet_desc_loss + photo_desc_loss
                avg_desc_loss += total_desc_loss.item()

                # backward
                monet_desc_loss.backward()
                photo_desc_loss.backward()
                self.adam_desc.step()
                
                t.set_postfix(gen_loss=total_gen_loss.item(), desc_loss=total_desc_loss.item())
            
            avg_gen_loss /= photo_dl.__len__()
            avg_desc_loss /= photo_dl.__len__()
            time_req = time.time() - start_time
            
            self.gen_stats.append(avg_gen_loss, time_req)
            self.desc_stats.append(avg_desc_loss, time_req)
            
            print(f"Epoch {epoch+1}  -  Generator Loss: {avg_gen_loss}  -  Discriminator Loss: {avg_desc_loss}")
      
            self.gen_lr_sched.step()
            self.desc_lr_sched.step()

### Training

With all necessary building blocks now defined, we can proceed to train our model. We instantiate a new object of our Cycle GAN class and initiate the training process with 50 epochs.

In [None]:
# epochs = 50 # for final submission
epochs = 5 # to make it faster

gan = CycleGAN(3, 3, epochs)
gan.train(image_dl)

In [None]:
plt.xlabel("Epochs")
plt.ylabel("Losses")
plt.plot(gan.gen_stats.losses, 'r', label='Generator Loss')
plt.plot(gan.desc_stats.losses, 'b', label='Descriminator Loss')
plt.legend()
plt.show()

The gradual decrease in both loss functions over time aligns with the expected behavior of a functioning model. To visualize the actual results, we can plot random photos alongside their corresponding Monet-esque counterparts generated by our model.

In [None]:
_, ax = plt.subplots(2, 4, figsize=(12, 8))
for i in range(4):
    photo_img, _ = next(iter(image_dl))
    pred_monet = gan.gen_ptm(photo_img.to(device)).cpu().detach()
    
    photo_img = reverse_normalize(photo_img)
    pred_monet = reverse_normalize(pred_monet)
    
    ax[0, i].imshow(photo_img[0].permute(1, 2, 0))
    ax[1, i].imshow(pred_monet[0].permute(1, 2, 0))
    ax[0, i].set_title("Photo")
    ax[1, i].set_title("Monet")
    ax[0, i].axis("off")
    ax[1, i].axis("off")

plt.show()

## Results

Now that our model is trained, we can utilize it to generate the desired Monet-like images from the provided photos, enabling us to prepare a submission for the competition.

In [None]:
class PhotoDataset(Dataset):
    def __init__(self, photo_dir, size=(256, 256), normalize=True):
        super().__init__()
        self.photo_dir = photo_dir
        self.photo_idx = dict()
        
        if normalize:
            self.transform = transforms.Compose([
                transforms.Resize(size),
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))                                
            ])
        else:
            self.transform = transforms.Compose([
                transforms.Resize(size),
                transforms.ToTensor()                               
            ])
            
        self.photo_idx = {i: fl for i, fl in enumerate(os.listdir(self.photo_dir))}

    def __getitem__(self, idx):
        photo_path = os.path.join(self.photo_dir, self.photo_idx[idx])
        
        photo_img = Image.open(photo_path)
        photo_img = self.transform(photo_img)
        
        return photo_img

    def __len__(self):
        return len(self.photo_idx.keys())

In [None]:
photo_ds = PhotoDataset(PHOTO_PATH)
photo_dl = DataLoader(photo_ds, batch_size=1, pin_memory=True)

In [None]:
# If OUTPUT_PATH doesn't exist, create it
if not os.path.exists(OUTPUT_PATH):
    os.makedirs(OUTPUT_PATH)

trans = transforms.ToPILImage()

i = 1
for photo in photo_dl:
    with torch.no_grad():
        pred_monet = gan.gen_ptm(photo.to(device)).cpu().detach()
    
    pred_monet = reverse_normalize(pred_monet)
    img = trans(pred_monet[0]).convert("RGB")
    
    img.save(os.path.join(OUTPUT_PATH, f'{i}.jpg'))
    i += 1

In [None]:
display_images_grid(OUTPUT_PATH)

In [None]:
shutil.make_archive("/kaggle/working/images", 'zip', "/kaggle/images")