### Model Explainability with SHAP

To improve transparency and root cause analysis, we use SHAP (SHapley Additive exPlanations) to explain anomaly predictions from the trained model.

This helps identify which features contributed most to a given anomaly.


In [None]:
import shap
from sklearn.ensemble import IsolationForest
from src.generate_data import generate_synthetic_telemetry
from src.preprocess import normalize_data
import pandas as pd
import numpy as np

# Load and process data
df = generate_synthetic_telemetry()

# Reuse features from previous notebooks
df['cpu_rolling_mean'] = df['cpu'].rolling(window=20).mean()
df['cpu_rolling_std'] = df['cpu'].rolling(window=20).std()
df['latency_diff'] = df['latency'].diff()
df['errors_rolling_sum'] = df['errors'].rolling(window=10).sum()
df['latency_fft_mean'] = pd.Series(
    [np.mean(np.abs(np.fft.fft(df['latency'][i:i+64]))) if i+64 < len(df) else np.nan for i in range(len(df))]
)
df['cpu_z'] = (df['cpu'] - df['cpu'].mean()) / df['cpu'].std()
df['latency_z'] = (df['latency'] - df['latency'].mean()) / df['latency'].std()
df['cpu_lag1'] = df['cpu'].shift(1)
df['latency_lag3'] = df['latency'].shift(3)
df['errors_lag2'] = df['errors'].shift(2)
df['latency_cpu_corr_20'] = df['latency'].rolling(window=20).corr(df['cpu'])

df_clean = df.dropna()
X = normalize_data(df_clean)

# Train model
model = IsolationForest(contamination=0.05, random_state=42)
model.fit(X)


We use the SHAP Kernel Explainer since Isolation Forest is not directly supported by TreeExplainer.
This may be slow for large datasets, so we limit to a small subset.


In [None]:
# Use SHAP KernelExplainer (model must have decision_function)
explainer = shap.KernelExplainer(model.decision_function, shap.kmeans(X, 10))

# Pick a sample anomaly (e.g. index 310, known injected anomaly)
i = 310
shap_values = explainer.shap_values(X.iloc[[i]])

# Visualize explanation
shap.initjs()
shap.force_plot(explainer.expected_value, shap_values, X.iloc[i], matplotlib=True)


# Alternatively: explain multiple predictions
subset = X.sample(50, random_state=42)
shap_vals = explainer.shap_values(subset)
shap.summary_plot(shap_vals, subset)
