# Interactive NLI Explainer

This notebook allows you to input a premise and a hypothesis, then receive:
1. The NLI model's prediction (entailment, contradiction, or neutral).
2. Attention visualizations (full, premise-only, and hypothesis-only self-attention).
3. A LIME explanation highlighting words contributing to the prediction.

## 1. Load Dependencies

In [None]:
import sys
sys.path.append('../src') # Add src directory to path

import torch
from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification
import numpy as np
# matplotlib and seaborn are used by xai_utils.plot_heatmap
import lime
import lime.lime_text
import os

from xai_utils import get_lime_predictor, process_model_attentions, plot_heatmap

%matplotlib inline

## 2. Load Fine-tuned Model and Tokenizer

In [None]:
model_dir = '../src/nli_model/' # Relative path from notebooks/ to src/
tokenizer = DistilBertTokenizerFast.from_pretrained(model_dir)
model = DistilBertForSequenceClassification.from_pretrained(model_dir, output_attentions=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()

label_map = {0: 'entailment', 1: 'contradiction', 2: 'neutral'}
class_names = list(label_map.values())

print(f"Model loaded on: {device}")

## 3. User Input

In [None]:
premise = input("Enter premise: ")
hypothesis = input("Enter hypothesis: ")

print(f"\nPremise: {premise}")
print(f"Hypothesis: {hypothesis}")

## 4. NLI Prediction

In [None]:
inputs = tokenizer(premise, hypothesis, return_tensors='pt', truncation=True, padding=True).to(device)

with torch.no_grad():
    outputs = model(**inputs)

logits = outputs.logits
model_attentions_output = outputs.attentions # Store for visualization section

predicted_idx = torch.argmax(logits, dim=1).item()
predicted_label = label_map[predicted_idx]

print(f"\nPredicted NLI Label: {predicted_label.upper()} (Class index: {predicted_idx})")

## 5. Attention Visualization

In [None]:
# Process attentions using utility function
avg_attentions_np, tokens = process_model_attentions(model_attentions_output, inputs['input_ids'], tokenizer)

# Plot combined heatmap using utility function
plot_heatmap(avg_attentions_np, tokens, tokens, title='Full Attention Heatmap - Last Layer (Averaged Heads)')

# Separated Attention Views
try:
    sep_idx = tokens.index(tokenizer.sep_token)
except ValueError:
    print(f"'{tokenizer.sep_token}' not found in tokens. Cannot generate separated attention views.")
    sep_idx = -1 # Ensure it's defined for logic below

if sep_idx != -1 and len(tokens) > 1: # Check if sep_idx is valid and there are enough tokens
    # Premise self-attention
    premise_tokens = tokens[1:sep_idx] # Exclude [CLS]
    premise_avg_attentions = avg_attentions_np[1:sep_idx, 1:sep_idx]
    if premise_tokens and premise_avg_attentions.size > 0:
        plot_heatmap(premise_avg_attentions, premise_tokens, premise_tokens, title='Premise Self-Attention')
    else:
        print("Not enough tokens to display premise self-attention.")

    # Hypothesis self-attention
    hypothesis_tokens = tokens[sep_idx+1:-1] # Exclude [SEP]s
    hypothesis_avg_attentions = avg_attentions_np[sep_idx+1:-1, sep_idx+1:-1]
    if hypothesis_tokens and hypothesis_avg_attentions.size > 0:
        plot_heatmap(hypothesis_avg_attentions, hypothesis_tokens, hypothesis_tokens, title='Hypothesis Self-Attention')
    else:
        print("Not enough tokens to display hypothesis self-attention.")
else:
    if sep_idx == -1:
        print("Skipping separated attention view plots as SEP token was not found.")
    else:
        print("Skipping separated attention view plots due to insufficient tokens.")

## 6. LIME Explanation

In [None]:
# Get the LIME predictor function from xai_utils
lime_predictor_fn = get_lime_predictor(model, tokenizer, device)

text_instance = premise + " " + tokenizer.sep_token + " " + hypothesis

explainer = lime.lime_text.LimeTextExplainer(class_names=class_names, bow=False, random_state=42)

print(f"\nExplaining LIME for text: '{text_instance}'")
print(f"Prediction LIME is explaining: {label_map[predicted_idx].upper()}")

lime_explanation = explainer.explain_instance(
    text_instance,
    lime_predictor_fn, # Use the refactored predictor
    num_features=10, 
    num_samples=500,
    labels=(predicted_idx,) # Explain only the predicted class
)

print("\nDisplaying LIME explanation in notebook (for the predicted class):")
lime_explanation.show_in_notebook(text=True)

--- End of Interactive Explanation ---