In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
model_name_or_path = "mistralai/Mistral-7B-Instruct-v0.2"
device = "cuda:0"
batch_size = 2

In [None]:
import torch
torch.cuda.empty_cache()

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig


model = AutoModelForCausalLM.from_pretrained(
    "mistralai/Mistral-7B-Instruct-v0.2",
    load_in_8bit=True,
    device_map="auto"
)

tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
tokenizer.pad_token = tokenizer.eos_token

generation_config = GenerationConfig.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")

In [None]:
messages = [
    [
        {
            "role": "user", 
            "content": "Tell me a bio of Albert Einstein."
        }
    ],
    [
        {
            "role": "user",
            "content": "Tell me a bio of Alla Pugacheva."
        }
    ],
    [
        {
            "role": "user",
            "content": "Tell me a bio of Paul McCartney."
        }
    ]
]

chat_messages = [tokenizer.apply_chat_template(m, tokenize=False) for m in messages]

In [None]:
import os

from lm_polygraph.model_adapters import WhiteboxModelBasic
from lm_polygraph.estimators import ClaimConditionedProbabilityClaim
from lm_polygraph.stat_calculators import *
from lm_polygraph.utils.openai_chat import OpenAIChat
from lm_polygraph.utils.deberta import Deberta


model_adapter = WhiteboxModelBasic(model, tokenizer)

calc_infer_llm = InferCausalLMCalculator(tokenize=False)

os.environ["OPENAI_API_KEY"] = "YOUR_API_KEY"
calc_claim_extractor = ClaimsExtractor(OpenAIChat("gpt-4"))

calc_claim_nli = GreedyAlternativesNLICalculator(Deberta(device=device))

estimator = ClaimConditionedProbabilityClaim()

In [None]:
from torch.utils.data import DataLoader


args_generate = {"generation_config" : generation_config,
                 "max_new_tokens": 100}

data_loader = DataLoader(chat_messages, batch_size=batch_size, shuffle=False, collate_fn=lambda x: x)
for batch in data_loader:
    encoded = tokenizer(batch, padding=True, return_tensors="pt")

    deps = {"model_inputs": encoded}
    deps.update(calc_infer_llm(
        deps, texts=batch, model=model_adapter, args_generate=args_generate))
    deps.update({"greedy_texts" : tokenizer.batch_decode(deps['greedy_tokens'])})
    deps.update(calc_claim_extractor(deps, texts=batch, model=model_adapter))
    deps.update(calc_claim_nli(deps, texts=None, model=model_adapter))

    uncertainty_scores = estimator(deps)

    for text, claims, ue_score in zip(deps["greedy_texts"], deps['claims'], uncertainty_scores):
        print("Output:", text)
        
        for claim, ue in zip(claims, ue_score):
            print("claim:", claim.claim_text)
            print("aligned tokens:", claim.aligned_token_ids)
            print("UE score:", ue)

        print()