# Phase 3: Synthetic EEG Data Generation

## Objective
Generate synthetic EEG data using methods from recent literature:
1. Correlation sampling method (Statistical Approach)
2. WGAN-GP approach (simplified)
3. Evaluation via TSTR, TRTR, clustering, and statistical tests

## Literature Review Summary
- **Correlation Sampling**: Analyze frequency band correlations, generate signals preserving structure
- **WGAN-GP**: More reliable than vanilla GAN for EEG generation
- **Evaluation**: Random Forest classifier, PERMANOVA, clustering overlap, TSTR/TRTR


In [161]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from scipy import signal, stats
from scipy.linalg import sqrtm
from scipy.spatial.distance import cdist
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA
from sklearn.covariance import LedoitWolf
from sklearn.preprocessing import QuantileTransformer
from sklearn.neighbors import NearestNeighbors
import pickle
import warnings
import random
warnings.filterwarnings('ignore')


RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)
random.seed(RANDOM_SEED)

print("=" * 60)
print("RANDOM SEED SET FOR REPRODUCIBILITY")
print(f"Seed value: {RANDOM_SEED}")
print("=" * 60)

# Set style
sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (14, 6)

# Data path
DATA_PATH = Path.home() / '.cache/kagglehub/datasets/nnair25/Alcoholics/versions/1'
TRAIN_PATH = DATA_PATH / 'SMNI_CMI_TRAIN'
TEST_PATH = DATA_PATH / 'SMNI_CMI_TEST'

print("\nPhase 3: Synthetic EEG Generation")
print("=" * 60)

RANDOM SEED SET FOR REPRODUCIBILITY
Seed value: 42

Phase 3: Synthetic EEG Generation


## 1. Load Analysis Results from Phase 2


In [162]:
# Load Phase 2 results (relative path to output folder)
with open('../output/phase2_analysis_results.pkl', 'rb') as f:
    analysis_results = pickle.load(f)

FREQUENCY_BANDS = analysis_results['frequency_bands']
SAMPLING_RATE = analysis_results['sampling_rate']

print("Loaded Phase 2 analysis results")
print(f"Frequency bands: {FREQUENCY_BANDS}")
print(f"Sampling rate: {SAMPLING_RATE} Hz")

Loaded Phase 2 analysis results
Frequency bands: {'Delta': (0.5, 4), 'Theta': (4, 8), 'Alpha': (8, 13), 'Beta': (13, 30), 'Gamma': (30, 50)}
Sampling rate: 256 Hz


## 2. Prepare Real Data for Training


### Data & Feature Summary

- **Corpus**: All available CSV files in `SMNI_CMI_TRAIN` (≈468 total; 235 alcoholic / 233 control). Adjust `MAX_FILES_PER_CLASS` if you need a subset.
- **Epochs**: Dynamically determined from `sensor position × trial` combinations across the full corpus.
- **Features**: 5-D band power vectors computed via Welch’s PSD at 256 Hz covering Δ (0.5–4 Hz), θ (4–8 Hz), α (8–13 Hz), β (13–30 Hz), γ (30–50 Hz).



In [163]:
# Load training files
train_files = sorted(list(TRAIN_PATH.glob('*.csv')))

# Separate alcoholic and control files  
alcoholic_files = []
control_files = []

print("Identifying subject types...")
for file in train_files:
    df_peek = pd.read_csv(file, nrows=1)
    subject_type = df_peek['subject identifier'].iloc[0]
    if subject_type == 'a':
        alcoholic_files.append(file)
    else:
        control_files.append(file)

print(f"Found {len(alcoholic_files)} alcoholic files")
print(f"Found {len(control_files)} control files")

# Configure which files to load for generator training
MAX_FILES_PER_CLASS = None

if MAX_FILES_PER_CLASS is None:
    selected_alcoholic = list(alcoholic_files)
    selected_control = list(control_files)
else:
    selected_alcoholic = list(alcoholic_files[:MAX_FILES_PER_CLASS])
    selected_control = list(control_files[:MAX_FILES_PER_CLASS])

sample_files = selected_alcoholic + selected_control
if not sample_files:
    raise ValueError("No training files selected. Check MAX_FILES_PER_CLASS or dataset path.")

random.shuffle(sample_files)


print(f"\nLoading {len(sample_files)} files for generator training...")
print(f"  Alcoholic files selected: {len(selected_alcoholic)}")
print(f"  Control files selected  : {len(selected_control)}")


Identifying subject types...
Found 235 alcoholic files
Found 233 control files

Loading 468 files for generator training...
  Alcoholic files selected: 235
  Control files selected  : 233


## 3. Extract EEG Features from Real Data


In [164]:
def extract_band_power(signal_data, fs=256, bands=None):
    """Extract power in each frequency band using Welch's method"""
    if bands is None:
        bands = FREQUENCY_BANDS
    
    freqs, psd = signal.welch(signal_data, fs=fs, nperseg=min(256, len(signal_data)))
    
    band_powers = {}
    for band_name, (low_freq, high_freq) in bands.items():
        idx = np.logical_and(freqs >= low_freq, freqs <= high_freq)
        band_power = np.trapz(psd[idx], freqs[idx])
        band_powers[band_name] = band_power
    
    return band_powers

# Extract features from sample files
real_features = []
real_signals = []
labels = []

print("Extracting features from real EEG data...")
progress_interval = max(1, len(sample_files) // 10) if sample_files else 1
for file_idx, file in enumerate(sample_files):
    df = pd.read_csv(file)
    subject_type = df['subject identifier'].iloc[0]
    
    # Get first few channels and trials
    channels = df['sensor position'].unique()[:5]
    trials = df['trial number'].unique()[:2]
    
    for channel in channels:
        for trial in trials:
            trial_data = df[
                (df['sensor position'] == channel) & 
                (df['trial number'] == trial)
            ].sort_values('sample num')
            
            if len(trial_data) >= 128:
                signal_data = trial_data['sensor value'].values[:256]
                band_powers = extract_band_power(signal_data)
                
                feature_vector = list(band_powers.values())
                real_features.append(feature_vector)
                real_signals.append(signal_data)
                labels.append(1 if subject_type == 'a' else 0)
    
    if ((file_idx + 1) % progress_interval == 0) or (file_idx + 1 == len(sample_files)):
        print(f"  Processed {file_idx + 1}/{len(sample_files)} files...")

real_features = np.array(real_features)
real_signals = np.array(real_signals)
labels = np.array(labels)

print(f"\\nExtracted {len(real_features)} epochs from real data")
print(f"Feature shape: {real_features.shape}")
print(f"Signal shape: {real_signals.shape}")
print(f"Class distribution: {np.sum(labels == 1)} alcoholic, {np.sum(labels == 0)} control")


Extracting features from real EEG data...
  Processed 46/468 files...
  Processed 92/468 files...
  Processed 138/468 files...
  Processed 184/468 files...
  Processed 230/468 files...
  Processed 276/468 files...
  Processed 322/468 files...
  Processed 368/468 files...
  Processed 414/468 files...
  Processed 460/468 files...
  Processed 468/468 files...
\nExtracted 2340 epochs from real data
Feature shape: (2340, 5)
Signal shape: (2340, 256)
Class distribution: 1175 alcoholic, 1165 control


## 4. Method 1: Correlation Sampling Approach

Based on "A Statistical Approach for Synthetic EEG Data Generation"

Steps:
1. Compute correlation matrix of frequency band features
2. Sample from multivariate normal distribution preserving correlations
3. Generate synthetic features matching real data statistics


In [165]:
def generate_correlation_based_eeg(real_features, n_synthetic=100, random_seed=42):
    """
    Generate synthetic EEG using correlation sampling method
    
    This preserves the correlation structure between frequency bands
    """
    # Set random seed for reproducibility
    np.random.seed(random_seed)
    
    # Compute correlation matrix and statistics
    correlation_matrix = np.corrcoef(real_features.T)
    mean_features = np.mean(real_features, axis=0)
    std_features = np.std(real_features, axis=0)
    
    print("Correlation Matrix of Frequency Bands:")
    band_names = list(FREQUENCY_BANDS.keys())
    for i, band1 in enumerate(band_names):
        for j, band2 in enumerate(band_names):
            if j >= i:
                print(f"{band1:6s} - {band2:6s}: {correlation_matrix[i,j]:6.3f}")
    
    # Generate synthetic features preserving correlation structure
    covariance_matrix = np.outer(std_features, std_features) * correlation_matrix
    
    synthetic_features = np.random.multivariate_normal(
        mean_features,
        covariance_matrix,
        size=n_synthetic
    )
    
    # Ensure non-negative powers
    synthetic_features = np.abs(synthetic_features)
    
    print(f"\\nGenerated {n_synthetic} synthetic feature vectors")
    print(f"Correlation structure preserved")
    
    return synthetic_features, correlation_matrix

# Generate synthetic data
n_synthetic_samples = len(real_features)
print(f"\\nGenerating {n_synthetic_samples} synthetic samples...")

synthetic_features_corr, corr_matrix = generate_correlation_based_eeg(
    real_features,
    n_synthetic=n_synthetic_samples,
    random_seed=RANDOM_SEED
)

print(f"Shape of synthetic features: {synthetic_features_corr.shape}")


\nGenerating 2340 synthetic samples...
Correlation Matrix of Frequency Bands:
Delta  - Delta :  1.000
Delta  - Theta :  0.570
Delta  - Alpha :  0.025
Delta  - Beta  : -0.048
Delta  - Gamma : -0.057
Theta  - Theta :  1.000
Theta  - Alpha :  0.352
Theta  - Beta  : -0.019
Theta  - Gamma : -0.010
Alpha  - Alpha :  1.000
Alpha  - Beta  :  0.041
Alpha  - Gamma :  0.042
Beta   - Beta  :  1.000
Beta   - Gamma :  0.952
Gamma  - Gamma :  1.000
\nGenerated 2340 synthetic feature vectors
Correlation structure preserved
Shape of synthetic features: (2340, 5)


## 5. Method 2: GAN-like Generation

Simplified approach inspired by WGAN-GP methodology
(Full implementation would require deep learning framework)


In [166]:
def generate_gan_based_eeg(real_features, n_synthetic=100, random_seed=42):
    """
    Simplified GAN-like generation using:
    - Interpolation between real samples
    - Addition of controlled gaussian noise
    
    Note: Full WGAN-GP would require TensorFlow/PyTorch implementation
    This is a simplified demonstration of the concept
    """
    # Set random seed for reproducibility
    np.random.seed(random_seed)
    
    synthetic_features = []
    
    for _ in range(n_synthetic):
        # Random interpolation between two real samples
        idx1, idx2 = np.random.choice(len(real_features), 2, replace=False)
        alpha = np.random.uniform(0.3, 0.7)
        
        interpolated = alpha * real_features[idx1] + (1 - alpha) * real_features[idx2]
        
        # Add controlled gaussian noise
        noise_scale = 0.1 * np.std(real_features, axis=0)
        noise = np.random.normal(0, noise_scale)
        synthetic_sample = interpolated + noise
        
        # Ensure non-negative values (power cannot be negative)
        synthetic_sample = np.abs(synthetic_sample)
        synthetic_features.append(synthetic_sample)
    
    return np.array(synthetic_features)

print("Generating synthetic data using GAN-like approach...")
synthetic_features_gan = generate_gan_based_eeg(
    real_features,
    n_synthetic=n_synthetic_samples,
    random_seed=RANDOM_SEED
)

print(f"Generated {len(synthetic_features_gan)} synthetic samples")
print(f"Shape: {synthetic_features_gan.shape}")


Generating synthetic data using GAN-like approach...
Generated 2340 synthetic samples
Shape: (2340, 5)


## 6. Evaluation: Distribution Comparison (KS Test & MMD)

Statistical tests to compare real vs synthetic data distributions


In [167]:
def evaluate_distributions(real_features, synthetic_features, method_name=""):
    """
    Compare distributions using:
    - KS test (Kolmogorov-Smirnov): tests if distributions are from same population
    - MMD (Maximum Mean Discrepancy): measures distance between distributions
    """
    print(f"\\n{'='*60}")
    print(f"Distribution Comparison: {method_name}")
    print("="*60)
    
    # KS test for each feature (frequency band)
    band_names = list(FREQUENCY_BANDS.keys())
    ks_results = []
    
    print("\\nKolmogorov-Smirnov Test Results:")
    print("(p-value > 0.05 suggests distributions are similar)")
    for i, band in enumerate(band_names):
        ks_stat, p_value = stats.ks_2samp(real_features[:, i], synthetic_features[:, i])
        ks_results.append({'band': band, 'ks_stat': ks_stat, 'p_value': p_value})
        
        significance = "✓ Similar" if p_value > 0.05 else "✗ Different"
        print(f"  {band:8s}: KS={ks_stat:.4f}, p={p_value:.4f} {significance}")
    
    # Simplified MMD computation
    def compute_mmd(X, Y):
        """Maximum Mean Discrepancy using pairwise distances"""
        XX = cdist(X, X, metric='euclidean')
        YY = cdist(Y, Y, metric='euclidean')
        XY = cdist(X, Y, metric='euclidean')
        
        mmd = np.mean(XX) + np.mean(YY) - 2 * np.mean(XY)
        return mmd
    
    mmd_score = compute_mmd(real_features, synthetic_features)
    print(f"\\nMMD Score: {mmd_score:.4f}")
    print("(Lower MMD indicates more similar distributions)")
    
    return ks_results, mmd_score

# Evaluate Correlation Sampling Method
ks_corr, mmd_corr = evaluate_distributions(
    real_features,
    synthetic_features_corr,
    "Correlation Sampling"
)

# Evaluate GAN-like Method  
ks_gan, mmd_gan = evaluate_distributions(
    real_features,
    synthetic_features_gan,
    "GAN-like Generation"
)


Distribution Comparison: Correlation Sampling
\nKolmogorov-Smirnov Test Results:
(p-value > 0.05 suggests distributions are similar)
  Delta   : KS=0.3573, p=0.0000 ✗ Different
  Theta   : KS=0.2269, p=0.0000 ✗ Different
  Alpha   : KS=0.3338, p=0.0000 ✗ Different
  Beta    : KS=0.5679, p=0.0000 ✗ Different
  Gamma   : KS=0.9350, p=0.0000 ✗ Different
\nMMD Score: -422.8791
(Lower MMD indicates more similar distributions)
Distribution Comparison: GAN-like Generation
\nKolmogorov-Smirnov Test Results:
(p-value > 0.05 suggests distributions are similar)
  Delta   : KS=0.1385, p=0.0000 ✗ Different
  Theta   : KS=0.1440, p=0.0000 ✗ Different
  Alpha   : KS=0.1321, p=0.0000 ✗ Different
  Beta    : KS=0.1154, p=0.0000 ✗ Different
  Gamma   : KS=0.8192, p=0.0000 ✗ Different
\nMMD Score: -29.6218
(Lower MMD indicates more similar distributions)


## 7. Evaluation: TSTR and TRTR

**TRTR** = Train on Real, Test on Real  
**TSTR** = Train on Synthetic, Test on Real

If TSTR ≈ TRTR, synthetic data quality is high


In [168]:
def evaluate_tstr_trtr(real_features, real_labels, synthetic_features, method_name=""):
    """
    TSTR/TRTR Evaluation from literature
    
    Validates synthetic data by comparing model performance when:
    - Training on real vs synthetic data
    - Testing on real data
    """
    print(f"\\n{'='*60}")
    print(f"TSTR/TRTR Evaluation: {method_name}")
    print("="*60)
    
    # Split real data
    X_train_real, X_test_real, y_train_real, y_test_real = train_test_split(
        real_features, real_labels, test_size=0.3, random_state=RANDOM_SEED
    )
    
    # Create synthetic labels matching real distribution
    n_alcoholic = np.sum(y_train_real == 1)
    n_control = np.sum(y_train_real == 0)
    y_synthetic = np.concatenate([
        np.ones(min(n_alcoholic, len(synthetic_features)//2)),
        np.zeros(min(n_control, len(synthetic_features)//2))
    ])
    X_synthetic = synthetic_features[:len(y_synthetic)]
    
    # TRTR: Train on Real, Test on Real
    print("\\n1. TRTR (Train on Real, Test on Real):")
    clf_trtr = RandomForestClassifier(n_estimators=100, random_state=RANDOM_SEED)
    clf_trtr.fit(X_train_real, y_train_real)
    y_pred_trtr = clf_trtr.predict(X_test_real)
    acc_trtr = accuracy_score(y_test_real, y_pred_trtr)
    print(f"   Accuracy: {acc_trtr:.4f}")
    
    # TSTR: Train on Synthetic, Test on Real
    print("\\n2. TSTR (Train on Synthetic, Test on Real):")
    clf_tstr = RandomForestClassifier(n_estimators=100, random_state=RANDOM_SEED)
    clf_tstr.fit(X_synthetic, y_synthetic)
    y_pred_tstr = clf_tstr.predict(X_test_real)
    acc_tstr = accuracy_score(y_test_real, y_pred_tstr)
    print(f"   Accuracy: {acc_tstr:.4f}")
    
    # Compare
    print(f"\\n3. Performance Comparison:")
    print(f"   TRTR: {acc_trtr:.4f}")
    print(f"   TSTR: {acc_tstr:.4f}")
    print(f"   Difference: {abs(acc_trtr - acc_tstr):.4f}")
    
    if abs(acc_trtr - acc_tstr) < 0.05:
        print("   ✓ Synthetic data quality: EXCELLENT")
    elif abs(acc_trtr - acc_tstr) < 0.10:
        print("   ✓ Synthetic data quality: GOOD")
    else:
        print("   ✗ Synthetic data quality: NEEDS IMPROVEMENT")
    
    return acc_trtr, acc_tstr

# Evaluate both methods
print("\\nEvaluating Correlation Sampling Method:")
acc_trtr_corr, acc_tstr_corr = evaluate_tstr_trtr(
    real_features, labels, synthetic_features_corr, "Correlation"
)

print("\\n" + "="*60)
print("Evaluating GAN-like Method:")
acc_trtr_gan, acc_tstr_gan = evaluate_tstr_trtr(
    real_features, labels, synthetic_features_gan, "GAN-like"
)


\nEvaluating Correlation Sampling Method:
TSTR/TRTR Evaluation: Correlation
\n1. TRTR (Train on Real, Test on Real):
   Accuracy: 0.6966
\n2. TSTR (Train on Synthetic, Test on Real):
   Accuracy: 0.3561
\n3. Performance Comparison:
   TRTR: 0.6966
   TSTR: 0.3561
   Difference: 0.3405
   ✗ Synthetic data quality: NEEDS IMPROVEMENT
Evaluating GAN-like Method:
TSTR/TRTR Evaluation: GAN-like
\n1. TRTR (Train on Real, Test on Real):
   Accuracy: 0.6966
\n2. TSTR (Train on Synthetic, Test on Real):
   Accuracy: 0.4758
\n3. Performance Comparison:
   TRTR: 0.6966
   TSTR: 0.4758
   Difference: 0.2208
   ✗ Synthetic data quality: NEEDS IMPROVEMENT


## 8. Evaluation: Real vs Synthetic Classification

Train classifier to distinguish real from synthetic.  
**Goal**: Classifier should perform at ~50% (chance level) if synthetic data is indistinguishable


In [169]:
def evaluate_real_vs_synthetic(real_features, synthetic_features, method_name=""):
    """
    Train classifier to distinguish real from synthetic
    
    From literature: If classifier performs at chance level (~50%),
    synthetic data is indistinguishable from real
    """
    print(f"\\n{'='*60}")
    print(f"Real vs Synthetic Classification: {method_name}")
    print("="*60)
    
    # Create combined dataset: 1=real, 0=synthetic
    X_combined = np.vstack([real_features, synthetic_features])
    y_combined = np.concatenate([
        np.ones(len(real_features)),
        np.zeros(len(synthetic_features))
    ])
    
    # Train classifier
    X_train, X_test, y_train, y_test = train_test_split(
        X_combined, y_combined, test_size=0.3, random_state=RANDOM_SEED
    )
    
    clf = RandomForestClassifier(n_estimators=100, random_state=RANDOM_SEED)
    clf.fit(X_train, y_train)
    y_pred = clf.predict(X_test)
    
    accuracy = accuracy_score(y_test, y_pred)
    
    print(f"\\nClassifier Accuracy: {accuracy:.4f}")
    
    if 0.45 <= accuracy <= 0.55:
        print("✓ EXCELLENT: Classifier at chance level (50%)")
        print("  → Synthetic data indistinguishable from real")
    elif 0.40 <= accuracy <= 0.60:
        print("✓ GOOD: Classifier struggles to distinguish")
    else:
        print("✗ POOR: Classifier easily distinguishes real from synthetic")
    
    # Confusion Matrix
    cm = confusion_matrix(y_test, y_pred)
    print("\\nConfusion Matrix:")
    print("                 Pred Real  Pred Synthetic")
    print(f"Actual Real:        {cm[1,1]:4d}         {cm[1,0]:4d}")
    print(f"Actual Synthetic:   {cm[0,1]:4d}         {cm[0,0]:4d}")
    
    return accuracy

# Evaluate Correlation Sampling
acc_corr = evaluate_real_vs_synthetic(
    real_features, synthetic_features_corr, "Correlation Sampling"
)

# Evaluate GAN-like
acc_gan = evaluate_real_vs_synthetic(
    real_features, synthetic_features_gan, "GAN-like"
)


Real vs Synthetic Classification: Correlation Sampling
\nClassifier Accuracy: 0.9779
✗ POOR: Classifier easily distinguishes real from synthetic
\nConfusion Matrix:
                 Pred Real  Pred Synthetic
Actual Real:         692           13
Actual Synthetic:     18          681
Real vs Synthetic Classification: GAN-like
\nClassifier Accuracy: 0.9338
✗ POOR: Classifier easily distinguishes real from synthetic
\nConfusion Matrix:
                 Pred Real  Pred Synthetic
Actual Real:         666           39
Actual Synthetic:     54          645


## 9. Improved Synthetic Generation Strategies

To reduce real-vs-synthetic separability while preserving frequency-band correlations, we explore two additional generators:

1. **Gaussian Copula Sampling**: preserves empirical marginals per class and matches correlation structure in a latent Gaussian space.
2. **Class-Conditional Interpolation**: SMOTE-like synthesis operating in log-power space with adaptive neighborhood mixing.



In [170]:
def _allocate_samples_by_class(labels, n_total):
    """Allocate synthetic samples per class, preserving empirical ratios."""
    classes, counts = np.unique(labels, return_counts=True)
    ratios = counts / counts.sum()
    expected = ratios * n_total
    allocated = np.floor(expected).astype(int)
    remainder = n_total - allocated.sum()
    if remainder > 0:
        remainders = expected - allocated
        order = np.argsort(remainders)[::-1]
        for idx in order[:remainder]:
            allocated[idx] += 1
    return dict(zip(classes, allocated))

def generate_gaussian_copula_eeg(real_features, labels, n_synthetic=100, random_seed=42):
    """
    Gaussian copula sampling:
    1. Fit class-conditional quantile transformers to map marginals to Gaussian space
    2. Estimate regularised covariance (Ledoit-Wolf) in latent space
    3. Sample multivariate normal per class and invert the transform
    """
    rng = np.random.default_rng(random_seed)
    allocation = _allocate_samples_by_class(labels, n_synthetic)
    synthetic_blocks = []

    print("Generating Gaussian copula samples per class...")
    for cls, n_cls_samples in allocation.items():
        class_features = real_features[labels == cls]
        if len(class_features) == 0 or n_cls_samples == 0:
            continue

        n_quantiles = min(len(class_features), 1000)
        transformer = QuantileTransformer(
            n_quantiles=n_quantiles,
            output_distribution='normal',
            random_state=random_seed
        )
        latent = transformer.fit_transform(class_features)

        cov_estimator = LedoitWolf().fit(latent)
        latent_mean = cov_estimator.location_
        latent_cov = cov_estimator.covariance_

        latent_samples = rng.multivariate_normal(
            latent_mean,
            latent_cov,
            size=n_cls_samples
        )

        samples = transformer.inverse_transform(latent_samples)
        samples = np.clip(samples, a_min=0, a_max=None)
        synthetic_blocks.append(samples)

        print(f"  Class {cls}: real={len(class_features)}, synthetic={n_cls_samples}")

    if not synthetic_blocks:
        raise ValueError("No synthetic samples were generated. Check class labels.")

    synthetic_features = np.vstack(synthetic_blocks)
    return synthetic_features

In [171]:
synthetic_features_copula = generate_gaussian_copula_eeg(
    real_features,
    labels,
    n_synthetic=n_synthetic_samples,
    random_seed=RANDOM_SEED
)

print(f"\nGenerated {len(synthetic_features_copula)} Gaussian copula samples")
print(f"Shape: {synthetic_features_copula.shape}")

Generating Gaussian copula samples per class...
  Class 0: real=1165, synthetic=1165
  Class 1: real=1175, synthetic=1175

Generated 2340 Gaussian copula samples
Shape: (2340, 5)


In [172]:
def generate_classwise_interpolation_eeg(
    real_features,
    labels,
    n_synthetic=100,
    random_seed=42,
    k_neighbors=8,
    noise_scale=0.02
):
    """
    Class-conditional interpolation inspired by SMOTE.
    Operates in log-power space to better capture multiplicative structure.
    """
    rng = np.random.default_rng(random_seed)
    allocation = _allocate_samples_by_class(labels, n_synthetic)
    synthetic_samples = []

    log_features = np.log1p(real_features)

    for cls, n_cls_samples in allocation.items():
        class_mask = labels == cls
        class_features_log = log_features[class_mask]
        if len(class_features_log) == 0 or n_cls_samples == 0:
            continue

        n_neighbors_eff = min(k_neighbors, len(class_features_log) - 1)
        if n_neighbors_eff <= 0:
            # Not enough samples to interpolate, fallback to jittering existing ones
            base_samples = np.repeat(class_features_log, repeats=max(1, n_cls_samples // max(1, len(class_features_log))), axis=0)
            base_samples = base_samples[:n_cls_samples]
            jitter = rng.normal(0, noise_scale, size=base_samples.shape)
            augmented = base_samples + jitter
            synthetic_samples.append(np.expm1(augmented))
            continue

        nbrs = NearestNeighbors(n_neighbors=n_neighbors_eff + 1)
        nbrs.fit(class_features_log)
        class_std = np.std(class_features_log, axis=0, ddof=1)
        class_std[class_std == 0] = 1e-6

        for _ in range(n_cls_samples):
            idx = rng.integers(len(class_features_log))
            neighbors = nbrs.kneighbors(class_features_log[idx].reshape(1, -1), return_distance=False)[0]
            neighbors = neighbors[neighbors != idx]
            if len(neighbors) == 0:
                neighbor_idx = idx
            else:
                neighbor_idx = rng.choice(neighbors)

            alpha = rng.uniform(0.2, 0.8)
            interpolated = (
                alpha * class_features_log[idx] +
                (1 - alpha) * class_features_log[neighbor_idx]
            )

            noise = rng.normal(0, noise_scale, size=class_features_log.shape[1]) * class_std
            synthetic_log = interpolated + noise
            synthetic_samples.append(np.expm1(synthetic_log))

    if not synthetic_samples:
        raise ValueError("Interpolation generator did not create any samples.")

    synthetic_features = np.vstack(synthetic_samples)
    synthetic_features = np.clip(synthetic_features, a_min=0, a_max=None)
    return synthetic_features

In [173]:
synthetic_features_interp = generate_classwise_interpolation_eeg(
    real_features,
    labels,
    n_synthetic=n_synthetic_samples,
    random_seed=RANDOM_SEED,
    k_neighbors=10, 
    noise_scale=0.015
)

print(f"\nGenerated {len(synthetic_features_interp)} interpolation-based samples")
print(f"Shape: {synthetic_features_interp.shape}")


Generated 2340 interpolation-based samples
Shape: (2340, 5)


### 9.1 Distribution and Quality Checks

Re-run the statistical and downstream evaluations for the new generators alongside previous baselines.


In [174]:
ks_copula, mmd_copula = evaluate_distributions(
    real_features,
    synthetic_features_copula,
    "Gaussian Copula"
)

ks_interp, mmd_interp = evaluate_distributions(
    real_features,
    synthetic_features_interp,
    "Classwise Interpolation"
)



Distribution Comparison: Gaussian Copula
\nKolmogorov-Smirnov Test Results:
(p-value > 0.05 suggests distributions are similar)
  Delta   : KS=0.0192, p=0.7800 ✓ Similar
  Theta   : KS=0.0184, p=0.8245 ✓ Similar
  Alpha   : KS=0.0184, p=0.8245 ✓ Similar
  Beta    : KS=0.0192, p=0.7800 ✓ Similar
  Gamma   : KS=0.0167, p=0.9013 ✓ Similar
\nMMD Score: -0.1115
(Lower MMD indicates more similar distributions)
Distribution Comparison: Classwise Interpolation
\nKolmogorov-Smirnov Test Results:
(p-value > 0.05 suggests distributions are similar)
  Delta   : KS=0.0201, p=0.7328 ✓ Similar
  Theta   : KS=0.0295, p=0.2609 ✓ Similar
  Alpha   : KS=0.0278, p=0.3274 ✓ Similar
  Beta    : KS=0.0235, p=0.5378 ✓ Similar
  Gamma   : KS=0.0333, p=0.1485 ✓ Similar
\nMMD Score: -0.1035
(Lower MMD indicates more similar distributions)


In [175]:
print("\nEvaluating Gaussian Copula Method:")
acc_trtr_copula, acc_tstr_copula = evaluate_tstr_trtr(
    real_features,
    labels,
    synthetic_features_copula,
    "Gaussian Copula"
)

print("\n" + "=" * 60)
print("Evaluating Classwise Interpolation Method:")
acc_trtr_interp, acc_tstr_interp = evaluate_tstr_trtr(
    real_features,
    labels,
    synthetic_features_interp,
    "Classwise Interpolation"
)




Evaluating Gaussian Copula Method:
TSTR/TRTR Evaluation: Gaussian Copula
\n1. TRTR (Train on Real, Test on Real):
   Accuracy: 0.6966
\n2. TSTR (Train on Synthetic, Test on Real):
   Accuracy: 0.3490
\n3. Performance Comparison:
   TRTR: 0.6966
   TSTR: 0.3490
   Difference: 0.3476
   ✗ Synthetic data quality: NEEDS IMPROVEMENT

Evaluating Classwise Interpolation Method:
TSTR/TRTR Evaluation: Classwise Interpolation
\n1. TRTR (Train on Real, Test on Real):
   Accuracy: 0.6966
\n2. TSTR (Train on Synthetic, Test on Real):
   Accuracy: 0.3134
\n3. Performance Comparison:
   TRTR: 0.6966
   TSTR: 0.3134
   Difference: 0.3832
   ✗ Synthetic data quality: NEEDS IMPROVEMENT


In [176]:
acc_copula_sep = evaluate_real_vs_synthetic(
    real_features,
    synthetic_features_copula,
    "Gaussian Copula"
)

acc_interp_sep = evaluate_real_vs_synthetic(
    real_features,
    synthetic_features_interp,
    "Classwise Interpolation"
)



Real vs Synthetic Classification: Gaussian Copula
\nClassifier Accuracy: 0.5349
✓ EXCELLENT: Classifier at chance level (50%)
  → Synthetic data indistinguishable from real
\nConfusion Matrix:
                 Pred Real  Pred Synthetic
Actual Real:         370          335
Actual Synthetic:    318          381
Real vs Synthetic Classification: Classwise Interpolation
\nClassifier Accuracy: 0.4338
✓ GOOD: Classifier struggles to distinguish
\nConfusion Matrix:
                 Pred Real  Pred Synthetic
Actual Real:         288          417
Actual Synthetic:    378          321


In [177]:
# === Synthetic EEG Tuner: per-class (k, noise) grid + early-stopping ===
# Paste this in one cell. Assumes you have:
#   real_features:  (N, 5) band-power features (Delta..Gamma)
#   labels:         (N,)   class labels {0,1}
# Modify BAND_NAMES if needed.

import numpy as np
from scipy import stats
from sklearn.neighbors import NearestNeighbors
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.covariance import LedoitWolf

# -----------------------------
# Config / Targets
# -----------------------------
BAND_NAMES = ["Delta","Theta","Alpha","Beta","Gamma"]
CLASS_VALUES = np.unique(labels)
assert set(CLASS_VALUES) == {0,1}, "This tuner assumes binary classes {0,1}."

PARAM_GRID_K = [5, 8, 10, 12, 15]
PARAM_GRID_NOISE = [0.00, 0.01, 0.015, 0.02, 0.03]

# Targets (edit to taste)
TARGET_RVS_MAX = 0.60     # real vs synthetic accuracy <= 0.60
TARGET_TSTR_MIN = 0.70    # TSTR >= 0.70
TARGET_GAP_MAX  = 0.15    # |TRTR - TSTR| <= 0.15
TARGET_KS_MINP  = 0.05    # per-band KS p-value > 0.05
TARGET_CORR_SIM = 0.90    # corr-matrix similarity >= 0.90

RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)

# -----------------------------
# Helpers: metrics
# -----------------------------
def ks_pvals_per_band(real, synth):
    pvals = []
    for j in range(real.shape[1]):
        _, p = stats.ks_2samp(real[:, j], synth[:, j])
        pvals.append(p)
    return np.array(pvals)

def rbf_mmd(X, Y, gamma=None):
    # Median heuristic for gamma if None
    Z = np.vstack([X, Y])
    if gamma is None:
        d2 = np.sum((Z[:, None, :] - Z[None, :, :])**2, axis=2)
        med2 = np.median(d2[d2>0])
        gamma = 1.0 / (med2 + 1e-8)
    def k(a,b): 
        d2 = np.sum((a[:,None,:]-b[None,:,:])**2, axis=2)
        return np.exp(-gamma * d2)
    Kxx = k(X,X); Kyy = k(Y,Y); Kxy = k(X,Y)
    m, n = X.shape[0], Y.shape[0]
    return Kxx.sum()/(m*m) + Kyy.sum()/(n*n) - 2*Kxy.sum()/(m*n)

def corr_matrix_similarity(real, synth):
    C_real = np.corrcoef(real.T)
    C_synth = np.corrcoef(synth.T)
    iu = np.triu_indices_from(C_real, k=1)
    r = np.corrcoef(C_real[iu], C_synth[iu])[0,1]
    return r

def per_band_spearman(real, synth, random_seed=RANDOM_SEED, max_samples=5000):
    if real.size == 0 or synth.size == 0:
        return np.full(real.shape[1], np.nan)

    rng = np.random.default_rng(random_seed)
    n = min(len(real), len(synth), max_samples)

    if len(real) > n:
        idx_real = rng.choice(len(real), size=n, replace=False)
    else:
        idx_real = np.arange(len(real))

    if len(synth) > n:
        idx_synth = rng.choice(len(synth), size=n, replace=False)
    else:
        idx_synth = np.arange(len(synth))

    real_sample = real[idx_real]
    synth_sample = synth[idx_synth]

    vals = []
    for j in range(real.shape[1]):
        rho = stats.spearmanr(real_sample[:, j], synth_sample[:, j]).correlation
        vals.append(rho)
    return np.array(vals)

def real_vs_synth_accuracy(real, synth):
    X = np.vstack([real, synth])
    y = np.hstack([np.zeros(len(real), dtype=int), np.ones(len(synth), dtype=int)])
    Xtr, Xte, ytr, yte = train_test_split(X, y, test_size=0.33, random_state=RANDOM_SEED, stratify=y)
    clf = RandomForestClassifier(n_estimators=200, random_state=RANDOM_SEED, class_weight="balanced")
    clf.fit(Xtr, ytr)
    return clf.score(Xte, yte)

def tstr_trtr_accuracy(real_X, real_y, synth_X, synth_y):
    # Classifier: RF; TSTR = train on synthetic, test on real; TRTR = train on real, test on real
    # Split real set for fair TRTR eval
    Xtr_r, Xte_r, ytr_r, yte_r = train_test_split(real_X, real_y, test_size=0.33, random_state=RANDOM_SEED, stratify=real_y)
    # TRTR
    clf_r = RandomForestClassifier(n_estimators=300, random_state=RANDOM_SEED, class_weight="balanced")
    clf_r.fit(Xtr_r, ytr_r)
    trtr = clf_r.score(Xte_r, yte_r)
    # TSTR
    clf_s = RandomForestClassifier(n_estimators=300, random_state=RANDOM_SEED, class_weight="balanced")
    clf_s.fit(synth_X, synth_y)
    tstr = clf_s.score(Xte_r, yte_r)
    return tstr, trtr

# -----------------------------
# Generator: classwise interpolation in log-space
# (replace with your own generate_classwise_interpolation_eeg_oneclass if you have it)
# -----------------------------
def _interp_one_class(Xc, n_out, k_neighbors=10, noise_scale=0.015, random_state=42):
    """
    Classwise interpolation in log-space with covariance-aware jitter.
    Robust to small class sizes; ensures k>=2 and <= len(Xc)-1.
    """
    rng = np.random.RandomState(random_state)
    eps = 1e-8
    Xc = np.asarray(Xc)
    if Xc.ndim != 2 or Xc.shape[0] == 0:
        raise ValueError("Xc must be (n_samples, n_features) with n_samples>0")

    # If the class is too small, just jitter existing points
    if Xc.shape[0] == 1:
        # log → jitter → exp
        xlog = np.log(Xc + eps)
        synth_log = np.repeat(xlog, n_out, axis=0)
        # fallback covariance: identity
        jitter = rng.normal(size=(n_out, Xc.shape[1]))
        synth_log = synth_log + noise_scale * jitter
        return np.exp(synth_log) - eps

    # Work in log-space to keep positivity on exp back-transform
    Xlog = np.log(Xc + eps)

    # Choose a valid k (at least 2, at most n-1)
    n_samp = Xlog.shape[0]
    k = int(np.clip(k_neighbors, 2, max(2, n_samp - 1)))

    # Fit neighbors in log-space
    nbrs = NearestNeighbors(n_neighbors=k, metric='euclidean').fit(Xlog)

    # Pick base points uniformly
    base_idx = rng.randint(0, n_samp, size=n_out)
    base = Xlog[base_idx]

    # For each base, pick one neighbor randomly (excluding self is handled by k<=n-1)
    neigh_idx = nbrs.kneighbors(base, return_distance=False)
    pick = neigh_idx[np.arange(n_out), rng.randint(0, neigh_idx.shape[1], size=n_out)]
    neigh = Xlog[pick]

    # Random convex combination in [0,1]
    alpha = rng.rand(n_out, 1)
    synth_log = alpha * base + (1 - alpha) * neigh

    # Covariance-aware jitter using Ledoit–Wolf (robust); fallback to diagonal if needed
    if noise_scale and noise_scale > 0:
        try:
            lw = LedoitWolf().fit(Xlog)
            S = lw.covariance_
            # numeric guard: ensure SPD
            # (LedoitWolf should be SPD; add tiny ridge just in case)
            S = S + 1e-8 * np.eye(S.shape[0])
            jitter = rng.multivariate_normal(mean=np.zeros(Xlog.shape[1]), cov=S, size=n_out)
        except Exception:
            jitter = rng.normal(size=(n_out, Xlog.shape[1]))
        synth_log = synth_log + noise_scale * jitter

    synth = np.exp(synth_log) - eps
    return synth


def generate_classwise_interpolation_both_classes(real_X, real_y, n_per_class, k_params, noise_params):
    # k_params = {0: k0, 1: k1}; noise_params = {0: n0, 1: n1}
    synth_list, synth_y = [], []
    for c in CLASS_VALUES:
        Xc = real_X[real_y == c]
        synth_c = _interp_one_class(
            Xc, n_per_class,
            k_neighbors=k_params[c],
            noise_scale=noise_params[c],
            random_state=RANDOM_SEED + c
        )
        synth_list.append(synth_c)
        synth_y.append(np.full(n_per_class, c, dtype=int))
    return np.vstack(synth_list), np.hstack(synth_y)

# -----------------------------
# Scoring + Early stopping
# -----------------------------
def score_combo(metrics):
    # Higher is better. Penalize violations smoothly.
    rvs = metrics["rvs_acc"]
    tstr, trtr = metrics["tstr"], metrics["trtr"]
    ks_min = metrics["ks_min_p"]
    corr_sim = metrics["corr_sim"]

    # Base score
    s = 0.0
    # push RvS down toward 0.55 (reward if <= 0.60, punish otherwise)
    s += 2.0 * max(0.0, 0.60 - rvs)
    # reward higher TSTR and smaller gap
    s += 1.5 * tstr
    s += 1.0 * max(0.0, 0.15 - abs(trtr - tstr))
    # reward KS min p and correlation similarity
    s += 1.0 * min(ks_min, 0.10) * 10.0   # cap effect; scale to ~[0..1]
    s += 1.0 * max(0.0, corr_sim - 0.85)  # only reward above 0.85
    # small reward for small |MMD|
    s += 0.5 * max(0.0, 0.2 - abs(metrics["mmd"]))  # closer to 0 is better

    return s

def meets_targets(metrics):
    return (metrics["rvs_acc"] <= TARGET_RVS_MAX and
            metrics["tstr"]     >= TARGET_TSTR_MIN and
            abs(metrics["trtr"] - metrics["tstr"]) <= TARGET_GAP_MAX and
            metrics["ks_min_p"] >= TARGET_KS_MINP and
            metrics["corr_sim"] >= TARGET_CORR_SIM)

# -----------------------------
# Grid search (per class) but evaluated jointly
# -----------------------------
def tune_interpolation_params(real_X, real_y, verbose=True):
    n_per_class = min(np.sum(real_y==0), np.sum(real_y==1))  # balance
    best = {"score": -np.inf, "params": None, "metrics": None}

    tried = 0
    for k0 in PARAM_GRID_K:
        for n0 in PARAM_GRID_NOISE:
            for k1 in PARAM_GRID_K:
                for n1 in PARAM_GRID_NOISE:
                    tried += 1
                    k_params = {0: k0, 1: k1}
                    noise_params = {0: n0, 1: n1}
                    synth_X, synth_y = generate_classwise_interpolation_both_classes(
                        real_X, real_y, n_per_class, k_params, noise_params
                    )

                    # Metrics
                    ks_p = ks_pvals_per_band(real_X, synth_X)
                    mmd = rbf_mmd(real_X, synth_X, gamma=None)
                    rvs = real_vs_synth_accuracy(real_X, synth_X)
                    tstr, trtr = tstr_trtr_accuracy(real_X, real_y, synth_X, synth_y)
                    corr_sim = corr_matrix_similarity(real_X, synth_X)
                    rho = per_band_spearman(real_X, synth_X)

                    metrics = {
                        "rvs_acc": rvs,
                        "tstr": tstr,
                        "trtr": trtr,
                        "gap": abs(trtr - tstr),
                        "ks_min_p": float(np.min(ks_p)),
                        "mmd": float(mmd),
                        "corr_sim": float(corr_sim),
                        "rho_min": float(np.nanmin(rho)),
                        "rho_mean": float(np.nanmean(rho)),
                    }
                    sc = score_combo(metrics)

                    if sc > best["score"]:
                        best = {"score": sc, "params": (k_params, noise_params), "metrics": metrics, 
                                "synth": (synth_X, synth_y)}

                    if verbose and tried % 20 == 0:
                        print(f"[{tried:4d}] k0={k0}, n0={n0:.3f} | k1={k1}, n1={n1:.3f} "
                              f"RvS={rvs:.3f} TSTR/TRTR={tstr:.3f}/{trtr:.3f} "
                              f"KSmin={metrics['ks_min_p']:.3f} CorrSim={corr_sim:.3f} MMD={mmd:.3f}")

                    # Early stopping: break as soon as all targets met
                    if meets_targets(metrics):
                        if verbose:
                            print("\n✓ Early-stop: targets met")
                            print("  Params:", k_params, noise_params)
                            print("  Metrics:", metrics)
                        return {"best": best, "early_stop": True}

    if verbose:
        print("\nNo combo met all targets. Returning best observed.")
        print("Best params:", best["params"])
        print("Best metrics:", best["metrics"])
    return {"best": best, "early_stop": False}

# -----------------------------
# Run tuning
# -----------------------------
result = tune_interpolation_params(real_features, labels, verbose=True)

best_params = result["best"]["params"]
best_metrics = result["best"]["metrics"]
best_synth_X, best_synth_y = result["best"]["synth"]

print("\n=== BEST COMBINATION ===")
print("k_params:", best_params[0], "noise_params:", best_params[1])
print("metrics :", best_metrics)



✓ Early-stop: targets met
  Params: {0: 5, 1: 10} {0: 0.0, 1: 0.0}
  Metrics: {'rvs_acc': 0.36186770428015563, 'tstr': 0.8408796895213454, 'trtr': 0.6998706338939198, 'gap': 0.1410090556274256, 'ks_min_p': 0.3261561157352818, 'mmd': 0.0004134147321114279, 'corr_sim': 0.995270237210689, 'rho_min': -0.03868685042240337, 'rho_mean': -0.002450164703319186}

=== BEST COMBINATION ===
k_params: {0: 5, 1: 5} noise_params: {0: 0.0, 1: 0.0}
metrics : {'rvs_acc': 0.33787289234760054, 'tstr': 0.8745148771021992, 'trtr': 0.6998706338939198, 'gap': 0.17464424320827943, 'ks_min_p': 0.4868306191329178, 'mmd': 0.000296001036291349, 'corr_sim': 0.9921030346539852, 'rho_min': -0.03970450882614366, 'rho_mean': -0.002404962460893624}


In [178]:
# === DROP-IN: Robust tuner with detectability, fair TSTR, matched Spearman ===
import numpy as np
from scipy import stats
from sklearn.neighbors import NearestNeighbors
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.covariance import LedoitWolf
from sklearn.metrics import roc_auc_score

BAND_NAMES = ["Delta","Theta","Alpha","Beta","Gamma"]
RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)

# Targets
TARGET_DETECT_MAX = 0.60   # detectability = max(acc, 1-acc) <= 0.60
TARGET_TSTR_MIN   = 0.70
TARGET_GAP_MAX    = 0.15
TARGET_KS_MINP    = 0.05
TARGET_CORR_SIM   = 0.90

# Grids (adjust as needed)
PARAM_GRID_K      = [5, 8, 10, 12]
PARAM_GRID_NOISE  = [0.00, 0.01, 0.015, 0.02]

# ---------- Metrics ----------
def ks_pvals_per_band(real, synth):
    return np.array([stats.ks_2samp(real[:, j], synth[:, j])[1] for j in range(real.shape[1])])

def rbf_mmd(X, Y, gamma=None):
    Z = np.vstack([X, Y])
    if gamma is None:
        d2 = np.sum((Z[:, None, :] - Z[None, :, :])**2, axis=2)
        med2 = np.median(d2[d2>0])
        gamma = 1.0 / (med2 + 1e-8)
    def k(a,b):
        d2 = np.sum((a[:,None,:]-b[None,:,:])**2, axis=2)
        return np.exp(-gamma * d2)
    Kxx = k(X,X); Kyy = k(Y,Y); Kxy = k(X,Y)
    m, n = len(X), len(Y)
    return Kxx.sum()/(m*m) + Kyy.sum()/(n*n) - 2*Kxy.sum()/(m*n)

def corr_matrix_similarity(real, synth):
    C_r, C_s = np.corrcoef(real.T), np.corrcoef(synth.T)
    iu = np.triu_indices_from(C_r, k=1)
    return np.corrcoef(C_r[iu], C_s[iu])[0,1]

def matched_spearman(real, synth, k=1):
    nbrs = NearestNeighbors(n_neighbors=k, metric='euclidean').fit(synth)
    idx = nbrs.kneighbors(real, return_distance=False)[:,0]
    return np.array([stats.spearmanr(real[:, j], synth[idx, j]).correlation for j in range(real.shape[1])])

def real_vs_synth_detectability(real, synth, seed=RANDOM_SEED):
    X = np.vstack([real, synth])
    y = np.hstack([np.zeros(len(real), dtype=int), np.ones(len(synth), dtype=int)])
    Xtr, Xte, ytr, yte = train_test_split(X, y, test_size=0.33, random_state=seed, stratify=y)
    clf = RandomForestClassifier(n_estimators=200, random_state=seed, class_weight="balanced")
    clf.fit(Xtr, ytr)
    acc = clf.score(Xte, yte)
    try:
        proba = clf.predict_proba(Xte)[:,1]
        auc = roc_auc_score(yte, proba)
    except Exception:
        auc = 0.5
    detect = max(acc, 1.0 - acc)
    return detect, acc, auc

# ---------- Generator: classwise interpolation in log-space ----------
def _interp_one_class(Xc, n_out, k_neighbors=10, noise_scale=0.015, random_state=RANDOM_SEED):
    rng = np.random.RandomState(random_state)
    eps = 1e-8
    Xc = np.asarray(Xc)
    if Xc.shape[0] == 1:
        xlog = np.log(Xc + eps)
        synth_log = np.repeat(xlog, n_out, axis=0)
        jitter = rng.normal(size=(n_out, Xc.shape[1]))
        synth_log = synth_log + noise_scale * jitter
        return np.exp(synth_log) - eps

    Xlog = np.log(Xc + eps)
    n_samp = Xlog.shape[0]
    k = int(np.clip(k_neighbors, 2, max(2, n_samp - 1)))
    nbrs = NearestNeighbors(n_neighbors=k, metric='euclidean').fit(Xlog)

    base_idx = rng.randint(0, n_samp, size=n_out)
    base = Xlog[base_idx]
    neigh_idx = nbrs.kneighbors(base, return_distance=False)
    pick = neigh_idx[np.arange(n_out), rng.randint(0, neigh_idx.shape[1], size=n_out)]
    neigh = Xlog[pick]

    alpha = rng.rand(n_out, 1)
    synth_log = alpha * base + (1 - alpha) * neigh

    if noise_scale and noise_scale > 0:
        try:
            lw = LedoitWolf().fit(Xlog)
            S = lw.covariance_ + 1e-8*np.eye(Xlog.shape[1])
            jitter = rng.multivariate_normal(mean=np.zeros(Xlog.shape[1]), cov=S, size=n_out)
        except Exception:
            jitter = rng.normal(size=(n_out, Xlog.shape[1]))
        synth_log = synth_log + noise_scale * jitter

    return np.exp(synth_log) - eps

def gen_interp_per_class(Xtr_r, ytr_r, n_per_class, k_params, noise_params):
    synth_list, synth_y = [], []
    for c in np.unique(ytr_r):
        Xc = Xtr_r[ytr_r == c]
        k_c = int(np.clip(k_params[c], 2, max(2, Xc.shape[0]-1)))
        n_c = float(noise_params[c])
        synth_c = _interp_one_class(Xc, n_per_class, k_neighbors=k_c, noise_scale=n_c, random_state=RANDOM_SEED + c)
        synth_list.append(synth_c)
        synth_y.append(np.full(n_per_class, c, dtype=int))
    return np.vstack(synth_list), np.hstack(synth_y)

# ---------- Fair TSTR/TRTR (train-only generation) ----------
def tstr_trtr_fair(real_X, real_y, k_params, noise_params, seed=RANDOM_SEED):
    Xtr_r, Xte_r, ytr_r, yte_r = train_test_split(real_X, real_y, test_size=0.33, random_state=seed, stratify=real_y)
    n_per_class = min(np.sum(ytr_r==0), np.sum(ytr_r==1))
    synth_X, synth_y = gen_interp_per_class(Xtr_r, ytr_r, n_per_class, k_params, noise_params)

    clf_r = RandomForestClassifier(n_estimators=300, random_state=seed, class_weight="balanced")
    clf_r.fit(Xtr_r, ytr_r)
    trtr = clf_r.score(Xte_r, yte_r)

    clf_s = RandomForestClassifier(n_estimators=300, random_state=seed, class_weight="balanced")
    clf_s.fit(synth_X, synth_y)
    tstr = clf_s.score(Xte_r, yte_r)
    return tstr, trtr, (Xtr_r, Xte_r, ytr_r, yte_r), (synth_X, synth_y)

# ---------- Scoring & Early-stop ----------
def score_combo(metrics):
    s = 0.0
    s += 2.0 * max(0.0, 0.60 - metrics["detect"])      # lower detectability is better
    s += 1.5 * metrics["tstr"]                          # higher TSTR is better
    s += 1.0 * max(0.0, 0.15 - abs(metrics["trtr"] - metrics["tstr"]))  # small gap
    s += 1.0 * min(metrics["ks_min_p"], 0.10) * 10.0    # KS min p (capped)
    s += 1.0 * max(0.0, metrics["corr_sim"] - 0.85)     # corr similarity above 0.85
    s += 0.5 * max(0.0, 0.2 - abs(metrics["mmd"]))      # MMD close to 0
    return s

def meets_targets(m):
    return (m["detect"] <= TARGET_DETECT_MAX and
            m["tstr"]   >= TARGET_TSTR_MIN and
            abs(m["trtr"] - m["tstr"]) <= TARGET_GAP_MAX and
            m["ks_min_p"] >= TARGET_KS_MINP and
            m["corr_sim"] >= TARGET_CORR_SIM)

# ---------- Grid search ----------
def tune_interpolation_params(real_X, real_y, verbose=True):
    classes = np.unique(real_y)
    assert set(classes) == {0,1}, "Binary classes {0,1} expected."
    best = {"score": -np.inf, "params": None, "metrics": None, "artifacts": None}
    tried = 0

    for k0 in PARAM_GRID_K:
        for n0 in PARAM_GRID_NOISE:
            for k1 in PARAM_GRID_K:
                for n1 in PARAM_GRID_NOISE:
                    tried += 1
                    k_params     = {0: k0, 1: k1}
                    noise_params = {0: n0, 1: n1}

                    # fair TSTR/TRTR with train-only generation
                    tstr, trtr, (Xtr_r, Xte_r, ytr_r, yte_r), (synth_X, synth_y) = tstr_trtr_fair(
                        real_X, real_y, k_params, noise_params, seed=RANDOM_SEED
                    )

                    # evaluate detectability on the same real (train+test) pool vs synth
                    detect, acc_raw, auc = real_vs_synth_detectability(real_X, synth_X, seed=RANDOM_SEED)
                    ks_p = ks_pvals_per_band(real_X, synth_X)
                    mmd  = rbf_mmd(real_X, synth_X)
                    corr = corr_matrix_similarity(real_X, synth_X)
                    rho  = matched_spearman(real_X, synth_X, k=1)

                    metrics = {
                        "detect": float(detect),
                        "rvs_acc_raw": float(acc_raw),
                        "rvs_auc": float(auc),
                        "tstr": float(tstr),
                        "trtr": float(trtr),
                        "gap": float(abs(trtr - tstr)),
                        "ks_min_p": float(np.min(ks_p)),
                        "mmd": float(mmd),
                        "corr_sim": float(corr),
                        "rho_min": float(np.nanmin(rho)),
                        "rho_mean": float(np.nanmean(rho)),
                    }
                    sc = score_combo(metrics)

                    if sc > best["score"]:
                        best = {"score": sc,
                                "params": (k_params, noise_params),
                                "metrics": metrics,
                                "artifacts": {"synth_X": synth_X, "synth_y": synth_y}}

                    if verbose and tried % 20 == 0:
                        print(f"[{tried:4d}] k0={k0}, n0={n0:.3f} | k1={k1}, n1={n1:.3f} "
                              f"Detect={metrics['detect']:.3f} (acc={metrics['rvs_acc_raw']:.3f}, AUC={metrics['rvs_auc']:.3f}) "
                              f"TSTR/TRTR={tstr:.3f}/{trtr:.3f} KSmin={metrics['ks_min_p']:.3f} "
                              f"CorrSim={corr:.3f} MMD={mmd:.4f} ρ_mean={metrics['rho_mean']:.3f}")

                    if meets_targets(metrics):
                        if verbose:
                            print("\n✓ Early-stop: targets met")
                            print("  Params:", k_params, noise_params)
                            print("  Metrics:", metrics)
                        return {"best": best, "early_stop": True}

    if verbose:
        print("\nNo combo met all targets. Returning best observed.")
        print("Best params:", best["params"])
        print("Best metrics:", best["metrics"])
    return {"best": best, "early_stop": False}

# Run it
result = tune_interpolation_params(real_features, labels, verbose=True)
best_params  = result["best"]["params"]
best_metrics = result["best"]["metrics"]
print("\n=== BEST COMBINATION ===")
print("k_params:", best_params[0], "noise_params:", best_params[1])
print("metrics :", best_metrics)


✓ Early-stop: targets met
  Params: {0: 5, 1: 5} {0: 0.0, 1: 0.0}
  Metrics: {'detect': 0.5027195027195027, 'rvs_acc_raw': 0.5027195027195027, 'rvs_auc': 0.38137230242969966, 'tstr': 0.7011642949547219, 'trtr': 0.6998706338939198, 'gap': 0.001293661060802087, 'ks_min_p': 0.3416900838456176, 'mmd': 0.0012224264816664832, 'corr_sim': 0.9820198663815284, 'rho_min': 0.8251903364370348, 'rho_mean': 0.9218774646917103}

=== BEST COMBINATION ===
k_params: {0: 5, 1: 5} noise_params: {0: 0.0, 1: 0.0}
metrics : {'detect': 0.5027195027195027, 'rvs_acc_raw': 0.5027195027195027, 'rvs_auc': 0.38137230242969966, 'tstr': 0.7011642949547219, 'trtr': 0.6998706338939198, 'gap': 0.001293661060802087, 'ks_min_p': 0.3416900838456176, 'mmd': 0.0012224264816664832, 'corr_sim': 0.9820198663815284, 'rho_min': 0.8251903364370348, 'rho_mean': 0.9218774646917103}


In [179]:
BEST_K = {0: 5, 1: 5}
BEST_NOISE = {0: 0.0, 1: 0.0}
SEED = 42

# Regenerate balanced synthetic set using train-only split for fairness
Xtr_r, Xte_r, ytr_r, yte_r = train_test_split(real_features, labels, test_size=0.33,
                                              random_state=SEED, stratify=labels)
n_per_class = min(np.sum(ytr_r==0), np.sum(ytr_r==1))
synth_X, synth_y = gen_interp_per_class(Xtr_r, ytr_r, n_per_class,
                                        k_params=BEST_K, noise_params=BEST_NOISE)

# Recompute headline metrics (detectability/TSTR/TRTR/KS/MMD/corr/rho)
detect, acc_raw, auc = real_vs_synth_detectability(real_features, synth_X, seed=SEED)
tstr, trtr, *_ = tstr_trtr_fair(real_features, labels, BEST_K, BEST_NOISE, seed=SEED)
ks_p = ks_pvals_per_band(real_features, synth_X)
mmd  = rbf_mmd(real_features, synth_X)
corr = corr_matrix_similarity(real_features, synth_X)
rho  = matched_spearman(real_features, synth_X, k=1)

summary = {
    "k_params": BEST_K, "noise_params": BEST_NOISE,
    "detectability": float(detect), "rvs_acc_raw": float(acc_raw), "rvs_auc": float(auc),
    "tstr": float(tstr), "trtr": float(trtr), "gap": float(abs(trtr - tstr)),
    "ks_min_p": float(np.min(ks_p)), "mmd": float(mmd),
    "corr_sim": float(corr), "rho_mean": float(np.nanmean(rho)), "rho_min": float(np.nanmin(rho))
}
print(summary)

{'k_params': {0: 5, 1: 5}, 'noise_params': {0: 0.0, 1: 0.0}, 'detectability': 0.5027195027195027, 'rvs_acc_raw': 0.5027195027195027, 'rvs_auc': 0.38137230242969966, 'tstr': 0.7011642949547219, 'trtr': 0.6998706338939198, 'gap': 0.001293661060802087, 'ks_min_p': 0.3416900838456176, 'mmd': 0.0012224264816664832, 'corr_sim': 0.9820198663815284, 'rho_mean': 0.9218774646917103, 'rho_min': 0.8251903364370348}


In [180]:
# Save artifacts
import pickle, os
os.makedirs("../output", exist_ok=True)
with open("../output/synth_interp_best.pkl", "wb") as f:
    pickle.dump({"synth_X": synth_X, "synth_y": synth_y, "summary": summary}, f)
print("Saved to ../output/synth_interp_best.pkl")

Saved to ../output/synth_interp_best.pkl
