In [2]:
import os
os.chdir("..")   # go from notebooks/ to project root
from pathlib import Path
from flamekit.io_fronts import Case, load_fronts
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotly.graph_objects as go
from plotly.subplots import make_subplots

## Input Parameters

In [6]:
TIME_STEP = 200
PHI = 0.40
LAT_SIZE = "100"
ISOLEVELS = [0.6, 0.8]
N_ISOLEVELS = len(ISOLEVELS)
POST = True
BASE_DIR = Path("./isocontours")

In [7]:
from flamekit.io_fields import field_path
from flamekit.io_fronts import Case
# --- I/O (template style, with capitalized variables) ---
case = Case(
    base_dir=BASE_DIR,
    phi=PHI,
    lat_size=LAT_SIZE,
    time_step=TIME_STEP,
    post=POST,
)
dataset = load_fronts(case, ISOLEVELS)  # dict[iso] -> DataFrame


FileNotFoundError: [Errno 2] No such file or directory: 'isocontours/phi0.40/h400x100_ref/extracted_flame_front_post_200_iso_0.6.csv'

In [8]:
def interactive_front_3d_colored_plotly(
        dataset,
        c_val,
        var_x3d,
        var_y3d,
        var_z3d,
        color_var,
        x_col_phys="x",
        y_col_phys="y",
        colorscale="Inferno",
):
    """
    Interactive 2D/3D plot for a given isotherm/front, coloured by an arbitrary variable.
    """

    if isinstance(dataset, dict):
        sub_front = dataset[c_val]
    else:
        sub_front = dataset

    x_arr = sub_front[x_col_phys].to_numpy()
    y_arr = sub_front[y_col_phys].to_numpy()
    x3d_arr = sub_front[var_x3d].to_numpy()
    y3d_arr = sub_front[var_y3d].to_numpy()
    z3d_arr = sub_front[var_z3d].to_numpy()
    c_arr = sub_front[color_var].to_numpy()

    c_min = float(np.nanmin(c_arr))
    c_max = float(np.nanmax(c_arr))

    fig = make_subplots(
        rows=1,
        cols=2,
        specs=[[{"type": "xy"}, {"type": "scene"}]],
        subplot_titles=(
            f"Physical space (coloured by {color_var}) – c = {c_val}",
            f"State space 3D ({var_x3d}, {var_y3d}, {var_z3d}) – c = {c_val}",
        ),
    )

    sc_phys = go.Scatter(
        x=x_arr,
        y=y_arr,
        mode="markers",
        marker=dict(
            color=c_arr,
            colorscale=colorscale,
            cmin=c_min,
            cmax=c_max,
            size=6,
            showscale=False,
        ),
        name=f"Physical ({color_var})",
        hovertemplate=(
            f"{x_col_phys}=%{{x:.4f}}<br>"
            f"{y_col_phys}=%{{y:.4f}}<br>"
            f"{color_var}=%{{marker.color:.4e}}<extra></extra>"
        ),
    )
    fig.add_trace(sc_phys, row=1, col=1)

    sc_state3d = go.Scatter3d(
        x=x3d_arr,
        y=y3d_arr,
        z=z3d_arr,
        mode="markers",
        marker=dict(
            color=c_arr,
            colorscale=colorscale,
            cmin=c_min,
            cmax=c_max,
            size=4,
            showscale=True,
            colorbar=dict(
                title=color_var,
                x=1.02,
            ),
        ),
        name=f"State 3D ({color_var})",
        hovertemplate=(
            f"{var_x3d}=%{{x:.4e}}<br>"
            f"{var_y3d}=%{{y:.4e}}<br>"
            f"{var_z3d}=%{{z:.4e}}<br>"
            f"{color_var}=%{{marker.color:.4e}}<extra></extra>"
        ),
    )
    fig.add_trace(sc_state3d, row=1, col=2)

    idx0 = 0
    sel_phys = go.Scatter(
        x=[x_arr[idx0]],
        y=[y_arr[idx0]],
        mode="markers",
        marker=dict(
            size=14,
            symbol="circle-open",
            line=dict(width=2, color="black"),
        ),
        showlegend=False,
    )
    sel_state3d = go.Scatter3d(
        x=[x3d_arr[idx0]],
        y=[y3d_arr[idx0]],
        z=[z3d_arr[idx0]],
        mode="markers",
        marker=dict(
            size=8,
            symbol="circle-open",
            line=dict(width=3, color="black"),
        ),
        showlegend=False,
    )
    fig.add_trace(sel_phys, row=1, col=1)
    fig.add_trace(sel_state3d, row=1, col=2)

    fig = go.FigureWidget(fig)

    tr_phys = fig.data[0]
    tr_state3d = fig.data[1]
    tr_sel_phys = fig.data[2]
    tr_sel_state3d = fig.data[3]

    def update_selected(trace, points, selector):
        if not points.point_inds:
            return
        idx = int(points.point_inds[0])
        with fig.batch_update():
            tr_sel_phys.x = [x_arr[idx]]
            tr_sel_phys.y = [y_arr[idx]]
            tr_sel_state3d.x = [x3d_arr[idx]]
            tr_sel_state3d.y = [y3d_arr[idx]]
            tr_sel_state3d.z = [z3d_arr[idx]]

    tr_phys.on_click(update_selected)
    tr_state3d.on_click(update_selected)

    fig.update_xaxes(title_text=x_col_phys, row=1, col=1)
    fig.update_yaxes(title_text=y_col_phys, row=1, col=1)

    fig.update_scenes(
        xaxis_title=var_x3d,
        yaxis_title=var_y3d,
        zaxis_title=var_z3d,
        row=1,
        col=2,
    )

    fig.update_layout(
        height=650,
        width=1300,
        margin=dict(l=60, r=40, t=40, b=40),
    )

    return fig



In [9]:
fig_3d = interactive_front_3d_colored_plotly(
    dataset=dataset,
    c_val= 4.5,
    var_x3d="stretch_rate",
    var_y3d="curvature",
    var_z3d="DW_FDS",
    color_var="p",
    x_col_phys="x",
    y_col_phys="y",
    colorscale="Inferno",
)

fig_3d.show()

NameError: name 'dataset' is not defined

In [10]:
plt.scatter(
    dataset[3.5]["x"],
    dataset[3.5]["y"],
    c = abs(dataset[3.5]["gradT"]) - abs(dataset[3.5]["gradT_normal"]),
    cmap = "inferno"
)
plt.colorbar(label="|gradT| - (gradT·n)")
plt.xlabel("x")
plt.ylabel("y")
plt.axis("equal")
plt.show()


NameError: name 'dataset' is not defined

In [11]:
from __future__ import annotations

from flamekit.io_fields import field_path
from pathlib import Path
import numpy as np
import pandas as pd

from flamekit.io_fronts import Case, load_fronts

import plotly.graph_objects as go
import plotly.colors as pc


# ============================================================
# 1) LOAD MULTIPLE TIMESTEPS + MULTIPLE ISOTHERMS (ONE CASE)
# ============================================================
def load_fronts_multi_timesteps(
    base_dir: Path,
    phi: float,
    lat_size: str,
    timesteps: list[int],
    isolevels: list[float],
    post: bool = True,
    max_points_per_iso_per_ts: int | None = None,  # e.g. 80_000 for speed
    seed: int = 0,
) -> pd.DataFrame:
    """
    Returns one DataFrame with columns:
      ['timestep', 'isotherm', var_x, var_y, var_z, ...]
    created by concatenating dict[iso]->DF for each timestep.
    """
    rng = np.random.default_rng(seed)
    frames = []

    for ts in timesteps:
        case = Case(
            base_dir=base_dir,
            phi=phi,
            lat_size=lat_size,
            time_step=ts,
            post=post,
        )

        data_ts = load_fronts(case, isolevels)  # dict[iso] -> DataFrame

        for iso in isolevels:
            df = data_ts[iso].copy()
            df["timestep"] = int(ts)
            df["isotherm"] = float(iso)

            # optional downsample per (timestep, isotherm)
            if max_points_per_iso_per_ts is not None and len(df) > max_points_per_iso_per_ts:
                idx = rng.choice(df.index.to_numpy(), size=max_points_per_iso_per_ts, replace=False)
                df = df.loc[idx]

            frames.append(df)

    if not frames:
        raise ValueError("No data loaded. Check timesteps/isolevels and file availability.")

    return pd.concat(frames, ignore_index=True)


# ============================================================
# 2A) ONE STATIC 3D PLOT (ALL TIMESTEPS OVERLAID)
#     DISCRETE COLORS BY ISOTHERM
# ============================================================
def plot_state3d_discrete_isotherms_all_timesteps(
    df_all: pd.DataFrame,
    isolevels: list[float],
    var_x3d: str = "stretch_rate",
    var_y3d: str = "curvature",
    var_z3d: str = "DW_FDS",
    marker_size: int = 3,
):
    needed = {"timestep", "isotherm", var_x3d, var_y3d, var_z3d}
    missing = [c for c in needed if c not in df_all.columns]
    if missing:
        raise KeyError(f"Missing columns: {missing}")

    df = df_all[list(needed)].replace([np.inf, -np.inf], np.nan).dropna()

    # discrete color per isotherm (fixed across everything)
    palette = pc.qualitative.Plotly
    color_map = {float(iso): palette[i % len(palette)] for i, iso in enumerate(isolevels)}

    fig = go.Figure()

    for iso in isolevels:
        dfi = df[df["isotherm"] == float(iso)]
        if dfi.empty:
            continue

        fig.add_trace(
            go.Scatter3d(
                x=dfi[var_x3d],
                y=dfi[var_y3d],
                z=dfi[var_z3d],
                mode="markers",
                name=f"iso={float(iso):g}",
                marker=dict(size=marker_size, color=color_map[float(iso)]),
                hovertemplate=(
                    "isotherm=%{customdata[0]:g}<br>"
                    "timestep=%{customdata[1]}<br>"
                    f"{var_x3d}=%{{x:.4e}}<br>"
                    f"{var_y3d}=%{{y:.4e}}<br>"
                    f"{var_z3d}=%{{z:.4e}}<extra></extra>"
                ),
                customdata=np.c_[dfi["isotherm"].to_numpy(), dfi["timestep"].to_numpy()],
            )
        )

    fig.update_layout(
        height=750,
        width=1100,
        margin=dict(l=60, r=40, t=40, b=40),
        legend=dict(title="Isotherm"),
        scene=dict(
            xaxis_title=var_x3d,
            yaxis_title=var_y3d,
            zaxis_title=var_z3d,
        ),
    )
    return fig


# ============================================================
# 2B) ANIMATED 3D PLOT (SLIDER OVER TIMESTEPS)
#     DISCRETE COLORS BY ISOTHERM (constant across frames)
# ============================================================
def plot_state3d_discrete_isotherms_with_timestep_slider(
    df_all: pd.DataFrame,
    isolevels: list[float],
    timesteps: list[int],
    var_x3d: str = "stretch_rate",
    var_y3d: str = "curvature",
    var_z3d: str = "DW_FDS",
    marker_size: int = 3,
):
    needed = {"timestep", "isotherm", var_x3d, var_y3d, var_z3d}
    missing = [c for c in needed if c not in df_all.columns]
    if missing:
        raise KeyError(f"Missing columns: {missing}")

    df = df_all[list(needed)].replace([np.inf, -np.inf], np.nan).dropna()

    palette = pc.qualitative.Plotly
    color_map = {float(iso): palette[i % len(palette)] for i, iso in enumerate(isolevels)}

    # initial timestep
    ts0 = int(timesteps[0])

    def traces_for_timestep(ts: int):
        traces = []
        dfts = df[df["timestep"] == int(ts)]
        for iso in isolevels:
            dfi = dfts[dfts["isotherm"] == float(iso)]
            traces.append(
                go.Scatter3d(
                    x=dfi[var_x3d],
                    y=dfi[var_y3d],
                    z=dfi[var_z3d],
                    mode="markers",
                    name=f"iso={float(iso):g}",
                    marker=dict(size=marker_size, color=color_map[float(iso)]),
                    hovertemplate=(
                        "isotherm=%{customdata[0]:g}<br>"
                        "timestep=%{customdata[1]}<br>"
                        f"{var_x3d}=%{{x:.4e}}<br>"
                        f"{var_y3d}=%{{y:.4e}}<br>"
                        f"{var_z3d}=%{{z:.4e}}<extra></extra>"
                    ),
                    customdata=np.c_[dfi["isotherm"].to_numpy(), dfi["timestep"].to_numpy()],
                    showlegend=True,
                )
            )
        return traces

    fig = go.Figure(data=traces_for_timestep(ts0))

    frames = []
    for ts in timesteps:
        frames.append(go.Frame(name=str(int(ts)), data=traces_for_timestep(int(ts))))
    fig.frames = frames

    slider_steps = [
        dict(
            method="animate",
            args=[[str(int(ts))], {"mode": "immediate", "frame": {"duration": 0, "redraw": True}, "transition": {"duration": 0}}],
            label=str(int(ts)),
        )
        for ts in timesteps
    ]

    fig.update_layout(
        height=750,
        width=1100,
        margin=dict(l=60, r=40, t=40, b=40),
        legend=dict(title="Isotherm"),
        scene=dict(
            xaxis_title=var_x3d,
            yaxis_title=var_y3d,
            zaxis_title=var_z3d,
        ),
        updatemenus=[
            dict(
                type="buttons",
                direction="left",
                x=0.02,
                y=1.06,
                buttons=[
                    dict(label="Play", method="animate",
                         args=[None, {"fromcurrent": True, "frame": {"duration": 250, "redraw": True}, "transition": {"duration": 0}}]),
                    dict(label="Pause", method="animate",
                         args=[[None], {"mode": "immediate", "frame": {"duration": 0, "redraw": False}, "transition": {"duration": 0}}]),
                ],
            )
        ],
        sliders=[dict(active=0, x=0.02, y=1.02, len=0.96, steps=slider_steps, currentvalue={"prefix": "timestep: "})],
    )

    return fig


# ============================================================
# EXAMPLE: YOUR CASE SETTINGS
# ============================================================
BASE_DIR = Path("../isocontours")
PHI = 0.40
LAT_SIZE = "100"
POST = True
ISOLEVELS = [4.5]

TIMESTEPS = [200,210]
df_all = load_fronts_multi_timesteps(
    base_dir=BASE_DIR,
    phi=PHI,
    lat_size=LAT_SIZE,
    timesteps=TIMESTEPS,
    isolevels=ISOLEVELS,
    post=POST,
    max_points_per_iso_per_ts=80_000,  # set None to keep all points
)

# Option A: overlay all timesteps (hover shows timestep)
fig = plot_state3d_discrete_isotherms_all_timesteps(
    df_all=df_all,
    isolevels=ISOLEVELS,
    var_x3d="stretch_rate",
    var_y3d="curvature",
    var_z3d="DW_FDS",
    marker_size=3,
)
fig.show()

# Option B: timestep slider animation (uncomment if you prefer)
# fig_anim = plot_state3d_discrete_isotherms_with_timestep_slider(
#     df_all=df_all,
#     isolevels=ISOLEVELS,
#     timesteps=TIMESTEPS,
#     var_x3d="stretch_rate",
#     var_y3d="curvature",
#     var_z3d="DW_FDS",
#     marker_size=3,
# )
# fig_anim.show()


In [14]:
# ============================================================
# ADD: TRIM LOW/HIGH PERCENTILE OUTLIERS IN 3D (x,y,z) THEN PLOT
# - Removes points outside [q_low, q_high] for EACH of the 3 dims
# - Applies globally (across all timesteps/isotherms) by default
# ============================================================

def trim_outliers_by_percentile_3d(
    df_all: pd.DataFrame,
    var_x3d: str = "stretch_rate",
    var_y3d: str = "curvature",
    var_z3d: str = "DW_FDS",
    q_low: float = 0.01,     # remove lowest 1%
    q_high: float = 0.99,    # remove highest 1%
) -> pd.DataFrame:
    """
    Keeps only rows that fall within [q_low, q_high] quantiles for EACH dimension.
    This removes extreme tails in each axis (noise/outliers).

    Returns a filtered copy of df_all.
    """
    cols = [var_x3d, var_y3d, var_z3d, "timestep", "isotherm"]
    missing = [c for c in cols if c not in df_all.columns]
    if missing:
        raise KeyError(f"Missing columns: {missing}")

    df = df_all[cols].replace([np.inf, -np.inf], np.nan).dropna().copy()

    x_lo, x_hi = df[var_x3d].quantile([q_low, q_high]).to_numpy()
    y_lo, y_hi = df[var_y3d].quantile([q_low, q_high]).to_numpy()
    z_lo, z_hi = df[var_z3d].quantile([q_low, q_high]).to_numpy()

    mask = (
        df[var_x3d].between(x_lo, x_hi) &
        df[var_y3d].between(y_lo, y_hi) &
        df[var_z3d].between(z_lo, z_hi)
    )

    df_f = df.loc[mask].copy()

    removed = len(df) - len(df_f)
    print(
        f"Trimming percentiles per-dim: [{q_low:.3f}, {q_high:.3f}] -> "
        f"kept {len(df_f):,}/{len(df):,} (removed {removed:,})\n"
        f"  {var_x3d}: [{x_lo:.4e}, {x_hi:.4e}]\n"
        f"  {var_y3d}: [{y_lo:.4e}, {y_hi:.4e}]\n"
        f"  {var_z3d}: [{z_lo:.4e}, {z_hi:.4e}]"
    )

    return df_f


ISOLEVELS = [4, 4.5]
TIMESTEPS = [210, 200, 190, 180]  # <-- put your timesteps here
# -----------------------
# USE IT: FILTER THEN PLOT
# -----------------------
df_all_filt = trim_outliers_by_percentile_3d(
    df_all=df_all,
    var_x3d="stretch_rate",
    var_y3d="curvature",
    var_z3d="total_heat_conduction",
    q_low=0.01,      # change as desired (e.g., 0.02)
    q_high=0.99,     # change as desired (e.g., 0.98)
)

fig = plot_state3d_discrete_isotherms_all_timesteps(
    df_all=df_all_filt,
    isolevels=ISOLEVELS,
    var_x3d="stretch_rate",
    var_y3d="curvature",
    var_z3d="total_heat_conduction",
    marker_size=3,
)
fig.show()

# (Optional) same filtering works with the animated slider too:
# fig_anim = plot_state3d_discrete_isotherms_with_timestep_slider(
#     df_all=df_all_filt,
#     isolevels=ISOLEVELS,
#     timesteps=TIMESTEPS,
#     var_x3d="stretch_rate",
#     var_y3d="curvature",
#     var_z3d="DW_FDS",
#     marker_size=3,
# )
# fig_anim.show()


Trimming percentiles per-dim: [0.010, 0.990] -> kept 16,838/17,474 (removed 636)
  stretch_rate: [-3.4383e+00, 4.5987e+00]
  curvature: [-1.0928e+00, 8.7000e-01]
  total_heat_conduction: [-1.6350e+01, 3.4479e-01]


In [14]:
import numpy as np
import pandas as pd

from sklearn.model_selection import GroupKFold, KFold, cross_val_score, train_test_split
from sklearn.pipeline import Pipeline
from sklearn.impute import SimpleImputer
from sklearn.metrics import r2_score, mean_squared_error
from sklearn.ensemble import HistGradientBoostingRegressor
from sklearn.inspection import permutation_importance


def pick_best_third_variable(
    df_all: pd.DataFrame,
    target: str = "DW_FDS",
    base_features: tuple[str, str] = ("stretch_rate", "curvature"),
    group_col: str | None = "timestep",   # set None to disable grouping
    exclude_cols: set[str] | None = None,
    min_nonnull_frac: float = 0.90,
    n_splits: int = 5,
    random_state: int = 0,
    scoring: str = "r2",
    top_k: int = 10,
):
    """
    Finds the best "third variable" v from df_all to predict Sd using:
        Sd ~ f(stretch_rate, curvature, v)
    by evaluating CV score for each candidate v and selecting the best.

    Returns: best_var, results_df (sorted), baseline_mean, baseline_std
    """

    if exclude_cols is None:
        exclude_cols = set()
    # Common identifiers/coords you usually don't want as "physics variable"
    exclude_cols |= {
        "timestep", "isotherm", "x", "y", "z", "nx", "ny", "nz",
        "i", "j", "k", "id", "cell_id", "point_id",
    }

    # Keep numeric only (robust)
    numeric_cols = df_all.select_dtypes(include=[np.number]).columns.tolist()

    # Basic required
    for c in [target, *base_features]:
        if c not in df_all.columns:
            raise KeyError(f"Missing required column: {c}")

    # Build candidate list
    candidates = []
    for c in numeric_cols:
        if c == target:
            continue
        if c in base_features:
            continue
        if c in exclude_cols:
            continue
        candidates.append(c)

    if not candidates:
        raise ValueError("No candidate variables found. Check exclude list / numeric columns.")

    # Drop rows with NaN in target or base_features (we can impute candidates but not target)
    cols_need = [target, *base_features]
    df = df_all.copy()
    df = df.replace([np.inf, -np.inf], np.nan)
    df = df.dropna(subset=cols_need)

    y = df[target].to_numpy()

    # groups (avoid leakage across timesteps if you have strong temporal correlation)
    groups = None
    if group_col is not None and group_col in df.columns:
        groups = df[group_col].to_numpy()

    # Model: strong non-linear regressor, fast, robust
    model = HistGradientBoostingRegressor(
        max_depth=6,
        learning_rate=0.05,
        max_iter=400,
        random_state=random_state,
    )

    pipe = Pipeline([
        ("impute", SimpleImputer(strategy="median")),
        ("model", model),
    ])

    # CV splitter
    if groups is not None:
        cv = GroupKFold(n_splits=n_splits)
        cv_args = dict(cv=cv, groups=groups)
    else:
        cv = KFold(n_splits=n_splits, shuffle=True, random_state=random_state)
        cv_args = dict(cv=cv)

    # Baseline: only (stretch, curvature)
    X_base = df[list(base_features)]
    base_scores = cross_val_score(pipe, X_base, y, scoring=scoring, **cv_args)
    baseline_mean = float(np.mean(base_scores))
    baseline_std = float(np.std(base_scores))

    rows = []
    for v in candidates:
        # Skip nearly-empty columns
        nonnull_frac = df[v].notna().mean()
        if nonnull_frac < min_nonnull_frac:
            continue

        # Skip constant columns
        if df[v].nunique(dropna=True) <= 2:
            continue

        X = df[list(base_features) + [v]]
        scores = cross_val_score(pipe, X, y, scoring=scoring, **cv_args)
        rows.append({
            "var": v,
            "mean_score": float(np.mean(scores)),
            "std_score": float(np.std(scores)),
            "delta_vs_baseline": float(np.mean(scores) - baseline_mean),
            "nonnull_frac": float(nonnull_frac),
        })

    if not rows:
        raise ValueError("All candidates were filtered out (nonnull/constant). Relax filters.")

    results = pd.DataFrame(rows).sort_values("mean_score", ascending=False).reset_index(drop=True)
    best_var = results.loc[0, "var"]

    print(f"Baseline ({base_features[0]}, {base_features[1]}) CV {scoring}: "
          f"{baseline_mean:.4f} ± {baseline_std:.4f}")
    print(f"Best third variable: {best_var}")
    print("\nTop candidates:")
    print(results.head(top_k))

    return best_var, results, baseline_mean, baseline_std


def train_final_model_with_best_var(
    df_all: pd.DataFrame,
    best_var: str,
    target: str = "DW_FDS",
    base_features: tuple[str, str] = ("stretch_rate", "curvature"),
    test_size: float = 0.2,
    random_state: int = 0,
):
    """
    Trains a final non-linear model with (stretch, curvature, best_var) and reports performance.
    """
    df = df_all.replace([np.inf, -np.inf], np.nan).dropna(subset=[target, *base_features, best_var]).copy()

    X = df[list(base_features) + [best_var]]
    y = df[target].to_numpy()

    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=test_size, random_state=random_state, shuffle=True
    )

    model = HistGradientBoostingRegressor(
        max_depth=6,
        learning_rate=0.05,
        max_iter=600,
        random_state=random_state,
    )

    pipe = Pipeline([
        ("impute", SimpleImputer(strategy="median")),
        ("model", model),
    ])

    pipe.fit(X_train, y_train)
    y_pred = pipe.predict(X_test)

    r2 = r2_score(y_test, y_pred)
    rmse = mean_squared_error(y_test, y_pred, squared=False)
    print(f"\nFinal model with [{base_features[0]}, {base_features[1]}, {best_var}]")
    print(f"Holdout R2  = {r2:.4f}")
    print(f"Holdout RMSE = {rmse:.4e}")

    # Optional: permutation importance on the holdout set
    imp = permutation_importance(pipe, X_test, y_test, n_repeats=5, random_state=random_state, scoring="r2")
    imp_df = pd.DataFrame({
        "feature": X.columns,
        "perm_importance_mean": imp.importances_mean,
        "perm_importance_std": imp.importances_std
    }).sort_values("perm_importance_mean", ascending=False)

    print("\nPermutation importance (holdout):")
    print(imp_df)

    return pipe, imp_df


# ============================================================
# USAGE (assumes df_all already exists from your loader)
# ============================================================

best_var, results, base_mu, base_sig = pick_best_third_variable(
    df_all=df_all,
    target="DW_FDS",
    base_features=("stretch_rate", "curvature"),
    group_col="timestep",          # set None if you don't want grouping
    min_nonnull_frac=0.90,
    n_splits=3,
    scoring="r2",
    top_k=10,
)

final_model, importance_df = train_final_model_with_best_var(
    df_all=df_all,
    best_var=best_var,
    target="DW_FDS",
    base_features=("stretch_rate", "curvature"),
    test_size=0.2,
)


Baseline (stretch_rate, curvature) CV r2: 0.7121 ± 0.0119
Best third variable: H2O

Top candidates:
                      var  mean_score  std_score  delta_vs_baseline  \
0                     H2O    0.940829   0.021446           0.228704   
1     density_ratio_sigma    0.938813   0.018715           0.226688   
2                     HO2    0.931616   0.019845           0.219490   
3                      H2    0.929756   0.019406           0.217631   
4      O2_diffusion_total    0.929566   0.015883           0.217441   
5                      O2    0.929532   0.021287           0.217406   
6                 phi_loc    0.928630   0.019753           0.216505   
7                omega_OH    0.927339   0.017529           0.215214   
8   total_heat_conduction    0.923592   0.017697           0.211467   
9  heat_conduction_normal    0.922611   0.017972           0.210486   

   nonnull_frac  
0           1.0  
1           1.0  
2           1.0  
3           1.0  
4           1.0  
5         

TypeError: got an unexpected keyword argument 'squared'

In [3]:
import sys
sys.path.append("..")

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import DataLoader, TensorDataset
from sklearn.preprocessing import StandardScaler

from flamekit.autoencoder import AutoEncoder
from flamekit.GLV_autoencoder import SNMLP, MLP


# -----------------------
# DEVICE
# -----------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
if device.type == "cuda":
    print("GPU:", torch.cuda.get_device_name(0))

# ============================================================
# BUILD X WITH ONLY [Sd, curvature, stretch_rate]
# ============================================================
FEATURES = ["DW_FDS", "curvature", "stretch_rate"]

# df_all must exist already (multi timestep + multi isotherm)
needed_meta = []
if "timestep" in df_all.columns:
    needed_meta.append("timestep")
if "isotherm" in df_all.columns:
    needed_meta.append("isotherm")

df_use = df_all[FEATURES + needed_meta].replace([np.inf, -np.inf], np.nan).dropna().copy()

# Keep labels for plots (optional)
isotherm_np = df_use["isotherm"].astype(float).to_numpy() if "isotherm" in df_use.columns else None
timestep_np = df_use["timestep"].astype(int).to_numpy() if "timestep" in df_use.columns else None

# Standardize features
scaler_X = StandardScaler()
X = scaler_X.fit_transform(df_use[FEATURES].to_numpy())
print("X shape:", X.shape)

X_torch = torch.from_numpy(X).float()
loader = DataLoader(TensorDataset(X_torch), batch_size=4096, shuffle=True, pin_memory=True)

num_points, D_orig = X.shape  # D_orig = 3


# ============================================================
# 1) STANDARD AUTOENCODER (LATENT=2)
# ============================================================
num_latent = 2
num_hidden = 64

ae = AutoEncoder(D_orig, num_latent, num_hidden).to(device)
optimizer = optim.AdamW(ae.parameters(), lr=1e-3, weight_decay=1e-2)
criterion = nn.MSELoss()

epochs = 4000
log_every = 250

use_amp = (device.type == "cuda")
if use_amp:
    from torch import amp
    torch.backends.cudnn.benchmark = True
    scaler_amp = amp.GradScaler("cuda")

ae.train()
for epoch in range(epochs):
    running = 0.0
    for (batch,) in loader:
        batch = batch.to(device, non_blocking=True)
        optimizer.zero_grad(set_to_none=True)

        if use_amp:
            with amp.autocast("cuda"):
                decoded, encoded = ae(batch)
                loss = criterion(decoded, batch)
            scaler_amp.scale(loss).backward()
            scaler_amp.step(optimizer)
            scaler_amp.update()
        else:
            decoded, encoded = ae(batch)
            loss = criterion(decoded, batch)
            loss.backward()
            optimizer.step()

        running += loss.item() * batch.size(0)

    if (epoch + 1) % log_every == 0:
        print(f"[AE-2D] Epoch {epoch+1:4d}/{epochs}  loss={running/num_points:.6e}")

ae.eval()
with torch.no_grad():
    X_gpu = X_torch.to(device)
    X_rec, Z = ae(X_gpu)
    rec_mse = F.mse_loss(X_rec, X_gpu).item()
Z_np = Z.detach().cpu().numpy()
print(f"[AE-2D] final recon MSE = {rec_mse:.6e}")

# Plot latent (colored by isotherm if present)
plt.figure(figsize=(7, 6))
if isotherm_np is None:
    plt.scatter(Z_np[:, 0], Z_np[:, 1], s=3, alpha=0.35)
else:
    for iso in np.unique(isotherm_np):
        m = (isotherm_np == iso)
        plt.scatter(Z_np[m, 0], Z_np[m, 1], s=3, alpha=0.35, label=f"iso={iso:g}")
    plt.legend(markerscale=4, fontsize=9, ncol=2, frameon=True)

plt.gca().set_aspect("equal", adjustable="box")
plt.xlabel("z1")
plt.ylabel("z2")
plt.title("AE latent space (2D) using [Sd, curvature, stretch]")
plt.tight_layout()
plt.show()


# ============================================================
# 2) LEAST-VOLUME AE + DYNAMIC PRUNING (EFFECTIVE 2D)
#    (Enc: 3 -> latent_dim_full, Dec: latent_dim_full -> 3)
# ============================================================
ambient_dim = D_orig
latent_dim_full = 16        # intentionally > 2, so LV+DP can collapse/prune
target_latent = 2

width = 256
enc = MLP(ambient_dim, latent_dim_full, [width] * 4).to(device)
dec = SNMLP(latent_dim_full, ambient_dim, [width] * 4).to(device)
opt = torch.optim.Adam(list(enc.parameters()) + list(dec.parameters()), lr=1e-3)

eta = 1e-2   # stability for log(std+eta)
lam = 1e-2   # volume regularization strength

total_epochs = 8000
log_every = 250

# pruning schedule
warmup = 1500
prune_every = 250

# hard mask (DP)
mask = torch.ones(latent_dim_full, device=device)

def current_k(epoch: int) -> int:
    if epoch < warmup:
        return latent_dim_full
    return target_latent

X_gpu = X_torch.to(device)

enc.train()
dec.train()
for epoch in range(total_epochs):
    opt.zero_grad(set_to_none=True)

    z = enc(X_gpu)                 # (N, latent_dim_full)
    z_masked = z * mask            # prune via mask
    X_hat = dec(z_masked)

    rec_loss = F.mse_loss(X_hat, X_gpu)

    std = z_masked.std(dim=0)
    vol_loss = torch.exp(torch.log(std + eta).mean())
    loss = rec_loss + lam * vol_loss

    loss.backward()
    opt.step()

    # dynamic pruning
    if epoch >= warmup and ((epoch + 1) % prune_every == 0):
        with torch.no_grad():
            std_now = (enc(X_gpu) * mask).std(dim=0)
            k = current_k(epoch)
            keep_idx = torch.topk(std_now, k=k, largest=True).indices
            new_mask = torch.zeros_like(mask)
            new_mask[keep_idx] = 1.0
            mask = new_mask
            print(f"[LV+DP] Epoch {epoch+1:4d}: prune -> keep k={k} dims: {keep_idx.detach().cpu().numpy()}")

    if (epoch + 1) % log_every == 0:
        print(f"[LV+DP] Epoch {epoch+1:4d}/{total_epochs}  rec={rec_loss.item():.5e}  vol={vol_loss.item():.5e}")

# Evaluate + choose 2 most active dims (should match pruned dims)
enc.eval()
dec.eval()
with torch.no_grad():
    z_full = enc(X_gpu)
    z_eff = z_full * mask
    X_rec2 = dec(z_eff)
    rec_mse2 = F.mse_loss(X_rec2, X_gpu).item()
    std_final = z_eff.std(dim=0)
    idx_sorted = torch.argsort(std_final, descending=True)
    top2 = idx_sorted[:2].detach().cpu().numpy()

print(f"[LV+DP] final recon MSE = {rec_mse2:.6e}")
print(f"[LV+DP] active dims (top10) = {idx_sorted[:10].detach().cpu().numpy()}, plotting dims = {top2}")

Z_eff_np = z_eff.detach().cpu().numpy()

plt.figure(figsize=(7, 6))
if isotherm_np is None:
    plt.scatter(Z_eff_np[:, top2[0]], Z_eff_np[:, top2[1]], s=3, alpha=0.35)
else:
    for iso in np.unique(isotherm_np):
        m = (isotherm_np == iso)
        plt.scatter(Z_eff_np[m, top2[0]], Z_eff_np[m, top2[1]], s=3, alpha=0.35, label=f"iso={iso:g}")
    plt.legend(markerscale=4, fontsize=9, ncol=2, frameon=True)

plt.gca().set_aspect("equal", adjustable="box")
plt.xlabel(f"z[{top2[0]}]")
plt.ylabel(f"z[{top2[1]}]")
plt.title("Least-volume AE + Dynamic Pruning (effective 2D) on [Sd, curvature, stretch]")
plt.tight_layout()
plt.show()

# latent activity plot
std_np = std_final.detach().cpu().numpy()
plt.figure(figsize=(8, 4))
plt.bar(np.arange(len(std_np)), np.sort(std_np)[::-1])
plt.xlabel("latent dims sorted by std")
plt.ylabel("std")
plt.title("LV+DP: latent activity decay (std)")
plt.tight_layout()
plt.show()


Using device: cuda
GPU: NVIDIA GeForce RTX 3050 4GB Laptop GPU
X shape: (63412, 3)
[AE-2D] Epoch  250/4000  loss=2.649268e-03
[AE-2D] Epoch  500/4000  loss=2.191058e-03
[AE-2D] Epoch  750/4000  loss=1.984516e-03
[AE-2D] Epoch 1000/4000  loss=2.182999e-03
[AE-2D] Epoch 1250/4000  loss=1.985698e-03
[AE-2D] Epoch 1500/4000  loss=1.658070e-03
[AE-2D] Epoch 1750/4000  loss=1.698249e-03
[AE-2D] Epoch 2000/4000  loss=1.790653e-03


KeyboardInterrupt: 