In [1]:
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
from luh import AutoUncertaintyHead

from lm_polygraph import CausalLMWithUncertainty

from luh.calculator_infer_luh import CalculatorInferLuh
from luh.calculator_apply_uq_head import CalculatorApplyUQHead
from luh.luh_estimator_dummy import LuhEstimatorDummy

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# load model and uhead
model_name = "mistralai/Mistral-7B-Instruct-v0.2"
uhead_name = "llm-uncertainty-head/uhead_Mistral-7B-Instruct-v0.2"

llm = AutoModelForCausalLM.from_pretrained(
    model_name, device_map="cuda")
tokenizer = AutoTokenizer.from_pretrained(
    model_name)
tokenizer.pad_token = tokenizer.eos_token
uhead = AutoUncertaintyHead.from_pretrained(
    uhead_name, base_model=llm)

Loading checkpoint shards: 100%|██████████| 3/3 [00:28<00:00,  9.34s/it]


In [3]:
generation_config = GenerationConfig.from_pretrained(model_name)
args_generate = {"generation_config": generation_config,
                 "max_new_tokens": 50}
calc_infer_llm = CalculatorInferLuh(uhead, 
                                    tokenize=True, 
                                    args_generate=args_generate,
                                    device="cuda",
                                    generations_cache_dir="",
                                    predict_token_uncertainties=True)

estimator = LuhEstimatorDummy()
llm_adapter = CausalLMWithUncertainty(llm, tokenizer=tokenizer, stat_calculators=[calc_infer_llm], estimator=estimator)

In [4]:

# prepare text ...

messages = [
    [
        {
            "role": "user", 
            "content": "In which year did the programming language Mercury first appear? Answer with a year only."
        }
    ]
]
# The correct answer is 1995

chat_messages = [tokenizer.apply_chat_template(m, tokenize=False, add_bos_token=False) for m in messages]
inputs = tokenizer(chat_messages, return_tensors="pt", padding=True, truncation=True, add_special_tokens=False).to("cuda")

output = llm_adapter.generate(inputs["input_ids"])
output["uncertainty_score"]

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


odict_keys(['sequences', 'scores', 'attentions', 'hidden_states', 'past_key_values', 'full_attention_mask', 'context_lengths'])


  output = torch._nested_tensor_from_mask(


[[0.6172351837158203,
  0.4881523847579956,
  0.5772963762283325,
  0.7417271733283997,
  0.7467945218086243,
  0.393752783536911,
  0.4387787878513336,
  0.5772822499275208,
  0.6776686906814575,
  0.7247191071510315,
  0.8887840509414673,
  0.9094085097312927,
  0.8421838283538818,
  0.8859848380088806,
  0.851408839225769,
  0.8874452114105225,
  0.8746453523635864,
  0.9152727127075195,
  0.8852745890617371,
  0.8997197151184082]]

In [5]:
print("Model response and uncertainty scores:")
print(f'Response: {tokenizer.batch_decode(output["sequences"][:,len(inputs["input_ids"][0]):])}')
print(f'UE Scores: {output["uncertainty_score"][0]}')


Model response and uncertainty scores:
Response: ['Mercury is a logic programming language that was first announced in 1993. However,']
UE Scores: [0.6172351837158203, 0.4881523847579956, 0.5772963762283325, 0.7417271733283997, 0.7467945218086243, 0.393752783536911, 0.4387787878513336, 0.5772822499275208, 0.6776686906814575, 0.7247191071510315, 0.8887840509414673, 0.9094085097312927, 0.8421838283538818, 0.8859848380088806, 0.851408839225769, 0.8874452114105225, 0.8746453523635864, 0.9152727127075195, 0.8852745890617371, 0.8997197151184082]


In [6]:
def highlight_html_tokens(
    token_ids,
    positions_to_highlight,
    tokenizer,
    color="red",
    font_weight="bold"
):
    """
    Convert a list of token IDs into a readable string, highlight tokens at
    the specified positions in `positions_to_highlight`, and remove the leading
    '▁' that Mistral/Llama tokenizers use for word boundaries.
    
    Args:
        token_ids (List[int]): The sequence of token IDs.
        tokenizer: A Hugging Face tokenizer (e.g., for mistralai/Mistral-7B-Instruct-v0.2).
        positions_to_highlight (Set[int] or List[int]): 0-based indices of tokens to highlight.
        color (str): CSS color for the highlighted text (default "red").
        font_weight (str): CSS font weight (default "bold").
    
    Returns:
        str: An HTML string with some tokens highlighted.
    """
    # Convert the IDs to subword tokens (may contain leading "▁")
    raw_tokens = tokenizer.convert_ids_to_tokens(token_ids)
    
    # Ensure positions_to_highlight is a set for quick membership check
    if not isinstance(positions_to_highlight, set):
        positions_to_highlight = set(positions_to_highlight)
    
    final_pieces = []
    
    for idx, token in enumerate(raw_tokens):
        # If the token starts with "▁", replace that with a literal space
        if token.startswith("▁"):
            display_str = " " + token[1:]
        else:
            display_str = token
        
        # If this position is in positions_to_highlight, wrap in <span>
        if idx in positions_to_highlight:
            display_str = (
                f"<span style='color:{color}; font-weight:{font_weight};'>"
                f"{display_str}"
                "</span>"
            )
        
        final_pieces.append(display_str)
    
    # Join everything without extra spaces
    return "".join(final_pieces)

In [7]:
from IPython.display import HTML


def highlight_uncertain_claims(uncertainties, generated_tokens, claims):
    threshold = 0.5
    tokens_to_highlight = set()

    for ue_score, claim in zip(uncertainties, claims):
        if ue_score > threshold:
            tokens_to_highlight.update([claim])
    
    display(HTML(highlight_html_tokens(generated_tokens, tokens_to_highlight, tokenizer)))

In [8]:
highlight_uncertain_claims(
    output["uncertainty_score"][0],
    output["sequences"][:,len(inputs["input_ids"][0]):][0],
    list(range(len(output["sequences"][:,len(inputs["input_ids"][0]):][0]))),
)