# Load dataset

In [1]:
from datasets import load_dataset

raw_datasets = load_dataset("conll2003")
dataset = raw_datasets["test"]
dataset

  from .autonotebook import tqdm as notebook_tqdm


Dataset({
    features: ['id', 'tokens', 'pos_tags', 'chunk_tags', 'ner_tags'],
    num_rows: 3453
})

In [2]:
from torch.utils.data import Dataset

class CustomDataset(Dataset):
    def __init__(self, dataset: Dataset, key1: str):
        self.dataset = dataset
        self.key1 = key1

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, i):
        text = " ".join(self.dataset[i][self.key1])
        return text

# Load model

In [3]:
import torch
from transformers import pipeline

model = "dslim/bert-base-NER-uncased"
pipe = pipeline("token-classification", model=model, 
                framework="pt", device="cuda", torch_dtype=torch.float16) # torch.float16 or torch.float32

Some weights of the model checkpoint at dslim/bert-base-NER-uncased were not used when initializing BertForTokenClassification: ['bert.pooler.dense.weight', 'bert.pooler.dense.bias']
- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


# Inference

In [4]:
import time
from tqdm import tqdm
from transformers.pipelines.pt_utils import KeyDataset

for batch_size in [8]:
    start = time.time()
    for output in tqdm(pipe(CustomDataset(dataset, "tokens"), batch_size=batch_size), total=len(dataset)):
        pass
    end = time.time()
    
    inference_time = end - start
    num_requests = len(dataset)
    print(f"Batch size: {batch_size}")
    print(f"Total inference time: {round(inference_time, 4)}s")
    print(f"Total sample: {num_requests}")
    print(f"Result: {round(num_requests / inference_time)} sample/s")
    print('---------------------------------------------------------')

100%|███████████████████████████████████████████████| 3453/3453 [00:05<00:00, 595.19it/s]

Batch size: 8
Total inference time: 5.8117s
Total sample: 3453
Result: 594 sample/s
---------------------------------------------------------



