Original Paper Link: [ConSens: Assessing context grounding in open-book question answering](https://arxiv.org/abs/2505.00065)

In [None]:
from huggingface_hub import login
login("hf_XXXXXX")  # Replace with your actual Hugging Face token

In [2]:
import torch
import torch.nn.functional as F

from transformers import AutoModelForCausalLM, AutoTokenizer

In [3]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

device

device(type='cuda')

In [4]:
def compute_log_prob(model, tokenizer, answer):
    """
    Compute log P(answer)
    Returns the average log probability per token.
    """

    inputs = tokenizer(answer, return_tensors="pt").to(device)

    with torch.inference_mode():
        outputs = model(**inputs)

    logits = outputs.logits

    # shift for next token prediction
    shift_logits = logits[..., :-1, :].contiguous()
    shift_labels = inputs['input_ids'][..., 1:].contiguous()

    # Log-softmax over vocabulary
    log_probs = F.log_softmax(shift_logits, dim=-1)

    # Log probs of the actual tokens
    token_log_probs = log_probs.gather(dim=-1, index=shift_labels.unsqueeze(-1)).squeeze(-1)

    return token_log_probs.mean().item()


def compute_log_prob_given_context(model, tokenizer, context, answer):
    """
    Compute log P(answer | context)
    Returns the average log probability per token given the context.
    """

    # Tokenize both context and answer and move to device
    context_inputs = tokenizer(context, return_tensors="pt").to(device)
    full_text = context + " " + answer
    full_inputs = tokenizer(full_text, return_tensors="pt").to(device)

    context_len = context_inputs['input_ids'].size(1)
    full_len = full_inputs['input_ids'].size(1)
    answer_len = full_len - context_len

    with torch.inference_mode():
        outputs = model(**full_inputs)

    logits = outputs.logits

    # Only consider logits corresponding to answer tokens
    answer_logits = logits[0, context_len-1:context_len-1+answer_len]
    answer_ids = full_inputs['input_ids'][0, context_len:context_len+answer_len]

    log_probs = F.log_softmax(answer_logits, dim=-1)
    token_log_probs = log_probs.gather(dim=-1, index=answer_ids.unsqueeze(-1)).squeeze(-1)

    return token_log_probs.mean().item()


def compute_consens_score(answer, context):
    """
    Compute ConSens score: measures how much context changes the probability of an answer.

    ConSens = 2 * sigmoid(log P(answer) - log P(answer|context)) - 1

    Returns a value between -1 and 1
    """

    model_name = "EleutherAI/gpt-neo-125M"

    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        device_map={"": device.type},
        torch_dtype=torch.float16 if device.type != "cpu" else torch.float32,
        use_cache=True,
        # flash_attn=True,
    ).to(device)

    tokenizer = AutoTokenizer.from_pretrained(model_name)

    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    log_pa = compute_log_prob(model, tokenizer, answer)
    log_pac = compute_log_prob_given_context(model, tokenizer, context, answer)

    r = log_pa - log_pac
    consens_score = 2 * torch.sigmoid(torch.tensor(r)) - 1

    return consens_score.item()

Let's say we ask:

> **What is the capital of France?**

1. If relevant information exists in the retrieval context and the model becomes more confident in the correct answer after seeing context,
    * then ConSens < 0 → context helps, model uses it.

2. If the model was already confident without context and the context didn’t help (or hurt),
    * ConSens ≥ 0 → context did not help, possible memorization or hallucination if the answer is wrong.

3. If the context contradicts the correct answer and the model’s confidence drops,
    * ConSens > 0 → context hurts the model → can surface hallucination risks or confusion.

In [40]:
# Example where context should help

answer1 = "Paris is the capital of France."
context1 = "France is a country in Europe. It's capital is known for art and culture."

score2 = compute_consens_score(answer1, context1)
print(f"Context helps case - ConSens Score: {score2:.2f}")

Context helps case - ConSens Score: -0.46


In [41]:
# Example where context might hurt

answer2 = "Paris is the capital of France."
context2 = "In Douglas Adams' Hitchhiker's Guide to the Galaxy, what is the answer to life, the universe, and everything?"

score3 = compute_consens_score(answer2, context2)
print(f"Context hurts case - ConSens Score: {score3:.2f}")

Context hurts case - ConSens Score: 0.18
