In [None]:
# === CLEAN TWO-STAGE CURVE MODELING PIPELINE ===
# Uses:
#   - Long_Normalized_All.xlsx : long format flux–time curves (possibly many sheets)
#   - CurveMeta_All.csv        : curve-level metadata (one row per article/curve/condition)
#
# Stage A: fit KWW to each curve -> J0, Jinf, tau, beta
# Stage B: metadata -> parameters models with GroupKFold (grouped by article_id)
# Stage C: reconstruct curves & compute curve-level MAE

import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from scipy.optimize import curve_fit
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import OneHotEncoder, StandardScaler
from sklearn.impute import SimpleImputer
from sklearn.ensemble import HistGradientBoostingRegressor
from sklearn.model_selection import GroupKFold, KFold, cross_val_score
from sklearn.metrics import mean_absolute_error

# -------------------------------------------------------------------
# 0) Paths and small helpers
# -------------------------------------------------------------------
LONG_XLSX = "Long_Normalized_All.xlsx"
META_CSV  = "CurveMeta_All.csv"

def norm_cols(df: pd.DataFrame) -> pd.DataFrame:
    """Strip, replace whitespace by '_', and lowercase all column names."""
    df = df.copy()
    df.columns = (df.columns
                  .str.strip()
                  .str.replace(r"\s+", "_", regex=True)
                  .str.lower())
    return df

def norm_id_series(s: pd.Series) -> pd.Series:
    """
    Normalize ID columns:
    - strip, lowercase
    - try to interpret as numeric -> integer -> string (so '5.0' -> '5')
    - collapse spaces to underscores
    """
    s = s.astype(str).str.strip().str.lower()
    asnum = pd.to_numeric(s, errors="coerce")
    s_num = asnum.astype("Int64").astype(str)  # '5' instead of '5.0'
    s = s.where(asnum.isna(), s_num)
    s = s.str.replace(r"\s+", " ", regex=True).str.replace(" ", "_")
    return s

# -------------------------------------------------------------------
# 1) Load data
# -------------------------------------------------------------------
if not os.path.exists(LONG_XLSX):
    raise FileNotFoundError(
        f"Time-series workbook '{LONG_XLSX}' not found. "
        "Upload it or adjust LONG_XLSX."
    )
if not os.path.exists(META_CSV):
    raise FileNotFoundError(
        f"Metadata file '{META_CSV}' not found. "
        "Upload it or adjust META_CSV."
    )

# Load all sheets from the long workbook and attach article_id if missing
sheets = pd.read_excel(LONG_XLSX, sheet_name=None)

long_parts = []
for sheet_name, df_sheet in sheets.items():
    df_sheet = df_sheet.copy()
    if "article_id" not in df_sheet.columns:
        df_sheet["article_id"] = sheet_name
    long_parts.append(df_sheet)

long = norm_cols(pd.concat(long_parts, ignore_index=True))
meta = norm_cols(pd.read_csv(META_CSV))

# Normalize ID columns (article_id, curve_id, conditionid)
for df_ in (long, meta):
    for c in ["article_id", "curve_id", "conditionid"]:
        if c in df_.columns:
            df_[c] = norm_id_series(df_[c])

print(f"Loaded LONG rows: {len(long)}")
print(f"Loaded META rows: {len(meta)}")

# -------------------------------------------------------------------
# 2) Time & target columns (NormFlux)
# -------------------------------------------------------------------
y_col = "normflux"

cand_min = ["t_min", "time_min", "minutes", "min"]
cand_sec = ["t_s", "t_sec", "time_s", "seconds", "sec", "t", "time"]

# auto-detect time column
t_col = next((c for c in cand_min if c in long.columns), None)
if t_col is None:
    for c in cand_sec:
        if c in long.columns:
            long[c] = pd.to_numeric(long[c], errors="coerce")
            long["t_min"] = long[c] / 60.0
            t_col = "t_min"
            break

if t_col is None:
    raise KeyError(
        "No time column found in LONG. "
        f"Expected one of {cand_min + cand_sec}, got {long.columns.tolist()}"
    )

# numeric coercion
long[y_col] = pd.to_numeric(long[y_col], errors="coerce")
long[t_col] = pd.to_numeric(long[t_col], errors="coerce")

# per-curve ID key (triple)
id_cols = ["article_id", "curve_id", "conditionid"]
need_cols = id_cols + [y_col, t_col]

long_use = long.dropna(subset=need_cols).copy()
print("Usable long rows (no NaN in ID/time/flux):", len(long_use))

# Data-driven upper bound for flux (prevents crazy 1.5 plateaus)
Y_UB = float(1.2 * np.nanpercentile(long_use[y_col], 99))
if not np.isfinite(Y_UB) or Y_UB <= 0:
    Y_UB = 1.5  # fallback
print(f"Flux upper bound (Y_UB): {Y_UB:.3f}")

# -------------------------------------------------------------------
# 3) Stage A: per-curve KWW fits
# -------------------------------------------------------------------
def kww(t, J0, Jinf, tau, beta):
    """
    Stretched-exponential (KWW):
        J(t) = Jinf + (J0 - Jinf) * exp( - (t/tau)^beta )
    """
    t = np.asarray(t, float)
    tau = max(float(tau), 1e-9)
    beta = float(beta)
    return Jinf + (J0 - Jinf) * np.exp(-np.power(np.clip(t, 0, None) / tau, beta))

def fit_one_curve(t, y):
    """Fit KWW to one curve, return dict with params and R²."""
    t = np.asarray(t, float)
    y = np.asarray(y, float)
    if len(t) < 3:
        return None

    # Initial guesses
    J0_guess   = float(np.nanmax(y))
    Jinf_guess = float(np.nanmin(y))
    tau_guess  = max(float(np.nanmedian(t)), 1e-3)
    beta_guess = 1.0
    p0 = [J0_guess, Jinf_guess, tau_guess, beta_guess]

    # Bounds based on data-driven Y_UB
    bounds = ([0.0, 0.0, 1e-6, 0.05],
              [Y_UB, Y_UB, 1e6, 3.0])

    try:
        popt, _ = curve_fit(kww, t, y, p0=p0, bounds=bounds, maxfev=20000)
    except Exception:
        popt = p0

    yhat = kww(t, *popt)
    ss_res = np.sum((y - yhat)**2)
    ss_tot = np.sum((y - np.mean(y))**2) + 1e-12
    r2 = 1.0 - ss_res / ss_tot

    return {
        "J0":   float(popt[0]),
        "Jinf": float(popt[1]),
        "tau":  float(popt[2]),
        "beta": float(popt[3]),
        "fit_r2": float(r2),
    }

MIN_POINTS = 5  # minimum points per curve to attempt a fit

fit_rows = []
for keys, g in long_use.groupby(id_cols):
    g = g.sort_values(t_col)
    if len(g) < MIN_POINTS:
        continue
    res = fit_one_curve(g[t_col].values, g[y_col].values)
    if res is None:
        continue
    rec = dict(zip(id_cols, keys))
    rec.update(res)
    fit_rows.append(rec)

params_df = pd.DataFrame(fit_rows)
print("Fitted curves:", len(params_df))

# -------------------------------------------------------------------
# 4) Join KWW params with META (triple key)
# -------------------------------------------------------------------
num_feats = ["p_bar", "oil_ppm", "salt_gl", "temp_c",
             "porosity", "pore_nm", "droplet_um"]
cat_feats = ["material", "geometry", "oil_type"]

# numeric coercion in META
for c in num_feats:
    if c in meta.columns:
        meta[c] = pd.to_numeric(meta[c], errors="coerce")

keep_cols = id_cols + [c for c in num_feats + cat_feats if c in meta.columns]

# one row per (article_id, curve_id, conditionid): first non-null value
meta_one = (
    meta[keep_cols]
    .groupby(id_cols, as_index=False)
    .agg({c: (lambda s: s.dropna().iloc[0] if s.dropna().size else np.nan)
          for c in keep_cols if c not in id_cols})
)

meta_full = params_df.merge(meta_one, on=id_cols, how="inner")
print("Rows after join (params + metadata):", len(meta_full))

if meta_full.empty:
    raise RuntimeError(
        "No overlap between LONG and META on (article_id, curve_id, conditionid). "
        "Check ID formatting in your Excel/CSV."
    )

# -------------------------------------------------------------------
# 5) Stage B: metadata -> parameter models
# -------------------------------------------------------------------
present_num = [c for c in num_feats if c in meta_full.columns]
present_cat = [c for c in cat_feats if c in meta_full.columns]

X = meta_full[present_num + present_cat].copy()
meta_full["log_tau"] = np.log(meta_full["tau"].clip(lower=1e-6))

targets = {
    "J0":      meta_full["J0"].astype(float).values,
    "Jinf":    meta_full["Jinf"].astype(float).values,
    "log_tau": meta_full["log_tau"].astype(float).values,
    "beta":    meta_full["beta"].astype(float).values,
}

num_pipe = Pipeline([
    ("imp", SimpleImputer(strategy="median")),
    ("sc",  StandardScaler())
])
cat_pipe = Pipeline([
    ("imp", SimpleImputer(strategy="most_frequent")),
    ("oh",  OneHotEncoder(handle_unknown="ignore"))
])

pre = ColumnTransformer(
    transformers=[
        ("num", num_pipe, present_num),
        ("cat", cat_pipe, present_cat),
    ],
    remainder="drop"
)

reg = HistGradientBoostingRegressor(
    learning_rate=0.06,
    max_leaf_nodes=31,
    min_samples_leaf=5,
    random_state=42
)

def make_cv(n_rows: int, groups: pd.Series | None):
    """Build GroupKFold by article_id if possible, else plain KFold."""
    if n_rows < 2:
        return None, {}
    if groups is not None and groups.nunique() >= 2:
        n_splits = min(5, groups.nunique(), n_rows)
        return GroupKFold(n_splits=n_splits), {"groups": groups}
    # fallback: plain KFold
    from sklearn.model_selection import KFold
    n_splits = min(5, max(2, min(3, n_rows)))
    return KFold(n_splits=n_splits, shuffle=True, random_state=42), {}

def eval_target(y_vec: np.ndarray):
    """Cross-validate and fit one parameter model."""
    mask = np.isfinite(y_vec)
    Xy = X.loc[mask]
    yy = y_vec[mask]
    n = len(Xy)
    pipe = Pipeline([("pre", pre), ("reg", reg)])

    if n < 2:
        model = pipe.fit(Xy, yy) if n >= 1 else pipe.fit(X, y_vec)
        return np.nan, np.nan, model

    groups = meta_full.loc[mask, "article_id"].astype(str) if "article_id" in meta_full.columns else None
    cv, kwargs = make_cv(n, groups)

    if cv is None:
        model = pipe.fit(Xy, yy)
        return np.nan, np.nan, model

    mae = -cross_val_score(pipe, Xy, yy, cv=cv,
                           scoring="neg_mean_absolute_error",
                           n_jobs=-1, **kwargs).mean()
    rmse = -cross_val_score(pipe, Xy, yy, cv=cv,
                            scoring="neg_root_mean_squared_error",
                            n_jobs=-1, **kwargs).mean()
    model = pipe.fit(Xy, yy)
    return mae, rmse, model

models = {}
print("\nCV (per-parameter):")
for name, y_vec in targets.items():
    mae, rmse, model = eval_target(y_vec)
    models[name] = model
    mae_s  = f"{mae:.4f}"  if np.isfinite(mae)  else "—"
    rmse_s = f"{rmse:.4f}" if np.isfinite(rmse) else "—"
    print(f"  {name:>7s}  MAE={mae_s}  RMSE={rmse_s}")

# -------------------------------------------------------------------
# 6) Reconstruct some curves from predicted parameters
# -------------------------------------------------------------------
def predict_params_for_row(xrow: pd.Series) -> dict:
    """Predict J0, Jinf, tau, beta from one metadata row."""
    xdf = pd.DataFrame([xrow], columns=X.columns)
    J0p   = float(np.clip(models["J0"].predict(xdf)[0],   0.0, Y_UB))
    Jinfp = float(np.clip(models["Jinf"].predict(xdf)[0], 0.0, Y_UB))
    taup  = float(np.exp(models["log_tau"].predict(xdf)[0]))
    betap = float(np.clip(models["beta"].predict(xdf)[0], 0.05, 3.0))
    return {"J0": J0p, "Jinf": Jinfp, "tau": max(taup, 1e-6), "beta": betap}

# Plot a few random curves
example_rows = meta_full[id_cols].sample(min(4, len(meta_full)), random_state=42)
fig, axes = plt.subplots(len(example_rows), 1, figsize=(6, 3*len(example_rows)))
if len(example_rows) == 1:
    axes = [axes]

for ax, (_, rec) in zip(axes, example_rows.iterrows()):
    a, c, k = rec["article_id"], rec["curve_id"], rec["conditionid"]

    g = (
        long_use[
            (long_use["article_id"] == a) &
            (long_use["curve_id"]   == c) &
            (long_use["conditionid"] == k)
        ][[t_col, y_col]]
        .dropna()
        .sort_values(t_col)
    )
    if g.empty:
        continue

    xr = (
        meta_full[
            (meta_full["article_id"] == a) &
            (meta_full["curve_id"]   == c) &
            (meta_full["conditionid"] == k)
        ][present_num + present_cat]
        .iloc[0]
    )

    p = predict_params_for_row(xr)

    tt = np.linspace(float(g[t_col].min()), float(g[t_col].max()), 200)
    yp = kww(tt, p["J0"], p["Jinf"], p["tau"], p["beta"])

    ax.plot(g[t_col], g[y_col], "o", ms=4, lw=1, label="true")
    ax.plot(tt, yp, lw=2, label="pred")
    ax.set_title(
        f"{c} | cond={k} | "
        f"J0={p['J0']:.3f}, Jinf={p['Jinf']:.3f}, tau={p['tau']:.3g}, beta={p['beta']:.2f}"
    )
    ax.set_xlabel("t_min")
    ax.set_ylabel("NormFlux")
    ax.legend()

plt.tight_layout()
plt.show()

# -------------------------------------------------------------------
# 7) Curve-level MAE (in-sample, using predicted params)
# -------------------------------------------------------------------
def curve_mae_triple(a, c, k):
    g = (
        long_use[
            (long_use["article_id"] == a) &
            (long_use["curve_id"]   == c) &
            (long_use["conditionid"] == k)
        ][[t_col, y_col]]
        .dropna()
        .sort_values(t_col)
    )
    if g.empty:
        return np.nan

    xr = (
        meta_full[
            (meta_full["article_id"] == a) &
            (meta_full["curve_id"]   == c) &
            (meta_full["conditionid"] == k)
        ][present_num + present_cat]
        .iloc[0]
    )
    p = predict_params_for_row(xr)
    y_true = g[y_col].values
    y_pred = kww(g[t_col].values, p["J0"], p["Jinf"], p["tau"], p["beta"])
    return mean_absolute_error(y_true, y_pred)

curve_errs = [
    curve_mae_triple(r["article_id"], r["curve_id"], r["conditionid"])
    for _, r in meta_full[id_cols].iterrows()
]
curve_errs = pd.Series(curve_errs, name="curve_MAE")
print("\nCurve-level reconstruction MAE (in-sample, using predicted params):")
print(curve_errs.describe())
