In [None]:
# TODO
# Save/Load Checkpoint
# Save/Load Model
# Look into training optimization
# - GradScaler
# - NumWorkers during loading

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import os

from PIL import Image
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.utils import save_image
from tqdm import tqdm

# Discriminator Architecture (Pix2Pix)

The `Discriminator` class implements a **PatchGAN-based** discriminator, commonly used in the **Pix2Pix** model for image-to-image translation. It classifies whether each patch in an image is real or fake.

## Architecture Overview

- The discriminator takes as input a **pair of images** `(x, y)`, where:
  - `x` is the input image.
  - `y` is the target/generated image.
  - These two images are concatenated along the channel dimension (`dim=1`).
  
- It consists of:
  1. **An Initial Convolution Layer**:
     - Uses a `Conv2d` layer with:
       - Input channels: `in_channels * 2` (since both `x` and `y` are concatenated).
       - Output channels: `features[0]` (typically `64`).
       - Kernel size: `4x4`.
       - Stride: `2` (reduces spatial dimensions).
       - No Batch Normalization (as in the Pix2Pix paper).
       - Activation: `LeakyReLU(0.2)`.
  
  2. **A Series of CNN Blocks (`CNNBlock`)**:
     - Each block consists of:
       - A `Conv2d` layer with **Batch Normalization** and **LeakyReLU(0.2)**.
       - Stride is `2` for all layers **except** the last one.
       - Channels progress as defined in `features=[64, 128, 256, 512]`.
  
  3. **Final Convolution Layer**:
     - A single **`Conv2d`** layer with:
       - Output channels = `1` (discriminator outputs a single-channel probability map).
       - Kernel size: `4x4`.
       - Stride: `1`.
       - Padding: `1`.
       - No activation function (raw logits output).

## Notes

- The discriminator **does not output a single value** but rather a **feature map**, where each patch's value represents its real/fake probability.
- The `padding_mode="reflect"` is used to **reduce artifacts**.

In [29]:
class CNNBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=4, stride=2, padding=1):
        super().__init__()

        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False, padding_mode="reflect"),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2)
        )

    def forward(self, x):
        return self.conv(x)
        
class Discriminator(nn.Module):
    def __init__(self, in_channels=3, features=[64, 128, 256, 512]):
        super().__init__()

        self.initial = nn.Sequential(
            # Channels are doubled because in Pix2Pix model, the discriminator takes (x,y) pair as input
            # Initial layer does not perform batch normalization and hence specified separately
            nn.Conv2d(in_channels*2, features[0], kernel_size=4, stride=2, padding=1, padding_mode="reflect"),
            nn.LeakyReLU(0.2),
        )

        convolution_layers = []
        in_channels = features[0]
        
        for feature in features[1:]:
            convolution_layers.append(
                # Stride is 2 except for the last layer
                CNNBlock(in_channels, feature, stride=1 if feature == features[-1] else 2),
            )
            in_channels = feature

        convolution_layers.append(
            nn.Conv2d(
                in_channels, 1, kernel_size=4, stride=1, padding=1, padding_mode="reflect"
            )
        )
        
        self.model = nn.Sequential(*convolution_layers)

    def forward(self, x, y):
        # Concatenate x and y along the channel
        x = torch.cat([x,y], dim=1)
        x = self.initial(x)
        return self.model(x)

# Generator Architecture (UNet-based)

The `Generator` class implements a **UNet-based** architecture, commonly used in **Pix2Pix** for image-to-image translation. It follows an **encoder-decoder** structure with **skip connections** to retain spatial information lost during downsampling.

## Architecture Overview

The generator consists of the following components:

1. **Initial Downsampling Block**:
   - A `Conv2d` layer with:
     - Input channels: `in_channels` (default: `3` for RGB).
     - Output channels: `features` (default: `64`).
     - Kernel size: `4x4`, stride `2`, padding `1`.
     - `LeakyReLU(0.2)` activation.
   - No batch normalization is applied in this layer.

2. **Series of Encoder Blocks**:
   - Each encoder block (`UNetBlock`) performs **downsampling** using:
     - A `Conv2d` layer with `stride=2` to reduce spatial dimensions.
     - Batch Normalization.
     - `LeakyReLU(0.2)` activation.
   - Skip connections store intermediate outputs for later use in the decoder.

3. **Bottleneck Layer**:
   - A `Conv2d` layer with:
     - Kernel size: `4x4`, stride `2`, padding `1`.
     - `ReLU` activation (no batch normalization).
   - A **bottleneck upsampling** block follows, performing:
     - Transposed convolution (`ConvTranspose2d`).
     - `ReLU` activation.
     - **Dropout (`0.5`)** is applied to this layer.

4. **Series of Decoder Blocks**:
   - Each decoder block (`UNetBlock`) performs **upsampling** using:
     - A `ConvTranspose2d` layer with `stride=2` to increase spatial dimensions.
     - Batch Normalization.
     - `ReLU` activation.
     - **Dropout** applied for the first `decoder_dropout_range` layers.
   - Decoder blocks take input from both:
     - The previous decoder layer.
     - The **corresponding encoder output** (via **skip connections**).

5. **Final Upsampling Block**:
   - A `ConvTranspose2d` layer with:
     - Output channels: `in_channels` (matching input size).
     - `Tanh` activation to normalize output values between `[-1, 1]`.

## Notes

- **Skip Connections** help retain high-resolution details lost during downsampling.
- The **use of dropout** in certain decoder layers acts as regularization.
- The generator **learns a mapping** from input images (`x`) to target images (`y`), making it suitable for **image-to-image translation**.

In [100]:
class UNetBlock(nn.Module):
    def __init__(self, in_channels, out_channels, down_sample=True, activation="relu", use_dropout=False):
        super().__init__()

        layers = []
        if down_sample:
            layers.append(nn.Conv2d(in_channels, out_channels, 4, 2, 1, bias=False, padding_mode="reflect"))
        else:
            layers.append(nn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1, bias=False))
            
        layers.append(nn.BatchNorm2d(out_channels))

        if activation == "relu":
            layers.append(nn.ReLU())
        else:
            layers.append(nn.LeakyReLU(0.2))
            
        if use_dropout:
            layers.append(nn.Dropout(0.5))
            
        self.conv = nn.Sequential(*layers)

    def forward(self, x):
        return self.conv(x)
        
class Generator(nn.Module):
    def __init__(self, in_channels=3, features=64, encoder_blocks=[], decoder_blocks=[], decoder_dropout_range=0):
        super().__init__()

        # Initial downward block with no batch normalization
        self.initial = nn.Sequential(
            nn.Conv2d(in_channels, features, 4, 2, 1, padding_mode="reflect"),
            nn.LeakyReLU(0.2)
        )

        # Series of downsampling encoders
        self.encoders = nn.ModuleList()
        for i, o in encoder_blocks:
            self.encoders.append(
                UNetBlock(features*i, features*o, down_sample=True, activation="leaky", use_dropout=False)
            )

        # Bottleneck
        self.bottleneck = nn.Sequential(
            nn.Conv2d(features*8, features*8, 4, 2, 1, bias=False),
            nn.ReLU()
        )

        self.bottleneck_up = nn.Sequential(
            UNetBlock(features*8, features*8, down_sample=False, activation="relu", use_dropout=True)
        )

        # Series of upsampling decoders
        self.decoders = nn.ModuleList()
        for idx, (i, o) in enumerate(decoder_blocks):
            use_dropout = idx < decoder_dropout_range
            self.decoders.append(
                # Multiplied by two to account for skip connection concatenation
                UNetBlock(features*i*2, features*o, down_sample=False, activation="relu", use_dropout=use_dropout)
            )

        # Final unsampling block
        self.final = nn.Sequential(
            nn.ConvTranspose2d(features*2, in_channels, 4, 2, 1),
            nn.Tanh()
        )

    def forward(self, x):
        skip_connections = []
        d = self.initial(x)
        skip_connections.append(d)

        for encoder in self.encoders:
            d = encoder(d)
            skip_connections.append(d)

        bottleneck = self.bottleneck(d)
        u = self.bottleneck_up(bottleneck)

        # Reverse the skip connections to match encoder-decoder
        for decoder, skip_encoder in zip(self.decoders, list(reversed(skip_connections))[:-1]):
            u = decoder(torch.cat([u, skip_encoder], dim=1))

        # Connect top level encoder with decoder
        return self.final(torch.cat([u, skip_connections[0]], dim=1))

In [124]:
def test():
    x = torch.randn((1, 3, 256, 256))
    model = Generator(in_channels=3, features=64, encoder_blocks=[(1,2), (2,4), (4,8), (8,8), (8,8), (8,8)], decoder_blocks=[(8,8), (8,8), (8,8), (8,4), (4,2), (2,1)], decoder_dropout_range=2)
    preds = model(x)
    print(preds.shape)

test()

torch.Size([1, 3, 256, 256])


In [131]:
class MapDataset(Dataset):
    def __init__(self, root_dir):
        self.root_dir = root_dir
        self.file_lists = os.listdir(self.root_dir)
        self.transformer = transforms.ToTensor()

    def __len__(self):
        return len(self.file_lists)

    def __getitem__(self, index):
        img_file = self.file_lists[index]
        img_path = os.path.join(self.root_dir, img_file)
        img = Image.open(img_path)
        img_tensor = self.transformer(img)

        # Define the resize transformation
        resize_transform = transforms.Compose([
            transforms.Resize((256, 256))  # Resize to 256x256
        ])

        # Split the image into img_x and img_y
        img_x = img_tensor[:, :, :600]  # First 600 pixels along the width
        img_y = img_tensor[:, :, 600:]  # Remaining pixels from 600 onward
        
        # Apply the transformation to both images
        img_x_resized = resize_transform(img_x)
        img_y_resized = resize_transform(img_y)
        
        # augmentation can be performed here if needed

        return img_x_resized, img_y_resized

In [135]:
def save_some_examples(gen, val_loader, epoch, folder):
    x, y = next(iter(val_loader))
    x, y = x.to(device), y.to(device)
    gen.eval()
    with torch.no_grad():
        y_fake = gen(x)
        y_fake = y_fake * 0.5 + 0.5  # remove normalization#
        save_image(y_fake, folder + f"/y_gen_{epoch}.png")
        save_image(x * 0.5 + 0.5, folder + f"/input_{epoch}.png")
        if epoch == 1:
            save_image(y * 0.5 + 0.5, folder + f"/label_{epoch}.png")
    gen.train()

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
train_file_dir = "./dataset/train"
val_file_dir = "./dataset/val"
lr = 2e-4
batch_size = 16
img_size = 256
img_channels = 3
l1_lambda = 100
lambda_gp = 10
num_epochs = 500

disc = Discriminator(
    in_channels=3,
    features=[64, 128, 256, 512]
).to(device)
gen = Generator(
    in_channels=3,
    features=64,
    encoder_blocks=[(1,2), (2,4), (4,8), (8,8), (8,8), (8,8)],
    decoder_blocks=[(8,8), (8,8), (8,8), (8,4), (4,2), (2,1)],
    decoder_dropout_range=2
).to(device)

disc_opt = optim.Adam(disc.parameters(), lr=lr, betas=(0.5, 0.999))
gen_opt = optim.Adam(gen.parameters(), lr=lr, betas=(0.5, 0.999))

# Wasserstein produced bad result with PatchGAN
BCE_loss = nn.BCEWithLogitsLoss()
L1_loss = nn.L1Loss()

train_dataset = MapDataset(root_dir=train_file_dir)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

val_dataset = MapDataset(root_dir=val_file_dir)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)

for epoch in range(num_epochs):
    loop = tqdm(train_loader, leave=True)

    for i, (x,y) in enumerate(loop):
        x, y = x.to(device).float(), y.to(device).float()

        # Train Discriminator
        y_fake = gen(x)
        d_real = disc(x, y)
        d_fake = disc(x, y_fake.detach())
        d_real_loss = BCE_loss(d_real, torch.ones_like(d_real))
        d_fake_loss = BCE_loss(d_fake, torch.zeros_like(d_fake))
        d_loss = (d_real_loss + d_fake_loss) / 2

        disc_opt.zero_grad()
        d_loss.backward()
        disc_opt.step()

        # Train Generator
        d_fake = disc(x, y_fake)
        g_fake_loss = BCE_loss(d_fake, torch.ones_like(d_fake))
        l1 = L1_loss(y_fake, y) * l1_lambda
        g_loss = g_fake_loss + l1

        gen_opt.zero_grad()
        g_loss.backward()
        gen_opt.step()

    save_some_examples(gen, val_loader, epoch, folder="evaluation")