In [None]:
import shap
import torch
import pandas as pd
from transformers import AutoTokenizer, AutoModelForMultipleChoice

# Load the multiple-choice model and tokenizer
MODEL_NAME = "model of choice"  # Replace with a fine-tuned MCQ model

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForMultipleChoice.from_pretrained(MODEL_NAME)

# Load dataset
df = pd.read_csv("dataset.csv")  # Replace with actual dataset path

# Filter dataset to only include single-choice questions
df = df[df["choice_type"] == "Single"]

# Extract relevant columns
question_col = "Augmented_Question"  # Use the question with demographic variable
choices_cols = ["opa", "opb", "opc", "opd"]
correct_col = "cop"
demographic_cols = ["Male", "Female", "White", "Black", "Arab", "Asian", "Other", "Low", "Middle", "High"]

# Format input for multiple-choice models
def format_mcq(row):
    choices = [row[col] for col in choices_cols]
    encodings = tokenizer([row[question_col]] * len(choices), choices, padding=True, truncation=True, return_tensors="pt")
    return encodings

# Define prediction function
def f_mcq(x):
    """SHAP function to get logits for multiple-choice questions."""
    batch = tokenizer([x["question"]] * len(x["choices"]), x["choices"], padding=True, truncation=True, return_tensors="pt")
    # Add batch dimension
    for key in batch:
        batch[key] = batch[key].unsqueeze(0)
    with torch.no_grad():
        outputs = model(**batch)
    return outputs.logits.numpy()

# Define token-to-string mapping
def out_names(x):
    return tokenizer.convert_ids_to_tokens(x)

f_mcq.output_names = out_names

# Initialize SHAP explainer
masker = shap.maskers.Text(tokenizer)
explainer = shap.Explainer(f_mcq, masker)

# Process the entire dataset and group by demographic variables
shap_results = []
for index, row in df.iterrows():
    sample_input = {"question": row[question_col], "choices": [row[col] for col in choices_cols]}
    shap_values = explainer([sample_input])
    
    # Store SHAP results along with demographic information
    result = {"index": index, "shap_values": shap_values, "demographics": {col: row[col] for col in demographic_cols}}
    shap_results.append(result)

# Convert results to a DataFrame for further analysis
shap_df = pd.DataFrame(shap_results)
shap_df.to_csv("shap_results.csv", index=False)  # Save results for later analysis

# Visualize an example explanation
# shap.plots.text(shap_results[0]["shap_values"])