# WGAN-GP (Standard) + Adaptive-Discriminator Hooks (Template)

This notebook provides:
- A **standard WGAN-GP** training loop (critic + gradient penalty).
- A **pluggable "Adaptive Discriminator" controller** with clearly marked hooks so you can experiment with different discriminator/critic adjustment strategies (e.g., dynamic `n_critic`, LR, GP weight, architecture toggles, etc.).

> Notes: WGAN-GP uses a gradient penalty to enforce the 1-Lipschitz constraint instead of weight clipping (see WGAN-GP objective in the referenced survey paper). ÓàÄfileciteÓàÇturn0file0ÓàÅ


In [13]:
# Cell 1 ‚Äî Imports & Reproducibility
import os
import math
import random
from dataclasses import dataclass
from typing import Optional, Dict, Any, Tuple
from pathlib import Path
import mne

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

def seed_everything(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed_everything(42)
print("Device:", DEVICE)


Device: cpu


In [14]:
def scale_to_tanh(x: np.ndarray, clip: float = 3.0, eps: float = 1e-6):
    """
    x: (C, T) numpy
    per-channel standardize, clip, then map to ~[-1, 1]
    """
    mean = x.mean(axis=-1, keepdims=True)
    std  = x.std(axis=-1, keepdims=True)
    std = np.maximum(std, eps)
    x = (x - mean) / std
    x = np.clip(x, -clip, clip) / clip
    return x

def descale_from_tanh(x: np.ndarray, original_mean: np.ndarray, original_std: np.ndarray, clip: float = 3.0):
    """
    x: (C, T) numpy in ~[-1, 1]
    reverse of scale_to_tanh
    """
    x = x * clip
    x = x * original_std + original_mean
    return x

In [15]:
# def load_gdf_files(data_dir: str, subject_ids: list[int]) -> list[str]:
#     files = []
#     for subject_id in subject_ids:
#         subject_str = f"S{subject_id:02d}"
#         subject_dir = os.path.join(data_dir, subject_str)
#         for file_name in os.listdir(subject_dir):
#             if file_name.endswith(".gdf"):
#                 files.append(os.path.join(subject_dir, file_name))
#     return files

In [16]:
# def to_epochs(raws: list[], tmin, tmax, resample_hz, wanted_labels: Tuple[str, str] = ("769", "770")) -> int:
#     return math.ceil((total_steps * batch_size) / dataset_size)

In [17]:
PLUEM_DIR = Path.cwd()
DATA_DIR = PLUEM_DIR.parent / "BCICIV_2b_gdf"
MODEL_DIR = PLUEM_DIR / "models"

MODEL_DIR.mkdir(parents=True, exist_ok=True)

print(f"DATA_DIR exists: {DATA_DIR.exists()}")
print(f"MODEL_DIR exists: {MODEL_DIR.exists()}")

files = list(DATA_DIR.glob("*.gdf"))
files.sort()
print(f"Found {len(files)} .gdf files.")
print("First 10 files:", [f.name for f in files[:10]])

DATA_DIR exists: True
MODEL_DIR exists: True
Found 45 .gdf files.
First 10 files: ['B0101T.gdf', 'B0102T.gdf', 'B0103T.gdf', 'B0104E.gdf', 'B0105E.gdf', 'B0201T.gdf', 'B0202T.gdf', 'B0203T.gdf', 'B0204E.gdf', 'B0205E.gdf']


In [18]:

# raws = []
# for file in files:
#     raw = mne.io.read_raw_gdf(file, preload=True, verbose="error")
#     raws.append(raw)
    
# type(raws)

In [19]:
def load_gdf_files(data_dir: str, resample_hz: int, mode: str, verbose: bool = False)  -> list[str]:
    if mode == "train":
        pattern = "*T.gdf"  # Training files
    elif mode == "eval":
        pattern = "*E.gdf"  # Evaluation files
    else:
        raise ValueError("mode should be 'train' or 'eval'")
    
    all_files = sorted([file for file in data_dir.glob(pattern)])
    
    if len(all_files) == 0:
        raise ValueError(f"No .gdf files found in {data_dir} with pattern {pattern}")
    
    raws = []   
    for file in all_files:
        if verbose:
            print("\n=== Reading:", file.name, "===")
        raw = mne.io.read_raw_gdf(file, preload=True, verbose="error")
        raw.pick("eeg")  
        raw.resample(resample_hz)
        raws.append(raw)
    
    return raws

In [20]:
def create_epoch(raw: mne.io.Raw, event_id: Dict[str, int], tmin: float, tmax: float) -> np.ndarray:
    events, event_dict = mne.events_from_annotations(raw, verbose="error")
    lh = str(event_id.get("LH"))
    rh = str(event_id.get("RH"))
    event_id = {"LH": event_dict.get(lh), "RH": event_dict.get(rh)}
    
    if event_id["LH"] is None or event_id["RH"] is None:
        print(f"At {raw.filenames[0].name} Event ID(lh({type(lh)}):{lh}) or RH({type(rh)}):{rh} not found in annotations: {event_dict}")
        return []
    
    print(f"From:{event_dict} to {event_id}")
    epochs = mne.Epochs(raw, events, event_id=event_id, tmin=tmin, tmax=tmax, baseline=None, preload=True, verbose="error")
    return epochs

In [21]:
def make_dataset(raws: list[mne.io.Raw], tmin: float, tmax: float, event_id: Dict[str, int]) -> np.ndarray:
    if not raws:
        raise ValueError("The list of raws is empty.")
    
    epochs_list = []
    for i, raw in enumerate(raws):
        epochs = create_epoch(raw, event_id=event_id, tmin=tmin, tmax=tmax)
        
        if len(epochs) == 0:
            continue
        
        epochs_list.append(epochs.get_data())  # shape: (n_epochs, n_channels, n_times)
    
    if not epochs_list:
        raise ValueError("No epochs were created from the provided raw data.")
    
    dataset = np.concatenate(epochs_list, axis=0)  # shape: (total_epochs, n_channels, n_times)
    print(f"Dataset shape: {dataset.shape}")
    
    return dataset

In [22]:
# # From first
# def _find_gdf_files(root: Path) -> list:
#     return sorted([p for p in root.rglob("*.gdf")])

# def load_bciciv2b_epochs_by_labels(
#     root: Path,
#     file_glob: str = "*T.gdf",
#     tmin: float = 0.0,
#     tmax: float = 4.0,
#     resample_hz: int = 256,
#     picks: str = "eeg",
#     baseline: Optional[Tuple[float, float]] = None,
#     wanted_labels: Tuple[str, str] = ("769", "770"),  # left/right in BCICIV
#     verbose: bool = True,
# ) -> Tuple[np.ndarray, np.ndarray, int, Dict[str, int]]:
#     all_files = _find_gdf_files(root)
#     if file_glob:
#         import fnmatch
#         all_files = [f for f in all_files if fnmatch.fnmatch(f.name, file_glob)]

#     if len(all_files) == 0:
#         raise FileNotFoundError(f"No .gdf files found under {root} (file_glob={file_glob})")

#     X_list, y_list = [], []
#     last_sfreq = None

#     # Fixed class order: wanted_labels[0] -> class 0, wanted_labels[1] -> class 1
#     global_event_ids = {"left": wanted_labels[0], "right": wanted_labels[1]}

#     for gdf_path in all_files:
#         print("\n=== Reading:", gdf_path.name, "===")

#         raw = mne.io.read_raw_gdf(str(gdf_path), preload=True, verbose="ERROR")
#         raw.pick(picks)
#         raw.resample(resample_hz)

#         events, event_dict = mne.events_from_annotations(raw, verbose="ERROR")
#         if verbose:
#             print("Available annotation events:", event_dict)

#         # Convert wanted annotation labels -> internal event codes for THIS file
#         missing = [lab for lab in wanted_labels if lab not in event_dict]
#         if len(missing) > 0:
#             print(f"Skipping {gdf_path.name} (missing labels {missing})")
#             continue

#         event_id = {"left": event_dict[wanted_labels[0]], "right": event_dict[wanted_labels[1]]}
#         print("Using internal event codes:", event_id, "(for labels", wanted_labels, ")")

#         epochs = mne.Epochs(
#             raw,
#             events=events,
#             event_id=event_id,
#             tmin=tmin,
#             tmax=tmax,
#             baseline=baseline,
#             preload=True,
#             reject=None,
#             on_missing="warn",
#             verbose="ERROR",
#         )

#         if len(epochs) == 0:
#             print(f"No epochs created for {gdf_path.name}. Skipping.")
#             continue

#         data = epochs.get_data().astype(np.float32)  # (N, C, L)
#         labels = epochs.events[:, -1]  # internal codes (e.g., 4/5)

#         # Map internal codes to 0/1 consistently using event_id dict
#         code_to_idx = {int(event_id["left"]): 0, int(event_id["right"]): 1}
#         y = np.array([code_to_idx[int(c)] for c in labels], dtype=np.int64)

#         X_list.append(data)
#         y_list.append(y)
#         last_sfreq = int(raw.info["sfreq"])

#     if len(X_list) == 0:
#         raise RuntimeError(
#             "No epochs were created from any file. "
#             "Double-check dataset contents and wanted_labels."
#         )

#     X = np.concatenate(X_list, axis=0)
#     y = np.concatenate(y_list, axis=0)
#     return X, y, last_sfreq, global_event_ids

# # ---- Configure epoching ----
# FILE_GLOB = "*T.gdf"  # good default for training runs

# X_np, y_np, sfreq, USED_LABELS = load_bciciv2b_epochs_by_labels(
#     root=DATA_DIR,
#     file_glob=FILE_GLOB,
#     tmin=TMIN,
#     tmax=TMAX,
#     resample_hz=RESAMPLE_HZ,
#     picks="eeg",
#     baseline=None,
#     wanted_labels=("769", "770"),
#     verbose=True,
# )

# # Fix SEQ_LEN to be divisible by 16 (crop the last sample if needed)
# SEQ_LEN = X_np.shape[2]
# if SEQ_LEN % 16 != 0:
#     new_len = (SEQ_LEN // 16) * 16  # floor to nearest multiple of 16
#     print(f"Cropping SEQ_LEN from {SEQ_LEN} -> {new_len} to satisfy architecture constraint.")
#     X_np = X_np[:, :, :new_len]

# CHANNELS = X_np.shape[1]
# SEQ_LEN  = X_np.shape[2]
# print("CHANNELS:", CHANNELS, "SEQ_LEN:", SEQ_LEN)
# assert SEQ_LEN % 16 == 0

# # print("\nLoaded X:", X_np.shape, "y:", y_np.shape, "sfreq:", sfreq)
# # CHANNELS = X_np.shape[1]
# # SEQ_LEN  = X_np.shape[2]
# # print("CHANNELS:", CHANNELS, "SEQ_LEN:", SEQ_LEN, "(must be divisible by 16)")

In [23]:
RESAMPLE_HZ = 256
TMIN, TMAX = 0.0, 4.0
EVENT_ID = {"LH": 769, "RH": 770}

raws = load_gdf_files(DATA_DIR, RESAMPLE_HZ, verbose=True, mode="train")
X = make_dataset(raws, TMIN, TMAX, event_id=EVENT_ID)
print("Dataset shape:", X.shape)

SEQ_LEN = X.shape[2]
if SEQ_LEN % 16 != 0:
    new_len = (SEQ_LEN // 16) * 16  # floor to nearest multiple of 16
    print(f"Cropping SEQ_LEN from {SEQ_LEN} -> {new_len} to satisfy architecture constraint.")
    X = X[:, :, :new_len]
    
CHANNELS = X.shape[1]
SEQ_LEN  = X.shape[2]
print("CHANNELS:", CHANNELS, "SEQ_LEN:", SEQ_LEN)
assert SEQ_LEN % 16 == 0


=== Reading: B0101T.gdf ===

=== Reading: B0102T.gdf ===

=== Reading: B0103T.gdf ===

=== Reading: B0201T.gdf ===

=== Reading: B0202T.gdf ===

=== Reading: B0203T.gdf ===

=== Reading: B0301T.gdf ===

=== Reading: B0302T.gdf ===

=== Reading: B0303T.gdf ===

=== Reading: B0401T.gdf ===

=== Reading: B0402T.gdf ===

=== Reading: B0403T.gdf ===

=== Reading: B0501T.gdf ===

=== Reading: B0502T.gdf ===

=== Reading: B0503T.gdf ===

=== Reading: B0601T.gdf ===

=== Reading: B0602T.gdf ===

=== Reading: B0603T.gdf ===

=== Reading: B0701T.gdf ===

=== Reading: B0702T.gdf ===

=== Reading: B0703T.gdf ===

=== Reading: B0801T.gdf ===

=== Reading: B0802T.gdf ===

=== Reading: B0803T.gdf ===

=== Reading: B0901T.gdf ===

=== Reading: B0902T.gdf ===

=== Reading: B0903T.gdf ===
From:{np.str_('1023'): 1, np.str_('1077'): 2, np.str_('1078'): 3, np.str_('1079'): 4, np.str_('1081'): 5, np.str_('276'): 6, np.str_('277'): 7, np.str_('32766'): 8, np.str_('768'): 9, np.str_('769'): 10, np.str_('770'

In [24]:
# From first
class EEGTensorDataset(Dataset):
    def __init__(self, X: np.ndarray, y: Optional[np.ndarray] = None, zscore_per_channel: bool = True):
        """
        X: (N, C, L)
        y: optional (N,)
        zscore_per_channel:
          - global per-channel mean/std computed across (N,L) for each channel
        """
        assert X.ndim == 3
        self.X = X.astype(np.float32)
        self.y = None if y is None else y.astype(np.int64)

        if zscore_per_channel:
            mean = self.X.mean(axis=(0, 2), keepdims=True)
            std  = self.X.std(axis=(0, 2), keepdims=True) + 1e-6
            self.X = (self.X - mean) / std

    def __len__(self):
        return self.X.shape[0]

    def __getitem__(self, idx):
        x = torch.from_numpy(self.X[idx])  # (C, L)
        if self.y is None:
            return x
        return x, int(self.y[idx])

# Unconditional GAN: we only use x; y is available if you want conditional later
dataset = EEGTensorDataset(X, y=None, zscore_per_channel=True)
loader = DataLoader(dataset, batch_size=64, shuffle=True, drop_last=True, num_workers=0, pin_memory=(DEVICE=="cuda"))
print("Batches:", len(loader))

Batches: 57


In [25]:
# Cell 3 ‚Äî Generator & Critic (1D Conv, Standard WGAN-GP Style)
# This is a *standard* WGAN-GP setup: the "discriminator" is a *critic* with a linear output (no sigmoid).

def weights_init(m):
    if isinstance(m, (nn.Conv1d, nn.ConvTranspose1d, nn.Linear)):
        nn.init.normal_(m.weight.data, 0.0, 0.02)
        if m.bias is not None:
            nn.init.constant_(m.bias.data, 0)

class Generator1D(nn.Module):
    def __init__(self, z_dim: int, out_channels: int, seq_len: int, base: int = 64):
        super().__init__()
        assert seq_len % 16 == 0, "For this template, seq_len should be divisible by 16."
        self.z_dim = z_dim
        self.out_channels = out_channels
        self.seq_len = seq_len

        # Project noise to a small temporal resolution then upsample x16 via ConvTranspose1d
        self.init_len = seq_len // 16
        self.fc = nn.Linear(z_dim, base * 8 * self.init_len)

        self.net = nn.Sequential(
            nn.ConvTranspose1d(base * 8, base * 4, kernel_size=4, stride=2, padding=1),  # x2
            nn.BatchNorm1d(base * 4),
            nn.ReLU(True),

            nn.ConvTranspose1d(base * 4, base * 2, kernel_size=4, stride=2, padding=1),  # x4
            nn.BatchNorm1d(base * 2),
            nn.ReLU(True),

            nn.ConvTranspose1d(base * 2, base, kernel_size=4, stride=2, padding=1),      # x8
            nn.BatchNorm1d(base),
            nn.ReLU(True),

            nn.ConvTranspose1d(base, out_channels, kernel_size=4, stride=2, padding=1),  # x16
            nn.Tanh(),
        )

    def forward(self, z):
        x = self.fc(z)
        x = x.view(z.size(0), -1, self.init_len)
        x = self.net(x)
        return x

class Critic1D(nn.Module):
    def __init__(self, in_channels: int, seq_len: int, base: int = 64):
        super().__init__()
        assert seq_len % 16 == 0, "For this template, seq_len should be divisible by 16."

        self.net = nn.Sequential(
            nn.Conv1d(in_channels, base, kernel_size=4, stride=2, padding=1),   # /2
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv1d(base, base * 2, kernel_size=4, stride=2, padding=1),      # /4
            nn.InstanceNorm1d(base * 2, affine=True),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv1d(base * 2, base * 4, kernel_size=4, stride=2, padding=1),  # /8
            nn.InstanceNorm1d(base * 4, affine=True),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv1d(base * 4, base * 8, kernel_size=4, stride=2, padding=1),  # /16
            nn.InstanceNorm1d(base * 8, affine=True),
            nn.LeakyReLU(0.2, inplace=True),
        )

        self.out = nn.Linear(base * 8 * (seq_len // 16), 1)

    def forward(self, x):
        h = self.net(x)
        h = h.view(x.size(0), -1)
        return self.out(h).view(-1)

# --- Model configs ---
Z_DIM = 128

G = Generator1D(z_dim=Z_DIM, out_channels=CHANNELS, seq_len=SEQ_LEN).to(DEVICE)
D = Critic1D(in_channels=CHANNELS, seq_len=SEQ_LEN).to(DEVICE)

G.apply(weights_init)
D.apply(weights_init)

print("G params:", sum(p.numel() for p in G.parameters())/1e6, "M")
print("D params:", sum(p.numel() for p in D.parameters())/1e6, "M")


G params: 4.918086 M
D params: 0.725185 M


In [26]:
# Cell 4 ‚Äî WGAN-GP Utilities (Gradient Penalty, Losses)
# WGAN-GP objective uses:
#   loss_D = E[D(fake)] - E[D(real)] + lambda_gp * (||‚àá_x_hat D(x_hat)||_2 - 1)^2
#   loss_G = -E[D(fake)]
# This matches the standard WGAN-GP formulation referenced in the survey paper. ÓàÄfileciteÓàÇturn0file0ÓàÅ

def gradient_penalty(critic: nn.Module, real: torch.Tensor, fake: torch.Tensor) -> torch.Tensor:
    bsz = real.size(0)
    eps = torch.rand(bsz, 1, 1, device=real.device)
    x_hat = eps * real + (1 - eps) * fake
    x_hat.requires_grad_(True)

    d_hat = critic(x_hat)
    grads = torch.autograd.grad(
        outputs=d_hat,
        inputs=x_hat,
        grad_outputs=torch.ones_like(d_hat),
        create_graph=True,
        retain_graph=True,
        only_inputs=True,
    )[0]
    grads = grads.view(bsz, -1)
    gp = ((grads.norm(2, dim=1) - 1.0) ** 2).mean()
    return gp

@torch.no_grad()
def sample_generator(generator: nn.Module, n: int, z_dim: int) -> torch.Tensor:
    z = torch.randn(n, z_dim, device=DEVICE)
    return generator(z).cpu()

In [27]:

def compute_gradient_vector(parameters) -> torch.Tensor:
    """Flatten all gradients into a single vector."""
    grads = []
    for p in parameters:
        if p.grad is not None:
            grads.append(p.grad.detach().view(-1))
    if len(grads) == 0:
        return None
    return torch.cat(grads)

def cosine_similarity_gradients(grad_vec1: torch.Tensor, grad_vec2: torch.Tensor) -> float:
    """Compute cosine similarity between two gradient vectors."""
    if grad_vec1 is None or grad_vec2 is None:
        return 0.0
    cos_sim = F.cosine_similarity(
        grad_vec1.unsqueeze(0), 
        grad_vec2.unsqueeze(0)
    ).item()
    return cos_sim

In [28]:
def update_n_critic(current_step: int, n_critic: int, n_critic_initial: int, n_critic_final: int, n_critic_rampup_steps: int) -> int:
    """Dynamically adjust the number of critic updates per generator update."""
    if current_step < n_critic_rampup_steps:
        n_critic = n_critic_initial + (n_critic_final - n_critic_initial) * (current_step / n_critic_rampup_steps)
        n_critic = int(round(n_critic))
    else:
        n_critic = n_critic_final
    return n_critic

In [None]:
# Cell 5 ‚Äî Adaptive Discriminator Controller (Gradient Direction Consistency)
# Uses cosine similarity between consecutive gradient vectors to measure training stability.
# High cosine similarity (‚âà1) = consistent gradient direction (stable training)
# Low/negative cosine similarity = oscillating gradients (potentially unstable)

@dataclass
class TrainState:
    step: int = 0
    epoch: int = 0
    n_critic: int = 5
    n_gen: int = 1
    lambda_gp: float = 10.0

    # Common diagnostics
    wasserstein_gap_ema: float = 0.0
    ema_beta: float = 0.99
    
    # Gradient consistency tracking
    d_grad_cos_sim: float = 0.0  # Discriminator gradient consistency
    g_grad_cos_sim: float = 0.0  # Generator gradient consistency

class AdaptiveDiscriminatorController:
    def __init__(
        self,
        cos_sim_threshold_high: float = 0.9,  # Too consistent ‚Üí reduce n_critic
        cos_sim_threshold_low: float = 0.3,   # Too inconsistent ‚Üí increase n_critic
        k_min: int | None = None,
        k_max: int | None = None,
        ema_beta: float = 0.95,
    ):
        self.cos_sim_threshold_high = cos_sim_threshold_high
        self.cos_sim_threshold_low = cos_sim_threshold_low
        self.k_min = k_min
        self.k_max = k_max
        self.ema_beta = ema_beta
        
        # Store previous gradient vectors
        self.prev_d_grad = None
        self.prev_g_grad = None
        
        # EMA of cosine similarities
        self.d_cos_sim_ema = None
        self.g_cos_sim_ema = None

    def on_batch_start(self, state: TrainState) -> None:
        pass

    def on_after_critic_update(
        self, 
        state: TrainState, 
        metrics: Dict[str, float], 
        optim_D: torch.optim.Optimizer,
        D: nn.Module = None,
    ) -> None:
        if D is None:
            return
        
        # Compute current gradient vector
        curr_d_grad = compute_gradient_vector(D.parameters())
        
        if curr_d_grad is not None and self.prev_d_grad is not None:
            cos_sim = cosine_similarity_gradients(curr_d_grad, self.prev_d_grad)
            
            # Update EMA
            if self.d_cos_sim_ema is None:
                self.d_cos_sim_ema = cos_sim
            else:
                self.d_cos_sim_ema = (
                    self.ema_beta * self.d_cos_sim_ema 
                    + (1 - self.ema_beta) * cos_sim
                )
            
            state.d_grad_cos_sim = self.d_cos_sim_ema
            metrics["d_grad_cos_sim"] = cos_sim
            metrics["d_grad_cos_sim_ema"] = self.d_cos_sim_ema
            
            # Adaptive control based on gradient consistency
            if self.d_cos_sim_ema > self.cos_sim_threshold_high:
                # Gradients too consistent ‚Üí D might be too strong
                if self.k_min is not None:
                    state.n_critic = max(self.k_min, state.n_critic - 1)
                else:
                    state.n_critic = max(1, state.n_critic - 1)
            elif self.d_cos_sim_ema < self.cos_sim_threshold_low:
                # Gradients too inconsistent ‚Üí D might need more updates
                if self.k_max is not None:
                    state.n_critic = min(self.k_max, state.n_critic + 1)
                else:
                    state.n_critic = state.n_critic + 1
        
        # Store for next iteration
        self.prev_d_grad = curr_d_grad.clone() if curr_d_grad is not None else None

    def on_after_generator_update(
        self, 
        state: TrainState, 
        metrics: Dict[str, float], 
        optim_G: torch.optim.Optimizer,
        G: nn.Module = None,
    ) -> None:
        if G is None:
            return
            
        curr_g_grad = compute_gradient_vector(G.parameters())
        
        if curr_g_grad is not None and self.prev_g_grad is not None:
            cos_sim = cosine_similarity_gradients(curr_g_grad, self.prev_g_grad)
            
            if self.g_cos_sim_ema is None:
                self.g_cos_sim_ema = cos_sim
            else:
                self.g_cos_sim_ema = (
                    self.ema_beta * self.g_cos_sim_ema 
                    + (1 - self.ema_beta) * cos_sim
                )
            
            state.g_grad_cos_sim = self.g_cos_sim_ema
            metrics["g_grad_cos_sim"] = cos_sim
            metrics["g_grad_cos_sim_ema"] = self.g_cos_sim_ema
            
        self.prev_g_grad = curr_g_grad.clone() if curr_g_grad is not None else None

controller = AdaptiveDiscriminatorController(k_max=None, k_min=None)

In [30]:
# Cell 6 ‚Äî Optimizers & Hyperparameters (Standard WGAN-GP Defaults)
LR = 1e-4
BETAS = (0.0, 0.9)   # standard WGAN-GP choice

optim_G = torch.optim.Adam(G.parameters(), lr=LR, betas=BETAS)
optim_D = torch.optim.Adam(D.parameters(), lr=LR, betas=BETAS)

state = TrainState(step=0, epoch=0, n_gen=2, n_critic=5, lambda_gp=10.0)

print("LR:", LR, "BETAS:", BETAS, "n_critic:", state.n_critic, "n_gen:", state.n_gen, "lambda_gp:", state.lambda_gp)

LR: 0.0001 BETAS: (0.0, 0.9) n_critic: 5 n_gen: 2 lambda_gp: 10.0


In [31]:
# Cell 7 ‚Äî Training Loop (WGAN-GP + Gradient Direction Consistency Adaptive Control)
# Logs:
#   - wasserstein gap: E[D(real)] - E[D(fake)]
#   - gp: gradient penalty term
#   - d_grad_cos_sim: discriminator gradient direction consistency
#   - g_grad_cos_sim: generator gradient direction consistency
#   - losses for D and G

RUN_TAG = "grad_cos_sim_2"
TAG_DIR = MODEL_DIR / RUN_TAG
CHECKPOINT_DIR = TAG_DIR / "checkpoints"

CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)
TAG_DIR.mkdir(parents=True, exist_ok=True)
print(f"Run directory: {TAG_DIR}")

def save_checkpoint(path, G, D, optim_G, optim_D, state, history):
    """Save full training checkpoint."""
    payload = {
        "G": G.state_dict(),
        "D": D.state_dict(),
        "optim_G": optim_G.state_dict(),
        "optim_D": optim_D.state_dict(),
        "state": state.__dict__,
        "history": dict(history),
    }
    torch.save(payload, path)

def load_checkpoint(path, G, D, optim_G=None, optim_D=None):
    """Load checkpoint and return state + history."""
    ckpt = torch.load(path, map_location=DEVICE)
    G.load_state_dict(ckpt["G"])
    D.load_state_dict(ckpt["D"])
    if optim_G is not None and "optim_G" in ckpt:
        optim_G.load_state_dict(ckpt["optim_G"])
    if optim_D is not None and "optim_D" in ckpt:
        optim_D.load_state_dict(ckpt["optim_D"])
    state = TrainState(**ckpt["state"])
    hist = ckpt.get("history", {})
    return state, hist

def train_wgan_gp(
    loader: DataLoader,
    G: nn.Module,
    D: nn.Module,
    optim_G: torch.optim.Optimizer,
    optim_D: torch.optim.Optimizer,
    state: TrainState,
    controller: AdaptiveDiscriminatorController,
    z_dim: int,
    epochs: int = 10,
    log_every: int = 50,
    save_every_steps: int = 500,
    best_metric: str = "gap_ema",
):
    G.train(); D.train()
    
    # History for plotting
    history = {
        "step": [], "loss_D": [], "loss_G": [], "gap": [], "gap_ema": [],
        "gp": [], "n_critic": [], "d_grad_cos_sim": [], "g_grad_cos_sim": [],
        "n_critic": [], "n_gen": [],
    }
    
    best_score = getattr(state, 'best_score', -1e18)
    
    for epoch in range(epochs):
        state.epoch = epoch
        for batch_idx, real in enumerate(loader):
            state.step += 1
            controller.on_batch_start(state)

            real = real.to(DEVICE)
            bsz = real.size(0)

            # -------------------------
            # Critic updates (n_critic)
            # -------------------------
            metrics_D = {}
            for _ in range(state.n_critic):
                z = torch.randn(bsz, z_dim, device=DEVICE)
                fake = G(z).detach()

                d_real = D(real).mean()
                d_fake = D(fake).mean()
                gap = (d_real - d_fake).item()

                gp = gradient_penalty(D, real, fake)
                loss_D = (d_fake - d_real) + state.lambda_gp * gp

                optim_D.zero_grad(set_to_none=True)
                loss_D.backward()
                
                # Call controller BEFORE optimizer step (gradients still exist)
                metrics_D = {
                    "d_real": float(d_real.item()),
                    "d_fake": float(d_fake.item()),
                    "gap": float(gap),
                    "gap_ema": float(state.wasserstein_gap_ema),
                    "gp": float(gp.item()),
                    "loss_D": float(loss_D.item()),
                }
                controller.on_after_critic_update(state, metrics_D, optim_D, D=D)
                
                optim_D.step()

                # Update EMA of the gap
                state.wasserstein_gap_ema = state.ema_beta * state.wasserstein_gap_ema + (1 - state.ema_beta) * gap
                metrics_D["gap_ema"] = float(state.wasserstein_gap_ema)

            # -------------------------
            # Generator update
            # -------------------------
            z = torch.randn(bsz, z_dim, device=DEVICE)
            fake = G(z)
            loss_G = -D(fake).mean()

            optim_G.zero_grad(set_to_none=True)
            loss_G.backward()
            
            metrics_G = {"loss_G": float(loss_G.item())}
            # Call controller BEFORE optimizer step (gradients still exist)
            controller.on_after_generator_update(state, metrics_G, optim_G, G=G)
            
            optim_G.step()

            # -------------------------
            # Log history
            # -------------------------
            history["step"].append(state.step)
            history["loss_D"].append(metrics_D.get("loss_D", 0))
            history["loss_G"].append(metrics_G.get("loss_G", 0))
            history["gap"].append(metrics_D.get("gap", 0))
            
            history["gap_ema"].append(metrics_D.get("gap_ema", 0))
            history["gp"].append(metrics_D.get("gp", 0))
            history["n_critic"].append(state.n_critic)
            history["n_gen"].append(state.n_gen)
            history["d_grad_cos_sim"].append(metrics_D.get("d_grad_cos_sim_ema", 0))
            history["g_grad_cos_sim"].append(metrics_G.get("g_grad_cos_sim_ema", 0))
            

            if (batch_idx % log_every) == 0:
                d_cos = metrics_D.get("d_grad_cos_sim_ema", 0)
                g_cos = metrics_G.get("g_grad_cos_sim_ema", 0)
                print(
                    f"[Epoch {epoch:03d}/{epochs:03d}] [Batch {batch_idx:04d}/{len(loader):04d}] "
                    f"[n_critic: {state.n_critic}] [n_gen: {state.n_gen}] "
                    f"[gap: {metrics_D['gap']:+.3f} | ema: {metrics_D['gap_ema']:+.3f}] "
                    f"[GP: {metrics_D['gp']:.3f}] "
                    f"[D_cos: {d_cos:+.3f}] [G_cos: {g_cos:+.3f}] "
                    f"[D: {metrics_D['loss_D']:+.3f}] [G: {metrics_G['loss_G']:+.3f}]"
                )

            # -------------------------
            # Save periodic checkpoint
            # -------------------------
            if save_every_steps > 0 and (state.step % save_every_steps) == 0:
                ckpt_path = CHECKPOINT_DIR / f"ckpt_{RUN_TAG}_step_{state.step}.pt"
                save_checkpoint(str(ckpt_path), G, D, optim_G, optim_D, state, history)
                print(f"üíæ Saved checkpoint: {ckpt_path}")

            # -------------------------
            # Save best checkpoint
            # -------------------------
            if metrics_D:
                score = float(metrics_D.get(best_metric, -1e18))
                if score > best_score:
                    best_score = score
                    best_path = CHECKPOINT_DIR / f"best_{RUN_TAG}.pt"
                    save_checkpoint(str(best_path), G, D, optim_G, optim_D, state, history)
                    print(f"üèÜ New BEST ({best_metric}={score:.4f}) -> {best_path}")
    
    # Save final checkpoint
    final_path = TAG_DIR / f"final_{RUN_TAG}.pt"
    save_checkpoint(str(final_path), G, D, optim_G, optim_D, state, history)
    print(f"‚úÖ Saved final checkpoint: {final_path}")
    
    return history

history = train_wgan_gp(
    loader, G, D, optim_G, optim_D, state, controller, 
    z_dim=Z_DIM, epochs=1000, log_every=50, save_every_steps=500
)

Run directory: /Users/ratchanonkhongsawi/Desktop/CMKL/3rd/S2/Research/gans_eeg/gans_eeg/pluem/models/grad_cos_sim_2
[Epoch 000/1000] [Batch 0000/0057] [n_critic: 1] [n_gen: 2] [gap: +0.602 | ema: +0.018] [GP: 231.649] [D_cos: +0.975] [G_cos: +0.000] [D: +2315.886] [G: +1.056]
üèÜ New BEST (gap_ema=0.0177) -> /Users/ratchanonkhongsawi/Desktop/CMKL/3rd/S2/Research/gans_eeg/gans_eeg/pluem/models/grad_cos_sim_2/checkpoints/best_grad_cos_sim_2.pt


KeyboardInterrupt: 

In [None]:
# Cell 8 ‚Äî Quick Sanity Sample (Optional)
# This just samples a few generated sequences so you can inspect shapes.
with torch.no_grad():
    fake = sample_generator(G, n=4, z_dim=Z_DIM)
print("Generated batch shape:", tuple(fake.shape))  # (N, C, L)
print("Example stats:", fake.mean().item(), fake.std().item())


In [None]:
# Cell 9 ‚Äî Save Models (Inference-only weights)
# Save just the model weights for inference (smaller files, no optimizer state)

# Generator only (for generating new samples)
generator_path = TAG_DIR / f"generator_{RUN_TAG}.pt"
torch.save(G.state_dict(), generator_path)
print(f"‚úÖ Saved Generator: {generator_path}")

# Critic only (for evaluation if needed)
critic_path = TAG_DIR / f"critic_{RUN_TAG}.pt"
torch.save(D.state_dict(), critic_path)
print(f"‚úÖ Saved Critic: {critic_path}")

# Save training history as well
import json
history_path = TAG_DIR / f"history_{RUN_TAG}.json"
with open(history_path, "w") as f:
    json.dump(history, f)
print(f"‚úÖ Saved history: {history_path}")

print(f"\nüìÅ All models saved to: {TAG_DIR}")

In [None]:
# Cell 8 ‚Äî Training Curve Visualization (with Gradient Consistency)
import matplotlib.pyplot as plt

def plot_training_history(history: dict):
    if len(history.get("step", [])) == 0:
        print("History is empty. Train first.")
        return

    step = history["step"]
    fig, axes = plt.subplots(4, 2, figsize=(16, 12))

    # Losses
    axes[0, 0].plot(step, history["loss_G"], label="loss_G", alpha=0.8)
    axes[0, 0].plot(step, history["loss_D"], label="loss_D", alpha=0.8)
    axes[0, 0].set_title("Generator / Critic Loss")
    axes[0, 0].set_xlabel("step")
    axes[0, 0].set_ylabel("loss")
    axes[0, 0].legend()
    axes[0, 0].grid(True)

    # Wasserstein gap
    axes[0, 1].plot(step, history["gap"], label="gap", alpha=0.5)
    axes[0, 1].plot(step, history["gap_ema"], label="gap_ema", linewidth=2)
    axes[0, 1].set_title("Wasserstein Gap")
    axes[0, 1].set_xlabel("step")
    axes[0, 1].set_ylabel("E[D(real)] - E[D(fake)]")
    axes[0, 1].legend()
    axes[0, 1].grid(True)

    # Gradient Penalty
    axes[1, 0].plot(step, history["gp"], label="GP", color="orange")
    axes[1, 0].set_title("Gradient Penalty")
    axes[1, 0].set_xlabel("step")
    axes[1, 0].set_ylabel("gp")
    axes[1, 0].grid(True)

    # Gradient Direction Consistency (Cosine Similarity)
    axes[1, 1].plot(step, history["d_grad_cos_sim"], label="D grad cos_sim", alpha=0.8)
    axes[1, 1].plot(step, history["g_grad_cos_sim"], label="G grad cos_sim", alpha=0.8)
    axes[1, 1].axhline(y=0.9, color='r', linestyle='--', alpha=0.5, label='high threshold')
    axes[1, 1].axhline(y=0.3, color='b', linestyle='--', alpha=0.5, label='low threshold')
    axes[1, 1].set_title("Gradient Direction Consistency (Cosine Similarity)")
    axes[1, 1].set_xlabel("step")
    axes[1, 1].set_ylabel("cosine similarity")
    axes[1, 1].set_ylim(-1.1, 1.1)
    axes[1, 1].legend()
    axes[1, 1].grid(True)

    # n_critic over time
    axes[2, 0].plot(step, history["n_critic"], label="n_critic", color="green", drawstyle='steps-post')
    axes[2, 0].set_title("Adaptive n_critic")
    axes[2, 0].set_xlabel("step")
    axes[2, 0].set_ylabel("n_critic")
    axes[2, 0].set_ylim(0, 8)
    axes[2, 0].grid(True)
    

    # Combined: D cosine sim vs n_critic
    ax2 = axes[2, 1]
    ax2.plot(step, history["d_grad_cos_sim"], label="D grad cos_sim", color="blue", alpha=0.7)
    ax2.set_xlabel("step")
    ax2.set_ylabel("D grad cosine similarity", color="blue")
    ax2.tick_params(axis='y', labelcolor="blue")
    ax2.set_ylim(-1.1, 1.1)
    
    ax2_twin = ax2.twinx()
    ax2_twin.plot(step, history["n_critic"], label="n_critic", color="green", alpha=0.7, drawstyle='steps-post')
    ax2_twin.set_ylabel("n_critic", color="green")
    ax2_twin.tick_params(axis='y', labelcolor="green")
    ax2_twin.set_ylim(0, 8)
    
    axes[2, 1].set_title("D Gradient Consistency vs n_critic")
    axes[2, 1].grid(True)
    
    # n_gen over time
    axes[3, 0].plot(step, history["n_gen"], label="n_gen", color="purple", drawstyle='steps-post')
    axes[3, 0].set_title("Adaptive n_gen")
    axes[3, 0].set_xlabel("step")
    axes[3, 0].set_ylabel("n_gen")
    axes[3, 0].set_ylim(0, 8)
    axes[3, 0].grid(True)
    
    # Combined: G cosine sim vs n_gen
    ax4 = axes[3, 1]
    ax4.plot(step, history["g_grad_cos_sim"], label="G grad cos_sim", color="blue", alpha=0.7)
    ax4.set_xlabel("step")
    ax4.set_ylabel("G grad cosine similarity", color="blue")
    ax4.tick_params(axis='y', labelcolor="blue")
    ax4.set_ylim(-1.1, 1.1)

    plt.tight_layout()
    plt.savefig(TAG_DIR / f"training_history_{RUN_TAG}.png")
    plt.show()

plot_training_history(history)