Skip to content

Commit

Permalink
Merge pull request #38 from jxhe/master
Browse files Browse the repository at this point in the history
Add sentence generation feature to vae_text example

Fixes #30
  • Loading branch information
ZhitingHu committed Oct 26, 2018
2 parents 1a5dbc6 + 7835807 commit 3d684fa
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 9 deletions.
24 changes: 22 additions & 2 deletions examples/vae_text/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ python prepare_data.py --data ptb
python prepare_data.py --data yahoo
```

## Usage
## Training
Train with the following command:

```shell
Expand All @@ -27,11 +27,31 @@ Here:
- [config_lstm_yahoo.py](./config_lstm_yahoo.py): LSTM decoder, on the Yahoo data
- [config_trans_ptb.py](./config_trans_ptb.py): Transformer decoder, on the PTB data
- [config_trans_yahoo.py](./config_trans_yahoo.py): Transformer decoder, on the Yahoo data


## Generation
Generating sentences with pre-trained model can be performed with the following command:
```shell
python vae_train.py --config config_file --mode predict --model /path/to/model.ckpt --out /path/to/output
```

Here `--model` specifies the saved model checkpoint, which is saved in `./models/dataset_name/` at training time. For example, the model path is `./models/ptb/ptb_lstmDecoder.ckpt` when generating with a LSTM decoder trained on PTB dataset. Generated sentences will be written to standard output if `--out` is not specifcied.

## Results

### Language Modeling

|Dataset |Metrics | VAE-LSTM |VAE-Transformer |
|---------------|-------------|----------------|------------------------|
|Yahoo | Test PPL<br>Test NLL | 68.11<br>337.13 |59.95<br>326.93|
|PTB | Test PPL<br>Test NLL | 104.61<br>101.92 | 103.68<br>101.72 |

### Generated Examples
We show the generated examples with transformer as decoder trained on PTB training data.

|Examples|
|:---------|
|i 'm always looking at a level of \$ N to \$ N billion \<EOS\> |
|after four years ago president bush has federal regulators decided to file financing for the waiver\<EOS\> |
|the savings & loan association said total asset revenue was about \$ N billion compared with \$ N billion \<EOS\> |
|the trend would seem to be effective \<EOS\> |
|chicago city 's <unk> computer bank of britain posted a N N jump in third-quarter net income \<EOS\>|
94 changes: 87 additions & 7 deletions examples/vae_text/vae_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

# pylint: disable=invalid-name, no-member, too-many-locals
# pylint: disable=too-many-branches, too-many-statements, redefined-variable-type
Expand All @@ -33,13 +34,22 @@
import sys
import time
import importlib
from io import open

import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
import texar as tx


tfd = tfp.distributions

flags = tf.flags

flags.DEFINE_string("config", "config", "The config to use.")
flags.DEFINE_string("mode", "train", "train or predict")
flags.DEFINE_string("model", None, "model path for generating sentences")
flags.DEFINE_string("out", None, "generation output path")

FLAGS = flags.FLAGS

Expand Down Expand Up @@ -92,12 +102,14 @@ def _main(_):
(train_data.dataset_size() / config.batch_size))

# Model architecture
embedder = tx.modules.WordEmbedder(
vocab_size=train_data.vocab.size, hparams=config.emb_hparams)
encoder_embedder = tx.modules.WordEmbedder(
vocab_size=train_data.vocab.size, hparams=config.emb_hparams)
decoder_embedder = tx.modules.WordEmbedder(
vocab_size=train_data.vocab.size, hparams=config.emb_hparams)


input_embed = embedder(data_batch["text_ids"])
output_embed = embedder(data_batch["text_ids"][:, :-1])
input_embed = encoder_embedder(data_batch["text_ids"])
output_embed = decoder_embedder(data_batch["text_ids"][:, :-1])

if config.enc_keep_prob_in < 1:
input_embed = tf.nn.dropout(
Expand All @@ -117,7 +129,7 @@ def _main(_):
decoder_initial_state_size = decoder.cell.state_size
elif config.decoder_hparams["type"] == 'transformer':
decoder = tx.modules.TransformerDecoder(
embedding=embedder.embedding,
embedding=decoder_embedder.embedding,
hparams=config.trans_hparams)
decoder_initial_state_size = tf.TensorShape(
[1, config.emb_hparams["dim"]])
Expand All @@ -130,6 +142,7 @@ def _main(_):
connector_stoch = tx.modules.ReparameterizedStochasticConnector(
decoder_initial_state_size)


_, ecdr_states = encoder(
input_embed,
sequence_length=data_batch["length"])
Expand All @@ -138,14 +151,19 @@ def _main(_):
mean, logvar = tf.split(mean_logvar, 2, 1)
kl_loss = kl_dvg(mean, logvar)

dst = tf.contrib.distributions.MultivariateNormalDiag(
dst = tfd.MultivariateNormalDiag(
loc=mean,
scale_diag=tf.exp(0.5 * logvar))

dcdr_states, _ = connector_stoch(dst)
dcdr_states, latent_z = connector_stoch(dst)

# decoder
if config.decoder_hparams["type"] == "lstm":
# concat latent variable to input at every time step
latent_z = tf.expand_dims(latent_z, axis=1)
latent_z = tf.tile(latent_z, [1, tf.shape(output_embed)[1], 1])
output_embed = tf.concat([output_embed, latent_z], axis=2)

outputs, _, _ = decoder(
initial_state=dcdr_states,
decoding_strategy="train_greedy",
Expand Down Expand Up @@ -246,9 +264,71 @@ def _run_epoch(sess, epoch, mode_string, display=10):

return nll_ / num_sents, np.exp(nll_ / num_words)

def generate(sess, saver, fname=None):
if tf.train.checkpoint_exists(FLAGS.model):
saver.restore(sess, FLAGS.model)
else:
raise ValueError("cannot find checkpoint model")

batch_size = train_data.batch_size

dst = tfd.MultivariateNormalDiag(
loc=tf.zeros([batch_size, config.latent_dims]),
scale_diag=tf.ones([batch_size, config.latent_dims]))

dcdr_states, latent_z = connector_stoch(dst)

# to concatenate latent variable to input word embeddings
def _cat_embedder(ids):
embedding = decoder_embedder(ids)
return tf.concat([embedding, latent_z], axis=1)

vocab = train_data.vocab
start_tokens = tf.ones(batch_size, tf.int32) * vocab.bos_token_id;
end_token = vocab.eos_token_id;

if config.decoder_hparams["type"] == "lstm":
outputs, _, _ = decoder(
initial_state=dcdr_states,
decoding_strategy="infer_sample",
embedding=_cat_embedder,
max_decoding_length=100,
start_tokens=start_tokens,
end_token=end_token)
else:
outputs, _ = decoder(
memory=dcdr_states,
decoding_strategy="infer_sample",
memory_sequence_length=tf.ones(tf.shape(dcdr_states)[0]),
max_decoding_length=100,
start_tokens=start_tokens,
end_token=end_token)

sample_tokens = vocab.map_ids_to_tokens(outputs.sample_id)
sess.run(tf.tables_initializer())

mode_key = tf.estimator.ModeKeys.EVAL
feed = {tx.global_mode():mode_key}
sample_tokens_ = sess.run(sample_tokens, feed_dict=feed)
if fname is None:
fh = sys.stdout
else:
fh = open(fname, 'w', encoding='utf-8')

for sent in sample_tokens_:
sent = list(sent)
end_id = sent.index(vocab.eos_token)
fh.write(' '.join(sent[:end_id+1]) + '\n')

fh.close()

saver = tf.train.Saver()
with tf.Session() as sess:
# generate samples from prior
if FLAGS.mode == "predict":
generate(sess, saver, FLAGS.out)
return

sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
sess.run(tf.tables_initializer())
Expand Down

0 comments on commit 3d684fa

Please sign in to comment.