In [1]:
import torch
from datasets import load_dataset
from evaluate import evaluator
from pprint import pprint

# Set device
def get_device():
    if torch.cuda.is_available():
        return torch.device('cuda')
    elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
        return torch.device('mps')  # For Apple Silicon
    else:
        return torch.device('cpu')

device = get_device()
print("Device:", device)

Device: cuda


In [2]:
test_size = 1000
data = load_dataset('squad', split=f'validation[:{test_size}]')

# model = 'distilbert-base-uncased-distilled-squad'
model = 'unsloth/Meta-Llama-3.1-8B'

task_evaluator = evaluator('question-answering')

eval_results = task_evaluator.compute(
    model_or_pipeline=model,
    data=data,
    metric='squad',
    strategy='bootstrap',
    n_resamples=30,
    squad_v2_format=False,  # Whether the dataset follows the format of squad_v2 dataset, where a question may have no answer in the context
)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Some weights of LlamaForQuestionAnswering were not initialized from the model checkpoint at unsloth/Meta-Llama-3.1-8B and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight', 'transformer.embed_tokens.weight', 'transformer.layers.0.input_layernorm.weight', 'transformer.layers.0.mlp.down_proj.weight', 'transformer.layers.0.mlp.gate_proj.weight', 'transformer.layers.0.mlp.up_proj.weight', 'transformer.layers.0.post_attention_layernorm.weight', 'transformer.layers.0.self_attn.k_proj.weight', 'transformer.layers.0.self_attn.o_proj.weight', 'transformer.layers.0.self_attn.q_proj.weight', 'transformer.layers.0.self_attn.v_proj.weight', 'transformer.layers.1.input_layernorm.weight', 'transformer.layers.1.mlp.down_proj.weight', 'transformer.layers.1.mlp.gate_proj.weight', 'transformer.layers.1.mlp.up_proj.weight', 'transformer.layers.1.post_attention_layernorm.weight', 'transformer.layers.1.self_attn.k_proj.weight', 'transformer.layers.1.self_attn.o_proj.weight', 'transformer.layers

In [3]:
pprint(eval_results, width=50, sort_dicts=False)

{'exact_match': {'confidence_interval': (0.1,
                                         0.7),
                 'standard_error': 0.16832508230603463,
                 'score': 0.3},
 'f1': {'confidence_interval': (3.4574312388097836,
                                4.913074192513453),
        'standard_error': 0.4179175251449614,
        'score': 4.172294649794647},
 'total_time_in_seconds': 52.72567349579185,
 'samples_per_second': 18.966092487748156,
 'latency_in_seconds': 0.05272567349579186}
