# Phase 4: Evaluation Pipeline over Synthetic Data

## 4.0. Path & Model setup

In [None]:
import numpy as np
import pandas as pd
from pathlib import Path
import os
import pickle

from scipy.stats import ks_2samp, spearmanr
from scipy import stats
from scipy.spatial.distance import cdist

from collections import Counter

from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, confusion_matrix, roc_auc_score
from sklearn.model_selection import train_test_split
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from sklearn.preprocessing import StandardScaler, QuantileTransformer
from sklearn.neighbors import NearestNeighbors
from sklearn.covariance import LedoitWolf

import matplotlib.pyplot as plt
import seaborn as sns

import json

RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)

MIN_BUCKET_SIZE = 20

DATA_DIR = Path("../output/band_extraction")
BASE_SYN_DIR = Path("../output/synthetic_generation")
EVAL_BASE = Path("../output/synthetic_evaluation")
EVAL_BASE.mkdir(parents=True, exist_ok=True)

MODEL_INFO = {
    "mixup":  "Mixup",
    "corr":   "Correlation Sampling",
    "wgangp": "WGAN-GP",
    "copula": "Gaussian Copula",
    "interp": "Classwise Interpolation",
}

BAND_COLS = ["Delta", "Theta", "Alpha", "Beta", "Gamma"]

CANONICAL_CONDITIONS = ["S1", "S2_match", "S2_nomatch"]

### 4.0.1. Normalize condtions of real and synthetic data

In [2]:
def load_model_data(model_key: str):
    """
    Load real and synthetic CSVs for a given model.
    Returns:
      real_X, real_y, real_c, syn_X, syn_y, syn_c
    """
    folder = BASE_SYN_DIR / model_key

    real_fp = folder / f"{model_key}_real.csv"
    syn_fp  = folder / f"{model_key}_syn.csv"

    real_df = pd.read_csv(real_fp)
    syn_df  = pd.read_csv(syn_fp)

    # detect condition column name
    cond_col_real = "condition" if "condition" in real_df.columns else "matching_condition"
    cond_col_syn  = "condition" if "condition" in syn_df.columns  else "matching_condition"

    # Basic checks for band + label + condition
    missing_real = [col for col in BAND_COLS + ["label", cond_col_real] if col not in real_df.columns]
    missing_syn  = [col for col in BAND_COLS + ["label", cond_col_syn] if col not in syn_df.columns]
    if missing_real:
        print(f"[WARN] {model_key}_real missing columns: {missing_real}")
    if missing_syn:
        print(f"[WARN] {model_key}_syn missing columns: {missing_syn}")

    # Features
    real_X = real_df[BAND_COLS].to_numpy()
    syn_X  = syn_df[BAND_COLS].to_numpy()

    # Labels
    real_y = real_df["label"].to_numpy().astype(int)
    syn_y  = syn_df["label"].to_numpy().astype(int)

    # condition normalization
    def normalize_condition(series):
        def norm(x):
            if pd.isna(x):
                return "UNKNOWN"
            s = str(x)

            # strip whitespace + remove commas
            s = s.strip()
            s = s.replace(",", "")
            s = " ".join(s.split())   # collapse multiple spaces
            s_low = s.lower()

            # Heuristic mapping
            if s_low.startswith("s1"):
                return "S1"
            if "s2" in s_low and "match" in s_low and "nomatch" not in s_low:
                return "S2_match"
            if "s2" in s_low and "nomatch" in s_low:
                return "S2_nomatch"

            if s in CANONICAL_CONDITIONS:
                return s
            return s

        return series.map(norm).to_numpy()

    real_c = normalize_condition(real_df[cond_col_real])
    syn_c  = normalize_condition(syn_df[cond_col_syn])

    # counts per condition
    print(f"\n[{model_key}] real condition counts:",
          dict(zip(*np.unique(real_c, return_counts=True))))
    print(f"[{model_key}] syn  condition counts:",
          dict(zip(*np.unique(syn_c, return_counts=True))))

    return real_X, real_y, real_c, syn_X, syn_y, syn_c

## 4.1 Real vs Synthetic EEG Evaluation

### 4.1.1. Visualization Helper Functions
- PCA
- t-SNE

In [3]:
def compute_embeddings(real_features, synthetic_features, max_points=2000):
    if len(real_features) < 10 or len(synthetic_features) < 10:
        return None

    n_sub = min(max_points, len(real_features), len(synthetic_features))
    idx_real = np.random.choice(len(real_features), n_sub, replace=False)
    idx_syn  = np.random.choice(len(synthetic_features), n_sub, replace=False)

    X_real = real_features[idx_real]
    X_syn  = synthetic_features[idx_syn]

    X = np.vstack([X_real, X_syn])
    labels = np.array(["Real"] * n_sub + ["Synthetic"] * n_sub)

    # Standardize
    scaler = StandardScaler()
    X_scaled = scaler.fit_transform(X)

    # PCA
    pca = PCA(n_components=2, random_state=RANDOM_SEED)
    X_pca = pca.fit_transform(X_scaled)

    # t-SNE
    tsne = TSNE(
        n_components=2,
        perplexity=min(30, max(5, n_sub // 5)),
        random_state=RANDOM_SEED,
        init="pca",
        learning_rate="auto",
    )
    X_tsne = tsne.fit_transform(X_scaled)

    return {
        "X_pca": X_pca,
        "X_tsne": X_tsne,
        "labels": labels,
    }

In [4]:
def plot_embeddings(embeddings, title, save_path=None, show=True):
    if embeddings is None:
        print(f"[PLOT] Skipped {title} (not enough samples).")
        return

    X_pca = embeddings["X_pca"]
    X_tsne = embeddings["X_tsne"]
    labels = embeddings["labels"]

    fig, axes = plt.subplots(1, 2, figsize=(12, 5))

    sns.scatterplot(
        x=X_pca[:, 0], y=X_pca[:, 1], hue=labels,
        alpha=0.5, s=12, ax=axes[0], edgecolor=None
    )
    axes[0].set_title(f"PCA – {title}")

    sns.scatterplot(
        x=X_tsne[:, 0], y=X_tsne[:, 1], hue=labels,
        alpha=0.5, s=12, ax=axes[1], edgecolor=None, legend=False
    )
    axes[1].set_title(f"t-SNE – {title}")

    plt.tight_layout()
    if save_path is not None:
        plt.savefig(save_path, dpi=150)
        print(f"[SAVED PLOT] {save_path}")
    if show:
        plt.show()
    else:
        plt.close(fig)

### 4.1.2. KS and MMD Helper Functions

In [5]:
def evaluate_distribution_metrics(real_features, synthetic_features, model_name, title_suffix="GLOBAL"):
    print(f"\nDISTRIBUTION METRICS — {model_name} [{title_suffix}]")

    ks_results = []
    for i, band in enumerate(BAND_COLS):
        ks_stat, p = ks_2samp(real_features[:, i], synthetic_features[:, i])
        flag = "✓ Similar" if p > 0.05 else "✗ Different"
        print(f"{band:8s}: KS={ks_stat:.4f}, p={p:.4f}   {flag}")
        ks_results.append((band, ks_stat, p))

    # Subsample for MMD
    n = min(1000, len(real_features), len(synthetic_features))
    X = real_features[:n]
    Y = synthetic_features[:n]

    XX = cdist(X, X)
    YY = cdist(Y, Y)
    XY = cdist(X, Y)

    mmd = XX.mean() + YY.mean() - 2 * XY.mean()
    print(f"MMD: {mmd:.6f}")

    return ks_results, mmd

### 4.1.3. Real vs Synthetic Classification Helper Functions

In [6]:
def evaluate_real_vs_syn(real_X, syn_X, model_name, title_suffix="GLOBAL"):
    print(f"\nREAL-vs-SYN CLASSIFIER — {model_name} [{title_suffix}]")

    X = np.vstack([real_X, syn_X])
    y = np.concatenate([np.ones(len(real_X)), np.zeros(len(syn_X))])

    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.3, random_state=RANDOM_SEED
    )

    clf = RandomForestClassifier(n_estimators=200, random_state=RANDOM_SEED)
    clf.fit(X_train, y_train)

    y_pred = clf.predict(X_test)
    acc = accuracy_score(y_test, y_pred)
    print("Accuracy:", round(acc, 4))

    return acc

### 4.1.4. TSTR VS. TRTR Helper Functions

In [7]:
def evaluate_tstr_trtr(real_X, real_y, syn_X, syn_y, model_name, title_suffix="GLOBAL"):
    print(f"\nTSTR / TRTR — {model_name} [{title_suffix}]")

    # Train on real
    X_train_real, X_test_real, y_train_real, y_test_real = train_test_split(
        real_X, real_y, test_size=0.3, random_state=RANDOM_SEED
    )

    # TRTR
    clf_real = RandomForestClassifier(n_estimators=200, random_state=RANDOM_SEED)
    clf_real.fit(X_train_real, y_train_real)
    acc_trtr = clf_real.score(X_test_real, y_test_real)

    # Match synthetic train size to real train labels
    syn_X_trim = syn_X[:len(y_train_real)]
    syn_y_trim = syn_y[:len(y_train_real)]

    # TSTR
    clf_syn = RandomForestClassifier(n_estimators=200, random_state=RANDOM_SEED)
    clf_syn.fit(syn_X_trim, syn_y_trim)
    acc_tstr = clf_syn.score(X_test_real, y_test_real)

    gap = abs(acc_trtr - acc_tstr)

    print(f"TRTR: {acc_trtr:.4f}")
    print(f"TSTR: {acc_tstr:.4f}")
    print(f"Gap : {gap:.4f}")

    return acc_trtr, acc_tstr, gap

### 4.1.5. Label-Stratified Helper

In [8]:
def evaluate_condition(real_X, real_y, real_c, syn_X, syn_y, syn_c, model_name, cond):
    print(f"\nCONDITION {cond} — {model_name}")

    mask_r = (real_c == cond)
    mask_s = (syn_c == cond)

    n_real = mask_r.sum()
    n_syn  = mask_s.sum()
    print(f"  Samples: real={n_real}, syn={n_syn}")

    if n_real < 50 or n_syn < 50:
        print(f"  [Skip] Not enough samples for {cond}")
        return None

    Xr = real_X[mask_r]
    yr = real_y[mask_r]
    Xs = syn_X[mask_s]
    ys = syn_y[mask_s]

    # Embeddings (no saving here)
    emb = compute_embeddings(Xr, Xs)

    # KS + MMD
    ks, mmd = evaluate_distribution_metrics(Xr, Xs, model_name, cond)

    # Real-vs-Syn classifier
    acc = evaluate_real_vs_syn(Xr, Xs, model_name, cond)

    # TSTR/TRTR
    trtr, tstr, gap = evaluate_tstr_trtr(Xr, yr, Xs, ys, model_name, cond)

    return {
        "ks": ks,
        "mmd": mmd,
        "rvs_acc": acc,
        "trtr": trtr,
        "tstr": tstr,
        "gap": gap,
        "N_real": int(n_real),
        "N_syn": int(n_syn),
        "embeddings": emb,
        "real_X": Xr,
        "real_y": yr,
        "syn_X": Xs,
        "syn_y": ys,
    }

### 4.1.5. Condition-conditional Evaluation Functions

In [9]:
def evaluate_by_label(real_X, syn_X, real_y, syn_y, model_name, scope_tag):
    label_results = {}

    for label_val in [0, 1]:
        mask_r = (real_y == label_val)
        mask_s = (syn_y == label_val)

        n_real = mask_r.sum()
        n_syn  = mask_s.sum()
        label_name = "control" if label_val == 0 else "alcoholic"

        print(f"\nLABEL {label_name.capitalize()} (label={label_val}) — "
              f"{model_name} [{scope_tag}-BY_LABEL]")
        print(f"  Samples: real={n_real}, syn={n_syn}")

        if n_real < 50 or n_syn < 50:
            print(f"  [Skip] Not enough samples for {scope_tag}-{label_name}")
            continue

        Xr = real_X[mask_r]
        Xs = syn_X[mask_s]

        # Embeddings
        emb = compute_embeddings(Xr, Xs)

        # KS + MMD
        ks, mmd = evaluate_distribution_metrics(Xr, Xs, model_name,
                                                f"{scope_tag}-BY_LABEL-{label_name.capitalize()}")

        # Real-vs-Syn classifier
        acc = evaluate_real_vs_syn(Xr, Xs, model_name,
                                   f"{scope_tag}-BY_LABEL-{label_name.capitalize()}")

        label_results[label_val] = {
            "ks": ks,
            "mmd": mmd,
            "rvs_acc": acc,
            "N_real": int(n_real),
            "N_syn": int(n_syn),
            "embeddings": emb,
        }

    return label_results

### 4.1.6. Master Function for Synthetic EEG Evaluation of Each Model

In [10]:
def evaluate_model_step_by_step(model_key):
    model_name = MODEL_INFO[model_key]
    print(f"\n{'#' * 80}")
    print(f"EVALUATING MODEL: {model_name} ({model_key})")
    print("#" * 80)

    # Load
    real_X, real_y, real_c, syn_X, syn_y, syn_c = load_model_data(model_key)

    # Global counts per condition
    for cond in CANONICAL_CONDITIONS:
        n_r = np.sum(real_c == cond)
        n_s = np.sum(syn_c == cond)
        print(f"  [{cond}] counts — real={n_r}, syn={n_s}")

    global_emb = compute_embeddings(real_X, syn_X)
    ks_g, mmd_g = evaluate_distribution_metrics(real_X, syn_X, model_name, "GLOBAL")
    rvs_g = evaluate_real_vs_syn(real_X, syn_X, model_name, "GLOBAL")
    trtr_g, tstr_g, gap_g = evaluate_tstr_trtr(real_X, real_y, syn_X, syn_y, model_name, "GLOBAL")

    global_results = {
        "ks": ks_g,
        "mmd": mmd_g,
        "rvs_acc": rvs_g,
        "trtr": trtr_g,
        "tstr": tstr_g,
        "gap": gap_g,
        "N_real": int(len(real_X)),
        "N_syn": int(len(syn_X)),
        "embeddings": global_emb,
        "real_X": real_X,   # keep for by-label global eval
        "real_y": real_y,
        "syn_X": syn_X,
        "syn_y": syn_y,
    }

    cond_results = {}
    for cond in CANONICAL_CONDITIONS:
        res = evaluate_condition(
            real_X, real_y, real_c,
            syn_X, syn_y, syn_c,
            model_name, cond
        )
        if res is not None:
            cond_results[cond] = res

    by_label_global = evaluate_by_label(
        real_X, syn_X, real_y, syn_y, model_name, "GLOBAL"
    )

    by_label_condition = {}
    for cond, cres in cond_results.items():
        Xr = cres["real_X"]
        yr = cres["real_y"]
        Xs = cres["syn_X"]
        ys = cres["syn_y"]
        by_label_condition[cond] = evaluate_by_label(
            Xr, Xs, yr, ys, model_name, cond
        )

    return {
        "model_key": model_key,
        "model_name": model_name,
        "global": global_results,
        "condition": cond_results,
        "by_label_global": by_label_global,
        "by_label_condition": by_label_condition,
    }

## 4.2. Each Synthetic Evaluation from Models

### 4.2.1. Mixup Baseline Model Evaluation

In [11]:
mixup_results = evaluate_model_step_by_step("mixup")


################################################################################
EVALUATING MODEL: Mixup (mixup)
################################################################################

[mixup] real condition counts: {'S1': 10240, 'S2_match': 10176, 'S2_nomatch': 9920}
[mixup] syn  condition counts: {'S1': 10240, 'S2_match': 10176, 'S2_nomatch': 9920}
  [S1] counts — real=10240, syn=10240
  [S2_match] counts — real=10176, syn=10176
  [S2_nomatch] counts — real=9920, syn=9920

DISTRIBUTION METRICS — Mixup [GLOBAL]
Delta   : KS=0.1381, p=0.0000   ✗ Different
Theta   : KS=0.1263, p=0.0000   ✗ Different
Alpha   : KS=0.1324, p=0.0000   ✗ Different
Beta    : KS=0.1212, p=0.0000   ✗ Different
Gamma   : KS=0.1361, p=0.0000   ✗ Different
MMD: -0.343235

REAL-vs-SYN CLASSIFIER — Mixup [GLOBAL]
Accuracy: 0.6879

TSTR / TRTR — Mixup [GLOBAL]
TRTR: 0.7527
TSTR: 0.4828
Gap : 0.2699

CONDITION S1 — Mixup
  Samples: real=10240, syn=10240

DISTRIBUTION METRICS — Mixup [S1]
Delta   : KS=0.1423

### 4.2.2. Correlation Sampling Model Evaluation

In [12]:
corr_results = evaluate_model_step_by_step("corr")


################################################################################
EVALUATING MODEL: Correlation Sampling (corr)
################################################################################

[corr] real condition counts: {'S1': 10240, 'S2_match': 10176, 'S2_nomatch': 9920}
[corr] syn  condition counts: {'S1': 10240, 'S2_match': 10176, 'S2_nomatch': 9920}
  [S1] counts — real=10240, syn=10240
  [S2_match] counts — real=10176, syn=10176
  [S2_nomatch] counts — real=9920, syn=9920

DISTRIBUTION METRICS — Correlation Sampling [GLOBAL]
Delta   : KS=0.2536, p=0.0000   ✗ Different
Theta   : KS=0.2176, p=0.0000   ✗ Different
Alpha   : KS=0.2360, p=0.0000   ✗ Different
Beta    : KS=0.1644, p=0.0000   ✗ Different
Gamma   : KS=0.2495, p=0.0000   ✗ Different
MMD: -0.428586

REAL-vs-SYN CLASSIFIER — Correlation Sampling [GLOBAL]
Accuracy: 0.8977

TSTR / TRTR — Correlation Sampling [GLOBAL]
TRTR: 0.7527
TSTR: 0.4913
Gap : 0.2614

CONDITION S1 — Correlation Sampling
  Samples: real

### 4.2.3. WGAN-GP Model Evaluation

In [13]:
wgangp_results = evaluate_model_step_by_step("wgangp")


################################################################################
EVALUATING MODEL: WGAN-GP (wgangp)
################################################################################

[wgangp] real condition counts: {'S1': 10240, 'S2_match': 10176, 'S2_nomatch': 9920}
[wgangp] syn  condition counts: {'S1': 10240, 'S2_match': 10176, 'S2_nomatch': 9920}
  [S1] counts — real=10240, syn=10240
  [S2_match] counts — real=10176, syn=10176
  [S2_nomatch] counts — real=9920, syn=9920

DISTRIBUTION METRICS — WGAN-GP [GLOBAL]
Delta   : KS=0.0138, p=0.0059   ✗ Different
Theta   : KS=0.0409, p=0.0000   ✗ Different
Alpha   : KS=0.0275, p=0.0000   ✗ Different
Beta    : KS=0.0315, p=0.0000   ✗ Different
Gamma   : KS=0.0480, p=0.0000   ✗ Different
MMD: -0.241229

REAL-vs-SYN CLASSIFIER — WGAN-GP [GLOBAL]
Accuracy: 0.6598

TSTR / TRTR — WGAN-GP [GLOBAL]
TRTR: 0.7527
TSTR: 0.5036
Gap : 0.2491

CONDITION S1 — WGAN-GP
  Samples: real=10240, syn=10240

DISTRIBUTION METRICS — WGAN-GP [S1]
Delt

### 4.2.4. Gaussian Copula Sampling Model Evaluation

In [14]:
copula_results = evaluate_model_step_by_step("copula")


################################################################################
EVALUATING MODEL: Gaussian Copula (copula)
################################################################################

[copula] real condition counts: {'S1': 10240, 'S2_match': 10176, 'S2_nomatch': 9920}
[copula] syn  condition counts: {'S1': 10240, 'S2_match': 10176, 'S2_nomatch': 9920}
  [S1] counts — real=10240, syn=10240
  [S2_match] counts — real=10176, syn=10176
  [S2_nomatch] counts — real=9920, syn=9920

DISTRIBUTION METRICS — Gaussian Copula [GLOBAL]
Delta   : KS=0.0546, p=0.0000   ✗ Different
Theta   : KS=0.0476, p=0.0000   ✗ Different
Alpha   : KS=0.0443, p=0.0000   ✗ Different
Beta    : KS=0.0560, p=0.0000   ✗ Different
Gamma   : KS=0.0541, p=0.0000   ✗ Different
MMD: -0.267701

REAL-vs-SYN CLASSIFIER — Gaussian Copula [GLOBAL]
Accuracy: 0.6537

TSTR / TRTR — Gaussian Copula [GLOBAL]
TRTR: 0.7527
TSTR: 0.6815
Gap : 0.0712

CONDITION S1 — Gaussian Copula
  Samples: real=10240, syn=10240



### 4.2.5. Classwise Interpolation Model Evaluation

In [15]:
interp_results = evaluate_model_step_by_step("interp")


################################################################################
EVALUATING MODEL: Classwise Interpolation (interp)
################################################################################

[interp] real condition counts: {'S1': 10240, 'S2_match': 10176, 'S2_nomatch': 9920}
[interp] syn  condition counts: {'S1': 10240, 'S2_match': 10176, 'S2_nomatch': 9920}
  [S1] counts — real=10240, syn=10240
  [S2_match] counts — real=10176, syn=10176
  [S2_nomatch] counts — real=9920, syn=9920

DISTRIBUTION METRICS — Classwise Interpolation [GLOBAL]
Delta   : KS=0.0089, p=0.1765   ✓ Similar
Theta   : KS=0.0101, p=0.0907   ✓ Similar
Alpha   : KS=0.0084, p=0.2289   ✓ Similar
Beta    : KS=0.0119, p=0.0277   ✗ Different
Gamma   : KS=0.0129, p=0.0125   ✗ Different
MMD: -0.248633

REAL-vs-SYN CLASSIFIER — Classwise Interpolation [GLOBAL]
Accuracy: 0.5083

TSTR / TRTR — Classwise Interpolation [GLOBAL]
TRTR: 0.7527
TSTR: 0.7917
Gap : 0.0390

CONDITION S1 — Classwise Interpolation


## 4.3. Save evaluation metrics

In [16]:
def _to_py(obj):
    if isinstance(obj, np.generic):
        return obj.item()
    if isinstance(obj, (list, tuple)):
        return [_to_py(x) for x in obj]
    if isinstance(obj, dict):
        return {k: _to_py(v) for k, v in obj.items()}
    return obj

In [17]:
def save_model_results(model_results: dict):
    model_key  = model_results["model_key"]
    model_name = model_results["model_name"]

    model_dir   = EVAL_BASE / model_key
    metrics_dir = model_dir / "metrics"
    plots_dir   = model_dir / "plots"
    metrics_dir.mkdir(parents=True, exist_ok=True)
    plots_dir.mkdir(parents=True, exist_ok=True)

    # GLOBAL metrics (pooled labels)
    g = model_results["global"]
    global_payload = {
        "model_key": model_key,
        "model_name": model_name,
        "scope": "global",
        "condition": None,
        "distribution": {
            "ks": [
                {"band": band, "ks": float(ks_val), "p": float(p_val)}
                for (band, ks_val, p_val) in g["ks"]
            ],
            "mmd": float(g["mmd"]),
        },
        "classification": {
            "real_vs_syn_acc": float(g["rvs_acc"]),
            "trtr": float(g["trtr"]),
            "tstr": float(g["tstr"]),
            "gap": float(g["gap"]),
        },
        "sample_sizes": {
            "n_real": int(g["N_real"]),
            "n_syn": int(g["N_syn"]),
        },
    }

    with open(metrics_dir / "global.json", "w") as f:
        json.dump(_to_py(global_payload), f, indent=2)

    # GLOBAL plot
    plot_embeddings(
        g["embeddings"],
        title=f"{model_name} [GLOBAL]",
        save_path=plots_dir / "global_pca_tsne.png",
        show=False,
    )

    # CONDITION metrics (pooled labels) + plots
    cond_results = model_results["condition"]

    for cond in CANONICAL_CONDITIONS:
        c_res = cond_results.get(cond)
        if c_res is None:
            continue

        cond_payload = {
            "model_key": model_key,
            "model_name": model_name,
            "scope": "condition",
            "condition": cond,
            "distribution": {
                "ks": [
                    {"band": band, "ks": float(ks_val), "p": float(p_val)}
                    for (band, ks_val, p_val) in c_res["ks"]
                ],
                "mmd": float(c_res["mmd"]),
            },
            "classification": {
                "real_vs_syn_acc": float(c_res["rvs_acc"]),
                "trtr": float(c_res["trtr"]),
                "tstr": float(c_res["tstr"]),
                "gap": float(c_res["gap"]),
            },
            "sample_sizes": {
                "n_real": int(c_res["N_real"]),
                "n_syn": int(c_res["N_syn"]),
            },
        }

        fname = f"{cond}.json"
        with open(metrics_dir / fname, "w") as f:
            json.dump(_to_py(cond_payload), f, indent=2)

        plot_embeddings(
            c_res["embeddings"],
            title=f"{model_name} [{cond}]",
            save_path=plots_dir / f"{cond}_pca_tsne.png",
            show=False,
        )

    # GLOBAL BY LABEL metrics + plots
    by_label_global = model_results.get("by_label_global", {})

    for label_val, lbl_res in by_label_global.items():
        label_val = int(label_val)
        label_name = "control" if label_val == 0 else "alcoholic"

        payload_lbl = {
            "model_key": model_key,
            "model_name": model_name,
            "scope": "global_by_label",
            "condition": None,
            "label": label_name,
            "label_value": label_val,
            "distribution": {
                "ks": [
                    {"band": band, "ks": float(ks_val), "p": float(p_val)}
                    for (band, ks_val, p_val) in lbl_res["ks"]
                ],
                "mmd": float(lbl_res["mmd"]),
            },
            "classification": {
                "real_vs_syn_acc": float(lbl_res["rvs_acc"]),
            },
            "sample_sizes": {
                "n_real": int(lbl_res["N_real"]),
                "n_syn": int(lbl_res["N_syn"]),
            },
        }

        fname = f"global_{label_name}.json"
        with open(metrics_dir / fname, "w") as f:
            json.dump(_to_py(payload_lbl), f, indent=2)

        plot_embeddings(
            lbl_res["embeddings"],
            title=f"{model_name} [GLOBAL-{label_name}]",
            save_path=plots_dir / f"global_{label_name}_pca_tsne.png",
            show=False,
        )

    # CONDITION BY LABEL metrics + plots
    by_label_condition = model_results.get("by_label_condition", {})

    for cond, label_dict in by_label_condition.items():
        for label_val, lbl_res in label_dict.items():
            label_val = int(label_val)
            label_name = "control" if label_val == 0 else "alcoholic"

            payload_lbl_cond = {
                "model_key": model_key,
                "model_name": model_name,
                "scope": "condition_by_label",
                "condition": cond,
                "label": label_name,
                "label_value": label_val,
                "distribution": {
                    "ks": [
                        {"band": band, "ks": float(ks_val), "p": float(p_val)}
                        for (band, ks_val, p_val) in lbl_res["ks"]
                    ],
                    "mmd": float(lbl_res["mmd"]),
                },
                "classification": {
                    "real_vs_syn_acc": float(lbl_res["rvs_acc"]),
                },
                "sample_sizes": {
                    "n_real": int(lbl_res["N_real"]),
                    "n_syn": int(lbl_res["N_syn"]),
                },
            }

            fname = f"{cond}_{label_name}.json"
            with open(metrics_dir / fname, "w") as f:
                json.dump(_to_py(payload_lbl_cond), f, indent=2)

            plot_embeddings(
                lbl_res["embeddings"],
                title=f"{model_name} [{cond}-{label_name}]",
                save_path=plots_dir / f"{cond}_{label_name}_pca_tsne.png",
                show=False,
            )

    print(f"[SAVED] Metrics + plots for {model_name}: {model_dir}")

In [18]:
all_results = {
    "mixup": mixup_results,
    "corr": corr_results,
    "wgangp": wgangp_results,
    "copula": copula_results,
    "interp": interp_results,
}

for key, res in all_results.items():
    print(f"Saving: {key}")
    save_model_results(res)

Saving: mixup
[SAVED PLOT] ../output/synthetic_evaluation/mixup/plots/global_pca_tsne.png
[SAVED PLOT] ../output/synthetic_evaluation/mixup/plots/S1_pca_tsne.png
[SAVED PLOT] ../output/synthetic_evaluation/mixup/plots/S2_match_pca_tsne.png
[SAVED PLOT] ../output/synthetic_evaluation/mixup/plots/S2_nomatch_pca_tsne.png
[SAVED PLOT] ../output/synthetic_evaluation/mixup/plots/global_control_pca_tsne.png
[SAVED PLOT] ../output/synthetic_evaluation/mixup/plots/global_alcoholic_pca_tsne.png
[SAVED PLOT] ../output/synthetic_evaluation/mixup/plots/S1_control_pca_tsne.png
[SAVED PLOT] ../output/synthetic_evaluation/mixup/plots/S1_alcoholic_pca_tsne.png
[SAVED PLOT] ../output/synthetic_evaluation/mixup/plots/S2_match_control_pca_tsne.png
[SAVED PLOT] ../output/synthetic_evaluation/mixup/plots/S2_match_alcoholic_pca_tsne.png
[SAVED PLOT] ../output/synthetic_evaluation/mixup/plots/S2_nomatch_control_pca_tsne.png
[SAVED PLOT] ../output/synthetic_evaluation/mixup/plots/S2_nomatch_alcoholic_pca_tsne.

## 4.4. Training and Testing

### 4.4.1. Control-only Interpolation Experiment (c-only model → test on c and a)

In [19]:
def run_interp_control_only_experiment():
    """
    1. 'Train' / evaluate the Interp model on controls (c) only:
       - Compare real_c vs syn_c (distribution + classifier).
    2. Then test how this control-only model behaves on alcoholic data:
       - Compare real_a vs syn_c (distribution + classifier).
    """
    model_key = "interp"
    model_name = MODEL_INFO[model_key]

    # Load full Interp real/syn data
    real_X, real_y, real_c, syn_X, syn_y, syn_c = load_model_data(model_key)

    # Masks
    mask_ctrl_real = (real_y == 0)
    mask_alc_real  = (real_y == 1)
    mask_ctrl_syn  = (syn_y  == 0)

    X_ctrl_real = real_X[mask_ctrl_real]
    X_alc_real  = real_X[mask_alc_real]
    X_ctrl_syn  = syn_X[mask_ctrl_syn]

    print("\n" + "=" * 80)
    print("CONTROL-ONLY INTERP EXPERIMENT")
    print("=" * 80)
    print(f"Control-only subset: real_c = {X_ctrl_real.shape[0]}, "
          f"syn_c = {X_ctrl_syn.shape[0]}, real_a = {X_alc_real.shape[0]}")

    results = {}

    # In-distribution: how well does Interp model controls
    print("\n[1] In-distribution: REAL(control) vs SYN(control)")
    emb_c_in = compute_embeddings(X_ctrl_real, X_ctrl_syn)
    ks_c_in, mmd_c_in = evaluate_distribution_metrics(
        X_ctrl_real, X_ctrl_syn,
        model_name,
        title_suffix="CONTROL_ONLY real_c vs syn_c"
    )
    acc_c_in = evaluate_real_vs_syn(
        X_ctrl_real, X_ctrl_syn,
        model_name,
        title_suffix="CONTROL_ONLY real_c vs syn_c"
    )

    results["control_in_dist"] = {
        "ks": ks_c_in,
        "mmd": mmd_c_in,
        "rvs_acc": acc_c_in,
        "N_real_c": int(X_ctrl_real.shape[0]),
        "N_syn_c": int(X_ctrl_syn.shape[0]),
        "embeddings": emb_c_in,
    }

    # how different are alcoholic trials from the control-only model?
    print("\n[2] Out-of-distribution: REAL(alcoholic) vs SYN(control)")
    emb_a_vs_c = compute_embeddings(X_alc_real, X_ctrl_syn)
    ks_a_vs_c, mmd_a_vs_c = evaluate_distribution_metrics(
        X_alc_real, X_ctrl_syn,
        model_name,
        title_suffix="CONTROL_ONLY real_a vs syn_c"
    )
    acc_a_vs_c = evaluate_real_vs_syn(
        X_alc_real, X_ctrl_syn,
        model_name,
        title_suffix="CONTROL_ONLY real_a vs syn_c"
    )

    results["alc_vs_control_model"] = {
        "ks": ks_a_vs_c,
        "mmd": mmd_a_vs_c,
        "rvs_acc": acc_a_vs_c,
        "N_real_a": int(X_alc_real.shape[0]),
        "N_syn_c": int(X_ctrl_syn.shape[0]),
        "embeddings": emb_a_vs_c,
    }

    return results

### 4.4.2. Per-condition Interpolation experiments (S1, S2_match, S2_nomatch)

In [20]:
def run_interp_per_condition_experiments():
    """
    Implements the scholar's suggestion of per-condition models by
    evaluating Interp separately within each condition:

      - condition S1:    real_S1 vs syn_S1
      - condition S2_match: real_S2_match vs syn_S2_match
      - condition S2_nomatch: real_S2_nomatch vs syn_S2_nomatch

    This reuses the same Interp synthetic data but isolates the evaluation
    to each condition, making results fully condition-conditional.
    """
    model_key = "interp"
    model_name = MODEL_INFO[model_key]

    real_X, real_y, real_c, syn_X, syn_y, syn_c = load_model_data(model_key)

    per_cond_results = {}

    for cond in CANONICAL_CONDITIONS:
        print("\n" + "=" * 80)
        print(f"PER-CONDITION INTERP EXPERIMENT — CONDITION = {cond}")
        print("=" * 80)

        cres = evaluate_condition(
            real_X, real_y, real_c,
            syn_X, syn_y, syn_c,
            model_name,
            cond
        )

        if cres is None:
            print(f"[Skip] Not enough samples for condition {cond}")
            continue

        # also do label-wise decomposition within this condition
        Xr = cres["real_X"]
        yr = cres["real_y"]
        Xs = cres["syn_X"]
        ys = cres["syn_y"]

        by_label = evaluate_by_label(
            Xr, Xs, yr, ys,
            model_name,
            scope_tag=f"{cond}"
        )

        per_cond_results[cond] = {
            "condition_summary": cres,
            "by_label": by_label,
        }

    return per_cond_results

In [21]:
interp_results = evaluate_model_step_by_step("interp")


################################################################################
EVALUATING MODEL: Classwise Interpolation (interp)
################################################################################

[interp] real condition counts: {'S1': 10240, 'S2_match': 10176, 'S2_nomatch': 9920}
[interp] syn  condition counts: {'S1': 10240, 'S2_match': 10176, 'S2_nomatch': 9920}
  [S1] counts — real=10240, syn=10240
  [S2_match] counts — real=10176, syn=10176
  [S2_nomatch] counts — real=9920, syn=9920

DISTRIBUTION METRICS — Classwise Interpolation [GLOBAL]
Delta   : KS=0.0089, p=0.1765   ✓ Similar
Theta   : KS=0.0101, p=0.0907   ✓ Similar
Alpha   : KS=0.0084, p=0.2289   ✓ Similar
Beta    : KS=0.0119, p=0.0277   ✗ Different
Gamma   : KS=0.0129, p=0.0125   ✗ Different
MMD: -0.248633

REAL-vs-SYN CLASSIFIER — Classwise Interpolation [GLOBAL]
Accuracy: 0.5083

TSTR / TRTR — Classwise Interpolation [GLOBAL]
TRTR: 0.7527
TSTR: 0.7917
Gap : 0.0390

CONDITION S1 — Classwise Interpolation


In [22]:
# Control-only experiment: c-only model, then test on a
interp_control_only_results = run_interp_control_only_experiment()


[interp] real condition counts: {'S1': 10240, 'S2_match': 10176, 'S2_nomatch': 9920}
[interp] syn  condition counts: {'S1': 10240, 'S2_match': 10176, 'S2_nomatch': 9920}

CONTROL-ONLY INTERP EXPERIMENT
Control-only subset: real_c = 15104, syn_c = 15104, real_a = 15232

[1] In-distribution: REAL(control) vs SYN(control)

DISTRIBUTION METRICS — Classwise Interpolation [CONTROL_ONLY real_c vs syn_c]
Delta   : KS=0.0101, p=0.4260   ✓ Similar
Theta   : KS=0.0107, p=0.3549   ✓ Similar
Alpha   : KS=0.0080, p=0.7144   ✓ Similar
Beta    : KS=0.0147, p=0.0758   ✓ Similar
Gamma   : KS=0.0233, p=0.0005   ✗ Different
MMD: -0.173354

REAL-vs-SYN CLASSIFIER — Classwise Interpolation [CONTROL_ONLY real_c vs syn_c]
Accuracy: 0.4999

[2] Out-of-distribution: REAL(alcoholic) vs SYN(control)

DISTRIBUTION METRICS — Classwise Interpolation [CONTROL_ONLY real_a vs syn_c]
Delta   : KS=0.2685, p=0.0000   ✗ Different
Theta   : KS=0.3184, p=0.0000   ✗ Different
Alpha   : KS=0.2645, p=0.0000   ✗ Different
Beta 

In [23]:
# Per-condition Interp experiments
interp_per_condition_results = run_interp_per_condition_experiments()


[interp] real condition counts: {'S1': 10240, 'S2_match': 10176, 'S2_nomatch': 9920}
[interp] syn  condition counts: {'S1': 10240, 'S2_match': 10176, 'S2_nomatch': 9920}

PER-CONDITION INTERP EXPERIMENT — CONDITION = S1

CONDITION S1 — Classwise Interpolation
  Samples: real=10240, syn=10240

DISTRIBUTION METRICS — Classwise Interpolation [S1]
Delta   : KS=0.0093, p=0.7666   ✓ Similar
Theta   : KS=0.0123, p=0.4170   ✓ Similar
Alpha   : KS=0.0100, p=0.6860   ✓ Similar
Beta    : KS=0.0142, p=0.2537   ✓ Similar
Gamma   : KS=0.0254, p=0.0027   ✗ Different
MMD: -0.205496

REAL-vs-SYN CLASSIFIER — Classwise Interpolation [S1]
Accuracy: 0.4557

TSTR / TRTR — Classwise Interpolation [S1]
TRTR: 0.7122
TSTR: 0.7887
Gap : 0.0765

LABEL Control (label=0) — Classwise Interpolation [S1-BY_LABEL]
  Samples: real=5120, syn=5120

DISTRIBUTION METRICS — Classwise Interpolation [S1-BY_LABEL-Control]
Delta   : KS=0.0152, p=0.5923   ✓ Similar
Theta   : KS=0.0221, p=0.1651   ✓ Similar
Alpha   : KS=0.0107, 

## 4.5. Save Best Baseline Model

In [32]:
# Define the "best baseline interp" bundle
best_interp_base = {
    "model_key": "interp",
    "model_name": MODEL_INFO["interp"],
    "interp_results": interp_results,
    "interp_control_only_results": interp_control_only_results,
    "interp_per_condition_results": interp_per_condition_results,
}

In [33]:
# Make sure the directory exists
os.makedirs("../output/model_tuning", exist_ok=True)

In [34]:
# Save to the requested path
with open("../output/model_tuning/best_interp_base.pkl", "wb") as f:
    pickle.dump(best_interp_base, f)
print("Saved best baseline Interp Model to output/model_tuning/best_interp_base.pkl")

Saved best baseline Interp Model to output/model_tuning/best_interp_base.pkl
