In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from lm_polygraph.estimators import *
from lm_polygraph.utils.model import WhiteboxModel
from lm_polygraph.utils.dataset import Dataset
from lm_polygraph.utils.processor import Logger
from lm_polygraph.utils.manager import UEManager
from lm_polygraph.ue_metrics import PredictionRejectionArea
from lm_polygraph.generation_metrics import RougeMetric, BartScoreSeqMetric, ModelScoreSeqMetric, ModelScoreTokenwiseMetric, AggregatedMetric

# Specify HyperParameters

In [None]:
model_path = "bigscience/bloomz-560m"
device = "cpu"
dataset_name = ("trivia_qa", "rc.nocontext")
batch_size = 4
seed = 42

# Initialize Model

In [None]:
base_model = AutoModelForCausalLM.from_pretrained(
    model_path,
    device_map=device,
)
tokenizer = AutoTokenizer.from_pretrained(model_path)

model = WhiteboxModel(base_model, tokenizer)

# Train and Eval Datasets

In [None]:
# Use validation split, since test split of trivia_qa doesn't have reference answers
dataset = Dataset.load(
    dataset_name,
    'question', 'answer',
    batch_size=batch_size,
    prompt="Question: {question}\nAnswer:{answer}",
    split="validation"
)
dataset.subsample(16, seed=seed)

train_dataset = Dataset.load(
    dataset_name,
    'question', 'answer',
    batch_size=batch_size,
    prompt="Question: {question}\nAnswer:{answer}",
    split="train"
)
train_dataset.subsample(16, seed=seed)

# Metric, UE Metric, and UE Methods

In [None]:
ue_methods = [MaximumSequenceProbability(), 
              SemanticEntropy(),
              MahalanobisDistanceSeq("decoder"),]

ue_metrics = [PredictionRejectionArea()]

# Wrap generation metric in AggregatedMetric, since trivia_qa is a multi-reference dataset
# (y is a list of possible correct answers)
metrics = [AggregatedMetric(RougeMetric('rougeL'))]

loggers = [Logger()] 

# Initialize UE Manager

In [None]:
man = UEManager(
    dataset,
    model,
    ue_methods,
    metrics,
    ue_metrics,
    loggers,
    train_data=train_dataset,
)

# Compute Results

In [None]:
results = man()

In [None]:
for key in results.keys():
    print(f"UE Score: {key[1]}, Metric: {key[2]}, UE Metric: {key[3]}, Score: {results[key]:.3f}")