In [None]:
!pip install numpy torch torchvision pillow tqdm matplotlib


In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.utils import save_image
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from PIL import Image,UnidentifiedImageError
from tqdm import tqdm
import matplotlib.pyplot as plt

In [None]:
import kagglehub

# Download latest version
path = kagglehub.dataset_download("aayush9753/image-colorization-dataset")

print("Path to dataset files:", path)

Path to dataset files: /root/.cache/kagglehub/datasets/aayush9753/image-colorization-dataset/versions/1


In [None]:
'''from google.colab import drive
drive.mount('/content/drive')'''

"from google.colab import drive\ndrive.mount('/content/drive')"

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

class GANLoss(nn.Module):
    def __init__(self, loss_mode="lsgan"):
        super(GANLoss, self).__init__()
        if loss_mode == "lsgan":
            self.loss = nn.MSELoss()
        elif loss_mode == "vanilla":
            self.loss = nn.BCEWithLogitsLoss()
        else:
            raise NotImplementedError(f"Loss mode {loss_mode} not implemented.")

    def forward(self, pred, target_is_real):
        target = torch.ones_like(pred) if target_is_real else torch.zeros_like(pred)
        return self.loss(pred, target)

Using device: cuda


In [None]:
class ImageColorizationDataset(Dataset):
    def __init__(self, black_dir, color_dir, transform_bw=None, transform_color=None):
        self.black_dir = black_dir
        self.color_dir = color_dir
        self.transform_bw = transform_bw
        self.transform_color = transform_color
        self.image_files = os.listdir(black_dir)  # Ensure this matches between directories

        # Check if corresponding color images exist
        self.image_files = [
            img for img in self.image_files
            if os.path.isfile(os.path.join(color_dir, img))
        ]

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, index):
        black_path = os.path.join(self.black_dir, self.image_files[index])
        color_path = os.path.join(self.color_dir, self.image_files[index])

        try:
            black_image = Image.open(black_path).convert("L")
            color_image = Image.open(color_path).convert("RGB")

            if self.transform_bw:
                black_image = self.transform_bw(black_image)
            if self.transform_color:
                color_image = self.transform_color(color_image)

            return black_image, color_image
        except (FileNotFoundError, UnidentifiedImageError) as e:
            print(f"Skipping file {self.image_files[index]} due to error: {e}")
            # Handle the skipped file gracefully by retrying with another index
            return self.__getitem__((index + 1) % self.__len__())  # Circular retry

# Paths
train_black = "/root/.cache/kagglehub/datasets/aayush9753/image-colorization-dataset/versions/1/data/train_black"
train_color = "/root/.cache/kagglehub/datasets/aayush9753/image-colorization-dataset/versions/1/data/train_color"
test_black = "/root/.cache/kagglehub/datasets/aayush9753/image-colorization-dataset/versions/1/data/test_black"
test_color = "/root/.cache/kagglehub/datasets/aayush9753/image-colorization-dataset/versions/1/data/test_color"
output_path = "/root/.cache/kagglehub/datasets/aayush9753/image-colorization-dataset/versions/1/data"
os.makedirs(output_path, exist_ok=True)
# Transforms
transform_bw = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5,), std=(0.5,))
])
transform_color = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])

# Datasets and Loaders
train_dataset = ImageColorizationDataset(train_black, train_color, transform_bw, transform_color)
test_dataset = ImageColorizationDataset(test_black, test_color, transform_bw, transform_color)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)

In [None]:
#with downsampling

In [None]:
class ResnetBlock(nn.Module):
    def __init__(self, dim, norm_layer=nn.BatchNorm2d, use_dropout=False, padding_type='reflect'):
        super(ResnetBlock, self).__init__()

        conv_block = []
        p = 0
        if padding_type == 'reflect':
            conv_block += [nn.ReflectionPad2d(1)]
        elif padding_type == 'replicate':
            conv_block += [nn.ReplicationPad2d(1)]
        elif padding_type == 'zero':
            p = 1

        conv_block += [
            nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=True),
            norm_layer(dim),
            nn.ReLU(True)
        ]

        if use_dropout:
            conv_block += [nn.Dropout(0.5)]

        conv_block += [
            nn.ReflectionPad2d(1),
            nn.Conv2d(dim, dim, kernel_size=3, padding=0, bias=True),
            norm_layer(dim)
        ]

        self.conv_block = nn.Sequential(*conv_block)

    def forward(self, x):
        return x + self.conv_block(x)

class ResnetGenerator(nn.Module):
    def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=9, padding_type='reflect'):
        super(ResnetGenerator, self).__init__()
        assert (n_blocks >= 0)
        use_bias = norm_layer == nn.InstanceNorm2d

        model = [nn.ReflectionPad2d(3),
                 nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
                 norm_layer(ngf),
                 nn.ReLU(True)]

        # Downsampling
        n_downsampling = 2
        for i in range(n_downsampling):
            mult = 2 ** i
            model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),
                      norm_layer(ngf * mult * 2),
                      nn.ReLU(True)]

        # ResNet blocks
        mult = 2 ** n_downsampling
        for i in range(n_blocks):
            model += [ResnetBlock(ngf * mult, norm_layer=norm_layer, use_dropout=use_dropout, padding_type=padding_type)]

        # Upsampling
        for i in range(n_downsampling):
            mult = 2 ** (n_downsampling - i)
            model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, output_padding=1, bias=use_bias),
                      norm_layer(int(ngf * mult / 2)),
                      nn.ReLU(True)]

        model += [nn.ReflectionPad2d(3),
                  nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0),
                  nn.Tanh()]

        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)

In [None]:
class NLayerDiscriminator(nn.Module):
    def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d):
        super(NLayerDiscriminator, self).__init__()
        use_bias = norm_layer == nn.InstanceNorm2d

        kw = 4
        padw = 1
        sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]

        nf_mult = 1
        nf_mult_prev = 1
        for n in range(1, n_layers):
            nf_mult_prev = nf_mult
            nf_mult = min(2 ** n, 8)
            sequence += [nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
                         norm_layer(ndf * nf_mult), nn.LeakyReLU(0.2, True)]

        nf_mult_prev = nf_mult
        nf_mult = min(2 ** n_layers, 8)
        sequence += [nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
                     norm_layer(ndf * nf_mult), nn.LeakyReLU(0.2, True)]

        sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]
        self.model = nn.Sequential(*sequence)

    def forward(self, input):
        return self.model(input)


In [61]:
input_nc = 1  # Grayscale input
output_nc = 3  # RGB output
ngf, ndf = 64, 64

netG = ResnetGenerator(input_nc, output_nc, ngf).to(device)
netD = NLayerDiscriminator(output_nc, ndf).to(device)

generator_path = f"{output_path}/generator7_epoch2.pth"
discriminator_path = f"{output_path}/discriminator7_epoch2.pth"

if os.path.exists(generator_path) and os.path.exists(discriminator_path):
    print("Loading saved model checkpoints...")
    map_location = torch.device('cpu') if not torch.cuda.is_available() else None
    netG.load_state_dict(torch.load(generator_path, map_location=map_location))
    netD.load_state_dict(torch.load(discriminator_path, map_location=map_location))
    '''netG.load_state_dict(torch.load(generator_path))
    netD.load_state_dict(torch.load(discriminator_path))'''
else:
    print("No checkpoints found. Starting training from scratch.")

# Define loss and optimizers
gan_loss = GANLoss("lsgan").to(device)
optimizer_G = optim.Adam(netG.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = optim.Adam(netD.parameters(), lr=0.0002, betas=(0.5, 0.999))


Loading saved model checkpoints...


  netG.load_state_dict(torch.load(generator_path, map_location=map_location))
  netD.load_state_dict(torch.load(discriminator_path, map_location=map_location))


In [None]:

# Training Loop
epochs = 5
for epoch in range(epochs):
    g_loss_epoch, d_loss_epoch = 0.0, 0.0
    for i, (bw, color) in enumerate(tqdm(train_loader, desc=f"Epoch {epoch + 1}/{epochs}")):
        bw, color = bw.to(device), color.to(device)

        # Update Discriminator
        optimizer_D.zero_grad()
        fake_color = netG(bw)
        real_loss = gan_loss(netD(color), True)
        fake_loss = gan_loss(netD(fake_color.detach()), False)
        d_loss = (real_loss + fake_loss) / 2
        d_loss.backward()
        optimizer_D.step()

        # Update Generator
        optimizer_G.zero_grad()
        g_loss = gan_loss(netD(fake_color), True)
        g_loss.backward()
        optimizer_G.step()

        g_loss_epoch += g_loss.item()
        d_loss_epoch += d_loss.item()

        if i % 10 == 0:  # Print every 10 batches
            print(f"Batch {i}/{len(train_loader)} - G Loss: {g_loss.item():.4f}, D Loss: {d_loss.item():.4f}")

    print(f"Epoch {epoch + 1} Completed. G Loss: {g_loss_epoch / len(train_loader):.4f}, D Loss: {d_loss_epoch / len(train_loader):.4f}")

    # Save model checkpoints
    torch.save(netG.state_dict(), f"{output_path}/generator7_epoch{epoch+1}.pth")
    torch.save(netD.state_dict(), f"{output_path}/discriminator7_epoch{epoch+1}.pth")


In [None]:
  # Save model checkpoints
  torch.save(netG.state_dict(), f"{output_path}/generator7_epoch.pth")
  torch.save(netD.state_dict(), f"{output_path}/discriminator7_epoch.pth")

In [None]:
def generate_colorized_image(netG, bw_image_path, output_filename):
    # Load the model in evaluation mode
    netG.eval()

    # Open and transform the black and white image
    bw_image = Image.open(bw_image_path).convert("L")
    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5,), std=(0.5,))
    ])
    bw_image_tensor = transform(bw_image).unsqueeze(0).to(device)

    # Generate colorized image
    with torch.no_grad():
        color_image_tensor = netG(bw_image_tensor)

    # Denormalize and save the image
    color_image_np = color_image_tensor.squeeze().permute(1, 2, 0).cpu().numpy() * 0.5 + 0.5

    # Create full output path
    full_output_path = os.path.join(output_path, output_filename)

    plt.figure(figsize=(10, 5))
    plt.subplot(1, 2, 1)
    plt.title("Original Black & White")
    plt.imshow(bw_image, cmap='gray')
    plt.subplot(1, 2, 2)
    plt.title("Colorized Image")
    plt.imshow(color_image_np)
    plt.tight_layout()
    plt.savefig(full_output_path)
    plt.close()

In [None]:
generate_colorized_image(netG, "/root/.cache/kagglehub/datasets/aayush9753/image-colorization-dataset/versions/1/data/test_black/WhatsApp Image 2024-12-13 at 20.19.52_c4391885.jpg", "/root/.cache/kagglehub/datasets/aayush9753/image-colorization-dataset/versions/1/colorized_output.jpg")

In [None]:
'''#without downsampling
class ResnetBlockNoDownsampling(nn.Module):
    def __init__(self, dim, norm_layer=nn.BatchNorm2d, use_dropout=False, padding_type='reflect'):
        super(ResnetBlockNoDownsampling, self).__init__()

        conv_block = []
        p = 0
        if padding_type == 'reflect':
            conv_block += [nn.ReflectionPad2d(1)]
        elif padding_type == 'replicate':
            conv_block += [nn.ReplicationPad2d(1)]
        elif padding_type == 'zero':
            p = 1

        conv_block += [
            nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=True),
            norm_layer(dim),
            nn.ReLU(True)
        ]

        if use_dropout:
            conv_block += [nn.Dropout(0.5)]

        conv_block += [
            nn.ReflectionPad2d(1),
            nn.Conv2d(dim, dim, kernel_size=3, padding=0, bias=True),
            norm_layer(dim)
        ]

        self.conv_block = nn.Sequential(*conv_block)

    def forward(self, x):
        return x + self.conv_block(x)

class ResnetGeneratorNoDownsampling(nn.Module):
    def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=9, padding_type='reflect'):
        super(ResnetGeneratorNoDownsampling, self).__init__()
        assert (n_blocks >= 0)
        use_bias = norm_layer == nn.InstanceNorm2d

        # Initial convolutional layer
        model = [
            nn.ReflectionPad2d(3),
            nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
            norm_layer(ngf),
            nn.ReLU(True)
        ]

        # ResNet blocks
        for i in range(n_blocks):
            model += [ResnetBlockNoDownsampling(ngf, norm_layer=norm_layer, use_dropout=use_dropout, padding_type=padding_type)]

        # Final convolutional layer
        model += [
            nn.ReflectionPad2d(3),
            nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0),
            nn.Tanh()
        ]

        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)
'''

"\nclass ResnetBlockNoDownsampling(nn.Module):\n    def __init__(self, dim, norm_layer=nn.BatchNorm2d, use_dropout=False, padding_type='reflect'):\n        super(ResnetBlockNoDownsampling, self).__init__()\n\n        conv_block = []\n        p = 0\n        if padding_type == 'reflect':\n            conv_block += [nn.ReflectionPad2d(1)]\n        elif padding_type == 'replicate':\n            conv_block += [nn.ReplicationPad2d(1)]\n        elif padding_type == 'zero':\n            p = 1\n\n        conv_block += [\n            nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=True),\n            norm_layer(dim),\n            nn.ReLU(True)\n        ]\n\n        if use_dropout:\n            conv_block += [nn.Dropout(0.5)]\n\n        conv_block += [\n            nn.ReflectionPad2d(1),\n            nn.Conv2d(dim, dim, kernel_size=3, padding=0, bias=True),\n            norm_layer(dim)\n        ]\n\n        self.conv_block = nn.Sequential(*conv_block)\n\n    def forward(self, x):\n        retu