*Copyright (c) Microsoft Corporation. All rights reserved.*  
*Licensed under the MIT License.*

# Named Entity Recognition Using BERT

# Before You Start

The running time shown in this notebook is on a Standard_NC6 Azure Deep Learning Virtual Machine with 1 NVIDIA Tesla K80 GPU. 
> **Tip**: If you want to run through the notebook quickly, you can set the **`QUICK_RUN`** flag in the cell below to **`True`** to run the notebook on a small subset of the data and a smaller number of epochs. 

The table below provides some reference running time on different machine configurations.  

|QUICK_RUN|Machine Configurations|Running time|
|:---------|:----------------------|:------------|
|True|4 **CPU**s, 14GB memory| ~ 2 minutes|
|False|4 **CPU**s, 14GB memory| ~1.5 hours|
|True|1 NVIDIA Tesla K80 GPUs, 12GB GPU memory| ~ 1 minute|
|False|1 NVIDIA Tesla K80 GPUs, 12GB GPU memory| ~ 7 minutes |

If you run into CUDA out-of-memory error or the jupyter kernel dies constantly, try reducing the `BATCH_SIZE` and `MAX_SEQ_LENGTH`, but note that model performance will be compromised. 

In [1]:
## Set QUICK_RUN = True to run the notebook on a small subset of data and a smaller number of epochs.
QUICK_RUN = True

## Summary
This notebook demonstrates how to fine tune [pretrained BERT model](https://github.com/huggingface/pytorch-pretrained-BERT) for named entity recognition (NER) task. Utility functions and classes in the NLP Best Practices repo are used to facilitate data preprocessing, model training, model scoring, and model evaluation. 

[BERT (Bidirectional Transformers for Language Understanding)](https://arxiv.org/pdf/1810.04805.pdf) is a powerful pre-trained lanaguage model that can be used for multiple NLP tasks, including text classification, question answering, named entity recognition, etc. It's able to achieve state of the art performance with only a few epochs of fine tuning on task specific datasets.  
The figure below illustrates how BERT can be fine tuned for NER tasks. The input data is a list of tokens representing a sentence. In the training data, each token has an entity label. After fine tuning, the model predicts an entity label for each token in a given testing sentence. 

<img src="https://nlpbp.blob.core.windows.net/images/bert_architecture.png">

In [2]:
import sys
import os
import random
import scrapbook as sb
from seqeval.metrics import classification_report

import torch

nlp_path = os.path.abspath('../../')
if nlp_path not in sys.path:
    sys.path.insert(0, nlp_path)

from utils_nlp.models.bert.token_classification import BERTTokenClassifier, create_label_map, postprocess_token_labels
from utils_nlp.models.bert.common import Language, Tokenizer
from utils_nlp.dataset.wikigold import load_train_test_dfs, get_unique_labels
from utils_nlp.common.timer import Timer

Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.


In [3]:
import pandas as pd
import numpy as np
import spacy
import re

## Configurations

In [4]:
TRAIN_DATA_FRACTION = 1
TEST_DATA_FRACTION = 1
NUM_TRAIN_EPOCHS = 2

if QUICK_RUN:
    TRAIN_DATA_FRACTION = 0.1
    TEST_DATA_FRACTION = 0.1
    NUM_TRAIN_EPOCHS = 1

if torch.cuda.is_available():
    BATCH_SIZE = 16
else:
    BATCH_SIZE = 8

CACHE_DIR="./temp"

# set random seeds
RANDOM_SEED = 100
torch.manual_seed(RANDOM_SEED)

# model configurations
LANGUAGE = Language.ENGLISHCASED
DO_LOWER_CASE = False
MAX_SEQ_LENGTH = 200

# optimizer configuration
LEARNING_RATE = 3e-5

# data configurations
TEXT_COL = "sentence"
LABELS_COL = "label"

## Preprocess Data

### Get training and testing data
The dataset used in this notebook is the [wikigold dataset](https://www.aclweb.org/anthology/W09-3302). The wikigold dataset consists of 145 mannually labelled Wikipedia articles, including 1841 sentences and 40k tokens in total. The dataset can be directly downloaded from [here](https://github.com/juand-r/entity-recognition-datasets/tree/master/data/wikigold). 

The helper function `load_train_test_dfs` downloads the data file if it doesn't exist in `local_cache_path`. It splits the dataset into training and testing sets according to `test_fraction`. Because this is a relatively small dataset, we set `test_fraction` to 0.5 in order to have enough data for model evaluation. Running this notebook multiple times with different random seeds produces similar results.   

The helper function `get_unique_labels` returns the unique entity labels in the dataset. There are 5 unique labels in the   original dataset: 'O' (non-entity), 'I-LOC' (location), 'I-MISC' (miscellaneous), 'I-PER' (person), and 'I-ORG' (organization). 

The maximum number of words in a sentence is 144, so we set MAX_SEQ_LENGTH to 200 above, because the number of tokens will grow after WordPiece tokenization.

## read test and train daat

In [5]:
def get_data(input_file_path):
#     input_file_path = '/home/fractaluser/NER/data/Genia4ERtest/Genia4EReval1.iob2'
    nlp = spacy.load('en_core_web_sm') 
    data = open(input_file_path,'r').read()
    data = re.split(r'\n\n',data)
    data = list(filter(None,data))
    new_data = list()
    crf_data = pd.DataFrame()
    for sentence in data:
        if sentence !='':
            word_tag_list = list(map(lambda x:re.split(r'\t',x),re.split(r'\n',sentence)))
            new_data.append(list(zip(*[i for i in word_tag_list if i!=['']])))
    sentence_labelled = pd.DataFrame(columns = ['sentence', 'label'])

    for i in new_data:
        sentence_labelled = sentence_labelled.append({ 'sentence' : list(i[0]), 'label' : list(i[1])}, ignore_index = True) 
    return sentence_labelled


In [6]:
train_df = get_data('/home/NER/data/Genia4ERtraining/Genia4ERtask1.txt')
test_df = get_data('/home/NER/data/Genia4ERtest/Genia4EReval1.txt')

In [7]:
train_df.head()

Unnamed: 0,sentence,label
0,"[IL-2, gene, expression, and, NF-kappa, B, act...","[B-DNA, I-DNA, O, O, B-protein, I-protein, O, ..."
1,"[Activation, of, the, CD28, surface, receptor,...","[O, O, O, B-protein, I-protein, I-protein, O, ..."
2,"[In, primary, T, lymphocytes, we, show, that, ...","[O, B-cell_type, I-cell_type, I-cell_type, O, ..."
3,"[Delineation, of, the, CD28, signaling, cascad...","[O, O, O, B-protein, O, O, O, O, O, O, B-prote..."
4,"[Our, data, suggest, that, lipoxygenase, metab...","[O, O, O, O, B-protein, I-protein, O, O, O, O,..."


In [8]:
test_df.head()

Unnamed: 0,sentence,label
0,"[Number, of, glucocorticoid, receptors, in, ly...","[O, O, B-protein, I-protein, O, B-cell_type, O..."
1,"[The, study, demonstrated, a, decreased, level...","[O, O, O, O, O, O, O, B-protein, I-protein, O,..."
2,"[In, the, lymphocytes, with, a, high, GR, numb...","[O, O, B-cell_type, O, O, O, B-protein, O, O, ..."
3,"[On, the, other, hand, ,, a, decreased, GR, nu...","[O, O, O, O, O, O, O, B-protein, O, O, O, O, O..."
4,"[These, data, showed, that, the, sensitivity, ...","[O, O, O, O, O, O, O, B-cell_type, O, O, O, O,..."


In [9]:
train_df.shape, test_df.shape

((18546, 2), (3856, 2))

In [10]:
# train_df, test_df = load_train_test_dfs(local_cache_path=CACHE_DIR, test_fraction=0.5,random_seed=RANDOM_SEED)
label_list = ['O', 'B-protein', 'I-protein', 'B-DNA', 'I-DNA', 'B-cell_type',
       'I-cell_type', 'B-cell_line', 'I-cell_line', 'B-RNA', 'I-RNA']
print('\nUnique entity labels: \n{}\n'.format(label_list))
print('Sample sentence: \n{}\n'.format(train_df[TEXT_COL][0]))
print('Sample sentence labels: \n{}\n'.format(train_df[LABELS_COL][0]))


Unique entity labels: 
['O', 'B-protein', 'I-protein', 'B-DNA', 'I-DNA', 'B-cell_type', 'I-cell_type', 'B-cell_line', 'I-cell_line', 'B-RNA', 'I-RNA']

Sample sentence: 
['IL-2', 'gene', 'expression', 'and', 'NF-kappa', 'B', 'activation', 'through', 'CD28', 'requires', 'reactive', 'oxygen', 'production', 'by', '5-lipoxygenase', '.']

Sample sentence labels: 
['B-DNA', 'I-DNA', 'O', 'O', 'B-protein', 'I-protein', 'O', 'O', 'B-protein', 'O', 'O', 'O', 'O', 'O', 'B-protein', 'O']



In [11]:
train_df.head()

Unnamed: 0,sentence,label
0,"[IL-2, gene, expression, and, NF-kappa, B, act...","[B-DNA, I-DNA, O, O, B-protein, I-protein, O, ..."
1,"[Activation, of, the, CD28, surface, receptor,...","[O, O, O, B-protein, I-protein, I-protein, O, ..."
2,"[In, primary, T, lymphocytes, we, show, that, ...","[O, B-cell_type, I-cell_type, I-cell_type, O, ..."
3,"[Delineation, of, the, CD28, signaling, cascad...","[O, O, O, B-protein, O, O, O, O, O, O, B-prote..."
4,"[Our, data, suggest, that, lipoxygenase, metab...","[O, O, O, O, B-protein, I-protein, O, O, O, O,..."


In [12]:
train_df = train_df.sample(frac=TRAIN_DATA_FRACTION).reset_index(drop=True)
test_df = test_df.sample(frac=TEST_DATA_FRACTION).reset_index(drop=True)

In [13]:
train_df.iloc[0]

sentence    [IL-2R, chains, were, measured, by, flow, cyto...
label       [O, O, O, O, O, O, O, O, O, B-protein, O, O, O...
Name: 0, dtype: object

In [14]:
train_df.shape, test_df.shape

((1855, 2), (386, 2))

**Note that the input text are lists of words instead of raw sentences. This format ensures matching between input words and token labels when the words are further tokenized by Tokenizer.tokenize_ner.**

### Tokenization and Preprocessing


**Create a dictionary that maps labels to numerical values**  
Note there is an argument called `trailing_piece_tag`. BERT uses a WordPiece tokenizer which breaks down some words into multiple tokens, e.g. "criticize" is tokenized into "critic" and "##ize". Since the input data only come with one token label for "criticize", within Tokenizer.prerocess_ner_tokens, the original token label is assigned to the first token "critic" and the second token "##ize" is labeled as "X". By default, `trailing_piece_tag` is set to "X". If "X" already exists in your data, you can set `trailing_piece_tag` to another value that doesn't exist in your data.

In [15]:
label_map = create_label_map(label_list, trailing_piece_tag="X")

In [16]:
label_list

['O',
 'B-protein',
 'I-protein',
 'B-DNA',
 'I-DNA',
 'B-cell_type',
 'I-cell_type',
 'B-cell_line',
 'I-cell_line',
 'B-RNA',
 'I-RNA']

In [17]:
label_map

{'O': 0,
 'B-protein': 1,
 'I-protein': 2,
 'B-DNA': 3,
 'I-DNA': 4,
 'B-cell_type': 5,
 'I-cell_type': 6,
 'B-cell_line': 7,
 'I-cell_line': 8,
 'B-RNA': 9,
 'I-RNA': 10,
 'X': 11}

**Create a tokenizer**

In [18]:
tokenizer = Tokenizer(language=LANGUAGE, 
                      to_lower=DO_LOWER_CASE, 
                      cache_dir=CACHE_DIR)

**Tokenize and preprocess text**  
The `tokenize_ner` method of the `Tokenizer` class converts text and labels in strings to numerical features, involving the following steps:
1. WordPiece tokenization.
2. Convert tokens and labels to numerical values, i.e. token ids and label ids.
3. Sequence padding or truncation according to the `max_seq_length` configuration.

In [19]:
train_token_ids, train_input_mask, train_trailing_token_mask, train_label_ids = \
    tokenizer.tokenize_ner(text=train_df[TEXT_COL],
                           label_map=label_map,
                           max_len=MAX_SEQ_LENGTH,
                           labels=train_df[LABELS_COL],
                           trailing_piece_tag="X")
test_token_ids, test_input_mask, test_trailing_token_mask, test_label_ids = \
    tokenizer.tokenize_ner(text=test_df[TEXT_COL],
                           label_map=label_map,
                           max_len=MAX_SEQ_LENGTH,
                           labels=test_df[LABELS_COL],
                           trailing_piece_tag="X")

`Tokenizer.tokenize_ner` outputs three or four lists of numerical features lists, each sublist contains features of an input sentence: 
1. token ids: list of numerical values each corresponds to a token.
2. attention mask: list of 1s and 0s, 1 for input tokens and 0 for padded tokens, so that padded tokens are not attended to. 
3. trailing word piece mask: boolean list, `True` for the first word piece of each original word, `False` for the trailing word pieces, e.g. ##ize. This mask is useful for removing predictions on trailing word pieces, so that each original word in the input text has a unique predicted label. 
4. label ids: list of numerical values each corresponds to an entity label, if `labels` is provided.

In [20]:
print("Sample token ids:\n{}\n".format(train_token_ids[0]))
print("Sample attention mask:\n{}\n".format(train_input_mask[0]))
print("Sample trailing token mask:\n{}\n".format(train_trailing_token_mask[0]))
print("Sample label ids:\n{}\n".format(train_label_ids[0]))

Sample token ids:
[15393, 118, 123, 2069, 9236, 1127, 7140, 1118, 4235, 172, 25669, 6758, 6013, 117, 1105, 147, 27843, 120, 1457, 9971, 1118, 15059, 1202, 2007, 3457, 1233, 28117, 9654, 2193, 185, 23415, 7409, 19944, 11787, 2007, 27426, 24266, 7880, 12238, 1548, 113, 19416, 1708, 118, 8544, 16523, 114, 1105, 2102, 171, 7841, 119, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0

## Create Token Classifier
The value of the `language` argument determines which BERT model is used:
* Language.ENGLISH: "bert-base-uncased"
* Language.ENGLISHCASED: "bert-base-cased"
* Language.ENGLISHLARGE: "bert-large-uncased"
* Language.ENGLISHLARGECASED: "bert-large-cased"
* Language.CHINESE: "bert-base-chinese"
* Language.MULTILINGUAL: "bert-base-multilingual-cased"
* Language.ENGLISHLARGEWWM: "bert-large-uncased-whole-word-masking"
* Language.ENGLISHLARGECASEDWWM: "bert-large-cased-whole-word-masking"

Here we use the base, cased pretrained model.

In [21]:
token_classifier = BERTTokenClassifier(language=LANGUAGE,
                                       num_labels=len(label_map),
                                       cache_dir=CACHE_DIR)

In [22]:
NUM_TRAIN_EPOCHS = 7

## Train Model

In [23]:
with Timer() as t:
    token_classifier.fit(token_ids=train_token_ids, 
                         input_mask=train_input_mask, 
                         labels=train_label_ids,
                             num_epochs=NUM_TRAIN_EPOCHS, 
                         batch_size=BATCH_SIZE, 
                         learning_rate=LEARNING_RATE)
print("Training time : {:.3f} hrs".format(t.interval / 3600))

Epoch:   0%|          | 0/7 [00:00<?, ?it/s]
Iteration:   0%|          | 0/232 [00:00<?, ?it/s][A
Iteration:   1%|▏         | 3/232 [00:30<39:07, 10.25s/it][A
Iteration:   1%|▏         | 3/232 [00:49<39:07, 10.25s/it][A
Iteration:   3%|▎         | 7/232 [01:09<37:42, 10.05s/it][A
Iteration:   3%|▎         | 7/232 [01:20<37:42, 10.05s/it][A
Iteration:   4%|▍         | 10/232 [01:39<37:12, 10.06s/it][A
Iteration:   4%|▍         | 10/232 [01:50<37:12, 10.06s/it][A
Iteration:   6%|▌         | 14/232 [02:18<36:10,  9.96s/it][A
Iteration:   6%|▌         | 14/232 [02:30<36:10,  9.96s/it][A
Iteration:   8%|▊         | 18/232 [02:57<35:25,  9.93s/it][A
Iteration:   8%|▊         | 18/232 [03:10<35:25,  9.93s/it][A
Iteration:   9%|▉         | 22/232 [03:36<34:35,  9.88s/it][A
Iteration:   9%|▉         | 22/232 [03:50<34:35,  9.88s/it][A
Iteration:  11%|█         | 26/232 [04:15<33:48,  9.85s/it][A
Iteration:  11%|█         | 26/232 [04:30<33:48,  9.85s/it][A
Iteration:  12%|█▎     

Train loss: 0.18479028641214146



Iteration:   2%|▏         | 4/232 [00:39<37:06,  9.77s/it][A
Iteration:   2%|▏         | 4/232 [00:55<37:06,  9.77s/it][A
Iteration:   3%|▎         | 8/232 [01:17<36:25,  9.76s/it][A
Iteration:   3%|▎         | 8/232 [01:35<36:25,  9.76s/it][A
Iteration:   5%|▌         | 12/232 [01:57<35:49,  9.77s/it][A
Iteration:   5%|▌         | 12/232 [02:15<35:49,  9.77s/it][A
Iteration:   7%|▋         | 16/232 [02:36<35:13,  9.79s/it][A
Iteration:   7%|▋         | 16/232 [02:55<35:13,  9.79s/it][A
Iteration:   9%|▊         | 20/232 [03:15<34:33,  9.78s/it][A
Iteration:   9%|▊         | 20/232 [03:25<34:33,  9.78s/it][A
Iteration:  10%|█         | 24/232 [03:54<33:53,  9.78s/it][A
Iteration:  10%|█         | 24/232 [04:05<33:53,  9.78s/it][A
Iteration:  12%|█▏        | 28/232 [04:33<33:13,  9.77s/it][A
Iteration:  12%|█▏        | 28/232 [04:45<33:13,  9.77s/it][A
Iteration:  14%|█▍        | 32/232 [05:12<32:33,  9.77s/it][A
Iteration:  14%|█▍        | 32/232 [05:25<32:33,  9.77s/it

Train loss: 0.07510603961117308



Iteration:   2%|▏         | 4/232 [00:39<37:17,  9.82s/it][A
Iteration:   2%|▏         | 4/232 [00:51<37:17,  9.82s/it][A
Iteration:   3%|▎         | 8/232 [01:18<36:34,  9.80s/it][A
Iteration:   3%|▎         | 8/232 [01:31<36:34,  9.80s/it][A
Iteration:   5%|▌         | 12/232 [01:57<35:58,  9.81s/it][A
Iteration:   5%|▌         | 12/232 [02:11<35:58,  9.81s/it][A
Iteration:   7%|▋         | 16/232 [02:37<35:20,  9.82s/it][A
Iteration:   7%|▋         | 16/232 [02:51<35:20,  9.82s/it][A
Iteration:   9%|▊         | 20/232 [03:16<34:40,  9.81s/it][A
Iteration:   9%|▊         | 20/232 [03:31<34:40,  9.81s/it][A
Iteration:  10%|█         | 24/232 [03:55<34:02,  9.82s/it][A
Iteration:  10%|█         | 24/232 [04:11<34:02,  9.82s/it][A
Iteration:  12%|█▏        | 28/232 [04:34<33:22,  9.81s/it][A
Iteration:  12%|█▏        | 28/232 [04:51<33:22,  9.81s/it][A
Iteration:  14%|█▍        | 32/232 [05:14<32:43,  9.82s/it][A
Iteration:  14%|█▍        | 32/232 [05:31<32:43,  9.82s/it

Train loss: 0.0439340058384977



Iteration:   2%|▏         | 4/232 [00:39<37:11,  9.79s/it][A
Iteration:   2%|▏         | 4/232 [00:55<37:11,  9.79s/it][A
Iteration:   3%|▎         | 8/232 [01:18<36:31,  9.78s/it][A
Iteration:   3%|▎         | 8/232 [01:35<36:31,  9.78s/it][A
Iteration:   5%|▌         | 12/232 [01:57<35:53,  9.79s/it][A
Iteration:   5%|▌         | 12/232 [02:15<35:53,  9.79s/it][A
Iteration:   7%|▋         | 16/232 [02:36<35:14,  9.79s/it][A
Iteration:   7%|▋         | 16/232 [02:55<35:14,  9.79s/it][A
Iteration:   9%|▊         | 20/232 [03:16<34:42,  9.82s/it][A
Iteration:   9%|▊         | 20/232 [03:35<34:42,  9.82s/it][A
Iteration:  10%|█         | 24/232 [03:55<34:03,  9.82s/it][A
Iteration:  10%|█         | 24/232 [04:15<34:03,  9.82s/it][A
Iteration:  12%|█▏        | 28/232 [04:34<33:20,  9.80s/it][A
Iteration:  12%|█▏        | 28/232 [04:45<33:20,  9.80s/it][A
Iteration:  14%|█▍        | 32/232 [05:13<32:40,  9.80s/it][A
Iteration:  14%|█▍        | 32/232 [05:25<32:40,  9.80s/it

Train loss: 0.02909880417904913



Iteration:   2%|▏         | 4/232 [00:39<37:25,  9.85s/it][A
Iteration:   2%|▏         | 4/232 [00:50<37:25,  9.85s/it][A
Iteration:   3%|▎         | 8/232 [01:18<36:38,  9.82s/it][A
Iteration:   3%|▎         | 8/232 [01:30<36:38,  9.82s/it][A
Iteration:   5%|▌         | 12/232 [01:57<35:58,  9.81s/it][A
Iteration:   5%|▌         | 12/232 [02:10<35:58,  9.81s/it][A
Iteration:   7%|▋         | 16/232 [02:36<35:20,  9.82s/it][A
Iteration:   7%|▋         | 16/232 [02:50<35:20,  9.82s/it][A
Iteration:   9%|▊         | 20/232 [03:15<34:38,  9.80s/it][A
Iteration:   9%|▊         | 20/232 [03:30<34:38,  9.80s/it][A
Iteration:  10%|█         | 24/232 [03:55<34:00,  9.81s/it][A
Iteration:  10%|█         | 24/232 [04:10<34:00,  9.81s/it][A
Iteration:  12%|█▏        | 28/232 [04:34<33:21,  9.81s/it][A
Iteration:  12%|█▏        | 28/232 [04:50<33:21,  9.81s/it][A
Iteration:  14%|█▍        | 32/232 [05:13<32:39,  9.80s/it][A
Iteration:  14%|█▍        | 32/232 [05:30<32:39,  9.80s/it

Train loss: 0.019611931382879165



Iteration:   2%|▏         | 4/232 [00:39<37:19,  9.82s/it][A
Iteration:   2%|▏         | 4/232 [00:55<37:19,  9.82s/it][A
Iteration:   3%|▎         | 8/232 [01:18<36:39,  9.82s/it][A
Iteration:   3%|▎         | 8/232 [01:35<36:39,  9.82s/it][A
Iteration:   5%|▌         | 12/232 [01:57<35:57,  9.81s/it][A
Iteration:   5%|▌         | 12/232 [02:15<35:57,  9.81s/it][A
Iteration:   7%|▋         | 16/232 [02:36<35:19,  9.81s/it][A
Iteration:   7%|▋         | 16/232 [02:55<35:19,  9.81s/it][A
Iteration:   9%|▊         | 20/232 [03:15<34:35,  9.79s/it][A
Iteration:   9%|▊         | 20/232 [03:35<34:35,  9.79s/it][A
Iteration:  10%|█         | 24/232 [03:55<33:55,  9.79s/it][A
Iteration:  10%|█         | 24/232 [04:05<33:55,  9.79s/it][A
Iteration:  12%|█▏        | 28/232 [04:33<33:13,  9.77s/it][A
Iteration:  12%|█▏        | 28/232 [04:45<33:13,  9.77s/it][A
Iteration:  14%|█▍        | 32/232 [05:13<32:35,  9.78s/it][A
Iteration:  14%|█▍        | 32/232 [05:25<32:35,  9.78s/it

Train loss: 0.014614704747459498



Iteration:   2%|▏         | 4/232 [00:39<37:17,  9.81s/it][A
Iteration:   2%|▏         | 4/232 [00:54<37:17,  9.81s/it][A
Iteration:   3%|▎         | 8/232 [01:18<36:38,  9.82s/it][A
Iteration:   3%|▎         | 8/232 [01:34<36:38,  9.82s/it][A
Iteration:   5%|▌         | 12/232 [01:57<36:00,  9.82s/it][A
Iteration:   5%|▌         | 12/232 [02:14<36:00,  9.82s/it][A
Iteration:   7%|▋         | 16/232 [02:37<35:21,  9.82s/it][A
Iteration:   7%|▋         | 16/232 [02:54<35:21,  9.82s/it][A
Iteration:   9%|▊         | 20/232 [03:16<34:39,  9.81s/it][A
Iteration:   9%|▊         | 20/232 [03:34<34:39,  9.81s/it][A
Iteration:  10%|█         | 24/232 [03:55<33:59,  9.80s/it][A
Iteration:  10%|█         | 24/232 [04:14<33:59,  9.80s/it][A
Iteration:  12%|█▏        | 28/232 [04:34<33:20,  9.81s/it][A
Iteration:  12%|█▏        | 28/232 [04:44<33:20,  9.81s/it][A
Iteration:  14%|█▍        | 32/232 [05:13<32:41,  9.81s/it][A
Iteration:  14%|█▍        | 32/232 [05:24<32:41,  9.81s/it

Train loss: 0.010886221540015962
Training time : 4.422 hrs





In [23]:
import pickle

In [24]:
Pkl_Filename = "temp/genia_bert.pkl" 

## uncomment this line to save the model

In [26]:
 
# with open(Pkl_Filename, 'wb') as file:  
#     pickle.dump(token_classifier, file)

In [25]:
loaded_model = pickle.load(open(Pkl_Filename, 'rb'))

In [26]:
len(['B-DNA', 'I-DNA', 'O', 'O', 'B-protein', 'I-protein', 'O', 'O', 'B-protein', 'O', 'O', 'O', 'O', 'O', 'B-protein', 'O'])

16

## code to test on a single sentence

In [27]:
test_token_ids, test_input_mask, test_trailing_token_mask, test_label_ids = \
    tokenizer.tokenize_ner(text=[['IL-2', 'gene', 'expression', 'and', 'NF-kappa', 'B', 'activation', 'through', 'CD28', 'requires', 'reactive', 'oxygen', 'production', 'by', '5-lipoxygenase', '.']],
                           label_map=label_map,
#                            max_len=16,
                           labels=[['B-DNA', 'I-DNA', 'O', 'O', 'B-protein', 'I-protein', 'O', 'O', 'B-protein', 'O', 'O', 'O', 'O', 'O', 'B-protein', 'O']],
                           trailing_piece_tag="X")

In [76]:
label_map

{'O': 0,
 'B-protein': 1,
 'I-protein': 2,
 'B-DNA': 3,
 'I-DNA': 4,
 'B-cell_type': 5,
 'I-cell_type': 6,
 'B-cell_line': 7,
 'I-cell_line': 8,
 'B-RNA': 9,
 'I-RNA': 10,
 'X': 11}

In [28]:
test_token_ids

[[15393,
  118,
  123,
  5565,
  2838,
  1105,
  151,
  2271,
  118,
  24181,
  13059,
  139,
  14915,
  1194,
  2891,
  24606,
  5315,
  26844,
  7621,
  1707,
  1118,
  126,
  118,
  4764,
  10649,
  1183,
  4915,
  6530,
  119,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,

In [29]:
test_input_mask

[[1.0,
  1.0,
  1.0,
  1.0,
  1.0,
  1.0,
  1.0,
  1.0,
  1.0,
  1.0,
  1.0,
  1.0,
  1.0,
  1.0,
  1.0,
  1.0,
  1.0,
  1.0,
  1.0,
  1.0,
  1.0,
  1.0,
  1.0,
  1.0,
  1.0,
  1.0,
  1.0,
  1.0,
  1.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,

In [30]:
test_trailing_token_mask

[[True,
  False,
  False,
  True,
  True,
  True,
  True,
  False,
  False,
  False,
  False,
  True,
  True,
  True,
  True,
  False,
  True,
  True,
  True,
  True,
  True,
  True,
  False,
  False,
  False,
  False,
  False,
  False,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  T

In [31]:
test_label_ids

[[3,
  11,
  11,
  4,
  0,
  0,
  1,
  11,
  11,
  11,
  11,
  2,
  0,
  0,
  1,
  11,
  0,
  0,
  0,
  0,
  0,
  1,
  11,
  11,
  11,
  11,
  11,
  11,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  

In [32]:
ids = loaded_model.predict(token_ids=test_token_ids, 
                                              input_mask=test_input_mask, 
                                              labels=test_label_ids, 
                                              batch_size=BATCH_SIZE)

Iteration: 100%|██████████| 1/1 [00:02<00:00,  2.01s/it]

Evaluation loss: 0.009510550647974014





In [33]:
ids

[[3,
  11,
  11,
  4,
  0,
  0,
  1,
  11,
  11,
  11,
  11,
  2,
  0,
  0,
  1,
  11,
  0,
  0,
  0,
  0,
  0,
  1,
  11,
  11,
  11,
  11,
  11,
  11,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  

In [66]:
label_to_text = dict()
for key, value in label_map.items():
    label_to_text[value] = key

In [68]:
[label_to_text[x] for x in ids[0][:16]]

['B-DNA',
 'X',
 'X',
 'I-DNA',
 'O',
 'O',
 'B-protein',
 'X',
 'X',
 'X',
 'X',
 'I-protein',
 'O',
 'O',
 'B-protein',
 'X']

## Predict on Test Data

In [28]:
with Timer() as t:
    pred_label_ids = loaded_model.predict(token_ids=test_token_ids, 
                                              input_mask=test_input_mask, 
                                              labels=test_label_ids, 
                                              batch_size=BATCH_SIZE)
print("Prediction time : {:.3f} hrs".format(t.interval / 3600))

Iteration: 100%|██████████| 49/49 [02:11<00:00,  2.68s/it]

Evaluation loss: 0.2593512408891503
Prediction time : 0.036 hrs





## Evaluate Model
The `predict` method of the token classifier outputs label ids for all tokens, including the padded tokens. `postprocess_token_labels` is a helper function that removes the predictions on padded tokens. If a `label_map` is provided, it maps the numerical label ids back to original token labels which are usually string type. 

In [55]:
pred_tags_no_padding = postprocess_token_labels(pred_label_ids, 
                                                test_input_mask, 
                                                label_map)
true_tags_no_padding = postprocess_token_labels(test_label_ids, 
                                                test_input_mask, 
                                                label_map)
report_no_padding = classification_report(true_tags_no_padding, 
                                          pred_tags_no_padding, 
                                          digits=2)
print(report_no_padding)

           precision    recall  f1-score   support

cell_line       0.33      0.59      0.43        96
        X       0.96      0.96      0.96      2268
  protein       0.59      0.76      0.66       765
      DNA       0.50      0.70      0.58       242
cell_type       0.72      0.37      0.49       324
      RNA       0.53      0.33      0.41        24

micro avg       0.79      0.84      0.81      3719
macro avg       0.82      0.84      0.82      3719



In [59]:
pred_tags_no_padding = postprocess_token_labels(pred_label_ids, 
                                                test_input_mask, 
                                                label_map)
true_tags_no_padding = postprocess_token_labels(test_label_ids, 
                                                test_input_mask, 
                                                label_map)
report_no_padding = classification_report(true_tags_no_padding, 
                                          pred_tags_no_padding, 
                                          digits=2)
print(report_no_padding)

           precision    recall  f1-score   support

cell_line       0.34      0.64      0.44        96
        X       0.97      0.97      0.97      2268
  protein       0.71      0.75      0.73       765
      DNA       0.58      0.76      0.66       242
cell_type       0.74      0.53      0.62       324
      RNA       0.60      0.75      0.67        24

micro avg       0.84      0.86      0.85      3719
macro avg       0.85      0.86      0.85      3719



In [30]:
pred_tags_no_padding = postprocess_token_labels(pred_label_ids, 
                                                test_input_mask, 
                                                label_map)
true_tags_no_padding = postprocess_token_labels(test_label_ids, 
                                                test_input_mask, 
                                                label_map)
report_no_padding = classification_report(true_tags_no_padding, 
                                          pred_tags_no_padding, 
                                          digits=2)
print(report_no_padding)

           precision    recall  f1-score   support

cell_line       0.39      0.78      0.52       102
cell_type       0.71      0.62      0.66       393
      DNA       0.59      0.81      0.68       213
        X       0.98      0.97      0.98      2250
  protein       0.64      0.81      0.72       654
      RNA       0.77      1.00      0.87        17

micro avg       0.82      0.89      0.86      3629
macro avg       0.85      0.89      0.87      3629



`postprocess_token_labels` also provides an option to remove the predictions on trailing word pieces, e.g. ##ize, so that the final predicted labels correspond to the original words in the input text. The `trailing_token_mask` is obtained from `tokenizer.tokenize_ner`

In [69]:
pred_tags_no_padding_no_trailing = postprocess_token_labels(pred_label_ids, 
                                                            test_input_mask, 
                                                            label_map, 
                                                            remove_trailing_word_pieces=True, 
                                                            trailing_token_mask=test_trailing_token_mask)
true_tags_no_padding_no_trailing = postprocess_token_labels(test_label_ids, 
                                                            test_input_mask, 
                                                            label_map, 
                                                            remove_trailing_word_pieces=True, 
                                                            trailing_token_mask=test_trailing_token_mask)
report_no_padding_no_trailing = classification_report(true_tags_no_padding_no_trailing, 
                                                      pred_tags_no_padding_no_trailing, 
                                                      digits=2)
print(report_no_padding_no_trailing)

           precision    recall  f1-score   support

  protein       1.00      1.00      1.00         2
      DNA       1.00      1.00      1.00         1

micro avg       1.00      1.00      1.00         3
macro avg       1.00      1.00      1.00         3

