<a href="https://colab.research.google.com/github/apester/IME/blob/main/Lab12a_GAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install torch torchvision pillow


In [None]:
import os
import time
import random
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torchvision.utils as vutils
from PIL import Image

# Set a random seed for reproducibility
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
#########################
# 1. Synthetic Dataset  #
#########################

class SyntheticStyleDataset(Dataset):
    """
    A synthetic dataset for style transfer.
    Domain A: images with a red tint.
    Domain B: images with a blue tint (a 'styled' version of domain A).
    Images are 64x64 with 3 channels.
    """
    def __init__(self, num_samples=100, image_size=64):
        super().__init__()
        self.num_samples = num_samples
        self.image_size = image_size
        self.transform = transforms.Compose([
            transforms.ToTensor(),  # scales pixel values to [0,1]
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # scale to [-1,1]
        ])

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        # For reproducibility, you can seed per index if needed.
        # Create a random noise image
        np.random.seed(idx)
        noise = np.random.rand(self.image_size, self.image_size, 3) * 255.0
        noise = noise.astype(np.uint8)
        img = Image.fromarray(noise)

        # Domain A: add a red tint
        img_a = img.copy().convert("RGB")
        img_a = np.array(img_a).astype(np.float32)
        img_a[..., 0] = np.clip(img_a[..., 0] + 100, 0, 255)  # increase red channel
        img_a = Image.fromarray(img_a.astype(np.uint8))

        # Domain B: simulate style transfer by shifting red to blue
        img_b = img.copy().convert("RGB")
        img_b = np.array(img_b).astype(np.float32)
        # decrease red and increase blue channel to mimic a change in style
        img_b[..., 0] = np.clip(img_b[..., 0] - 50, 0, 255)
        img_b[..., 2] = np.clip(img_b[..., 2] + 50, 0, 255)
        img_b = Image.fromarray(img_b.astype(np.uint8))

        # Transform images to normalized tensors
        img_a = self.transform(img_a)
        img_b = self.transform(img_b)

        return {"A": img_a, "B": img_b}

# Create dataset and dataloader
dataset = SyntheticStyleDataset(num_samples=200, image_size=64)
batch_size = 16
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)


In [None]:
##################################
# 2. Define Generator & Discriminator
##################################

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

class Generator(nn.Module):
    """
    A simple Generator network for pix2pix.
    Architecture: Encoder (Conv layers) -> Decoder (ConvTranspose layers)
    """
    def __init__(self, input_nc=3, output_nc=3, n_features=64):
        super(Generator, self).__init__()
        # Encoder layers
        self.down1 = nn.Sequential(
            nn.Conv2d(input_nc, n_features, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True)
        )
        self.down2 = nn.Sequential(
            nn.Conv2d(n_features, n_features * 2, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(n_features * 2),
            nn.LeakyReLU(0.2, inplace=True)
        )
        self.down3 = nn.Sequential(
            nn.Conv2d(n_features * 2, n_features * 4, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(n_features * 4),
            nn.LeakyReLU(0.2, inplace=True)
        )

        # Decoder layers
        self.up1 = nn.Sequential(
            nn.ConvTranspose2d(n_features * 4, n_features * 2, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(n_features * 2),
            nn.ReLU(inplace=True)
        )
        self.up2 = nn.Sequential(
            nn.ConvTranspose2d(n_features * 2, n_features, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(n_features),
            nn.ReLU(inplace=True)
        )
        self.up3 = nn.Sequential(
            nn.ConvTranspose2d(n_features, output_nc, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )

    def forward(self, x):
        d1 = self.down1(x)
        d2 = self.down2(d1)
        d3 = self.down3(d2)
        u1 = self.up1(d3)
        u2 = self.up2(u1)
        u3 = self.up3(u2)
        return u3

class Discriminator(nn.Module):
    """
    A PatchGAN discriminator for pix2pix.
    It takes a concatenation of the input and the output (generated or real) image.
    """
    def __init__(self, input_nc=6, n_features=64):
        super(Discriminator, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(input_nc, n_features, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True)
        )
        self.layer2 = nn.Sequential(
            nn.Conv2d(n_features, n_features * 2, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(n_features * 2),
            nn.LeakyReLU(0.2, inplace=True)
        )
        self.layer3 = nn.Sequential(
            nn.Conv2d(n_features * 2, n_features * 4, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(n_features * 4),
            nn.LeakyReLU(0.2, inplace=True)
        )
        self.layer4 = nn.Sequential(
            nn.Conv2d(n_features * 4, 1, kernel_size=4, stride=1, padding=1),
            nn.Sigmoid()
        )

    def forward(self, input_img, target_img):
        # Concatenate input and target images along the channel dimension
        x = torch.cat([input_img, target_img], dim=1)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        return x

# Initialize models
netG = Generator().to(device)
netD = Discriminator().to(device)

netG.apply(weights_init)
netD.apply(weights_init)

In [None]:
##############################
# 3. Losses, Optimizers, etc.
##############################

criterion_GAN = nn.BCELoss()
criterion_L1 = nn.L1Loss()

# Hyperparameters
lr = 0.0002
beta1 = 0.5
n_epochs = 5
lambda_L1 = 100

optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))


In [None]:
##########################################
# 4. Training Loop for Pix2Pix Style GAN
##########################################

print("Starting Training Loop...")
for epoch in range(n_epochs):
    epoch_start = time.time()
    for i, batch in enumerate(dataloader):
        # Get data and move to device
        real_A = batch["A"].to(device)  # input image (red-tinted)
        real_B = batch["B"].to(device)  # target image (blue-tinted)

        # Create real and fake labels
        # Here real_label=1 and fake_label=0
        real_label = torch.full((real_A.size(0), 1, 6, 6), 1.0, device=device)
        fake_label = torch.full((real_A.size(0), 1, 6, 6), 0.0, device=device)

        ############################
        # Update Generator: G
        ############################
        optimizerG.zero_grad()

        # Generate fake images from input A
        fake_B = netG(real_A)
        # Evaluate the discriminator with fake output (detach not used here because we need gradients to flow back into G)
        pred_fake = netD(real_A, fake_B)
        loss_GAN = criterion_GAN(pred_fake, real_label)
        # L1 loss between the generated image and the real image B
        loss_L1 = criterion_L1(fake_B, real_B)
        # Total generator loss
        loss_G = loss_GAN + lambda_L1 * loss_L1

        loss_G.backward()
        optimizerG.step()

        ############################
        # Update Discriminator: D
        ############################
        optimizerD.zero_grad()
        # Compute loss with real images
        pred_real = netD(real_A, real_B)
        loss_D_real = criterion_GAN(pred_real, real_label)
        # Compute loss with fake images (detach so gradients don’t flow to G)
        pred_fake = netD(real_A, fake_B.detach())
        loss_D_fake = criterion_GAN(pred_fake, fake_label)
        # Total discriminator loss
        loss_D = (loss_D_real + loss_D_fake) * 0.5

        loss_D.backward()
        optimizerD.step()

        # Optionally, print statistics every few iterations
        if i % 10 == 0:
            print(f"Epoch [{epoch+1}/{n_epochs}] Batch [{i}/{len(dataloader)}] "
                  f"Loss_G: {loss_G.item():.4f} Loss_D: {loss_D.item():.4f}")

    epoch_time = time.time() - epoch_start
    print(f"Epoch {epoch+1} completed in {epoch_time:.2f} sec.")


In [None]:
#############################
# 5. Save Outputs and Model #
#############################

os.makedirs("output", exist_ok=True)
# Save a batch of input, output and target images for visual comparison
sample_batch = next(iter(dataloader))
real_A_sample = sample_batch["A"].to(device)
real_B_sample = sample_batch["B"].to(device)
with torch.no_grad():
    fake_B_sample = netG(real_A_sample)

# Concatenate images: input (A), generated (fake B), and target (B)
combined = torch.cat([real_A_sample, fake_B_sample, real_B_sample], dim=0)
vutils.save_image(combined, "output/sample_comparison.png", nrow=batch_size, normalize=True)
print("Saved sample image comparison to output/sample_comparison.png")

# Save the generator model checkpoint
torch.save(netG.state_dict(), "output/netG.pth")
print("Saved generator model weights to output/netG.pth")
