In [1]:
import os
import joblib
import pandas as pd
from interpret import show
import plotly.io as pio

pio.renderers.default = "notebook"  # or "inline", "jupyterlab", etc.


In [4]:

INPUT_DIR = "ebm_inputs"
MODEL_DIR = "models"
OUT_DIR_SHAPES = os.path.join(MODEL_DIR, "plots\shape_plots")
OUT_DIR_SUMMARY = os.path.join(MODEL_DIR, "plots\summary_plots")

os.makedirs(OUT_DIR_SHAPES, exist_ok=True)
os.makedirs(OUT_DIR_SUMMARY, exist_ok=True)

sexes = ["M", "F"]
TOP_K = 15  # number of feature-wise shape plots per model to save

# Base datasets we trained from
base_variants = {
    "tnpca": os.path.join(INPUT_DIR, "ebm_input_tnpca.csv"),
    "vae":   os.path.join(INPUT_DIR, "ebm_input_vae.csv"),
    "pca":   os.path.join(INPUT_DIR, "ebm_input_pca.csv"),
}

# For each base dataset, which model variants exist?
# These must match what ebm_alcohol.py trains/saves.
model_variants_by_base = {
    "tnpca": ["tnpca", "cog_only", "tnpca_only"],
    "vae":   ["vae", "vae_only"],
    "pca":   ["pca", "pca_only"],
}

def load_model(model_path):
    obj = joblib.load(model_path)
    if isinstance(obj, dict):
        return obj["model"], obj.get("features", None)
    else:
        return obj, None


In [5]:

for base_variant, data_path in base_variants.items():
    if not os.path.exists(data_path):
        print(f"[WARN] missing {data_path}, skipping base_variant={base_variant}")
        continue

    df = pd.read_csv(data_path)

    if "Gender" not in df.columns or "alc_y" not in df.columns:
        raise ValueError(f"'Gender' or 'alc_y' missing in {data_path}")

    model_variants = model_variants_by_base[base_variant]

    for sex in sexes:
        df_sex = df[df["Gender"] == sex].copy()
        if df_sex.empty:
            print(f"[INFO] no rows for sex={sex}, base_variant={base_variant}, skipping")
            continue

        pos_rate = df_sex["alc_y"].mean()
        print(f"\n=== sex={sex}, base_variant={base_variant} (pos rate={pos_rate:.3f}) ===")

        for variant in model_variants:
            model_path = os.path.join(MODEL_DIR, f"ebm_{sex}_{variant}.pkl")
            if not os.path.exists(model_path):
                print(f"  [WARN] missing model {model_path}, skipping variant={variant}")
                continue

            print(f"  -> plots for model variant={variant}")

            ebm, feat_list = load_model(model_path)

            # Global explanation
            global_expl = ebm.explain_global(name=f"{sex}-{variant}")

            # ----- 1) Summary / coverage plot (feature importance) -----
            summary_fig = global_expl.visualize()   # overview bar chart
            summary_fig.update_layout(
                title=f"{sex} - {variant} - Feature importance summary",
                margin=dict(l=80, r=20, t=60, b=50),
            )

            summary_path = os.path.join(
                OUT_DIR_SUMMARY, f"summary_{sex}_{variant}.png"
            )
            summary_fig.write_image(summary_path, width=900, height=700, scale=2)
            print("     saved summary plot:", summary_path)

            # ----- 2) Top-K shape plots -----
            data_all = global_expl.data()
            term_names = data_all["names"]
            n_terms = len(term_names)
            print("     number of terms:", n_terms)

            K = min(TOP_K, n_terms)

            for idx in range(K):
                name = term_names[idx]
                safe_name = str(name).replace(" ", "_").replace("/", "_")

                fig = global_expl.visualize(idx)
                fig.update_layout(
                    title=f"{sex} - {variant} - {name}",
                    margin=dict(l=50, r=20, t=60, b=50),
                )

                out_path = os.path.join(
                    OUT_DIR_SHAPES,
                    f"shape_{sex}_{variant}_{idx:02d}_{safe_name}.png"
                )
                fig.write_image(out_path, width=800, height=600, scale=2)
                print("       saved shape plot:", out_path)




=== sex=M, base_variant=tnpca (pos rate=0.195) ===
  -> plots for model variant=tnpca
     saved summary plot: models\plots\summary_plots\summary_M_tnpca.png
     number of terms: 86
       saved shape plot: models\plots\shape_plots\shape_M_tnpca_00_Age_in_Yrs.png
       saved shape plot: models\plots\shape_plots\shape_M_tnpca_01_PMAT24_A_CR.png
       saved shape plot: models\plots\shape_plots\shape_M_tnpca_02_PMAT24_A_SI.png
       saved shape plot: models\plots\shape_plots\shape_M_tnpca_03_PMAT24_A_RTCR.png
       saved shape plot: models\plots\shape_plots\shape_M_tnpca_04_ReadEng_Unadj.png
       saved shape plot: models\plots\shape_plots\shape_M_tnpca_05_ReadEng_AgeAdj.png
       saved shape plot: models\plots\shape_plots\shape_M_tnpca_06_PicVocab_Unadj.png
       saved shape plot: models\plots\shape_plots\shape_M_tnpca_07_PicVocab_AgeAdj.png
       saved shape plot: models\plots\shape_plots\shape_M_tnpca_08_IWRD_TOT.png
       saved shape plot: models\plots\shape_plots\shape_M_t