Skip to content

Commit

Permalink
Merge pull request #19 from jxhe/jxhe/dev
Browse files Browse the repository at this point in the history
Fix VAE text example when applying transformer
  • Loading branch information
ZhitingHu committed Sep 7, 2018
2 parents 14d4424 + 57e1f6c commit e9a8cec
Show file tree
Hide file tree
Showing 6 changed files with 64 additions and 107 deletions.
4 changes: 2 additions & 2 deletions examples/vae_text/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,6 @@ Here:

|Dataset |Metrics | VAE-LSTM |VAE-Transformer |
|---------------|-------------|----------------|------------------------|
|Yahoo | Test PPL<br>Test NLL | 68.31<br>337.36 |59.56<br>326.41|
|PTB | Test PPL<br>Test NLL | 105.48<br>102.10 | 102.53<br>101.48 |
|Yahoo | Test PPL<br>Test NLL | 68.11<br>337.13 |59.95<br>326.93|
|PTB | Test PPL<br>Test NLL | 104.61<br>101.92 | 103.68<br>101.72 |

12 changes: 7 additions & 5 deletions examples/vae_text/config_lstm_ptb.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@

# pylint: disable=invalid-name, too-few-public-methods, missing-docstring

num_epochs = 50
dataset = "ptb"
num_epochs = 100
hidden_size = 256
dec_keep_prob_in = 0.5
dec_keep_prob_out = 0.5
Expand All @@ -31,8 +32,9 @@

lr_decay_hparams = {
"init_lr": 0.001,
"threshold": 5,
"rate": 0.5
"threshold": 2,
"decay_factor": 0.5,
"max_decay": 5
}


Expand Down Expand Up @@ -76,7 +78,7 @@
# KL annealing
kl_anneal_hparams={
"warm_up": 10,
"start": 0.01
"start": 0.1
}

train_data_hparams = {
Expand Down Expand Up @@ -119,4 +121,4 @@
"type": "clip_by_global_norm",
"kwargs": {"clip_norm": 5.}
}
}
}
63 changes: 6 additions & 57 deletions examples/vae_text/config_lstm_yahoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@

# pylint: disable=invalid-name, too-few-public-methods, missing-docstring

dataset = "yahoo"
num_epochs = 100
hidden_size = 550
hidden_size = 550
dec_keep_prob_in = 0.5
dec_keep_prob_out = 0.5
enc_keep_prob_in = 1.0
Expand All @@ -30,8 +31,9 @@

lr_decay_hparams = {
"init_lr": 0.001,
"threshold": 5,
"rate": 0.5
"threshold": 2,
"decay_factor": 0.5,
"max_decay": 5
}


Expand All @@ -42,8 +44,7 @@
num_blocks = 3

decoder_hparams = {
"type": "lstm",
"train": "vae"
"type": "lstm"
}

enc_cell_hparams = {
Expand Down Expand Up @@ -78,58 +79,6 @@
}
}

# due to the residual connection, the embed_dim should be equal to hidden_size
trans_hparams = {
'share_embed_and_transform': True,
'transform_with_bias': False,
'beam_width': 1,
'multiply_embedding_mode': 'sqrt_depth',
'embedding_dropout': embedding_dropout,
'attention_dropout': attention_dropout,
'residual_dropout': residual_dropout,
'sinusoid': True,
'num_heads': 8,
'num_blocks': num_blocks,
'num_units': hidden_size,
'zero_pad': False,
'bos_pad': False,
'initializer': {
'type': 'variance_scaling_initializer',
'kwargs': {
'scale': 1.0,
'mode':'fan_avg',
'distribution':'uniform',
},
},
'poswise_feedforward': {
'name':'fnn',
'layers':[
{
'type':'Dense',
'kwargs': {
'name':'conv1',
'units':hidden_size*4,
'activation':'relu',
'use_bias':True,
},
},
{
'type':'Dropout',
'kwargs': {
'rate': relu_dropout,
}
},
{
'type':'Dense',
'kwargs': {
'name':'conv2',
'units':hidden_size,
'use_bias':True,
}
}
],
}
}

# KL annealing
# kl_weight = 1.0 / (1 + np.exp(-k*(step-x0)))
Expand Down
24 changes: 8 additions & 16 deletions examples/vae_text/config_trans_ptb.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,22 @@

# pylint: disable=invalid-name, too-few-public-methods, missing-docstring

num_epochs = 50
dataset = "ptb"
num_epochs = 100
hidden_size = 256
enc_keep_prob_in = 1.0
enc_keep_prob_out = 1.0
dec_keep_prob_in = 0.5
dec_keep_prob_in = 1.0
batch_size = 32
embed_dim = 256

latent_dims = 32

lr_decay_hparams = {
"init_lr": 0.001,
"threshold": 1,
"rate": 0.1
"threshold": 2,
"decay_factor": 0.5,
"max_decay": 5
}


Expand Down Expand Up @@ -68,23 +70,13 @@

# due to the residual connection, the embed_dim should be equal to hidden_size
trans_hparams = {
'share_embed_and_transform': True,
'transform_with_bias': False,
'beam_width': 1,
'multiply_embedding_mode': 'sqrt_depth',
'output_layer_bias': False,
'embedding_dropout': embedding_dropout,
'attention_dropout': attention_dropout,
'residual_dropout': residual_dropout,
'position_embedder': {
'name': 'sinusoids',
'hparams': None,
},
'sinusoid': True,
'num_heads': 8,
'num_blocks': num_blocks,
'num_units': hidden_size,
'zero_pad': False,
'bos_pad': False,
'dim': hidden_size,
'initializer': {
'type': 'variance_scaling_initializer',
'kwargs': {
Expand Down
24 changes: 8 additions & 16 deletions examples/vae_text/config_trans_yahoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,22 @@

# pylint: disable=invalid-name, too-few-public-methods, missing-docstring

num_epochs = 50
dataset = "yahoo"
num_epochs = 100
hidden_size = 512
enc_keep_prob_in = 1.0
enc_keep_prob_out = 1.0
dec_keep_prob_in = 0.5
dec_keep_prob_in = 1.0
batch_size = 32
embed_dim = 512

latent_dims = 32

lr_decay_hparams = {
"init_lr": 0.001,
"threshold": 1,
"rate": 0.1
"threshold": 2,
"decay_factor": 0.5,
"max_decay": 5
}


Expand Down Expand Up @@ -68,23 +70,13 @@

# due to the residual connection, the embed_dim should be equal to hidden_size
trans_hparams = {
'share_embed_and_transform': True,
'transform_with_bias': False,
'beam_width': 1,
'multiply_embedding_mode': 'sqrt_depth',
'output_layer_bias': False,
'embedding_dropout': embedding_dropout,
'attention_dropout': attention_dropout,
'residual_dropout': residual_dropout,
'position_embedder': {
'name': 'sinusoids',
'hparams': None,
},
'sinusoid': True,
'num_heads': 8,
'num_blocks': num_blocks,
'num_units': hidden_size,
'zero_pad': False,
'bos_pad': False,
'dim': hidden_size,
'initializer': {
'type': 'variance_scaling_initializer',
'kwargs': {
Expand Down
44 changes: 33 additions & 11 deletions examples/vae_text/vae_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,21 @@ def _main(_):
'kl_weight': config.kl_anneal_hparams["start"]
}

decay_cnt = 0
max_decay = config.lr_decay_hparams["max_decay"]
decay_factor = config.lr_decay_hparams["decay_factor"]
decay_ts = config.lr_decay_hparams["threshold"]

save_dir = "./models/%s" % config.dataset

if not os.path.exists(save_dir):
os.makedirs(save_dir)

suffix = "%s_%sDecoder.ckpt" % \
(config.dataset, config.decoder_hparams["type"])

save_path = os.path.join(save_dir, suffix)

# KL term annealing rate
anneal_r = 1.0 / (config.kl_anneal_hparams["warm_up"] * \
(train_data.dataset_size() / config.batch_size))
Expand All @@ -80,7 +95,8 @@ def _main(_):
vocab_size=train_data.vocab.size, hparams=config.emb_hparams)


output_embed = input_embed = embedder(data_batch["text_ids"])
input_embed = embedder(data_batch["text_ids"])
output_embed = embedder(data_batch["text_ids"][:, :-1])

if config.enc_keep_prob_in < 1:
input_embed = tf.nn.dropout(
Expand Down Expand Up @@ -133,15 +149,15 @@ def _main(_):
decoding_strategy="train_greedy",
inputs=output_embed,
sequence_length=data_batch["length"]-1)
logits = outputs.logits
else:
logits, _ = decoder(
decoder_input=data_batch["text_ids"][:, :-1],
encoder_output=dcdr_states,
encoder_decoder_attention_bias=None)
outputs = decoder(
inputs=output_embed,
memory=dcdr_states,
memory_sequence_length=tf.ones(tf.shape(dcdr_states)[0]))

seq_lengths = data_batch["length"]-1
logits = outputs.logits

seq_lengths = data_batch["length"] - 1
# Losses & train ops
rc_loss = tx.losses.sequence_sparse_softmax_cross_entropy(
labels=data_batch["text_ids"][:, 1:],
Expand All @@ -155,7 +171,7 @@ def _main(_):

learning_rate = \
tf.placeholder(dtype=tf.float32, shape=(), name='learning_rate')
train_op = tx.core.get_train_op(nll, learning_rate=learning_rate,
train_op = tx.core.get_train_op(nll, learning_rate=learning_rate,
hparams=config.opt_hparams)

def _run_epoch(sess, epoch, mode_string, display=10):
Expand Down Expand Up @@ -229,6 +245,7 @@ def _run_epoch(sess, epoch, mode_string, display=10):
return nll_ / num_sents, np.exp(nll_ / num_words)


saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
Expand Down Expand Up @@ -258,16 +275,21 @@ def _run_epoch(sess, epoch, mode_string, display=10):
opt_vars['steps_not_improved'] = 0
best_nll = test_nll
best_ppl = test_ppl
saver.save(sess, save_path)
else:
opt_vars['steps_not_improved'] += 1
if opt_vars['steps_not_improved'] == \
config.lr_decay_hparams["threshold"]:
if opt_vars['steps_not_improved'] == decay_ts:
old_lr = opt_vars['learning_rate']
opt_vars['learning_rate'] *= config.lr_decay_hparams["rate"]
opt_vars['learning_rate'] *= decay_factor
opt_vars['steps_not_improved'] = 0
new_lr = opt_vars['learning_rate']
print('-----\nchange lr, old lr: %f, new lr: %f\n-----' %
(old_lr, new_lr))
saver.restore(sess, save_path)
decay_cnt += 1
if decay_cnt == max_decay:
break


print('\nbest testing nll: %.4f, best testing ppl %.4f\n' %
(best_nll, best_ppl))
Expand Down

0 comments on commit e9a8cec

Please sign in to comment.