In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "0"

import sys
sys.path.append('../')

In [None]:
import torch

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [None]:
model_name = "mistralai/Mistral-7B-Instruct-v0.2"
uhead_name = "llm-uncertainty-head/uhead_Mistral-7B-Instruct-v0.2"

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from luh import AutoUncertaintyHead, CausalLMWithUncertainty

llm = AutoModelForCausalLM.from_pretrained(
    model_name, device_map="cuda")
tokenizer = AutoTokenizer.from_pretrained(
    model_name)
tokenizer.pad_token = tokenizer.eos_token
uhead = AutoUncertaintyHead.from_pretrained(
    uhead_name, base_model=llm)
llm_adapter = CausalLMWithUncertainty(llm, uhead, tokenizer=tokenizer)

In [None]:
# text and prepare inputs ...
messages = [
    [
        {
            "role": "user", 
            "content": "How many fingers are on a coala's foot?"
        }
    ],
    [
        {
            "role": "user",
            "content": "Who sang a song Yesterday?"
        }
    ],
    [
        {
            "role": "user",
            "content": "Who sang a song Кукла Колдуна?"
        }
    ],
    [
        {
            "role": "user",
            "content": "Translate into French: 'I want a small cup of coffee'"
        }
    ]
]

chat_messages = [tokenizer.apply_chat_template(m, tokenize=False, add_bos_token=False) for m in messages]
inputs = tokenizer(chat_messages, return_tensors="pt", padding=True, truncation=True, add_special_tokens=False).to("cuda")

In [None]:
output = llm_adapter.generate(inputs)

In [None]:
def highlight_html_tokens(
    token_ids,
    positions_to_highlight,
    tokenizer,
    color="red",
    font_weight="bold"
):
    """
    Convert a list of token IDs into a readable string, highlight tokens at
    the specified positions in `positions_to_highlight`, and remove the leading
    '▁' that Mistral/Llama tokenizers use for word boundaries.
    
    Args:
        token_ids (List[int]): The sequence of token IDs.
        tokenizer: A Hugging Face tokenizer (e.g., for mistralai/Mistral-7B-Instruct-v0.2).
        positions_to_highlight (Set[int] or List[int]): 0-based indices of tokens to highlight.
        color (str): CSS color for the highlighted text (default "red").
        font_weight (str): CSS font weight (default "bold").
    
    Returns:
        str: An HTML string with some tokens highlighted.
    """
    # Convert the IDs to subword tokens (may contain leading "▁")
    raw_tokens = tokenizer.convert_ids_to_tokens(token_ids)
    
    # Ensure positions_to_highlight is a set for quick membership check
    if not isinstance(positions_to_highlight, set):
        positions_to_highlight = set(positions_to_highlight)
    
    final_pieces = []
    
    for idx, token in enumerate(raw_tokens):
        # If the token starts with "▁", replace that with a literal space
        if token.startswith("▁"):
            display_str = " " + token[1:]
        else:
            display_str = token
        
        # If this position is in positions_to_highlight, wrap in <span>
        if idx in positions_to_highlight:
            display_str = (
                f"<span style='color:{color}; font-weight:{font_weight};'>"
                f"{display_str}"
                "</span>"
            )
        
        final_pieces.append(display_str)
    
    # Join everything without extra spaces
    return "".join(final_pieces)

In [None]:
from IPython.display import HTML

from scipy.special import expit


def highlight_uncertain_tokens(generated_tokens, uncertain_logits, threshold=0.5):
    """
    Highlight tokens in a generated sequence based on their uncertainty scores.
    
    Args:
        generated_tokens (List[int]): The sequence of generated token IDs.
        uncertain_logits (List[float]): The uncertainty scores for each token.
        threshold (float): The threshold for considering a token uncertain (default 0.5).
    
    Returns:
        str: An HTML string with some tokens highlighted based on uncertainty.
    """
    # Find the positions of uncertain tokens
    uncertain_probs = expit(uncertain_logits)
    uncertain_positions = [idx for idx, score in enumerate(uncertain_probs) if score > threshold]
    
    # Highlight the tokens at those positions
    return display(HTML(highlight_html_tokens(generated_tokens, uncertain_positions, tokenizer)))

In [None]:
idx = 2
highlight_uncertain_tokens(output["greedy_tokens"][idx], output["uncertainty_logits"][idx], threshold=0.3)