In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [1]:
import torch
import torch.nn as nn

# -----------------------------------
# 1. U-Net Generator
# -----------------------------------
class UNetGenerator(nn.Module):
    def __init__(self, input_channels=1, output_channels=3):
        super(UNetGenerator, self).__init__()

        def down_block(in_channels, out_channels, normalize=True):
            layers = [nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1)]
            if normalize:
                layers.append(nn.BatchNorm2d(out_channels))
            layers.append(nn.LeakyReLU(0.2))
            return nn.Sequential(*layers)

        def up_block(in_channels, out_channels, dropout=False):
            layers = [nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1)]
            layers.append(nn.BatchNorm2d(out_channels))
            if dropout:
                layers.append(nn.Dropout(0.5))
            layers.append(nn.ReLU())
            return nn.Sequential(*layers)

        self.down1 = down_block(input_channels, 64, normalize=False)
        self.down2 = down_block(64, 128)
        self.down3 = down_block(128, 256)
        self.down4 = down_block(256, 512)
        self.down5 = down_block(512, 512)
        self.down6 = down_block(512, 512)
        self.down7 = down_block(512, 512)
        self.down8 = down_block(512, 512, normalize=False)

        self.up1 = up_block(512, 512, dropout=True)
        self.up2 = up_block(1024, 512, dropout=True)
        self.up3 = up_block(1024, 512, dropout=True)
        self.up4 = up_block(1024, 512)
        self.up5 = up_block(1024, 256)
        self.up6 = up_block(512, 128)
        self.up7 = up_block(256, 64)

        self.final = nn.ConvTranspose2d(128, output_channels, kernel_size=4, stride=2, padding=1)
        self.tanh = nn.Tanh()

    def forward(self, x):
        d1 = self.down1(x)
        d2 = self.down2(d1)
        d3 = self.down3(d2)
        d4 = self.down4(d3)
        d5 = self.down5(d4)
        d6 = self.down6(d5)
        d7 = self.down7(d6)
        d8 = self.down8(d7)

        u1 = self.up1(d8)
        u1 = torch.cat([u1, d7], dim=1)
        u2 = self.up2(u1)
        u2 = torch.cat([u2, d6], dim=1)
        u3 = self.up3(u2)
        u3 = torch.cat([u3, d5], dim=1)
        u4 = self.up4(u3)
        u4 = torch.cat([u4, d4], dim=1)
        u5 = self.up5(u4)
        u5 = torch.cat([u5, d3], dim=1)
        u6 = self.up6(u5)
        u6 = torch.cat([u6, d2], dim=1)
        u7 = self.up7(u6)
        u7 = torch.cat([u7, d1], dim=1)

        final = self.final(u7)
        return self.tanh(final)


# -----------------------------------
# 2. PatchGAN Discriminator
# -----------------------------------
class PatchGANDiscriminator(nn.Module):
    def __init__(self, input_channels=4):  # (Gray + Color Image)
        super(PatchGANDiscriminator, self).__init__()

        def discriminator_block(in_channels, out_channels, normalization=True):
            layers = [nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1)]
            if normalization:
                layers.append(nn.BatchNorm2d(out_channels))
            layers.append(nn.LeakyReLU(0.2))
            return nn.Sequential(*layers)

        self.model = nn.Sequential(
            discriminator_block(input_channels, 64, normalization=False),
            discriminator_block(64, 128),
            discriminator_block(128, 256),
            discriminator_block(256, 512),
            nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=1),
            nn.Sigmoid()
        )

    def forward(self, gray_img, color_img):
        combined = torch.cat([gray_img, color_img], dim=1)
        return self.model(combined)


# # -----------------------------------
# # Testing the Models (Optional)
# # -----------------------------------
# if __name__ == "__main__":
#     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#     # Test Generator
#     generator = UNetGenerator().to(device)
#     input_tensor = torch.randn(1, 1, 256, 256).to(device)  # 1 grayscale image
#     output_tensor = generator(input_tensor)
#     print(f"Generator Output Shape: {output_tensor.shape}")  # Should be (1, 3, 256, 256)

#     # Test Discriminator
#     discriminator = PatchGANDiscriminator().to(device)
#     gray_tensor = torch.randn(1, 1, 256, 256).to(device)  # 1 grayscale image
#     color_tensor = torch.randn(1, 3, 256, 256).to(device)  # 1 color image
#     output_disc = discriminator(gray_tensor, color_tensor)
#     print(f"Discriminator Output Shape: {output_disc.shape}")  # Should be (1, 1, 30, 30)


In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import os
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from tqdm import tqdm
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Hyperparameters
BATCH_SIZE = 8
IMAGE_SIZE = 256
EPOCHS = 100
LEARNING_RATE = 0.0002
BETA1 = 0.5

# Dataset Class
class ColorizationDataset(Dataset):
    def __init__(self, gray_dir, color_dir, transform=None):
        self.gray_dir = gray_dir
        self.color_dir = color_dir
        self.transform = transform
        self.image_filenames = os.listdir(gray_dir)

    def __len__(self):
        return len(self.image_filenames)

    def __getitem__(self, idx):
        gray_path = os.path.join(self.gray_dir, self.image_filenames[idx])
        color_path = os.path.join(self.color_dir, self.image_filenames[idx])

        gray_image = Image.open(gray_path).convert("L")  # Load as grayscale
        color_image = Image.open(color_path).convert("RGB")  # Load as RGB

        if self.transform:
            gray_image = self.transform(gray_image)
            color_image = self.transform(color_image)

        return gray_image, color_image

# Data Transformations
transform = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

# Load Dataset
dataset = ColorizationDataset("/kaggle/input/landscape-image-colorization/landscape Images/gray", "/kaggle/input/landscape-image-colorization/landscape Images/color", transform=transform)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

# Initialize Models
generator = UNetGenerator().to(device)
discriminator = PatchGANDiscriminator().to(device)

# Loss Functions
criterion_GAN = nn.BCEWithLogitsLoss()
criterion_L1 = nn.L1Loss()  # L1 loss for better realism

# Optimizers
optimizer_G = optim.Adam(generator.parameters(), lr=LEARNING_RATE, betas=(BETA1, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=LEARNING_RATE, betas=(BETA1, 0.999))

# Training Loop
for epoch in range(EPOCHS):
    for gray_images, real_images in tqdm(dataloader, desc=f"Epoch {epoch+1}/{EPOCHS}"):

        gray_images, real_images = gray_images.to(device), real_images.to(device)

        # 1. Train Discriminator
        optimizer_D.zero_grad()

        # Real images
        real_labels = torch.ones((gray_images.size(0), 1, 15, 15), device=device)  # PatchGAN output size (adjusted)
        fake_labels = torch.zeros((gray_images.size(0), 1, 15, 15), device=device)

        output_real = discriminator(gray_images, real_images)
        loss_real = criterion_GAN(output_real, real_labels)

        # Fake images
        fake_images = generator(gray_images)
        output_fake = discriminator(gray_images, fake_images.detach())
        loss_fake = criterion_GAN(output_fake, fake_labels)

        # Compute total loss and update discriminator
        loss_D = (loss_real + loss_fake) / 2
        loss_D.backward()
        optimizer_D.step()

        # 2. Train Generator
        optimizer_G.zero_grad()

        output_fake = discriminator(gray_images, fake_images)
        loss_GAN = criterion_GAN(output_fake, real_labels)
        loss_L1 = criterion_L1(fake_images, real_images) * 100  # L1 loss for realism

        loss_G = loss_GAN + loss_L1
        loss_G.backward()
        optimizer_G.step()

    # Save checkpoint every 10 epochs
    if (epoch + 1) % 10 == 0:
        torch.save(generator.state_dict(), f"colorization_gan_{epoch+1}.pth")

# Save Final Model
torch.save(generator.state_dict(), "colorization_gan.pth")
print("Training Complete! Model Saved as colorization_gan.pth")


Epoch 1/100: 100%|██████████| 892/892 [03:23<00:00,  4.38it/s]
Epoch 2/100: 100%|██████████| 892/892 [03:26<00:00,  4.32it/s]
Epoch 3/100: 100%|██████████| 892/892 [03:26<00:00,  4.32it/s]
Epoch 4/100: 100%|██████████| 892/892 [03:26<00:00,  4.32it/s]
Epoch 5/100: 100%|██████████| 892/892 [03:26<00:00,  4.33it/s]
Epoch 6/100: 100%|██████████| 892/892 [03:26<00:00,  4.32it/s]
Epoch 7/100: 100%|██████████| 892/892 [03:26<00:00,  4.32it/s]
Epoch 8/100: 100%|██████████| 892/892 [03:26<00:00,  4.32it/s]
Epoch 9/100: 100%|██████████| 892/892 [03:26<00:00,  4.32it/s]
Epoch 10/100: 100%|██████████| 892/892 [03:26<00:00,  4.32it/s]
Epoch 11/100: 100%|██████████| 892/892 [03:26<00:00,  4.33it/s]
Epoch 12/100: 100%|██████████| 892/892 [03:26<00:00,  4.32it/s]
Epoch 13/100: 100%|██████████| 892/892 [03:26<00:00,  4.32it/s]
Epoch 14/100: 100%|██████████| 892/892 [03:26<00:00,  4.32it/s]
Epoch 15/100: 100%|██████████| 892/892 [03:26<00:00,  4.33it/s]
Epoch 17/100: 100%|██████████| 892/892 [03:26<00:

KeyboardInterrupt: 