In [1]:
import pandas as pd
import numpy as np
import shap
import joblib
import matplotlib.pyplot as plt
from sklearn.ensemble import RandomForestClassifier

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# CONFIG
TRAINING_DATA_PATH="data/training_set.csv"
TEST_DATA_PATH="data/test_set.csv"
OUTPUT_MODEL_PATH="models/rf_final_model.pkl"

# Load data
train_df = pd.read_csv(TRAINING_DATA_PATH)

# Define features
feature_cols = [
    'AllofUs_ALL_AF', 'AlphaMissense_score', 'BayesDel_addAF_score', 'BayesDel_noAF_score',
    'CADD_raw', 'ClinPred_score', 'DANN_score', 'DEOGEN2_score', 'ESM1b_score',
    'Eigen-PC-raw_coding', 'Eigen-raw_coding', 'GERP++_RS', 'GERP_91_mammals',
    'M-CAP_score', 'MPC_score', 'MVP_score', 'MetaLR_score', 'MetaRNN_score',
    'MetaSVM_score', 'MutFormer_score', 'MutationAssessor_score', 'MutationTaster_score',
    'PROVEAN_score', 'Polyphen2_HDIV_score', 'Polyphen2_HVAR_score', 'PrimateAI_score',
    'REVEL_score', 'RegeneronME_ALL_AF', 'SIFT4G_score', 'SIFT_score', 'VEST4_score',
    'bStatistic', 'dbNSFP_POPMAX_AF', 'fathmm-XF_coding_score', 'gMVP_score',
    'gnomAD4.1_joint_AF', 'phastCons100way_vertebrate', 'phastCons17way_primate',
    'phastCons470way_mammalian', 'phyloP100way_vertebrate',
    'phyloP17way_primate', 'phyloP470way_mammalian'
]

# Extract input features and binary class labels
X_train = train_df[feature_cols].copy()
y_train = train_df["label"].copy()

# Train the RF model(same hps as in R)
model = RandomForestClassifier(n_estimators=1000, random_state=2005)
model.fit(X_train, y_train)

# Save the model
joblib.dump(model, OUTPUT_MODEL_PATH)
print(f"Model has been saved to：{OUTPUT_MODEL_PATH}")


Model has been saved to：models/rf_final_model.pkl


In [3]:
# Load test data for SHAP analysis
test_df = pd.read_csv(TEST_DATA_PATH)
X_test = test_df[feature_cols].copy()

# Create SHAP explainer and calculate SHAP value
explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(X_test)
shap_values_flat = shap_values[:, :, 1]

# SHAP beeswarm plot
shap.summary_plot(
    shap_values_flat, 
    features=X_test, 
    feature_names=feature_cols,
    max_display=43,
    show=False
)
plt.savefig('plots/shap_beeswarm_testset.png', dpi=300, bbox_inches='tight')
plt.close()
print("SHAP beeswarm plot saved")

# SHAP bar chart of mean absolute values
mean_shap = np.abs(shap_values_flat).mean(axis=0)
sorted_idx = np.argsort(mean_shap)
sorted_shap_values = mean_shap[sorted_idx]
sorted_feature_names = np.array(feature_cols)[sorted_idx]

# Define correlated groups
feature_groups = {
    "Group_1": [
        "AlphaMissense_score", "BayesDel_addAF_score", "BayesDel_noAF_score", "CADD_raw", "ClinPred_score",
        "Eigen-PC-raw_coding", "Eigen-raw_coding", "GERP++_RS", "MetaLR_score", "MetaRNN_score",
        "MetaSVM_score", "MutScore_score", "MutationTaster_score", "Polyphen2_HDIV_score",
        "Polyphen2_HVAR_score", "PrimateAI_score", "REVEL_score", "VEST4_score",
        "fathmm-XF_coding_score", "gMVP_score", "phastCons100way_vertebrate",
        "phyloP100way_vertebrate", "phyloP470way_mammalian"
    ],
    "Group_2": ["RegeneronME_ALL_AF", "gnomAD4.1_joint_AF"],
    "Group_3": ["SIFT4G_score", "SIFT_score"]
}

# Define group colors
group_colors = {
    "Group_1": "tomato",
    "Group_2": "mediumseagreen",
    "Group_3": "royalblue",
    "Other": "lightgrey"
}

# Map features to colors
feature_to_color = {}
for group, features in feature_groups.items():
    for feat in features:
        feature_to_color[feat] = group_colors[group]

bar_colors = [feature_to_color.get(f, group_colors["Other"]) for f in sorted_feature_names]

plt.figure(figsize=(8, 18))
plt.barh(sorted_feature_names, sorted_shap_values, color=bar_colors)
plt.xlabel('Mean(|SHAP value|)', fontsize=12)
plt.title('Global Feature Importance (SHAP, Test Set)', fontsize=14)
plt.grid(True, axis='x', linestyle='--', alpha=0.6)
plt.tight_layout()
plt.savefig('plots/shap_bar_testset_feature_groups.png', dpi=300, bbox_inches='tight')
plt.close()
print("SHAP bar chart saved")

SHAP beeswarm plot saved
SHAP bar chart saved


In [4]:
# Save SHAP value matrix
shap_df = pd.DataFrame(shap_values_flat, columns=feature_cols)
shap_df.to_csv(f"data/shap_values_testset.csv", index=False)
print("SHAP values matrix saved.")

SHAP values matrix saved.


### Compare SHAP & Gini

In [5]:
# Load SHAP data
shap_df=pd.read_csv("data/shap_values_testset.csv")
shap_mean = shap_df.abs().mean().rename("mean_abs_shap")

# Load Gini
gini_df = pd.read_csv("data/rf_feature_importance.txt", sep="\t", index_col=0)
gini_col_name = "MeanDecreaseGini" if "MeanDecreaseGini" in gini_df.columns else gini_df.columns[0]
gini_importance = gini_df[gini_col_name].rename("gini")

# Merge
comparison_df = pd.concat([shap_mean, gini_importance], axis=1, join="inner")

# Normalization
comparison_df["shap_norm"] = comparison_df["mean_abs_shap"] / comparison_df["mean_abs_shap"].max()
comparison_df["gini_norm"] = comparison_df["gini"] / comparison_df["gini"].max()

# Sort
comparison_df_sorted = comparison_df.sort_values("shap_norm", ascending=True)

# Plot
plt.figure(figsize=(10, 16))
bar_width = 0.4
indices = np.arange(len(comparison_df_sorted))

plt.barh(indices, comparison_df_sorted["shap_norm"], height=bar_width, label="SHAP", color="dodgerblue")
plt.barh(indices + bar_width, comparison_df_sorted["gini_norm"], height=bar_width, label="GiniIndex", color="orange")

plt.yticks(indices + bar_width / 2, comparison_df_sorted.index)
plt.xlabel("Normalized Feature Importance", fontsize=12)
plt.title("SHAP vs GiniIndex Feature Importance", fontsize=14)
plt.legend()
plt.grid(axis='x', linestyle='--', alpha=0.5)
plt.tight_layout()
plt.savefig("plots/shap_vs_gini_comparison.png", dpi=300, bbox_inches='tight')
plt.close()
print(f"SHAP vs Gini bar chart saved.")

SHAP vs Gini bar chart saved.
