## Perform SHAP analysis and generate figures & tables

#### Generate legends

In [None]:
import matplotlib.pyplot as plt
from matplotlib.patches import Patch
from matplotlib import font_manager as fm
import matplotlib as mpl
from pathlib import Path

font_path = "../resources/fonts/Aptos.ttf"
fm.fontManager.addfont(font_path)
prop = fm.FontProperties(fname=font_path)
mpl.rcParams['font.family'] = prop.get_name()

out_dir_legend = Path("../results/figures/shap/legend")
out_dir_legend.mkdir(parents=True, exist_ok=True)

type_colors = {
    "Variant effect predictor": "#BD2424",
    "Mutation-level": "#2ca02c",
    "Gene-level": "#9467bd",
    "Conservation score-based": "#1f77b4",
    "Residue-level": "#c2d41e"
}

legend_patches = [
    Patch(facecolor=color, edgecolor="black", linewidth=0.4, label=label)
    for label, color in type_colors.items()
]

def save_shap_legend():
    fig = plt.figure(figsize=(6, 0.3))
    ax = fig.add_axes([0, 0, 1, 1])
    ax.axis("off")

    ax.legend(handles=legend_patches,
              loc="center",
              frameon=False,
              ncol=len(type_colors),
              fontsize=14,
              prop=prop)

    out_path = out_dir_legend / "shap_legend_by_type.png"
    fig.savefig(out_path, dpi=300, bbox_inches="tight", pad_inches=0)
    plt.close(fig)
    print(f"Legend saved to: {out_path}")

save_shap_legend()


Legend saved to: ..\results\figures\shap\legend\shap_legend_by_type.png


#### Generate importance and summary plots

In [None]:
import pandas as pd
import joblib
import shap
import matplotlib.pyplot as plt
import numpy as np
import os
import re
from pathlib import Path
from matplotlib import font_manager as fm
import matplotlib as mpl

font_path = "../resources/fonts/Aptos.ttf"
fm.fontManager.addfont(font_path)
mpl.rcParams['font.family'] = fm.FontProperties(fname=font_path).get_name()

FONT_SCALE_BAR = 0.9
FONT_SCALE_SUM = 1.0

# Importance plot font sizes
TITLE_SIZE_BAR  = int(17 * FONT_SCALE_BAR)
XLABEL_SIZE_BAR = int(14 * FONT_SCALE_BAR)
XTICK_SIZE_BAR  = int(14 * FONT_SCALE_BAR)
YTICK_SIZE_BAR  = int(12 * FONT_SCALE_BAR)

# Summary plot font sizes
TITLE_SIZE_SUM  = int(17 * FONT_SCALE_SUM)
XLABEL_SIZE_SUM = int(14 * FONT_SCALE_SUM)
XTICK_SIZE_SUM  = int(14 * FONT_SCALE_SUM)
YTICK_SIZE_SUM  = int(12 * FONT_SCALE_SUM)

TITLE_WEIGHT  = "normal"
XLABEL_WEIGHT = "normal"
TICK_WEIGHT   = "normal"

models     = ["FuncVEP_CTI", "FuncVEP_CTE", "FuncVEP_SP",
              "ClinVEP_CTI", "ClinVEP_CTE", "ClinVEP_SP"]
data_path  = "../data/final/functional_labels_model_input.txt"
meta_path  = "../resources/feature_lists/all_columns.txt"

id_column   = "ID"
target_col  = "functional_label"
max_display = 20
linewidth   = 0.4

out_bar = Path("../results/figures/shap/importance")
out_sum = Path("../results/figures/shap/summary")
out_bar.mkdir(parents=True, exist_ok=True)
out_sum.mkdir(parents=True, exist_ok=True)

meta_df          = pd.read_csv(meta_path, sep="\t")
feature_type_map = meta_df.set_index("Name")["Type"].to_dict()

type_colors = {
    "Variant Effect Predictor": "#BD2424",
    "Mutation-Level":            "#2ca02c",
    "Gene-Level":                "#9467bd",
    "Conservation Score-Based":  "#1f77b4",
    "Residue-Level":             "#c2d41e"
}

def clean_feature_name(feat: str) -> str:
    feat = re.sub(r"^glm_", "", feat)
    feat = re.sub(r"_score$", "", feat)
    return feat.replace("___", "-").replace("__", "-").replace("_", "-")

def filter_features(df: pd.DataFrame, model_name: str) -> pd.DataFrame:
    with open("../resources/feature_lists/tools_excluded_due_to_unavailable_training_sets.txt") as f:
        df = df.drop(columns=[ln.strip() for ln in f], errors="ignore")
    if model_name in ["FuncVEP_CTE", "ClinVEP_CTE"]:
        with open("../resources/feature_lists/clinical_trained_tools.txt") as f:
            df = df.drop(columns=[ln.strip() for ln in f], errors="ignore")
    if model_name in ["FuncVEP_SP", "ClinVEP_SP"]:
        with open("../resources/feature_lists/all_tools.txt") as f:
            df = df.drop(columns=[ln.strip() for ln in f], errors="ignore")
    return df

for model_name in models:
    print(f"Processing {model_name}")
    model_dir = f"../models/{model_name}"
    model     = joblib.load(os.path.join(model_dir, "model.pkl"))

    train_ids = pd.read_csv(os.path.join(model_dir, "training_variants.txt"),
                            sep="\t")["ID"].tolist()
    trained_features = model.feature_name_

    df = pd.read_csv(data_path, sep="\t")
    df.columns = df.columns.str.replace(" ", "_")
    df = filter_features(df, model_name)
    df = df[~df[id_column].isin(train_ids)]
    df[target_col] = df[target_col].replace({"PS3": 1, "BS3": 0})
    df = df[df["weight"] == 1]

    num_cols = df.columns.difference([id_column, target_col, "weight"])
    df[num_cols] = df[num_cols].apply(pd.to_numeric, errors="coerce")
    df = df.dropna(subset=[target_col])

    X = df[trained_features].copy()

    explainer     = shap.TreeExplainer(model)
    shap_vals_raw = explainer.shap_values(X)
    shap_vals     = shap_vals_raw[1] if isinstance(shap_vals_raw, list) and len(shap_vals_raw) == 2 else shap_vals_raw

    shap_means   = np.abs(shap_vals).mean(axis=0)
    top_idx      = np.argsort(shap_means)[::-1][:max_display]
    top_feats    = X.columns[top_idx]
    top_import   = shap_means[top_idx]
    top_colors   = [type_colors.get(feature_type_map.get(f, "Other"), "#7f7f7f") for f in top_feats]
    top_labels   = [clean_feature_name(f) for f in top_feats]

    fig, ax = plt.subplots(figsize=(7, 5))
    ax.barh(range(max_display)[::-1], top_import,
            color=top_colors, edgecolor="black", linewidth=linewidth)
    ax.set_yticks(range(max_display)[::-1])
    ax.set_yticklabels(top_labels, fontsize=YTICK_SIZE_BAR, fontweight=TICK_WEIGHT)
    ax.set_xlabel("Mean |SHAP value|", fontsize=XLABEL_SIZE_BAR, fontweight=XLABEL_WEIGHT)
    ax.set_title(model_name.replace("_", "-"),
                 fontsize=TITLE_SIZE_BAR, fontweight=TITLE_WEIGHT, pad=15)
    ax.tick_params(axis="x", labelsize=XTICK_SIZE_BAR)
    plt.tight_layout()
    plt.savefig(out_bar / f"shap_bar_{model_name}.png", dpi=300, bbox_inches="tight")
    plt.close()

    shap.summary_plot(shap_vals, X, max_display=max_display, show=False)
    fig = plt.gcf()
    if fig._suptitle is not None:
        fig._suptitle.set_text("")
    fig.suptitle(model_name.replace("_", "-"),
                 fontsize=TITLE_SIZE_SUM, fontweight=TITLE_WEIGHT, y=0.98)

    ax_sum = fig.axes[0]
    ax_sum.tick_params(axis="x", labelsize=XTICK_SIZE_SUM)
    ax_sum.tick_params(axis="y", labelsize=YTICK_SIZE_SUM)

    plt.tight_layout(rect=[0, 0, 1, 0.97])
    plt.savefig(out_sum / f"shap_summary_{model_name}.png", dpi=600, bbox_inches="tight")
    plt.close()

print("All SHAP plots saved.")


Processing FuncVEP_CTI


https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations
  df[target_col] = df[target_col].replace({"PS3": 1, "BS3": 0})


Processing FuncVEP_CTE


https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations
  df[target_col] = df[target_col].replace({"PS3": 1, "BS3": 0})


Processing FuncVEP_SP


https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations
  df[target_col] = df[target_col].replace({"PS3": 1, "BS3": 0})


Processing ClinVEP_CTI


https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations
  df[target_col] = df[target_col].replace({"PS3": 1, "BS3": 0})


Processing ClinVEP_CTE


https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations
  df[target_col] = df[target_col].replace({"PS3": 1, "BS3": 0})


Processing ClinVEP_SP


https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations
  df[target_col] = df[target_col].replace({"PS3": 1, "BS3": 0})


All SHAP plots saved.


#### Combine importance plots

In [None]:
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import matplotlib.gridspec as gridspec
from pathlib import Path
import string
import math

models = ["FuncVEP_CTI", "FuncVEP_CTE", "FuncVEP_SP", "ClinVEP_CTI", "ClinVEP_CTE", "ClinVEP_SP"]
fig_dir_imp = Path("../results/figures/shap/importance")
fig_dir_sum = Path("../results/figures/shap/summary")
legend_path = Path("../results/figures/shap/legend/shap_legend_by_type.png")

out_imp = fig_dir_imp / "shap_bar_combined.png"
out_sum = fig_dir_sum / "shap_summary_combined.png"

func_models = ["FuncVEP_CTI", "FuncVEP_CTE", "FuncVEP_SP"]
clin_models = ["ClinVEP_CTI", "ClinVEP_CTE", "ClinVEP_SP"]

imp_paths = []
for f, c in zip(func_models, clin_models):
    imp_paths.append(fig_dir_imp / f"shap_bar_{f}.png")
    imp_paths.append(fig_dir_imp / f"shap_bar_{c}.png")

def build_grid_two_columns(img_paths, out_path, include_legend=False, legend_img=None,
                           wspace=0.2, hspace=0.2, figsize_override=None):
    if include_legend and legend_img is None:
        raise ValueError("legend_img path must be provided when include_legend=True")

    n = len(img_paths)
    cols = 2
    rows = math.ceil(n / cols)
    height_ratios = [1] * rows

    if include_legend:
        rows += 1
        height_ratios.append(0.18)

    gs = gridspec.GridSpec(rows, cols, height_ratios=height_ratios)

    if figsize_override is None:
        fig = plt.figure(figsize=(cols * 7, rows * 4))
    else:
        fig = plt.figure(figsize=figsize_override)

    for idx, img_path in enumerate(img_paths):
        row = idx // cols
        col = idx % cols
        ax = fig.add_subplot(gs[row, col])
        img = mpimg.imread(img_path)
        ax.imshow(img, aspect='auto')
        ax.axis("off")

        ax.text(
            0.01,
            0.97,
            f"{string.ascii_uppercase[idx]}",
            transform=ax.transAxes,
            fontsize=14,
            fontweight="bold",
            va="top",
            ha="left",
        )

    if include_legend:
        legend_ax = fig.add_subplot(gs[-1, :])
        legend_ax.axis("off")
        legend_img_data = mpimg.imread(legend_img)
        legend_ax.imshow(legend_img_data)
        legend_ax.set_anchor("N")

    plt.subplots_adjust(left=0.05, right=0.95, top=0.95, bottom=0.05, wspace=wspace, hspace=hspace)
    plt.tight_layout()
    fig.savefig(out_path, dpi=300, bbox_inches="tight")
    plt.close(fig)
    print(f"saved → {out_path.relative_to(out_path.parents[2])}")

build_grid_two_columns(
    imp_paths,
    out_imp,
    include_legend=True,
    legend_img=legend_path,
    wspace=0.2,
    hspace=0.2,
)

saved → shap\importance\shap_bar_combined.png


#### Combine summary plots

In [None]:
import pandas as pd
import numpy as np

importance_dict = {}

for model_name in models:
    print(f"Collecting SHAP importances: {model_name}")
    
    model_dir = f"../models/{model_name}"
    model = joblib.load(os.path.join(model_dir, "model.pkl"))

    train_ids = pd.read_csv(os.path.join(model_dir, "training_variants.txt"), sep="\t")["ID"].tolist()
    trained_features = model.feature_name_

    df = pd.read_csv(data_path, sep="\t")
    df.columns = df.columns.str.replace(" ", "_")
    df = filter_features(df, model_name)
    df = df[~df[id_column].isin(train_ids)]
    df[target_column] = df[target_column].replace({"PS3": 1, "BS3": 0})

    numeric_cols = df.columns.difference([id_column, target_column, "weight"])
    df[numeric_cols] = df[numeric_cols].apply(pd.to_numeric, errors="coerce")
    df = df.dropna(subset=[target_column])

    X = df[trained_features].copy()

    explainer = shap.TreeExplainer(model)
    shap_vals_full = explainer.shap_values(X)
    shap_vals = shap_vals_full[1] if isinstance(shap_vals_full, list) and len(shap_vals_full) == 2 else shap_vals_full

    shap_means = np.abs(shap_vals).mean(axis=0)

    importance_series = pd.Series(shap_means, index=X.columns, name=model_name)

    importance_dict[model_name] = importance_series

importance_df = pd.DataFrame(importance_dict).reset_index()
importance_df.rename(columns={"index": "Feature"}, inplace=True)

importance_df.to_csv("../results/tables/shap/all_models_shap_importances.txt", sep="\t", index=False)


Collecting SHAP importances: FuncVEP_CTI


https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations
  df[target_column] = df[target_column].replace({"PS3": 1, "BS3": 0})
https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations


Collecting SHAP importances: FuncVEP_CTE


  df[target_column] = df[target_column].replace({"PS3": 1, "BS3": 0})
https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations


Collecting SHAP importances: FuncVEP_SP


  df[target_column] = df[target_column].replace({"PS3": 1, "BS3": 0})
https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations


Collecting SHAP importances: ClinVEP_CTI


  df[target_column] = df[target_column].replace({"PS3": 1, "BS3": 0})
https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations


Collecting SHAP importances: ClinVEP_CTE


  df[target_column] = df[target_column].replace({"PS3": 1, "BS3": 0})
https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations


Collecting SHAP importances: ClinVEP_SP


  df[target_column] = df[target_column].replace({"PS3": 1, "BS3": 0})


#### Generate importance tables

In [None]:
import pandas as pd
import numpy as np

models = ["FuncVEP_CTI", "ClinVEP_CTI", "FuncVEP_CTE", "ClinVEP_CTE", "FuncVEP_SP", "ClinVEP_SP"]
meta_df = pd.read_csv("../resources/feature_lists/all_columns.txt", sep="\t")
feature_type_map = meta_df.set_index("Name")["Category"].to_dict()
unique_types = sorted(meta_df["Category"].dropna().unique())
unique_types.remove("Functional-Trained")

summary_by_type = {}

for model_name in models:
    print(f"Processing model: {model_name}")
    
    model_dir = f"../models/{model_name}"
    model = joblib.load(os.path.join(model_dir, "model.pkl"))

    train_ids = pd.read_csv(os.path.join(model_dir, "training_variants.txt"), sep="\t")["ID"].tolist()
    df = pd.read_csv(data_path, sep="\t")
    df.columns = df.columns.str.replace(" ", "_")
    df = filter_features(df, model_name)
    df = df[~df[id_column].isin(train_ids)]
    df[target_column] = df[target_column].replace({"PS3": 1, "BS3": 0})
    df = df[df["weight"] == 1]

    numeric_cols = df.columns.difference([id_column, target_column, "weight"])
    df[numeric_cols] = df[numeric_cols].apply(pd.to_numeric, errors="coerce")
    df = df.dropna(subset=[target_column])
    
    trained_features = model.feature_name_
    X = df[trained_features].copy()

    explainer = shap.TreeExplainer(model)
    shap_vals_full = explainer.shap_values(X)
    shap_vals = shap_vals_full[1] if isinstance(shap_vals_full, list) and len(shap_vals_full) == 2 else shap_vals_full

    shap_means = np.abs(shap_vals).mean(axis=0)
    
    importance_df = pd.DataFrame({
        "feature": X.columns,
        "mean_abs_shap": shap_means
    })
    importance_df["type"] = importance_df["feature"].map(feature_type_map).fillna("Other")

    grouped = importance_df.groupby("type")["mean_abs_shap"]
    total_importance = grouped.sum()
    average_importance = grouped.mean()

    summary_by_type[model_name] = {
        "total": total_importance.to_dict(),
        "average": average_importance.to_dict()
    }

total_df = pd.DataFrame(index=unique_types, columns=models)
average_df = pd.DataFrame(index=unique_types, columns=models)

for model in models:
    for t in unique_types:
        total_df.loc[t, model] = summary_by_type[model]["total"].get(t, 0.0)
        average_df.loc[t, model] = summary_by_type[model]["average"].get(t, 0.0)

total_df = total_df.astype(float)
average_df = average_df.astype(float)

total_df.to_csv("../results/tables/shap/shap_total_importance_by_type.csv", index_label="Feature_Type")
average_df.to_csv("../results/tables/shap/shap_average_importance_by_type.csv", index_label="Feature_Type")

print("SHAP total and average importance by feature type saved.")

total_df_normalized = total_df.div(total_df.sum(axis=0), axis=1)

average_df_normalized = average_df.div(average_df.sum(axis=0), axis=1)

total_df_normalized.to_csv("../results/tables/shap/shap_total_importance_by_type_normalized.csv", index_label="Feature_Type")
average_df_normalized.to_csv("../results/tables/shap/shap_average_importance_by_type_normalized.csv", index_label="Feature_Type")

print("Normalized SHAP importance values saved.")

Processing model: FuncVEP_CTI


https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations
  df[target_column] = df[target_column].replace({"PS3": 1, "BS3": 0})
https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations


Processing model: ClinVEP_CTI


  df[target_column] = df[target_column].replace({"PS3": 1, "BS3": 0})
https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations


Processing model: FuncVEP_CTE


  df[target_column] = df[target_column].replace({"PS3": 1, "BS3": 0})
https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations


Processing model: ClinVEP_CTE


  df[target_column] = df[target_column].replace({"PS3": 1, "BS3": 0})
https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations


Processing model: FuncVEP_SP


  df[target_column] = df[target_column].replace({"PS3": 1, "BS3": 0})
https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations


Processing model: ClinVEP_SP


  df[target_column] = df[target_column].replace({"PS3": 1, "BS3": 0})


SHAP total and average importance by feature type saved.
Normalized SHAP importance values saved.


