In [None]:
import torch
import torch.nn as nn

class ResBlockUp(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.upsample = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.upsample(x)

class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=4, stride=2, padding=1),   # [64,16,16]
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1), # [128,8,8]
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True)
        )

        self.middle = nn.Sequential(
            nn.Conv2d(128, 128, kernel_size=3, padding=1),  # [128,8,8]
            nn.ReLU(inplace=True)
        )

        self.decoder = nn.Sequential(
            ResBlockUp(128, 64),      # [64,16,16]
            ResBlockUp(64, 32),       # [32,32,32]
            nn.Conv2d(32, 3, kernel_size=3, padding=1),  # Final color output
            nn.Tanh()
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.middle(x)
        x = self.decoder(x)
        return x

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(4, 64, kernel_size=4, stride=2, padding=1),  # 1 (gray) + 3 (color)
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),  # [256, 4, 4]
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Flatten(),
            nn.Linear(256 * 4 * 4, 1),
            nn.Sigmoid()
        )

    def forward(self, gray, color):
        x = torch.cat([gray, color], dim=1)  # Concatenate on channel dimension
        return self.model(x)



In [None]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import ToTensor, ToPILImage, Grayscale

# Custom Dataset that returns: grayscale image, color image
class CIFAR10GrayColor(Dataset):
    def __init__(self, train=True):
        self.dataset = datasets.CIFAR10(
            root='./data',
            train=train,
            download=True,
            transform=transforms.ToTensor()
        )
        self.to_gray = transforms.Grayscale(num_output_channels=1)
        self.to_pil = ToPILImage()
        self.to_tensor = ToTensor()

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        color_img, _ = self.dataset[idx]  # [3,32,32]
        pil_img = self.to_pil(color_img)  # Convert back to PIL
        gray_img = self.to_gray(pil_img)  # [1,32,32]
        gray_img = self.to_tensor(gray_img)

        return gray_img, color_img

# Create train/test dataloaders
train_dataset = CIFAR10GrayColor(train=True)
test_dataset = CIFAR10GrayColor(train=False)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# Example batch
gray_batch, color_batch = next(iter(train_loader))
print("Grayscale batch shape:", gray_batch.shape)  # [64, 1, 32, 32]
print("Color batch shape:", color_batch.shape)     # [64, 3, 32, 32]


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

num_epochs = 50
criterion_gan = nn.BCELoss()   #uncomment if running from scratch
criterion_l1 = nn.L1Loss()

lr = 0.0002
beta1 = 0.5
beta2 = 0.999
lambda_l1 = 100

generator = Generator().to(device)
discriminator = Discriminator().to(device)

optimizer_g = optim.Adam(generator.parameters(), lr=lr, betas=(beta1, beta2))
optimizer_d = optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, beta2))

for epoch in range(num_epochs):
    g_gan_loss_epoch = 0.0
    g_l1_loss_epoch = 0.0
    g_total_loss_epoch = 0.0
    d_loss_epoch = 0.0

    for batch_idx, (gray, color) in enumerate(train_loader):
        gray = gray.to(device)
        color = color.to(device)

        real_labels = torch.full((gray.size(0), 1), 0.9, device=device)  # Label smoothing
        fake_labels = torch.zeros((gray.size(0), 1), device=device)

        # --- Train Discriminator every alternate step ---
        if batch_idx % 2 == 0:
            optimizer_d.zero_grad()
            with torch.no_grad():
                fake_color = generator(gray)

            real_output = discriminator(gray, color)
            fake_output = discriminator(gray, fake_color)

            d_loss_real = criterion_gan(real_output, real_labels)
            d_loss_fake = criterion_gan(fake_output, fake_labels)
            d_loss = (d_loss_real + d_loss_fake) / 2
            d_loss.backward()
            optimizer_d.step()
        else:
            d_loss = torch.tensor(0.0)

        # --- Train Generator (twice per batch) ---
        g_gan_loss = g_l1_loss = g_total_loss = 0.0
        for _ in range(2):
            optimizer_g.zero_grad()
            fake_color = generator(gray)
            output = discriminator(gray, fake_color)

            g_gan = criterion_gan(output, real_labels)
            g_l1 = criterion_l1(fake_color, color)
            g_total = g_gan + lambda_l1 * g_l1

            g_total.backward()
            optimizer_g.step()

            g_gan_loss += g_gan.item()
            g_l1_loss += g_l1.item()
            g_total_loss += g_total.item()

        g_gan_loss_epoch += g_gan_loss / 2
        g_l1_loss_epoch += g_l1_loss / 2
        g_total_loss_epoch += g_total_loss / 2
        d_loss_epoch += d_loss.item()

        # --- Print percentage progress ---
        percent = 100 * (batch_idx + 1) / len(train_loader)
        print(f"\rEpoch [{epoch+51}/{num_epochs+50}] Progress: {percent:.2f}%", end='', flush=True)

    # --- Epoch Summary ---
    avg_d = d_loss_epoch / len(train_loader)
    avg_gg = g_gan_loss_epoch / len(train_loader)
    avg_l1 = g_l1_loss_epoch / len(train_loader)
    avg_gtotal = g_total_loss_epoch / len(train_loader)

    print(f"\n\n=== Epoch [{epoch+1}/{num_epochs}] Summary ===")
    print(f"Discriminator Loss: {avg_d:.4f}")
    print(f"Generator GAN Loss: {avg_gg:.4f}")
    print(f"Generator L1 Loss : {avg_l1:.4f}")
    print(f"Generator Total Loss (GAN + L1): {avg_gtotal:.4f}")
    print("=============================================")

    # --- Visualize Samples ---
    generator.eval()
    with torch.no_grad():
        sample_gray = gray[0].unsqueeze(0).to(device)
        sample_real_color = color[0].cpu()
        sample_fake_color = generator(sample_gray).squeeze(0).cpu()

        sample_gray_show = sample_gray.squeeze(0).repeat(3, 1, 1).cpu()

        fig, axs = plt.subplots(1, 3, figsize=(10, 3))
        axs[0].imshow(sample_real_color.permute(1, 2, 0))
        axs[0].set_title("Actual Colored Image")
        axs[0].axis('off')

        axs[1].imshow(sample_gray_show.permute(1, 2, 0), cmap='gray')
        axs[1].set_title("Grayscale Input")
        axs[1].axis('off')

        axs[2].imshow(sample_fake_color.permute(1, 2, 0).clip(0, 1))
        axs[2].set_title("Recolored by Generator")
        axs[2].axis('off')

        plt.tight_layout()
        plt.show()
    generator.train()


In [None]:
import torch
import random
from torchvision.utils import make_grid
import matplotlib.pyplot as plt

import matplotlib.pyplot as plt
import torch

def visualize_recolor_samples(generator, train_loader, test_loader, use_train=True, num_samples=4, device=None):
    generator.eval()

    loader = train_loader if use_train else test_loader
    data_iter = iter(loader)

    try:
        gray_imgs, color_imgs = next(data_iter)
    except StopIteration:
        print("Empty loader!")
        return

    # Take only the first few samples
    gray_imgs = gray_imgs[:num_samples].to(device)
    color_imgs = color_imgs[:num_samples].to(device)

    # Ensure grayscale input has correct shape: [B, 1, H, W]
    if gray_imgs.ndim == 3:
        gray_imgs = gray_imgs.unsqueeze(1)  # add channel dimension if missing

    with torch.no_grad():
        fake_color_imgs = generator(gray_imgs).cpu()

    gray_imgs = gray_imgs.cpu()
    color_imgs = color_imgs.cpu()

    for i in range(num_samples):
        fig, axs = plt.subplots(1, 3, figsize=(12, 4))

        axs[0].imshow(color_imgs[i].permute(1, 2, 0).clip(0, 1))
        axs[0].set_title("Actual Colored Image")
        axs[0].axis('off')

        axs[1].imshow(gray_imgs[i][0], cmap='gray')
        axs[1].set_title("Grayscale Input")
        axs[1].axis('off')

        axs[2].imshow(fake_color_imgs[i].permute(1, 2, 0).clip(0, 1))
        axs[2].set_title("Recolored by Generator")
        axs[2].axis('off')

        plt.tight_layout()
        plt.show()

    generator.train()



In [None]:
visualize_recolor_samples(generator, train_loader, test_loader, use_train=False, num_samples=50, device=device)


In [None]:
# Save locally
torch.save(generator.state_dict(), f'generator_epoch{epoch+1}.pth')
torch.save(discriminator.state_dict(), f'discriminator_epoch{epoch+1}.pth')
from google.colab import files

# Download to local machine
files.download(f'generator_epoch{epoch+1}.pth')
files.download(f'discriminator_epoch{epoch+1}.pth')
# Save checkpoint
torch.save({
    'epoch': epoch+1,
    'generator_state_dict': generator.state_dict(),
    'discriminator_state_dict': discriminator.state_dict(),
    'optimizer_g_state_dict': optimizer_g.state_dict(),
    'optimizer_d_state_dict': optimizer_d.state_dict(),
}, f'model_checkpoint_epoch{epoch+1}.pth')

# Download
files.download(f'model_checkpoint_epoch{epoch+1}.pth')


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

num_epochs = 100
# criterion_gan = nn.BCELoss()   #uncomment if running from scratch
# criterion_l1 = nn.L1Loss()

lr = 0.0002
beta1 = 0.5
beta2 = 0.999
lambda_l1 = 250

#generator = Generator().to(device)
#discriminator = Discriminator().to(device)

#optimizer_g = optim.Adam(generator.parameters(), lr=lr, betas=(beta1, beta2))
#optimizer_d = optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, beta2))

for epoch in range(num_epochs):
    g_gan_loss_epoch = 0.0
    g_l1_loss_epoch = 0.0
    g_total_loss_epoch = 0.0
    d_loss_epoch = 0.0

    for batch_idx, (gray, color) in enumerate(train_loader):
        gray = gray.to(device)
        color = color.to(device)

        real_labels = torch.full((gray.size(0), 1), 0.9, device=device)  # Label smoothing
        fake_labels = torch.zeros((gray.size(0), 1), device=device)

        # --- Train Discriminator every alternate step ---
        if batch_idx % 2 == 0:
            optimizer_d.zero_grad()
            with torch.no_grad():
                fake_color = generator(gray)

            real_output = discriminator(gray, color)
            fake_output = discriminator(gray, fake_color)

            d_loss_real = criterion_gan(real_output, real_labels)
            d_loss_fake = criterion_gan(fake_output, fake_labels)
            d_loss = (d_loss_real + d_loss_fake) / 2
            d_loss.backward()
            optimizer_d.step()
        else:
            d_loss = torch.tensor(0.0)

        # --- Train Generator (twice per batch) ---
        g_gan_loss = g_l1_loss = g_total_loss = 0.0
        for _ in range(2):
            optimizer_g.zero_grad()
            fake_color = generator(gray)
            output = discriminator(gray, fake_color)

            g_gan = criterion_gan(output, real_labels)
            g_l1 = criterion_l1(fake_color, color)
            g_total = g_gan + lambda_l1 * g_l1

            g_total.backward()
            optimizer_g.step()

            g_gan_loss += g_gan.item()
            g_l1_loss += g_l1.item()
            g_total_loss += g_total.item()

        g_gan_loss_epoch += g_gan_loss / 2
        g_l1_loss_epoch += g_l1_loss / 2
        g_total_loss_epoch += g_total_loss / 2
        d_loss_epoch += d_loss.item()

        # --- Print percentage progress ---
        percent = 100 * (batch_idx + 1) / len(train_loader)
        print(f"\rEpoch [{epoch+101}/{num_epochs+100}] Progress: {percent:.2f}%", end='', flush=True)

    # --- Epoch Summary ---
    avg_d = d_loss_epoch / len(train_loader)
    avg_gg = g_gan_loss_epoch / len(train_loader)
    avg_l1 = g_l1_loss_epoch / len(train_loader)
    avg_gtotal = g_total_loss_epoch / len(train_loader)

    print(f"\n\n=== Epoch [{epoch+1}/{num_epochs}] Summary ===")
    print(f"Discriminator Loss: {avg_d:.4f}")
    print(f"Generator GAN Loss: {avg_gg:.4f}")
    print(f"Generator L1 Loss : {avg_l1:.4f}")
    print(f"Generator Total Loss (GAN + L1): {avg_gtotal:.4f}")
    print("=============================================")

    # --- Visualize Samples ---
    generator.eval()
    with torch.no_grad():
        sample_gray = gray[0].unsqueeze(0).to(device)
        sample_real_color = color[0].cpu()
        sample_fake_color = generator(sample_gray).squeeze(0).cpu()

        sample_gray_show = sample_gray.squeeze(0).repeat(3, 1, 1).cpu()

        fig, axs = plt.subplots(1, 3, figsize=(10, 3))
        axs[0].imshow(sample_real_color.permute(1, 2, 0))
        axs[0].set_title("Actual Colored Image")
        axs[0].axis('off')

        axs[1].imshow(sample_gray_show.permute(1, 2, 0), cmap='gray')
        axs[1].set_title("Grayscale Input")
        axs[1].axis('off')

        axs[2].imshow(sample_fake_color.permute(1, 2, 0).clip(0, 1))
        axs[2].set_title("Recolored by Generator")
        axs[2].axis('off')

        plt.tight_layout()
        plt.show()
    generator.train()

In [None]:
# Save locally
torch.save(generator.state_dict(), f'generator_epoch{epoch+1}.pth')
torch.save(discriminator.state_dict(), f'discriminator_epoch{epoch+1}.pth')
from google.colab import files

# Download to local machine
files.download(f'generator_epoch{epoch+1}.pth')
files.download(f'discriminator_epoch{epoch+1}.pth')
# Save checkpoint
torch.save({
    'epoch': epoch+1,
    'generator_state_dict': generator.state_dict(),
    'discriminator_state_dict': discriminator.state_dict(),
    'optimizer_g_state_dict': optimizer_g.state_dict(),
    'optimizer_d_state_dict': optimizer_d.state_dict(),
}, f'model_checkpoint_epoch{epoch+1}.pth')

# Download
files.download(f'model_checkpoint_epoch{epoch+1}.pth')