# Multi-Class Outage Cause Prediction with XGBoost & SHAP

From the [Sisyphean Gridworks ML Playground](https://sgridworks.com/ml-playground/guides/09-advanced-outage-prediction.html)

## Setup

Clone the repository and install dependencies. Run this cell first.

In [None]:
!git clone https://github.com/SGridworks/Dynamic-Network-Model.git 2>/dev/null || echo 'Already cloned'
%cd Dynamic-Network-Model
!pip install -q pandas numpy matplotlib seaborn scikit-learn xgboost lightgbm pyarrow

## Load and Merge All Data Sources

Unlike Guide 01 where we used only weather data, here we merge weather and asset data to give the model a richer picture of outage drivers.

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import xgboost as xgb
import shap
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.preprocessing import LabelEncoder

from demo_data.load_demo_data import (
    load_outage_history, load_weather_data,
    load_transformers, load_network_edges
)

# Load all datasets
outages = load_outage_history()
weather = load_weather_data()
transformers = load_transformers()
edges = load_network_edges()

print(f"Outages: {len(outages):,} events")
print(f"Cause codes: {outages['cause_code'].unique()}")
print(f"\nCause code distribution:")
print(outages["cause_code"].value_counts())

## Build Enriched Feature Set

We combine daily weather summaries with asset condition data for each outage's feeder. This gives the model both environmental and infrastructure context. Note that we keep all cause codes including “unknown”—in real utility data, a significant portion of outages have undetermined causes, and the model should learn to recognize these patterns rather than ignoring them.

In [None]:
# Create daily weather features (same as Guide 01)
weather["date"] = weather["timestamp"].dt.date
daily_weather = weather.groupby("date").agg({
    "temperature":    ["max", "min", "mean"],
    "wind_speed":     ["max", "mean"],
    "is_storm":       "max",
    "humidity":       "mean",
}).reset_index()
daily_weather.columns = ["date", "temp_max", "temp_min", "temp_mean",
                          "wind_max", "wind_mean", "is_storm", "humidity_mean"]

# Add asset features per feeder
feeder_assets = transformers.groupby("feeder_id").agg({
    "age_years":  "mean",
    "kva_rating": "sum",
}).reset_index()
feeder_assets.columns = ["feeder_id", "avg_asset_age", "total_kva"]

# Build the training table: one row per outage event
outages["date"] = outages["fault_detected"].dt.date
df = outages.merge(daily_weather, on="date", how="left")
df = df.merge(feeder_assets, on="feeder_id", how="left")

# Add time features
df["fault_detected"] = pd.to_datetime(df["fault_detected"])
df["month"] = df["fault_detected"].dt.month
df["hour"] = df["fault_detected"].dt.hour
df["day_of_week"] = df["fault_detected"].dt.dayofweek
df["is_summer"] = df["month"].isin([6, 7, 8]).astype(int)
df["is_storm_season"] = df["month"].isin([3, 4, 5, 6]).astype(int)

# Drop rows with missing weather data (keep all cause codes including unknown)
df = df.dropna(subset=["temp_max"])

print(f"Training samples: {len(df):,}")
print(f"\nFeatures built per event:")
print(f"  Weather: 7 | Asset: 2 | Calendar: 5")

## Time-Aware Train/Test Split

In Guide 01, we used a random split. But in production, your model always predicts the future based on the past. A time-aware split is more honest: train on 2020–2023 data, test on 2024–2025.

In [None]:
# Define features and target
feature_cols = [
    "temp_max", "temp_min", "temp_mean",
    "wind_max", "wind_mean", "is_storm", "humidity_mean",
    "avg_asset_age", "total_kva",
    "month", "hour", "day_of_week", "is_summer", "is_storm_season"
]

# Encode cause codes as integers
le = LabelEncoder()
df["cause_label"] = le.fit_transform(df["cause_code"])
cause_names = le.classes_

# Time-aware split: train on 2020-2023, test on 2024-2025
train_mask = df["fault_detected"].dt.year 2023
test_mask  = df["fault_detected"].dt.year >= 2024

X_train = df.loc[train_mask, feature_cols]
y_train = df.loc[train_mask, "cause_label"]
X_test  = df.loc[test_mask, feature_cols]
y_test  = df.loc[test_mask, "cause_label"]

print(f"Training set: {len(X_train):,} events (2020-2023)")
print(f"Test set:     {len(X_test):,} events (2024-2025)")
print(f"Classes: {list(cause_names)}")

## Train XGBoost Multi-Class Classifier

XGBoost (Extreme Gradient Boosting) builds trees sequentially, where each new tree corrects the mistakes of the previous ones. It typically outperforms Random Forest on structured data, especially with class imbalance.

In [None]:
# Calculate class weights for imbalanced data
class_counts = y_train.value_counts().sort_index()
total = len(y_train)
n_classes = len(class_counts)
class_weights = {i: total / (n_classes * count)
                 for i, count in class_counts.items()}

# Assign sample weights
sample_weights = y_train.map(class_weights)

# Train XGBoost
model = xgb.XGBClassifier(
    n_estimators=300,
    max_depth=6,
    learning_rate=0.1,
    subsample=0.8,
    colsample_bytree=0.8,
    objective="multi:softprob",
    num_class=n_classes,
    random_state=42,
    eval_metric="mlogloss",
)

model.fit(
    X_train, y_train,
    sample_weight=sample_weights,
    eval_set=[(X_test, y_test)],
    verbose=50
)

print("Training complete.")

## Evaluate Multi-Class Performance

In [None]:
# Predict on test set
y_pred = model.predict(X_test)

# Classification report with cause code names
print(classification_report(y_test, y_pred,
      target_names=cause_names))

# Confusion matrix
cm = confusion_matrix(y_test, y_pred)
fig, ax = plt.subplots(figsize=(8, 7))
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues",
            xticklabels=cause_names, yticklabels=cause_names, ax=ax)
ax.set_xlabel("Predicted Cause")
ax.set_ylabel("Actual Cause")
ax.set_title("Multi-Class Confusion Matrix: Outage Cause Prediction")
plt.xticks(rotation=45, ha="right")
plt.tight_layout()
plt.show()

## Explain Predictions with SHAP

Feature importance tells you which features matter globally. SHAP (SHapley Additive exPlanations) goes further: it tells you how much each feature contributed to a specific prediction and in which direction.

In [None]:
# Create SHAP explainer
explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(X_test)

# Summary plot: how each feature affects each class
fig, axes = plt.subplots(1, len(cause_names), figsize=(20, 6))
for i, cause in enumerate(cause_names):
    plt.sca(axes[i])
    shap.summary_plot(shap_values[i], X_test,
                       feature_names=feature_cols,
                       show=False, max_display=8)
    axes[i].set_title(f"{cause}")
plt.tight_layout()
plt.show()

## Explain Individual Predictions

SHAP's real power is explaining individual events. Pick a specific outage and see exactly why the model predicted its cause.

In [None]:
# Pick a specific outage event from the test set
event_idx = 0
event = X_test.iloc[event_idx]
actual = cause_names[y_test.iloc[event_idx]]
predicted = cause_names[y_pred[event_idx]]

print(f"Event details:")
print(f"  Actual cause:    {actual}")
print(f"  Predicted cause: {predicted}")
print(f"  Wind max:        {event['wind_max']:.1f} mph")
print(f"  Storm flag:      {event['is_storm']}")
print(f"  Avg asset age:   {event['avg_asset_age']:.0f} years")

# Waterfall plot for this prediction
predicted_class = y_pred[event_idx]
shap.waterfall_plot(
    shap.Explanation(
        values=shap_values[predicted_class][event_idx],
        base_values=explainer.expected_value[predicted_class],
        data=event.values,
        feature_names=feature_cols
    )
)

## Benchmark Against SAIFI Metrics

How does your model's predicted outage distribution compare to SP&L's actual reliability metrics? This bridges the gap between ML model accuracy and real utility KPIs.

In [None]:
# Compute SAIFI from outage_history (no separate reliability file needed)
total_customers = 48_000  # SP&L total customers served
annual_saifi = (outages.groupby(outages["fault_detected"].dt.year)["affected_customers"]
                .sum() / total_customers)
print("Annual SAIFI (computed from outage_history):")
print(annual_saifi)

# Compare predicted vs actual cause distribution for test years
test_outages = df[test_mask]

actual_dist = test_outages["cause_code"].value_counts(normalize=True)
pred_causes = pd.Series(cause_names[y_pred])
pred_dist = pred_causes.value_counts(normalize=True)

comparison = pd.DataFrame({
    "Actual %": (actual_dist * 100).round(1),
    "Predicted %": (pred_dist * 100).round(1)
})
print("\nCause distribution comparison (test set):")
print(comparison)

# Plot side-by-side
fig, ax = plt.subplots(figsize=(10, 5))
x = np.arange(len(comparison))
width = 0.35
ax.bar(x - width/2, comparison["Actual %"], width, label="Actual", color="#2D6A7A")
ax.bar(x + width/2, comparison["Predicted %"], width, label="Predicted", color="#5FCCDB")
ax.set_xticks(x)
ax.set_xticklabels(comparison.index, rotation=45, ha="right")
ax.set_ylabel("Percentage of Outages")
ax.set_title("Predicted vs Actual Outage Cause Distribution (2024-2025)")
ax.legend()
plt.tight_layout()
plt.show()

## Seasonal Cause Analysis

Outage causes vary by season. Vegetation peaks in spring/summer during the growing season. Weather outages spike during storm season. Equipment failures may increase in extreme heat. Let's validate the model captures these patterns.

In [None]:
# Predicted causes by month
test_outages_with_pred = test_outages.copy()
test_outages_with_pred["predicted_cause"] = cause_names[y_pred]

monthly_causes = test_outages_with_pred.groupby(
    ["month", "predicted_cause"]
).size().unstack(fill_value=0)

fig, ax = plt.subplots(figsize=(12, 6))
monthly_causes.plot(kind="bar", stacked=True, ax=ax,
                    colormap="Set2")
ax.set_xlabel("Month")
ax.set_ylabel("Predicted Outage Count")
ax.set_title("Predicted Outage Causes by Month")
ax.legend(title="Cause", bbox_to_anchor=(1.05, 1))
plt.tight_layout()
plt.show()

## What You Built and Next Steps

In [None]:
# For reproducible results, set random seeds at the top of your notebook
np.random.seed(42)

# Save the trained model for later use
import joblib
joblib.dump(model, "outage_cause_model.pkl")

# Load it back
model = joblib.load("outage_cause_model.pkl")