Skip to content

Commit

Permalink
Merge pull request #41 from tanyuqian/master
Browse files Browse the repository at this point in the history
examples/seq2seq_exposure added.
  • Loading branch information
ZhitingHu committed Nov 20, 2018
2 parents f2a040b + 05ddbee commit 388390b
Show file tree
Hide file tree
Showing 21 changed files with 2,324 additions and 3 deletions.
4 changes: 2 additions & 2 deletions examples/seq2seq_attn/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ This example builds an attentional seq2seq model for machine translation.
Two example datasets are provided:

* toy_copy: A small toy autoencoding dataset from [TF Seq2seq toolkit](https://github.com/google/seq2seq/tree/2500c26add91b079ca00cf1f091db5a99ddab9ae).
* iwslt14: The benchmark [IWSLT2014](https://sites.google.com/site/iwsltevaluation2014/home) (de-en) machine translation dataset.
* iwslt14: The benchmark [IWSLT2014](https://sites.google.com/site/iwsltevaluation2014/home) (de-en) machine translation dataset, following [(Ranzato et al., 2015)](https://arxiv.org/pdf/1511.06732.pdf) for data pre-processing.

Download the data with the following cmds:

Expand All @@ -36,5 +36,5 @@ For demonstration purpose, [config_model_full.py](./config_model_full.py) gives

## Results ##

On the IWSLT14 dataset, using original target texts as reference(no `<UNK>` in the reference), the model achieves `BLEU=21.66` within `10` epochs.
On the IWSLT14 dataset, using original target texts as reference(no `<UNK>` in the reference), the model achieves `BLEU = 26.44 ± 0.18` .

2 changes: 1 addition & 1 deletion examples/seq2seq_attn/prepare_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def prepare_data():
if FLAGS.data == 'iwslt14':
tx.data.maybe_download(
urls='https://drive.google.com/file/d/'
'1Vuv3bed10qUxrpldHdYoiWLzPKa4pNXd/view?usp=sharing',
'1y4mUWXRS2KstgHopCS9koZ42ENOh6Yb9/view?usp=sharing',
path='./',
filenames='iwslt14.zip',
extract=True)
Expand Down
108 changes: 108 additions & 0 deletions examples/seq2seq_exposure/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# Sequence Generation Algorithms Tackling Exposure Bias #

Despite the computational simplicity and efficiency, maximum likelihood training of sequence generation models (e.g., RNNs) suffers from the exposure bias [(Ranzato et al., 2015)](https://arxiv.org/pdf/1511.06732.pdf). That is, the model is trained to predict the next token given the previous ground-truth tokens; while at test time, since the resulting model does not have access to the ground truth, tokens generated by the model itself are instead used to make the next prediction. This discrepancy between training and test leads to the issue that mistakes in prediction can quickly accumulate.

This example provide implementations of some classic and advanced training algorithms that tackles the exposure bias. The base model is an attentional seq2seq.

* Baseline: attentional seq2seq model with maximum likelihood training.
* Reward Augmented Maximum Likelihood (RAML): Described in [(Norouzi et al., 2016)](https://arxiv.org/pdf/1609.00150.pdf) and we use the sampling approach (n-gram replacement) by [(Ma et al., 2017)](https://arxiv.org/abs/1705.07136).
* Scheduled Sampling: Described in [(Bengio et al., 2015)](https://arxiv.org/abs/1506.03099)
* Interpolation Algorithm: Described in [Connecting the Dots Between MLE and RL for Sequence Generation](https://www.cs.cmu.edu/~zhitingh/)

## Usage ##

### Dataset ###

Two example datasets are provided:

* iwslt14: The benchmark [IWSLT2014](https://sites.google.com/site/iwsltevaluation2014/home) (de-en) machine translation dataset, following [(Ranzato et al., 2015)](https://arxiv.org/pdf/1511.06732.pdf) for data pre-processing.
* gigaword: The benchmark [GIGAWORD](https://catalog.ldc.upenn.edu/LDC2003T05) text summarization dataset. we sampled 200K out of the 3.8M pre-processed training examples provided by [(Rush et al., 2015)](https://www.aclweb.org/anthology/D/D15/D15-1044.pdf) for the sake of training efficiency. We used the refined validation and test sets provided by [(Zhou et al., 2017)](https://arxiv.org/pdf/1704.07073.pdf).

Download the data with the following commands:

```
python utils/prepare_data.py --data iwslt14
python utils/prepare_data.py --data giga
```

### Train the models ###

#### Baseline Attentional Seq2seq

```
python baseline_seq2seq_attn_main.py \
--config_model configs.config_model \
--config_data configs.config_iwslt14
```

Here:
* `--config_model` specifies the model config. Note not to include the `.py` suffix.
* `--config_data` specifies the data config.

[configs.config_model.py](./configs/config_model.py) specifies a single-layer seq2seq model with Luong attention and bi-directional RNN encoder. Hyperparameters taking default values can be omitted from the config file.

For demonstration purpose, [configs.config_model_full.py](./configs/config_model_full.py) gives all possible hyperparameters for the model. The two config files will lead to the same model.

#### Reward Augmented Maximum Likelihood (RAML)
```
python raml_main.py \
--config_model configs.config_model \
--config_data configs.config_iwslt14 \
--raml_file data/iwslt14/samples_iwslt14.txt \
--n_samples 10
```
Here:
* `--raml_file` specifies the file containing the augmented samples and rewards.
* `--n_samples` specifies number of augmented samples for every target sentence.
* `--tau` specifies the temperature of the exponentiated payoff distribution in RAML.

In the downloaded datasets, we have provided example files for `--raml_file`, which including augmented samples for ```iwslt14``` and ```gigaword``` respectively. We also provide scripts for generating augmented samples by yourself. Please refer to [utils/raml_samples_generation](utils/raml_samples_generation).


#### Scheduled Sampling
```
python scheduled_sampling_main.py \
--config_model configs.config_model \
--config_data configs.config_iwslt14 \
--decay_factor 500.
```
Here:
* `--decay_factor` specifies the hyperparameter controling the speed of increasing the probability of sampling from model.


#### Interpolation Algorithm
```
python interpolation_main.py \
--config_model configs.config_model \
--config_data configs.config_iwslt14 \
--lambdas_init [0.04,0.06,0.0] \
--delta_lambda_self 0.06 \
--delta_lambda_reward 0.06 \
--lambda_reward_steps 4
```
Here:

* `--lambdas_init` specifies the initial value of lambdas.
* `--delta_lambda_reward` specifies the increment of lambda_reward every annealing step.
* `--delta_lambda_self` specifies the decrement of lambda_self every annealing step.
* `--k` specifies the times of increasing lambda_reward after incresing lambda_self once.

## Results ##

### Machine Translation
| Model | BLEU Score |
| -----------| -------|
| MLE | 26.44 ± 0.18 |
| Scheduled Sampling | 26.76 ± 0.17 |
| RAML | 27.22 ± 0.14 |
| Interpolation | 27.82 ± 0.11 |

### Text Summarization
| Model | Rouge-1 | Rouge-2 | Rouge-L |
| -----------| -------|-------|-------|
| MLE | 36.11 ± 0.21 | 16.39 ± 0.16 | 32.32 ± 0.19 |
| Scheduled Sampling | 36.59 ± 0.12 |16.79 ± 0.22|32.77 ± 0.17|
| RAML | 36.30 ± 0.24 | 16.69 ± 0.20 | 32.49 ± 0.17 |
| Interpolation | 36.72 ± 0.29 |16.99 ± 0.17 | 32.95 ± 0.33|


232 changes: 232 additions & 0 deletions examples/seq2seq_exposure/baseline_seq2seq_attn_main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,232 @@
# Copyright 2018 The Texar Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Attentional Seq2seq.
same as examples/seq2seq_attn except that here Rouge is also supported.
"""
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
from __future__ import unicode_literals

# pylint: disable=invalid-name, too-many-arguments, too-many-locals

from io import open
import importlib
import tensorflow as tf
import texar as tx
from rouge import Rouge

flags = tf.flags

flags.DEFINE_string("config_model", "configs.config_model", "The model config.")
flags.DEFINE_string("config_data", "configs.config_iwslt14",
"The dataset config.")

flags.DEFINE_string('output_dir', '.', 'where to keep training logs')

FLAGS = flags.FLAGS

config_model = importlib.import_module(FLAGS.config_model)
config_data = importlib.import_module(FLAGS.config_data)

if not FLAGS.output_dir.endswith('/'):
FLAGS.output_dir += '/'
log_dir = FLAGS.output_dir + 'training_log_baseline/'
tx.utils.maybe_create_dir(log_dir)


def build_model(batch, train_data):
"""Assembles the seq2seq model.
"""
source_embedder = tx.modules.WordEmbedder(
vocab_size=train_data.source_vocab.size, hparams=config_model.embedder)

encoder = tx.modules.BidirectionalRNNEncoder(
hparams=config_model.encoder)

enc_outputs, _ = encoder(source_embedder(batch['source_text_ids']))

target_embedder = tx.modules.WordEmbedder(
vocab_size=train_data.target_vocab.size, hparams=config_model.embedder)

decoder = tx.modules.AttentionRNNDecoder(
memory=tf.concat(enc_outputs, axis=2),
memory_sequence_length=batch['source_length'],
vocab_size=train_data.target_vocab.size,
hparams=config_model.decoder)

training_outputs, _, _ = decoder(
decoding_strategy='train_greedy',
inputs=target_embedder(batch['target_text_ids'][:, :-1]),
sequence_length=batch['target_length'] - 1)

train_op = tx.core.get_train_op(
tx.losses.sequence_sparse_softmax_cross_entropy(
labels=batch['target_text_ids'][:, 1:],
logits=training_outputs.logits,
sequence_length=batch['target_length'] - 1),
hparams=config_model.opt)

start_tokens = tf.ones_like(batch['target_length']) *\
train_data.target_vocab.bos_token_id
beam_search_outputs, _, _ = \
tx.modules.beam_search_decode(
decoder_or_cell=decoder,
embedding=target_embedder,
start_tokens=start_tokens,
end_token=train_data.target_vocab.eos_token_id,
beam_width=config_model.beam_width,
max_decoding_length=60)

return train_op, beam_search_outputs


def print_stdout_and_file(content, file):
print(content)
print(content, file=file)


def main():
"""Entrypoint.
"""
train_data = tx.data.PairedTextData(hparams=config_data.train)
val_data = tx.data.PairedTextData(hparams=config_data.val)
test_data = tx.data.PairedTextData(hparams=config_data.test)
data_iterator = tx.data.TrainTestDataIterator(
train=train_data, val=val_data, test=test_data)

batch = data_iterator.get_next()

train_op, infer_outputs = build_model(batch, train_data)

def _train_epoch(sess, epoch_no):
data_iterator.switch_to_train_data(sess)
training_log_file = \
open(log_dir + 'training_log' + str(epoch_no) + '.txt', 'w',
encoding='utf-8')

step = 0
while True:
try:
loss = sess.run(train_op)
print("step={}, loss={:.4f}".format(step, loss),
file=training_log_file)
training_log_file.flush()
step += 1
except tf.errors.OutOfRangeError:
break

def _eval_epoch(sess, mode, epoch_no):
if mode == 'val':
data_iterator.switch_to_val_data(sess)
else:
data_iterator.switch_to_test_data(sess)

refs, hypos = [], []
while True:
try:
fetches = [
batch['target_text'][:, 1:],
infer_outputs.predicted_ids[:, :, 0]
]
feed_dict = {
tx.global_mode(): tf.estimator.ModeKeys.EVAL
}
target_texts_ori, output_ids = \
sess.run(fetches, feed_dict=feed_dict)

target_texts = tx.utils.strip_special_tokens(
target_texts_ori.tolist(), is_token_list=True)
target_texts = tx.utils.str_join(target_texts)
output_texts = tx.utils.map_ids_to_strs(
ids=output_ids, vocab=val_data.target_vocab)

tx.utils.write_paired_text(
target_texts, output_texts,
log_dir + mode + '_results' + str(epoch_no) + '.txt',
append=True, mode='h', sep=' ||| ')

for hypo, ref in zip(output_texts, target_texts):
if config_data.eval_metric == 'bleu':
hypos.append(hypo)
refs.append([ref])
elif config_data.eval_metric == 'rouge':
hypos.append(tx.utils.compat_as_text(hypo))
refs.append(tx.utils.compat_as_text(ref))
except tf.errors.OutOfRangeError:
break

if config_data.eval_metric == 'bleu':
return tx.evals.corpus_bleu_moses(
list_of_references=refs, hypotheses=hypos)
elif config_data.eval_metric == 'rouge':
rouge = Rouge()
return rouge.get_scores(hyps=hypos, refs=refs, avg=True)

def _calc_reward(score):
"""
Return the bleu score or the sum of (Rouge-1, Rouge-2, Rouge-L).
"""
if config_data.eval_metric == 'bleu':
return score
elif config_data.eval_metric == 'rouge':
return sum([value['f'] for key, value in score.items()])

with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
sess.run(tf.tables_initializer())

best_val_score = -1.
scores_file = open(log_dir + 'scores.txt', 'w', encoding='utf-8')
for i in range(config_data.num_epochs):
_train_epoch(sess, i)

val_score = _eval_epoch(sess, 'val', i)
test_score = _eval_epoch(sess, 'test', i)

best_val_score = max(best_val_score, _calc_reward(val_score))

if config_data.eval_metric == 'bleu':
print_stdout_and_file(
'val epoch={}, BLEU={:.4f}; best-ever={:.4f}'.format(
i, val_score, best_val_score), file=scores_file)

print_stdout_and_file(
'test epoch={}, BLEU={:.4f}'.format(i, test_score),
file=scores_file)
print_stdout_and_file('=' * 50, file=scores_file)

elif config_data.eval_metric == 'rouge':
print_stdout_and_file(
'valid epoch {}:'.format(i), file=scores_file)
for key, value in val_score.items():
print_stdout_and_file(
'{}: {}'.format(key, value), file=scores_file)
print_stdout_and_file('fsum: {}; best_val_fsum: {}'.format(
_calc_reward(val_score), best_val_score), file=scores_file)

print_stdout_and_file(
'test epoch {}:'.format(i), file=scores_file)
for key, value in test_score.items():
print_stdout_and_file(
'{}: {}'.format(key, value), file=scores_file)
print_stdout_and_file('=' * 110, file=scores_file)

scores_file.flush()


if __name__ == '__main__':
main()
Empty file.

0 comments on commit 388390b

Please sign in to comment.