In [1]:
import numpy as np 
import pandas as pd
import matplotlib.pyplot as plt
import cv2
import os
import math
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from torch.utils.data import Dataset, random_split, DataLoader

import torchvision.models as models
import torchvision.transforms as transforms

from tqdm.notebook import tqdm
import itertools
import time
import shutil
from torch.utils.data import Dataset, random_split, DataLoader

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using {device} device')

# Define paths for datasets
path_monet = "/kaggle/input/gan-getting-started/monet_jpg/"
path_photo = "/kaggle/input/gan-getting-started/photo_jpg/"

Using cuda device


In [3]:
class ImageDataset(Dataset):
    def __init__(self, path_monet, path_photo, size=(256, 256), normalize=True):
        super().__init__()
        self.monet_dir = path_monet
        self.photo_dir = path_photo
        self.monet_idx = dict()
        self.photo_idx = dict()
        
        if normalize:
            self.transform = transforms.Compose([
                transforms.Resize(size),
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))                                
            ])
        else:
            self.transform = transforms.Compose([
                transforms.Resize(size),
                transforms.ToTensor()                               
            ])
            
        for i, fl in enumerate(os.listdir(self.monet_dir)):
            self.monet_idx[i] = fl
        for i, fl in enumerate(os.listdir(self.photo_dir)):
            self.photo_idx[i] = fl

    def __getitem__(self, idx):
        rand_idx = int(np.random.uniform(0, len(self.monet_idx.keys())))
        photo_path = os.path.join(self.photo_dir, self.photo_idx[rand_idx])
        monet_path = os.path.join(self.monet_dir, self.monet_idx[idx])
        photo_img = Image.open(photo_path)
        photo_img = self.transform(photo_img)
        monet_img = Image.open(monet_path)
        monet_img = self.transform(monet_img)
        return photo_img, monet_img

    def __len__(self):
        return min(len(self.monet_idx.keys()), len(self.photo_idx.keys()))

img_ds = ImageDataset(path_monet, path_photo)
img_dl = DataLoader(img_ds, batch_size=4, pin_memory=True)


def reverse_normalize(img, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]):
    for image, mu, std in zip(img, mean, std):
        image.mul_(std).add_(std)
        
    return img


In [4]:
def Upsample(in_ch, out_ch, use_dropout=True, dropout_ratio=0.5):
    if use_dropout:
        return nn.Sequential(
            nn.ConvTranspose2d(in_ch, out_ch, 3, stride=2, padding=1, output_padding=1),
            nn.InstanceNorm2d(out_ch),
            nn.Dropout(dropout_ratio),
            nn.GELU())
    else:
        return nn.Sequential(
            nn.ConvTranspose2d(in_ch, out_ch, 3, stride=2, padding=1, output_padding=1),
            nn.InstanceNorm2d(out_ch),
            nn.GELU())
# define a convolutional layer with options for padding and more
def Convlayer(in_ch, out_ch, kernel_size=3, stride=2, use_leaky=True, use_inst_norm=True, use_pad=True):
    if use_pad:
        conv = nn.Conv2d(in_ch, out_ch, kernel_size, stride, 1, bias=True)
    else:
        conv = nn.Conv2d(in_ch, out_ch, kernel_size, stride, 0, bias=True)

    if use_leaky:
        actv = nn.LeakyReLU(negative_slope=0.2, inplace=True)
    else:
        actv = nn.GELU()

    if use_inst_norm:
        norm = nn.InstanceNorm2d(out_ch)
    else:
        norm = nn.BatchNorm2d(out_ch)

    return nn.Sequential(conv, norm, actv)
class Resblock(nn.Module):
    def __init__(self, in_features, use_dropout=True, dropout_ratio=0.5):
        super().__init__()
        layers = list()
        layers.append(nn.ReflectionPad2d(1))
        layers.append(Convlayer(in_features, in_features, 3, 1, False, use_pad=False))
        layers.append(nn.Dropout(dropout_ratio))
        layers.append(nn.ReflectionPad2d(1))
        layers.append(nn.Conv2d(in_features, in_features, 3, 1, padding=0, bias=True))
        layers.append(nn.InstanceNorm2d(in_features))
        self.res = nn.Sequential(*layers)

    def forward(self, x):
        return x + self.res(x)
class Generator(nn.Module):
    def __init__(self, in_ch, out_ch, num_res_blocks=6):
        super().__init__()
        model = list()
        model.append(nn.ReflectionPad2d(3))
        model.append(Convlayer(in_ch, 64, 7, 1, False, True, False))
        model.append(Convlayer(64, 128, 3, 2, False))
        model.append(Convlayer(128, 256, 3, 2, False))
        for _ in range(num_res_blocks):
            model.append(Resblock(256))
        model.append(Upsample(256, 128))
        model.append(Upsample(128, 64))
        model.append(nn.ReflectionPad2d(3))
        model.append(nn.Conv2d(64, out_ch, kernel_size=7, padding=0))
        model.append(nn.Tanh())

        self.gen = nn.Sequential(*model)

    def forward(self, x):
        return self.gen(x)
class Discriminator(nn.Module):
    def __init__(self, in_ch, num_layers=4):
        super().__init__()
        model = list()
        model.append(nn.Conv2d(in_ch, 64, 4, stride=2, padding=1))
        model.append(nn.LeakyReLU(0.2, inplace=True))
        for i in range(1, num_layers):
            in_chs = 64 * 2**(i-1)
            out_chs = in_chs * 2
            if i == num_layers -1:
                model.append(Convlayer(in_chs, out_chs, 4, 1))
            else:
                model.append(Convlayer(in_chs, out_chs, 4, 2))
        model.append(nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=1))
        self.disc = nn.Sequential(*model)

    def forward(self, x):
        return self.disc(x)
    
# initalize with normally distributed weights around mean 0 and standard deviation of 0.02
def init_weights(net, init_type='normal', std=0.02):
    def init_func(m):
        classname = m.__class__.__name__
        if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
            init.normal_(m.weight.data, 0.0, std)
            if hasattr(m, 'bias') and m.bias is not None:
                init.constant_(m.bias.data, 0.0)
        elif classname.find('BatchNorm2d') != -1:
            init.normal_(m.weight.data, 1.0, std)
            init.constant_(m.bias.data, 0.0)
    net.apply(init_func)

def update_req_grad(models, requires_grad=True):
    for model in models:
        for param in model.parameters():
            param.requires_grad = requires_grad
class sample_fake(object):
    def __init__(self, max_imgs=50):
        self.max_imgs = max_imgs
        self.cur_img = 0
        self.imgs = list()

    def __call__(self, imgs):
        ret = list()
        for img in imgs:
            if self.cur_img < self.max_imgs:
                self.imgs.append(img)
                ret.append(img)
                self.cur_img += 1
            else:
                if np.random.ranf() > 0.5:
                    idx = np.random.randint(0, self.max_imgs)
                    ret.append(self.imgs[idx])
                    self.imgs[idx] = img
                else:
                    ret.append(img)
        return ret
class lr_sched():
    def __init__(self, decay_epochs=50, total_epochs=100):
        self.decay_epochs = decay_epochs
        self.total_epochs = total_epochs

    def step(self, epoch_num):
        if epoch_num <= self.decay_epochs:
            return 1.0
        else:
            fract = (epoch_num - self.decay_epochs)  / (self.total_epochs - self.decay_epochs)
            return 1.0 - fract

# class for saving some training metrics
class AvgStats(object):
    def __init__(self):
        self.reset()
        
    def reset(self):
        self.losses =[]
        self.its = []
        
    def append(self, loss, it):
        self.losses.append(loss)
        self.its.append(it)

# defines the Cycle GAN and its training loop
class CycleGAN(object):
    def __init__(self, in_ch, out_ch, epochs, device, start_lr=2e-4, lmbda=10, idt_coef=0.5, decay_epoch=0):
        self.epochs = epochs
        self.decay_epoch = decay_epoch if decay_epoch > 0 else int(self.epochs/2)
        self.lmbda = lmbda
        self.idt_coef = idt_coef
        self.device = device
        self.gen_mtp = Generator(in_ch, out_ch)
        self.gen_ptm = Generator(in_ch, out_ch)
        self.desc_m = Discriminator(in_ch)
        self.desc_p = Discriminator(in_ch)
        self.init_models()
        self.mse_loss = nn.MSELoss()
        self.l1_loss = nn.L1Loss()
        self.adam_gen = torch.optim.Adam(itertools.chain(self.gen_mtp.parameters(), self.gen_ptm.parameters()),
                                         lr = start_lr, betas=(0.5, 0.999))
        self.adam_desc = torch.optim.Adam(itertools.chain(self.desc_m.parameters(), self.desc_p.parameters()),
                                          lr=start_lr, betas=(0.5, 0.999))
        self.sample_monet = sample_fake()
        self.sample_photo = sample_fake()
        gen_lr = lr_sched(self.decay_epoch, self.epochs)
        desc_lr = lr_sched(self.decay_epoch, self.epochs)
        self.gen_lr_sched = torch.optim.lr_scheduler.LambdaLR(self.adam_gen, gen_lr.step)
        self.desc_lr_sched = torch.optim.lr_scheduler.LambdaLR(self.adam_desc, desc_lr.step)
        self.gen_stats = AvgStats()
        self.desc_stats = AvgStats()
        
    def init_models(self):
        init_weights(self.gen_mtp)
        init_weights(self.gen_ptm)
        init_weights(self.desc_m)
        init_weights(self.desc_p)
        self.gen_mtp = self.gen_mtp.to(self.device)
        self.gen_ptm = self.gen_ptm.to(self.device)
        self.desc_m = self.desc_m.to(self.device)
        self.desc_p = self.desc_p.to(self.device)
        
    def train(self, photo_dl):
        for epoch in range(self.epochs):
            start_time = time.time()
            avg_gen_loss = 0.0
            avg_desc_loss = 0.0
            t = tqdm(photo_dl, leave=False, total=photo_dl.__len__())
            for i, (photo_real, monet_real) in enumerate(t):
                photo_img, monet_img = photo_real.to(device), monet_real.to(device)
                update_req_grad([self.desc_m, self.desc_p], False)
                self.adam_gen.zero_grad()

                # forward pass through generator
                fake_photo = self.gen_mtp(monet_img)
                fake_monet = self.gen_ptm(photo_img)

                cycl_monet = self.gen_ptm(fake_photo)
                cycl_photo = self.gen_mtp(fake_monet)

                id_monet = self.gen_ptm(monet_img)
                id_photo = self.gen_mtp(photo_img)

                # generator losses
                idt_loss_monet = self.l1_loss(id_monet, monet_img) * self.lmbda * self.idt_coef
                idt_loss_photo = self.l1_loss(id_photo, photo_img) * self.lmbda * self.idt_coef

                cycle_loss_monet = self.l1_loss(cycl_monet, monet_img) * self.lmbda
                cycle_loss_photo = self.l1_loss(cycl_photo, photo_img) * self.lmbda

                monet_desc = self.desc_m(fake_monet)
                photo_desc = self.desc_p(fake_photo)

                real = torch.ones(monet_desc.size()).to(self.device)

                adv_loss_monet = self.mse_loss(monet_desc, real)
                adv_loss_photo = self.mse_loss(photo_desc, real)

                # total generator loss
                total_gen_loss = cycle_loss_monet + adv_loss_monet\
                              + cycle_loss_photo + adv_loss_photo\
                              + idt_loss_monet + idt_loss_photo
                
                avg_gen_loss += total_gen_loss.item()

                # backward pass
                total_gen_loss.backward()
                self.adam_gen.step()

                # forward pass through Descriminator
                update_req_grad([self.desc_m, self.desc_p], True)
                self.adam_desc.zero_grad()

                fake_monet = self.sample_monet([fake_monet.cpu().data.numpy()])[0]
                fake_photo = self.sample_photo([fake_photo.cpu().data.numpy()])[0]
                fake_monet = torch.tensor(fake_monet).to(self.device)
                fake_photo = torch.tensor(fake_photo).to(self.device)

                monet_desc_real = self.desc_m(monet_img)
                monet_desc_fake = self.desc_m(fake_monet)
                photo_desc_real = self.desc_p(photo_img)
                photo_desc_fake = self.desc_p(fake_photo)

                real = torch.ones(monet_desc_real.size()).to(self.device)
                fake = torch.zeros(monet_desc_fake.size()).to(self.device)

                # descriminator losses
                monet_desc_real_loss = self.mse_loss(monet_desc_real, real)
                monet_desc_fake_loss = self.mse_loss(monet_desc_fake, fake)
                photo_desc_real_loss = self.mse_loss(photo_desc_real, real)
                photo_desc_fake_loss = self.mse_loss(photo_desc_fake, fake)

                monet_desc_loss = (monet_desc_real_loss + monet_desc_fake_loss) / 2
                photo_desc_loss = (photo_desc_real_loss + photo_desc_fake_loss) / 2
                total_desc_loss = monet_desc_loss + photo_desc_loss
                avg_desc_loss += total_desc_loss.item()

                # backward
                monet_desc_loss.backward()
                photo_desc_loss.backward()
                self.adam_desc.step()
                
                t.set_postfix(gen_loss=total_gen_loss.item(), desc_loss=total_desc_loss.item())
            
            avg_gen_loss /= photo_dl.__len__()
            avg_desc_loss /= photo_dl.__len__()
            time_req = time.time() - start_time
            
            self.gen_stats.append(avg_gen_loss, time_req)
            self.desc_stats.append(avg_desc_loss, time_req)
            
            print("Epoch %d  -  Generator Loss:%f  -  Discriminator Loss:%f" % (epoch+1, avg_gen_loss, avg_desc_loss))
      
            self.gen_lr_sched.step()
            self.desc_lr_sched.step()

gan = CycleGAN(3, 3, 50, device)
gan.train(img_dl)
class PhotoDataset(Dataset):
    def __init__(self, photo_dir, size=(256, 256), normalize=True):
        super().__init__()
        self.photo_dir = photo_dir
        self.photo_idx = dict()
        
        if normalize:
            self.transform = transforms.Compose([
                transforms.Resize(size),
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))                                
            ])
        else:
            self.transform = transforms.Compose([
                transforms.Resize(size),
                transforms.ToTensor()                               
            ])
        for i, fl in enumerate(os.listdir(self.photo_dir)):
            self.photo_idx[i] = fl

    def __getitem__(self, idx):
        photo_path = os.path.join(self.photo_dir, self.photo_idx[idx])
        photo_img = Image.open(photo_path)
        photo_img = self.transform(photo_img)
        return photo_img

    def __len__(self):
        return len(self.photo_idx.keys())
# make a dataloader and the required directory for storing the images to be created
ph_ds = PhotoDataset(path_photo)
ph_dl = DataLoader(ph_ds, batch_size=1, pin_memory=True)
trans = transforms.ToPILImage()
# use model to create a monet style image
t = tqdm(ph_dl, leave=False, total=ph_dl.__len__())
for i, photo in enumerate(t):
    with torch.no_grad():
        pred_monet = gan.gen_ptm(photo.to(device)).cpu().detach()
    
    pred_monet = reverse_normalize(pred_monet)  # Adjusted normalization
    img = trans(pred_monet[0]).convert("RGB")
    
    os.makedirs("/kaggle/working/images", exist_ok=True)  # Ensure directory exists
    img.save("/kaggle/working/images/" + str(i+1) + ".jpg")
shutil.make_archive("/kaggle/working/images", 'zip', "/kaggle/working/images")

  0%|          | 0/75 [00:00<?, ?it/s]

Epoch 1  -  Generator Loss:13.287411  -  Discriminator Loss:1.688427


  0%|          | 0/75 [00:00<?, ?it/s]

Epoch 2  -  Generator Loss:10.359746  -  Discriminator Loss:0.552276


  0%|          | 0/75 [00:00<?, ?it/s]

Epoch 3  -  Generator Loss:9.493285  -  Discriminator Loss:0.475858


  0%|          | 0/75 [00:00<?, ?it/s]

Epoch 4  -  Generator Loss:9.073657  -  Discriminator Loss:0.445538


  0%|          | 0/75 [00:00<?, ?it/s]

Epoch 5  -  Generator Loss:8.923146  -  Discriminator Loss:0.451132


  0%|          | 0/75 [00:00<?, ?it/s]

Epoch 6  -  Generator Loss:8.172950  -  Discriminator Loss:0.414949


  0%|          | 0/75 [00:00<?, ?it/s]

Epoch 7  -  Generator Loss:8.203859  -  Discriminator Loss:0.427397


  0%|          | 0/75 [00:00<?, ?it/s]

Epoch 8  -  Generator Loss:8.178611  -  Discriminator Loss:0.394040


  0%|          | 0/75 [00:00<?, ?it/s]

Epoch 9  -  Generator Loss:7.797466  -  Discriminator Loss:0.413690


  0%|          | 0/75 [00:00<?, ?it/s]

Epoch 10  -  Generator Loss:7.914119  -  Discriminator Loss:0.388582


  0%|          | 0/75 [00:00<?, ?it/s]

Epoch 11  -  Generator Loss:7.658181  -  Discriminator Loss:0.388808


  0%|          | 0/75 [00:00<?, ?it/s]

Epoch 12  -  Generator Loss:7.641524  -  Discriminator Loss:0.388718


  0%|          | 0/75 [00:00<?, ?it/s]

Epoch 13  -  Generator Loss:7.511680  -  Discriminator Loss:0.351836


  0%|          | 0/75 [00:00<?, ?it/s]

Epoch 14  -  Generator Loss:7.439435  -  Discriminator Loss:0.400211


  0%|          | 0/75 [00:00<?, ?it/s]

Epoch 15  -  Generator Loss:7.474471  -  Discriminator Loss:0.381278


  0%|          | 0/75 [00:00<?, ?it/s]

Epoch 16  -  Generator Loss:7.504125  -  Discriminator Loss:0.345493


  0%|          | 0/75 [00:00<?, ?it/s]

Epoch 17  -  Generator Loss:7.298208  -  Discriminator Loss:0.348721


  0%|          | 0/75 [00:00<?, ?it/s]

Epoch 18  -  Generator Loss:7.195888  -  Discriminator Loss:0.371756


  0%|          | 0/75 [00:00<?, ?it/s]

Epoch 19  -  Generator Loss:7.132397  -  Discriminator Loss:0.352342


  0%|          | 0/75 [00:00<?, ?it/s]

Epoch 20  -  Generator Loss:7.086496  -  Discriminator Loss:0.359011


  0%|          | 0/75 [00:00<?, ?it/s]

Epoch 21  -  Generator Loss:6.968605  -  Discriminator Loss:0.413057


  0%|          | 0/75 [00:00<?, ?it/s]

Epoch 22  -  Generator Loss:6.994590  -  Discriminator Loss:0.378180


  0%|          | 0/75 [00:00<?, ?it/s]

Epoch 23  -  Generator Loss:6.840165  -  Discriminator Loss:0.376455


  0%|          | 0/75 [00:00<?, ?it/s]

Epoch 24  -  Generator Loss:6.794669  -  Discriminator Loss:0.370511


  0%|          | 0/75 [00:00<?, ?it/s]

Epoch 25  -  Generator Loss:6.810408  -  Discriminator Loss:0.383968


  0%|          | 0/75 [00:00<?, ?it/s]

Epoch 26  -  Generator Loss:6.929216  -  Discriminator Loss:0.371126


  0%|          | 0/75 [00:00<?, ?it/s]

Epoch 27  -  Generator Loss:6.763895  -  Discriminator Loss:0.361078


  0%|          | 0/75 [00:00<?, ?it/s]

Epoch 28  -  Generator Loss:6.643391  -  Discriminator Loss:0.334059


  0%|          | 0/75 [00:00<?, ?it/s]

Epoch 29  -  Generator Loss:6.676039  -  Discriminator Loss:0.325343


  0%|          | 0/75 [00:00<?, ?it/s]

Epoch 30  -  Generator Loss:6.596513  -  Discriminator Loss:0.326014


  0%|          | 0/75 [00:00<?, ?it/s]

Epoch 31  -  Generator Loss:6.445706  -  Discriminator Loss:0.289475


  0%|          | 0/75 [00:00<?, ?it/s]

Epoch 32  -  Generator Loss:6.372656  -  Discriminator Loss:0.304105


  0%|          | 0/75 [00:00<?, ?it/s]

Epoch 33  -  Generator Loss:6.201096  -  Discriminator Loss:0.288346


  0%|          | 0/75 [00:00<?, ?it/s]

Epoch 34  -  Generator Loss:6.315585  -  Discriminator Loss:0.279667


  0%|          | 0/75 [00:00<?, ?it/s]

Epoch 35  -  Generator Loss:6.303044  -  Discriminator Loss:0.273097


  0%|          | 0/75 [00:00<?, ?it/s]

Epoch 36  -  Generator Loss:6.279340  -  Discriminator Loss:0.274963


  0%|          | 0/75 [00:00<?, ?it/s]

Epoch 37  -  Generator Loss:6.252329  -  Discriminator Loss:0.267557


  0%|          | 0/75 [00:00<?, ?it/s]

Epoch 38  -  Generator Loss:6.175430  -  Discriminator Loss:0.281767


  0%|          | 0/75 [00:00<?, ?it/s]

Epoch 39  -  Generator Loss:6.091350  -  Discriminator Loss:0.249669


  0%|          | 0/75 [00:00<?, ?it/s]

Epoch 40  -  Generator Loss:5.921401  -  Discriminator Loss:0.262183


  0%|          | 0/75 [00:00<?, ?it/s]

Epoch 41  -  Generator Loss:5.923071  -  Discriminator Loss:0.240175


  0%|          | 0/75 [00:00<?, ?it/s]

Epoch 42  -  Generator Loss:5.982737  -  Discriminator Loss:0.236160


  0%|          | 0/75 [00:00<?, ?it/s]

Epoch 43  -  Generator Loss:5.936587  -  Discriminator Loss:0.221710


  0%|          | 0/75 [00:00<?, ?it/s]

Epoch 44  -  Generator Loss:5.765211  -  Discriminator Loss:0.228628


  0%|          | 0/75 [00:00<?, ?it/s]

Epoch 45  -  Generator Loss:5.903244  -  Discriminator Loss:0.231567


  0%|          | 0/75 [00:00<?, ?it/s]

Epoch 46  -  Generator Loss:5.787498  -  Discriminator Loss:0.224053


  0%|          | 0/75 [00:00<?, ?it/s]

Epoch 47  -  Generator Loss:5.782575  -  Discriminator Loss:0.206082


  0%|          | 0/75 [00:00<?, ?it/s]

Epoch 48  -  Generator Loss:5.719639  -  Discriminator Loss:0.207906


  0%|          | 0/75 [00:00<?, ?it/s]

Epoch 49  -  Generator Loss:5.721194  -  Discriminator Loss:0.200132


  0%|          | 0/75 [00:00<?, ?it/s]

Epoch 50  -  Generator Loss:5.718550  -  Discriminator Loss:0.192367


  0%|          | 0/7038 [00:00<?, ?it/s]

'/kaggle/working/images.zip'