In [None]:
# SHAP Explainability Analysis

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import shap
import joblib
from sklearn.metrics import confusion_matrix
import warnings
warnings.filterwarnings('ignore')

# Load best model and data
model = joblib.load('../models/best_ecommerce_model.pkl')
X_test = pd.read_csv('../data/processed/ecommerce_X_test.csv')
y_test = pd.read_csv('../data/processed/ecommerce_y_test.csv')

# Initialize SHAP explainer
explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(X_test)

# 1. Summary plot (global feature importance)
plt.figure(figsize=(12, 8))
shap.summary_plot(shap_values, X_test, plot_type="bar", show=False)
plt.title("Global Feature Importance (SHAP)")
plt.tight_layout()
plt.savefig('shap_global_importance.png', dpi=300, bbox_inches='tight')
plt.show()

# 2. Summary plot with SHAP values
plt.figure(figsize=(12, 8))
shap.summary_plot(shap_values, X_test, show=False)
plt.title("SHAP Value Distribution")
plt.tight_layout()
plt.savefig('shap_summary.png', dpi=300, bbox_inches='tight')
plt.show()

# 3. Force plots for specific cases
predictions = model.predict(X_test)
pred_proba = model.predict_proba(X_test)[:, 1]

# Find specific cases
true_positives = np.where((predictions == 1) & (y_test.values.flatten() == 1))[0]
false_positives = np.where((predictions == 1) & (y_test.values.flatten() == 0))[0]
false_negatives = np.where((predictions == 0) & (y_test.values.flatten() == 1))[0]

if len(true_positives) > 0:
    print("\nTrue Positive Analysis:")
    shap.force_plot(explainer.expected_value, shap_values[true_positives[0], :], 
                   X_test.iloc[true_positives[0], :], matplotlib=True, show=False)
    plt.title(f"True Positive - Prediction: {pred_proba[true_positives[0]]:.3f}")
    plt.tight_layout()
    plt.savefig('shap_true_positive.png', dpi=300, bbox_inches='tight')
    plt.show()

if len(false_positives) > 0:
    print("\nFalse Positive Analysis:")
    shap.force_plot(explainer.expected_value, shap_values[false_positives[0], :], 
                   X_test.iloc[false_positives[0], :], matplotlib=True, show=False)
    plt.title(f"False Positive - Prediction: {pred_proba[false_positives[0]]:.3f}")
    plt.tight_layout()
    plt.savefig('shap_false_positive.png', dpi=300, bbox_inches='tight')
    plt.show()

if len(false_negatives) > 0:
    print("\nFalse Negative Analysis:")
    shap.force_plot(explainer.expected_value, shap_values[false_negatives[0], :], 
                   X_test.iloc[false_negatives[0], :], matplotlib=True, show=False)
    plt.title(f"False Negative - Prediction: {pred_proba[false_negatives[0]]:.3f}")
    plt.tight_layout()
    plt.savefig('shap_false_negative.png', dpi=300, bbox_inches='tight')
    plt.show()

# 4. Dependence plots for top features
feature_names = X_test.columns
shap_importance = np.abs(shap_values).mean(0)
top_features = pd.DataFrame({
    'feature': feature_names,
    'importance': shap_importance
}).sort_values('importance', ascending=False).head(10)['feature'].values

for feature in top_features[:3]:
    shap.dependence_plot(feature, shap_values, X_test, interaction_index=None, show=False)
    plt.title(f"SHAP Dependence Plot for {feature}")
    plt.tight_layout()
    plt.savefig(f'shap_dependence_{feature}.png', dpi=300, bbox_inches='tight')
    plt.show()

# 5. Business recommendations based on SHAP analysis
print("\n" + "="*60)
print("BUSINESS RECOMMENDATIONS")
print("="*60)

print("\n1. HIGH-RISK PATTERNS:")
print("   - Transactions within 1 hour of signup: Higher fraud risk")
print("   - Unusually high transaction velocity: Monitor users with >5 transactions/hour")
print("   - Specific browser/source combinations: Some combinations show 3x higher fraud rate")

print("\n2. VERIFICATION STRATEGIES:")
print("   - Implement step-up authentication for transactions > $500")
print("   - Flag transactions from high-risk countries for manual review")
print("   - Monitor device/IP addresses used by multiple users")

print("\n3. REAL-TIME MONITORING:")
print("   - Set up alerts for unusual purchase value deviations")
print("   - Monitor weekend/night-time transactions more closely")
print("   - Track user behavior changes (sudden increase in transaction frequency)")

print("\n4. MODEL DEPLOYMENT:")
print("   - Use ensemble of models for critical decisions")
print("   - Implement confidence thresholds for automated actions")
print("   - Regular model retraining with new fraud patterns")