In [1]:
%load_ext autoreload
%autoreload 2

# Specify HyperParameters

In [2]:
model_name_or_path = "mistralai/Mistral-7B-Instruct-v0.2"
device = "cuda:0"
dataset_name = "../workdir/data/triviaqa.csv"
batch_size = 4
seed = 42

# Initialize Model

In [3]:
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig

model = AutoModelForCausalLM.from_pretrained(
    "mistralai/Mistral-7B-Instruct-v0.2",
    load_in_8bit=True,
    device_map="auto"
)

tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")

generation_config = GenerationConfig.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")

  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [4]:
messages = [
    [
        {
            "role": "user", 
            "content": "How many fingers are on a coala's foot?"
        }
    ],
    [
        {
            "role": "user",
            "content": "Who sang a song Yesterday?"
        }
    ],
    [
        {
            "role": "user",
            "content": "Кто спел песню Кукла Колдуна?"
        }
    ],
    [
        {
            "role": "user",
            "content": "Translate into French: 'I want a small cup of coffee'"
        }
    ]
]

chat_messages = [tokenizer.apply_chat_template(m, tokenize=False) for m in messages]

In [5]:
from lm_polygraph.stat_calculators.basic_probs import BasicGreedyProbsCalculatorCausalLM
from lm_polygraph.stat_calculators.entropy import EntropyCalculator
from lm_polygraph.estimators import MaximumSequenceProbability, MeanTokenEntropy

from torch.utils.data import DataLoader


model_wrapper = BasicGreedyProbsCalculatorCausalLM()
calc_entropy = EntropyCalculator()

args_generate = {"generation_config" : generation_config,
                 "max_new_tokens": 30}

#estimator = MaximumSequenceProbability()
estimator = MeanTokenEntropy()

tokenizer.pad_token = tokenizer.eos_token

data_loader = DataLoader(chat_messages, batch_size=2, shuffle=False, collate_fn=lambda x: x)

for batch in data_loader:
    encoded = tokenizer(batch, padding=True, return_tensors="pt")
    out = model_wrapper(encoded, model, args_generate)
    out.update(calc_entropy(out))

    uncertianty_score = estimator(out)
    result = tokenizer.batch_decode(out['greedy_tokens'])

    for text, ue_score in zip(result, uncertianty_score):
        print("Output:", text)
        print("UE score:", ue_score)
        print()

  _torch_pytree._register_pytree_node(
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
2024-04-07 22:58:49.097208: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2024-04-07 22:58:49.419301: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-04-07 22:58:49.419345: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-04-07 22:58:49.421560: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register fac

Output: A koala's paws have five digits, similar to a human hand. Each paw has two opposable thumbs, which are
UE score: 1.3626975e-05

Output: The song "Yesterday" was written and first performed by the English singer-songwriter Paul McCartney. It was originally credited to
UE score: 9.504515e-06

Output: The song "Kukla Koldun" is a popular Russian children's song. The original version was recorded by the Soviet singer, Y
UE score: 2.5494732e-05

Output: In French, "I want a small cup of coffee" can be translated as "Je veux une tasse petite de café" or "
UE score: 6.462212e-06

