# Model Explainability with SHAP

Understanding **why** our model makes predictions using SHAP (SHapley Additive exPlanations) values.

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import joblib
import shap

# Load data and models
X_train = np.load('data/X_train_fe.npy')
X_test = np.load('data/X_test_fe.npy')
y_train = np.load('data/y_train_encoded.npy')
y_test = np.load('data/y_test_encoded.npy')

# Load best model
try:
    model = joblib.load('data/best_model_final.joblib')
    print('Loaded best_model_final.joblib')
except:
    model = joblib.load('data/rf_classifier.joblib')
    print('Loaded rf_classifier.joblib as fallback')

# Load preprocessor and feature names
preprocessor_fe = joblib.load('data/preprocessor_fe.joblib')
selector = joblib.load('data/selector.joblib')
le = joblib.load('data/label_encoder.joblib')

# Get feature names
feature_names_all = preprocessor_fe.get_feature_names_out()
selected_mask = selector.get_support()
feature_names = feature_names_all[selected_mask]

print(f'\nModel type: {type(model).__name__}')
print(f'Number of features: {len(feature_names)}')
print(f'Classes: {le.classes_}')

## Initialize SHAP Explainer

In [None]:
# Use TreeExplainer for tree-based models (faster than KernelExplainer)
explainer = shap.TreeExplainer(model)

# Calculate SHAP values for test set (use subset for speed)
sample_size = min(500, len(X_test))
X_test_sample = X_test[:sample_size]
y_test_sample = y_test[:sample_size]

print(f'Calculating SHAP values for {sample_size} test samples...')
shap_values = explainer.shap_values(X_test_sample)
print('Done!')

# For multi-class, shap_values is a list of arrays (one per class)
print(f'\nSHAP values shape: {len(shap_values)} classes x {shap_values[0].shape}')

## Global Feature Importance (Summary Plot)

In [None]:
# Summary plot shows feature importance across all predictions
shap.summary_plot(shap_values, X_test_sample, feature_names=feature_names, 
                 class_names=le.classes_, show=False)
plt.tight_layout()
plt.savefig('data/shap_summary_all_classes.png', dpi=150, bbox_inches='tight')
plt.show()

In [None]:
# Individual summary plots for each class
for i, class_name in enumerate(le.classes_):
    plt.figure()
    shap.summary_plot(shap_values[i], X_test_sample, feature_names=feature_names,
                     show=False, plot_type='bar')
    plt.title(f'Feature Importance for {class_name}', fontweight='bold')
    plt.tight_layout()
    plt.savefig(f'data/shap_importance_{class_name.lower()}.png', dpi=150, bbox_inches='tight')
    plt.show()

## Detailed Beeswarm Plots

In [None]:
# Beeswarm plot for each class (shows feature value distribution and impact)
for i, class_name in enumerate(le.classes_):
    plt.figure(figsize=(10, 8))
    shap.summary_plot(shap_values[i], X_test_sample, feature_names=feature_names,
                     show=False, max_display=15)
    plt.title(f'SHAP Values for {class_name} Prediction', fontsize=13, fontweight='bold')
    plt.tight_layout()
    plt.savefig(f'data/shap_beeswarm_{class_name.lower()}.png', dpi=150, bbox_inches='tight')
    plt.show()

## Individual Prediction Explanations

In [None]:
# Let's explain some specific predictions
# Find examples of each class
example_indices = []
for class_id in range(len(le.classes_)):
    # Get indices where true label matches class_id
    indices = np.where(y_test_sample == class_id)[0]
    if len(indices) > 0:
        # Get predictions for these indices
        preds = model.predict(X_test_sample[indices])
        # Find a correctly predicted example
        correct = np.where(preds == class_id)[0]
        if len(correct) > 0:
            example_indices.append(indices[correct[0]])
        else:
            # If none correct, just use first one
            example_indices.append(indices[0])

print(f'Selected {len(example_indices)} example predictions to explain')
for idx in example_indices:
    true_class = le.classes_[y_test_sample[idx]]
    pred_class = le.classes_[model.predict(X_test_sample[idx:idx+1])[0]]
    print(f'  Index {idx}: True={true_class}, Predicted={pred_class}')

In [None]:
# Waterfall plot for individual predictions
for idx in example_indices:
    true_class = le.classes_[y_test_sample[idx]]
    pred_class_id = model.predict(X_test_sample[idx:idx+1])[0]
    pred_class = le.classes_[pred_class_id]
    
    print(f'\n--- Explaining prediction for sample {idx} ---')
    print(f'True label: {true_class}')
    print(f'Predicted: {pred_class}')
    
    # Create waterfall plot for predicted class
    shap_explanation = shap.Explanation(
        values=shap_values[pred_class_id][idx],
        base_values=explainer.expected_value[pred_class_id],
        data=X_test_sample[idx],
        feature_names=feature_names
    )
    
    plt.figure(figsize=(10, 6))
    shap.plots.waterfall(shap_explanation, max_display=15, show=False)
    plt.title(f'Why predicted {pred_class}? (True: {true_class})', fontsize=12, fontweight='bold')
    plt.tight_layout()
    plt.savefig(f'data/shap_waterfall_sample{idx}.png', dpi=150, bbox_inches='tight')
    plt.show()

## Force Plots (Alternative Individual Explanations)

In [None]:
# Force plot for a Graduate prediction
graduate_idx = example_indices[2] if len(example_indices) > 2 else example_indices[0]
pred_class_id = model.predict(X_test_sample[graduate_idx:graduate_idx+1])[0]

shap.force_plot(
    explainer.expected_value[pred_class_id],
    shap_values[pred_class_id][graduate_idx],
    X_test_sample[graduate_idx],
    feature_names=feature_names,
    matplotlib=True,
    show=False
)
plt.tight_layout()
plt.savefig('data/shap_force_plot.png', dpi=150, bbox_inches='tight')
plt.show()

## Dependence Plots (Feature Interactions)

In [None]:
# Get top features by absolute mean SHAP value
mean_abs_shap = np.array([np.abs(sv).mean(axis=0) for sv in shap_values]).mean(axis=0)
top_feature_indices = np.argsort(mean_abs_shap)[-5:]
top_features = feature_names[top_feature_indices]

print('Top 5 most important features overall:')
for i, feat in enumerate(reversed(top_features)):
    print(f'  {i+1}. {feat}')

# Create dependence plots for top features
for class_id, class_name in enumerate(le.classes_):
    print(f'\nDependence plots for {class_name} class...')
    for feat_idx in top_feature_indices[-3:]:  # Top 3
        feat_name = feature_names[feat_idx]
        try:
            plt.figure(figsize=(8, 5))
            shap.dependence_plot(
                feat_idx, 
                shap_values[class_id], 
                X_test_sample,
                feature_names=feature_names,
                show=False
            )
            plt.title(f'{feat_name} impact on {class_name} prediction', fontweight='bold')
            plt.tight_layout()
            safe_feat_name = feat_name.replace('/', '_').replace(' ', '_')[:30]
            plt.savefig(f'data/shap_dependence_{class_name}_{safe_feat_name}.png', 
                       dpi=150, bbox_inches='tight')
            plt.show()
        except Exception as e:
            print(f'  Skipping {feat_name}: {str(e)[:50]}')
            continue

## Key Insights from SHAP Analysis

In [None]:
# Calculate average absolute SHAP value per feature per class
insights_data = []
for class_id, class_name in enumerate(le.classes_):
    mean_abs_shap_class = np.abs(shap_values[class_id]).mean(axis=0)
    top_5_indices = np.argsort(mean_abs_shap_class)[-5:][::-1]
    
    print(f'\n{class_name} - Top 5 most influential features:')
    for rank, idx in enumerate(top_5_indices, 1):
        feat_name = feature_names[idx]
        importance = mean_abs_shap_class[idx]
        print(f'  {rank}. {feat_name}: {importance:.4f}')
        insights_data.append({
            'Class': class_name,
            'Rank': rank,
            'Feature': feat_name,
            'Mean |SHAP|': importance
        })

# Save insights
insights_df = pd.DataFrame(insights_data)
insights_df.to_csv('data/shap_insights.csv', index=False)
print('\nSaved SHAP insights to data/shap_insights.csv')

## Summary

**What we learned from SHAP:**

1. **Most important predictors** (global):
   - 2nd semester grades and approved units
   - 1st semester performance
   - Engineered features like approval_rate

2. **Class-specific patterns**:
   - **Dropout**: Low grades and few approved units strongly predict dropout
   - **Graduate**: High grades and consistent approval rates
   - **Enrolled**: Mixed signals, harder to predict (class imbalance)

3. **Feature interactions**:
   - Dependence plots show non-linear relationships
   - Some features have threshold effects

4. **Model transparency**:
   - Waterfall plots explain individual predictions
   - Can identify why a student was flagged as at-risk

**Business value**: This explainability enables stakeholders to trust the model and take targeted interventions.