Skip to content

Commit

Permalink
Merge pull request #227 from gpengzhi/polish-bert
Browse files Browse the repository at this point in the history
Polish BERT example
  • Loading branch information
gpengzhi committed Oct 18, 2019
2 parents bd7d0d8 + bca8362 commit c58d423
Show file tree
Hide file tree
Showing 8 changed files with 173 additions and 690 deletions.
6 changes: 3 additions & 3 deletions examples/bert/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,13 @@ Run the following command to this end:
```
python prepare_data.py --task=MRPC
[--max_seq_length=128]
[--vocab_file=bert_pretrained_models/uncased_L-12_H-768_A-12/vocab.txt]
[--pretrained_model_name=bert-base-uncased]
[--tfrecord_output_dir=data/MRPC]
```

- `--task`: Specifies the dataset name to preprocess. BERT provides default support for `{'CoLA', 'MNLI', 'MRPC', 'XNLI', 'SST'}` data.
- `--max_seq_length`: The maxium length of sequence. This includes BERT special tokens that will be automatically added. Longer sequence will be trimmed.
- `--vocab_file`: Path to a vocabary file used for tokenization.
- `--pretrained_model_name`: The name of pre-trained BERT model. See the [doc](https://texar.readthedocs.io/en/latest/code/modules.html#texar.tf.modules.PretrainedBERTMixin) for all supported models.
- `--tfrecord_output_dir`: The output path where the resulting TFRecord files will be put in. Be default, it is set to `data/{task}` where `{task}` is the (upper-cased) dataset name specified in `--task` above. So in the above cmd, the TFRecord files are output to `data/MRPC`.

**Outcome of the Preprocessing**:
Expand Down Expand Up @@ -75,7 +75,7 @@ Here:
- `config_downstream`: Configuration of the downstream part. In this example, [`config_classifier`](./config_classifier.py) configures the classification layer and the optimization method.
- `config_data`: The data configuration. See the default [`config_data.py`](./config_data.py) for example. Make sure to specify `num_classes`, `num_train_data`, `max_seq_length`, and `tfrecord_data_dir` as used or output in the above [data preparation](#prepare-data) step.
- `output_dir`: The output path where checkpoints and TensorBoard summaries are saved.
- `pretrained_model_name`: The name of a pre-trained model to load selected in the list of: `bert-base-uncased`, `bert-large-uncased`, `bert-base-cased`, `bert-large-cased`, `bert-base-multilingual-uncased`, `bert-base-multilingual-cased`, and `bert-base-chinese`.
- `pretrained_model_name`: The name of pre-trained BERT model. See the [doc](https://texar.readthedocs.io/en/latest/code/modules.html#texar.tf.modules.PretrainedBERTMixin) for all supported models.


For **Multi-GPU training** on one or multiple machines, you may first install the prerequisite OpenMPI and Hovorod packages, as detailed in the [distributed_gpu](https://github.com/asyml/texar/tree/master/examples/distributed_gpu) example.
Expand Down
5 changes: 3 additions & 2 deletions examples/bert/bert_classifier_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,9 @@
"config_downstream", "config_classifier",
"Configuration of the downstream part of the model.")
flags.DEFINE_string(
"pretrained_model_name", "bert-base-uncased",
"Name of the pre-trained checkpoint to load.")
"pretrained_model_name", 'bert-base-uncased',
"The name of pre-trained BERT model. See the doc of "
"`texar.tf.modules.PretrainedBERTMixin for all supported models.`")
flags.DEFINE_string(
"config_data", "config_data",
"The dataset config.")
Expand Down
4 changes: 2 additions & 2 deletions examples/bert/data/README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
This gives the explanation on data preparation.

When you run `data/download_glue_data.py` in the parent directory, by default, all datasets in GLEU will be stored here. For more information on GLUE, please refer to
When you run `data/download_glue_data.py` in the parent directory, by default, all datasets in the General Language Understanding Evaluation (GLUE) will be stored here. For more information on GLUE, please refer to
[gluebenchmark](https://gluebenchmark.com/tasks)

Here we show the data format of the SSN-2 dataset.
Expand All @@ -26,4 +26,4 @@ index sentence
* The test data is in a different format: the first column is a unique index for each test example, the second column is the space-seperated string.


In [`bert/utils/data_utils.py`](https://github.com/asyml/texar/blob/master/examples/bert/utils/data_utils.py), there are 5 types of `Data Processor` Implemented. You can run `python bert_classifier_main.py` and specify `--task` to run on different datasets.
In [`bert/utils/data_utils.py`](https://github.com/asyml/texar/blob/master/examples/bert/utils/data_utils.py), there are 5 types of `Data Processor` implemented. You can run `python bert_classifier_main.py` and specify `--task` to run on different datasets.
85 changes: 56 additions & 29 deletions examples/bert/data/download_glue_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,26 +3,32 @@
Adapted from https://gist.github.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e
"""
import argparse
import os
import sys
import shutil
import argparse
import tempfile
import urllib.request
import zipfile

TASKS = ["CoLA", "SST", "MRPC", "QQP", "STS", "MNLI", "SNLI", "QNLI", "RTE", "WNLI", "diagnostic"]
TASK2PATH = {"CoLA":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FCoLA.zip?alt=media&token=46d5e637-3411-4188-bc44-5809b5bfb5f4',
"SST":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSST-2.zip?alt=media&token=aabc5f6b-e466-44a2-b9b4-cf6337f84ac8',
"MRPC":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2Fmrpc_dev_ids.tsv?alt=media&token=ec5c0836-31d5-48f4-b431-7480817f1adc',
"QQP":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FQQP.zip?alt=media&token=700c6acf-160d-4d89-81d1-de4191d02cb5',
"STS":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSTS-B.zip?alt=media&token=bddb94a7-8706-4e0d-a694-1109e12273b5',
"MNLI":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FMNLI.zip?alt=media&token=50329ea1-e339-40e2-809c-10c40afff3ce',
"SNLI":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSNLI.zip?alt=media&token=4afcfbb2-ff0c-4b2d-a09a-dbf07926f4df',
"QNLI": 'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FQNLIv2.zip?alt=media&token=6fdcf570-0fc5-4631-8456-9505272d1601',
"RTE":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FRTE.zip?alt=media&token=5efa7e85-a0bb-4f19-8ea2-9e1840f077fb',
"WNLI":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FWNLI.zip?alt=media&token=068ad0a0-ded7-4bd7-99a5-5e00222e0faf',
"diagnostic":'https://storage.googleapis.com/mtl-sentence-representations.appspot.com/tsvsWithoutLabels%2FAX.tsv?GoogleAccessId=firebase-adminsdk-0khhl@mtl-sentence-representations.iam.gserviceaccount.com&Expires=2498860800&Signature=DuQ2CSPt2Yfre0C%2BiISrVYrIFaZH1Lc7hBVZDD4ZyR7fZYOMNOUGpi8QxBmTNOrNPjR3z1cggo7WXFfrgECP6FBJSsURv8Ybrue8Ypt%2FTPxbuJ0Xc2FhDi%2BarnecCBFO77RSbfuz%2Bs95hRrYhTnByqu3U%2FYZPaj3tZt5QdfpH2IUROY8LiBXoXS46LE%2FgOQc%2FKN%2BA9SoscRDYsnxHfG0IjXGwHN%2Bf88q6hOmAxeNPx6moDulUF6XMUAaXCSFU%2BnRO2RDL9CapWxj%2BDl7syNyHhB7987hZ80B%2FwFkQ3MEs8auvt5XW1%2Bd4aCU7ytgM69r8JDCwibfhZxpaa4gd50QXQ%3D%3D'}
TASKS = ["CoLA", "SST", "MRPC", "QQP", "STS", "MNLI", "SNLI", "QNLI",
"RTE", "WNLI", "diagnostic"]

# pylint: disable=line-too-long

TASK2PATH = {
"CoLA": 'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FCoLA.zip?alt=media&token=46d5e637-3411-4188-bc44-5809b5bfb5f4',
"SST": 'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSST-2.zip?alt=media&token=aabc5f6b-e466-44a2-b9b4-cf6337f84ac8',
"MRPC": 'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2Fmrpc_dev_ids.tsv?alt=media&token=ec5c0836-31d5-48f4-b431-7480817f1adc',
"QQP": 'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FQQP.zip?alt=media&token=700c6acf-160d-4d89-81d1-de4191d02cb5',
"STS": 'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSTS-B.zip?alt=media&token=bddb94a7-8706-4e0d-a694-1109e12273b5',
"MNLI": 'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FMNLI.zip?alt=media&token=50329ea1-e339-40e2-809c-10c40afff3ce',
"SNLI": 'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSNLI.zip?alt=media&token=4afcfbb2-ff0c-4b2d-a09a-dbf07926f4df',
"QNLI": 'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FQNLIv2.zip?alt=media&token=6fdcf570-0fc5-4631-8456-9505272d1601',
"RTE": 'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FRTE.zip?alt=media&token=5efa7e85-a0bb-4f19-8ea2-9e1840f077fb',
"WNLI": 'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FWNLI.zip?alt=media&token=068ad0a0-ded7-4bd7-99a5-5e00222e0faf',
"diagnostic": 'https://storage.googleapis.com/mtl-sentence-representations.appspot.com/tsvsWithoutLabels%2FAX.tsv?GoogleAccessId=firebase-adminsdk-0khhl@mtl-sentence-representations.iam.gserviceaccount.com&Expires=2498860800&Signature=DuQ2CSPt2Yfre0C%2BiISrVYrIFaZH1Lc7hBVZDD4ZyR7fZYOMNOUGpi8QxBmTNOrNPjR3z1cggo7WXFfrgECP6FBJSsURv8Ybrue8Ypt%2FTPxbuJ0Xc2FhDi%2BarnecCBFO77RSbfuz%2Bs95hRrYhTnByqu3U%2FYZPaj3tZt5QdfpH2IUROY8LiBXoXS46LE%2FgOQc%2FKN%2BA9SoscRDYsnxHfG0IjXGwHN%2Bf88q6hOmAxeNPx6moDulUF6XMUAaXCSFU%2BnRO2RDL9CapWxj%2BDl7syNyHhB7987hZ80B%2FwFkQ3MEs8auvt5XW1%2Bd4aCU7ytgM69r8JDCwibfhZxpaa4gd50QXQ%3D%3D'}

# pylint: enable=line-too-long


def download_and_extract(task, data_dir):
print("Downloading and extracting %s..." % task)
Expand All @@ -33,6 +39,7 @@ def download_and_extract(task, data_dir):
os.remove(data_file)
print("\tCompleted!")


def format_mrpc(data_dir, path_to_data):
print("Processing MRPC...")
mrpc_dir = os.path.join(data_dir, "MRPC")
Expand All @@ -44,18 +51,21 @@ def format_mrpc(data_dir, path_to_data):
else:
mrpc_train_file = os.path.join(mrpc_dir, "msr_paraphrase_train.txt")
mrpc_test_file = os.path.join(mrpc_dir, "msr_paraphrase_test.txt")
assert os.path.isfile(mrpc_train_file), "Train data not found at %s" % mrpc_train_file
assert os.path.isfile(mrpc_test_file), "Test data not found at %s" % mrpc_test_file
urllib.request.urlretrieve(TASK2PATH["MRPC"], os.path.join(mrpc_dir, "dev_ids.tsv"))
assert os.path.isfile(mrpc_train_file), \
"Train data not found at %s" % mrpc_train_file
assert os.path.isfile(mrpc_test_file), \
"Test data not found at %s" % mrpc_test_file
urllib.request.urlretrieve(TASK2PATH["MRPC"],
os.path.join(mrpc_dir, "dev_ids.tsv"))

dev_ids = []
with open(os.path.join(mrpc_dir, "dev_ids.tsv")) as ids_fh:
for row in ids_fh:
dev_ids.append(row.strip().split('\t'))

with open(mrpc_train_file) as data_fh, \
open(os.path.join(mrpc_dir, "train.tsv"), 'w') as train_fh, \
open(os.path.join(mrpc_dir, "dev.tsv"), 'w') as dev_fh:
open(os.path.join(mrpc_dir, "train.tsv"), 'w') as train_fh, \
open(os.path.join(mrpc_dir, "dev.tsv"), 'w') as dev_fh:
header = data_fh.readline()
train_fh.write(header)
dev_fh.write(header)
Expand All @@ -64,16 +74,18 @@ def format_mrpc(data_dir, path_to_data):
if [id1, id2] in dev_ids:
dev_fh.write("%s\t%s\t%s\t%s\t%s\n" % (label, id1, id2, s1, s2))
else:
train_fh.write("%s\t%s\t%s\t%s\t%s\n" % (label, id1, id2, s1, s2))
train_fh.write(
"%s\t%s\t%s\t%s\t%s\n" % (label, id1, id2, s1, s2))
with open(mrpc_test_file) as data_fh, \
open(os.path.join(mrpc_dir, "test.tsv"), 'w') as test_fh:
header = data_fh.readline()
_ = data_fh.readline()
test_fh.write("index\t#1 ID\t#2 ID\t#1 String\t#2 String\n")
for idx, row in enumerate(data_fh):
label, id1, id2, s1, s2 = row.strip().split('\t')
test_fh.write("%d\t%s\t%s\t%s\t%s\n" % (idx, id1, id2, s1, s2))
print("\tCompleted!")


def download_diagnostic(data_dir):
print("Downloading and extracting diagnostic...")
if not os.path.isdir(os.path.join(data_dir, "diagnostic")):
Expand All @@ -83,6 +95,7 @@ def download_diagnostic(data_dir):
print("\tCompleted!")
return


def get_tasks(task_names):
task_names = task_names.split(',')
if "all" in task_names:
Expand All @@ -94,13 +107,21 @@ def get_tasks(task_names):
tasks.append(task_name)
return tasks


def main(arguments):
parser = argparse.ArgumentParser()
parser.add_argument('--data_dir', help='directory to save data to', type=str, default='data')
parser.add_argument('--tasks', help='tasks to download data for as a comma separated string',
type=str, default='all')
parser.add_argument('--path_to_mrpc', help='path to directory containing extracted MRPC data, msr_paraphrase_train.txt and msr_paraphrase_text.txt',
type=str, default='')
parser.add_argument(
'--data_dir', help='directory to save data to',
type=str, default='data')
parser.add_argument(
'--tasks',
help='tasks to download data for as a comma separated string',
type=str, default='all')
parser.add_argument(
'--path_to_mrpc',
help='path to directory containing extracted MRPC data, '
'msr_paraphrase_train.txt and msr_paraphrase_text.txt',
type=str, default='')
args = parser.parse_args(arguments)

if not os.path.isdir(args.data_dir):
Expand All @@ -112,8 +133,14 @@ def main(arguments):
import subprocess
if not os.path.exists("data/MRPC"):
subprocess.run("mkdir data/MRPC", shell=True)
subprocess.run('wget -P data/MRPC/ https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_train.txt', shell=True)
subprocess.run('wget -P data/MRPC/ https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_test.txt', shell=True)
# pylint: disable=line-too-long
subprocess.run(
'wget -P data/MRPC/ https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_train.txt',
shell=True)
subprocess.run(
'wget -P data/MRPC/ https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_test.txt',
shell=True)
# pylint: enable=line-too-long
format_mrpc(args.data_dir, args.path_to_mrpc)
subprocess.run('rm data/MRPC/msr_paraphrase_train.txt', shell=True)
subprocess.run('rm data/MRPC/msr_paraphrase_test.txt', shell=True)
Expand Down
109 changes: 52 additions & 57 deletions examples/bert/prepare_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import texar.tf as tx

# pylint: disable=no-name-in-module
from utils import data_utils, tokenization
from utils import data_utils

# pylint: disable=invalid-name, too-many-locals, too-many-statements

Expand All @@ -35,25 +35,64 @@
"The task to run experiment on. One of "
"{'COLA', 'MNLI', 'MRPC', 'XNLI', 'SST'}.")
flags.DEFINE_string(
"vocab_file", 'bert_pretrained_models/uncased_L-12_H-768_A-12/vocab.txt',
"The one-wordpiece-per-line vocabary file directory.")
"pretrained_model_name", 'bert-base-uncased',
"The name of pre-trained BERT model. See the doc of "
"`texar.tf.modules.PretrainedBERTMixin for all supported models.`")
flags.DEFINE_integer(
"max_seq_length", 128,
"The maxium length of sequence, longer sequence will be trimmed.")
"The maximum length of sequence, longer sequence will be trimmed.")
flags.DEFINE_string(
"tfrecord_output_dir", None,
"The output directory where the TFRecord files will be generated. "
"By default it will be set to 'data/{task}'. E.g.: if "
"task is 'MRPC', it will be set as 'data/MRPC'")
flags.DEFINE_bool(
"do_lower_case", True,
"Whether to lower case the input text. Should be True for uncased "
"models and False for cased models.")

tf.logging.set_verbosity(tf.logging.INFO)


def prepare_data():
def _modify_config_data(max_seq_length, num_train_data, num_classes):
# Modify the data configuration file
config_data_exists = os.path.isfile('./config_data.py')
if config_data_exists:
with open("./config_data.py", 'r') as file:
filedata = file.read()
filedata_lines = filedata.split('\n')
idx = 0
while True:
if idx >= len(filedata_lines):
break
line = filedata_lines[idx]
if (line.startswith('num_classes =') or
line.startswith('num_train_data =') or
line.startswith('max_seq_length =')):
filedata_lines.pop(idx)
idx -= 1
idx += 1

if len(filedata_lines) > 0:
insert_idx = 1
else:
insert_idx = 0
filedata_lines.insert(
insert_idx, '{} = {}'.format(
"num_train_data", num_train_data))
filedata_lines.insert(
insert_idx, '{} = {}'.format(
"num_classes", num_classes))
filedata_lines.insert(
insert_idx, '{} = {}'.format(
"max_seq_length", max_seq_length))

with open("./config_data.py", 'w') as file:
file.write('\n'.join(filedata_lines))
tf.logging.info("config_data.py has been updated")
else:
tf.logging.info("config_data.py cannot be found")

tf.logging.info("Data preparation finished")


def main():
"""Prepares data.
"""
# Loads data
Expand Down Expand Up @@ -88,9 +127,9 @@ def prepare_data():
num_train_data = len(processor.get_train_examples(data_dir))
tf.logging.info(
'num_classes:%d; num_train_data:%d' % (num_classes, num_train_data))
tokenizer = tokenization.FullTokenizer(
vocab_file=FLAGS.vocab_file,
do_lower_case=FLAGS.do_lower_case)

tokenizer = tx.data.BERTTokenizer(
pretrained_model_name=FLAGS.pretrained_model_name)

# Produces TFRecord files
data_utils.prepare_TFRecord_data(
Expand All @@ -99,53 +138,9 @@ def prepare_data():
data_dir=data_dir,
max_seq_length=FLAGS.max_seq_length,
output_dir=tfrecord_output_dir)
modify_config_data(FLAGS.max_seq_length, num_train_data, num_classes)

def modify_config_data(max_seq_length, num_train_data, num_classes):
# Modify the data configuration file
config_data_exists = os.path.isfile('./config_data.py')
if config_data_exists:
with open("./config_data.py", 'r') as file:
filedata = file.read()
filedata_lines = filedata.split('\n')
idx = 0
while True:
if idx >= len(filedata_lines):
break
line = filedata_lines[idx]
if (line.startswith('num_classes =') or
line.startswith('num_train_data =') or
line.startswith('max_seq_length =')):
filedata_lines.pop(idx)
idx -= 1
idx += 1

if len(filedata_lines) > 0:
insert_idx = 1
else:
insert_idx = 0
filedata_lines.insert(
insert_idx, '{} = {}'.format(
"num_train_data", num_train_data))
filedata_lines.insert(
insert_idx, '{} = {}'.format(
"num_classes", num_classes))
filedata_lines.insert(
insert_idx, '{} = {}'.format(
"max_seq_length", max_seq_length))

with open("./config_data.py", 'w') as file:
file.write('\n'.join(filedata_lines))
tf.logging.info("config_data.py has been updated")
else:
tf.logging.info("config_data.py cannot be found")

tf.logging.info("Data preparation finished")
_modify_config_data(FLAGS.max_seq_length, num_train_data, num_classes)

def main():
""" Starts the data preparation
"""
prepare_data()

if __name__ == "__main__":
main()

0 comments on commit c58d423

Please sign in to comment.