# Fine-Tuning BERT models for NER

by Benjamin Kissinger & Andreas Sünder

## Install required packages (only once)

```bash
%pip install -r requirements.txt
```

## Setup

Open up a terminal and run the following commands:

```bash
huggingface-cli login
wandb login
```

In [1]:
model_id = 'distilbert-base-cased'

## Load dataset

In [2]:
from datasets import load_dataset
dataset = load_dataset('textminr/ner_tokenized')

## Process dataset

In [3]:
label_list = ['O', 'AUTHOR', 'DATE']

In [4]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id)

In [5]:
def tokenize_and_align_labels(row):
  tokenized_inputs = tokenizer(row['tokens'], truncation=True, is_split_into_words=True)

  labels = []
  for i, label in enumerate(row[f'ner_ids']):
    word_ids = tokenized_inputs.word_ids(batch_index=i)  # Map tokens to their respective word.
    previous_word_idx = None
    label_ids = []
    for word_idx in word_ids:  # Set the special tokens to -100.
      if word_idx is None:
        label_ids.append(-100)
      elif word_idx != previous_word_idx:  # Only label the first token of a given word.
        label_ids.append(label[word_idx])
      else:
        label_ids.append(-100)
      previous_word_idx = word_idx
    labels.append(label_ids)

  tokenized_inputs['labels'] = labels
  return tokenized_inputs

tokenized_datasets = dataset.map(tokenize_and_align_labels, batched=True)

In [6]:
from transformers import DataCollatorForTokenClassification
data_collator = DataCollatorForTokenClassification(tokenizer)

In [7]:
import evaluate
seqeval = evaluate.load('seqeval')

In [8]:
import numpy as np

def compute_metrics(p):
  predictions, labels = p
  predictions = np.argmax(predictions, axis=2)

  true_predictions = [
    [label_list[p] for (p, l) in zip(prediction, label) if l != -100]
    for prediction, label in zip(predictions, labels)
  ]
  true_labels = [
    [label_list[l] for (p, l) in zip(prediction, label) if l != -100]
    for prediction, label in zip(predictions, labels)
  ]

  results = seqeval.compute(predictions=true_predictions, references=true_labels)
  return {
    'precision': results['overall_precision'],
    'recall': results['overall_recall'],
    'f1': results['overall_f1'],
    'accuracy': results['overall_accuracy'],
  }

## Train model

In [9]:
id2tag= {
  0: 'O',
  1: 'AUTHOR',
  2: 'DATE',
}

tag2id = {v: k for k, v in id2tag.items()}

In [10]:
from transformers import AutoModelForTokenClassification
import torch

model = AutoModelForTokenClassification.from_pretrained(
  model_id,
  num_labels=len(label_list),
  id2label=id2tag,
  label2id=tag2id,
  # device_map='cuda:0',
)

Some weights of DistilBertForTokenClassification were not initialized from the model checkpoint at distilbert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [11]:
from transformers import TrainingArguments, Trainer

project_name = 'ner-distilbert-english'

training_args = TrainingArguments(
  output_dir=f'models/{project_name}',
  fp16=False,
  learning_rate=2e-5,
  auto_find_batch_size=True,
  num_train_epochs=1,
  evaluation_strategy='steps',
  eval_steps=200,
  save_strategy='steps',
  save_steps=1000,
  load_best_model_at_end=True,
)
  
trainer = Trainer(
  model=model,
  args=training_args,
  train_dataset=tokenized_datasets['train'],
  eval_dataset=tokenized_datasets['validation'],
  tokenizer=tokenizer,
  data_collator=data_collator,
  compute_metrics=compute_metrics,
)

In [12]:
trainer.train()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33masuender[0m ([33mtextminr[0m). Use [1m`wandb login --relogin`[0m to force relogin


  0%|          | 0/2025 [00:00<?, ?it/s]

You're using a DistilBertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


  0%|          | 0/507 [00:00<?, ?it/s]



{'eval_loss': 0.00324088241904974, 'eval_precision': 0.9988090512107979, 'eval_recall': 0.9913317572892041, 'eval_f1': 0.9950563575242239, 'eval_accuracy': 0.9991569772065724, 'eval_runtime': 3.917, 'eval_samples_per_second': 1033.692, 'eval_steps_per_second': 129.435, 'epoch': 0.1}


  0%|          | 0/507 [00:00<?, ?it/s]

{'eval_loss': 0.0015421499265357852, 'eval_precision': 0.9998020194020986, 'eval_recall': 0.9948778565799843, 'eval_f1': 0.9973338599782758, 'eval_accuracy': 0.9995546294676232, 'eval_runtime': 3.883, 'eval_samples_per_second': 1042.755, 'eval_steps_per_second': 130.57, 'epoch': 0.2}
{'loss': 0.0218, 'learning_rate': 1.506172839506173e-05, 'epoch': 0.25}


  0%|          | 0/507 [00:00<?, ?it/s]

{'eval_loss': 0.001125755487009883, 'eval_precision': 0.9990157480314961, 'eval_recall': 0.9998029944838456, 'eval_f1': 0.999409216226861, 'eval_accuracy': 0.9998886573669058, 'eval_runtime': 3.8631, 'eval_samples_per_second': 1048.12, 'eval_steps_per_second': 131.241, 'epoch': 0.3}


  0%|          | 0/507 [00:00<?, ?it/s]

{'eval_loss': 0.0005293559515848756, 'eval_precision': 0.9998029944838456, 'eval_recall': 0.9998029944838456, 'eval_f1': 0.9998029944838456, 'eval_accuracy': 0.9999522817286739, 'eval_runtime': 3.8978, 'eval_samples_per_second': 1038.801, 'eval_steps_per_second': 130.075, 'epoch': 0.4}
{'loss': 0.001, 'learning_rate': 1.0123456790123458e-05, 'epoch': 0.49}


  0%|          | 0/507 [00:00<?, ?it/s]

{'eval_loss': 0.00031827494967728853, 'eval_precision': 0.9996059889676912, 'eval_recall': 0.9996059889676912, 'eval_f1': 0.9996059889676912, 'eval_accuracy': 0.9999363756382319, 'eval_runtime': 3.9145, 'eval_samples_per_second': 1034.371, 'eval_steps_per_second': 129.52, 'epoch': 0.49}


  0%|          | 0/507 [00:00<?, ?it/s]



{'eval_loss': 0.000282159773632884, 'eval_precision': 0.9996060665747488, 'eval_recall': 0.9998029944838456, 'eval_f1': 0.9997045208312814, 'eval_accuracy': 0.9999522817286739, 'eval_runtime': 4.0818, 'eval_samples_per_second': 991.96, 'eval_steps_per_second': 124.209, 'epoch': 0.59}


  0%|          | 0/507 [00:00<?, ?it/s]

{'eval_loss': 0.00027990678790956736, 'eval_precision': 0.9998030332873744, 'eval_recall': 1.0, 'eval_f1': 0.9999015069437605, 'eval_accuracy': 0.9999681878191159, 'eval_runtime': 3.8802, 'eval_samples_per_second': 1043.505, 'eval_steps_per_second': 130.664, 'epoch': 0.69}
{'loss': 0.0003, 'learning_rate': 5.185185185185185e-06, 'epoch': 0.74}


  0%|          | 0/507 [00:00<?, ?it/s]

{'eval_loss': 0.0002756484318524599, 'eval_precision': 0.9998030332873744, 'eval_recall': 1.0, 'eval_f1': 0.9999015069437605, 'eval_accuracy': 0.9999681878191159, 'eval_runtime': 4.027, 'eval_samples_per_second': 1005.466, 'eval_steps_per_second': 125.9, 'epoch': 0.79}


  0%|          | 0/507 [00:00<?, ?it/s]

{'eval_loss': 0.00027747408603318036, 'eval_precision': 0.9998030332873744, 'eval_recall': 1.0, 'eval_f1': 0.9999015069437605, 'eval_accuracy': 0.9999681878191159, 'eval_runtime': 3.9227, 'eval_samples_per_second': 1032.189, 'eval_steps_per_second': 129.247, 'epoch': 0.89}
{'loss': 0.0003, 'learning_rate': 2.469135802469136e-07, 'epoch': 0.99}


  0%|          | 0/507 [00:00<?, ?it/s]

{'eval_loss': 0.00026903985417447984, 'eval_precision': 0.9998030332873744, 'eval_recall': 1.0, 'eval_f1': 0.9999015069437605, 'eval_accuracy': 0.9999681878191159, 'eval_runtime': 4.0134, 'eval_samples_per_second': 1008.871, 'eval_steps_per_second': 126.327, 'epoch': 0.99}
{'train_runtime': 117.641, 'train_samples_per_second': 137.656, 'train_steps_per_second': 17.213, 'train_loss': 0.00579406047096177, 'epoch': 1.0}


TrainOutput(global_step=2025, training_loss=0.00579406047096177, metrics={'train_runtime': 117.641, 'train_samples_per_second': 137.656, 'train_steps_per_second': 17.213, 'train_loss': 0.00579406047096177, 'epoch': 1.0})

## Inference

In [4]:
model_id = 'distilbert-base-cased'

In [8]:
from transformers import pipeline
classifier = pipeline(
  'ner',
  model='models/ner-distilbert-english/checkpoint-2000',
  tokenizer=model_id,
  aggregation_strategy='simple'
)

In [15]:
# sentence = "In his book 'The Whale' Herman Melville mentions the date 1851. He was born in 1819. His mother, Maria Gansevoort, died in 1832." 
sentence = "In his book 'The Whale' Herman Melville mentions the date 1851. He was born in 1819 and wrote 'The Whale' in 2020. His mother, Maria Gansevoort, died in 1832." 

In [16]:
classifier(sentence)

[{'entity_group': 'AUTHOR',
  'score': 0.99893075,
  'word': 'Herman Melville',
  'start': 24,
  'end': 39},
 {'entity_group': 'DATE',
  'score': 0.9672856,
  'word': '1851',
  'start': 58,
  'end': 62},
 {'entity_group': 'DATE',
  'score': 0.96138334,
  'word': '1819',
  'start': 79,
  'end': 83},
 {'entity_group': 'DATE',
  'score': 0.999453,
  'word': '2020',
  'start': 109,
  'end': 113},
 {'entity_group': 'AUTHOR',
  'score': 0.9973533,
  'word': 'Maria G',
  'start': 127,
  'end': 134},
 {'entity_group': 'DATE',
  'score': 0.99920195,
  'word': '1832',
  'start': 153,
  'end': 157}]