In [64]:

!pip install -q transformers datasets evaluate seqeval tqdm torch


In [65]:
from transformers import AutoTokenizer, AutoModelForTokenClassification
from datasets import load_dataset
import evaluate, torch
from tqdm.auto import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [66]:
model_ckpt = "dslim/bert-base-NER"

tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
model = AutoModelForTokenClassification.from_pretrained(model_ckpt).to(device)
model.eval()

model_id2label = model.config.id2label       # ← model’s own label map


Some weights of the model checkpoint at dslim/bert-base-NER were not used when initializing BertForTokenClassification: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [70]:
dataset = load_dataset("conll2003", split="test", trust_remote_code=True)

dataset_label_list = dataset.features["ner_tags"].feature.names
dataset_id2label  = {i: lbl for i, lbl in enumerate(dataset_label_list)}


In [72]:
metric = evaluate.load("seqeval")

def predict_labels(tokens):
    """Return word‑level labels predicted by the model."""
    enc = tokenizer(tokens,
                    is_split_into_words=True,
                    return_offsets_mapping=True,
                    return_tensors="pt")

    with torch.no_grad():
        logits = model(
            input_ids=enc["input_ids"].to(device),
            attention_mask=enc["attention_mask"].to(device)
        ).logits[0]

    preds    = logits.argmax(dim=-1).tolist()
    word_ids = enc.word_ids(batch_index=0)

    labels, last_word = [], None
    for pred_id, word_id in zip(preds, word_ids):
        if word_id is None or word_id == last_word:
            continue
        labels.append(model_id2label[pred_id])     # ← use model map
        last_word = word_id
    return labels


In [74]:
all_preds, all_gold = [], []

for ex in tqdm(dataset, desc="Evaluating"):
    tokens      = ex["tokens"]
    gold_labels = [dataset_id2label[i] for i in ex["ner_tags"]]

    pred_labels = predict_labels(tokens)

    if len(pred_labels) != len(gold_labels):
        raise ValueError("Length mismatch (alignment error)")

    all_preds.append(pred_labels)
    all_gold.append(gold_labels)

results = metric.compute(predictions=all_preds, references=all_gold)
results


Evaluating:   0%|          | 0/3453 [00:00<?, ?it/s]

{'LOC': {'precision': 0.9320505111244738,
  'recall': 0.9292565947242206,
  'f1': 0.9306514560192134,
  'number': 1668},
 'MISC': {'precision': 0.7819650067294751,
  'recall': 0.8276353276353277,
  'f1': 0.8041522491349481,
  'number': 702},
 'ORG': {'precision': 0.8879107981220657,
  'recall': 0.9108970499698976,
  'f1': 0.89925705794948,
  'number': 1661},
 'PER': {'precision': 0.9573283858998145,
  'recall': 0.9573283858998145,
  'f1': 0.9573283858998145,
  'number': 1617},
 'overall_precision': 0.9065828531517374,
 'overall_recall': 0.9192634560906515,
 'overall_f1': 0.9128791208791208,
 'overall_accuracy': 0.9825347259610208}

In [75]:
for k, v in results.items():
    if isinstance(v, dict):
        print(f"{k}:")
        for sub_k, sub_v in v.items():
            print(f"  {sub_k:<15} {sub_v:.4f}")
    else:
        print(f"{k:<20} {v:.4f}")


LOC:
  precision       0.9321
  recall          0.9293
  f1              0.9307
  number          1668.0000
MISC:
  precision       0.7820
  recall          0.8276
  f1              0.8042
  number          702.0000
ORG:
  precision       0.8879
  recall          0.9109
  f1              0.8993
  number          1661.0000
PER:
  precision       0.9573
  recall          0.9573
  f1              0.9573
  number          1617.0000
overall_precision    0.9066
overall_recall       0.9193
overall_f1           0.9129
overall_accuracy     0.9825


<class 'transformers.models.bert.tokenization_bert_fast.BertTokenizerFast'>
