In [10]:
import os
import joblib
import pandas as pd
from interpret import show

import plotly.io as pio
pio.renderers.default = "notebook"  # or "inline", "jupyterlab", etc.


In [11]:

# Where inputs/models live
INPUT_DIR = "ebm_inputs"
MODEL_DIR = "ebm_alcohol"
OUT_DIR = os.path.join(MODEL_DIR, "shape_plots")
os.makedirs(OUT_DIR, exist_ok=True)

variants = ["tnpca", "vae", "pca"]
sexes = ["M", "F"]
TOP_K = 15  # number of features per model to save

def load_model(model_path):
    obj = joblib.load(model_path)
    if isinstance(obj, dict):
        return obj["model"], obj.get("features", None)
    else:
        return obj, None


In [12]:

for variant in variants:
    data_path = os.path.join(INPUT_DIR, f"ebm_input_{variant}.csv")
    if not os.path.exists(data_path):
        print(f"[WARN] missing {data_path}, skipping variant={variant}")
        continue

    df = pd.read_csv(data_path)

    if "Gender" not in df.columns or "alc_y" not in df.columns:
        raise ValueError(f"'Gender' or 'alc_y' missing in {data_path}")

    for sex in sexes:
        df_sex = df[df["Gender"] == sex].copy()
        if df_sex.empty:
            print(f"[INFO] no rows for sex={sex}, variant={variant}, skipping")
            continue

        model_path = os.path.join(MODEL_DIR, f"ebm_{sex}_{variant}.pkl")
        if not os.path.exists(model_path):
            print(f"[WARN] missing model {model_path}, skipping sex={sex}, variant={variant}")
            continue

        print(f"\n=== shape plots for sex={sex}, variant={variant} ===")
        print("  n rows:", len(df_sex), "  pos rate:", df_sex["alc_y"].mean())

        ebm, feat_list = load_model(model_path)

        # Global explanation
        global_expl = ebm.explain_global(name=f"{sex}-{variant}")

        data_all = global_expl.data()
        term_names = data_all["names"]
        n_terms = len(term_names)
        print("  number of terms:", n_terms)

        K = min(TOP_K, n_terms)

        for idx in range(K):
            name = term_names[idx]
            safe_name = str(name).replace(" ", "_").replace("/", "_")

            fig = global_expl.visualize(idx)
            fig.update_layout(
                title=f"{sex} - {variant} - {name}",
                margin=dict(l=50, r=20, t=60, b=50),
            )

            # optional: show inline for quick check
            # fig.show()

            out_path = os.path.join(
                OUT_DIR,
                f"shape_{sex}_{variant}_{idx:02d}_{safe_name}.png"
            )
            fig.write_image(out_path, width=800, height=600, scale=2)
            print("   saved:", out_path)



=== shape plots for sex=M, variant=tnpca ===
  n rows: 487   pos rate: 0.19507186858316222
  number of terms: 86
   saved: ebm_alcohol\shape_plots\shape_M_tnpca_00_Age_in_Yrs.png
   saved: ebm_alcohol\shape_plots\shape_M_tnpca_01_PMAT24_A_CR.png
   saved: ebm_alcohol\shape_plots\shape_M_tnpca_02_PMAT24_A_SI.png
   saved: ebm_alcohol\shape_plots\shape_M_tnpca_03_PMAT24_A_RTCR.png
   saved: ebm_alcohol\shape_plots\shape_M_tnpca_04_ReadEng_Unadj.png
   saved: ebm_alcohol\shape_plots\shape_M_tnpca_05_ReadEng_AgeAdj.png
   saved: ebm_alcohol\shape_plots\shape_M_tnpca_06_PicVocab_Unadj.png
   saved: ebm_alcohol\shape_plots\shape_M_tnpca_07_PicVocab_AgeAdj.png
   saved: ebm_alcohol\shape_plots\shape_M_tnpca_08_IWRD_TOT.png
   saved: ebm_alcohol\shape_plots\shape_M_tnpca_09_IWRD_RTC.png
   saved: ebm_alcohol\shape_plots\shape_M_tnpca_10_ProcSpeed_Unadj.png
   saved: ebm_alcohol\shape_plots\shape_M_tnpca_11_ProcSpeed_AgeAdj.png
   saved: ebm_alcohol\shape_plots\shape_M_tnpca_12_DDisc_SV_1mo_20