In [None]:
import shap
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import os

In [None]:
def evaluation(
    model,
    X: pd.DataFrame,
    name: str = None,
    max_display: int = 20,
    sample_size: int = None,
    random_state: int = 2025
) -> pd.DataFrame:

    if sample_size is not None:
        n = min(sample_size, len(X))
        X = X.sample(n, random_state=random_state)

    output_dir = "Results"
    os.makedirs(output_dir, exist_ok=True)

    explainer = shap.TreeExplainer(model, approximate=True)
    shap_values = explainer.shap_values(X)

    # Beeswarm
    plt.figure(figsize=(8, 6))
    shap.summary_plot(shap_values, X, max_display=max_display, show=False)
    plt.title(f"{name} SHAP Summary (Beeswarm)")
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, f"{name}_shap_beeswarm.png"))
    plt.close()

    # Bar
    plt.figure(figsize=(8, 6))
    shap.summary_plot(shap_values, X, plot_type="bar", max_display=max_display, show=False)
    plt.title(f"{name} SHAP Summary (Bar)")
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, f"{name}_shap_bar.png"))
    plt.close()

    # SHAP table
    mean_abs = pd.Series(np.abs(shap_values).mean(axis=0), index=X.columns)
    mean_abs_shap_df = mean_abs.reset_index()
    mean_abs_shap_df.columns = ["feature", "mean_abs_shap"]
    mean_abs_shap_df = mean_abs_shap_df.sort_values(by="mean_abs_shap", ascending=False).reset_index(drop=True)

    # Save as csv
    csv_path = os.path.join(output_dir, f"{name}_mean_abs_shap.csv")
    mean_abs_shap_df.to_csv(csv_path, index=False)

    # SHAP Dependence Plot
    top_features = mean_abs_shap_df['feature'].tolist()[:10]
    for feature in top_features:
        plt.figure(figsize=(8, 6))
        shap.dependence_plot(feature, shap_values, X, interaction_index=None, show=False)
        plt.title(f"{name} SHAP Dependence for {feature}")
        plt.tight_layout()
        safe_feat = feature.replace(" ", "_")
        plt.savefig(os.path.join(output_dir, f"{name}_dependence_{safe_feat}.png"))
        plt.close()