In [None]:
# =========================================
# 🧱 TILE LINE — VARIOUS PREDICTIONS + ACCURACY + STAGE FLAGS + SHAP XAI
#
# ✅ What this does
# - Trains ML models (RandomForest) to predict stage ratios:
#     press_ratio = vol_press_out  / vol_start
#     glaze_ratio = vol_glaze_out  / vol_press_out
#     kiln_ratio  = vol_kiln_out   / vol_glaze_out
# - Creates a TEST set with mixed starting points (Start / Press / Glaze / Kiln / Full).
# - Predicts ONLY forward-missing stages.
# - Builds per-stage “minimum required” = (predicted ratio - 0.03), clipped to [0,1].
# - Hard rule for Sort ONLY: minimum ratio = 0.90 (no ML for Sort).
# - Flags rows where a *predicted* stage falls below its min (e.g., LOW_KILN_YIELD).
# - Optionally explains flagged *ML* stages with SHAP (top negative contributors).
#
# 📦 Artifacts written:
#   - tiles_staged_test_input.csv          (masked test)
#   - tiles_staged_test_predictions.csv    (filled predictions)
#   - tiles_predictions_friendly.csv       (human-friendly per-row report)
#   - tiles_predictions_compact.txt/.csv   (one-line summaries)
#   - tiles_predictions_eval.csv           (per-stage eval vs ACTUAL)
#   - tiles_flags_only.csv                 (rows with any stage flag, NO hardcoded spec flags)
#   - tiles_predictions_explanations.csv   (SHAP top negative features, if shap installed)
#   - tiles_predictions_explanations.json  (same in JSON, if shap installed)
#   - tile_models_bundle.joblib            (bundle to load in your UI service)
# =========================================

import os, json
from datetime import datetime

import numpy as np
import pandas as pd

from sklearn.model_selection import train_test_split
from sklearn.compose import ColumnTransformer
from sklearn.preprocessing import OneHotEncoder
from sklearn.impute import SimpleImputer
from sklearn.pipeline import Pipeline
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import r2_score, mean_absolute_error

# ---------- SHAP (Explainability) ----------
HAS_SHAP = False
try:
    import shap
    HAS_SHAP = True
except Exception as e:
    print("[WARN] SHAP not installed; skipping XAI. Install with: pip install shap")
    shap = None

# -----------------------
# CONFIG
# -----------------------
DATA_PATH   = "tiles6.csv"   # <--- your input file
RANDOM_STATE = 42
TEST_SIZE    = 0.15

# Staged masking mix (sum to 1.0):
# 0 = Start only, 1 = up to Press, 2 = up to Glaze, 3 = up to Kiln, 4 = Full known
MASK_PROBS = [0.30, 0.25, 0.25, 0.15, 0.05]

# Sort policy (ONLY hard-coded flag)
SORT_MIN     = 0.90
SORT_DEFAULT = 0.90
SORT_MAX     = 0.92

# per-stage min margin around ML predictions
THRESHOLD_MARGIN = 0.03

# filenames
OUT_MASKED     = "tiles_staged_test_input.csv"
OUT_PRED       = "tiles_staged_test_predictions.csv"
OUT_FRIENDLY   = "tiles_predictions_friendly.csv"
OUT_COMPACT_C  = "tiles_predictions_compact.csv"
OUT_COMPACT_T  = "tiles_predictions_compact.txt"
OUT_EVAL       = "tiles_predictions_eval.csv"
FLAGS_ONLY_OUT = "tiles_flags_only.csv"
XAI_OUT_CSV    = "tiles_predictions_explanations.csv"
XAI_OUT_JSON   = "tiles_predictions_explanations.json"
BUNDLE_PATH    = "tile_models_bundle.joblib"

# -----------------------
# HELPERS
# -----------------------
def safe_ratio(num, den):
    num = pd.to_numeric(num, errors="coerce")
    den = pd.to_numeric(den, errors="coerce")
    r = np.where((den <= 0) | pd.isna(den), np.nan, num / den)
    return np.clip(r, 0.0, 1.0)

def clip01(x):
    return np.minimum(1.0, np.maximum(0.0, x))

def volume_from_ratio(prev_vol, ratio, round_to_int=True):
    v = prev_vol * ratio
    return np.rint(v) if round_to_int else v

def as_int_nullable(values):
    s = pd.Series(values)
    s = pd.to_numeric(s, errors="coerce")
    return s.round().astype("Int64")

def met_min_str(pred_mask, actual_series, minreq_series):
    actual = pd.to_numeric(actual_series, errors="coerce").astype(float)
    minreq = pd.to_numeric(minreq_series,  errors="coerce").astype(float)
    a = actual.to_numpy()
    m = minreq.to_numpy()
    pred_arr = np.asarray(pred_mask, dtype=bool)
    has_both = (~np.isnan(a)) & (~np.isnan(m))
    met = a >= m
    out = np.full(len(pred_arr), "", dtype=object)
    idx = pred_arr
    out[idx] = np.where(has_both[idx], np.where(met[idx], "YES", "NO"), "")
    return pd.Series(out, index=actual_series.index)

def fmt_int(x):
    try:
        v = pd.to_numeric(x)
        if pd.isna(v): return "-"
        return f"{int(round(float(v))):,}"
    except Exception:
        return "-"

# -----------------------
# HARD-CODED RECIPE SETPOINTS (memory for features only)
# -----------------------
RECIPE_SETPOINTS = {
    241: dict(k_set_max_temp=1197.0, k_set_cooling=5.2, k_set_moisture=6.0,
              k_set_humidity=50.0, k_set_air_flow=6.0, k_set_air_cooling=5.0, k_set_thickness_mm=8.5),
    244: dict(k_set_max_temp=1195.0, k_set_cooling=5.0, k_set_moisture=6.0,
              k_set_humidity=48.0, k_set_air_flow=5.8, k_set_air_cooling=5.1, k_set_thickness_mm=8.5),
    245: dict(k_set_max_temp=1196.0, k_set_cooling=5.3, k_set_moisture=6.0,
              k_set_humidity=52.0, k_set_air_flow=6.2, k_set_air_cooling=4.9, k_set_thickness_mm=8.5),
    246: dict(k_set_max_temp=1194.0, k_set_cooling=4.9, k_set_moisture=6.0,
              k_set_humidity=47.0, k_set_air_flow=5.7, k_set_air_cooling=5.2, k_set_thickness_mm=8.5),
    247: dict(k_set_max_temp=1198.0, k_set_cooling=5.1, k_set_moisture=6.0,
              k_set_humidity=49.0, k_set_air_flow=6.1, k_set_air_cooling=5.0, k_set_thickness_mm=8.5),
}

# -----------------------
# 1) LOAD
# -----------------------
df = pd.read_csv(DATA_PATH, parse_dates=["datetime"])
print(f"Loaded {len(df)} rows from {DATA_PATH}")

required_cols = [
    "datetime","vol_start",
    "pressure_psi","vol_press_out",
    "recipe_id","vol_glaze_out",
    "max_temp","cooling_profile","moisture_pct","external_humidity",
    "air_flow_top_setting","air_cooling","thickness_mm","vol_kiln_out",
    "vol_sort_out"
]
missing = [c for c in required_cols if c not in df.columns]
if missing:
    raise ValueError(f"Missing required columns: {missing}")

# enforce monotonic safety
df["vol_press_out"] = np.minimum(df["vol_press_out"], df["vol_start"])
df["vol_glaze_out"] = np.minimum(df["vol_glaze_out"], df["vol_press_out"])
df["vol_kiln_out"]  = np.minimum(df["vol_kiln_out"],  df["vol_glaze_out"])
df["vol_sort_out"]  = np.minimum(df["vol_sort_out"],  df["vol_kiln_out"])

# attach setpoints + deltas
sp = pd.DataFrame.from_dict(RECIPE_SETPOINTS, orient="index").reset_index().rename(columns={"index":"recipe_id"})
df = df.merge(sp, on="recipe_id", how="left")

df["delta_max_temp"]    = df["max_temp"]             - df["k_set_max_temp"]
df["delta_cooling"]     = df["cooling_profile"]      - df["k_set_cooling"]
df["delta_moisture"]    = df["moisture_pct"]         - df["k_set_moisture"]
df["delta_humidity"]    = df["external_humidity"]    - df["k_set_humidity"]
df["delta_air_flow"]    = df["air_flow_top_setting"] - df["k_set_air_flow"]
df["delta_air_cooling"] = df["air_cooling"]          - df["k_set_air_cooling"]

# targets
df["press_ratio"] = safe_ratio(df["vol_press_out"], df["vol_start"])
df["glaze_ratio"] = safe_ratio(df["vol_glaze_out"], df["vol_press_out"])
df["kiln_ratio"]  = safe_ratio(df["vol_kiln_out"],  df["vol_glaze_out"])
df["sort_ratio"]  = safe_ratio(df["vol_sort_out"],  df["vol_kiln_out"])

# -----------------------
# 2) FEATURE SETS
# -----------------------
PRESS_FEATS = ["vol_start", "pressure_psi", "recipe_id"]
GLAZE_FEATS = ["recipe_id", "vol_press_out", "pressure_psi"]
KILN_FEATS  = [
    "recipe_id", "vol_glaze_out",
    "k_set_max_temp","k_set_cooling","k_set_moisture","k_set_humidity",
    "k_set_air_flow","k_set_air_cooling","k_set_thickness_mm",
    "max_temp","cooling_profile","moisture_pct","external_humidity",
    "air_flow_top_setting","air_cooling","thickness_mm",
    "delta_max_temp","delta_cooling","delta_moisture","delta_humidity",
    "delta_air_flow","delta_air_cooling"
]
STAGE_TO_FEATS = {
    "press_ratio": PRESS_FEATS,
    "glaze_ratio": GLAZE_FEATS,
    "kiln_ratio":  KILN_FEATS,
}

def make_pipe(feat_cols):
    cat_cols = [c for c in feat_cols if c == "recipe_id"]
    num_cols = [c for c in feat_cols if c not in cat_cols]
    pre = ColumnTransformer(
        transformers=[
            ("cat", OneHotEncoder(handle_unknown="ignore"), cat_cols),
            ("num", SimpleImputer(strategy="median"), num_cols),
        ],
        remainder="drop"
    )
    model = RandomForestRegressor(
        n_estimators=500, random_state=RANDOM_STATE, n_jobs=-1
    )
    return Pipeline([("pre", pre), ("rf", model)])

# -----------------------
# 3) TRAIN / TEST SPLIT
# -----------------------
train_df, test_df = train_test_split(df, test_size=TEST_SIZE, random_state=RANDOM_STATE, shuffle=True)
train_df = train_df.reset_index(drop=True)
test_df  = test_df.reset_index(drop=True)
print(f"Train: {len(train_df)} | Test: {len(test_df)}")

# -----------------------
# 4) FIT MODELS
# -----------------------
models = {}
for target, feats in STAGE_TO_FEATS.items():
    tdf = train_df.copy()
    mask = ~tdf[target].isna()
    X = tdf.loc[mask, feats]
    y = tdf.loc[mask, target]
    pipe = make_pipe(feats)
    pipe.fit(X, y)
    models[target] = pipe
    y_pred = clip01(pipe.predict(X))
    print(f"[FIT] {target}: R2={r2_score(y, y_pred):.3f} MAE={mean_absolute_error(y, y_pred):.4f} n={len(X)}")

# -----------------------
# 5) STAGED MASKING (various prediction starting points)
# -----------------------
rng = np.random.default_rng(RANDOM_STATE)
levels = rng.choice([0,1,2,3,4], size=len(test_df), p=MASK_PROBS)
masked = test_df.copy()
masked["known_level"] = levels

for i, lvl in enumerate(levels):
    if lvl < 1:
        masked.loc[i, ["vol_press_out","vol_glaze_out","vol_kiln_out","vol_sort_out"]] = np.nan
    elif lvl < 2:
        masked.loc[i, ["vol_glaze_out","vol_kiln_out","vol_sort_out"]] = np.nan
    elif lvl < 3:
        masked.loc[i, ["vol_kiln_out","vol_sort_out"]] = np.nan
    elif lvl < 4:
        masked.loc[i, ["vol_sort_out"]] = np.nan

truth = test_df[["vol_start","vol_press_out","vol_glaze_out","vol_kiln_out","vol_sort_out"]].copy()

print("\nMasked start-points (test):")
print(pd.Series(levels).map({0:"Start",1:"Press",2:"Glaze",3:"Kiln",4:"Full"}).value_counts().sort_index().to_string())

# -----------------------
# 6) PREDICT ONLY MISSING FORWARD
# -----------------------
pred = masked.copy()

def predict_stage(mdf, stage_key, feat_cols, prev_col, out_col):
    need = mdf[out_col].isna() & mdf[prev_col].notna()
    if need.any():
        r = clip01(models[stage_key].predict(mdf.loc[need, feat_cols]))
        mdf.loc[need, f"{stage_key}_pred"] = r
        mdf.loc[need, f"{stage_key}_min"]  = clip01(r - THRESHOLD_MARGIN)
        mdf.loc[need, f"{out_col}_pred"] = volume_from_ratio(mdf.loc[need, prev_col].values, r, round_to_int=True)
        mdf.loc[need, out_col] = mdf.loc[need, f"{out_col}_pred"]
    return mdf

# Press
pred = predict_stage(pred, "press_ratio", PRESS_FEATS, prev_col="vol_start",     out_col="vol_press_out")
pred["vol_press_out"] = np.minimum(pred["vol_press_out"], pred["vol_start"])
# Glaze
pred = predict_stage(pred, "glaze_ratio", GLAZE_FEATS, prev_col="vol_press_out", out_col="vol_glaze_out")
pred["vol_glaze_out"] = np.minimum(pred["vol_glaze_out"], pred["vol_press_out"])
# Kiln
pred = predict_stage(pred, "kiln_ratio",  KILN_FEATS,  prev_col="vol_glaze_out", out_col="vol_kiln_out")
pred["vol_kiln_out"]  = np.minimum(pred["vol_kiln_out"],  pred["vol_glaze_out"])
# Sort (rule, only hard-coded threshold here)
need_sort = pred["vol_sort_out"].isna() & pred["vol_kiln_out"].notna()
if need_sort.any():
    r = np.clip(np.full(need_sort.sum(), SORT_DEFAULT), SORT_MIN, SORT_MAX)
    pred.loc[need_sort, "sort_ratio_pred"] = r
    pred.loc[need_sort, "sort_ratio_min"]  = SORT_MIN
    pred.loc[need_sort, "vol_sort_out_pred"] = volume_from_ratio(pred.loc[need_sort, "vol_kiln_out"].values, r, round_to_int=True)
    pred.loc[need_sort, "vol_sort_out"] = pred.loc[need_sort, "vol_sort_out_pred"]
pred["vol_sort_out"] = np.minimum(pred["vol_sort_out"], pred["vol_kiln_out"])

# -----------------------
# 7) ACCURACY (predicted rows only)
# -----------------------
def stage_metrics(stage_out, base_col, pred_out_col):
    tr = safe_ratio(truth[stage_out].to_numpy(), truth[base_col].to_numpy())
    pr = safe_ratio(pred[stage_out].to_numpy(),  pred[base_col].to_numpy())
    if pred_out_col not in pred.columns:
        return dict(n=0, R2=np.nan, MAE=np.nan)
    mask = pred[pred_out_col].notna()
    if mask.any():
        return dict(n=int(mask.sum()),
                    R2=float(r2_score(tr[mask], pr[mask])),
                    MAE=float(mean_absolute_error(tr[mask], pr[mask])))
    return dict(n=0, R2=np.nan, MAE=np.nan)

print("\n=== ACCURACY (predicted rows only) ===")
print("Press:", stage_metrics("vol_press_out", "vol_start",     "vol_press_out_pred"))
print("Glaze:", stage_metrics("vol_glaze_out", "vol_press_out", "vol_glaze_out_pred"))
print("Kiln :", stage_metrics("vol_kiln_out",  "vol_glaze_out", "vol_kiln_out_pred"))

# -----------------------
# 8) REALIZED RATIOS & STAGE FLAGS
# -----------------------
pred["press_ratio_real"] = safe_ratio(pred["vol_press_out"], pred["vol_start"])
pred["glaze_ratio_real"] = safe_ratio(pred["vol_glaze_out"], pred["vol_press_out"])
pred["kiln_ratio_real"]  = safe_ratio(pred["vol_kiln_out"],  pred["vol_glaze_out"])
pred["sort_ratio_real"]  = safe_ratio(pred["vol_sort_out"],  pred["vol_kiln_out"])

for nm in ["press_ratio_min","glaze_ratio_min","kiln_ratio_min","sort_ratio_min"]:
    if nm not in pred.columns: pred[nm] = np.nan
pred.loc[pred["vol_kiln_out"].notna(), "sort_ratio_min"] = SORT_MIN

pred["press_flag"] = np.where(
    (pred["press_ratio_min"].notna()) & (pred["press_ratio_real"] < pred["press_ratio_min"]),
    "LOW_PRESS_YIELD", ""
)
pred["glaze_flag"] = np.where(
    (pred["glaze_ratio_min"].notna()) & (pred["glaze_ratio_real"] < pred["glaze_ratio_min"]),
    "LOW_GLAZE_YIELD", ""
)
pred["kiln_flag"] = np.where(
    (pred["kiln_ratio_min"].notna()) & (pred["kiln_ratio_real"] < pred["kiln_ratio_min"]),
    "LOW_KILN_YIELD", ""
)
pred["sort_flag"] = np.where(pred["sort_ratio_real"] < SORT_MIN, "SORT_BELOW_90", "")

# -----------------------
# 9) FRIENDLY REPORT (known vs predicted + min + flag + correctness)
# -----------------------
LEVEL_LABELS = {0:"Start only",1:"Up to Press",2:"Up to Glaze",3:"Up to Kiln",4:"Full known"}
press_known = masked["vol_press_out"].notna()
glaze_known = masked["vol_glaze_out"].notna()
kiln_known  = masked["vol_kiln_out"].notna()
sort_known  = masked["vol_sort_out"].notna()
press_pred  = pred["vol_press_out_pred"].notna()
glaze_pred  = pred["vol_glaze_out_pred"].notna()
kiln_pred   = pred["vol_kiln_out_pred"].notna()
sort_pred   = pred["vol_sort_out_pred"].notna()

press_min_vol_num = pred["press_ratio_min"] * pred["vol_start"]
glaze_min_vol_num = pred["glaze_ratio_min"] * pred["vol_press_out"]
kiln_min_vol_num  = pred["kiln_ratio_min"]  * pred["vol_glaze_out"]
sort_min_vol_num  = SORT_MIN * pred["vol_kiln_out"]

friendly = pd.DataFrame({
    "datetime": pred["datetime"],
    "recipe_id": pred["recipe_id"],
    "known_level": pred["known_level"],
    "known_level_label": pd.Series(pred["known_level"]).map(LEVEL_LABELS),
    "START_amount_ft2": as_int_nullable(pred["vol_start"]),
})

# PRESS
friendly["PRESS_status"]               = np.where(press_known, "known", np.where(press_pred, "predicted", "unknown"))
friendly["PRESS_known_amount_ft2"]     = as_int_nullable(np.where(press_known, pred["vol_press_out"], np.nan))
friendly["PRESS_predicted_amount_ft2"] = as_int_nullable(np.where(press_pred,  pred["vol_press_out_pred"], np.nan))
friendly["PRESS_min_required_ft2"]     = as_int_nullable(np.where(press_pred,  press_min_vol_num, np.nan))
friendly["PRESS_flag"]                 = np.where(press_pred,  pred["press_flag"].fillna(""), "")
friendly["PRESS_actual_ft2"]           = as_int_nullable(truth["vol_press_out"])
friendly["PRESS_met_min"]              = met_min_str(press_pred, friendly["PRESS_actual_ft2"], friendly["PRESS_min_required_ft2"])

# GLAZE
friendly["GLAZE_status"]               = np.where(glaze_known, "known", np.where(glaze_pred, "predicted", "unknown"))
friendly["GLAZE_known_amount_ft2"]     = as_int_nullable(np.where(glaze_known, pred["vol_glaze_out"], np.nan))
friendly["GLAZE_predicted_amount_ft2"] = as_int_nullable(np.where(glaze_pred,  pred["vol_glaze_out_pred"], np.nan))
friendly["GLAZE_min_required_ft2"]     = as_int_nullable(np.where(glaze_pred,  glaze_min_vol_num, np.nan))
friendly["GLAZE_flag"]                 = np.where(glaze_pred,  pred["glaze_flag"].fillna(""), "")
friendly["GLAZE_actual_ft2"]           = as_int_nullable(truth["vol_glaze_out"])
friendly["GLAZE_met_min"]              = met_min_str(glaze_pred, friendly["GLAZE_actual_ft2"], friendly["GLAZE_min_required_ft2"])

# KILN
friendly["KILN_status"]               = np.where(kiln_known, "known", np.where(kiln_pred, "predicted", "unknown"))
friendly["KILN_known_amount_ft2"]     = as_int_nullable(np.where(kiln_known, pred["vol_kiln_out"], np.nan))
friendly["KILN_predicted_amount_ft2"] = as_int_nullable(np.where(kiln_pred,  pred["vol_kiln_out_pred"], np.nan))
friendly["KILN_min_required_ft2"]     = as_int_nullable(np.where(kiln_pred,  kiln_min_vol_num, np.nan))
friendly["KILN_flag"]                 = np.where(kiln_pred,  pred["kiln_flag"].fillna(""), "")
friendly["KILN_actual_ft2"]           = as_int_nullable(truth["vol_kiln_out"])
friendly["KILN_met_min"]              = met_min_str(kiln_pred, friendly["KILN_actual_ft2"], friendly["KILN_min_required_ft2"])

# SORT (rule)
friendly["SORT_status"]               = np.where(sort_known, "known", np.where(sort_pred, "predicted", "unknown"))
friendly["SORT_known_amount_ft2"]     = as_int_nullable(np.where(sort_known, pred["vol_sort_out"], np.nan))
friendly["SORT_predicted_amount_ft2"] = as_int_nullable(np.where(sort_pred,  pred["vol_sort_out_pred"], np.nan))
friendly["SORT_min_required_ft2"]     = as_int_nullable(np.where(sort_pred,  sort_min_vol_num, np.nan))
friendly["SORT_flag"]                 = np.where(sort_pred,  pred["sort_flag"].fillna(""), "")
friendly["SORT_actual_ft2"]           = as_int_nullable(truth["vol_sort_out"])
friendly["SORT_met_min"]              = met_min_str(sort_pred, friendly["SORT_actual_ft2"], friendly["SORT_min_required_ft2"])

friendly = friendly.sort_values("datetime").reset_index(drop=True)
friendly.to_csv(OUT_FRIENDLY, index=False)
print(f"\nWrote friendly per-row report:\n - {OUT_FRIENDLY}")

# -----------------------
# 10) COMPACT ONE-LINE SUMMARY PER ROW
# -----------------------
def min_volumes_row(row):
    press_min = row.get("press_ratio_min", np.nan) * row.get("vol_start", np.nan)
    glaze_min = row.get("glaze_ratio_min", np.nan) * row.get("vol_press_out", np.nan)
    kiln_min  = row.get("kiln_ratio_min",  np.nan) * row.get("vol_glaze_out", np.nan)
    sort_min  = SORT_MIN * row.get("vol_kiln_out", np.nan)
    return press_min, glaze_min, kiln_min, sort_min

def stage_piece(name, known_val, pred_val, min_vol):
    known = not pd.isna(known_val)
    predicted = not pd.isna(pred_val)
    if known:
        return f"{name} = {fmt_int(known_val)}"
    if predicted:
        return f"{name} = {fmt_int(pred_val)} (pred), min {fmt_int(min_vol)}"
    return f"{name} = -"

press_known_mask = masked["vol_press_out"].notna()
glaze_known_mask = masked["vol_glaze_out"].notna()
kiln_known_mask  = masked["vol_kiln_out"].notna()
sort_known_mask  = masked["vol_sort_out"].notna()

pred["__press_pred__"] = pred["vol_press_out_pred"].notna()
pred["__glaze_pred__"] = pred["vol_glaze_out_pred"].notna()
pred["__kiln_pred__"]  = pred["vol_kiln_out_pred"].notna()
pred["__sort_pred__"]  = pred["vol_sort_out_pred"].notna()

lines = []
rows = []
for i, row in pred.reset_index(drop=True).iterrows():
    start_piece = f"Start = {fmt_int(row['vol_start'])}"
    pmin, gmin, kmin, smin = min_volumes_row(row)

    press_piece = stage_piece(
        "Press",
        known_val = pred.loc[i, "vol_press_out"] if press_known_mask.iloc[i] else np.nan,
        pred_val  = pred.loc[i, "vol_press_out_pred"] if pred["__press_pred__"].iloc[i] else np.nan,
        min_vol   = pmin
    )
    glaze_piece = stage_piece(
        "Glaze",
        known_val = pred.loc[i, "vol_glaze_out"] if glaze_known_mask.iloc[i] else np.nan,
        pred_val  = pred.loc[i, "vol_glaze_out_pred"] if pred["__glaze_pred__"].iloc[i] else np.nan,
        min_vol   = gmin
    )
    kiln_piece = stage_piece(
        "Kiln",
        known_val = pred.loc[i, "vol_kiln_out"] if kiln_known_mask.iloc[i] else np.nan,
        pred_val  = pred.loc[i, "vol_kiln_out_pred"] if pred["__kiln_pred__"].iloc[i] else np.nan,
        min_vol   = kmin
    )
    sort_piece = stage_piece(
        "Sort",
        known_val = pred.loc[i, "vol_sort_out"] if sort_known_mask.iloc[i] else np.nan,
        pred_val  = pred.loc[i, "vol_sort_out_pred"] if pred["__sort_pred__"].iloc[i] else np.nan,
        min_vol   = smin
    )

    summary = "  ".join([start_piece, press_piece, glaze_piece, kiln_piece, sort_piece])
    lines.append(summary)
    rows.append({"datetime": row["datetime"], "recipe_id": row["recipe_id"], "known_level": row["known_level"], "summary": summary})

with open(OUT_COMPACT_T, "w", encoding="utf-8") as f:
    for s in lines:
        f.write(s + "\n")
pd.DataFrame(rows).sort_values("datetime").to_csv(OUT_COMPACT_C, index=False)
print(f"\nWrote compact summaries:\n - {OUT_COMPACT_T}\n - {OUT_COMPACT_C}")

# -----------------------
# 11) PER-STAGE EVAL vs ACTUAL
# -----------------------
def stage_eval(stage, base_col, out_col, pred_out_col, min_ratio_col, fixed_min=None):
    if pred_out_col not in pred.columns:
        return pd.DataFrame(columns=[
            "stage","datetime","recipe_id",
            "predicted_amount_ft2","min_required_ft2","actual_amount_ft2",
            "abs_error_ft2","pct_error","met_min","flag_result"
        ])
    m = pred[pred_out_col].notna()
    if not m.any():
        return pd.DataFrame(columns=[
            "stage","datetime","recipe_id",
            "predicted_amount_ft2","min_required_ft2","actual_amount_ft2",
            "abs_error_ft2","pct_error","met_min","flag_result"
        ])

    min_ratio = pd.to_numeric(pred.loc[m, min_ratio_col], errors="coerce") if fixed_min is None else pd.Series(fixed_min, index=pred.index)[m]
    base = pd.to_numeric(pred.loc[m, base_col], errors="coerce")
    min_required = min_ratio * base

    pred_amt = pd.to_numeric(pred.loc[m, pred_out_col], errors="coerce")
    actual_amt = pd.to_numeric(truth.loc[m, out_col], errors="coerce")

    abs_err = (pred_amt - actual_amt).abs()
    pct_err = abs_err / actual_amt.replace({0: np.nan})

    actual_ratio = actual_amt / base.replace({0: np.nan})
    flag_result = np.where(actual_ratio < min_ratio, "FLAG", "OK")
    met_min = np.where(actual_amt >= min_required, "YES", "NO")

    return pd.DataFrame({
        "stage": stage,
        "datetime": pred.loc[m, "datetime"],
        "recipe_id": pred.loc[m, "recipe_id"],
        "predicted_amount_ft2": pred_amt.round().astype("Int64"),
        "min_required_ft2": min_required.round().astype("Int64"),
        "actual_amount_ft2": actual_amt.round().astype("Int64"),
        "abs_error_ft2": abs_err.round().astype("Int64"),
        "pct_error": pct_err,
        "met_min": met_min,
        "flag_result": flag_result
    })

eval_press = stage_eval("PRESS", base_col="vol_start",     out_col="vol_press_out",
                        pred_out_col="vol_press_out_pred", min_ratio_col="press_ratio_min")
eval_glaze = stage_eval("GLAZE", base_col="vol_press_out", out_col="vol_glaze_out",
                        pred_out_col="vol_glaze_out_pred", min_ratio_col="glaze_ratio_min")
eval_kiln  = stage_eval("KILN",  base_col="vol_glaze_out", out_col="vol_kiln_out",
                        pred_out_col="vol_kiln_out_pred",  min_ratio_col="kiln_ratio_min")
eval_sort  = stage_eval("SORT",  base_col="vol_kiln_out",  out_col="vol_sort_out",
                        pred_out_col="vol_sort_out_pred",  min_ratio_col="sort_ratio_min", fixed_min=SORT_MIN)

eval_all = pd.concat([eval_press, eval_glaze, eval_kiln, eval_sort], ignore_index=True)
eval_all.to_csv(OUT_EVAL, index=False)
print(f"\nWrote per-stage evaluation vs ACTUAL:\n - {OUT_EVAL}")

def summarize_stage(df_stage):
    if df_stage.empty:
        return "n=0"
    n = len(df_stage)
    mae  = pd.to_numeric(df_stage["abs_error_ft2"], errors="coerce").dropna()
    mape = pd.to_numeric(df_stage["pct_error"],     errors="coerce").dropna()
    comp = (df_stage["met_min"] == "YES").mean() if n else np.nan
    return f"n={n}, MAE={mae.mean():.1f} ft², MAPE={(mape.mean()*100):.2f}%, compliance={comp*100:.1f}%"

print("\n=== RESULTS vs ACTUAL (predicted rows) ===")
print("PRESS :", summarize_stage(eval_press))
print("GLAZE :", summarize_stage(eval_glaze))
print("KILN  :", summarize_stage(eval_kiln))
print("SORT  :", summarize_stage(eval_sort))

# Also write the filled test & masked inputs for reference
pred.to_csv(OUT_PRED, index=False)
masked.to_csv(OUT_MASKED, index=False)
print(f"\nWrote:\n - {OUT_MASKED}\n - {OUT_PRED}")

# Quick friendly preview
print("\nFriendly preview (first 6 rows):")
cols = [
    "datetime","recipe_id","known_level_label","START_amount_ft2",
    "PRESS_status","PRESS_known_amount_ft2","PRESS_predicted_amount_ft2","PRESS_min_required_ft2","PRESS_flag","PRESS_actual_ft2","PRESS_met_min",
    "GLAZE_status","GLAZE_known_amount_ft2","GLAZE_predicted_amount_ft2","GLAZE_min_required_ft2","GLAZE_flag","GLAZE_actual_ft2","GLAZE_met_min",
    "KILN_status","KILN_known_amount_ft2","KILN_predicted_amount_ft2","KILN_min_required_ft2","KILN_flag","KILN_actual_ft2","KILN_met_min",
    "SORT_status","SORT_known_amount_ft2","SORT_predicted_amount_ft2","SORT_min_required_ft2","SORT_flag","SORT_actual_ft2","SORT_met_min",
]
print(friendly[cols].head(6).to_string(index=False))

# -----------------------
# 12) SHAP XAI (ONLY for ML stages; NO hardcoded spec flags)
# -----------------------
def get_ct_feature_names(preprocessor):
    # Try to get nice feature names from ColumnTransformer
    names = []
    try:
        names = preprocessor.get_feature_names_out()
        names = [n.replace("num__", "").replace("cat__", "").replace("pre__", "") for n in names]
        names = [n.replace("cat__recipe_id_", "recipe_id=").replace("recipe_id_", "recipe_id=") for n in names]
    except Exception:
        # Fallback generic names
        names = [f"f{i}" for i in range(
            preprocessor.transform(pd.DataFrame([{}])).shape[1]
        )]
    return np.array(names)

def top_negative_shap(pipeline, X_df, topk=5):
    """Return list of lists of (feature, shap_value) sorted by most negative for each row in X_df."""
    pre = pipeline.named_steps["pre"]
    rf  = pipeline.named_steps["rf"]

    X_t = pre.transform(X_df)
    # dense array (TreeExplainer can use dense)
    try:
        X_t = X_t.toarray()
    except Exception:
        pass

    if not HAS_SHAP:
        return [[] for _ in range(len(X_df))], None

    explainer = shap.TreeExplainer(rf)
    shap_vals = explainer.shap_values(X_t)  # (n_rows, n_features)
    base_val  = explainer.expected_value

    feat_names = get_ct_feature_names(pre)
    feat_names = feat_names if len(feat_names) == X_t.shape[1] else np.array([f"f{i}" for i in range(X_t.shape[1])])

    out = []
    for i in range(X_t.shape[0]):
        sv = shap_vals[i]
        order = np.argsort(sv)  # ascending: most negative first
        pairs = [(feat_names[j], float(sv[j])) for j in order[:topk]]
        out.append(pairs)
    return out, float(base_val)

# collect explanations for flagged *predicted* rows
xai_records = []
def explain_stage(stage_key, feat_cols, ratio_min_col, flag_col, out_pred_col, base_col):
    # rows where stage was predicted AND flagged
    if out_pred_col not in pred.columns:
        return
    m_pred = pred[out_pred_col].notna()
    m_flag = pred[flag_col] != ""
    idx = pred.index[m_pred & m_flag]
    if len(idx) == 0:
        return

    X = pred.loc[idx, feat_cols].copy()
    pairs_list, base_val = top_negative_shap(models[stage_key], X, topk=5)

    for row_idx, pairs in zip(idx, pairs_list):
        rec = {
            "stage": stage_key.replace("_ratio","").upper(),
            "datetime": str(pred.at[row_idx, "datetime"]),
            "recipe_id": int(pred.at[row_idx, "recipe_id"]),
            "known_level": int(pred.at[row_idx, "known_level"]),
            "predicted_ratio": float(pred.at[row_idx, f"{stage_key}_pred"]),
            "min_ratio":      float(pred.at[row_idx, ratio_min_col]) if not pd.isna(pred.at[row_idx, ratio_min_col]) else None,
            "actual_ratio":   float(safe_ratio(truth.at[row_idx, stage_key.replace("_ratio","").join(["vol_","_out"])],
                                               truth.at[row_idx, base_col])),
            "flag":           str(pred.at[row_idx, flag_col]),
            "top_neg":        [{"feature": f, "shap": v} for (f, v) in pairs],
            "base_value":     base_val,
        }
        xai_records.append(rec)

# explain ML stages only (press/glaze/kiln)
explain_stage("press_ratio", PRESS_FEATS, "press_ratio_min", "press_flag", "vol_press_out_pred", "vol_start")
explain_stage("glaze_ratio", GLAZE_FEATS, "glaze_ratio_min", "glaze_flag", "vol_glaze_out_pred", "vol_press_out")
explain_stage("kiln_ratio",  KILN_FEATS,  "kiln_ratio_min",  "kiln_flag",  "vol_kiln_out_pred",  "vol_glaze_out")

# write explanations (if any; and if SHAP available they have values; otherwise list is empty)
if len(xai_records) > 0:
    # CSV: flatten top 5 into columns
    rows_csv = []
    for r in xai_records:
        flat = {
            "stage": r["stage"],
            "datetime": r["datetime"],
            "recipe_id": r["recipe_id"],
            "known_level": r["known_level"],
            "predicted_ratio": r["predicted_ratio"],
            "min_ratio": r["min_ratio"],
            "actual_ratio": r["actual_ratio"],
            "flag": r["flag"],
        }
        for k in range(5):
            if k < len(r["top_neg"]):
                flat[f"neg{k+1}_feature"] = r["top_neg"][k]["feature"]
                flat[f"neg{k+1}_shap"]    = r["top_neg"][k]["shap"]
            else:
                flat[f"neg{k+1}_feature"] = ""
                flat[f"neg{k+1}_shap"]    = ""
        rows_csv.append(flat)
    pd.DataFrame(rows_csv).to_csv(XAI_OUT_CSV, index=False)
    with open(XAI_OUT_JSON, "w", encoding="utf-8") as f:
        json.dump(xai_records, f, indent=2)
    print(f"\nWrote SHAP explanations:\n - {XAI_OUT_CSV}\n - {XAI_OUT_JSON}")
else:
    print("\n[INFO] No flagged predicted ML rows or SHAP unavailable → no XAI files written.")

# -----------------------
# 13) FLAGS-ONLY OUTPUT (NO hardcoded spec flags; ONLY stage flags + Sort@90)
# -----------------------
stage_cols = ["press_flag","glaze_flag","kiln_flag","sort_flag"]
any_flag_mask = (pred[stage_cols] != "").any(axis=1)

flags_cols_order = [
    "datetime","recipe_id","known_level",
    # stage results
    "vol_start","vol_press_out","press_ratio_real","press_ratio_min","press_flag",
    "vol_glaze_out","glaze_ratio_real","glaze_ratio_min","glaze_flag",
    "vol_kiln_out","kiln_ratio_real","kiln_ratio_min","kiln_flag",
    "vol_sort_out","sort_ratio_real","sort_ratio_min","sort_flag",
]
flags_cols_order = [c for c in flags_cols_order if c in pred.columns]

flags_only = pred.loc[any_flag_mask, flags_cols_order].sort_values("datetime").reset_index(drop=True)

# If we produced SHAP, attach a compact column per row (top 3 neg features per stage)
def topneg_str_for_row(dt, rid, stage):
    # find first matching record
    for rec in xai_records:
        if rec["stage"] == stage and rec["datetime"] == str(dt) and rec["recipe_id"] == int(rid):
            parts = [f'{p["feature"]}({p["shap"]:+.4f})' for p in rec["top_neg"][:3]]
            return ", ".join(parts)
    return ""

if len(flags_only) > 0 and len(xai_records) > 0:
    # add per-stage top-neg strings
    flags_only["PRESS_topneg"] = flags_only.apply(
        lambda r: topneg_str_for_row(r["datetime"], r["recipe_id"], "PRESS"), axis=1) if "press_flag" in flags_only.columns else ""
    flags_only["GLAZE_topneg"] = flags_only.apply(
        lambda r: topneg_str_for_row(r["datetime"], r["recipe_id"], "GLAZE"), axis=1) if "glaze_flag" in flags_only.columns else ""
    flags_only["KILN_topneg"]  = flags_only.apply(
        lambda r: topneg_str_for_row(r["datetime"], r["recipe_id"], "KILN"),  axis=1) if "kiln_flag" in flags_only.columns else ""

flags_only.to_csv(FLAGS_ONLY_OUT, index=False)
print(f"\nWrote flags-only rows:\n - {FLAGS_ONLY_OUT} (rows with any stage flag)")
print(f"Flagged rows: {len(flags_only)} of {len(pred)} total")
if len(flags_only) > 0:
    print("\nFlags-only preview (first 8):")
    print(flags_only.head(8).to_string(index=False))

# -----------------------
# 14) PRINT-ONLY REPORTER (console summary)
# -----------------------
def exists(path):
    ok = os.path.exists(path)
    if not ok:
        print(f"[WARN] File not found: {path}")
    return ok

def print_header(title):
    print("\n" + "="*len(title))
    print(title)
    print("="*len(title))

MASKED      = OUT_MASKED
PREDICTED   = OUT_PRED
FRIENDLY    = OUT_FRIENDLY
COMPACT_TXT = OUT_COMPACT_T
EVAL_PATH   = OUT_EVAL
FLAGS_PATH  = FLAGS_ONLY_OUT

masked_o   = pd.read_csv(MASKED, parse_dates=["datetime"])    if exists(MASKED)     else None
pred_o     = pd.read_csv(PREDICTED, parse_dates=["datetime"])  if exists(PREDICTED) else None
friendly_o = pd.read_csv(FRIENDLY, parse_dates=["datetime"])   if exists(FRIENDLY)  else None
evaldf_o   = pd.read_csv(EVAL_PATH, parse_dates=["datetime"])  if exists(EVAL_PATH) else None
flags_o    = pd.read_csv(FLAGS_PATH, parse_dates=["datetime"]) if exists(FLAGS_PATH) else None

print_header("FILES SNAPSHOT")
for p in [MASKED, PREDICTED, FRIENDLY, COMPACT_TXT, EVAL_PATH, FLAGS_PATH, XAI_OUT_CSV, XAI_OUT_JSON]:
    print(("✓ " if os.path.exists(p) else "✗ ") + p)

if masked_o is not None and "known_level" in masked_o.columns:
    print_header("MASKED START-POINTS (test rows)")
    labels = {0:"Start",1:"Press",2:"Glaze",3:"Kiln",4:"Full"}
    dist = (masked_o["known_level"].map(labels).value_counts()
            .reindex(["Start","Press","Glaze","Kiln","Full"])
            .fillna(0).astype(int))
    print(dist.to_string())

if pred_o is not None:
    print_header("PREDICTION COVERAGE BY STAGE")
    def has(col): return col in pred_o.columns and pred_o[col].notna().any()
    cov = {
        "Press_pred_rows": int(pred_o["vol_press_out_pred"].notna().sum()) if has("vol_press_out_pred") else 0,
        "Glaze_pred_rows": int(pred_o["vol_glaze_out_pred"].notna().sum()) if has("vol_glaze_out_pred") else 0,
        "Kiln_pred_rows" : int(pred_o["vol_kiln_out_pred"].notna().sum())  if has("vol_kiln_out_pred")  else 0,
        "Sort_pred_rows" : int(pred_o["vol_sort_out_pred"].notna().sum())  if has("vol_sort_out_pred")  else 0,
        "Test_total_rows": len(pred_o)
    }
    for k,v in cov.items():
        print(f"{k:>18}: {v}")

if friendly_o is not None:
    print_header("FRIENDLY REPORT (first 8 rows)")
    cols_disp = [
        "datetime","recipe_id","known_level_label","START_amount_ft2",
        "PRESS_status","PRESS_known_amount_ft2","PRESS_predicted_amount_ft2","PRESS_min_required_ft2","PRESS_flag","PRESS_actual_ft2","PRESS_met_min",
        "GLAZE_status","GLAZE_known_amount_ft2","GLAZE_predicted_amount_ft2","GLAZE_min_required_ft2","GLAZE_flag","GLAZE_actual_ft2","GLAZE_met_min",
        "KILN_status","KILN_known_amount_ft2","KILN_predicted_amount_ft2","KILN_min_required_ft2","KILN_flag","KILN_actual_ft2","KILN_met_min",
        "SORT_status","SORT_known_amount_ft2","SORT_predicted_amount_ft2","SORT_min_required_ft2","SORT_flag","SORT_actual_ft2","SORT_met_min",
    ]
    show = [c for c in cols_disp if c in friendly_o.columns]
    print(friendly_o[show].head(8).to_string(index=False))

if exists(COMPACT_TXT):
    print_header("COMPACT SUMMARIES (first 8 lines)")
    with open(COMPACT_TXT, "r", encoding="utf-8") as f:
        for i, line in enumerate(f):
            if i >= 8: break
            print(line.rstrip("\n"))

if evaldf_o is not None and not evaldf_o.empty:
    print_header("RESULTS vs ACTUAL (predicted rows only)")
    def summarize_stage_o(df_stage):
        if df_stage.empty: return "n=0"
        n = len(df_stage)
        mae  = pd.to_numeric(df_stage["abs_error_ft2"], errors="coerce").dropna()
        mape = pd.to_numeric(df_stage["pct_error"],     errors="coerce").dropna()
        comp = (df_stage["met_min"] == "YES").mean() if n else np.nan
        return f"n={n}, MAE={mae.mean():.1f} ft², MAPE={(mape.mean()*100):.2f}%, compliance={comp*100:.1f}%"
    for stg in ["PRESS","GLAZE","KILN","SORT"]:
        s = summarize_stage_o(evaldf_o[evaldf_o["stage"]==stg])
        print(f"{stg:>5}: {s}")

    print("\nPer-stage eval sample (first 10 rows):")
    keep = ["stage","datetime","recipe_id","predicted_amount_ft2","min_required_ft2",
            "actual_amount_ft2","abs_error_ft2","pct_error","met_min","flag_result"]
    print(evaldf_o[keep].head(10).to_string(index=False))

if flags_o is not None and not flags_o.empty:
    print_header("FLAGS SUMMARY")
    print(f"Flagged rows: {len(flags_o)}")
    # Count each *stage* flag (no spec flags exist anymore)
    for c in ["press_flag","glaze_flag","kiln_flag","sort_flag"]:
        if c in flags_o.columns:
            cnt = int((flags_o[c].fillna("") != "").sum())
            if cnt > 0:
                print(f"{c:>12}: {cnt}")

    print("\nFlags-only preview (first 8):")
    print(flags_o.head(8).to_string(index=False))

# -----------------------
# 15) SAVE TRAINED BUNDLE FOR UI
# -----------------------
import joblib, sklearn
bundle = {
    "models": models,  # sklearn Pipelines per stage
    "stage_features": {
        "press_ratio": PRESS_FEATS,
        "glaze_ratio": GLAZE_FEATS,
        "kiln_ratio":  KILN_FEATS
    },
    "config": {
        "THRESHOLD_MARGIN": THRESHOLD_MARGIN,
        "SORT_MIN": SORT_MIN,
        "SORT_DEFAULT": SORT_DEFAULT,
        "SORT_MAX": SORT_MAX
    },
    "recipe_setpoints": RECIPE_SETPOINTS,
    "meta": {
        "created_at_utc": datetime.utcnow().isoformat(timespec="seconds"),
        "sklearn_version": sklearn.__version__,
        "notes": "RF models with per-stage features; Sort rule at 90% min; SHAP XAI optional"
    }
}
joblib.dump(bundle, BUNDLE_PATH)
print(f"\n✅ Saved model bundle → {BUNDLE_PATH}")