In [None]:
import pickle
import shap
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.preprocessing import RobustScaler

# 1. Load the pickled Random Forest model
with open('path_to_your_model.pkl', 'rb') as f:
    rf_model = pickle.load(f)

# 2. Prepare your test data (you'll need the same features used during training)
# If you have a separate test dataset:
test_data = pd.read_csv('path_to_test_data.csv')  # Adjust with your actual data source

# If you need to preprocess the test data similar to your training:
# Make sure to use the same preprocessing steps as in your training
X_test = test_data.drop(columns=['org_name', 'ab_name', 'susceptible_flag', 'charttime'])
y_test = test_data['susceptible_flag']  # Target variable if available

# Apply scaling if it was used in training (use the same scaler)
num_cols = X_test.select_dtypes(include=['number']).columns
scaler = RobustScaler()  # Ideally, load the same scaler used during training
X_test[num_cols] = scaler.fit_transform(X_test[num_cols])

# 3. Create a SHAP explainer for the model
explainer = shap.TreeExplainer(rf_model)

# 4. Calculate SHAP values for the test set
shap_values = explainer.shap_values(X_test)

# 5. Handle different SHAP value structures
# Check the structure of SHAP values
print(f"SHAP values shape/type: {type(shap_values)}")
if isinstance(shap_values, list):
    print(f"List length: {len(shap_values)}")
    for i, sv in enumerate(shap_values):
        print(f"Class {i} shape: {sv.shape}")
elif isinstance(shap_values, np.ndarray):
    print(f"Array shape: {shap_values.shape}")

# 6. Extract the appropriate SHAP values based on structure
# For binary classification (common case):
if isinstance(shap_values, list) and len(shap_values) == 2:
    # Use class 1 (positive class) for binary classification
    shap_values_class1 = shap_values[1]
    feature_importance = np.abs(shap_values_class1).mean(axis=0)
elif isinstance(shap_values, np.ndarray) and len(shap_values.shape) == 3:
    # For 3D array structure [samples, features, classes]
    shap_values_class1 = shap_values[:, :, 1]  # Get values for class 1
    feature_importance = np.abs(shap_values_class1).mean(axis=0)
else:
    # For single class or regression
    feature_importance = np.abs(shap_values).mean(axis=0)

# 7. Create DataFrame for feature importance
feature_names = X_test.columns  # Get feature names from your test data
importance_data = [{'Feature': feat, 'SHAP Importance': imp}
                  for feat, imp in zip(feature_names, feature_importance)]

feature_importance_df = pd.DataFrame(importance_data)
feature_importance_df = feature_importance_df.sort_values('SHAP Importance', ascending=False)

# 8. Print top important features
print("\nTop 20 important features:")
print(feature_importance_df.head(20))

# 9. Create a bar plot for feature importance
plt.figure(figsize=(12, 8))
top_features = feature_importance_df.head(20)
plt.barh(top_features['Feature'], top_features['SHAP Importance'])
plt.xlabel('Mean |SHAP Value|')
plt.title('Top 20 Feature Importance')
plt.gca().invert_yaxis()  # Display highest values at the top
plt.tight_layout()
plt.savefig('feature_importance.png')  # Save the figure
plt.show()

# 10. Optional: Create SHAP summary plot
try:
    # Create a smaller sample for visualization if dataset is large
    sample_size = min(100, X_test.shape[0])
    X_sample = X_test.iloc[:sample_size]

    # Get the appropriate SHAP values for the sample
    if isinstance(shap_values, list) and len(shap_values) == 2:
        sample_shap = shap_values[1][:sample_size]
    elif isinstance(shap_values, np.ndarray) and len(shap_values.shape) == 3:
        sample_shap = shap_values[:sample_size, :, 1]
    else:
        sample_shap = shap_values[:sample_size]

    # Create a summary plot
    plt.figure(figsize=(12, 8))
    shap.summary_plot(
        sample_shap,
        X_sample,
        plot_type="bar",
        max_display=20,
        show=False
    )
    plt.title("SHAP Summary Plot")
    plt.tight_layout()
    plt.savefig('shap_summary.png')  # Save the figure
    plt.show()

    # Optional: Also create a dot summary plot (shows direction of impact)
    plt.figure(figsize=(12, 8))
    shap.summary_plot(
        sample_shap,
        X_sample,
        max_display=20,
        show=False
    )
    plt.title("SHAP Dot Summary Plot")
    plt.tight_layout()
    plt.savefig('shap_dot_summary.png')  # Save the figure
    plt.show()
except Exception as e:
    print(f"Could not create SHAP summary plot: {e}")