### Loading Packages

In [None]:
import sys
import os
sys.path.append(os.path.abspath("..")) 
import pandas as pd
from pathlib import Path
from src.generators.tvae_generator import TVAESynthesizerWrapper
from src.generators.ctgan_generator import CTGANSynthesizerWrapper
from src.generators.ctabgan_generator_tofix import CTABGANSynthesizerWrapper
from src.generators.great_generator import GREATSynthesizerWrapper
from src.generators.rtf_generator import RTFGeneratorWrapper
from src.utils.postprocess import match_format

In [2]:
import time
import numpy as np
import pandas as pd
from typing import Optional, List, Dict
from sklearn.metrics import accuracy_score
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.model_selection import train_test_split

### Importing the Data

In [None]:
# Load preprocessed training data
df = pd.read_csv("../data/raw/stroke.csv")
df.drop(columns=['id'],inplace=True)
dataset_name = "stroke"
# Output log file
log_path = Path("../results/logs/synthetic_finetuning_log_stroke.csv")
log_path.parent.mkdir(parents=True, exist_ok=True)

### Parameters to Check

In [None]:
generators = {
    "tvae": TVAESynthesizerWrapper(output_dir="../data/synthetic/tvae"),
    "ctgan": CTGANSynthesizerWrapper(output_dir="../data/synthetic/ctgan"),
    "ctabgan": lambda: CTABGANSynthesizerWrapper(output_dir="../data/synthetic/ctabgan", num_experiments=1),
    "great": GREATSynthesizerWrapper(output_dir="../data/synthetic/great"),
    "rtf": RTFGeneratorWrapper(output_dir="../data/synthetic/rtf"),
}

# Light param grids 
param_grids = {
    "ctgan": [
        {"epochs": 100, "batch_size": 128, "embedding_dim": 128, "pac": 1, "generator_lr": 2e-4, "discriminator_lr": 2e-4},
        {"epochs": 100, "batch_size": 128, "embedding_dim": 128, "pac": 2, "generator_lr": 2e-4, "discriminator_lr": 2e-4},
        {"epochs": 100, "batch_size": 256, "embedding_dim": 128, "pac": 1, "generator_lr": 2e-4, "discriminator_lr": 2e-4},
        {"epochs": 100, "batch_size": 256, "embedding_dim": 128, "pac": 2, "generator_lr": 2e-4, "discriminator_lr": 2e-4},
        {"epochs": 300, "batch_size": 256, "embedding_dim": 128, "pac": 1, "generator_lr": 2e-4, "discriminator_lr": 2e-4},
        {"epochs": 300, "batch_size": 256, "embedding_dim": 128, "pac": 2, "generator_lr": 2e-4, "discriminator_lr": 2e-4},
    ], 
    "tvae": [
        {"epochs": 100, "compress_dims": (128,128), "decompress_dims": (128,128)},
        {"epochs": 200, "compress_dims": (128,128), "decompress_dims": (128,128)},
        {"epochs": 100, "compress_dims": (64,64), "decompress_dims": (64,64)},
        {"epochs": 200, "compress_dims": (64,64), "decompress_dims": (64,64)},
    ],
    "ctabgan": [
        {"epochs": 100},
        {"epochs": 150},
        {"epochs": 200},
        {"epochs": 100, "generator_dim": (128, 128), "discriminator_dim": (128, 128)},
        {"epochs": 150, "generator_dim": (128, 128), "discriminator_dim": (128, 128)},
        {"epochs": 100, "generator_dim": (256, 256), "discriminator_dim": (256, 256)},
    ],
    "great": [
        {"llm": "distilgpt2", "batch_size": 32, "epochs": 40, "save_steps": 10000, "guided_sampling": False},
        {"llm": "distilgpt2", "batch_size": 64, "epochs": 40, "save_steps": 10000, "guided_sampling": False},
        {"llm": "distilgpt2", "batch_size": 32, "epochs": 80, "save_steps": 10000, "guided_sampling": False},
        {"llm": "distilgpt2", "batch_size": 64, "epochs": 80, "save_steps": 10000, "guided_sampling": False},
    ],

    "rtf": [
        {"batch_size": 64, "epochs": 30, "gradient_accumulation_steps": 4, "mask_rate": 0.00, "logging_steps": 100},
        {"batch_size": 32, "epochs": 30, "gradient_accumulation_steps": 8, "mask_rate": 0.05, "logging_steps": 100},
        {"batch_size": 96, "epochs": 20, "gradient_accumulation_steps": 2, "mask_rate": 0.00, "logging_steps": 100},
        {"batch_size": 32, "epochs": 40, "gradient_accumulation_steps": 8, "mask_rate": 0.00, "logging_steps": 100},
        {"batch_size": 64, "epochs": 30, "gradient_accumulation_steps": 4, "mask_rate": 0.10, "logging_steps": 100},
        {"batch_size": 64, "epochs": 25, "gradient_accumulation_steps": 4, "mask_rate": 0.05, "logging_steps": 100},
    ]

}

# CTABGAN configs 
ctabgan_configs = {
    "diabetes": {
        "raw_csv_path": "../data/processed/diabetes_train.csv",
        "categorical_columns": ['gender', 'hypertension', 'heart_disease', 'smoking_history', 'diabetes'],
        "log_columns": [],
        "mixed_columns": {},
        "general_columns": ['bmi', 'HbA1c_level'],
        "non_categorical_columns": [],
        "integer_columns": ['age', 'blood_glucose_level'],
        "problem_type": {"Classification": 'diabetes'}
    },
    "stroke": {
        "raw_csv_path": "../data/processed/stroke_train.csv",
        "categorical_columns": ['gender', 'hypertension', 'heart_disease', 'ever_married', 'work_type', 'Residence_type', 'smoking_status', 'stroke'],
        "log_columns": [],
        "mixed_columns": {},
        "general_columns": ['bmi'],
        "non_categorical_columns": [],
        "integer_columns": ['age', 'avg_glucose_level'],
        "problem_type": {"Classification": 'stroke'}
    },
    "cirrhosis": {
        "raw_csv_path": "../data/processed/cirrhosis_train.csv",
        "categorical_columns": ['Sex', 'Ascites', 'Hepatomegaly', 'Spiders', 'Edema', 'Drug', 'Status', 'Stage'],
        "log_columns": [],
        "mixed_columns": {},
        "general_columns": ['Cholesterol', 'Albumin', 'Copper', 'Alk_Phos', 'SGOT', 'Tryglicerides', 'Platelets', 'Prothrombin'],
        "non_categorical_columns": [],
        "integer_columns": ['N_Days', 'Age'],
        "problem_type": {"Classification": 'Status'}
    }
}


In [None]:
# Evaluator function

def _ml_utility_valsplit(
    real_df: pd.DataFrame,
    synth_df: pd.DataFrame,
    target_col: str,
    test_size: float = 0.30,
    random_state: int = 42,
    max_synth_rows: int = 5000,
) -> Optional[float]:
    """Train on synthetic, validate on a single real holdout split."""
    if target_col not in real_df.columns or target_col not in synth_df.columns:
        return None

    # numeric-only features (keep consistent with your current approach)
    X_real = real_df.drop(columns=[target_col]).select_dtypes(include=[np.number])
    y_real = real_df[target_col]
    X_syn  = synth_df.drop(columns=[target_col]).select_dtypes(include=[np.number])
    y_syn  = synth_df[target_col]

    # guards
    if X_real.shape[1] == 0 or y_real.nunique() < 2 or y_real.nunique() > 20:
        return None

    # downsample synthetic for speed 
    if len(X_syn) > max_synth_rows:
        idx = np.random.RandomState(random_state).choice(len(X_syn), size=max_synth_rows, replace=False)
        X_syn = X_syn.iloc[idx]
        y_syn = y_syn.iloc[idx]

    X_tr, X_te, y_tr, y_te = train_test_split(
        X_real, y_real, test_size=test_size, random_state=random_state, stratify=y_real
    )

    # Train on synthetic, test on real holdout
    clf = RandomForestClassifier(n_estimators=50, max_depth=6, random_state=random_state, n_jobs=-1)
    clf.fit(X_syn, y_syn)
    pred = clf.predict(X_te)
    return float(accuracy_score(y_te, pred))


def _ml_utility_kfold(
    real_df: pd.DataFrame,
    synth_df: pd.DataFrame,
    target_col: str,
    n_splits: int = 5,
    random_state: int = 42,
    max_synth_rows_per_fold: int = 3000,
) -> Optional[dict]:
    """Train on synthetic, test on K disjoint real folds; return mean/stdev accuracy."""
    if target_col not in real_df.columns or target_col not in synth_df.columns:
        return None

    X_real = real_df.drop(columns=[target_col]).select_dtypes(include=[np.number])
    y_real = real_df[target_col]
    X_syn  = synth_df.drop(columns=[target_col]).select_dtypes(include=[np.number])
    y_syn  = synth_df[target_col]

    if X_real.shape[1] == 0 or y_real.nunique() < 2 or y_real.nunique() > 20:
        return None

    # downsample synthetic once (held fixed across folds for fairness/speed)
    if len(X_syn) > max_synth_rows_per_fold:
        idx = np.random.RandomState(random_state).choice(len(X_syn), size=max_synth_rows_per_fold, replace=False)
        X_syn = X_syn.iloc[idx]
        y_syn = y_syn.iloc[idx]

    skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=random_state)
    fold_scores: List[float] = []

    for _, test_idx in skf.split(X_real, y_real):
        X_te = X_real.iloc[test_idx]
        y_te = y_real.iloc[test_idx]

        clf = RandomForestClassifier(n_estimators=50, max_depth=6, random_state=random_state, n_jobs=-1)
        clf.fit(X_syn, y_syn)
        pred = clf.predict(X_te)
        fold_scores.append(accuracy_score(y_te, pred))

    return {
        "mean": float(np.mean(fold_scores)),
        "std":  float(np.std(fold_scores, ddof=1)) if len(fold_scores) > 1 else 0.0,
        "n_splits": n_splits
    }


def evaluate_quick_with_cv(
    real_data: pd.DataFrame,
    synthetic_data: pd.DataFrame,
    target_col: Optional[str] = None,
    *,
    use_kfold: bool = True,
    n_splits: int = 5,
    random_state: int = 42,
    max_synth_rows: int = 5000,
) -> dict:

    # 1) shape (30%)
    shape_score = 0.30 if real_data.shape[1] == synthetic_data.shape[1] else 0.0

    # 2) numeric column means (40%)
    mean_score = 0.0
    numeric_cols = real_data.select_dtypes(include=[np.number]).columns
    if len(numeric_cols) > 0:
        diffs = []
        for col in numeric_cols:
            if col in synthetic_data.columns:
                rm = real_data[col].mean()
                sm = synthetic_data[col].mean()
                if rm != 0:
                    diffs.append(abs(rm - sm) / abs(rm))
        if diffs:
            mean_score = 0.40 * max(0.0, 1.0 - float(np.mean(diffs)))

    # 3) ML utility (30%): TSTR on val split or K-fold
    ml_score = None
    ml_cv = None
    if target_col:
        if use_kfold:
            ml_cv = _ml_utility_kfold(real_data, synthetic_data, target_col, n_splits=n_splits, random_state=random_state)
            if ml_cv is not None:
                ml_score = 0.30 * ml_cv["mean"]
        else:
            acc = _ml_utility_valsplit(real_data, synthetic_data, target_col, random_state=random_state, max_synth_rows=max_synth_rows)
            if acc is not None:
                ml_score = 0.30 * acc

    overall = shape_score + mean_score + (ml_score or 0.0)
    overall = min(1.0, float(overall))

    return {
        "overall_quick_score": overall,
        "shape_component": shape_score,
        "means_component": mean_score,
        "ml_component": ml_score if ml_score is not None else 0.0,
        "ml_cv": ml_cv, 
    }


# Wrapper Calling
def _call_wrapper(generator_obj_or_factory, name, df, dataset_name, *, params, ctabgan_configs):
    wrapper = generator_obj_or_factory() if callable(generator_obj_or_factory) else generator_obj_or_factory
    kwargs = {}
    if name.lower() == "ctabgan" and ctabgan_configs is not None:
        cfg = ctabgan_configs.get(dataset_name.lower())
        if cfg is None:
            raise ValueError(f"No CTABGAN config for dataset '{dataset_name}'.")
        kwargs["ctabgan_config"] = cfg

    try:
        return wrapper.fit_and_generate(df, dataset_name, **params, **kwargs)
    except TypeError:
        try:
            return wrapper.fit_and_generate(df, dataset_name, synth_params=params, **kwargs)
        except TypeError:
            print(f"  Note: '{name}' wrapper did not accept params; using defaults.")
            return wrapper.fit_and_generate(df, dataset_name, **kwargs)


#  Grid search with validation / K-fold
def light_grid_search_with_wrappers(
    df: pd.DataFrame,
    dataset_name: str,
    generators: Dict[str, object],
    param_grids: Dict[str, List[dict]],
    *,
    target_col: Optional[str] = None,
    n_samples: int = 1000,
    ctabgan_configs: Optional[dict] = None,
    models_to_test: Optional[List[str]] = None,
    # new options:
    use_kfold: bool = True,
    n_splits: int = 5,
    random_state: int = 42,
    max_synth_rows_for_eval: int = 5000,
):
    """
    Runs a light grid search with optional K-fold utility evaluation (Train on Synthetic, Test on Real).
    """
    if models_to_test is None:
        models_to_test = list(generators.keys())

    results = []
    print(f"Grid search on dataset '{dataset_name}' | models: {models_to_test}")

    for name in models_to_test:
        if name not in generators:
            print(f"Skipping '{name}': not in generators dict.")
            continue
        if name not in param_grids:
            print(f"Skipping '{name}': no param grid provided.")
            continue

        print(f"\n=== {name.upper()} ===")
        for i, params in enumerate(param_grids[name], start=1):
            print(f"[{i}/{len(param_grids[name])}] params={params}")
            t0 = time.time()
            try:
                synth_df, stats = _call_wrapper(
                    generators[name], name, df, dataset_name,
                    params=params, ctabgan_configs=ctabgan_configs
                )

                if len(synth_df) > n_samples:
                    synth_df_eval = synth_df.sample(n_samples, random_state=random_state).reset_index(drop=True)
                else:
                    synth_df_eval = synth_df.reset_index(drop=True)

                eval_out = evaluate_quick_with_cv(
                    df.reset_index(drop=True),
                    synth_df_eval,
                    target_col=target_col,
                    use_kfold=use_kfold,
                    n_splits=n_splits,
                    random_state=random_state,
                    max_synth_rows=max_synth_rows_for_eval
                )
                elapsed = time.time() - t0

                row = {
                    "generator": name,
                    "params": params,
                    "quick_score": round(eval_out["overall_quick_score"], 4),
                    "shape_comp": round(eval_out["shape_component"], 4),
                    "means_comp": round(eval_out["means_component"], 4),
                    "ml_comp": round(eval_out["ml_component"], 4),
                    "elapsed_s": round(elapsed, 2),
                }
                if eval_out["ml_cv"] is not None:
                    row.update({
                        "ml_cv_mean": round(eval_out["ml_cv"]["mean"], 4),
                        "ml_cv_std":  round(eval_out["ml_cv"]["std"], 4),
                        "ml_cv_folds": eval_out["ml_cv"]["n_splits"],
                    })

                results.append(row)
                print(f"  -> quick_score={row['quick_score']:.3f} | ml={row.get('ml_cv_mean', row['ml_comp']):.3f} | time={row['elapsed_s']}s")

            except Exception as e:
                print(f"  FAILED: {type(e).__name__}: {str(e)[:200]}")
                continue

    # Sort best-first by quick_score then ml utility
    results = sorted(
        results,
        key=lambda r: (-r["quick_score"], -r.get("ml_cv_mean", r["ml_comp"]))
    )

    print("\nTOP 5:")
    for idx, r in enumerate(results[:5], 1):
        ml_val = r.get("ml_cv_mean", r["ml_comp"])
        print(f"#{idx} {r['generator'].upper()} | quick={r['quick_score']:.3f} | ml={ml_val:.3f} | params={r['params']}")

    return results


### Running Light Grid Search

In [None]:
results = light_grid_search_with_wrappers(
    df,
    dataset_name,
    generators=generators,
    param_grids=param_grids,
    target_col="stroke",             
    n_samples=1000,
    ctabgan_configs=ctabgan_configs,
    models_to_test=["ctgan", "tvae", "ctabgan", "great", "rtf"],
    models_to_test=["ctabgan"],
    use_kfold=True,                  
    n_splits=5,
    random_state=42
)
