# Phase 3: Synthetic EEG Data Generation

## 3.0. Load package & Setup

In [282]:
import os
import random
from pathlib import Path
from collections import defaultdict, Counter
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler
from sklearn.preprocessing import QuantileTransformer
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, roc_auc_score
from sklearn.cluster import KMeans
from sklearn.covariance import LedoitWolf
from sklearn.neighbors import NearestNeighbors
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)
random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
plt.rcParams["figure.figsize"] = (6, 4)

DATA_DIR = Path("../output/band_extraction")
OUT_DIR = Path("../output/synthetic_generation")
OUT_DIR.mkdir(parents=True, exist_ok=True)

## 3.1. Processed data features check

### 3.1.1. Dataset overview

In [241]:
# Data structure
LABEL_COL = "subject_type"         
CONDITION_COL = "matching_condition"   
SPLIT_COL = "dataset_split"
META_COLS = ["dataset_split", "file_name", "subject_type", "subject_id", "channel", "trial", "matching_condition", "Delta", "Theta", "Alpha", "Beta", "Gamma", "total_power"]
BAND_COLS = ["Delta", "Theta", "Alpha", "Beta", "Gamma", "total_power"]

In [242]:
# load current feature dataset
FEATURE_FP = DATA_DIR / "band_features_segments.csv"
df_all = pd.read_csv(FEATURE_FP)
print("Full data shape:", df_all.shape)
df_all.head()

Full data shape: (60672, 13)


Unnamed: 0,dataset_split,file_name,subject_type,subject_id,channel,trial,matching_condition,Delta,Theta,Alpha,Beta,Gamma,total_power
0,train,Data1.csv,a,co2a0000364,FP1,0,S1 obj,20.048105,5.830134,0.854299,6.705598,6.848762,40.286898
1,train,Data1.csv,a,co2a0000364,FP2,0,S1 obj,21.769006,6.052321,1.013807,16.487621,15.773774,61.09653
2,train,Data1.csv,a,co2a0000364,F7,0,S1 obj,7.742259,6.272004,1.893497,39.119253,49.533282,104.560295
3,train,Data1.csv,a,co2a0000364,F8,0,S1 obj,11.400244,4.816262,2.360998,53.64694,44.50218,116.726624
4,train,Data1.csv,a,co2a0000364,AF1,0,S1 obj,13.188257,2.347635,0.54275,4.036543,2.914738,23.029923


In [243]:
# Basic sanity checks
print("Split counts:")
print(df_all[SPLIT_COL].value_counts())

Split counts:
dataset_split
test     30720
train    29952
Name: count, dtype: int64


In [244]:
print("Subject_type counts:")
print(df_all[LABEL_COL].value_counts())

Subject_type counts:
subject_type
a    30400
c    30272
Name: count, dtype: int64


In [245]:
print("Matching_condition counts:")
print(df_all[CONDITION_COL].value_counts())

Matching_condition counts:
matching_condition
S1 obj         20480
S2 match       20416
S2 nomatch,    19776
Name: count, dtype: int64


In [246]:
# Train / Test based on dataset_split
df_train = df_all[df_all[SPLIT_COL] == "train"].reset_index(drop=True)
df_test = df_all[df_all[SPLIT_COL] == "test"].reset_index(drop=True)
print("Train shape:", df_train.shape)
print("Test shape :", df_test.shape)

Train shape: (29952, 13)
Test shape : (30720, 13)


In [247]:
print("Train label x condition:")
print(pd.crosstab(df_train[LABEL_COL], df_train[CONDITION_COL]))

Train label x condition:
matching_condition  S1 obj  S2 match  S2 nomatch,
subject_type                                     
a                     5120      5120         4800
c                     5120      5056         4736


In [248]:
print("Test label x condition:")
print(pd.crosstab(df_test[LABEL_COL], df_test[CONDITION_COL]))

Test label x condition:
matching_condition  S1 obj  S2 match  S2 nomatch,
subject_type                                     
a                     5120      5120         5120
c                     5120      5120         5120


### 3.1.2. Basic check for missing values and infinite features

In [249]:
# Check missing and extra columns
missing = [c for c in META_COLS if c not in df_all.columns]
extra = [c for c in df_all.columns if c not in META_COLS]

if missing:
    print("MISSING columns:", missing)
if extra:
    print("EXTRA columns:", extra)
if not missing:
    df_all = df_all[META_COLS].copy()

In [250]:
# Force to numeric and report any conversion issues
for col in BAND_COLS:
    df_all[col] = pd.to_numeric(df_all[col], errors="coerce")

print("Data types:")
print(df_all.dtypes)

Data types:
dataset_split          object
file_name              object
subject_type           object
subject_id             object
channel                object
trial                   int64
matching_condition     object
Delta                 float64
Theta                 float64
Alpha                 float64
Beta                  float64
Gamma                 float64
total_power           float64
dtype: object


In [251]:
# Check missing values
print("Missing values per column:")
print(df_all.isna().sum())
mask_nan_features = df_all[BAND_COLS].isna().any(axis=1)
n_nan_rows = mask_nan_features.sum()
print(f"Rows with NaN in any feature: {n_nan_rows}")

Missing values per column:
dataset_split         0
file_name             0
subject_type          0
subject_id            0
channel               0
trial                 0
matching_condition    0
Delta                 0
Theta                 0
Alpha                 0
Beta                  0
Gamma                 0
total_power           0
dtype: int64
Rows with NaN in any feature: 0


In [252]:
# Check for non-finite values
mask_nonfinite = ~np.isfinite(df_all[BAND_COLS].to_numpy()).all(axis=1)
n_nonfinite = mask_nonfinite.sum()
print(f"Rows with non-finite feature values: {n_nonfinite}")

Rows with non-finite feature values: 0


### 3.1.3. Clip outlier check and process

In [253]:
# clip extreme outliers to stabilize covariance
q_low = df_all[BAND_COLS].quantile(0.01)
q_high = df_all[BAND_COLS].quantile(0.99)

print("1st percentile for features:\n", q_low)
print("99th percentile for features:\n", q_high)

1st percentile for features:
 Delta          0.254321
Theta          0.126250
Alpha          0.108306
Beta           0.191547
Gamma          0.057350
total_power    1.368929
Name: 0.01, dtype: float64
99th percentile for features:
 Delta          125.841570
Theta           37.735722
Alpha           42.706482
Beta            23.624554
Gamma           18.142717
total_power    187.203228
Name: 0.99, dtype: float64


In [254]:
# Clip outliers FIRST on full dataset
df_all[BAND_COLS] = df_all[BAND_COLS].clip(lower=q_low, upper=q_high, axis=1)
print("After clipping, feature summary (all):")
df_all[BAND_COLS].describe(percentiles=[0.01, 0.5, 0.99])

After clipping, feature summary (all):


Unnamed: 0,Delta,Theta,Alpha,Beta,Gamma,total_power
count,60672.0,60672.0,60672.0,60672.0,60672.0,60672.0
mean,14.194337,5.21076,5.2124,4.203085,1.73345,31.266163
std,19.77333,6.28368,7.203401,4.136839,2.747217,31.317794
min,0.254321,0.12625,0.108306,0.191547,0.05735,1.368929
1%,0.254393,0.126368,0.1084,0.191558,0.057352,1.368965
50%,7.512234,3.138961,2.681088,2.934805,0.805226,22.106393
99%,125.82675,37.729023,42.705766,23.622675,18.129812,187.112226
max,125.84157,37.735722,42.706482,23.624554,18.142717,187.203228


### 3.1.4. Subject-wise train/test split

In [255]:
# make sure subjects don't appear in both splits
train_subj = set(df_train["subject_id"])
test_subj = set(df_test["subject_id"])
overlap = train_subj & test_subj

print(f"# unique subjects in train: {len(train_subj)}")
print(f"# unique subjects in test: {len(test_subj)}")
print(f"# overlapping subjects: {len(overlap)}")

if overlap:
    print("WARNING: Some subject_ids appear in BOTH train and test!")

# unique subjects in train: 16
# unique subjects in test: 16
# overlapping subjects: 16


In [256]:
# subject-wise train/test split
subjects = df_all["subject_id"].unique()
rng = np.random.default_rng(42)
rng.shuffle(subjects)

In [257]:
# 50/50 split
n_train = len(subjects) // 2
train_subjects = set(subjects[:n_train])
test_subjects = set(subjects[n_train:])

df_train = df_all[df_all["subject_id"].isin(train_subjects)].reset_index(drop=True)
df_test = df_all[df_all["subject_id"].isin(test_subjects)].reset_index(drop=True)

print("Train subjects:", len(train_subjects))
print("Test subjects :", len(test_subjects))
print("Overlap       :", len(train_subjects & test_subjects))

Train subjects: 8
Test subjects : 8
Overlap       : 0


### 3.1.5. Check of cleaned data for synthetic generation

In [258]:
# Overwrite dataset_split using the NEW split
df_all[SPLIT_COL] = np.where(df_all["subject_id"].isin(train_subjects), "train", "test")
df_train = df_all[df_all["subject_id"].isin(train_subjects)].reset_index(drop=True)
df_test = df_all[df_all["subject_id"].isin(test_subjects)].reset_index(drop=True)
print("Split counts with new split:")
print(df_all[SPLIT_COL].value_counts())

Split counts with new split:
dataset_split
train    30336
test     30336
Name: count, dtype: int64


In [259]:
print("Train shape with new split:", df_train.shape)
print("Test shape with new split:", df_test.shape)

Train shape with new split: (30336, 13)
Test shape with new split: (30336, 13)


In [260]:
print("Train label x condition with new split:")
print(pd.crosstab(df_train[LABEL_COL], df_train["matching_condition"]))

print("Test label x condition with new split:")
print(pd.crosstab(df_test[LABEL_COL], df_test["matching_condition"]))

Train label x condition with new split:
matching_condition  S1 obj  S2 match  S2 nomatch,
subject_type                                     
a                     5120      5120         4992
c                     5120      5056         4928
Test label x condition with new split:
matching_condition  S1 obj  S2 match  S2 nomatch,
subject_type                                     
a                     5120      5120         4928
c                     5120      5120         4928


### 3.1.6. Encode labels & standardize features

In [261]:
label_map = {"c": 0, "a": 1}

y_train = df_train[LABEL_COL].map(label_map).values
y_test = df_test[LABEL_COL].map(label_map).values

X_train_raw = df_train[BAND_COLS].values
X_test_raw = df_test[BAND_COLS].values

scaler = StandardScaler()
X_train = scaler.fit_transform(X_train_raw)
X_test = scaler.transform(X_test_raw)

conds_train = df_train[CONDITION_COL].values
conds_test = df_test[CONDITION_COL].values

print("Final matrices ready for synthetic generation:")
print("X_train shape:", X_train.shape)
print("X_test shape:", X_test.shape)
print("y_train distribution:", Counter(y_train))
print("y_test distribution:", Counter(y_test))
print("Train condition counts:", Counter(conds_train))
print("Test condition counts:", Counter(conds_test))

Final matrices ready for synthetic generation:
X_train shape: (30336, 6)
X_test shape: (30336, 6)
y_train distribution: Counter({1: 15232, 0: 15104})
y_test distribution: Counter({1: 15168, 0: 15168})
Train condition counts: Counter({'S1 obj': 10240, 'S2 match': 10176, 'S2 nomatch,': 9920})
Test condition counts: Counter({'S1 obj': 10240, 'S2 match': 10240, 'S2 nomatch,': 9856})


## 3.2. Synthetic Generation

### 3.2.1. Method 0: Mixup baseline Generator

In [262]:
def generate_mixup_baseline(real_features, n_synthetic=None, random_seed=42):
    """
    Mixup-style baseline for synthetic EEG feature generation.
    
    - Interpolates between two real samples
    - Adds small Gaussian noise proportional to feature std
    - Ensures same dimensionality as the input (6 bands)
    """
    np.random.seed(random_seed)

    if n_synthetic is None:
        n_synthetic = len(real_features)

    n_samples, n_features = real_features.shape
    synthetic_features = np.zeros((n_synthetic, n_features))

    # Noise scale
    noise_scale = 0.1 * np.std(real_features, axis=0)

    for i in range(n_synthetic):
        idx1, idx2 = np.random.choice(n_samples, 2, replace=False)

        alpha = np.random.uniform(0.3, 0.7)
        interpolated = alpha * real_features[idx1] + (1 - alpha) * real_features[idx2]

        # Add Gaussian noise
        noise = np.random.normal(0, noise_scale)
        synthetic = interpolated + noise

        # Ensure non-negative powers
        synthetic = np.clip(synthetic, 0, None)

        synthetic_features[i] = synthetic

    return synthetic_features

In [263]:
real_features = df_train[BAND_COLS].to_numpy()
n_synthetic_samples = real_features.shape[0]

print(f"Generating {n_synthetic_samples} mixup synthetic samples...")

synthetic_features_mixup = generate_mixup_baseline(
    real_features=real_features,
    n_synthetic=n_synthetic_samples,
    random_seed=RANDOM_SEED
)
print("Shape:", synthetic_features_mixup.shape)

Generating 30336 mixup synthetic samples...
Shape: (30336, 6)


In [264]:
X_syn_mixup = scaler.transform(synthetic_features_mixup)
y_syn_mixup = y_train.copy()
perm = np.random.permutation(len(y_syn_mixup))
X_syn_mixup = X_syn_mixup[perm]
y_syn_mixup = y_syn_mixup[perm]

## 3.2.2 Method 1: Correlation Sampling Generator

In [265]:
def generate_correlation_based_eeg(real_features, band_names, n_synthetic=None, random_seed=42):
    """
    Generate synthetic EEG band features using a correlation sampling method.
    real_features: np.ndarray of shape (n_samples, n_bands)
    band_names: list of band names (for logging, same order as columns)
    n_synthetic: number of synthetic samples to generate. If None, use n_synthetic = real_features.shape[0]
    """
    np.random.seed(random_seed)
    
    if n_synthetic is None:
        n_synthetic = real_features.shape[0]
    correlation_matrix = np.corrcoef(real_features.T)
    mean_features = np.mean(real_features, axis=0)
    std_features = np.std(real_features, axis=0)
    
    print("Correlation Matrix of Frequency Bands:")
    for i, band1 in enumerate(band_names):
        for j, band2 in enumerate(band_names):
            if j >= i:
                print(f"{band1:11s} - {band2:11s}: {correlation_matrix[i, j]:7.3f}")
    
    # construct covariance matrix: cov = D * Corr * D
    covariance_matrix = np.outer(std_features, std_features) * correlation_matrix
    
    # Generate synthetic features preserving the correlation structure
    synthetic_features = np.random.multivariate_normal(
        mean_features,
        covariance_matrix,
        size=n_synthetic
    )
    # Ensure non-negative powers (since band powers should be >= 0)
    synthetic_features = np.clip(synthetic_features, a_min=0.0, a_max=None)
    print(f"Generated {n_synthetic} synthetic feature vectors")
    print("Correlation structure preserved (in expectation)")
    return synthetic_features, correlation_matrix, covariance_matrix

In [266]:
# Use the band features from the TRAIN set after clipping + subject-wise split
real_features = df_train[BAND_COLS].to_numpy()
band_names = BAND_COLS

# Generate the same number of synthetic samples as real training samples
n_synthetic_samples = real_features.shape[0]
print(f"Generating {n_synthetic_samples} synthetic samples using correlation sampling")

synthetic_features_corr_raw, corr_matrix_corr, cov_corr = generate_correlation_based_eeg(
    real_features=real_features,
    band_names=band_names,
    n_synthetic=n_synthetic_samples,
    random_seed=RANDOM_SEED,
)
print(f"Shape of synthetic features (raw): {synthetic_features_corr_raw.shape}")

Generating 30336 synthetic samples using correlation sampling
Correlation Matrix of Frequency Bands:
Delta       - Delta      :   1.000
Delta       - Theta      :   0.634
Delta       - Alpha      :   0.282
Delta       - Beta       :   0.264
Delta       - Gamma      :   0.108
Delta       - total_power:   0.880
Theta       - Theta      :   1.000
Theta       - Alpha      :   0.365
Theta       - Beta       :   0.291
Theta       - Gamma      :   0.110
Theta       - total_power:   0.719
Alpha       - Alpha      :   1.000
Alpha       - Beta       :   0.305
Alpha       - Gamma      :   0.056
Alpha       - total_power:   0.505
Beta        - Beta       :   1.000
Beta        - Gamma      :   0.720
Beta        - total_power:   0.573
Gamma       - Gamma      :   1.000
Gamma       - total_power:   0.391
total_power - total_power:   1.000
Generated 30336 synthetic feature vectors
Correlation structure preserved (in expectation)
Shape of synthetic features (raw): (30336, 6)


In [267]:
# Standardize synthetic features using the scaler fit on REAL training data
X_syn_corr = scaler.transform(synthetic_features_corr_raw)
print("X_syn_corr shape:", X_syn_corr.shape)
print("First 5 standardized synthetic samples:")
print(X_syn_corr[:5])

X_syn_corr shape: (30336, 6)
First 5 standardized synthetic samples:
[[-0.57904767  0.55278216 -0.74434845 -0.28479666  0.14378794 -0.4558341 ]
 [-0.67371446 -0.80425419 -0.74434845 -1.01361852 -0.68446539 -0.96143269]
 [-0.67371446 -0.3452406   2.02880318  0.04906978 -0.1556162  -0.01796368]
 [ 0.33528254  0.49086969  0.20279222  1.81970979  2.63919986  1.08293578]
 [ 0.54250611  0.71968462  0.91355495 -0.60091889 -0.50811442  0.51950002]]


In [268]:
# Reuse y_train distribution for synthetic labels
# This preserves the class balance, even though the generator itself is label-agnostic
y_syn_corr = y_train.copy()

# shuffle (X_syn_corr, y_syn_corr) together to avoid any accidental order structure
perm = np.random.permutation(len(y_syn_corr))
X_syn_corr = X_syn_corr[perm]
y_syn_corr = y_syn_corr[perm]

print("Synthetic label distribution:", Counter(y_syn_corr))

Synthetic label distribution: Counter({1: 15232, 0: 15104})


### 3.2.3. Method 2: WGAN-GP Synthetic Generator

In [None]:
# Use the same band-power features
real_features = df_train[BAND_COLS].to_numpy().astype(np.float32)
n_synthetic_samples = real_features.shape[0]

print("Real features shape for WGAN-GP:", real_features.shape)
print("Number of synthetic samples to generate:", n_synthetic_samples)

Real features shape for WGAN-GP: (30336, 6)
Number of synthetic samples to generate: 30336


In [None]:
def generate_wgangp_eeg(real_features, n_synthetic=100, noise_dim=16, hidden_dim=64, n_critic=5, gp_lambda=10.0, lr=1e-4, batch_size=128, epochs=300, random_seed=42):
    """
    Train a compact WGAN-GP on band-power features and return synthetic samples.
    real_features: np.ndarray of shape (n_samples, n_features) – here (30336, 6)
    """
    torch.manual_seed(random_seed)
    np.random.seed(random_seed)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # real_features already clipped and cleaned
    data = torch.from_numpy(real_features.astype(np.float32))
    dataset = TensorDataset(data)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)

    feature_dim = real_features.shape[1]  # should be 6 (Delta, Theta, Alpha, Beta, Gamma, total_power)

    class Generator(nn.Module):
        def __init__(self):
            super().__init__()
            self.net = nn.Sequential(
                nn.Linear(noise_dim, hidden_dim),
                nn.LayerNorm(hidden_dim),
                nn.GELU(),
                nn.Linear(hidden_dim, hidden_dim),
                nn.GELU(),
                nn.Linear(hidden_dim, feature_dim),
            )

        def forward(self, z):
            x = self.net(z)
            # softplus keeps outputs positive but smooth
            return torch.nn.functional.softplus(x)

    class Critic(nn.Module):
        def __init__(self):
            super().__init__()
            self.net = nn.Sequential(
                nn.Linear(feature_dim, hidden_dim),
                nn.LayerNorm(hidden_dim),
                nn.GELU(),
                nn.Linear(hidden_dim, hidden_dim),
                nn.GELU(),
                nn.Linear(hidden_dim, 1),
            )

        def forward(self, x):
            return self.net(x)

    def gradient_penalty(critic, real, fake):
        batch_size = real.size(0)
        epsilon = torch.rand(batch_size, 1, device=real.device)
        epsilon = epsilon.expand_as(real)
        interpolated = epsilon * real + (1 - epsilon) * fake
        interpolated.requires_grad_(True)
        mixed_scores = critic(interpolated)
        grad = torch.autograd.grad(
            outputs=mixed_scores,
            inputs=interpolated,
            grad_outputs=torch.ones_like(mixed_scores),
            create_graph=True,
            retain_graph=True,
        )[0]
        grad = grad.view(batch_size, -1)
        gp = ((grad.norm(2, dim=1) - 1) ** 2).mean()
        return gp

    G = Generator().to(device)
    D = Critic().to(device)

    opt_G = optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.9))
    opt_D = optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.9))

    for epoch in range(epochs):
        for i, (real_batch,) in enumerate(loader):
            real_batch = real_batch.to(device)

            # Critic updates
            for _ in range(n_critic):
                z = torch.randn(real_batch.size(0), noise_dim, device=device)
                fake_batch = G(z).detach()

                opt_D.zero_grad()
                critic_real = D(real_batch).mean()
                critic_fake = D(fake_batch).mean()
                gp = gradient_penalty(D, real_batch, fake_batch)
                loss_D = -(critic_real - critic_fake) + gp_lambda * gp
                loss_D.backward()
                opt_D.step()

            # Generator update
            z = torch.randn(real_batch.size(0), noise_dim, device=device)
            opt_G.zero_grad()
            fake_batch = G(z)
            loss_G = -D(fake_batch).mean()
            loss_G.backward()
            opt_G.step()

        # simple monitoring every 50 epochs
        if (epoch + 1) % 50 == 0:
            with torch.no_grad():
                z = torch.randn(batch_size, noise_dim, device=device)
                preview = G(z).cpu().numpy()
            preview_mean = preview.mean(axis=0)
            print(
                f"Epoch {epoch + 1:03d}/{epochs} | "
                f"D: {loss_D.item():.4f} | G: {loss_G.item():.4f} | "
                f"preview mean={np.round(preview_mean, 3)}"
            )

    # Generate n_synthetic samples
    G.eval()
    with torch.no_grad():
        synth_chunks = []
        remaining = n_synthetic
        while remaining > 0:
            current = min(batch_size, remaining)
            z = torch.randn(current, noise_dim, device=device)
            synth = G(z).cpu().numpy()
            synth_chunks.append(synth)
            remaining -= current

    synthetic = np.vstack(synth_chunks)
    synthetic = np.clip(synthetic, a_min=0.0, a_max=None)
    return synthetic

In [271]:
print("Training WGAN-GP generator on band-power features...")
if torch is None:
    print("PyTorch not installed; skipping WGAN-GP synthesis.")
    synthetic_features_wgangp = None
else:
    synthetic_features_wgangp = generate_wgangp_eeg(
        real_features=real_features,
        n_synthetic=n_synthetic_samples,
        random_seed=RANDOM_SEED,
    )
    print(f"Generated {len(synthetic_features_wgangp)} WGAN-GP samples")
    print(f"Shape: {synthetic_features_wgangp.shape}")

Training WGAN-GP generator on band-power features...


Epoch 050/300 | D: -0.9164 | G: 101.8643 | preview mean=[12.439  4.572  4.873  5.001  2.099 29.57 ]
Epoch 100/300 | D: -0.9513 | G: 95.1498 | preview mean=[12.229  3.953  5.142  4.606  2.421 29.934]
Epoch 150/300 | D: -1.1573 | G: 95.1559 | preview mean=[14.428  4.374  4.736  4.736  1.577 30.608]
Epoch 200/300 | D: -4.8910 | G: 86.5985 | preview mean=[11.266  3.649  3.564  5.073  3.035 26.998]
Epoch 250/300 | D: 2.6741 | G: 85.2702 | preview mean=[11.784  5.169  4.427  4.363  2.219 29.297]
Epoch 300/300 | D: -1.0484 | G: 77.1386 | preview mean=[16.441  4.64   4.221  5.543  2.885 34.137]
Generated 30336 WGAN-GP samples
Shape: (30336, 6)


In [272]:
# Standardize synthetic samples
X_syn_wgangp = scaler.transform(synthetic_features_wgangp)
y_syn_wgangp = y_train.copy()
perm = np.random.permutation(len(y_syn_wgangp))
X_syn_wgangp = X_syn_wgangp[perm]
y_syn_wgangp = y_syn_wgangp[perm]

print("X_syn_wgangp shape:", X_syn_wgangp.shape)
print("y_syn_wgangp distribution:", Counter(y_syn_wgangp))

X_syn_wgangp shape: (30336, 6)
y_syn_wgangp distribution: Counter({1: 15232, 0: 15104})


### 3.2.4. Method 3: Gaussian Copula Sampling Generator

In [278]:
def _allocate_samples_by_class(labels, n_total):
    """
    Allocate synthetic samples per class, preserving empirical ratios.
    Returns a dict: {class_label: n_synth_for_that_class}
    """
    classes, counts = np.unique(labels, return_counts=True)
    ratios = counts / counts.sum()
    expected = ratios * n_total

    allocated = np.floor(expected).astype(int)
    remainder = n_total - allocated.sum()

    if remainder > 0:
        remainders = expected - allocated
        order = np.argsort(remainders)[::-1]
        for idx in order[:remainder]:
            allocated[idx] += 1

    return dict(zip(classes, allocated))

In [279]:
def generate_gaussian_copula_eeg(real_features, labels, n_synthetic=100, random_seed=42):
    """
    Gaussian copula sampling:
    1. For each class (0/1), fit a quantile transformer to map marginals -> N(0,1)
    2. Estimate regularised covariance (Ledoit-Wolf) in that latent space
    3. Sample multivariate normal per class and invert the transform
    4. Clip to non-negative band powers
    """
    rng = np.random.default_rng(random_seed)
    allocation = _allocate_samples_by_class(labels, n_synthetic)
    synthetic_blocks = []

    print("Generating Gaussian copula samples per class...")
    for cls, n_cls_samples in allocation.items():
        class_features = real_features[labels == cls]
        if len(class_features) == 0 or n_cls_samples == 0:
            continue

        # quantile transformer to approximate Gaussian marginals
        n_quantiles = min(len(class_features), 1000)
        transformer = QuantileTransformer(
            n_quantiles=n_quantiles,
            output_distribution="normal",
            random_state=random_seed,
        )

        latent = transformer.fit_transform(class_features)

        # Ledoit-Wolf for stable covariance
        cov_estimator = LedoitWolf().fit(latent)
        latent_mean = cov_estimator.location_
        latent_cov = cov_estimator.covariance_

        # sample in latent Gaussian space
        latent_samples = rng.multivariate_normal(
            latent_mean,
            latent_cov,
            size=n_cls_samples,
        )

        # invert back to band-power space
        samples = transformer.inverse_transform(latent_samples)

        # enforce non-negativity for power features
        samples = np.clip(samples, a_min=0, a_max=None)
        synthetic_blocks.append(samples)

        print(f"  Class {cls}: real={len(class_features)}, synthetic={n_cls_samples}")

    if not synthetic_blocks:
        raise ValueError("No synthetic samples were generated. Check class labels.")

    synthetic_features = np.vstack(synthetic_blocks)
    return synthetic_features

In [280]:
synthetic_features_copula = generate_gaussian_copula_eeg(
    real_features=real_features,
    labels=y_train,
    n_synthetic=n_synthetic_samples,
    random_seed=RANDOM_SEED,
)

print(f"Generated {len(synthetic_features_copula)} Gaussian copula samples")
print(f"Shape: {synthetic_features_copula.shape}")

Generating Gaussian copula samples per class...
  Class 0: real=15104, synthetic=15104
  Class 1: real=15232, synthetic=15232
Generated 30336 Gaussian copula samples
Shape: (30336, 6)


In [281]:
# Standardize with the same scaler used for X_train / X_test
X_syn_copula = scaler.transform(synthetic_features_copula)

# Preserve class ratio but randomize ordering
y_syn_copula = y_train.copy()
perm = np.random.permutation(len(y_syn_copula))
X_syn_copula = X_syn_copula[perm]
y_syn_copula = y_syn_copula[perm]

print("X_syn_copula shape:", X_syn_copula.shape)
print("y_syn_copula distribution:", Counter(y_syn_copula))

X_syn_copula shape: (30336, 6)
y_syn_copula distribution: Counter({1: 15232, 0: 15104})


### 3.2.5. Method 4: Classwise Interpolation Generator

In [283]:
def generate_classwise_interpolation_eeg(real_features, labels, n_synthetic=100, random_seed=42, k_neighbors=8, noise_scale=0.02,):
    """
    Class-conditional interpolation inspired by SMOTE.
    Operates in log-power space to better capture multiplicative structure.
    
    real_features: np.ndarray, shape (N, n_bands) – here (N, 6)
    labels       : np.ndarray, shape (N,) with class labels (0/1)
    """
    rng = np.random.default_rng(random_seed)
    allocation = _allocate_samples_by_class(labels, n_synthetic)
    synthetic_samples = []

    # work in log-power space
    log_features = np.log1p(real_features)

    for cls, n_cls_samples in allocation.items():
        class_mask = labels == cls
        class_features_log = log_features[class_mask]
        if len(class_features_log) == 0 or n_cls_samples == 0:
            continue

        # effective neighbors (avoid k > n-1)
        n_neighbors_eff = min(k_neighbors, len(class_features_log) - 1)
        if n_neighbors_eff <= 0:
            # fallback: jitter existing samples
            base_samples = np.repeat(
                class_features_log,
                repeats=max(1, n_cls_samples // max(1, len(class_features_log))),
                axis=0,
            )
            base_samples = base_samples[:n_cls_samples]
            jitter = rng.normal(0, noise_scale, size=base_samples.shape)
            augmented = base_samples + jitter
            synthetic_samples.append(np.expm1(augmented))
            continue

        nbrs = NearestNeighbors(n_neighbors=n_neighbors_eff + 1)
        nbrs.fit(class_features_log)

        class_std = np.std(class_features_log, axis=0, ddof=1)
        class_std[class_std == 0] = 1e-6

        for _ in range(n_cls_samples):
            # pick anchor
            idx = rng.integers(len(class_features_log))
            # get neighbors (excluding itself)
            neighbors = nbrs.kneighbors(
                class_features_log[idx].reshape(1, -1),
                return_distance=False,
            )[0]
            neighbors = neighbors[neighbors != idx]

            if len(neighbors) == 0:
                neighbor_idx = idx
            else:
                neighbor_idx = rng.choice(neighbors)

            # interpolate in log-space
            alpha = rng.uniform(0.2, 0.8)
            interpolated = (
                alpha * class_features_log[idx]
                + (1 - alpha) * class_features_log[neighbor_idx]
            )

            # add small Gaussian noise in log-space
            noise = rng.normal(0, noise_scale, size=class_features_log.shape[1]) * class_std
            synthetic_log = interpolated + noise

            synthetic_samples.append(np.expm1(synthetic_log))

    if not synthetic_samples:
        raise ValueError("Interpolation generator did not create any samples.")

    synthetic_features = np.vstack(synthetic_samples)
    synthetic_features = np.clip(synthetic_features, a_min=0, a_max=None)
    return synthetic_features

In [284]:
synthetic_features_interp = generate_classwise_interpolation_eeg(
    real_features=real_features,
    labels=y_train,
    n_synthetic=n_synthetic_samples,
    random_seed=RANDOM_SEED,
    k_neighbors=10,
    noise_scale=0.015,
)

print(f"Generated {len(synthetic_features_interp)} interpolation-based samples")
print(f"Shape: {synthetic_features_interp.shape}")

Generated 30336 interpolation-based samples
Shape: (30336, 6)


In [None]:
# Standardize using the same scaler
X_syn_interp = scaler.transform(synthetic_features_interp)

# Reuse label distribution and shuffle
y_syn_interp = y_train.copy()
perm = np.random.permutation(len(y_syn_interp))
X_syn_interp = X_syn_interp[perm]
y_syn_interp = y_syn_interp[perm]

print("X_syn_interp shape:", X_syn_interp.shape)
print("y_syn_interp distribution:", Counter(y_syn_interp))

X_syn_interp shape: (30336, 6)
y_syn_interp distribution: Counter({1: 15232, 0: 15104})
