Skip to content

Commit

Permalink
polished gpt2
Browse files Browse the repository at this point in the history
  • Loading branch information
qkaren committed Apr 27, 2019
1 parent 0382b41 commit b56d4f4
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions examples/gpt-2/gpt2_train_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def main(_):

## Loads data

# Configures training data shard in distribued mode
# Configures training data shard in distributed mode
if FLAGS.distributed:
config_train.train_hparam["dataset"]["num_shards"] = hvd.size()
config_train.train_hparam["dataset"]["shard_id"] = hvd.rank()
Expand Down Expand Up @@ -147,9 +147,9 @@ def main(_):
# For training
seq_len = tf.fill([batch_size], tf.shape(batch['text_ids'])[1])
pos_embeds = pos_embedder(sequence_length=seq_len)
input_embedding = word_embedder(batch['text_ids']) + pos_embeds
input_embeds = word_embedder(batch['text_ids']) + pos_embeds

outputs = decoder(inputs=input_embedding, decoding_strategy='train_greedy')
outputs = decoder(inputs=input_embeds, decoding_strategy='train_greedy')

loss = tx.losses.sequence_sparse_softmax_cross_entropy(
labels=batch['text_ids'][:, 1:],
Expand Down

0 comments on commit b56d4f4

Please sign in to comment.