In [None]:
import os
import time
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from tqdm.auto import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torchvision.utils import make_grid, save_image

# Set random seed for reproducibility
torch.manual_seed(42)

# Select the computational device (GPU T4 x2 is ideal on Kaggle)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
# --- CONFIGURATION & HYPERPARAMETERS ---

# Image settings
IMG_WIDTH = 256
IMG_HEIGHT = 256
CHANNELS_IMG = 3

# Training settings
BATCH_SIZE = 16          # Number of images processed at once
LEARNING_RATE = 2e-4     # How fast the optimizer updates weights
BETA1 = 0.5              # Specific setting for the Adam optimizer
BETA2 = 0.999            # Specific setting for the Adam optimizer
NUM_EPOCHS = 100         # Total passes through the entire dataset
L1_LAMBDA = 100          # Weight for L1 loss (reconstruction quality)

# Path to your dataset on Kaggle (Update this based on your selected data)
# Example for CMP Facades: '/kaggle/input/facades-dataset/facades/train/'
DATA_PATH = "/kaggle/input/notebooks/japanajitsinghgandhi/pix2pix-facades-dataset"

In [None]:
class FacadesDataset(Dataset):
    def __init__(self, root_dir):
        self.root_dir = root_dir
        # Only include actual image files
        self.list_files = [f for f in os.listdir(self.root_dir) 
                          if f.lower().endswith(('.png', '.jpg', '.jpeg'))]

    def __len__(self):
        return len(self.list_files)

    def __getitem__(self, index):
        img_file = self.list_files[index]
        img_path = os.path.join(self.root_dir, img_file)
        
        # 1. Open and ensure RGB
        image = Image.open(img_path).convert("RGB")
        image = np.array(image)
        
        # 2. Split the side-by-side image (Target is left, Input is right)
        # These MUST be defined before the transform
        target_img_arr = image[:, :256, :]
        input_img_arr = image[:, 256:, :]

        # 3. Define the transform with Resize to prevent U-Net dimension errors
        transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize((256, 256)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

        # 4. Apply transforms and return
        return transform(input_img_arr), transform(target_img_arr)

# Re-initialize the loader
train_dataset = FacadesDataset(root_dir='/kaggle/input/notebooks/japanajitsinghgandhi/pix2pix-facades-dataset')
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

print(f"âœ… Ready! Loaded {len(train_dataset)} image pairs.")

In [None]:
# Helper for Downsampling (Encoder)
def down_block(in_c, out_c, normalize=True, dropout=0.0):
    layers = [nn.Conv2d(in_c, out_c, 4, stride=2, padding=1, bias=False)]
    if normalize:
        layers.append(nn.BatchNorm2d(out_c))
    layers.append(nn.LeakyReLU(0.2))
    if dropout:
        layers.append(nn.Dropout(dropout))
    return nn.Sequential(*layers)

# Helper for Upsampling (Decoder)
def up_block(in_c, out_c, dropout=0.0):
    layers = [
        nn.ConvTranspose2d(in_c, out_c, 4, stride=2, padding=1, bias=False),
        nn.BatchNorm2d(out_c),
        nn.ReLU(inplace=True)
    ]
    if dropout:
        layers.append(nn.Dropout(dropout))
    return nn.Sequential(*layers)

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        # Encoder (Contracting Path)
        self.d1 = down_block(3, 64, normalize=False) # 128x128
        self.d2 = down_block(64, 128)               # 64x64
        self.d3 = down_block(128, 256)              # 32x32
        self.d4 = down_block(256, 512, dropout=0.5) # 16x16
        self.d5 = down_block(512, 512, dropout=0.5) # 8x8
        self.d6 = down_block(512, 512, dropout=0.5) # 4x4
        self.d7 = down_block(512, 512, dropout=0.5) # 2x2
        self.d8 = down_block(512, 512, normalize=False) # 1x1

        # Decoder (Expanding Path) with Skip Connections
        self.u1 = up_block(512, 512, dropout=0.5)      # 2x2
        self.u2 = up_block(1024, 512, dropout=0.5)     # 4x4
        self.u3 = up_block(1024, 512, dropout=0.5)     # 8x8
        self.u4 = up_block(1024, 512, dropout=0.5)     # 16x16
        self.u5 = up_block(1024, 256)                  # 32x32
        self.u6 = up_block(512, 128)                   # 64x64
        self.u7 = up_block(256, 64)                    # 128x128
        
        self.final = nn.Sequential(
            nn.ConvTranspose2d(128, 3, 4, stride=2, padding=1),
            nn.Tanh() # Output range [-1, 1]
        )

    def forward(self, x):
        # Encoder
        en1 = self.d1(x)
        en2 = self.d2(en1)
        en3 = self.d3(en2)
        en4 = self.d4(en3)
        en5 = self.d5(en4)
        en6 = self.d6(en5)
        en7 = self.d7(en6)
        en8 = self.d8(en7)

        # Decoder with Skip Connections (Concatenation)
        de1 = self.u1(en8)
        de2 = self.u2(torch.cat([de1, en7], 1))
        de3 = self.u3(torch.cat([de2, en6], 1))
        de4 = self.u4(torch.cat([de3, en5], 1))
        de5 = self.u5(torch.cat([de4, en4], 1))
        de6 = self.u6(torch.cat([de5, en3], 1))
        de7 = self.u7(torch.cat([de6, en2], 1))
        
        return self.final(torch.cat([de7, en1], 1))

# Initialize
gen = Generator().to(device)

In [None]:
# --- LOSS FUNCTIONS ---
# Binary Cross Entropy with Logits for the Adversarial Loss
# This measures how well the Generator fools the Discriminator
BCE_LOSS = nn.BCEWithLogitsLoss()

# L1 Loss for the Reconstruction Loss
# This measures how physically similar the generated image is to the real target
L1_LOSS = nn.L1Loss()

# --- OPTIMIZERS ---
# We use the Adam optimizer, which is the standard for GANs
gen_optimizer = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(BETA1, BETA2))
disc_optimizer = optim.Adam(disc.parameters(), lr=LEARNING_RATE, betas=(BETA1, BETA2))

# --- UPDATED: SCALER FOR MIXED PRECISION ---
# Using the new PyTorch 2.x syntax
scaler_gen = torch.amp.GradScaler('cuda')
scaler_disc = torch.amp.GradScaler('cuda')

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        
        # PatchGAN takes both the Input Image and the Target Image concatenated
        # So in_channels = 3 (input) + 3 (target) = 6
        self.model = nn.Sequential(
            down_block(6, 64, normalize=False), # 128x128
            down_block(64, 128),               # 64x64
            down_block(128, 256),              # 32x32
            down_block(256, 512),              # 31x31 (stride=1 in final layers)
            nn.Conv2d(512, 1, 4, padding=1)    # 30x30 output patch
        )

    def forward(self, x, y):
        # x is the input sketch, y is the generated/real image
        # We concatenate them along the channel dimension
        input_data = torch.cat([x, y], dim=1)
        return self.model(input_data)

# Initialize
disc = Discriminator().to(device)

In [None]:
def show_progress(gen, input_image, target_image, epoch):
    gen.eval() # Set to evaluation mode
    with torch.no_grad():
        # Generate the fake image
        fake_image = gen(input_image)
        
        # Denormalize images from [-1, 1] back to [0, 1] for plotting
        input_plot = input_image[0] * 0.5 + 0.5
        target_plot = target_image[0] * 0.5 + 0.5
        fake_plot = fake_image[0] * 0.5 + 0.5
        
        # Create the plot
        plt.figure(figsize=(12, 4))
        titles = ['Input Sketch', 'Target Photo', 'Generated Photo']
        images = [input_plot, target_plot, fake_plot]
        
        for i in range(3):
            plt.subplot(1, 3, i+1)
            plt.title(f"{titles[i]} (Epoch {epoch})")
            plt.imshow(images[i].permute(1, 2, 0).cpu().numpy())
            plt.axis('off')
        
        plt.show()
    gen.train() # Set back to training mode

In [None]:
for epoch in range(NUM_EPOCHS):
    loop = tqdm(train_loader, leave=True)
    
    for idx, (input_img, target_img) in enumerate(loop):
        input_img, target_img = input_img.to(device), target_img.to(device)

        # --- TRAIN DISCRIMINATOR ---
        disc_optimizer.zero_grad()
        # We still use autocast for speed, but perform standard backward/step
        with torch.amp.autocast('cuda'):
            fake_img = gen(input_img)
            disc_real = disc(input_img, target_img)
            disc_fake = disc(input_img, fake_img.detach())
            
            disc_real_loss = BCE_LOSS(disc_real, torch.ones_like(disc_real))
            disc_fake_loss = BCE_LOSS(disc_fake, torch.zeros_like(disc_fake))
            disc_loss = (disc_real_loss + disc_fake_loss) / 2

        disc_loss.backward()
        disc_optimizer.step()

        # --- TRAIN GENERATOR ---
        gen_optimizer.zero_grad()
        with torch.amp.autocast('cuda'):
            disc_fake = disc(input_img, fake_img)
            gen_fake_loss = BCE_LOSS(disc_fake, torch.ones_like(disc_fake))
            l1_loss = L1_LOSS(fake_img, target_img) * L1_LAMBDA
            gen_loss = gen_fake_loss + l1_loss

        gen_loss.backward()
        gen_optimizer.step()
        if idx % 10 == 0:
            loop.set_postfix(D_loss=disc_loss.item(), G_loss=gen_loss.item())

    # Show results at the end of each epoch
    show_progress(gen, input_img, target_img, epoch)