# WGAN-GP (Standard) + Adaptive-Discriminator Hooks (Template)

This notebook provides:
- A **standard WGAN-GP** training loop (critic + gradient penalty).
- A **pluggable "Adaptive Discriminator" controller** with clearly marked hooks so you can experiment with different discriminator/critic adjustment strategies (e.g., dynamic `n_critic`, LR, GP weight, architecture toggles, etc.).

> Notes: WGAN-GP uses a gradient penalty to enforce the 1-Lipschitz constraint instead of weight clipping (see WGAN-GP objective in the referenced survey paper). ÓàÄfileciteÓàÇturn0file0ÓàÅ


In [1]:
# Cell 1 ‚Äî Imports & Reproducibility
import os
import math
import random
from dataclasses import dataclass
from typing import Optional, Dict, Any, Tuple
from pathlib import Path
import mne

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset

import matplotlib.pyplot as plt

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

def seed_everything(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed_everything(42)
print("Device:", DEVICE)


Device: cpu


In [2]:
def save_models(tag_dir: Path, run_tag: str, G: nn.Module, D: nn.Module, history: Dict[str, list]):
    # Generator only (for generating new samples)
    generator_path = tag_dir / f"generator_{run_tag}.pt"
    torch.save(G.state_dict(), generator_path)
    print(f"‚úÖ Saved Generator: {generator_path}")

    # Critic only (for evaluation if needed)
    critic_path = tag_dir / f"critic_{run_tag}.pt"
    torch.save(D.state_dict(), critic_path)
    print(f"‚úÖ Saved Critic: {critic_path}")

    # Save training history as well
    import json
    history_path = tag_dir / f"history_{run_tag}.json"
    with open(history_path, "w") as f:
        json.dump(history, f)
    print(f"‚úÖ Saved history: {history_path}")

    print(f"\nüìÅ All models saved to: {tag_dir}")

In [3]:
def plot_training_history(history: dict, TAG_DIR: Path , RUN_TAG: str = "default"):
    if len(history.get("step", [])) == 0:
        print("History is empty. Train first.")
        return

    step = history["step"]
    fig, axes = plt.subplots(4, 2, figsize=(16, 12))

    # Losses
    axes[0, 0].plot(step, history["loss_G"], label="loss_G", alpha=0.8)
    axes[0, 0].plot(step, history["loss_D"], label="loss_D", alpha=0.8)
    axes[0, 0].set_title("Generator / Critic Loss")
    axes[0, 0].set_xlabel("step")
    axes[0, 0].set_ylabel("loss")
    axes[0, 0].legend()
    axes[0, 0].grid(True)

    # Wasserstein gap
    axes[0, 1].plot(step, history["gap"], label="gap", alpha=0.5)
    axes[0, 1].plot(step, history["gap_ema"], label="gap_ema", linewidth=2)
    axes[0, 1].set_title("Wasserstein Gap")
    axes[0, 1].set_xlabel("step")
    axes[0, 1].set_ylabel("E[D(real)] - E[D(fake)]")
    axes[0, 1].legend()
    axes[0, 1].grid(True)

    # Gradient Penalty
    axes[1, 0].plot(step, history["gp"], label="GP", color="orange")
    axes[1, 0].set_title("Gradient Penalty")
    axes[1, 0].set_xlabel("step")
    axes[1, 0].set_ylabel("gp")
    axes[1, 0].grid(True)

    # Gradient Direction Consistency (Cosine Similarity)
    axes[1, 1].plot(step, history["d_grad_cos_sim"], label="D grad cos_sim", alpha=0.8)
    axes[1, 1].plot(step, history["g_grad_cos_sim"], label="G grad cos_sim", alpha=0.8)
    axes[1, 1].axhline(y=0.9, color='r', linestyle='--', alpha=0.5, label='high threshold')
    axes[1, 1].axhline(y=0.3, color='b', linestyle='--', alpha=0.5, label='low threshold')
    axes[1, 1].set_title("Gradient Direction Consistency (Cosine Similarity)")
    axes[1, 1].set_xlabel("step")
    axes[1, 1].set_ylabel("cosine similarity")
    axes[1, 1].set_ylim(-1.1, 1.1)
    axes[1, 1].legend()
    axes[1, 1].grid(True)

    # n_critic over time - auto-scale with padding
    n_critic_data = history["n_critic"]
    axes[2, 0].plot(step, n_critic_data, label="n_critic", color="green", drawstyle='steps-post', linewidth=2)
    axes[2, 0].set_title("Adaptive n_critic")
    axes[2, 0].set_xlabel("step")
    axes[2, 0].set_ylabel("n_critic")
    # Auto-scale with padding to make constant lines visible
    n_min, n_max = min(n_critic_data), max(n_critic_data)
    if n_min == n_max:  # Constant line case
        axes[2, 0].set_ylim(n_min - 1, n_max + 1)
        axes[2, 0].axhline(y=n_min, color='green', linestyle='-', alpha=0.3, linewidth=10)  # Highlight flat line
    else:
        axes[2, 0].set_ylim(max(0, n_min - 1), n_max + 1)
    axes[2, 0].grid(True)
    axes[2, 0].legend()

    # Combined: D cosine sim vs n_critic
    ax2 = axes[2, 1]
    ax2.plot(step, history["d_grad_cos_sim"], label="D grad cos_sim", color="blue", alpha=0.7)
    ax2.set_xlabel("step")
    ax2.set_ylabel("D grad cosine similarity", color="blue")
    ax2.tick_params(axis='y', labelcolor="blue")
    ax2.set_ylim(-1.1, 1.1)
    
    ax2_twin = ax2.twinx()
    ax2_twin.plot(step, n_critic_data, label="n_critic", color="green", alpha=0.7, drawstyle='steps-post', linewidth=2)
    ax2_twin.set_ylabel("n_critic", color="green")
    ax2_twin.tick_params(axis='y', labelcolor="green")
    if n_min == n_max:
        ax2_twin.set_ylim(n_min - 1, n_max + 1)
    else:
        ax2_twin.set_ylim(max(0, n_min - 1), n_max + 1)
    
    axes[2, 1].set_title("D Gradient Consistency vs n_critic")
    axes[2, 1].grid(True)
    
    # n_gen over time - auto-scale with padding
    n_gen_data = history["n_gen"]
    axes[3, 0].plot(step, n_gen_data, label="n_gen", color="purple", drawstyle='steps-post', linewidth=2)
    axes[3, 0].set_title("Adaptive n_gen")
    axes[3, 0].set_xlabel("step")
    axes[3, 0].set_ylabel("n_gen")
    g_min, g_max = min(n_gen_data), max(n_gen_data)
    if g_min == g_max:  # Constant line case
        axes[3, 0].set_ylim(g_min - 1, g_max + 1)
        axes[3, 0].axhline(y=g_min, color='purple', linestyle='-', alpha=0.3, linewidth=10)
    else:
        axes[3, 0].set_ylim(max(0, g_min - 1), g_max + 1)
    axes[3, 0].grid(True)
    axes[3, 0].legend()
    
    # Combined: G cosine sim vs n_gen
    ax4 = axes[3, 1]
    ax4.plot(step, history["g_grad_cos_sim"], label="G grad cos_sim", color="blue", alpha=0.7)
    ax4.set_xlabel("step")
    ax4.set_ylabel("G grad cosine similarity", color="blue")
    ax4.tick_params(axis='y', labelcolor="blue")
    ax4.set_ylim(-1.1, 1.1)
    ax4.grid(True)

    plt.tight_layout()
    plt.savefig(TAG_DIR / f"training_history_{RUN_TAG}.png")
    plt.show()

In [4]:
def plot_comparison(history_adaptive: dict, history_normal: dict, TAG_DIR: Path, RUN_TAG: str = "comparison"):
    """
    Overlay two training histories to compare adaptive vs normal WGAN-GP.
    
    Args:
        history_adaptive: History dict from adaptive (cosine similarity) training
        history_normal: History dict from normal training
        TAG_DIR: Directory to save the plot
        RUN_TAG: Tag for the saved file name
    """
    if len(history_adaptive.get("step", [])) == 0 or len(history_normal.get("step", [])) == 0:
        print("One or both histories are empty. Train first.")
        return

    fig, axes = plt.subplots(4, 2, figsize=(16, 14))
    
    step_a = history_adaptive["step"]
    step_n = history_normal["step"]

    # ---- Row 0: Losses ----
    # Generator Loss
    axes[0, 0].plot(step_a, history_adaptive["loss_G"], label="Adaptive G", alpha=0.8, color="blue")
    axes[0, 0].plot(step_n, history_normal["loss_G"], label="Normal G", alpha=0.8, color="red", linestyle="--")
    axes[0, 0].set_title("Generator Loss Comparison")
    axes[0, 0].set_xlabel("step")
    axes[0, 0].set_ylabel("loss_G")
    axes[0, 0].legend()
    axes[0, 0].grid(True)

    # Critic Loss
    axes[0, 1].plot(step_a, history_adaptive["loss_D"], label="Adaptive D", alpha=0.8, color="blue")
    axes[0, 1].plot(step_n, history_normal["loss_D"], label="Normal D", alpha=0.8, color="red", linestyle="--")
    axes[0, 1].set_title("Critic Loss Comparison")
    axes[0, 1].set_xlabel("step")
    axes[0, 1].set_ylabel("loss_D")
    axes[0, 1].legend()
    axes[0, 1].grid(True)

    # ---- Row 1: Wasserstein Gap ----
    # Raw gap
    axes[1, 0].plot(step_a, history_adaptive["gap"], label="Adaptive gap", alpha=0.5, color="blue")
    axes[1, 0].plot(step_n, history_normal["gap"], label="Normal gap", alpha=0.5, color="red")
    axes[1, 0].plot(step_a, history_adaptive["gap_ema"], label="Adaptive EMA", linewidth=2, color="darkblue")
    axes[1, 0].plot(step_n, history_normal["gap_ema"], label="Normal EMA", linewidth=2, color="darkred", linestyle="--")
    axes[1, 0].set_title("Wasserstein Gap Comparison")
    axes[1, 0].set_xlabel("step")
    axes[1, 0].set_ylabel("E[D(real)] - E[D(fake)]")
    axes[1, 0].legend()
    axes[1, 0].grid(True)

    # Gradient Penalty
    axes[1, 1].plot(step_a, history_adaptive["gp"], label="Adaptive GP", alpha=0.7, color="blue")
    axes[1, 1].plot(step_n, history_normal["gp"], label="Normal GP", alpha=0.7, color="red", linestyle="--")
    axes[1, 1].set_title("Gradient Penalty Comparison")
    axes[1, 1].set_xlabel("step")
    axes[1, 1].set_ylabel("gp")
    axes[1, 1].legend()
    axes[1, 1].grid(True)

    # ---- Row 2: Gradient Cosine Similarity ----
    # D grad cos_sim
    axes[2, 0].plot(step_a, history_adaptive["d_grad_cos_sim"], label="Adaptive D cos_sim", alpha=0.8, color="blue")
    axes[2, 0].plot(step_n, history_normal["d_grad_cos_sim"], label="Normal D cos_sim", alpha=0.8, color="red", linestyle="--")
    axes[2, 0].axhline(y=0.9, color='green', linestyle=':', alpha=0.5, label='high threshold')
    axes[2, 0].axhline(y=0.3, color='orange', linestyle=':', alpha=0.5, label='low threshold')
    axes[2, 0].set_title("Discriminator Gradient Consistency")
    axes[2, 0].set_xlabel("step")
    axes[2, 0].set_ylabel("cosine similarity")
    axes[2, 0].set_ylim(-1.1, 1.1)
    axes[2, 0].legend()
    axes[2, 0].grid(True)

    # G grad cos_sim
    axes[2, 1].plot(step_a, history_adaptive["g_grad_cos_sim"], label="Adaptive G cos_sim", alpha=0.8, color="blue")
    axes[2, 1].plot(step_n, history_normal["g_grad_cos_sim"], label="Normal G cos_sim", alpha=0.8, color="red", linestyle="--")
    axes[2, 1].set_title("Generator Gradient Consistency")
    axes[2, 1].set_xlabel("step")
    axes[2, 1].set_ylabel("cosine similarity")
    axes[2, 1].set_ylim(-1.1, 1.1)
    axes[2, 1].legend()
    axes[2, 1].grid(True)

    # ---- Row 3: n_critic and n_gen ----
    # n_critic
    axes[3, 0].plot(step_a, history_adaptive["n_critic"], label="Adaptive n_critic", color="blue", drawstyle='steps-post', linewidth=2)
    axes[3, 0].plot(step_n, history_normal["n_critic"], label="Normal n_critic", color="red", drawstyle='steps-post', linewidth=2, linestyle="--")
    axes[3, 0].set_title("n_critic Over Time")
    axes[3, 0].set_xlabel("step")
    axes[3, 0].set_ylabel("n_critic")
    all_n_critic = history_adaptive["n_critic"] + history_normal["n_critic"]
    axes[3, 0].set_ylim(max(0, min(all_n_critic) - 1), max(all_n_critic) + 1)
    axes[3, 0].legend()
    axes[3, 0].grid(True)

    # n_gen
    axes[3, 1].plot(step_a, history_adaptive["n_gen"], label="Adaptive n_gen", color="blue", drawstyle='steps-post', linewidth=2)
    axes[3, 1].plot(step_n, history_normal["n_gen"], label="Normal n_gen", color="red", drawstyle='steps-post', linewidth=2, linestyle="--")
    axes[3, 1].set_title("n_gen Over Time")
    axes[3, 1].set_xlabel("step")
    axes[3, 1].set_ylabel("n_gen")
    all_n_gen = history_adaptive["n_gen"] + history_normal["n_gen"]
    axes[3, 1].set_ylim(max(0, min(all_n_gen) - 1), max(all_n_gen) + 1)
    axes[3, 1].legend()
    axes[3, 1].grid(True)

    plt.suptitle("Adaptive (Cosine Similarity) vs Normal WGAN-GP", fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.savefig(TAG_DIR / f"comparison_{RUN_TAG}.png", dpi=150)
    plt.show()
    
    # Print summary statistics
    print("\n" + "="*60)
    print("SUMMARY STATISTICS")
    print("="*60)
    print(f"{'Metric':<25} {'Adaptive':>15} {'Normal':>15}")
    print("-"*60)
    print(f"{'Final loss_G':<25} {history_adaptive['loss_G'][-1]:>15.4f} {history_normal['loss_G'][-1]:>15.4f}")
    print(f"{'Final loss_D':<25} {history_adaptive['loss_D'][-1]:>15.4f} {history_normal['loss_D'][-1]:>15.4f}")
    print(f"{'Final gap_ema':<25} {history_adaptive['gap_ema'][-1]:>15.4f} {history_normal['gap_ema'][-1]:>15.4f}")
    print(f"{'Mean GP':<25} {np.mean(history_adaptive['gp']):>15.4f} {np.mean(history_normal['gp']):>15.4f}")
    print(f"{'Mean D cos_sim':<25} {np.mean(history_adaptive['d_grad_cos_sim']):>15.4f} {np.mean(history_normal['d_grad_cos_sim']):>15.4f}")
    print(f"{'Mean G cos_sim':<25} {np.mean(history_adaptive['g_grad_cos_sim']):>15.4f} {np.mean(history_normal['g_grad_cos_sim']):>15.4f}")
    print(f"{'Final n_critic':<25} {history_adaptive['n_critic'][-1]:>15} {history_normal['n_critic'][-1]:>15}")
    print("="*60)

In [5]:

def compute_gradient_vector(parameters) -> torch.Tensor:
    """Flatten all gradients into a single vector."""
    grads = []
    for p in parameters:
        if p.grad is not None:
            grads.append(p.grad.detach().view(-1))
    if len(grads) == 0:
        return None
    return torch.cat(grads)

def cosine_similarity_gradients(grad_vec1: torch.Tensor, grad_vec2: torch.Tensor) -> float:
    """Compute cosine similarity between two gradient vectors."""
    if grad_vec1 is None or grad_vec2 is None:
        return 0.0
    cos_sim = F.cosine_similarity(
        grad_vec1.unsqueeze(0), 
        grad_vec2.unsqueeze(0)
    ).item()
    return cos_sim

In [6]:
def scale_to_tanh(x: np.ndarray, clip: float = 3.0, eps: float = 1e-6):
    """
    x: (C, T) numpy
    per-channel standardize, clip, then map to ~[-1, 1]
    """
    mean = x.mean(axis=-1, keepdims=True)
    std  = x.std(axis=-1, keepdims=True)
    std = np.maximum(std, eps)
    x = (x - mean) / std
    x = np.clip(x, -clip, clip) / clip
    return x

def descale_from_tanh(x: np.ndarray, original_mean: np.ndarray, original_std: np.ndarray, clip: float = 3.0):
    """
    x: (C, T) numpy in ~[-1, 1]
    reverse of scale_to_tanh
    """
    x = x * clip
    x = x * original_std + original_mean
    return x

In [7]:
# Cell 3 ‚Äî Generator & Critic (1D Conv, Standard WGAN-GP Style)
# This is a *standard* WGAN-GP setup: the "discriminator" is a *critic* with a linear output (no sigmoid).

def weights_init(m):
    if isinstance(m, (nn.Conv1d, nn.ConvTranspose1d, nn.Linear)):
        nn.init.normal_(m.weight.data, 0.0, 0.02)
        if m.bias is not None:
            nn.init.constant_(m.bias.data, 0)

class Generator1D(nn.Module):
    def __init__(self, z_dim: int, out_channels: int, seq_len: int, base: int = 64):
        super().__init__()
        assert seq_len % 16 == 0, "For this template, seq_len should be divisible by 16."
        self.z_dim = z_dim
        self.out_channels = out_channels
        self.seq_len = seq_len

        # Project noise to a small temporal resolution then upsample x16 via ConvTranspose1d
        self.init_len = seq_len // 16
        self.fc = nn.Linear(z_dim, base * 8 * self.init_len)

        self.net = nn.Sequential(
            nn.ConvTranspose1d(base * 8, base * 4, kernel_size=4, stride=2, padding=1),  # x2
            nn.BatchNorm1d(base * 4),
            nn.ReLU(True),

            nn.ConvTranspose1d(base * 4, base * 2, kernel_size=4, stride=2, padding=1),  # x4
            nn.BatchNorm1d(base * 2),
            nn.ReLU(True),

            nn.ConvTranspose1d(base * 2, base, kernel_size=4, stride=2, padding=1),      # x8
            nn.BatchNorm1d(base),
            nn.ReLU(True),

            nn.ConvTranspose1d(base, out_channels, kernel_size=4, stride=2, padding=1),  # x16
            nn.Tanh(),
        )

    def forward(self, z):
        x = self.fc(z)
        x = x.view(z.size(0), -1, self.init_len)
        x = self.net(x)
        return x

class Critic1D(nn.Module):
    def __init__(self, in_channels: int, seq_len: int, base: int = 64):
        super().__init__()
        assert seq_len % 16 == 0, "For this template, seq_len should be divisible by 16."

        self.net = nn.Sequential(
            nn.Conv1d(in_channels, base, kernel_size=4, stride=2, padding=1),   # /2
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv1d(base, base * 2, kernel_size=4, stride=2, padding=1),      # /4
            nn.InstanceNorm1d(base * 2, affine=True),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv1d(base * 2, base * 4, kernel_size=4, stride=2, padding=1),  # /8
            nn.InstanceNorm1d(base * 4, affine=True),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv1d(base * 4, base * 8, kernel_size=4, stride=2, padding=1),  # /16
            nn.InstanceNorm1d(base * 8, affine=True),
            nn.LeakyReLU(0.2, inplace=True),
        )

        self.out = nn.Linear(base * 8 * (seq_len // 16), 1)

    def forward(self, x):
        h = self.net(x)
        h = h.view(x.size(0), -1)
        return self.out(h).view(-1)

In [8]:
# Cell 4 ‚Äî WGAN-GP Utilities (Gradient Penalty, Losses)
# WGAN-GP objective uses:
#   loss_D = E[D(fake)] - E[D(real)] + lambda_gp * (||‚àá_x_hat D(x_hat)||_2 - 1)^2
#   loss_G = -E[D(fake)]
# This matches the standard WGAN-GP formulation referenced in the survey paper. ÓàÄfileciteÓàÇturn0file0ÓàÅ

def gradient_penalty(critic: nn.Module, real: torch.Tensor, fake: torch.Tensor) -> torch.Tensor:
    bsz = real.size(0)
    eps = torch.rand(bsz, 1, 1, device=real.device)
    x_hat = eps * real + (1 - eps) * fake
    x_hat.requires_grad_(True)

    d_hat = critic(x_hat)
    grads = torch.autograd.grad(
        outputs=d_hat,
        inputs=x_hat,
        grad_outputs=torch.ones_like(d_hat),
        create_graph=True,
        retain_graph=True,
        only_inputs=True,
    )[0]
    grads = grads.view(bsz, -1)
    gp = ((grads.norm(2, dim=1) - 1.0) ** 2).mean()
    return gp

@torch.no_grad()
def sample_generator(generator: nn.Module, n: int, z_dim: int) -> torch.Tensor:
    z = torch.randn(n, z_dim, device=DEVICE)
    return generator(z).cpu()

In [9]:
# Cell 5 ‚Äî Adaptive Discriminator Controller (Gradient Direction Consistency)
# Uses cosine similarity between consecutive gradient vectors to measure training stability.
# High cosine similarity (‚âà1) = consistent gradient direction (stable training)
# Low/negative cosine similarity = oscillating gradients (potentially unstable)
#
# NOTE: Cosine similarity is now computed in update_critic/update_generator.
# This controller only adjusts n_critic based on the EMA values in state.

@dataclass
class TrainState:
    step: int = 0
    epoch: int = 0
    n_critic: int = 5
    n_gen: int = 1
    lambda_gp: float = 10.0

    # Common diagnostics
    wasserstein_gap_ema: float = 0.0
    ema_beta: float = 0.99
    
    # Gradient consistency tracking
    d_grad_cos_sim: float = 0.0  # Discriminator gradient consistency
    g_grad_cos_sim: float = 0.0  # Generator gradient consistency
    
    # EMA of cosine similarities
    d_cos_sim_ema: float | None = None
    g_cos_sim_ema: float | None = None
    
    # Store previous gradient vectors
    prev_d_grad: torch.Tensor | None = None
    prev_g_grad: torch.Tensor | None = None

class AdaptiveDiscriminatorController:
    """
    Adjusts n_critic based on gradient cosine similarity EMA stored in state.
    Cosine similarity is computed in update_critic/update_generator functions.
    """
    def __init__(self,
        cos_sim_threshold_high: float = 0.9,  # Too consistent ‚Üí reduce n_critic
        cos_sim_threshold_low: float = 0.3,   # Too inconsistent ‚Üí increase n_critic
        k_min: int = 2,
        k_max: int = 10,
        state: TrainState = None,
    ):
        self.cos_sim_threshold_high = cos_sim_threshold_high
        self.cos_sim_threshold_low = cos_sim_threshold_low
        self.k_min = k_min
        self.k_max = k_max
        self.state = state if state is not None else TrainState()

    def on_batch_start(self, state: TrainState) -> None:
        pass

    def on_after_critic_update(self, optim_D: torch.optim.Optimizer, D: nn.Module = None) -> None:
        """Adjust n_critic based on D gradient cosine similarity EMA."""
        if self.state.d_cos_sim_ema is None:
            return
            
        # Adaptive control based on gradient consistency
        if self.state.d_cos_sim_ema > self.cos_sim_threshold_high:
            # Gradients too consistent ‚Üí D might be too strong, reduce training
            self.state.n_critic = max(self.k_min, self.state.n_critic - 1)
        elif self.state.d_cos_sim_ema < self.cos_sim_threshold_low:
            # Gradients too inconsistent ‚Üí D needs more updates
            self.state.n_critic = min(self.k_max, self.state.n_critic + 1)

    def on_after_generator_update(self, optim_G: torch.optim.Optimizer, G: nn.Module = None) -> None:
        """Optional: Could adjust n_gen based on G gradient consistency."""
        pass  # Currently no adjustment for generator

In [10]:
def update_critic(z, real, G, D, optim_D, state, controller: Optional[AdaptiveDiscriminatorController] = None) -> Dict[str, float]:
    fake = G(z).detach() # generate fake samples without gradients
    
    d_real = D(real).mean()
    d_fake = D(fake).mean()
    gap = (d_real - d_fake).item()
    
    gp = gradient_penalty(D, real, fake)
    loss_D = (d_fake - d_real) + state.lambda_gp * gp
    
    optim_D.zero_grad(set_to_none=True)
    loss_D.backward()
    
    # Always compute cosine similarity (for both adaptive and normal training)
    curr_d_grad = compute_gradient_vector(D.parameters())
    if curr_d_grad is not None and state.prev_d_grad is not None:
        cos_sim = cosine_similarity_gradients(curr_d_grad, state.prev_d_grad)
        # Update EMA
        if state.d_cos_sim_ema is None:
            state.d_cos_sim_ema = cos_sim
        else:
            state.d_cos_sim_ema = state.ema_beta * state.d_cos_sim_ema + (1 - state.ema_beta) * cos_sim
        state.d_grad_cos_sim = state.d_cos_sim_ema
    state.prev_d_grad = curr_d_grad.clone() if curr_d_grad is not None else None
    
    # Let controller do adaptive adjustments (if present)
    if controller is not None:
        controller.on_after_critic_update(optim_D, D=D)
    
    optim_D.step()
    
    return {
        "d_real": float(d_real.item()),
        "d_fake": float(d_fake.item()),
        "gap": float(gap),
        "gp": float(gp.item()),
        "loss_D": float(loss_D.item()),
    }

In [11]:
def update_generator(z, optim_G, G, D, state, controller: Optional[AdaptiveDiscriminatorController] = None) -> float:
    
    fake = G(z)
    loss_G = -D(fake).mean()
    
    optim_G.zero_grad(set_to_none=True)
    loss_G.backward()
    
    # Always compute cosine similarity (for both adaptive and normal training)
    curr_g_grad = compute_gradient_vector(G.parameters())
    if curr_g_grad is not None and state.prev_g_grad is not None:
        cos_sim = cosine_similarity_gradients(curr_g_grad, state.prev_g_grad)
        # Update EMA
        if state.g_cos_sim_ema is None:
            state.g_cos_sim_ema = cos_sim
        else:
            state.g_cos_sim_ema = state.ema_beta * state.g_cos_sim_ema + (1 - state.ema_beta) * cos_sim
        state.g_grad_cos_sim = state.g_cos_sim_ema
    state.prev_g_grad = curr_g_grad.clone() if curr_g_grad is not None else None
    
    # Let controller do adaptive adjustments (if present)
    if controller is not None:
        controller.on_after_generator_update(optim_G, G=G)
        
    optim_G.step()
    
    return {"loss_G": float(loss_G.item())}

In [12]:
def save_checkpoint(path, G, D, optim_G, optim_D, state, history):
    """Save full training checkpoint."""
    payload = {
        "G": G.state_dict(),
        "D": D.state_dict(),
        "optim_G": optim_G.state_dict(),
        "optim_D": optim_D.state_dict(),
        "state": state.__dict__,
        "history": dict(history),
    }
    torch.save(payload, path)

In [13]:
def load_checkpoint(path, G, D, optim_G=None, optim_D=None):
    """Load checkpoint and return state + history."""
    ckpt = torch.load(path, map_location=DEVICE)
    G.load_state_dict(ckpt["G"])
    D.load_state_dict(ckpt["D"])
    if optim_G is not None and "optim_G" in ckpt:
        optim_G.load_state_dict(ckpt["optim_G"])
    if optim_D is not None and "optim_D" in ckpt:
        optim_D.load_state_dict(ckpt["optim_D"])
    state = TrainState(**ckpt["state"])
    hist = ckpt.get("history", {})
    return state, hist

In [14]:
def initialize(
    Z_DIM: int, 
    CHANNELS: int, 
    SEQ_LEN: int, 
    LR_G: float, 
    LR_D: float,
    BETAS: Tuple[float, float], 
    DEVICE: str,
    kmax: int = 10,
    kmin: int = 2,
    t_high: float = 0.9,
    t_low: float = 0.3,
):
    G = Generator1D(z_dim=Z_DIM, out_channels=CHANNELS, seq_len=SEQ_LEN).to(DEVICE)
    D = Critic1D(in_channels=CHANNELS, seq_len=SEQ_LEN).to(DEVICE)
    
    G.apply(weights_init) # need checked
    D.apply(weights_init) # need checked
    
    optim_G = torch.optim.Adam(G.parameters(), lr=LR_G, betas=BETAS)
    optim_D = torch.optim.Adam(D.parameters(), lr=LR_D, betas=BETAS)
    
    state = TrainState(step=0, epoch=0, n_gen=1, n_critic=5, lambda_gp=10.0)

    controller = AdaptiveDiscriminatorController(k_max=kmax, k_min=kmin, state=state, cos_sim_threshold_high=t_high, cos_sim_threshold_low=t_low)
    
    print(f"LR_G: {LR_G}, LR_D: {LR_D}, BETAS: {BETAS}")
    print(f"n_critic: {state.n_critic}, n_gen: {state.n_gen}, lambda_gp: {state.lambda_gp}")
    
    return G, D, optim_G, optim_D, state, controller

In [None]:
def train_wgan_gp(
    loader: DataLoader,
    G: nn.Module,
    D: nn.Module,
    optim_G: torch.optim.Optimizer,
    optim_D: torch.optim.Optimizer,
    state: TrainState,
    z_dim: int,
    controller: Optional[AdaptiveDiscriminatorController] = None,
    epochs: int = 10,
    log_every: int = 50,
    save_every_steps: int = 500,
    best_metric: str = "gap_ema",
    model_dir: str = MODEL_DIR,
    run_tag: str = "wgan_gp_adaptiveD",
):
    CHECKPOINT_DIR = Path(model_dir) / "checkpoints"
    CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)
    TAG_DIR = Path(model_dir) / run_tag
    TAG_DIR.mkdir(parents=True, exist_ok=True)
    RUN_TAG = run_tag
    
    G.train(); D.train()
    
    # History for plotting
    history = {
        "step": [], "loss_D": [], "loss_G": [], "gap": [], "gap_ema": [],
        "gp": [], "n_critic": [], "d_grad_cos_sim": [], "g_grad_cos_sim": [],
        "n_critic": [], "n_gen": [],
    }
    
    best_score = getattr(state, 'best_score', -1e18)
    
    for epoch in range(epochs):
        state.epoch = epoch
        for batch_idx, real in enumerate(loader):
            state.step += 1
            if controller is not None:
                controller.on_batch_start(state)
                
            real = real.to(DEVICE)
            bsz = real.size(0)

            # -------------------------
            # Critic updates (n_critic)
            # -------------------------
            metrics_D = {}
            for _ in range(state.n_critic):
                z = torch.randn(bsz, z_dim, device=DEVICE)
                metrics_D = update_critic(z, real, G, D, optim_D, state, controller)

                # Update EMA of the gap
                state.wasserstein_gap_ema = state.ema_beta * state.wasserstein_gap_ema + (1 - state.ema_beta) * metrics_D["gap"]
                metrics_D["gap_ema"] = float(state.wasserstein_gap_ema) 

            # -------------------------
            # Generator update
            # -------------------------
            z = torch.randn(bsz, z_dim, device=DEVICE)
            metrics_G = update_generator(z, optim_G, G=G, D=D, state=state, controller=controller)

            # -------------------------
            # Log history (read cosine similarity from state, not metrics)
            # -------------------------
            history["step"].append(state.step)
            history["loss_D"].append(metrics_D.get("loss_D", 0))
            history["loss_G"].append(metrics_G.get("loss_G", 0))
            history["gap"].append(metrics_D.get("gap", 0))
            
            history["gap_ema"].append(metrics_D.get("gap_ema", 0))
            history["gp"].append(metrics_D.get("gp", 0))
            history["n_critic"].append(state.n_critic)
            history["n_gen"].append(state.n_gen)
            # Read from state where the controller stores the values
            history["d_grad_cos_sim"].append(state.d_grad_cos_sim)
            history["g_grad_cos_sim"].append(state.g_grad_cos_sim)

            if (batch_idx % log_every) == 0:
                d_cos = state.d_grad_cos_sim
                g_cos = state.g_grad_cos_sim
                print(
                    f"[Epoch {epoch:03d}/{epochs:03d}] [Batch {batch_idx:04d}/{len(loader):04d}] "
                    f"[n_critic: {state.n_critic}] [n_gen: {state.n_gen}] "
                    f"[gap: {metrics_D['gap']:+.3f} | ema: {metrics_D['gap_ema']:+.3f}] "
                    f"[GP: {metrics_D['gp']:.3f}] "
                    f"[D_cos: {d_cos:+.3f}] [G_cos: {g_cos:+.3f}] "
                    f"[D: {metrics_D['loss_D']:+.3f}] [G: {metrics_G['loss_G']:+.3f}]"
                )

            # -------------------------
            # Save periodic checkpoint
            # -------------------------
            if save_every_steps > 0 and (state.step % save_every_steps) == 0:
                ckpt_path = CHECKPOINT_DIR / f"ckpt_{RUN_TAG}_step_{state.step}.pt"
                save_checkpoint(str(ckpt_path), G, D, optim_G, optim_D, state, history)
                print(f"üíæ Saved checkpoint: {ckpt_path}")

            # -------------------------
            # Save best checkpoint
            # -------------------------
            if metrics_D:
                score = float(metrics_D.get(best_metric, -1e18))
                if score > best_score:
                    best_score = score
                    best_path = CHECKPOINT_DIR / f"best_{RUN_TAG}.pt"
                    save_checkpoint(str(best_path), G, D, optim_G, optim_D, state, history)
                    print(f"üèÜ New BEST ({best_metric}={score:.4f}) -> {best_path}")
    
    # Save final checkpoint
    final_path = TAG_DIR / f"final_{RUN_TAG}.pt"
    save_checkpoint(str(final_path), G, D, optim_G, optim_D, state, history)
    print(f"‚úÖ Saved final checkpoint: {final_path}")
    
    return history

NameError: name 'MODEL_DIR' is not defined

In [16]:
def train_both(init_cos, init_og, tag, tag_dir, epochs):
    cos_tag = f"cos:{tag}"
    cos_dir = tag_dir / cos_tag
    cos_checkpoints = cos_dir / "checkpoints"

    cos_checkpoints.mkdir(parents=True, exist_ok=True)
    cos_dir.mkdir(parents=True, exist_ok=True)
    print(f"Run directory: {cos_tag}")
    
    G_cos, D_cos, optim_G_cos, optim_D_cos, state_cos, controller = initialize(
        Z_DIM=init_cos["z_dim"], CHANNELS=init_cos["channels"], SEQ_LEN=init_cos["seq_len"], LR_D=init_cos["lr_D"], LR_G=init_cos["lr_G"], BETAS=init_cos["betas"], DEVICE=init_cos["device"],
        kmax=init_cos["kmax"], kmin=init_cos["kmin"], t_high=init_cos["t_high"], t_low=init_cos["t_low"]
    )

    history_cos = train_wgan_gp(
        loader, G_cos , D_cos, optim_G_cos, optim_D_cos, state_cos,
        z_dim=init_cos["z_dim"], controller=controller,
        epochs=epochs, log_every=50, save_every_steps=500
    )
    
    save_models(cos_dir, cos_tag, G_cos, D_cos, history_cos)
    
    og_tag = f"og:{tag}"
    og_dir = tag_dir / og_tag
    og_checkpoints = og_dir / "checkpoints"

    og_checkpoints.mkdir(parents=True, exist_ok=True)
    og_dir.mkdir(parents=True, exist_ok=True)
    print(f"Run directory: {og_tag}")
    
    G_og, D_og, optim_G_og, optim_D_og, state_og, _ = initialize(
        Z_DIM=init_og["z_dim"], CHANNELS=init_og["channels"], SEQ_LEN=init_og["seq_len"], LR_D=init_og["lr_D"], LR_G=init_og["lr_G"], BETAS=init_og["betas"], DEVICE=init_og["device"]
    )

    history_og = train_wgan_gp(
        loader, G_og , D_og, optim_G_og, optim_D_og, state_og,
        z_dim=init_og["z_dim"], controller=None,  # No controller for normal training
        epochs=epochs, log_every=50, save_every_steps=500
    )
    
    save_models(og_dir, og_tag, G_og, D_og, history_og)
    
    # Return histories, directories, and tags
    return {
        "cos": {"history": history_cos, "dir": cos_dir, "tag": cos_tag},
        "og": {"history": history_og, "dir": og_dir, "tag": og_tag},
        "tag_dir": tag_dir,  # Parent directory for comparison plots
        "run_tag": tag,
    }

----

# Dataset Fuc

In [None]:
def load_gdf_files(data_dir: str, resample_hz: int, mode: str, verbose: bool = False)  -> list[str]:
    if mode == "train":
        pattern = "*T.gdf"  # Training files
    elif mode == "eval":
        pattern = "*E.gdf"  # Evaluation files
    else:
        raise ValueError("mode should be 'train' or 'eval'")
    
    all_files = sorted([file for file in data_dir.glob(pattern)])
    
    if len(all_files) == 0:
        raise ValueError(f"No .gdf files found in {data_dir} with pattern {pattern}")
    
    raws = []   
    for file in all_files:
        raw = mne.io.read_raw_gdf(file, preload=True, verbose="error")
        raw.pick("eeg")  
        raw.resample(resample_hz)
        raws.append(raw)
    
    return raws

In [18]:
def create_epoch(raw: mne.io.Raw, event_id: Dict[str, int], tmin: float, tmax: float) -> np.ndarray:
    events, event_dict = mne.events_from_annotations(raw, verbose="error")
    lh = str(event_id.get("LH"))
    rh = str(event_id.get("RH"))
    event_id = {"LH": event_dict.get(lh), "RH": event_dict.get(rh)}
    
    if event_id["LH"] is None or event_id["RH"] is None:
        print(f"At {raw.filenames[0].name} Event ID(lh({type(lh)}):{lh}) or RH({type(rh)}):{rh} not found in annotations: {event_dict}")
        return []
    
    print(f"From:{event_dict} to {event_id}")
    epochs = mne.Epochs(raw, events, event_id=event_id, tmin=tmin, tmax=tmax, baseline=None, preload=True, verbose="error")
    return epochs

In [2]:
def make_dataset(raws: list[mne.io.Raw], tmin: float, tmax: float, event_id: Dict[str, int]) -> np.ndarray:
    if not raws:
        raise ValueError("The list of raws is empty.")
    
    epochs_list = []
    for i, raw in enumerate(raws):
        epochs = create_epoch(raw, event_id=event_id, tmin=tmin, tmax=tmax)
        
        if len(epochs) == 0:
            continue
        
        epochs_list.append(epochs.get_data())  # shape: (n_epochs, n_channels, n_times)
    
    if not epochs_list:
        raise ValueError("No epochs were created from the provided raw data.")
    
    dataset = np.concatenate(epochs_list, axis=0)  # shape: (total_epochs, n_channels, n_times)
    print(f"Dataset shape: {dataset.shape}")
    
    return dataset

NameError: name 'mne' is not defined

In [20]:
# From first
class EEGTensorDataset(Dataset):
    def __init__(self, X: np.ndarray, y: Optional[np.ndarray] = None, zscore_per_channel: bool = True):
        """
        X: (N, C, L)
        y: optional (N,)
        zscore_per_channel:
          - global per-channel mean/std computed across (N,L) for each channel
        """
        assert X.ndim == 3
        self.X = X.astype(np.float32)
        self.y = None if y is None else y.astype(np.int64)

        if zscore_per_channel:
            mean = self.X.mean(axis=(0, 2), keepdims=True)
            std  = self.X.std(axis=(0, 2), keepdims=True) + 1e-6
            self.X = (self.X - mean) / std

    def __len__(self):
        return self.X.shape[0]

    def __getitem__(self, idx):
        x = torch.from_numpy(self.X[idx])  # (C, L)
        if self.y is None:
            return x
        return x, int(self.y[idx])

# Unconditional GAN: we only use x; y is available if you want conditional later

----

# Load data

In [21]:
PLUEM_DIR = Path.cwd()
DATA_DIR = PLUEM_DIR.parent / "BCICIV_2b_gdf"
MODEL_DIR = PLUEM_DIR / "models"

MODEL_DIR.mkdir(parents=True, exist_ok=True)

print(f"DATA_DIR exists: {DATA_DIR.exists()}")
print(f"MODEL_DIR exists: {MODEL_DIR.exists()}")

files = list(DATA_DIR.glob("*.gdf"))
files.sort()
print(f"Found {len(files)} .gdf files.")
print("First 10 files:", [f.name for f in files[:10]])

DATA_DIR exists: True
MODEL_DIR exists: True
Found 45 .gdf files.
First 10 files: ['B0101T.gdf', 'B0102T.gdf', 'B0103T.gdf', 'B0104E.gdf', 'B0105E.gdf', 'B0201T.gdf', 'B0202T.gdf', 'B0203T.gdf', 'B0204E.gdf', 'B0205E.gdf']


In [None]:
one_file = files[0]
raw = mne.io.read_raw_gdf(one_file, preload=True, verbose="error")
raw.pick("eeg")  # Keep only EEG channels

<Info | 8 non-empty values
 bads: []
 ch_names: EEG:C3, EEG:Cz, EEG:C4, EOG:ch01, EOG:ch02, EOG:ch03
 chs: 6 EEG
 custom_ref_applied: False
 highpass: 0.0 Hz
 lowpass: 50.0 Hz
 meas_date: 2005-10-25 09:35:11 UTC
 nchan: 6
 projs: []
 sfreq: 100.0 Hz
 subject_info: <subject_info | his_id: B01, sex: 0, last_name: X, birthday: 1984-12-01>
>
Using qt as 2D backend.


<mne_qt_browser._pg_figure.MNEQtBrowser(0x30f9ae2d0) at 0x31d002b00>

Channels marked as bad:
[np.str_('EEG:Cz')]
Attempting to create new mne-python configuration file:
/Users/ratchanonkhongsawi/.mne/mne-python.json
Could not read the /Users/ratchanonkhongsawi/.mne/mne-python.json json file during the writing. Assuming it is empty. Got: Expecting value: line 1 column 1 (char 0)


In [None]:
RESAMPLE_HZ = 256
TMIN, TMAX = 0.0, 4.0
EVENT_ID = {"LH": "769", "RH": "770"}

raws = load_gdf_files(DATA_DIR, RESAMPLE_HZ, verbose=True, mode="train")
X = make_dataset(raws, TMIN, TMAX, event_id=EVENT_ID)
print("Dataset shape:", X.shape)

SEQ_LEN = X.shape[2]
if SEQ_LEN % 16 != 0:
    new_len = (SEQ_LEN // 16) * 16  # floor to nearest multiple of 16
    print(f"Cropping SEQ_LEN from {SEQ_LEN} -> {new_len} to satisfy architecture constraint.")
    X = X[:, :, :new_len]
    
CHANNELS = X.shape[1]
SEQ_LEN  = X.shape[2]
print("CHANNELS:", CHANNELS, "SEQ_LEN:", SEQ_LEN)
assert SEQ_LEN % 16 == 0


=== Reading: B0101T.gdf ===

=== Reading: B0102T.gdf ===

=== Reading: B0103T.gdf ===

=== Reading: B0201T.gdf ===

=== Reading: B0202T.gdf ===

=== Reading: B0203T.gdf ===

=== Reading: B0301T.gdf ===

=== Reading: B0302T.gdf ===

=== Reading: B0303T.gdf ===

=== Reading: B0401T.gdf ===

=== Reading: B0402T.gdf ===

=== Reading: B0403T.gdf ===

=== Reading: B0501T.gdf ===

=== Reading: B0502T.gdf ===

=== Reading: B0503T.gdf ===

=== Reading: B0601T.gdf ===

=== Reading: B0602T.gdf ===

=== Reading: B0603T.gdf ===

=== Reading: B0701T.gdf ===

=== Reading: B0702T.gdf ===


KeyboardInterrupt: 

: 

In [None]:
dataset = EEGTensorDataset(X, y=None, zscore_per_channel=True)
loader = DataLoader(dataset, batch_size=64, shuffle=True, drop_last=True, num_workers=0, pin_memory=(DEVICE=="cuda"))
print("Batches:", len(loader))

Batches: 57


In [None]:
# --- Model configs ---


# G = Generator1D(z_dim=Z_DIM, out_channels=CHANNELS, seq_len=SEQ_LEN).to(DEVICE)
# D = Critic1D(in_channels=CHANNELS, seq_len=SEQ_LEN).to(DEVICE)

# G.apply(weights_init) # need checked
# D.apply(weights_init) # need checked



In [None]:
Z_DIM = 128

LR_G = 1e-4  # Generator learning rate
LR_D = [5e-5, 2e-5]  # Discriminator/Critic learning rate
BETAS = (0.2, 0.9)   # standard WGAN-GP choice

KMAX = 10
KMIN = 2
T_HIGH = 0.9
T_LOW = 0.0

EPOCH = 1
TAG = "test"

tag_dir = MODEL_DIR / TAG
tag_dir.mkdir(parents=True, exist_ok=True)

# print("G params:", sum(p.numel() for p in G.parameters())/1e6, "M")
# print("D params:", sum(p.numel() for p in D.parameters())/1e6, "M")

In [None]:

init_cos = {
    "z_dim": Z_DIM,
    "channels": CHANNELS,
    "seq_len": SEQ_LEN,
    "lr_D": None,  # to be set in loop
    "lr_G": LR_G,
    "betas": BETAS,
    "device": DEVICE,
    "kmax": KMAX,
    "kmin": KMIN,
    "t_high": T_HIGH,
    "t_low": T_LOW,
}

init_og = {
    "z_dim": Z_DIM,
    "channels": CHANNELS,
    "seq_len": SEQ_LEN,
    "lr_D": None,  # to be set in loop
    "lr_G": LR_G,
    "betas": BETAS,
    "device": DEVICE,
}

results = []  # to store results for different LR_D [5e-5, 2e-5]
for lr_D in LR_D:
    run_tag = f"{TAG}_lr_D_{lr_D}"
    init_cos["lr_D"] = lr_D
    init_og["lr_D"] = lr_D
    result = train_both(
        init_cos=init_cos,
        init_og=init_og,
        tag=run_tag,
        tag_dir=tag_dir,
        epochs=EPOCH,
    )
    results.append(result)

Run directory: cos:test
LR_G: 0.0001, LR_D: 5e-05, BETAS: (0.0, 0.9)
n_critic: 5, n_gen: 1, lambda_gp: 10.0


KeyboardInterrupt: 

In [None]:
# Cell 8 ‚Äî Training Curve Visualization (with Gradient Consistency)
# LR_D = 5e-5
r0 = results[0]

# Plot adaptive (cosine) - saves to cos_dir
plot_training_history(r0["cos"]["history"], r0["cos"]["dir"], r0["cos"]["tag"])

In [None]:
# Plot normal (og) - saves to og_dir
plot_training_history(r0["og"]["history"], r0["og"]["dir"], r0["og"]["tag"])

In [None]:
# Comparison plot - saves to parent tag_dir
plot_comparison(
    r0["cos"]["history"],
    r0["og"]["history"],
    r0["tag_dir"],
    f"comparison_{r0['run_tag']}"
)

---

In [None]:
# LR_D = 2e-5
r1 = results[1]

# Plot adaptive (cosine) - saves to cos_dir
plot_training_history(r1["cos"]["history"], r1["cos"]["dir"], r1["cos"]["tag"])

In [None]:
# Plot normal (og) - saves to og_dir
plot_training_history(r1["og"]["history"], r1["og"]["dir"], r1["og"]["tag"])

In [None]:
# Comparison plot - saves to parent tag_dir
plot_comparison(
    r1["cos"]["history"],
    r1["og"]["history"],
    r1["tag_dir"],
    f"comparison_{r1['run_tag']}"
)

---

metric: 

### Normal