In [5]:
import os
import joblib
import pandas as pd
from interpret import show
import plotly.io as pio

pio.renderers.default = "notebook"  # or "inline", "jupyterlab", etc.


In [8]:

INPUT_DIR = "data_inputs"
MODEL_DIR = "models"
OUT_DIR_SHAPES = os.path.join(MODEL_DIR, "plots", "shape_plots")
OUT_DIR_SUMMARY = os.path.join(MODEL_DIR, "plots", "summary_plots")


os.makedirs(OUT_DIR_SHAPES, exist_ok=True)
os.makedirs(OUT_DIR_SUMMARY, exist_ok=True)

sexes = ["M", "F"]
TOP_K = 15  # number of feature-wise shape plots per model to save

# Base datasets we trained from
base_variants = {
    "tnpca": os.path.join(INPUT_DIR, "ebm_input_tnpca.csv"),
    "vae":   os.path.join(INPUT_DIR, "ebm_input_vae.csv"),
    "pca":   os.path.join(INPUT_DIR, "ebm_input_pca.csv"),
}

# For each base dataset, which model variants exist?
# These must match what ebm_alcohol.py trains/saves.
model_variants_by_base = {
    "tnpca": ["tnpca", "cog_only", "tnpca_only"],
    "vae":   ["vae", "vae_only"],
    "pca":   ["pca", "pca_only"],
}

def load_model(model_path):
    obj = joblib.load(model_path)
    if isinstance(obj, dict):
        pipe = obj["model"]
        features = obj.get("features", None)
    else:
        pipe = obj
        features = None

    # If this is an imblearn/sklearn pipeline with a "model" step, extract EBM
    if hasattr(pipe, "named_steps") and "model" in pipe.named_steps:
        ebm = pipe.named_steps["model"]
    else:
        ebm = pipe

    return ebm, features



In [9]:

for base_variant, data_path in base_variants.items():
    if not os.path.exists(data_path):
        print(f"[WARN] missing {data_path}, skipping base_variant={base_variant}")
        continue

    df = pd.read_csv(data_path)

    if "Gender" not in df.columns or "alc_y" not in df.columns:
        raise ValueError(f"'Gender' or 'alc_y' missing in {data_path}")

    model_variants = model_variants_by_base[base_variant]

    for sex in sexes:
        df_sex = df[df["Gender"] == sex].copy()
        if df_sex.empty:
            print(f"[INFO] no rows for sex={sex}, base_variant={base_variant}, skipping")
            continue

        pos_rate = df_sex["alc_y"].mean()
        print(f"\n=== sex={sex}, base_variant={base_variant} (pos rate={pos_rate:.3f}) ===")

        for variant in model_variants:
            model_path = os.path.join(MODEL_DIR, f"ebm_{sex}_{variant}.pkl")
            if not os.path.exists(model_path):
                print(f"  [WARN] missing model {model_path}, skipping variant={variant}")
                continue

            print(f"  -> plots for model variant={variant}")

            ebm, feat_list = load_model(model_path)

            # Global explanation
            global_expl = ebm.explain_global(name=f"{sex}-{variant}")

            # ----- 1) Summary / coverage plot (feature importance) -----
            summary_fig = global_expl.visualize()   # overview bar chart
            summary_fig.update_layout(
                title=f"{sex} - {variant} - Feature importance summary",
                margin=dict(l=80, r=20, t=60, b=50),
            )

            summary_path = os.path.join(
                OUT_DIR_SUMMARY, f"summary_{sex}_{variant}.png"
            )
            summary_fig.write_image(summary_path, width=900, height=700, scale=2)
            print("     saved summary plot:", summary_path)

            # ----- 2) Top-K shape plots -----
            data_all = global_expl.data()
            term_names = data_all["names"]
            n_terms = len(term_names)
            print("     number of terms:", n_terms)

            K = min(TOP_K, n_terms)

            for idx in range(K):
                name = term_names[idx]
                safe_name = str(name).replace(" ", "_").replace("/", "_")

                fig = global_expl.visualize(idx)
                fig.update_layout(
                    title=f"{sex} - {variant} - {name}",
                    margin=dict(l=50, r=20, t=60, b=50),
                )

                out_path = os.path.join(
                    OUT_DIR_SHAPES,
                    f"shape_{sex}_{variant}_{idx:02d}_{safe_name}.png"
                )
                fig.write_image(out_path, width=800, height=600, scale=2)
                print("       saved shape plot:", out_path)




=== sex=M, base_variant=tnpca (pos rate=0.195) ===
  -> plots for model variant=tnpca



Trying to unpickle estimator NearestNeighbors from version 1.2.2 when using version 1.7.2. This might lead to breaking code or invalid results. Use at your own risk. For more info please refer to:
https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations



     saved summary plot: models\plots\summary_plots\summary_M_tnpca.png
     number of terms: 86
       saved shape plot: models\plots\shape_plots\shape_M_tnpca_00_feature_0000.png
       saved shape plot: models\plots\shape_plots\shape_M_tnpca_01_feature_0001.png
       saved shape plot: models\plots\shape_plots\shape_M_tnpca_02_feature_0002.png
       saved shape plot: models\plots\shape_plots\shape_M_tnpca_03_feature_0003.png
       saved shape plot: models\plots\shape_plots\shape_M_tnpca_04_feature_0004.png
       saved shape plot: models\plots\shape_plots\shape_M_tnpca_05_feature_0005.png
       saved shape plot: models\plots\shape_plots\shape_M_tnpca_06_feature_0006.png
       saved shape plot: models\plots\shape_plots\shape_M_tnpca_07_feature_0007.png
       saved shape plot: models\plots\shape_plots\shape_M_tnpca_08_feature_0008.png
       saved shape plot: models\plots\shape_plots\shape_M_tnpca_09_feature_0009.png
       saved shape plot: models\plots\shape_plots\shape_M_tnpca


Trying to unpickle estimator NearestNeighbors from version 1.2.2 when using version 1.7.2. This might lead to breaking code or invalid results. Use at your own risk. For more info please refer to:
https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations



     saved summary plot: models\plots\summary_plots\summary_M_cog_only.png
     number of terms: 46
       saved shape plot: models\plots\shape_plots\shape_M_cog_only_00_feature_0000.png
       saved shape plot: models\plots\shape_plots\shape_M_cog_only_01_feature_0001.png
       saved shape plot: models\plots\shape_plots\shape_M_cog_only_02_feature_0002.png
       saved shape plot: models\plots\shape_plots\shape_M_cog_only_03_feature_0003.png
       saved shape plot: models\plots\shape_plots\shape_M_cog_only_04_feature_0004.png
       saved shape plot: models\plots\shape_plots\shape_M_cog_only_05_feature_0005.png
       saved shape plot: models\plots\shape_plots\shape_M_cog_only_06_feature_0006.png
       saved shape plot: models\plots\shape_plots\shape_M_cog_only_07_feature_0007.png
       saved shape plot: models\plots\shape_plots\shape_M_cog_only_08_feature_0008.png
       saved shape plot: models\plots\shape_plots\shape_M_cog_only_09_feature_0009.png
       saved shape plot: model


Trying to unpickle estimator NearestNeighbors from version 1.2.2 when using version 1.7.2. This might lead to breaking code or invalid results. Use at your own risk. For more info please refer to:
https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations



     saved summary plot: models\plots\summary_plots\summary_M_tnpca_only.png
     number of terms: 40
       saved shape plot: models\plots\shape_plots\shape_M_tnpca_only_00_feature_0000.png
       saved shape plot: models\plots\shape_plots\shape_M_tnpca_only_01_feature_0001.png
       saved shape plot: models\plots\shape_plots\shape_M_tnpca_only_02_feature_0002.png
       saved shape plot: models\plots\shape_plots\shape_M_tnpca_only_03_feature_0003.png
       saved shape plot: models\plots\shape_plots\shape_M_tnpca_only_04_feature_0004.png
       saved shape plot: models\plots\shape_plots\shape_M_tnpca_only_05_feature_0005.png
       saved shape plot: models\plots\shape_plots\shape_M_tnpca_only_06_feature_0006.png
       saved shape plot: models\plots\shape_plots\shape_M_tnpca_only_07_feature_0007.png
       saved shape plot: models\plots\shape_plots\shape_M_tnpca_only_08_feature_0008.png
       saved shape plot: models\plots\shape_plots\shape_M_tnpca_only_09_feature_0009.png
       s


Trying to unpickle estimator NearestNeighbors from version 1.2.2 when using version 1.7.2. This might lead to breaking code or invalid results. Use at your own risk. For more info please refer to:
https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations



     saved summary plot: models\plots\summary_plots\summary_F_tnpca.png
     number of terms: 86
       saved shape plot: models\plots\shape_plots\shape_F_tnpca_00_feature_0000.png
       saved shape plot: models\plots\shape_plots\shape_F_tnpca_01_feature_0001.png
       saved shape plot: models\plots\shape_plots\shape_F_tnpca_02_feature_0002.png
       saved shape plot: models\plots\shape_plots\shape_F_tnpca_03_feature_0003.png
       saved shape plot: models\plots\shape_plots\shape_F_tnpca_04_feature_0004.png
       saved shape plot: models\plots\shape_plots\shape_F_tnpca_05_feature_0005.png
       saved shape plot: models\plots\shape_plots\shape_F_tnpca_06_feature_0006.png
       saved shape plot: models\plots\shape_plots\shape_F_tnpca_07_feature_0007.png
       saved shape plot: models\plots\shape_plots\shape_F_tnpca_08_feature_0008.png
       saved shape plot: models\plots\shape_plots\shape_F_tnpca_09_feature_0009.png
       saved shape plot: models\plots\shape_plots\shape_F_tnpca


Trying to unpickle estimator NearestNeighbors from version 1.2.2 when using version 1.7.2. This might lead to breaking code or invalid results. Use at your own risk. For more info please refer to:
https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations



     saved summary plot: models\plots\summary_plots\summary_F_cog_only.png
     number of terms: 46
       saved shape plot: models\plots\shape_plots\shape_F_cog_only_00_feature_0000.png
       saved shape plot: models\plots\shape_plots\shape_F_cog_only_01_feature_0001.png
       saved shape plot: models\plots\shape_plots\shape_F_cog_only_02_feature_0002.png
       saved shape plot: models\plots\shape_plots\shape_F_cog_only_03_feature_0003.png
       saved shape plot: models\plots\shape_plots\shape_F_cog_only_04_feature_0004.png
       saved shape plot: models\plots\shape_plots\shape_F_cog_only_05_feature_0005.png
       saved shape plot: models\plots\shape_plots\shape_F_cog_only_06_feature_0006.png
       saved shape plot: models\plots\shape_plots\shape_F_cog_only_07_feature_0007.png
       saved shape plot: models\plots\shape_plots\shape_F_cog_only_08_feature_0008.png
       saved shape plot: models\plots\shape_plots\shape_F_cog_only_09_feature_0009.png
       saved shape plot: model


Trying to unpickle estimator NearestNeighbors from version 1.2.2 when using version 1.7.2. This might lead to breaking code or invalid results. Use at your own risk. For more info please refer to:
https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations



     saved summary plot: models\plots\summary_plots\summary_F_tnpca_only.png
     number of terms: 40
       saved shape plot: models\plots\shape_plots\shape_F_tnpca_only_00_feature_0000.png
       saved shape plot: models\plots\shape_plots\shape_F_tnpca_only_01_feature_0001.png
       saved shape plot: models\plots\shape_plots\shape_F_tnpca_only_02_feature_0002.png
       saved shape plot: models\plots\shape_plots\shape_F_tnpca_only_03_feature_0003.png
       saved shape plot: models\plots\shape_plots\shape_F_tnpca_only_04_feature_0004.png
       saved shape plot: models\plots\shape_plots\shape_F_tnpca_only_05_feature_0005.png
       saved shape plot: models\plots\shape_plots\shape_F_tnpca_only_06_feature_0006.png
       saved shape plot: models\plots\shape_plots\shape_F_tnpca_only_07_feature_0007.png
       saved shape plot: models\plots\shape_plots\shape_F_tnpca_only_08_feature_0008.png
       saved shape plot: models\plots\shape_plots\shape_F_tnpca_only_09_feature_0009.png
       s


Trying to unpickle estimator NearestNeighbors from version 1.2.2 when using version 1.7.2. This might lead to breaking code or invalid results. Use at your own risk. For more info please refer to:
https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations



     saved summary plot: models\plots\summary_plots\summary_M_vae.png
     number of terms: 86
       saved shape plot: models\plots\shape_plots\shape_M_vae_00_feature_0000.png
       saved shape plot: models\plots\shape_plots\shape_M_vae_01_feature_0001.png
       saved shape plot: models\plots\shape_plots\shape_M_vae_02_feature_0002.png
       saved shape plot: models\plots\shape_plots\shape_M_vae_03_feature_0003.png
       saved shape plot: models\plots\shape_plots\shape_M_vae_04_feature_0004.png
       saved shape plot: models\plots\shape_plots\shape_M_vae_05_feature_0005.png
       saved shape plot: models\plots\shape_plots\shape_M_vae_06_feature_0006.png
       saved shape plot: models\plots\shape_plots\shape_M_vae_07_feature_0007.png
       saved shape plot: models\plots\shape_plots\shape_M_vae_08_feature_0008.png
       saved shape plot: models\plots\shape_plots\shape_M_vae_09_feature_0009.png
       saved shape plot: models\plots\shape_plots\shape_M_vae_10_feature_0010.png
   


Trying to unpickle estimator NearestNeighbors from version 1.2.2 when using version 1.7.2. This might lead to breaking code or invalid results. Use at your own risk. For more info please refer to:
https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations



     saved summary plot: models\plots\summary_plots\summary_M_vae_only.png
     number of terms: 40
       saved shape plot: models\plots\shape_plots\shape_M_vae_only_00_feature_0000.png
       saved shape plot: models\plots\shape_plots\shape_M_vae_only_01_feature_0001.png
       saved shape plot: models\plots\shape_plots\shape_M_vae_only_02_feature_0002.png
       saved shape plot: models\plots\shape_plots\shape_M_vae_only_03_feature_0003.png
       saved shape plot: models\plots\shape_plots\shape_M_vae_only_04_feature_0004.png
       saved shape plot: models\plots\shape_plots\shape_M_vae_only_05_feature_0005.png
       saved shape plot: models\plots\shape_plots\shape_M_vae_only_06_feature_0006.png
       saved shape plot: models\plots\shape_plots\shape_M_vae_only_07_feature_0007.png
       saved shape plot: models\plots\shape_plots\shape_M_vae_only_08_feature_0008.png
       saved shape plot: models\plots\shape_plots\shape_M_vae_only_09_feature_0009.png
       saved shape plot: model


Trying to unpickle estimator NearestNeighbors from version 1.2.2 when using version 1.7.2. This might lead to breaking code or invalid results. Use at your own risk. For more info please refer to:
https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations



     saved summary plot: models\plots\summary_plots\summary_F_vae.png
     number of terms: 86
       saved shape plot: models\plots\shape_plots\shape_F_vae_00_feature_0000.png
       saved shape plot: models\plots\shape_plots\shape_F_vae_01_feature_0001.png
       saved shape plot: models\plots\shape_plots\shape_F_vae_02_feature_0002.png
       saved shape plot: models\plots\shape_plots\shape_F_vae_03_feature_0003.png
       saved shape plot: models\plots\shape_plots\shape_F_vae_04_feature_0004.png
       saved shape plot: models\plots\shape_plots\shape_F_vae_05_feature_0005.png
       saved shape plot: models\plots\shape_plots\shape_F_vae_06_feature_0006.png
       saved shape plot: models\plots\shape_plots\shape_F_vae_07_feature_0007.png
       saved shape plot: models\plots\shape_plots\shape_F_vae_08_feature_0008.png
       saved shape plot: models\plots\shape_plots\shape_F_vae_09_feature_0009.png
       saved shape plot: models\plots\shape_plots\shape_F_vae_10_feature_0010.png
   


Trying to unpickle estimator NearestNeighbors from version 1.2.2 when using version 1.7.2. This might lead to breaking code or invalid results. Use at your own risk. For more info please refer to:
https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations



     saved summary plot: models\plots\summary_plots\summary_F_vae_only.png
     number of terms: 40
       saved shape plot: models\plots\shape_plots\shape_F_vae_only_00_feature_0000.png
       saved shape plot: models\plots\shape_plots\shape_F_vae_only_01_feature_0001.png
       saved shape plot: models\plots\shape_plots\shape_F_vae_only_02_feature_0002.png
       saved shape plot: models\plots\shape_plots\shape_F_vae_only_03_feature_0003.png
       saved shape plot: models\plots\shape_plots\shape_F_vae_only_04_feature_0004.png
       saved shape plot: models\plots\shape_plots\shape_F_vae_only_05_feature_0005.png
       saved shape plot: models\plots\shape_plots\shape_F_vae_only_06_feature_0006.png
       saved shape plot: models\plots\shape_plots\shape_F_vae_only_07_feature_0007.png
       saved shape plot: models\plots\shape_plots\shape_F_vae_only_08_feature_0008.png
       saved shape plot: models\plots\shape_plots\shape_F_vae_only_09_feature_0009.png
       saved shape plot: model


Trying to unpickle estimator NearestNeighbors from version 1.2.2 when using version 1.7.2. This might lead to breaking code or invalid results. Use at your own risk. For more info please refer to:
https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations



     saved summary plot: models\plots\summary_plots\summary_M_pca.png
     number of terms: 86
       saved shape plot: models\plots\shape_plots\shape_M_pca_00_feature_0000.png
       saved shape plot: models\plots\shape_plots\shape_M_pca_01_feature_0001.png
       saved shape plot: models\plots\shape_plots\shape_M_pca_02_feature_0002.png
       saved shape plot: models\plots\shape_plots\shape_M_pca_03_feature_0003.png
       saved shape plot: models\plots\shape_plots\shape_M_pca_04_feature_0004.png
       saved shape plot: models\plots\shape_plots\shape_M_pca_05_feature_0005.png
       saved shape plot: models\plots\shape_plots\shape_M_pca_06_feature_0006.png
       saved shape plot: models\plots\shape_plots\shape_M_pca_07_feature_0007.png
       saved shape plot: models\plots\shape_plots\shape_M_pca_08_feature_0008.png
       saved shape plot: models\plots\shape_plots\shape_M_pca_09_feature_0009.png
       saved shape plot: models\plots\shape_plots\shape_M_pca_10_feature_0010.png
   


Trying to unpickle estimator NearestNeighbors from version 1.2.2 when using version 1.7.2. This might lead to breaking code or invalid results. Use at your own risk. For more info please refer to:
https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations



     saved summary plot: models\plots\summary_plots\summary_M_pca_only.png
     number of terms: 40
       saved shape plot: models\plots\shape_plots\shape_M_pca_only_00_feature_0000.png
       saved shape plot: models\plots\shape_plots\shape_M_pca_only_01_feature_0001.png
       saved shape plot: models\plots\shape_plots\shape_M_pca_only_02_feature_0002.png
       saved shape plot: models\plots\shape_plots\shape_M_pca_only_03_feature_0003.png
       saved shape plot: models\plots\shape_plots\shape_M_pca_only_04_feature_0004.png
       saved shape plot: models\plots\shape_plots\shape_M_pca_only_05_feature_0005.png
       saved shape plot: models\plots\shape_plots\shape_M_pca_only_06_feature_0006.png
       saved shape plot: models\plots\shape_plots\shape_M_pca_only_07_feature_0007.png
       saved shape plot: models\plots\shape_plots\shape_M_pca_only_08_feature_0008.png
       saved shape plot: models\plots\shape_plots\shape_M_pca_only_09_feature_0009.png
       saved shape plot: model


Trying to unpickle estimator NearestNeighbors from version 1.2.2 when using version 1.7.2. This might lead to breaking code or invalid results. Use at your own risk. For more info please refer to:
https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations



     saved summary plot: models\plots\summary_plots\summary_F_pca.png
     number of terms: 86
       saved shape plot: models\plots\shape_plots\shape_F_pca_00_feature_0000.png
       saved shape plot: models\plots\shape_plots\shape_F_pca_01_feature_0001.png
       saved shape plot: models\plots\shape_plots\shape_F_pca_02_feature_0002.png
       saved shape plot: models\plots\shape_plots\shape_F_pca_03_feature_0003.png
       saved shape plot: models\plots\shape_plots\shape_F_pca_04_feature_0004.png
       saved shape plot: models\plots\shape_plots\shape_F_pca_05_feature_0005.png
       saved shape plot: models\plots\shape_plots\shape_F_pca_06_feature_0006.png
       saved shape plot: models\plots\shape_plots\shape_F_pca_07_feature_0007.png
       saved shape plot: models\plots\shape_plots\shape_F_pca_08_feature_0008.png
       saved shape plot: models\plots\shape_plots\shape_F_pca_09_feature_0009.png
       saved shape plot: models\plots\shape_plots\shape_F_pca_10_feature_0010.png
   


Trying to unpickle estimator NearestNeighbors from version 1.2.2 when using version 1.7.2. This might lead to breaking code or invalid results. Use at your own risk. For more info please refer to:
https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations



     saved summary plot: models\plots\summary_plots\summary_F_pca_only.png
     number of terms: 40
       saved shape plot: models\plots\shape_plots\shape_F_pca_only_00_feature_0000.png
       saved shape plot: models\plots\shape_plots\shape_F_pca_only_01_feature_0001.png
       saved shape plot: models\plots\shape_plots\shape_F_pca_only_02_feature_0002.png
       saved shape plot: models\plots\shape_plots\shape_F_pca_only_03_feature_0003.png
       saved shape plot: models\plots\shape_plots\shape_F_pca_only_04_feature_0004.png
       saved shape plot: models\plots\shape_plots\shape_F_pca_only_05_feature_0005.png
       saved shape plot: models\plots\shape_plots\shape_F_pca_only_06_feature_0006.png
       saved shape plot: models\plots\shape_plots\shape_F_pca_only_07_feature_0007.png
       saved shape plot: models\plots\shape_plots\shape_F_pca_only_08_feature_0008.png
       saved shape plot: models\plots\shape_plots\shape_F_pca_only_09_feature_0009.png
       saved shape plot: model