In [1]:
from torchvision.datasets import Flowers102
import multiprocessing
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
from torchvision import transforms
import torch.optim as optim
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
import random
import os
import pandas as pd


In [2]:
num_workers = multiprocessing.cpu_count()
print("Number of CPU cores available:", num_workers)
batch_size = 4

Number of CPU cores available: 4


In [3]:
class DataProvider(Dataset):

    def __init__(self, sample_dataset, label_dataset):
        self.sample_dataset = sample_dataset
        self.label_dataset = label_dataset

    def __len__(self):
        return len(self.sample_dataset)

    def __getitem__(self, idx):
        sample = self.sample_dataset[idx][0]
        label = self.label_dataset[idx][0]
        return (sample, label)


transform_Grayscale = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.Grayscale(),
    transforms.ToTensor()
])

transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor()
])

# Load sample datasets rgb image
train_set_sample = Flowers102(root='.', split='train', download=True, transform=transform)
test_set_sample = Flowers102(root='.', split='test', download=True, transform=transform)
validation_set_sample = Flowers102(root='.', split='val', download=True, transform=transform)

# Load label datasets - grayscale images as labels)
train_set_label = Flowers102(root='.', split='train', download=True, transform=transform_Grayscale)
test_set_label = Flowers102(root='.', split='test', download=True, transform=transform_Grayscale)
validation_set_label = Flowers102(root='.', split='val', download=True, transform=transform_Grayscale)

train_set = DataProvider(train_set_sample, train_set_label)
test_set = DataProvider(test_set_sample, test_set_label)
validation_set = DataProvider(validation_set_sample, validation_set_label)

train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=num_workers)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=True, num_workers=num_workers)
validation_loader = DataLoader(validation_set, batch_size=batch_size, shuffle=True, num_workers=num_workers)

Downloading https://thor.robots.ox.ac.uk/datasets/flowers-102/102flowers.tgz to flowers-102/102flowers.tgz


100%|██████████| 344862509/344862509 [00:12<00:00, 27321536.92it/s]


Extracting flowers-102/102flowers.tgz to flowers-102
Downloading https://thor.robots.ox.ac.uk/datasets/flowers-102/imagelabels.mat to flowers-102/imagelabels.mat


100%|██████████| 502/502 [00:00<00:00, 407418.85it/s]


Downloading https://thor.robots.ox.ac.uk/datasets/flowers-102/setid.mat to flowers-102/setid.mat


100%|██████████| 14989/14989 [00:00<00:00, 9206094.99it/s]


In [6]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', device)

Device: cuda


# Generator - UNET

In [7]:
class EncoderBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(EncoderBlock, self).__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.block(x)

class DecoderBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DecoderBlock, self).__init__()
        self.block = nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.block(x)

class UNetGenerator(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UNetGenerator, self).__init__()

        # Define encoder blocks
        self.encoder1 = EncoderBlock(in_channels, 64)
        self.encoder2 = EncoderBlock(64, 128)
        self.encoder3 = EncoderBlock(128, 256)
        self.encoder4 = EncoderBlock(256, 512)
        self.encoder5 = EncoderBlock(512, 1024)

        # Define decoder blocks=
        self.decoder1 = DecoderBlock(1024, 512)
        self.decoder2 = DecoderBlock(512 + 512, 256)
        self.decoder3 = DecoderBlock(256 + 256, 128)
        self.decoder4 = DecoderBlock(128 + 128, 64)
        self.decoder5 = DecoderBlock(64 + 64, out_channels)

    def forward(self, x):
        # Encoding
        x1 = self.encoder1(x)
        x2 = self.encoder2(x1)
        x3 = self.encoder3(x2)
        x4 = self.encoder4(x3)
        x5 = self.encoder5(x4)

        # Decoding with skip connections
        y = self.decoder1(x5)
        y = torch.cat([y, x4], dim=1)
        y = self.decoder2(y)
        y = torch.cat([y, x3], dim=1)
        y = self.decoder3(y)
        y = torch.cat([y, x2], dim=1)
        y = self.decoder4(y)
        y = torch.cat([y, x1], dim=1)
        y = self.decoder5(y)

        return y


# Critic - CNN

In [8]:
class VGG_block(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(VGG_block, self).__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.InstanceNorm2d(out_channels, affine=True),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.InstanceNorm2d(out_channels, affine=True),
            nn.LeakyReLU(0.2, inplace=True),
            nn.MaxPool2d(2)
        )

    def forward(self, x):
        return self.block(x)
    

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        
        # Define blocks
        self.layer1 = VGG_block(3, 64)
        self.layer2 = VGG_block(64, 128)
        self.layer3 = VGG_block(128, 256)
        self.layer4 = VGG_block(256, 512)
        self.layer5 = VGG_block(512, 1024)
        self.layer6 = VGG_block(1024, 1024)
        
        # Final convolutional layer
        self.final_conv = nn.Sequential(
            nn.Conv2d(1024, 1, kernel_size=3, stride=1, padding=1, bias=False)
        )

    def forward(self, img):
        x1 = self.layer1(img)
        x2 = self.layer2(x1)
        x3 = self.layer3(x2)
        x4 = self.layer4(x3)
        x5 = self.layer5(x4)
        x6 = self.layer6(x5)
        output = self.final_conv(x6)
        return output

In [None]:
# Initialize the generator and discriminator
generator = UNetGenerator(in_channels=1, out_channels=3)  
discriminator = Discriminator()

optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.0, 0.9))
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.0, 0.9))

criterion = nn.MSELoss()

generator.to(device)
discriminator.to(device)

# Pretraining

In [10]:
def pretrain_generator(generator, criterion, optimizer, train_loader):
    generator.train()  # Set the generator to training mode
    for epoch in range(100):
        for real_image, grey_image in train_loader:
            grey_image = grey_image.to(device)
            real_image = real_image.to(device)

            # Generate fake images
            fake_image = generator(grey_image)

            # Compute generator loss
            generator_loss = criterion(fake_image, real_image)

            # Backpropagation
            optimizer.zero_grad()
            generator_loss.backward()
            optimizer.step()
        if(epoch % 20 == 0):
            print(f'epoch number: {epoch}')

def pretrain_discriminator(discriminator, criterion, optimizer, train_loader):
    discriminator.train()  # Set the discriminator to training mode
    for epoch in range(100):
        for real_image, grey_image in train_loader:
            grey_image = grey_image.to(device)
            real_image = real_image.to(device)

            # Generate fake images
            fake_image = generator(grey_image).detach()

            # Compute discriminator loss
            real_output = discriminator(real_image)
            fake_output = discriminator(fake_image)
            discriminator_loss_real = criterion(real_output, torch.ones_like(real_output))
            discriminator_loss_fake = criterion(fake_output, torch.zeros_like(fake_output))
            discriminator_loss = (discriminator_loss_real + discriminator_loss_fake) / 2

            # Backpropagation
            optimizer.zero_grad()
            discriminator_loss.backward()
            optimizer.step()
        if(epoch % 20 == 0):
            print(f'epoch number: {epoch}')


# Pretraining the generator
pretrain_generator(generator, criterion, optimizer_G, train_loader)

# Pretraining the discriminator
pretrain_discriminator(discriminator, criterion, optimizer_D, train_loader)

torch.save(generator.state_dict(), f"MSE_generator_epoch.pt")
torch.save(discriminator.state_dict(), f"MSE_discriminator_epoch.pt")

epoch number: 0
epoch number: 20
epoch number: 40
epoch number: 60
epoch number: 80
epoch number: 0
epoch number: 20
epoch number: 40
epoch number: 60
epoch number: 80


In [12]:
def compute_gradient_penalty(discriminator, real_samples, fake_samples):
    # Generate random interpolation factors between real and fake samples
    alpha = torch.rand(real_samples.size(0), 1, 1, 1).to(device)

    # Perform linear interpolation between real and fake samples
    interpolates = alpha * real_samples + (1 - alpha) * fake_samples

    # Allow gradients for interpolation points
    interpolates.requires_grad_(True)

    # Pass interpolation points through the discriminator
    d_interpolates = discriminator(interpolates)

    # Create gradient tensors for each interpolation point
    grad_outputs = torch.ones_like(d_interpolates, requires_grad=False).to(device)

    # Compute gradients of discriminator outputs w.r.t. interpolation points
    gradients = torch.autograd.grad(outputs=d_interpolates, inputs=interpolates, grad_outputs=grad_outputs, create_graph=True, retain_graph=True)[0]

    # Flatten gradients to compute norm per sample
    gradients = gradients.view(gradients.size(0), -1)

    # Compute gradient penalty as per WGAN-GP formula
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()

    return gradient_penalty


def psnr(image_true, image_pred):
    mse = F.mse_loss(image_pred, image_true)
    max_pixel = 1.0
    return 20 * torch.log10(max_pixel / torch.sqrt(mse))


def validate(generator, criterion, validation_loader):
    generator.eval()  # Set the generator to evaluation mode
    total_loss = 0
    total_psnr = 0
    total_mae = 0

    with torch.no_grad():
        for real_image, grey_image in validation_loader:
            grey_image = grey_image.to(device)
            real_image = real_image.to(device)
            # Generate fake images
            fake_image = generator(grey_image)

            # Compute MSE loss
            loss = criterion(fake_image, real_image)
            # Accumulate total loss
            total_loss += loss.item()

            # Compute PSNR
            psnr_value = psnr(real_image, fake_image)
            total_psnr += psnr_value.item()


            # Compute MAE
            mae_value = F.l1_loss(fake_image, real_image)
            total_mae += mae_value.item()

    # Calculate average validation loss and other metrics
    avg_loss = total_loss / len(validation_loader)
    avg_psnr = total_psnr / len(validation_loader)
    # avg_ssim = total_ssim / len(validation_loader)
    avg_mae = total_mae / len(validation_loader)

    # return avg_loss, avg_psnr, avg_ssim, avg_mae
    return avg_loss, avg_psnr, avg_mae


def loadModel():
    generator.load_state_dict(torch.load('/kaggle/working/MSE_generator_epoch.pt'))
    discriminator.load_state_dict(torch.load('/kaggle/working/MSE_discriminator_epoch.pt'))
    print("Models found in both paths.")


In [14]:
# Hyperparameters
mini_batch = 64
batch_epoch = 16
gradient_penalty_lambda = 10
lambda_mse = 0.95
lambda_wgan = 0.05
n_critic = 2

# Training loop
num_epochs = 10
save_every_epoch = 2

# Lists to store losses
train_losses_discriminator = []
train_losses_generator = []
mse_losses = []
validation_losses = []
psnr_values = []
mae_values = []

In [None]:
loadModel()

# Training loop
for epoch in range(num_epochs):
    generator.train()
    discriminator.train()
    loss_G = 0
    mse_loss_total = 0

    for idx, (real_image, grey_image) in enumerate(train_loader):

        for _ in range(n_critic):
            # Randomly sample a mini batch from real data
            subset_indices = torch.randperm(len(train_set))[:mini_batch]

            # Create a SubsetRandomSampler using the subset indices
            subset_sampler = SubsetRandomSampler(subset_indices)

            critic_loader = DataLoader(train_set, batch_size=batch_size, sampler=subset_sampler, num_workers=num_workers)

            for i, (real_image_critic, grey_image_critic) in enumerate(critic_loader):
                grey_image_critic = grey_image_critic.to(device)
                real_image_critic = real_image_critic.to(device)

                optimizer_D.zero_grad()

                real_output_critic = discriminator(real_image_critic)
                fake_image_critic = generator(grey_image_critic)
                fake_output_critic = discriminator(fake_image_critic.detach())

                # Compute WGAN loss for discriminator with gradient penalty
                discriminator_loss = -(torch.mean(real_output_critic) - torch.mean(fake_output_critic))
                gradient_penalty = compute_gradient_penalty(discriminator, real_image_critic, fake_image_critic)
                discriminator_loss += gradient_penalty * gradient_penalty_lambda


                # Update discriminator weights
                discriminator_loss.backward()
                optimizer_D.step()

                # Apply weight clipping  - Needs to try with out it
                for p in discriminator.parameters():
                    p.data.clamp_(-0.01, 0.01)


        grey_image = grey_image.to(device)
        real_image = real_image.to(device)

        # Train Generator
        optimizer_G.zero_grad()
        fake_image = generator(grey_image)
        fake_output = discriminator(fake_image)

        # Compute WGAN loss and MSE Loss for generator
        generator_loss_wgan = -torch.mean(fake_output)
        generator_loss_mse = criterion(fake_image, real_image)
        generator_loss = lambda_wgan * generator_loss_wgan + lambda_mse * generator_loss_mse

        # Update generator weights
        generator_loss.backward()
        optimizer_G.step()


        train_losses_discriminator.append(discriminator_loss.item())
        train_losses_generator.append(generator_loss.item())
        mse_losses.append(generator_loss_mse.item())


        if idx % batch_epoch == 0:
            print(f"Epoch [{epoch}/{num_epochs}], Batch [{idx}/{len(train_loader)}], Generator WGAN Loss: {generator_loss_wgan.item():.4f}, Generator MSE Loss: {generator_loss_mse.item():.4f}, Discriminator Loss: {discriminator_loss.item():.4f}")

    # Calculate validation loss
    validation_loss, validation_psnr, validation_mae = validate(generator, criterion, validation_loader)


    validation_losses.append(validation_loss)
    psnr_values.append(validation_psnr)
    # ssim_values.append(validation_ssim)
    mae_values.append(validation_mae)
    print(f"Epoch [{epoch}/{num_epochs}], Validation Loss: {validation_loss:.4f}")

    if epoch % save_every_epoch == 0:
        # Save model checkpoints
        torch.save(generator.state_dict(), f"generator_epoch.pt")
        torch.save(discriminator.state_dict(), f"discriminator_epoch.pt")

        # Save model checkpoints to Google Drive
#         torch.save(generator.state_dict(), '/content/drive/My Drive/generator_epoch.pt')
#         torch.save(discriminator.state_dict(), '/content/drive/My Drive/discriminator_epoch.pt')


        # Generate random image and its colorized version
        real_image, grey_image = random.choice(train_loader.dataset)
        grey_image = grey_image.unsqueeze(0).to(device)
        colorized_image = generator(grey_image)

        # Convert tensors to PIL images
        grey_image_pil = transforms.ToPILImage()(grey_image.squeeze().cpu())
        colorized_image_pil = transforms.ToPILImage()(colorized_image.squeeze().cpu().detach())
        original_image_pil = transforms.ToPILImage()(real_image.cpu())

        # Plot and save the images
        fig, axs = plt.subplots(1, 3, figsize=(15, 5))
        axs[0].imshow(grey_image_pil, cmap='gray')
        axs[0].set_title("Grayscale Image")
        axs[0].axis("off")
        axs[1].imshow(colorized_image_pil)
        axs[1].set_title("Generated Colorized Image")
        axs[1].axis("off")
        axs[2].imshow(original_image_pil)
        axs[2].set_title("Original RGB Image")
        axs[2].axis("off")
        plt.savefig(os.path.join('/kaggle/working/', f"epoch_{epoch + 1}_colorized_image.png"))
#         plt.close()

# Plotting
plt.figure(figsize=(15, 15))

# Training Loss - Discriminator
plt.subplot(3, 2, 1)
plt.plot(train_losses_discriminator, label='Training Loss - Discriminator')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss - Discriminator')
plt.legend()
plt.grid(True)

# Training Loss - Generator
plt.subplot(3, 2, 2)
plt.plot(train_losses_generator, label='Training Loss - Generator')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss - Generator')
plt.legend()
plt.grid(True)

# MSE Loss
plt.subplot(3, 2, 3)
plt.plot(mse_losses, label='MSE Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('MSE Loss')
plt.legend()
plt.grid(True)

# Validation Loss
plt.subplot(3, 2, 4)
plt.plot(validation_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Validation Loss')
plt.legend()
plt.grid(True)

# PSNR
plt.subplot(3, 2, 5)
plt.plot(psnr_values, label='PSNR')
plt.xlabel('Epoch')
plt.ylabel('PSNR')
plt.title('PSNR')
plt.legend()
plt.grid(True)

# MAE
plt.subplot(3, 2, 6)
plt.plot(mae_values, label='MAE')
plt.xlabel('Epoch')
plt.ylabel('MAE')
plt.title('MAE')
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.savefig('performance_metrics.png')
plt.show()



Models found in both paths.
Epoch [0/10], Batch [0/255], Generator WGAN Loss: -0.0110, Generator MSE Loss: 0.0013, Discriminator Loss: 6.8930
Epoch [0/10], Batch [16/255], Generator WGAN Loss: -0.0873, Generator MSE Loss: 0.0032, Discriminator Loss: 0.8992
Epoch [0/10], Batch [32/255], Generator WGAN Loss: -0.0577, Generator MSE Loss: 0.0021, Discriminator Loss: 2.7579
Epoch [0/10], Batch [48/255], Generator WGAN Loss: -0.0684, Generator MSE Loss: 0.0023, Discriminator Loss: 1.6442
Epoch [0/10], Batch [64/255], Generator WGAN Loss: -0.0501, Generator MSE Loss: 0.0031, Discriminator Loss: 1.9743
Epoch [0/10], Batch [80/255], Generator WGAN Loss: -0.0489, Generator MSE Loss: 0.0019, Discriminator Loss: 0.1718
Epoch [0/10], Batch [96/255], Generator WGAN Loss: -0.0476, Generator MSE Loss: 0.0027, Discriminator Loss: 0.3673
Epoch [0/10], Batch [112/255], Generator WGAN Loss: -0.0487, Generator MSE Loss: 0.0017, Discriminator Loss: 0.3567
Epoch [0/10], Batch [128/255], Generator WGAN Loss: 

# Test

In [None]:
generator.eval()  # Set the generator to evaluation mode
for i, (real_image, gray_image) in enumerate(test_loader):

    with torch.no_grad():
        # Generate random image and its colorized version
        gray_image = gray_image.to(device)
        colorized_image = generator(gray_image)  # Remove the batch dimension
        colorized_image = colorized_image[0]  # Remove the batch dimension

        # Convert tensors to PIL images
        gray_image_pil = transforms.ToPILImage()(gray_image[0].cpu())
        colorized_image_pil = transforms.ToPILImage()(colorized_image.cpu().detach())
        original_image_pil = transforms.ToPILImage()(real_image[0].cpu())

        # Plot and save the images
        fig, axs = plt.subplots(1, 3, figsize=(15, 5))
        axs[0].imshow(gray_image_pil, cmap='gray')
        axs[0].set_title("Grayscale Image")
        axs[0].axis("off")
        axs[1].imshow(colorized_image_pil)
        axs[1].set_title("Generated Colorized Image")
        axs[1].axis("off")
        axs[2].imshow(original_image_pil)
        axs[2].set_title("Original RGB Image")
        axs[2].axis("off")

        if i == 5:
            break


In [None]:
# Convert lists to DataFrames
train_losses_discriminator_df = pd.DataFrame({'train_loss_discriminator': train_losses_discriminator})
train_losses_generator_df = pd.DataFrame({'train_loss_generator': train_losses_generator})
mse_losses_df = pd.DataFrame({'mse_loss': mse_losses})
validation_losses_df = pd.DataFrame({'validation_loss': validation_losses})
psnr_values_df = pd.DataFrame({'psnr_value': psnr_values})
mae_values_df = pd.DataFrame({'mae_value': mae_values})

# Save DataFrames to CSV files
train_losses_discriminator_df.to_csv('/kaggle/working/train_losses_discriminator.csv', index=False)
train_losses_generator_df.to_csv('/kaggle/working/train_losses_generator.csv', index=False)
mse_losses_df.to_csv('/kaggle/working/mse_losses.csv', index=False)
validation_losses_df.to_csv('/kaggle/working/validation_losses.csv', index=False)
psnr_values_df.to_csv('/kaggle/working/psnr_values.csv', index=False)
mae_values_df.to_csv('/kaggle/working/mae_values.csv', index=False)