In [None]:
# @title Image Colorization GAN (Complete Working Version)
# @markdown ### Setup
%%capture
!pip install torch torchvision opencv-python tqdm matplotlib --quiet
!nvidia-smi  # Verify GPU

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import cv2
import numpy as np
import os
from pathlib import Path
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from google.colab import drive

# Mount Google Drive
drive.mount('/content/drive')

# @title Configuration
# @title Configuration (Updated Single Version)
class Config:
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    num_epochs = 30
    batch_size = 16
    image_size = 256  # Must be 256 to match discriminator expectations
    dataset_path = "/content/drive/MyDrive/MINIPROJECT/train2017"
    save_dir = "/content/drive/MyDrive/MINIPROJECT/colorization_results/saved_images"
    checkpoint_dir = "/content/drive/MyDrive/MINIPROJECT/colorization_results/checkpoints"

    @classmethod
    def initialize_dirs(cls):
        os.makedirs(cls.save_dir, exist_ok=True)
        os.makedirs(cls.checkpoint_dir, exist_ok=True)

# Initialize directories
Config.initialize_dirs()

# Mixed precision training
scaler = torch.cuda.amp.GradScaler()

# @title Dataset Class
class ColorizationDataset(Dataset):
    def __init__(self, root_dir):
        self.root_dir = Path(root_dir)
        valid_extensions = ('.png', '.jpg', '.jpeg', '.JPG', '.PNG')
        self.image_files = [f for f in self.root_dir.glob('*') if f.suffix.lower() in valid_extensions]

        if not self.image_files:
            available = [f.suffix for f in self.root_dir.glob('*')]
            raise ValueError(f"No valid images found. Found extensions: {set(available)}")

        self.transform = transforms.Compose([
            transforms.Resize((Config.image_size, Config.image_size)),
            transforms.ToTensor()
        ])

    def __len__(self): return len(self.image_files)

    def __getitem__(self, idx):
        try:
            with Image.open(str(self.image_files[idx])) as img:
                img = img.convert('RGB')
                img = self.transform(img)
                img = img.permute(1, 2, 0).numpy()
                lab_img = cv2.cvtColor(img, cv2.COLOR_RGB2LAB)
                L = torch.FloatTensor(lab_img[:,:,0]) / 50.0 - 1.0
                ab = torch.FloatTensor(lab_img[:,:,1:].transpose(2, 0, 1)) / 110.0
                return L.unsqueeze(0), ab
        except Exception as e:
            print(f"Error loading {self.image_files[idx]}: {e}")
            H, W = Config.image_size, Config.image_size
            return torch.rand(1, H, W)*2-1, torch.rand(2, H, W)

# @title Model Architecture
class UNetDown(nn.Module):
    def __init__(self, in_size, out_size, normalize=True, dropout=0.0):
        super().__init__()
        layers = [nn.Conv2d(in_size, out_size, 4, 2, 1, bias=False)]
        if normalize: layers.append(nn.InstanceNorm2d(out_size))
        layers.extend([nn.LeakyReLU(0.2), nn.Dropout(dropout)] if dropout else [nn.LeakyReLU(0.2)])
        self.model = nn.Sequential(*layers)

    def forward(self, x): return self.model(x)

class UNetUp(nn.Module):
    def __init__(self, in_size, out_size, dropout=0.0):
        super().__init__()
        self.model = nn.Sequential(
            nn.ConvTranspose2d(in_size, out_size, 4, 2, 1, bias=False),
            nn.InstanceNorm2d(out_size),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout) if dropout else nn.Identity()
        )

    def forward(self, x, skip_input):
        x = self.model(x)
        return torch.cat((x, skip_input), 1)

class GeneratorUNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.down1 = UNetDown(1, 64, normalize=False)
        self.down2 = UNetDown(64, 128)
        self.down3 = UNetDown(128, 256)
        self.down4 = UNetDown(256, 512, dropout=0.5)
        self.down5 = UNetDown(512, 512, dropout=0.5)
        self.down6 = UNetDown(512, 512, dropout=0.5)
        self.down7 = UNetDown(512, 512, dropout=0.5)
        self.down8 = UNetDown(512, 512, normalize=False, dropout=0.5)

        self.up1 = UNetUp(512, 512, dropout=0.5)
        self.up2 = UNetUp(1024, 512, dropout=0.5)
        self.up3 = UNetUp(1024, 512, dropout=0.5)
        self.up4 = UNetUp(1024, 512, dropout=0.5)
        self.up5 = UNetUp(1024, 256)
        self.up6 = UNetUp(512, 128)
        self.up7 = UNetUp(256, 64)

        self.final = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.ZeroPad2d((1, 0, 1, 0)),
            nn.Conv2d(128, 2, 4, padding=1),
            nn.Tanh()
        )

    def forward(self, x):
        d1 = self.down1(x)
        d2 = self.down2(d1)
        d3 = self.down3(d2)
        d4 = self.down4(d3)
        d5 = self.down5(d4)
        d6 = self.down6(d5)
        d7 = self.down7(d6)
        d8 = self.down8(d7)
        u1 = self.up1(d8, d7)
        u2 = self.up2(u1, d6)
        u3 = self.up3(u2, d5)
        u4 = self.up4(u3, d4)
        u5 = self.up5(u4, d3)
        u6 = self.up6(u5, d2)
        u7 = self.up7(u6, d1)
        return self.final(u7)

# @title Fixed Discriminator Architecture
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        def block(in_filters, out_filters, normalization=True):
            layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)]
            if normalization: layers.append(nn.InstanceNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(3, 64, normalization=False),
            *block(64, 128),
            *block(128, 256),
            *block(256, 512),
            nn.ZeroPad2d((1, 1, 1, 1)),  # Adjusted padding
            nn.Conv2d(512, 1, 4, padding=1, bias=False)  # Outputs 30x30
        )

    def forward(self, img_A, img_B):
        return self.model(torch.cat((img_A, img_B), 1))
# Must be 256 to match discriminator expectations
    # ... (keep other config parameters the same)

# @title Training Setup
def initialize_training():
    generator = GeneratorUNet().to(Config.device)
    discriminator = Discriminator().to(Config.device)
    optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
    optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
    criterion_GAN = nn.MSELoss()
    criterion_L1 = nn.L1Loss()

    # Verify dataset path
    if not Path(Config.dataset_path).exists():
        raise FileNotFoundError(f"Path not found: {Config.dataset_path}")

    dataset = ColorizationDataset(Config.dataset_path)
    dataloader = DataLoader(dataset, batch_size=Config.batch_size, shuffle=True, num_workers=2, pin_memory=True)
    return generator, discriminator, optimizer_G, optimizer_D, criterion_GAN, criterion_L1, dataloader

# @title Training Loop
def train_model():
    generator, discriminator, optimizer_G, optimizer_D, criterion_GAN, criterion_L1, dataloader = initialize_training()

    for epoch in range(Config.num_epochs):
        generator.train()
        progress_bar = tqdm(dataloader, desc=f'Epoch {epoch+1}/{Config.num_epochs}')

        for i, (real_A, real_B) in enumerate(progress_bar):
            real_A, real_B = real_A.to(Config.device), real_B.to(Config.device)

            # === Dynamically determine output size of the discriminator ===
            with torch.no_grad():
                output_shape = discriminator(real_A, real_B).shape[2:]
            valid = torch.ones((real_A.size(0), 1, *output_shape), device=Config.device)
            fake = torch.zeros_like(valid)

            # ======== Train Discriminator ========
            optimizer_D.zero_grad()
            with torch.cuda.amp.autocast():
                fake_B = generator(real_A)
                pred_real = discriminator(real_A, real_B)
                loss_real = criterion_GAN(pred_real, valid)
                pred_fake = discriminator(real_A, fake_B.detach())
                loss_fake = criterion_GAN(pred_fake, fake)
                loss_D = (loss_real + loss_fake) * 0.5
            scaler.scale(loss_D).backward()
            scaler.step(optimizer_D)

            # ======== Train Generator ========
            optimizer_G.zero_grad()
            with torch.cuda.amp.autocast():
                pred_fake = discriminator(real_A, fake_B)
                loss_GAN = criterion_GAN(pred_fake, valid)
                loss_L1 = criterion_L1(fake_B, real_B) * 100
                loss_G = loss_GAN + loss_L1
            scaler.scale(loss_G).backward()
            scaler.step(optimizer_G)
            scaler.update()

            # Update progress bar
            progress_bar.set_postfix({
                'D_loss': loss_D.item(),
                'G_loss': loss_G.item(),
                'L1': loss_L1.item()
            })

        # ======== Save model checkpoints and sample images ========
        torch.save(generator.state_dict(), f"{Config.checkpoint_dir}/generator_{epoch}.pth")

        if epoch % 5 == 0 or epoch == Config.num_epochs - 1:
            with torch.no_grad():
                fake_B = generator(real_A[:1])
                fake_B = fake_B.cpu().numpy().transpose(0, 2, 3, 1)[0] * 110.0
                real_A_ = (real_A[0].cpu().numpy().squeeze() + 1) * 50.0
                colorized = cv2.cvtColor(np.concatenate([real_A_[..., np.newaxis], fake_B], axis=2).astype(np.uint8), cv2.COLOR_LAB2RGB)

                plt.figure(figsize=(10, 5))
                plt.subplot(1, 2, 1)
                plt.imshow(real_A_[..., np.newaxis].repeat(3, axis=2), cmap='gray')
                plt.title("Input (L)")

                plt.subplot(1, 2, 2)
                plt.imshow(colorized)
                plt.title(f"Colorized (Epoch {epoch})")
                plt.savefig(f"{Config.save_dir}/epoch_{epoch}.png")
                plt.show()




# @title Start Training
print("=== System Check ===")
print(f"Device: {Config.device}")
print(f"Dataset path: {Config.dataset_path}")
print(f"Image size: {Config.image_size}")
print(f"Batch size: {Config.batch_size}")

train_model()



In [None]:
!ls "/content/drive/MyDrive/MINIPROJECT/train2017" | head -n 5  # Show first 5 files

000000061409.jpg
000000061439.jpg
000000061463.jpg
000000061492.jpg
000000061503.jpg
