Skip to content

Commit

Permalink
Polish BERT example (#121)
Browse files Browse the repository at this point in the history
* Polish BERT example
  • Loading branch information
gpengzhi committed Jul 25, 2019
1 parent 8679104 commit 84747db
Show file tree
Hide file tree
Showing 8 changed files with 97 additions and 115 deletions.
81 changes: 49 additions & 32 deletions examples/bert/README.md
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
# BERT: Pre-trained models and downstream applications

This is a Texar implementation of Google's BERT model, which allows to load pre-trained model parameters downloaded from the [official release](https://github.com/google-research/bert) and build/fine-tune arbitrary downstream applications with **distributed training** (This example showcases BERT for sentence classification).
This is a Texar implementation of Google's BERT model, which allows to load pre-trained model parameters downloaded from the [official release](https://github.com/google-research/bert) and build/fine-tune arbitrary downstream applications (This example showcases BERT for sentence classification).

With Texar, building the BERT model is as simple as creating a [`TransformerEncoder`](https://texar.readthedocs.io/en/latest/code/modules.html#transformerencoder) instance. We can initialize the parameters of the BERT model by specifying `pretrained_model_name` in `BertEncoder`. The pre-trained model can be retrieved from a list of: `bert-base-uncased`, `bert-large-uncased`, `bert-base-cased`, `bert-large-cased`, `bert-base-multilingual-uncased`, `bert-base-multilingual-cased`, and `bert-base-chinese`.
Texar provides ready-to-use modules including [`BERTEncoder`](https://texar-pytorch.readthedocs.io/en/latest/code/modules.html#bertencoder), [`BERTClassifier`](https://texar-pytorch.readthedocs.io/en/latest/code/modules.html#bertclassifier), etc. This example shows the use of `BERTClassifier` for sentence classification tasks.

In sum, this example showcases:

* Use of pre-trained Google BERT models in Texar
* Building and fine-tuning on downstream tasks
* Use of Texar `RecordData` module for data loading and processing

Future work:

* Train or fine-tune the model with distributed GPU

## Prerequisite

#### Install dependencies
Expand All @@ -27,63 +31,75 @@ pip install -r requirements.txt
We explain the use of the example code based on the Microsoft Research Paraphrase Corpus (MRPC) corpus for sentence classification.

Download the data with the following cmd
```

```bash
python data/download_glue_data.py --tasks=MRPC
```
By default, it will download the MRPC dataset into the `data` directory. FYI, the MRPC dataset is part of the [GLUE](https://gluebenchmark.com/tasks) dataset collection.

### Prepare data

We first preprocess the downloaded raw data into [TFRecord](https://www.tensorflow.org/tutorials/load_data/tf_records) files. The preprocessing tokenizes raw text with BPE encoding, truncates sequences, adds special tokens, etc.
Run the following cmd to this end:
```
python prepare_data.py --task=MRPC
[--max_seq_length=128]
[--vocab_file=bert_pretrained_models/uncased_L-12_H-768_A-12/vocab.txt]
[--tfrecord_output_dir=data/MRPC]
We first preprocess the downloaded raw data into [pickled](https://docs.python.org/3/library/pickle.html) files. The preprocessing tokenizes raw text with BPE encoding, truncates sequences, adds special tokens, etc. Run the following command to this end:

```bash
python prepare_data.py --task=MRPC \
--max_seq_length=128 \
--pretrained_model_name=bert-base-uncased \
--output_dir=data/MRPC
```

- `task`: Specifies the dataset name to preprocess. BERT provides default support for `{'CoLA', 'MNLI', 'MRPC', 'XNLI', 'SST'}` data.
- `max_seq_length`: The maxium length of sequence. This includes BERT special tokens that will be automatically added. Longer sequence will be trimmed.
- `vocab_file`: Path to a vocabary file used for tokenization.
- `tfrecord_output_dir`: The output path where the resulting TFRecord files will be put in. Be default, it is set to `data/{task}` where `{task}` is the (upper-cased) dataset name specified in `--task` above. So in the above cmd, the TFRecord files are output to `data/MRPC`.
- `pretrained_model_name`: The name of a pre-trained model to load selected in the list of: `bert-base-uncased`, `bert-large-uncased`, `bert-base-cased`, `bert-large-cased`, `bert-base-multilingual-uncased`, `bert-base-multilingual-cased`, and `bert-base-chinese`.
- `output_dir`: The output path where the resulting pickled files will be put in. Be default, it is set to `data/{task}` where `{task}` is the (upper-cased) dataset name specified in `--task` above. So in the above command, the pickled files are output to `data/MRPC`.

**Outcome of the Preprocessing**:
- The preprocessing will output 3 TFRecord data files `{train.tf_record, eval.tf_record, test.tf_record}` in the specified output directory.
- The cmd also prints logs as follows:

- The preprocessing will output 3 pickled data files `{train.pkl, eval.pkl, test.pkl}` in the specified output directory.

- The command also prints logs as follows:

```
INFO:tensorflow:Loading data
INFO:tensorflow:num_classes:2; num_train_data:3668
INFO:tensorflow:config_data.py has been updated
INFO:tensorflow:Data preparation finished
INFO:root:Loading data
>> Downloading uncased_L-12_H-768_A-12.zip 100.0%%
Successfully downloaded uncased_L-12_H-768_A-12.zip 407727028 bytes.
INFO:root:Extract bert_pretrained_models/uncased_L-12_H-768_A-12.zip
INFO:root:num_classes: 2; num_train_data: 3668
INFO:root:config_data.py has been updated
INFO:root:Data preparation finished
```
**Note that** the data info `num_classes` and `num_train_data`, as well as `max_seq_length` specified in the cmd, are required for BERT training in the following. They should be specified in the data configuration file passed to BERT training (see below).
- For convenience, the above cmd automatically writes `num_classes`, `num_train_data` and `max_seq_length` to `config_data.py`.
**Note that** the data info `num_classes` and `num_train_data`, as well as `max_seq_length` specified in the command, are required for BERT training in the following. They should be specified in the data configuration file passed to BERT training (see below).

- For convenience, the above command automatically writes `num_classes`, `num_train_data` and `max_seq_length` to `config_data.py`.

### Train and Evaluate

For **single-GPU** training (and evaluation), run the following cmd. The training updates the classification layer and fine-tunes the pre-trained BERT parameters.
```
python bert_classifier_main.py --do_train --do_eval
[--config_downstream=config_classifier]
[--config_data=config_data]
[--output_dir=output]
For **single-GPU** training (and evaluation), run the following command. The training updates the classification layer and fine-tunes the pre-trained BERT parameters.

```bash
python bert_classifier_main.py --do_train --do_eval \
--config_downstream=config_classifier \
--config_data=config_data \
--output_dir=output
```
Here:

- `config_downstream`: Configuration of the downstream part. In this example, [`config_classifier`](./config_classifier.py) configures the classification layer and the optimization method.
- `config_data`: The data configuration. See the default [`config_data.py`](./config_data.py) for example. Make sure to specify `num_classes`, `num_train_data`, `max_seq_length`, and `tfrecord_data_dir` as used or output in the above [data preparation](#prepare-data) step.
- `output_dir`: The output path where checkpoints and TensorBoard summaries are saved.
- `config_data`: The data configuration. See the default [`config_data.py`](./config_data.py) for example. Make sure to specify `num_classes`, `num_train_data`, `max_seq_length`, and `pickle_data_dir` as used or output in the above [data preparation](#prepare-data) step.
- `output_dir`: The output path where checkpoints are saved.
- `pretrained_model_name`: The name of a pre-trained model to load selected in the list of: `bert-base-uncased`, `bert-large-uncased`, `bert-base-cased`, `bert-large-cased`, `bert-base-multilingual-uncased`, `bert-base-multilingual-cased`, and `bert-base-chinese`.

After convergence, the evaluation performance is around the following. Due to certain randomness (e.g., random initialization of the classification layer), the evaluation accuracy is reasonable as long as it's `>0.84`.

```
INFO:tensorflow:dev accu: 0.8676470588235294
INFO:root:dev accu: 0.8676470588235294
```

### Restore and Test

``
```bash
python bert_classifier_main.py --do_test --checkpoint=output/model.ckpt
``
```

The output is by default saved in `output/test_results.tsv`, where each line contains the predicted label for each sample.

Expand All @@ -93,7 +109,8 @@ The output is by default saved in `output/test_results.tsv`, where each line con
`bert_classifier_main.py` also support other datasets/tasks. To do this, specify a different value to the `--task` flag when running [data preparation](#prepare-data).

For example, use the following commands to download the SST (Stanford Sentiment Treebank) dataset and run for sentence classification. Make sure to specify the correct data path and other info in the data configuration file.
```

```bash
python data/download_glue_data.py --tasks=SST
python prepare_data.py --task=SST
python bert_classifier_main.py --do_train --do_eval --config_data=config_data
Expand Down
13 changes: 11 additions & 2 deletions examples/bert/bert_classifier_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,13 @@
parser.add_argument(
"--config_downstream", default="config_classifier",
help="Configuration of the downstream part of the model")
parser.add_argument(
'--pretrained_model_name', type=str, default='bert-base-uncased',
help="The name of a pre-trained model to load selected in the "
"list of: `bert-base-uncased`, `bert-large-uncased`, "
"`bert-base-cased`, `bert-large-cased`, "
"`bert-base-multilingual-uncased`, `bert-base-multilingual-cased`, "
"and `bert-base-chinese`.")
parser.add_argument(
"--config_data", default="config_data", help="The dataset config.")
parser.add_argument(
Expand Down Expand Up @@ -71,8 +78,10 @@ def main():
num_train_data = config_data.num_train_data

# Builds BERT
model = tx.modules.BertClassifier(cache_dir='bert_pretrained_models',
hparams=config_downstream)
model = tx.modules.BERTClassifier(
pretrained_model_name=args.pretrained_model_name,
cache_dir='bert_pretrained_models',
hparams=config_downstream)
model.to(device)

num_train_steps = int(num_train_data / config_data.train_batch_size *
Expand Down
3 changes: 0 additions & 3 deletions examples/bert/bert_config_lib/README.md

This file was deleted.

Empty file.

This file was deleted.

20 changes: 10 additions & 10 deletions examples/bert/config_data.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
tfrecord_data_dir = "data/MRPC"
pickle_data_dir = "data/MRPC"
max_seq_length = 128
num_classes = 2
num_train_data = 3668
Expand All @@ -14,14 +14,14 @@
test_batch_size = 8

feature_original_types = {
# Reading features from TFRecord data file.
# E.g., Reading feature "input_ids" as dtype `tf.int64`;
# Reading features from pickled data file.
# E.g., Reading feature "input_ids" as dtype `int64`;
# "FixedLenFeature" indicates its length is fixed for all data instances;
# and the sequence length is limited by `max_seq_length`.
"input_ids": ["tf.int64", "FixedLenFeature", max_seq_length],
"input_mask": ["tf.int64", "FixedLenFeature", max_seq_length],
"segment_ids": ["tf.int64", "FixedLenFeature", max_seq_length],
"label_ids": ["tf.int64", "FixedLenFeature"]
"input_ids": ["int64", "FixedLenFeature", max_seq_length],
"input_mask": ["int64", "FixedLenFeature", max_seq_length],
"segment_ids": ["int64", "FixedLenFeature", max_seq_length],
"label_ids": ["int64", "FixedLenFeature"]
}

train_hparam = {
Expand All @@ -30,7 +30,7 @@
"dataset": {
"data_name": "data",
"feature_original_types": feature_original_types,
"files": "{}/train.tf_record".format(tfrecord_data_dir)
"files": "{}/train.pkl".format(pickle_data_dir)
},
"shuffle": True,
"shuffle_buffer_size": 100
Expand All @@ -42,7 +42,7 @@
"dataset": {
"data_name": "data",
"feature_original_types": feature_original_types,
"files": "{}/eval.tf_record".format(tfrecord_data_dir)
"files": "{}/eval.pkl".format(pickle_data_dir)
},
"shuffle": False
}
Expand All @@ -53,7 +53,7 @@
"dataset": {
"data_name": "data",
"feature_original_types": feature_original_types,
"files": "{}/predict.tf_record".format(tfrecord_data_dir)
"files": "{}/predict.pkl".format(pickle_data_dir)
},
"shuffle": False
}
35 changes: 22 additions & 13 deletions examples/bert/prepare_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Produces TFRecord files and modifies data configuration file
"""Produces pickle files and modifies data configuration file
"""

import argparse
Expand All @@ -28,15 +28,18 @@
choices=['COLA', 'MNLI', 'MRPC', 'XNLI', 'SST'],
help="The task to run experiment on.")
parser.add_argument(
"--vocab_file", type=str,
default='bert_pretrained_models/uncased_L-12_H-768_A-12/vocab.txt',
help="The one-wordpiece-per-line vocabary file directory.")
'--pretrained_model_name', type=str, default='bert-base-uncased',
help="The name of a pre-trained model to load selected in the "
"list of: `bert-base-uncased`, `bert-large-uncased`, "
"`bert-base-cased`, `bert-large-cased`, "
"`bert-base-multilingual-uncased`, `bert-base-multilingual-cased`, "
"and `bert-base-chinese`.")
parser.add_argument(
"--max_seq_length", type=int, default=128,
help="The maxium length of sequence, longer sequence will be trimmed.")
parser.add_argument(
"--tfrecord_output_dir", type=str, default=None,
help="The output directory where the TFRecord files will be generated. "
"--output_dir", type=str, default=None,
help="The output directory where the pickle files will be generated. "
"By default it will be set to 'data/{task}'. E.g.: if "
"task is 'MRPC', it will be set as 'data/MRPC'")
parser.add_argument(
Expand Down Expand Up @@ -64,11 +67,11 @@ def prepare_data():
if args.task.upper() in task_datasets_rename:
data_dir = f'data/{task_datasets_rename[args.task]}'

if args.tfrecord_output_dir is None:
tfrecord_output_dir = data_dir
if args.output_dir is None:
pickle_output_dir = data_dir
else:
tfrecord_output_dir = args.tfrecord_output_dir
tx.utils.maybe_create_dir(tfrecord_output_dir)
pickle_output_dir = args.output_dir
tx.utils.maybe_create_dir(pickle_output_dir)

processors = {
"COLA": data_utils.ColaProcessor,
Expand All @@ -81,21 +84,27 @@ def prepare_data():

from config_data import feature_original_types

pretrained_model_dir = tx.modules.load_pretrained_bert(
pretrained_model_name=args.pretrained_model_name,
cache_dir='bert_pretrained_models')

vocab_file = os.path.join(pretrained_model_dir, "vocab.txt")

num_classes = len(processor.get_labels())
num_train_data = len(processor.get_train_examples(data_dir))
logging.info("num_classes: %d; num_train_data: %d",
num_classes, num_train_data)
tokenizer = tokenization.FullTokenizer(
vocab_file=args.vocab_file,
vocab_file=vocab_file,
do_lower_case=args.do_lower_case)

# Produces TFRecord files
# Produces pickle files
data_utils.prepare_record_data(
processor=processor,
tokenizer=tokenizer,
data_dir=data_dir,
max_seq_length=args.max_seq_length,
output_dir=tfrecord_output_dir,
output_dir=pickle_output_dir,
feature_original_types=feature_original_types)
modify_config_data(args.max_seq_length, num_train_data, num_classes)

Expand Down
10 changes: 5 additions & 5 deletions examples/bert/utils/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,7 @@ def convert_single_example(ex_index, example, label_list, max_seq_length,
def file_based_convert_examples_to_features(
examples, label_list, max_seq_length, tokenizer, output_file,
feature_original_types):
r"""Convert a set of `InputExample`s to a TFRecord file."""
r"""Convert a set of `InputExample`s to a pickled file."""

with tx.data.RecordData.writer(
output_file, feature_original_types) as writer:
Expand Down Expand Up @@ -459,25 +459,25 @@ def prepare_record_data(processor, tokenizer,
SentencePiece Model.
data_dir: The input data directory.
max_seq_length: Max sequence length.
output_dir: The directory to save the TFRecord in.
output_dir: The directory to save the pickled file in.
feature_original_types: The original type of the feature.
"""
label_list = processor.get_labels()

train_examples = processor.get_train_examples(data_dir)
train_file = os.path.join(output_dir, "train.tf_record")
train_file = os.path.join(output_dir, "train.pkl")
file_based_convert_examples_to_features(
train_examples, label_list, max_seq_length,
tokenizer, train_file, feature_original_types)

eval_examples = processor.get_dev_examples(data_dir)
eval_file = os.path.join(output_dir, "eval.tf_record")
eval_file = os.path.join(output_dir, "eval.pkl")
file_based_convert_examples_to_features(
eval_examples, label_list,
max_seq_length, tokenizer, eval_file, feature_original_types)

test_examples = processor.get_test_examples(data_dir)
test_file = os.path.join(output_dir, "predict.tf_record")
test_file = os.path.join(output_dir, "predict.pkl")
file_based_convert_examples_to_features(
test_examples, label_list,
max_seq_length, tokenizer, test_file, feature_original_types)

0 comments on commit 84747db

Please sign in to comment.