Skip to content

Commit

Permalink
Merge pull request #76 from wwt17/transformer_clean
Browse files Browse the repository at this point in the history
Add helper to and refactor code of TransformerDecoder
  • Loading branch information
ZhitingHu committed Feb 21, 2019
2 parents 214d63c + 3f9a066 commit a2e28b2
Show file tree
Hide file tree
Showing 7 changed files with 401 additions and 270 deletions.
26 changes: 13 additions & 13 deletions examples/transformer/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ This is an implementation of the Transformer model described in [Vaswani, Ashish
### Prerequisites ###

Run the following cmd to install necessary packages for the example:
```
```bash
pip install -r requirements.txt
```

Expand All @@ -22,15 +22,15 @@ Two example datasets are provided:
- WMT'14 **EN-DE** for English-German translation

Download and pre-process the **IWSLT'15 EN-VI** data with the following cmds:
```
```bash
sh scripts/iwslt15_en_vi.sh
sh preprocess_data.sh spm en vi
```
By default, the downloaded dataset is in `./data/en_vi`.
As with the [official implementation](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/models/transformer.py), `spm` (`sentencepiece`) encoding is used to encode the raw text as data pre-processing. The encoded data is by default in `./temp/run_en_vi_spm`.

For the **WMT'14 EN-DE** data, download and pre-process with:
```
```bash
sh scripts/wmt14_en_de.sh
sh preprocess_data.sh bpe en de
```
Expand All @@ -41,7 +41,7 @@ Note that for this dataset, `bpe` encoding (Byte pair encoding) is used instead.
### Train and evaluate the model ###

Train the model with the cmd:
```
```bash
python transformer_main.py --run_mode=train_and_evaluate --config_model=config_model --config_data=config_iwslt15
```
* Specify `--model_dir` to dump model checkpoints, training logs, and tensorboard summaries to a desired directory. By default it is set to `./outputs`.
Expand All @@ -51,22 +51,22 @@ python transformer_main.py --run_mode=train_and_evaluate --config_model=config_m
### Test a trained model ###

To only evaluate a model checkpoint without training, first load the checkpoint and generate samples:
```
```bash
python transformer_main.py --run_mode=test --config_data=config_iwslt15 --model_dir=./outputs
```
The latest checkpoint in `./outputs` is used. Generated samples are in the file `./outputs/test.output.hyp`, and reference sentences are in the file `./outputs/test.output.ref`

Next, decode the samples with respective decoder, and evaluate with `bleu_tool`:
```
```bash
../../bin/utils/spm_decode --infile ./outputs/test.output.hyp --outfile temp/test.output.spm --model temp/run_en_vi_spm/data/spm-codes.32000.model --input_format=piece

python bleu_tool.py --reference=data/en_vi/test.vi --translation=temp/test.output.spm
```

For WMT'14, the corresponding cmds are:
```
```bash
# Loads model and generates samples
python transformer_main.py --run_mode=test --config_data=config_wmt14 --log_dir=./outputs
python transformer_main.py --run_mode=test --config_data=config_wmt14 --model_dir=./outputs

# BPE decoding
cat outputs/test.output.hyp | sed -E 's/(@@ )|(@@ ?$)//g' > temp/test.output.bpe
Expand Down Expand Up @@ -107,7 +107,7 @@ Create a data directory and put the raw data in the directory. To be compatible
### 2. Preprocess the data

To obtain the processed dataset, run
```
```bash
preprocess_data.sh ${encoder} ${src} ${tgt} ${vocab_size} ${max_seq_length}
```
where
Expand Down Expand Up @@ -137,7 +137,7 @@ Please refer to the example configuration files `config_model.py` for model conf
### 4. Train the model

Train the model with the following cmd:
```
```bash
python transformer_main.py --run_mode=train_and_evaluate --config_model=custom_config_model --config_data=custom_config_data
```
where the model and data configuration files are `custom_config_model.py` and `custom_config_data.py`, respectively.
Expand All @@ -147,12 +147,12 @@ Outputs such as model checkpoints are by default under `outputs/`.
### 5. Test the model

Test with the following cmd:
```
```bash
python transformer_main.py --run_mode=test --config_data=custom_config_data --model_dir=./outputs
```

Generated samples on the test set are in `outputs/test.output.hyp`, and reference sentences are in `outputs/test.output.ref`. If you've used `bpe` or `spm` encoding in the data preprocessing step, the text in these files are in the respective encoding too. To decode, use the respective cmd:
```
```bash
# BPE decoding
cat outputs/test.output.hyp | sed -E 's/(@@ )|(@@ ?$)//g' > temp/test.output.hyp.final

Expand All @@ -161,7 +161,7 @@ cat outputs/test.output.hyp | sed -E 's/(@@ )|(@@ ?$)//g' > temp/test.output.hyp
```

Finally, to evaluate the BLEU score against the ground truth on the test set:
```
```bash
python bleu_tool.py --reference=you_reference_file --translation=temp/test.output.hyp.final
```
E.g., in the `iwslt15_en_vi` example, with `--reference=data/en_vi/test.vi`
13 changes: 4 additions & 9 deletions examples/transformer/transformer_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,20 +127,15 @@ def main():
predictions = decoder(
memory=encoder_output,
memory_sequence_length=encoder_input_length,
decoding_strategy='infer_greedy',
beam_width=beam_width,
alpha=config_model.alpha,
start_tokens=start_tokens,
end_token=eos_token_id,
max_decoding_length=config_data.max_decoding_length,
mode=tf.estimator.ModeKeys.PREDICT
)
if beam_width <= 1:
inferred_ids = predictions[0].sample_id
else:
# Uses the best sample by beam search
inferred_ids = predictions['sample_id'][:, :, 0]

# Uses the best sample by beam search
beam_search_ids = predictions['sample_id'][:, :, 0]

saver = tf.train.Saver(max_to_keep=5)
best_results = {'score': 0, 'epoch': -1}
Expand All @@ -163,11 +158,11 @@ def _eval_epoch(sess, epoch, mode):
tx.global_mode(): tf.estimator.ModeKeys.EVAL,
}
fetches = {
'inferred_ids': inferred_ids,
'beam_search_ids': beam_search_ids,
}
fetches_ = sess.run(fetches, feed_dict=feed_dict)

hypotheses.extend(h.tolist() for h in fetches_['inferred_ids'])
hypotheses.extend(h.tolist() for h in fetches_['beam_search_ids'])
references.extend(r.tolist() for r in targets)
hypotheses = utils.list_strip_eos(hypotheses, eos_token_id)
references = utils.list_strip_eos(references, eos_token_id)
Expand Down
8 changes: 4 additions & 4 deletions texar/core/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1173,19 +1173,19 @@ def default_average_pooling3d_kwargs():
}

def layer_normalize(inputs,
scope=None):
scope=None,
**kwargs):
'''Applies layer normalization. averaging over the last dimension
Args:
inputs: A tensor with 2 or more dimensions, where the first
dimension has `batch_size`.
epsilon: A floating number. A very small number for preventing
ZeroDivision Error.
scope: Optional scope for `variable_scope`.
Returns:
A tensor with the same shape and data dtype as `inputs`.
'''
return tf.contrib.layers.layer_norm(
inputs=inputs, begin_norm_axis=-1, begin_params_axis=-1, scope=scope
inputs=inputs, begin_norm_axis=-1, begin_params_axis=-1, scope=scope,
**kwargs
)


Expand Down

0 comments on commit a2e28b2

Please sign in to comment.