With this kind of pipeline we cannot use SHAP because the input to XGBoost is the precomputed BERT embeddings, not raw words or tokens. We need to modify our pipeline so that SHAP can work directly on the raw text before it is converted into BERT embeddings.

Instead of training XGBoost on BERT embeddings, you will fine-tune a BERT classifier instead of using precomputed embeddings and use SHAP’s DeepExplainer or GradientExplainer on the BERT model to explain feature importance at the token level.

The implementation of this code is based on this website: https://medium.com/@raoashish10/fine-tuning-a-pre-trained-bert-model-for-classification-using-native-pytorch-c5f33e87616e


In [2]:
import pandas as pd
import numpy as np
import re
import torch
import shap
import matplotlib.pyplot as plt
from transformers import BertTokenizer, BertForSequenceClassification, AdamW
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report

# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}\n")

# Load dataset
file_path1 = "Data/SpamAssasin.csv" # 5809 total values
file_path2 = "Data/Enron.csv" #29767 total values
file_path3 = "Data/Nazario.csv" # 1565 total values
print("Loading datasets...")
df1 = pd.read_csv(file_path1)
df2 = pd.read_csv(file_path2)
df3 = pd.read_csv(file_path3)

# Concatenate DataFrames
df_combined = pd.concat([df1, df2, df3], ignore_index=True)
df = df_combined[['body', 'label']].dropna()
print(f"Dataset loaded: {len(df)} emails\n") # 37140 emails

# Preprocessing function
def preprocess_text(text):
    text = re.sub(r"http\S+|www\S+|https\S+", "", text)  # Remove URLs
    text = re.sub(r"\d+", "", text)  # Remove numbers
    text = re.sub(r"\S+@\S+\.\S+", "", text)  # Remove email addresses
    text = re.sub(r"[^A-Za-z0-9\s]", "", text)  # Remove special characters
    text = re.sub(r"\s+", " ", text).strip()  # Normalize spaces
    return text.lower()

# Apply text preprocessing
print("Preprocessing email content...")
df['body'] = df['body'].apply(preprocess_text)
print("Text preprocessing complete!\n")

# Load BERT tokenizer
MODEL_NAME = "bert-base-uncased"
tokenizer = BertTokenizer.from_pretrained(MODEL_NAME)

# Train-test split
X_train, X_test, y_train, y_test = train_test_split(df['body'], df['label'], test_size=0.2, random_state=42)
print(f"Training set: {len(X_train)} emails, Test set: {len(X_test)} emails\n")

# Define PyTorch Dataset
class SpamDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_len=512):
        self.texts = texts.tolist()
        self.labels = labels.tolist()
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        label = self.labels[idx]

        encoding = self.tokenizer(
            text,
            truncation=True,
            padding="max_length",
            max_length=self.max_len,
            return_tensors="pt"
        )

        return {
            'input_ids': encoding['input_ids'].squeeze(0),
            'attention_mask': encoding['attention_mask'].squeeze(0),
            'labels': torch.tensor(label, dtype=torch.long)
        }

# Create DataLoaders
print("Creating DataLoader instances...")
train_dataset = SpamDataset(X_train, y_train, tokenizer)
test_dataset = SpamDataset(X_test, y_test, tokenizer)

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True) # batch size 16/32 (article BERT)
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False)
print("DataLoader instances created successfully!\n")

# Define BERT model
print("Initializing BERT model for spam classification...")
bert_model = BertForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=2).to(device)
print("BERT model loaded successfully!\n")

# Define optimizer and loss function
optimizer = AdamW(bert_model.parameters(), lr=2e-5) # Learning rate: 2e-5, 3e-5, or 5e-5 (article BERT)
criterion = torch.nn.CrossEntropyLoss()
    
# Training loop
print("Starting training process...\n")
epochs = 1 # 2-4 (article BERT)
bert_model.train()

for epoch in range(epochs):
    total_loss = 0
    print(f"Epoch {epoch+1}/{epochs} starting...\n")

    for batch_idx, batch in enumerate(train_loader):
        optimizer.zero_grad()
        
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)

        outputs = bert_model(input_ids, attention_mask=attention_mask)
        loss = criterion(outputs.logits, labels)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()

        # Print progress every 10 batches
        if (batch_idx + 1) % 10 == 0:
            print(f"  Batch {batch_idx+1}/{len(train_loader)} - Loss: {loss.item():.4f}")

    avg_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch+1} completed! Average Loss: {avg_loss:.4f}\n")

print("Training complete!\n")

# Save the trained model
#MODEL_SAVE_PATH = "bert_spam_classifier.pth"
#print(f"Saving trained model to {MODEL_SAVE_PATH}...\n")
#torch.save(bert_model.state_dict(), MODEL_SAVE_PATH)
#print("Model saved successfully!\n")

# Evaluate model
print("Evaluating model performance on test data...\n")
bert_model.eval()
y_preds = []
y_true = []

with torch.no_grad():
    for batch_idx, batch in enumerate(test_loader):
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].cpu().numpy()
        
        outputs = bert_model(input_ids, attention_mask=attention_mask)
        preds = torch.argmax(outputs.logits, dim=1).cpu().numpy()
        
        y_preds.extend(preds)
        y_true.extend(labels)

        if (batch_idx + 1) % 10 == 0:
            print(f"  Processed {batch_idx+1}/{len(test_loader)} test batches...")

# Print classification report
print("\nFinal Model Performance:")
print(classification_report(y_true, y_preds))


Using device: cpu

Loading dataset...
Dataset loaded: 5808 emails

Preprocessing email content...
Selected 5808 emails from the dataset.
Text preprocessing complete!

Loading BERT tokenizer...
BERT tokenizer loaded successfully!

Splitting dataset into training and testing sets...
Training set: 4646 emails, Test set: 1162 emails

Creating DataLoader instances...
DataLoader instances created successfully!

Initializing BERT model for spam classification...


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


BERT model loaded successfully!

Starting training process...

Epoch 1/1 starting...

  Batch 10/581 - Loss: 0.6119
  Batch 20/581 - Loss: 0.4172
  Batch 30/581 - Loss: 0.1927
  Batch 40/581 - Loss: 0.3196
  Batch 50/581 - Loss: 0.3118
  Batch 60/581 - Loss: 0.2048
  Batch 70/581 - Loss: 0.0455
  Batch 80/581 - Loss: 0.1368
  Batch 90/581 - Loss: 0.1010
  Batch 100/581 - Loss: 0.1500
  Batch 110/581 - Loss: 0.1193
  Batch 120/581 - Loss: 0.1695
  Batch 130/581 - Loss: 0.0142
  Batch 140/581 - Loss: 0.0082
  Batch 150/581 - Loss: 0.2587
  Batch 160/581 - Loss: 0.0189
  Batch 170/581 - Loss: 0.0066
  Batch 180/581 - Loss: 0.0057
  Batch 190/581 - Loss: 0.0145
  Batch 200/581 - Loss: 0.0358
  Batch 210/581 - Loss: 0.0044
  Batch 220/581 - Loss: 0.0047
  Batch 230/581 - Loss: 0.0244
  Batch 240/581 - Loss: 0.0047
  Batch 250/581 - Loss: 0.0525
  Batch 260/581 - Loss: 0.2116
  Batch 270/581 - Loss: 0.3124
  Batch 280/581 - Loss: 0.1006
  Batch 290/581 - Loss: 0.1326
  Batch 300/581 - Loss: 

ValueError: text input must be of type `str` (single example), `List[str]` (batch or single pretokenized example) or `List[List[str]]` (batch of pretokenized examples).

In [53]:
### shap.plots.text(shap_values) ###

import shap

# Ensure the model is in evaluation mode
bert_model.eval()

# Function to tokenize text and return model predictions
def predict_proba(texts):
    """Tokenizes input texts and returns model probability predictions."""
    if isinstance(texts, str):  # Convert a single string input to a list
        texts = [texts]

    # Tokenize input text (convert to BERT-compatible format)
    inputs = tokenizer(texts, padding=True, truncation=True, max_length=512, return_tensors="pt")

    # Move tensors to the same device as the model
    inputs = {key: value.to(device) for key, value in inputs.items()}

    with torch.no_grad():
        outputs = bert_model(**inputs).logits  # Get raw logits

    return torch.softmax(outputs, dim=1).cpu().numpy()  # Convert logits to probabilities

# Select a sample email from the test set
sample_email = str(X_test.iloc[3])  # Convert to string --> 3 = spam; 6 = ham

# Apply SHAP explainability
print(f"\nApplying SHAP for explainability on a test email:\n{sample_email}\n")

# Use SHAP's Text masker to process raw text correctly
masker = shap.maskers.Text(tokenizer)

# Create the SHAP Explainer with lambda function to ensure proper input format
explainer = shap.Explainer(lambda x: predict_proba([str(t) for t in x]), masker)

# Compute SHAP values for the sample email
shap_values = explainer([sample_email])  # Pass as a list

# Visualize explanation
shap.plots.text(shap_values)

print("\n✅ SHAP analysis complete! Check the visualization for token-level explainability.")


Applying SHAP for explainability on a test email:
me and my friends have this brand new idea a live webcam click here this is not spam you have received this email because at one time or another you entered the weekly draw at one of our portals or ffa sites we comply with all proposed and current laws on commercial email under bill s title iii passed by the th congress if you have received this email in error we apologize for the inconvenience and ask that you remove yourself click here to unsubscribe fysibvcgjyuwinmyvbpjtaebsymyukbrkn



PartitionExplainer explainer: 2it [00:19, 19.94s/it]               



✅ SHAP analysis complete! Check the visualization for token-level explainability.


In [52]:
# ### shap.plots.bar(shap_values) ###

# import shap

# # Ensure the model is in evaluation mode
# bert_model.eval()

# # Function to tokenize text and return model predictions
# def predict_proba(texts):
#     """Tokenizes input texts and returns model probability predictions."""
#     if isinstance(texts, str):  # If a single string is passed, convert it to a list
#         texts = [texts]

#     # Ensure texts are all strings
#     texts = [str(t) for t in texts]

#     # Tokenize input text (convert to BERT-compatible format)
#     inputs = tokenizer(texts, padding=True, truncation=True, max_length=512, return_tensors="pt")
#     inputs = {key: value.to(device) for key, value in inputs.items()}  # Move tensors to device

#     with torch.no_grad():
#         outputs = bert_model(**inputs).logits  # Get raw logits

#     return torch.softmax(outputs, dim=1).cpu().numpy()  # Convert logits to probabilities

# # Select multiple sample emails and explicitly convert to a list of strings
# num_samples = 2  # Adjust as needed
# sample_emails = X_test.sample(n=num_samples, random_state=42).astype(str).tolist()  # ✅ Ensure strings

# # Apply SHAP explainability
# print(f"\nApplying SHAP for explainability on {num_samples} test emails...\n")

# # Use SHAP's Text masker to process raw text correctly
# masker = shap.maskers.Text(tokenizer)

# # Create the SHAP Explainer using the defined predict function
# explainer = shap.Explainer(predict_proba, masker)

# # Compute SHAP values for multiple samples
# shap_values = explainer(sample_emails) 

# # Visualize explanations - Bar Plot
# shap.plots.bar(shap_values)

# print("\n✅ SHAP analysis complete! Check the bar plot for global feature importance.")