Skip to content

Commit

Permalink
Integrating support for GPT/T5/BART for Question Answering (#4532)
Browse files Browse the repository at this point in the history
* added class for qa related metrics

Signed-off-by: Ameya Mahabaleshwarkar <ameyasm1154@gmail.com>

* removed BLEU code from QA metrics

Signed-off-by: Ameya Mahabaleshwarkar <ameyasm1154@gmail.com>

* added classes for data handling and loading for BERT/T5/BART/GPT

Signed-off-by: Ameya Mahabaleshwarkar <ameyasm1154@gmail.com>

* removed unnecassary main function

Signed-off-by: Ameya Mahabaleshwarkar <ameyasm1154@gmail.com>

* added classes for BERT, S2S(T5/BART), GPT question answering models

Signed-off-by: Ameya Mahabaleshwarkar <ameyasm1154@gmail.com>

* created separate modules for model specific input features

Signed-off-by: Ameya Mahabaleshwarkar <ameyasm1154@gmail.com>

* moved non-moodel methods to QAMetrics and refactored method names to more intuitive

Signed-off-by: Ameya Mahabaleshwarkar <ameyasm1154@gmail.com>

* changes classmethods to staticmethods

Signed-off-by: Ameya Mahabaleshwarkar <ameyasm1154@gmail.com>

* removed unnecassary copyright

Signed-off-by: Ameya Mahabaleshwarkar <ameyasm1154@gmail.com>

* removed deprecated input features file

Signed-off-by: Ameya Mahabaleshwarkar <ameyasm1154@gmail.com>

* abstracted cache filename, feature loading, feature dumping to QADataset

Signed-off-by: Ameya Mahabaleshwarkar <ameyasm1154@gmail.com>

* removed unused imports and added dataclass decorator

Signed-off-by: Ameya Mahabaleshwarkar <ameyasm1154@gmail.com>

* removed unused imports and refactored method name

Signed-off-by: Ameya Mahabaleshwarkar <ameyasm1154@gmail.com>

* added base class for QA models and abstracted out common methods

Signed-off-by: Ameya Mahabaleshwarkar <ameyasm1154@gmail.com>

* moved non-model eval code and predictions file dump to metrics class

Signed-off-by: Ameya Mahabaleshwarkar <ameyasm1154@gmail.com>

* added combined example of train/eval/test/inference for all qa models

Signed-off-by: Ameya Mahabaleshwarkar <ameyasm1154@gmail.com>

* renamed qa example file

Signed-off-by: Ameya Mahabaleshwarkar <ameyasm1154@gmail.com>

* fixed trailing whitespaces

Signed-off-by: Ameya Mahabaleshwarkar <ameyasm1154@gmail.com>

* added type casting to float for logger warning

Signed-off-by: Ameya Mahabaleshwarkar <ameyasm1154@gmail.com>

* removed unsed import

Signed-off-by: Ameya Mahabaleshwarkar <ameyasm1154@gmail.com>

* converted cached filename creation to class method

Signed-off-by: Ameya Mahabaleshwarkar <ameyasm1154@gmail.com>

* moved common code in dataset classes to base class, renamed Features class to Example

Signed-off-by: Ameya Mahabaleshwarkar <ameyasm1154@gmail.com>

* converted base QA example class to dataclass

Signed-off-by: Ameya Mahabaleshwarkar <ameyasm1154@gmail.com>

* reduced code repition in prediciton evaluation

Signed-off-by: Ameya Mahabaleshwarkar <ameyasm1154@gmail.com>

* converted prediction output files to jsonl

Signed-off-by: Ameya Mahabaleshwarkar <ameyasm1154@gmail.com>

* added flag for checking if ground truth present in context spans

Signed-off-by: Ameya Mahabaleshwarkar <ameyasm1154@gmail.com>

* converted predictions dump to jsonl from json

Signed-off-by: Ameya Mahabaleshwarkar <ameyasm1154@gmail.com>

* converted nbest predictions dump to jsonl from json

Signed-off-by: Ameya Mahabaleshwarkar <ameyasm1154@gmail.com>

* removed unused argument to no pad loss method

Signed-off-by: Ameya Mahabaleshwarkar <ameyasm1154@gmail.com>

* added unit tests for qa metrics and dataset utilities

Signed-off-by: Ameya Mahabaleshwarkar <ameyasm1154@gmail.com>

* applied style fix on new files

Signed-off-by: Ameya Mahabaleshwarkar <ameyasm1154@gmail.com>

* added integration tests

Signed-off-by: Ameya Mahabaleshwarkar <ameyasm1154@gmail.com>

* restored default values in qa config

Signed-off-by: Ameya Mahabaleshwarkar <ameyasm1154@gmail.com>

* renamed stage to avoid duplicate

Signed-off-by: Ameya Mahabaleshwarkar <ameyasm1154@gmail.com>

* added init files for new modules

Signed-off-by: Ameya Mahabaleshwarkar <ameyasm1154@gmail.com>

* applied style fix for module init files

Signed-off-by: Ameya Mahabaleshwarkar <ameyasm1154@gmail.com>

* added inline comments to make concise

Signed-off-by: Ameya Mahabaleshwarkar <ameyasm1154@gmail.com>

* specified class as abstract

Signed-off-by: Ameya Mahabaleshwarkar <ameyasm1154@gmail.com>

* specified .json format for output prediction files

Signed-off-by: Ameya Mahabaleshwarkar <ameyasm1154@gmail.com>

* created separate variable for answer in context check for readability

Signed-off-by: Ameya Mahabaleshwarkar <ameyasm1154@gmail.com>

* shifted stages to parallel

Signed-off-by: Ameya Mahabaleshwarkar <ameyasm1154@gmail.com>

* applied style fix

Signed-off-by: Ameya Mahabaleshwarkar <ameyasm1154@gmail.com>

* restored file modified by linter

Signed-off-by: Ameya Mahabaleshwarkar <ameyasm1154@gmail.com>

* added transformers offline flag to true and moved all stages to parallel

Signed-off-by: Ameya Mahabaleshwarkar <ameyasm1154@gmail.com>

* moved inference code inside test_ds check

Signed-off-by: Ameya Mahabaleshwarkar <ameyasm1154@gmail.com>

* added script for converting msmarco dataset to squad format

Signed-off-by: Ameya Mahabaleshwarkar <ameyasm1154@gmail.com>

* added tutorial for question answering with generative models

Signed-off-by: Ameya Mahabaleshwarkar <ameyasm1154@gmail.com>

* added copyright header

Signed-off-by: Ameya Mahabaleshwarkar <ameyasm1154@gmail.com>

* renamed old qa docs with _squad postfix and added docs for new qa modules

Signed-off-by: Ameya Mahabaleshwarkar <ameyasm1154@gmail.com>

* added generative qa architecture diagram

Signed-off-by: Ameya Mahabaleshwarkar <ameyasm1154@gmail.com>

* modified tutorial with colab testing changes, improved documentation

Signed-off-by: Ameya Mahabaleshwarkar <ameyasm1154@gmail.com>

* changed branch name to main in tutorial

* deprecated old QA tutorial

* deprecated old QA docs

* deprecated old QA example

* removed deprecated ci test for old qa example

* removed additional deprecated ci tests
  • Loading branch information
ameyasm1154 committed Jul 25, 2022
1 parent 6442e33 commit 6b9617d
Show file tree
Hide file tree
Showing 31 changed files with 5,321 additions and 1,141 deletions.
235 changes: 147 additions & 88 deletions Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -1088,7 +1088,7 @@ pipeline {
}
}
}
stage('L2: Parallel BERT SQUAD v1.1 / v2.0') {
stage('L2: Duplex Text Normalization') {
when {
anyOf {
branch 'main'
Expand All @@ -1097,53 +1097,6 @@ pipeline {
}
failFast true
parallel {
stage('BERT SQUAD 1.1') {
// Cannot do fast_dev_run because squad needs whole dev dataset
steps {
sh 'cd examples/nlp/question_answering && \
python question_answering_squad.py \
model.train_ds.file=/home/TestData/nlp/squad_mini/v1.1/train-v1.1.json \
model.dataset.use_cache=false \
model.validation_ds.file=/home/TestData/nlp/squad_mini/v1.1/dev-v1.1.json \
model.test_ds.file=/home/TestData/nlp/squad_mini/v1.1/dev-v1.1.json \
model.train_ds.batch_size=2 \
model.train_ds.num_samples=2 \
model.validation_ds.batch_size=2 \
model.validation_ds.num_samples=2 \
model.test_ds.num_samples=2 \
model.test_ds.batch_size=2 \
trainer.max_epochs=1 \
+trainer.max_steps=1 \
model.language_model.pretrained_model_name=bert-base-uncased \
model.dataset.version_2_with_negative=false \
trainer.precision=16 \
trainer.devices=[0] \
trainer.accelerator="gpu" \
exp_manager=null'
}
}
stage('BERT SQUAD 2.0') {
// Cannot do fast_dev_run because squad needs whole dev dataset
steps {
sh 'cd examples/nlp/question_answering && \
python question_answering_squad.py \
model.train_ds.file=/home/TestData/nlp/squad_mini/v2.0/train-v2.0.json \
model.dataset.use_cache=false \
model.train_ds.batch_size=2 \
model.train_ds.num_samples=2 \
model.validation_ds.batch_size=2 \
model.validation_ds.num_samples=2 \
trainer.max_epochs=1 \
+trainer.max_steps=1 \
model.validation_ds.file=/home/TestData/nlp/squad_mini/v2.0/dev-v2.0.json \
model.language_model.pretrained_model_name=bert-base-uncased \
model.dataset.version_2_with_negative=true \
trainer.precision=16 \
trainer.devices=[1] \
trainer.accelerator="gpu" \
exp_manager=null'
}
}
stage('Duplex Text Normalization with Tarred dataset') {
steps {
sh 'cd examples/nlp/duplex_text_normalization && \
Expand Down Expand Up @@ -1199,7 +1152,36 @@ pipeline {
// }
// }

stage('L2: Parallel SQUAD v1.1 & v2.0') {
stage('L2: BERT Text Classification') {
when {
anyOf {
branch 'main'
changeRequest target: 'main'
}
}
failFast true
parallel {
stage ('Text Classification with BERT Test') {
steps {
sh 'cd examples/nlp/text_classification && \
python text_classification_with_bert.py \
model.dataset.num_classes=6 \
model.train_ds.file_path=/home/TestData/nlp/retail_text_classification/train.tsv \
model.validation_ds.file_path=/home/TestData/nlp/retail_text_classification/dev.tsv \
model.language_model.pretrained_model_name=distilbert-base-uncased \
model.train_ds.batch_size=10 \
model.dataset.max_seq_length=50 \
model.dataset.use_cache=false \
trainer.devices=[0] \
trainer.accelerator="gpu" \
+trainer.fast_dev_run=true \
exp_manager=null'
}
}
}
}

stage('L2: Parallel BERT/BART/GPT2 Question-Answering SQUAD v1.1 & v2.0') {
when {
anyOf {
branch 'main'
Expand All @@ -1208,72 +1190,149 @@ pipeline {
}
failFast true
parallel {
// TODO: use megatron bert when supported again
stage('SQUAD v2.0 with DistilBERT Uncased') {
// stage('SQUAD v2.0 with Megatron with ckpt & config') {
stage('BERT SQUAD 1.1') {
// Cannot do fast_dev_run because squad needs whole dev dataset
// model.language_model.pretrained_model_name=megatron-bert-uncased \
// model.language_model.lm_checkpoint=/home/TestData/nlp/megatron_345m_uncased/model_optim_rng.pt \
// model.language_model.config_file=/home/TestData/nlp/megatron_345m_uncased/345m_config.json \
steps {
sh 'cd examples/nlp/question_answering && \
python question_answering_squad.py \
model.train_ds.file=/home/TestData/nlp/squad_mini/v2.0/train-v2.0.json \
sh 'TRANSFORMERS_OFFLINE=0 && cd examples/nlp/question_answering && \
python question_answering.py \
model.train_ds.file=/home/TestData/nlp/squad_mini/v1.1/train-v1.1.json \
model.dataset.use_cache=false \
model.train_ds.batch_size=1 \
model.train_ds.num_samples=1 \
model.validation_ds.batch_size=1 \
model.validation_ds.num_samples=1 \
trainer.accelerator=gpu \
trainer.strategy=ddp \
model.validation_ds.file=/home/TestData/nlp/squad_mini/v1.1/dev-v1.1.json \
model.test_ds.file=/home/TestData/nlp/squad_mini/v1.1/dev-v1.1.json \
model.train_ds.batch_size=2 \
model.train_ds.num_samples=2 \
model.validation_ds.batch_size=2 \
model.validation_ds.num_samples=2 \
model.test_ds.num_samples=2 \
model.test_ds.batch_size=2 \
trainer.max_epochs=1 \
+trainer.max_steps=1 \
model.validation_ds.file=/home/TestData/nlp/squad_mini/v2.0/dev-v2.0.json \
model.language_model.pretrained_model_name=distilbert-base-uncased \
model.dataset.version_2_with_negative=true \
trainer.max_steps=1 \
model.language_model.pretrained_model_name=bert-base-uncased \
model.dataset.version_2_with_negative=false \
trainer.precision=16 \
trainer.devices=[1] \
trainer.devices=[0] \
trainer.accelerator="gpu" \
exp_manager=null'
exp_manager=null && TRANSFORMERS_OFFLINE=1'
}
}
stage('RoBERTa SQUAD 1.1') {
stage('BART SQUAD 1.1') {
// Cannot do fast_dev_run because squad needs whole dev dataset
steps {
sh 'cd examples/nlp/question_answering && \
python question_answering_squad.py \
sh 'TRANSFORMERS_OFFLINE=0 && cd examples/nlp/question_answering && \
python question_answering.py \
model.train_ds.file=/home/TestData/nlp/squad_mini/v1.1/train-v1.1.json \
model.dataset.use_cache=false \
model.dataset.check_if_answer_in_context=false \
model.validation_ds.file=/home/TestData/nlp/squad_mini/v1.1/dev-v1.1.json \
model.test_ds.file=/home/TestData/nlp/squad_mini/v1.1/dev-v1.1.json \
model.train_ds.batch_size=2 \
model.train_ds.num_samples=2 \
model.validation_ds.batch_size=2 \
model.validation_ds.num_samples=2 \
model.test_ds.num_samples=2 \
model.test_ds.batch_size=2 \
trainer.max_epochs=1 \
+trainer.max_steps=1 \
model.validation_ds.file=/home/TestData/nlp/squad_mini/v1.1/dev-v1.1.json \
model.language_model.pretrained_model_name=roberta-base \
trainer.max_steps=1 \
model.language_model.pretrained_model_name=facebook/bart-base \
model.dataset.version_2_with_negative=false \
trainer.precision=16 \
trainer.devices=[0] \
trainer.accelerator="gpu" \
exp_manager=null'
exp_manager=null && TRANSFORMERS_OFFLINE=1'
}
}
stage ('Text Classification with BERT Test') {
stage('GPT2 SQUAD 1.1') {
// Cannot do fast_dev_run because squad needs whole dev dataset
steps {
sh 'cd examples/nlp/text_classification && \
python text_classification_with_bert.py \
model.dataset.num_classes=6 \
model.train_ds.file_path=/home/TestData/nlp/retail_text_classification/train.tsv \
model.validation_ds.file_path=/home/TestData/nlp/retail_text_classification/dev.tsv \
model.language_model.pretrained_model_name=distilbert-base-uncased \
model.train_ds.batch_size=10 \
model.dataset.max_seq_length=50 \
sh 'TRANSFORMERS_OFFLINE=0 && cd examples/nlp/question_answering && \
python question_answering.py \
model.train_ds.file=/home/TestData/nlp/squad_mini/v1.1/train-v1.1.json \
model.dataset.use_cache=false \
model.dataset.check_if_answer_in_context=false \
model.validation_ds.file=/home/TestData/nlp/squad_mini/v1.1/dev-v1.1.json \
model.test_ds.file=/home/TestData/nlp/squad_mini/v1.1/dev-v1.1.json \
model.train_ds.batch_size=2 \
model.train_ds.num_samples=2 \
model.validation_ds.batch_size=2 \
model.validation_ds.num_samples=2 \
model.test_ds.num_samples=2 \
model.test_ds.batch_size=2 \
trainer.max_epochs=1 \
trainer.max_steps=1 \
model.language_model.pretrained_model_name=gpt2 \
model.dataset.version_2_with_negative=false \
trainer.precision=16 \
trainer.devices=[0] \
trainer.accelerator="gpu" \
+trainer.fast_dev_run=true \
exp_manager=null'
exp_manager=null && TRANSFORMERS_OFFLINE=1'
}
}
stage('BERT SQUAD 2.0') {
// Cannot do fast_dev_run because squad needs whole dev dataset
steps {
sh 'TRANSFORMERS_OFFLINE=0 && cd examples/nlp/question_answering && \
python question_answering.py \
model.train_ds.file=/home/TestData/nlp/squad_mini/v2.0/train-v2.0.json \
model.dataset.use_cache=false \
model.train_ds.batch_size=2 \
model.train_ds.num_samples=2 \
model.validation_ds.batch_size=2 \
model.validation_ds.num_samples=2 \
trainer.max_epochs=1 \
trainer.max_steps=1 \
model.validation_ds.file=/home/TestData/nlp/squad_mini/v2.0/dev-v2.0.json \
model.language_model.pretrained_model_name=bert-base-uncased \
model.dataset.version_2_with_negative=true \
trainer.precision=16 \
trainer.devices=[1] \
trainer.accelerator="gpu" \
exp_manager=null && TRANSFORMERS_OFFLINE=1'
}
}
stage('BART SQUAD 2.0') {
// Cannot do fast_dev_run because squad needs whole dev dataset
steps {
sh 'TRANSFORMERS_OFFLINE=0 && cd examples/nlp/question_answering && \
python question_answering.py \
model.train_ds.file=/home/TestData/nlp/squad_mini/v2.0/train-v2.0.json \
model.dataset.use_cache=false \
model.dataset.check_if_answer_in_context=false \
model.train_ds.batch_size=2 \
model.train_ds.num_samples=2 \
model.validation_ds.batch_size=2 \
model.validation_ds.num_samples=2 \
trainer.max_epochs=1 \
trainer.max_steps=1 \
model.validation_ds.file=/home/TestData/nlp/squad_mini/v2.0/dev-v2.0.json \
model.language_model.pretrained_model_name=facebook/bart-base \
model.dataset.version_2_with_negative=true \
trainer.precision=16 \
trainer.devices=[1] \
trainer.accelerator="gpu" \
exp_manager=null && TRANSFORMERS_OFFLINE=1'
}
}
stage('GPT2 SQUAD 2.0') {
// Cannot do fast_dev_run because squad needs whole dev dataset
steps {
sh 'TRANSFORMERS_OFFLINE=0 && cd examples/nlp/question_answering && \
python question_answering.py \
model.train_ds.file=/home/TestData/nlp/squad_mini/v2.0/train-v2.0.json \
model.dataset.use_cache=false \
model.dataset.check_if_answer_in_context=false \
model.train_ds.batch_size=2 \
model.train_ds.num_samples=2 \
model.validation_ds.batch_size=2 \
model.validation_ds.num_samples=2 \
trainer.max_epochs=1 \
trainer.max_steps=1 \
model.validation_ds.file=/home/TestData/nlp/squad_mini/v2.0/dev-v2.0.json \
model.language_model.pretrained_model_name=gpt2 \
model.dataset.version_2_with_negative=true \
trainer.precision=16 \
trainer.devices=[1] \
trainer.accelerator="gpu" \
exp_manager=null && TRANSFORMERS_OFFLINE=1'
}
}
}
Expand Down
Loading

0 comments on commit 6b9617d

Please sign in to comment.