In [1]:
# --- 1. IMPORT NECESSARY LIBRARIES ---
import pandas as pd
import xgboost as xgb
import shap
import os
import joblib # Used for loading the trained models
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split

# --- 2. SETUP AND LOOP THROUGH EACH CLASS ---
# This script assumes you have a 'trained_models' folder with the 4 saved models
# and a 'data' folder with the 4 CSV files.

# Create a directory to save the SHAP plots
if not os.path.exists('shap_plots'):
    os.makedirs('shap_plots')

# Loop through each of the 4 classes
for i in range(1, 5):
    print(f"--- Processing and Generating SHAP Plots for Class {i} ---")
    
    # --- A. Load the Model and Data ---
    model_path = f'trained_models/xgb_model_class_{i}.joblib'
    data_path = f'data/class_{i}_modeling_data.csv'
    
    try:
        model = joblib.load(model_path)
        df = pd.read_csv(data_path)
    except FileNotFoundError as e:
        print(f"Error: Could not find a necessary file. {e}")
        continue

    # We need the test set to explain the model's predictions on unseen data
    X = df.drop('udpyilal', axis=1)
    y = df['udpyilal']
    
    # Re-create the same train/test split to get the correct test set
    # Using the same random_state ensures the split is identical to the one in Step 2
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.2, random_state=42, stratify=y
    )
    
    # --- B. Calculate SHAP Values ---
    # SHAP needs an "explainer" that is specific to the type of model.
    # For XGBoost and other tree-based models, we use the TreeExplainer.
    explainer = shap.TreeExplainer(model)
    
    print("Calculating SHAP values for the test set...")
    # The shap_values object contains the impact of every feature for every sample
    shap_values = explainer.shap_values(X_test)
    
    # --- C. Generate and Save SHAP Bar Plot (Global Importance) ---
    # This plot shows the mean absolute SHAP value for each feature,
    # giving a simple ranking of feature importance.
    print("Generating SHAP bar plot...")
    plt.figure()
    shap.summary_plot(shap_values, X_test, plot_type="bar", show=False)
    plt.title(f'Feature Importance - Class {i}')
    # Save the plot
    bar_plot_path = f'shap_plots/class_{i}_bar_plot.png'
    plt.savefig(bar_plot_path, bbox_inches='tight')
    plt.close() # Close the plot to free up memory
    print(f"Saved bar plot to: {bar_plot_path}")

    # --- D. Generate and Save SHAP Summary Plot (Beeswarm) ---
    # This is the most detailed plot. Each dot is a person in the test set.
    # - X-axis: The SHAP value (impact on model output). Positive values increase risk.
    # - Y-axis: Features, ranked by importance.
    # - Color: The value of the feature (red=high, blue=low).
    print("Generating SHAP summary (beeswarm) plot...")
    plt.figure()
    shap.summary_plot(shap_values, X_test, show=False)
    plt.title(f'SHAP Summary Plot - Class {i}')
    # Save the plot
    summary_plot_path = f'shap_plots/class_{i}_summary_plot.png'
    plt.savefig(summary_plot_path, bbox_inches='tight')
    plt.close() # Close the plot to free up memory
    print(f"Saved summary plot to: {summary_plot_path}\n")

print("--- Step 3 Complete. All SHAP plots have been generated and saved. ---")


--- Processing and Generating SHAP Plots for Class 1 ---
Calculating SHAP values for the test set...
Generating SHAP bar plot...
Saved bar plot to: shap_plots/class_1_bar_plot.png
Generating SHAP summary (beeswarm) plot...
Saved summary plot to: shap_plots/class_1_summary_plot.png

--- Processing and Generating SHAP Plots for Class 2 ---
Calculating SHAP values for the test set...
Generating SHAP bar plot...
Saved bar plot to: shap_plots/class_2_bar_plot.png
Generating SHAP summary (beeswarm) plot...
Saved summary plot to: shap_plots/class_2_summary_plot.png

--- Processing and Generating SHAP Plots for Class 3 ---
Calculating SHAP values for the test set...
Generating SHAP bar plot...
Saved bar plot to: shap_plots/class_3_bar_plot.png
Generating SHAP summary (beeswarm) plot...
Saved summary plot to: shap_plots/class_3_summary_plot.png

--- Processing and Generating SHAP Plots for Class 4 ---
Calculating SHAP values for the test set...
Generating SHAP bar plot...
Saved bar plot to: sha

In [1]:
# --- 1. IMPORT NECESSARY LIBRARIES ---
import pandas as pd
import xgboost as xgb
import shap
import os
import joblib # Used for loading the trained models
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import numpy as np
import re # For regular expressions to group column names

# --- 2. SETUP AND LOOP THROUGH EACH CLASS ---
# This script assumes you have a 'trained_models' folder with the 4 saved models
# and a 'data' folder with the 4 CSV files.

# Create a directory to save the SHAP plots
if not os.path.exists('shap_plots1'):
    os.makedirs('shap_plots1')

# Loop through each of the 4 classes
for i in range(1, 5):
    print(f"--- Processing and Generating SHAP Plots for Class {i} ---")
    
    # --- A. Load the Model and Data ---
    model_path = f'trained_models/xgb_model_class_{i}.joblib'
    data_path = f'data/class_{i}_modeling_data.csv'
    
    try:
        model = joblib.load(model_path)
        df = pd.read_csv(data_path)
    except FileNotFoundError as e:
        print(f"Error: Could not find a necessary file. {e}")
        continue

    X = df.drop('udpyilal', axis=1)
    y = df['udpyilal']
    
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.2, random_state=42, stratify=y
    )
    
    # --- B. Calculate SHAP Values ---
    explainer = shap.TreeExplainer(model)
    print("Calculating SHAP values for the test set...")
    shap_values = explainer.shap_values(X_test)
    
    # --- C. Generate and Save Original (Encoded) SHAP Plots ---
    # We'll still save the original plots for detailed inspection if needed.
    print("Generating original SHAP plots (with one-hot encoded features)...")
    plt.figure()
    shap.summary_plot(shap_values, X_test, plot_type="bar", show=False)
    plt.title(f'Feature Importance (Encoded) - Class {i}')
    plt.savefig(f'shap_plots/class_{i}_bar_plot_encoded.png', bbox_inches='tight')
    plt.close()

    plt.figure()
    shap.summary_plot(shap_values, X_test, show=False)
    plt.title(f'SHAP Summary Plot (Encoded) - Class {i}')
    plt.savefig(f'shap_plots/class_{i}_summary_plot_encoded.png', bbox_inches='tight')
    plt.close()

    # --- NEW: D. Aggregate SHAP values for original features ---
    print("Aggregating SHAP values for original categorical features...")
    
    # Create a DataFrame of the SHAP values for easier manipulation
    shap_df = pd.DataFrame(shap_values, columns=X_test.columns)
    
    # Get the base feature names by splitting the column names (e.g., 'age_cat.18.25' -> 'age_cat')
    # We use a regular expression to get the part of the string before the first dot.
    base_features = [re.split(r'\.', f)[0] for f in X_test.columns]
    
    # Create a new DataFrame to store the aggregated SHAP values
    # We take the absolute SHAP values because we want to measure total impact
    abs_shap_df = shap_df.abs()
    
    # Group by the base feature name and sum the SHAP values
    aggregated_shap_values = abs_shap_df.groupby(base_features, axis=1).sum()

    # Now, calculate the mean of these aggregated values for the bar plot
    mean_aggregated_shap = aggregated_shap_values.mean(axis=0).sort_values(ascending=False)

    # --- E. Generate and Save Aggregated Bar Plot ---
    print("Generating new aggregated SHAP bar plot...")
    plt.figure(figsize=(10, 8))
    mean_aggregated_shap.plot(kind='barh', color='dodgerblue')
    plt.gca().invert_yaxis() # To show the most important feature at the top
    plt.title(f'Aggregated Feature Importance - Class {i}')
    plt.xlabel('mean(|SHAP value|) (average impact on model output magnitude)')
    plt.grid(axis='x', linestyle='--', alpha=0.6)
    
    # Save the new aggregated plot
    agg_plot_path = f'shap_plots/class_{i}_bar_plot_aggregated.png'
    plt.savefig(agg_plot_path, bbox_inches='tight')
    plt.close()
    print(f"Saved aggregated bar plot to: {agg_plot_path}\n")

print("--- Step 3 Complete. All SHAP plots have been generated and saved. ---")


--- Processing and Generating SHAP Plots for Class 1 ---
Calculating SHAP values for the test set...
Generating original SHAP plots (with one-hot encoded features)...
Aggregating SHAP values for original categorical features...
Generating new aggregated SHAP bar plot...


  aggregated_shap_values = abs_shap_df.groupby(base_features, axis=1).sum()


Saved aggregated bar plot to: shap_plots/class_1_bar_plot_aggregated.png

--- Processing and Generating SHAP Plots for Class 2 ---
Calculating SHAP values for the test set...
Generating original SHAP plots (with one-hot encoded features)...
Aggregating SHAP values for original categorical features...
Generating new aggregated SHAP bar plot...


  aggregated_shap_values = abs_shap_df.groupby(base_features, axis=1).sum()


Saved aggregated bar plot to: shap_plots/class_2_bar_plot_aggregated.png

--- Processing and Generating SHAP Plots for Class 3 ---
Calculating SHAP values for the test set...
Generating original SHAP plots (with one-hot encoded features)...
Aggregating SHAP values for original categorical features...
Generating new aggregated SHAP bar plot...


  aggregated_shap_values = abs_shap_df.groupby(base_features, axis=1).sum()


Saved aggregated bar plot to: shap_plots/class_3_bar_plot_aggregated.png

--- Processing and Generating SHAP Plots for Class 4 ---
Calculating SHAP values for the test set...
Generating original SHAP plots (with one-hot encoded features)...
Aggregating SHAP values for original categorical features...
Generating new aggregated SHAP bar plot...


  aggregated_shap_values = abs_shap_df.groupby(base_features, axis=1).sum()


Saved aggregated bar plot to: shap_plots/class_4_bar_plot_aggregated.png

--- Step 3 Complete. All SHAP plots have been generated and saved. ---
