# Improved XGBoost Pipeline with SHAP Explainability

This notebook implements the improved ML pipeline with all Phase 1 and Phase 2 enhancements:

**Phase 1 (Critical):**
- ‚úÖ Data validation & feature consistency checks
- ‚úÖ Separate validation set (no data leakage)
- ‚úÖ SHAP values for model interpretability
- ‚úÖ Git commit tracking
- ‚úÖ Environment metadata

**Phase 2 (Enhanced):**
- ‚úÖ JSON configuration
- ‚úÖ Learning curves
- ‚úÖ Adjusted R¬≤
- ‚úÖ Correlation heatmap
- ‚úÖ QQ plots for residuals

In [None]:
import sys
import os

# Add parent directory to path for imports
sys.path.append(os.path.dirname(os.getcwd()))

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import scipy.stats as stats

from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
from xgboost import XGBRegressor
import shap
import warnings

# Import our improved utilities
from src.ml_utils import (
    validate_train_test_features,
    check_missing_values,
    adjusted_r2,
    create_model_metadata,
    save_model_with_metadata,
    load_config,
    log_data_split_info
)

warnings.filterwarnings('ignore')
sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (12, 6)

print("‚úÖ Libraries and utilities loaded!")

## 1. Load Processed Data

In [None]:
# Load existing train/test split
X_train_full = pd.read_parquet('../data/processed/X_train.parquet')
X_test = pd.read_parquet('../data/processed/X_test.parquet')
y_train_full = pd.read_parquet('../data/processed/y_train.parquet').squeeze()
y_test = pd.read_parquet('../data/processed/y_test.parquet').squeeze()

print(f"Original split:")
print(f"  Train: X={X_train_full.shape}, y={y_train_full.shape}")
print(f"  Test:  X={X_test.shape}, y={y_test.shape}")
print(f"\nFeatures: {X_train_full.shape[1]}")

## 2. PHASE 1: Create Separate Validation Set

Split training data into train (80%) and validation (20%) to prevent data leakage.

In [None]:
# Split train_full into train + validation
X_train, X_val, y_train, y_val = train_test_split(
    X_train_full, y_train_full, test_size=0.2, random_state=42
)

# Log split information
log_data_split_info(X_train, X_val, X_test, y_train, y_val, y_test)

## 3. PHASE 1: Data Validation

In [None]:
print("üîç PHASE 1: Data Validation\n")

# Validate feature consistency
validate_train_test_features(X_train, X_test)
validate_train_test_features(X_train, X_val)

# Check for missing values
check_missing_values(X_train, X_test)

## 4. PHASE 2: Load Configuration

In [None]:
# Load hyperparameters from JSON config
params = load_config('../config/xgboost_params.json')

print("\nModel Parameters:")
for key, value in params.items():
    print(f"  {key}: {value}")

## 5. Train Model with Learning Curves

In [None]:
# Initialize model
model = XGBRegressor(**params)

# Train with evaluation sets
print("Training XGBoost model with validation monitoring...")
eval_set = [(X_train, y_train), (X_val, y_val)]

model.fit(
    X_train, y_train,
    eval_set=eval_set,
    verbose=50
)

print("\n‚úÖ Training complete!")

## 6. PHASE 2: Learning Curves

In [None]:
# Extract and plot learning curves
results = model.evals_result()

plt.figure(figsize=(12, 5))
plt.plot(results['validation_0']['rmse'], label='Train RMSE', linewidth=2)
plt.plot(results['validation_1']['rmse'], label='Validation RMSE', linewidth=2)
plt.xlabel('Iterations')
plt.ylabel('RMSE')
plt.title('XGBoost Learning Curve - Overfitting Detection', fontsize=14, fontweight='bold')
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('../outputs/plots/learning_curve_notebook.png', dpi=300, bbox_inches='tight')
plt.show()

print("‚úÖ Learning curve saved")

## 7. Make Predictions

In [None]:
# Predictions on all sets
y_train_pred = model.predict(X_train)
y_val_pred = model.predict(X_val)
y_test_pred = model.predict(X_test)

print("Predictions generated for train, validation, and test sets")

## 8. Evaluation Metrics (with PHASE 2: Adjusted R¬≤)

In [None]:
# Calculate standard metrics
train_rmse = np.sqrt(mean_squared_error(y_train, y_train_pred))
val_rmse = np.sqrt(mean_squared_error(y_val, y_val_pred))
test_rmse = np.sqrt(mean_squared_error(y_test, y_test_pred))

train_mae = mean_absolute_error(y_train, y_train_pred)
val_mae = mean_absolute_error(y_val, y_val_pred)
test_mae = mean_absolute_error(y_test, y_test_pred)

train_r2 = r2_score(y_train, y_train_pred)
val_r2 = r2_score(y_val, y_val_pred)
test_r2 = r2_score(y_test, y_test_pred)

print("\n" + "="*60)
print("MODEL PERFORMANCE")
print("="*60)
print(f"{'Metric':<15} {'Train':<12} {'Validation':<12} {'Test':<12}")
print("-" * 60)
print(f"{'RMSE':<15} {train_rmse:<12.4f} {val_rmse:<12.4f} {test_rmse:<12.4f}")
print(f"{'MAE':<15} {train_mae:<12.4f} {val_mae:<12.4f} {test_mae:<12.4f}")
print(f"{'R¬≤':<15} {train_r2:<12.4f} {val_r2:<12.4f} {test_r2:<12.4f}")

# PHASE 2: Adjusted R¬≤
test_adj_r2 = adjusted_r2(test_r2, len(y_test), X_test.shape[1])
val_adj_r2 = adjusted_r2(val_r2, len(y_val), X_val.shape[1])

print("\nüìä PHASE 2: Adjusted R¬≤ (accounts for # of features):")
print(f"  Test:       {test_adj_r2:.4f} (vs R¬≤: {test_r2:.4f})")
print(f"  Validation: {val_adj_r2:.4f} (vs R¬≤: {val_r2:.4f})")
print("="*60)

## 9. Visualization: Actual vs Predicted

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Train
axes[0].scatter(y_train, y_train_pred, alpha=0.3, s=10)
axes[0].plot([y_train.min(), y_train.max()], [y_train.min(), y_train.max()], 'r--', lw=2)
axes[0].set_xlabel('Actual Popularity')
axes[0].set_ylabel('Predicted Popularity')
axes[0].set_title(f'Training Set\nR¬≤ = {train_r2:.4f}, RMSE = {train_rmse:.2f}')
axes[0].grid(True, alpha=0.3)

# Validation
axes[1].scatter(y_val, y_val_pred, alpha=0.3, s=10, color='orange')
axes[1].plot([y_val.min(), y_val.max()], [y_val.min(), y_val.max()], 'r--', lw=2)
axes[1].set_xlabel('Actual Popularity')
axes[1].set_ylabel('Predicted Popularity')
axes[1].set_title(f'Validation Set\nR¬≤ = {val_r2:.4f}, RMSE = {val_rmse:.2f}')
axes[1].grid(True, alpha=0.3)

# Test
axes[2].scatter(y_test, y_test_pred, alpha=0.3, s=10, color='green')
axes[2].plot([y_test.min(), y_test.max()], [y_test.min(), y_test.max()], 'r--', lw=2)
axes[2].set_xlabel('Actual Popularity')
axes[2].set_ylabel('Predicted Popularity')
axes[2].set_title(f'Test Set\nR¬≤ = {test_r2:.4f}, RMSE = {test_rmse:.2f}')
axes[2].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('../outputs/plots/all_sets_predictions.png', dpi=300, bbox_inches='tight')
plt.show()

## 10. PHASE 2: Correlation Heatmap

In [None]:
# Create correlation heatmap with target
analysis_df = X_train.copy()
analysis_df['popularity'] = y_train.values

# Calculate correlations with target
correlations = analysis_df.corr()['popularity'].drop('popularity').sort_values(ascending=False)

print("Top 15 Features Correlated with Popularity:")
print(correlations.head(15))

# Plot top 20 correlations
plt.figure(figsize=(10, 8))
top_corr = correlations.abs().sort_values(ascending=False).head(20)
colors = ['green' if correlations[f] > 0 else 'red' for f in top_corr.index]
plt.barh(range(len(top_corr)), [correlations[f] for f in top_corr.index], color=colors, alpha=0.7)
plt.yticks(range(len(top_corr)), top_corr.index)
plt.xlabel('Correlation with Popularity')
plt.title('Top 20 Feature Correlations with Popularity', fontsize=14, fontweight='bold')
plt.axvline(x=0, color='black', linestyle='-', linewidth=0.8)
plt.gca().invert_yaxis()
plt.tight_layout()
plt.savefig('../outputs/plots/feature_correlations.png', dpi=300, bbox_inches='tight')
plt.show()

## 11. PHASE 2: Residual Analysis with QQ Plot

In [None]:
residuals = y_test - y_test_pred

fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Residual scatter
axes[0].scatter(y_test_pred, residuals, alpha=0.3, s=10)
axes[0].axhline(y=0, color='r', linestyle='--', lw=2)
axes[0].set_xlabel('Predicted Popularity')
axes[0].set_ylabel('Residuals')
axes[0].set_title('Residual Plot')
axes[0].grid(True, alpha=0.3)

# Residual histogram
axes[1].hist(residuals, bins=50, edgecolor='black', alpha=0.7)
axes[1].axvline(x=0, color='r', linestyle='--', lw=2)
axes[1].set_xlabel('Residuals')
axes[1].set_ylabel('Frequency')
axes[1].set_title('Residual Distribution')
axes[1].grid(True, alpha=0.3)

# QQ Plot (PHASE 2)
stats.probplot(residuals, dist="norm", plot=axes[2])
axes[2].set_title('QQ Plot (Normality Check)')
axes[2].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('../outputs/plots/residual_analysis_complete.png', dpi=300, bbox_inches='tight')
plt.show()

print(f"\nResidual Statistics:")
print(f"  Mean: {residuals.mean():.4f}")
print(f"  Std:  {residuals.std():.4f}")
print(f"  Min:  {residuals.min():.4f}")
print(f"  Max:  {residuals.max():.4f}")

## 12. Standard Feature Importance

In [None]:
feature_importance = pd.DataFrame({
    'feature': X_train.columns,
    'importance': model.feature_importances_
}).sort_values('importance', ascending=False)

print("Top 20 Most Important Features (Standard XGBoost):")
print(feature_importance.head(20))

# Plot
plt.figure(figsize=(10, 8))
top_20 = feature_importance.head(20)
plt.barh(range(len(top_20)), top_20['importance'], color='steelblue')
plt.yticks(range(len(top_20)), top_20['feature'])
plt.xlabel('Importance (Gain)')
plt.title('Top 20 Feature Importances (XGBoost)', fontsize=14, fontweight='bold')
plt.gca().invert_yaxis()
plt.tight_layout()
plt.savefig('../outputs/plots/standard_feature_importance.png', dpi=300, bbox_inches='tight')
plt.show()

## 13. PHASE 1 CRITICAL: SHAP Values for Interpretability

SHAP (SHapley Additive exPlanations) provides:
- Global feature importance (more reliable than tree-based importance)
- Feature impact direction (positive/negative)
- Individual prediction explanations

In [None]:
print("üî¨ Computing SHAP values (this may take a few minutes)...")

# Use a sample if test set is very large
X_test_shap = X_test
if len(X_test) > 10000:
    print(f"  Sampling {10000} test samples for faster computation...")
    X_test_shap = X_test.sample(n=10000, random_state=42)

# Create SHAP explainer
explainer = shap.Explainer(model, X_train)
shap_values = explainer(X_test_shap)

print("‚úÖ SHAP values computed successfully!")

### 13.1 SHAP Summary Plot (Bar) - Global Importance

In [None]:
# SHAP summary bar plot
plt.figure(figsize=(12, 8))
shap.summary_plot(shap_values, X_test_shap, plot_type="bar", show=False, max_display=20)
plt.title('SHAP Feature Importance (Global)', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig('../outputs/plots/shap_summary_bar_notebook.png', dpi=300, bbox_inches='tight')
plt.show()

### 13.2 SHAP Beeswarm Plot - Feature Impact Distribution

In [None]:
# SHAP beeswarm plot (shows feature impact direction)
plt.figure(figsize=(12, 10))
shap.summary_plot(shap_values, X_test_shap, show=False, max_display=20)
plt.title('SHAP Feature Impact (Beeswarm)', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig('../outputs/plots/shap_beeswarm_notebook.png', dpi=300, bbox_inches='tight')
plt.show()

print("""\nüéØ How to read the beeswarm plot:
- X-axis: SHAP value (impact on prediction)
- Color: Feature value (red=high, blue=low)
- Position: Each dot is one sample
- Example: If high values of a feature (red dots) appear on the right,
  it means that feature increases popularity when high.""")

### 13.3 Compare SHAP vs Standard Importance

In [None]:
# Get mean absolute SHAP values
shap_importance = pd.DataFrame({
    'feature': X_test_shap.columns,
    'shap_importance': np.abs(shap_values.values).mean(axis=0)
}).sort_values('shap_importance', ascending=False)

# Merge with standard importance
comparison = shap_importance.merge(
    feature_importance[['feature', 'importance']],
    on='feature'
)

print("\nTop 15 Features - SHAP vs Standard Importance:")
print(comparison.head(15).to_string(index=False))

# Visualize comparison
top_15_features = shap_importance.head(15)['feature'].values
comparison_top = comparison[comparison['feature'].isin(top_15_features)].copy()

# Normalize for comparison
comparison_top['shap_norm'] = comparison_top['shap_importance'] / comparison_top['shap_importance'].max()
comparison_top['xgb_norm'] = comparison_top['importance'] / comparison_top['importance'].max()

fig, ax = plt.subplots(figsize=(12, 8))
x = np.arange(len(comparison_top))
width = 0.35

ax.barh(x - width/2, comparison_top['shap_norm'], width, label='SHAP', color='steelblue')
ax.barh(x + width/2, comparison_top['xgb_norm'], width, label='XGBoost', color='orange')

ax.set_yticks(x)
ax.set_yticklabels(comparison_top['feature'])
ax.set_xlabel('Normalized Importance')
ax.set_title('Feature Importance: SHAP vs XGBoost (Top 15)', fontsize=14, fontweight='bold')
ax.legend()
ax.invert_yaxis()
plt.tight_layout()
plt.savefig('../outputs/plots/shap_vs_standard_comparison.png', dpi=300, bbox_inches='tight')
plt.show()

### 13.4 Example: Explain Individual Predictions

In [None]:
# Pick a random sample to explain
sample_idx = 42  # You can change this

print(f"\nüìä Explaining prediction for sample {sample_idx}:")
print(f"  Actual popularity: {y_test.iloc[sample_idx]:.2f}")
print(f"  Predicted popularity: {y_test_pred[sample_idx]:.2f}")
print(f"  Error: {abs(y_test.iloc[sample_idx] - y_test_pred[sample_idx]):.2f}")

# Waterfall plot for individual prediction
plt.figure(figsize=(12, 8))
shap.plots.waterfall(shap_values[sample_idx], show=False, max_display=15)
plt.title(f'SHAP Explanation for Sample {sample_idx}', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig(f'../outputs/plots/shap_waterfall_sample_{sample_idx}.png', dpi=300, bbox_inches='tight')
plt.show()

print("""\nüéØ How to read the waterfall plot:
- Base value: Average model prediction
- Red bars: Features pushing prediction higher
- Blue bars: Features pushing prediction lower
- Final value (f(x)): Actual prediction for this sample""")

## 14. PHASE 1: Save Model with Metadata

In [None]:
from datetime import datetime

# Collect metrics
metrics = {
    'train_rmse': float(train_rmse),
    'val_rmse': float(val_rmse),
    'test_rmse': float(test_rmse),
    'train_mae': float(train_mae),
    'val_mae': float(val_mae),
    'test_mae': float(test_mae),
    'train_r2': float(train_r2),
    'val_r2': float(val_r2),
    'test_r2': float(test_r2),
    'test_adjusted_r2': float(test_adj_r2),
    'val_adjusted_r2': float(val_adj_r2)
}

# Create comprehensive metadata
metadata = create_model_metadata(
    model_params=params,
    metrics=metrics,
    feature_names=list(X_train.columns),
    train_size=X_train.shape,
    test_size=X_test.shape
)

# Add notebook-specific info
metadata['source'] = 'notebook/04_Improved_ML_Pipeline.ipynb'
metadata['improvements_implemented'] = {
    'phase_1': ['data_validation', 'separate_validation_set', 'shap_values', 'git_tracking', 'environment_metadata'],
    'phase_2': ['json_config', 'learning_curves', 'adjusted_r2', 'correlation_heatmap', 'qq_plots']
}

# Save model and metadata
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
model_path = f'../outputs/models/improved_xgb_model_{timestamp}.joblib'
metadata_path = f'../outputs/metadata/improved_xgb_metadata_{timestamp}.json'

save_model_with_metadata(model, metadata, model_path, metadata_path)

print("\n‚úÖ Model and metadata saved with full reproducibility tracking!")

## 15. Summary Report

In [None]:
print("\n" + "="*80)
print("IMPROVED ML PIPELINE - EXECUTION SUMMARY")
print("="*80)

print("\nüìä Final Model Performance:")
print(f"  Test R¬≤:          {test_r2:.4f}")
print(f"  Test Adjusted R¬≤: {test_adj_r2:.4f}")
print(f"  Test RMSE:        {test_rmse:.4f}")
print(f"  Test MAE:         {test_mae:.4f}")

print("\n‚úÖ Phase 1 (Critical) Improvements Implemented:")
print("  ‚úì Data validation with feature consistency checks")
print("  ‚úì Separate validation set (no data leakage)")
print("  ‚úì SHAP values for model interpretability")
print("  ‚úì Git commit hash tracking")
print("  ‚úì Environment metadata capture")

print("\n‚úÖ Phase 2 (Enhanced) Improvements Implemented:")
print("  ‚úì JSON configuration management")
print("  ‚úì Learning curves (overfitting detection)")
print("  ‚úì Adjusted R¬≤ metric")
print("  ‚úì Correlation heatmap")
print("  ‚úì QQ plots for residual normality")

print("\nüéØ Top 5 Most Important Features (SHAP):")
for i, row in shap_importance.head(5).iterrows():
    print(f"  {i+1}. {row['feature']}: {row['shap_importance']:.6f}")

print("\nüíæ Outputs Saved:")
print(f"  ‚úì Model: {model_path}")
print(f"  ‚úì Metadata: {metadata_path}")
print(f"  ‚úì Visualizations: ../outputs/plots/ (10+ plots)")

if metadata.get('git_commit'):
    print(f"\nüìå Git Commit: {metadata['git_commit'][:8]}")

print("\n" + "="*80)
print("üéâ IMPROVED PIPELINE COMPLETE!")
print("="*80)