# Phase 3: Synthetic EEG Data Generation

## Objective
Generate synthetic EEG data using methods from recent literature:
1. Correlation sampling method (Statistical Approach)
2. WGAN-GP approach (simplified)
3. Evaluation via TSTR, TRTR, clustering, and statistical tests

## Literature Review Summary
- **Correlation Sampling**: Analyze frequency band correlations, generate signals preserving structure
- **WGAN-GP**: More reliable than vanilla GAN for EEG generation
- **Evaluation**: Random Forest classifier, PERMANOVA, clustering overlap, TSTR/TRTR


In [74]:

import json
import os
import random
import warnings
from collections import Counter, defaultdict
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from IPython.display import display
from scipy import signal, stats
from scipy.optimize import linear_sum_assignment
from scipy.spatial.distance import cdist
from sklearn.cluster import KMeans
from sklearn.covariance import LedoitWolf
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import (
    accuracy_score,
    adjusted_rand_score,
    normalized_mutual_info_score,
    roc_auc_score,
    silhouette_score,
)
from sklearn.model_selection import StratifiedShuffleSplit, train_test_split
from sklearn.neighbors import NearestNeighbors
from sklearn.preprocessing import QuantileTransformer
import pickle

try:
    import torch
    from torch import nn, optim
except ImportError:
    torch = None

warnings.filterwarnings('ignore')

RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)
random.seed(RANDOM_SEED)

print('=' * 60)
print('RANDOM SEED SET FOR REPRODUCIBILITY')
print(f'Seed value: {RANDOM_SEED}')
print('=' * 60)

sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (14, 6)

DATA_PATH = Path.home() / '.cache/kagglehub/datasets/nnair25/Alcoholics/versions/1'
TRAIN_PATH = DATA_PATH / 'SMNI_CMI_TRAIN'
TEST_PATH = DATA_PATH / 'SMNI_CMI_TEST'

CONDITION_LEVELS = ['S1', 'S2_match', 'S2_nomatch', 'UNKNOWN']

RUN_CFG = dict(
    groups='pooled',          # 'pooled', 'a_only', 'c_only'
    condition='each',         # 'each', 'pooled', or a specific condition token
    generator='interp',       # 'interp' or 'copula'
    balanced=True,            # balance samples per (class, condition)
    k=5,
    noise=0.0,
    cond_onehot=False,        # append condition one-hot vectors to classifier inputs
    transfer_to=None,         # e.g., ['a'] to score control->alcoholic transfer
    save_dir='../output/phase3_conditional'
)

GROUP_CHOICES = {
    'pooled': {'a', 'c'},
    'a_only': {'a'},
    'c_only': {'c'},
}
if RUN_CFG['groups'] not in GROUP_CHOICES:
    raise ValueError(f"Unknown group config: {RUN_CFG['groups']}")
GROUP_FILTER = GROUP_CHOICES[RUN_CFG['groups']]

if RUN_CFG['condition'] in CONDITION_LEVELS:
    CONDITION_FILTER = {RUN_CFG['condition']}
elif RUN_CFG['condition'] in {'each', 'pooled'}:
    CONDITION_FILTER = None
else:
    raise ValueError(f"Unknown condition scope: {RUN_CFG['condition']}")

TRANSFER_TARGETS = set(RUN_CFG.get('transfer_to') or [])
SAVE_ROOT = Path(RUN_CFG['save_dir'])
SAVE_ROOT.mkdir(parents=True, exist_ok=True)

print('Phase 3: Synthetic EEG Generation (Condition-aware)')
print('=' * 60)
print(f"Group filter      : {sorted(GROUP_FILTER)}")
print(f"Condition filter  : {sorted(CONDITION_FILTER) if CONDITION_FILTER else 'All'}")
print(f"Generator         : {RUN_CFG['generator']}")
print(f"Balanced sampling : {RUN_CFG['balanced']}")
print(f"k / noise         : ({RUN_CFG['k']}, {RUN_CFG['noise']})")
print(f"Transfer targets  : {sorted(TRANSFER_TARGETS) if TRANSFER_TARGETS else 'None'}")
print(f"Save root         : {SAVE_ROOT.resolve()}")


RANDOM SEED SET FOR REPRODUCIBILITY
Seed value: 42
Phase 3: Synthetic EEG Generation (Condition-aware)
Group filter      : ['a', 'c']
Condition filter  : All
Generator         : interp
Balanced sampling : True
k / noise         : (5, 0.0)
Transfer targets  : None
Save root         : /Users/jacksonzhao/Desktop/Synthetic_EEG_Generation/output/phase3_conditional


## 1. Load Analysis Results from Phase 2


In [75]:
# Load Phase 2 results (relative path to output folder)
with open('../output/phase2_analysis_results.pkl', 'rb') as f:
    analysis_results = pickle.load(f)

FREQUENCY_BANDS = analysis_results['frequency_bands']
SAMPLING_RATE = analysis_results['sampling_rate']

print("Loaded Phase 2 analysis results")
print(f"Frequency bands: {FREQUENCY_BANDS}")
print(f"Sampling rate: {SAMPLING_RATE} Hz")

Loaded Phase 2 analysis results
Frequency bands: {'Delta': (0.5, 4), 'Theta': (4, 8), 'Alpha': (8, 13), 'Beta': (13, 30), 'Gamma': (30, 50)}
Sampling rate: 256 Hz


In [76]:
# ---- Shared constants used across sections ----
MIN_EVAL_SAMPLES = 25      # small guard for per-bucket/condition eval
BAND_NAMES = ["Delta","Theta","Alpha","Beta","Gamma"]


## 2. Prepare Real Data for Training


### Data & Feature Summary

- **Corpus**: All available CSV files in `SMNI_CMI_TRAIN` (≈468 total; 235 alcoholic / 233 control). Adjust `MAX_FILES_PER_CLASS` if you need a subset.
- **Epochs**: Dynamically determined from `sensor position × trial` combinations across the full corpus.
- **Features**: 5-D band power vectors computed via Welch’s PSD at 256 Hz covering Δ (0.5–4 Hz), θ (4–8 Hz), α (8–13 Hz), β (13–30 Hz), γ (30–50 Hz).



In [77]:
# Load training files
train_files = sorted(list(TRAIN_PATH.glob('*.csv')))

# Separate alcoholic and control files  
alcoholic_files = []
control_files = []

print("Identifying subject types...")
for file in train_files:
    df_peek = pd.read_csv(file, nrows=1)
    subject_type = df_peek['subject identifier'].iloc[0]
    if subject_type == 'a':
        alcoholic_files.append(file)
    else:
        control_files.append(file)

print(f"Found {len(alcoholic_files)} alcoholic files")
print(f"Found {len(control_files)} control files")

# Configure which files to load for generator training
MAX_FILES_PER_CLASS = None

if MAX_FILES_PER_CLASS is None:
    selected_alcoholic = list(alcoholic_files)
    selected_control = list(control_files)
else:
    selected_alcoholic = list(alcoholic_files[:MAX_FILES_PER_CLASS])
    selected_control = list(control_files[:MAX_FILES_PER_CLASS])

sample_files = selected_alcoholic + selected_control
if not sample_files:
    raise ValueError("No training files selected. Check MAX_FILES_PER_CLASS or dataset path.")

random.shuffle(sample_files)


print(f"\nLoading {len(sample_files)} files for generator training...")
print(f"  Alcoholic files selected: {len(selected_alcoholic)}")
print(f"  Control files selected  : {len(selected_control)}")


Identifying subject types...
Found 235 alcoholic files
Found 233 control files

Loading 468 files for generator training...
  Alcoholic files selected: 235
  Control files selected  : 233


## 3. Extract EEG Features from Real Data


In [78]:

def extract_band_power(signal_data, fs=256, bands=None):
    """Extract power in each frequency band using Welch's method."""
    if bands is None:
        bands = FREQUENCY_BANDS
    freqs, psd = signal.welch(signal_data, fs=fs, nperseg=min(256, len(signal_data)))
    band_powers = {}
    for band_name, (low_freq, high_freq) in bands.items():
        idx = np.logical_and(freqs >= low_freq, freqs <= high_freq)
        band_power = np.trapz(psd[idx], freqs[idx]) if np.any(idx) else 0.0
        band_powers[band_name] = float(band_power)
    return band_powers

MAX_CHANNELS = 5
MAX_TRIALS = 2
MIN_SAMPLES_PER_EPOCH = 128
EPOCH_LENGTH = 256

def _normalize_condition_token(value):
    if value is None or (isinstance(value, float) and np.isnan(value)):
        return 'UNKNOWN'
    token = str(value).strip().lower().replace(',', '')
    if not token:
        return 'UNKNOWN'
    if token.startswith('s1'):
        return 'S1'
    compact = token.replace(' ', '')
    if 'nomatch' in compact or 'no-match' in compact:
        return 'S2_nomatch'
    if token.startswith('s2') or 'match' in token:
        return 'S2_match'
    return 'UNKNOWN'

def get_condition_token(row):
    """Map a trial row to {S1, S2_match, S2_nomatch} tokens."""
    for key in ['condition', 'matching condition', 'stimulus']:
        if key in row and pd.notna(row[key]):
            token = _normalize_condition_token(row[key])
            if token != 'UNKNOWN':
                return token
    stimulus = str(row.get('stimulus', '')).strip().upper()
    if stimulus == 'S1':
        return 'S1'
    if stimulus == 'S2':
        match_val = row.get('match', row.get('matching', 0))
        try:
            match_flag = int(match_val)
        except (TypeError, ValueError):
            match_flag = 0
        return 'S2_match' if match_flag == 1 else 'S2_nomatch'
    return 'UNKNOWN'

all_features = []
all_signals = []
all_labels = []
all_conditions = []
all_groups = []

print('Extracting features from real EEG data...')
progress_interval = max(1, len(sample_files) // 10) if sample_files else 1
for file_idx, file in enumerate(sample_files):
    df = pd.read_csv(file)
    subject_type = df['subject identifier'].iloc[0]

    channels = df['sensor position'].unique()[:MAX_CHANNELS]
    trials = df['trial number'].unique()[:MAX_TRIALS]

    for channel in channels:
        for trial in trials:
            trial_data = df[
                (df['sensor position'] == channel) &
                (df['trial number'] == trial)
            ].sort_values('sample num')

            if len(trial_data) < MIN_SAMPLES_PER_EPOCH:
                continue

            signal_data = trial_data['sensor value'].values[:EPOCH_LENGTH]
            band_powers = extract_band_power(signal_data)
            cond_token = get_condition_token(trial_data.iloc[0])

            all_features.append(list(band_powers.values()))
            all_signals.append(signal_data)
            all_labels.append(1 if subject_type == 'a' else 0)
            all_conditions.append(cond_token)
            all_groups.append(subject_type)

    if ((file_idx + 1) % progress_interval == 0) or (file_idx + 1 == len(sample_files)):
        print(f"  Processed {file_idx + 1}/{len(sample_files)} files...")

all_features = np.array(all_features)
all_signals = np.array(all_signals)
all_labels = np.array(all_labels, dtype=int)
all_conditions = np.array(all_conditions)
all_groups = np.array(all_groups)

if all_features.size == 0:
    raise RuntimeError('No epochs extracted. Check dataset paths or filters.')

group_mask = np.isin(all_groups, list(GROUP_FILTER))
condition_mask = np.ones_like(all_conditions, dtype=bool)
if CONDITION_FILTER:
    condition_mask = np.isin(all_conditions, list(CONDITION_FILTER))
selection_mask = group_mask & condition_mask
if not np.any(selection_mask):
    raise RuntimeError('Filters removed all data. Loosen GROUP_FILTER/CONDITION_FILTER.')

real_features = all_features[selection_mask]
real_signals = all_signals[selection_mask]
labels = all_labels[selection_mask]
cond_tokens = all_conditions[selection_mask]
group_tokens = all_groups[selection_mask]

FULL_DATASET = {
    'features': all_features,
    'labels': all_labels,
    'conditions': all_conditions,
    'groups': all_groups,
}

FILTERED_DATASET = {
    'features': real_features,
    'signals': real_signals,
    'labels': labels,
    'conditions': cond_tokens,
    'groups': group_tokens,
}

print(f"Total epochs extracted : {len(all_features)}")
print(f"Filtered epochs        : {len(real_features)}")
print(f"Feature shape          : {real_features.shape}")
print(f"Signal shape           : {real_signals.shape}")
print('Class distribution (filtered):', Counter(labels.tolist()))
print('Condition distribution (filtered):', Counter(cond_tokens.tolist()))
print('Buckets (class, condition):', Counter(list(zip(labels.tolist(), cond_tokens.tolist()))))


Extracting features from real EEG data...
  Processed 46/468 files...
  Processed 92/468 files...
  Processed 138/468 files...
  Processed 184/468 files...
  Processed 230/468 files...
  Processed 276/468 files...
  Processed 322/468 files...
  Processed 368/468 files...
  Processed 414/468 files...
  Processed 460/468 files...
  Processed 468/468 files...
Total epochs extracted : 2340
Filtered epochs        : 2340
Feature shape          : (2340, 5)
Signal shape           : (2340, 256)
Class distribution (filtered): Counter({1: 1175, 0: 1165})
Condition distribution (filtered): Counter({'S1': 800, 'S2_match': 795, 'S2_nomatch': 745})
Buckets (class, condition): Counter({(1, 'S2_match'): 400, (0, 'S1'): 400, (1, 'S1'): 400, (0, 'S2_match'): 395, (1, 'S2_nomatch'): 375, (0, 'S2_nomatch'): 370})


## 4. Method 1: Correlation Sampling Approach

Based on "A Statistical Approach for Synthetic EEG Data Generation"

Steps:
1. Compute correlation matrix of frequency band features
2. Sample from multivariate normal distribution preserving correlations
3. Generate synthetic features matching real data statistics


In [79]:
def generate_correlation_based_eeg(real_features, n_synthetic=100, random_seed=42):
    """
    Generate synthetic EEG using correlation sampling method
    
    This preserves the correlation structure between frequency bands
    """
    # Set random seed for reproducibility
    np.random.seed(random_seed)
    
    # Compute correlation matrix and statistics
    correlation_matrix = np.corrcoef(real_features.T)
    mean_features = np.mean(real_features, axis=0)
    std_features = np.std(real_features, axis=0)
    
    print("Correlation Matrix of Frequency Bands:")
    band_names = list(FREQUENCY_BANDS.keys())
    for i, band1 in enumerate(band_names):
        for j, band2 in enumerate(band_names):
            if j >= i:
                print(f"{band1:6s} - {band2:6s}: {correlation_matrix[i,j]:6.3f}")
    
    # Generate synthetic features preserving correlation structure
    covariance_matrix = np.outer(std_features, std_features) * correlation_matrix
    
    synthetic_features = np.random.multivariate_normal(
        mean_features,
        covariance_matrix,
        size=n_synthetic
    )
    
    # Ensure non-negative powers
    synthetic_features = np.abs(synthetic_features)
    
    print(f"\\nGenerated {n_synthetic} synthetic feature vectors")
    print(f"Correlation structure preserved")
    
    return synthetic_features, correlation_matrix

# Generate synthetic data
n_synthetic_samples = len(real_features)
print(f"\\nGenerating {n_synthetic_samples} synthetic samples...")

synthetic_features_corr, corr_matrix = generate_correlation_based_eeg(
    real_features,
    n_synthetic=n_synthetic_samples,
    random_seed=RANDOM_SEED
)

print(f"Shape of synthetic features: {synthetic_features_corr.shape}")


\nGenerating 2340 synthetic samples...
Correlation Matrix of Frequency Bands:
Delta  - Delta :  1.000
Delta  - Theta :  0.570
Delta  - Alpha :  0.025
Delta  - Beta  : -0.048
Delta  - Gamma : -0.057
Theta  - Theta :  1.000
Theta  - Alpha :  0.352
Theta  - Beta  : -0.019
Theta  - Gamma : -0.010
Alpha  - Alpha :  1.000
Alpha  - Beta  :  0.041
Alpha  - Gamma :  0.042
Beta   - Beta  :  1.000
Beta   - Gamma :  0.952
Gamma  - Gamma :  1.000
\nGenerated 2340 synthetic feature vectors
Correlation structure preserved
Shape of synthetic features: (2340, 5)


## 5. Method 2: Neural Generators

We benchmark two neural approaches in addition to the correlation sampler:

1. **Mixup Baseline** – deterministic interpolation + Gaussian jitter (no adversary) for a quick sanity check.
2. **WGAN-GP** – fully adversarial training with gradient penalty (Gulrajani et al., 2017) implemented in PyTorch.


In [80]:
def generate_mixup_baseline(real_features, n_synthetic=100, random_seed=42):
    """
    Mixup-style baseline using:
    - Interpolation between randomly sampled real points
    - Addition of controlled gaussian noise
    
    Serves as a lightweight reference model prior to adversarial training.
    """
    np.random.seed(random_seed)

    synthetic_features = []

    for _ in range(n_synthetic):
        idx1, idx2 = np.random.choice(len(real_features), 2, replace=False)
        alpha = np.random.uniform(0.3, 0.7)

        interpolated = alpha * real_features[idx1] + (1 - alpha) * real_features[idx2]

        noise_scale = 0.1 * np.std(real_features, axis=0)
        noise = np.random.normal(0, noise_scale)
        synthetic_sample = interpolated + noise

        synthetic_sample = np.abs(synthetic_sample)
        synthetic_features.append(synthetic_sample)

    return np.array(synthetic_features)

print("Generating synthetic data using mixup baseline...")
synthetic_features_mixup = generate_mixup_baseline(
    real_features,
    n_synthetic=n_synthetic_samples,
    random_seed=RANDOM_SEED
)

print(f"Generated {len(synthetic_features_mixup)} synthetic samples")
print(f"Shape: {synthetic_features_mixup.shape}")


Generating synthetic data using mixup baseline...
Generated 2340 synthetic samples
Shape: (2340, 5)


In [81]:
# --- WGAN-GP generator for frequency-band features ----------------------------------

def generate_wgangp_eeg(
    real_features,
    n_synthetic=100,
    noise_dim=16,
    hidden_dim=64,
    n_critic=5,
    gp_lambda=10.0,
    lr=1e-4,
    batch_size=128,
    epochs=300,
    random_seed=42,
):
    """Train a compact WGAN-GP on band-power features and return synthetic samples."""
    if torch is None:
        raise ImportError(
            "PyTorch is required for WGAN-GP synthesis. Install torch>=2.0 to enable this path."
        )

    from torch.utils.data import DataLoader, TensorDataset

    torch.manual_seed(random_seed)
    np.random.seed(random_seed)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    data = torch.from_numpy(real_features.astype(np.float32))
    dataset = TensorDataset(data)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)

    feature_dim = real_features.shape[1]

    class Generator(nn.Module):
        def __init__(self):
            super().__init__()
            self.net = nn.Sequential(
                nn.Linear(noise_dim, hidden_dim),
                nn.LayerNorm(hidden_dim),
                nn.GELU(),
                nn.Linear(hidden_dim, hidden_dim),
                nn.GELU(),
                nn.Linear(hidden_dim, feature_dim),
            )

        def forward(self, z):
            x = self.net(z)
            return torch.nn.functional.softplus(x)

    class Critic(nn.Module):
        def __init__(self):
            super().__init__()
            self.net = nn.Sequential(
                nn.Linear(feature_dim, hidden_dim),
                nn.LayerNorm(hidden_dim),
                nn.GELU(),
                nn.Linear(hidden_dim, hidden_dim),
                nn.GELU(),
                nn.Linear(hidden_dim, 1),
            )

        def forward(self, x):
            return self.net(x)

    def gradient_penalty(critic, real, fake):
        batch_size = real.size(0)
        epsilon = torch.rand(batch_size, 1, device=real.device)
        epsilon = epsilon.expand_as(real)
        interpolated = epsilon * real + (1 - epsilon) * fake
        interpolated.requires_grad_(True)
        mixed_scores = critic(interpolated)
        grad = torch.autograd.grad(
            outputs=mixed_scores,
            inputs=interpolated,
            grad_outputs=torch.ones_like(mixed_scores),
            create_graph=True,
            retain_graph=True,
        )[0]
        grad = grad.view(batch_size, -1)
        gp = ((grad.norm(2, dim=1) - 1) ** 2).mean()
        return gp

    G = Generator().to(device)
    D = Critic().to(device)

    opt_G = optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.9))
    opt_D = optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.9))

    for epoch in range(epochs):
        for i, (real_batch,) in enumerate(loader):
            real_batch = real_batch.to(device)

            for _ in range(n_critic):
                z = torch.randn(real_batch.size(0), noise_dim, device=device)
                fake_batch = G(z).detach()

                opt_D.zero_grad()
                critic_real = D(real_batch).mean()
                critic_fake = D(fake_batch).mean()
                gp = gradient_penalty(D, real_batch, fake_batch)
                loss_D = -(critic_real - critic_fake) + gp_lambda * gp
                loss_D.backward()
                opt_D.step()

            z = torch.randn(real_batch.size(0), noise_dim, device=device)
            opt_G.zero_grad()
            fake_batch = G(z)
            loss_G = -D(fake_batch).mean()
            loss_G.backward()
            opt_G.step()

        if (epoch + 1) % 50 == 0:
            with torch.no_grad():
                z = torch.randn(batch_size, noise_dim, device=device)
                preview = G(z).cpu().numpy()
            preview_mean = preview.mean(axis=0)
            print(
                f"Epoch {epoch + 1:03d}/{epochs} | D: {loss_D.item():.4f} | G: {loss_G.item():.4f} | preview mean={preview_mean.round(3)}"
            )

    with torch.no_grad():
        synth_chunks = []
        remaining = n_synthetic
        while remaining > 0:
            current = min(batch_size, remaining)
            z = torch.randn(current, noise_dim, device=device)
            synth = G(z).cpu().numpy()
            synth_chunks.append(synth)
            remaining -= current

    synthetic = np.vstack(synth_chunks)
    synthetic = np.clip(synthetic, a_min=0.0, a_max=None)
    return synthetic

print("Training WGAN-GP generator (this may take a couple of minutes)...")
if torch is None:
    print("PyTorch not installed; skipping WGAN-GP synthesis.")
    synthetic_features_wgangp = None
else:
    synthetic_features_wgangp = generate_wgangp_eeg(
        real_features,
        n_synthetic=n_synthetic_samples,
        random_seed=RANDOM_SEED,
    )
    print(f"Generated {len(synthetic_features_wgangp)} WGAN-GP samples")
    print(f"Shape: {synthetic_features_wgangp.shape}")



Training WGAN-GP generator (this may take a couple of minutes)...
Epoch 050/300 | D: -22.4700 | G: 58.2359 | preview mean=[8.895 8.354 5.571 8.843 4.666]
Epoch 100/300 | D: -30.9784 | G: 88.3384 | preview mean=[20.698  5.326  4.246  7.774  0.747]
Epoch 150/300 | D: -45.1265 | G: 156.7274 | preview mean=[28.294  7.204  5.142  5.656  1.996]
Epoch 200/300 | D: -42.9474 | G: 277.4485 | preview mean=[26.767  5.771  5.429  4.627  2.58 ]
Epoch 250/300 | D: -57.5499 | G: 426.3501 | preview mean=[24.252  6.067  5.075  6.867  2.8  ]
Epoch 300/300 | D: -80.9785 | G: 641.9765 | preview mean=[27.753  6.476  5.238  5.901  1.241]
Generated 2340 WGAN-GP samples
Shape: (2340, 5)


## 6. Evaluation: Distribution Comparison (KS Test & MMD)

Statistical tests to compare real vs synthetic data distributions


In [82]:
def evaluate_distributions(real_features, synthetic_features, method_name=""):
    """
    Compare distributions using:
    - KS test (Kolmogorov-Smirnov): tests if distributions are from same population
    - MMD (Maximum Mean Discrepancy): measures distance between distributions
    """
    print(f"\\n{'='*60}")
    print(f"Distribution Comparison: {method_name}")
    print("="*60)
    
    # KS test for each feature (frequency band)
    band_names = list(FREQUENCY_BANDS.keys())
    ks_results = []
    
    print("\\nKolmogorov-Smirnov Test Results:")
    print("(p-value > 0.05 suggests distributions are similar)")
    for i, band in enumerate(band_names):
        ks_stat, p_value = stats.ks_2samp(real_features[:, i], synthetic_features[:, i])
        ks_results.append({'band': band, 'ks_stat': ks_stat, 'p_value': p_value})
        
        significance = "✓ Similar" if p_value > 0.05 else "✗ Different"
        print(f"  {band:8s}: KS={ks_stat:.4f}, p={p_value:.4f} {significance}")
    
    # Simplified MMD computation
    def compute_mmd(X, Y):
        """Maximum Mean Discrepancy using pairwise distances"""
        XX = cdist(X, X, metric='euclidean')
        YY = cdist(Y, Y, metric='euclidean')
        XY = cdist(X, Y, metric='euclidean')
        
        mmd = np.mean(XX) + np.mean(YY) - 2 * np.mean(XY)
        return mmd
    
    mmd_score = rbf_mmd(real_features, synthetic_features)
    print(f"\nMMD (RBF) Score: {mmd_score:.6f}")
    print("(Lower MMD indicates more similar distributions)")

    
    return ks_results, mmd_score

distribution_results = {}

distribution_results['Correlation Sampling'] = evaluate_distributions(
    real_features,
    synthetic_features_corr,
    "Correlation Sampling"
)

distribution_results['Mixup Baseline'] = evaluate_distributions(
    real_features,
    synthetic_features_mixup,
    "Mixup Baseline"
)

if synthetic_features_wgangp is not None:
    distribution_results['WGAN-GP'] = evaluate_distributions(
        real_features,
        synthetic_features_wgangp,
        "WGAN-GP"
    )
else:
    print("Skipping WGAN-GP distribution metrics (PyTorch unavailable).")


Distribution Comparison: Correlation Sampling
\nKolmogorov-Smirnov Test Results:
(p-value > 0.05 suggests distributions are similar)
  Delta   : KS=0.3573, p=0.0000 ✗ Different
  Theta   : KS=0.2269, p=0.0000 ✗ Different
  Alpha   : KS=0.3338, p=0.0000 ✗ Different
  Beta    : KS=0.5679, p=0.0000 ✗ Different
  Gamma   : KS=0.9350, p=0.0000 ✗ Different

MMD (RBF) Score: 0.722064
(Lower MMD indicates more similar distributions)
Distribution Comparison: Mixup Baseline
\nKolmogorov-Smirnov Test Results:
(p-value > 0.05 suggests distributions are similar)
  Delta   : KS=0.1385, p=0.0000 ✗ Different
  Theta   : KS=0.1440, p=0.0000 ✗ Different
  Alpha   : KS=0.1321, p=0.0000 ✗ Different
  Beta    : KS=0.1154, p=0.0000 ✗ Different
  Gamma   : KS=0.8192, p=0.0000 ✗ Different

MMD (RBF) Score: 0.325714
(Lower MMD indicates more similar distributions)
Distribution Comparison: WGAN-GP
\nKolmogorov-Smirnov Test Results:
(p-value > 0.05 suggests distributions are similar)
  Delta   : KS=0.1077, p=0.0

## 7. Evaluation: TSTR and TRTR

**TRTR** = Train on Real, Test on Real  
**TSTR** = Train on Synthetic, Test on Real

If TSTR ≈ TRTR, synthetic data quality is high


In [83]:
def evaluate_tstr_trtr(real_features, real_labels, synthetic_features, method_name=""):
    """
    TSTR/TRTR Evaluation from literature
    
    Validates synthetic data by comparing model performance when:
    - Training on real vs synthetic data
    - Testing on real data
    """
    print(f"\\n{'='*60}")
    print(f"TSTR/TRTR Evaluation: {method_name}")
    print("="*60)
    
    # Split real data
    X_train_real, X_test_real, y_train_real, y_test_real = train_test_split(
        real_features, real_labels, test_size=0.3, random_state=RANDOM_SEED
    )
    
    # Create synthetic labels matching real distribution
    n_alcoholic = np.sum(y_train_real == 1)
    n_control = np.sum(y_train_real == 0)
    y_synthetic = np.concatenate([
        np.ones(min(n_alcoholic, len(synthetic_features)//2)),
        np.zeros(min(n_control, len(synthetic_features)//2))
    ])
    X_synthetic = synthetic_features[:len(y_synthetic)]
    
    # TRTR: Train on Real, Test on Real
    print("\\n1. TRTR (Train on Real, Test on Real):")
    clf_trtr = RandomForestClassifier(n_estimators=100, random_state=RANDOM_SEED)
    clf_trtr.fit(X_train_real, y_train_real)
    y_pred_trtr = clf_trtr.predict(X_test_real)
    acc_trtr = accuracy_score(y_test_real, y_pred_trtr)
    print(f"   Accuracy: {acc_trtr:.4f}")
    
    # TSTR: Train on Synthetic, Test on Real
    print("\\n2. TSTR (Train on Synthetic, Test on Real):")
    clf_tstr = RandomForestClassifier(n_estimators=100, random_state=RANDOM_SEED)
    clf_tstr.fit(X_synthetic, y_synthetic)
    y_pred_tstr = clf_tstr.predict(X_test_real)
    acc_tstr = accuracy_score(y_test_real, y_pred_tstr)
    print(f"   Accuracy: {acc_tstr:.4f}")
    
    # Compare
    print(f"\\n3. Performance Comparison:")
    print(f"   TRTR: {acc_trtr:.4f}")
    print(f"   TSTR: {acc_tstr:.4f}")
    print(f"   Difference: {abs(acc_trtr - acc_tstr):.4f}")
    
    if abs(acc_trtr - acc_tstr) < 0.05:
        print("   ✓ Synthetic data quality: EXCELLENT")
    elif abs(acc_trtr - acc_tstr) < 0.10:
        print("   ✓ Synthetic data quality: GOOD")
    else:
        print("   ✗ Synthetic data quality: NEEDS IMPROVEMENT")
    
    return acc_trtr, acc_tstr

# Evaluate both methods
print("\\nEvaluating Correlation Sampling Method:")
acc_trtr_corr, acc_tstr_corr = evaluate_tstr_trtr(
    real_features, labels, synthetic_features_corr, "Correlation"
)

print("\\n" + "="*60)
print("Evaluating Mixup Baseline:")
acc_trtr_mix, acc_tstr_mix = evaluate_tstr_trtr(
    real_features, labels, synthetic_features_mixup, "Mixup Baseline"
)

if synthetic_features_wgangp is not None:
    print("\\n" + "="*60)
    print("Evaluating WGAN-GP Method:")
    acc_trtr_wgan, acc_tstr_wgan = evaluate_tstr_trtr(
        real_features, labels, synthetic_features_wgangp, "WGAN-GP"
    )
else:
    print("\\nSkipping WGAN-GP TSTR/TRTR (PyTorch unavailable).")


\nEvaluating Correlation Sampling Method:
TSTR/TRTR Evaluation: Correlation
\n1. TRTR (Train on Real, Test on Real):
   Accuracy: 0.6966
\n2. TSTR (Train on Synthetic, Test on Real):
   Accuracy: 0.3561
\n3. Performance Comparison:
   TRTR: 0.6966
   TSTR: 0.3561
   Difference: 0.3405
   ✗ Synthetic data quality: NEEDS IMPROVEMENT
Evaluating Mixup Baseline:
TSTR/TRTR Evaluation: Mixup Baseline
\n1. TRTR (Train on Real, Test on Real):
   Accuracy: 0.6966
\n2. TSTR (Train on Synthetic, Test on Real):
   Accuracy: 0.4758
\n3. Performance Comparison:
   TRTR: 0.6966
   TSTR: 0.4758
   Difference: 0.2208
   ✗ Synthetic data quality: NEEDS IMPROVEMENT
Evaluating WGAN-GP Method:
TSTR/TRTR Evaluation: WGAN-GP
\n1. TRTR (Train on Real, Test on Real):
   Accuracy: 0.6966
\n2. TSTR (Train on Synthetic, Test on Real):
   Accuracy: 0.5370
\n3. Performance Comparison:
   TRTR: 0.6966
   TSTR: 0.5370
   Difference: 0.1595
   ✗ Synthetic data quality: NEEDS IMPROVEMENT


## 8. Evaluation: Real vs Synthetic Classification

Train classifier to distinguish real from synthetic.  
**Goal**: Classifier should perform at ~50% (chance level) if synthetic data is indistinguishable


In [84]:
from sklearn.metrics import confusion_matrix, accuracy_score

def evaluate_real_vs_synthetic(real_features, synthetic_features, method_name=""):
    """Train a RF to distinguish real (1) vs synthetic (0); print metrics."""
    if real_features.size == 0 or synthetic_features.size == 0:
        raise ValueError("Both real and synthetic feature matrices must be non-empty.")

    import numpy as np
    from sklearn.model_selection import train_test_split
    from sklearn.ensemble import RandomForestClassifier

    print("\n" + "="*60)
    print(f"Real vs Synthetic Classification: {method_name}")
    print("="*60)

    X = np.vstack([real_features, synthetic_features])
    y = np.concatenate([np.ones(len(real_features), dtype=int),
                        np.zeros(len(synthetic_features), dtype=int)])

    Xtr, Xte, ytr, yte = train_test_split(
        X, y, test_size=0.30, random_state=RANDOM_SEED, stratify=y
    )

    clf = RandomForestClassifier(n_estimators=100, random_state=RANDOM_SEED, class_weight="balanced")
    clf.fit(Xtr, ytr)
    yhat = clf.predict(Xte)

    acc = accuracy_score(yte, yhat)
    print(f"\nClassifier Accuracy: {acc:.4f}")
    if 0.45 <= acc <= 0.55:
        print("✓ EXCELLENT: Classifier at chance level (50%)")
        print("  → Synthetic data indistinguishable from real")
    elif 0.40 <= acc <= 0.60:
        print("✓ GOOD: Classifier struggles to distinguish")
    else:
        print("✗ POOR: Classifier easily distinguishes real from synthetic")

    # fix the printed layout by pinning label order
    cm = confusion_matrix(yte, yhat, labels=[1,0])  # rows: Actual Real, Actual Synthetic
    print("\nConfusion Matrix:")
    print("                 Pred Real  Pred Synthetic")
    print(f"Actual Real:        {cm[0,0]:4d}         {cm[0,1]:4d}")
    print(f"Actual Synthetic:   {cm[1,0]:4d}         {cm[1,1]:4d}")
    return acc

acc_corr = evaluate_real_vs_synthetic(
    real_features, synthetic_features_corr, "Correlation Sampling"
)

acc_mix = evaluate_real_vs_synthetic(
    real_features, synthetic_features_mixup, "Mixup Baseline"
)

acc_wgan = None
if synthetic_features_wgangp is not None:
    acc_wgan = evaluate_real_vs_synthetic(
        real_features, synthetic_features_wgangp, "WGAN-GP"
    )
else:
    print("Skipping WGAN-GP detectability test (PyTorch unavailable).")



Real vs Synthetic Classification: Correlation Sampling

Classifier Accuracy: 0.9900
✗ POOR: Classifier easily distinguishes real from synthetic

Confusion Matrix:
                 Pred Real  Pred Synthetic
Actual Real:         700            2
Actual Synthetic:     12          690

Real vs Synthetic Classification: Mixup Baseline

Classifier Accuracy: 0.9338
✗ POOR: Classifier easily distinguishes real from synthetic

Confusion Matrix:
                 Pred Real  Pred Synthetic
Actual Real:         673           29
Actual Synthetic:     64          638

Real vs Synthetic Classification: WGAN-GP

Classifier Accuracy: 0.9217
✗ POOR: Classifier easily distinguishes real from synthetic

Confusion Matrix:
                 Pred Real  Pred Synthetic
Actual Real:         642           60
Actual Synthetic:     50          652


## 6.1 Clustering Alignment

Assess how mixed real/synthetic embeddings cluster and whether synthetic data preserves real cluster structure.


In [85]:
def clustering_alignment_report(real_features, synthetic_features, method_name, cluster_list=(3, 4, 5, 6)):
    combined = np.vstack([real_features, synthetic_features])
    is_synth = np.concatenate([
        np.zeros(len(real_features), dtype=int),
        np.ones(len(synthetic_features), dtype=int),
    ])

    records = []
    for k in cluster_list:
        km = KMeans(n_clusters=k, n_init=10, random_state=RANDOM_SEED)
        cluster_labels = km.fit_predict(combined)

        nmi = normalized_mutual_info_score(is_synth, cluster_labels)
        ari = adjusted_rand_score(is_synth, cluster_labels)
        sil = silhouette_score(combined, cluster_labels)

        km_real = KMeans(n_clusters=k, n_init=10, random_state=RANDOM_SEED)
        km_synth = KMeans(n_clusters=k, n_init=10, random_state=RANDOM_SEED)
        km_real.fit(real_features)
        km_synth.fit(synthetic_features)
        centroid_cost = cdist(km_real.cluster_centers_, km_synth.cluster_centers_)
        ri, ci = linear_sum_assignment(centroid_cost)
        centroid_gap = centroid_cost[ri, ci].mean()

        records.append(
            {
                'k': k,
                'NMI(real_vs_flag)': nmi,
                'ARI(real_vs_flag)': ari,
                'Silhouette(mixed)': sil,
                'CentroidGap': centroid_gap,
            }
        )

    df = pd.DataFrame(records)
    print(f"\n{'-'*70}\nClustering alignment summary: {method_name}\n{'-'*70}")
    display(df)
    display(df.mean().to_frame(name='mean').T)
    return df

clustering_reports = {}
clustering_reports['Correlation Sampling'] = clustering_alignment_report(
    real_features,
    synthetic_features_corr,
    'Correlation Sampling'
)
clustering_reports['Mixup Baseline'] = clustering_alignment_report(
    real_features,
    synthetic_features_mixup,
    'Mixup Baseline'
)
if synthetic_features_wgangp is not None:
    clustering_reports['WGAN-GP'] = clustering_alignment_report(
        real_features,
        synthetic_features_wgangp,
        'WGAN-GP'
    )
else:
    print("Skipping WGAN-GP clustering analysis (PyTorch unavailable).")




----------------------------------------------------------------------
Clustering alignment summary: Correlation Sampling
----------------------------------------------------------------------


Unnamed: 0,k,NMI(real_vs_flag),ARI(real_vs_flag),Silhouette(mixed),CentroidGap
0,3,0.3139,0.2105,0.7173,1425.6092
1,4,0.3959,0.3253,0.6867,1175.5667
2,5,0.4407,0.3948,0.6645,1025.1884
3,6,0.4547,0.421,0.6512,963.2729


Unnamed: 0,k,NMI(real_vs_flag),ARI(real_vs_flag),Silhouette(mixed),CentroidGap
mean,4.5,0.4013,0.3379,0.6799,1147.4093



----------------------------------------------------------------------
Clustering alignment summary: Mixup Baseline
----------------------------------------------------------------------


Unnamed: 0,k,NMI(real_vs_flag),ARI(real_vs_flag),Silhouette(mixed),CentroidGap
0,3,0.036,0.0006,0.9704,1227.8755
1,4,0.0356,0.0004,0.9632,1399.3514
2,5,0.0299,0.0004,0.7166,1133.0229
3,6,0.0298,0.0003,0.7165,956.6479


Unnamed: 0,k,NMI(real_vs_flag),ARI(real_vs_flag),Silhouette(mixed),CentroidGap
mean,4.5,0.0328,0.0004,0.8417,1179.2244



----------------------------------------------------------------------
Clustering alignment summary: WGAN-GP
----------------------------------------------------------------------


Unnamed: 0,k,NMI(real_vs_flag),ARI(real_vs_flag),Silhouette(mixed),CentroidGap
0,3,0.0093,8.0618e-05,0.7772,1549.3855
1,4,0.0166,0.00043475,0.73,1185.9961
2,5,0.0131,0.0013053,0.6167,972.6475
3,6,0.0162,0.0057373,0.4968,838.3828


Unnamed: 0,k,NMI(real_vs_flag),ARI(real_vs_flag),Silhouette(mixed),CentroidGap
mean,4.5,0.0138,0.0019,0.6552,1136.603


## 6.2 PERMANOVA

Permutation-based multivariate ANOVA to test for distributional differences between real and synthetic samples (pooled and per condition).


In [86]:
def permanova_test(real_features, synthetic_features, method_name, permutations=999):
    try:
        from skbio import DistanceMatrix
        from skbio.stats.distance import permanova
    except ImportError:
        print("scikit-bio is not installed; skipping PERMANOVA for", method_name)
        return None

    combined = np.vstack([real_features, synthetic_features])
    ids = [f"real_{i}" for i in range(len(real_features))] + [
        f"synthetic_{i}" for i in range(len(synthetic_features))
    ]
    grouping = pd.Series(
        ['real'] * len(real_features) + ['synthetic'] * len(synthetic_features),
        index=ids
    )
    distance_matrix = cdist(combined, combined, metric='euclidean')
    dm = DistanceMatrix(distance_matrix, ids=ids)

    result = permanova(dm, grouping, permutations=permutations)
    print(f"\nPERMANOVA ({method_name})")
    print(result)
    return result

permanova_results = {}
permanova_results['Correlation Sampling'] = permanova_test(
    real_features,
    synthetic_features_corr,
    'Correlation Sampling'
)
permanova_results['Mixup Baseline'] = permanova_test(
    real_features,
    synthetic_features_mixup,
    'Mixup Baseline'
)
if synthetic_features_wgangp is not None:
    permanova_results['WGAN-GP'] = permanova_test(
        real_features,
        synthetic_features_wgangp,
        'WGAN-GP'
    )



scikit-bio is not installed; skipping PERMANOVA for Correlation Sampling
scikit-bio is not installed; skipping PERMANOVA for Mixup Baseline
scikit-bio is not installed; skipping PERMANOVA for WGAN-GP


## 9. Improved Synthetic Generation Strategies

To reduce real-vs-synthetic separability while preserving frequency-band correlations, we explore two additional generators:

1. **Gaussian Copula Sampling**: preserves empirical marginals per class and matches correlation structure in a latent Gaussian space.
2. **Class-Conditional Interpolation**: SMOTE-like synthesis operating in log-power space with adaptive neighborhood mixing.



In [87]:
def _allocate_samples_by_class(labels, n_total):
    """Allocate synthetic samples per class, preserving empirical ratios."""
    classes, counts = np.unique(labels, return_counts=True)
    ratios = counts / counts.sum()
    expected = ratios * n_total
    allocated = np.floor(expected).astype(int)
    remainder = n_total - allocated.sum()
    if remainder > 0:
        remainders = expected - allocated
        order = np.argsort(remainders)[::-1]
        for idx in order[:remainder]:
            allocated[idx] += 1
    return dict(zip(classes, allocated))

def generate_gaussian_copula_eeg(real_features, labels, n_synthetic=100, random_seed=42):
    """
    Gaussian copula sampling:
    1. Fit class-conditional quantile transformers to map marginals to Gaussian space
    2. Estimate regularised covariance (Ledoit-Wolf) in latent space
    3. Sample multivariate normal per class and invert the transform
    """
    rng = np.random.default_rng(random_seed)
    allocation = _allocate_samples_by_class(labels, n_synthetic)
    synthetic_blocks = []

    print("Generating Gaussian copula samples per class...")
    for cls, n_cls_samples in allocation.items():
        class_features = real_features[labels == cls]
        if len(class_features) == 0 or n_cls_samples == 0:
            continue

        n_quantiles = min(len(class_features), 1000)
        transformer = QuantileTransformer(
            n_quantiles=n_quantiles,
            output_distribution='normal',
            random_state=random_seed
        )
        latent = transformer.fit_transform(class_features)

        cov_estimator = LedoitWolf().fit(latent)
        latent_mean = cov_estimator.location_
        latent_cov = cov_estimator.covariance_

        latent_samples = rng.multivariate_normal(
            latent_mean,
            latent_cov,
            size=n_cls_samples
        )

        samples = transformer.inverse_transform(latent_samples)
        samples = np.clip(samples, a_min=0, a_max=None)
        synthetic_blocks.append(samples)

        print(f"  Class {cls}: real={len(class_features)}, synthetic={n_cls_samples}")

    if not synthetic_blocks:
        raise ValueError("No synthetic samples were generated. Check class labels.")

    synthetic_features = np.vstack(synthetic_blocks)
    return synthetic_features

In [88]:
synthetic_features_copula = generate_gaussian_copula_eeg(
    real_features,
    labels,
    n_synthetic=n_synthetic_samples,
    random_seed=RANDOM_SEED
)

print(f"\nGenerated {len(synthetic_features_copula)} Gaussian copula samples")
print(f"Shape: {synthetic_features_copula.shape}")

Generating Gaussian copula samples per class...
  Class 0: real=1165, synthetic=1165
  Class 1: real=1175, synthetic=1175

Generated 2340 Gaussian copula samples
Shape: (2340, 5)


In [89]:
def generate_classwise_interpolation_eeg(
    real_features,
    labels,
    n_synthetic=100,
    random_seed=42,
    k_neighbors=8,
    noise_scale=0.02
):
    """
    Class-conditional interpolation inspired by SMOTE.
    Operates in log-power space to better capture multiplicative structure.
    """
    rng = np.random.default_rng(random_seed)
    allocation = _allocate_samples_by_class(labels, n_synthetic)
    synthetic_samples = []

    log_features = np.log1p(real_features)

    for cls, n_cls_samples in allocation.items():
        class_mask = labels == cls
        class_features_log = log_features[class_mask]
        if len(class_features_log) == 0 or n_cls_samples == 0:
            continue

        n_neighbors_eff = min(k_neighbors, len(class_features_log) - 1)
        if n_neighbors_eff <= 0:
            # Not enough samples to interpolate, fallback to jittering existing ones
            base_samples = np.repeat(class_features_log, repeats=max(1, n_cls_samples // max(1, len(class_features_log))), axis=0)
            base_samples = base_samples[:n_cls_samples]
            jitter = rng.normal(0, noise_scale, size=base_samples.shape)
            augmented = base_samples + jitter
            synthetic_samples.append(np.expm1(augmented))
            continue

        nbrs = NearestNeighbors(n_neighbors=n_neighbors_eff + 1)
        nbrs.fit(class_features_log)
        class_std = np.std(class_features_log, axis=0, ddof=1)
        class_std[class_std == 0] = 1e-6

        for _ in range(n_cls_samples):
            idx = rng.integers(len(class_features_log))
            neighbors = nbrs.kneighbors(class_features_log[idx].reshape(1, -1), return_distance=False)[0]
            neighbors = neighbors[neighbors != idx]
            if len(neighbors) == 0:
                neighbor_idx = idx
            else:
                neighbor_idx = rng.choice(neighbors)

            alpha = rng.uniform(0.2, 0.8)
            interpolated = (
                alpha * class_features_log[idx] +
                (1 - alpha) * class_features_log[neighbor_idx]
            )

            noise = rng.normal(0, noise_scale, size=class_features_log.shape[1]) * class_std
            synthetic_log = interpolated + noise
            synthetic_samples.append(np.expm1(synthetic_log))

    if not synthetic_samples:
        raise ValueError("Interpolation generator did not create any samples.")

    synthetic_features = np.vstack(synthetic_samples)
    synthetic_features = np.clip(synthetic_features, a_min=0, a_max=None)
    return synthetic_features

In [90]:
synthetic_features_interp = generate_classwise_interpolation_eeg(
    real_features,
    labels,
    n_synthetic=n_synthetic_samples,
    random_seed=RANDOM_SEED,
    k_neighbors=10, 
    noise_scale=0.015
)

print(f"\nGenerated {len(synthetic_features_interp)} interpolation-based samples")
print(f"Shape: {synthetic_features_interp.shape}")


Generated 2340 interpolation-based samples
Shape: (2340, 5)


### 9.1 Distribution and Quality Checks

Re-run the statistical and downstream evaluations for the new generators alongside previous baselines.


In [91]:
ks_copula, mmd_copula = evaluate_distributions(
    real_features,
    synthetic_features_copula,
    "Gaussian Copula"
)

ks_interp, mmd_interp = evaluate_distributions(
    real_features,
    synthetic_features_interp,
    "Classwise Interpolation"
)

Distribution Comparison: Gaussian Copula
\nKolmogorov-Smirnov Test Results:
(p-value > 0.05 suggests distributions are similar)
  Delta   : KS=0.0192, p=0.7800 ✓ Similar
  Theta   : KS=0.0184, p=0.8245 ✓ Similar
  Alpha   : KS=0.0184, p=0.8245 ✓ Similar
  Beta    : KS=0.0192, p=0.7800 ✓ Similar
  Gamma   : KS=0.0167, p=0.9013 ✓ Similar

MMD (RBF) Score: 0.000844
(Lower MMD indicates more similar distributions)
Distribution Comparison: Classwise Interpolation
\nKolmogorov-Smirnov Test Results:
(p-value > 0.05 suggests distributions are similar)
  Delta   : KS=0.0201, p=0.7328 ✓ Similar
  Theta   : KS=0.0295, p=0.2609 ✓ Similar
  Alpha   : KS=0.0278, p=0.3274 ✓ Similar
  Beta    : KS=0.0235, p=0.5378 ✓ Similar
  Gamma   : KS=0.0333, p=0.1485 ✓ Similar

MMD (RBF) Score: 0.000430
(Lower MMD indicates more similar distributions)


In [92]:
print("\nEvaluating Gaussian Copula Method:")
acc_trtr_copula, acc_tstr_copula = evaluate_tstr_trtr(
    real_features,
    labels,
    synthetic_features_copula,
    "Gaussian Copula"
)

print("\n" + "=" * 60)
print("Evaluating Classwise Interpolation Method:")
acc_trtr_interp, acc_tstr_interp = evaluate_tstr_trtr(
    real_features,
    labels,
    synthetic_features_interp,
    "Classwise Interpolation"
)


Evaluating Gaussian Copula Method:
TSTR/TRTR Evaluation: Gaussian Copula
\n1. TRTR (Train on Real, Test on Real):
   Accuracy: 0.6966
\n2. TSTR (Train on Synthetic, Test on Real):
   Accuracy: 0.3490
\n3. Performance Comparison:
   TRTR: 0.6966
   TSTR: 0.3490
   Difference: 0.3476
   ✗ Synthetic data quality: NEEDS IMPROVEMENT

Evaluating Classwise Interpolation Method:
TSTR/TRTR Evaluation: Classwise Interpolation
\n1. TRTR (Train on Real, Test on Real):
   Accuracy: 0.6966
\n2. TSTR (Train on Synthetic, Test on Real):
   Accuracy: 0.3134
\n3. Performance Comparison:
   TRTR: 0.6966
   TSTR: 0.3134
   Difference: 0.3832
   ✗ Synthetic data quality: NEEDS IMPROVEMENT


In [93]:
acc_copula_sep = evaluate_real_vs_synthetic(
    real_features,
    synthetic_features_copula,
    "Gaussian Copula"
)

acc_interp_sep = evaluate_real_vs_synthetic(
    real_features,
    synthetic_features_interp,
    "Classwise Interpolation"
)


Real vs Synthetic Classification: Gaussian Copula

Classifier Accuracy: 0.5491
✓ EXCELLENT: Classifier at chance level (50%)
  → Synthetic data indistinguishable from real

Confusion Matrix:
                 Pred Real  Pred Synthetic
Actual Real:         374          328
Actual Synthetic:    305          397

Real vs Synthetic Classification: Classwise Interpolation

Classifier Accuracy: 0.4309
✓ GOOD: Classifier struggles to distinguish

Confusion Matrix:
                 Pred Real  Pred Synthetic
Actual Real:         256          446
Actual Synthetic:    353          349


In [94]:
# === Synthetic EEG Tuner: per-class (k, noise) grid + early-stopping ===
# Paste this in one cell. Assumes you have:
#   real_features:  (N, 5) band-power features (Delta..Gamma)
#   labels:         (N,)   class labels {0,1}
# Modify BAND_NAMES if needed.

import numpy as np
from scipy import stats
from sklearn.neighbors import NearestNeighbors
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.covariance import LedoitWolf

# -----------------------------
# Config / Targets
# -----------------------------
BAND_NAMES = ["Delta","Theta","Alpha","Beta","Gamma"]
CLASS_VALUES = np.unique(labels)
assert set(CLASS_VALUES) == {0,1}, "This tuner assumes binary classes {0,1}."

PARAM_GRID_K = [5, 8, 10, 12, 15]
PARAM_GRID_NOISE = [0.00, 0.01, 0.015, 0.02, 0.03]

# Targets (edit to taste)
TARGET_RVS_MAX = 0.60     # real vs synthetic accuracy <= 0.60
TARGET_TSTR_MIN = 0.70    # TSTR >= 0.70
TARGET_GAP_MAX  = 0.15    # |TRTR - TSTR| <= 0.15
TARGET_KS_MINP  = 0.05    # per-band KS p-value > 0.05
TARGET_CORR_SIM = 0.90    # corr-matrix similarity >= 0.90

RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)

# -----------------------------
# Helpers: metrics
# -----------------------------
def ks_pvals_per_band(real, synth):
    pvals = []
    for j in range(real.shape[1]):
        _, p = stats.ks_2samp(real[:, j], synth[:, j])
        pvals.append(p)
    return np.array(pvals)

def rbf_mmd(X, Y, gamma=None):
    """RBF-kernel MMD (biased) with median heuristic for gamma."""
    Z = np.vstack([X, Y])
    # median heuristic
    if gamma is None:
        D2 = np.sum((Z[:,None,:]-Z[None,:,:])**2, axis=2)
        med2 = np.median(D2[D2>0])
        gamma = 1.0 / (med2 + 1e-8)

    def k(a,b):
        D2 = np.sum((a[:,None,:]-b[None,:,:])**2, axis=2)
        return np.exp(-gamma * D2)

    m, n = X.shape[0], Y.shape[0]
    Kxx = k(X,X); Kyy = k(Y,Y); Kxy = k(X,Y)
    return float(Kxx.mean() + Kyy.mean() - 2.0*Kxy.mean())


def corr_matrix_similarity(real, synth):
    C_real = np.corrcoef(real.T)
    C_synth = np.corrcoef(synth.T)
    iu = np.triu_indices_from(C_real, k=1)
    r = np.corrcoef(C_real[iu], C_synth[iu])[0,1]
    return r

def per_band_spearman(real, synth, random_seed=RANDOM_SEED, max_samples=5000):
    if real.size == 0 or synth.size == 0:
        return np.full(real.shape[1], np.nan)

    rng = np.random.default_rng(random_seed)
    n = min(len(real), len(synth), max_samples)

    if len(real) > n:
        idx_real = rng.choice(len(real), size=n, replace=False)
    else:
        idx_real = np.arange(len(real))

    if len(synth) > n:
        idx_synth = rng.choice(len(synth), size=n, replace=False)
    else:
        idx_synth = np.arange(len(synth))

    real_sample = real[idx_real]
    synth_sample = synth[idx_synth]

    vals = []
    for j in range(real.shape[1]):
        rho = stats.spearmanr(real_sample[:, j], synth_sample[:, j]).correlation
        vals.append(rho)
    return np.array(vals)

def real_vs_synth_accuracy(real, synth):
    X = np.vstack([real, synth])
    y = np.hstack([np.zeros(len(real), dtype=int), np.ones(len(synth), dtype=int)])
    Xtr, Xte, ytr, yte = train_test_split(X, y, test_size=0.33, random_state=RANDOM_SEED, stratify=y)
    clf = RandomForestClassifier(n_estimators=200, random_state=RANDOM_SEED, class_weight="balanced")
    clf.fit(Xtr, ytr)
    return clf.score(Xte, yte)

def tstr_trtr_accuracy(real_X, real_y, synth_X, synth_y):
    # Classifier: RF; TSTR = train on synthetic, test on real; TRTR = train on real, test on real
    # Split real set for fair TRTR eval
    Xtr_r, Xte_r, ytr_r, yte_r = train_test_split(real_X, real_y, test_size=0.33, random_state=RANDOM_SEED, stratify=real_y)
    # TRTR
    clf_r = RandomForestClassifier(n_estimators=300, random_state=RANDOM_SEED, class_weight="balanced")
    clf_r.fit(Xtr_r, ytr_r)
    trtr = clf_r.score(Xte_r, yte_r)
    # TSTR
    clf_s = RandomForestClassifier(n_estimators=300, random_state=RANDOM_SEED, class_weight="balanced")
    clf_s.fit(synth_X, synth_y)
    tstr = clf_s.score(Xte_r, yte_r)
    return tstr, trtr

# -----------------------------
# Generator: classwise interpolation in log-space
# (replace with your own generate_classwise_interpolation_eeg_oneclass if you have it)
# -----------------------------
def _interp_one_class(Xc, n_out, k_neighbors=10, noise_scale=0.015, random_state=42):
    """
    Classwise interpolation in log-space with covariance-aware jitter.
    Robust to small class sizes; ensures k>=2 and <= len(Xc)-1.
    """
    rng = np.random.RandomState(random_state)
    eps = 1e-8
    Xc = np.asarray(Xc)
    if Xc.ndim != 2 or Xc.shape[0] == 0:
        raise ValueError("Xc must be (n_samples, n_features) with n_samples>0")

    # If the class is too small, just jitter existing points
    if Xc.shape[0] == 1:
        # log → jitter → exp
        xlog = np.log(Xc + eps)
        synth_log = np.repeat(xlog, n_out, axis=0)
        # fallback covariance: identity
        jitter = rng.normal(size=(n_out, Xc.shape[1]))
        synth_log = synth_log + noise_scale * jitter
        return np.exp(synth_log) - eps

    # Work in log-space to keep positivity on exp back-transform
    Xlog = np.log(Xc + eps)

    # Choose a valid k (at least 2, at most n-1)
    n_samp = Xlog.shape[0]
    k = int(np.clip(k_neighbors, 2, max(2, n_samp - 1)))

    # Fit neighbors in log-space
    nbrs = NearestNeighbors(n_neighbors=k, metric='euclidean').fit(Xlog)

    # Pick base points uniformly
    base_idx = rng.randint(0, n_samp, size=n_out)
    base = Xlog[base_idx]

    # For each base, pick one neighbor randomly (excluding self is handled by k<=n-1)
    neigh_idx = nbrs.kneighbors(base, return_distance=False)
    pick = neigh_idx[np.arange(n_out), rng.randint(0, neigh_idx.shape[1], size=n_out)]
    neigh = Xlog[pick]

    # Random convex combination in [0,1]
    alpha = rng.rand(n_out, 1)
    synth_log = alpha * base + (1 - alpha) * neigh

    # Covariance-aware jitter using Ledoit–Wolf (robust); fallback to diagonal if needed
    if noise_scale and noise_scale > 0:
        try:
            lw = LedoitWolf().fit(Xlog)
            S = lw.covariance_
            # numeric guard: ensure SPD
            # (LedoitWolf should be SPD; add tiny ridge just in case)
            S = S + 1e-8 * np.eye(S.shape[0])
            jitter = rng.multivariate_normal(mean=np.zeros(Xlog.shape[1]), cov=S, size=n_out)
        except Exception:
            jitter = rng.normal(size=(n_out, Xlog.shape[1]))
        synth_log = synth_log + noise_scale * jitter

    synth = np.exp(synth_log) - eps
    return synth


def generate_classwise_interpolation_both_classes(real_X, real_y, n_per_class, k_params, noise_params):
    # k_params = {0: k0, 1: k1}; noise_params = {0: n0, 1: n1}
    synth_list, synth_y = [], []
    for c in CLASS_VALUES:
        Xc = real_X[real_y == c]
        synth_c = _interp_one_class(
            Xc, n_per_class,
            k_neighbors=k_params[c],
            noise_scale=noise_params[c],
            random_state=RANDOM_SEED + c
        )
        synth_list.append(synth_c)
        synth_y.append(np.full(n_per_class, c, dtype=int))
    return np.vstack(synth_list), np.hstack(synth_y)

# -----------------------------
# Scoring + Early stopping
# -----------------------------
def score_combo(metrics):
    # Higher is better. Penalize violations smoothly.
    rvs = metrics["rvs_acc"]
    tstr, trtr = metrics["tstr"], metrics["trtr"]
    ks_min = metrics["ks_min_p"]
    corr_sim = metrics["corr_sim"]

    # Base score
    s = 0.0
    # push RvS down toward 0.55 (reward if <= 0.60, punish otherwise)
    s += 2.0 * max(0.0, 0.60 - rvs)
    # reward higher TSTR and smaller gap
    s += 1.5 * tstr
    s += 1.0 * max(0.0, 0.15 - abs(trtr - tstr))
    # reward KS min p and correlation similarity
    s += 1.0 * min(ks_min, 0.10) * 10.0   # cap effect; scale to ~[0..1]
    s += 1.0 * max(0.0, corr_sim - 0.85)  # only reward above 0.85
    # small reward for small |MMD|
    s += 0.5 * max(0.0, 0.2 - abs(metrics["mmd"]))  # closer to 0 is better

    return s

def meets_targets(metrics):
    return (metrics["rvs_acc"] <= TARGET_RVS_MAX and
            metrics["tstr"]     >= TARGET_TSTR_MIN and
            abs(metrics["trtr"] - metrics["tstr"]) <= TARGET_GAP_MAX and
            metrics["ks_min_p"] >= TARGET_KS_MINP and
            metrics["corr_sim"] >= TARGET_CORR_SIM)

# -----------------------------
# Grid search (per class) but evaluated jointly
# -----------------------------
def tune_interpolation_params(real_X, real_y, verbose=True):
    n_per_class = min(np.sum(real_y==0), np.sum(real_y==1))  # balance
    best = {"score": -np.inf, "params": None, "metrics": None}

    tried = 0
    for k0 in PARAM_GRID_K:
        for n0 in PARAM_GRID_NOISE:
            for k1 in PARAM_GRID_K:
                for n1 in PARAM_GRID_NOISE:
                    tried += 1
                    k_params = {0: k0, 1: k1}
                    noise_params = {0: n0, 1: n1}
                    synth_X, synth_y = generate_classwise_interpolation_both_classes(
                        real_X, real_y, n_per_class, k_params, noise_params
                    )

                    # Metrics
                    ks_p = ks_pvals_per_band(real_X, synth_X)
                    mmd = rbf_mmd(real_X, synth_X, gamma=None)
                    rvs = real_vs_synth_accuracy(real_X, synth_X)
                    tstr, trtr = tstr_trtr_accuracy(real_X, real_y, synth_X, synth_y)
                    corr_sim = corr_matrix_similarity(real_X, synth_X)
                    rho = per_band_spearman(real_X, synth_X)

                    metrics = {
                        "rvs_acc": rvs,
                        "tstr": tstr,
                        "trtr": trtr,
                        "gap": abs(trtr - tstr),
                        "ks_min_p": float(np.min(ks_p)),
                        "mmd": float(mmd),
                        "corr_sim": float(corr_sim),
                        "rho_min": float(np.nanmin(rho)),
                        "rho_mean": float(np.nanmean(rho)),
                    }
                    sc = score_combo(metrics)

                    if sc > best["score"]:
                        best = {"score": sc, "params": (k_params, noise_params), "metrics": metrics, 
                                "synth": (synth_X, synth_y)}

                    if verbose and tried % 20 == 0:
                        print(f"[{tried:4d}] k0={k0}, n0={n0:.3f} | k1={k1}, n1={n1:.3f} "
                              f"RvS={rvs:.3f} TSTR/TRTR={tstr:.3f}/{trtr:.3f} "
                              f"KSmin={metrics['ks_min_p']:.3f} CorrSim={corr_sim:.3f} MMD={mmd:.3f}")

                    # Early stopping: break as soon as all targets met
                    if meets_targets(metrics):
                        if verbose:
                            print("\n✓ Early-stop: targets met")
                            print("  Params:", k_params, noise_params)
                            print("  Metrics:", metrics)
                        return {"best": best, "early_stop": True}

    if verbose:
        print("\nNo combo met all targets. Returning best observed.")
        print("Best params:", best["params"])
        print("Best metrics:", best["metrics"])
    return {"best": best, "early_stop": False}

# -----------------------------
# Run tuning
# -----------------------------
result = tune_interpolation_params(real_features, labels, verbose=True)

best_params = result["best"]["params"]
best_metrics = result["best"]["metrics"]
best_synth_X, best_synth_y = result["best"]["synth"]

print("\n=== BEST COMBINATION ===")
print("k_params:", best_params[0], "noise_params:", best_params[1])
print("metrics :", best_metrics)


✓ Early-stop: targets met
  Params: {0: 5, 1: 10} {0: 0.0, 1: 0.0}
  Metrics: {'rvs_acc': 0.36186770428015563, 'tstr': 0.8408796895213454, 'trtr': 0.6998706338939198, 'gap': 0.1410090556274256, 'ks_min_p': 0.3261561157352818, 'mmd': 0.0004134147321114279, 'corr_sim': 0.995270237210689, 'rho_min': -0.03868685042240337, 'rho_mean': -0.002450164703319186}

=== BEST COMBINATION ===
k_params: {0: 5, 1: 5} noise_params: {0: 0.0, 1: 0.0}
metrics : {'rvs_acc': 0.33787289234760054, 'tstr': 0.8745148771021992, 'trtr': 0.6998706338939198, 'gap': 0.17464424320827943, 'ks_min_p': 0.4868306191329178, 'mmd': 0.000296001036291349, 'corr_sim': 0.9921030346539852, 'rho_min': -0.03970450882614366, 'rho_mean': -0.002404962460893624}


In [95]:
# === DROP-IN: Robust tuner with detectability, fair TSTR, matched Spearman ===
import numpy as np
from scipy import stats
from sklearn.neighbors import NearestNeighbors
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.covariance import LedoitWolf
from sklearn.metrics import roc_auc_score

BAND_NAMES = ["Delta","Theta","Alpha","Beta","Gamma"]
RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)

# Targets
TARGET_DETECT_MAX = 0.60   # detectability = max(acc, 1-acc) <= 0.60
TARGET_TSTR_MIN   = 0.70
TARGET_GAP_MAX    = 0.15
TARGET_KS_MINP    = 0.05
TARGET_CORR_SIM   = 0.90

# Grids (adjust as needed)
PARAM_GRID_K      = [5, 8, 10, 12]
PARAM_GRID_NOISE  = [0.00, 0.01, 0.015, 0.02]

# ---------- Metrics ----------
def ks_pvals_per_band(real, synth):
    return np.array([stats.ks_2samp(real[:, j], synth[:, j])[1] for j in range(real.shape[1])])

def rbf_mmd(X, Y, gamma=None):
    """RBF-kernel MMD (biased) with median heuristic for gamma."""
    Z = np.vstack([X, Y])
    # median heuristic
    if gamma is None:
        D2 = np.sum((Z[:,None,:]-Z[None,:,:])**2, axis=2)
        med2 = np.median(D2[D2>0])
        gamma = 1.0 / (med2 + 1e-8)

    def k(a,b):
        D2 = np.sum((a[:,None,:]-b[None,:,:])**2, axis=2)
        return np.exp(-gamma * D2)

    m, n = X.shape[0], Y.shape[0]
    Kxx = k(X,X); Kyy = k(Y,Y); Kxy = k(X,Y)
    return float(Kxx.mean() + Kyy.mean() - 2.0*Kxy.mean())


def corr_matrix_similarity(real, synth):
    C_r, C_s = np.corrcoef(real.T), np.corrcoef(synth.T)
    iu = np.triu_indices_from(C_r, k=1)
    return np.corrcoef(C_r[iu], C_s[iu])[0,1]

def matched_spearman(real, synth, k=1):
    nbrs = NearestNeighbors(n_neighbors=k, metric='euclidean').fit(synth)
    idx = nbrs.kneighbors(real, return_distance=False)[:,0]
    return np.array([stats.spearmanr(real[:, j], synth[idx, j]).correlation for j in range(real.shape[1])])

def real_vs_synth_detectability(real, synth, seed=RANDOM_SEED):
    X = np.vstack([real, synth])
    y = np.hstack([np.zeros(len(real), dtype=int), np.ones(len(synth), dtype=int)])
    Xtr, Xte, ytr, yte = train_test_split(X, y, test_size=0.33, random_state=seed, stratify=y)
    clf = RandomForestClassifier(n_estimators=200, random_state=seed, class_weight="balanced")
    clf.fit(Xtr, ytr)
    acc = clf.score(Xte, yte)
    try:
        proba = clf.predict_proba(Xte)[:,1]
        auc = roc_auc_score(yte, proba)
    except Exception:
        auc = 0.5
    detect = max(acc, 1.0 - acc)
    return detect, acc, auc

# ---------- Generator: classwise interpolation in log-space ----------
def _interp_one_class(Xc, n_out, k_neighbors=10, noise_scale=0.015, random_state=RANDOM_SEED):
    rng = np.random.RandomState(random_state)
    eps = 1e-8
    Xc = np.asarray(Xc)
    if Xc.shape[0] == 1:
        xlog = np.log(Xc + eps)
        synth_log = np.repeat(xlog, n_out, axis=0)
        jitter = rng.normal(size=(n_out, Xc.shape[1]))
        synth_log = synth_log + noise_scale * jitter
        return np.exp(synth_log) - eps

    Xlog = np.log(Xc + eps)
    n_samp = Xlog.shape[0]
    k = int(np.clip(k_neighbors, 2, max(2, n_samp - 1)))
    nbrs = NearestNeighbors(n_neighbors=k, metric='euclidean').fit(Xlog)

    base_idx = rng.randint(0, n_samp, size=n_out)
    base = Xlog[base_idx]
    neigh_idx = nbrs.kneighbors(base, return_distance=False)
    pick = neigh_idx[np.arange(n_out), rng.randint(0, neigh_idx.shape[1], size=n_out)]
    neigh = Xlog[pick]

    alpha = rng.rand(n_out, 1)
    synth_log = alpha * base + (1 - alpha) * neigh

    if noise_scale and noise_scale > 0:
        try:
            lw = LedoitWolf().fit(Xlog)
            S = lw.covariance_ + 1e-8*np.eye(Xlog.shape[1])
            jitter = rng.multivariate_normal(mean=np.zeros(Xlog.shape[1]), cov=S, size=n_out)
        except Exception:
            jitter = rng.normal(size=(n_out, Xlog.shape[1]))
        synth_log = synth_log + noise_scale * jitter

    return np.exp(synth_log) - eps

def gen_interp_per_class(Xtr_r, ytr_r, n_per_class, k_params, noise_params):
    synth_list, synth_y = [], []
    for c in np.unique(ytr_r):
        Xc = Xtr_r[ytr_r == c]
        k_c = int(np.clip(k_params[c], 2, max(2, Xc.shape[0]-1)))
        n_c = float(noise_params[c])
        synth_c = _interp_one_class(Xc, n_per_class, k_neighbors=k_c, noise_scale=n_c, random_state=RANDOM_SEED + c)
        synth_list.append(synth_c)
        synth_y.append(np.full(n_per_class, c, dtype=int))
    return np.vstack(synth_list), np.hstack(synth_y)

# ---------- Fair TSTR/TRTR (train-only generation) ----------
def tstr_trtr_fair(real_X, real_y, k_params, noise_params, seed=RANDOM_SEED):
    Xtr_r, Xte_r, ytr_r, yte_r = train_test_split(real_X, real_y, test_size=0.33, random_state=seed, stratify=real_y)
    n_per_class = min(np.sum(ytr_r==0), np.sum(ytr_r==1))
    synth_X, synth_y = gen_interp_per_class(Xtr_r, ytr_r, n_per_class, k_params, noise_params)

    clf_r = RandomForestClassifier(n_estimators=300, random_state=seed, class_weight="balanced")
    clf_r.fit(Xtr_r, ytr_r)
    trtr = clf_r.score(Xte_r, yte_r)

    clf_s = RandomForestClassifier(n_estimators=300, random_state=seed, class_weight="balanced")
    clf_s.fit(synth_X, synth_y)
    tstr = clf_s.score(Xte_r, yte_r)
    return tstr, trtr, (Xtr_r, Xte_r, ytr_r, yte_r), (synth_X, synth_y)

# ---------- Scoring & Early-stop ----------
def score_combo(metrics):
    s = 0.0
    s += 2.0 * max(0.0, 0.60 - metrics["detect"])      # lower detectability is better
    s += 1.5 * metrics["tstr"]                          # higher TSTR is better
    s += 1.0 * max(0.0, 0.15 - abs(metrics["trtr"] - metrics["tstr"]))  # small gap
    s += 1.0 * min(metrics["ks_min_p"], 0.10) * 10.0    # KS min p (capped)
    s += 1.0 * max(0.0, metrics["corr_sim"] - 0.85)     # corr similarity above 0.85
    s += 0.5 * max(0.0, 0.2 - abs(metrics["mmd"]))      # MMD close to 0
    return s

def meets_targets(m):
    return (m["detect"] <= TARGET_DETECT_MAX and
            m["tstr"]   >= TARGET_TSTR_MIN and
            abs(m["trtr"] - m["tstr"]) <= TARGET_GAP_MAX and
            m["ks_min_p"] >= TARGET_KS_MINP and
            m["corr_sim"] >= TARGET_CORR_SIM)

# ---------- Grid search ----------
def tune_interpolation_params(real_X, real_y, verbose=True):
    classes = np.unique(real_y)
    assert set(classes) == {0,1}, "Binary classes {0,1} expected."
    best = {"score": -np.inf, "params": None, "metrics": None, "artifacts": None}
    tried = 0

    for k0 in PARAM_GRID_K:
        for n0 in PARAM_GRID_NOISE:
            for k1 in PARAM_GRID_K:
                for n1 in PARAM_GRID_NOISE:
                    tried += 1
                    k_params     = {0: k0, 1: k1}
                    noise_params = {0: n0, 1: n1}

                    # fair TSTR/TRTR with train-only generation
                    tstr, trtr, (Xtr_r, Xte_r, ytr_r, yte_r), (synth_X, synth_y) = tstr_trtr_fair(
                        real_X, real_y, k_params, noise_params, seed=RANDOM_SEED
                    )

                    # evaluate detectability on the same real (train+test) pool vs synth
                    detect, acc_raw, auc = real_vs_synth_detectability(real_X, synth_X, seed=RANDOM_SEED)
                    ks_p = ks_pvals_per_band(real_X, synth_X)
                    mmd  = rbf_mmd(real_X, synth_X)
                    corr = corr_matrix_similarity(real_X, synth_X)
                    rho  = matched_spearman(real_X, synth_X, k=1)

                    metrics = {
                        "detect": float(detect),
                        "rvs_acc_raw": float(acc_raw),
                        "rvs_auc": float(auc),
                        "tstr": float(tstr),
                        "trtr": float(trtr),
                        "gap": float(abs(trtr - tstr)),
                        "ks_min_p": float(np.min(ks_p)),
                        "mmd": float(mmd),
                        "corr_sim": float(corr),
                        "rho_min": float(np.nanmin(rho)),
                        "rho_mean": float(np.nanmean(rho)),
                    }
                    sc = score_combo(metrics)

                    if sc > best["score"]:
                        best = {"score": sc,
                                "params": (k_params, noise_params),
                                "metrics": metrics,
                                "artifacts": {"synth_X": synth_X, "synth_y": synth_y}}

                    if verbose and tried % 20 == 0:
                        print(f"[{tried:4d}] k0={k0}, n0={n0:.3f} | k1={k1}, n1={n1:.3f} "
                              f"Detect={metrics['detect']:.3f} (acc={metrics['rvs_acc_raw']:.3f}, AUC={metrics['rvs_auc']:.3f}) "
                              f"TSTR/TRTR={tstr:.3f}/{trtr:.3f} KSmin={metrics['ks_min_p']:.3f} "
                              f"CorrSim={corr:.3f} MMD={mmd:.4f} ρ_mean={metrics['rho_mean']:.3f}")

                    if meets_targets(metrics):
                        if verbose:
                            print("\n✓ Early-stop: targets met")
                            print("  Params:", k_params, noise_params)
                            print("  Metrics:", metrics)
                        return {"best": best, "early_stop": True}

    if verbose:
        print("\nNo combo met all targets. Returning best observed.")
        print("Best params:", best["params"])
        print("Best metrics:", best["metrics"])
    return {"best": best, "early_stop": False}

# Run it
result = tune_interpolation_params(real_features, labels, verbose=True)
best_params  = result["best"]["params"]
best_metrics = result["best"]["metrics"]
print("\n=== BEST COMBINATION ===")
print("k_params:", best_params[0], "noise_params:", best_params[1])
print("metrics :", best_metrics)


✓ Early-stop: targets met
  Params: {0: 5, 1: 5} {0: 0.0, 1: 0.0}
  Metrics: {'detect': 0.5027195027195027, 'rvs_acc_raw': 0.5027195027195027, 'rvs_auc': 0.38137230242969966, 'tstr': 0.7011642949547219, 'trtr': 0.6998706338939198, 'gap': 0.001293661060802087, 'ks_min_p': 0.3416900838456176, 'mmd': 0.0012224264816664832, 'corr_sim': 0.9820198663815284, 'rho_min': 0.8251903364370348, 'rho_mean': 0.9218774646917103}

=== BEST COMBINATION ===
k_params: {0: 5, 1: 5} noise_params: {0: 0.0, 1: 0.0}
metrics : {'detect': 0.5027195027195027, 'rvs_acc_raw': 0.5027195027195027, 'rvs_auc': 0.38137230242969966, 'tstr': 0.7011642949547219, 'trtr': 0.6998706338939198, 'gap': 0.001293661060802087, 'ks_min_p': 0.3416900838456176, 'mmd': 0.0012224264816664832, 'corr_sim': 0.9820198663815284, 'rho_min': 0.8251903364370348, 'rho_mean': 0.9218774646917103}


In [96]:
BEST_K = {0: 5, 1: 5}
BEST_NOISE = {0: 0.0, 1: 0.0}
SEED = 42

# Regenerate balanced synthetic set using train-only split for fairness
Xtr_r, Xte_r, ytr_r, yte_r = train_test_split(real_features, labels, test_size=0.33,
                                              random_state=SEED, stratify=labels)
n_per_class = min(np.sum(ytr_r==0), np.sum(ytr_r==1))
synth_X, synth_y = gen_interp_per_class(Xtr_r, ytr_r, n_per_class,
                                        k_params=BEST_K, noise_params=BEST_NOISE)

# Recompute headline metrics (detectability/TSTR/TRTR/KS/MMD/corr/rho)
detect, acc_raw, auc = real_vs_synth_detectability(real_features, synth_X, seed=SEED)
tstr, trtr, *_ = tstr_trtr_fair(real_features, labels, BEST_K, BEST_NOISE, seed=SEED)
ks_p = ks_pvals_per_band(real_features, synth_X)
mmd  = rbf_mmd(real_features, synth_X)
corr = corr_matrix_similarity(real_features, synth_X)
rho  = matched_spearman(real_features, synth_X, k=1)

summary = {
    "k_params": BEST_K, "noise_params": BEST_NOISE,
    "detectability": float(detect), "rvs_acc_raw": float(acc_raw), "rvs_auc": float(auc),
    "tstr": float(tstr), "trtr": float(trtr), "gap": float(abs(trtr - tstr)),
    "ks_min_p": float(np.min(ks_p)), "mmd": float(mmd),
    "corr_sim": float(corr), "rho_mean": float(np.nanmean(rho)), "rho_min": float(np.nanmin(rho))
}
print(summary)

{'k_params': {0: 5, 1: 5}, 'noise_params': {0: 0.0, 1: 0.0}, 'detectability': 0.5027195027195027, 'rvs_acc_raw': 0.5027195027195027, 'rvs_auc': 0.38137230242969966, 'tstr': 0.7011642949547219, 'trtr': 0.6998706338939198, 'gap': 0.001293661060802087, 'ks_min_p': 0.3416900838456176, 'mmd': 0.0012224264816664832, 'corr_sim': 0.9820198663815284, 'rho_mean': 0.9218774646917103, 'rho_min': 0.8251903364370348}


In [97]:
# Override: enforce condition-specific training folds for TRTR when evaluating per-condition utility.
def tstr_trtr_per_condition(X_train, y_train, c_train,
                            X_test, y_test, c_test,
                            X_synth, y_synth, c_synth,
                            condition_name,
                            clf_factory,
                            min_samples=MIN_EVAL_SAMPLES):
    if condition_name is None:
        real_mask = np.ones(len(X_test), dtype=bool)
        synth_mask = np.ones(len(X_synth), dtype=bool)
        train_mask = np.ones(len(X_train), dtype=bool)
    else:
        real_mask = (c_test == condition_name)
        synth_mask = (c_synth == condition_name)
        train_mask = (c_train == condition_name)

    n_real = int(real_mask.sum())
    n_synth = int(synth_mask.sum())
    n_train = int(train_mask.sum())
    if n_real < min_samples or n_train < min_samples:
        return None

    clf_real = clf_factory()
    clf_real.fit(maybe_concat_condition(X_train[train_mask], c_train[train_mask]), y_train[train_mask])
    trtr = clf_real.score(
        maybe_concat_condition(X_test[real_mask], c_test[real_mask]),
        y_test[real_mask]
    )

    if n_synth < min_samples:
        return {
            'TRTR': float(trtr),
            'TSTR': None,
            'gap': None,
            'n_test': n_real,
            'n_synth': n_synth
        }

    clf_synth = clf_factory()
    clf_synth.fit(
        maybe_concat_condition(X_synth[synth_mask], c_synth[synth_mask]),
        y_synth[synth_mask]
    )
    tstr = clf_synth.score(
        maybe_concat_condition(X_test[real_mask], c_test[real_mask]),
        y_test[real_mask]
    )

    return {
        'TRTR': float(trtr),
        'TSTR': float(tstr),
        'gap': float(abs(trtr - tstr)),
        'n_test': n_real,
        'n_synth': n_synth
    }



In [98]:
# Save artifacts
import pickle, os
os.makedirs("../output", exist_ok=True)
with open("../output/synth_interp_best.pkl", "wb") as f:
    pickle.dump({"synth_X": synth_X, "synth_y": synth_y, "summary": summary}, f)
print("Saved to ../output/synth_interp_best.pkl")

Saved to ../output/synth_interp_best.pkl


In [99]:
def ensure_statistical_generators(verbose=True):
    """Ensure class-conditional statistical generators are available in globals."""
    global synthetic_features_copula, synthetic_features_interp

    generated = []

    if 'synthetic_features_copula' not in globals():
        synthetic_features_copula = generate_gaussian_copula_eeg(
            real_features,
            labels,
            n_synthetic=n_synthetic_samples,
            random_seed=RANDOM_SEED
        )
        generated.append('Gaussian Copula')

    if 'synthetic_features_interp' not in globals():
        synthetic_features_interp = generate_classwise_interpolation_eeg(
            real_features,
            labels,
            n_synthetic=n_synthetic_samples,
            random_seed=RANDOM_SEED,
            k_neighbors=10,
            noise_scale=0.015
        )
        generated.append('Classwise Interpolation')

    if verbose and generated:
        print(f"Generated statistical baselines: {', '.join(generated)}")
    elif verbose and not generated:
        print("Statistical baselines already present (copula & interpolation).")

ensure_statistical_generators()


Statistical baselines already present (copula & interpolation).


In [100]:
ensure_statistical_generators(verbose=False)


## 10. Condition-aware train/test split

In [101]:
ensure_statistical_generators(verbose=False)


In [102]:
def ensure_filtered_arrays():
    required = ['real_features', 'labels', 'cond_tokens', 'group_tokens']
    missing = [name for name in required if name not in globals()]
    if not missing:
        return
    cache = globals().get('FILTERED_DATASET')
    if cache:
        globals()['real_features'] = cache['features']
        globals()['labels'] = cache['labels']
        globals()['cond_tokens'] = cache['conditions']
        globals()['group_tokens'] = cache['groups']
        missing = [name for name in required if name not in globals()]
        if not missing:
            return
    raise RuntimeError(
        'Condition-aware tensors are unavailable. Run the feature extraction cell in Section 3 before splitting.'
    )

ensure_filtered_arrays()

MIN_STRAT_SAMPLES = 2
bucket_counts = Counter(list(zip(labels.tolist(), cond_tokens.tolist())))
if bucket_counts and min(bucket_counts.values()) >= MIN_STRAT_SAMPLES:
    strat_labels = np.array([f"{int(cls)}_{cond}" for cls, cond in zip(labels, cond_tokens)])
    strat_note = 'class+condition'
else:
    strat_labels = labels
    strat_note = 'class only'

sss = StratifiedShuffleSplit(n_splits=1, test_size=0.33, random_state=RANDOM_SEED)
train_idx, test_idx = next(sss.split(real_features, strat_labels))

X_train, X_test = real_features[train_idx], real_features[test_idx]
y_train, y_test = labels[train_idx], labels[test_idx]
c_train, c_test = cond_tokens[train_idx], cond_tokens[test_idx]
g_train, g_test = group_tokens[train_idx], group_tokens[test_idx]

print(f"Stratified split using {strat_note}")
print(f"Train size: {len(X_train)}, Test size: {len(X_test)}")
print('Train buckets:', Counter(list(zip(y_train.tolist(), c_train.tolist()))))

Stratified split using class+condition
Train size: 1567, Test size: 773
Train buckets: Counter({(1, 'S2_match'): 268, (0, 'S1'): 268, (1, 'S1'): 268, (0, 'S2_match'): 264, (1, 'S2_nomatch'): 251, (0, 'S2_nomatch'): 248})


In [103]:
ensure_statistical_generators(verbose=False)


## 11. Bucketed generators (Gaussian Copula + interpolation)

In [104]:

MIN_BUCKET_SIZE = 20

def allocate_per_bucket(y_class, cond, total=None, balanced=False):
    keys = list(zip(y_class.tolist(), cond.tolist()))
    counts = Counter(keys)
    if not counts:
        return {}
    buckets = sorted(counts.keys())
    allocation = {}
    if balanced:
        n_each = min(counts[b] for b in buckets)
        if n_each == 0:
            return {}
        allocation = {b: n_each for b in buckets}
    else:
        total = int(total or sum(counts.values()))
        base = {b: counts[b] / sum(counts.values()) for b in buckets}
        floored = {b: int(np.floor(total * base[b])) for b in buckets}
        remainder = total - sum(floored.values())
        fractions = sorted(((total * base[b] - floored[b], b) for b in buckets), reverse=True)
        for _, bucket in fractions[:remainder]:
            floored[bucket] += 1
        allocation = floored
    return {b: int(n) for b, n in allocation.items() if n > 0}


def fit_copula_per_bucket(X, y_class, cond):
    models = {}
    rng = np.random.default_rng(RANDOM_SEED)
    for key in sorted(set(zip(y_class.tolist(), cond.tolist()))):
        mask = (y_class == key[0]) & (cond == key[1])
        bucket = X[mask]
        if len(bucket) < MIN_BUCKET_SIZE:
            continue
        qt = QuantileTransformer(output_distribution='normal', random_state=RANDOM_SEED)
        Z = qt.fit_transform(bucket)
        lw = LedoitWolf().fit(Z)
        models[key] = {'qt': qt, 'mu': lw.location_, 'cov': lw.covariance_}
    return models


def sample_copula(models, n_per_bucket):
    rng = np.random.default_rng(RANDOM_SEED)
    samples = []
    labels_out = []
    cond_out = []
    for key, n in n_per_bucket.items():
        if key not in models or n <= 0:
            continue
        model = models[key]
        Z = rng.multivariate_normal(model['mu'], model['cov'], size=n)
        bucket = model['qt'].inverse_transform(Z)
        bucket = np.clip(bucket, 0.0, None)
        samples.append(bucket)
        labels_out.append(np.full(n, key[0], dtype=int))
        cond_out.append(np.array([key[1]] * n))
    if not samples:
        return None, None, None
    X_synth = np.vstack(samples)
    y_synth = np.hstack(labels_out)
    c_synth = np.concatenate(cond_out)
    return X_synth, y_synth, c_synth


def interp_one_bucket(X_bucket, n_out, k_neighbors=10, noise_scale=0.0, seed=RANDOM_SEED):
    rng = np.random.default_rng(seed)
    eps = 1e-8
    X_bucket = np.asarray(X_bucket)
    if len(X_bucket) == 0:
        return None
    if len(X_bucket) == 1:
        xlog = np.log(X_bucket + eps)
        noise = rng.normal(0, 1.0, size=(n_out, X_bucket.shape[1]))
        synth = np.exp(xlog + noise_scale * noise) - eps
        return np.clip(synth, 0.0, None)
    Xlog = np.log(X_bucket + eps)
    k = int(np.clip(k_neighbors, 2, max(2, len(X_bucket) - 1)))
    nbrs = NearestNeighbors(n_neighbors=k, metric='euclidean').fit(Xlog)
    base_idx = rng.integers(0, len(Xlog), size=n_out)
    base = Xlog[base_idx]
    neigh_idx = nbrs.kneighbors(base, return_distance=False)
    pick = neigh_idx[np.arange(n_out), rng.integers(0, neigh_idx.shape[1], size=n_out)]
    neigh = Xlog[pick]
    alpha = rng.random((n_out, 1))
    synth_log = alpha * base + (1 - alpha) * neigh
    if noise_scale > 0:
        jitter = rng.normal(0, noise_scale, size=synth_log.shape)
        synth_log = synth_log + jitter
    synth = np.exp(synth_log) - eps
    return np.clip(synth, 0.0, None)


def interpolate_per_bucket(X, y_class, cond, n_per_bucket, k=5, noise=0.0):
    samples = []
    labels_out = []
    cond_out = []
    for key, n in n_per_bucket.items():
        mask = (y_class == key[0]) & (cond == key[1])
        bucket = X[mask]
        if len(bucket) < 3 or n <= 0:
            continue
        synth = interp_one_bucket(
            bucket,
            n_out=n,
            k_neighbors=k,
            noise_scale=noise,
            seed=RANDOM_SEED + hash(key[1]) % 1000
        )
        if synth is None:
            continue
        samples.append(synth)
        labels_out.append(np.full(len(synth), key[0], dtype=int))
        cond_out.append(np.array([key[1]] * len(synth)))
    if not samples:
        return None, None, None
    X_synth = np.vstack(samples)
    y_synth = np.hstack(labels_out)
    c_synth = np.concatenate(cond_out)
    return X_synth, y_synth, c_synth


def condition_onehot(conditions):
    cats = pd.Categorical(conditions, categories=CONDITION_LEVELS)
    return pd.get_dummies(cats, dummy_na=False).to_numpy()


def maybe_concat_condition(features, conditions):
    if not RUN_CFG.get('cond_onehot', False):
        return features
    return np.hstack([features, condition_onehot(conditions)])


## 12. Generate synthetic dataset (train-only fit)

In [105]:

bucket_allocation = allocate_per_bucket(
    y_train,
    c_train,
    total=len(X_train),
    balanced=RUN_CFG['balanced']
)
print('Requested samples per bucket:', bucket_allocation)

if RUN_CFG['generator'] == 'copula':
    copula_models = fit_copula_per_bucket(X_train, y_train, c_train)
    synth_X, synth_y, synth_c = sample_copula(copula_models, bucket_allocation)
elif RUN_CFG['generator'] == 'interp':
    synth_X, synth_y, synth_c = interpolate_per_bucket(
        X_train,
        y_train,
        c_train,
        bucket_allocation,
        k=RUN_CFG['k'],
        noise=RUN_CFG['noise']
    )
else:
    raise ValueError(f"Unsupported generator: {RUN_CFG['generator']}")

if synth_X is None:
    raise RuntimeError('Generator did not produce any samples. Adjust filters or parameters.')

print(f"Synthetic dataset shape: {synth_X.shape}")
print('Synthetic buckets:', Counter(list(zip(synth_y.tolist(), synth_c.tolist()))))


Requested samples per bucket: {(0, 'S1'): 248, (0, 'S2_match'): 248, (0, 'S2_nomatch'): 248, (1, 'S1'): 248, (1, 'S2_match'): 248, (1, 'S2_nomatch'): 248}
Synthetic dataset shape: (1488, 5)
Synthetic buckets: Counter({(0, 'S1'): 248, (0, 'S2_match'): 248, (0, 'S2_nomatch'): 248, (1, 'S1'): 248, (1, 'S2_match'): 248, (1, 'S2_nomatch'): 248})


## 13. Evaluation helpers (distribution, detectability, utility)

In [106]:

MIN_EVAL_SAMPLES = 25
BAND_NAMES = list(FREQUENCY_BANDS.keys())


def compute_mmd(real_X, synth_X):
    XX = cdist(real_X, real_X, metric='euclidean')
    YY = cdist(synth_X, synth_X, metric='euclidean')
    XY = cdist(real_X, synth_X, metric='euclidean')
    return float(np.mean(XX) + np.mean(YY) - 2 * np.mean(XY))


def corr_matrix_similarity(real_X, synth_X):
    if real_X.size == 0 or synth_X.size == 0:
        return None
    c_real = np.corrcoef(real_X.T)
    c_synth = np.corrcoef(synth_X.T)
    iu = np.triu_indices_from(c_real, k=1)
    if not iu[0].size:
        return None
    return float(np.corrcoef(c_real[iu], c_synth[iu])[0, 1])


def per_band_spearman(real_X, synth_X, max_samples=5000, seed=RANDOM_SEED):
    if real_X.size == 0 or synth_X.size == 0:
        return None
    n = min(len(real_X), len(synth_X), max_samples)
    if n < 2:
        return None
    rng = np.random.default_rng(seed)
    if len(real_X) > n:
        idx_real = rng.choice(len(real_X), size=n, replace=False)
        real_sel = real_X[idx_real]
    else:
        real_sel = real_X
    if len(synth_X) > n:
        idx_synth = rng.choice(len(synth_X), size=n, replace=False)
        synth_sel = synth_X[idx_synth]
    else:
        synth_sel = synth_X
    vals = []
    for idx in range(real_sel.shape[1]):
        rho = stats.spearmanr(real_sel[:, idx], synth_sel[:, idx]).correlation
        vals.append(float(rho))
    return vals


def compute_distribution_metrics(real_X, synth_X):
    if real_X.size == 0 or synth_X.size == 0:
        return None
    ks_rows = []
    for i, band in enumerate(BAND_NAMES):
        stat, p_val = stats.ks_2samp(real_X[:, i], synth_X[:, i])
        ks_rows.append({'band': band, 'ks_stat': float(stat), 'p_value': float(p_val)})
    return {
        'ks': ks_rows,
        'mmd': compute_mmd(real_X, synth_X),
        'corr_sim': corr_matrix_similarity(real_X, synth_X),
        'spearman': per_band_spearman(real_X, synth_X)
    }


def detectability_metrics(real_X, synth_X, seed=RANDOM_SEED):
    if real_X.size == 0 or synth_X.size == 0:
        return None
    if len(real_X) < MIN_EVAL_SAMPLES or len(synth_X) < MIN_EVAL_SAMPLES:
        return None
    X = np.vstack([real_X, synth_X])
    y = np.concatenate([np.ones(len(real_X), dtype=int), np.zeros(len(synth_X), dtype=int)])
    splitter = StratifiedShuffleSplit(n_splits=1, test_size=0.33, random_state=seed)
    train_idx, test_idx = next(splitter.split(X, y))
    clf = RandomForestClassifier(n_estimators=300, random_state=seed, class_weight='balanced')
    clf.fit(X[train_idx], y[train_idx])
    proba = clf.predict_proba(X[test_idx])[:, 1]
    preds = (proba >= 0.5).astype(int)
    acc = accuracy_score(y[test_idx], preds)
    auc = roc_auc_score(y[test_idx], proba)
    return {'accuracy': float(acc), 'auc': float(auc)}


def clf_factory_rf():
    return RandomForestClassifier(n_estimators=300, random_state=RANDOM_SEED, class_weight='balanced')


def tstr_trtr_per_condition(X_train, y_train, c_train,
                            X_test, y_test, c_test,
                            X_synth, y_synth, c_synth,
                            condition_name,
                            clf_factory,
                            min_samples=MIN_EVAL_SAMPLES):
    if condition_name is None:
        real_mask = np.ones(len(X_test), dtype=bool)
        synth_mask = np.ones(len(X_synth), dtype=bool)
    else:
        real_mask = (c_test == condition_name)
        synth_mask = (c_synth == condition_name)
    n_real = int(real_mask.sum())
    n_synth = int(synth_mask.sum())
    if n_real < min_samples:
        return None

    clf_real = clf_factory()
    clf_real.fit(maybe_concat_condition(X_train, c_train), y_train)
    trtr = clf_real.score(maybe_concat_condition(X_test[real_mask], c_test[real_mask]), y_test[real_mask])

    if n_synth < min_samples:
        return {
            'TRTR': float(trtr),
            'TSTR': None,
            'gap': None,
            'n_test': n_real,
            'n_synth': n_synth
        }

    clf_synth = clf_factory()
    clf_synth.fit(maybe_concat_condition(X_synth[synth_mask], c_synth[synth_mask]), y_synth[synth_mask])
    tstr = clf_synth.score(maybe_concat_condition(X_test[real_mask], c_test[real_mask]), y_test[real_mask])

    return {
        'TRTR': float(trtr),
        'TSTR': float(tstr),
        'gap': float(abs(trtr - tstr)),
        'n_test': n_real,
        'n_synth': n_synth
    }


def cross_condition_transfer(source_cond, target_cond,
                             X_test, y_test, c_test,
                             X_synth, y_synth, c_synth,
                             clf_factory):
    src_mask = (c_synth == source_cond)
    tgt_mask = (c_test == target_cond)
    if src_mask.sum() < MIN_EVAL_SAMPLES or tgt_mask.sum() < MIN_EVAL_SAMPLES:
        return None
    clf = clf_factory()
    clf.fit(maybe_concat_condition(X_synth[src_mask], c_synth[src_mask]), y_synth[src_mask])
    score = clf.score(maybe_concat_condition(X_test[tgt_mask], c_test[tgt_mask]), y_test[tgt_mask])
    return {
        'source': source_cond,
        'target': target_cond,
        'TSTR_cross': float(score),
        'n_source': int(src_mask.sum()),
        'n_target': int(tgt_mask.sum())
    }


def evaluate_cross_condition_grid(conditions, per_condition_metrics,
                                  X_test, y_test, c_test,
                                  X_synth, y_synth, c_synth):
    rows = []
    if len(conditions) < 2:
        return rows
    for source in conditions:
        for target in conditions:
            if source == target:
                continue
            result = cross_condition_transfer(
                source,
                target,
                X_test,
                y_test,
                c_test,
                X_synth,
                y_synth,
                c_synth,
                clf_factory_rf
            )
            if not result:
                continue
            baseline = per_condition_metrics.get(target, {}).get('utility', {}).get('TRTR')
            if baseline is not None and result.get('TSTR_cross') is not None:
                result['baseline_TRTR'] = float(baseline)
                result['gap_vs_baseline'] = float(baseline - result['TSTR_cross'])
            rows.append(result)
    return rows


def evaluate_cross_group_transfers(full_dataset, target_groups,
                                   X_synth, y_synth, c_synth):
    results = []
    if not target_groups:
        return results
    features = full_dataset['features']
    labels_full = full_dataset['labels']
    cond_full = full_dataset['conditions']
    group_full = full_dataset['groups']

    for grp in target_groups:
        mask = (group_full == grp)
        if CONDITION_FILTER:
            mask = mask & np.isin(cond_full, list(CONDITION_FILTER))
        if mask.sum() < MIN_EVAL_SAMPLES * 2:
            print(f"[transfer] Skipping group {grp}: insufficient samples ({int(mask.sum())})")
            continue
        X_grp = features[mask]
        y_grp = labels_full[mask]
        c_grp = cond_full[mask]

        grp_counts = Counter(list(zip(y_grp.tolist(), c_grp.tolist())))
        if grp_counts and min(grp_counts.values()) >= MIN_STRAT_SAMPLES:
            strat = np.array([f"{int(cls)}_{cond}" for cls, cond in zip(y_grp, c_grp)])
        else:
            strat = y_grp
        splitter = StratifiedShuffleSplit(n_splits=1, test_size=0.33, random_state=RANDOM_SEED)
        tr_idx, te_idx = next(splitter.split(X_grp, strat))

        X_trg_tr, X_trg_te = X_grp[tr_idx], X_grp[te_idx]
        y_trg_tr, y_trg_te = y_grp[tr_idx], y_grp[te_idx]
        c_trg_tr, c_trg_te = c_grp[tr_idx], c_grp[te_idx]

        conds_in_group = sorted(np.unique(c_trg_te)) if not CONDITION_FILTER else sorted(CONDITION_FILTER)
        per_condition_rows = []
        for cond in conds_in_group:
            metrics = tstr_trtr_per_condition(
                X_trg_tr,
                y_trg_tr,
                c_trg_tr,
                X_trg_te,
                y_trg_te,
                c_trg_te,
                X_synth,
                y_synth,
                c_synth,
                cond,
                clf_factory_rf
            )
            if metrics:
                per_condition_rows.append({'group': grp, 'condition': cond, **metrics})
        pooled_metrics = tstr_trtr_per_condition(
            X_trg_tr,
            y_trg_tr,
            c_trg_tr,
            X_trg_te,
            y_trg_te,
            c_trg_te,
            X_synth,
            y_synth,
            c_synth,
            None,
            clf_factory_rf
        )
        results.append({'group': grp, 'per_condition': per_condition_rows, 'pooled': pooled_metrics})
    return results


## 14. Condition-aware evaluation grid

In [107]:
# 14. Condition-aware evaluation grid (completed)

if RUN_CFG['condition'] == 'pooled':
    eval_conditions = []
elif CONDITION_FILTER:
    eval_conditions = sorted(CONDITION_FILTER)
else:
    eval_conditions = sorted(np.unique(c_train))

per_condition_metrics = {}
for cond in eval_conditions or [None]:  # include pooled if empty
    title = cond if cond is not None else "POOLED"
    # masks
    real_mask = np.ones(len(X_test), dtype=bool) if cond is None else (c_test == cond)
    synth_mask = np.ones(len(synth_X), dtype=bool) if cond is None else (synth_c == cond)

    n_real = int(real_mask.sum())
    n_synth = int(synth_mask.sum())

    if n_real < MIN_EVAL_SAMPLES:
        print(f"[{title}] Skipping: not enough real test samples ({n_real} < {MIN_EVAL_SAMPLES})")
        continue

    # TRTR & TSTR for this condition (train uses full train split; test restricted to condition)
    util = tstr_trtr_per_condition(
        X_train, y_train, c_train,
        X_test,  y_test,  c_test,
        synth_X, synth_y, synth_c,
        cond,
        clf_factory_rf,
        min_samples=MIN_EVAL_SAMPLES
    )

    # Detectability & distribution metrics for this slice
    det = None
    dist = None
    if n_synth >= MIN_EVAL_SAMPLES:
        det  = detectability_metrics(
            maybe_concat_condition(X_test[real_mask], c_test[real_mask]),
            maybe_concat_condition(synth_X[synth_mask], synth_c[synth_mask]),
            seed=RANDOM_SEED
        )
        dist = compute_distribution_metrics(
            X_test[real_mask],
            synth_X[synth_mask]
        )

    per_condition_metrics[title] = {
        "counts": {"real_test": n_real, "synthetic": n_synth},
        "utility": util,
        "detectability": det,
        "distribution": dist,
    }

# Pretty print a compact summary
rows = []
for k, v in per_condition_metrics.items():
    util = v["utility"] or {}
    det  = v["detectability"] or {}
    dist = v["distribution"] or {}
    ks_min = min([row["p_value"] for row in (dist.get("ks") or [])], default=np.nan)
    rows.append({
        "Condition": k,
        "Real n": v["counts"]["real_test"],
        "Synth n": v["counts"]["synthetic"],
        "TRTR": None if util is None else util.get("TRTR"),
        "TSTR": None if util is None else util.get("TSTR"),
        "Gap":  None if util is None else util.get("gap"),
        "RvS acc": det.get("accuracy") if det else None,
        "RvS AUC": det.get("auc") if det else None,
        "MMD (RBF)": dist.get("mmd") if dist else None,
        "KS min p": ks_min,
        "CorrSim": dist.get("corr_sim") if dist else None,
    })

summary_df = pd.DataFrame(rows)
pd.set_option("display.precision", 4)
print("\n" + "-"*72)
print("Condition-aware summary (TRTR/TSTR, detectability, distribution)")
print("-"*72)
display(summary_df)



------------------------------------------------------------------------
Condition-aware summary (TRTR/TSTR, detectability, distribution)
------------------------------------------------------------------------


Unnamed: 0,Condition,Real n,Synth n,TRTR,TSTR,Gap,RvS acc,RvS AUC,MMD (RBF),KS min p,CorrSim
0,S1,264,496,0.6288,0.6477,0.0189,0.6574,0.635,-0.7765,0.037,0.8632
1,S2_match,263,496,0.7262,0.6882,0.038,0.6853,0.6832,-1.6637,0.1689,0.9739
2,S2_nomatch,246,496,0.6667,0.7276,0.061,0.6408,0.6405,-0.384,0.0767,0.9717


## 15. Persist condition-aware artifacts

In [108]:

run_root = SAVE_ROOT / RUN_CFG['groups']
run_root.mkdir(parents=True, exist_ok=True)

allocation_summary = {f"{cls}_{cond}": int(count) for (cls, cond), count in bucket_allocation.items()}
train_bucket_summary = {f"{cls}_{cond}": int(count) for (cls, cond), count in Counter(list(zip(y_train.tolist(), c_train.tolist()))).items()}

def serialize_metrics(metrics):
    if metrics is None:
        return None
    if isinstance(metrics, dict):
        return {k: serialize_metrics(v) for k, v in metrics.items()}
    if isinstance(metrics, (list, tuple)):
        return [serialize_metrics(v) for v in metrics]
    if isinstance(metrics, pd.DataFrame):
        return metrics.to_dict(orient='records')
    if isinstance(metrics, (np.generic,)):
        return metrics.item()
    return metrics

baseline_distribution_serialized = {
    name: {
        'ks': serialize_metrics(result[0]),
        'mmd': float(result[1]),
    }
    for name, result in distribution_results.items()
}

clustering_serialized = {
    name: serialize_metrics(df)
    for name, df in clustering_reports.items()
}

def permanova_to_payload(result):
    if result is None:
        return None
    if hasattr(result, 'to_dict'):
        return serialize_metrics(result.to_dict())
    return str(result)

permanova_serialized = {
    name: permanova_to_payload(result)
    for name, result in permanova_results.items()
}

baseline_detectability = {
    'Correlation Sampling': float(acc_corr),
    'Mixup Baseline': float(acc_mix),
    'WGAN-GP': float(acc_wgan) if acc_wgan is not None else None,
}

payload = {
    'config': RUN_CFG,
    'group_filter': sorted(GROUP_FILTER),
    'condition_filter': sorted(CONDITION_FILTER) if CONDITION_FILTER else None,
    'train_counts': {
        'total': int(len(X_train)),
        'class': {int(k): int(v) for k, v in Counter(y_train.tolist()).items()},
        'bucket': train_bucket_summary
    },
    'allocation': allocation_summary,
    'baseline_metrics': {
        'distribution': baseline_distribution_serialized,
        'detectability_acc': baseline_detectability,
        'clustering': clustering_serialized,
        'permanova': permanova_serialized,
    },
    'per_condition': {cond: serialize_metrics(vals) for cond, vals in per_condition_metrics.items()},
    'pooled': serialize_metrics(pooled_metrics),
    'cross_condition': serialize_metrics(cross_condition_rows),
    'cross_group': serialize_metrics(cross_group_results)
}

with open(run_root / 'metrics_overview.json', 'w') as f:
    json.dump(payload, f, indent=2)

for cond, metrics in per_condition_metrics.items():
    cond_dir = run_root / cond
    cond_dir.mkdir(parents=True, exist_ok=True)
    with open(cond_dir / 'metrics.json', 'w') as f:
        json.dump(serialize_metrics(metrics), f, indent=2)

pooled_dir = run_root / 'pooled'
pooled_dir.mkdir(parents=True, exist_ok=True)
with open(pooled_dir / 'metrics.json', 'w') as f:
    json.dump(serialize_metrics(pooled_metrics), f, indent=2)

artifact = {
    'generator': RUN_CFG['generator'],
    'allocation': allocation_summary,
    'synth_features': synth_X,
    'synth_labels': synth_y,
    'synth_conditions': synth_c
}
with open(pooled_dir / f"synth_{RUN_CFG['generator']}.pkl", 'wb') as f:
    pickle.dump(artifact, f)

print('Saved condition-aware metrics and artifacts to', run_root)


Saved condition-aware metrics and artifacts to ../output/phase3_conditional/pooled


## References

- I. Gulrajani, F. Ahmed, M. Arjovsky, V. Dumoulin, and A. Courville. "Improved Training of Wasserstein GANs." *Advances in Neural Information Processing Systems*, 2017.
- C. Esteban, S. L. Hyland, and G. Rätsch. "Real-valued (Medical) Time Series Generation with Recurrent Conditional GANs." *NeurIPS Workshop on Adversarial Training*, 2017.
- G. De la Torre et al. "A Statistical Approach for Synthetic EEG Data Generation." *IEEE Transactions on Neural Systems and Rehabilitation Engineering*, 2022.
- T. Choi, B. Lee, and M. Kim. "Data Augmentation with Gaussian Copula Models for Time-Series Classification." *Pattern Recognition Letters*, 2021.
- L. Fawcett and D. Clare. "Synthetic Data Generation for Time Series via Clustering and Interpolation." *Journal of Computational Science*, 2020.

