# Phase 5: Hyperparameter tuning for Classwise Interpolation

## 5.0. Path & Model Setup

In [2]:
import numpy as np
import pandas as pd
from pathlib import Path
from itertools import product
import json
import time

from scipy.stats import ks_2samp, spearmanr
from scipy.spatial.distance import cdist

from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)

DATA_DIR = Path("../output/band_extraction")
SYN_BASE_DIR = Path("../output/synthetic_generation")
EVAL_BASE_DIR  = Path("../output/model_tuning")
EVAL_BASE_DIR.mkdir(parents=True, exist_ok=True)

MODEL_KEY  = "interp"
MODEL_NAME = "Classwise Interpolation"

BAND_COLS = ["Delta", "Theta", "Alpha", "Beta", "Gamma"]
CANONICAL_CONDITIONS = ["S1", "S2_match", "S2_nomatch"]

## 5.1. Base real data and baseline best model

### 5.1.1. Band Extraction Real Data

In [None]:
real_fp = DATA_DIR / "band_features_segments.csv"
real_df = pd.read_csv(real_fp)

print("Real band features shape:", real_df.shape)
real_df.head()

Real band features shape: (60672, 13)


Unnamed: 0,dataset_split,file_name,subject_type,subject_id,channel,trial,matching_condition,Delta,Theta,Alpha,Beta,Gamma,total_power
0,train,Data1.csv,a,co2a0000364,FP1,0,S1 obj,20.048105,5.830134,0.854299,6.705598,6.848762,40.286898
1,train,Data1.csv,a,co2a0000364,FP2,0,S1 obj,21.769006,6.052321,1.013807,16.487621,15.773774,61.09653
2,train,Data1.csv,a,co2a0000364,F7,0,S1 obj,7.742259,6.272004,1.893497,39.119253,49.533282,104.560295
3,train,Data1.csv,a,co2a0000364,F8,0,S1 obj,11.400244,4.816262,2.360998,53.64694,44.50218,116.726624
4,train,Data1.csv,a,co2a0000364,AF1,0,S1 obj,13.188257,2.347635,0.54275,4.036543,2.914738,23.029923


In [7]:
print(real_df["matching_condition"].value_counts())

matching_condition
S1 obj         20480
S2 match       20416
S2 nomatch,    19776
Name: count, dtype: int64


In [8]:
print(real_df["subject_type"].value_counts())

subject_type
a    30400
c    30272
Name: count, dtype: int64


### 5.1.2. Best Model Data Generation Check

In [None]:
interp_dir = SYN_BASE_DIR / MODEL_KEY
interp_real_fp = interp_dir / f"{MODEL_KEY}_real.csv"
interp_syn_fp = interp_dir / f"{MODEL_KEY}_syn.csv"

interp_real = pd.read_csv(interp_real_fp)
interp_syn = pd.read_csv(interp_syn_fp)

In [15]:
print("Baseline real data shape:", interp_real.shape)
print("Baseline synthetic data shape:", interp_syn.shape)

Baseline real data shape: (30336, 9)
Baseline synthetic data shape: (30336, 9)


In [17]:
interp_syn.head()

Unnamed: 0,Delta,Theta,Alpha,Beta,Gamma,total_power,label,condition,source
0,-0.350379,-0.085281,0.427328,-0.759346,-0.610462,-0.351327,0,"S2 nomatch,",synthetic
1,-0.588359,-0.144221,-0.700369,-0.781926,-0.549787,-0.731038,1,"S2 nomatch,",synthetic
2,5.21563,3.877514,1.045012,-0.324019,-0.240247,4.820642,0,S2 match,synthetic
3,-0.612762,-0.781888,-0.686589,-0.918365,-0.555892,-0.878805,0,S2 match,synthetic
4,0.19897,0.583152,-0.480207,-0.564385,-0.178585,-0.020233,1,"S2 nomatch,",synthetic


## 5.2. Condition Harmonization and 6D Feature Extraction

### 5.2.1. Harmonize conditions

In [22]:
# Map the verbose condition strings to canonical ones
def canonical_condition(cond: str) -> str:
    if cond.startswith("S1"):
        return "S1"
    if "match" in cond and "no" not in cond:
        return "S2_match"
    if "nomatch" in cond:
        return "S2_nomatch"
    raise ValueError(f"Unknown condition: {cond}")

In [23]:
# Apply mapping
interp_real["condition_canon"] = interp_real["condition"].apply(canonical_condition)
interp_syn["condition_canon"]  = interp_syn["condition"].apply(canonical_condition)

In [24]:
# Sanity
print("Real canonical condition counts:")
print(interp_real["condition_canon"].value_counts())

Real canonical condition counts:
condition_canon
S1            10240
S2_match      10176
S2_nomatch     9920
Name: count, dtype: int64


In [26]:
print("Synthetic canonical condition counts:")
print(interp_syn["condition_canon"].value_counts())

Synthetic canonical condition counts:
condition_canon
S1            10240
S2_match      10176
S2_nomatch     9920
Name: count, dtype: int64


### 5.2.2. Band Power + Label + Condition extraction 

In [None]:
real_features = interp_real[BAND_COLS].to_numpy()
real_labels = interp_real["label"].to_numpy()
real_conds = interp_real["condition_canon"].to_numpy()

In [None]:
syn_features_baseline = interp_syn[BAND_COLS].to_numpy()
syn_labels_baseline = interp_syn["label"].to_numpy()
syn_conds_baseline = interp_syn["condition_canon"].to_numpy()

## 5.3. Hyperparameter Tuning

### 5.3.1. Classwise Interpolation Generator via Condition Slicing

In [None]:
class InterpGenerator(nn.Module):
    def __init__(self, input_dim=6, hidden_dim=32, output_dim=6):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim),
        )
    def forward(self, x):
        return self.net(x)

In [29]:
def get_condition_slice(X, y, conds, target_cond: str):
    mask = (conds == target_cond)
    return X[mask], y[mask]