In [None]:
%load_ext autoreload
%autoreload 2

# Specify HyperParameters

In [None]:
model_name_or_path = "meta-llama/Llama-3.1-8B-Instruct"
device = "cuda:0"
dataset_name = "../workdir/data/triviaqa.csv"
batch_size = 2

# Initialize Model

In [None]:
import os
from vllm import LLM, SamplingParams

os.environ["CUDA_VISIBLE_DEVICES"] = "2"

llm = LLM(model=model_name_or_path, gpu_memory_utilization=0.5)
sampling_params = SamplingParams(max_tokens=30, logprobs=20)

In [None]:
messages = [
    [
        {
            "role": "user", 
            "content": "How many fingers on a coala's foot?"
        }
    ],
    [
        {
            "role": "user",
            "content": "Who sang a song Yesterday?"
        }
    ],
    [
        {
            "role": "user",
            "content": "Кто спел песню Кукла Колдуна?"
        }
    ],
    [
        {
            "role": "user",
            "content": "Translate into French: 'I want a small cup of coffee'"
        }
    ]
]

tokenizer = llm.get_tokenizer()
chat_messages = [tokenizer.apply_chat_template(m, tokenize=False) for m in messages]

# Infer LLM and get uncertainty scores

In [None]:
from typing import List
from lm_polygraph.utils.model import Model
from transformers.generation import GenerateDecoderOnlyOutput

class WhiteboxModelvLLM(Model):
    """Basic whitebox model adapter for using vLLM in stat calculators and uncertainty estimators."""

    def __init__(self, model: LLM, device: str = "cuda"):
        self.model = model
        self.tokenizer = self.model.get_tokenizer()
        self.base_device = device
        self.model_type = "vLLMCausalLM"
        
    def generate(self, *args, **kwargs):
        sampling_params = kwargs.pop("sampling_params")
        sampling_params.n = kwargs.get("num_return_sequences", 1)
        output = self.model.generate(*args, sampling_params)
        return self.post_processing(output)

    def device(self):
        return self.base_device

    def tokenize(self, *args, **kwargs):
        return self.tokenizer(*args, **kwargs)

    def __call__(self, *args, **kwargs):
        return self.generate(*args, **kwargs)

    def generate_texts(self, input_texts: List[str], **args):
        outputs = self.generate(input_texts, **args)
        texts = [
            outputs.text
            for sampled_outputs in outputs
            for output in sampled_outputs.outputs
        ]      
        return texts

    def post_processing(self, outputs):
        
        standard_output = GenerateDecoderOnlyOutput()
        vocab_size = max(self.tokenizer.vocab_size, max(self.tokenizer.added_tokens_decoder.keys()))
        logits = []
        sequences = []

        max_seq_len = max([
            len(output.token_ids)
            for sampled_outputs in outputs
            for output in sampled_outputs.outputs
        ])
        for sample_output in outputs:

            for output in sample_output.outputs:

                log_prob = torch.zeros((max_seq_len, vocab_size)).fill_(-torch.inf)
                sequence = torch.zeros(max_seq_len).fill_(self.tokenizer.eos_token_id).long()
    
                for i, probs in enumerate(output.logprobs):
                    top_tokens = torch.tensor(list(probs.keys()))
                    top_values = torch.tensor([lp.logprob for lp in probs.values()])
                    log_prob[i, top_tokens] = top_values
                    sequence[i] = output.token_ids[i]

                logits.append(log_prob)
                sequences.append(sequence)

        standard_output.logits = logits
        standard_output.scores = logits
        standard_output.sequences = sequences

        return standard_output

In [None]:
from lm_polygraph.stat_calculators import StatCalculator

from typing import Dict, List, Tuple

import numpy as np

import torch


class InfervLLMCalculator(StatCalculator):
    """
    Performs inference of the model and ensures that output contains
    1. logprobas
    2. tokens
    3. embeddings

    For Whitebox model (lm_polygraph.WhiteboxModel), at input texts batch calculates:
    * generation texts
    * tokens of the generation texts
    * probabilities distribution of the generated tokens
    """

    def __init__(
        self,
        n_alternatives: int = 10,
        return_embeddings: bool = False,
    ):
        super().__init__()

        self.n_alternatives = n_alternatives
        self._return_embeddings = return_embeddings # not supported by vLLM

    @staticmethod
    def meta_info() -> Tuple[List[str], List[str]]:
        """
        Returns the statistics and dependencies for the calculator.
        """

        return [
            "greedy_log_probs",
            "greedy_logits",
            "greedy_tokens",
            "greedy_log_likelihoods",
            "greedy_tokens_alternatives",
        ], []

    def _post_process_logits(self, out, model_inputs, eos_token_id):
        cut_logits = []
        cut_sequences = []
        cut_log_probs = []
        cut_alternatives = []
        lls = []
        
        for i in range(len(model_inputs)):
        
            seq = np.array(out.sequences[i])
            length = len(seq)
            for j in range(len(seq)):
                if seq[j] == eos_token_id:
                    length = j + 1
                    break
                    
            tokens = seq[:length].tolist()
            cut_sequences.append(tokens)

            log_probs = out.scores[i][:length, :]
            logits = np.exp(log_probs)
            
            cut_logits.append(log_probs)
            cut_log_probs.append(log_probs)
            lls.append([log_probs[j, tokens[j]] for j in range(len(log_probs))])
        
            cut_alternatives.append([[] for _ in range(length)])
            for j in range(length):
                lt = logits[j, :]
                best_tokens = np.argpartition(lt, -self.n_alternatives)
                ln = len(best_tokens)
                best_tokens = best_tokens[ln - self.n_alternatives : ln]
                for t in best_tokens:
                    cut_alternatives[-1][j].append((t, lt[t]))
        
                cut_alternatives[-1][j].sort(
                    key=lambda x: x[0] == cut_sequences[-1][j],
                    reverse=True,
                )
        
        result_dict = {
            "greedy_log_probs": cut_log_probs,
            "greedy_logits": cut_logits,
            "greedy_tokens": cut_sequences,
            "greedy_log_likelihoods": lls,
            "greedy_tokens_alternatives": cut_alternatives,
        }
        
        return result_dict

    def __call__(
        self,
        dependencies: Dict[str, np.array],
        texts: List[str],
        model: Model,
        max_new_tokens: int = 100,
        **kwargs,
    ) -> Dict[str, np.ndarray]:
        """
        Calculates the statistics of probabilities at each token position in the generation.

        Parameters:
            dependencies (Dict[str, np.ndarray]): input statistics, can be empty (not used).
            texts (List[str]): Input texts batch used for model generation.
            model (Model): Model used for generation.
            max_new_tokens (int): Maximum number of new tokens at model generation. Default: 100.
        Returns:
            Dict[str, np.ndarray]: dictionary with the following items:
                - 'greedy_log_probs' (List[List[np.array]]): logarithms of autoregressive
                        probability distributions at each token,
                - 'greedy_texts' (List[str]): model generations corresponding to the inputs,
                - 'greedy_tokens' (List[List[int]]): tokenized model generations,
                - 'greedy_log_likelihoods' (List[List[float]]): log-probabilities of the generated tokens.
        """
        out = model.generate(texts, **kwargs)
        result_dict = self._post_process_logits(
            out, texts, model.tokenizer.eos_token_id
        )
        if self._return_embeddings:
            result_dict.update(
                {"embeddings_decoder": self._get_embeddings_from_output(out)}
            )

        return result_dict

In [None]:
def _gen_samples(n_samples, model, batch, **kwargs):
    batch_size = len(batch["input_ids"])
    logits, sequences = [[] for _ in range(batch_size)], [[] for _ in range(batch_size)]
    with torch.no_grad():
        for k in range(n_samples):
            out = model.generate(batch["texts"], **kwargs)
            cur_logits = torch.stack(out.scores, dim=0)
            for i in range(batch_size):
                sequences[i].append(out.sequences[i])
                logits[i].append(cur_logits[i])
    sequences = [s for sample_seqs in sequences for s in sample_seqs]
    return sequences, sum(logits, [])

class SamplingGenerationCalculator(StatCalculator):
    """
    For WhiteboxModelvLLM model, at input texts batch calculates:
    * sampled texts
    * tokens of the sampled texts
    * probabilities of the sampled tokens generation
    """

    @staticmethod
    def meta_info() -> Tuple[List[str], List[str]]:
        """
        Returns the statistics and dependencies for the calculator.
        """

        return [
            "sample_log_probs",
            "sample_tokens",
            "sample_texts",
            "sample_log_likelihoods",
        ], []

    def __init__(self, samples_n: int = 10):
        super().__init__()
        self.samples_n = samples_n

    def __call__(
        self,
        dependencies: Dict[str, np.array],
        texts: List[str],
        model: WhiteboxModelvLLM,
        max_new_tokens: int = 100,
        **kwargs
    ) -> Dict[str, np.ndarray]:
        """
        Calculates the statistics of sampling texts.

        Parameters:
            dependencies (Dict[str, np.ndarray]): input statistics, can be empty (not used).
            texts (List[str]): Input texts batch used for model generation.
            model (Model): Model used for generation.
            max_new_tokens (int): Maximum number of new tokens at model generation. Default: 100.
        Returns:
            Dict[str, np.ndarray]: dictionary with the following items:
                - 'sample_texts' (List[List[str]]): `samples_n` texts for each input text in the batch,
                - 'sample_tokens' (List[List[List[float]]]): tokenized 'sample_texts',
                - 'sample_log_probs' (List[List[float]]): sum of the log probabilities at each token of the sampling generation.
                - 'sample_log_likelihoods' (List[List[List[float]]]): log probabilities at each token of the sampling generation.
        """
        batch: Dict[str, torch.Tensor] = model.tokenize(texts)
        batch["texts"] = texts
        sequences, logits = _gen_samples(
            self.samples_n,
            model,
            batch,
            **kwargs,
        )

        log_probs = [[] for _ in range(len(texts))]
        tokens = [[] for _ in range(len(texts))]
        texts = [[] for _ in range(len(texts))]
        log_likelihoods = [[] for _ in range(len(texts))]
        if model.model_type == "Seq2SeqLM":
            sequences = [seq[1:] for seq in sequences]
        for i in range(len(logits)):
            log_prob, ll, toks = 0, [], []
            inp_size = (
                len(batch["input_ids"][int(i / self.samples_n)])
                if model.model_type == "CausalLM"
                else 0
            )
            for j in range(len(sequences[i]) - inp_size):
                cur_token = sequences[i][j + inp_size].item()
                log_prob += logits[i][j][cur_token].item()
                if cur_token == model.tokenizer.eos_token_id:
                    break
                ll.append(logits[i][j][cur_token].item())
                toks.append(cur_token)

            log_likelihoods[int(i / self.samples_n)].append(ll)
            log_probs[int(i / self.samples_n)].append(log_prob)
            tokens[int(i / self.samples_n)].append(toks)
            texts[int(i / self.samples_n)].append(model.tokenizer.decode(toks))

        return {
            "sample_log_likelihoods": log_likelihoods,
            "sample_log_probs": log_probs,
            "sample_tokens": tokens,
            "sample_texts": texts,
        }

In [None]:
from lm_polygraph.stat_calculators.greedy_alternatives_nli import GreedyAlternativesNLICalculator
from lm_polygraph.stat_calculators.cross_encoder_similarity import CrossEncoderSimilarityMatrixCalculator
from lm_polygraph.stat_calculators.semantic_matrix import SemanticMatrixCalculator
from lm_polygraph.stat_calculators.semantic_classes import SemanticClassesCalculator

from lm_polygraph.estimators import MaximumSequenceProbability, ClaimConditionedProbability, DegMat, SemanticEntropy, SAR

from lm_polygraph.utils.deberta import Deberta

from torch.utils.data import DataLoader

model_adapter = WhiteboxModelvLLM(llm, device)

calc_infer_llm = InfervLLMCalculator()
nli_model = Deberta(device=device)
nli_model.setup()
calc_nli = GreedyAlternativesNLICalculator(nli_model=nli_model)

calc_samples = SamplingGenerationCalculator()
calc_cross_encoder = CrossEncoderSimilarityMatrixCalculator()
calc_semantic_matrix = SemanticMatrixCalculator(nli_model=nli_model)
calc_semantic_classes = SemanticClassesCalculator()

estimators = [MaximumSequenceProbability(), 
              ClaimConditionedProbability(),
              DegMat(), 
              SemanticEntropy(), 
              SAR()]

In [None]:
data_loader = DataLoader(chat_messages, batch_size=batch_size, shuffle=False, collate_fn=lambda x: x)
for batch in data_loader:
    deps = {"input_texts": batch}
    deps.update(calc_infer_llm(
        deps, texts=batch, model=model_adapter, sampling_params=sampling_params))
    deps.update(calc_nli(deps, texts=batch, model=model_adapter))
    deps.update(calc_samples(deps, texts=batch, model=model_adapter, sampling_params=sampling_params))
    deps.update(calc_cross_encoder(deps, texts=batch, model=model_adapter))
    deps.update(calc_semantic_matrix(deps, texts=batch, model=model_adapter))
    deps.update(calc_semantic_classes(deps, texts=batch, model=model_adapter))
    
    generated_texts = tokenizer.batch_decode(deps['greedy_tokens'])
    ues = []
    for estimator in estimators:
        uncertainty_scores = estimator(deps)
        ues.append((str(estimator), uncertainty_scores))

    for i, text in enumerate(generated_texts):
        print("Output:", text)
        for scores in ues:
            print(f"Uncertainty score by {scores[0]}: {scores[1][i]}")
        print()