# **Token Classification**
Token classification is a natural language understanding task in which a label is assigned to individual tokens in a sentence. Some popular token classification tasks are Named Entity Recognition (NER) and Part-of-Speech tagging (POS).

We shall fine-tune a BERT model using TensorFlow on a NER task using CoNLL-2003 dataset, which contains news stories from Reuters.

### **1. Install and Import Required Libraries**

In [None]:
!pip install datasets transformers evaluate seqeval

In [None]:
import tensorflow as tf
import numpy as np
import evaluate

from transformers import AutoTokenizer, DataCollatorForTokenClassification, TFAutoModelForTokenClassification, create_optimizer, pipeline
from datasets import load_dataset

### **2. Load Data**

In [None]:
raw_dataset = load_dataset('conll2003')

In [None]:
raw_dataset

DatasetDict({
    train: Dataset({
        features: ['id', 'tokens', 'pos_tags', 'chunk_tags', 'ner_tags'],
        num_rows: 14041
    })
    validation: Dataset({
        features: ['id', 'tokens', 'pos_tags', 'chunk_tags', 'ner_tags'],
        num_rows: 3250
    })
    test: Dataset({
        features: ['id', 'tokens', 'pos_tags', 'chunk_tags', 'ner_tags'],
        num_rows: 3453
    })
})

In [None]:
label_names = raw_dataset['train'].features['ner_tags'].feature.names
label_names

['O', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC', 'B-MISC', 'I-MISC']

### **3. Preprocess Data**

In [None]:
model_checkpoint = 'bert-base-cased'
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

In [None]:
def align_labels_with_tokens(labels, word_ids):
  # For tokens inside a word and not at the beginning, replacing B-XXXX with I-XXXX by adding 1
  new_labels = [-100 if word_id is None
                else (labels[word_id] + (labels[word_id] % 2)) if (word_ids[i] == word_ids[i-1]) else labels[word_id]
                for i, word_id in enumerate(word_ids)]
  return new_labels

In [None]:
def tokenize_and_align_labels(examples):
  tokenized_inputs = tokenizer(examples['tokens'], truncation=True, is_split_into_words=True)

  new_labels = list()
  for i, labels in enumerate(examples['ner_tags']):
    word_ids = tokenized_inputs.word_ids(i)
    new_labels.append(align_labels_with_tokens(labels, word_ids))

  tokenized_inputs['labels'] = new_labels
  return tokenized_inputs

In [None]:
tokenized_dataset = raw_dataset.map(tokenize_and_align_labels, batched=True, remove_columns=raw_dataset['train'].column_names)

In [None]:
tokenized_dataset

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 14041
    })
    validation: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 3250
    })
    test: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 3453
    })
})

In [None]:
batch_size = 16
data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer, return_tensors='tf')

tf_train_dataset = tokenized_dataset['train'].to_tf_dataset(
    columns=['input_ids', 'token_type_ids', 'attention_mask', 'labels'],
    collate_fn=data_collator,
    shuffle=True,
    batch_size=batch_size
)

tf_validation_dataset = tokenized_dataset['validation'].to_tf_dataset(
    columns=['input_ids', 'token_type_ids', 'attention_mask', 'labels'],
    collate_fn=data_collator,
    shuffle=False,
    batch_size=batch_size
)

tf_test_dataset = tokenized_dataset['test'].to_tf_dataset(
    columns=['input_ids', 'token_type_ids', 'attention_mask', 'labels'],
    collate_fn=data_collator,
    shuffle=False,
    batch_size=batch_size
)

### **4. Define Model**

In [None]:
id2label = {i: label for i, label in enumerate(label_names)}
label2id = {value: key for key, value in id2label.items()}

In [None]:
model = TFAutoModelForTokenClassification.from_pretrained(model_checkpoint, id2label=id2label, label2id=label2id)

In [None]:
model.config.num_labels

9

### **5. Fine-tune the Model**

In [None]:
# Training in mixed-precision float16
tf.keras.mixed_precision.set_global_policy('mixed_float16')

num_epochs = 5
num_train_steps = len(tf_train_dataset) * num_epochs

optimizer, schedule = create_optimizer(
    init_lr=2e-5,
    num_train_steps=num_train_steps,
    num_warmup_steps=0,
    weight_decay_rate=0.01
)

model.compile(optimizer=optimizer, metrics=['accuracy'])
history = model.fit(tf_train_dataset, validation_data=tf_validation_dataset, epochs=num_epochs, verbose=1)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


### **6. Compute Metrics**

In [None]:
metric = evaluate.load('seqeval')

In [None]:
all_predictions = list()
all_labels = list()

for batch in tf_test_dataset:
  logits = model.predict_on_batch(batch)['logits']
  predictions = np.argmax(logits, axis=-1)
  labels = batch['labels']

  for prediction, label in zip(predictions, labels):
    for pred, lbl in zip(prediction, label):
      if lbl == -100:
        continue
      all_predictions.append(label_names[pred])
      all_labels.append(label_names[lbl])

metric.compute(predictions=[all_predictions], references=[all_labels])

{'LOC': {'precision': 0.9061771561771562,
  'recall': 0.9322541966426858,
  'f1': 0.9190307328605201,
  'number': 1668},
 'MISC': {'precision': 0.7234314980793854,
  'recall': 0.8048433048433048,
  'f1': 0.7619689817936615,
  'number': 702},
 'ORG': {'precision': 0.8567351598173516,
  'recall': 0.9036724864539434,
  'f1': 0.8795780837972459,
  'number': 1661},
 'PER': {'precision': 0.9451553930530164,
  'recall': 0.9591836734693877,
  'f1': 0.9521178637200736,
  'number': 1617},
 'overall_precision': 0.8780984719864177,
 'overall_recall': 0.9157223796033994,
 'overall_f1': 0.8965158606344253,
 'overall_accuracy': 0.9716361229731646}

### **7. Predict using the Fine-tuned Model**

In [None]:
token_classifier = pipeline('token-classification', model=model, tokenizer=tokenizer, aggregation_strategy='simple')

In [None]:
token_classifier('Opel AG together with General Motors came in second place.')

[{'entity_group': 'ORG',
  'score': 0.99919504,
  'word': 'Opel AG',
  'start': 0,
  'end': 7},
 {'entity_group': 'ORG',
  'score': 0.9992789,
  'word': 'General Motors',
  'start': 22,
  'end': 36}]