# USG Failure Prediction - SHAP Interpretability Analysis

**Objective:** Explain model predictions for business stakeholders

**Key Analyses:**
- SHAP summary plot (global feature importance)
- SHAP waterfall plots (individual predictions)
- Partial Dependence Plots (PDP)
- Feature importance comparison
- Business insights

In [None]:
# Import libraries
import sys
sys.path.append('../src')

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import shap
import joblib
from datetime import datetime
import warnings

warnings.filterwarnings('ignore')

# Configure
np.random.seed(42)
plt.style.use('seaborn-v0_8-darkgrid')
shap.initjs()
%matplotlib inline

print(f"SHAP Analysis started: {datetime.now()}")

## 1. Load Model and Data

In [None]:
# Load trained model
try:
    model = joblib.load('../models/model.pkl')
    print("✓ Model loaded successfully")
except FileNotFoundError:
    print("⚠ Model not found. Please run Notebook 03 first.")
    model = None

# Load test data
try:
    X = pd.read_csv('../data/processed/X_processed.csv')
    y = pd.read_csv('../data/processed/y_target.csv').squeeze()
    print(f"✓ Data loaded: {X.shape}")
    
    # Split to get test set (same random state as training)
    from sklearn.model_selection import train_test_split
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.2, random_state=42, stratify=y
    )
    print(f"✓ Test set: {X_test.shape}")
    
except FileNotFoundError:
    print("⚠ Processed data not found. Please run Notebook 02 first.")
    X_test, y_test = None, None

## 2. Initialize SHAP Explainer

In [None]:
if model is not None and X_test is not None:
    print("Initializing SHAP explainer...")
    print("This may take 1-2 minutes...\n")
    
    # Sample data for background (use subset for speed)
    X_background = shap.sample(X_train, 100, random_state=42)
    
    # Create TreeExplainer for tree-based models
    # Get the base model from the calibrated wrapper
    if hasattr(model, 'model'):
        if hasattr(model.model, 'calibrated_model'):
            base_estimator = model.model.calibrated_model.calibrated_classifiers_[0].estimator
        else:
            base_estimator = model.model
    else:
        base_estimator = model
    
    explainer = shap.TreeExplainer(base_estimator)
    
    # Calculate SHAP values for test set
    print("Calculating SHAP values for test set...")
    shap_values = explainer.shap_values(X_test)
    
    # For binary classification, get positive class
    if isinstance(shap_values, list):
        shap_values_positive = shap_values[1]
    else:
        shap_values_positive = shap_values
    
    print(f"\n✓ SHAP values computed: {shap_values_positive.shape}")
    
    # Save explainer for API
    joblib.dump(explainer, '../models/shap_explainer.pkl')
    print("✓ SHAP explainer saved to models/shap_explainer.pkl")

## 3. SHAP Summary Plot - Global Feature Importance

In [None]:
if 'shap_values_positive' in locals():
    # Summary plot (beeswarm)
    plt.figure(figsize=(12, 10))
    shap.summary_plot(shap_values_positive, X_test, max_display=20, show=False)
    plt.title('SHAP Summary Plot - Top 20 Features', fontsize=16, fontweight='bold', pad=20)
    plt.tight_layout()
    plt.savefig('../reports/visualizations/shap_summary_plot.png', dpi=300, bbox_inches='tight')
    plt.show()
    
    print("✓ SHAP summary plot saved")

## 4. SHAP Bar Plot - Mean Absolute Impact

In [None]:
if 'shap_values_positive' in locals():
    # Bar plot showing mean absolute SHAP values
    plt.figure(figsize=(12, 8))
    shap.summary_plot(shap_values_positive, X_test, plot_type="bar", 
                     max_display=20, show=False)
    plt.title('Feature Importance - Mean |SHAP Value|', fontsize=16, fontweight='bold', pad=20)
    plt.tight_layout()
    plt.savefig('../reports/visualizations/shap_bar_plot.png', dpi=300, bbox_inches='tight')
    plt.show()
    
    # Get top features
    feature_importance = pd.DataFrame({
        'feature': X_test.columns,
        'importance': np.abs(shap_values_positive).mean(axis=0)
    }).sort_values('importance', ascending=False)
    
    print("\nTop 20 Most Important Features:")
    print("="*60)
    display(feature_importance.head(20))

## 5. SHAP Waterfall Plots - Individual Predictions

In [None]:
if 'shap_values_positive' in locals():
    # Find examples: one failure, one no-failure
    failure_idx = y_test[y_test == 'Yes'].index[0]
    no_failure_idx = y_test[y_test == 'No'].index[0]
    
    # Get positions in test set
    failure_pos = X_test.index.get_loc(failure_idx)
    no_failure_pos = X_test.index.get_loc(no_failure_idx)
    
    # Create explanation objects
    base_value = explainer.expected_value
    if isinstance(base_value, list):
        base_value = base_value[1]
    
    # Waterfall for failure case
    print("\nExample 1: Device with Warranty Claim (Failure)")
    print("="*60)
    
    explanation_failure = shap.Explanation(
        values=shap_values_positive[failure_pos],
        base_values=base_value,
        data=X_test.iloc[failure_pos].values,
        feature_names=X_test.columns.tolist()
    )
    
    plt.figure(figsize=(10, 8))
    shap.plots.waterfall(explanation_failure, max_display=20, show=False)
    plt.title('SHAP Waterfall - Failure Case', fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.savefig('../reports/visualizations/shap_waterfall_failure.png', 
               dpi=300, bbox_inches='tight')
    plt.show()
    
    # Waterfall for no-failure case
    print("\nExample 2: Device without Warranty Claim (No Failure)")
    print("="*60)
    
    explanation_no_failure = shap.Explanation(
        values=shap_values_positive[no_failure_pos],
        base_values=base_value,
        data=X_test.iloc[no_failure_pos].values,
        feature_names=X_test.columns.tolist()
    )
    
    plt.figure(figsize=(10, 8))
    shap.plots.waterfall(explanation_no_failure, max_display=20, show=False)
    plt.title('SHAP Waterfall - No Failure Case', fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.savefig('../reports/visualizations/shap_waterfall_no_failure.png', 
               dpi=300, bbox_inches='tight')
    plt.show()

## 6. Partial Dependence Plots (PDP)

In [None]:
if 'feature_importance' in locals():
    # Get top 10 features for PDP
    top_features = feature_importance.head(10)['feature'].tolist()
    
    print(f"Creating Partial Dependence Plots for top {len(top_features)} features...")
    
    # Create SHAP dependence plots
    fig, axes = plt.subplots(5, 2, figsize=(16, 20))
    axes = axes.ravel()
    
    for idx, feature in enumerate(top_features[:10]):
        feature_idx = X_test.columns.get_loc(feature)
        
        plt.sca(axes[idx])
        shap.dependence_plot(
            feature_idx, 
            shap_values_positive, 
            X_test,
            show=False,
            ax=axes[idx]
        )
        axes[idx].set_title(f'Dependence: {feature}', fontweight='bold')
    
    plt.tight_layout()
    plt.savefig('../reports/visualizations/shap_dependence_plots.png', 
               dpi=300, bbox_inches='tight')
    plt.show()
    
    print("✓ Dependence plots saved")

## 7. Feature Importance Comparison

In [None]:
if model is not None and 'feature_importance' in locals():
    # Get XGBoost native importance
    try:
        # Extract base XGBoost model
        if hasattr(model, 'model'):
            if hasattr(model.model, 'calibrated_model'):
                xgb_model = model.model.calibrated_model.calibrated_classifiers_[0].estimator
            else:
                xgb_model = model.model
        else:
            xgb_model = model
        
        # Check if it's an ensemble
        if hasattr(xgb_model, 'estimators_'):
            # Get XGBoost from ensemble
            for name, estimator in xgb_model.estimators_:
                if 'xgb' in name.lower():
                    xgb_base = estimator
                    break
        else:
            xgb_base = xgb_model
        
        # Get feature importances
        if hasattr(xgb_base, 'feature_importances_'):
            gain_importance = pd.DataFrame({
                'feature': X_test.columns,
                'gain': xgb_base.feature_importances_
            }).sort_values('gain', ascending=False)
            
            # Compare with SHAP
            comparison = feature_importance.merge(
                gain_importance, on='feature', how='left'
            ).head(20)
            
            # Normalize for comparison
            comparison['importance_norm'] = comparison['importance'] / comparison['importance'].max()
            comparison['gain_norm'] = comparison['gain'] / comparison['gain'].max()
            
            # Plot comparison
            fig, ax = plt.subplots(figsize=(14, 8))
            
            x = np.arange(len(comparison))
            width = 0.35
            
            ax.barh(x - width/2, comparison['importance_norm'], width, 
                   label='SHAP Importance', color='steelblue')
            ax.barh(x + width/2, comparison['gain_norm'], width, 
                   label='XGBoost Gain', color='coral')
            
            ax.set_yticks(x)
            ax.set_yticklabels(comparison['feature'])
            ax.set_xlabel('Normalized Importance', fontsize=12)
            ax.set_title('Feature Importance Comparison: SHAP vs XGBoost Gain', 
                        fontsize=14, fontweight='bold')
            ax.legend()
            ax.invert_yaxis()
            
            plt.tight_layout()
            plt.savefig('../reports/visualizations/feature_importance_comparison.png', 
                       dpi=300, bbox_inches='tight')
            plt.show()
            
            print("\nFeature Importance Comparison (Top 10):")
            display(comparison[['feature', 'importance', 'gain']].head(10))
            
    except Exception as e:
        print(f"Could not extract XGBoost importances: {e}")

## 8. Business Insights from SHAP Analysis

In [None]:
if 'feature_importance' in locals():
    print("\n" + "="*80)
    print("BUSINESS INSIGHTS FROM SHAP ANALYSIS")
    print("="*80)
    
    top_10_features = feature_importance.head(10)
    
    print("\n1. CRITICAL FAILURE PREDICTORS:")
    print("-" * 80)
    for idx, row in top_10_features.iterrows():
        print(f"   {idx+1}. {row['feature']}: {row['importance']:.4f}")
    
    print("\n2. ACTIONABLE RECOMMENDATIONS:")
    print("-" * 80)
    
    # Identify feature categories
    supplier_features = [f for f in top_10_features['feature'] if 'Supplier' in f]
    batch_features = [f for f in top_10_features['feature'] if 'Batch' in f]
    temp_features = [f for f in top_10_features['feature'] if 'Temp' in f]
    interaction_features = [f for f in top_10_features['feature'] if '_x_' in f or '_div_' in f]
    
    if supplier_features:
        print("\n   SUPPLIER QUALITY:")
        print(f"   - {len(supplier_features)} supplier-related features in top 10")
        print("   - Action: Review and audit high-risk suppliers")
        print("   - Action: Implement supplier qualification programs")
    
    if batch_features:
        print("\n   BATCH MONITORING:")
        print(f"   - {len(batch_features)} batch-related features in top 10")
        print("   - Action: Enhance batch-level quality controls")
        print("   - Action: Implement real-time batch anomaly detection")
    
    if temp_features:
        print("\n   ENVIRONMENTAL CONTROLS:")
        print(f"   - {len(temp_features)} temperature-related features in top 10")
        print("   - Action: Tighten environmental parameter tolerances")
        print("   - Action: Install real-time monitoring systems")
    
    if interaction_features:
        print("\n   PROCESS INTERACTIONS:")
        print(f"   - {len(interaction_features)} interaction features in top 10")
        print("   - Action: Optimize parameter combinations")
        print("   - Action: Avoid extreme parameter pairings")
    
    print("\n3. DEPLOYMENT STRATEGY:")
    print("-" * 80)
    print("   - Use SHAP waterfall plots in production for each prediction")
    print("   - Alert quality team when high-risk devices detected")
    print("   - Generate weekly reports on feature trends")
    print("   - Integrate with manufacturing execution system (MES)")
    
    print("\n" + "="*80)

## 9. Summary

In [None]:
print("\n" + "="*80)
print("SHAP ANALYSIS SUMMARY")
print("="*80)
print("\nVisualizations Created:")
print("  ✓ SHAP Summary Plot (beeswarm)")
print("  ✓ SHAP Bar Plot (mean absolute impact)")
print("  ✓ SHAP Waterfall Plots (2 examples)")
print("  ✓ Partial Dependence Plots (top 10 features)")
print("  ✓ Feature Importance Comparison")

print("\nArtifacts Saved:")
print("  ✓ SHAP explainer (models/shap_explainer.pkl)")
print("  ✓ All visualizations (reports/visualizations/)")

print("\nKey Findings:")
if 'top_10_features' in locals():
    print(f"  - {len(top_10_features)} critical features identified")
    print(f"  - Top feature: {top_10_features.iloc[0]['feature']}")
    print("  - Model is fully interpretable for business use")

print("\nNext Steps:")
print("  → Deploy API with SHAP explanations (src/api.py)")
print("  → Generate business report")
print("  → Present findings to stakeholders")

print("\n" + "="*80)
print(f"SHAP analysis completed: {datetime.now()}")
print("="*80)