In [None]:
# Copyright 2019 NVIDIA Corporation. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

<img src="http://developer.download.nvidia.com/compute/machine-learning/frameworks/nvidia_logo.png" style="width: 90px; float: right;">

# BioBERT Named-Entity Recognition Inference with Mixed Precision


## 1. Overview

Bidirectional Embedding Representations from Transformers (BERT), is a method of pre-training language representations which obtains state-of-the-art results on a wide array of Natural Language Processing (NLP) tasks. 

BioBERT is a domain specific version of BERT that has been trained on PubMed abstracts.

The original BioBERT paper can be found here: https://arxiv.org/abs/1901.08746

NVIDIA's BioBERT is an optimized version of the implementation presented in the paper, leveraging mixed precision arithmetic and tensor cores on V100 GPUS for faster training times while maintaining target accuracy.

### 1.a Learning objectives

This notebook demonstrates:
- Inference on NER task with BioBERT model
- The use/download of fine-tuned NVIDIA BioBERT models
- Use of Mixed Precision for Inference

## 2. Requirements

Please refer to the ReadMe file

## 3. BioBERT Inference: Named-Entity Recognition

We can run inference on a fine-tuned BioBERT model for tasks like Named-Entity Recognition.

Here we use a BioBERT model fine-tuned on a [BC5CDR-disease Dataset](https://www.ncbi.nlm.nih.gov/research/bionlp/Data/) which consists of 1500 PubMed articles with 5818 annotated diseases.

### 3.a Extract Disease Information from Text

In this example we will use Named-Entity Recognition model created using BioBERT to extract disease information from the following paragraph:

**Input Text**

_"The authors describe the case of a 56 - year - old woman with chronic, severe heart failure 
secondary to dilated cardiomyopathy and absence of significant ventricular arrhythmias 
who developed QT prolongation and torsade de pointes ventricular tachycardia during one cycle 
of intermittent low dose (2.5 mcg/kg per min) dobutamine. 
This report of torsade de pointes ventricular tachycardia during intermittent dobutamine 
supports the hypothesis that unpredictable fatal arrhythmias may occur even with low doses 
and in patients with no history of significant rhythm disturbances.
The mechanisms of proarrhythmic effects of Dubutamine are discussed."_

**Output visualized using displaCy**

<div class="entities" style="line-height: 2.5; direction: ltr">The authors describe the case of a 56 year old woman with chronic , severe 
<mark class="entity" style="background: #ddd; padding: 0.45em 0.6em; margin: 0 0.25em; line-height: 1; border-radius: 0.35em; box-decoration-break: clone; -webkit-box-decoration-break: clone">
    heart failure 
    <span style="font-size: 0.8em; font-weight: bold; line-height: 1; border-radius: 0.35em; text-transform: uppercase; vertical-align: middle; margin-left: 0.5rem">DISEASE</span>
</mark>
secondary to 
<mark class="entity" style="background: #ddd; padding: 0.45em 0.6em; margin: 0 0.25em; line-height: 1; border-radius: 0.35em; box-decoration-break: clone; -webkit-box-decoration-break: clone">
    dilated cardiomyopathy 
    <span style="font-size: 0.8em; font-weight: bold; line-height: 1; border-radius: 0.35em; text-transform: uppercase; vertical-align: middle; margin-left: 0.5rem">DISEASE</span>
</mark>
and absence of significant 
<mark class="entity" style="background: #ddd; padding: 0.45em 0.6em; margin: 0 0.25em; line-height: 1; border-radius: 0.35em; box-decoration-break: clone; -webkit-box-decoration-break: clone">
    ventricular arrhythmias 
    <span style="font-size: 0.8em; font-weight: bold; line-height: 1; border-radius: 0.35em; text-transform: uppercase; vertical-align: middle; margin-left: 0.5rem">DISEASE</span>
</mark>
who developed QT 
<mark class="entity" style="background: #ddd; padding: 0.45em 0.6em; margin: 0 0.25em; line-height: 1; border-radius: 0.35em; box-decoration-break: clone; -webkit-box-decoration-break: clone">
    prolongation 
    <span style="font-size: 0.8em; font-weight: bold; line-height: 1; border-radius: 0.35em; text-transform: uppercase; vertical-align: middle; margin-left: 0.5rem">DISEASE</span>
</mark>
and torsade de pointes 
<mark class="entity" style="background: #ddd; padding: 0.45em 0.6em; margin: 0 0.25em; line-height: 1; border-radius: 0.35em; box-decoration-break: clone; -webkit-box-decoration-break: clone">
    ventricular tachycardia 
    <span style="font-size: 0.8em; font-weight: bold; line-height: 1; border-radius: 0.35em; text-transform: uppercase; vertical-align: middle; margin-left: 0.5rem">DISEASE</span>
</mark>
during one cycle of intermittent low dose ( 2.5 mcg / kg per min ) dobutamine . This report of torsade de pointes 
<mark class="entity" style="background: #ddd; padding: 0.45em 0.6em; margin: 0 0.25em; line-height: 1; border-radius: 0.35em; box-decoration-break: clone; -webkit-box-decoration-break: clone">
    ventricular tachycardia 
    <span style="font-size: 0.8em; font-weight: bold; line-height: 1; border-radius: 0.35em; text-transform: uppercase; vertical-align: middle; margin-left: 0.5rem">DISEASE</span>
</mark>
during intermittent dobutamine supports the hypothesis that unpredictable fatal 
<mark class="entity" style="background: #ddd; padding: 0.45em 0.6em; margin: 0 0.25em; line-height: 1; border-radius: 0.35em; box-decoration-break: clone; -webkit-box-decoration-break: clone">
    arrhythmias 
    <span style="font-size: 0.8em; font-weight: bold; line-height: 1; border-radius: 0.35em; text-transform: uppercase; vertical-align: middle; margin-left: 0.5rem">DISEASE</span>
</mark>
may occur even with low doses and in patients with no history of significant 
<mark class="entity" style="background: #ddd; padding: 0.45em 0.6em; margin: 0 0.25em; line-height: 1; border-radius: 0.35em; box-decoration-break: clone; -webkit-box-decoration-break: clone">
    rhythm disturbances 
    <span style="font-size: 0.8em; font-weight: bold; line-height: 1; border-radius: 0.35em; text-transform: uppercase; vertical-align: middle; margin-left: 0.5rem">DISEASE</span>
</mark>
. The mechanisms of proarrhythmic effects of Dubutamine are discussed . </div>

In [None]:
text= """
The authors describe the case of a 56 year old woman with chronic, severe heart failure
secondary to dilated cardiomyopathy and absence of significant ventricular arrhythmias
who developed QT prolongation and torsade de pointes ventricular tachycardia during one cycle
of intermittent low dose (2.5 mcg/kg per min) dobutamine.
This report of torsade de pointes ventricular tachycardia during intermittent dobutamine
supports the hypothesis that unpredictable fatal arrhythmias may occur even with low doses
and in patients with no history of significant rhythm disturbances.
The mechanisms of proarrhythmic effects of Dubutamine are discussed.
"""

In [None]:
import os
import sys

notebooks_dir = '../notebooks'
working_dir = '../'
if working_dir not in sys.path:
    sys.path.append(working_dir)

In [None]:
# Convert the text into the IOB tags format seen during training, using dummy placeholder labels
import spacy
nlp = spacy.load("en_core_web_sm")

text = text.strip()
doc = nlp(text)
input_file = os.path.join(notebooks_dir, 'input.tsv')
with open(os.path.join(input_file), 'w') as wf: 
    for word in doc:
        if word.text is '\n':
            continue
        wf.write(word.text + '\tO\n')
    wf.write('\n') # Indicate end of text

### 3.b Mixed Precision

Mixed precision training offers significant computational speedup by performing operations in half-precision format, while storing minimal information in single-precision to retain as much information as possible in critical parts of the network. Since the introduction of tensor cores in the Volta and Turing architectures, significant training speedups are experienced by switching to mixed precision -- up to 3x overall speedup on the most arithmetically intense model architectures.

For information about:
- How to train using mixed precision, see the [Mixed Precision Training](https://arxiv.org/abs/1710.03740) paper and [Training With Mixed Precision](https://docs.nvidia.com/deeplearning/sdk/mixed-precision-training/index.html) documentation.
- How to access and enable AMP for TensorFlow, see [Using TF-AMP](https://docs.nvidia.com/deeplearning/dgx/tensorflow-user-guide/index.html#tfamp) from the TensorFlow User Guide.
- Techniques used for mixed precision training, see the [Mixed-Precision Training of Deep Neural Networks](https://devblogs.nvidia.com/mixed-precision-training-deep-neural-networks/) blog.

In this notebook we control mixed precision execution with the environmental variable:

In [None]:
import os
os.environ["TF_ENABLE_AUTO_MIXED_PRECISION"] = "1" 

The model we'll use was trained with mixed precision model, which takes much less time to train than the fp32 version, without losing accuracy. So we'll need to set with the following flag: 

In [None]:
use_mixed_precision_model = True

## 4. Fine-Tuned NVIDIA BioBERT TF Models

We have the following Named Entity Reconition models fine-tuned from BioBERT available on NGC (NVIDIA GPU Cluster, https://ngc.nvidia.com).

| **Model** | **Description** |
|:---------:|:----------:|
|BioBERT NER BC5CDR Disease  | NER model to extract disease information from text, trained on the BC5CDR-Disease dataset |
|BioBERT NER BC5CDR Chemical | NER model to extract chemical information from text, trained on the BC5CDR-Chemical dataset. |


For this exampple, we will download the Diease NER model trained from the BC5CDR-disease Dataset.


In [None]:
# biobert_uncased_base_ner_disease
DATA_DIR_FP16 = '../data/download/finetuned_model_fp16'
!mkdir -p $DATA_DIR_FP16
!wget -nc -q --show-progress -O $DATA_DIR_FP16/biobert_uncased_base_ner_disease.zip \
https://api.ngc.nvidia.com/v2/models/nvidia/biobert_uncased_base_ner_disease/versions/1/zip
!unzip -n -d $DATA_DIR_FP16/ $DATA_DIR_FP16/biobert_uncased_base_ner_disease.zip 

In the code that follows we will refer to these models.

## 5. Running NER task inference

In order to run NER inference we will follow step-by-step the flow implemented in run_ner.py.

### 5.a Configure Things

In [None]:
import run_ner
from run_ner import BC5CDRProcessor, model_fn_builder, file_based_input_fn_builder, filed_based_convert_examples_to_features, result_to_pair

import os, sys
import time

import tensorflow as tf
import modeling
import tokenization

tf.logging.set_verbosity(tf.logging.ERROR)

# Create the output directory where all the results are saved.
output_dir = os.path.join(working_dir, 'output')
tf.gfile.MakeDirs(output_dir)

# The config json file corresponding to the pre-trained BERT model.
# This specifies the model architecture.
bert_config_file = os.path.join(DATA_DIR_FP16, 'bert_config.json')

# The vocabulary file that the BERT model was trained on.
vocab_file = os.path.join(DATA_DIR_FP16, 'vocab.txt')

init_checkpoint = os.path.join(DATA_DIR_FP16, 'model.ckpt-10251')

# Whether to lower case the input text. 
# Should be True for uncased models and False for cased models.
# The BioBERT available in NGC is uncased
do_lower_case = True
  
# Total batch size for predictions
predict_batch_size = 1
params = dict([('batch_size', predict_batch_size)])

# The maximum total input sequence length after WordPiece tokenization. 
# Sequences longer than this will be truncated, and sequences shorter than this will be padded.
max_seq_length = 128

# This is a WA to use flags from here:
flags = tf.flags

if 'f' not in tf.flags.FLAGS: 
    tf.app.flags.DEFINE_string('f', '', 'kernel')
FLAGS = flags.FLAGS

FLAGS.output_dir = output_dir

### 5.b Define Tokenizer & Create Estimator

In [None]:
# Validate the casing config consistency with the checkpoint name.
tokenization.validate_case_matches_checkpoint(do_lower_case, init_checkpoint)

# Create the tokenizer.
tokenizer = tokenization.FullTokenizer(vocab_file=vocab_file, do_lower_case=do_lower_case)

# Load the configuration from file
bert_config = modeling.BertConfig.from_json_file(bert_config_file)


# Use the data processor for BC5CDR
processor = BC5CDRProcessor()
# Get labels in the index order that was used during training
label_list = processor.get_labels()

# Reverse index the labels. This will be used later when evaluating predictions.
id2label = {}
for (i, label) in enumerate(label_list, 1):
    id2label[i] = label


config = tf.ConfigProto(log_device_placement=True) 
run_config = tf.estimator.RunConfig(
      model_dir=None,
      session_config=config,
      save_checkpoints_steps=1000,
      keep_checkpoint_max=1)


# Use model function builder to create the model function
model_fn = model_fn_builder(
    bert_config=bert_config,
    num_labels=len(label_list) + 1,
    init_checkpoint=init_checkpoint,
    use_fp16=use_mixed_precision_model)

estimator = tf.estimator.Estimator(
  model_fn=model_fn,
  config=run_config,
  params=params)

### 5.c Run Inference

In [None]:
# Load the input data using the BC5CDR processor
predict_examples = processor.get_test_examples(notebooks_dir, file_name='input.tsv')


# Convert to tf_records and save it
predict_file = os.path.join(output_dir, "predict.tf_record")
filed_based_convert_examples_to_features(predict_examples, label_list,
                                         max_seq_length, tokenizer,
                                         predict_file)


tf.logging.info("***** Running predictions *****")
tf.logging.info("  Num orig examples = %d", len(predict_examples))
tf.logging.info("  Batch size = %d", predict_batch_size)

# Run prediction on this tf_record file
predict_input_fn = file_based_input_fn_builder(
    input_file=predict_file,
    batch_size=predict_batch_size,
    seq_length=max_seq_length,
    is_training=False,
    drop_remainder=False)


pred_start_time = time.time()

predictions = estimator.predict(input_fn=predict_input_fn)
predictions = list(predictions)

pred_time_elapsed = time.time() - pred_start_time

tf.logging.info("-----------------------------")
tf.logging.info("Total Inference Time = %0.2f", pred_time_elapsed)
# tf.logging.info("Inference Performance = %0.4f sentences/sec", avg_sentences_per_second)
tf.logging.info("-----------------------------")

### 5.d Save Predictions

In [None]:
# Let's now process the predictions and save them to file(s)
tf.logging.info("Save Predictions:")

# File containing the list of predictions as IOB tags
output_predict_file = os.path.join(FLAGS.output_dir, "label_test.txt")
# File containing the list of words, the dummy token and the predicted IOB tag
test_labels_file = os.path.join(FLAGS.output_dir, "test_labels.txt")
test_labels_err_file = os.path.join(FLAGS.output_dir, "test_labels_errs.txt")

with tf.gfile.Open(output_predict_file, 'w') as writer, \
        tf.gfile.Open(test_labels_file, 'w') as tl, \
        tf.gfile.Open(test_labels_err_file, 'w') as tle:
    i=0
    for prediction in estimator.predict(input_fn=predict_input_fn, yield_single_examples=True):
        output_line = "\n".join(id2label[id] for id in prediction if id != 0) + "\n"
        writer.write(output_line)
        result_to_pair(predict_examples[i], prediction, id2label, tl, tle)
        i = i + 1

### 5.e Visualize Predictions

In [None]:
# Let's create a function that can formats the predictions for display using displaCy
def predictions_for_displacy(predict_examples, predictions, id2label):
    processed_text = ''
    entities = []
    current_pos = 0
    start_pos = 0
    end_pos = 0
    end_detected = False
    prev_label = ''

    for predict_line, pred_ids in zip(predict_examples, predictions):
        words = str(predict_line.text).split(' ')
        labels = str(predict_line.label).split(' ')

        # get from CLS to SEP
        pred_labels = []
        for id in pred_ids:
            if id == 0:
                continue
            curr_label = id2label[id]
            if curr_label == '[CLS]':
                continue
            elif curr_label == '[SEP]':
                break
            elif curr_label == 'X':
                continue
            pred_labels.append(curr_label)

        for tok, label, pred_label in zip(words, labels, pred_labels):
            if pred_label is 'B':
                start_pos = current_pos
            elif pred_label is 'I' and prev_label is not 'B' and prev_label is not 'I':
                start_pos = current_pos
            elif pred_label is 'O' and (prev_label is 'B' or prev_label is 'I'):
                end_pos = current_pos
                end_detected = True

            if end_detected:
                entities.append({'start':start_pos, 'end': end_pos, 'label': 'DISEASE'})
                start_pos = 0
                end_pos = 0
                end_detected = False

            processed_text = processed_text + tok + ' '
            current_pos = current_pos + len(tok) + 1
            prev_label = pred_label

    #Handle entity at the very end
    if start_pos > 0 and end_detected is False:
        entities.append({'start':start_pos, 'end': current_pos, 'label': 'DISEASE'})
    
    displacy_input = [{"text": processed_text,
                            "ents": entities,
                            "title": None}]
    
    return displacy_input

In [None]:
# Convert the predictions to the Named Entities format required by displaCy and visualize
displacy_input = predictions_for_displacy(predict_examples, predictions, id2label)
html = spacy.displacy.render(displacy_input, style="ent", manual=True)

## 6. What's next

Now that you are familiar with running NER Inference on BioBERT, using mixed precision, you may want to try extracting disease information from other biomedical text. 