In [None]:
# Task: Using SHAP for Feature Drift Analysis

import pandas as pd
import numpy as np
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
import shap
import matplotlib.pyplot as plt

# Suppress warnings
import warnings
warnings.filterwarnings("ignore")

# Load or simulate datasets with drift
np.random.seed(42)

# Train dataset
train_data = pd.DataFrame({
    'feature_1': np.random.normal(0, 1, 1000),
    'feature_2': np.random.normal(5, 1.5, 1000),
    'feature_3': np.random.randint(0, 2, 1000),
    'label': np.random.randint(0, 2, 1000)
})

# Test dataset with feature drift
test_data = pd.DataFrame({
    'feature_1': np.random.normal(1, 1.1, 1000),
    'feature_2': np.random.normal(6, 1.7, 1000),
    'feature_3': np.random.randint(0, 2, 1000),
    'label': np.random.randint(0, 2, 1000)
})

# Split features and labels
X_train = train_data.drop(columns='label')
y_train = train_data['label']
X_test = test_data.drop(columns='label')
y_test = test_data['label']

# Train a model on training data
model = RandomForestClassifier(n_estimators=100, random_state=42)
model.fit(X_train, y_train)

# SHAP analysis on both datasets
explainer = shap.TreeExplainer(model)

shap_values_train = explainer.shap_values(X_train)[1]
shap_values_test = explainer.shap_values(X_test)[1]

# Mean absolute SHAP value per feature
shap_train_importance = np.abs(shap_values_train).mean(axis=0)
shap_test_importance = np.abs(shap_values_test).mean(axis=0)

# Create comparison DataFrame
shap_df = pd.DataFrame({
    'Feature': X_train.columns,
    'Train_SHAP_Importance': shap_train_importance,
    'Test_SHAP_Importance': shap_test_importance,
    'Drift': shap_test_importance - shap_train_importance
}).sort_values('Drift', ascending=False)

# Output result
print("\nSHAP-Based Feature Drift Analysis:\n")
print(shap_df)

# Plot the drift
shap_df.set_index("Feature")[["Train_SHAP_Importance", "Test_SHAP_Importance"]].plot(kind='bar', figsize=(10,5))
plt.title("SHAP Importance Comparison (Train vs Test)")
plt.ylabel("Mean |SHAP value|")
plt.xticks(rotation=0)
plt.tight_layout()
plt.show()
