In [None]:
import torch
from torch.utils.data import DataLoader
from torch import nn
import torch.optim as optim
import numpy as np
from sklearn.metrics import roc_auc_score, accuracy_score, matthews_corrcoef, recall_score, precision_score
from sklearn.metrics import confusion_matrix, f1_score, classification_report
from sklearn.metrics import roc_curve
import torch.nn.functional as F
import shap
import pandas as pd
import matplotlib.pyplot as plt

In [None]:
loaded_datasets_info = torch.load('/root/autodl-tmp/data/saved_datasets.pth', weights_only=False)
loaded_train_dataset = loaded_datasets_info['train_dataset']
loaded_val_dataset = loaded_datasets_info['val_dataset']

In [None]:
from torch.utils.data import DataLoader

def extract_features_labels_from_subset(subset):
    
    loader = DataLoader(subset, batch_size=len(subset))
    
    for features, labels in loader:
        features = features.squeeze(1).numpy()
        labels = labels.squeeze(1).numpy()
        return features, labels

X_train, y_train = extract_features_labels_from_subset(loaded_train_dataset)
X_val, y_val = extract_features_labels_from_subset(loaded_val_dataset)

In [None]:
train_data_new = pd.read_csv('/root/autodl-tmp/data/train_data_new.csv') # 644
feature_names = train_data_new.columns

#### Model

In [None]:
import pickle
with open("../model_params/rf_model.pkl", "rb") as f:
    model = pickle.load(f)

#### SHAP

In [None]:
# warnings.filterwarnings("ignore") # this code chunk will have warnings every time constructing the DeepExplainer

explainer_rf = shap.TreeExplainer(model, X_train) # X_train_tensor as background_data
shap_values_rf = np.sum(explainer_rf.shap_values(X_val[:2000],check_additivity=False), axis=-1)*1000000000 # 1. shap values

In [None]:
predictions = model.predict_proba(X_train)
expected_value = predictions.astype('float32') # 2. base_values (which is just the expected value)

shap_values_rf_reconstructed = shap.Explanation(values=shap_values_rf, # construct back to only 2 dim
                                             base_values=expected_value, 
                                             data=X_val[:2000], # construct back to only 2 dim
                                             feature_names=feature_names)

In [None]:
fig1_rf = shap.plots.bar(shap_values_rf_reconstructed, show=False)
plt.savefig('/root/autodl-tmp/SHAP/shap_fig1_rf.pdf', bbox_inches='tight')

In [None]:
fig2_rf = shap.plots.beeswarm(shap_values_rf_reconstructed, show=False) 
plt.savefig('/root/autodl-tmp/SHAP/shap_fig2_rf.pdf', bbox_inches='tight')