In [1]:
from __future__ import annotations

import os
os.chdir("..")   # go from notebooks/ to project root

from flamekit.io_fronts import Case, load_fronts
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from typing import List, Dict, Tuple, Optional

from sklearn.ensemble import RandomForestRegressor
from sklearn.feature_selection import SequentialFeatureSelector
from sklearn.model_selection import KFold, train_test_split
from sklearn.metrics import r2_score, mean_squared_error

# =========================
# USER SETTINGS
# =========================

BASE_DIR = Path(r"/media/alexandros/OS/Documents and Settings/alexp/Documents/Bachelor Thesis/Code/isocontours")
PHI = 0.40
LAT_SIZE = "100"
POST = True

TIME_STEPS = [200]
ISOLEVEL = 0.8
TARGET_VAR = "DW_FDS"

CLUSTER_ON_SPATIAL = False

CLUSTER_FEATURES_INCLUDE = {
    "curvature",
}

MODEL_FEATURES_INCLUDE = {
    "curvature",
    "dcurvdx",
    "dcurvdy",
    "tangential_strain_rate",
    "normal_strain_rate",
    "vorticity",
    "u_n",
    "u_t",
    "du_ndx",
    "du_ndy",
    "du_tdx",
    "du_tdy",
    "dTdx",
    "dTdy",
}

FEATURES_EXCLUDE = set()

CURVATURE_COLUMN = "curvature"
CURVATURE_BOUNDS = (-0.1, 0.1)  # low/high thresholds for 3 bins
N_CLUSTERS = len(CURVATURE_BOUNDS) + 1

MIN_CLUSTER_SAMPLES = 50
RANDOM_STATE = 0
TEST_SIZE = 0.25

# Backward selection
BACKWARD_N_FEATURES_GLOBAL = 4
BACKWARD_N_FEATURES_CLUSTER = 4
SFS_SCORING = "r2"
SFS_CV_SPLITS = 3

MODEL_PARAMS = dict(
    n_estimators=300,
    max_depth=None,
    random_state=RANDOM_STATE,
    n_jobs=-1,
)


# =========================
# PLOT SAVING (folder naming convention)
# =========================

OUTPUT_BASE_DIR = Path(r"/media/alexandros/OS/Documents and Settings/alexp/Documents/Bachelor Thesis/report_figures/results/backward_selection")

def _time_steps_tag(time_steps: list[int]) -> str:
    if not time_steps:
        return "t_none"
    if len(time_steps) == 1:
        return f"t_{time_steps[0]}"
    t_min = min(time_steps)
    t_max = max(time_steps)
    return f"t_{t_min}_to_{t_max}"

RUN_DIR = OUTPUT_BASE_DIR / f"lat_{LAT_SIZE}" / _time_steps_tag(TIME_STEPS)
TS_TAG = "_".join(map(str, TIME_STEPS))
SAVE_DIR = RUN_DIR / f"h_{LAT_SIZE}_t_{TS_TAG}_iso_{ISOLEVEL}"
SAVE_DIR.mkdir(parents=True, exist_ok=True)

FIG_DPI = 300

def _safe_name(s: str) -> str:
    return "".join(ch if (ch.isalnum() or ch in "._-") else "_" for ch in s)

def save_png(stem: str, dpi: int = FIG_DPI):
    """Save current matplotlib figure as PNG to SAVE_DIR."""
    fname = SAVE_DIR / f"{_safe_name(stem)}.png"
    plt.gcf().savefig(fname, dpi=dpi, bbox_inches="tight", facecolor="white")

def save_then_show(stem: str, dpi: int = FIG_DPI):
    save_png(stem, dpi=dpi)
    plt.show()

print(f"[INFO] Saving outputs to: {SAVE_DIR}")


# =========================
# Utilities
# =========================

def _numeric_cols(df: pd.DataFrame) -> List[str]:
    return [c for c in df.columns if pd.api.types.is_numeric_dtype(df[c])]

def resolve_features(
    df: pd.DataFrame,
    include: Optional[List[str] | set[str]],
    exclude: set[str],
) -> List[str]:
    num = set(_numeric_cols(df))
    feats = num if include is None else set(include).intersection(num)
    feats = feats.difference(exclude)
    return sorted(feats)

def intersect_feature_space(feature_sets: List[set[str]]) -> List[str]:
    if not feature_sets:
        return []
    common = set.intersection(*feature_sets)
    return sorted(common)

def run_backward_selection(
    X: np.ndarray,
    y: np.ndarray,
    feature_names: List[str],
    n_features_to_select: int,
    random_state: int,
    scoring: str,
    cv_splits: int,
) -> List[str]:
    if X.shape[1] == 0:
        return []

    n_features_to_select = int(min(n_features_to_select, X.shape[1]))
    cv = KFold(n_splits=cv_splits, shuffle=True, random_state=random_state)

    model = RandomForestRegressor(**MODEL_PARAMS)
    sfs = SequentialFeatureSelector(
        estimator=model,
        n_features_to_select=n_features_to_select,
        direction="backward",
        scoring=scoring,
        cv=cv,
        n_jobs=-1,
    )

    sfs.fit(X, y)
    mask = sfs.get_support()
    return [f for f, keep in zip(feature_names, mask) if keep]


def evaluate_selected_model(
    X: np.ndarray,
    y: np.ndarray,
    test_size: float,
    random_state: int,
) -> Tuple[float, float]:
    Xtr, Xte, ytr, yte = train_test_split(
        X, y, test_size=test_size, random_state=random_state
    )

    model = RandomForestRegressor(**MODEL_PARAMS)
    model.fit(Xtr, ytr)

    ypred = model.predict(Xte)
    r2 = float(r2_score(yte, ypred))
    rmse = float(np.sqrt(mean_squared_error(yte, ypred)))
    return r2, rmse



# ============================================================
# LOAD + POOL
# ============================================================

dfs = []
cluster_feature_sets_per_t: List[set[str]] = []
model_feature_sets_per_t: List[set[str]] = []

cluster_include = set(CLUSTER_FEATURES_INCLUDE)
if CLUSTER_ON_SPATIAL:
    cluster_include = cluster_include.union({"x", "y", "z"})

model_include = set(MODEL_FEATURES_INCLUDE)

for ts in TIME_STEPS:
    CASE = Case(
        base_dir=BASE_DIR,
        phi=PHI,
        lat_size=LAT_SIZE,
        time_step=ts,
        post=POST,
    )

    fronts = load_fronts(CASE, [ISOLEVEL])
    if ISOLEVEL not in fronts:
        raise ValueError(f"ISOLEVEL {ISOLEVEL} not found for timestep {ts}")

    df_t = fronts[ISOLEVEL].copy()
    df_t["c_iso"] = float(ISOLEVEL)
    df_t["timestep"] = int(ts)

    if TARGET_VAR not in df_t.columns:
        raise ValueError(f"TARGET_VAR '{TARGET_VAR}' not found for timestep {ts}")

    cl_feats_t = set(resolve_features(df_t, cluster_include, FEATURES_EXCLUDE))
    ml_feats_t = set(resolve_features(df_t, model_include, FEATURES_EXCLUDE))

    cluster_feature_sets_per_t.append(cl_feats_t)
    model_feature_sets_per_t.append(ml_feats_t)

    dfs.append(df_t)

df_all = pd.concat(dfs, ignore_index=True)
print(f"Pooled rows total: n={len(df_all)} across timesteps={TIME_STEPS}")

cluster_features = intersect_feature_space(cluster_feature_sets_per_t)
model_features = intersect_feature_space(model_feature_sets_per_t)

if len(cluster_features) == 0:
    raise ValueError("No common numeric CLUSTER features across requested timesteps. Adjust CLUSTER_FEATURES_INCLUDE.")
if len(model_features) == 0:
    raise ValueError("No common numeric MODEL features across requested timesteps. Adjust MODEL_FEATURES_INCLUDE.")

print(f"\nCluster features (common across all timesteps): {cluster_features}")
print(f"Model features (common across all timesteps):   {model_features}")

required = sorted(set(cluster_features).union(model_features).union({TARGET_VAR, CURVATURE_COLUMN}))
dfc = df_all.dropna(subset=required).copy()
print(f"\nAfter dropna on required (cluster+model+target): n={len(dfc)}")


# ============================================================
# GLOBAL BACKWARD SELECTION
# ============================================================

X_all = dfc[model_features].to_numpy()
y_all = dfc[TARGET_VAR].to_numpy()

selected_global = run_backward_selection(
    X=X_all,
    y=y_all,
    feature_names=model_features,
    n_features_to_select=BACKWARD_N_FEATURES_GLOBAL,
    random_state=RANDOM_STATE,
    scoring=SFS_SCORING,
    cv_splits=SFS_CV_SPLITS,
)

print("\nGlobal backward selection:")
print(selected_global)

r2_global = np.nan
rmse_global = np.nan
if selected_global:
    X_sel = dfc[selected_global].to_numpy()
    r2_global, rmse_global = evaluate_selected_model(
        X=X_sel,
        y=y_all,
        test_size=TEST_SIZE,
        random_state=RANDOM_STATE,
    )

summary_global = pd.DataFrame({
    "rank": list(range(1, len(selected_global) + 1)),
    "feature": selected_global,
})
summary_global["r2"] = r2_global
summary_global["rmse"] = rmse_global
summary_global_csv = SAVE_DIR / "backward_selected_global.csv"
summary_global.to_csv(summary_global_csv, index=False)
print(f"\nGlobal model accuracy: R2={r2_global:.4f} | RMSE={rmse_global:.6e}")
print(f"\n[INFO] Wrote global selection: {summary_global_csv}")


# ============================================================
# CURVATURE BINNING (predefined thresholds)
# ============================================================

curv_col = CURVATURE_COLUMN
if curv_col not in dfc.columns:
    raise ValueError(f"CURVATURE_COLUMN '{curv_col}' not found in dataframe")

if len(CURVATURE_BOUNDS) != 2:
    raise ValueError("CURVATURE_BOUNDS must contain exactly two values")
low, high = CURVATURE_BOUNDS
if low >= high:
    raise ValueError("CURVATURE_BOUNDS must be strictly increasing (low < high)")

dfc["cluster"] = np.digitize(
    dfc[curv_col].to_numpy(),
    bins=[low, high],
    right=True,
)

print(f"\nCurvature bins: {CURVATURE_BOUNDS} -> clusters 0..{N_CLUSTERS - 1}")


# ============================================================
# Per-cluster backward selection
# ============================================================

selected_rows: List[dict] = []
cluster_metrics_rows: List[dict] = []

for cl in range(N_CLUSTERS):
    sub = dfc[dfc["cluster"] == cl].copy()
    n_cl = len(sub)
    print(f"\n--- Cluster {cl} (CURVATURE BIN) | n={n_cl} ---")

    if n_cl < MIN_CLUSTER_SAMPLES:
        print(f"Skipping (n < MIN_CLUSTER_SAMPLES={MIN_CLUSTER_SAMPLES})")
        continue

    X = sub[model_features].to_numpy()
    y = sub[TARGET_VAR].to_numpy()

    selected = run_backward_selection(
        X=X,
        y=y,
        feature_names=model_features,
        n_features_to_select=BACKWARD_N_FEATURES_CLUSTER,
        random_state=RANDOM_STATE,
        scoring=SFS_SCORING,
        cv_splits=SFS_CV_SPLITS,
    )

    r2_cl = np.nan
    rmse_cl = np.nan
    if selected:
        X_sel = sub[selected].to_numpy()
        r2_cl, rmse_cl = evaluate_selected_model(
            X=X_sel,
            y=y,
            test_size=TEST_SIZE,
            random_state=RANDOM_STATE,
        )

    cluster_metrics_rows.append({
        "cluster": cl,
        "n_cluster": int(n_cl),
        "n_features": int(len(selected)),
        "r2": r2_cl,
        "rmse": rmse_cl,
    })

    print(f"Accuracy: R2={r2_cl:.4f} | RMSE={rmse_cl:.6e}")
    print("Selected features (backward order):")
    for i, f in enumerate(selected, 1):
        print(f"  {i:02d}. {f}")
        selected_rows.append({
            "cluster": cl,
            "rank": i,
            "feature": f,
            "n_cluster": int(n_cl),
        })

if selected_rows:
    summary_df = pd.DataFrame(selected_rows)
    summary_csv = SAVE_DIR / "backward_selected_per_cluster.csv"
    summary_df.to_csv(summary_csv, index=False)

    print("\nSummary (backward selected features per cluster):")
    print(summary_df.to_string(index=False))
    print(f"\n[INFO] Wrote per-cluster selection: {summary_csv}")
else:
    print("\nNo clusters met MIN_CLUSTER_SAMPLES; no per-cluster results to report.")

if cluster_metrics_rows:
    metrics_df = pd.DataFrame(cluster_metrics_rows)
    metrics_csv = SAVE_DIR / "backward_cluster_metrics.csv"
    metrics_df.to_csv(metrics_csv, index=False)
    print("\nPer-cluster model accuracy:")
    print(metrics_df.to_string(index=False))
    print(f"\n[INFO] Wrote per-cluster metrics: {metrics_csv}")

print(f"\nSaved outputs to: {SAVE_DIR}")


[INFO] Saving outputs to: /media/alexandros/OS/Documents and Settings/alexp/Documents/Bachelor Thesis/report_figures/results/backward_selection/lat_100/t_200/h_100_t_200_iso_0.8
Pooled rows total: n=3845 across timesteps=[200]

Cluster features (common across all timesteps): ['curvature']
Model features (common across all timesteps):   ['curvature', 'dTdx', 'dTdy', 'dcurvdx', 'dcurvdy', 'du_ndx', 'du_ndy', 'du_tdx', 'du_tdy', 'normal_strain_rate', 'tangential_strain_rate', 'u_n', 'u_t', 'vorticity']

After dropna on required (cluster+model+target): n=3845

Global backward selection:
['dTdx', 'dTdy', 'du_ndx', 'du_tdy']

Global model accuracy: R2=0.9182 | RMSE=1.039460e-01

[INFO] Wrote global selection: /media/alexandros/OS/Documents and Settings/alexp/Documents/Bachelor Thesis/report_figures/results/backward_selection/lat_100/t_200/h_100_t_200_iso_0.8/backward_selected_global.csv

Curvature bins: (-0.1, 0.1) -> clusters 0..2

--- Cluster 0 (CURVATURE BIN) | n=773 ---
Accuracy: R2=0.93