In [1]:
# %% [markdown]
# # Step 4: Model Explainability with SHAP

# %%
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import shap
import joblib
import warnings
warnings.filterwarnings('ignore')

# Load models and data
xgb_p92 = joblib.load('../models/xgb_petrol92.pkl')
xgb_ad = joblib.load('../models/xgb_diesel.pkl')
feature_names = joblib.load('../models/feature_names.pkl')

X_test = pd.read_csv('../feature_engineered_data/X_test.csv')
y_test = pd.read_csv('../feature_engineered_data/y_test.csv')

print("Models and data loaded successfully")

# %% [markdown]
# ## SHAP Analysis for Petrol 92 Model

# %%
# Create SHAP explainer (using a sample of data for speed)
X_sample = X_test.sample(n=min(100, len(X_test)), random_state=42)

# Initialize SHAP explainer
explainer_p92 = shap.TreeExplainer(xgb_p92)
shap_values_p92 = explainer_p92.shap_values(X_sample)

# Summary plot
plt.figure(figsize=(12, 8))
shap.summary_plot(shap_values_p92, X_sample, feature_names=feature_names, show=False)
plt.title('SHAP Feature Importance - Petrol 92 Model')
plt.tight_layout()
plt.savefig('../reports/shap_summary_petrol92.png', dpi=150, bbox_inches='tight')
plt.show()

# %%
# Bar plot of feature importance
plt.figure(figsize=(10, 8))
shap.summary_plot(shap_values_p92, X_sample, feature_names=feature_names, 
                   plot_type="bar", show=False)
plt.title('SHAP Feature Importance (Bar) - Petrol 92 Model')
plt.tight_layout()
plt.savefig('../reports/shap_bar_petrol92.png', dpi=150, bbox_inches='tight')
plt.show()

# %% [markdown]
# ## SHAP Analysis for Auto Diesel Model

# %%
# SHAP for diesel model
explainer_ad = shap.TreeExplainer(xgb_ad)
shap_values_ad = explainer_ad.shap_values(X_sample)

# Summary plot
plt.figure(figsize=(12, 8))
shap.summary_plot(shap_values_ad, X_sample, feature_names=feature_names, show=False)
plt.title('SHAP Feature Importance - Auto Diesel Model')
plt.tight_layout()
plt.savefig('../reports/shap_summary_diesel.png', dpi=150, bbox_inches='tight')
plt.show()

# %%
# Bar plot
plt.figure(figsize=(10, 8))
shap.summary_plot(shap_values_ad, X_sample, feature_names=feature_names, 
                   plot_type="bar", show=False)
plt.title('SHAP Feature Importance (Bar) - Auto Diesel Model')
plt.tight_layout()
plt.savefig('../reports/shap_bar_diesel.png', dpi=150, bbox_inches='tight')
plt.show()

# %% [markdown]
## Feature Interaction Analysis

# %%
# SHAP dependence plots for top features
top_features = ['LP_92_lag_1', 'Crude_Oil_USD', 'Exchange Rate', 'Inflation_Rate']

fig, axes = plt.subplots(2, 2, figsize=(14, 12))

for i, feature in enumerate(top_features):
    if feature in feature_names:
        feature_idx = feature_names.index(feature)
        ax = axes[i//2, i%2]
        shap.dependence_plot(
            feature_idx, shap_values_p92, X_sample, 
            feature_names=feature_names, ax=ax, show=False
        )
        ax.set_title(f'SHAP Dependence: {feature}')

plt.tight_layout()
plt.savefig('../reports/shap_dependence.png', dpi=150, bbox_inches='tight')
plt.show()

# %% [markdown]
## Individual Prediction Explanations

# %%
# Explain a few individual predictions
for i in range(3):
    print(f"\n{'='*60}")
    print(f"Explanation for Test Sample {i+1}")
    print('='*60)
    
    # Get sample
    sample = X_test.iloc[i:i+1]
    actual_p92 = y_test.iloc[i]['LP_92']
    actual_ad = y_test.iloc[i]['LAD']
    
    # Predict
    pred_p92 = xgb_p92.predict(sample)[0]
    pred_ad = xgb_ad.predict(sample)[0]
    
    print(f"Actual Petrol 92: {actual_p92:.2f} LKR")
    print(f"Predicted Petrol 92: {pred_p92:.2f} LKR")
    print(f"Error: {actual_p92 - pred_p92:.2f} LKR")
    print()
    print(f"Actual Auto Diesel: {actual_ad:.2f} LKR")
    print(f"Predicted Auto Diesel: {pred_ad:.2f} LKR")
    print(f"Error: {actual_ad - pred_ad:.2f} LKR")
    
    # SHAP force plot for Petrol 92
    shap_values_single = explainer_p92.shap_values(sample)
    
    plt.figure(figsize=(12, 3))
    shap.force_plot(
        explainer_p92.expected_value, 
        shap_values_single[0], 
        sample.iloc[0], 
        feature_names=feature_names,
        matplotlib=True,
        show=False
    )
    plt.title(f'SHAP Force Plot - Petrol 92 (Sample {i+1})')
    plt.tight_layout()
    plt.savefig(f'../reports/shap_force_p92_sample{i+1}.png', dpi=150, bbox_inches='tight')
    plt.show()

# %% [markdown]
## Feature Importance Comparison

# %%
# Get feature importance from XGBoost
importance_p92 = pd.DataFrame({
    'Feature': feature_names,
    'Importance': xgb_p92.feature_importances_
}).sort_values('Importance', ascending=False)

importance_ad = pd.DataFrame({
    'Feature': feature_names,
    'Importance': xgb_ad.feature_importances_
}).sort_values('Importance', ascending=False)

# Plot comparison
fig, axes = plt.subplots(1, 2, figsize=(14, 6))

# Petrol 92
axes[0].barh(importance_p92['Feature'][:10], importance_p92['Importance'][:10])
axes[0].set_xlabel('Importance')
axes[0].set_title('XGBoost Feature Importance - Petrol 92')
axes[0].invert_yaxis()

# Auto Diesel
axes[1].barh(importance_ad['Feature'][:10], importance_ad['Importance'][:10])
axes[1].set_xlabel('Importance')
axes[1].set_title('XGBoost Feature Importance - Auto Diesel')
axes[1].invert_yaxis()

plt.tight_layout()
plt.savefig('../reports/feature_importance_comparison.png', dpi=150, bbox_inches='tight')
plt.show()

# %% [markdown]
## Key Findings from Explainability

# 1. **Most Important Features**:
#    - Lagged fuel prices (especially 1-day lag)
#    - Crude oil prices
#    - USD/LKR exchange rate
#    - Rolling averages of fuel prices

# 2. **Model Behavior**:
#    - The model correctly learns that yesterday's price is the best predictor of today's price
#    - Crude oil price has a positive relationship with local fuel prices
#    - Exchange rate depreciation (higher USD/LKR) leads to higher fuel prices

# 3. **Alignment with Domain Knowledge**:
#    - Confirms that Sri Lanka's fuel prices are heavily influenced by international crude prices
#    - Exchange rate impact reflects import dependence
#    - Crisis periods (2022) show expected extreme values

Models and data loaded successfully


ValueError: could not convert string to float: '[1.3473811E2]'