This is a shorter version of the token classification notebook prepared by Huggingface. 

Some parts are intentionally removed, to make this notebook easier, for educational purposes.

Source of the original notebook: 

https://colab.research.google.com/github/huggingface/notebooks/blob/master/examples/token_classification.ipynb#scrollTo=545PP3o8IrJV

# Token Classification

Token classification is about classifying the parts (words, subwords...) of a text.

Most known application is Named Entity Recognition:
- [ "My", "name", "is", "Ahmet", "." ]
- [ "O", "O", "O", "PERSON", "O" ]

Named entity recognition finds the special entities in a text, such as "person", "location", "date". 

It is a type of token classification, classes being "O", "PERSON", "LOC", "DATE".

## Data Exploration for Named Entity Recognition

In [1]:
! pip install datasets transformers

Collecting datasets
[?25l  Downloading https://files.pythonhosted.org/packages/ee/78/5873ac1e27bf25a2cbf3447d6704edd3136b1b3ff0eb3bfab38a45d2a1ff/datasets-1.2.0-py3-none-any.whl (159kB)
[K     |████████████████████████████████| 163kB 8.7MB/s 
[?25hCollecting transformers
[?25l  Downloading https://files.pythonhosted.org/packages/50/0c/7d5950fcd80b029be0a8891727ba21e0cd27692c407c51261c3c921f6da3/transformers-4.1.1-py3-none-any.whl (1.5MB)
[K     |████████████████████████████████| 1.5MB 8.6MB/s 
[?25hCollecting xxhash
[?25l  Downloading https://files.pythonhosted.org/packages/f7/73/826b19f3594756cb1c6c23d2fbd8ca6a77a9cd3b650c9dec5acc85004c38/xxhash-2.0.0-cp36-cp36m-manylinux2010_x86_64.whl (242kB)
[K     |████████████████████████████████| 245kB 35.7MB/s 
[?25hCollecting pyarrow>=0.17.1
[?25l  Downloading https://files.pythonhosted.org/packages/d7/e1/27958a70848f8f7089bff8d6ebe42519daf01f976d28b481e1bfd52c8097/pyarrow-2.0.0-cp36-cp36m-manylinux2014_x86_64.whl (17.7MB)
[K     |█

In [2]:
# token classification
# named entity recognition NER
# part of speech tagging POS

# distillation learning

task = "ner" # Should be one of "ner", "pos" or "chunk"
model_checkpoint = "distilbert-base-uncased"
batch_size = 16

from datasets import load_dataset, load_metric
datasets = load_dataset("conll2003")

datasets

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=2584.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1781.0, style=ProgressStyle(description…


Downloading and preparing dataset conll2003/conll2003 (download: 4.63 MiB, generated: 9.78 MiB, post-processed: Unknown size, total: 14.41 MiB) to /root/.cache/huggingface/datasets/conll2003/conll2003/1.0.0/63ba56944e35c1943434322a07ceefd79864672041b7834583709af4a5de4664...


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=649539.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=162714.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=145897.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Dataset conll2003 downloaded and prepared to /root/.cache/huggingface/datasets/conll2003/conll2003/1.0.0/63ba56944e35c1943434322a07ceefd79864672041b7834583709af4a5de4664. Subsequent calls will reuse this data.


DatasetDict({
    train: Dataset({
        features: ['id', 'tokens', 'pos_tags', 'chunk_tags', 'ner_tags'],
        num_rows: 14041
    })
    validation: Dataset({
        features: ['id', 'tokens', 'pos_tags', 'chunk_tags', 'ner_tags'],
        num_rows: 3250
    })
    test: Dataset({
        features: ['id', 'tokens', 'pos_tags', 'chunk_tags', 'ner_tags'],
        num_rows: 3453
    })
})

In [3]:
datasets["train"][0]

{'chunk_tags': [11, 21, 11, 12, 21, 22, 11, 12, 0],
 'id': '0',
 'ner_tags': [3, 0, 7, 0, 0, 0, 7, 0, 0],
 'pos_tags': [22, 42, 16, 21, 35, 37, 16, 21, 7],
 'tokens': ['EU',
  'rejects',
  'German',
  'call',
  'to',
  'boycott',
  'British',
  'lamb',
  '.']}

In [4]:
label_list = datasets["train"].features[f"{task}_tags"].feature.names
label_list

['O', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC', 'B-MISC', 'I-MISC']

### Show random elements from dataset to understand it better

In [5]:
from datasets import ClassLabel, Sequence
import random
import pandas as pd
from IPython.display import display, HTML

def show_random_elements(dataset, num_examples=10):
    assert num_examples <= len(dataset), "Can't pick more elements than there are in the dataset."
    picks = []
    for _ in range(num_examples):
        pick = random.randint(0, len(dataset)-1)
        while pick in picks:
            pick = random.randint(0, len(dataset)-1)
        picks.append(pick)
    
    df = pd.DataFrame(dataset[picks])
    for column, typ in dataset.features.items():
        if isinstance(typ, ClassLabel):
            df[column] = df[column].transform(lambda i: typ.names[i])
        elif isinstance(typ, Sequence) and isinstance(typ.feature, ClassLabel):
            df[column] = df[column].transform(lambda x: [typ.feature.names[i] for i in x])
    display(HTML(df.to_html()))

In [6]:
show_random_elements(datasets["train"])

Unnamed: 0,chunk_tags,id,ner_tags,pos_tags,tokens
0,"[O, B-PP, B-ADVP, B-NP, B-VP, I-VP, B-PP, B-NP, O, B-PP, B-NP, B-NP, B-VP, I-VP, I-VP, B-NP, O, B-LST, B-NP, B-NP, I-NP, I-NP, I-NP, O, B-NP, B-VP, B-NP, O, O, B-NP, I-NP, O, B-NP, B-PP, B-NP, I-NP, I-NP, I-NP, I-NP, I-NP, B-VP, B-NP, I-NP, B-VP, B-PP, B-NP, I-NP, I-NP, O]",11559,"[O, O, O, O, O, O, O, B-LOC, O, O, B-PER, O, O, O, O, B-PER, O, O, B-PER, O, O, B-PER, I-PER, O, B-PER, O, O, O, O, B-PER, I-PER, O, O, O, O, O, O, B-ORG, I-ORG, I-ORG, O, O, O, O, O, O, O, O, O]","["", IN, RB, PRP, MD, VB, TO, NNP, CC, IN, NNP, PRP, MD, RB, VB, NNP, CC, (, NNP, POS, NN, NNP, NNP, ), NNP, VB, NN, ,, "", NNP, NNP, ,, NN, IN, DT, JJ, NN, NNP, NNP, NNP, VBD, DT, NN, VBD, IN, DT, JJ, NN, .]","["", From, here, we, will, march, to, Islamabad, and, by, God, we, will, not, let, Benazir, and, (, Bhutto, 's, husband, Asif, Ali, ), Zardari, escape, justice, ,, "", Nawaz, Sharif, ,, leader, of, the, main, opposition, Pakistan, Muslim, League, told, a, rally, organised, by, a, 16-party, alliance, .]"
1,"[B-NP, I-NP]",2533,"[B-LOC, O]","[NNS, CD]","[ATHENS, 1996-08-23]"
2,"[B-NP, I-NP, B-VP, I-VP, I-VP, B-SBAR, B-NP, I-NP, B-VP, B-PP, B-NP, I-NP, B-PP, B-NP, I-NP, I-NP, I-NP, B-VP, B-PP, B-NP, I-NP, B-PP, B-NP, I-NP, O]",9072,"[O, O, O, O, O, O, O, B-MISC, O, O, O, O, O, O, O, O, O, O, O, B-MISC, O, O, O, O, O]","[DT, NNS, VBD, VBN, VBN, IN, DT, NNPS, VBD, IN, DT, NNS, TO, NN, NNS, CC, NN, VBD, IN, JJ, NNS, IN, DT, NN, .]","[The, talks, had, been, postponed, while, the, Russians, waited, for, the, rebels, to, return, arms, and, ammunition, seized, from, Russian, soldiers, at, the, weekend, .]"
3,"[B-NP, I-NP, I-NP, O, B-NP, O, B-NP, I-NP]",9956,"[O, B-PER, I-PER, O, B-LOC, O, O, O]","[CD, NNP, NNP, (, NNPS, ), CD, NNS]","[1., Chandra, Sturrup, (, Bahamas, ), 11.34, seconds]"
4,"[B-NP, I-NP, I-NP, I-NP, I-NP, I-NP, I-NP, I-NP]",10339,"[B-ORG, I-ORG, I-ORG, O, B-ORG, I-ORG, I-ORG, O]","[NNP, NNP, NNP, CD, NNP, NNP, NNP, CD]","[Hapoel, Kfar, Sava, 0, Hapoel, Zafririm, Holon, 1]"
5,"[B-PP, B-NP, I-NP, I-NP, O, B-NP, B-VP, B-NP, I-NP, B-PP, B-NP, B-NP, B-VP, B-ADJP, B-VP, I-VP, I-VP, I-VP, O, B-SBAR, B-VP, B-PP, B-NP, I-NP, I-NP, B-VP, B-PP, B-NP, I-NP, O]",9907,"[O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, O]","[IN, DT, NN, NN, ,, NNS, VBP, DT, NN, IN, WDT, PRP, VBP, JJ, TO, VB, CC, VB, ,, IN, VBN, TO, DT, NN, NN, VBN, IN, VBG, NNS, .]","[In, a, limit, order, ,, investors, specify, the, price, at, which, they, are, willing, to, buy, or, sell, ,, as, opposed, to, a, market, order, executed, at, prevailing, prices, .]"
6,[B-NP],402,[O],[CD],[7-5]
7,"[B-NP, I-NP, I-NP]",760,"[B-MISC, I-MISC, I-MISC]","[NNP, NNP, NNP]","[Major, League, Baseball]"
8,"[B-NP, I-NP, I-NP, I-NP, O, B-NP, I-NP, O, B-NP, I-NP]",12429,"[O, O, O, O, O, O, O, O, O, O]","[NNP, NNP, NNP, NNP, SYM, NN, NNP, SYM, DT, NN]","[W, D, L, G, /, F, G, /, A, P]"
9,"[B-NP, I-NP, O, B-NP, O, B-VP, B-NP, I-NP, O, B-NP, O, B-NP, I-NP]",10323,"[B-PER, I-PER, O, B-LOC, O, O, B-PER, I-PER, O, B-LOC, O, O, O]","[NNP, NNP, (, NNP, ), VB, NNP, NNP, (, NNP, ), CD, JJ]","[Amr, Shabana, (, Egypt, ), beat, John, White, (, Australia, ), 10-15, 15-9]"


## Preprocessing for Named Entity Recognition

### Tokenization

In [7]:
from transformers import AutoTokenizer
    
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=442.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=231508.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=466062.0, style=ProgressStyle(descripti…




In [8]:
# Let's see what tokenizer does
tokenizer("Hello, this is one sentence!")

{'input_ids': [101, 7592, 1010, 2023, 2003, 2028, 6251, 999, 102], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1]}

### What is subword tokenization? 

Subtokens are used in modern approaches in lieu of stemming and lemmatization. 

Since it is hard to represent every possible state of a word, like:
- head -> token id: 1
- hunt -> token id: 2
- hunter -> token id: 3
- headhunter -> token id: 4

We instead do this:
- head -> token id: 1
- hunt -> token id: 2
- -er -> token id: 3
- headhunter -> token ids: 1 2 3 

This way, it is easier to represent compound words and words with additions. Especially in Turkish language, additions of a word is a huge issue.

Techniques like Byte-Pair-Encoding is also utilized when we want to be language agnostic, and learn our tokens from data, in an unsupervised way.


.

.

Note that transformers are often pretrained with subword tokenizers, meaning that even if your inputs have been split into words already, each of those words could be split again by the tokenizer. Let's look at an example of that:

In [9]:
example = datasets["train"][4]
print(example["tokens"])

['Germany', "'s", 'representative', 'to', 'the', 'European', 'Union', "'s", 'veterinary', 'committee', 'Werner', 'Zwingmann', 'said', 'on', 'Wednesday', 'consumers', 'should', 'buy', 'sheepmeat', 'from', 'countries', 'other', 'than', 'Britain', 'until', 'the', 'scientific', 'advice', 'was', 'clearer', '.']


In [10]:
tokenized_input = tokenizer(example["tokens"], is_split_into_words=True)
tokens = tokenizer.convert_ids_to_tokens(tokenized_input["input_ids"])
print(tokens)

['[CLS]', 'germany', "'", 's', 'representative', 'to', 'the', 'european', 'union', "'", 's', 'veterinary', 'committee', 'werner', 'z', '##wing', '##mann', 'said', 'on', 'wednesday', 'consumers', 'should', 'buy', 'sheep', '##me', '##at', 'from', 'countries', 'other', 'than', 'britain', 'until', 'the', 'scientific', 'advice', 'was', 'clearer', '.', '[SEP]']


Here the words "Zwingmann" and "sheepmeat" have been split in three subtokens.

However, we do not have three labels for each subword token of Zwingmann, like: 

`Z -> [PER] ##wing -> [PER] ##mann -> [PER]`

("##" is used when we want to sign that the string is not an original word, instead, it is a non-first subword of an original word.)

Instead, we have only one label for the whole word: 

`Zwingmann -> [PER]`

This means that we need to do some processing on our labels as the input ids returned by the tokenizer are longer than the lists of labels our dataset contain, first because some special tokens might be added (we can a `[CLS]` and a `[SEP]` above) and then because of those possible splits of words in multiple tokens.

Thankfully, the tokenizer returns outputs that have a `word_ids` method which can help us.

In [11]:
print(tokenized_input.word_ids())

[None, 0, 1, 1, 2, 3, 4, 5, 6, 7, 7, 8, 9, 10, 11, 11, 11, 12, 13, 14, 15, 16, 17, 18, 18, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, None]


As we can see, it returns a list with the same number of elements as our processed input ids, mapping special tokens to `None` and all other tokens to their respective word. This way, we can align the labels with the processed input ids.

### Aligning subwords with word labels

In [12]:
word_ids = tokenized_input.word_ids()
aligned_labels = [-100 if i is None else example[f"{task}_tags"][i] for i in word_ids]
print(len(aligned_labels), len(tokenized_input["input_ids"]))

39 39


Here we set the labels of all special tokens to -100 (the index that is ignored by PyTorch) and the labels of all other tokens to the label of the word they come from.

We're now ready to write the function that will preprocess our samples. We feed them to the `tokenizer` with the argument `truncation=True` (to truncate texts that are bigger than the maximum size allowed by the model) and `is_split_into_words=True` (as seen above). Then we align the labels with the token ids using the strategy we picked:

In [13]:
#If you wonder what `label_all_tokens` is, go to original notebook, cited at the top of this notebook.
# I intentionally removed the information on it to reduce the complexity of the notebook.
label_all_tokens = True

def tokenize_and_align_labels(examples):
    tokenized_inputs = tokenizer(examples["tokens"], truncation=True, is_split_into_words=True)

    labels = []
    for i, label in enumerate(examples[f"{task}_tags"]):
        word_ids = tokenized_inputs.word_ids(batch_index=i)
        previous_word_idx = None
        label_ids = []
        for word_idx in word_ids:
            # Special tokens have a word id that is None. We set the label to -100 so they are automatically
            # ignored in the loss function.
            if word_idx is None:
                label_ids.append(-100)
            # We set the label for the first token of each word.
            elif word_idx != previous_word_idx:
                label_ids.append(label[word_idx])
            # For the other tokens in a word, we set the label to either the current label or -100, depending on
            # the label_all_tokens flag.
            else:
                label_ids.append(label[word_idx] if label_all_tokens else -100)
            previous_word_idx = word_idx

        labels.append(label_ids)

    tokenized_inputs["labels"] = labels
    return tokenized_inputs

In [14]:
tokenize_and_align_labels(datasets['train'][:5])

{'input_ids': [[101, 7327, 19164, 2446, 2655, 2000, 17757, 2329, 12559, 1012, 102], [101, 2848, 13934, 102], [101, 9371, 2727, 1011, 5511, 1011, 2570, 102], [101, 1996, 2647, 3222, 2056, 2006, 9432, 2009, 18335, 2007, 2446, 6040, 2000, 10390, 2000, 18454, 2078, 2329, 12559, 2127, 6529, 5646, 3251, 5506, 11190, 4295, 2064, 2022, 11860, 2000, 8351, 1012, 102], [101, 2762, 1005, 1055, 4387, 2000, 1996, 2647, 2586, 1005, 1055, 15651, 2837, 14121, 1062, 9328, 5804, 2056, 2006, 9317, 10390, 2323, 4965, 8351, 4168, 4017, 2013, 3032, 2060, 2084, 3725, 2127, 1996, 4045, 6040, 2001, 24509, 1012, 102]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], 'labels': [[-100, 3, 0, 7, 0, 0, 0, 7, 0, 0, -100], [-100, 1, 2, -100], [-100, 5, 0, 

In [15]:
tokenized_datasets = datasets.map(tokenize_and_align_labels, batched=True)

HBox(children=(FloatProgress(value=0.0, max=15.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))




## Fine-tuning the NER model

In [16]:
from transformers import AutoModelForTokenClassification, TrainingArguments, Trainer

model = AutoModelForTokenClassification.from_pretrained(model_checkpoint, num_labels=len(label_list))

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=267967963.0, style=ProgressStyle(descri…




Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForTokenClassification: ['vocab_transform.weight', 'vocab_transform.bias', 'vocab_layer_norm.weight', 'vocab_layer_norm.bias', 'vocab_projector.weight', 'vocab_projector.bias']
- This IS expected if you are initializing DistilBertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of DistilBertForTokenClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN t

In [17]:
args = TrainingArguments(
    f"test-{task}",
    evaluation_strategy = "epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=3,
    weight_decay=0.01,
)

In [18]:
from transformers import DataCollatorForTokenClassification

data_collator = DataCollatorForTokenClassification(tokenizer)

In [19]:
!pip install seqeval
metric = load_metric("seqeval")

Collecting seqeval
[?25l  Downloading https://files.pythonhosted.org/packages/9d/2d/233c79d5b4e5ab1dbf111242299153f3caddddbb691219f363ad55ce783d/seqeval-1.2.2.tar.gz (43kB)
[K     |███████▌                        | 10kB 27.1MB/s eta 0:00:01[K     |███████████████                 | 20kB 15.3MB/s eta 0:00:01[K     |██████████████████████▌         | 30kB 13.4MB/s eta 0:00:01[K     |██████████████████████████████  | 40kB 12.5MB/s eta 0:00:01[K     |████████████████████████████████| 51kB 4.5MB/s 
Building wheels for collected packages: seqeval
  Building wheel for seqeval (setup.py) ... [?25l[?25hdone
  Created wheel for seqeval: filename=seqeval-1.2.2-cp36-none-any.whl size=16171 sha256=beebc3d1f983c34dd605fdb27d7c9bd62649498e1fb17f21feca623528fdc462
  Stored in directory: /root/.cache/pip/wheels/52/df/1b/45d75646c37428f7e626214704a0e35bd3cfc32eda37e59e5f
Successfully built seqeval
Installing collected packages: seqeval
Successfully installed seqeval-1.2.2


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1753.0, style=ProgressStyle(description…




This metric takes list of labels for the predictions and references:

In [20]:
labels = [label_list[i] for i in example[f"{task}_tags"]]
metric.compute(predictions=[labels], references=[labels])

{'LOC': {'f1': 1.0, 'number': 2, 'precision': 1.0, 'recall': 1.0},
 'ORG': {'f1': 1.0, 'number': 1, 'precision': 1.0, 'recall': 1.0},
 'PER': {'f1': 1.0, 'number': 1, 'precision': 1.0, 'recall': 1.0},
 'overall_accuracy': 1.0,
 'overall_f1': 1.0,
 'overall_precision': 1.0,
 'overall_recall': 1.0}

So we will need to do a bit of post-processing on our predictions:
- select the predicted index (with the maximum logit) for each token
- convert it to its string label
- ignore everywhere we set a label of -100

The following function does all this post-processing on the result of `Trainer.evaluate` (which is a namedtuple containing predictions and labels) before applying the metric:

In [21]:
import numpy as np

def compute_metrics(p):
    predictions, labels = p
    predictions = np.argmax(predictions, axis=2)

    # Remove ignored index (special tokens)
    true_predictions = [
        [label_list[p] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]
    true_labels = [
        [label_list[l] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]

    results = metric.compute(predictions=true_predictions, references=true_labels)
    return {
        "precision": results["overall_precision"],
        "recall": results["overall_recall"],
        "f1": results["overall_f1"],
        "accuracy": results["overall_accuracy"],
    }

Note that we drop the precision/recall/f1 computed for each category and only focus on the overall precision/recall/f1/accuracy.

Then we just need to pass all of this along with our datasets to the `Trainer`:

In [22]:
trainer = Trainer(
    model,
    args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

We can now finetune our model by just calling the `train` method:

In [23]:
trainer.train()

Epoch,Training Loss,Validation Loss,Precision,Recall,F1,Accuracy
1,0.242019,0.06918,0.906384,0.92281,0.914523,0.98065
2,0.052258,0.060312,0.923673,0.932767,0.928198,0.983017
3,0.029636,0.060102,0.92835,0.939255,0.933771,0.984209


TrainOutput(global_step=2634, training_loss=0.08512135808361056)

## Evaluation of the NER model

The `evaluate` method allows you to evaluate again on the evaluation dataset or on another dataset:

In [24]:
trainer.evaluate()

{'epoch': 3.0,
 'eval_accuracy': 0.9842089376777289,
 'eval_f1': 0.9337707835177668,
 'eval_loss': 0.060101624578237534,
 'eval_precision': 0.9283502874834144,
 'eval_recall': 0.9392549502181452}

To get the precision/recall/f1 computed for each category now that we have finished training, we can apply the same function as before on the result of the `predict` method:

In [25]:
predictions, labels, _ = trainer.predict(tokenized_datasets["validation"])
predictions = np.argmax(predictions, axis=2)

# Remove ignored index (special tokens)
true_predictions = [
    [label_list[p] for (p, l) in zip(prediction, label) if l != -100]
    for prediction, label in zip(predictions, labels)
]
true_labels = [
    [label_list[l] for (p, l) in zip(prediction, label) if l != -100]
    for prediction, label in zip(predictions, labels)
]

results = metric.compute(predictions=true_predictions, references=true_labels)
results

{'LOC': {'f1': 0.9560606060606062,
  'number': 2618,
  'precision': 0.9481592787377912,
  'recall': 0.9640947288006112},
 'MISC': {'f1': 0.8214429665457478,
  'number': 1231,
  'precision': 0.8152,
  'recall': 0.8277822908204712},
 'ORG': {'f1': 0.9069373942470389,
  'number': 2056,
  'precision': 0.9014896684286401,
  'recall': 0.9124513618677043},
 'PER': {'f1': 0.9784716516023007,
  'number': 3034,
  'precision': 0.9757456571615863,
  'recall': 0.9812129202373104},
 'overall_accuracy': 0.9842089376777289,
 'overall_f1': 0.9337707835177668,
 'overall_precision': 0.9283502874834144,
 'overall_recall': 0.9392549502181452}