In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ["OPENAI_API_KEY"] = "add your key here"

import sys
sys.path.append('../')

In [None]:
import torch

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig

model_name = "mistralai/Mistral-7B-Instruct-v0.2"
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    load_in_8bit=False,
    device_map="auto"
)
model.pad_token_id = model.config.eos_token_id

tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
generation_config = GenerationConfig.from_pretrained(model_name)

In [None]:
from lm_polygraph.model_adapters import WhiteboxModelBasic
from lm_polygraph.stat_calculators.extract_claims import ClaimsExtractor
from lm_polygraph.utils.openai_chat import OpenAIChat

from luh.luh_claim_estimator import LuhClaimEstimator
from luh.calculator_infer_luh import CalculatorInferLuh
from luh.auto_uncertainty_head import AutoUncertaintyHead


model_adapter = WhiteboxModelBasic(model=model, 
                                   tokenizer=tokenizer, 
                                   tokenizer_args={"add_special_tokens": False, 
                                                   "return_tensors": "pt", 
                                                   "padding": True, "truncation": True},
                                   model_type="CausalLM")
model_adapter.model_path = model_name

args_generate = {"generation_config": generation_config,
                 "max_new_tokens": 50}
uq_head = "llm-uncertainty-head/saplma_Mistral-7B-Instruct-v0.2"
uncertainty_head = AutoUncertaintyHead.from_pretrained(uq_head, base_model=model)
calc_infer_llm = CalculatorInferLuh(uncertainty_head, 
                                    tokenize=True, 
                                    args_generate=args_generate,
                                    device="cuda",
                                    generations_cache_dir=""
                                    )

openai_chat = OpenAIChat(cache_path='./workdir/cache', openai_model="gpt-4o")
calc_extract_claims = ClaimsExtractor(openai_chat=openai_chat)

estimator = LuhClaimEstimator(reduce_type="mean")

In [None]:
messages = [
    [
        {
            "role": "user", 
            "content": "How many fingers are on a coala's foot?"
        }
    ],
    [
        {
            "role": "user",
            "content": "Who sang a song Yesterday?"
        }
    ],
    [
        {
            "role": "user",
            "content": "Who sang a song Кукла Колдуна?"
        }
    ],
    [
        {
            "role": "user",
            "content": "Translate into French: 'I want a small cup of coffee'"
        }
    ]
]

chat_messages = [tokenizer.apply_chat_template(m, tokenize=False, add_bos_token=False) for m in messages]

In [None]:
from tqdm import tqdm
from torch.utils.data import DataLoader


data_loader = DataLoader(chat_messages, batch_size=2,
                         shuffle=False, collate_fn=lambda x: x)

all_generated_tokens = []
all_claims = []
all_uncertainties = []
for texts in tqdm(data_loader):
    deps = dict()
    print("Performing inference...")
    deps.update(calc_infer_llm(deps, texts=texts, model=model_adapter))
    print(model_adapter.tokenizer.decode(deps['greedy_tokens'][0]))
    print("Extracting claims...")
    deps.update(calc_extract_claims(
        deps, texts=texts, model=model_adapter))
    print("Estimating uncertainty...")
    uncertainty_score = estimator(deps)

    all_generated_tokens += deps['greedy_tokens']
    all_claims += deps['claims']
    all_uncertainties += uncertainty_score
    print("Results:")
    for doc_claims, doc_ues in zip(deps['claims'], uncertainty_score):
        for claim, ue in zip(doc_claims, doc_ues):
            print(ue, claim)
        
        print("\n")

In [None]:
def highlight_html_tokens(
    token_ids,
    positions_to_highlight,
    tokenizer,
    color="red",
    font_weight="bold"
):
    """
    Convert a list of token IDs into a readable string, highlight tokens at
    the specified positions in `positions_to_highlight`, and remove the leading
    '▁' that Mistral/Llama tokenizers use for word boundaries.
    
    Args:
        token_ids (List[int]): The sequence of token IDs.
        tokenizer: A Hugging Face tokenizer (e.g., for mistralai/Mistral-7B-Instruct-v0.2).
        positions_to_highlight (Set[int] or List[int]): 0-based indices of tokens to highlight.
        color (str): CSS color for the highlighted text (default "red").
        font_weight (str): CSS font weight (default "bold").
    
    Returns:
        str: An HTML string with some tokens highlighted.
    """
    # Convert the IDs to subword tokens (may contain leading "▁")
    raw_tokens = tokenizer.convert_ids_to_tokens(token_ids)
    
    # Ensure positions_to_highlight is a set for quick membership check
    if not isinstance(positions_to_highlight, set):
        positions_to_highlight = set(positions_to_highlight)
    
    final_pieces = []
    
    for idx, token in enumerate(raw_tokens):
        # If the token starts with "▁", replace that with a literal space
        if token.startswith("▁"):
            display_str = " " + token[1:]
        else:
            display_str = token
        
        # If this position is in positions_to_highlight, wrap in <span>
        if idx in positions_to_highlight:
            display_str = (
                f"<span style='color:{color}; font-weight:{font_weight};'>"
                f"{display_str}"
                "</span>"
            )
        
        final_pieces.append(display_str)
    
    # Join everything without extra spaces
    return "".join(final_pieces)

In [None]:
from IPython.display import HTML


def highlight_uncertain_claims(uncertainties, generated_tokens, claims):
    threshold = 0.5
    tokens_to_highlight = set()

    for ue_score, claim in zip(uncertainties, claims):
        if ue_score > threshold:
            tokens_to_highlight.update(claim.aligned_token_ids)
    
    display(HTML(highlight_html_tokens(generated_tokens, tokens_to_highlight, model_adapter.tokenizer)))

In [None]:
idx = 2
highlight_uncertain_claims(all_uncertainties[idx], all_generated_tokens[idx], all_claims[idx])