# Entity Embedding Tutorial

In this tutorial, we walk through how to generate Bootleg contextual entity embeddings for use in downstream tasks using a pretrained Bootleg model. We also demonstrate how to extract other Bootleg embeddings for downstream tasks when contextualized embeddings are not needed.

### Requirements

You will need to download the following files for this notebook:
- Pretrained Bootleg uncased model and config [here](https://bootleg-data.s3.amazonaws.com/models/lateset/bootleg_uncased.tar.gz). Cased model and config [here](https://bootleg-data.s3.amazonaws.com/models/lateset/bootleg_cased.tar.gz)
- Sample of Natural Questions with hand-labelled entities [here](https://bootleg-data.s3.amazonaws.com/data/lateset/nq.tar.gz)
- Entity data [here](https://bootleg-data.s3.amazonaws.com/data/lateset/wiki_entity_data.tar.gz)
- Embedding data [here](https://bootleg-data.s3.amazonaws.com/data/lateset/emb_data.tar.gz)

For convenience, you can run the commands below (from the root directory of the repo) to download all the above files and unpack them to `models` and `data` directories. It will take several minutes to download all the files.

```
    # use cased for cased model
    bash tutorials/download_model.sh uncased
    bash tutorials/download_data.sh
```

You can also run directly in this notebook by

In [None]:
!sh download_model.sh uncased
!sh download_data.sh

## 1.  Prepare Model Config

As with the other tutorials, we set up the config to point to the correct data directories and model checkpoint. We use the sample of [Natural Questions](https://ai.google.com/research/NaturalQuestions) with mentions extracted by Bootleg introduced in the End-to-End tutorial. 

In [1]:
from pathlib import Path

# set up logging
import sys
import logging
from importlib import reload
reload(logging)
logging.basicConfig(stream=sys.stdout, format='%(asctime)s %(message)s', level=logging.INFO)
logger = logging.getLogger(__name__)

The history saving thread hit an unexpected error (DatabaseError('database disk image is malformed')).History will not be written to the database.


If you have a GPU with at least 12GB of memory available, set the below to 0 to run inference on a GPU.

In [2]:
device = -1

We need to specify the input directory where files were downloaded below. 

In [7]:
from bootleg.utils.parser.parser_utils import parse_boot_and_emm_args
from bootleg.utils.utils import load_yaml_file
from bootleg.run import run_model

# root_dir = FILL IN FULL PATH TO DIRECTORY WHERE DATA IS DOWNLOADED (e.g., root_dir/data and root_dir/models)
root_dir = Path(".")

config_in_path = str(root_dir / 'models/bootleg_uncased/bootleg_config.yaml')
data_dir =  str(root_dir / 'data/nq')
entity_dir = str(root_dir / 'data/wiki_entity_data')
alias_map = "alias2qids_wiki_filt.json"
test_file = "test_bootleg_men.jsonl"

config_args = load_yaml_file(config_in_path)

# decrease number of data threads as this is a small file
config_args["run_config"]["dataset_threads"] = 2
# set the model checkpoint path 
config_args["emmental"]["model_path"] = str(root_dir / 'models/bootleg_uncased/bootleg_wiki.pth')

# set the path for the entity db and candidate map
config_args["data_config"]["entity_dir"] = entity_dir
config_args["data_config"]["alias_cand_map"] = alias_map

# set the data path and kore50 test file 
config_args["data_config"]["data_dir"] = data_dir

# to speed things up for the tutorial, we have already prepped the data with the mentions detected by Bootleg
config_args["data_config"]["test_dataset"]["file"] = test_file

# set the embedding paths 
config_args["data_config"]["emb_dir"] =  str(root_dir / 'data/emb_data')
config_args["data_config"]["word_embedding"]["cache_dir"] =  str(root_dir / 'data/embs/pretrained_bert_models')

# set the devie if on CPU
config_args["emmental"]["device"] = device

config_args = parse_boot_and_emm_args(config_args) # or you can pass in the config_out_path

## 2. Load Contextual Entity Embeddings

We now show how Bootleg contextualized embeddings can be loaded and used in downstream tasks. First we use the `dump_embs` mode to generate contextual entity embeddings. 

In [8]:
bootleg_label_file, bootleg_emb_file = run_model(mode="dump_embs", config=config_args)

2021-01-28 14:00:56,371 Logging was already initialized to use bootleg_logs/wiki_full_ft/2021_01_28/12_20_50/bc3d0092.  To configure logging manually, call emmental.init_logging before initialiting Meta.
2021-01-28 14:00:56,424 Loading Emmental default config from /dfs/scratch0/lorr1/projects/emmental/src/emmental/emmental-default-config.yaml.
2021-01-28 14:00:56,425 Updating Emmental config from user provided config.
2021-01-28 14:00:56,549 COMMAND: /dfs/scratch0/lorr1/env_bootleg_38/lib/python3.8/site-packages/ipykernel_launcher.py -f /dfs/scratch0/lorr1/projects/bootleg/notebooks/:/afs/cs.stanford.edu/u/lorr1/.local/apt-cache/share/jupyter/runtime/kernel-4a75c8f6-3129-4873-a5c1-0a51ed79b2fe.json
2021-01-28 14:00:56,550 Saving config to bootleg_logs/wiki_full_ft/2021_01_28/12_20_50/bc3d0092/parsed_config.yaml
2021-01-28 14:00:57,287 Git Hash: Not able to retrieve git hash
2021-01-28 14:00:57,288 Loading entity symbols...
2021-01-28 14:03:39,383 Starting to build data for test from /d

  guid_dtype = np.dtype(
  descr = dtypedescr(dtype)


2021-01-28 14:03:39,622 Built dataloader for test set with 2540 and 1 threads samples (Shuffle=False, Batch size=32).
2021-01-28 14:03:39,630 Building slice dataset for test from /dfs/scratch0/lorr1/projects/bootleg-data/data/benchmarks/aida_0928_nosep/filtered/test_bootleg_men.jsonl.
2021-01-28 14:03:39,686 Loading data from /dfs/scratch0/lorr1/projects/bootleg-data/data/benchmarks/aida_0928_nosep/filtered/prep/test_bootleg_men_bert-base-cased_L100_A10_InC1_Aug1/ned_slices_1f126b5224.bin and /dfs/scratch0/lorr1/projects/bootleg-data/data/benchmarks/aida_0928_nosep/filtered/prep/test_bootleg_men_bert-base-cased_L100_A10_InC1_Aug1/ned_slices_config.json


Building sent idx to row idx mapping: 100%|██████████| 2465/2465 [00:00<00:00, 16114.36it/s]


2021-01-28 14:03:39,901 Final slice data initialization time from test is 0.2705113887786865s
2021-01-28 14:03:39,902 Updating Emmental config from user provided config.
2021-01-28 14:03:39,908 Starting Bootleg Model
2021-01-28 14:03:39,909 Created emmental model Bootleg that contains task set().
2021-01-28 14:03:46,040 Loading embeddings...
2021-01-28 14:03:46,041 Embedding "learned" has params (these can be changed in the config)
{
    "load_class":"LearnedEntityEmb",
    "key":"learned",
    "cpu":false,
    "freeze":false,
    "dropout1d":0.0,
    "dropout2d":0.0,
    "normalize":true,
    "send_through_bert":false
}
2021-01-28 14:04:36,082 Embedding "title_static" has params (these can be changed in the config)
{
    "load_class":"StaticEmb",
    "key":"title_static",
    "cpu":false,
    "freeze":false,
    "dropout1d":0.0,
    "dropout2d":0.0,
    "normalize":true,
    "send_through_bert":false
}
2021-01-28 14:06:57,154 Embedding "learned_type" has params (these can be changed i

Evaluating Bootleg (test): 100%|██████████| 80/80 [05:15<00:00,  3.94s/it]
100%|██████████| 2540/2540 [00:00<00:00, 4134.38it/s]
Reading values for marisa trie: 100%|██████████| 2465/2465 [00:00<00:00, 360368.05it/s]


2021-01-28 14:15:11,440 Merging sentences together with 2 processes


Building sent_idx, alias_list_pos mapping: 100%|██████████| 11360/11360 [00:00<00:00, 97337.90it/s] 
Reading values for marisa trie: 100%|██████████| 11360/11360 [00:00<00:00, 401629.31it/s]
Writing data: 100%|██████████| 1233/1233 [00:03<00:00, 400.69it/s]
Writing data: 100%|██████████| 1232/1232 [00:00<00:00, 5720.78it/s]


2021-01-28 14:16:48,177 Merging output files
2021-01-28 14:16:50,036 Saving contextual entity embeddings to bootleg_logs/wiki_full_ft/2021_01_28/12_20_50/bc3d0092/test_bootleg_men/bootleg_wiki_1/bootleg_embs.npy
2021-01-28 14:16:50,038 Wrote predictions to bootleg_logs/wiki_full_ft/2021_01_28/12_20_50/bc3d0092/test_bootleg_men/bootleg_wiki_1/bootleg_labels.jsonl


In `dump_embs` mode, Bootleg saves the contextual entity embeddings corresponding to each mention in each sentence to a file. We return this file in the variable `bootleg_emb_file`. We can also see the full file path in the log (ends in `*npy`). 

In [9]:
import numpy as np
contextual_entity_embs = np.load(bootleg_emb_file)
contextual_entity_embs.shape

(11360, 512)

Each row in the contextual entity embedding above corresponds to an extracted mention in a sentence. In the above embedding there are 100 extracted mentions total with 512 dimensions for each corresponding contextual entity embedding.

The mapping from mentions to rows in the contextual entity embedding is stored in `ctx_emb_ids` in the label file. We now check out the label file, which was also generated and returned from running `dump_embs` mode.

In [10]:
import jsonlines
with jsonlines.open(bootleg_label_file) as f: 
    for i, line in enumerate(f): 
        print('sentence:', line['sentence'])
        print('mentions:', line['aliases'])
        print('contextual emb ids:', line['ctx_emb_ids'])
        print()
        if i == 5: 
            break

sentence: soccer - japan get lucky win , china in surprise defeat . soccer - japan get lucky win , china in surprise defeat .
mentions: ['soccer', 'japan', 'china', 'soccer', 'japan', 'china']
contextual emb ids: [0, 1, 2, 3, 4, 5]

sentence: soccer - japan get lucky win , china in surprise defeat . al-ain , United Arab Emirates 1996-12-06
mentions: ['soccer', 'japan', 'china', 'alain', 'united arab emirates']
contextual emb ids: [6, 7, 8, 9, 10]

sentence: soccer - japan get lucky win , china in surprise defeat . Japan began the defence of their Asian Cup title with a lucky 2-1 win against Syria in a Group C championship match on Friday .
mentions: ['soccer', 'japan', 'china', 'japan', 'asian cup', 'syria', 'group c', 'championship match', 'friday']
contextual emb ids: [11, 12, 13, 14, 15, 16, 17, 18, 19]

sentence: soccer - japan get lucky win , china in surprise defeat . But China saw their luck desert them in the second match of the group , crashing to a surprise 2-0 defeat to newc

In the first sentence, we can find the corresponding contextual entity embedding for "the voice", "the magician", and "frosty the snowman" in rows 0, 1, and 2 of `contextual_entity_embs`, respectively. Similarly, we have unique row ids for the mentions in each of the other sentences. A downstream task can use this process to load the correct contextual entity embeddings for each mention in a simple dataloader.

## 3. Load Static Embeddings

In addition to contextual entity embeddings, Bootleg learns static entity embeddings as well as type and relation embeddings. These can be useful in downstream tasks when contextual information is not available for the downstream task, or if we want the same entity embedding regardless of the context or position of the mention.

We walk through how to extract the static, learned entity embeddings from a pretrained Bootleg model. First, we define a utility function to load a model.

In [19]:
import torch
import os
import emmental
from bootleg.task_config import NED_TASK, TYPE_PRED_TASK
from bootleg.tasks import ned_task, type_pred_task
from bootleg.symbols.entity_symbols import EntitySymbols
from emmental.model import EmmentalModel


def load_model(config, device=-1):
        if "emmental" in config:
            config = parse_boot_and_emm_args(config)

        emmental.init(
            log_dir=config["meta_config"]["log_path"], config=config
        )

        print("Reading entity database")
        entity_db = EntitySymbols(
            os.path.join(
                config.data_config.entity_dir,
                config.data_config.entity_map_dir,
            ),
            alias_cand_map_file=config.data_config.alias_cand_map,
        )

        # Create tasks
        tasks = [NED_TASK]
        if config.data_config.type_prediction.use_type_pred is True:
            tasks.append(TYPE_PRED_TASK)

        # Create tasks
        model = EmmentalModel(name="Bootleg")
        model.add_task(ned_task.create_task(config, entity_db))
        if TYPE_PRED_TASK in tasks:
            model.add_task(type_pred_task.create_task(config, entity_db))
            # Add the mention type embedding to the embedding payload
            type_pred_task.update_ned_task(model)

        print("Loading model")
        # Load the best model from the pretrained model
        assert (
            config["model_config"]["model_path"] is not None
        ), f"Must have a model to load in the model_path for the BootlegAnnotator"
        model.load(config["model_config"]["model_path"])
        model.eval()
        return model

Load the pretrained Bootleg model. This will take several minutes. 

In [20]:
model = load_model(config=config_args, device=device)

2021-01-28 15:01:52,858 Logging was already initialized to use bootleg_logs/wiki_full_ft/2021_01_28/12_20_50/bc3d0092.  To configure logging manually, call emmental.init_logging before initialiting Meta.
2021-01-28 15:01:52,917 Loading Emmental default config from /dfs/scratch0/lorr1/projects/emmental/src/emmental/emmental-default-config.yaml.
2021-01-28 15:01:52,918 Updating Emmental config from user provided config.
Reading entity database
Reading word tokenizers
2021-01-28 15:04:07,517 Created emmental model Bootleg that contains task set().
2021-01-28 15:04:10,958 Loading embeddings...
2021-01-28 15:04:10,959 Embedding "learned" has params (these can be changed in the config)
{
    "load_class":"LearnedEntityEmb",
    "key":"learned",
    "cpu":false,
    "freeze":false,
    "dropout1d":0.0,
    "dropout2d":0.0,
    "normalize":true,
    "send_through_bert":false
}
2021-01-28 15:04:23,757 Embedding "title_static" has params (these can be changed in the config)
{
    "load_class":"S

Get the static, learned entity embedding as a torch tensor.

In [54]:
learned_emb_obj = model.module_pool.learned
embedding_as_tensor = torch.Tensor(learned_emb_obj.learned_entity_embedding.weight)
print(embedding_as_tensor.shape)

torch.Size([5832701, 200])


This Bootleg model was trained on data with 5.8 million entities and each entity embedding is 200-dimensional, as indicated by the shape of the static, learned entity embedding above.

The mapping from mentions to rows in the static, learned entity embedding (corresponding to the predicted entity) is also saved in the label file produced by `dump_embs` mode. We check out the label file below and use the `entity_ids` key to find the corresponding embedding row. The `entity_ids` can also be extracted from the returned `qids` by using the `qid2eid.json` mapping in `entity_dir/entity_mappings`.

In [56]:
import jsonlines
with jsonlines.open(bootleg_label_file) as f: 
    for i, line in enumerate(f): 
        print('sentence:', line['sentence'])
        print('mentions:', line['aliases'])
        print('entity ids:', line['entity_ids'])
        print()
        if i == 5: 
            break

sentence: soccer - japan get lucky win , china in surprise defeat . soccer - japan get lucky win , china in surprise defeat .
mentions: ['soccer', 'japan', 'china', 'soccer', 'japan', 'china']
entity ids: [3157011, 3705410, 106035, 3157011, 3705410, 4486968]

sentence: soccer - japan get lucky win , china in surprise defeat . al-ain , United Arab Emirates 1996-12-06
mentions: ['soccer', 'japan', 'china', 'alain', 'united arab emirates']
entity ids: [3157011, 4223535, 4486968, 1944367, 2478913]

sentence: soccer - japan get lucky win , china in surprise defeat . Japan began the defence of their Asian Cup title with a lucky 2-1 win against Syria in a Group C championship match on Friday .
mentions: ['soccer', 'japan', 'china', 'japan', 'asian cup', 'syria', 'group c', 'championship match', 'friday']
entity ids: [3157011, 1593316, 4486968, 4223535, 2069182, 120274, 1853145, 519968, 1932597]

sentence: soccer - japan get lucky win , china in surprise defeat . But China saw their luck deser

Unlike the contextual entity embeddings, the static embeddings are not unique across mentions. For instance, if the same entity is predicted across two different mentions, the static entity embedding (and ids in the label file) will be the same for those mentions, whereas the contextual entity embeddings and ids will be different. 

You can also extract the embeddings through the `forward` pass on the embedding class. We will use random entity ids for demonstration.

### Important: the `forward` pass will _normalize_ the embedding. Use the weight tensor above to not normalize.

In [53]:
learned_emb_obj = model.module_pool.learned
batch = 5
M = 4
K = 3
eid_cands = torch.randint(0, 5000, (batch, M, K))
# batch_on_the_fly_data is a dictionary used for KG metadata; keep it emtpy for extracting embeddings
embs = learned_emb_obj(eid_cands, batch_on_the_fly_data={})
print(embs.shape)

torch.Size([5, 4, 3, 200])


You can repeat the same process to extract the type embeddings. Our type embeddings are 128 dimensions.

### Important: the type module `forward` will also _normalize_ and apply an additive attention mechanism to merge the multiple type embeddings for a single entity.

In [55]:
wd_type_obj = model.module_pool.learned_type_wiki
batch = 5
M = 4
K = 3
eid_cands = torch.randint(0, 5000, (batch, M, K))
embs = wd_type_obj(eid_cands, {})
print(embs.shape)

torch.Size([5, 4, 3, 128])
