In [None]:
!pip install datasets

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
import shap
import numpy as np
import time
import pandas as pd
from tqdm import tqdm

In [None]:
# Configuration
access_token = ""  # Replace with your token
model_name = "meta-llama/Llama-3.1-8B-Instruct"
max_length = 256  # Reduced sequence length for stability

# Label mapping
label_map = {
    0: 'entailment',
    1: 'neutral',
    2: 'contradiction'
}

# Load model and tokenizer
print("Loading model...")
tokenizer = AutoTokenizer.from_pretrained(model_name, token=access_token)
model = AutoModelForCausalLM.from_pretrained(model_name, token=access_token,
                                           torch_dtype=torch.float16,
                                           device_map="auto")

# Fix tokenizer settings
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"

model.eval()

In [None]:
# Load dataset
num_samples = 500
print("Loading dataset...")
dataset = load_dataset("esnli", split=f"validation[:{num_samples}]")

In [None]:
# Improved prediction function
def safe_predict(texts):
    try:
        if isinstance(texts, str):
            texts = [texts]
        elif isinstance(texts, np.ndarray):
            texts = texts.tolist()

        # Ensure we have a list of strings
        if not all(isinstance(t, str) for t in texts):
            texts = [str(t) for t in texts]

        inputs = tokenizer(texts, return_tensors="pt",
                         padding=True, truncation=True,
                         max_length=max_length).to(model.device)

        with torch.no_grad():
            outputs = model(**inputs)

        # Return logits for the last token only
        return outputs.logits[:, -1, :3].cpu().numpy()  # Only first 3 dimensions for our labels

    except Exception as e:
        print(f"\nPrediction error: {str(e)}")
        return np.zeros((len(texts), 3))  # Return neutral predictions on error


In [None]:
# Robust SHAP explainer
def get_shap_values(text):
    try:
        # Create explainer with safety checks
        masker = shap.maskers.Text(tokenizer, mask_token=tokenizer.eos_token)
        explainer = shap.Explainer(
            safe_predict,
            masker,
            algorithm="permutation",
            max_evals=200,  # Reduced for speed
            output_names=list(label_map.values()))

        with torch.no_grad():
            shap_values = explainer([text])
        return shap_values
    except Exception as e:
        print(f"\nSHAP error: {str(e)}")
        return None


In [None]:
# Robust explanation generator
def generate_explanation(prompt, max_new_tokens=100):
    try:
        inputs = tokenizer(prompt, return_tensors="pt",
                         max_length=max_length,
                         truncation=True).to(model.device)

        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_new_tokens=max_new_tokens,
                pad_token_id=tokenizer.eos_token_id,
                do_sample=True,
                temperature=0.7,
                top_p=0.9
            )

        full_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
        # Extract only the new generated part
        input_length = inputs.input_ids.shape[1]
        return full_text[input_length:].strip()
    except Exception as e:
        print(f"\nGeneration error: {str(e)}")
        return "Error generating explanation"

In [None]:
# Main processing loop
results_c = []
results_ic = []
print(f"\nProcessing {num_samples} samples...")
start_time = time.time()

for i, example in enumerate(tqdm(dataset, desc="Processing samples")):
    if i >= num_samples:
        break

    premise = example['premise']
    hypothesis = example['hypothesis']
    text = f"Premise: {premise}\nHypothesis: {hypothesis}"
    gt_label = label_map[example['label']]

    # Get prediction
    pred_prompt = f"""Classify this as entailment, neutral, or contradiction:

Premise: {premise}
Hypothesis: {hypothesis}
Answer:"""

    pred_label = generate_explanation(pred_prompt, max_new_tokens=10)
    pred_label = pred_label.lower().strip()

    # Clean and validate prediction
    pred_label = pred_label.replace('"', '').replace("'", "")
    if "entail" in pred_label:
        pred_label = "entailment"
    elif "neutral" in pred_label:
        pred_label = "neutral"
    elif "contradict" in pred_label:
        pred_label = "contradiction"
    else:
        pred_label = "neutral"  # Default fallback

    print(f"{gt_label} --- {pred_label}")

    # Get SHAP values
    shap_values = get_shap_values(text)
    shap_str = str([(token, values.tolist()) for token, values in zip(shap_values.data[0], shap_values.values[0])])

    if pred_label == gt_label:
      # Store results
      results_c.append({
        'premise': premise,
        'hypothesis': hypothesis,
        'gt_label': gt_label,
        'pred_label': pred_label,
        'shap_value': shap_str,
        'human_explanation': example.get('explanation_1', '')
      })

    else:
      # Store results
      results_ic.append({
        'premise': premise,
        'hypothesis': hypothesis,
        'gt_label': gt_label,
        'pred_label': pred_label,
        'shap_value': shap_str,
        'human_explanation': example.get('explanation_1', '')
      })


In [None]:
df_c = pd.DataFrame(results_c)
df_ic = pd.DataFrame(results_ic)

In [None]:
# Save results
output_file_c  = "llama3_nli_analysis_500_shap_correct.csv"
output_file_ic = "llama3_nli_analysis_500_shap_incorrect.csv"

df_c.to_csv(output_file_c, index=False)
df_ic.to_csv(output_file_ic, index=False)

In [None]:
df_c.shape