In [None]:
import shap
import torch
import pandas as pd
import numpy as np
from transformers import AutoTokenizer, AutoModelForMultipleChoice
from tqdm import tqdm
import torch.nn.functional as F

# Load the multiple-choice model and tokenizer
MODEL_NAME = "model of choice"  # Replace with actual model name
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForMultipleChoice.from_pretrained(MODEL_NAME)

# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
model.eval()  # Set the model to evaluation mode

# Load dataset
df = pd.read_csv("dataset.csv")  # Replace with actual dataset path
# Filter dataset to only include single-choice questions
df = df[df["choice_type"] == "Single"]

# Extract relevant columns
question_col = "Augmented_Question"
choices_cols = ["opa", "opb", "opc", "opd"]
correct_col = "cop"
demographic_cols = ["Male", "Female", "White", "Black", "Arab", "Asian", "Other", "Low", "Middle", "High"]

# Define prediction function for SHAP
def f_mcq(x):
    """SHAP function to get logits for multiple-choice questions."""
    # Extract question and choices from SHAP input format
    question = x["question"]
    choices = x["choices"]
    
    # Tokenize inputs
    inputs = tokenizer(
        [question] * len(choices),
        choices,
        return_tensors="pt",
        padding=True,
        truncation=True
    )
    
    # Move inputs to the same device as model
    inputs = {k: v.to(device) for k, v in inputs.items()}
    
    # Get model predictions
    with torch.no_grad():
        outputs = model(**inputs)
    
    # Return logits as numpy array
    return outputs.logits.cpu().numpy()

# Create a small background dataset for the SHAP explainer
background_indices = np.random.choice(len(df), min(10, len(df)), replace=False)
background_dataset = []
for idx in background_indices:
    row = df.iloc[idx]
    background_dataset.append({
        "question": row[question_col],
        "choices": [row[col] for col in choices_cols if pd.notna(row[col])]
    })

# Initialize SHAP explainer with proper masker
masker = shap.maskers.Text(tokenizer)
explainer = shap.Explainer(f_mcq, masker, output_names=["opa", "opb", "opc", "opd"])

# Process the dataset
results = []
for index, row in tqdm(df.iterrows(), total=len(df)):
    try:
        # Skip rows with missing data
        if any(pd.isna(row[col]) for col in [question_col] + choices_cols):
            continue
        
        # Prepare input for the explainer
        sample_input = {
            "question": row[question_col],
            "choices": [row[col] for col in choices_cols if pd.notna(row[col])]
        }
        
        # Get SHAP values
        shap_values = explainer([sample_input])
        
        # Get raw model predictions for confidence metrics
        logits = f_mcq(sample_input)
        probabilities = F.softmax(torch.tensor(logits), dim=1).numpy()
        
        # Get tokens from the tokenizer
        tokens = tokenizer.convert_ids_to_tokens(shap_values.data[0])
        
        # Basic data extraction
        shap_data = {
            "index": index,
            "question": row[question_col],
            "correct_option": row[correct_col] if correct_col in row else None
        }
        
        # Add demographic information from binary columns
        for col in demographic_cols:
            if col in row:
                shap_data[f"demographic_{col}"] = int(row[col])
        
        # 1. TOKEN-LEVEL IMPORTANCE SCORES
        # Sum importance across all choices to get overall token importance
        token_importances = np.abs(shap_values.values[0]).sum(axis=0)
        
        # Find top 10 most important tokens
        most_important_indices = np.argsort(token_importances)[-10:][::-1]
        
        # Store important tokens and their scores
        shap_data["important_tokens"] = []
        shap_data["token_importance_scores"] = []
        
        for idx in most_important_indices:
            if idx < len(tokens):  # Safety check
                token = tokens[idx]
                score = float(token_importances[idx])  # Convert to Python float for JSON serialization
                shap_data["important_tokens"].append(token)
                shap_data["token_importance_scores"].append(score)
        
        # 2. CONFIDENCE METRICS
        # Get overall model confidence
        shap_data["confidence"] = float(probabilities.max())
        shap_data["prediction"] = int(probabilities.argmax())
        
        # Calculate entropy (measure of uncertainty)
        entropy = -np.sum(probabilities * np.log(probabilities + 1e-10))
        shap_data["uncertainty_entropy"] = float(entropy)
        
        # Calculate how spread out the probabilities are (another uncertainty measure)
        spread = np.max(probabilities) - np.min(probabilities)
        shap_data["confidence_spread"] = float(spread)
        
        # 3. BIAS INDICATORS
        # Calculate bias indicators based on demographic binary columns
        shap_data["bias_indicators"] = {}
        
        # Store option importance for each choice
        option_importances = np.abs(shap_values.values).mean(axis=2)
        for i, option in enumerate(choices_cols[:len(sample_input["choices"])]):
            if i < option_importances.shape[1]:
                shap_data[f"importance_{option}"] = float(option_importances[0][i])
        
        results.append(shap_data)
        
        # Save visualization for the first few samples
        if index < 5:
            shap.plots.text(shap_values[0, :, :], save_path=f"shap_example_{index}.png")
    
    except Exception as e:
        print(f"Error processing row {index}: {e}")

# Convert results to DataFrame
results_df = pd.DataFrame(results)

# Save full results to JSON for better preservation of nested structures
import json
with open("shap_analysis_full_results.json", "w") as f:
    json.dump(results, f)

# Create demographic-specific analysis
demographic_analysis = {}

# Analyze data by demographic groups
for demo_col in demographic_cols:
    demo_key = f"demographic_{demo_col}"
    if demo_key in results_df.columns:
        # Create a subset of results for this demographic group
        demo_results = results_df[results_df[demo_key] == 1]
        non_demo_results = results_df[results_df[demo_key] == 0]
        
        if not demo_results.empty and not non_demo_results.empty:
            # Calculate average metrics for demographic comparison
            analysis = {
                "total_questions": len(demo_results),
                "average_confidence": demo_results["confidence"].mean(),
                "average_uncertainty": demo_results["uncertainty_entropy"].mean(),
                "option_importance": {}
            }
            
            # Compare option importance between demographic groups
            for option in choices_cols:
                option_col = f"importance_{option}"
                if option_col in demo_results.columns:
                    demo_avg = demo_results[option_col].mean()
                    non_demo_avg = non_demo_results[option_col].mean()
                    
                    # Calculate relative importance (how much more/less important for this demographic)
                    relative_importance = demo_avg / non_demo_avg if non_demo_avg > 0 else 0
                    
                    analysis["option_importance"][option] = {
                        "demo_group_avg": float(demo_avg),
                        "other_group_avg": float(non_demo_avg),
                        "relative_importance": float(relative_importance)
                    }
            
            # Compare token importance patterns
            # Aggregate token importance across demographic groups
            demo_tokens = {}
            non_demo_tokens = {}
            
            for _, row in demo_results.iterrows():
                if "important_tokens" in row and "token_importance_scores" in row:
                    for token, score in zip(row["important_tokens"], row["token_importance_scores"]):
                        if token not in demo_tokens:
                            demo_tokens[token] = []
                        demo_tokens[token].append(score)
            
            for _, row in non_demo_results.iterrows():
                if "important_tokens" in row and "token_importance_scores" in row:
                    for token, score in zip(row["important_tokens"], row["token_importance_scores"]):
                        if token not in non_demo_tokens:
                            non_demo_tokens[token] = []
                        non_demo_tokens[token].append(score)
            
            # Find tokens that are significantly more important for one demographic group
            analysis["token_bias"] = {}
            for token in set(demo_tokens.keys()) | set(non_demo_tokens.keys()):
                demo_avg = np.mean(demo_tokens.get(token, [0]))
                non_demo_avg = np.mean(non_demo_tokens.get(token, [0]))
                
                if demo_avg > 0 and non_demo_avg > 0:
                    relative_importance = demo_avg / non_demo_avg
                    
                    # Flag tokens with significant bias (more than 2x difference)
                    if relative_importance > 2 or relative_importance < 0.5:
                        analysis["token_bias"][token] = {
                            "demo_importance": float(demo_avg),
                            "non_demo_importance": float(non_demo_avg),
                            "relative_importance": float(relative_importance)
                        }
            
            demographic_analysis[demo_col] = analysis

# Save demographic analysis
with open("demographic_bias_analysis.json", "w") as f:
    json.dump(demographic_analysis, f)

# Generate summary report
summary = {
    "total_questions_analyzed": len(results),
    "average_confidence": results_df["confidence"].mean() if "confidence" in results_df else None,
    "average_uncertainty": results_df["uncertainty_entropy"].mean() if "uncertainty_entropy" in results_df else None,
    "demographic_summary": {}
}

# Summarize key findings for each demographic group
for demo_col, analysis in demographic_analysis.items():
    summary["demographic_summary"][demo_col] = {
        "question_count": analysis["total_questions"],
        "biased_tokens_count": len(analysis["token_bias"]),
        "top_biased_tokens": list(analysis["token_bias"].keys())[:5],
        "option_importance_variance": max([d["relative_importance"] for d in analysis["option_importance"].values()]) 
                                     if analysis["option_importance"] else None
    }

# Save summary
with open("shap_analysis_summary.json", "w") as f:
    json.dump(summary, f)

# Create a CSV with the most important findings
summary_df = pd.DataFrame()
for demo in demographic_cols:
    if demo in demographic_analysis:
        summary_df.loc[demo, "Question Count"] = demographic_analysis[demo]["total_questions"]
        summary_df.loc[demo, "Avg Confidence"] = demographic_analysis[demo]["average_confidence"]
        summary_df.loc[demo, "Biased Tokens"] = len(demographic_analysis[demo]["token_bias"])
        
        # Add option importance data
        for option in choices_cols:
            if option in demographic_analysis[demo]["option_importance"]:
                rel_imp = demographic_analysis[demo]["option_importance"][option]["relative_importance"]
                summary_df.loc[demo, f"{option} Rel Importance"] = rel_imp

summary_df.to_csv("demographic_bias_summary.csv")

print("Enhanced SHAP analysis completed successfully!")