# Entity Embedding Tutorial

In this tutorial, we walk through how to generate Bootleg contextual entity embeddings for use in downstream tasks using a pretrained Bootleg model. We also demonstrate how to extract other Bootleg embeddings for downstream tasks when contextualized embeddings are not needed.

### Requirements

You will need to download the following files for this notebook:
- Pretrained Bootleg uncased model and config [here](https://bootleg-data.s3-us-west-2.amazonaws.com/models/lateset/bootleg_uncased.tar.gz). Cased model and config [here](https://bootleg-data.s3-us-west-2.amazonaws.com/models/lateset/bootleg_cased.tar.gz)
- Sample of Natural Questions with hand-labelled entities [here](https://bootleg-data.s3-us-west-2.amazonaws.com/data/lateset/nq.tar.gz)
- Entity data [here](https://bootleg-data.s3-us-west-2.amazonaws.com/data/lateset/entity_db.tar.gz)

For convenience, you can run the commands below (from the root directory of the repo) to download all the above files and unpack them to `models` and `data` directories. It will take several minutes to download all the files.

```
    # use cased for cased model
    bash tutorials/download_model.sh uncased
    bash tutorials/download_data.sh
```

You can also run directly in this notebook by

In [None]:
!sh download_model.sh uncased
!sh download_data.sh

## 1.  Prepare Model Config

As with the other tutorials, we set up the config to point to the correct data directories and model checkpoint. We use the sample of [Natural Questions](https://ai.google.com/research/NaturalQuestions) with mentions extracted by Bootleg introduced in the End-to-End tutorial. 

In [1]:
from pathlib import Path

# set up logging
import sys
import logging
from importlib import reload
reload(logging)
logging.basicConfig(stream=sys.stdout, format='%(asctime)s %(message)s', level=logging.INFO)
logger = logging.getLogger(__name__)

If you have a GPU with at least 12GB of memory available, set the below to 0 to run inference on a GPU.

In [2]:
device = -1

We need to specify the input directory where files were downloaded below. 

In [6]:
from bootleg.utils.parser.parser_utils import parse_boot_and_emm_args
from bootleg.utils.utils import load_yaml_file
from bootleg.run import run_model

# root_dir = FILL IN FULL PATH TO DIRECTORY WHERE DATA IS DOWNLOADED (e.g., root_dir/data and root_dir/models)
root_dir = Path("../tutorial_data_ep")

config_in_path = str(root_dir / 'models/bootleg_uncased/bootleg_config.yaml')
data_dir =  str(root_dir / 'data/nq')
entity_dir = str(root_dir / 'data/entity_db')
alias_map = "alias2qids.json"
test_file = "test_50_bootleg.jsonl"

config_args = load_yaml_file(config_in_path)

# decrease number of data threads as this is a small file
config_args["run_config"]["dataset_threads"] = 2
# set the model checkpoint path 
config_args["emmental"]["model_path"] = str(root_dir / 'models/bootleg_uncased/bootleg_wiki.pth')

# set the path for the entity db and candidate map
config_args["data_config"]["entity_dir"] = entity_dir
config_args["data_config"]["alias_cand_map"] = alias_map

# set the data path and kore50 test file 
config_args["data_config"]["data_dir"] = data_dir

# to speed things up for the tutorial, we have already prepped the data with the mentions detected by Bootleg
config_args["data_config"]["test_dataset"]["file"] = test_file

# set the embedding paths 
config_args["data_config"]["emb_dir"] =  str(root_dir / 'data/entity_db')
config_args["data_config"]["word_embedding"]["cache_dir"] =  str(root_dir / 'data/pretrained_bert_models')

# set the devie if on CPU
config_args["emmental"]["device"] = device

config_args = parse_boot_and_emm_args(config_args) # or you can pass in the config_out_path

## 2. Load Contextual Entity Embeddings

We now show how Bootleg contextualized embeddings can be loaded and used in downstream tasks. First we use the `dump_embs` mode to generate contextual entity embeddings. 

In [7]:
bootleg_label_file, bootleg_emb_file = run_model(mode="dump_embs", config=config_args)

2021-03-11 17:16:49,834 Logging was already initialized to use bootleg_logs/wiki_full_ft/2021_03_11/17_13_50/c2cd809c.  To configure logging manually, call emmental.init_logging before initialiting Meta.
2021-03-11 17:16:49,888 Loading Emmental default config from /dfs/scratch0/lorr1/env_bootleg_38/lib/python3.8/site-packages/emmental/emmental-default-config.yaml.
2021-03-11 17:16:49,889 Updating Emmental config from user provided config.
2021-03-11 17:16:49,890 Set random seed to 1234.
2021-03-11 17:16:50,010 COMMAND: /dfs/scratch0/lorr1/env_bootleg_38/lib/python3.8/site-packages/ipykernel_launcher.py -f /dfs/scratch0/lorr1/projects/:/afs/cs.stanford.edu/u/lorr1/.local/apt-cache/share/jupyter/runtime/kernel-1d6bf30d-8475-4b26-8c6d-c585b1302c91.json
2021-03-11 17:16:50,011 Saving config to bootleg_logs/wiki_full_ft/2021_03_11/17_13_50/c2cd809c/parsed_config.yaml
2021-03-11 17:16:50,419 Git Hash: Not able to retrieve git hash
2021-03-11 17:16:50,421 Loading entity symbols...
2021-03-11 

  guid_dtype = np.dtype(
  descr = dtypedescr(dtype)


2021-03-11 17:19:08,489 Built dataloader for test set with 49 and 1 threads samples (Shuffle=False, Batch size=32).
2021-03-11 17:19:08,501 Building slice dataset for test from ../tutorial_data_ep/data/nq/test_50_bootleg.jsonl.
2021-03-11 17:19:08,553 Loading data from ../tutorial_data_ep/data/nq/prep/test_50_bootleg_bert-base-uncased_L100_A10_InC1_Aug1/ned_slices_1f126b5224.bin and ../tutorial_data_ep/data/nq/prep/test_50_bootleg_bert-base-uncased_L100_A10_InC1_Aug1/ned_slices_config.json


Building sent idx to row idx mapping: 100%|██████████| 50/50 [00:00<00:00, 10693.21it/s]

2021-03-11 17:19:08,675 Final slice data initialization time from test is 0.17346429824829102s
2021-03-11 17:19:08,676 Updating Emmental config from user provided config.
2021-03-11 17:19:08,677 Set random seed to 1234.
2021-03-11 17:19:08,683 Starting Bootleg Model
2021-03-11 17:19:08,684 Created emmental model Bootleg that contains task set().





2021-03-11 17:19:12,206 Loading embeddings...
2021-03-11 17:19:40,563 Created task: NED
2021-03-11 17:19:40,565 Moving bert module to CPU.
2021-03-11 17:19:40,572 Moving embedding_payload module to CPU.
2021-03-11 17:19:40,573 Moving attn_network module to CPU.
2021-03-11 17:19:40,576 Moving pred_layer module to CPU.
2021-03-11 17:19:40,577 Moving learned module to CPU.
2021-03-11 17:19:40,578 Moving title_static module to CPU.
2021-03-11 17:19:40,579 Moving learned_type module to CPU.
2021-03-11 17:19:40,580 Moving learned_type_wiki module to CPU.
2021-03-11 17:19:40,582 Moving learned_type_relations module to CPU.
2021-03-11 17:19:40,583 Moving adj_index module to CPU.
2021-03-11 17:19:47,806 Created task: Type
2021-03-11 17:19:47,810 Moving bert module to CPU.
2021-03-11 17:19:47,815 Moving embedding_payload module to CPU.
2021-03-11 17:19:47,816 Moving attn_network module to CPU.
2021-03-11 17:19:47,819 Moving pred_layer module to CPU.
2021-03-11 17:19:47,820 Moving learned module 

Evaluating Bootleg (test): 100%|██████████| 2/2 [00:09<00:00,  4.57s/it]


2021-03-11 17:20:34,497 Finished dumping. Merging results across accumulation steps.
2021-03-11 17:20:34,539 Bootleg labels saved at bootleg_logs/wiki_full_ft/2021_03_11/17_13_50/c2cd809c/test_50_bootleg/bootleg_wiki/bootleg_labels.jsonl
2021-03-11 17:20:34,540 Trying to merge numpy embedding arrays. If your machine is limited in memory, this may cause OOM errors. Is that happens, result files should be saved in bootleg_logs/wiki_full_ft/2021_03_11/17_13_50/c2cd809c/test_50_bootleg/bootleg_wiki/batch_results.
2021-03-11 17:20:34,598 Bootleg embeddings saved at bootleg_logs/wiki_full_ft/2021_03_11/17_13_50/c2cd809c/test_50_bootleg/bootleg_wiki/bootleg_embs.npy


In `dump_embs` mode, Bootleg saves the contextual entity embeddings corresponding to each mention in each sentence to a file. We return this file in the variable `bootleg_emb_file`. We can also see the full file path in the log (ends in `*npy`). 

In [8]:
import numpy as np
contextual_entity_embs = np.load(bootleg_emb_file)
contextual_entity_embs.shape

(89, 512)

Each row in the contextual entity embedding above corresponds to an extracted mention in a sentence. In the above embedding there are 100 extracted mentions total with 512 dimensions for each corresponding contextual entity embedding.

The mapping from mentions to rows in the contextual entity embedding is stored in `ctx_emb_ids` in the label file. We now check out the label file, which was also generated and returned from running `dump_embs` mode.

In [9]:
import jsonlines
with jsonlines.open(bootleg_label_file) as f: 
    for i, line in enumerate(f): 
        print('sentence:', line['sentence'])
        print('mentions:', line['aliases'])
        print('contextual emb ids:', line['ctx_emb_ids'])
        print()
        if i == 5: 
            break

sentence: Which of these was not an export of Ancient Greece
mentions: ['ancient greece']
contextual emb ids: [49]

sentence: Who opened and closed the 1960 Winter Olympics
mentions: ['1960 winter olympics']
contextual emb ids: [50]

sentence: I see the river Tiber foaming with much blood
mentions: ['river tiber']
contextual emb ids: [51]

sentence: What causes a dead zone in the ocean
mentions: ['dead zone']
contextual emb ids: [52]

sentence: Who plays Claire Underwood 's mom on House of Cards
mentions: ['claire underwood', 'mom', 'house of cards']
contextual emb ids: [53, 54, 55]

sentence: What is the T Rex name in Land Before Time
mentions: ['t rex', 'time']
contextual emb ids: [56, 57]



In the first sentence, we can find the corresponding contextual entity embedding for "the voice", "the magician", and "frosty the snowman" in rows 0, 1, and 2 of `contextual_entity_embs`, respectively. Similarly, we have unique row ids for the mentions in each of the other sentences. A downstream task can use this process to load the correct contextual entity embeddings for each mention in a simple dataloader.

## 3. Load Static Embeddings

In addition to contextual entity embeddings, Bootleg learns static entity embeddings as well as type and relation embeddings. These can be useful in downstream tasks when contextual information is not available for the downstream task, or if we want the same entity embedding regardless of the context or position of the mention.

We walk through how to extract the static, learned entity embeddings from a pretrained Bootleg model. First, we define a utility function to load a model.

In [12]:
import torch
import os
import emmental
from bootleg.task_config import NED_TASK, TYPE_PRED_TASK
from bootleg.tasks import ned_task, type_pred_task
from bootleg.symbols.entity_symbols import EntitySymbols
from emmental.model import EmmentalModel


def load_model(config, device=-1):
        if "emmental" in config:
            config = parse_boot_and_emm_args(config)

        emmental.init(
            log_dir=config["meta_config"]["log_path"], config=config
        )

        print("Reading entity database")
        entity_db = EntitySymbols.load_from_cache(
            os.path.join(
                config.data_config.entity_dir,
                config.data_config.entity_map_dir,
            ),
            alias_cand_map_file=config.data_config.alias_cand_map,
        )

        # Create tasks
        tasks = [NED_TASK]
        if config.data_config.type_prediction.use_type_pred is True:
            tasks.append(TYPE_PRED_TASK)

        # Create tasks
        model = EmmentalModel(name="Bootleg")
        model.add_task(ned_task.create_task(config, entity_db))
        if TYPE_PRED_TASK in tasks:
            model.add_task(type_pred_task.create_task(config, entity_db))
            # Add the mention type embedding to the embedding payload
            type_pred_task.update_ned_task(model)

        print("Loading model")
        # Load the best model from the pretrained model
        assert (
            config["model_config"]["model_path"] is not None
        ), f"Must have a model to load in the model_path for the BootlegAnnotator"
        model.load(config["model_config"]["model_path"])
        model.eval()
        return model

Load the pretrained Bootleg model. This will take several minutes. 

In [13]:
model = load_model(config=config_args, device=device)

2021-03-11 17:38:15,779 Logging was already initialized to use bootleg_logs/wiki_full_ft/2021_03_11/17_13_50/c2cd809c.  To configure logging manually, call emmental.init_logging before initialiting Meta.
2021-03-11 17:38:15,839 Loading Emmental default config from /dfs/scratch0/lorr1/env_bootleg_38/lib/python3.8/site-packages/emmental/emmental-default-config.yaml.
2021-03-11 17:38:15,840 Updating Emmental config from user provided config.
2021-03-11 17:38:15,842 Set random seed to 1234.
Reading entity database
2021-03-11 17:40:06,511 Created emmental model Bootleg that contains task set().
2021-03-11 17:40:24,198 Loading embeddings...
2021-03-11 17:40:46,679 Created task: NED
2021-03-11 17:40:46,680 Moving bert module to CPU.
2021-03-11 17:40:46,685 Moving embedding_payload module to CPU.
2021-03-11 17:40:46,686 Moving attn_network module to CPU.
2021-03-11 17:40:46,689 Moving pred_layer module to CPU.
2021-03-11 17:40:46,690 Moving learned module to CPU.
2021-03-11 17:40:46,690 Moving

Get the static, learned entity embedding as a torch tensor.

In [14]:
learned_emb_obj = model.module_pool.learned
embedding_as_tensor = torch.Tensor(learned_emb_obj.learned_entity_embedding.weight)
print(embedding_as_tensor.shape)

torch.Size([5832701, 200])


This Bootleg model was trained on data with 5.8 million entities and each entity embedding is 200-dimensional, as indicated by the shape of the static, learned entity embedding above.

The mapping from mentions to rows in the static, learned entity embedding (corresponding to the predicted entity) is also saved in the label file produced by `dump_embs` mode. We check out the label file below and use the `entity_ids` key to find the corresponding embedding row. The `entity_ids` can also be extracted from the returned `qids` by using the `qid2eid.json` mapping in `entity_dir/entity_mappings`.

In [15]:
import jsonlines
with jsonlines.open(bootleg_label_file) as f: 
    for i, line in enumerate(f): 
        print('sentence:', line['sentence'])
        print('mentions:', line['aliases'])
        print('entity ids:', line['entity_ids'])
        print()
        if i == 5: 
            break

sentence: Which of these was not an export of Ancient Greece
mentions: ['ancient greece']
entity ids: [552973]

sentence: Who opened and closed the 1960 Winter Olympics
mentions: ['1960 winter olympics']
entity ids: [91786]

sentence: I see the river Tiber foaming with much blood
mentions: ['river tiber']
entity ids: [2608573]

sentence: What causes a dead zone in the ocean
mentions: ['dead zone']
entity ids: [2793916]

sentence: Who plays Claire Underwood 's mom on House of Cards
mentions: ['claire underwood', 'mom', 'house of cards']
entity ids: [3443290, 564561, 3993575]

sentence: What is the T Rex name in Land Before Time
mentions: ['t rex', 'time']
entity ids: [3052538, 1284493]



Unlike the contextual entity embeddings, the static embeddings are not unique across mentions. For instance, if the same entity is predicted across two different mentions, the static entity embedding (and ids in the label file) will be the same for those mentions, whereas the contextual entity embeddings and ids will be different. 

You can also extract the embeddings through the `forward` pass on the embedding class. We will use random entity ids for demonstration.

### Important: the `forward` pass will _normalize_ the embedding. Use the weight tensor above to not normalize.

In [16]:
learned_emb_obj = model.module_pool.learned
batch = 5
M = 4
K = 3
eid_cands = torch.randint(0, 5000, (batch, M, K))
# batch_on_the_fly_data is a dictionary used for KG metadata; keep it emtpy for extracting embeddings
embs = learned_emb_obj(eid_cands, batch_on_the_fly_data={})
print(embs.shape)

torch.Size([5, 4, 3, 200])


You can repeat the same process to extract the type embeddings. Our type embeddings are 128 dimensions.

### Important: the type module `forward` will also _normalize_ and apply an additive attention mechanism to merge the multiple type embeddings for a single entity.

In [17]:
wd_type_obj = model.module_pool.learned_type_wiki
batch = 5
M = 4
K = 3
eid_cands = torch.randint(0, 5000, (batch, M, K))
embs = wd_type_obj(eid_cands, {})
print(embs.shape)

torch.Size([5, 4, 3, 128])
