# IEEE-CIS Fraud Detection - Model Interpretation

## Overview
This notebook provides comprehensive model interpretation for the fraud detection system.
Explainability is critical in fraud detection for:
1. **Regulatory Compliance**: Many jurisdictions require explainable AI decisions
2. **Fraud Analyst Trust**: Analysts need to understand why transactions are flagged
3. **Model Debugging**: Identify if model is learning spurious correlations
4. **Continuous Improvement**: Understand failure modes to improve the model

## Interpretation Methods
1. **Feature Importance**: Which features contribute most to predictions
2. **SHAP Values**: Local explanations for individual predictions
3. **Error Analysis**: Understanding false positives and false negatives
4. **Partial Dependence**: How features affect predictions globally

In [None]:
# Standard library imports
import os
import sys
import warnings
import pickle
from pathlib import Path

# Data manipulation
import numpy as np
import pandas as pd

# Visualization
import matplotlib.pyplot as plt
import seaborn as sns

# ML libraries
from sklearn.metrics import confusion_matrix, classification_report

# Configuration
warnings.filterwarnings('ignore')
plt.style.use('seaborn-v0_8-whitegrid')

# Set random seed
RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)

# Define paths
BASE_PATH = Path('..').resolve()
PROCESSED_PATH = BASE_PATH / 'Data' / 'processed'
FEATURES_PATH = BASE_PATH / 'Data' / 'features'
OUTPUT_PATH = BASE_PATH / 'outputs'
MODELS_PATH = OUTPUT_PATH / 'models'
VISUALS_PATH = OUTPUT_PATH / 'visuals'

# Add src to path
sys.path.insert(0, str(BASE_PATH / 'src'))

print(f"Base Path: {BASE_PATH}")

In [None]:
# Load best model and artifacts
print("Loading model and data...")

with open(MODELS_PATH / 'best_model.pkl', 'rb') as f:
    model_artifacts = pickle.load(f)

model = model_artifacts['model']
model_name = model_artifacts['model_name']
optimal_threshold = model_artifacts['optimal_threshold']
feature_cols = model_artifacts['feature_cols']

print(f"Model: {model_name}")
print(f"Optimal Threshold: {optimal_threshold:.4f}")
print(f"Number of Features: {len(feature_cols)}")

In [None]:
# Load processed data
train_df = pd.read_parquet(PROCESSED_PATH / 'train_processed.parquet')

# Time-based split (same as in modeling)
train_df_sorted = train_df.sort_values('TransactionDT').reset_index(drop=True)
split_idx = int(len(train_df_sorted) * 0.8)
val_data = train_df_sorted.iloc[split_idx:]

# Prepare features
available_features = [c for c in feature_cols if c in val_data.columns]
X_val = val_data[available_features]
y_val = val_data['isFraud']

print(f"Validation set: {len(X_val)} samples")
print(f"Fraud rate: {y_val.mean()*100:.2f}%")

In [None]:
# Generate predictions
y_pred_proba = model.predict_proba(X_val)[:, 1]
y_pred = (y_pred_proba >= optimal_threshold).astype(int)

print(f"Predictions generated for {len(y_pred)} samples")

## 1. Feature Importance Analysis

Feature importance helps understand which features the model relies on most.
For tree-based models, we can use:
- **Gain**: Average improvement in split criterion when feature is used
- **Split**: Number of times feature is used for splitting

In [None]:
# Extract feature importance
if hasattr(model, 'feature_importances_'):
    importance_df = pd.DataFrame({
        'feature': available_features,
        'importance': model.feature_importances_
    }).sort_values('importance', ascending=False)
else:
    # For models without direct feature_importances_
    importance_df = pd.DataFrame({
        'feature': available_features,
        'importance': [1/len(available_features)] * len(available_features)
    })

print("Top 20 Most Important Features:")
print(importance_df.head(20).to_string(index=False))

In [None]:
# Visualization 1: Top 15 Feature Importance
fig, ax = plt.subplots(figsize=(12, 8))

top_15 = importance_df.head(15)
colors = plt.cm.RdYlGn(np.linspace(0.2, 0.8, 15))[::-1]

bars = ax.barh(range(len(top_15)), top_15['importance'].values, color=colors, edgecolor='black')
ax.set_yticks(range(len(top_15)))
ax.set_yticklabels(top_15['feature'].values)
ax.invert_yaxis()
ax.set_xlabel('Feature Importance', fontsize=12)
ax.set_title('Top 15 Most Important Features for Fraud Detection', fontsize=14, fontweight='bold')

# Add importance values as text
for bar, imp in zip(bars, top_15['importance'].values):
    ax.text(bar.get_width() + 0.001, bar.get_y() + bar.get_height()/2, 
            f'{imp:.4f}', va='center', fontsize=10)

plt.tight_layout()
plt.savefig(VISUALS_PATH / 'feature_importance_top15.png', dpi=300, bbox_inches='tight')
plt.show()
print(f"Saved: {VISUALS_PATH / 'feature_importance_top15.png'}")

In [None]:
# Feature importance by category
def categorize_feature(feature_name):
    """Categorize features by their prefix/type."""
    if feature_name.startswith('V'):
        return 'Vesta Features'
    elif feature_name.startswith('C'):
        return 'Counting Features'
    elif feature_name.startswith('D'):
        return 'Timedelta Features'
    elif feature_name.startswith('M'):
        return 'Match Features'
    elif feature_name.startswith('id_'):
        return 'Identity Features'
    elif feature_name.startswith('card'):
        return 'Card Features'
    elif feature_name.startswith('addr'):
        return 'Address Features'
    elif 'email' in feature_name.lower():
        return 'Email Features'
    elif feature_name in ['hour', 'day', 'day_of_week', 'hour_sin', 'hour_cos', 
                          'dow_sin', 'dow_cos', 'is_night', 'is_weekend', 'is_business_hours']:
        return 'Temporal Features'
    elif 'TransactionAmt' in feature_name:
        return 'Amount Features'
    else:
        return 'Other Features'

importance_df['category'] = importance_df['feature'].apply(categorize_feature)

# Aggregate importance by category
category_importance = importance_df.groupby('category')['importance'].sum().sort_values(ascending=False)

print("\nFeature Importance by Category:")
print(category_importance)

In [None]:
# Visualization 2: Importance by Category
fig, ax = plt.subplots(figsize=(12, 6))

colors = plt.cm.Set3(np.linspace(0, 1, len(category_importance)))
bars = ax.bar(range(len(category_importance)), category_importance.values, color=colors, edgecolor='black')
ax.set_xticks(range(len(category_importance)))
ax.set_xticklabels(category_importance.index, rotation=45, ha='right')
ax.set_ylabel('Total Importance', fontsize=12)
ax.set_title('Feature Importance by Category', fontsize=14, fontweight='bold')

# Add values on bars
for bar, val in zip(bars, category_importance.values):
    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01, 
            f'{val:.3f}', ha='center', fontsize=10)

plt.tight_layout()
plt.savefig(VISUALS_PATH / 'feature_importance_by_category.png', dpi=300, bbox_inches='tight')
plt.show()

## 2. SHAP Value Analysis

SHAP (SHapley Additive exPlanations) provides:
- **Global interpretability**: Which features matter most overall
- **Local interpretability**: Why a specific prediction was made
- **Feature interaction**: How features work together

Note: SHAP computation can be slow for large datasets, so we sample.

In [None]:
# Try to import SHAP, install if not available
try:
    import shap
    SHAP_AVAILABLE = True
    print("SHAP library available")
except ImportError:
    SHAP_AVAILABLE = False
    print("SHAP library not available. Skipping SHAP analysis.")
    print("Install with: pip install shap")

In [None]:
if SHAP_AVAILABLE:
    # Sample data for SHAP (full dataset too slow)
    sample_size = min(1000, len(X_val))
    sample_idx = np.random.choice(X_val.index, size=sample_size, replace=False)
    X_sample = X_val.loc[sample_idx]
    
    print(f"Computing SHAP values for {sample_size} samples...")
    
    # Create SHAP explainer based on model type
    if 'LightGBM' in model_name or 'XGBoost' in model_name:
        explainer = shap.TreeExplainer(model)
    else:
        explainer = shap.TreeExplainer(model)
    
    shap_values = explainer.shap_values(X_sample)
    
    # For binary classification, shap_values might be a list
    if isinstance(shap_values, list):
        shap_values = shap_values[1]  # Use positive class
    
    print(f"SHAP values computed. Shape: {shap_values.shape}")

In [None]:
if SHAP_AVAILABLE:
    # Visualization 3: SHAP Summary Plot
    plt.figure(figsize=(12, 10))
    shap.summary_plot(shap_values, X_sample, plot_type="bar", max_display=15, show=False)
    plt.title('SHAP Feature Importance (Top 15)', fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.savefig(VISUALS_PATH / 'shap_importance_bar.png', dpi=300, bbox_inches='tight')
    plt.show()

In [None]:
if SHAP_AVAILABLE:
    # Visualization 4: SHAP Beeswarm Plot
    plt.figure(figsize=(12, 10))
    shap.summary_plot(shap_values, X_sample, max_display=15, show=False)
    plt.title('SHAP Value Distribution (Top 15 Features)', fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.savefig(VISUALS_PATH / 'shap_beeswarm.png', dpi=300, bbox_inches='tight')
    plt.show()

In [None]:
if SHAP_AVAILABLE:
    # Visualization 5: SHAP Waterfall for Individual Predictions
    # Find a fraud case and a legitimate case
    fraud_idx = X_sample[y_val.loc[sample_idx] == 1].index
    legit_idx = X_sample[y_val.loc[sample_idx] == 0].index
    
    if len(fraud_idx) > 0:
        # Get index position in X_sample
        fraud_pos = np.where(X_sample.index == fraud_idx[0])[0][0]
        
        fig, ax = plt.subplots(figsize=(12, 8))
        shap.plots.waterfall(shap.Explanation(
            values=shap_values[fraud_pos],
            base_values=explainer.expected_value if not isinstance(explainer.expected_value, list) else explainer.expected_value[1],
            data=X_sample.iloc[fraud_pos],
            feature_names=X_sample.columns.tolist()
        ), max_display=15, show=False)
        plt.title('SHAP Explanation for a Fraud Transaction', fontsize=14, fontweight='bold')
        plt.tight_layout()
        plt.savefig(VISUALS_PATH / 'shap_waterfall_fraud.png', dpi=300, bbox_inches='tight')
        plt.show()

## 3. Error Analysis

Understanding model errors is crucial for fraud detection:
- **False Positives**: Legitimate transactions flagged as fraud (customer friction)
- **False Negatives**: Fraud transactions missed (financial loss)

We analyze patterns in errors to improve the model.

In [None]:
# Create error analysis dataframe
error_df = val_data.copy()
error_df['predicted_proba'] = y_pred_proba
error_df['predicted'] = y_pred
error_df['actual'] = y_val.values

# Categorize predictions
def categorize_prediction(row):
    if row['actual'] == 1 and row['predicted'] == 1:
        return 'True Positive'
    elif row['actual'] == 0 and row['predicted'] == 0:
        return 'True Negative'
    elif row['actual'] == 0 and row['predicted'] == 1:
        return 'False Positive'
    else:
        return 'False Negative'

error_df['prediction_type'] = error_df.apply(categorize_prediction, axis=1)

# Summary
print("Prediction Type Distribution:")
print(error_df['prediction_type'].value_counts())

In [None]:
# Visualization 6: Confusion Matrix
fig, ax = plt.subplots(figsize=(8, 6))

cm = confusion_matrix(y_val, y_pred)
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=ax,
            xticklabels=['Legitimate', 'Fraud'],
            yticklabels=['Legitimate', 'Fraud'])
ax.set_xlabel('Predicted', fontsize=12)
ax.set_ylabel('Actual', fontsize=12)
ax.set_title('Confusion Matrix', fontsize=14, fontweight='bold')

plt.tight_layout()
plt.savefig(VISUALS_PATH / 'confusion_matrix.png', dpi=300, bbox_inches='tight')
plt.show()

In [None]:
# Analyze False Positives (legitimate transactions flagged as fraud)
false_positives = error_df[error_df['prediction_type'] == 'False Positive']
true_negatives = error_df[error_df['prediction_type'] == 'True Negative']

print("\n" + "="*60)
print("FALSE POSITIVE ANALYSIS")
print("="*60)
print(f"\nTotal False Positives: {len(false_positives)}")
print(f"False Positive Rate: {len(false_positives) / len(true_negatives) * 100:.4f}%")

# Compare characteristics
if 'TransactionAmt' in false_positives.columns:
    print(f"\nTransaction Amount Comparison:")
    print(f"  False Positives - Mean: ${false_positives['TransactionAmt'].mean():.2f}, "
          f"Median: ${false_positives['TransactionAmt'].median():.2f}")
    print(f"  True Negatives - Mean: ${true_negatives['TransactionAmt'].mean():.2f}, "
          f"Median: ${true_negatives['TransactionAmt'].median():.2f}")

In [None]:
# Analyze False Negatives (fraud transactions missed)
false_negatives = error_df[error_df['prediction_type'] == 'False Negative']
true_positives = error_df[error_df['prediction_type'] == 'True Positive']

print("\n" + "="*60)
print("FALSE NEGATIVE ANALYSIS")
print("="*60)
print(f"\nTotal False Negatives: {len(false_negatives)}")
print(f"False Negative Rate: {len(false_negatives) / (len(false_negatives) + len(true_positives)) * 100:.2f}%")

if 'TransactionAmt' in false_negatives.columns:
    print(f"\nTransaction Amount Comparison:")
    print(f"  False Negatives - Mean: ${false_negatives['TransactionAmt'].mean():.2f}, "
          f"Median: ${false_negatives['TransactionAmt'].median():.2f}")
    print(f"  True Positives - Mean: ${true_positives['TransactionAmt'].mean():.2f}, "
          f"Median: ${true_positives['TransactionAmt'].median():.2f}")

In [None]:
# Visualization 7: Prediction Probability Distribution by Actual Class
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Distribution of probabilities
for label, color, name in [(0, '#2ecc71', 'Legitimate'), (1, '#e74c3c', 'Fraud')]:
    subset = error_df[error_df['actual'] == label]['predicted_proba']
    axes[0].hist(subset, bins=50, alpha=0.6, color=color, label=name, density=True)

axes[0].axvline(x=optimal_threshold, color='black', linestyle='--', 
                label=f'Threshold = {optimal_threshold:.3f}')
axes[0].set_xlabel('Predicted Probability', fontsize=12)
axes[0].set_ylabel('Density', fontsize=12)
axes[0].set_title('Prediction Probability Distribution by Actual Class', fontsize=14, fontweight='bold')
axes[0].legend()

# Box plot by prediction type
error_df.boxplot(column='predicted_proba', by='prediction_type', ax=axes[1])
axes[1].set_xlabel('Prediction Type', fontsize=12)
axes[1].set_ylabel('Predicted Probability', fontsize=12)
axes[1].set_title('Prediction Probability by Prediction Type', fontsize=14, fontweight='bold')
axes[1].axhline(y=optimal_threshold, color='red', linestyle='--', label='Threshold')
plt.suptitle('')

plt.tight_layout()
plt.savefig(VISUALS_PATH / 'prediction_distribution.png', dpi=300, bbox_inches='tight')
plt.show()

## 4. Error Analysis by Transaction Amount

Understanding how model performance varies with transaction amount is critical
because the business impact of fraud scales with amount.

In [None]:
# Create amount bins
if 'TransactionAmt' in error_df.columns:
    amount_bins = [0, 50, 100, 200, 500, 1000, 5000, float('inf')]
    amount_labels = ['$0-50', '$50-100', '$100-200', '$200-500', '$500-1K', '$1K-5K', '$5K+']
    
    error_df['amount_bin'] = pd.cut(error_df['TransactionAmt'], bins=amount_bins, labels=amount_labels)
    
    # Calculate metrics by amount bin
    amount_analysis = error_df.groupby('amount_bin', observed=True).apply(
        lambda x: pd.Series({
            'total': len(x),
            'frauds': (x['actual'] == 1).sum(),
            'fraud_rate': (x['actual'] == 1).mean() * 100,
            'true_positives': ((x['actual'] == 1) & (x['predicted'] == 1)).sum(),
            'false_negatives': ((x['actual'] == 1) & (x['predicted'] == 0)).sum(),
            'false_positives': ((x['actual'] == 0) & (x['predicted'] == 1)).sum(),
            'recall': ((x['actual'] == 1) & (x['predicted'] == 1)).sum() / max((x['actual'] == 1).sum(), 1) * 100,
            'avg_fraud_amount': x[x['actual'] == 1]['TransactionAmt'].mean() if (x['actual'] == 1).sum() > 0 else 0
        })
    ).reset_index()
    
    print("\nPerformance Analysis by Transaction Amount:")
    print(amount_analysis.to_string(index=False))

In [None]:
# Visualization 8: Performance by Transaction Amount
if 'TransactionAmt' in error_df.columns:
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    
    # Fraud rate by amount
    axes[0, 0].bar(range(len(amount_analysis)), amount_analysis['fraud_rate'], 
                   color='steelblue', edgecolor='black')
    axes[0, 0].set_xticks(range(len(amount_analysis)))
    axes[0, 0].set_xticklabels(amount_analysis['amount_bin'], rotation=45)
    axes[0, 0].set_ylabel('Fraud Rate (%)', fontsize=12)
    axes[0, 0].set_title('Fraud Rate by Transaction Amount', fontsize=14, fontweight='bold')
    
    # Recall by amount
    axes[0, 1].bar(range(len(amount_analysis)), amount_analysis['recall'], 
                   color='coral', edgecolor='black')
    axes[0, 1].set_xticks(range(len(amount_analysis)))
    axes[0, 1].set_xticklabels(amount_analysis['amount_bin'], rotation=45)
    axes[0, 1].set_ylabel('Recall (%)', fontsize=12)
    axes[0, 1].set_title('Fraud Detection Recall by Amount', fontsize=14, fontweight='bold')
    axes[0, 1].axhline(y=error_df[error_df['actual']==1]['predicted'].mean()*100, 
                       color='red', linestyle='--', label='Overall Recall')
    axes[0, 1].legend()
    
    # False Negatives by amount
    axes[1, 0].bar(range(len(amount_analysis)), amount_analysis['false_negatives'], 
                   color='#e74c3c', edgecolor='black')
    axes[1, 0].set_xticks(range(len(amount_analysis)))
    axes[1, 0].set_xticklabels(amount_analysis['amount_bin'], rotation=45)
    axes[1, 0].set_ylabel('Count', fontsize=12)
    axes[1, 0].set_title('Missed Fraud Transactions by Amount', fontsize=14, fontweight='bold')
    
    # False Positives by amount
    axes[1, 1].bar(range(len(amount_analysis)), amount_analysis['false_positives'], 
                   color='#f39c12', edgecolor='black')
    axes[1, 1].set_xticks(range(len(amount_analysis)))
    axes[1, 1].set_xticklabels(amount_analysis['amount_bin'], rotation=45)
    axes[1, 1].set_ylabel('Count', fontsize=12)
    axes[1, 1].set_title('False Alarms by Transaction Amount', fontsize=14, fontweight='bold')
    
    plt.tight_layout()
    plt.savefig(VISUALS_PATH / 'error_analysis_by_amount.png', dpi=300, bbox_inches='tight')
    plt.show()

## 5. Threshold Analysis

The classification threshold dramatically affects the precision-recall trade-off.
Different business contexts may require different thresholds:
- **High-value transactions**: Lower threshold (catch more fraud, accept more false positives)
- **Customer experience focus**: Higher threshold (fewer false positives)

In [None]:
# Analyze performance at different thresholds
from sklearn.metrics import precision_recall_curve

precision, recall, thresholds = precision_recall_curve(y_val, y_pred_proba)

# Calculate F1 for each threshold
f1_scores = 2 * (precision * recall) / (precision + recall + 1e-8)

# Create threshold analysis dataframe
threshold_analysis = pd.DataFrame({
    'threshold': list(thresholds) + [1.0],
    'precision': precision,
    'recall': recall,
    'f1': f1_scores
})

# Sample some key thresholds
key_thresholds = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
key_analysis = threshold_analysis[threshold_analysis['threshold'].apply(
    lambda x: any(abs(x - t) < 0.05 for t in key_thresholds)
)].drop_duplicates(subset=['threshold']).head(9)

print("Performance at Different Thresholds:")
print(key_analysis.to_string(index=False))

In [None]:
# Visualization 9: Threshold Trade-off Analysis
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Precision vs Recall at different thresholds
axes[0].plot(thresholds, precision[:-1], label='Precision', linewidth=2)
axes[0].plot(thresholds, recall[:-1], label='Recall', linewidth=2)
axes[0].plot(thresholds, f1_scores[:-1], label='F1 Score', linewidth=2)
axes[0].axvline(x=optimal_threshold, color='red', linestyle='--', 
                label=f'Optimal Threshold = {optimal_threshold:.3f}')
axes[0].set_xlabel('Threshold', fontsize=12)
axes[0].set_ylabel('Score', fontsize=12)
axes[0].set_title('Precision, Recall, F1 vs Threshold', fontsize=14, fontweight='bold')
axes[0].legend(loc='center right')
axes[0].set_xlim(0, 1)
axes[0].grid(True, alpha=0.3)

# Precision-Recall curve
axes[1].plot(recall, precision, linewidth=2, color='green')
axes[1].fill_between(recall, precision, alpha=0.3, color='green')
axes[1].set_xlabel('Recall', fontsize=12)
axes[1].set_ylabel('Precision', fontsize=12)
axes[1].set_title('Precision-Recall Trade-off', fontsize=14, fontweight='bold')
axes[1].set_xlim(0, 1)
axes[1].set_ylim(0, 1)
axes[1].grid(True, alpha=0.3)

# Mark operating point
idx = np.argmin(np.abs(thresholds - optimal_threshold))
axes[1].scatter([recall[idx]], [precision[idx]], color='red', s=100, zorder=5,
                label=f'Operating Point (t={optimal_threshold:.3f})')
axes[1].legend()

plt.tight_layout()
plt.savefig(VISUALS_PATH / 'threshold_analysis.png', dpi=300, bbox_inches='tight')
plt.show()

## 6. Summary and Recommendations

### Key Findings

1. **Most Important Features**: The model relies heavily on V-features (Vesta engineered features),
   transaction amount, and card-related features.

2. **Error Patterns**:
   - False positives tend to be transactions with unusual but legitimate patterns
   - False negatives often involve fraud that mimics normal behavior

3. **Amount-Based Performance**:
   - Performance varies across transaction amount ranges
   - Higher-value transactions may need stricter thresholds

### Recommendations for Production

1. **Tiered Thresholds**: Consider different thresholds based on transaction amount
2. **Feature Monitoring**: Track distributions of top features for drift detection
3. **Regular Retraining**: Fraud patterns evolve; schedule regular model updates
4. **Human Review Queue**: Route high-uncertainty predictions to manual review

In [None]:
# Save feature importance
importance_df.to_csv(OUTPUT_PATH / 'metrics' / 'feature_importance.csv', index=False)

# Save error analysis
error_summary = error_df['prediction_type'].value_counts().to_dict()
error_summary['optimal_threshold'] = optimal_threshold
error_summary['model_name'] = model_name

import json
with open(OUTPUT_PATH / 'metrics' / 'error_analysis.json', 'w') as f:
    json.dump(error_summary, f, indent=2)

print(f"Feature importance saved to: {OUTPUT_PATH / 'metrics' / 'feature_importance.csv'}")
print(f"Error analysis saved to: {OUTPUT_PATH / 'metrics' / 'error_analysis.json'}")

In [None]:
print("\n" + "="*60)
print("MODEL INTERPRETATION COMPLETE")
print("="*60)
print(f"\nModel: {model_name}")
print(f"Optimal Threshold: {optimal_threshold:.4f}")
print(f"\nKey visualizations saved to: {VISUALS_PATH}")
print(f"Analysis metrics saved to: {OUTPUT_PATH / 'metrics'}")
print("\nAll notebooks complete. Check src/ for production scripts.")