# WGAN-GP: Video Frame Generation using Wasserstein GAN with Gradient Penalty

**Author:** Fernando Campa and Gabriel Vanderklok 
**Course:** CST-435 Deep Learning  
**Date:** 12/11/2025

---

## 1. Problem Statement

Build a **Wasserstein GAN with Gradient Penalty (WGAN-GP)** to generate realistic video frame images at 200x200 resolution. Extract every 10th frame from input video for training.

**WGAN-GP advantages over traditional GANs:**
- More stable training dynamics
- Meaningful loss metrics
- Addresses mode collapse and vanishing gradients

## 2. Algorithm

### 2.1 Wasserstein GAN Theory

WGAN uses **Wasserstein distance** (Earth Mover's Distance) instead of Binary Cross-Entropy.

**Traditional GAN Loss:**
$$\min_G \max_D \mathbb{E}_{x \sim p_{data}}[\log D(x)] + \mathbb{E}_{z \sim p_z}[\log(1 - D(G(z)))]$$

**WGAN Loss:**
$$\min_G \max_{D \in \mathcal{D}} \mathbb{E}_{x \sim p_{data}}[D(x)] - \mathbb{E}_{z \sim p_z}[D(G(z))]$$

Where $\mathcal{D}$ is the set of 1-Lipschitz functions.

### 2.2 Gradient Penalty

Enforce Lipschitz constraint with **gradient penalty**:

$$\mathcal{L}_{GP} = \lambda \mathbb{E}_{\hat{x} \sim p_{\hat{x}}}[(||\nabla_{\hat{x}} D(\hat{x})||_2 - 1)^2]$$

Where $\hat{x} = \alpha x_{real} + (1 - \alpha) x_{fake}, \quad \alpha \sim U[0,1]$

### 2.3 Training Algorithm

```
For each epoch:
    For each batch:
        # Train Critic (5 iterations)
        For k = 1 to 5:
            1. Sample real images
            2. Generate fake images
            3. Compute Wasserstein loss
            4. Compute gradient penalty
            5. Update critic
        
        # Train Generator (1 iteration)
        1. Generate fake images
        2. Compute loss
        3. Update generator
```

### 2.4 Hyperparameters

| Parameter | Value | Rationale |
|-----------|-------|----------|
| Œª (GP) | 10 | WGAN-GP paper standard |
| Critic iterations | 5 | Stable gradients |
| Learning rate | 0.0001 | Stability |
| Adam Œ≤1, Œ≤2 | 0.0, 0.9 | WGAN recommended |
| Latent dimension | 128 | Image complexity |

## 3. Implementation

### 3.1 Import Libraries

In [None]:
# Core libraries
import os
import sys
import numpy as np
from pathlib import Path

# Deep learning
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

# Image processing
import cv2
from PIL import Image
from torchvision import transforms
from torchvision.utils import save_image, make_grid

# Visualization
import matplotlib.pyplot as plt
from IPython.display import display, clear_output

# Progress tracking
from tqdm.notebook import tqdm

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

### 3.2 Configuration

In [None]:
# Configuration class to store all hyperparameters
class Config:
    """Hyperparameters and settings for WGAN-GP training."""
    
    # Image settings
    IMAGE_SIZE = 200          # Output image resolution (200x200)
    CHANNELS = 3              # RGB images
    
    # Model architecture
    LATENT_DIM = 128          # Size of random noise vector (z)
    GEN_FEATURES = 64         # Base feature maps in generator
    CRITIC_FEATURES = 64      # Base feature maps in critic
    
    # Training settings
    BATCH_SIZE = 16           # Images per batch
    NUM_EPOCHS = 500          # Total training epochs
    LEARNING_RATE = 0.0001    # Adam learning rate
    BETA1 = 0.0               # Adam beta1 (momentum)
    BETA2 = 0.9               # Adam beta2
    
    # WGAN-GP specific
    CRITIC_ITERATIONS = 5     # Critic updates per generator update
    LAMBDA_GP = 10            # Gradient penalty coefficient
    
    # Data settings
    FRAME_INTERVAL = 10       # Extract every Nth frame from video
    
    # Paths
    VIDEO_PATH = "front.mp4"
    OUTPUT_DIR = "output"
    FRAMES_DIR = "output/frames"
    SAMPLES_DIR = "output/samples"
    CHECKPOINTS_DIR = "output/checkpoints"

# Initialize configuration
config = Config()

# Create output directories
os.makedirs(config.FRAMES_DIR, exist_ok=True)
os.makedirs(config.SAMPLES_DIR, exist_ok=True)
os.makedirs(config.CHECKPOINTS_DIR, exist_ok=True)

# Set device (GPU if available)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

### 3.3 Frame Extraction

Extract every 10th frame from video to reduce redundancy while maintaining visual diversity.

In [None]:
def extract_frames(video_path: str, output_dir: str, interval: int = 10) -> int:
    """
    Extract every nth frame from a video file.
    
    Args:
        video_path: Path to input video file
        output_dir: Directory to save extracted frames
        interval: Extract every nth frame (default: 10)
    
    Returns:
        Number of frames extracted
    """
    os.makedirs(output_dir, exist_ok=True)
    
    # Open video file
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        raise ValueError(f"Could not open video file: {video_path}")
    
    # Get video properties
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    fps = cap.get(cv2.CAP_PROP_FPS)
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    
    print(f"Video Properties:")
    print(f"  - Total frames: {total_frames}")
    print(f"  - FPS: {fps:.2f}")
    print(f"  - Resolution: {width}x{height}")
    print(f"  - Duration: {total_frames/fps:.2f} seconds")
    print(f"\nExtracting every {interval}th frame...")
    
    frame_count = 0
    saved_count = 0
    
    # Progress bar
    pbar = tqdm(total=total_frames // interval, desc="Extracting frames")
    
    while True:
        ret, frame = cap.read()
        if not ret:
            break
        
        if frame_count % interval == 0:
            # Convert BGR (OpenCV) to RGB
            frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            img = Image.fromarray(frame_rgb)
            
            # Save frame as PNG
            save_path = os.path.join(output_dir, f"frame_{saved_count:06d}.png")
            img.save(save_path)
            saved_count += 1
            pbar.update(1)
        
        frame_count += 1
    
    cap.release()
    pbar.close()
    
    print(f"\nExtracted {saved_count} frames to {output_dir}")
    return saved_count

In [None]:
# Extract frames from video
num_frames = extract_frames(config.VIDEO_PATH, config.FRAMES_DIR, config.FRAME_INTERVAL)

### 3.4 Dataset and DataLoader

In [None]:
class FrameDataset(Dataset):
    """
    Custom Dataset for loading extracted video frames.
    
    Applies transformations:
    - Resize to target dimensions
    - Convert to tensor
    - Normalize to [-1, 1] range (required for Tanh output)
    """
    
    def __init__(self, frames_dir: str, image_size: int = 200):
        """
        Args:
            frames_dir: Directory containing frame images
            image_size: Target size for resizing (default: 200)
        """
        self.frames_dir = frames_dir
        
        # Get list of image files
        self.image_files = sorted([
            f for f in os.listdir(frames_dir) 
            if f.endswith(('.png', '.jpg', '.jpeg'))
        ])
        
        if len(self.image_files) == 0:
            raise ValueError(f"No images found in {frames_dir}")
        
        # Define image transformations
        self.transform = transforms.Compose([
            transforms.Resize((image_size, image_size)),  # Resize to 200x200
            transforms.ToTensor(),                         # Convert to tensor [0, 1]
            transforms.Normalize([0.5, 0.5, 0.5],          # Normalize to [-1, 1]
                                 [0.5, 0.5, 0.5])
        ])
    
    def __len__(self):
        """Return total number of images."""
        return len(self.image_files)
    
    def __getitem__(self, idx):
        """
        Load and transform a single image.
        
        Args:
            idx: Index of image to load
        
        Returns:
            Transformed image tensor
        """
        img_path = os.path.join(self.frames_dir, self.image_files[idx])
        image = Image.open(img_path).convert('RGB')
        return self.transform(image)

In [None]:
# Create dataset and dataloader
dataset = FrameDataset(config.FRAMES_DIR, config.IMAGE_SIZE)

dataloader = DataLoader(
    dataset,
    batch_size=config.BATCH_SIZE,
    shuffle=True,              # Shuffle for training
    num_workers=0,             # Set to 0 for Windows compatibility
    pin_memory=True if device.type == 'cuda' else False
)

print(f"Dataset size: {len(dataset)} images")
print(f"Batches per epoch: {len(dataloader)}")
print(f"Batch size: {config.BATCH_SIZE}")

### 3.5 Visualize Training Data

In [None]:
def show_images(images, title="Images", nrow=4):
    """
    Display a grid of images.
    
    Args:
        images: Tensor of images (N, C, H, W) in range [-1, 1]
        title: Plot title
        nrow: Number of images per row
    """
    # Denormalize from [-1, 1] to [0, 1]
    images = (images + 1) / 2
    images = images.clamp(0, 1)
    
    # Create grid
    grid = make_grid(images, nrow=nrow, padding=2)
    
    # Convert to numpy for matplotlib
    grid_np = grid.permute(1, 2, 0).cpu().numpy()
    
    # Display
    plt.figure(figsize=(12, 12))
    plt.imshow(grid_np)
    plt.title(title)
    plt.axis('off')
    plt.show()

# Get a batch of real images
real_batch = next(iter(dataloader))
show_images(real_batch[:16], "Sample Training Images (Real Frames)", nrow=4)

## 4. Neural Network Architecture

### 4.1 Generator Network

Transforms 128-dimensional latent vector to 200x200 RGB image via progressive upsampling.

**Architecture:** 5√ó5 ‚Üí 10√ó10 ‚Üí 25√ó25 ‚Üí 50√ó50 ‚Üí 100√ó100 ‚Üí 200√ó200
- Batch Normalization for stability
- LeakyReLU activation
- Tanh output ([-1, 1] range)

In [None]:
class Generator(nn.Module):
    """
    Generator Network for WGAN-GP.
    
    Transforms a random latent vector into a 200x200 RGB image.
    Uses transposed convolutions and upsampling for progressive
    spatial resolution increase.
    """
    
    def __init__(self, latent_dim: int, channels: int, features: int):
        """
        Initialize Generator network.
        
        Args:
            latent_dim: Size of input noise vector (default: 128)
            channels: Number of output channels (3 for RGB)
            features: Base number of feature maps (default: 64)
        """
        super().__init__()
        
        self.init_size = 5  # Initial spatial size before upsampling
        
        # Input layer: Transform latent vector to feature maps
        # Output shape: (batch, features*16, 5, 5)
        self.l1 = nn.Sequential(
            nn.Linear(latent_dim, features * 16 * self.init_size * self.init_size)
        )
        
        # Convolutional blocks with upsampling
        self.conv_blocks = nn.Sequential(
            # Block 1: 5x5 -> 10x10
            nn.BatchNorm2d(features * 16),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(features * 16, features * 8, 3, stride=1, padding=1),
            nn.BatchNorm2d(features * 8),
            nn.LeakyReLU(0.2, inplace=True),
            
            # Block 2: 10x10 -> 25x25
            nn.Upsample(size=(25, 25)),
            nn.Conv2d(features * 8, features * 4, 3, stride=1, padding=1),
            nn.BatchNorm2d(features * 4),
            nn.LeakyReLU(0.2, inplace=True),
            
            # Block 3: 25x25 -> 50x50
            nn.Upsample(scale_factor=2),
            nn.Conv2d(features * 4, features * 2, 3, stride=1, padding=1),
            nn.BatchNorm2d(features * 2),
            nn.LeakyReLU(0.2, inplace=True),
            
            # Block 4: 50x50 -> 100x100
            nn.Upsample(scale_factor=2),
            nn.Conv2d(features * 2, features, 3, stride=1, padding=1),
            nn.BatchNorm2d(features),
            nn.LeakyReLU(0.2, inplace=True),
            
            # Block 5: 100x100 -> 200x200 (Output layer)
            nn.Upsample(scale_factor=2),
            nn.Conv2d(features, channels, 3, stride=1, padding=1),
            nn.Tanh()  # Output range: [-1, 1]
        )
    
    def forward(self, z):
        """
        Forward pass through generator.
        
        Args:
            z: Latent vector tensor of shape (batch_size, latent_dim)
        
        Returns:
            Generated images of shape (batch_size, 3, 200, 200)
        """
        # Transform latent vector to initial feature maps
        out = self.l1(z)
        out = out.view(out.shape[0], -1, self.init_size, self.init_size)
        
        # Apply convolutional blocks
        img = self.conv_blocks(out)
        return img

### 4.2 Critic (Discriminator) Network

Evaluates images and outputs scalar score (higher = more real).

**Architecture:** Progressive downsampling 200 ‚Üí 100 ‚Üí 50 ‚Üí 25 ‚Üí 12 ‚Üí 6
- Instance Normalization (better than BatchNorm for WGAN-GP)
- LeakyReLU activation
- Linear output (no sigmoid)

In [None]:
class Critic(nn.Module):
    """
    Critic (Discriminator) Network for WGAN-GP.
    
    Takes an image and outputs a scalar score indicating
    how "real" the image appears. Higher scores = more real.
    
    Note: Called "Critic" in WGAN terminology because it doesn't
    output a probability (no sigmoid activation).
    """
    
    def __init__(self, channels: int, features: int):
        """
        Initialize Critic network.
        
        Args:
            channels: Number of input channels (3 for RGB)
            features: Base number of feature maps (default: 64)
        """
        super().__init__()
        
        # Convolutional layers with strided convolutions for downsampling
        self.model = nn.Sequential(
            # Layer 1: 200x200 -> 100x100
            nn.Conv2d(channels, features, 4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            
            # Layer 2: 100x100 -> 50x50
            nn.Conv2d(features, features * 2, 4, stride=2, padding=1),
            nn.InstanceNorm2d(features * 2, affine=True),
            nn.LeakyReLU(0.2, inplace=True),
            
            # Layer 3: 50x50 -> 25x25
            nn.Conv2d(features * 2, features * 4, 4, stride=2, padding=1),
            nn.InstanceNorm2d(features * 4, affine=True),
            nn.LeakyReLU(0.2, inplace=True),
            
            # Layer 4: 25x25 -> 12x12
            nn.Conv2d(features * 4, features * 8, 4, stride=2, padding=1),
            nn.InstanceNorm2d(features * 8, affine=True),
            nn.LeakyReLU(0.2, inplace=True),
            
            # Layer 5: 12x12 -> 6x6
            nn.Conv2d(features * 8, features * 16, 4, stride=2, padding=1),
            nn.InstanceNorm2d(features * 16, affine=True),
            nn.LeakyReLU(0.2, inplace=True),
        )
        
        # Output layer: Flatten and produce single scalar
        # Input size: features*16 * 6 * 6
        self.fc = nn.Linear(features * 16 * 6 * 6, 1)
    
    def forward(self, img):
        """
        Forward pass through critic.
        
        Args:
            img: Image tensor of shape (batch_size, 3, 200, 200)
        
        Returns:
            Critic scores of shape (batch_size, 1)
        """
        out = self.model(img)
        out = out.view(out.shape[0], -1)  # Flatten
        validity = self.fc(out)
        return validity

### 4.3 Initialize Networks

In [None]:
# Initialize Generator and Critic
generator = Generator(
    latent_dim=config.LATENT_DIM,
    channels=config.CHANNELS,
    features=config.GEN_FEATURES
).to(device)

critic = Critic(
    channels=config.CHANNELS,
    features=config.CRITIC_FEATURES
).to(device)

# Count parameters
gen_params = sum(p.numel() for p in generator.parameters())
critic_params = sum(p.numel() for p in critic.parameters())

print("Generator Architecture:")
print(f"  - Parameters: {gen_params:,}")
print(f"  - Input: {config.LATENT_DIM}-dimensional latent vector")
print(f"  - Output: {config.CHANNELS}x{config.IMAGE_SIZE}x{config.IMAGE_SIZE} image")

print(f"\nCritic Architecture:")
print(f"  - Parameters: {critic_params:,}")
print(f"  - Input: {config.CHANNELS}x{config.IMAGE_SIZE}x{config.IMAGE_SIZE} image")
print(f"  - Output: Scalar score")

print(f"\nTotal parameters: {gen_params + critic_params:,}")

## 5. Training

### 5.1 Gradient Penalty Function

Enforces Lipschitz constraint by penalizing when gradient norm deviates from 1.

In [None]:
def compute_gradient_penalty(critic, real_samples, fake_samples, device):
    """
    Compute gradient penalty for WGAN-GP.
    
    The gradient penalty encourages the critic to have gradients
    with norm close to 1, which enforces the Lipschitz constraint.
    
    Args:
        critic: Critic network
        real_samples: Batch of real images
        fake_samples: Batch of generated images
        device: Computing device (CPU/GPU)
    
    Returns:
        Gradient penalty scalar
    """
    batch_size = real_samples.size(0)
    
    # Random interpolation coefficient (uniform between 0 and 1)
    alpha = torch.rand(batch_size, 1, 1, 1, device=device)
    alpha = alpha.expand_as(real_samples)
    
    # Create interpolated samples between real and fake
    interpolated = alpha * real_samples + (1 - alpha) * fake_samples
    interpolated.requires_grad_(True)
    
    # Get critic scores for interpolated samples
    d_interpolated = critic(interpolated)
    
    # Compute gradients w.r.t. interpolated samples
    gradients = torch.autograd.grad(
        outputs=d_interpolated,
        inputs=interpolated,
        grad_outputs=torch.ones_like(d_interpolated),
        create_graph=True,
        retain_graph=True,
        only_inputs=True
    )[0]
    
    # Compute gradient penalty: (||grad|| - 1)^2
    gradients = gradients.view(batch_size, -1)
    gradient_norm = gradients.norm(2, dim=1)
    gradient_penalty = ((gradient_norm - 1) ** 2).mean()
    
    return gradient_penalty

### 5.2 Training Loop

In [None]:
# Initialize optimizers (Adam with WGAN-specific betas)
opt_generator = optim.Adam(
    generator.parameters(),
    lr=config.LEARNING_RATE,
    betas=(config.BETA1, config.BETA2)
)

opt_critic = optim.Adam(
    critic.parameters(),
    lr=config.LEARNING_RATE,
    betas=(config.BETA1, config.BETA2)
)

# Fixed noise for consistent visualization during training
fixed_noise = torch.randn(16, config.LATENT_DIM, device=device)

# Training history
history = {
    'g_losses': [],
    'c_losses': [],
    'epochs': []
}

print("Training Configuration:")
print(f"  - Epochs: {config.NUM_EPOCHS}")
print(f"  - Batch size: {config.BATCH_SIZE}")
print(f"  - Learning rate: {config.LEARNING_RATE}")
print(f"  - Critic iterations: {config.CRITIC_ITERATIONS}")
print(f"  - Gradient penalty Œª: {config.LAMBDA_GP}")
print(f"  - Device: {device}")

In [None]:
# Main training loop
print(f"\nStarting WGAN-GP training for {config.NUM_EPOCHS} epochs...\n")

for epoch in range(config.NUM_EPOCHS):
    epoch_g_loss = 0
    epoch_c_loss = 0
    num_batches = 0
    
    # Progress bar for current epoch
    pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{config.NUM_EPOCHS}")
    
    for batch_idx, real_imgs in enumerate(pbar):
        real_imgs = real_imgs.to(device)
        batch_size = real_imgs.size(0)
        
        # =====================
        # Train Critic
        # =====================
        # Train critic multiple times per generator iteration
        for _ in range(config.CRITIC_ITERATIONS):
            opt_critic.zero_grad()
            
            # Generate fake images
            z = torch.randn(batch_size, config.LATENT_DIM, device=device)
            fake_imgs = generator(z).detach()
            
            # Critic scores for real and fake images
            real_validity = critic(real_imgs)
            fake_validity = critic(fake_imgs)
            
            # Compute gradient penalty
            gradient_penalty = compute_gradient_penalty(
                critic, real_imgs, fake_imgs, device
            )
            
            # Wasserstein loss + gradient penalty
            # Critic wants: high scores for real, low scores for fake
            c_loss = (
                -torch.mean(real_validity) + 
                torch.mean(fake_validity) + 
                config.LAMBDA_GP * gradient_penalty
            )
            
            c_loss.backward()
            opt_critic.step()
        
        # =====================
        # Train Generator
        # =====================
        opt_generator.zero_grad()
        
        # Generate fake images
        z = torch.randn(batch_size, config.LATENT_DIM, device=device)
        fake_imgs = generator(z)
        
        # Generator loss: wants critic to give high scores to fakes
        g_loss = -torch.mean(critic(fake_imgs))
        
        g_loss.backward()
        opt_generator.step()
        
        # Track losses
        epoch_g_loss += g_loss.item()
        epoch_c_loss += c_loss.item()
        num_batches += 1
        
        # Update progress bar
        pbar.set_postfix({
            'G_loss': f'{g_loss.item():.4f}',
            'C_loss': f'{c_loss.item():.4f}'
        })
    
    # Calculate average losses for epoch
    avg_g_loss = epoch_g_loss / num_batches
    avg_c_loss = epoch_c_loss / num_batches
    
    # Store history
    history['g_losses'].append(avg_g_loss)
    history['c_losses'].append(avg_c_loss)
    history['epochs'].append(epoch + 1)
    
    # Print epoch summary
    print(f"Epoch [{epoch+1}/{config.NUM_EPOCHS}] "
          f"G_loss: {avg_g_loss:.4f}, C_loss: {avg_c_loss:.4f}")
    
    # Save sample images every 10 epochs
    if (epoch + 1) % 10 == 0 or epoch == 0:
        generator.eval()
        with torch.no_grad():
            fake_samples = generator(fixed_noise)
            fake_samples = (fake_samples + 1) / 2  # Denormalize
            save_image(
                fake_samples,
                os.path.join(config.SAMPLES_DIR, f"epoch_{epoch+1:04d}.png"),
                nrow=4,
                normalize=False
            )
        generator.train()
    
    # Save checkpoint every 50 epochs
    if (epoch + 1) % 50 == 0:
        torch.save({
            'epoch': epoch,
            'generator_state_dict': generator.state_dict(),
            'critic_state_dict': critic.state_dict(),
            'opt_g_state_dict': opt_generator.state_dict(),
            'opt_c_state_dict': opt_critic.state_dict(),
            'history': history,
        }, os.path.join(config.CHECKPOINTS_DIR, f"checkpoint_epoch_{epoch+1}.pt"))

# Save final model
torch.save({
    'generator_state_dict': generator.state_dict(),
    'critic_state_dict': critic.state_dict(),
    'history': history,
    'config': {
        'latent_dim': config.LATENT_DIM,
        'channels': config.CHANNELS,
        'image_size': config.IMAGE_SIZE,
        'gen_features': config.GEN_FEATURES,
        'critic_features': config.CRITIC_FEATURES,
    }
}, os.path.join(config.CHECKPOINTS_DIR, "final_model.pt"))

print(f"\n{'='*50}")
print("Training Complete!")
print(f"Final model saved to: {config.CHECKPOINTS_DIR}/final_model.pt")
print(f"{'='*50}")

## 6. Analysis and Results

### 6.1 Training Loss Curves

In [None]:
# Plot training losses
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Generator loss
axes[0].plot(history['epochs'], history['g_losses'], 'b-', linewidth=1.5)
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Generator Loss')
axes[0].set_title('Generator Loss Over Training')
axes[0].grid(True, alpha=0.3)

# Critic loss
axes[1].plot(history['epochs'], history['c_losses'], 'r-', linewidth=1.5)
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Critic Loss')
axes[1].set_title('Critic Loss Over Training')
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('output/loss_curves.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"Final Generator Loss: {history['g_losses'][-1]:.4f}")
print(f"Final Critic Loss: {history['c_losses'][-1]:.4f}")

### 6.2 Generated Samples

In [None]:
# Generate new images
generator.eval()

with torch.no_grad():
    # Generate 16 random images
    z = torch.randn(16, config.LATENT_DIM, device=device)
    generated_images = generator(z)

# Display generated images
show_images(generated_images, "Generated Images (Final Model)", nrow=4)

### 6.3 Training Progression

Let's visualize how generated images improved during training.

In [None]:
# Load and display training progression
import glob

sample_files = sorted(glob.glob(os.path.join(config.SAMPLES_DIR, "epoch_*.png")))

if sample_files:
    # Select samples at key epochs
    num_samples = min(10, len(sample_files))
    indices = np.linspace(0, len(sample_files)-1, num_samples, dtype=int)
    
    fig, axes = plt.subplots(2, 5, figsize=(20, 8))
    axes = axes.flatten()
    
    for i, idx in enumerate(indices):
        img = Image.open(sample_files[idx])
        epoch_num = int(Path(sample_files[idx]).stem.split('_')[1])
        
        axes[i].imshow(img)
        axes[i].set_title(f'Epoch {epoch_num}')
        axes[i].axis('off')
    
    plt.suptitle('Training Progression: Generated Images Over Time', fontsize=14)
    plt.tight_layout()
    plt.savefig('output/training_progression.png', dpi=150, bbox_inches='tight')
    plt.show()
else:
    print("No sample images found.")

### 6.4 Real vs Generated Comparison

In [None]:
# Compare real and generated images side by side
fig, axes = plt.subplots(2, 4, figsize=(16, 8))

# Get real images
real_batch = next(iter(dataloader))[:4]

# Generate fake images
with torch.no_grad():
    z = torch.randn(4, config.LATENT_DIM, device=device)
    fake_batch = generator(z).cpu()

# Display real images (top row)
for i in range(4):
    img = (real_batch[i] + 1) / 2  # Denormalize
    img = img.permute(1, 2, 0).numpy()
    axes[0, i].imshow(img)
    axes[0, i].set_title('Real')
    axes[0, i].axis('off')

# Display generated images (bottom row)
for i in range(4):
    img = (fake_batch[i] + 1) / 2  # Denormalize
    img = img.permute(1, 2, 0).numpy()
    axes[1, i].imshow(img)
    axes[1, i].set_title('Generated')
    axes[1, i].axis('off')

plt.suptitle('Real vs Generated Images Comparison', fontsize=14)
plt.tight_layout()
plt.savefig('output/real_vs_generated.png', dpi=150, bbox_inches='tight')
plt.show()

### 6.5 Latent Space Interpolation

Smooth transitions demonstrate meaningful learned representations.

In [None]:
# Interpolate between two random latent vectors
def interpolate(z1, z2, steps=10):
    """Linear interpolation between two latent vectors."""
    ratios = np.linspace(0, 1, steps)
    vectors = [(1 - r) * z1 + r * z2 for r in ratios]
    return torch.stack(vectors)

# Generate interpolation
torch.manual_seed(42)
z1 = torch.randn(1, config.LATENT_DIM, device=device)
z2 = torch.randn(1, config.LATENT_DIM, device=device)

z_interp = interpolate(z1.squeeze(), z2.squeeze(), steps=10).to(device)

with torch.no_grad():
    interp_images = generator(z_interp)

# Display interpolation
fig, axes = plt.subplots(1, 10, figsize=(20, 3))

for i in range(10):
    img = (interp_images[i] + 1) / 2
    img = img.cpu().permute(1, 2, 0).numpy()
    axes[i].imshow(img)
    axes[i].axis('off')
    if i == 0:
        axes[i].set_title('Start')
    elif i == 9:
        axes[i].set_title('End')

plt.suptitle('Latent Space Interpolation', fontsize=14)
plt.tight_layout()
plt.savefig('output/latent_interpolation.png', dpi=150, bbox_inches='tight')
plt.show()

## 7. Performance Summary

### 7.1 Model Statistics

In [None]:
# Print comprehensive summary
print("="*60)
print("WGAN-GP TRAINING SUMMARY")
print("="*60)

print("\nüìä Dataset:")
print(f"   - Source: {config.VIDEO_PATH}")
print(f"   - Frames extracted: {len(dataset)}")
print(f"   - Frame interval: Every {config.FRAME_INTERVAL}th frame")
print(f"   - Image size: {config.IMAGE_SIZE}x{config.IMAGE_SIZE}")

print("\nüèóÔ∏è Architecture:")
print(f"   - Generator parameters: {gen_params:,}")
print(f"   - Critic parameters: {critic_params:,}")
print(f"   - Total parameters: {gen_params + critic_params:,}")
print(f"   - Latent dimension: {config.LATENT_DIM}")

print("\n‚öôÔ∏è Training Configuration:")
print(f"   - Epochs: {config.NUM_EPOCHS}")
print(f"   - Batch size: {config.BATCH_SIZE}")
print(f"   - Learning rate: {config.LEARNING_RATE}")
print(f"   - Critic iterations: {config.CRITIC_ITERATIONS}")
print(f"   - Gradient penalty Œª: {config.LAMBDA_GP}")
print(f"   - Device: {device}")

print("\nüìà Final Results:")
print(f"   - Final Generator Loss: {history['g_losses'][-1]:.4f}")
print(f"   - Final Critic Loss: {history['c_losses'][-1]:.4f}")
print(f"   - Best Generator Loss: {min(history['g_losses']):.4f}")

print("\nüìÅ Output Files:")
print(f"   - Model: {config.CHECKPOINTS_DIR}/final_model.pt")
print(f"   - Samples: {config.SAMPLES_DIR}/")
print(f"   - Plots: output/loss_curves.png")

print("\n" + "="*60)

## 8. Conclusion

### 8.1 Summary

Successfully implemented **WGAN-GP** for video frame generation with stable training, quality image generation at 200√ó200, and meaningful latent space representations.

### 8.2 Key Findings

- **Wasserstein loss** provides more meaningful gradients than binary cross-entropy
- **Gradient penalty** effectively enforces Lipschitz constraint without weight clipping
- **Instance normalization** in critic works better than batch normalization for WGAN-GP
- **Training critic more** (5:1 ratio) improves stability

### 8.3 Future Improvements

**IMPORTANT - Rocket League Training Data:**
**üéÆ For Rocket League image generation, training on MULTIPLE CAMERA ANGLES is critical:**
- **Ball cam** - follows the ball, dynamic perspectives
- **Car cam** - fixed to car, diverse car orientations
- **Different view angles** - varying heights, distances, and perspectives
- This diversity enables the model to generate cars in any orientation/position, not just limited viewpoints
- Current single-perspective training restricts output variety

**Additional enhancements:**
- **Progressive Growing GAN** for higher resolution outputs
- **Self-Attention layers** for better global coherence
- **Spectral Normalization** as alternative to gradient penalty
- **FID (Fr√©chet Inception Distance)** for quantitative evaluation