In [1]:
from transformers import BertTokenizerFast, BertForTokenClassification
import torch

###### Mapping of label indices to label names and directory where the trained model is saved and load the tokenizer and the model for token classification

In [56]:
id2label = {0: 'B-MOUNTAIN', 1: 'I-MOUNTAIN', 2: 'O'}
model_dir = 'model_save/'
tokenizer = BertTokenizerFast.from_pretrained(model_dir)
model = BertForTokenClassification.from_pretrained(model_dir)

Predict named entities in the given text using the loaded model.

In [57]:
def predict(text):
    tokenized_input = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
    
    with torch.no_grad():
        outputs = model(**tokenized_input)
    
    predicted_labels = outputs.logits.argmax(dim=-1)[0]

    named_entities = []
    for token, label in zip(tokenized_input["input_ids"][0], predicted_labels):
        label_id = label.item()
        label_name = id2label[label_id]

        named_entities.append((tokenizer.decode([token]), label_name))

    return named_entities


In [58]:
text = "Alps is the tallest mountain in the world, attracting climbers from all over the globe."

In [59]:
token_label_pairs = predict(text)

In [60]:
print(token_label_pairs)

[('[CLS]', 'O'), ('alps', 'B-MOUNTAIN'), ('is', 'O'), ('the', 'O'), ('tallest', 'O'), ('mountain', 'O'), ('in', 'O'), ('the', 'O'), ('world', 'O'), (',', 'O'), ('attracting', 'O'), ('climb', 'O'), ('##ers', 'O'), ('from', 'O'), ('all', 'O'), ('over', 'O'), ('the', 'O'), ('globe', 'O'), ('.', 'O'), ('[SEP]', 'O')]
