# Fine-Tuning BERT models for NER

by Benjamin Kissinger & Andreas Sünder

## Install required packages (only once)

```bash
%pip install -r requirements.txt
```

## Setup

Open up a terminal and run the following commands:

```bash
huggingface-cli login
wandb login
```

In [None]:
model_id = 'distilbert-base-cased'

## Load dataset

In [None]:
from datasets import load_dataset
dataset = load_dataset('textminr/ner_tokenized')

## Process dataset

In [None]:
label_list = ['O', 'AUTHOR', 'DATE']

In [None]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id, add_prefix_space=True)

In [None]:
def tokenize_and_align_labels(row):
  tokenized_inputs = tokenizer(row['tokens'], truncation=True, is_split_into_words=True)

  labels = []
  for i, label in enumerate(row[f'ner_ids']):
    word_ids = tokenized_inputs.word_ids(batch_index=i)  # Map tokens to their respective word.
    previous_word_idx = None
    label_ids = []
    for word_idx in word_ids:  # Set the special tokens to -100.
      if word_idx is None:
        label_ids.append(-100)
      elif word_idx != previous_word_idx:  # Only label the first token of a given word.
        label_ids.append(label[word_idx])
      else:
        label_ids.append(-100)
      previous_word_idx = word_idx
    labels.append(label_ids)

  tokenized_inputs['labels'] = labels
  return tokenized_inputs

tokenized_datasets = dataset.map(tokenize_and_align_labels, batched=True)

In [None]:
from transformers import DataCollatorForTokenClassification
data_collator = DataCollatorForTokenClassification(tokenizer)

In [None]:
import evaluate
seqeval = evaluate.load('seqeval')

In [None]:
import numpy as np

def compute_metrics(p):
  predictions, labels = p
  predictions = np.argmax(predictions, axis=2)

  true_predictions = [
    [label_list[p] for (p, l) in zip(prediction, label) if l != -100]
    for prediction, label in zip(predictions, labels)
  ]
  true_labels = [
    [label_list[l] for (p, l) in zip(prediction, label) if l != -100]
    for prediction, label in zip(predictions, labels)
  ]

  results = seqeval.compute(predictions=true_predictions, references=true_labels)
  return {
    'precision': results['overall_precision'],
    'recall': results['overall_recall'],
    'f1': results['overall_f1'],
    'accuracy': results['overall_accuracy'],
  }

## Train model

In [None]:
id2tag= {
  0: 'O',
  1: 'AUTHOR',
  2: 'DATE',
}

tag2id = {v: k for k, v in id2tag.items()}

In [None]:
from transformers import AutoModelForTokenClassification
import torch

model = AutoModelForTokenClassification.from_pretrained(
  model_id,
  num_labels=len(label_list),
  id2label=id2tag,
  label2id=tag2id,
)

In [None]:
from transformers import TrainingArguments, Trainer
from datetime import datetime

PROJECT_NAME = 'ner_distilbert-base-cased'
%env WANDB_PROJECT=$PROJECT_NAME

training_args = TrainingArguments(
  output_dir=f'models/{PROJECT_NAME}',
  fp16=False,
  bf16=False,
  learning_rate=2e-5,
  auto_find_batch_size=True,
  num_train_epochs=1,
  logging_strategy='steps',
  logging_steps=200,
  evaluation_strategy='steps',
  eval_steps=200,
  report_to='wandb',
  save_strategy='no',
  run_name=f'{PROJECT_NAME}-{datetime.now().strftime("%Y-%m-%d-%H-%M")}'
)
  
trainer = Trainer(
  model=model,
  args=training_args,
  train_dataset=tokenized_datasets['train'],
  eval_dataset=tokenized_datasets['validation'],
  tokenizer=tokenizer,
  data_collator=data_collator,
  compute_metrics=compute_metrics,
)

trainer.train()

In [None]:
model.push_to_hub('textminr/ner_distil-bert')

## Inference

In [None]:
model_id = 'textminr/ner_distil-bert'

In [None]:
from transformers import pipeline
classifier = pipeline(
  'ner',
  model=model_id,
  tokenizer='distilbert-base-cased',
  aggregation_strategy='simple'
)

In [None]:
# sentence = "His book, written in 2013, mentions his Project and helps young people." 
sentence = "Captivated by the medieval tapestry, the author Albert Einstein transcribed the wisdom of 'Ink and Parchment' by Eleanor the Wise in the year 1268" 

In [None]:
classifier(sentence)