In [5]:
import pandas as pd
import numpy as np
import joblib
import shap
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore')

print("="*70)
print(" SHAP EXPLAINABILITY ANALYSIS")
print("="*70)

# Load model and data
print("\nLoading model and test data...")
model = joblib.load('../models/phishing_detector_model.pkl')
feature_names = joblib.load('../models/feature_names.pkl')

# Load dataset
df = pd.read_csv('../data/raw/dataset_phishing.csv')
if df['status'].dtype == 'object':
    df['status'] = df['status'].map({'legitimate': 0, 'phishing': 1})

# Get test data
from sklearn.model_selection import train_test_split
X = df[feature_names]
y = df['status']
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42, stratify=y
)

print(f"âœ“ Model loaded")
print(f"âœ“ Test set: {X_test.shape}")

# Sample for SHAP
X_sample = X_test.sample(min(100, len(X_test)), random_state=42)
y_sample = y_test.loc[X_sample.index]

print(f"âœ“ Using {len(X_sample)} samples for SHAP analysis")

# Create SHAP explainer
print("\n" + "="*70)
print(" CREATING SHAP EXPLAINER (this may take 1-2 minutes)...")
print("="*70)

# ... (keep everything before SHAP computation the same)

explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(X_sample)

# FIX: Extract SHAP values correctly
print(f"\nâœ“ SHAP values computed!")
print(f"  Raw shape: {np.array(shap_values).shape if isinstance(shap_values, list) else shap_values.shape}")

# shap_values has shape (n_samples, n_features, n_classes)
# We want class 1 (phishing): shape (n_samples, n_features)
if len(shap_values.shape) == 3:
    shap_values_phishing = shap_values[:, :, 1]  # Extract class 1
    print(f"  Extracted phishing SHAP values: {shap_values_phishing.shape}")
elif isinstance(shap_values, list):
    shap_values_phishing = shap_values[1]
    print(f"  Using list index [1]: {shap_values_phishing.shape}")
else:
    shap_values_phishing = shap_values
    print(f"  Using as-is: {shap_values_phishing.shape}")

# Expected value for class 1
if isinstance(explainer.expected_value, (list, np.ndarray)):
    expected_value = explainer.expected_value[1]
else:
    expected_value = explainer.expected_value

print(f"  Base value (expected): {expected_value:.4f}")

# ============================================================
# PLOT 1: SHAP Summary Plot (Beeswarm)
# ============================================================
print("\n[1/5] Generating SHAP summary plot...")

plt.figure(figsize=(10, 8))
shap.summary_plot(shap_values_phishing, X_sample, show=False, max_display=20)
plt.title('SHAP Feature Importance - Phishing Detection',
          fontweight='bold', fontsize=14, pad=20)
plt.tight_layout()
plt.savefig('../models/shap_summary.png', dpi=150, bbox_inches='tight')
plt.close()
print("âœ“ Saved: models/shap_summary.png")

# ============================================================
# PLOT 2: SHAP Bar Plot
# ============================================================
print("[2/5] Generating SHAP bar plot...")

plt.figure(figsize=(10, 8))
shap.summary_plot(shap_values_phishing, X_sample, plot_type="bar", show=False, max_display=20)
plt.title('Mean Absolute SHAP Values', fontweight='bold', fontsize=14, pad=20)
plt.tight_layout()
plt.savefig('../models/shap_bar.png', dpi=150, bbox_inches='tight')
plt.close()
print("âœ“ Saved: models/shap_bar.png")

# ============================================================
# PLOT 3: Waterfall Plots
# ============================================================
print("[3/5] Generating waterfall plots...")

phishing_indices = y_sample[y_sample == 1].index
legit_indices = y_sample[y_sample == 0].index

if len(phishing_indices) > 0:
    phishing_idx = phishing_indices[0]
    sample_idx = X_sample.index.get_loc(phishing_idx)

    # Create Explanation object with correct shape
    explanation = shap.Explanation(
        values=shap_values_phishing[sample_idx],  # Now shape (42,)
        base_values=expected_value,
        data=X_sample.iloc[sample_idx].values,
        feature_names=feature_names
    )

    plt.figure(figsize=(12, 6))
    shap.waterfall_plot(explanation, show=False, max_display=15)
    plt.title('SHAP Waterfall: Phishing URL Example', fontweight='bold', fontsize=12)
    plt.tight_layout()
    plt.savefig('../models/shap_waterfall_phishing.png', dpi=150, bbox_inches='tight')
    plt.close()
    print("âœ“ Saved: models/shap_waterfall_phishing.png")

if len(legit_indices) > 0:
    legit_idx = legit_indices[0]
    sample_idx = X_sample.index.get_loc(legit_idx)

    explanation = shap.Explanation(
        values=shap_values_phishing[sample_idx],  # Now shape (42,)
        base_values=expected_value,
        data=X_sample.iloc[sample_idx].values,
        feature_names=feature_names
    )

    plt.figure(figsize=(12, 6))
    shap.waterfall_plot(explanation, show=False, max_display=15)
    plt.title('SHAP Waterfall: Legitimate URL Example', fontweight='bold', fontsize=12)
    plt.tight_layout()
    plt.savefig('../models/shap_waterfall_legitimate.png', dpi=150, bbox_inches='tight')
    plt.close()
    print("âœ“ Saved: models/shap_waterfall_legitimate.png")

# ============================================================
# PLOT 4: Force Plots
# ============================================================
print("[4/5] Generating force plots...")

if len(phishing_indices) > 0:
    phishing_idx = phishing_indices[0]
    sample_idx = X_sample.index.get_loc(phishing_idx)

    shap.force_plot(
        expected_value,
        shap_values_phishing[sample_idx],
        X_sample.iloc[sample_idx],
        matplotlib=True,
        show=False,
        figsize=(20, 3)
    )
    plt.savefig('../models/shap_force_phishing.png', dpi=150, bbox_inches='tight')
    plt.close()
    print("âœ“ Saved: models/shap_force_phishing.png")

if len(legit_indices) > 0:
    legit_idx = legit_indices[0]
    sample_idx = X_sample.index.get_loc(legit_idx)

    shap.force_plot(
        expected_value,
        shap_values_phishing[sample_idx],
        X_sample.iloc[sample_idx],
        matplotlib=True,
        show=False,
        figsize=(20, 3)
    )
    plt.savefig('../models/shap_force_legitimate.png', dpi=150, bbox_inches='tight')
    plt.close()
    print("âœ“ Saved: models/shap_force_legitimate.png")

# ============================================================
# PLOT 5: Dependence Plots
# ============================================================
print("[5/5] Generating feature dependence plots...")

mean_abs_shap = np.abs(shap_values_phishing).mean(axis=0)
top_features_idx = np.argsort(mean_abs_shap)[-3:][::-1]
top_features = [feature_names[i] for i in top_features_idx]

fig, axes = plt.subplots(1, 3, figsize=(18, 5))

for idx, (feat_idx, feat_name) in enumerate(zip(top_features_idx, top_features)):
    shap.dependence_plot(
        feat_idx,
        shap_values_phishing,
        X_sample,
        feature_names=feature_names,
        ax=axes[idx],
        show=False
    )
    axes[idx].set_title(f'{feat_name}', fontweight='bold')

plt.suptitle('SHAP Dependence Plots - Top 3 Features', fontweight='bold', fontsize=14, y=1.02)
plt.tight_layout()
plt.savefig('../models/shap_dependence.png', dpi=150, bbox_inches='tight')
plt.close()
print("âœ“ Saved: models/shap_dependence.png")

# ============================================================
# SUMMARY STATISTICS
# ============================================================
print("\n" + "="*70)
print(" TOP 10 MOST IMPORTANT FEATURES (BY MEAN |SHAP|)")
print("="*70)

feature_importance = pd.DataFrame({
    'Feature': feature_names,
    'Mean_Abs_SHAP': np.abs(shap_values_phishing).mean(axis=0)
}).sort_values('Mean_Abs_SHAP', ascending=False)

print(feature_importance.head(10).to_string(index=False))

# Save to CSV
feature_importance.to_csv('../models/shap_feature_importance.csv', index=False)
print("\nâœ“ Saved: models/shap_feature_importance.csv")

print("\n" + "="*70)
print(" SHAP ANALYSIS COMPLETE!")
print("="*70)
print("\nðŸ“Š Generated Visualizations:")
print("  1. shap_summary.png - Feature importance beeswarm plot")
print("  2. shap_bar.png - Mean absolute SHAP values")
print("  3. shap_waterfall_phishing.png - Phishing prediction breakdown")
print("  4. shap_waterfall_legitimate.png - Legitimate prediction breakdown")
print("  5. shap_force_phishing.png - Force plot (phishing)")
print("  6. shap_force_legitimate.png - Force plot (legitimate)")
print("  7. shap_dependence.png - Top 3 feature interactions")
print("  8. shap_feature_importance.csv - Detailed rankings")
print("\nâœ… All explainability visualizations ready for documentation!")
print("="*70)


 SHAP EXPLAINABILITY ANALYSIS

Loading model and test data...
âœ“ Model loaded
âœ“ Test set: (2286, 42)
âœ“ Using 100 samples for SHAP analysis

 CREATING SHAP EXPLAINER (this may take 1-2 minutes)...

âœ“ SHAP values computed!
  Raw shape: (100, 42, 2)
  Extracted phishing SHAP values: (100, 42)
  Base value (expected): 0.4996

[1/5] Generating SHAP summary plot...
âœ“ Saved: models/shap_summary.png
[2/5] Generating SHAP bar plot...
âœ“ Saved: models/shap_bar.png
[3/5] Generating waterfall plots...
âœ“ Saved: models/shap_waterfall_phishing.png
âœ“ Saved: models/shap_waterfall_legitimate.png
[4/5] Generating force plots...
âœ“ Saved: models/shap_force_phishing.png
âœ“ Saved: models/shap_force_legitimate.png
[5/5] Generating feature dependence plots...
âœ“ Saved: models/shap_dependence.png

 TOP 10 MOST IMPORTANT FEATURES (BY MEAN |SHAP|)
           Feature  Mean_Abs_SHAP
            nb_www       0.107678
  ratio_digits_url       0.039180
 longest_word_path       0.038292
        nb_hyp