# Load dataset

In [1]:
from datasets import load_dataset

raw_datasets = load_dataset("squad_v2")
dataset = raw_datasets["validation"]
dataset

  from .autonotebook import tqdm as notebook_tqdm


Dataset({
    features: ['id', 'title', 'context', 'question', 'answers'],
    num_rows: 11873
})

In [2]:
from torch.utils.data import Dataset

class CustomDataset(Dataset):
    def __init__(self, dataset: Dataset, key1: str, key2: str):
        self.dataset = dataset
        self.key1 = key1
        self.key2 = key2

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, i):
        return {"question": self.dataset[i][self.key1], "context": self.dataset[i][self.key2]}

# Load model

In [3]:
import torch
from transformers import pipeline

model = "deepset/roberta-base-squad2"
pipe = pipeline("question-answering", model=model, 
                framework="pt", device="cuda", torch_dtype=torch.float16, # torch.float16 or torch.float32
                padding='max_length', # Will pad the sequences up to the model max length
                truncation=True) # Will truncate the sequences that are longer than the specified max length

# Inference

In [4]:
import time
from tqdm import tqdm
from transformers.pipelines.pt_utils import KeyDataset

for batch_size in [8]:
    start = time.time()
    for output in tqdm(pipe(CustomDataset(dataset, "question", "context"), batch_size=batch_size), total=len(dataset)):
        pass
    end = time.time()
    
    inference_time = end - start
    num_requests = len(dataset)
    print(f"Batch size: {batch_size}")
    print(f"Total inference time: {round(inference_time, 4)}s")
    print(f"Total sample: {num_requests}")
    print(f"Result: {round(num_requests / inference_time)} sample/s")
    print('---------------------------------------------------------')

100%|█████████████████████████████████████████████| 11873/11873 [00:37<00:00, 313.70it/s]

Batch size: 8
Total inference time: 37.8548s
Total sample: 11873
Result: 314 sample/s
---------------------------------------------------------



