In [None]:
import torch
from torch.utils.data import DataLoader
from torch import nn
import torch.optim as optim
import numpy as np
from sklearn.metrics import roc_auc_score, accuracy_score, matthews_corrcoef, recall_score, precision_score
from sklearn.metrics import confusion_matrix, f1_score, classification_report
from sklearn.metrics import roc_curve
import torch.nn.functional as F
import shap
import pandas as pd
import matplotlib.pyplot as plt

In [None]:
loaded_datasets_info = torch.load('/root/autodl-tmp/data/saved_datasets.pth', weights_only=False)
loaded_train_dataset = loaded_datasets_info['train_dataset']
loaded_val_dataset = loaded_datasets_info['val_dataset']

In [None]:
from torch.utils.data import DataLoader

def extract_features_labels_from_subset(subset):
    
    loader = DataLoader(subset, batch_size=len(subset))
    
    for features, labels in loader:
        features = features.squeeze(1).numpy()
        labels = labels.squeeze(1).numpy()
        return features, labels

X_train, y_train = extract_features_labels_from_subset(loaded_train_dataset)
X_val, y_val = extract_features_labels_from_subset(loaded_val_dataset)

In [None]:
# for deep learning model, here we expand 2nd dimension (channel) to 1 
X_train_tensor = torch.tensor(X_train, dtype=torch.float32).unsqueeze(1) 
X_val_tensor = torch.tensor(X_val, dtype=torch.float32).unsqueeze(1)

In [None]:
train_data_new = pd.read_csv('/root/autodl-tmp/data/train_data_new.csv') # 644
feature_names = train_data_new.columns

#### Model

In [None]:
class TransformerEncoderClassification(nn.Module):
    def __init__(self):
        super(TransformerEncoderClassification, self).__init__()
        self.transformer_encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=31, nhead=31), 
            num_layers= 3,
        ) 
        self.fc = nn.Linear(31, 3)

    def forward(self, x):
        x = x.permute(1, 0, 2)  
        x = self.transformer_encoder(x) 
        x = x.permute(1, 0, 2) 
        x = x.flatten(1) 
        x = self.fc(x)
        return x

device = "cpu"
model = TransformerEncoderClassification().to(device)

In [None]:
model.load_state_dict(torch.load('/root/autodl-tmp/model_params/Transformer.pth', weights_only=False))

#### SHAP

In [None]:
# warnings.filterwarnings("ignore") # this code chunk will have warnings every time constructing the DeepExplainer

explainer_transformer = shap.GradientExplainer(model, X_train_tensor) # X_train_tensor as background_data
shap_values_transformer = np.sum(explainer_transformer.shap_values(X_val_tensor[:2000]), axis=-1) # 1. shap values

In [None]:
np.sum(explainer_transformer.shap_values(X_val_tensor[:10]), axis=-1)

In [None]:
model.eval()
with torch.no_grad():
    predictions = model(X_train_tensor)
expected_value = predictions.numpy() # 2. base_values (which is just the expected value)

shap_values_transformer_reconstructed = shap.Explanation(values=shap_values_transformer.squeeze(1), # construct back to only 2 dim
                                             base_values=expected_value, 
                                             data=X_val_tensor[:2000].squeeze(1), # construct back to only 2 dim
                                             feature_names=feature_names)

In [None]:
shap_values_transformer_reconstructed

In [None]:
print(shap_values_transformer.squeeze(1).shape)
print(expected_value.shape)
print(X_val_tensor[:2000].squeeze(1).shape)

In [None]:
fig1_transformer = shap.plots.bar(shap_values_transformer_reconstructed, show=False)
plt.savefig('/root/autodl-tmp/SHAP/shap_fig1_transformer.pdf', bbox_inches='tight')

In [None]:
fig2_transformer = shap.plots.beeswarm(shap_values_transformer_reconstructed, show=False) 
plt.savefig('/root/autodl-tmp/SHAP/shap_fig2_transformer.pdf', bbox_inches='tight')