Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bert modules #167

Merged
merged 20 commits into from
May 31, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions docs/code/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,11 @@ Encoders
.. autoclass:: texar.modules.TransformerEncoder
:members:

:hidden:`BertEncoder`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: texar.modules.BertEncoder
:members:

:hidden:`Conv1DEncoder`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: texar.modules.Conv1DEncoder
Expand Down Expand Up @@ -254,6 +259,12 @@ Classifiers
:members:
:inherited-members:

:hidden:`BertClassifier`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: texar.modules.BertClassifier
:members:
:inherited-members:

Networks
========

Expand Down Expand Up @@ -324,3 +335,11 @@ Q-Nets
.. autoclass:: texar.modules.CategoricalQNet
:members:
:inherited-members:

Berts
=========

:hidden:`BertBase`
~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: texar.modules.BertBase
:members:
254 changes: 254 additions & 0 deletions examples/bert/bert_classifier_main_v2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,254 @@
# Copyright 2019 The Texar Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Example of building a sentence classifier based on pre-trained BERT
model.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import importlib
import tensorflow as tf
import texar as tx

from utils import model_utils

# pylint: disable=invalid-name, too-many-locals, too-many-statements

flags = tf.flags

FLAGS = flags.FLAGS

flags.DEFINE_string(
"config_downstream", "config_classifier",
"Configuration of the downstream part of the model and optmization.")
flags.DEFINE_string(
"config_data", "config_data",
"The dataset config.")
flags.DEFINE_string(
"output_dir", "output/",
"The output directory where the model checkpoints will be written.")
flags.DEFINE_string(
"checkpoint", None,
"Path to a model chceckpoint (including bert modules) to restore from.")
flags.DEFINE_bool("do_train", False, "Whether to run training.")
flags.DEFINE_bool("do_eval", False, "Whether to run eval on the dev set.")
flags.DEFINE_bool("do_test", False, "Whether to run test on the test set.")
flags.DEFINE_bool("distributed", False, "Whether to run in distributed mode.")

config_data = importlib.import_module(FLAGS.config_data)
config_downstream = importlib.import_module(FLAGS.config_downstream)


def main(_):
"""
Builds the model and runs.
"""

if FLAGS.distributed:
import horovod.tensorflow as hvd
hvd.init()

tf.logging.set_verbosity(tf.logging.INFO)

tx.utils.maybe_create_dir(FLAGS.output_dir)

# Loads data
num_train_data = config_data.num_train_data

# Configures distribued mode
if FLAGS.distributed:
config_data.train_hparam["dataset"]["num_shards"] = hvd.size()
config_data.train_hparam["dataset"]["shard_id"] = hvd.rank()
config_data.train_hparam["batch_size"] //= hvd.size()

train_dataset = tx.data.TFRecordData(hparams=config_data.train_hparam)
eval_dataset = tx.data.TFRecordData(hparams=config_data.eval_hparam)
test_dataset = tx.data.TFRecordData(hparams=config_data.test_hparam)

iterator = tx.data.FeedableDataIterator({
'train': train_dataset, 'eval': eval_dataset, 'test': test_dataset})
batch = iterator.get_next()
input_ids = batch["input_ids"]
segment_ids = batch["segment_ids"]
batch_size = tf.shape(input_ids)[0]
input_length = tf.reduce_sum(1 - tf.to_int32(tf.equal(input_ids, 0)),
axis=1)
# Builds BERT
hparams = {
'clas_strategy': 'cls_time'
}
model = tx.modules.BertClassifier(hparams=hparams)
logits, preds = model(input_ids, input_length, segment_ids)

accu = tx.evals.accuracy(batch['label_ids'], preds)

# Optimization
loss = tf.losses.sparse_softmax_cross_entropy(
labels=batch["label_ids"], logits=logits)
global_step = tf.Variable(0, trainable=False)

# Builds learning rate decay scheduler
static_lr = config_downstream.lr['static_lr']
num_train_steps = int(num_train_data / config_data.train_batch_size
* config_data.max_train_epoch)
num_warmup_steps = int(num_train_steps * config_data.warmup_proportion)
lr = model_utils.get_lr(global_step, num_train_steps, # lr is a Tensor
num_warmup_steps, static_lr)

opt = tx.core.get_optimizer(
global_step=global_step,
learning_rate=lr,
hparams=config_downstream.opt
)

if FLAGS.distributed:
opt = hvd.DistributedOptimizer(opt)

train_op = tf.contrib.layers.optimize_loss(
loss=loss,
global_step=global_step,
learning_rate=None,
optimizer=opt)

# Train/eval/test routine

def _is_head():
if not FLAGS.distributed:
return True
return hvd.rank() == 0

def _train_epoch(sess):
"""Trains on the training set, and evaluates on the dev set
periodically.
"""
iterator.restart_dataset(sess, 'train')

fetches = {
'train_op': train_op,
'loss': loss,
'batch_size': batch_size,
'step': global_step
}

while True:
try:
feed_dict = {
iterator.handle: iterator.get_handle(sess, 'train'),
tx.global_mode(): tf.estimator.ModeKeys.TRAIN,
}
rets = sess.run(fetches, feed_dict)
step = rets['step']

dis_steps = config_data.display_steps
if _is_head() and dis_steps > 0 and step % dis_steps == 0:
tf.logging.info('step:%d; loss:%f;' % (step, rets['loss']))

eval_steps = config_data.eval_steps
if _is_head() and eval_steps > 0 and step % eval_steps == 0:
_eval_epoch(sess)

except tf.errors.OutOfRangeError:
break

def _eval_epoch(sess):
"""Evaluates on the dev set.
"""
iterator.restart_dataset(sess, 'eval')

cum_acc = 0.0
cum_loss = 0.0
nsamples = 0
fetches = {
'accu': accu,
'loss': loss,
'batch_size': batch_size,
}
while True:
try:
feed_dict = {
iterator.handle: iterator.get_handle(sess, 'eval'),
tx.context.global_mode(): tf.estimator.ModeKeys.EVAL,
}
rets = sess.run(fetches, feed_dict)

cum_acc += rets['accu'] * rets['batch_size']
cum_loss += rets['loss'] * rets['batch_size']
nsamples += rets['batch_size']
except tf.errors.OutOfRangeError:
break

tf.logging.info('eval accu: {}; loss: {}; nsamples: {}'.format(
cum_acc / nsamples, cum_loss / nsamples, nsamples))

def _test_epoch(sess):
"""Does predictions on the test set.
"""
iterator.restart_dataset(sess, 'test')

_all_preds = []
while True:
try:
feed_dict = {
iterator.handle: iterator.get_handle(sess, 'test'),
tx.context.global_mode(): tf.estimator.ModeKeys.PREDICT,
}
_preds = sess.run(preds, feed_dict=feed_dict)
_all_preds.extend(_preds.tolist())
except tf.errors.OutOfRangeError:
break

output_file = os.path.join(FLAGS.output_dir, "test_results.tsv")
with tf.gfile.GFile(output_file, "w") as writer:
writer.write('\n'.join(str(p) for p in _all_preds))

# Broadcasts global variables from rank-0 process
if FLAGS.distributed:
bcast = hvd.broadcast_global_variables(0)

session_config = tf.ConfigProto()
if FLAGS.distributed:
session_config.gpu_options.visible_device_list = str(hvd.local_rank())

with tf.Session(config=session_config) as sess:
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
sess.run(tf.tables_initializer())

if FLAGS.distributed:
bcast.run()

# Restores trained model if specified
saver = tf.train.Saver()
if FLAGS.checkpoint:
saver.restore(sess, FLAGS.checkpoint)

iterator.initialize_dataset(sess)

if FLAGS.do_train:
for i in range(config_data.max_train_epoch):
_train_epoch(sess)
saver.save(sess, FLAGS.output_dir + '/model.ckpt')

if FLAGS.do_eval:
_eval_epoch(sess)

if FLAGS.do_test:
_test_epoch(sess)


if __name__ == "__main__":
tf.app.run()
Empty file added examples/bert/utils/__init__.py
Empty file.
1 change: 1 addition & 0 deletions texar/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,4 @@
from texar.modules.policies import *
from texar.modules.qnets import *
from texar.modules.memory import *
from texar.modules.berts import *
25 changes: 25 additions & 0 deletions texar/modules/berts/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Copyright 2019 The Texar Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Modules of texar library qnets.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

# pylint: disable=wildcard-import

from texar.modules.berts.berts import *
from texar.modules.berts.bert_utils import *
Loading