Skip to content

Commit

Permalink
Merge pull request #126 from TomNong/master
Browse files Browse the repository at this point in the history
- Refactor TransformerEncoder/Decoder to separate position embeddings from the modules
  * refactor interfaces of helper modules
- Make TransformerDecoder construction interface exact the same with RNN decoders
- Allow to pass a Tensor to `output_layer` of decoders -- used for weight tie b/w output layer and input embedding matrix
  • Loading branch information
ZhitingHu committed Apr 9, 2019
2 parents 0a1940e + 7070bb6 commit 1ff01fe
Show file tree
Hide file tree
Showing 27 changed files with 1,625 additions and 485 deletions.
13 changes: 11 additions & 2 deletions examples/bert/bert_classifier_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,18 +111,27 @@ def main(_):
axis=1)
# Builds BERT
with tf.variable_scope('bert'):
# Word embedding
embedder = tx.modules.WordEmbedder(
vocab_size=bert_config.vocab_size,
hparams=bert_config.embed)
word_embeds = embedder(input_ids)

# Creates segment embeddings for each type of tokens.
# Segment embedding for each type of tokens
segment_embedder = tx.modules.WordEmbedder(
vocab_size=bert_config.type_vocab_size,
hparams=bert_config.segment_embed)
segment_embeds = segment_embedder(segment_ids)

input_embeds = word_embeds + segment_embeds
# Position embedding
position_embedder = tx.modules.PositionEmbedder(
position_size=bert_config.position_size,
hparams=bert_config.position_embed)
seq_length = tf.ones([batch_size], tf.int32) * tf.shape(input_ids)[1]
pos_embeds = position_embedder(sequence_length=seq_length)

# Aggregates embeddings
input_embeds = word_embeds + segment_embeds + pos_embeds

# The BERT model (a TransformerEncoder)
encoder = tx.modules.TransformerEncoder(hparams=bert_config.encoder)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,13 @@
}
type_vocab_size = 2

position_embed = {
'dim': 768,
'name': 'position_embeddings'
}
position_size = 512


encoder = {
'dim': 768,
'embedding_dropout': 0.1,
Expand All @@ -23,11 +30,6 @@
},
'name': 'encoder',
'num_blocks': 12,
'position_embedder_hparams': {
'dim': 768
},
'position_embedder_type': 'variables',
'position_size': 512,
'poswise_feedforward': {
'layers': [
{ 'kwargs': {
Expand Down
17 changes: 9 additions & 8 deletions examples/bert/utils/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,19 @@ def transform_bert_to_texar_config(input_json):
'name': 'word_embeddings',
'dim': hidden_dim}
configs['vocab_size'] = config_ckpt['vocab_size']

configs['segment_embed'] = {
'name': 'token_type_embeddings',
'dim': hidden_dim}
configs['type_vocab_size'] = config_ckpt['type_vocab_size']

configs['position_embed'] = {
'name': 'position_embeddings',
'dim': hidden_dim}
configs['position_size'] = config_ckpt['max_position_embeddings']

configs['encoder'] = {
'name': 'encoder',
'position_embedder_type': 'variables',
'position_size': config_ckpt['max_position_embeddings'],
'position_embedder_hparams': {
'dim': hidden_dim,
},
'embedding_dropout': config_ckpt['hidden_dropout_prob'],
'num_blocks': config_ckpt['num_hidden_layers'],
'multihead_attention': {
Expand Down Expand Up @@ -128,7 +129,7 @@ def _get_assignment_map_from_checkpoint(tvars, init_checkpoint):
'bert/embeddings/word_embeddings': 'bert/word_embeddings/w',
'bert/embeddings/token_type_embeddings': 'bert/token_type_embeddings/w',
'bert/embeddings/position_embeddings':
'bert/encoder/position_embedder/w',
'bert/position_embeddings/w',
'bert/embeddings/LayerNorm/beta': 'bert/encoder/LayerNorm/beta',
'bert/embeddings/LayerNorm/gamma': 'bert/encoder/LayerNorm/gamma',
}
Expand Down Expand Up @@ -172,9 +173,9 @@ def _get_assignment_map_from_checkpoint(tvars, init_checkpoint):
return (assignment_map, initialized_variable_names)

def init_bert_checkpoint(init_checkpoint):
"""Initializes BERT model parameters from a checkpoint provided by
"""Initializes BERT model parameters from a checkpoint provided by
Google.
Args:
init_checkpoint (str): Path to the checkpoint.
"""
Expand Down
11 changes: 5 additions & 6 deletions examples/gpt-2/configs/config_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,12 @@
"dim": dim,
}

pos_embed = {
'dim': dim
}
position_size = 1024

decoder = {
"scale_embeds": False,
"dim": dim,
"num_blocks": 12,
"multihead_attention": {
Expand All @@ -18,11 +22,6 @@
"num_heads": 12,
"output_dim": dim,
},
"position_embedder_type": "variables",
"position_size": 1024,
"position_embedder_hparams": {
"dim": dim,
},
"initializer": {
"type": "variance_scaling_initializer",
"kwargs": {
Expand Down
23 changes: 16 additions & 7 deletions examples/gpt-2/gpt2_generate_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ def main(_):
batch_size = FLAGS.batch_size
max_decoding_length = FLAGS.max_decoding_length


ckpt_path = FLAGS.checkpoint
# Load GPT-2 model configuration
if FLAGS.config_type == "json":
Expand All @@ -88,8 +87,8 @@ def main(_):
else:
raise ValueError('Unknown config_type.')

assert max_decoding_length <= gpt2_config.decoder["position_size"], (
"max_decoding_length should be smaller than position size")
assert max_decoding_length <= gpt2_config.position_size, (
"max_decoding_length should not be greater than position size")
assert nsamples % batch_size == 0, (
"nsamples must be dividable by batch_size")

Expand All @@ -107,20 +106,30 @@ def main(_):
start_tokens = tf.fill([batch_size], end_token)

# Build the GPT-2 modle
embedder = tx.modules.WordEmbedder(
word_embedder = tx.modules.WordEmbedder(
vocab_size=gpt2_config.vocab_size,
hparams=gpt2_config.embed)

pos_embedder = tx.modules.PositionEmbedder(
position_size=gpt2_config.position_size,
hparams=gpt2_config.pos_embed
)

def _embedding_fn(x, y):
return word_embedder(x) + pos_embedder(y)

helper = tx.modules.TopKSampleEmbeddingHelper(
embedding=embedder,
embedding=_embedding_fn,
start_tokens=start_tokens,
end_token=end_token,
top_k=FLAGS.top_k,
softmax_temperature=FLAGS.temperature)
output_layer = tf.transpose(word_embedder.embedding, (1, 0))

decoder = TransformerDecoder(
embedding=embedder.embedding, hparams=gpt2_config.decoder)

vocab_size=gpt2_config.vocab_size,
output_layer=output_layer,
hparams=gpt2_config.decoder)

with tf.Session() as sess:

Expand Down
14 changes: 6 additions & 8 deletions examples/gpt-2/utils/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,13 @@ def transform_gpt2_to_texar_config(input_json_path):
configs["embedding_size"] = config_gpt["n_embd"]
hidden_dim = config_gpt["n_embd"]
configs['embed'] = {
'dim': config_gpt["n_embd"],
'dim': hidden_dim,
}
configs['position_size'] = config_gpt['n_ctx'],
configs['pos_embed'] = {
'dim': hidden_dim
}
configs['decoder'] = {
'scale_embeds': False,
'dim': hidden_dim,
'num_blocks': config_gpt['n_layer'],
'multihead_attention': {
Expand All @@ -30,11 +33,6 @@ def transform_gpt2_to_texar_config(input_json_path):
'num_heads': config_gpt['n_head'],
'output_dim': hidden_dim,
},
'position_embedder_type': 'variables',
'position_size': config_gpt['n_ctx'],
'position_embedder_hparams': {
'dim': hidden_dim,
},
'initializer': {
'type': 'variance_scaling_initializer',
'kwargs': {
Expand Down Expand Up @@ -76,7 +74,7 @@ def _map_tensor_names(original_tensor_name):
"""
global_tensor_map = {
"model/wte": "word_embedder/w",
"model/wpe": "transformer_decoder/position_embedder/w",
"model/wpe": "position_embedder/w",
"model/ln_f/b": "transformer_decoder/beta",
"model/ln_f/g": "transformer_decoder/gamma",
}
Expand Down
7 changes: 4 additions & 3 deletions examples/transformer/config_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@
}
}

position_embedder_hparams = {
'dim': hidden_dim
}

encoder = {
'dim': hidden_dim,
'num_blocks': 6,
Expand All @@ -28,9 +32,6 @@
'output_dim': hidden_dim
# See documentation for more optional hyperparameters
},
'position_embedder_hparams': {
'dim': hidden_dim
},
'initializer': {
'type': 'variance_scaling_initializer',
'kwargs': {
Expand Down
59 changes: 45 additions & 14 deletions examples/transformer/transformer_main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2018 The Texar Authors. All Rights Reserved.
# Copyright 2019 The Texar Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -31,6 +31,7 @@
from utils import data_utils, utils
from utils.preprocess import bos_token_id, eos_token_id
from bleu_tool import bleu_wrapper

# pylint: disable=invalid-name, too-many-locals

flags = tf.flags
Expand Down Expand Up @@ -70,6 +71,7 @@ def main():
# Build model graph
encoder_input = tf.placeholder(tf.int64, shape=(None, None))
decoder_input = tf.placeholder(tf.int64, shape=(None, None))
batch_size = tf.shape(encoder_input)[0]
# (text sequence length excluding padding)
encoder_input_length = tf.reduce_sum(
1 - tf.to_int32(tf.equal(encoder_input, 0)), axis=1)
Expand All @@ -82,27 +84,52 @@ def main():
global_step = tf.Variable(0, dtype=tf.int64, trainable=False)
learning_rate = tf.placeholder(tf.float64, shape=(), name='lr')

embedder = tx.modules.WordEmbedder(
# Source word embedding
src_word_embedder = tx.modules.WordEmbedder(
vocab_size=vocab_size, hparams=config_model.emb)
encoder = TransformerEncoder(hparams=config_model.encoder)
src_word_embeds = src_word_embedder(encoder_input)
src_word_embeds = src_word_embeds * config_model.hidden_dim ** 0.5

# Position embedding (shared b/w source and target)
pos_embedder = tx.modules.SinusoidsPositionEmbedder(
position_size=config_data.max_decoding_length,
hparams=config_model.position_embedder_hparams)
src_seq_len = tf.ones([batch_size], tf.int32) * tf.shape(encoder_input)[1]
src_pos_embeds = pos_embedder(sequence_length=src_seq_len)

encoder_output = encoder(inputs=embedder(encoder_input),
src_input_embedding = src_word_embeds + src_pos_embeds

encoder = TransformerEncoder(hparams=config_model.encoder)
encoder_output = encoder(inputs=src_input_embedding,
sequence_length=encoder_input_length)

# The decoder ties the input word embedding with the output logit layer.
# As the decoder masks out <PAD>'s embedding, which in effect means
# <PAD> has all-zero embedding, so here we explicitly set <PAD>'s embedding
# to all-zero.
tgt_embedding = tf.concat(
[tf.zeros(shape=[1, embedder.dim]), embedder.embedding[1:, :]], axis=0)
decoder = TransformerDecoder(embedding=tgt_embedding,
[tf.zeros(shape=[1, src_word_embedder.dim]),
src_word_embedder.embedding[1:, :]],
axis=0)
tgt_embedder = tx.modules.WordEmbedder(tgt_embedding)
tgt_word_embeds = tgt_embedder(decoder_input)
tgt_word_embeds = tgt_word_embeds * config_model.hidden_dim ** 0.5

tgt_seq_len = tf.ones([batch_size], tf.int32) * tf.shape(decoder_input)[1]
tgt_pos_embeds = pos_embedder(sequence_length=tgt_seq_len)

tgt_input_embedding = tgt_word_embeds + tgt_pos_embeds

_output_w = tf.transpose(tgt_embedder.embedding, (1, 0))

decoder = TransformerDecoder(vocab_size=vocab_size,
output_layer=_output_w,
hparams=config_model.decoder)
# For training
outputs = decoder(
memory=encoder_output,
memory_sequence_length=encoder_input_length,
inputs=embedder(decoder_input),
sequence_length=decoder_input_length,
inputs=tgt_input_embedding,
decoding_strategy='train_greedy',
mode=tf.estimator.ModeKeys.TRAIN
)
Expand All @@ -121,19 +148,23 @@ def main():
tf.summary.scalar('mle_loss', mle_loss)
summary_merged = tf.summary.merge_all()

# For inference
start_tokens = tf.fill([tx.utils.get_batch_size(encoder_input)],
bos_token_id)
# For inference (beam-search)
start_tokens = tf.fill([batch_size], bos_token_id)

def _embedding_fn(x, y):
return tgt_embedder(x) * config_model.hidden_dim ** 0.5 + pos_embedder(
y)

predictions = decoder(
memory=encoder_output,
memory_sequence_length=encoder_input_length,
beam_width=beam_width,
length_penalty=config_model.length_penalty,
start_tokens=start_tokens,
end_token=eos_token_id,
embedding=_embedding_fn,
max_decoding_length=config_data.max_decoding_length,
mode=tf.estimator.ModeKeys.PREDICT
)
mode=tf.estimator.ModeKeys.PREDICT)
# Uses the best sample by beam search
beam_search_ids = predictions['sample_id'][:, :, 0]

Expand All @@ -151,7 +182,7 @@ def _eval_epoch(sess, epoch, mode):
references, hypotheses = [], []
bsize = config_data.test_batch_size
for i in range(0, len(eval_data), bsize):
sources, targets = zip(*eval_data[i:i+bsize])
sources, targets = zip(*eval_data[i:i + bsize])
x_block = data_utils.source_pad_concat_convert(sources)
feed_dict = {
encoder_input: x_block,
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@ tensorflow >= 1.7.0
tensorflow-gpu >= 1.7.0
tensorflow-probability >= 0.3.0
tensorflow-probability-gpu >= 0.3.0
funcsigs >= 1.0.2
funcsigs >= 1.0.2
1 change: 1 addition & 0 deletions texar/models/model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
"ModelBase"
]


class ModelBase(object):
"""Base class inherited by all model classes.
Expand Down
1 change: 1 addition & 0 deletions texar/models/seq2seq/basic_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
"BasicSeq2seq"
]


class BasicSeq2seq(Seq2seqBase):
"""The basic seq2seq model (without attention).
Expand Down
1 change: 1 addition & 0 deletions texar/models/seq2seq/seq2seq_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
"Seq2seqBase"
]


class Seq2seqBase(ModelBase):
"""Base class inherited by all seq2seq model classes.
Expand Down
1 change: 1 addition & 0 deletions texar/modules/decoders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from texar.modules.decoders.rnn_decoder_base import *
from texar.modules.decoders.rnn_decoders import *
from texar.modules.decoders.tf_helpers import *
from texar.modules.decoders.rnn_decoder_helpers import *
from texar.modules.decoders.transformer_decoders import *
from texar.modules.decoders.beam_search_decode import *

0 comments on commit 1ff01fe

Please sign in to comment.