This repo contains an implementation of TABi, a bi-encoder for entity retrieval that trains over knowledge graph types and unstructured text. TABi introduces a type-enforced contrastive loss to encourage query and entity embeddings to cluster by type in the embedding space. You can find more details in our paper.
This repo also includes pre-trained TABi models to retrieve Wikipedia pages from queries and training scripts to train TABi models on new datasets.
Our code is tested on Python 3.7. We recommend installing with a virtualenv.
pip install -r requirements.txt
pip install -e .
If you are using NVIDIA A100 GPUs, you will need to install a version of PyTorch that supports the sm_80
CUDA architecture:
pip install torch==1.10.1+cu111 torchvision==0.11.2+cu111 torchaudio==0.10.1 -f https://download.pytorch.org/whl/torch_stable.html
We provide the following pre-trained TABi models. We also provide the pre-computed entity embeddings over the KILT-E knowledge base. The pre-computed entity embeddings require 16GB of disk space to download.
We provide models that are trained with types using the type-enforced contrastive loss and models that are trained without types. Note that the models trained on the BLINK training data require mention boundaries (or mention detection, e.g. via flair) at test time. This is because all examples in the BLINK training data have mention boundaries.
Training Data | Trained with Types | Weights | Entity Embs |
---|---|---|---|
KILT | Yes | url | url |
KILT | No | url | url |
BLINK | Yes | url | url |
BLINK | No | url | url |
See our paper for hyperparameter settings for the pre-trained models.
We use a filtered version of the KILT knowledge base. We remove Wikimedia internal items (e.g. disambiguation pages, list articles) and add FIGER types to entities where available. The final knowledge base, KILT-Entity (KILT-E) has 5.45 million entities corresponding to English Wikipedia pages.
Download KILT-E:
- jsonlines: entity.jsonl
- pickle: entity.pkl
Both formats can be used for entity_file
in the following commands, but the pickle will load a bit faster.
We support two modes to use TABi interactively. We recommend using the models trained on KILT for the interactive mode. The interactive mode does not currently support mention detection or providing mention boundaries.
To retrieve entities from a pre-computed entity index, run:
python scripts/demo.py --model_checkpoint best_model.pth --entity_emb_path embs.npy --entity_file entity.pkl
Example:
To control the number of retrieved entities, use the flag --top_k
. By default, the top 10 entities will be returned.
To input your own entities (title and description) and get a score between the query and your entity, simply provide the model checkpoint:
python scripts/demo.py -model_checkpoint best_model.pth
Example:
The scores provided are cosine similarities and will be between -1 and 1 (1 is most similar). The demo will continue to prompt you for entities. To enter a new query and entities, type exit
.
We include AmbER and KILT datasets for evaluation and BLINK and KILT datasets for training in the TABi data format in Datasets. If you plan to use our provided datasets, you can skip to Evaluation and Training.
We require that the input to TABi be in the following format:
{
"id": # unique example id
"text": # question or sentence
"label_id": # list of gold knowledge base ids if available, otherwise use [-1]
"alt_label_id": # list of lists of alternate gold knowledge base ids, if none use [[]]
"mentions": # list of character spans of mention boundaries if available, otherwise []
}
Example (from Natural Questions):
{
"id": "-143054837169120955",
"text": "where are the giant redwoods located in california",
"label_id": [903760],
"alt_label_id": [[4683290, 2526048, 242069]],
"mentions": []
}
Note that if providing mention spans, TABi currently only supports disambiguating one mention at a time and will run separate evaluation queries on the model for each mention span in the list.
To convert a jsonlines file in the KILT data format to the TABi data format, run:
python scripts/preprocess_kilt.py --entity_file entity.pkl --input_file nq-dev-kilt.jsonl --output_file nq-dev-tabi.jsonl
To convert a directory of KILT-formatted files to the TABi format, run:
python scripts/preprocess_kilt.py --entity_file entity.pkl --input_dir kilt_dev --output_dir kilt_dev_tabi
The evaluation script runs the model eval, reports accuracy@1
and accuracy@10
, and saves the predictions in KILT-formatted files.
To evaluate a TABi model, run:
python tabi/eval.py --test_data_file nq-dev-kilt.jsonl --entity_file entity.pkl --model_checkpoint best_model.pth --entity_emb_path embs.npy --mode eval --log_dir logs
log_dir
specifies where the log file and prediction file are written.- You can also specify the name for the prediction file with
--pred_file
. For instance:
python tabi/eval.py --test_data_file nq-dev-kilt.jsonl --entity_file entity.pkl --model_checkpoint best_model.pth --entity_emb_path embs.npy --mode eval --log_dir logs --pred_file nq-dev-preds.jsonl
For benchmarks, we use the evaluation scripts provided by AmbER and KILT to report final numbers.
Training consists of a multi-step procedure.
- Train with local in-batch negatives.
- Extract entity embeddings.
- Extract hard negatives using nearest neighbor search with optional hard negative filtering.
- Train with in-batch negatives and hard negatives.
An example script is in scripts/run_sample.py. To run with the small sample data in the repo on a GPU:
python scripts/run_sample.py
To run with the small sample data in the repo on a CPU:
python scripts/run_sample_cpu.py
To train a new TABi model on your own dataset, make sure to format your training, eval, and test datasets in the TABi data format and modify data_dir
, train_file
, dev_file
, and test_file
in the example script.
To use a new entity knowledge base, each entity in the knowledge base (jsonlines file) should have the following format:
{
"label_id": # unique id of the entity (optional, if not provided, row in knowledge base is assigned as the id)
"title": # title of the entity
"text": # description of the entity
"types": # list of types ([] if none)
"wikipedia_page_id": # wikipedia page id (can exclude if not linking to Wikipedia)
}
See KILT-E knowledge base for an example of the expected format. The type-enforced contrastive loss uses query types, which are assigned as the types associated with the gold entity for the query. It is important that the "types" are not all empty in the knowledge base in order to see benefits from the type-enforced contrastive loss. Make sure to update entity_file
in the example script to use your new knowledge base.
We provide example scripts to train a new TABi model on the BLINK and KILT datasets. The datasets for training can be downloaded below. The provided pre-trained models were trained on 16 A100 GPUs for four epochs, which took approximately 9 and 11 hours total for the BLINK and KILT datasets, respectively.
We support DistributedDataParallel training on a single node with multiple GPUs. See the example scripts above for training on BLINK and KILT data using distributed training. You may need to increase the ulimit (number of open files) on your machine for large datasets using ulimit -n 100000
.
We have support for filtering hard negatives, following the procedure described in Botha et al.. The goal is to balance the frequency an entity occurs as a hard negative relative to the frequency an entity occurs in the training dataset as a gold entity. Filtering can help reduce the proportion of hard negatives that are rare entities. To use filtering, we provide the --filter_negatives
flag. We only recommend this frequency-based filtering procedure for large training datasets (e.g. BLINK or KILT). On small training datases, most entities may have very low or zero counts, leading to aggressive filtering.
We provide evaluation files in the TABi data format for the AmbER and KILT benchmarks. For KILT, we include the 8 open-domain tasks:
The ids for the dev/test splits we use for AmbER in our paper can be found here.
We provide training and validation files in the TABi data format for:
- KILT (train) (11.7M examples)
- KILT (dev - combined) (35.8K examples)
- BLINK (train) (8.4M examples)
- BLINK (dev) (9.9K examples)
If you find this code useful, please cite the following paper:
@inproceedings{leszczynski-etal-2022-tabi,
title={{TAB}i: {T}ype-Aware Bi-Encoders for Open-Domain Entity Retrieval},
author={Megan Leszczynski and Daniel Y. Fu and Mayee F. Chen and Christopher R\'e},
booktitle={Findings of the Association of Computational Linguistics: ACL 2022},
year={2022}
}
Our work was inspired by the following repos: