In [1]:
from time import time
from transformers import AutoModelForCausalLM, AutoTokenizer

from lm_polygraph.estimators import TopologicalDivergence

from lm_polygraph.utils import WhiteboxModel, UEManager, Dataset
from lm_polygraph.utils.builder_enviroment_stat_calculator import BuilderEnvironmentStatCalculator
from lm_polygraph.utils.factory_stat_calculator import StatCalculatorContainer
from lm_polygraph.utils.estimate_uncertainty import UncertaintyOutput

from lm_polygraph.defaults.register_default_stat_calculators import (
    register_default_stat_calculators,
)
from lm_polygraph.stat_calculators import TrainMTopDivCalculator


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def estimate_uncertainty(
    model, estimator, available_stat_calculators, input_text: str
) -> UncertaintyOutput:
    man = UEManager(
        Dataset([input_text], [""], batch_size=1),
        model,
        [estimator],
        available_stat_calculators=available_stat_calculators,
        builder_env_stat_calc=BuilderEnvironmentStatCalculator(model),
        generation_metrics=[],
        ue_metrics=[],
        processors=[],
        ignore_exceptions=False,
        verbose=False,
    )
    man()
    ue = man.estimations[estimator.level, str(estimator)]
    texts = man.stats.get("greedy_texts", None)
    tokens = man.stats.get("greedy_tokens", None)
    if tokens is not None and len(tokens) > 0:
        # Remove last token, which is the end of the sequence token
        # since we don't include it's uncertainty in the estimator's output
        tokens = tokens[0][:-1]
    return UncertaintyOutput(
        ue[0], input_text, texts[0], tokens, model.model_path, str(estimator)
    )


In [3]:
class Config:
    model_name = 'Qwen/Qwen2.5-0.5B-Instruct'
    train_dataset = '../../coqa_Meta-Llama-3-8B-Instruct.csv'
    context_column = 'context'
    question_column = 'question'
    prompt_column = 'prompt'
    response_column = 'generated_answer'
    label_column = 'hallucination'
    batch_size = 1
    subsample_train_dataset = 4
    seed = 52
    device_map = 'mps'       
    max_heads = 6
    n_jobs = -1

cfg = Config()

base_model = AutoModelForCausalLM.from_pretrained(
    cfg.model_name,
    device_map=cfg.device_map,
)
tokenizer = AutoTokenizer.from_pretrained(cfg.model_name)
model = WhiteboxModel(base_model, tokenizer)
available_stat_calculators = (
    register_default_stat_calculators("Whitebox")
)
sc = StatCalculatorContainer(
    name=TrainMTopDivCalculator.__name__,
    obj=TrainMTopDivCalculator,
    builder="lm_polygraph.defaults.stat_calculator_builders.default_TrainMTopDivCalculator",
    cfg=cfg,
    dependencies=TrainMTopDivCalculator.meta_info()[1],
    stats=TrainMTopDivCalculator.meta_info()[0],
)
available_stat_calculators.append(sc)


In [4]:
selected_heads = None
input_text = 'How many floors are in the Empire State Building?'

estimator = TopologicalDivergence(
    heads=selected_heads,
    max_heads=cfg.max_heads,
    n_jobs=cfg.n_jobs
)

start = time()
print(estimate_uncertainty(
    model,
    estimator,
    available_stat_calculators,
    input_text
))
print(f"Time taken: {time() - start:.2f} seconds")
print(f"Best heads: {estimator.best_heads}")

100%|██████████| 4/4 [00:43<00:00, 10.77s/it]


UncertaintyOutput(uncertainty=11.994889825582504, input_text='How many floors are in the Empire State Building?', generation_text='The Empire State Building has 105 floors.', generation_tokens=[785, 20448, 3234, 16858, 702, 220, 16, 15, 20, 25945, 13], model_path=None, estimator='TopologicalDivergence')
Time taken: 53.50 seconds
Best heads: [(8, 3), (0, 1)]


In [5]:
input_text = 'What has a head and a tail but no body?'

start = time()
print(estimate_uncertainty(
    model,
    estimator,
    available_stat_calculators,
    input_text
))
print(f"Time taken: {time() - start:.2f} seconds")

UncertaintyOutput(uncertainty=87.97025752067566, input_text='What has a head and a tail but no body?', generation_text="The answer to this question is a virus. Viruses are small, non-living entities that can only replicate within living cells. They do not have a head or a tail, but they can cause harm to living organisms by attaching to and hijacking the host cell's machinery. Viruses are a significant threat to public health and can cause a wide range of diseases, including but not limited to, influenza, HIV, and cancer.", generation_tokens=[785, 4226, 311, 419, 3405, 374, 264, 16770, 13, 9542, 4776, 525, 2613, 11, 2477, 2852, 2249, 14744, 429, 646, 1172, 45013, 2878, 5382, 7761, 13, 2379, 653, 537, 614, 264, 1968, 476, 264, 9787, 11, 714, 807, 646, 5240, 11428, 311, 5382, 43204, 553, 71808, 311, 323, 21415, 8985, 279, 3468, 2779, 594, 25868, 13, 9542, 4776, 525, 264, 5089, 5899, 311, 584, 2820, 323, 646, 5240, 264, 6884, 2088, 315, 18808, 11, 2670, 714, 537, 7199, 311, 11, 61837, 11,