# Phase 5: Hyperparameter tuning for Classwise Interpolation

## 5.0. Path & Model Setup

In [78]:
import numpy as np
import pandas as pd
from pathlib import Path
from itertools import product
import json
import time
import random

from scipy.stats import ks_2samp, spearmanr
from scipy.spatial.distance import cdist

from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)
random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)

DATA_DIR = Path("../output/band_extraction")
SYN_BASE_DIR = Path("../output/synthetic_generation")
EVAL_BASE_DIR = Path("../output/model_tuning")
EVAL_BASE_DIR.mkdir(parents=True, exist_ok=True)

BAND_COLS = ["Delta", "Theta", "Alpha", "Beta", "Gamma"]
CANONICAL_CONDITIONS = ["S1", "S2_match", "S2_nomatch"]
LABEL_COL = "label"
COND_COL  = "condition"
SOURCE_COL = "source"

MODEL_INFO = {"interp": "Classwise Interpolation"}

model_key  = "interp"
model_name = MODEL_INFO[model_key]

interp_dir = SYN_BASE_DIR / model_key

## 5.1. Base real data and baseline best model

### 5.1.1. Best Model Data Generation Check

In [79]:
# Baseline Phase-4 outputs
interp_real_fp = interp_dir / f"{MODEL_KEY}_real.csv"
interp_syn_fp  = interp_dir / f"{MODEL_KEY}_syn.csv"

interp_real = pd.read_csv(interp_real_fp)
interp_syn  = pd.read_csv(interp_syn_fp)

print("Baseline REAL shape:", interp_real.shape)
print("Baseline SYN  shape:", interp_syn.shape)

Baseline REAL shape: (30336, 9)
Baseline SYN  shape: (30336, 9)


In [80]:
print("Baseline real data shape:", interp_real.shape)
print("Baseline synthetic data shape:", interp_syn.shape)

Baseline real data shape: (30336, 9)
Baseline synthetic data shape: (30336, 9)


In [81]:
interp_real.head()

Unnamed: 0,Delta,Theta,Alpha,Beta,Gamma,total_power,label,condition,source
0,0.292732,0.249902,-0.610662,0.411409,1.345733,0.284397,1,S1 obj,real
1,0.37569,0.290176,-0.585701,2.490218,3.991401,0.927912,1,S1 obj,real
2,-0.300488,0.329996,-0.448042,4.00691,4.693634,2.271983,1,S1 obj,real
3,-0.12415,0.066128,-0.374884,4.00691,4.693634,2.648213,1,S1 obj,real
4,-0.037956,-0.381335,-0.659416,-0.1558,0.179559,-0.249257,1,S1 obj,real


In [82]:
real_features = interp_real[BAND_COLS].to_numpy()
real_labels = interp_real["label"].to_numpy()
real_conds = interp_real["condition"].to_numpy()

In [83]:
print("Real condition counts:", interp_real[COND_COL].value_counts())
print("Syn  condition counts:", interp_syn[COND_COL].value_counts())

Real condition counts: condition
S1 obj         10240
S2 match       10176
S2 nomatch,     9920
Name: count, dtype: int64
Syn  condition counts: condition
S1 obj         10240
S2 match       10176
S2 nomatch,     9920
Name: count, dtype: int64


## 5.2. Define Evaluation Metrics

In [84]:
def get_condition_slice(df: pd.DataFrame, cond_tag: str):
    mask = (df[COND_COL] == cond_tag)
    df_cond = df.loc[mask].reset_index(drop=True)

    X_cond = df_cond[BAND_COLS].to_numpy()
    y_cond = df_cond[LABEL_COL].to_numpy()

    return X_cond, y_cond, df_cond

In [85]:
class InterpGenerator(nn.Module):
    def __init__(self, input_dim=6, hidden_dim=32, output_dim=6):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim),
        )

    def forward(self, x):
        return self.net(x)

In [90]:
def fine_tune_interp_model_for_condition(
    cond_tag: str,
    X_cond_real: np.ndarray,
    n_epochs: int = 80,
    patience: int = 8,
    hyper_grid: dict | None = None,
    noise_std: float = 0.05,
):
    """
    Fine-tune an INTERP-style generator on REAL data for a single condition.

    X_cond_real: (N, 6) normalized band features for this condition.
    Early stopping is based on reconstruction loss.
    Hyperparameter selection is based on best (lowest) final loss.
    """

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    if hyper_grid is None:
        hyper_grid = {
            "hidden_dim":   [32, 64],
            "lr":           [1e-3, 5e-4],
            "weight_decay": [0.0, 1e-4],
            "batch_size":   [128, 256],
        }

    X_tensor = torch.from_numpy(X_cond_real.astype(np.float32))

    best_cfg   = None
    best_score = np.inf
    best_state = None

    # Iterate over hyperparameter combos
    for hd, lr, wd, bs in product(
        hyper_grid["hidden_dim"],
        hyper_grid["lr"],
        hyper_grid["weight_decay"],
        hyper_grid["batch_size"],
    ):
        print(f"\n[{cond_tag}] Trying config: hidden_dim={hd}, lr={lr}, wd={wd}, bs={bs}")

        dataset = TensorDataset(X_tensor)
        loader  = DataLoader(dataset, batch_size=bs, shuffle=True)

        gen = InterpGenerator(
            input_dim=len(BAND_COLS),
            hidden_dim=hd,
            output_dim=len(BAND_COLS),
        ).to(device)

        optimizer = optim.Adam(gen.parameters(), lr=lr, weight_decay=wd)
        criterion = nn.MSELoss()

        best_epoch_loss = np.inf
        epochs_no_improve = 0

        for epoch in range(1, n_epochs + 1):
            gen.train()
            total_loss = 0.0

            for (x_batch,) in loader:
                x_batch = x_batch.to(device)

                if noise_std > 0.0:
                    noise = torch.randn_like(x_batch) * noise_std
                    x_noisy = x_batch + noise
                else:
                    x_noisy = x_batch

                optimizer.zero_grad()
                x_hat = gen(x_noisy)
                loss  = criterion(x_hat, x_batch)
                loss.backward()
                optimizer.step()

                total_loss += loss.item() * x_batch.size(0)

            avg_loss = total_loss / len(dataset)
            print(f"[{cond_tag}] Epoch {epoch:03d} | recon loss={avg_loss:.6f}")

            # Early stopping on reconstruction loss
            if avg_loss + 1e-6 < best_epoch_loss:
                best_epoch_loss = avg_loss
                epochs_no_improve = 0
            else:
                epochs_no_improve += 1
                if epochs_no_improve >= patience:
                    print(f"[{cond_tag}] Early stopping at epoch {epoch}")
                    break

        # After training for this config, select best on loss
        if best_epoch_loss < best_score:
            best_score = best_epoch_loss
            best_cfg   = {
                "hidden_dim": hd,
                "lr": lr,
                "weight_decay": wd,
                "batch_size": bs,
                "best_epoch_loss": best_epoch_loss,
            }
            best_state = gen.state_dict()

    # Rebuild generator with best config and load its weights
    print(f"\n[{cond_tag}] BEST CONFIG:", best_cfg)

    best_gen = InterpGenerator(
        input_dim=len(BAND_COLS),
        hidden_dim=best_cfg["hidden_dim"],
        output_dim=len(BAND_COLS),
    ).to(device)
    best_gen.load_state_dict(best_state)

    # Save checkpoint for this condition
    ft_ckpt = interp_dir / f"{MODEL_KEY}_{cond_tag}_finetuned.pt"
    torch.save(best_gen.state_dict(), ft_ckpt)
    print(f"[{cond_tag}] Fine-tuned weights saved to {ft_ckpt}")

    return best_gen, best_cfg


In [91]:
@torch.no_grad()
def generate_synthetic_from_gen(
    gen: nn.Module,
    X_real_cond: np.ndarray,
    n_samples: int,
    noise_std: float = 0.10,
):
    """
    Generate synthetic samples by taking real vectors (as anchors)
    and adding noise before passing through the generator.
    This preserves the real distribution structure but lets
    the generator learn a denoising + interpolation mapping.
    """
    device = next(gen.parameters()).device

    N_real = X_real_cond.shape[0]
    # Sample indices with replacement to get n_samples anchors
    idx = np.random.randint(0, N_real, size=n_samples)
    X_anchor = X_real_cond[idx]

    X_anchor_tensor = torch.from_numpy(X_anchor.astype(np.float32)).to(device)

    if noise_std > 0.0:
        noise = torch.randn_like(X_anchor_tensor) * noise_std
        X_noisy = X_anchor_tensor + noise
    else:
        X_noisy = X_anchor_tensor

    gen.eval()
    X_syn_tensor = gen(X_noisy)
    X_syn = X_syn_tensor.cpu().numpy()

    return X_syn, idx


In [92]:
def fine_tune_and_regenerate_all_conditions(
    interp_real: pd.DataFrame,
    interp_syn: pd.DataFrame,
    noise_std_train: float = 0.05,
    noise_std_gen: float = 0.10,
):
    """
    For each condition (S1, S2_match, S2_nomatch):
      1. Take REAL normalized features (Phase-4 baseline).
      2. Fine-tune an INTERP generator on this condition's real data.
      3. Regenerate the SAME number of synthetic samples as in the baseline,
         using the trained generator and real anchors.
      4. Return a new syn DataFrame with same columns but source='synthetic_finetuned'.
    """

    all_syn_rows = []

    for cond in CANONICAL_CONDITIONS:
        print("\n" + "#" * 80)
        print(f"FINE-TUNING & REGENERATING — CONDITION = {cond}")
        print("#" * 80)

        # Real subset for this condition
        X_real_cond, y_real_cond, df_real_cond = get_condition_slice(interp_real, cond)
        n_real_cond = X_real_cond.shape[0]

        # Baseline synthetic count for this condition
        _, _, df_syn_cond = get_condition_slice(interp_syn, cond)
        n_syn_cond = df_syn_cond.shape[0]

        print(f"[{cond}] N_real={n_real_cond}, N_syn_baseline={n_syn_cond}")

        # 1) Fine-tune generator on REAL
        gen, cfg = fine_tune_interp_model_for_condition(
            cond_tag=cond,
            X_cond_real=X_real_cond,
            n_epochs=80,
            patience=8,
            hyper_grid=None,           # use default grid
            noise_std=noise_std_train, # noise for denoising-style training
        )

        # 2) Regenerate synthetic data with same count
        X_syn_new, idx_used = generate_synthetic_from_gen(
            gen,
            X_real_cond,
            n_samples=n_syn_cond,
            noise_std=noise_std_gen,
        )

        # Build DataFrame
        df_syn_new = pd.DataFrame(X_syn_new, columns=BAND_COLS)

        # Copy labels and condition from the real anchors that we sampled
        df_syn_new[LABEL_COL] = y_real_cond[idx_used]
        df_syn_new[COND_COL]  = cond
        df_syn_new["source"]  = "synthetic_finetuned"

        all_syn_rows.append(df_syn_new)

    interp_syn_finetuned = pd.concat(all_syn_rows, axis=0).reset_index(drop=True)

    print("\nALL CONDITIONS DONE.")
    print("Finetuned synthetic shape:", interp_syn_finetuned.shape)

    # Save to disk for Phase-5 evaluation
    out_fp = interp_dir / f"{MODEL_KEY}_syn_finetuned.csv"
    interp_syn_finetuned.to_csv(out_fp, index=False)
    print(f"Saved finetuned synthetic to {out_fp}")

    return interp_syn_finetuned


In [93]:
interp_syn_finetuned = fine_tune_and_regenerate_all_conditions(
    interp_real=interp_real,
    interp_syn=interp_syn,
    noise_std_train=0.05,
    noise_std_gen=0.10,
)


################################################################################
FINE-TUNING & REGENERATING — CONDITION = S1
################################################################################
[S1] N_real=0, N_syn_baseline=0

[S1] Trying config: hidden_dim=32, lr=0.001, wd=0.0, bs=128


ValueError: num_samples should be a positive integer value, but got num_samples=0

### 5.2.1. MMD Helper

In [63]:
def linear_mmd(X_real: np.ndarray, X_syn: np.ndarray) -> float:
    
    Xr = X_real
    Xs = X_syn
    n = Xr.shape[0]
    m = Xs.shape[0]

    # means
    mean_r = Xr.mean(axis=0, keepdims=True)
    mean_s = Xs.mean(axis=0, keepdims=True)

    # linear kernel MMD^2
    mmd2 = np.sum((mean_r - mean_s) ** 2)
    return float(mmd2)

### 5.2.2. Distribution Metrics (KS per band + MMD)

In [64]:
def evaluate_distribution_metrics(X_real: np.ndarray, X_syn: np.ndarray, band_cols=BAND_COLS, model_name: str = "", title_suffix: str = ""):
    
    assert X_real.shape[1] == len(band_cols)
    assert X_syn.shape[1]  == len(band_cols)

    print(f"\nDISTRIBUTION METRICS — {model_name} [{title_suffix}]")

    ks_results = []
    for i, band in enumerate(band_cols):
        r = X_real[:, i]
        s = X_syn[:, i]
        ks_stat, p_val = ks_2samp(r, s)

        similar_flag = "✓ Similar" if p_val >= 0.05 else "✗ Different"
        print(f"{band:<7}: KS={ks_stat:.4f}, p={p_val:.4f}   {similar_flag}")

        ks_results.append({
            "band": band,
            "ks": float(ks_stat),
            "p": float(p_val),
        })

    mmd_val = linear_mmd(X_real, X_syn)
    print(f"MMD: {mmd_val:.6f}")

    return ks_results, mmd_val

### 5.2.3. Real VS Synthetic Random Forest Classifier

In [65]:
def evaluate_real_vs_syn(X_real: np.ndarray, X_syn: np.ndarray, model_name: str = "", title_suffix: str = "", n_repeats: int = 3):

    X = np.vstack([X_real, X_syn])
    y = np.concatenate([
        np.zeros(X_real.shape[0], dtype=int),
        np.ones(X_syn.shape[0], dtype=int),
    ])

    accs = []
    for rep in range(n_repeats):
        X_tr, X_te, y_tr, y_te = train_test_split(
            X, y, test_size=0.3, random_state=RANDOM_SEED + rep, stratify=y
        )
        clf = RandomForestClassifier(
            n_estimators=200,
            max_depth=None,
            random_state=RANDOM_SEED + rep,
            n_jobs=-1,
        )
        clf.fit(X_tr, y_tr)
        acc = clf.score(X_te, y_te)
        accs.append(acc)

    acc_mean = float(np.mean(accs))
    print(f"\nREAL-vs-SYN CLASSIFIER — {model_name} [{title_suffix}]")
    print(f"Accuracy: {acc_mean:.4f}")

    return acc_mean

### 5.2.4. TSTR/TRTR Random Forest Metrics

In [66]:
def evaluate_tstr_trtr(X_real: np.ndarray, y_real: np.ndarray, X_syn: np.ndarray, y_syn: np.ndarray, model_name: str = "", title_suffix: str = ""):

    Xr_tr, Xr_te, yr_tr, yr_te = train_test_split(
        X_real, y_real, test_size=0.3,
        random_state=RANDOM_SEED, stratify=y_real
    )

    clf = RandomForestClassifier(
        n_estimators=200,
        max_depth=None,
        random_state=RANDOM_SEED,
        n_jobs=-1,
    )
    clf.fit(Xr_tr, yr_tr)

    trtr_acc = clf.score(Xr_te, yr_te)
    tstr_acc = clf.score(X_syn, y_syn)

    print(f"\nTSTR / TRTR — {model_name} [{title_suffix}]")
    print(f"TRTR: {trtr_acc:.4f}")
    print(f"TSTR: {tstr_acc:.4f}")
    print(f"Gap : {tstr_acc - trtr_acc:.4f}")

    return float(trtr_acc), float(tstr_acc), float(tstr_acc - trtr_acc)

## 5.3. Classwise Interpolation Model Fine Tune

### 5.3.1. Interpolation Model Generator

In [67]:
class InterpGenerator(nn.Module):
    def __init__(self, input_dim=6, hidden_dim=32, output_dim=6):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim),
        )

    def forward(self, x):
        return self.net(x)

### 5.3.2. Fine Tune for a Single Condition with Grid Search and Early Stopping

In [68]:
def fine_tune_interp_model_for_condition(cond_tag: str, X_cond_real: np.ndarray, n_epochs: int = 100, patience: int = 10, hyper_grid=None, noise_std: float = 0.0):

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    input_dim = X_cond_real.shape[1]

    if hyper_grid is None:
        hyper_grid = {
            "hidden_dim": [32, 64],
            "lr": [1e-3, 5e-4],
            "batch_size": [128, 256],
        }

    # reproducible shuffle / split indices
    n = X_cond_real.shape[0]
    indices = np.arange(n)
    np.random.shuffle(indices)

    # 80% train, 20% val
    split = int(0.8 * n)
    train_idx = indices[:split]
    val_idx   = indices[split:]

    X_train = X_cond_real[train_idx]
    X_val   = X_cond_real[val_idx]

    X_train_t = torch.from_numpy(X_train.astype(np.float32))
    X_val_t   = torch.from_numpy(X_val.astype(np.float32))

    best_global = {
        "val_loss": float("inf"),
        "config":   None,
        "state":    None,
    }

    # Grid search
    for h_dim in hyper_grid["hidden_dim"]:
        for lr in hyper_grid["lr"]:
            for bsz in hyper_grid["batch_size"]:
                print(f"\n[{cond_tag}] Trying config: hidden_dim={h_dim}, lr={lr}, batch_size={bsz}")

                gen = InterpGenerator(input_dim=input_dim, hidden_dim=h_dim, output_dim=input_dim).to(device)

                # load baseline checkpoint if exists
                baseline_ckpt = interp_dir / f"{model_key}_{cond_tag}_baseline.pt"
                if baseline_ckpt.exists():
                    print(f"  Loading baseline weights from {baseline_ckpt}")
                    gen.load_state_dict(torch.load(baseline_ckpt, map_location=device))
                else:
                    print("  No baseline checkpoint found. Starting from scratch (for this config).")

                optimizer = optim.Adam(gen.parameters(), lr=lr)
                criterion = nn.MSELoss()

                train_ds = TensorDataset(X_train_t)
                train_loader = DataLoader(train_ds, batch_size=bsz, shuffle=True)

                X_val_t_device = X_val_t.to(device)

                best_val_for_cfg = float("inf")
                epochs_no_improve = 0

                for epoch in range(n_epochs):
                    gen.train()
                    total_train_loss = 0.0

                    for (xb,) in train_loader:
                        xb = xb.to(device)

                        # denoising-style noise
                        if noise_std > 0:
                            noise = torch.randn_like(xb) * noise_std
                            xb_in = xb + noise
                        else:
                            xb_in = xb

                        optimizer.zero_grad()
                        x_hat = gen(xb_in)
                        loss = criterion(x_hat, xb)
                        loss.backward()
                        optimizer.step()

                        total_train_loss += loss.item() * xb.size(0)

                    train_loss = total_train_loss / len(train_ds)

                    # validate
                    gen.eval()
                    with torch.no_grad():
                        if noise_std > 0:
                            noise_val = torch.randn_like(X_val_t_device) * noise_std
                            val_in = X_val_t_device + noise_val
                        else:
                            val_in = X_val_t_device

                        val_hat = gen(val_in)
                        val_loss = criterion(val_hat, X_val_t_device).item()

                    print(f"[{cond_tag}] Epoch {epoch+1}/{n_epochs} "
                          f"(h={h_dim}, lr={lr}, bsz={bsz}) "
                          f"- train: {train_loss:.4f}, val: {val_loss:.4f}")

                    # early stopping check
                    if val_loss < best_val_for_cfg - 1e-4:
                        best_val_for_cfg = val_loss
                        epochs_no_improve = 0

                        # snapshot weights
                        best_state_for_cfg = {
                            "state_dict": gen.state_dict(),
                            "hidden_dim": h_dim,
                            "lr": lr,
                            "batch_size": bsz,
                        }
                    else:
                        epochs_no_improve += 1
                        if epochs_no_improve >= patience:
                            print(f"[{cond_tag}] Early stopping (no improvement for {patience} epochs).")
                            break

                # compare with global best
                if best_val_for_cfg < best_global["val_loss"]:
                    best_global["val_loss"] = best_val_for_cfg
                    best_global["config"]   = {
                        "hidden_dim": h_dim,
                        "lr": lr,
                        "batch_size": bsz,
                    }
                    best_global["state"]    = best_state_for_cfg["state_dict"]

    # Build final generator with best config and load best weights
    best_cfg = best_global["config"]
    print(f"\n[{cond_tag}] BEST CONFIG: {best_cfg}, best_val_loss={best_global['val_loss']:.6f}")

    best_gen = InterpGenerator(
        input_dim=input_dim,
        hidden_dim=best_cfg["hidden_dim"],
        output_dim=input_dim
    ).to(device)
    best_gen.load_state_dict(best_global["state"])

    # Save fine-tuned weights
    ft_ckpt = interp_dir / f"{model_key}_{cond_tag}_finetuned.pt"
    torch.save(best_gen.state_dict(), ft_ckpt)
    print(f"[{cond_tag}] Fine-tuned weights saved to {ft_ckpt}")

    return best_gen, best_cfg

## 5.4. Generate New Synthetic Data from Fined Tuned Model

In [69]:
def generate_synthetic_from_gen(gen: nn.Module, X_real_cond: np.ndarray, n_samples: int, noise_std: float = 0.1):

    device = next(gen.parameters()).device
    gen.eval()

    idx = np.random.choice(X_real_cond.shape[0], size=n_samples, replace=True)
    X_base = X_real_cond[idx]

    with torch.no_grad():
        x = torch.from_numpy(X_base.astype(np.float32)).to(device)
        z = torch.randn_like(x) * noise_std
        x_in = x + z
        x_out = gen(x_in).cpu().numpy()

    return x_out, idx

In [70]:
def fine_tune_and_regenerate_all_conditions(interp_real: pd.DataFrame, interp_syn: pd.DataFrame, noise_std_train: float = 0.05, noise_std_gen: float = 0.10):

    all_syn_rows = []

    for cond in CANONICAL_CONDITIONS:
        print("\n" + "#" * 80)
        print(f"FINE-TUNE FOR CONDITION: {cond}")
        print("#" * 80)

        # real subset for this condition
        mask_real_cond = (interp_real[COND_COL] == cond)
        df_real_cond = interp_real.loc[mask_real_cond].reset_index(drop=True)
        X_real_cond = df_real_cond[BAND_COLS].to_numpy()
        y_real_cond = df_real_cond[LABEL_COL].to_numpy()

        # baseline syn subset (to match sample size & label distribution)
        mask_syn_cond = (interp_syn[COND_COL] == cond)
        df_syn_cond = interp_syn.loc[mask_syn_cond].reset_index(drop=True)
        n_syn_cond  = df_syn_cond.shape[0]

        print(f"Condition {cond}: N_real={X_real_cond.shape[0]}, N_syn={n_syn_cond}")

        # fine-tune
        gen, cfg = fine_tune_interp_model_for_condition(
            cond_tag=cond,
            X_cond_real=X_real_cond,
            n_epochs=80,
            patience=8,
            hyper_grid=None,          # default grid
            noise_std=noise_std_train # noise in training (denoising style)
        )

        # generate new synthetic
        X_syn_new, idx_used = generate_synthetic_from_gen(
            gen,
            X_real_cond,
            n_samples=n_syn_cond,
            noise_std=noise_std_gen,
        )

        # For labels, reuse labels from the sampled real rows
        y_syn_new = y_real_cond[idx_used]

        df_syn_new = pd.DataFrame(X_syn_new, columns=BAND_COLS)
        df_syn_new[LABEL_COL] = y_syn_new
        df_syn_new[COND_COL]  = cond
        df_syn_new[SOURCE_COL] = "synthetic_finetuned"

        all_syn_rows.append(df_syn_new)

    interp_syn_ft = pd.concat(all_syn_rows, axis=0, ignore_index=True)
    print("\nNew fine-tuned synthetic df shape:", interp_syn_ft.shape)

    # Save for later if needed
    out_fp = interp_dir / f"{model_key}_syn_finetuned.csv"
    interp_syn_ft.to_csv(out_fp, index=False)
    print(f"Saved fine-tuned synthetic data to {out_fp}")

    return interp_syn_ft

In [72]:
interp_syn_finetuned = fine_tune_and_regenerate_all_conditions(
    interp_real=real_df,
    interp_syn=syn_df,
    noise_std_train=0.05,
    noise_std_gen=0.10,
)


################################################################################
FINE-TUNE FOR CONDITION: S1
################################################################################
Condition S1: N_real=0, N_syn=0

[S1] Trying config: hidden_dim=32, lr=0.001, batch_size=128
  No baseline checkpoint found. Starting from scratch (for this config).


ValueError: num_samples should be a positive integer value, but got num_samples=0