In [3]:
import torch
import torch.nn as nn
from torchvision import models
import torch.nn.functional as F
import numpy as np

# custom dataloader
from src.dataloader import DIV2KDataModule
# Generator network
from src.generator import Generator
# Discriminator network
from src.discriminator import Discriminator
# module to get VGG features for perceptual loss
from src.vgg_wrapper import VGGFeatureExtractor

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [4]:
# set up dataloader
dataloader = DIV2KDataModule(batch_size=32, num_workers=4)
dataloader.setup()

In [5]:
# set up feature extractor
vgg_extractor = VGGFeatureExtractor().to(device)
vgg_extractor.eval()

VGGFeatureExtractor(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): Conv2d(256, 256, kernel_size=(3, 3), str

In [None]:
# Initialize Generator and Discriminator
generator = Generator(in_channels=3, out_channels=3, n_residual_blocks=16, upscale_factor=4).to(device)
discriminator = Discriminator(input_shape=(3, 256, 256)).to(device)

print(f"Generator parameters: {sum(p.numel() for p in generator.parameters()):,}")
print(f"Discriminator parameters: {sum(p.numel() for p in discriminator.parameters()):,}")

In [None]:
# Define Loss Functions
import torch.optim as optim

# Pixel-wise loss (MSE)
criterion_mse = nn.MSELoss()

# Adversarial loss (Binary Cross Entropy)
criterion_bce = nn.BCEWithLogitsLoss()

# Perceptual loss weight
lambda_pixel = 1e-2  # Weight for pixel-wise MSE loss
lambda_content = 1.0  # Weight for content (VGG) loss
lambda_adversarial = 1e-3  # Weight for adversarial loss

In [None]:
# Optimizers
lr_gen = 1e-4
lr_disc = 1e-4

optimizer_G = optim.Adam(generator.parameters(), lr=lr_gen, betas=(0.9, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr_disc, betas=(0.9, 0.999))

In [None]:
# Training Loop
from tqdm import tqdm
import time

def train_srgan(generator, discriminator, vgg_extractor, dataloader, 
                num_epochs=100, pretrain_epochs=10, device='cuda'):
    """
    Train SRGAN with two phases:
    1. Pre-training phase: Train generator with MSE loss only
    2. GAN training phase: Train with full perceptual loss (content + adversarial + pixel)
    """
    
    train_loader = dataloader.train_dataloader()
    val_loader = dataloader.val_dataloader()
    
    # Tracking metrics
    history = {
        'g_loss': [], 'd_loss': [], 'content_loss': [], 
        'adversarial_loss': [], 'pixel_loss': []
    }
    
    print("=" * 60)
    print("Starting Pre-training Phase (Generator with MSE loss only)")
    print("=" * 60)
    
    # Phase 1: Pre-train Generator with MSE loss
    for epoch in range(pretrain_epochs):
        generator.train()
        epoch_mse_loss = 0
        
        pbar = tqdm(train_loader, desc=f'Pre-train Epoch {epoch+1}/{pretrain_epochs}')
        for lr_imgs, hr_imgs in pbar:
            lr_imgs = lr_imgs.to(device)
            hr_imgs = hr_imgs.to(device)
            
            # Generate SR images
            sr_imgs = generator(lr_imgs)
            
            # MSE loss
            loss_mse = criterion_mse(sr_imgs, hr_imgs)
            
            # Update generator
            optimizer_G.zero_grad()
            loss_mse.backward()
            optimizer_G.step()
            
            epoch_mse_loss += loss_mse.item()
            pbar.set_postfix({'MSE Loss': f'{loss_mse.item():.4f}'})
        
        avg_mse_loss = epoch_mse_loss / len(train_loader)
        print(f'Epoch {epoch+1}/{pretrain_epochs} - MSE Loss: {avg_mse_loss:.4f}')
    
    print("\n" + "=" * 60)
    print("Starting GAN Training Phase (Full Perceptual Loss)")
    print("=" * 60)
    
    # Phase 2: Train with full perceptual loss
    for epoch in range(num_epochs):
        generator.train()
        discriminator.train()
        
        epoch_g_loss = 0
        epoch_d_loss = 0
        epoch_content_loss = 0
        epoch_adversarial_loss = 0
        epoch_pixel_loss = 0
        
        pbar = tqdm(train_loader, desc=f'GAN Epoch {epoch+1}/{num_epochs}')
        for lr_imgs, hr_imgs in pbar:
            batch_size = lr_imgs.size(0)
            lr_imgs = lr_imgs.to(device)
            hr_imgs = hr_imgs.to(device)
            
            # Labels for adversarial loss
            real_labels = torch.ones(batch_size, 1).to(device)
            fake_labels = torch.zeros(batch_size, 1).to(device)
            
            # ==========================================
            # Train Discriminator
            # ==========================================
            optimizer_D.zero_grad()
            
            # Generate SR images
            sr_imgs = generator(lr_imgs).detach()
            
            # Real images
            real_preds = discriminator(hr_imgs)
            loss_real = criterion_bce(real_preds, real_labels)
            
            # Fake images
            fake_preds = discriminator(sr_imgs)
            loss_fake = criterion_bce(fake_preds, fake_labels)
            
            # Total discriminator loss
            loss_D = (loss_real + loss_fake) / 2
            
            loss_D.backward()
            optimizer_D.step()
            
            # ==========================================
            # Train Generator
            # ==========================================
            optimizer_G.zero_grad()
            
            # Generate SR images
            sr_imgs = generator(lr_imgs)
            
            # 1. Adversarial loss (fool the discriminator)
            sr_preds = discriminator(sr_imgs)
            loss_adv = criterion_bce(sr_preds, real_labels)
            
            # 2. Content loss (VGG perceptual loss)
            sr_features = vgg_extractor(sr_imgs)
            hr_features = vgg_extractor(hr_imgs).detach()
            loss_content = criterion_mse(sr_features, hr_features)
            
            # 3. Pixel-wise MSE loss
            loss_pixel = criterion_mse(sr_imgs, hr_imgs)
            
            # Total generator loss (perceptual loss)
            loss_G = (lambda_content * loss_content + 
                     lambda_adversarial * loss_adv + 
                     lambda_pixel * loss_pixel)
            
            loss_G.backward()
            optimizer_G.step()
            
            # Track losses
            epoch_g_loss += loss_G.item()
            epoch_d_loss += loss_D.item()
            epoch_content_loss += loss_content.item()
            epoch_adversarial_loss += loss_adv.item()
            epoch_pixel_loss += loss_pixel.item()
            
            pbar.set_postfix({
                'G_loss': f'{loss_G.item():.4f}',
                'D_loss': f'{loss_D.item():.4f}',
                'Content': f'{loss_content.item():.4f}'
            })
        
        # Average losses for the epoch
        avg_g_loss = epoch_g_loss / len(train_loader)
        avg_d_loss = epoch_d_loss / len(train_loader)
        avg_content_loss = epoch_content_loss / len(train_loader)
        avg_adversarial_loss = epoch_adversarial_loss / len(train_loader)
        avg_pixel_loss = epoch_pixel_loss / len(train_loader)
        
        # Store history
        history['g_loss'].append(avg_g_loss)
        history['d_loss'].append(avg_d_loss)
        history['content_loss'].append(avg_content_loss)
        history['adversarial_loss'].append(avg_adversarial_loss)
        history['pixel_loss'].append(avg_pixel_loss)
        
        print(f'Epoch {epoch+1}/{num_epochs}')
        print(f'  G_loss: {avg_g_loss:.4f} | D_loss: {avg_d_loss:.4f}')
        print(f'  Content: {avg_content_loss:.4f} | Adversarial: {avg_adversarial_loss:.4f} | Pixel: {avg_pixel_loss:.4f}')
        
        # Save checkpoint every 10 epochs
        if (epoch + 1) % 10 == 0:
            torch.save({
                'epoch': epoch,
                'generator_state_dict': generator.state_dict(),
                'discriminator_state_dict': discriminator.state_dict(),
                'optimizer_G_state_dict': optimizer_G.state_dict(),
                'optimizer_D_state_dict': optimizer_D.state_dict(),
                'history': history
            }, f'srgan_checkpoint_epoch_{epoch+1}.pth')
            print(f'Checkpoint saved: srgan_checkpoint_epoch_{epoch+1}.pth')
    
    return history

In [None]:
# Start Training
# Adjust num_epochs and pretrain_epochs based on your needs
history = train_srgan(
    generator=generator,
    discriminator=discriminator,
    vgg_extractor=vgg_extractor,
    dataloader=dataloader,
    num_epochs=50,  # Number of GAN training epochs
    pretrain_epochs=5,  # Number of pre-training epochs with MSE only
    device=device
)

In [None]:
# Plot Training Losses
import matplotlib.pyplot as plt

fig, axes = plt.subplots(2, 2, figsize=(15, 10))

# Generator and Discriminator Loss
axes[0, 0].plot(history['g_loss'], label='Generator Loss', color='blue')
axes[0, 0].plot(history['d_loss'], label='Discriminator Loss', color='red')
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].set_title('Generator vs Discriminator Loss')
axes[0, 0].legend()
axes[0, 0].grid(True)

# Content Loss (VGG Perceptual Loss)
axes[0, 1].plot(history['content_loss'], label='Content Loss', color='green')
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('Loss')
axes[0, 1].set_title('Content (VGG) Loss')
axes[0, 1].legend()
axes[0, 1].grid(True)

# Adversarial Loss
axes[1, 0].plot(history['adversarial_loss'], label='Adversarial Loss', color='orange')
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('Loss')
axes[1, 0].set_title('Adversarial Loss')
axes[1, 0].legend()
axes[1, 0].grid(True)

# Pixel Loss
axes[1, 1].plot(history['pixel_loss'], label='Pixel (MSE) Loss', color='purple')
axes[1, 1].set_xlabel('Epoch')
axes[1, 1].set_ylabel('Loss')
axes[1, 1].set_title('Pixel-wise MSE Loss')
axes[1, 1].legend()
axes[1, 1].grid(True)

plt.tight_layout()
plt.show()

In [None]:
# Visualize Results
def visualize_results(generator, dataloader, device, num_samples=4):
    """Visualize LR, SR, and HR images side by side"""
    generator.eval()
    val_loader = dataloader.val_dataloader()
    
    # Get a batch
    lr_imgs, hr_imgs = next(iter(val_loader))
    lr_imgs = lr_imgs[:num_samples].to(device)
    hr_imgs = hr_imgs[:num_samples].to(device)
    
    with torch.no_grad():
        sr_imgs = generator(lr_imgs)
    
    # Convert to numpy for visualization
    lr_imgs = lr_imgs.cpu().numpy().transpose(0, 2, 3, 1)
    sr_imgs = sr_imgs.cpu().numpy().transpose(0, 2, 3, 1)
    hr_imgs = hr_imgs.cpu().numpy().transpose(0, 2, 3, 1)
    
    # Clip values to [0, 1]
    sr_imgs = np.clip(sr_imgs, 0, 1)
    
    fig, axes = plt.subplots(num_samples, 3, figsize=(12, 4*num_samples))
    
    for i in range(num_samples):
        # Low Resolution
        axes[i, 0].imshow(lr_imgs[i])
        axes[i, 0].set_title('Low Resolution (Input)')
        axes[i, 0].axis('off')
        
        # Super Resolution (Generated)
        axes[i, 1].imshow(sr_imgs[i])
        axes[i, 1].set_title('Super Resolution (Generated)')
        axes[i, 1].axis('off')
        
        # High Resolution (Ground Truth)
        axes[i, 2].imshow(hr_imgs[i])
        axes[i, 2].set_title('High Resolution (Ground Truth)')
        axes[i, 2].axis('off')
    
    plt.tight_layout()
    plt.show()

# Visualize some results
visualize_results(generator, dataloader, device, num_samples=4)

## SRGAN Training Setup

### Perceptual Loss Components:

1. **Content Loss (VGG)**: MSE between VGG features of SR and HR images
   - Weight: `lambda_content = 1.0`
   - Uses pre-trained VGG19 features from layer 36 (conv5_4)

2. **Adversarial Loss**: BCE loss to fool the discriminator
   - Weight: `lambda_adversarial = 1e-3`
   - Encourages photorealistic images

3. **Pixel Loss (MSE)**: Direct pixel-wise MSE between SR and HR
   - Weight: `lambda_pixel = 1e-2`
   - Helps with initial convergence

### Training Strategy:

**Phase 1 - Pre-training (MSE only):**
- Train generator with pixel-wise MSE loss only
- Helps stabilize initial training
- Default: 5 epochs

**Phase 2 - GAN Training (Full Perceptual Loss):**
- Alternate between discriminator and generator training
- Generator uses combined perceptual loss
- Discriminator learns to distinguish real vs generated images
- Checkpoints saved every 10 epochs

### Hyperparameters:
- Learning rate: 1e-4 (both G and D)
- Optimizer: Adam (β1=0.9, β2=0.999)
- Batch size: 32