In [1]:
from sklearn.inspection import partial_dependence
from sklearn.inspection import PartialDependenceDisplay
import os
import pandas as pd
import numpy as np
import shap
from sklearn.ensemble import RandomForestRegressor
import matplotlib.pyplot as plt

# Create a folder for explanations
os.makedirs("shap_pdp", exist_ok=True)

In [2]:
# Set the folder path
folder_path = "train_test_pickles"

# Load the first training pickle
train_pickle_path = os.path.join(folder_path, "train_df_0.pickle")
train_df = pd.read_pickle(train_pickle_path)

# Define columns to drop
cols_dropped = ['date', 'last_rtt', 'prb_id', 'dst_id', 'normalizzed_rtt', 'src_names', 'distance',
                'Latitude_source', 'Longitude_source', 'Latitude_destination', 'Longitude_destination',
                'Public_destination', 'Public_source', 'norm_storedtimestamp']

# Create feature matrix (X) and target variable (y) for training and testing sets
X_train = train_df.drop(columns=cols_dropped, axis=1)
y_train = train_df['normalizzed_rtt']

# Initialize the RandomForestRegressor
rand_forest = RandomForestRegressor(
    n_estimators=10,
    criterion='squared_error',
    random_state=42,
)

# Train the random forest on the training data
rand_forest.fit(X_train, y_train)

# Calculate SHAP values using the first 10,000 samples from the training data
explainer_shap = shap.Explainer(rand_forest)
shap_values = explainer_shap(X_train[:100000])

# Define a list of feature names
feature_names = X_train.columns.tolist()

# Select the top 6 features based on average absolute SHAP values
top_features_idx = np.abs(shap_values.values).mean(axis=0).argsort()[-10:]

# Create and save scatter plots for the top 6 features
for feature_idx in top_features_idx:
    # Get SHAP values for the selected feature
    shap_feature_values = shap_values.values[:100000, feature_idx]

    # Plot the scatter plot for partial dependence using SHAP values
    plt.scatter(X_train.iloc[:100000, feature_idx], shap_feature_values, alpha=0.5)
    column_name = X_train.columns[feature_idx]
    plt.xlabel(f"{column_name}")
    plt.ylabel("SHAP Values")
    
    # Save the scatter plot in the 'pdp_plots' folder
    plt.savefig(f"pdp_plots/shap_scatter_plot_feature_{feature_idx}.png")
    plt.clf()
    plt.clf()

<Figure size 640x480 with 0 Axes>