# Libraries
Requirement for the execution of this notebook

In [None]:
import torch
import torch.nn as nn
from torchvision import transforms
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

from PIL import Image
import numpy as np
import os
import zipfile
import glob
from skimage import color
from matplotlib import pyplot as plt

import time

In [None]:
!pip install torchmetrics

from torchmetrics import functional

# Hyper Parameters
we can tune this parameters to modify the training proces

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
lr = 2e-4
batch_size = 16
n_workers = 2
img_size = 256
L1_lambda = 100
n_epochs = 100
n_samples = 10000

# Networks

In [None]:
class PatchGAN(nn.Module):

    def __init__(self, in_channels = 4, out_channels = 1):
        super().__init__()
        
        features = [64, 128, 256, 512]

        self.model = nn.Sequential(
            nn.Conv2d(in_channels, features[0], kernel_size = 4, stride = 2, padding = 1, padding_mode = 'reflect'),
            nn.LeakyReLU(0.2),
            nn.Conv2d(features[0], features[1], kernel_size = 4, stride = 2, padding = 1, padding_mode = 'reflect'),
            nn.BatchNorm2d(features[1]),
            nn.LeakyReLU(0.2),
            nn.Conv2d(features[1], features[2], kernel_size = 4, stride = 2, padding = 1, padding_mode = 'reflect'),
            nn.BatchNorm2d(features[2]),
            nn.LeakyReLU(0.2),
            nn.Conv2d(features[2], features[3], kernel_size = 4, stride = 1, padding = 1, padding_mode = 'reflect'),
            nn.BatchNorm2d(features[3]),
            nn.LeakyReLU(0.2),
            nn.Conv2d(features[3], out_channels, kernel_size = 4, stride = 1, padding = 1, padding_mode = 'reflect')
        )

    def forward(self, x, y):
        x = torch.cat([x, y], dim = 1) # we are using a conditional GAN so the input of the discriminator must include the input of the generator
        return self.model(x)

In [None]:
class Critic(nn.Module):

    def __init__(self, in_channels = 4, out_channels = 1):
        super().__init__()
        
        features = [64, 128, 256, 512]

        self.model = nn.Sequential(
            nn.Conv2d(in_channels, features[0], kernel_size = 4, stride = 2, padding = 1, padding_mode = 'reflect'),
            nn.LeakyReLU(0.2),
            nn.Conv2d(features[0], features[1], kernel_size = 4, stride = 2, padding = 1, padding_mode = 'reflect'),
            nn.InstanceNorm2d(features[1], affine=True),
            nn.LeakyReLU(0.2),
            nn.Conv2d(features[1], features[2], kernel_size = 4, stride = 2, padding = 1, padding_mode = 'reflect'),
            nn.InstanceNorm2d(features[2], affine=True),
            nn.LeakyReLU(0.2),
            nn.Conv2d(features[2], features[3], kernel_size = 4, stride = 1, padding = 1, padding_mode = 'reflect'),
            nn.InstanceNorm2d(features[3], affine=True),
            nn.LeakyReLU(0.2),
            nn.Conv2d(features[3], out_channels, kernel_size = 4, stride = 1, padding = 1, padding_mode = 'reflect')
        )

    def forward(self, x, y):
        x = torch.cat([x, y], dim = 1) # we are using a conditional GAN so the input of the discriminator must include the input of the generator
        return self.model(x)

In [None]:
class UNet(nn.Module):
    
    def __init__(self, in_channels = 1, out_channels = 3):
        super().__init__()
        
        self.pool_layer = nn.MaxPool2d(kernel_size = 2, stride = 2)

        self.encoder_block_1 = nn.Sequential(
            nn.Conv2d(in_channels = in_channels, out_channels = 64, kernel_size = 3, padding = 1),
            nn.BatchNorm2d(num_features = 64),
            nn.ReLU(),
            nn.Conv2d(in_channels = 64, out_channels = 64, kernel_size = 3, padding = 1),
            nn.BatchNorm2d(num_features = 64),
            nn.ReLU()
            )

        self.encoder_block_2 = nn.Sequential(
            nn.Conv2d(in_channels = 64, out_channels = 128, kernel_size = 3, padding = 1),
            nn.BatchNorm2d(num_features = 128),
            nn.ReLU(),
            nn.Conv2d(in_channels = 128, out_channels = 128, kernel_size = 3, padding = 1),
            nn.BatchNorm2d(num_features = 128),
            nn.ReLU()
        )
        
        self.encoder_block_3 = nn.Sequential(
            nn.Conv2d(in_channels = 128, out_channels = 256, kernel_size = 3, padding = 1),
            nn.BatchNorm2d(num_features = 256),
            nn.ReLU(),
            nn.Conv2d(in_channels = 256, out_channels = 256, kernel_size = 3, padding = 1),
            nn.BatchNorm2d(num_features = 256),
            nn.ReLU()
        )

        self.encoder_block_4 = nn.Sequential(
            nn.Conv2d(in_channels = 256, out_channels = 512, kernel_size = 3, padding = 1),
            nn.BatchNorm2d(num_features = 512),
            nn.ReLU(),
            nn.Conv2d(in_channels = 512, out_channels = 512, kernel_size = 3, padding = 1),
            nn.BatchNorm2d(num_features = 512),
            nn.ReLU()
        )

        self.middle_block = nn.Sequential(
            nn.Conv2d(in_channels = 512, out_channels = 1024, kernel_size = 3, padding = 1),
            nn.BatchNorm2d(num_features = 1024),
            nn.ReLU(),
            nn.Conv2d(in_channels = 1024, out_channels = 1024, kernel_size = 3, padding = 1),
            nn.BatchNorm2d(num_features = 1024),
            nn.ReLU()
        )

        self.upsampling_block_4 = nn.ConvTranspose2d(in_channels = 1024, out_channels = 512, kernel_size=2, stride=2)

        self.decoder_block_4 = nn.Sequential(
            nn.Conv2d(in_channels = 512 + 512, out_channels = 512, kernel_size = 3, padding = 1),
            nn.BatchNorm2d(num_features = 512),
            nn.ReLU(),
            nn.Conv2d(in_channels = 512, out_channels = 512, kernel_size = 3, padding = 1),
            nn.BatchNorm2d(num_features = 512),
            nn.ReLU(),
        )

        self.upsampling_block_3 = nn.ConvTranspose2d(in_channels = 512, out_channels = 256, kernel_size=2, stride=2)

        self.decoder_block_3 = nn.Sequential(
            nn.Conv2d(in_channels = 256 + 256, out_channels = 256, kernel_size = 3, padding = 1),
            nn.BatchNorm2d(num_features = 256),
            nn.ReLU(),
            nn.Conv2d(in_channels = 256, out_channels = 256, kernel_size = 3, padding = 1),
            nn.BatchNorm2d(num_features = 256),
            nn.ReLU(),
        )

        self.upsampling_block_2 = nn.ConvTranspose2d(in_channels = 256, out_channels = 128, kernel_size=2, stride=2)

        self.decoder_block_2 = nn.Sequential(
            nn.Conv2d(in_channels = 128 + 128, out_channels = 128, kernel_size = 3, padding = 1),
            nn.BatchNorm2d(num_features = 128),
            nn.ReLU(),
            nn.Conv2d(in_channels = 128, out_channels = 128, kernel_size = 3, padding = 1),
            nn.BatchNorm2d(num_features = 128),
            nn.ReLU(),
        )

        self.upsampling_block_1 = nn.ConvTranspose2d(in_channels = 128, out_channels = 64, kernel_size=2, stride=2)
        
        self.decoder_block_1 = nn.Sequential(
            nn.Conv2d(in_channels = 64 + 64, out_channels = 64, kernel_size = 3, padding = 1),
            nn.BatchNorm2d(num_features = 64),
            nn.ReLU(),
            nn.Conv2d(in_channels = 64, out_channels = 64, kernel_size = 3, padding = 1),
            nn.BatchNorm2d(num_features = 64),
            nn.ReLU(),
        )

        self.final_block = nn.Conv2d(in_channels = 64, out_channels = out_channels, kernel_size = 1)
        
        
    def forward(self, x):
        enc1 = self.encoder_block_1(x)
        enc2 = self.encoder_block_2(self.pool_layer(enc1))
        enc3 = self.encoder_block_3(self.pool_layer(enc2))
        enc4 = self.encoder_block_4(self.pool_layer(enc3))

        mid = self.middle_block(self.pool_layer(enc4))

        dec4 = self.upsampling_block_4(mid)
        dec4 = torch.cat([dec4, enc4], dim=1) # skip connection
        dec4 = self.decoder_block_4(dec4)

        dec3 = self.upsampling_block_3(dec4)
        dec3 = torch.cat([dec3, enc3], dim=1) # skip connection
        dec3 = self.decoder_block_3(dec3)

        dec2 = self.upsampling_block_2(dec3)
        dec2 = torch.cat([dec2, enc2], dim=1) # skip connection
        dec2 = self.decoder_block_2(dec2)

        dec1 = self.upsampling_block_1(dec2)
        dec1 = torch.cat([dec1, enc1], dim=1) # skip connection
        dec1 = self.decoder_block_1(dec1)

        return torch.sigmoid(self.final_block(dec1))

# Obtaining the Coco Dataset

In [None]:
!wget http://images.cocodataset.org/zips/train2017.zip -O "/content/train2017.zip"

In [None]:
zip_ref = zipfile.ZipFile('/content/train2017.zip', 'r')
zip_ref.extractall('/content')
zip_ref.close()

# Defining our Datasets

In [None]:
path = "/content/train2017"
paths = glob.glob(path + "/*.jpg") # Creates a list with the paths of all the images of the training set
np.random.seed(2052415)

paths_subset = np.random.choice(paths, n_samples, replace = False)
rand = np.random.permutation(n_samples)
train_indexes = rand[:(int)(n_samples * 0.8)] 
test_indexes = rand[(int)(n_samples * 0.8):] 

train_paths = paths_subset[train_indexes] # Contains the selected number of training image paths
test_paths = paths_subset[test_indexes] # Contains the selected number of training image paths

We can use GrayRGB to have a gray image to colorize and its ground truth

In [None]:
class GrayRGB(Dataset):
    def __init__(self, paths):
        self.size = img_size
        self.paths = paths

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, index):
        img = Image.open(self.paths[index]).convert("RGB")
        img = transforms.Resize((self.size, self.size))(img)
        img = np.array(img).astype("float32") / 255 # values in [0, 1]
       
        img_gray = color.rgb2gray(img).astype("float32")
        img_gray = transforms.ToTensor()(img_gray)
        img_rgb = transforms.ToTensor()(img) 
        
        return {"gray" : img_gray, "rgb" : img_rgb}

We can use labRGB to have a Luminance channel and its corresponding a and b channels that together form the image in Lab colorspace

In [None]:
class LabRGB(Dataset):
    def __init__(self, paths):
        self.size = img_size
        self.paths = paths

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, index):
        img = Image.open(self.paths[index]).convert("RGB")
        img = transforms.Resize((self.size, self.size))(img)
        img = np.array(img)
        img_Lab = color.rgb2lab(img).astype("float32")
        img_L = img_Lab[:, :, :1] / 100 # values in [0, 1]
        img_ab = (img_Lab[:, :, 1:] + 128) / 255 # values in [0, 1]
        img_L = transforms.ToTensor()(img_L)
        img_ab = transforms.ToTensor()(img_ab)
        
        return {"L" : img_L, "ab" : img_ab}

# Visualization Functions
Function needed to visualize some example during the training

In [None]:
def visualize(x, y, y_fake, i, n_batch, lab = False):

    fig = plt.figure(figsize=(9, 4))

    plt.suptitle(f"\nIteration {i}/{n_batch}")

    if not lab:
        ax = plt.subplot(1, 3, 1)
        ax.imshow(x[0].cpu().detach().numpy().astype('float32'), cmap='gray')
        ax.set_title("Grayscale Image")
        ax.axis("off")
        ax = plt.subplot(1, 3, 2)
        ax.imshow(y_fake.cpu().detach().permute(1,2,0).numpy().astype('float32'))
        ax.set_title("RGB Generated Image")
        ax.axis("off")
        ax = plt.subplot(1, 3, 3)
        ax.imshow(y.cpu().detach().permute(1,2,0).numpy().astype('float32'))
        ax.set_title("RGB Real Image")
        ax.axis("off")
        plt.show()
    else:
        x = x * 100
        y = y * 255 - 128
        y_fake = y_fake * 255 - 128
        ax = plt.subplot(1, 3, 1)
        lab_fake = torch.cat([x, y_fake], dim = 0).permute(1, 2, 0)
        rgb_fake = color.lab2rgb(lab_fake.cpu().detach())
        lab_real = torch.cat([x, y], dim = 0).permute(1, 2, 0)
        rgb_real = color.lab2rgb(lab_real.cpu().detach())
        ax.imshow(x[0].cpu().detach().numpy().astype('float32'), cmap='gray')
        ax.set_title("Grayscale Image")
        ax.axis("off")
        ax = plt.subplot(1, 3, 2)
        ax.imshow(rgb_fake)
        ax.set_title("RGB Generated Image")
        ax.axis("off")
        ax = plt.subplot(1, 3, 3)
        ax.imshow(rgb_real)
        ax.set_title("RGB Real Image")
        ax.axis("off")
        plt.show()

# Training

## Training functions
We optimize our network with this fucntion, thanks to its flags it can satisfy various cases according to our needs
- ssim_loss: if True ssim metrics is used to improve the convergence of the generator
- lab: if True the training is performed in the Lab colorspace instead of the standard RGB

In [None]:
def train_step(disc, gen, dl, opt_disc, opt_gen, l1, BCE, d_scaler, g_scaler, ssim_loss = False, lab = False, display_every = 50):
    i = 0
    loop = tqdm(dl, leave = True)

    for data in loop:

        if not lab:
            x = data["gray"]
            y = data["rgb"]
        else:
            x = data["L"]
            y = data["ab"]

        x = x.to(device)
        y = y.to(device)

        # Discriminator train

        with torch.cuda.amp.autocast():
            y_fake = gen(x)
            pred_real = disc(x, y)
            pred_fake = disc(x, y_fake.detach())
            disc_real_loss = BCE(pred_real, torch.ones_like(pred_real))
            disc_fake_loss = BCE(pred_fake, torch.zeros_like(pred_fake))
            disc_loss = (disc_real_loss + disc_fake_loss) / 2

        disc.zero_grad()
        d_scaler.scale(disc_loss).backward()
        d_scaler.step(opt_disc)
        d_scaler.update()

        # Generator train

        with torch.cuda.amp.autocast():
            pred_fake = disc(x, y_fake)
            gen_fake_loss = BCE(pred_fake, torch.ones_like(pred_fake))
            L1 = l1(y_fake, y) * L1_lambda
            gen_loss = gen_fake_loss + L1
            if ssim_loss:
                ssim_value = functional.structural_similarity_index_measure(preds = y_fake, target = y)
                gen_loss += (1 - ssim_value)

        gen.zero_grad()
        g_scaler.scale(gen_loss).backward()
        g_scaler.step(opt_gen)
        g_scaler.update()

        i += 1
        
        if i % display_every == 0:
            visualize(x[0], y[0], y_fake[0], i, len(dl), lab)
            print(f"\nGenerator Loss: {gen_loss} \nDiscriminator Loss: {disc_loss}")

In [None]:
def train(discriminator, generator, device, dataset, learning_rate, batch_size, n_workers, n_epochs, ssim_loss = False, lab = False):

    beta1 = 0.5
    beta2 = 0.999

    opt_disc = optim.Adam(discriminator.parameters(), lr = learning_rate, betas = (beta1, beta2))
    opt_gen = optim.Adam(generator.parameters(), lr = learning_rate, betas = (beta1, beta2))
    
    L1_loss = nn.L1Loss()
    BCE = nn.BCEWithLogitsLoss()

    train_dl = DataLoader(dataset, batch_size = batch_size, shuffle = True, num_workers = n_workers)

    d_scaler = torch.cuda.amp.GradScaler()
    g_scaler = torch.cuda.amp.GradScaler()
    
    for epoch in range(n_epochs):
            print(f"\nEpoch {epoch+1}/{n_epochs}")
            train_step(discriminator, generator, train_dl, opt_disc, opt_gen, L1_loss, BCE, d_scaler, g_scaler, ssim_loss, lab)

## Training with grayscale to RGB

Here we are generating an RGB image (3 color channels) from a grayscale image (1 channel). Our network is a conditional GAN so the input of our discriminator is 4 (3 for RGB image and 1 for the grayscale image).

In [None]:
GrayRGB_dataset = GrayRGB(train_paths)

gray_discriminator = PatchGAN(in_channels = 4).to(device)
gray_generator = UNet(out_channels = 3).to(device)

start_time = time.time()
train(gray_discriminator, gray_generator, device, GrayRGB_dataset, lr, batch_size, n_workers, n_epochs)
elapsed_time = time.time() - start_time
print("Training time: %.2f s" % elapsed_time)

# Save the weights
torch.save(gray_generator.state_dict(), '/content/gray_generator_weights.pt')
torch.save(gray_discriminator.state_dict(), '/content/gray_discriminator_weights.pt')

# Free GPU memory
torch.cuda.empty_cache()

## Training with Lab to RGB
We are generating an RGB image from Lab color space, exploiting the fact that channel L represents the luminance and we can use it as a sort of grayscale image. This will let us generate only 2 channels (a and b) to recreate the final RGB image (with a conversion from Lab to RGB). This will benefit the training process.

In [None]:
LabRGB_dataset = LabRGB(train_paths)

lab_discriminator = PatchGAN(in_channels = 3).to(device)
lab_generator = UNet(out_channels = 2).to(device)

start_time = time.time()
train(lab_discriminator, lab_generator, device, LabRGB_dataset, lr, batch_size, n_workers, n_epochs, lab = True)
elapsed_time = time.time() - start_time
print("Training time: %.2f s" % elapsed_time)

# Save the weights
torch.save(lab_generator.state_dict(), 'lab_generator_weights.pt')
torch.save(lab_discriminator.state_dict(), '/content/lab_discriminator_weights.pt')

# Free GPU memory
torch.cuda.empty_cache()

## Training with Lab to RGB with SSIM Loss

In [None]:
ssim_lab_discriminator = PatchGAN(in_channels = 3).to(device)
ssim_lab_generator = UNet(out_channels = 2).to(device)

train(ssim_lab_discriminator, ssim_lab_generator, device, GrayRGB_dataset, lr, batch_size, n_workers, n_epochs, ssim_loss = True)

# Save the weights
torch.save(ssim_lab_generator.state_dict(), '/content/lab_generator_ssim_loss_weights.pt')
torch.save(lab_discriminator.state_dict(), '/content/lab_discriminator_ssim_weights.pt')

# Free GPU memory
torch.cuda.empty_cache()

## WGAN
The main idea behind WGAN is to use a different loss function, the Wasserstein distance, which provides a more stable training process and improved generated results

In [None]:
from torch.autograd import Variable
from torch import autograd

def calculate_gradient_penalty(critic, device, real_images, gray_im, fake_images):
    batch_size, channel, height, width = real_images.shape
    eps_shape = [batch_size]+[1]*(len(real_images.shape)-1)
    eps = torch.rand(eps_shape, device=device)
    interpolated = eps * real_images + ((1 - eps) * fake_images)
    # calculate probability of interpolated examples
    score_interpolated = critic(gray_im, interpolated)

    # calculate gradients of probabilities with respect to examples
    gradients = autograd.grad(outputs=score_interpolated,
                                inputs=interpolated,
                                grad_outputs=torch.ones(score_interpolated.size(), device=device),
                                create_graph=True,
                                retain_graph=True,
                                only_inputs=True,
                                allow_unused=True)[0]
    
    grad_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    return grad_penalty

def W_train_step(critic, gen, device, dl, opt_critic, opt_gen, iter, lambda_term, lab = False):
    i=1

    for data in tqdm(dl):

        if not lab:
            gray = data["gray"]
            col = data["rgb"]
        else:
            gray = data["L"]
            col = data["ab"]

        gray = gray.to(device)
        col = col.to(device)

        c_loss_real = 0
        c_loss_fake = 0
        
        critic.zero_grad()
        gen.zero_grad()
    
        # Train discriminator with real images
        c_loss_real = critic(gray, col)

        # Train with fake images
        fake_images = gen(gray)
        c_loss_fake = critic(gray, fake_images.detach())

        # Train with gradient penalty
        gradient_penalty = calculate_gradient_penalty(critic, device, col, gray, fake_images)

        d_loss = (c_loss_fake - c_loss_real).mean() + (lambda_term*gradient_penalty.mean())
        d_loss.backward()
        opt_critic.step()

        if(i%iter==0):
            # Generator update
            critic.zero_grad()
            gen.zero_grad()
            
            # train generator
            fake_images = gen(gray)
            g_loss = -critic(gray, fake_images)
            g_loss = g_loss.mean()
            g_loss.backward()
            opt_gen.step()

        if(i==50):
            print(f'\nGenerator -> g_loss: {g_loss}')
            print(f'Critic -> loss_fake: {c_loss_fake.mean()}, loss_real: {c_loss_real.mean()}')
            visualize(gray[0], col[0], fake_images[0], i, len(dl), lab)
        i=i+1

In [None]:
def W_train(critic, generator, device, dataset, learning_rate, batch_size, n_workers, n_epochs, lab=False):

    beta1 = 0.5
    beta2 = 0.999
    iter = 5
    lambda_term = 10

    opt_critic = optim.Adam(critic.parameters(), lr = learning_rate, betas = (beta1, beta2))
    opt_gen = optim.Adam(generator.parameters(), lr = learning_rate, betas = (beta1, beta2))

    train_dl = DataLoader(dataset, batch_size = batch_size, shuffle = True, num_workers = n_workers)

    for epoch in range(n_epochs):
        print(f"\nEpoch {epoch+1}/{n_epochs}")
        W_train_step(critic, generator, device, train_dl, opt_critic, opt_gen, iter, lambda_term, lab)

## WGAN training with grayscale to RGB

In [None]:
GrayRGB_dataset = GrayRGB(train_paths)

wgan_gray_gen = UNet(out_channels = 3).to(device)
wgan_gray_crit = Critic(in_channels = 4).to(device)

W_train(wgan_gray_crit, wgan_gray_gen, device, GrayRGB_dataset, lr, batch_size, n_workers, n_epochs)

# Save the weights
torch.save(wgan_gray_gen.state_dict(), '/content/gray_generator_wgan_weights.pt')
torch.save(wgan_gray_crit.state_dict(), '/content/gray_critic_wgan_weights.pt')

# Free GPU memory
torch.cuda.empty_cache()

# Evaluation

In [None]:
def evaluate_network(dataloader, generator, discriminator, BCE, l1_loss, lab = False, wgan = False):
    generator.eval()
    generator.eval()
    with torch.no_grad():
        input = []
        generated = []
        real = []
        prediction_fake = []
        for data in tqdm(dataloader):

            if not lab:
                x = data["gray"]
                y = data["rgb"]
            else:
                x = data["L"]
                y = data["ab"]

            x = x.to(device)
            y = y.to(device)
            
            y_fake = generator(x)
            pred_fake = discriminator(x, y_fake)

            input.append(x)
            generated.append(y_fake)
            real.append(y)
            prediction_fake.append(pred_fake)

        input = torch.cat(input, axis=0)
        generated = torch.cat(generated, axis=0)
        real = torch.cat(real, axis=0)

        prediction_fake = torch.cat(prediction_fake, axis=0)

        ssim = functional.structural_similarity_index_measure(generated, real).item()

        difference = generated - real
        squared_difference = difference ** 2
        mse = torch.mean(squared_difference).item()

        if not wgan:
            gen_fake_loss = BCE(prediction_fake, torch.ones_like(prediction_fake)).detach().cpu().numpy()
            L1 = (l1_loss(generated, real) * L1_lambda).detach().cpu().numpy()
            gen_loss = gen_fake_loss + L1
        else:
            gen_loss = -prediction_fake
            gen_loss = gen_loss.mean()

        print(f"\nGenerator Loss: {gen_loss}")
        print(f"\nMSE: {mse}")
        print(f"\nSSIM: {ssim}")

        if lab:
            input = input * 100
            real = real * 255 - 128
            generated = generated * 255 - 128

            lab_fake = torch.cat([input, generated], dim = 1).permute(0, 2, 3, 1)
            generated = torch.tensor(color.lab2rgb(lab_fake.cpu())).permute(0, 3, 1, 2)
            lab_real = torch.cat([input, real], dim = 1).permute(0, 2, 3, 1)
            real = torch.tensor(color.lab2rgb(lab_real.cpu())).permute(0, 3, 1, 2)
        
        input = input.detach().cpu()
        generated = generated.detach().cpu()
        real = real.detach().cpu()

        fig, axes = plt.subplots(3, 6, figsize=(15, 6))
        axes = axes.ravel()
        fig.suptitle("Generated images VS Real images", fontsize = 15)
        for i, ax in enumerate(axes):
            if i % 6 < 3:
                ax.imshow(generated[i].permute(1,2,0), aspect = 'equal')
                ax.axis('off')
            else:
                ax.imshow(real[i - 3].permute(1,2,0), aspect = 'equal')
                ax.axis('off')

In [None]:
beta1 = 0.5
beta2 = 0.999

L1_loss = nn.L1Loss()
BCE = nn.BCEWithLogitsLoss()

In [None]:
# Load the weights
gray_generator_tester = UNet().to(device)
gray_discriminator_tester = PatchGAN().to(device)

gray_generator_tester.load_state_dict(torch.load('/content/gray_generator_weights.pt'))
gray_discriminator_tester.load_state_dict(torch.load('/content/gray_discriminator_weights.pt'))

opt_disc = optim.Adam(gray_discriminator_tester.parameters(), lr, betas = (beta1, beta2))
opt_gen = optim.Adam(gray_generator_tester.parameters(), lr, betas = (beta1, beta2))

gray_testing_ds = GrayRGB(test_paths)
gray_testing_dl = DataLoader(gray_testing_ds, batch_size = 16, shuffle = True)

print("From Grayscale to RGB\n")

evaluate_network(gray_testing_dl, gray_generator_tester, gray_discriminator_tester, BCE, L1_loss)

# Free GPU memory
torch.cuda.empty_cache()

In [None]:
# Load the weights
lab_generator_tester = UNet(out_channels = 2).to(device)
lab_discriminator_tester = PatchGAN(in_channels = 3).to(device)

lab_generator_tester.load_state_dict(torch.load('/content/lab_generator_weights.pt'))
lab_discriminator_tester.load_state_dict(torch.load('/content/lab_discriminator_weights.pt'))

opt_disc = optim.Adam(lab_discriminator_tester.parameters(), lr, betas = (beta1, beta2))
opt_gen = optim.Adam(lab_generator_tester.parameters(), lr, betas = (beta1, beta2))

lab_testing_ds = LabRGB(test_paths)
lab_testing_dl = DataLoader(lab_testing_ds, batch_size = 16, shuffle = True)

print("From Lab to RGB\n")

evaluate_network(lab_testing_dl, lab_generator_tester, lab_discriminator_tester, BCE, L1_loss, lab = True)

# Free GPU memory
torch.cuda.empty_cache()

In [None]:
# Load the weights
gray_generator_wgan_tester = UNet().to(device)
gray_discriminator_wgan_tester = Critic().to(device)

gray_generator_wgan_tester.load_state_dict(torch.load('/content/gray_generator_wgan_weights.pt'))
gray_discriminator_wgan_tester.load_state_dict(torch.load('/content/gray_critic_wgan_weights.pt'))

opt_disc = optim.Adam(gray_discriminator_wgan_tester.parameters(), lr, betas = (beta1, beta2))
opt_gen = optim.Adam(gray_generator_wgan_tester.parameters(), lr, betas = (beta1, beta2))

gray_testing_wgan_ds = GrayRGB(test_paths)
gray_testing_wgan_dl = DataLoader(gray_testing_wgan_ds, batch_size = 16, shuffle = True)

print("From Grayscale to RGB WGAN\n")

evaluate_network(gray_testing_wgan_dl, gray_generator_wgan_tester, gray_discriminator_wgan_tester, BCE, L1_loss, wgan = True)

# Free GPU memory
torch.cuda.empty_cache()