In [8]:
"""
You can run either this notebook locally (if you have all the dependencies and a GPU) or on Google Colab.

Instructions for setting up Colab are as follows:
1. Open a new Python 3 notebook.
2. Import this notebook from GitHub (File -> Upload Notebook -> "GITHUB" tab -> copy/paste GitHub URL)
3. Connect to an instance with a GPU (Runtime -> Change runtime type -> select "GPU" for hardware accelerator)
4. Run this cell to set up dependencies.
"""


## Install dependencies
!pip install wget
!pip install faiss-gpu

## Install NeMo
BRANCH = 'r1.0.0rc1'
!python -m pip install git+https://github.com/NVIDIA/NeMo.git@$BRANCH#egg=nemo_toolkit[all]

In [1]:
import faiss
import torch
import wget
import os
import numpy as np
import pandas as pd

from omegaconf import OmegaConf
from pytorch_lightning import Trainer
from IPython.display import display
from tqdm import tqdm

from nemo.collections import nlp as nemo_nlp
from nemo.utils.exp_manager import exp_manager

[NeMo W 2021-04-15 20:14:12 optimizers:46] Apex was not found. Using the lamb optimizer will error out.


## Entity Linking

#### Task Description
[Entity linking](https://en.wikipedia.org/wiki/Entity_linking) is the process of connecting concepts mentioned in natural language to their canonical forms stored in a knowledge base. For example, say a knowledge base contained the entity 'ID3452 influenza' and we wanted to process some natural language containing the sentence "The patient has flu like symptoms". An entity linking model would match the word 'flu' to the knowledge base entity 'ID3452 influenza', allowing for disambiguation and normalization of concepts referenced in text. Entity linking applications range from helping automate data ingestion to assisting in real time dialogue concept normalization. We will be focusing on entity linking in the medical domain for this demo, but the entity linking model, dataset, and training code within NVIDIA NeMo can be applied to other domains like finance and retail.

Within NeMo and this tutorial we use the entity linking approach described in Liu et. al's NAACL 2021 "[Self-alignment Pre-training for Biomedical Entity Representations](https://arxiv.org/abs/2010.11784v2)". The main idea behind this approach is to reshape an initial concept embedding space such that synonyms of the same concept are pulled closer together and unrelated concepts are pushed further apart. The concept embeddings from this reshaped space can then be used to build a knowledge base embedding index. This index stores concept IDs mapped to their respective concept embeddings in a format conducive to efficient nearest neighbor search. We can link query concepts to their canonical forms in the knowledge base by performing a nearest neighbor search- matching concept query embeddings to the most similar concepts embeddings in the knowledge base index. 

In this tutorial we will be using the [faiss](https://github.com/facebookresearch/faiss) library to build our concept index.

#### Self Alignment Pretraining
Self-Alignment pretraining is a second stage pretraining of an existing encoder (called second stage because the encoder model can be further finetuned after this more general pretraining step). The dataset used during training consists of pairs of concept synonyms that map to the same ID. At each training iteration, we only select *hard* examples present in the mini batch to calculate the loss and update the model weights. In this context, a hard example is an example where a concept is closer to an unrelated concept in the mini batch than it is to the synonym concept it is paired with by some margin. I encourage you to take a look at [section 2 of the paper](https://arxiv.org/pdf/2010.11784.pdf) for a more formal and in depth description of how hard examples are selected.

We then use a [metric learning loss](https://openaccess.thecvf.com/content_CVPR_2019/papers/Wang_Multi-Similarity_Loss_With_General_Pair_Weighting_for_Deep_Metric_Learning_CVPR_2019_paper.pdf) calculated from the hard examples selected. This loss helps reshape the embedding space. The concept representation space is rearranged to be more suitable for entity matching via embedding cosine similarity. 

Now that we have idea of what's going on, let's get started!

## Dataset Preprocessing

In [2]:
# Download data
DATA_DIR = "tiny_example_data"

if not os.path.isdir(os.path.join("tiny_example_data")):
    wget.download('https://dldata-public.s3.us-east-2.amazonaws.com/tiny_example_data.zip',
                  os.path.join("tiny_example_data.zip"))

    !unzip tiny_example_data.zip

Archive:  tiny_example_data.zip
   creating: tiny_example_data/
  inflating: tiny_example_data/tiny_example_dev_data.csv  
  inflating: tiny_example_data/tiny_example_test_queries.tsv  
  inflating: tiny_example_data/tiny_example_validation_pairs.tsv  
  inflating: tiny_example_data/tiny_example_train_pairs.tsv  
  inflating: tiny_example_data/tiny_example_index_data.tsv  
  inflating: tiny_example_data/tiny_example_test_kb.tsv  


In this tutorial we will be using a tiny toy dataset to demonstrate how to use NeMo's entity linking model functionality. The dataset includes synonyms for 12 medical concepts. Entity phrases with the same ID are synonyms for the same concept. For example, "*chronic kidney failure*", "*gradual loss of kidney function*", and "*CKD*" are all synonyms of concept ID 5. Here's the dataset before preprocessing:

In [3]:
raw_data = pd.read_csv(os.path.join(DATA_DIR, "tiny_example_dev_data.csv"), names=["ID", "CONCEPT"], index_col=False)
print(raw_data)

    ID                                            CONCEPT
0    1                                          Head ache
1    1                                           Headache
2    1                                           Migraine
3    1                                   Pain in the head
4    1                                          cephalgia
5    1                                        cephalalgia
6    2                                       heart attack
7    2                              Myocardial infraction
8    2                           necrosis of heart muscle
9    2                                                 MI
10   3                                                CAD
11   3                            Coronary artery disease
12   3                      atherosclerotic heart disease
13   3                                      heart disease
14   3                damage of major heart blood vessels
15   4                                myocardial ischemia
16   4        

We've already paired off the concepts for this dataset with the format `ID concept_synonym1 concept_synonym2`. Here are the first ten rows:

In [4]:
training_data = pd.read_table(os.path.join(DATA_DIR, "tiny_example_train_pairs.tsv"), names=["ID", "CONCEPT_SYN1", "CONCEPT_SYN2"], delimiter='\t')
print(training_data.head(10))

   ID      CONCEPT_SYN1      CONCEPT_SYN2
0   1  Pain in the head         cephalgia
1   1  Pain in the head       cephalalgia
2   1          Migraine         cephalgia
3   1         Head ache  Pain in the head
4   1         Head ache          Migraine
5   1         Head ache       cephalalgia
6   1          Headache          Migraine
7   1          Migraine       cephalalgia
8   1         cephalgia       cephalalgia
9   1          Headache  Pain in the head


Use the [Unified Medical Language System (UMLS)](https://www.nlm.nih.gov/research/umls/index.html) dataset for full medical domain entity linking training. The data contains over 9 million entities and is a table of medical concepts with their corresponding concept IDs (CUI). After [requesting a free license and making a UMLS Terminology Services (UTS) account](https://www.nlm.nih.gov/research/umls/index.html), the [entire UMLS dataset](https://www.nlm.nih.gov/research/umls/licensedcontent/umlsknowledgesources.html) can be downloaded from the NIH's website. If you've cloned the NeMo repo you can run the data processing script located in `examples/nlp/entity_linking/data/umls_dataset_processing.py` on the full dataset. This script will take in the initial table of UMLS concepts and produce a .tsv file with each row formatted as `CUI\tconcept_synonym1\tconcept_synonym2`. Once the UMLS dataset .RRF file is downloaded, the script can be run from the `examples/nlp/entity_linking` directory like so: 
```
python data/umls_dataset_processing.py --cfg conf/umls_medical_entity_linking_config.yaml
```

## Model Training

Second stage pretrain a BERT Base encoder on the self-alignment pretraining task (SAP) for improved entity linking. Using a GPU, the model should take 5 minutes or less to train on this example dataset and training progress will be output below the cell.

In [5]:
# Download config
wget.download("https://raw.githubusercontent.com/vadam5/NeMo/main/examples/nlp/entity_linking/conf/tiny_example_entity_linking_config.yaml",
              os.path.join("tiny_example_entity_linking_config.yaml"))

# Load in config file
cfg = OmegaConf.load(os.path.join("tiny_example_entity_linking_config.yaml"))

In [6]:
# Initialize the trainer and model
trainer = Trainer(**cfg.trainer)
exp_manager(trainer, cfg.get("exp_manager", None))
model = nemo_nlp.models.EntityLinkingModel(cfg=cfg.model, trainer=trainer)

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
Using native 16bit precision.


[NeMo I 2021-04-15 20:14:25 exp_manager:216] Experiments will be logged at SelfAlignmentPretrainingTinyExample/2021-04-15_20-14-25
[NeMo I 2021-04-15 20:14:25 exp_manager:563] TensorboardLogger has been set up
[NeMo I 2021-04-15 20:14:27 entity_linking_dataset:61] Getting datafile newline indices
[NeMo I 2021-04-15 20:14:27 entity_linking_dataset:78] Loaded dataset with 63 examples
[NeMo I 2021-04-15 20:14:27 entity_linking_dataset:61] Getting datafile newline indices
[NeMo I 2021-04-15 20:14:27 entity_linking_dataset:78] Loaded dataset with 21 examples


In [7]:
# Train and save the model
trainer.fit(model)
model.save_to(cfg.model.nemo_path)

[NeMo I 2021-04-15 20:14:32 modelPT:688] Optimizer config = Adam (
    Parameter Group 0
        amsgrad: False
        betas: (0.9, 0.999)
        eps: 1e-08
        lr: 3e-05
        weight_decay: 0.0
    )
[NeMo I 2021-04-15 20:14:32 lr_scheduler:621] Scheduler "<nemo.core.optim.lr_scheduler.CosineAnnealing object at 0x7f364cbe9e10>" 
    will be used during training (effective maximum steps = 24) - 
    Parameters : 
    (warmup_steps: null
    warmup_ratio: 0.1
    min_lr: 0.0
    last_epoch: -1
    max_steps: 24
    )


initializing ddp: GLOBAL_RANK: 0, MEMBER: 1/1

  | Name  | Type                | Params
----------------------------------------------
0 | model | BertEncoder         | 109 M 
1 | loss  | MultiSimilarityLoss | 0     
----------------------------------------------
109 M     Trainable params
0         Non-trainable params
109 M     Total params
437.929   Total estimated model params size (MB)
INFO:lightning:
  | Name  | Type                | Params
----------------------------------------------
0 | model | BertEncoder         | 109 M 
1 | loss  | MultiSimilarityLoss | 0     
----------------------------------------------
109 M     Trainable params
0         Non-trainable params
109 M     Total params
437.929   Total estimated model params size (MB)
    


Validation sanity check:   0%|          | 0/2 [00:00<?, ?it/s][NeMo I 2021-04-15 20:14:35 entity_linking_model:125] val loss: 1.1948829889297485
Validation sanity check:  50%|█████     | 1/2 [00:00<00:00,  9.94it/s][NeMo I 2021-04-15 20:14:35 entity_linking_model:125] val loss: 0.8535466194152832


    


                                                                      

    


Epoch 0:  10%|█         | 2/20 [00:00<00:02,  7.68it/s, loss=0.741, v_num=4-25, val_loss=1.020, lr=1e-5]
Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/3 [00:00<?, ?it/s][A[NeMo I 2021-04-15 20:14:35 entity_linking_model:125] val loss: 1.009026050567627
Epoch 0:  20%|██        | 4/20 [00:00<00:01, 10.89it/s, loss=0.741, v_num=4-25, val_loss=1.020, lr=1e-5][NeMo I 2021-04-15 20:14:35 entity_linking_model:125] val loss: 0.816156268119812

Validating:  67%|██████▋   | 2/3 [00:00<00:00, 15.78it/s][A[NeMo I 2021-04-15 20:14:35 entity_linking_model:125] val loss: 0.7925214171409607
Epoch 0:  30%|███       | 6/20 [00:00<00:01, 12.90it/s, loss=0.741, v_num=4-25, val_loss=0.873, lr=2e-5]
                                                         [A

Epoch 0, global step 1: val_loss reached 0.87257 (best 0.87257), saving model to "/home/vadams/Projects/entity-linking-research/NeMo/tutorials/nlp/SelfAlignmentPretrainingTinyExample/2021-04-15_20-14-25/checkpoints/SelfAlignmentPretrainingTinyExample---val_loss=0.87-epoch=0.ckpt" as top 3
INFO:lightning:Epoch 0, global step 1: val_loss reached 0.87257 (best 0.87257), saving model to "/home/vadams/Projects/entity-linking-research/NeMo/tutorials/nlp/SelfAlignmentPretrainingTinyExample/2021-04-15_20-14-25/checkpoints/SelfAlignmentPretrainingTinyExample---val_loss=0.87-epoch=0.ckpt" as top 3


Epoch 0:  40%|████      | 8/20 [00:04<00:06,  1.87it/s, loss=0.717, v_num=4-25, val_loss=0.873, lr=3e-5]
Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/3 [00:00<?, ?it/s][A[NeMo I 2021-04-15 20:14:39 entity_linking_model:125] val loss: 0.9627444744110107

Epoch 0:  50%|█████     | 10/20 [00:04<00:04,  2.26it/s, loss=0.717, v_num=4-25, val_loss=0.873, lr=3e-5][NeMo I 2021-04-15 20:14:39 entity_linking_model:125] val loss: 0.6589674949645996
[NeMo I 2021-04-15 20:14:39 entity_linking_model:125] val loss: 0.7004890441894531
Epoch 0:  60%|██████    | 12/20 [00:04<00:03,  2.65it/s, loss=0.717, v_num=4-25, val_loss=0.774, lr=2.98e-5]
                                                         [A

Epoch 0, global step 3: val_loss reached 0.77407 (best 0.77407), saving model to "/home/vadams/Projects/entity-linking-research/NeMo/tutorials/nlp/SelfAlignmentPretrainingTinyExample/2021-04-15_20-14-25/checkpoints/SelfAlignmentPretrainingTinyExample---val_loss=0.77-epoch=0.ckpt" as top 3
INFO:lightning:Epoch 0, global step 3: val_loss reached 0.77407 (best 0.77407), saving model to "/home/vadams/Projects/entity-linking-research/NeMo/tutorials/nlp/SelfAlignmentPretrainingTinyExample/2021-04-15_20-14-25/checkpoints/SelfAlignmentPretrainingTinyExample---val_loss=0.77-epoch=0.ckpt" as top 3


Epoch 0:  70%|███████   | 14/20 [00:10<00:04,  1.28it/s, loss=0.715, v_num=4-25, val_loss=0.774, lr=2.94e-5]
Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/3 [00:00<?, ?it/s][A[NeMo I 2021-04-15 20:14:46 entity_linking_model:125] val loss: 0.876510739326477

Epoch 0:  80%|████████  | 16/20 [00:11<00:02,  1.45it/s, loss=0.715, v_num=4-25, val_loss=0.774, lr=2.94e-5][NeMo I 2021-04-15 20:14:46 entity_linking_model:125] val loss: 0.628505289554596
[NeMo I 2021-04-15 20:14:46 entity_linking_model:125] val loss: 0.6741818785667419
Epoch 0:  90%|█████████ | 18/20 [00:11<00:01,  1.61it/s, loss=0.715, v_num=4-25, val_loss=0.726, lr=2.86e-5]
                                                         [A

Epoch 0, global step 5: val_loss reached 0.72640 (best 0.72640), saving model to "/home/vadams/Projects/entity-linking-research/NeMo/tutorials/nlp/SelfAlignmentPretrainingTinyExample/2021-04-15_20-14-25/checkpoints/SelfAlignmentPretrainingTinyExample---val_loss=0.73-epoch=0.ckpt" as top 3
INFO:lightning:Epoch 0, global step 5: val_loss reached 0.72640 (best 0.72640), saving model to "/home/vadams/Projects/entity-linking-research/NeMo/tutorials/nlp/SelfAlignmentPretrainingTinyExample/2021-04-15_20-14-25/checkpoints/SelfAlignmentPretrainingTinyExample---val_loss=0.73-epoch=0.ckpt" as top 3


Epoch 0: 100%|██████████| 20/20 [00:18<00:00,  1.10it/s, loss=0.677, v_num=4-25, val_loss=0.726, lr=2.76e-5]
Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/3 [00:00<?, ?it/s][A[NeMo I 2021-04-15 20:14:53 entity_linking_model:125] val loss: 0.8661032915115356

Validating:  33%|███▎      | 1/3 [00:00<00:00,  7.76it/s][A[NeMo I 2021-04-15 20:14:53 entity_linking_model:125] val loss: 0.6094290614128113
[NeMo I 2021-04-15 20:14:53 entity_linking_model:125] val loss: 0.6588873267173767
Epoch 0: 100%|██████████| 20/20 [00:18<00:00,  1.08it/s, loss=0.677, v_num=4-25, val_loss=0.711, lr=2.63e-5]
                                                         [A

Epoch 0, global step 7: val_loss reached 0.71147 (best 0.71147), saving model to "/home/vadams/Projects/entity-linking-research/NeMo/tutorials/nlp/SelfAlignmentPretrainingTinyExample/2021-04-15_20-14-25/checkpoints/SelfAlignmentPretrainingTinyExample---val_loss=0.71-epoch=0.ckpt" as top 3
INFO:lightning:Epoch 0, global step 7: val_loss reached 0.71147 (best 0.71147), saving model to "/home/vadams/Projects/entity-linking-research/NeMo/tutorials/nlp/SelfAlignmentPretrainingTinyExample/2021-04-15_20-14-25/checkpoints/SelfAlignmentPretrainingTinyExample---val_loss=0.71-epoch=0.ckpt" as top 3


Epoch 1:  15%|█▌        | 3/20 [00:00<00:01, 10.46it/s, loss=0.672, v_num=4-25, val_loss=0.711, lr=2.48e-5] 
Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/3 [00:00<?, ?it/s][A[NeMo I 2021-04-15 20:15:00 entity_linking_model:125] val loss: 0.8344232439994812

Validating:  33%|███▎      | 1/3 [00:00<00:00,  6.39it/s][A[NeMo I 2021-04-15 20:15:01 entity_linking_model:125] val loss: 0.6012848019599915
[NeMo I 2021-04-15 20:15:01 entity_linking_model:125] val loss: 0.6414815783500671
Epoch 1:  30%|███       | 6/20 [00:00<00:01, 10.35it/s, loss=0.672, v_num=4-25, val_loss=0.692, lr=2.31e-5]
                                                         [A

Epoch 1, global step 9: val_loss reached 0.69240 (best 0.69240), saving model to "/home/vadams/Projects/entity-linking-research/NeMo/tutorials/nlp/SelfAlignmentPretrainingTinyExample/2021-04-15_20-14-25/checkpoints/SelfAlignmentPretrainingTinyExample---val_loss=0.69-epoch=1.ckpt" as top 3
INFO:lightning:Epoch 1, global step 9: val_loss reached 0.69240 (best 0.69240), saving model to "/home/vadams/Projects/entity-linking-research/NeMo/tutorials/nlp/SelfAlignmentPretrainingTinyExample/2021-04-15_20-14-25/checkpoints/SelfAlignmentPretrainingTinyExample---val_loss=0.69-epoch=1.ckpt" as top 3


Epoch 1:  45%|████▌     | 9/20 [00:07<00:09,  1.15it/s, loss=0.637, v_num=4-25, val_loss=0.692, lr=2.12e-5]
Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/3 [00:00<?, ?it/s][A[NeMo I 2021-04-15 20:15:08 entity_linking_model:125] val loss: 0.8027772903442383

Validating:  33%|███▎      | 1/3 [00:00<00:00,  9.43it/s][A[NeMo I 2021-04-15 20:15:08 entity_linking_model:125] val loss: 0.581346333026886
[NeMo I 2021-04-15 20:15:08 entity_linking_model:125] val loss: 0.6344283223152161
Epoch 1:  60%|██████    | 12/20 [00:08<00:05,  1.48it/s, loss=0.637, v_num=4-25, val_loss=0.673, lr=1.92e-5]
                                                         [A

Epoch 1, global step 11: val_loss reached 0.67285 (best 0.67285), saving model to "/home/vadams/Projects/entity-linking-research/NeMo/tutorials/nlp/SelfAlignmentPretrainingTinyExample/2021-04-15_20-14-25/checkpoints/SelfAlignmentPretrainingTinyExample---val_loss=0.67-epoch=1.ckpt" as top 3
INFO:lightning:Epoch 1, global step 11: val_loss reached 0.67285 (best 0.67285), saving model to "/home/vadams/Projects/entity-linking-research/NeMo/tutorials/nlp/SelfAlignmentPretrainingTinyExample/2021-04-15_20-14-25/checkpoints/SelfAlignmentPretrainingTinyExample---val_loss=0.67-epoch=1.ckpt" as top 3


Epoch 1:  75%|███████▌  | 15/20 [00:14<00:04,  1.03it/s, loss=0.624, v_num=4-25, val_loss=0.673, lr=1.71e-5]
Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/3 [00:00<?, ?it/s][A[NeMo I 2021-04-15 20:15:15 entity_linking_model:125] val loss: 0.782863974571228

Validating:  33%|███▎      | 1/3 [00:00<00:00,  9.53it/s][A[NeMo I 2021-04-15 20:15:15 entity_linking_model:125] val loss: 0.5425866842269897
[NeMo I 2021-04-15 20:15:15 entity_linking_model:125] val loss: 0.6156286597251892
Epoch 1:  90%|█████████ | 18/20 [00:14<00:01,  1.22it/s, loss=0.624, v_num=4-25, val_loss=0.647, lr=1.5e-5] 
                                                         [A

Epoch 1, global step 13: val_loss reached 0.64703 (best 0.64703), saving model to "/home/vadams/Projects/entity-linking-research/NeMo/tutorials/nlp/SelfAlignmentPretrainingTinyExample/2021-04-15_20-14-25/checkpoints/SelfAlignmentPretrainingTinyExample---val_loss=0.65-epoch=1.ckpt" as top 3
INFO:lightning:Epoch 1, global step 13: val_loss reached 0.64703 (best 0.64703), saving model to "/home/vadams/Projects/entity-linking-research/NeMo/tutorials/nlp/SelfAlignmentPretrainingTinyExample/2021-04-15_20-14-25/checkpoints/SelfAlignmentPretrainingTinyExample---val_loss=0.65-epoch=1.ckpt" as top 3


Epoch 1: 100%|██████████| 20/20 [00:21<00:00,  1.08s/it, loss=0.609, v_num=4-25, val_loss=0.647, lr=1.29e-5]
Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/3 [00:00<?, ?it/s][A[NeMo I 2021-04-15 20:15:22 entity_linking_model:125] val loss: 0.7710549235343933

Validating:  33%|███▎      | 1/3 [00:00<00:00,  9.93it/s][A[NeMo I 2021-04-15 20:15:22 entity_linking_model:125] val loss: 0.5387244820594788
[NeMo I 2021-04-15 20:15:22 entity_linking_model:125] val loss: 0.6098815202713013
Epoch 1: 100%|██████████| 20/20 [00:21<00:00,  1.09s/it, loss=0.609, v_num=4-25, val_loss=0.640, lr=1.08e-5]
                                                         [A

Epoch 1, global step 15: val_loss reached 0.63989 (best 0.63989), saving model to "/home/vadams/Projects/entity-linking-research/NeMo/tutorials/nlp/SelfAlignmentPretrainingTinyExample/2021-04-15_20-14-25/checkpoints/SelfAlignmentPretrainingTinyExample---val_loss=0.64-epoch=1.ckpt" as top 3
INFO:lightning:Epoch 1, global step 15: val_loss reached 0.63989 (best 0.63989), saving model to "/home/vadams/Projects/entity-linking-research/NeMo/tutorials/nlp/SelfAlignmentPretrainingTinyExample/2021-04-15_20-14-25/checkpoints/SelfAlignmentPretrainingTinyExample---val_loss=0.64-epoch=1.ckpt" as top 3


Epoch 2:  15%|█▌        | 3/20 [00:00<00:01, 10.66it/s, loss=0.614, v_num=4-25, val_loss=0.640, lr=8.77e-6] 
Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/3 [00:00<?, ?it/s][A[NeMo I 2021-04-15 20:15:29 entity_linking_model:125] val loss: 0.7613283395767212

Validating:  33%|███▎      | 1/3 [00:00<00:00,  8.70it/s][A[NeMo I 2021-04-15 20:15:29 entity_linking_model:125] val loss: 0.5353438258171082
[NeMo I 2021-04-15 20:15:29 entity_linking_model:125] val loss: 0.606192409992218
Epoch 2:  30%|███       | 6/20 [00:00<00:01, 11.17it/s, loss=0.614, v_num=4-25, val_loss=0.634, lr=6.89e-6]
                                                         [A

Epoch 2, global step 17: val_loss reached 0.63429 (best 0.63429), saving model to "/home/vadams/Projects/entity-linking-research/NeMo/tutorials/nlp/SelfAlignmentPretrainingTinyExample/2021-04-15_20-14-25/checkpoints/SelfAlignmentPretrainingTinyExample---val_loss=0.63-epoch=2.ckpt" as top 3
INFO:lightning:Epoch 2, global step 17: val_loss reached 0.63429 (best 0.63429), saving model to "/home/vadams/Projects/entity-linking-research/NeMo/tutorials/nlp/SelfAlignmentPretrainingTinyExample/2021-04-15_20-14-25/checkpoints/SelfAlignmentPretrainingTinyExample---val_loss=0.63-epoch=2.ckpt" as top 3


Epoch 2:  45%|████▌     | 9/20 [00:08<00:10,  1.08it/s, loss=0.599, v_num=4-25, val_loss=0.634, lr=5.18e-6]
Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/3 [00:00<?, ?it/s][A[NeMo I 2021-04-15 20:15:37 entity_linking_model:125] val loss: 0.729590117931366

Validating:  33%|███▎      | 1/3 [00:00<00:00,  8.73it/s][A[NeMo I 2021-04-15 20:15:37 entity_linking_model:125] val loss: 0.5150667428970337
[NeMo I 2021-04-15 20:15:37 entity_linking_model:125] val loss: 0.6025132536888123
Epoch 2:  60%|██████    | 12/20 [00:08<00:05,  1.41it/s, loss=0.599, v_num=4-25, val_loss=0.616, lr=3.66e-6]
                                                         [A

Epoch 2, global step 19: val_loss reached 0.61572 (best 0.61572), saving model to "/home/vadams/Projects/entity-linking-research/NeMo/tutorials/nlp/SelfAlignmentPretrainingTinyExample/2021-04-15_20-14-25/checkpoints/SelfAlignmentPretrainingTinyExample---val_loss=0.62-epoch=2.ckpt" as top 3
INFO:lightning:Epoch 2, global step 19: val_loss reached 0.61572 (best 0.61572), saving model to "/home/vadams/Projects/entity-linking-research/NeMo/tutorials/nlp/SelfAlignmentPretrainingTinyExample/2021-04-15_20-14-25/checkpoints/SelfAlignmentPretrainingTinyExample---val_loss=0.62-epoch=2.ckpt" as top 3


Epoch 2:  75%|███████▌  | 15/20 [00:15<00:05,  1.03s/it, loss=0.569, v_num=4-25, val_loss=0.616, lr=2.38e-6]
Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/3 [00:00<?, ?it/s][A[NeMo I 2021-04-15 20:15:44 entity_linking_model:125] val loss: 0.7279607057571411
[NeMo I 2021-04-15 20:15:44 entity_linking_model:125] val loss: 0.5141437649726868

Validating:  67%|██████▋   | 2/3 [00:00<00:00, 15.23it/s][A[NeMo I 2021-04-15 20:15:44 entity_linking_model:125] val loss: 0.6012243032455444
Epoch 2:  90%|█████████ | 18/20 [00:15<00:01,  1.15it/s, loss=0.569, v_num=4-25, val_loss=0.614, lr=1.36e-6]
                                                         [A

Epoch 2, global step 21: val_loss reached 0.61444 (best 0.61444), saving model to "/home/vadams/Projects/entity-linking-research/NeMo/tutorials/nlp/SelfAlignmentPretrainingTinyExample/2021-04-15_20-14-25/checkpoints/SelfAlignmentPretrainingTinyExample---val_loss=0.61-epoch=2.ckpt" as top 3
INFO:lightning:Epoch 2, global step 21: val_loss reached 0.61444 (best 0.61444), saving model to "/home/vadams/Projects/entity-linking-research/NeMo/tutorials/nlp/SelfAlignmentPretrainingTinyExample/2021-04-15_20-14-25/checkpoints/SelfAlignmentPretrainingTinyExample---val_loss=0.61-epoch=2.ckpt" as top 3


Epoch 2: 100%|██████████| 20/20 [00:22<00:00,  1.13s/it, loss=0.552, v_num=4-25, val_loss=0.614, lr=6.08e-7]
Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/3 [00:00<?, ?it/s][A[NeMo I 2021-04-15 20:15:51 entity_linking_model:125] val loss: 0.7275949716567993

Validating:  33%|███▎      | 1/3 [00:00<00:00,  8.58it/s][A[NeMo I 2021-04-15 20:15:51 entity_linking_model:125] val loss: 0.5138847231864929
[NeMo I 2021-04-15 20:15:51 entity_linking_model:125] val loss: 0.601005494594574
Epoch 2: 100%|██████████| 20/20 [00:22<00:00,  1.14s/it, loss=0.552, v_num=4-25, val_loss=0.614, lr=1.53e-7]
                                                         [A

Epoch 2, global step 23: val_loss reached 0.61416 (best 0.61416), saving model to "/home/vadams/Projects/entity-linking-research/NeMo/tutorials/nlp/SelfAlignmentPretrainingTinyExample/2021-04-15_20-14-25/checkpoints/SelfAlignmentPretrainingTinyExample---val_loss=0.61-epoch=2-v1.ckpt" as top 3
INFO:lightning:Epoch 2, global step 23: val_loss reached 0.61416 (best 0.61416), saving model to "/home/vadams/Projects/entity-linking-research/NeMo/tutorials/nlp/SelfAlignmentPretrainingTinyExample/2021-04-15_20-14-25/checkpoints/SelfAlignmentPretrainingTinyExample---val_loss=0.61-epoch=2-v1.ckpt" as top 3


Epoch 2: 100%|██████████| 20/20 [00:25<00:00,  1.30s/it, loss=0.552, v_num=4-25, val_loss=0.614, lr=1.53e-7]

Saving latest checkpoint...
INFO:lightning:Saving latest checkpoint...


Epoch 2: 100%|██████████| 20/20 [00:29<00:00,  1.48s/it, loss=0.552, v_num=4-25, val_loss=0.614, lr=1.53e-7]


You can run the script at `examples/nlp/entity_linking/self_alignment_pretraining.py` to train a model on a larger dataset. Run

```
python self_alignment_pretraining.py
```
from the `examples/nlp/entity_linking` directory.

## Model Evaluation

Let's evaluate our freshly trained model and compare its performance with a BERT Base encoder that hasn't undergone self-alignment pretraining. We first need to restore our trained model and load our BERT Base Baseline model.

In [8]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Restore second stage pretrained model
sap_model_cfg = cfg
sap_model = nemo_nlp.models.EntityLinkingModel.restore_from(sap_model_cfg.model.nemo_path).to(device)

# Load original model
base_model_cfg = OmegaConf.load("tiny_example_entity_linking_config.yaml")

# Set train/val datasets to None to avoid loading datasets associated with training
base_model_cfg.model.train_ds = None
base_model_cfg.model.validation_ds = None
base_model_cfg.index.index_save_name = "base_model_index"
base_model = nemo_nlp.models.EntityLinkingModel(base_model_cfg.model).to(device)

[NeMo W 2021-04-15 20:16:31 modelPT:133] Please call the ModelPT.setup_training_data() method and provide a valid configuration file to setup the train data loader.
    Train config : 
    data_file: tiny_example_data/tiny_example_train_pairs.tsv
    max_seq_length: 128
    batch_size: 8
    shuffle: true
    num_workers: 2
    pin_memory: false
    drop_last: false
    
[NeMo W 2021-04-15 20:16:31 modelPT:140] Please call the ModelPT.setup_validation_data() or ModelPT.setup_multiple_validation_data() method and provide a valid configuration file to setup the validation data loader(s). 
    Validation config : 
    data_file: tiny_example_data/tiny_example_validation_pairs.tsv
    max_seq_length: 128
    batch_size: 8
    shuffle: false
    num_workers: 2
    pin_memory: false
    drop_last: false
    
[NeMo W 2021-04-15 20:16:31 modelPT:1137] World size can only be set by PyTorch Lightning Trainer.


[NeMo I 2021-04-15 20:16:35 modelPT:376] Model EntityLinkingModel was successfully restored from tiny_example_sap_bert_model.nemo.


[NeMo W 2021-04-15 20:16:37 modelPT:1137] World size can only be set by PyTorch Lightning Trainer.


We are going evaluate our model on a nearest neighbors task using top 1 and top 5 accuarcy as our metric. We will be using a tiny example test knowledge base and test queries. For this evaluation we are going to be comparing every test query with every concept vector in our test set knowledge base. We will rank each item in the knowledge base by its cosine similarity with the test query. We'll then compare the IDs of the predicted most similar test knowledge base concepts with our ground truth query IDs to calculate top 1 and top 5 accuarcy. For this metric higher is better.

In [9]:
# Helper function to get data embeddings
def get_embeddings(model, dataloader):
    embeddings, cids = [], []

    with torch.no_grad():
        for batch in tqdm(dataloader):
            input_ids, token_type_ids, attention_mask, batch_cids = batch
            batch_embeddings = model.forward(input_ids=input_ids.to(device), 
                                             token_type_ids=token_type_ids.to(device), 
                                             attention_mask=attention_mask.to(device))

            # Accumulate index embeddings and their corresponding IDs
            embeddings.extend(batch_embeddings.cpu().detach().numpy())
            cids.extend(batch_cids)
            
    return embeddings, cids

In [10]:
def evaluate(model, test_kb, test_queries, ks):
    # Initialize knowledge base and query data loaders
    test_kb_dataloader = model.setup_dataloader(test_kb, is_index_data=True)
    test_query_dataloader = model.setup_dataloader(test_queries, is_index_data=True)
    
    # Get knowledge base and query embeddings
    test_kb_embs, test_kb_cids = get_embeddings(model, test_kb_dataloader)
    test_query_embs, test_query_cids = get_embeddings(model, test_query_dataloader)

    # Calculate the cosine distance between each query and knowledge base concept
    score_matrix = np.matmul(np.array(test_query_embs), np.array(test_kb_embs).T)
    accs = {k : 0 for k in ks}
    
    # Compare the knowledge base IDs of the knowledge base entities with 
    # the smallest cosine distance from the query 
    for query_idx in tqdm(range(len(test_query_cids))):
        query_emb = test_query_embs[query_idx]
        query_cid = test_query_cids[query_idx]
        query_scores = score_matrix[query_idx]

        for k in ks:
            topk_idxs = np.argpartition(query_scores, -k)[-k:]
            topk_cids = [test_kb_cids[idx] for idx in topk_idxs]
            
            # If the correct query ID is amoung the top k closest kb IDs
            # the model correctly linked the entity
            match = int(query_cid in topk_cids)
            accs[k] += match

    for k in ks:
        accs[k] /= len(test_query_cids)
                
    return accs

In [11]:
# Create configs for our test data
test_kb = OmegaConf.create({
    "data_file": os.path.join(DATA_DIR, "tiny_example_test_kb.tsv"),
    "max_seq_length": 128,
    "batch_size": 10,
    "shuffle": False,
})

test_queries = OmegaConf.create({
    "data_file": os.path.join(DATA_DIR, "tiny_example_test_queries.tsv"),
    "max_seq_length": 128,
    "batch_size": 10,
    "shuffle": False,
})

ks = [1, 5]

# Evaluate both models on our test data
base_accs = evaluate(base_model, test_kb, test_queries, ks)
base_accs["Model"] = "BERT Base Baseline"

sap_accs = evaluate(sap_model, test_kb, test_queries, ks)
sap_accs["Model"] = "BERT + SAP"

print("Top 1 and Top 5 Accuracy Comparison:")
results_df = pd.DataFrame([base_accs, sap_accs], columns=["Model", 1, 5])
results_df = results_df.style.set_properties(**{'text-align': 'left', }).set_table_styles([dict(selector='th', props=[('text-align', 'left')])])
display(results_df)

[NeMo I 2021-04-15 20:16:41 entity_linking_dataset:61] Getting datafile newline indices
[NeMo I 2021-04-15 20:16:41 entity_linking_dataset:78] Loaded dataset with 22 examples
[NeMo I 2021-04-15 20:16:41 entity_linking_dataset:61] Getting datafile newline indices
[NeMo I 2021-04-15 20:16:41 entity_linking_dataset:78] Loaded dataset with 10 examples


100%|██████████| 3/3 [00:00<00:00, 12.43it/s]
100%|██████████| 1/1 [00:00<00:00,  6.64it/s]
100%|██████████| 10/10 [00:00<00:00, 7436.71it/s]

[NeMo I 2021-04-15 20:16:41 entity_linking_dataset:78] Loaded dataset with 22 examples
[NeMo I 2021-04-15 20:16:41 entity_linking_dataset:78] Loaded dataset with 10 examples



100%|██████████| 3/3 [00:00<00:00, 12.69it/s]
100%|██████████| 1/1 [00:00<00:00,  5.67it/s]
100%|██████████| 10/10 [00:00<00:00, 7913.78it/s]


Top 1 and Top 5 Accuracy Comparison:


Unnamed: 0,Model,1,5
0,BERT Base Baseline,0.7,1.0
1,BERT + SAP,0.9,1.0


The purpose of this section was to show an example of evaluating your entity linking model. This evaluation set contains very little data, and no serious conclusions should be drawn about model performance. Top 1 accuracy should be between 0.7 and 1.0 for both models and top 5 accuracy should be between 0.9 and 1.0. When evaluating a model trained on a larger dataset, you can use a nearest neighbors index to speed up the evaluation time.

## Building an Index

To qualitatively observe the improvement we gain from the second stage pretraining, let's build two indices. One will be built with BERT base embeddings before self-alignment pretraining and one will be built with the model we just trained. Our knowledge base in this tutorial will be in the same domain and have some overlapping concepts as the training set. This data file is formatted as `ID\tconcept`.

The `EntityLinkingDataset` class can load the data used for training the entity linking encoder as well as for building the index if the `is_index_data` flag is set to true. 

In [12]:
def build_index(cfg, model):
    # Setup index dataset loader
    index_dataloader = model.setup_dataloader(cfg.index.index_ds, is_index_data=True)
    
    # Get index dataset embeddings
    embeddings, _ = get_embeddings(model, index_dataloader)
    
    # Train IVFFlat index using faiss
    embeddings = np.array(embeddings)
    quantizer = faiss.IndexFlatL2(cfg.index.dims)
    index = faiss.IndexIVFFlat(quantizer, cfg.index.dims, cfg.index.nlist)
    index = faiss.index_cpu_to_all_gpus(index)
    index.train(embeddings)
    
    # Add concept embeddings to index
    for i in tqdm(range(0, embeddings.shape[0], cfg.index.index_batch_size)):
            index.add(embeddings[i:i+cfg.index.index_batch_size])

    # Save index
    faiss.write_index(faiss.index_gpu_to_cpu(index), cfg.index.index_save_name)

In [13]:
build_index(sap_model_cfg, sap_model.to(device))
build_index(base_model_cfg, base_model.to(device))

[NeMo I 2021-04-15 20:16:42 entity_linking_dataset:61] Getting datafile newline indices
[NeMo I 2021-04-15 20:16:42 entity_linking_dataset:78] Loaded dataset with 12 examples


100%|██████████| 1/1 [00:00<00:00,  4.74it/s]
100%|██████████| 2/2 [00:00<00:00, 2355.69it/s]

[NeMo I 2021-04-15 20:16:43 entity_linking_dataset:78] Loaded dataset with 12 examples



100%|██████████| 1/1 [00:00<00:00,  6.22it/s]
100%|██████████| 2/2 [00:00<00:00, 1380.84it/s]


## Entity Linking via Nearest Neighbor Search

Now it's time to query our indices! We are going to query both our index built with embeddings from BERT Base, and our index with embeddings built from the SAP BERT model we trained. Our sample query phrases will be "*high blood sugar*" and "*head pain*". 

To query our indices, we first need get the embedding of each query from the corresponding encoder model. We can then pass these query embeddings into the faiss index which will perform a nearest neighbor search, using cosine distance to compare the query embedding with embeddings present in the index. Once we get a list of knowledge base index concept IDs most closely matching our query, all that is left to do is map the IDs to a representative string describing the concept. 

In [14]:
def query_index(cfg, model, index, queries, id2string):
    # Get query embeddings from our entity linking encoder model
    query_embs = get_query_embedding(queries, model).cpu().detach().numpy()
    
    # Use query embedding to find closest concept embedding in knowledge base
    distances, neighbors = index.search(query_embs, cfg.index.top_n)
    
    # Get the canonical strings corresponding to the IDs of the query's nearest neighbors in the kb 
    neighbor_concepts = [[id2string[concept_id] for concept_id in query_neighbor] \
                                                for query_neighbor in neighbors]
    
    # Display most similar concepts in the knowledge base. 
    for query_idx in range(len(queries)):
        print(f"\nThe most similar concepts to {queries[query_idx]} are:")
        for cid, concept, dist in zip(neighbors[query_idx], neighbor_concepts[query_idx], distances[query_idx]):
            print(cid, concept, 1 - dist)

    
def get_query_embedding(queries, model):
    # Tokenize our queries
    model_input =  model.tokenizer(queries,
                                   add_special_tokens = True,
                                   padding = True,
                                   truncation = True,
                                   max_length = 512,
                                   return_token_type_ids = True,
                                   return_attention_mask = True)
    
    # Pass tokenized input into model
    query_emb =  model.forward(input_ids=torch.LongTensor(model_input["input_ids"]).to(device),
                               token_type_ids=torch.LongTensor(model_input["token_type_ids"]).to(device),
                               attention_mask=torch.LongTensor(model_input["attention_mask"]).to(device))
    
    return query_emb

In [15]:
# Load indices
sap_index = faiss.read_index(sap_model_cfg.index.index_save_name)
base_index = faiss.read_index(base_model_cfg.index.index_save_name)

In [16]:
# Map concept IDs to one canonical string
index_data = open(sap_model_cfg.index.index_ds.data_file, "r", encoding='utf-8-sig')
id2string = {}

for line in index_data:
    cid, concept = line.split("\t")
    id2string[int(cid) - 1] = concept.strip()

In [17]:
id2string

{0: 'Headache',
 1: 'Myocardial infraction',
 2: 'Coronary artery disease',
 3: 'myocardial ischemia',
 4: 'chronic kidney disease',
 5: 'alchohol intoxication',
 6: 'diabetes',
 7: 'Hyperinsulinemia',
 8: 'Nesina',
 9: 'hypoglycemia',
 10: 'anticoagulants',
 11: 'Ibuprofen'}

In [18]:
# Some sample queries
queries = ["high blood sugar", "head pain"]

# Query BERT Base
print("BERT Base output before Self Alignment Pretraining:")
query_index(base_model_cfg, base_model, base_index, queries, id2string)
print("\n" + "-" * 50 + "\n")

# Query SAP BERT
print("SAP BERT output after Self Alignment Pretraining:")
query_index(sap_model_cfg, sap_model, sap_index, queries, id2string)
print("\n" + "-" * 50 + "\n")

BERT Base output before Self Alignment Pretraining:

The most similar concepts to high blood sugar are:
6 diabetes 0.9095035567879677
0 Headache 0.9046077281236649
8 Nesina 0.8512845635414124

The most similar concepts to head pain are:
1 Myocardial infraction 0.7848672568798065
4 chronic kidney disease 0.7667323648929596
3 myocardial ischemia 0.761662557721138

--------------------------------------------------

SAP BERT output after Self Alignment Pretraining:

The most similar concepts to high blood sugar are:
7 Hyperinsulinemia 0.3652629256248474
9 hypoglycemia 0.27968084812164307
3 myocardial ischemia 0.24829477071762085

The most similar concepts to head pain are:
0 Headache 0.6382399797439575
6 diabetes 0.18829917907714844
8 Nesina -0.005720615386962891

--------------------------------------------------



Even after only training on this tiny amount of data, the qualitative performance boost from self-alignment pretraining is visible. The baseline model links "*high blood sugar*" to the entity "*6 diabetes*" while our SAP BERT model accurately links "*high blood sugar*" to "*Hyperinsulinemia*". Similarly, "*head pain*" and "*Myocardial infraction*" are not the same concept, but "*head pain*" and "*Headache*" are.

For larger knowledge bases keeping the default embedding size might be too large and cause out of memory issues. You can apply PCA or some other dimensionality reduction method to your data to reduce its memory footprint. Code for creating a text file of all the UMLS entities in the correct format needed to build an index and creating a dictionary mapping concept ids to canonical concept strings can be found here `examples/nlp/entity_linking/data/umls_dataset_processing.py`. 

The code for extracting knowledge base concept embeddings, training and applying a PCA transformation to the embeddings, building a faiss index and querying the index from the command line is located at `examples/nlp/entity_linking/build_and_query_index.py`. 

If you've cloned the NeMo repo, both of these steps can be run as follows on the command line from the `examples/nlp/entity_linking/` directory.

```
python data/umls_dataset_processing.py --index --cfg /conf/medical_entity_linking_config.yaml
python build_and_query_index.py --restore --cfg conf/medical_entity_linking_config.yaml --top_n 5 
```
Intermediate steps of the index building process are saved. In the occurance of an error, previously completed steps do not need to be rerun. 

## Command Recap

Here is a recap of the commands and steps to repeat this process on the full UMLS dataset. 

1) Download the UMLS datset file `MRCONSO.RRF` from the NIH website and place it in the `examples/nlp/entity_linking/data` directory.

2) Run the following commands from the `examples/nlp/entity_linking` directory
```
python data/umls_dataset_processing.py --cfg conf/umls_medical_entity_linking_config.yaml
python self_alignment_pretraining.py
python data/umls_dataset_processing.py --index --cfg conf/umls_medical_entity_linking_config.yaml
python build_and_query_index.py --restore --cfg conf/umls_medical_entity_linking_config.yaml --top_n 5
```
The model will take ~24hrs to train on two GPUs and ~48hrs to train on one GPU.

As mentioned in the introduction, entity linking within NVIDIA NeMo is not limited to the medical domain. The same data processing and training steps can be applied to a variety of domains and use cases. You can edit the datasets used as well as training and loss function hyperparameters within your config file to better suit your domain.