# Load dataset

In [5]:
from datasets import load_dataset

raw_datasets = load_dataset("squad_v2")
dataset = raw_datasets["validation"]
dataset

Dataset({
    features: ['id', 'title', 'context', 'question', 'answers'],
    num_rows: 11873
})

In [8]:
from torch.utils.data import Dataset

class CustomDataset(Dataset):
    def __init__(self, dataset: Dataset, key1: str, key2: str):
        self.dataset = dataset
        self.key1 = key1
        self.key2 = key2

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, i):
        return {"question": self.dataset[i][self.key1], "context": self.dataset[i][self.key2]}

# Load model

In [6]:
import torch
from transformers import pipeline

model = "deepset/roberta-base-squad2"
pipe = pipeline("question-answering", model=model, 
                framework="pt", device="cuda", torch_dtype=torch.float16, # torch.float16 or torch.float32
                padding='max_length', # Will pad the sequences up to the model max length
                truncation=True) # Will truncate the sequences that are longer than the specified max length

Downloading (…)lve/main/config.json: 100%|█████| 571/571 [00:00<00:00, 78.3kB/s]
Downloading model.safetensors: 100%|██████████| 496M/496M [00:00<00:00, 633MB/s]
Downloading (…)okenizer_config.json: 100%|███| 79.0/79.0 [00:00<00:00, 8.36kB/s]
Downloading (…)olve/main/vocab.json: 100%|████| 899k/899k [00:00<00:00, 931kB/s]
Downloading (…)olve/main/merges.txt: 100%|████| 456k/456k [00:00<00:00, 634kB/s]
Downloading (…)cial_tokens_map.json: 100%|██████| 772/772 [00:00<00:00, 380kB/s]


# Inference

In [12]:
import time
from tqdm import tqdm
from transformers.pipelines.pt_utils import KeyDataset

for batch_size in [8]:
    start = time.time()
    for output in tqdm(pipe(CustomDataset(dataset, "question", "context"), batch_size=batch_size), total=len(dataset)):
        pass
    end = time.time()
    
    inference_time = end - start
    num_requests = len(dataset)
    print(f"Batch size: {batch_size}")
    print(f"Total inference time: {round(inference_time, 4)}s")
    print(f"Total sample: {num_requests}")
    print(f"Result: {round(num_requests / inference_time)} sample/s")
    print('---------------------------------------------------------')

100%|████████████████████████████████████| 11873/11873 [00:36<00:00, 321.40it/s]

Batch size: 8
Total inference time: 36.9465s
Total sample: 11873
Result: 321 sample/s
---------------------------------------------------------



