-
Notifications
You must be signed in to change notification settings - Fork 371
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #41 from tanyuqian/master
examples/seq2seq_exposure added.
- Loading branch information
Showing
21 changed files
with
2,324 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
# Sequence Generation Algorithms Tackling Exposure Bias # | ||
|
||
Despite the computational simplicity and efficiency, maximum likelihood training of sequence generation models (e.g., RNNs) suffers from the exposure bias [(Ranzato et al., 2015)](https://arxiv.org/pdf/1511.06732.pdf). That is, the model is trained to predict the next token given the previous ground-truth tokens; while at test time, since the resulting model does not have access to the ground truth, tokens generated by the model itself are instead used to make the next prediction. This discrepancy between training and test leads to the issue that mistakes in prediction can quickly accumulate. | ||
|
||
This example provide implementations of some classic and advanced training algorithms that tackles the exposure bias. The base model is an attentional seq2seq. | ||
|
||
* Baseline: attentional seq2seq model with maximum likelihood training. | ||
* Reward Augmented Maximum Likelihood (RAML): Described in [(Norouzi et al., 2016)](https://arxiv.org/pdf/1609.00150.pdf) and we use the sampling approach (n-gram replacement) by [(Ma et al., 2017)](https://arxiv.org/abs/1705.07136). | ||
* Scheduled Sampling: Described in [(Bengio et al., 2015)](https://arxiv.org/abs/1506.03099) | ||
* Interpolation Algorithm: Described in [Connecting the Dots Between MLE and RL for Sequence Generation](https://www.cs.cmu.edu/~zhitingh/) | ||
|
||
## Usage ## | ||
|
||
### Dataset ### | ||
|
||
Two example datasets are provided: | ||
|
||
* iwslt14: The benchmark [IWSLT2014](https://sites.google.com/site/iwsltevaluation2014/home) (de-en) machine translation dataset, following [(Ranzato et al., 2015)](https://arxiv.org/pdf/1511.06732.pdf) for data pre-processing. | ||
* gigaword: The benchmark [GIGAWORD](https://catalog.ldc.upenn.edu/LDC2003T05) text summarization dataset. we sampled 200K out of the 3.8M pre-processed training examples provided by [(Rush et al., 2015)](https://www.aclweb.org/anthology/D/D15/D15-1044.pdf) for the sake of training efficiency. We used the refined validation and test sets provided by [(Zhou et al., 2017)](https://arxiv.org/pdf/1704.07073.pdf). | ||
|
||
Download the data with the following commands: | ||
|
||
``` | ||
python utils/prepare_data.py --data iwslt14 | ||
python utils/prepare_data.py --data giga | ||
``` | ||
|
||
### Train the models ### | ||
|
||
#### Baseline Attentional Seq2seq | ||
|
||
``` | ||
python baseline_seq2seq_attn_main.py \ | ||
--config_model configs.config_model \ | ||
--config_data configs.config_iwslt14 | ||
``` | ||
|
||
Here: | ||
* `--config_model` specifies the model config. Note not to include the `.py` suffix. | ||
* `--config_data` specifies the data config. | ||
|
||
[configs.config_model.py](./configs/config_model.py) specifies a single-layer seq2seq model with Luong attention and bi-directional RNN encoder. Hyperparameters taking default values can be omitted from the config file. | ||
|
||
For demonstration purpose, [configs.config_model_full.py](./configs/config_model_full.py) gives all possible hyperparameters for the model. The two config files will lead to the same model. | ||
|
||
#### Reward Augmented Maximum Likelihood (RAML) | ||
``` | ||
python raml_main.py \ | ||
--config_model configs.config_model \ | ||
--config_data configs.config_iwslt14 \ | ||
--raml_file data/iwslt14/samples_iwslt14.txt \ | ||
--n_samples 10 | ||
``` | ||
Here: | ||
* `--raml_file` specifies the file containing the augmented samples and rewards. | ||
* `--n_samples` specifies number of augmented samples for every target sentence. | ||
* `--tau` specifies the temperature of the exponentiated payoff distribution in RAML. | ||
|
||
In the downloaded datasets, we have provided example files for `--raml_file`, which including augmented samples for ```iwslt14``` and ```gigaword``` respectively. We also provide scripts for generating augmented samples by yourself. Please refer to [utils/raml_samples_generation](utils/raml_samples_generation). | ||
|
||
|
||
#### Scheduled Sampling | ||
``` | ||
python scheduled_sampling_main.py \ | ||
--config_model configs.config_model \ | ||
--config_data configs.config_iwslt14 \ | ||
--decay_factor 500. | ||
``` | ||
Here: | ||
* `--decay_factor` specifies the hyperparameter controling the speed of increasing the probability of sampling from model. | ||
|
||
|
||
#### Interpolation Algorithm | ||
``` | ||
python interpolation_main.py \ | ||
--config_model configs.config_model \ | ||
--config_data configs.config_iwslt14 \ | ||
--lambdas_init [0.04,0.06,0.0] \ | ||
--delta_lambda_self 0.06 \ | ||
--delta_lambda_reward 0.06 \ | ||
--lambda_reward_steps 4 | ||
``` | ||
Here: | ||
|
||
* `--lambdas_init` specifies the initial value of lambdas. | ||
* `--delta_lambda_reward` specifies the increment of lambda_reward every annealing step. | ||
* `--delta_lambda_self` specifies the decrement of lambda_self every annealing step. | ||
* `--k` specifies the times of increasing lambda_reward after incresing lambda_self once. | ||
|
||
## Results ## | ||
|
||
### Machine Translation | ||
| Model | BLEU Score | | ||
| -----------| -------| | ||
| MLE | 26.44 ± 0.18 | | ||
| Scheduled Sampling | 26.76 ± 0.17 | | ||
| RAML | 27.22 ± 0.14 | | ||
| Interpolation | 27.82 ± 0.11 | | ||
|
||
### Text Summarization | ||
| Model | Rouge-1 | Rouge-2 | Rouge-L | | ||
| -----------| -------|-------|-------| | ||
| MLE | 36.11 ± 0.21 | 16.39 ± 0.16 | 32.32 ± 0.19 | | ||
| Scheduled Sampling | 36.59 ± 0.12 |16.79 ± 0.22|32.77 ± 0.17| | ||
| RAML | 36.30 ± 0.24 | 16.69 ± 0.20 | 32.49 ± 0.17 | | ||
| Interpolation | 36.72 ± 0.29 |16.99 ± 0.17 | 32.95 ± 0.33| | ||
|
||
|
232 changes: 232 additions & 0 deletions
232
examples/seq2seq_exposure/baseline_seq2seq_attn_main.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,232 @@ | ||
# Copyright 2018 The Texar Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
""" | ||
Attentional Seq2seq. | ||
same as examples/seq2seq_attn except that here Rouge is also supported. | ||
""" | ||
from __future__ import absolute_import | ||
from __future__ import print_function | ||
from __future__ import division | ||
from __future__ import unicode_literals | ||
|
||
# pylint: disable=invalid-name, too-many-arguments, too-many-locals | ||
|
||
from io import open | ||
import importlib | ||
import tensorflow as tf | ||
import texar as tx | ||
from rouge import Rouge | ||
|
||
flags = tf.flags | ||
|
||
flags.DEFINE_string("config_model", "configs.config_model", "The model config.") | ||
flags.DEFINE_string("config_data", "configs.config_iwslt14", | ||
"The dataset config.") | ||
|
||
flags.DEFINE_string('output_dir', '.', 'where to keep training logs') | ||
|
||
FLAGS = flags.FLAGS | ||
|
||
config_model = importlib.import_module(FLAGS.config_model) | ||
config_data = importlib.import_module(FLAGS.config_data) | ||
|
||
if not FLAGS.output_dir.endswith('/'): | ||
FLAGS.output_dir += '/' | ||
log_dir = FLAGS.output_dir + 'training_log_baseline/' | ||
tx.utils.maybe_create_dir(log_dir) | ||
|
||
|
||
def build_model(batch, train_data): | ||
"""Assembles the seq2seq model. | ||
""" | ||
source_embedder = tx.modules.WordEmbedder( | ||
vocab_size=train_data.source_vocab.size, hparams=config_model.embedder) | ||
|
||
encoder = tx.modules.BidirectionalRNNEncoder( | ||
hparams=config_model.encoder) | ||
|
||
enc_outputs, _ = encoder(source_embedder(batch['source_text_ids'])) | ||
|
||
target_embedder = tx.modules.WordEmbedder( | ||
vocab_size=train_data.target_vocab.size, hparams=config_model.embedder) | ||
|
||
decoder = tx.modules.AttentionRNNDecoder( | ||
memory=tf.concat(enc_outputs, axis=2), | ||
memory_sequence_length=batch['source_length'], | ||
vocab_size=train_data.target_vocab.size, | ||
hparams=config_model.decoder) | ||
|
||
training_outputs, _, _ = decoder( | ||
decoding_strategy='train_greedy', | ||
inputs=target_embedder(batch['target_text_ids'][:, :-1]), | ||
sequence_length=batch['target_length'] - 1) | ||
|
||
train_op = tx.core.get_train_op( | ||
tx.losses.sequence_sparse_softmax_cross_entropy( | ||
labels=batch['target_text_ids'][:, 1:], | ||
logits=training_outputs.logits, | ||
sequence_length=batch['target_length'] - 1), | ||
hparams=config_model.opt) | ||
|
||
start_tokens = tf.ones_like(batch['target_length']) *\ | ||
train_data.target_vocab.bos_token_id | ||
beam_search_outputs, _, _ = \ | ||
tx.modules.beam_search_decode( | ||
decoder_or_cell=decoder, | ||
embedding=target_embedder, | ||
start_tokens=start_tokens, | ||
end_token=train_data.target_vocab.eos_token_id, | ||
beam_width=config_model.beam_width, | ||
max_decoding_length=60) | ||
|
||
return train_op, beam_search_outputs | ||
|
||
|
||
def print_stdout_and_file(content, file): | ||
print(content) | ||
print(content, file=file) | ||
|
||
|
||
def main(): | ||
"""Entrypoint. | ||
""" | ||
train_data = tx.data.PairedTextData(hparams=config_data.train) | ||
val_data = tx.data.PairedTextData(hparams=config_data.val) | ||
test_data = tx.data.PairedTextData(hparams=config_data.test) | ||
data_iterator = tx.data.TrainTestDataIterator( | ||
train=train_data, val=val_data, test=test_data) | ||
|
||
batch = data_iterator.get_next() | ||
|
||
train_op, infer_outputs = build_model(batch, train_data) | ||
|
||
def _train_epoch(sess, epoch_no): | ||
data_iterator.switch_to_train_data(sess) | ||
training_log_file = \ | ||
open(log_dir + 'training_log' + str(epoch_no) + '.txt', 'w', | ||
encoding='utf-8') | ||
|
||
step = 0 | ||
while True: | ||
try: | ||
loss = sess.run(train_op) | ||
print("step={}, loss={:.4f}".format(step, loss), | ||
file=training_log_file) | ||
training_log_file.flush() | ||
step += 1 | ||
except tf.errors.OutOfRangeError: | ||
break | ||
|
||
def _eval_epoch(sess, mode, epoch_no): | ||
if mode == 'val': | ||
data_iterator.switch_to_val_data(sess) | ||
else: | ||
data_iterator.switch_to_test_data(sess) | ||
|
||
refs, hypos = [], [] | ||
while True: | ||
try: | ||
fetches = [ | ||
batch['target_text'][:, 1:], | ||
infer_outputs.predicted_ids[:, :, 0] | ||
] | ||
feed_dict = { | ||
tx.global_mode(): tf.estimator.ModeKeys.EVAL | ||
} | ||
target_texts_ori, output_ids = \ | ||
sess.run(fetches, feed_dict=feed_dict) | ||
|
||
target_texts = tx.utils.strip_special_tokens( | ||
target_texts_ori.tolist(), is_token_list=True) | ||
target_texts = tx.utils.str_join(target_texts) | ||
output_texts = tx.utils.map_ids_to_strs( | ||
ids=output_ids, vocab=val_data.target_vocab) | ||
|
||
tx.utils.write_paired_text( | ||
target_texts, output_texts, | ||
log_dir + mode + '_results' + str(epoch_no) + '.txt', | ||
append=True, mode='h', sep=' ||| ') | ||
|
||
for hypo, ref in zip(output_texts, target_texts): | ||
if config_data.eval_metric == 'bleu': | ||
hypos.append(hypo) | ||
refs.append([ref]) | ||
elif config_data.eval_metric == 'rouge': | ||
hypos.append(tx.utils.compat_as_text(hypo)) | ||
refs.append(tx.utils.compat_as_text(ref)) | ||
except tf.errors.OutOfRangeError: | ||
break | ||
|
||
if config_data.eval_metric == 'bleu': | ||
return tx.evals.corpus_bleu_moses( | ||
list_of_references=refs, hypotheses=hypos) | ||
elif config_data.eval_metric == 'rouge': | ||
rouge = Rouge() | ||
return rouge.get_scores(hyps=hypos, refs=refs, avg=True) | ||
|
||
def _calc_reward(score): | ||
""" | ||
Return the bleu score or the sum of (Rouge-1, Rouge-2, Rouge-L). | ||
""" | ||
if config_data.eval_metric == 'bleu': | ||
return score | ||
elif config_data.eval_metric == 'rouge': | ||
return sum([value['f'] for key, value in score.items()]) | ||
|
||
with tf.Session() as sess: | ||
sess.run(tf.global_variables_initializer()) | ||
sess.run(tf.local_variables_initializer()) | ||
sess.run(tf.tables_initializer()) | ||
|
||
best_val_score = -1. | ||
scores_file = open(log_dir + 'scores.txt', 'w', encoding='utf-8') | ||
for i in range(config_data.num_epochs): | ||
_train_epoch(sess, i) | ||
|
||
val_score = _eval_epoch(sess, 'val', i) | ||
test_score = _eval_epoch(sess, 'test', i) | ||
|
||
best_val_score = max(best_val_score, _calc_reward(val_score)) | ||
|
||
if config_data.eval_metric == 'bleu': | ||
print_stdout_and_file( | ||
'val epoch={}, BLEU={:.4f}; best-ever={:.4f}'.format( | ||
i, val_score, best_val_score), file=scores_file) | ||
|
||
print_stdout_and_file( | ||
'test epoch={}, BLEU={:.4f}'.format(i, test_score), | ||
file=scores_file) | ||
print_stdout_and_file('=' * 50, file=scores_file) | ||
|
||
elif config_data.eval_metric == 'rouge': | ||
print_stdout_and_file( | ||
'valid epoch {}:'.format(i), file=scores_file) | ||
for key, value in val_score.items(): | ||
print_stdout_and_file( | ||
'{}: {}'.format(key, value), file=scores_file) | ||
print_stdout_and_file('fsum: {}; best_val_fsum: {}'.format( | ||
_calc_reward(val_score), best_val_score), file=scores_file) | ||
|
||
print_stdout_and_file( | ||
'test epoch {}:'.format(i), file=scores_file) | ||
for key, value in test_score.items(): | ||
print_stdout_and_file( | ||
'{}: {}'.format(key, value), file=scores_file) | ||
print_stdout_and_file('=' * 110, file=scores_file) | ||
|
||
scores_file.flush() | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
Empty file.
Oops, something went wrong.