In [1]:
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from tqdm.auto import tqdm
import time
import gc
import ast # To parse the string representation of the list/tuples

In [2]:
# --- Configuration ---
model_name = "meta-llama/Llama-3.1-8B-Instruct"
access_token = ""

# Max tokens for the generated *SHAP-informed explanation*
max_shap_explanation_length = 50
# Input file path (output from the previous step)
input_file_with_explanations = '/content/Llama_3_8B_Instruct_explanations.csv'

top_n_shap_tokens = 10

# --- Original Label Mapping (Crucial for interpreting SHAP values) ---
label_map = {
    0: 'entailment',
    1: 'neutral',
    2: 'contradiction'
}
# Create inverse map: label name -> index
label_to_index = {v: k for k, v in label_map.items()}

# --- Load Model and Tokenizer (if not already loaded) ---
# Re-use the loading code from the previous script if needed,
# ensuring the same configuration (dtype, device_map, pad_token).
# Example minimal reload:
print("Checking/Loading model and tokenizer...")
try:
    model
    tokenizer
    print("Model and tokenizer already loaded.")
    model.eval() # Ensure it's in eval mode
except NameError:
    print("Loading model and tokenizer...")
    tokenizer = AutoTokenizer.from_pretrained(model_name, token=access_token)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "left"

    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        token=access_token,
        torch_dtype=torch.float16,
        device_map="auto"
    )
    model.eval()
    print("Model and tokenizer loaded.")

Checking/Loading model and tokenizer...
Loading model and tokenizer...


tokenizer_config.json:   0%|          | 0.00/55.4k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/296 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/855 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/23.9k [00:00<?, ?B/s]

Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

model-00004-of-00004.safetensors:   0%|          | 0.00/1.17G [00:00<?, ?B/s]

model-00003-of-00004.safetensors:   0%|          | 0.00/4.92G [00:00<?, ?B/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00002-of-00004.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/184 [00:00<?, ?B/s]



Model and tokenizer loaded.


In [3]:
# --- Helper Function for Generation (Use the improved one) ---
def generate_model_explanation(prompt, max_new_tokens):
    """Generates text using the loaded model (incorporate sampling/penalty)."""
    inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=tokenizer.model_max_length - max_new_tokens).to(model.device)
    attention_mask = inputs.attention_mask

    if torch.cuda.is_available():
        torch.cuda.empty_cache()

    with torch.no_grad():
        outputs = model.generate(
            inputs.input_ids,
            attention_mask=attention_mask,
            max_new_tokens=max_new_tokens,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id, # Important for stopping
            # --- Use Improved Generation Params ---
            do_sample=True,
            temperature=0.1, # Or your preferred value
            top_k=50,        # Or your preferred value
            repetition_penalty=1.15 # Or your preferred value
            # --- End Improved Params ---
        )

    input_length = inputs.input_ids.shape[1]
    generated_ids = outputs[0, input_length:]
    explanation = tokenizer.decode(generated_ids, skip_special_tokens=True)
    explanation = explanation.strip()

    del inputs, outputs
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    gc.collect()

    return explanation

In [4]:
# --- Helper Function to Process SHAP String ---
def get_top_shap_tokens(shap_str, pred_label, label_to_index_map, top_n=10):
    """Parses SHAP string and returns top N tokens for the predicted label."""
    try:
        # Safely parse the string representation
        shap_data = ast.literal_eval(shap_str)
    except (ValueError, SyntaxError, TypeError) as e:
        print(f"Warning: Could not parse SHAP string: {e}\nString: {shap_str[:100]}...")
        return "Error parsing SHAP" # Return placeholder on error

    if not isinstance(shap_data, list) or not shap_data:
        return "No SHAP data"

    try:
        # Get the index corresponding to the predicted label
        pred_index = label_to_index_map.get(pred_label)
        if pred_index is None:
            print(f"Warning: Predicted label '{pred_label}' not found in label_to_index map.")
            return "Unknown label index"

        # Calculate importance score (absolute SHAP value for the predicted class)
        token_importance = []
        for item in shap_data:
             # Ensure item is a tuple/list with at least two elements (token, values_list)
            if isinstance(item, (list, tuple)) and len(item) >= 2:
                token = item[0]
                values = item[1]
                 # Ensure values is a list/tuple and index is valid
                if isinstance(values, (list, tuple)) and len(values) > pred_index:
                    # Use abs value of SHAP for the predicted class as importance
                    importance = abs(values[pred_index])
                    token_importance.append((token, importance))
                else:
                    # Handle cases where values format is unexpected
                    # print(f"Warning: Unexpected SHAP values format for token '{token}': {values}")
                    pass # Optionally skip or assign default importance
            else:
                # Handle cases where item format is unexpected
                # print(f"Warning: Unexpected item format in SHAP data: {item}")
                pass # Optionally skip


        # Sort by importance (descending)
        token_importance.sort(key=lambda x: x[1], reverse=True)

        # Get top N tokens (the actual token strings)
        top_tokens = [item[0] for item in token_importance[:top_n]]

        # Return as a comma-separated string for the prompt
        return ", ".join(top_tokens)

    except Exception as e:
        print(f"Error processing SHAP data for label '{pred_label}': {e}")
        return "Error processing SHAP"

In [5]:
# --- Load Data ---
print(f"Loading results from {input_file_with_explanations}...")
try:
    df_correct = pd.read_csv(input_file_with_explanations)
except FileNotFoundError:
    print(f"Error: File not found at {input_file_with_explanations}")
    exit()
except Exception as e:
    print(f"Error loading CSV: {e}")
    exit()

# Ensure 'shap_value' column exists
if 'shap_value' not in df_correct.columns:
    print(f"Error: 'shap_value' column not found in {input_file_with_explanations}.")
    exit()
# Ensure 'pred_label' column exists
if 'pred_label' not in df_correct.columns:
     print(f"Error: 'pred_label' column not found in {input_file_with_explanations}.")
     exit()


print(f"Loaded {len(df_correct)} samples.")

Loading results from /content/llama3_nli_analysis_500_shap_correct.csv...
Loaded 179 samples.


In [10]:
# --- Generate SHAP-informed Explanations ---
model_explanations_shap = []
print(f"Generating SHAP-informed explanations (top {top_n_shap_tokens} tokens, max_new_tokens={max_shap_explanation_length})...")
start_time = time.time()

for index, row in tqdm(df_correct.iterrows(), total=df_correct.shape[0], desc="Generating SHAP Explanations"):
    premise = row['premise']
    hypothesis = row['hypothesis']
    predicted_label = row['pred_label']
    shap_str = row['shap_value']

    # Get the most influential tokens based on SHAP for the predicted class
    important_tokens_str = get_top_shap_tokens(shap_str, predicted_label, label_to_index, top_n=top_n_shap_tokens)

    # Handle cases where SHAP processing failed
    if "Error" in important_tokens_str or "No SHAP data" in important_tokens_str or "Unknown label" in important_tokens_str:
        shap_explanation = f"Could not generate SHAP explanation ({important_tokens_str})."
        print(f"Skipping SHAP explanation for index {index} due to SHAP processing issue: {important_tokens_str}")
    elif not important_tokens_str:
         shap_explanation = "Could not generate SHAP explanation (No important tokens found)."
         print(f"Skipping SHAP explanation for index {index} due to no important tokens found.")
    else:
        shap_prompt = f"""Given [Premise: {premise}
                  Hypothesis: {hypothesis}
                  Classification: {predicted_label}]
                  These tokens were important: {important_tokens_str}.
                  In exactly one sentence, why is this classification correct?"""

        # Generate the SHAP-informed explanation
        shap_explanation = generate_model_explanation(shap_prompt, max_new_tokens=max_shap_explanation_length)

    model_explanations_shap.append(shap_explanation)
    print(shap_explanation)


Generating SHAP-informed explanations (top 10 tokens, max_new_tokens=100)...


Generating SHAP Explanations:   0%|          | 0/179 [00:00<?, ?it/s]

The hypothesis and premise contradict each other because the premise describes two women embracing while holding grocery packages, which suggests they are together in harmony, whereas the hypothesis states that there is conflict between people outside a deli. Answer: This classification is correct because
The hypothesis entails that two kids in numbered jerseys wash their hand. The premise describes two kids in numbered jerseys (number 9 and number 2) who are indeed washing their hands. Therefore, the hypothesis logically follows from the premise. This makes
The classification as entailment is incorrect because it does not follow from the premise that there will be a man selling donuts. The premise only mentions an event and a sale but does not guarantee the presence of a man or donuts. Therefore,


KeyboardInterrupt: 

In [None]:
# --- Add Explanations to DataFrame ---
df_correct['shap_model_explanation'] = model_explanations_shap
print("\nFinished generating SHAP-informed explanations.")
end_time = time.time()
print(f"Total time taken: {end_time - start_time:.2f} seconds")


Finished generating SHAP-informed explanations.
Total time taken: 310.77 seconds


In [None]:
# --- Save Updated DataFrame ---
output_file = 'Llama_3_8B_Instruct_explanations_final.csv'
print(f"\nSaving updated DataFrame to {output_file}...")
try:
    df_correct.to_csv(output_file, index=False)
    print("DataFrame saved successfully.")
except Exception as e:
    print(f"Error saving CSV: {e}")

print("\nScript finished.")


Saving updated DataFrame to results_correct_with_shap_explanations.csv...
DataFrame saved successfully.

Script finished.
