Skip to content

Two approaches for robust TableQA: 1) ITR is a general-purpose retrieval-based approach for handling long tables in TableQA transformer models. 2) LI-RAGE is a robust framework for open-domain TableQA which addresses several limitations. (ACL 2023)

License

Unknown and 2 other licenses found

Licenses found

Unknown
LICENSE
MIT-0
LICENSE-SAMPLECODE
Unknown
LICENSE-SUMMARY

amazon-science/robust-tableqa

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

8 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Robust Table Question Answering

Code for training and evaluation the transformer-based robust Table QA models introduced in the following ACL 2023 papers:

An Inner Table Retriever for Robust Table Question Answering

Paper Conference License: CC BY-NC 4.0

Inner Table Retriever (ITR) is a general-purpose approach for handling long tables in TableQA that extracts sub-tables to preserve the most relevant information for a question. ITR can be easily integrated into existing systems to improve their accuracy achieve state-of-the-art results.

If you find our paper, code or framework useful, please put a link to this repo and reference this work in your paper:

@Inproceedings{Lin2023,
 author = {Weizhe Lin and Rexhina Blloshmi  and Bill Byrne and AdriĂ  de Gispert and Gonzalo Iglesias},
 title = {An inner table retriever for robust table question answering},
 year = {2023},
 url = {https://www.amazon.science/publications/an-inner-table-retriever-for-robust-table-question-answering},
 booktitle = {ACL 2023},
}

For more info and details on how to install/run check Scripts for Inner Table Retriever in the Codebase Section below.


LI-RAGE: Late Interaction Retrieval Augmented Generation with Explicit Signals for Open-Domain Table Question Answering

Paper Conference License: CC BY-NC 4.0

LI-RAGE is a framework for open-domain TableQA which addresses several limitations thanks to:
  1. applying late interaction models which enforce a finer-grained interaction between question and table embeddings at retrieval time.
  2. incorporating a joint training scheme of the retriever and reader with explicit table-level signals, and
  3. embedding a binary relevance token as a prefix to the answer generated by the reader, so we can determine at inference time whether the table used to answer the question is reliable and filter accordingly. The combined strategies set a new state-to-the-art performance on two public open-domain TableQA datasets.

If you find our paper, code or framework useful, please put a link to this repo and reference this work in your paper:

@Inproceedings{Lin2023,
 author = {Weizhe Lin and Rexhina Blloshmi and Bill Byrne and AdriĂ  de Gispert and Gonzalo Iglesias},
 title = {LI-RAGE: Late interaction retrieval augmented generation with explicit signals for open-domain table question answering},
 year = {2023},
 url = {https://www.amazon.science/publications/li-rage-late-interaction-retrieval-augmented-generation-with-explicit-signals-for-open-domain-table-question-answering},
 booktitle = {ACL 2023},
}

For more info and details on how to install/run check Scripts for Open-domain TableQA with Late Interaction Models in the Codebase Section below.


Codebase

The code base is created by Weizhe Lin (wzlin@amazon.co.uk), during his internship as an applied scientist of Alexa AI. This code base contains:

  • TableQA Baseline Systems
    • TAPAS (Huggingface)
    • TAPEX (Huggingface)
    • OmniTab (Huggingface)
  • Inner Table Retrieval (ITR)
    • Inner Table Retriever based on Dense Passage Retrieval (DPR)
    • TableQA systems with ITR
  • Late Interaction Models for Retrieval Augmented TableQA Systems
    • Late Interaction Table Retriever based on ColBERT
    • Retrieval Augmented TableQA Models (TaPEx) with Late Interaction Retriever
    • Dense Table Retriever (DPR)
    • Retrieval Augmented TableQA Models (TaPEx) with DPR

Codebase Structure

3rd-party Open-Source Tools

In this codebase, several open-source tools are used, including:

  • Weights and Biases
  • Pytorch-lightning
  • Pytorch
  • Pytorch-scatter
  • Huggingface-transformers
  • Other open-source tools
    • Experiment framework is built upon RAVQA

Overview

The training and testing are backboned by pytorch-lightning. The pre-trained Transformer models are from Huggingface-transformers. The training platform is Pytorch.

Structure

The framework consists of:

  1. main.py: the entry point to the main program. It loads a config file and override some entries with command-line arguments. It initialises a data loader wrapper, a model trainer, and a pytorch-lightning trainer to execute training and testing.
  2. Data Loader Wrapper: it loads the data according to data_loader.dataset_modules defined in config files. A data module called LoadDataLoaders is used to create pytorch dataloaders from the data.
  3. Datasets: they are automatically loaded by the data loader wrapper. .collate_fn is defined to collate the data. An decorator class ModuleParser is used to help generate the training inputs. This decorator class generates input dict according to configs (config.model_config.input_modules/decorder_input_modules/output_modules).
  4. Model Trainers: a pytorch-lightning LightningModule instance. It defines training/testing behaviors (training steps, optimizers, schedulers, logging, checkpointing, and so on). It initialises the model being trained at self.model.
  5. Models: pytorch nn.Modules models.

Configs

The configuration is achieved with jsonnet. It enables inheritance of config files. For example, wtq/tapex_ITR_mix_wtq.jsonnet override its configs to wtq/tapex_ITR_column_wise_wtq.jsonnet, which again inherits from base_env.jsonnet where most of common configurations are defined.

By including the corresponding key:value pair in the config file, overriding can be easily performed.

ModuleParser

A decorator class that helps to parse data into features that are used by models.

An example taken from ITR is shown below:

"input_modules": {
    "module_list":[
    {"type": "QuestionInput",  "option": "default", 
                "separation_tokens": {"start": "", "end": ""}},
    ],
    "postprocess_module_list": [
    {"type": "PostProcessInputTokenization", "option": "default"},
    ],
},
"decoder_input_modules": {
    "module_list":[
    {"type": "TextBasedTableInput",  "option": "default",
                "separation_tokens": {"header_start": "<HEADER>", "header_sep": "<HEADER_SEP>", "header_end": "<HEADER_END>", "row_start": "<ROW>", "row_sep": "<ROW_SEP>", "row_end": "<ROW_END>"}},
    ],
    "postprocess_module_list": [
    {"type": "PostProcessDecoderInputTokenization", "option": "default"},
    ],
},
"output_modules": {
    "module_list":[
    {"type": "SimilarityOutput", "option": "default"},
    ],
    "postprocess_module_list": [
    {"type": "PostProcessConcatenateLabels", "option": "default"},
    ],
},

which first run input modules in the order defined in input_modules, and then the postprocessing unit PostProcessInputTokenization is used to tokenize the input into input_ids and input_attention_mask.

Similarly, decoder_input_modules (in the ITR, decoder input is used as the input to the item encoder. The naming is slightly confusing but feel free to change it, though not deemed necessary) generates item_input_ids and item_input_attention_mask. output_modules generates the labels for ITR training.

By defining new functions in ModuleParser, e.g. self.TextBasedVisionInput, a new behavior can be easily introduced to transform modules into training features.

MetricsProcessor

The following entries in config file test.metrics define the metrics to compute in validation and testing. Each module uploads log_dict with metrics_name: metrics_value which can be processed in trainers conveniently.

"metrics": [
    {'name': 'compute_exact_match'},
    {'name': 'compute_retrieval_metrics'},
],

Environment

conda create -n tableqa python=3.8
conda activate tableqa
pip install torch==1.10.1+cu111 torchvision==0.11.2+cu111 torchaudio==0.10.1 -f https://download.pytorch.org/whl/torch_stable.html
pip install torch-scatter -f https://data.pyg.org/whl/torch-1.10.1+cu111.html --force-reinstall --no-cache-dir --no-index
pip install transformers==4.22.1 datasets==2.6.1
pip install jsonnet
pip install easydict tqdm pytorch-lightning==1.8.2 tfrecord frozendict
pip install wandb==0.13.4
conda install -c pytorch faiss-gpu -y
pip install bitarray spacy ujson gitpython
pip install ninja
pip install absl-py tensorboard
pip install -e src/ColBERT
pip install scipy scikit-learn
pip install setuptools==56.1.0

Note: All experiments were run on 8 V100/A100 clusters. You may want to reduce batch sizes or use smaller base models (e.g. TaPEx-base) to fit the memory if your GPUs are smaller.

Data

We use open-sourced data which can be found in the respective papers, please check the references in our papers. Make sure to add/change the data paths in the respective config files.

Useful Command-line Arguments

Some general cli arguments. For more details, please read the code / directly look at how they are used in training/evaluation of specific models.

Universal

  • All trainer parameters supported by pytorch-lightning, such as --accelerator gpu --devices 8 --strategy ddp --num_sanity_val_steps 2
  • --experiment_name EXPERIMENT_NAME the name of the experiment. Will be used as the name of the folder as well as the run name on WANDB
  • --mode [train/test] indicate the mode for running.
  • --modules module1 module2 module3 ... list of modules that will be used. They will be saved to self.config.model_config.modules so that they are accessible anywhere in the framework.

Training

  • --opts [list of configurations] used at the end of the cli command. self.config will be overwritten by the configurations here. For example:

    • train.batch_size=1 batch size
    • train.scheduler=linear currently supports none/linear
    • train.epochs=20
    • train.lr=0.00002
    • train.retriever_lr=0.00001
    • train.additional.gradient_accumulation_steps=4
    • train.additional.warmup_steps=0
    • train.additional.early_stop_patience=7
    • train.additional.save_top_k=1
    • valid.step_size=400
    • valid.batch_size=4
    • model_config.GeneratorModelVersion=microsoft/tapex-large: an example of how you can change the pretrained model checkpoint for the answer generator
    • data_loader.additional.num_knowledge_passages=5: an example of how you can change K in ITR/ITR+TableQA/OpenDomainRA-TableQA
    • model_config.num_beams=5: number of beams in generation

Testing

  • --test_evaluation_name nq_tables_all this will create a folder under the experiment folder (indicated by --experiment_name) and save everything there. Also, in the WANDB run (run name indicated by --experiment_name), a new section with this name (nq_tables_all) will be created, and the evaluation scores will be logged into this section.
  • --opts test.batch_size=32
  • --opts test.load_epoch=3825 which checkpoint to load. Note that you need to have the same experiment name

Scripts for Inner Table Retriever

Inner Table Retrieval

NOTE After training, you need to run inference to generate index files that are used in TableQA + ITR. The index files will be generated in Experiments/{experiment_name}/test/epoch{load_epoch}/.

Main Experiments

ITR mix (intersecting columns and rows)

WikiSQL train

python src/main.py configs/wikisql/dpr_ITR_mix_wikisql.jsonnet --accelerator gpu --devices 8 --strategy ddp --num_sanity_val_steps 2 --experiment_name DPR_InnerTableRetrieval_wikisql_with_in_batch_neg_sampling_mixed --mode train --override --opts train.batch_size=1 train.scheduler=None train.epochs=20 train.lr=0.00001 train.additional.gradient_accumulation_steps=4 train.additional.warmup_steps=200 train.additional.early_stop_patience=8 train.additional.save_top_k=3 valid.batch_size=8 test.batch_size=8 valid.step_size=200 reset=1

WikiSQL test

python src/main.py configs/wikisql/dpr_ITR_mix_wikisql.jsonnet --accelerator gpu --devices 8 --strategy ddp --experiment_name DPR_InnerTableRetrieval_wikisql_with_in_batch_neg_sampling_mixed --mode test --test_evaluation_name original_sets --opts test.batch_size=32 test.load_epoch=11604

WikiTQ test

python src/main.py configs/wtq/dpr_ITR_mix_wtq.jsonnet --accelerator gpu --devices 1 --strategy ddp --experiment_name DPR_InnerTableRetrieval_wikisql_with_in_batch_neg_sampling_mixed --mode test --test_evaluation_name wtq_original_sets --opts test.batch_size=32 test.load_epoch=11604

Additional ablation experiments

Column-wise ITR

WikiSQL train

python src/main.py configs/dpr_ITR_wikisql.jsonnet --accelerator gpu --devices 8 --strategy ddp --num_sanity_val_steps 2 --experiment_name DPR_InnerTableRetrieval_wikisql_with_in_batch_neg_sampling_single_column --mode train --opts train.batch_size=1 train.scheduler=None train.epochs=20 train.lr=0.00001 train.additional.gradient_accumulation_steps=4 train.additional.warmup_steps=200 train.additional.early_stop_patience=6 train.additional.save_top_k=3 train.save_interval=200 valid.batch_size=8 test.batch_size=8 valid.step_size=200 data_loader.dummy_dataloader=0

WikiSQL test

python src/main.py configs/dpr_ITR_wikisql.jsonnet --accelerator gpu --devices 8 --strategy ddp --experiment_name DPR_InnerTableRetrieval_wikisql_with_in_batch_neg_sampling_single_column --mode test --test_evaluation_name wikisql_original_sets --opts test.batch_size=32 test.load_epoch=3801

WikiTQ test

python src/main.py configs/wtq/dpr_ITR_column_wise_wtq.jsonnet --accelerator gpu --devices 1 --strategy ddp --experiment_name DPR_InnerTableRetrieval_wikisql_with_in_batch_neg_sampling_single_column --mode test --test_evaluation_name wtq_original_sets --opts test.batch_size=32 test.load_epoch=3801

Row-wise ITR

WikiSQL train

python src/main.py configs/dpr_ITR_row_wise_wikisql.jsonnet --accelerator gpu --devices 8 --strategy ddp --num_sanity_val_steps 2 --experiment_name DPR_InnerTableRetrieval_wikisql_with_in_batch_neg_sampling_single_row --mode train --opts train.batch_size=1 train.scheduler=None train.epochs=20 train.lr=0.00001 train.additional.gradient_accumulation_steps=4 train.additional.warmup_steps=200 train.additional.early_stop_patience=8 train.additional.save_top_k=3 train.save_interval=200 valid.batch_size=8 test.batch_size=8 valid.step_size=200 data_loader.dummy_dataloader=0

WikiSQL test

python src/main.py configs/dpr_ITR_row_wise_wikisql.jsonnet --accelerator gpu --devices 8 --strategy ddp --experiment_name DPR_InnerTableRetrieval_wikisql_with_in_batch_neg_sampling_single_row --mode test --test_evaluation_name wikisql_original_sets --opts test.batch_size=32 test.load_epoch=6803

WikiTQ test

python src/main.py configs/wtq/dpr_ITR_row_wise_wikisql.jsonnet --accelerator gpu --devices 1 --strategy ddp --experiment_name DPR_InnerTableRetrieval_wikisql_with_in_batch_neg_sampling_single_row --mode test --test_evaluation_name wtq_original_sets --opts test.batch_size=32 test.load_epoch=6803

TableQA with ITR (ITR + TaPEx)

You will need to change the index path in the config files. You can also change the index path by --opts dynamically.

For example, after running the inference for ITR-mix, you can change the config file configs/wikisql/tapas_ITR_mix_wikisql.jsonnet as follows:

// here we put the index file paths
local index_files = {
  "index_paths": {
    "train": "DPR_InnerTableRetrieval_wikisql_with_in_batch_neg_sampling_mixed/test/original_sets/step_11604/test.ITRWikiSQLDataset.train",
    "validation": "DPR_InnerTableRetrieval_wikisql_with_in_batch_neg_sampling_mixed/test/original_sets/step_11604/test.ITRWikiSQLDataset.validation",
    "test": "DPR_InnerTableRetrieval_wikisql_with_in_batch_neg_sampling_mixed/test/original_sets/step_11604/test.ITRWikiSQLDataset.test",
  },
};

Some general notes:

  • You can set test.load_epoch=0 model_config.GeneratorModelVersion=microsoft/tapex-large-finetuned-wikisql to use the official checkpoint in evaluation.
  • Do make sure test.load_epoch= has the same number as the checkpoint. Otherwise, model_config.GeneratorModelVersion will be loaded.
  • Check arguments carefully before running! Especially for dangerous commands such as --override --opts reset=1.
  • data_loader.additional.max_decoder_source_length=128 controls the token limit for ITR, they are 1024 for TaPEx and 512 for TaPas by default.
  • In the paper we provide different variations. Here we only show the main results with the ITR-mix strategy and the ablations with column- or row-wise ITR. Feel free to try your own configuration. You can change model_config.ModelClass to ITRRagReduceMixModel, ITRRagAdditionRowWiseModel, etc. Available model classes are in models/itr_rag.py and models/itr_rag_reduce.py.

Main Experiments

ITR mix (intersecting columns and rows)

WikiSQL train

python src/main.py configs/wikisql/tapex_ITR_mix_wikisql.jsonnet --accelerator gpu --devices 8 --strategy ddp --num_sanity_val_steps 2 --experiment_name finetune_tapex_large_on_WikiSQL_with_ITR_mix_reduction_smoothing_overflow_only_original_sub_table_order --mode train --modules overflow_only original_sub_table_order shuffle_sub_table_order_in_training --override --opts train.batch_size=1 train.scheduler=linear train.epochs=10 train.lr=0.00003 train.additional.gradient_accumulation_steps=4 train.additional.warmup_steps=1000 train.additional.early_stop_patience=7 train.additional.save_top_k=3 valid.step_size=1000 valid.batch_size=2 test.batch_size=2 data_loader.dummy_dataloader=0 model_config.GeneratorModelVersion=microsoft/tapex-large data_loader.additional.num_knowledge_passages=10 reset=1

WikiSQL test

python src/main.py configs/wikisql/tapex_ITR_mix_wikisql.jsonnet --accelerator gpu --devices 8 --strategy ddp --experiment_name finetune_tapex_large_on_WikiSQL_with_ITR_mix_reduction_smoothing_overflow_only_original_sub_table_order --mode test --modules overflow_only original_sub_table_order --test_evaluation_name original_sets --opts test.batch_size=4 test.load_epoch=[] model_config.GeneratorModelVersion=microsoft/tapex-large model_config.ModelClass=ITRRagReduceMixModel data_loader.additional.num_knowledge_passages=1 data_loader.additional.max_decoder_source_length=1024

WikiTQ train

python src/main.py configs/wtq/tapex_ITR_mix_wtq.jsonnet --accelerator gpu --devices 8 --strategy ddp --num_sanity_val_steps 2 --experiment_name finetune_tapex_large_on_WTQ_with_ITR_mix_reduction_smoothing_overflow_only_original_sub_table_order_K_10 --mode train --modules overflow_only original_sub_table_order --override --opts train.batch_size=1 train.scheduler=linear train.epochs=40 train.lr=0.00002 train.additional.gradient_accumulation_steps=4 train.additional.warmup_steps=0 train.additional.early_stop_patience=7 train.additional.save_top_k=3 valid.step_size=1000 valid.batch_size=2 test.batch_size=2 data_loader.dummy_dataloader=0 model_config.GeneratorModelVersion=microsoft/tapex-large model_config.ModelClass=ITRRagReduceMixModel data_loader.additional.num_knowledge_passages=10

WikiTQ test

python src/main.py configs/wtq/tapex_ITR_mix_wtq.jsonnet --accelerator gpu --devices 8 --strategy ddp --experiment_name finetune_tapex_large_on_WTQ_with_ITR_mix_reduction_smoothing_overflow_only_original_sub_table_order_K_10 --mode test --modules overflow_only original_sub_table_order --test_evaluation_name original_sets --opts test.batch_size=2 test.load_epoch=[] model_config.GeneratorModelVersion=microsoft/tapex-large-finetuned-wtq model_config.ModelClass=ITRRagReduceMixModel data_loader.additional.num_knowledge_passages=10 data_loader.additional.max_decoder_source_length=1024

Additional Ablation Experiments

Column-wise ITR

WikiSQL train

python src/main.py configs/tapex_ITR_wikisql.jsonnet --accelerator gpu --devices 8 --strategy ddp --num_sanity_val_steps 2 --experiment_name finetune_tapex_large_on_WikiSQL_with_ITR_addition_smoothing_overflow_only --mode train --modules overflow_only --opts train.batch_size=1 train.scheduler=linear train.epochs=10 train.lr=0.00003 train.additional.gradient_accumulation_steps=4 train.additional.warmup_steps=1000 train.additional.early_stop_patience=5 train.additional.save_top_k=3 valid.step_size=1000 valid.batch_size=4 test.batch_size=4 data_loader.dummy_dataloader=0 model_config.GeneratorModelVersion=microsoft/tapex-large data_loader.additional.num_knowledge_passages=5

WikiSQL test

python src/main.py configs/wikisql/tapex_ITR_wikisql.jsonnet --accelerator gpu --devices 8 --strategy ddp --experiment_name finetune_tapex_large_on_WikiSQL_with_ITR_addition_smoothing_overflow_only_original_sub_table_order --modules overflow_only original_sub_table_order --mode test --test_evaluation_name ITR_addition_oo_osto_official --opts test.batch_size=4 test.load_epoch=[] model_config.GeneratorModelVersion=microsoft/tapex-large  model_config.ModelClass=ITRRagModel data_loader.additional.num_knowledge_passages=5

WikiTQ train

python src/main.py configs/wtq/tapex_ITR_column_wise_wtq.jsonnet --accelerator gpu --devices 8 --strategy ddp --num_sanity_val_steps 2 --experiment_name finetune_tapex_large_on_WTQ_with_ITR_column_wise_addition_smoothing_overflow_only_original_sub_table_order_K_10 --override --mode train --modules overflow_only original_sub_table_order --opts train.batch_size=1 train.scheduler=linear train.epochs=40 train.lr=0.00002 train.additional.gradient_accumulation_steps=4 train.additional.warmup_steps=0 train.additional.early_stop_patience=7 train.additional.save_top_k=3 valid.step_size=1000 valid.batch_size=2 test.batch_size=2 data_loader.dummy_dataloader=0 model_config.GeneratorModelVersion=microsoft/tapex-large model_config.ModelClass=ITRRagModel data_loader.additional.num_knowledge_passages=10 reset=1

WikiTQ test

python src/main.py configs/wtq/tapex_ITR_column_wise_wtq.jsonnet --accelerator gpu --devices 8 --strategy ddp --experiment_name finetune_tapex_large_on_WTQ_with_ITR_column_wise_addition_smoothing_overflow_only_original_sub_table_order_K_10 --mode test --modules overflow_only original_sub_table_order --test_evaluation_name ITR_addition_oo_osto_K_10 --opts test.batch_size=2 test.load_epoch=[] model_config.GeneratorModelVersion=microsoft/tapex-large model_config.ModelClass=ITRRagModel data_loader.additional.num_knowledge_passages=10

Row-wise ITR

WikiSQL train

python src/main.py configs/tapex_ITR_row_wise_wikisql.jsonnet --accelerator gpu --devices 8 --strategy ddp --num_sanity_val_steps 2 --experiment_name finetune_tapex_large_on_WikiSQL_with_ITR_row_wise_addition_smoothing_overflow_only_original_sub_table_order_K_10 --mode train --modules overflow_only original_sub_table_order force_select_last --opts train.batch_size=1 train.scheduler=linear train.epochs=10 train.lr=0.00003 train.additional.gradient_accumulation_steps=4 train.additional.warmup_steps=1000 train.additional.early_stop_patience=7 train.additional.save_top_k=3 valid.step_size=1000 valid.batch_size=2 test.batch_size=2 data_loader.dummy_dataloader=0 model_config.GeneratorModelVersion=microsoft/tapex-large model_config.ModelClass=ITRRagAdditionRowWiseModel data_loader.additional.num_knowledge_passages=10

WikiSQL test

python src/main.py configs/tapex_ITR_row_wise_wikisql.jsonnet --accelerator gpu --devices 8 --strategy ddp --experiment_name finetune_tapex_large_on_WikiSQL_with_ITR_row_wise_addition_smoothing_overflow_only_original_sub_table_order_K_10 --mode test --modules overflow_only original_sub_table_order --test_evaluation_name ITR_reduction_oo_osto_K_10 --opts test.batch_size=2 test.load_epoch=[] model_config.GeneratorModelVersion=microsoft/tapex-large model_config.ModelClass=ITRRagReduceRowWiseModel data_loader.additional.num_knowledge_passages=10

WikiTQ train

python src/main.py configs/wtq/tapex_ITR_row_wise_wtq.jsonnet --accelerator gpu --devices 8 --strategy ddp --num_sanity_val_steps 2 --experiment_name finetune_tapex_large_on_WTQ_with_ITR_row_wise_addition_smoothing_overflow_only_original_sub_table_order_K_10 --mode train --modules overflow_only original_sub_table_order force_select_last --opts train.batch_size=1 train.scheduler=linear train.epochs=30 train.lr=0.00003 train.additional.gradient_accumulation_steps=4 train.additional.warmup_steps=1000 train.additional.early_stop_patience=7 train.additional.save_top_k=3 valid.step_size=1000 valid.batch_size=2 test.batch_size=2 data_loader.dummy_dataloader=0 model_config.GeneratorModelVersion=microsoft/tapex-large model_config.ModelClass=ITRRagAdditionRowWiseModel data_loader.additional.num_knowledge_passages=10

WikiTQ test

python src/main.py configs/wtq/tapex_ITR_row_wise_wtq.jsonnet --accelerator gpu --devices 8 --strategy ddp --experiment_name finetune_tapex_large_on_WTQ_with_ITR_row_wise_addition_smoothing_overflow_only_original_sub_table_order_K_10 --mode test --modules overflow_only original_sub_table_order force_select_last --test_evaluation_name ITR_addition_oo_osto_fsl_K_10 --opts test.batch_size=4 test.load_epoch=[] model_config.GeneratorModelVersion=microsoft/tapex-large model_config.ModelClass=ITRRagAdditionRowWiseModel data_loader.additional.num_knowledge_passages=10

TableQA with ITR (TaPas + ITR)

WikiSQL test

python src/main.py configs/wikisql/tapas_ITR_mix_wikisql.jsonnet --accelerator gpu --devices 8 --strategy ddp --experiment_name evaluate_tapas_with_ITR_on_WikiSQL --mode test --modules overflow_only original_sub_table_order --test_evaluation_name 128tokens --opts test.batch_size=32 test.load_epoch=0 model_config.GeneratorModelVersion=google/tapas-large-finetuned-wikisql-supervised model_config.ModelClass=ITRRagReduceMixModel data_loader.additional.num_knowledge_passages=1 data_loader.additional.max_decoder_source_length=128 model_config.min_columns=1

WikiTQ test

python src/main.py configs/wtq/tapas_ITR_mix_wtq.jsonnet --accelerator gpu --devices 8 --strategy ddp --experiment_name evaluate_tapas_with_ITR_on_WTQ_token_limit_exploration --mode test --modules overflow_only original_sub_table_order --test_evaluation_name 128tokens_official --opts test.batch_size=16 test.load_epoch=0 model_config.GeneratorModelVersion=google/tapas-large-finetuned-wtq model_config.ModelClass=ITRRagReduceMixModel data_loader.additional.num_knowledge_passages=1 data_loader.additional.max_decoder_source_length=128 model_config.min_columns=1

TableQA with ITR (OmniTab + ITR)

python src/main.py configs/wtq/tapex_ITR_mix_wtq.jsonnet --accelerator gpu --devices 8 --strategy ddp --experiment_name evaluate_ominitab_on_WTQ --mode test --modules overflow_only original_sub_table_order --test_evaluation_name original_sets --opts test.batch_size=2 test.load_epoch=0 model_config.GeneratorModelVersion=neulab/omnitab-large-finetuned-wtq model_config.DecoderTokenizerModelVersion=neulab/omnitab-large-finetuned-wtq model_config.ModelClass=ITRRagReduceMixModel data_loader.additional.num_knowledge_passages=10

Disable ITR

Add --modules suppress_ITR, for example:

python src/main.py configs/wtq/tapas_ITR_mix_wtq.jsonnet --accelerator gpu --devices 8 --strategy ddp --experiment_name evaluate_tapas_with_ITR_on_WTQ_token_limit_exploration --mode test --modules overflow_only original_sub_table_order suppress_ITR --test_evaluation_name 128tokens_official_no_ITR --opts test.batch_size=16 test.load_epoch=0 model_config.GeneratorModelVersion=google/tapas-large-finetuned-wtq model_config.ModelClass=ITRRagReduceMixModel data_loader.additional.num_knowledge_passages=1 data_loader.additional.max_decoder_source_length=128 model_config.min_columns=1

TableQA Baselines

TaPEx

WikiSQL train

python src/main.py configs/wikisql/tapex_wikisql.jsonnet --accelerator gpu --devices 8 --strategy ddp --num_sanity_val_steps 2 --experiment_name finetune_tapex_large_on_WikiSQL_smoothing_0.1 --mode train --opts train.batch_size=1 train.scheduler=linear train.epochs=20 train.lr=0.00003 train.additional.gradient_accumulation_steps=4 train.additional.warmup_steps=1000 train.additional.early_stop_patience=6 train.additional.save_top_k=3 train.save_interval=1000 valid.batch_size=4 test.batch_size=4 data_loader.dummy_dataloader=0 train.additional.label_smoothing_factor=0.1

WikiSQL test

python src/main.py configs/wikisql/tapex_wikisql.jsonnet --accelerator gpu --devices 1 --strategy ddp --experiment_name finetune_tapex_large_on_WikiSQL_smoothing_0.1 --mode test --log_prediction_tables --test_evaluation_name original_valid_test_set --opts test.batch_size=16 test.load_epoch=[]

Change the config to configs/wtq/tapex_wtq.jsonnet to use WTQ instead.

Scripts for Open-domain TableQA with Late Interaction Models

Main Experiments

ColBERT Retrieval

Some additional useful arguments:

  • model_config.nbits=2: how many bits the embeddings are quantized (compressed) into. A higher nbits will significantly increase the index size.

NQ-TABLES train

python src/main.py configs/nq_tables/colbert.jsonnet --accelerator gpu --devices 8 --strategy ddp --num_sanity_val_steps 2 --experiment_name ColBERT_NQTables_bz4_negative4_fix_doclen_full_search_NewcrossGPU --mode train --override --opts train.batch_size=6 train.scheduler=None train.epochs=1000 train.lr=0.00001 train.additional.gradient_accumulation_steps=4 train.additional.warmup_steps=0 train.additional.early_stop_patience=10 train.additional.save_top_k=3 valid.batch_size=32 test.batch_size=32 valid.step_size=200 data_loader.dummy_dataloader=0 reset=1 model_config.num_negative_samples=4 model_config.bm25_top_k=5 model_config.bm25_ratio=0 model_config.nbits=2

NQ-TABLES test

python src/main.py configs/nq_tables/colbert.jsonnet --accelerator gpu --devices 1 --strategy ddp --experiment_name ColBERT_NQTables_bz4_negative4_fix_doclen_full_search_NewcrossGPU --mode test --test_evaluation_name nq_tables_all --opts test.batch_size=32 test.load_epoch=5427 model_config.nbits=8

E2E_WTQ train

python src/main.py configs/e2e_wtq/colbert.jsonnet --accelerator gpu --devices 8 --strategy ddp --num_sanity_val_steps 0 --experiment_name ColBERT_E2EWTQ_bz4_negative4_fix_doclen_full_search_NewcrossGPU --mode train --override --modules exhaustive_search_in_testing --opts train.batch_size=6 train.scheduler=None train.epochs=1000 train.lr=0.00001 train.additional.gradient_accumulation_steps=4 train.additional.warmup_steps=0 train.additional.early_stop_patience=30 train.additional.save_top_k=3 valid.batch_size=32 test.batch_size=32 valid.step_size=10 data_loader.dummy_dataloader=0 reset=1 model_config.num_negative_samples=4 model_config.bm25_top_k=5 model_config.bm25_ratio=0 model_config.nbits=2

E2E_WTQ test

python src/main.py configs/e2e_wtq/colbert.jsonnet --accelerator gpu --devices 1 --strategy ddp --experiment_name ColBERT_E2EWTQ_bz4_negative4_fix_doclen_full_search_NewcrossGPU --mode test --test_evaluation_name e2e_wtq_all --opts test.batch_size=32 test.load_epoch=300 model_config.nbits=8

LIRAGE

Note: if you are not using the index files specified in the config files, you may want to

  • change the paths in the config file; or
  • use the following to change the paths in commandline directly:
--opts model_config.QueryEncoderModelVersion=$ColBERT_NQTables_bz4_negative4_fix_doclen_full_search_NewcrossGPU/train/saved_model/step_5427 model_config.index_files.index_passages_path=ColBERT_NQTables_bz4_negative4_fix_doclen_full_search_NewcrossGPU/test/nq_tables_all/step_5427/table_dataset model_config.index_files.index_path=ColBERT_NQTables_bz4_negative4_fix_doclen_full_search_NewcrossGPU/test/nq_tables_all/step_5427/table_dataset_colbert_index model_config.index_files.embedding_path=ColBERT_NQTables_bz4_negative4_fix_doclen_full_search_NewcrossGPU/test/nq_tables_all/step_5427/item_embeddings.pkl

NQ-TABLES train

python src/main.py configs/nq_tables/colbert_rag.jsonnet --accelerator gpu --devices 8 --strategy ddp --num_sanity_val_steps 2 --experiment_name RAG_ColBERT_NQTables_RAVQA_Approach5_add_prompt --mode train --modules add_binary_labels_as_prompt --override --opts train.batch_size=1 train.scheduler=linear train.epochs=20 train.lr=0.00002 train.retriever_lr=0.00001 train.additional.gradient_accumulation_steps=4 train.additional.warmup_steps=0 train.additional.early_stop_patience=3 train.additional.save_top_k=1 valid.step_size=400 valid.batch_size=4 test.batch_size=4 data_loader.dummy_dataloader=0 model_config.GeneratorModelVersion=microsoft/tapex-large data_loader.additional.num_knowledge_passages=5 model_config.num_beams=5 reset=1

NQ-TABLES test

python src/main.py configs/nq_tables/colbert_rag.jsonnet --accelerator gpu --devices 1 --strategy ddp --experiment_name RAG_ColBERT_NQTables_RAVQA_Approach5_add_prompt --mode test --test_evaluation_name alternative_answers --modules add_binary_labels_as_prompt --opts test.batch_size=1 test.load_epoch=[] model_config.num_beams=5 data_loader.additional.num_knowledge_passages=5

E2E_WTQ train (using the officially released TaPEx does not differ much from using our own finetuned. To save steps, we just use it here. You can change it to your own finetuned TaPEx version.)

python src/main.py configs/e2e_wtq/colbert_rag.jsonnet --accelerator gpu --devices 8 --strategy ddp --num_sanity_val_steps 2 --experiment_name RAG_ColBERT_E2EWTQ_RAVQA_Approach5_add_prompt_pretrained --modules add_binary_labels_as_prompt --mode train --override --opts train.batch_size=1 train.scheduler=none train.epochs=100 train.lr=0.000015 train.retriever_lr=0.00001 train.additional.gradient_accumulation_steps=4 train.additional.warmup_steps=0 train.additional.early_stop_patience=3 train.additional.save_top_k=1 valid.step_size=25 valid.batch_size=4 test.batch_size=4 data_loader.dummy_dataloader=0 model_config.GeneratorModelVersion=microsoft/tapex-large-finetuned-wtq data_loader.additional.num_knowledge_passages=5 model_config.num_beams=5 model_config.RAVQA_loss_type=Approach5 reset=1

E2E_WTQ test

python src/main.py configs/e2e_wtq/colbert_rag.jsonnet --accelerator gpu --devices 1 --strategy ddp --experiment_name RAG_ColBERT_E2EWTQ_RAVQA_Approach5_add_prompt_pretrained --mode test --test_evaluation_name K5 --modules add_binary_labels_as_prompt --opts test.batch_size=1 test.load_epoch=176 model_config.num_beams=5 data_loader.additional.num_knowledge_passages=5

Additional Ablation Experiments

Dense Passage Retrieval (DPR)

Some useful arguments

  • --modules negative_samples_across_gpus: sharing negative samples across GPUs
  • --modules exhaustive_search_in_testing: use exhaustive search in testing, but this is significantly slower than building HNSW index. HNSW offers faster dynamic search later
  • --opts model_config.bm25_ratio=0 model_config.bm25_top_k=3
    • ratio=0: no bm25 mined negative examples are used
    • bm25_top_k=K: find negative examples in top-K bm25 mined examples. Note that this value should be large enough when model_config.num_negative_samples is large so that enough examples can be found.
    • Note: this didn't improve the performance at the end. So it is disabled now.

NQ-TABLES train

python src/main.py configs/nq_tables/dpr.jsonnet --accelerator gpu --devices 8 --strategy ddp --num_sanity_val_steps 2 --experiment_name DPR_NQTables_train_bz8_gc_4_crossGPU --mode train --override --modules negative_samples_across_gpus exhaustive_search_in_testing --opts train.batch_size=8 train.scheduler=None train.epochs=1000 train.lr=0.00001 train.additional.gradient_accumulation_steps=4 train.additional.warmup_steps=0 train.additional.early_stop_patience=10 train.additional.save_top_k=3 valid.batch_size=32 test.batch_size=32 valid.step_size=200 data_loader.dummy_dataloader=0 reset=1 model_config.num_negative_samples=4 model_config.bm25_ratio=0 model_config.bm25_top_k=3

NQ-TABLES test

python src/main.py configs/nq_tables/dpr.jsonnet --accelerator gpu --devices 1 --strategy ddp --experiment_name DPR_NQTables_train_bz8_gc_4_crossGPU --mode test --test_evaluation_name nq_tables_all --opts test.batch_size=32 test.load_epoch=[]

E2E_WTQ train

(Note that loading a pre-trained DPR checkpoint does not improve the performance much. If you don't have a checkpoint pre-trained on NQ-TABLES, simply drop model_config.QueryEncoderModelVersion and model_config.ItemEncoderModelVersion.)

python src/main.py configs/e2e_wtq/dpr.jsonnet --accelerator gpu --devices 8 --strategy ddp --num_sanity_val_steps 2 --experiment_name DPR_E2EWTQ_train_bz8_gc_4_neg4 --mode train --override --modules exhaustive_search_in_testing --opts train.batch_size=8 train.scheduler=None train.epochs=300 train.lr=0.00001 train.additional.gradient_accumulation_steps=4 train.additional.warmup_steps=0 train.additional.early_stop_patience=100 train.additional.save_top_k=3 valid.batch_size=32 test.batch_size=32 valid.step_size=10 data_loader.dummy_dataloader=0 reset=1 model_config.num_negative_samples=4 model_config.QueryEncoderModelVersion=/wd/Experiments/DPR_NQTables_train_bz8_gc_4_crossGPU/train/saved_model/step_2039/query_encoder model_config.ItemEncoderModelVersion=/wd/Experiments/DPR_NQTables_train_bz8_gc_4_crossGPU/train/saved_model/step_2039/item_encoder model_config.bm25_top_k=5 model_config.bm25_ratio=0

E2E_WTQ test

python src/main.py configs/e2e_wtq/dpr.jsonnet --accelerator gpu --devices 1 --strategy ddp --experiment_name DPR_E2EWTQ_train_bz8_gc_4_neg4 --mode test --test_evaluation_name e2e_wtq_all --opts test.batch_size=32 test.load_epoch=480

DPR + RAGE

Some settings:

  • --modules add_binary_labels_as_prompt: add binary relevance tokens in training and in testing. Note: must be enabled both in training and testing.
  • nq_tables/rag.jsonnet: joint training of retriever and reader; frozen_rag.jsonnet: freeze the retriever during training.

NQ-TABLES train

python src/main.py configs/nq_tables/rag.jsonnet --accelerator gpu --devices 8 --strategy ddp --num_sanity_val_steps 2 --experiment_name RAG_NQTables_RAVQA_loss_approach5_add_prompt --mode train --modules add_binary_labels_as_prompt --override --opts train.batch_size=1 train.scheduler=linear train.epochs=20 train.lr=0.00002 train.retriever_lr=0.00001 train.additional.gradient_accumulation_steps=4 train.additional.warmup_steps=0 train.additional.early_stop_patience=3 train.additional.save_top_k=1 valid.step_size=400 valid.batch_size=4 test.batch_size=4 data_loader.dummy_dataloader=0 model_config.GeneratorModelVersion=microsoft/tapex-large data_loader.additional.num_knowledge_passages=5 model_config.num_beams=5 model_config.RAVQA_loss_type=Approach5 reset=1

NQ-TABLES test

python src/main.py configs/nq_tables/rag.jsonnet --accelerator gpu --devices 1 --strategy ddp --experiment_name RAG_NQTables_RAVQA_loss_approach5_add_prompt --mode test --test_evaluation_name official_sets --modules add_binary_labels_as_prompt --opts test.batch_size=1 test.load_epoch=[] model_config.num_beams=5 data_loader.additional.num_knowledge_passages=5

E2E_WTQ train

python src/main.py configs/e2e_wtq/rag.jsonnet --accelerator gpu --devices 8 --strategy ddp --num_sanity_val_steps 2 --experiment_name RAG_E2EWTQ_RAVQA_Approach5_add_prompt --mode train --modules add_binary_labels_as_prompt --override --opts train.batch_size=1 train.scheduler=none train.epochs=100 train.lr=0.000015 train.retriever_lr=0.00001 train.additional.gradient_accumulation_steps=4 train.additional.warmup_steps=0 train.additional.early_stop_patience=3 train.additional.save_top_k=-1 valid.step_size=25 valid.batch_size=4 test.batch_size=4 data_loader.dummy_dataloader=0 model_config.GeneratorModelVersion=microsoft/tapex-large-finetuned-wtq data_loader.additional.num_knowledge_passages=5 model_config.num_beams=5 model_config.RAVQA_loss_type=Approach5 reset=1

E2E_WTQ test

python src/main.py configs/e2e_wtq/rag.jsonnet --accelerator gpu --devices 1 --strategy ddp --experiment_name RAG_E2EWTQ_RAVQA_Approach5_add_prompt --mode test --test_evaluation_name K5 --modules add_binary_labels_as_prompt --opts test.batch_size=1 test.load_epoch=126 model_config.num_beams=5 data_loader.additional.num_knowledge_passages=5

Miscellaneous

Difference w.r.t. the official ColBERT repository

We have made several changes to the ColBERT codebase so that it can be run and integrated into our framework.

  • The original code uses total_visible_gpus to determine if we are to use gpus. However, in dynamic retrieval, we'd like to use cpu for retrieval only, while keeping gpus visible to our main training framework. Therefore, we modified colbert/indexing/codecs/residual.py to add an argument disable_gpu. Similarly, self.use_gpu is added to colbert/search/index_loader.py. colbert/searcher.py reads initial_config.total_visible_gpus = config.total_visible_gpus from the pass-in config.
  • colbert/indexing/collection_indexer.py: we commented out self.config.help() which generates redundant information flooding the terminal.
  • colbert/modeling/colbert.py: its original negative samples are only shared across batches in one device. We added support for sharing negative samples in all batches across all GPUs. We also re-wrote the original in-batch negative sampling loss with a cleaner version.
  • colbert/modeling/tokenization/doc_tokenization.py: the padding strategy is changed from longest to max_length. This is because our tables are typically long, and we set max_length=512.
  • Other issues that blocked running
    • colbert/search/index_storage.py: approx_scores = torch.cat(approx_scores, dim=0).float()
    • colbert/search/strided_tensor.py: pids = pids.cpu() is necessary. Move them to GPU afterwards.

Security

See CONTRIBUTING for more information.

License Summary

The documentation is made available under the Creative Commons Attribution-ShareAlike 4.0 International License. See the LICENSE file.

The sample code within this documentation is made available under the MIT-0 license. See the LICENSE-SAMPLECODE file.

About

Two approaches for robust TableQA: 1) ITR is a general-purpose retrieval-based approach for handling long tables in TableQA transformer models. 2) LI-RAGE is a robust framework for open-domain TableQA which addresses several limitations. (ACL 2023)

Resources

License

Unknown and 2 other licenses found

Licenses found

Unknown
LICENSE
MIT-0
LICENSE-SAMPLECODE
Unknown
LICENSE-SUMMARY

Code of conduct

Security policy

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published