# Measure Exact Answer & Spilled Energy Detection

This notebook demonstrates how to:
1. Load a sample from TriviaQA.
2. Generate an answer using an LLM.
3. Extract the exact answer from the generated text.
4. Compute **Spilled Energy (Delta)**, **Energy (E)**, and **Marginalized Energy (E_margin)** specifically on the exact answer tokens.
5. Aggregates these metrics using **Mean**, **Max**, **Min**, and **Sum** strategies.

In [None]:
%load_ext autoreload
%autoreload 2

import sys
import os
import torch
import numpy as np
import logging
import transformers
import datasets
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset

# Configure logging to see extraction details
logging.basicConfig(level=logging.INFO, force=True)
logging.getLogger("spilled_energy.extraction").setLevel(logging.DEBUG)

# Disable INFO logs from libraries to reduce clutter
logging.getLogger("transformers").setLevel(logging.ERROR)
logging.getLogger("httpx").setLevel(logging.WARNING)
transformers.logging.set_verbosity_error()
datasets.logging.set_verbosity_error()

from spilled_energy.generation import generate_answer
from spilled_energy.extraction import extract_exact_answer
from spilled_energy.energy import spilled_energy

## 1. Load Data
We use a sample from the TriviaQA dataset.

In [None]:
# Load a small subset of TriviaQA
dataset = load_dataset("trivia_qa", "rc", split="validation", streaming=True)
sample = next(iter(dataset))

question = sample["question"]
ground_truth_aliases = sample["answer"]["aliases"]

print(f"Question: {question}")
print(f"Ground Truth(s): {ground_truth_aliases}")

## 2. Load Model
Load the model for both generation and extraction. A consistent model is recommended.

In [None]:
model_name = "meta-llama/Meta-Llama-3-8B"
device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Loading {model_name} on {device}...")
try:
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    # using dtype per user request
    model = AutoModelForCausalLM.from_pretrained(model_name, dtype=torch.bfloat16 if device=="cuda" else torch.float32).to(device)
except Exception as e:
    print(f"Error loading {model_name}: {e}")
    print("Using fallback model...")
    model_name = "facebook/opt-125m"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(model_name).to(device)

## 3. Generate Answer
Generate a long-form answer for the question.

In [None]:
prompt = f"Q: {question}\nA:"

gen_output = generate_answer(
    prompt=prompt,
    model=model,
    tokenizer=tokenizer,
    max_new_tokens=100,
    do_sample=False,
    device=device
)

generated_text = gen_output['text']
print("Generated Answer:")
print(generated_text)

## 4. Extract Exact Answer
Extract the short/exact string from the generated long answer.

In [None]:
print(f"Extracting answer for: {question}")
exact_answer = extract_exact_answer(
    question=question,
    long_answer=generated_text,
    model=model,
    tokenizer=tokenizer,
    device=device
)

print(f"Extracted Exact Answer: '{exact_answer}'")

## 5. Compute Metrics on Exact Tokens
We locate the extracted answer within the original generation and compute:
- **Spilled Energy (Delta)**
- **Energy (E)**
- **Marginalized Energy (E_margin)**

For each, we check **Mean**, **Max**, **Min**, and **Sum**.

In [None]:
# 1. Find the substring in the generated text
# Clean quotes for matching
cleaned_exact_answer = exact_answer.strip("'\"").strip()
start_idx = generated_text.find(cleaned_exact_answer)

token_start = None
token_end = None

if start_idx == -1:
    print(f"Warning: Could not find extracted answer '{cleaned_exact_answer}' in generated text.")
    print("Calculating metrics on full generation instead.")
else:
    end_idx = start_idx + len(cleaned_exact_answer)
    print(f"Found answer at character indices: {start_idx}-{end_idx}")

    # 2. Map indices to tokens
    # Re-tokenize to get offset mapping. 
    enc = tokenizer(generated_text, return_offsets_mapping=True, add_special_tokens=False)
    offsets = enc.offset_mapping
    
    for i, (s, e) in enumerate(offsets):
        if s >= start_idx and token_start is None:
            token_start = i
        if s < end_idx:
            token_end = i + 1
            
    print(f"Token span: {token_start}-{token_end}")
    if token_start is not None and token_end is not None:
        print(f"Corresponding text: {tokenizer.decode(enc.input_ids[token_start:token_end])}")

# Prepare inputs
logits = torch.stack(gen_output['scores'], dim=1) # [batch, seq_len, vocab]
sequences = gen_output['sequences'] # [batch, total_len]

# Align logits and IDs (remove prompt from IDs)
input_len = sequences.shape[1] - logits.shape[1]
generated_ids = sequences[:, input_len:]

# Slice for exact answer if found
if token_start is not None and token_end is not None:
    # Check alignment
    if len(enc.input_ids) != generated_ids.shape[1]:
         print("Warning: Tokenizer mismatch (length). Alignment might be imperfect. Proceeding with mapping indices.")
    
    # Ensure indices are within bounds
    token_start = max(0, min(token_start, logits.shape[1]-1))
    token_end = max(token_start+1, min(token_end, logits.shape[1]))
    
    logits = logits[:, token_start:token_end, :]
    generated_ids = generated_ids[:, token_start:token_end]

# Compute Spilled Energy
logits_list = logits.cpu().float().numpy().tolist()
ids_list = generated_ids.cpu().numpy().tolist()

delta, E_margin, E = spilled_energy(
    logits=logits_list,
    ids=ids_list,
    beta=1.0
)

# Helper function to display stats
def display_stats(name, values):
    vals = np.array(values[0])
    if len(vals) == 0:
        print(f"{name}: [Empty]")
        return 0.0
    
    print(f"--- {name} ---")
    print(f"  Mean: {np.mean(vals):.4f}")
    print(f"  Max:  {np.max(vals):.4f}")
    print(f"  Min:  {np.min(vals):.4f}")
    print(f"  Sum:  {np.sum(vals):.4f}")
    return np.mean(vals)

print("=== Metrics on Exact Answer Tokens ===")
mean_se = display_stats("Spilled Energy (Delta)", delta)
display_stats("Energy (E)", E)
display_stats("Marginalized Energy (E_margin)", E_margin)


In [None]:
# Hallucination Detection Threshold Check (using Mean Spilled Energy)
SE_THRESHOLD = 0.5 

p_value = mean_se # Using mean SE as the primary score
is_hallucination = p_value > SE_THRESHOLD

print(f"Threshold: {SE_THRESHOLD}")
print(f"Score (Mean SE): {p_value:.4f}")
print(f"Prediction: {'HALLUCINATION' if is_hallucination else 'RELIABLE'}")

# Verify correctness
is_correct = any(alias.lower() in exact_answer.lower() for alias in ground_truth_aliases)
print(f"Actual status: {'Correct' if is_correct else 'Incorrect'}")