<a href="https://colab.research.google.com/github/aimldlnlp/RKK303-Computer-Vision-Final-Project/blob/main/Image_Colorization_using_GANs.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Import Libraries

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np

# Data Preparation


In [2]:
transform = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
    transforms.Lambda(lambda x: (x * 2 - 1))  # Normalize to [-1, 1]
])

dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
dataset = torch.utils.data.Subset(dataset, range(0, 5000))  # Use the first 5000 images
data_loader = DataLoader(dataset, batch_size=16, shuffle=True)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170M/170M [00:04<00:00, 41.5MB/s]


Extracting ./data/cifar-10-python.tar.gz to ./data


# Convert to grayscale (L channel)


In [3]:
def rgb_to_grayscale(img):
    return img.mean(dim=1, keepdim=True)  # Average across color channels

# Generator

In [4]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(1, 64, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, 4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, 256, 4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 2, 4, stride=2, padding=1),  # Output is ab channels
            nn.Tanh()
        )

    def forward(self, x):
        return self.model(x)

# Discriminator


In [5]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 64, 4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, 4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.Conv2d(128, 256, 4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),
            nn.Flatten(),
            nn.Linear(256 * 8 * 8, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.model(x)

# Initialize Models

In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator = Generator().to(device)
discriminator = Discriminator().to(device)

# Optimizers and Loss

In [7]:
criterion = nn.BCELoss()
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002)
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002)

# Training Loop

In [9]:
import os# Directory to save images
output_dir = "output_images"
os.makedirs(output_dir, exist_ok=True)

In [12]:
num_epochs = 50
for epoch in range(num_epochs):
    for real_images, _ in data_loader:
        real_images = real_images.to(device)
        gray_images = rgb_to_grayscale(real_images).to(device)

        # Prepare real and fake labels
        batch_size = real_images.size(0)
        real_labels = torch.ones(batch_size, 1).to(device)
        fake_labels = torch.zeros(batch_size, 1).to(device)

        # Train Discriminator
        fake_images = generator(gray_images).detach()
        fake_ab_images = torch.cat((gray_images, fake_images), dim=1)

        real_ab_images = real_images
        optimizer_D.zero_grad()
        real_loss = criterion(discriminator(real_ab_images), real_labels)
        fake_loss = criterion(discriminator(fake_ab_images), fake_labels)
        d_loss = real_loss + fake_loss
        d_loss.backward()
        optimizer_D.step()

        # Train Generator
        optimizer_G.zero_grad()
        fake_images = generator(gray_images)
        fake_ab_images = torch.cat((gray_images, fake_images), dim=1)
        g_loss = criterion(discriminator(fake_ab_images), real_labels)
        g_loss.backward()
        optimizer_G.step()

    # Save models after each epoch
    # torch.save(generator.state_dict(), 'generator.pth')  # Save Generator
    # torch.save(discriminator.state_dict(), 'discriminator.pth')  # Save Discriminator

    # Save models and optimizers
    torch.save(generator.state_dict(), f'generator_epoch_{epoch+1}.pth')
    torch.save(discriminator.state_dict(), f'discriminator_epoch_{epoch+1}.pth')

    print(f'Epoch [{epoch+1}/{num_epochs}], d_loss: {d_loss.item():.4f}, g_loss: {g_loss.item():.4f}')

    # Save sample images
    if (epoch + 1) % 1 == 0:
        generated = (fake_images.detach().cpu() + 1) / 2  # Rescale to [0, 1]
        plt.figure(figsize=(10, 5))
        for i in range(8):
            plt.subplot(2, 8, i + 1)
            plt.imshow(gray_images[i][0].cpu(), cmap='gray')
            plt.axis('off')
            plt.subplot(2, 8, i + 9)
        #     plt.imshow(generated[i].permute(1, 2, 0))
            plt.axis('off')

        # Save the figure
        output_path = os.path.join(output_dir, f"epoch_{epoch + 1}.png")
        plt.savefig(output_path)
        plt.close()  # Close the figure to free memory
        print(f"Sample images saved to {output_path}")

Epoch [1/50], d_loss: 0.0190, g_loss: 6.4500
Sample images saved to output_images/epoch_1.png


KeyboardInterrupt: 

# Resuming Training

In [None]:
# generator.load_state_dict(torch.load('generator.pth'))
# discriminator.load_state_dict(torch.load('discriminator.pth'))

In [None]:
# # Load full checkpoint
# checkpoint = torch.load('checkpoint.pth')

# generator.load_state_dict(checkpoint['generator_state_dict'])
# discriminator.load_state_dict(checkpoint['discriminator_state_dict'])

# optimizer_G.load_state_dict(checkpoint['optimizer_G_state_dict'])
# optimizer_D.load_state_dict(checkpoint['optimizer_D_state_dict'])

# # Resume training from the saved epoch
# start_epoch = checkpoint['epoch'] + 1