In [1]:
from torchvision.datasets import Flowers102
import multiprocessing
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
from torchvision import transforms
import torch.optim as optim
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
import random
import os
import pandas as pd

In [2]:
num_workers = multiprocessing.cpu_count()
print("Number of CPU cores available:", num_workers)
batch_size = 4

Number of CPU cores available: 4


In [3]:
class DataProvider(Dataset):

    def __init__(self, sample_dataset, label_dataset):
        self.sample_dataset = sample_dataset
        self.label_dataset = label_dataset

    def __len__(self):
        return len(self.sample_dataset)

    def __getitem__(self, idx):
        sample = self.sample_dataset[idx][0]
        label = self.label_dataset[idx][0]
        return (sample, label)


transform_Grayscale = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.Grayscale(),
    transforms.ToTensor()
])

transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor()
])

# Load sample datasets rgb image
train_set_sample = Flowers102(root='.', split='train', download=True, transform=transform)
test_set_sample = Flowers102(root='.', split='test', download=True, transform=transform)
validation_set_sample = Flowers102(root='.', split='val', download=True, transform=transform)

# Load label datasets - grayscale images as labels)
train_set_label = Flowers102(root='.', split='train', download=True, transform=transform_Grayscale)
test_set_label = Flowers102(root='.', split='test', download=True, transform=transform_Grayscale)
validation_set_label = Flowers102(root='.', split='val', download=True, transform=transform_Grayscale)

train_set = DataProvider(train_set_sample, train_set_label)
test_set = DataProvider(test_set_sample, test_set_label)
validation_set = DataProvider(validation_set_sample, validation_set_label)

train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=num_workers)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=True, num_workers=num_workers)
validation_loader = DataLoader(validation_set, batch_size=batch_size, shuffle=True, num_workers=num_workers)

Downloading https://thor.robots.ox.ac.uk/datasets/flowers-102/102flowers.tgz to flowers-102/102flowers.tgz


100%|██████████| 344862509/344862509 [00:01<00:00, 183246190.29it/s]


Extracting flowers-102/102flowers.tgz to flowers-102
Downloading https://thor.robots.ox.ac.uk/datasets/flowers-102/imagelabels.mat to flowers-102/imagelabels.mat


100%|██████████| 502/502 [00:00<00:00, 643109.53it/s]


Downloading https://thor.robots.ox.ac.uk/datasets/flowers-102/setid.mat to flowers-102/setid.mat


100%|██████████| 14989/14989 [00:00<00:00, 12013839.61it/s]


In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', device)

Device: cuda


# Generator - UNET

In [5]:
class VGGBlock(nn.Module):
    def __init__(self, in_channels, out_channels, dropout_rate=0.0):
        super(VGGBlock, self).__init__()
        layers = [
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        ]
       
        if dropout_rate > 0.0:
            layers.append(nn.Dropout(dropout_rate))
        
        self.conv = nn.Sequential(*layers)

    def forward(self, x):
        return self.conv(x)

class UNetGenerator(nn.Module):
    def __init__(self):
        super(UNetGenerator, self).__init__()
        
        # Pooling and upsampling
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        
        # Encoder: Downsampling path
        self.encoder_block1 = VGGBlock(1, 64)
        self.encoder_block2 = VGGBlock(64, 128)
        self.encoder_block3 = VGGBlock(128, 256)
        self.encoder_block4 = VGGBlock(256, 512)
        self.encoder_block5 = VGGBlock(512, 512)
        self.encoder_block6 = VGGBlock(512, 1024)
        
        # Decoder: Upsampling path
        self.decoder_block5 = VGGBlock(512 + 1024, 512)
        self.decoder_block4 = VGGBlock(512 + 512, 256)
        self.decoder_block3 = VGGBlock(256 + 256, 256)
        self.decoder_block2 = VGGBlock(128 + 256, 128)
        self.decoder_block1 = VGGBlock(128 + 64, 64)
        
        # Final convolution
        self.conv_last = nn.Conv2d(64, 3, kernel_size=1)
        
    def forward(self, x):
        # Downsample
        conv1 = self.encoder_block1(x)
        conv2 = self.encoder_block2(self.maxpool(conv1))
        conv3 = self.encoder_block3(self.maxpool(conv2))
        conv4 = self.encoder_block4(self.maxpool(conv3))
        conv5 = self.encoder_block5(self.maxpool(conv4))
        x = self.encoder_block6(self.maxpool(conv5))

        # Upsample and concatenate
        x = torch.cat([self.upsample(x), conv5], dim=1)
        x = torch.cat([self.upsample(self.decoder_block5(x)), conv4], dim=1)       
        x = torch.cat([self.upsample(self.decoder_block4(x)), conv3], dim=1)     
        x = torch.cat([self.upsample(self.decoder_block3(x)), conv2], dim=1)            
        x = torch.cat([self.upsample(self.decoder_block2(x)), conv1], dim=1)   
        
        # Final convolution
        out = self.conv_last(self.decoder_block1(x))
        
        return out

# Discriminator - CNN

In [6]:
class VGG_block(nn.Module):
    def __init__(self, in_channels, out_channels, dropout_rate=0.0):
        super(VGG_block, self).__init__()
        layers = [
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.InstanceNorm2d(out_channels, affine=True),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.InstanceNorm2d(out_channels, affine=True),
            nn.LeakyReLU(0.2, inplace=True)
        ]
        
        if dropout_rate > 0.0:
            layers.append(nn.Dropout(dropout_rate))
        
        layers.append(nn.MaxPool2d(2))
        
        self.block = nn.Sequential(*layers)

    def forward(self, x):
        return self.block(x)
    
    
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
                
        self.layer1 = VGG_block(3, 64)  
        self.layer2 = VGG_block(64, 128)
        self.layer3 = VGG_block(128, 256)
        self.layer4 = VGG_block(256, 512)
        self.layer5 = VGG_block(512, 1024)
        self.layer6 = VGG_block(1024, 1024)
        
        # Final convolutional layer
        self.final_conv = nn.Sequential(
            nn.Conv2d(1024, 1, kernel_size=3, stride=1, padding=1, bias=False)
        )

    def forward(self, img):
        x1 = self.layer1(img)
        x2 = self.layer2(x1)
        x3 = self.layer3(x2)
        x4 = self.layer4(x3)
        x5 = self.layer5(x4)
        x6 = self.layer6(x5)
        output = self.final_conv(x6)
        return torch.sigmoid(output)  

In [None]:
# Initialize the generator and discriminator
generator = UNetGenerator()  # Input is grayscale, output is color
discriminator = Discriminator()

optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

adversarial_loss = nn.BCEWithLogitsLoss()
mse_loss = nn.MSELoss()

generator = generator.to(device)
discriminator = discriminator.to(device)


In [8]:
def psnr(image_true, image_pred):
    mse = F.mse_loss(image_pred, image_true)
    max_pixel = 1.0
    return 20 * torch.log10(max_pixel / torch.sqrt(mse))

def ssim(img1, img2, C1=0.01**2, C2=0.03**2):
    mean1 = img1.mean([2, 3])
    mean2 = img2.mean([2, 3])
    std1 = img1.std([2, 3])
    std2 = img2.std([2, 3])
    std12 = (img1 * img2).mean([2, 3]) - mean1 * mean2
    
    ssim_n = (2 * mean1 * mean2 + C1) * (2 * std12 + C2)
    ssim_d = (mean1**2 + mean2**2 + C1) * (std1**2 + std2**2 + C2)
    
    return ssim_n / ssim_d

def validate(generator, criterion, validation_loader):
    generator.eval()  # Set the generator to evaluation mode
    total_loss = 0
    total_psnr = 0
    total_mae = 0
    total_ssim = 0

    with torch.no_grad():
        for real_image, grey_image in validation_loader:
            grey_image = grey_image.to(device)
            real_image = real_image.to(device)
            # Generate fake images
            fake_image = generator(grey_image)

            # Compute MSE loss
            loss = criterion(fake_image, real_image)
            # Accumulate total loss
            total_loss += loss.item()

            # Compute PSNR
            psnr_value = psnr(real_image, fake_image)
            total_psnr += psnr_value.item()

            # Compute SSIM
            ssim_value = ssim(real_image, fake_image)
            total_ssim += ssim_value.mean().item()
            
            # Compute MAE
            mae_value = F.l1_loss(fake_image, real_image)
            total_mae += mae_value.item()

    # Calculate average validation loss and other metrics
    avg_loss = total_loss / len(validation_loader)
    avg_psnr = total_psnr / len(validation_loader)
    avg_ssim = total_ssim / len(validation_loader)
    avg_mae = total_mae / len(validation_loader)

    return avg_loss, avg_psnr, avg_ssim, avg_mae


def loadModel():
    generator.load_state_dict(torch.load('/kaggle/working/MSE_generator_epoch.pt'))
    discriminator.load_state_dict(torch.load('/kaggle/working/MSE_discriminator_epoch.pt'))
    print("Models found in both paths.")


In [9]:
# Hyperparameters
batch_epoch = 16
lambda_adv = 0.01
lambda_mse = 0.99

# Training loop
num_epochs = 350
save_every_epoch = 30

# Lists to store losses
train_losses_discriminator = []
train_losses_generator = []
mse_losses = []
validation_losses = []
psnr_values = []
mae_values = []
ssim_values = []

In [None]:
total_discriminator_loss = 0.0
total_generator_loss = 0.0
total_mse_loss = 0.0

# Training loop
for epoch in range(num_epochs):
    generator.train()
    discriminator.train()

    for idx, (real_image, gray_image) in enumerate(train_loader):
        batch_size = gray_image.size(0)
        
        gray_image = gray_image.to(device)
        real_image = real_image.to(device)

        # ---------------------
        #  Train Discriminator
        # ---------------------
        optimizer_D.zero_grad()
        
        # Train discriminator with real images
        real_predictions = discriminator(real_image)
        real_labels = torch.ones_like(real_predictions).to(device) # Real labels
        real_loss = adversarial_loss(real_predictions, real_labels)

        # Train discriminator with fake images
        fake_color_image = generator(gray_image).detach()
        fake_predictions = discriminator(fake_color_image) 
        fake_labels = torch.zeros_like(fake_predictions).to(device)  # Label for fake images
        fake_loss = adversarial_loss(fake_predictions, fake_labels)

        # Total discriminator loss
        discriminator_loss = real_loss + fake_loss

        # Update discriminator weights
        discriminator_loss.backward()
        optimizer_D.step()

        # -----------------
        #  Train Generator
        # -----------------
    
        optimizer_G.zero_grad()
        fake_image = generator(gray_image)
        fake_predictions = discriminator(fake_image)
        
        
        real_labels = torch.ones_like(fake_predictions).to(device) # Label for real images
        generator_loss_adv = adversarial_loss(fake_predictions, real_labels)
        
       
        # Compute MSE loss between generated color images and ground truth
        generator_loss_mse = mse_loss(fake_image, real_image)
        
         # Combine adversarial loss and MSE loss
        generator_loss =  lambda_adv * generator_loss_adv + lambda_mse * generator_loss_mse


        # Update generator weights
        generator_loss.backward()
        optimizer_G.step()

        train_losses_discriminator.append(discriminator_loss.item())
        train_losses_generator.append(generator_loss.item())
        mse_losses.append(generator_loss_mse.item())


        if idx % batch_epoch == 0:
            print(f"Epoch [{epoch}/{num_epochs}], Batch [{idx}/{len(train_loader)}], Generator GAN Loss: {generator_loss.item():.4f}, Generator MSE Loss: {generator_loss_mse.item():.4f}, Discriminator Loss: {discriminator_loss.item():.4f}")

    # Calculate validation loss
    validation_loss, validation_psnr, validation_ssim, validation_mae = validate(generator, mse_loss, validation_loader)

    validation_losses.append(validation_loss)
    psnr_values.append(validation_psnr)
    mae_values.append(validation_mae)
    ssim_values.append(validation_ssim)
    print(f"Epoch [{epoch}/{num_epochs}], Validation Loss: {validation_loss:.4f}")

    if epoch % save_every_epoch == 0:
        # Save model checkpoints
        torch.save(generator.state_dict(), f"generator_epoch.pt")
        torch.save(discriminator.state_dict(), f"discriminator_epoch.pt")

        # Generate random image and its colorized version
        real_image, grey_image = random.choice(train_loader.dataset)
        grey_image = grey_image.unsqueeze(0).to(device)
        colorized_image = generator(grey_image)

        # Convert tensors to PIL images
        grey_image_pil = transforms.ToPILImage()(grey_image.squeeze().cpu())
        colorized_image_pil = transforms.ToPILImage()(colorized_image.squeeze().cpu().detach())
        original_image_pil = transforms.ToPILImage()(real_image.cpu())

        # Plot and save the images
        fig, axs = plt.subplots(1, 3, figsize=(15, 5))
        axs[0].imshow(grey_image_pil, cmap='gray')
        axs[0].set_title("Grayscale Image")
        axs[0].axis("off")
        axs[1].imshow(colorized_image_pil)
        axs[1].set_title("Generated Colorized Image")
        axs[1].axis("off")
        axs[2].imshow(original_image_pil)
        axs[2].set_title("Original RGB Image")
        axs[2].axis("off")
        plt.savefig(os.path.join('/kaggle/working/', f"epoch_{epoch + 1}_colorized_image.png"))



# Plotting
plt.figure(figsize=(15, 15))

# Training Loss - Discriminator
plt.subplot(4, 2, 1)
plt.plot(train_losses_discriminator, label='Training Loss - Discriminator')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss - Discriminator')
plt.legend()
plt.grid(True)

# Training Loss - Generator
plt.subplot(4, 2, 2)
plt.plot(train_losses_generator, label='Training Loss - Generator')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss - Generator')
plt.legend()
plt.grid(True)

# MSE Loss
plt.subplot(4, 2, 3)
plt.plot(mse_losses, label='MSE Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('MSE Loss')
plt.legend()
plt.grid(True)

# Validation Loss
plt.subplot(4, 2, 4)
plt.plot(validation_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Validation Loss')
plt.legend()
plt.grid(True)

# PSNR
plt.subplot(4, 2, 5)
plt.plot(psnr_values, label='PSNR')
plt.xlabel('Epoch')
plt.ylabel('PSNR')
plt.title('PSNR')
plt.legend()
plt.grid(True)

# MAE
plt.subplot(4, 2, 6)
plt.plot(mae_values, label='MAE')
plt.xlabel('Epoch')
plt.ylabel('MAE')
plt.title('MAE')
plt.legend()
plt.grid(True)

#SSIM
plt.subplot(4, 2, 7)
plt.plot(ssim_values, label='SSIM')
plt.xlabel('Epoch')
plt.ylabel('SSIM')
plt.title('SSIM')
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.savefig('performance_metrics.png')
plt.show()

In [17]:
torch.save(generator.state_dict(), f"generator_epoch.pt")
torch.save(discriminator.state_dict(), f"discriminator_epoch.pt")

# Test

In [None]:
generator.eval()  # Set the generator to evaluation mode
for i, (real_image, gray_image) in enumerate(test_loader):

    with torch.no_grad():
        # Generate random image and its colorized version
        gray_image = gray_image.to(device)
        colorized_image = generator(gray_image)  
        colorized_image = colorized_image[0]  

        # Convert tensors to PIL images
        gray_image_pil = transforms.ToPILImage()(gray_image[0].cpu())
        colorized_image_pil = transforms.ToPILImage()(colorized_image.cpu().detach())
        original_image_pil = transforms.ToPILImage()(real_image[0].cpu())

        # Plot and save the images
        fig, axs = plt.subplots(1, 3, figsize=(15, 5))
        axs[0].imshow(gray_image_pil, cmap='gray')
        axs[0].set_title("Grayscale Image")
        axs[0].axis("off")
        axs[1].imshow(colorized_image_pil)
        axs[1].set_title("Generated Colorized Image")
        axs[1].axis("off")
        axs[2].imshow(original_image_pil)
        axs[2].set_title("Original RGB Image")
        axs[2].axis("off")

        if i == 30:
            break