# WGAN-GP (Standard) + Adaptive-Discriminator Hooks (Template)

This notebook provides:
- A **standard WGAN-GP** training loop (critic + gradient penalty).
- A **pluggable "Adaptive Discriminator" controller** with clearly marked hooks so you can experiment with different discriminator/critic adjustment strategies (e.g., dynamic `n_critic`, LR, GP weight, architecture toggles, etc.).

> Notes: WGAN-GP uses a gradient penalty to enforce the 1-Lipschitz constraint instead of weight clipping (see WGAN-GP objective in the referenced survey paper). fileciteturn0file0


In [None]:
# Cell 1 — Imports & Reproducibility
import os
import math
import random
from dataclasses import dataclass
from typing import Optional, Dict, Any, Tuple

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

def seed_everything(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed_everything(42)
print("Device:", DEVICE)


In [None]:
# Cell 2 — Simple 1D Dataset Placeholder (Replace with your EEG loader)
# Expect samples shaped as (C, L) where:
#   C = number of channels (e.g., EEG electrodes or feature channels)
#   L = sequence length (time points)
#
# Replace this with your real EEG dataset.
class RandomEEGDataset(Dataset):
    def __init__(self, n_samples: int = 10_000, channels: int = 8, length: int = 256):
        self.n = n_samples
        self.c = channels
        self.l = length

    def __len__(self):
        return self.n

    def __getitem__(self, idx):
        # Random "EEG-like" noise as a placeholder
        x = torch.randn(self.c, self.l)
        return x

# --- Configure data shape here ---
CHANNELS = 8
SEQ_LEN  = 256

dataset = RandomEEGDataset(n_samples=5000, channels=CHANNELS, length=SEQ_LEN)
loader  = DataLoader(dataset, batch_size=64, shuffle=True, drop_last=True, num_workers=0, pin_memory=(DEVICE=="cuda"))
print("Batches:", len(loader))


In [None]:
# Cell 3 — Generator & Critic (1D Conv, Standard WGAN-GP Style)
# This is a *standard* WGAN-GP setup: the "discriminator" is a *critic* with a linear output (no sigmoid).

def weights_init(m):
    if isinstance(m, (nn.Conv1d, nn.ConvTranspose1d, nn.Linear)):
        nn.init.normal_(m.weight.data, 0.0, 0.02)
        if m.bias is not None:
            nn.init.constant_(m.bias.data, 0)

class Generator1D(nn.Module):
    def __init__(self, z_dim: int, out_channels: int, seq_len: int, base: int = 64):
        super().__init__()
        assert seq_len % 16 == 0, "For this template, seq_len should be divisible by 16."
        self.z_dim = z_dim
        self.out_channels = out_channels
        self.seq_len = seq_len

        # Project noise to a small temporal resolution then upsample x16 via ConvTranspose1d
        self.init_len = seq_len // 16
        self.fc = nn.Linear(z_dim, base * 8 * self.init_len)

        self.net = nn.Sequential(
            nn.ConvTranspose1d(base * 8, base * 4, kernel_size=4, stride=2, padding=1),  # x2
            nn.BatchNorm1d(base * 4),
            nn.ReLU(True),

            nn.ConvTranspose1d(base * 4, base * 2, kernel_size=4, stride=2, padding=1),  # x4
            nn.BatchNorm1d(base * 2),
            nn.ReLU(True),

            nn.ConvTranspose1d(base * 2, base, kernel_size=4, stride=2, padding=1),      # x8
            nn.BatchNorm1d(base),
            nn.ReLU(True),

            nn.ConvTranspose1d(base, out_channels, kernel_size=4, stride=2, padding=1),  # x16
            nn.Tanh(),
        )

    def forward(self, z):
        x = self.fc(z)
        x = x.view(z.size(0), -1, self.init_len)
        x = self.net(x)
        return x

class Critic1D(nn.Module):
    def __init__(self, in_channels: int, seq_len: int, base: int = 64):
        super().__init__()
        assert seq_len % 16 == 0, "For this template, seq_len should be divisible by 16."

        self.net = nn.Sequential(
            nn.Conv1d(in_channels, base, kernel_size=4, stride=2, padding=1),   # /2
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv1d(base, base * 2, kernel_size=4, stride=2, padding=1),      # /4
            nn.InstanceNorm1d(base * 2, affine=True),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv1d(base * 2, base * 4, kernel_size=4, stride=2, padding=1),  # /8
            nn.InstanceNorm1d(base * 4, affine=True),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv1d(base * 4, base * 8, kernel_size=4, stride=2, padding=1),  # /16
            nn.InstanceNorm1d(base * 8, affine=True),
            nn.LeakyReLU(0.2, inplace=True),
        )

        self.out = nn.Linear(base * 8 * (seq_len // 16), 1)

    def forward(self, x):
        h = self.net(x)
        h = h.view(x.size(0), -1)
        return self.out(h).view(-1)

# --- Model configs ---
Z_DIM = 128

G = Generator1D(z_dim=Z_DIM, out_channels=CHANNELS, seq_len=SEQ_LEN).to(DEVICE)
D = Critic1D(in_channels=CHANNELS, seq_len=SEQ_LEN).to(DEVICE)

G.apply(weights_init)
D.apply(weights_init)

print("G params:", sum(p.numel() for p in G.parameters())/1e6, "M")
print("D params:", sum(p.numel() for p in D.parameters())/1e6, "M")


In [None]:
# Cell 4 — WGAN-GP Utilities (Gradient Penalty, Losses)
# WGAN-GP objective uses:
#   loss_D = E[D(fake)] - E[D(real)] + lambda_gp * (||∇_x_hat D(x_hat)||_2 - 1)^2
#   loss_G = -E[D(fake)]
# This matches the standard WGAN-GP formulation referenced in the survey paper. fileciteturn0file0

def gradient_penalty(critic: nn.Module, real: torch.Tensor, fake: torch.Tensor) -> torch.Tensor:
    bsz = real.size(0)
    eps = torch.rand(bsz, 1, 1, device=real.device)
    x_hat = eps * real + (1 - eps) * fake
    x_hat.requires_grad_(True)

    d_hat = critic(x_hat)
    grads = torch.autograd.grad(
        outputs=d_hat,
        inputs=x_hat,
        grad_outputs=torch.ones_like(d_hat),
        create_graph=True,
        retain_graph=True,
        only_inputs=True,
    )[0]
    grads = grads.view(bsz, -1)
    gp = ((grads.norm(2, dim=1) - 1.0) ** 2).mean()
    return gp

@torch.no_grad()
def sample_generator(generator: nn.Module, n: int, z_dim: int) -> torch.Tensor:
    z = torch.randn(n, z_dim, device=DEVICE)
    return generator(z).cpu()


In [None]:
# Cell 5 — Adaptive Discriminator Controller (Hooks / Empty Spots)
# This is the *intended place* to experiment with "Adaptive D" logic.
# You can:
#   - change n_critic dynamically
#   - change critic LR / betas
#   - change lambda_gp
#   - freeze/unfreeze layers
#   - swap architectures
#   - etc.
#
# IMPORTANT: The default implementation below does *nothing* (standard WGAN-GP behavior).
# You should only edit the "TODO" sections.

@dataclass
class TrainState:
    step: int = 0
    epoch: int = 0
    n_critic: int = 5
    lambda_gp: float = 10.0

    # Common diagnostics you may want to use
    wasserstein_gap_ema: float = 0.0
    ema_beta: float = 0.99

class AdaptiveDiscriminatorController:
    def __init__(self):
        # Store any stateful variables for your adaptive method here
        pass

    def on_batch_start(self, state: TrainState) -> None:
        # TODO (optional): adjust settings at the start of each batch
        # Example knobs:
        #   state.n_critic = ...
        #   state.lambda_gp = ...
        pass

    def on_after_critic_update(self, state: TrainState, metrics: Dict[str, float], optim_D: torch.optim.Optimizer) -> None:
        # TODO (optional): react after each critic update
        # You can:
        #   - change optimizer learning rate: for g in optim_D.param_groups: g["lr"] = ...
        #   - change betas / weight decay
        #   - freeze layers based on metrics
        pass

    def on_after_generator_update(self, state: TrainState, metrics: Dict[str, float], optim_G: torch.optim.Optimizer) -> None:
        # TODO (optional): react after generator update
        pass

controller = AdaptiveDiscriminatorController()


In [None]:
# Cell 6 — Optimizers & Hyperparameters (Standard WGAN-GP Defaults)
LR = 1e-4
BETAS = (0.0, 0.9)   # standard WGAN-GP choice

optim_G = torch.optim.Adam(G.parameters(), lr=LR, betas=BETAS)
optim_D = torch.optim.Adam(D.parameters(), lr=LR, betas=BETAS)

state = TrainState(step=0, epoch=0, n_critic=5, lambda_gp=10.0)

print("LR:", LR, "BETAS:", BETAS, "n_critic:", state.n_critic, "lambda_gp:", state.lambda_gp)


In [None]:
# Cell 7 — Training Loop (Standard WGAN-GP + Adaptive Hooks)
# Logs:
#   - wasserstein gap: E[D(real)] - E[D(fake)]  (higher generally means critic separates better)
#   - gp: gradient penalty term
#   - losses for D and G

def train_wgan_gp(
    loader: DataLoader,
    G: nn.Module,
    D: nn.Module,
    optim_G: torch.optim.Optimizer,
    optim_D: torch.optim.Optimizer,
    state: TrainState,
    controller: AdaptiveDiscriminatorController,
    z_dim: int,
    epochs: int = 10,
    log_every: int = 50,
):
    G.train(); D.train()
    for epoch in range(epochs):
        state.epoch = epoch
        for batch_idx, real in enumerate(loader):
            state.step += 1
            controller.on_batch_start(state)

            real = real.to(DEVICE)
            bsz = real.size(0)

            # -------------------------
            # Critic updates (n_critic)
            # -------------------------
            for _ in range(state.n_critic):
                z = torch.randn(bsz, z_dim, device=DEVICE)
                fake = G(z).detach()

                d_real = D(real).mean()
                d_fake = D(fake).mean()
                gap = (d_real - d_fake).item()  # Wasserstein gap estimate

                gp = gradient_penalty(D, real, fake)
                loss_D = (d_fake - d_real) + state.lambda_gp * gp

                optim_D.zero_grad(set_to_none=True)
                loss_D.backward()
                optim_D.step()

                # Update EMA of the gap (useful for adaptive control)
                state.wasserstein_gap_ema = state.ema_beta * state.wasserstein_gap_ema + (1 - state.ema_beta) * gap

                metrics_D = {
                    "d_real": float(d_real.item()),
                    "d_fake": float(d_fake.item()),
                    "gap": float(gap),
                    "gap_ema": float(state.wasserstein_gap_ema),
                    "gp": float(gp.item()),
                    "loss_D": float(loss_D.item()),
                }
                controller.on_after_critic_update(state, metrics_D, optim_D)

            # -------------------------
            # Generator update
            # -------------------------
            z = torch.randn(bsz, z_dim, device=DEVICE)
            fake = G(z)
            loss_G = -D(fake).mean()

            optim_G.zero_grad(set_to_none=True)
            loss_G.backward()
            optim_G.step()

            metrics_G = {"loss_G": float(loss_G.item())}
            controller.on_after_generator_update(state, metrics_G, optim_G)

            if (batch_idx % log_every) == 0:
                print(
                    f"[Epoch {epoch:03d}/{epochs:03d}] [Batch {batch_idx:04d}/{len(loader):04d}] "
                    f"[n_critic: {state.n_critic}] "
                    f"[gap: {metrics_D['gap']:+.3f} | ema: {metrics_D['gap_ema']:+.3f}] "
                    f"[GP: {metrics_D['gp']:.3f}] "
                    f"[D: {metrics_D['loss_D']:+.3f}] [G: {metrics_G['loss_G']:+.3f}]"
                )

train_wgan_gp(loader, G, D, optim_G, optim_D, state, controller, z_dim=Z_DIM, epochs=2, log_every=100)


In [None]:
# Cell 8 — Quick Sanity Sample (Optional)
# This just samples a few generated sequences so you can inspect shapes.
with torch.no_grad():
    fake = sample_generator(G, n=4, z_dim=Z_DIM)
print("Generated batch shape:", tuple(fake.shape))  # (N, C, L)
print("Example stats:", fake.mean().item(), fake.std().item())
