In [1]:
# Import standard libraries and PyTorch modules for data handling, image processing, 
# neural network building, optimization, and training utilities.
import os
import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
from PIL import Image

import torch
import torch.nn as nn
import torchvision
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, datasets
from torch import optim
import torch.nn.functional as F
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr
from torchvision.utils import save_image
from torch_fidelity import calculate_metrics

from IPython.display import clear_output  # For clearing Jupyter output during training


### Hyperparameters

In [2]:
batch_size = 32  # Number of image pairs processed in one forward/backward pass
image_size = (256, 256)  # Target size of input and output images (height, width)

lr = 1e-4  # Learning rate for the optimizer
betas = (0.0, 0.9)  # Beta parameters for the Adam optimizer (commonly used in GANs)

disc_iter = 3  # Number of discriminator updates per generator update (used in some training regimes)

lambda_gp = 10  # Gradient penalty coefficient (used in WGAN-GP or other regularized GANs)


## Creating Dataset and DataLoader


In [3]:
train_directory = "/kaggle/input/pix2pix-dataset/maps/maps/train"
# Path to the training set containing paired satellite-map images

val_directory = "/kaggle/input/pix2pix-dataset/maps/maps/val"
# Path to the validation set for evaluating model performance during training


In [4]:
# Define image preprocessing transformations for the input (e.g., satellite) and target (e.g., map) images

input_transform = transforms.Compose([
    transforms.Resize(image_size),  # Resize the input image to the specified size (e.g., 256x256)
    transforms.ColorJitter(0.1),    # Apply slight random changes in brightness, contrast, saturation
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  
    # Normalize image pixel values from [0, 1] to [-1, 1] (standard for GANs)
])

target_transform = transforms.Compose([
    transforms.Resize(image_size),  # Resize the target image to match the input size
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  
    # Normalize target image pixel values to the same range [-1, 1]
])


In [5]:
class map_dataset(Dataset):
    def __init__(self, dir, input_transform, target_transfrom):
        """
        Args:
            dir (str): Directory containing the dataset images.
            input_transform (callable): Transformations to apply to the input (e.g., satellite) images.
            target_transfrom (callable): Transformations to apply to the target (e.g., map) images.
        """
        super().__init__()
        self.dir = dir
        self.input_transform = input_transform
        self.target_transform = target_transfrom
        self.list_files = os.listdir(self.dir)  # List of image filenames in the directory

    def __len__(self):
        # Return total number of image pairs
        return len(self.list_files)

    def __getitem__(self, index):
        # Load image from the dataset
        image_path = os.path.join(self.dir, self.list_files[index])
        image = transforms.ToTensor()(Image.open(image_path))  # Convert image to tensor

        # Split the image into input (X) and target (y)
        # Assumes original image is a concatenation of input and target images side-by-side
        X = image[:, :, :600]       # Left half (e.g., satellite image)
        y = image[:, :, 600:1200]   # Right half (e.g., map image)

        # Apply transformations
        if self.input_transform:
            X = self.input_transform(X)
        if self.target_transform:
            y = self.target_transform(y)

        return X, y


**Load training and validation datasets with corresponding transforms**

In [6]:
train_dataset = map_dataset(train_directory, input_transform, target_transform)
val_dataset = map_dataset(val_directory, target_transform, target_transform)


**Create DataLoaders for training and validation datasets**

In [7]:
train_loader = DataLoader(train_dataset, batch_size=batch_size)
val_loader = DataLoader(val_dataset, batch_size=batch_size)


# Disciminator


In [8]:
class Dis_Block(nn.Module):
    """
    Discriminator block with Conv2D, BatchNorm, and LeakyReLU activation.

    Args:
        in_channels (int): Number of input channels.
        out_channels (int): Number of output channels.
        stride (int): Stride size for convolution.

    Forward pass applies a 4x4 convolution with reflection padding,
    followed by batch normalization and LeakyReLU activation.
    """
    def __init__(self, in_channels, out_channels, stride):
        super().__init__()
        self.layer = nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
                      kernel_size=4, stride=stride, padding=1, padding_mode="reflect"),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2)
        )

    def forward(self, x):
        return self.layer(x)


In [9]:
class Discriminator(nn.Module):
    """
    PatchGAN discriminator for pix2pix.

    Combines input and target images (channel-wise), then passes through a series
    of convolutional layers to classify local image patches as real or fake.

    Input:
        X (Tensor): Input image (e.g., satellite)
        y (Tensor): Target image (e.g., map)
    
    Output:
        Tensor: Discriminator prediction (patch-level real/fake scores)
    """
    def __init__(self):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Conv2d(in_channels=6, out_channels=64, kernel_size=4, stride=2, padding=1, padding_mode="reflect"),  # 128x128
            nn.LeakyReLU(0.2),

            Dis_Block(64, 128, 2),   # 64x64
            Dis_Block(128, 256, 2),  # 32x32
            Dis_Block(256, 512, 1),  # 31x31

            nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=1, padding_mode="reflect")  # 30x30
        )

    def forward(self, X, y):
        X = torch.cat([X, y], dim=1)
        return self.layers(X)


## Generator Model


In [10]:
class DownBlock(nn.Module):
    """
    Encoder block for the U-Net generator.

    Applies a 4x4 Conv2D with stride 2, followed by BatchNorm and LeakyReLU.
    Used to downsample feature maps.
    """
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2)
        )

    def forward(self, X):
        return self.block(X)


class UpBlock(nn.Module):
    """
    Decoder block for the U-Net generator.

    Applies a 4x4 transposed convolution to upsample, followed by BatchNorm and ReLU.
    Optionally applies dropout for regularization.
    
    Args:
        dropout (bool): If True, applies Dropout(0.5) after upsampling.
    """
    def __init__(self, in_channels, out_channels, dropout=False):
        super().__init__()
        self.block = nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )
        self.dropout = dropout
        self.dropout_layer = nn.Dropout(0.5)

    def forward(self, X):
        X = self.block(X)
        return self.dropout_layer(X) if self.dropout else X


In [11]:
class Generator(nn.Module):
    """
    U-Net based generator for the pix2pix model.

    Encodes the input image into deep features using downsampling blocks,
    then decodes it with upsampling blocks and skip connections for
    high-resolution output.

    Args:
        in_channels (int): Number of input channels (default is 3 for RGB images).
    
    Output:
        Tensor: Generated image (same shape as input, with Tanh activation).
    """
    def __init__(self, in_channels=3):
        super().__init__()

        # Encoder: downsampling blocks
        self.de1 = nn.Sequential(
            nn.Conv2d(in_channels, 64, 4, 2, 1, padding_mode="reflect"),
            nn.ReLU()
        )  # 128x128
        self.de2 = DownBlock(64, 128)     # 64x64
        self.de3 = DownBlock(128, 256)    # 32x32
        self.de4 = DownBlock(256, 512)    # 16x16
        self.de5 = DownBlock(512, 512)    # 8x8
        self.de6 = DownBlock(512, 512)    # 4x4
        self.de7 = DownBlock(512, 512)    # 2x2
        self.de8 = nn.Sequential(         # 1x1
            nn.Conv2d(512, 512, 4, 2, 1),
            nn.LeakyReLU(0.2)
        )

        # Decoder: upsampling(with skip connections)
        self.up1 = UpBlock(512, 512, True)           # 2x2
        self.up2 = UpBlock(512*2, 512, True)         # 4x4
        self.up3 = UpBlock(512*2, 512, True)         # 8x8
        self.up4 = UpBlock(512*2, 512, True)         # 16x16
        self.up5 = UpBlock(512*2, 256)               # 32x32
        self.up6 = UpBlock(256*2, 128)               # 64x64
        self.up7 = UpBlock(128*2, 64)                # 128x128
        self.up8 = nn.Sequential(                    # 256x256
            nn.ConvTranspose2d(64*2, 3, 4, 2, 1),
            nn.Tanh()
        )

    def forward(self, x):
        # Downsampling
        d1 = self.de1(x)
        d2 = self.de2(d1)
        d3 = self.de3(d2)
        d4 = self.de4(d3)
        d5 = self.de5(d4)
        d6 = self.de6(d5)
        d7 = self.de7(d6)
        d8 = self.de8(d7)

        # Upsampling(with skip connections)
        up1 = self.up1(d8)
        up2 = self.up2(torch.cat([up1, d7], dim=1))
        up3 = self.up3(torch.cat([up2, d6], dim=1))
        up4 = self.up4(torch.cat([up3, d5], dim=1))
        up5 = self.up5(torch.cat([up4, d4], dim=1))
        up6 = self.up6(torch.cat([up5, d3], dim=1))
        up7 = self.up7(torch.cat([up6, d2], dim=1))
        return self.up8(torch.cat([up7, d1], dim=1))


## Training and Testing

In [12]:
device = "cuda" if torch.cuda.is_available() else "cpu"  # Use GPU if available, otherwise fallback to CPU


In [13]:
disc = Discriminator().to(device)  # Initialize discriminator and move to device
gen = Generator().to(device)       # Initialize generator and move to device

optimizer_disc = optim.Adam(disc.parameters(), lr=lr, betas=betas)  # Optimizer for discriminator
optimizer_gen = optim.Adam(gen.parameters(), lr=lr, betas=betas)    # Optimizer for generator

criterion = nn.BCEWithLogitsLoss()  # Binary cross-entropy loss with logits (used for GAN loss)
l1 = nn.L1Loss()                    # L1 loss for pixel-wise similarity between generated and real images


In [16]:
def gradient_penalty(critic, real, fake, input_image, device):
    # Interpolate between real and fake images
    alpha = torch.rand(real.size(0), 1, 1, 1).to(real.device)
    interpolated = real * alpha + fake * (1 - alpha)
    interpolated.requires_grad_(True)

    # Get critic scores for interpolated images
    mixed_scores = critic(input_image, interpolated)

    # Compute gradients of scores w.r.t. interpolated images
    grad_outputs = torch.ones_like(mixed_scores)
    gradients = torch.autograd.grad(
        outputs=mixed_scores,
        inputs=interpolated,
        grad_outputs=grad_outputs,
        create_graph=True,
        retain_graph=True
    )[0]

    # Flatten gradients and compute gradient penalty
    gradients = gradients.view(gradients.size(0), -1)
    penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    return penalty  # Enforces Lipschitz constraint for WGAN-GP


In [None]:
def plot_losses(G_train_losses, D_train_losses, L1_train_losses,
                G_test_losses, D_test_losses, L1_test_losses):
    # Clear previous plot output
    clear_output(wait=True)

    # Plot training and testing losses
    plt.figure(figsize=(10, 6))
    plt.plot(G_train_losses, label='Generator Train Loss')
    plt.plot(D_train_losses, label='Discriminator Train Loss')
    plt.plot(L1_train_losses, label='L1 Train Loss')
    plt.plot(G_test_losses, label='Generator Test Loss')
    plt.plot(D_test_losses, label='Discriminator Test Loss')
    plt.plot(L1_test_losses, label='L1 Test Loss')

    # Configure plot
    plt.xlabel('Checkpoint every 25 epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)
    plt.title('Training Losses')
    plt.tight_layout()
    plt.show()


In [None]:
def save_checkpoint(val_loader, gen, epoch, device):
    # Generate outputs using the current generator model
    gen.eval()
    with torch.no_grad():
        inputs, targets = next(iter(val_loader))
        inputs = inputs.to(device)
        outputs = targets.to(device)
        model_outputs = gen(inputs)

    # Plot input, generated, and ground truth images
    fig, axs = plt.subplots(3, 4, figsize=(12, 9))
    for i in range(4):
        axs[0][i].imshow(inputs[i].permute(1, 2, 0).cpu())
        axs[1][i].imshow(model_outputs[i].permute(1, 2, 0).cpu())
        axs[2][i].imshow(outputs[i].permute(1, 2, 0).cpu())

    for ax in axs.flat:
        ax.axis('off')
    axs[0][0].set_ylabel("Input")
    axs[1][0].set_ylabel("Generated")
    axs[2][0].set_ylabel("Target")

    plt.tight_layout()
    plt.savefig(f"output_image_epoch{epoch}.png")  # Save comparison as image
    gen.train()


In [None]:
G_test_losses = []
D_test_losses = []
L1_test_losses = []
G_train_losses = []
D_train_losses = []
L1_train_losses = []

epochs = 200  # Total number of training epochs

for epoch in tqdm(range(epochs)):
    # Initializing accumulators for losses in this epoch
    gen_test_loss, dis_test_loss, l1_test_loss = 0.0, 0.0, 0.0
    gen_train_loss, dis_train_loss, l1_train_loss = 0.0, 0.0, 0.0

    for i, (x, y) in enumerate(train_loader):
        x = x.to(device)
        y = y.to(device)

        # Training discriminator multiple times per generator update (disc_iter)
        for _ in range(disc_iter):
            # Generating fake images without gradient computation for generator
            with torch.no_grad():
                y_fake = gen(x)

            # Computing discriminator output on real and fake images
            d_real = disc(x, y)
            d_fake = disc(x, y_fake)

            # Calculating gradient penalty for WGAN-GP regularization
            gp = gradient_penalty(disc, y, y_fake, x, device)

            # Wasserstein loss with gradient penalty for discriminator
            d_loss = -(torch.mean(d_real) - torch.mean(d_fake)) + lambda_gp * gp

            optimizer_disc.zero_grad()
            d_loss.backward()
            optimizer_disc.step()

            dis_train_loss += d_loss.item()

        y_fake = gen(x)  # Generating fake images with gradients
        d_fake = disc(x, y_fake)

        # Generator adversarial loss (BCE loss against 'real' labels)
        gen_fakeloss = criterion(d_fake, torch.ones_like(d_fake))

        # L1 loss encourages pixel-wise similarity between generated and real images
        L1_loss = l1(y_fake, y) * 100

        # Total generator loss is adversarial + L1 losses
        gen_total_loss = gen_fakeloss + L1_loss

        optimizer_gen.zero_grad()
        gen_total_loss.backward()
        optimizer_gen.step()

        gen_train_loss += gen_total_loss.item()
        l1_train_loss += L1_loss.item()

    # Average training losses over the entire training dataset
    G_train_losses.append(gen_train_loss / len(train_loader))
    D_train_losses.append(dis_train_loss / len(train_loader))
    L1_train_losses.append(l1_train_loss / len(train_loader))

    #Validation Phase
    gen.eval()
    disc.eval()
    with torch.inference_mode():
        for x, y in val_loader:
            x = x.to(device)
            y = y.to(device)

            y_fake = gen(x)

            # Discriminator loss on validation set using BCE loss
            d_real = disc(x, y)
            d_fake = disc(x, y_fake.detach())  # detach to avoid generator gradients

            d_realloss = criterion(d_real, torch.ones_like(d_real))
            d_fakeloss = criterion(d_fake, torch.zeros_like(d_fake))
            d_loss = (d_realloss + d_fakeloss) / 2
            dis_test_loss += d_loss.item()

            # Generator loss on validation set
            d_fake = disc(x, y_fake)
            gen_fakeloss = criterion(d_fake, torch.ones_like(d_fake))
            L1_loss = l1(y_fake, y) * 100
            gen_total_loss = gen_fakeloss + L1_loss

            gen_test_loss += gen_total_loss.item()
            l1_test_loss += L1_loss.item()

    gen.train()
    disc.train()

    # Average validation losses over the validation dataset
    G_test_losses.append(gen_test_loss / len(val_loader))
    D_test_losses.append(dis_test_loss / len(val_loader))
    L1_test_losses.append(l1_test_loss / len(val_loader))

    # Saving generated sample images and plot losses every 25 epochs
    if epoch % 25 == 0:
        save_checkpoint(val_loader, gen, epoch, device)
        plot_losses(G_train_losses, D_train_losses, L1_train_losses,
                    G_test_losses, D_test_losses, L1_test_losses)


In [None]:
torch.save(gen.state_dict(), "Generator_weights.pth")      # To save Generator model weights
torch.save(disc.state_dict(), "Discriminator_weights.pth") # To save Discriminator model weights


In [None]:
weights_path = "/kaggle/input/generator-weights/Generator_weights.pth" 

gen.load_state_dict(torch.load(weights_path, map_location='cuda' if torch.cuda.is_available() else 'cpu'))

gen.eval()


In [19]:
# Directories to store images temporarily for FID
real_dir = "real_images"
gen_dir = "gen_images"
os.makedirs(real_dir, exist_ok=True)
os.makedirs(gen_dir, exist_ok=True)

# Move generator to eval mode
gen.eval()

ssim_scores, psnr_scores, gram_diffs = [], [], []

def gram_matrix(tensor):
    b, c, h, w = tensor.size()
    features = tensor.view(b, c, h * w)
    G = torch.bmm(features, features.transpose(1, 2))  # (b, c, c)
    return G / (c * h * w)

with torch.no_grad():
    for i, (input_img, target_img) in enumerate(tqdm(val_loader)):
        input_img = input_img.to(device)       # [B, 3, H, W]
        target_img = target_img.to(device)     # [B, 3, H, W]

        pred_img = gen(input_img)        # [B, 3, H, W]

        for b in range(pred_img.size(0)):
            pred = pred_img[b].detach().cpu()
            real = target_img[b].detach().cpu()

            # Save for FID (denormalized to [0, 1])
            save_image((pred + 1) / 2, f"{gen_dir}/{i*val_loader.batch_size + b}.png")
            save_image((real + 1) / 2, f"{real_dir}/{i*val_loader.batch_size + b}.png")

            # Convert for metrics (CHW to HWC, [0, 1])
            pred_np = ((pred.numpy() + 1) / 2).transpose(1, 2, 0)
            real_np = ((real.numpy() + 1) / 2).transpose(1, 2, 0)

            # Clamp to avoid nan issues
            pred_np = np.clip(pred_np, 0, 1)
            real_np = np.clip(real_np, 0, 1)

            # SSIM & PSNR
            # ssim_scores.append(ssim(pred_np, real_np, channel_axis=-1))
            ssim_scores.append(ssim(pred_np, real_np, data_range=1.0, channel_axis=-1))

            psnr_scores.append(psnr(real_np, pred_np))

            # Gram Matrix Distance
            pred_gram = gram_matrix(pred.unsqueeze(0))
            real_gram = gram_matrix(real.unsqueeze(0))
            gram_diffs.append(F.mse_loss(pred_gram, real_gram).item())

# Compute FID
fid_result = calculate_metrics(
    input1=gen_dir,
    input2=real_dir,
    fid = True,
    metrics=['fid'],
    verbose=False
)

# Display Results
print("\n--- Evaluation Metrics ---")
print(f"Avg SSIM: {np.mean(ssim_scores):.4f}")
print(f"Avg PSNR: {np.mean(psnr_scores):.2f} dB")
print(f"Avg Gram Matrix Distance: {np.mean(gram_diffs):.6f}")
print(f"FID: {fid_result['frechet_inception_distance']:.2f}")


  0%|          | 0/35 [00:00<?, ?it/s]

Downloading: "https://github.com/toshas/torch-fidelity/releases/download/v0.2.0/weights-inception-2015-12-05-6726825d.pth" to /root/.cache/torch/hub/checkpoints/weights-inception-2015-12-05-6726825d.pth
100%|██████████| 91.2M/91.2M [00:00<00:00, 350MB/s]
  img = torch.ByteTensor(torch.ByteStorage.from_buffer(img.tobytes())).view(height, width, 3)



--- Evaluation Metrics ---
Avg SSIM: 0.7383
Avg PSNR: 26.93 dB
Avg Gram Matrix Distance: 0.000185
FID: 286.07
