In [None]:
import numpy as np
import pandas as pd
import os
import cv2
from PIL import Image
from zipfile import ZipFile
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
from copy import deepcopy
from skimage import io
from glob import glob
import torch
from torch import nn
import torchvision
from torchvision.utils import make_grid
from torchvision import transforms as T
from torch.utils.data import Dataset, DataLoader
from torchvision.models import vgg19_bn

In [None]:
monet_path = '../input/gan-getting-started/monet_jpg/'
photo_path = '../input/gan-getting-started/photo_jpg/'
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
monet_files = np.array(glob(monet_path + '*.jpg'))
photo_files = np.array(glob(photo_path + '*.jpg'))

# we make the dataset that provides two classes, first is monet paintings and scond is normal images. since there are fewer monet paintings than normal images, we choose random normal images to pair with monet paintings. we also normalize images from both classes.

In [None]:
class MonetData(Dataset):
    def __init__(self, monet_files, photo_files):
        self.monet = monet_files
        self.photo = photo_files
        self.transforms = T.Compose([T.ToTensor(), T.Normalize(0.5, 0.5)])
        self.random_choice()
    def __len__(self):
        return len(self.monet)
    def random_choice(self):
        self.photo_files = np.random.choice(self.photo, len(self.monet))
        np.random.shuffle(self.monet)
    def __getitem__(self, idx):
        image_a = io.imread(self.monet[idx])
        image_b = io.imread(self.photo_files[idx])
        image_a = self.transforms(image_a)
        image_b = self.transforms(image_b)
        if idx == len(self) - 1:
            self.random_choice()
        return image_a, image_b

In [None]:
train_data = MonetData(monet_files, photo_files)
train_data = DataLoader(train_data, batch_size = 4, shuffle = True)

# now we should implement the UNet architecture with residual blocks for the generators and a CNN for discriminators.

In [None]:
class DecoderBlock(nn.Module):
    
    def __init__(self , in_ch , out_ch):
        super(DecoderBlock , self).__init__()
        self.deconv = nn.ConvTranspose2d(in_ch , out_ch , kernel_size = 4 , padding = 1 , stride = 2)
        self.relu = nn.ReLU()
        self.bn = nn.InstanceNorm2d(out_ch)
        
    def forward(self , feat , x):
        x = self.deconv(x)
        x = self.bn(x)
        x = self.relu(x)
        x = torch.cat((x , feat) , dim = 1)
        return x

class DecoderBlock1(nn.Module):
    
    def __init__(self , in_ch , out_ch):
        super(DecoderBlock1 , self).__init__()
        self.deconv = nn.Upsample(scale_factor = 2, mode = 'bilinear')
        self.conv = nn.Conv2d(in_ch, out_ch, kernel_size = 3, padding = 1)
        self.relu = nn.ReLU()
        self.bn = nn.InstanceNorm2d(out_ch)
        
    def forward(self , feat , x):
        x = self.deconv(x)
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        x = torch.cat((x , feat) , dim = 1)
        return x

class ResidualBlock(nn.Module):
    
    def __init__(self, in_ch):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_ch, in_ch, kernel_size = 3, padding = 1, padding_mode = 'reflect')
        self.bn = nn.InstanceNorm2d(in_ch)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(in_ch, in_ch, kernel_size = 3, padding = 1, padding_mode = 'reflect')
        
    def forward(self, x):
        x1 = x.clone()
        x = self.conv1(x)
        x = self.bn(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.bn(x)
        return x1 + x

class UNet(nn.Module):
    
    def __init__(self , out_channel):
        super(UNet , self).__init__()
        self.out_ch = out_channel
        self.feature_detector = vgg19_bn(pretrained = True).features
        self.enc1 = self.feature_detector[:6]
        self.enc2 = self.feature_detector[6:13]
        self.enc3 = self.feature_detector[13:26]
        self.enc4 = self.feature_detector[26:39]
        #self.enc5 = self.feature_detector[39:52]
        
        self.bottle_neck = nn.Conv2d(512 , 512 , kernel_size = 3 , padding = 1 , stride = 1)
        self.res1 = ResidualBlock(512)
        self.res2 = ResidualBlock(512)
        self.res3 = ResidualBlock(512)
        self.res4 = ResidualBlock(512)
        
        #self.dec1 = DecoderBlock1(512 , 512)
        self.dec2 = DecoderBlock1(512 , 256)
        self.dec3 = DecoderBlock1(256 + 512 , 128)
        self.dec4 = DecoderBlock1(128 + 256 , 64)
        self.dec5 = DecoderBlock1(64 + 128 , 32)
        
        self.output_conv = nn.Conv2d(99 , out_channel , kernel_size = 7 , padding = 3 , stride = 1)
        self.tanh = nn.Tanh()
        
    def forward(self , x1):
        f1 = self.enc1(x1)
        f2 = self.enc2(f1)
        f3 = self.enc3(f2)
        f4 = self.enc4(f3)
        #f5 = self.enc5(f4)
        
        x = self.feature_detector[52](f4)
        x = self.bottle_neck(x)
        x = self.res1(x)
        x = self.res2(x)
        x = self.res3(x)
        x = self.res4(x)
        
        #x = self.dec1(f5 , x)
        x = self.dec2(f4 , x)
        x = self.dec3(f3 , x)
        x = self.dec4(f2 , x)
        x = self.dec5(torch.cat((f1 , x1) , dim = 1) , x)
        
        x = self.output_conv(x)
        x = self.tanh(x)
        return x

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.conv1_1 = nn.Conv2d(3, 64, kernel_size = 3, padding = 1, stride = 1)
        self.bn1_1 = nn.BatchNorm2d(64)
        self.relu1_1 = nn.LeakyReLU()
        self.conv1_2  = nn.Conv2d(64, 64, kernel_size = 3, padding = 1, stride = 2)
        self.bn1_2 = nn.BatchNorm2d(64)
        self.relu1_2 = nn.LeakyReLU()
        
        self.conv2_1 = nn.Conv2d(64, 128, kernel_size = 3, padding = 1, stride = 1)
        self.bn2_1 = nn.BatchNorm2d(128)
        self.relu2_1 = nn.LeakyReLU()
        self.conv2_2  = nn.Conv2d(128, 128, kernel_size = 3, padding = 1, stride = 2)
        self.bn2_2 = nn.BatchNorm2d(128)
        self.relu2_2 = nn.LeakyReLU()
        
        self.conv3_1 = nn.Conv2d(128, 256, kernel_size = 3, padding = 1, stride = 1)
        self.bn3_1 = nn.BatchNorm2d(256)
        self.relu3_1 = nn.LeakyReLU()
        self.conv3_2  = nn.Conv2d(256, 256, kernel_size = 3, padding = 1, stride = 2)
        self.bn3_2 = nn.BatchNorm2d(256)
        self.relu3_2 = nn.LeakyReLU()
        
        self.conv4_1 = nn.Conv2d(256, 512, kernel_size = 3, padding = 1, stride = 1)
        self.bn4_1 = nn.BatchNorm2d(512)
        self.relu4_1 = nn.LeakyReLU()
        self.conv4_2  = nn.Conv2d(512, 512, kernel_size = 3, padding = 1, stride = 2)
        self.bn4_2 = nn.BatchNorm2d(512)
        self.relu4_2 = nn.LeakyReLU()
        
        self.conv5_1 = nn.Conv2d(512, 512, kernel_size = 3, padding = 1, stride = 1)
        self.bn5_1 = nn.BatchNorm2d(512)
        self.relu5_1 = nn.LeakyReLU()
        self.conv5_2  = nn.Conv2d(512, 1, kernel_size = 1, padding = 1, stride = 2)
        
        
    def forward(self, x):
        x = self.conv1_1(x)
        x = self.bn1_1(x)
        x = self.relu1_1(x)
        x = self.conv1_2(x)
        x = self.bn1_2(x)
        x = self.relu1_2(x)

        x = self.conv2_1(x)
        x = self.bn2_1(x)
        x = self.relu2_1(x)
        x = self.conv2_2(x)
        x = self.bn2_2(x)
        x = self.relu2_2(x)

        x = self.conv3_1(x)
        x = self.bn3_1(x)
        x = self.relu3_1(x)
        x = self.conv3_2(x)
        x = self.bn3_2(x)
        x = self.relu3_2(x)

        x = self.conv4_1(x)
        x = self.bn4_1(x)
        x = self.relu4_1(x)
        x = self.conv4_2(x)
        x = self.bn4_2(x)
        x = self.relu4_2(x)

        x = self.conv5_1(x)
        x = self.bn5_1(x)
        x = self.relu5_1(x)
        x = self.conv5_2(x)

        return x

In [None]:
def show_tensor_images(image_tensor, num_images=25, size=(3, 256, 256)):
    '''
    Function for visualizing images: Given a tensor of images, number of images, and
    size per image, plots and prints the images in an uniform grid.
    '''
    image_tensor = (image_tensor + 1) / 2
    image_shifted = image_tensor
    image_unflat = image_shifted.detach().cpu().view(-1, *size)
    image_grid = make_grid(image_unflat[:num_images], nrow=5)
    plt.figure(figsize = (30, 30))
    plt.imshow(image_grid.permute(1, 2, 0).squeeze())
    plt.show()

# in this step we define all the loss functions that we need for training the Cycle GAN. these loss functions are, discriminator loss, adversarial loss, identity loss and consistency loss.

In [None]:
def get_disc_loss(disc_a, real_a, fake_a, criterion):
    real_pred = disc_a(real_a)
    fake_pred = disc_a(fake_a)
    real_loss = criterion(real_pred, torch.ones_like(real_pred))
    fake_loss = criterion(fake_pred, torch.zeros_like(fake_pred))
    loss = (real_loss + fake_loss) / 2
    return loss
def get_adv_loss(gen_ab, disc_b, real_a, criterion):
    fake_b = gen_ab(real_a)
    fake_pred = disc_b(fake_b)
    loss = criterion(fake_pred, torch.ones_like(fake_pred))
    return loss
def get_id_loss(gen_ab, real_b, criterion):
    fake_b = gen_ab(real_b).detach()
    loss = criterion(fake_b, real_b)
    return loss
def get_consist_loss(gen_ab, real_b, fake_a, criterion):
    fake_b = gen_ab(fake_a)
    loss = criterion(fake_b, real_b)
    return loss
def get_gen_loss(gen_ab, gen_ba, disc_a, disc_b, real_a, real_b, adv_crit, consist_crit, consist_lambda = 10):
    fake_a = gen_ba(real_b).detach()
    fake_b = gen_ab(real_a).detach()
    adv_loss_a = get_adv_loss(gen_ab, disc_b, real_a, adv_crit)
    adv_loss_b = get_adv_loss(gen_ba, disc_a, real_b, adv_crit)
    consist_loss_a = get_consist_loss(gen_ab, real_b, fake_a, consist_crit) * consist_lambda
    consist_loss_b = get_consist_loss(gen_ba, real_a, fake_b, consist_crit) * consist_lambda
    loss = adv_loss_a + adv_loss_b + consist_loss_a + consist_loss_b
    return loss

# in Cycle GAN we have two generators and two discriminators and we need two optimizers for discriminators and one optimizer for both of the generators.

In [None]:
gen_ab = UNet(3)
gen_ba = UNet(3)
disc_a = Discriminator()
disc_b = Discriminator()
disc_a_optim = torch.optim.Adam(disc_a.parameters(), lr = 0.0002, betas = (0.5, 0.999))
disc_b_optim = torch.optim.Adam(disc_b.parameters(), lr = 0.0002, betas = (0.5, 0.999))
gen_optim = torch.optim.Adam(list(gen_ab.parameters()) + list(gen_ba.parameters()), lr = 0.0002, betas = (0.5, 0.999))
adv_criterion = nn.MSELoss()
consist_criterion = nn.L1Loss()

# in this section we train the Cycle GAN and we save the generators.

In [None]:
def train_cycle_gan(gen_ab, gen_ba, disc_a, disc_b, dsic_a_optim, disc_b_optim, gen_optim, train_data, adv_crit, consist_crit, epochs, device = device):
    gen_ab.to(device)
    gen_ba.to(device)
    disc_a.to(device)
    disc_b.to(device)
    for epoch in range(epochs):
        mean_disc_loss = 0
        mean_gen_loss = 0
        for real_a, real_b in tqdm(train_data):
            real_a = real_a.to(device)
            real_b = real_b.to(device)
            
            disc_a_optim.zero_grad()
            with torch.no_grad():
                fake_a = gen_ba(real_b)
            disc_a_loss = get_disc_loss(disc_a, real_a, fake_a, adv_crit)
            disc_a_loss.backward(retain_graph = True)
            disc_a_optim.step()
            
            disc_b_optim.zero_grad()
            with torch.no_grad():
                fake_b = gen_ab(real_a)
            disc_b_loss = get_disc_loss(disc_b, real_b, fake_b, adv_crit)
            disc_b_loss.backward(retain_graph = True)
            disc_b_optim.step()
            
            gen_optim.zero_grad()
            gen_loss = get_gen_loss(gen_ab, gen_ba, disc_a, disc_b, real_a, real_b, adv_crit, consist_crit)
            gen_loss.backward()
            gen_optim.step()
            
            mean_disc_loss += disc_a_loss.item()
            mean_gen_loss += gen_loss.item()
        print(f'epoch: {epoch}  gen_loss: {mean_gen_loss/len(train_data)}  disc_loss: {mean_disc_loss/len(train_data)}')
        show_tensor_images(torch.cat([real_a, real_b]))
        show_tensor_images(torch.cat([fake_b, fake_a]))

In [None]:
train_cycle_gan(gen_ab, gen_ba, disc_a, disc_b, disc_a_optim, disc_b_optim, gen_optim, train_data, adv_criterion, consist_criterion, 40)

In [None]:
torch.save(gen_ba.state_dict(), './gen_ba.pth')
torch.save(gen_ab.state_dict(), './gen_ab.pth')

# at last we test the model to see how it works.

In [None]:
class TestData(Dataset):
    def __init__(self, file_names):
        self.files = file_names
        self.transforms = T.Compose([T.ToTensor(), T.Normalize(0.5, 0.5)])
    def __len__(self):
        return len(self.files)
    def __getitem__(self, idx):
        file = self.files[idx]
        image = io.imread(file)
        image = self.transforms(image)
        return image

In [None]:
test_data = TestData(photo_files)
test_data = DataLoader(test_data)

In [None]:
file_path = []
ind = 0
for image in tqdm(test_data):
    image = image.to(device)
    new_image = Image.fromarray((gen_ba(image)[0].detach().permute(1, 2, 0).cpu().numpy() * 127.5 + 127.5).astype(np.uint8))
    path = f'./image_{ind}.jpg'
    ind += 1
    file_path.append(path)
    new_image.save(path)

with ZipFile('./images.zip', 'w') as zip:
    for path in file_path:
        zip.write(path)