In [None]:
%load_ext autoreload
%autoreload 2

# Specify HyperParameters

In [None]:
model_name_or_path = "mistralai/Mistral-7B-Instruct-v0.2"
device = "cuda:0"
dataset_name = "../workdir/data/triviaqa.csv"
batch_size = 2

# Initialize Model

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig


model = AutoModelForCausalLM.from_pretrained(
    "mistralai/Mistral-7B-Instruct-v0.2",
    load_in_8bit=True,
    device_map="auto"
)

tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
tokenizer.pad_token = tokenizer.eos_token

generation_config = GenerationConfig.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")

In [None]:
messages = [
    [
        {
            "role": "user", 
            "content": "How many fingers on a coala's foot?"
        }
    ],
    [
        {
            "role": "user",
            "content": "Who sang a song Yesterday?"
        }
    ],
    [
        {
            "role": "user",
            "content": "Кто спел песню Кукла Колдуна?"
        }
    ],
    [
        {
            "role": "user",
            "content": "Translate into French: 'I want a small cup of coffee'"
        }
    ]
]

chat_messages = [tokenizer.apply_chat_template(m, tokenize=False) for m in messages]

# Infer LLM and get uncertainty scores

In [None]:
from lm_polygraph.stat_calculators.infer_causal_lm_calculator import InferCausalLMCalculator
from lm_polygraph.stat_calculators.greedy_alternatives_nli import GreedyAlternativesNLICalculator
from lm_polygraph.estimators.claim_conditioned_probability import ClaimConditionedProbability
from lm_polygraph.utils.deberta import Deberta
from lm_polygraph.model_adapters import WhiteboxModelBasic

from torch.utils.data import DataLoader


model_adapter = WhiteboxModelBasic(model, tokenizer)

calc_infer_llm = InferCausalLMCalculator(tokenize=False)
nli_model = Deberta(device=device)
nli_model.setup()
calc_nli = GreedyAlternativesNLICalculator(nli_model=nli_model)

args_generate = {"generation_config" : generation_config,
                 "max_new_tokens": 30}

estimator = ClaimConditionedProbability()

data_loader = DataLoader(chat_messages, batch_size=batch_size, shuffle=False, collate_fn=lambda x: x)
for batch in data_loader:
    encoded = tokenizer(batch, padding=True, return_tensors="pt")

    deps = {"model_inputs": encoded}
    deps.update(calc_infer_llm(
        deps, texts=batch, model=model_adapter, args_generate=args_generate))
    deps.update(calc_nli(deps, texts=None, model=model_adapter))

    uncertainty_scores = estimator(deps)
    generated_texts = tokenizer.batch_decode(deps['greedy_tokens'])
    
    for text, ue_score in zip(generated_texts, uncertainty_scores):
        print("Output:", text)
        print("Uncertainty score:", ue_score)
        print()