<a href="https://colab.research.google.com/github/OneFineStarstuff/OneFineStarstuff/blob/main/Example_SHAP_for_NLP_Foundation_Models.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
pip install shap

In [None]:
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import shap
import numpy as np

# Load pre-trained model and tokenizer
model = AutoModelForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2)
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

# Dummy input text data (Ensure this is a list of strings)
text_data = ["This is a sample text for model explanation."]

# Tokenize the text data for use in prediction
inputs = tokenizer(text_data, return_tensors="pt", truncation=True, padding=True)

# Define a wrapper for the model to use with SHAP
class ModelWrapper(torch.nn.Module):
    def __init__(self, model):
        super(ModelWrapper, self).__init__()
        self.model = model

    def forward(self, input_ids, attention_mask=None):
        outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
        return outputs.logits

model_wrapper = ModelWrapper(model)

# Define the prediction function
def predict(input_texts):
    # Check if input_texts is a numpy ndarray, if so, convert to list of strings
    if isinstance(input_texts, np.ndarray):
        input_texts = input_texts.tolist()

    # Ensure input_texts is a list of strings
    if isinstance(input_texts, list) and all(isinstance(i, str) for i in input_texts):
        # Tokenize input (support for batch processing)
        tokenized_inputs = tokenizer(input_texts, return_tensors="pt", truncation=True, padding=True, max_length=512)
        with torch.no_grad():
            logits = model_wrapper(tokenized_inputs['input_ids'], tokenized_inputs['attention_mask'])
            return logits.cpu().numpy()
    else:
        raise ValueError("Input must be a list of strings.")

# Initialize the SHAP Explainer with a masker
masker = shap.maskers.Text(tokenizer)

# Initialize the SHAP Explainer
explainer = shap.Explainer(predict, masker=masker)

# Compute SHAP values
shap_values = explainer(text_data)

# Visualize the SHAP values
shap.plots.text(shap_values[0])

In [None]:
pip install datasets

In [None]:
# Import necessary libraries
import torch
from transformers import BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments
from datasets import load_dataset
import shap

# 1. Load Dataset (Here, we use IMDB for binary sentiment classification)
dataset = load_dataset("imdb")

# 2. Initialize the BERT tokenizer
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

# 3. Tokenize the dataset (this will handle padding and truncation)
def tokenize_function(examples):
    return tokenizer(examples['text'], padding="max_length", truncation=True)

tokenized_datasets = dataset.map(tokenize_function, batched=True)
tokenized_datasets = tokenized_datasets.remove_columns(["text"])  # Remove raw text for efficiency
tokenized_datasets.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])

# 4. Split the dataset into train and validation
train_dataset = tokenized_datasets["train"]
val_dataset = tokenized_datasets["test"]

# 5. Initialize the model for sequence classification
model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2)  # Adjust num_labels for multi-class tasks

# 6. Define training arguments
training_args = TrainingArguments(
    output_dir="./results",              # Output directory
    evaluation_strategy="epoch",         # Evaluate at each epoch
    learning_rate=2e-5,                  # Learning rate
    per_device_train_batch_size=8,       # Batch size for training
    per_device_eval_batch_size=8,        # Batch size for evaluation
    num_train_epochs=3,                  # Number of training epochs
    weight_decay=0.01,                   # Weight decay
)

# 7. Initialize the Trainer
trainer = Trainer(
    model=model,                         # The model to train
    args=training_args,                  # Training arguments
    train_dataset=train_dataset,         # Training dataset
    eval_dataset=val_dataset,            # Validation dataset
)

# 8. Train the model
trainer.train()

# 9. Save the fine-tuned model and tokenizer
model.save_pretrained("./fine_tuned_model")
tokenizer.save_pretrained("./fine_tuned_model")

# 10. Define a prediction function for SHAP
def predict(input_texts):
    inputs = tokenizer(input_texts, padding=True, truncation=True, return_tensors="pt")
    with torch.no_grad():
        logits = model(**inputs).logits
    return logits.numpy()

# 11. Initialize SHAP explainer
explainer = shap.Explainer(predict, tokenizer)

# 12. Example input text for SHAP explanation
input_texts = ["This is a sample text for model explanation."]

# 13. Get SHAP values
shap_values = explainer(input_texts)

# 14. Visualize SHAP values for the first input text
shap.plots.text(shap_values[0])

In [None]:
def predict(input_texts):
    print(f"Input format: {type(input_texts)}")  # Print input type
    if isinstance(input_texts, list) and all(isinstance(i, str) for i in input_texts):
        # Tokenize input (support for batch processing)
        tokenized_inputs = tokenizer(input_texts, return_tensors="pt", truncation=True, padding=True, max_length=512)
        with torch.no_grad():
            logits = model_wrapper(tokenized_inputs['input_ids'], tokenized_inputs['attention_mask'])
            return logits.cpu().numpy()
    else:
        raise ValueError("Input must be a list of strings.")  # Raise an error if input format is incorrect