From 79e58a07be19cacce7837c1ae8f15ea3aafab51b Mon Sep 17 00:00:00 2001 From: Zhiting Hu Date: Fri, 18 Jan 2019 19:07:25 -0500 Subject: [PATCH] updated transfer_main --- examples/transformer/README.md | 14 +++++++------- examples/transformer/transformer_main.py | 3 ++- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/examples/transformer/README.md b/examples/transformer/README.md index 5882a56a..e04adfb7 100644 --- a/examples/transformer/README.md +++ b/examples/transformer/README.md @@ -54,11 +54,11 @@ To only evaluate a model checkpoint without training, first load the checkpoint ``` python transformer_main.py --run_mode=test --config_data=config_iwslt15 --model_dir=./outputs ``` -The latest checkpoint in `./outputs` is used. Generated samples are in the file `./outputs/test.output`. +The latest checkpoint in `./outputs` is used. Generated samples are in the file `./outputs/test.output.hyp`, and reference sentences are in the file `./outputs/test.output.ref` Next, decode the samples with respective decoder, and evaluate with `bleu_tool`: ``` -../../bin/utils/spm_decode --infile ./outputs/test.output.src --outfile temp/test.output.spm --model temp/run_en_vi_spm/data/spm-codes.32000.model --input_format=piece +../../bin/utils/spm_decode --infile ./outputs/test.output.hyp --outfile temp/test.output.spm --model temp/run_en_vi_spm/data/spm-codes.32000.model --input_format=piece python bleu_tool.py --reference=data/en_vi/test.vi --translation=temp/test.output.spm ``` @@ -69,7 +69,7 @@ For WMT'14, the corresponding cmds are: python transformer_main.py --run_mode=test --config_data=config_wmt14 --log_dir=./outputs # BPE decoding -cat outputs/test.output.src | sed -E 's/(@@ )|(@@ ?$)//g' > temp/test.output.bpe +cat outputs/test.output.hyp | sed -E 's/(@@ )|(@@ ?$)//g' > temp/test.output.bpe # Evaluates BLEU python bleu_tool.py --reference=data/en_de/test.de --translation=temp/test.output.bpe @@ -151,17 +151,17 @@ Test with the following cmd: python transformer_main.py --run_mode=test --config_data=custom_config_data --model_dir=./outputs ``` -Generated samples on the test set are in `outputs/test.output`. If you've used `bpe` or `spm` encoding in the data preprocessing step, the generated samples in `outputs/test.output` are in the respective encoding too. To decode, use the respective cmd: +Generated samples on the test set are in `outputs/test.output.hyp`, and reference sentences are in `outputs/test.output.ref`. If you've used `bpe` or `spm` encoding in the data preprocessing step, the text in these files are in the respective encoding too. To decode, use the respective cmd: ``` # BPE decoding -cat outputs/test.output | sed -E 's/(@@ )|(@@ ?$)//g' > temp/test.output.final +cat outputs/test.output.hyp | sed -E 's/(@@ )|(@@ ?$)//g' > temp/test.output.hyp.final # SPM decoding (take `iwslt15_en_vi` for example) -../../bin/utils/spm_decode --infile ./outputs/test.output --outfile temp/test.output.final --model temp/run_en_vi_spm/data/spm-codes.32000.model --input_format=piece +../../bin/utils/spm_decode --infile ./outputs/test.output.hyp --outfile temp/test.output.hyp.final --model temp/run_en_vi_spm/data/spm-codes.32000.model --input_format=piece ``` Finally, to evaluate the BLEU score against the ground truth on the test set: ``` -python bleu_tool.py --reference=you_reference_file --translation=temp/test.output.final +python bleu_tool.py --reference=you_reference_file --translation=temp/test.output.hyp.final ``` E.g., in the `iwslt15_en_vi` example, with `--reference=data/en_vi/test.vi` diff --git a/examples/transformer/transformer_main.py b/examples/transformer/transformer_main.py index 5a7251c3..ced0ed5d 100644 --- a/examples/transformer/transformer_main.py +++ b/examples/transformer/transformer_main.py @@ -207,7 +207,8 @@ def _eval_epoch(sess, epoch, mode): hwords = tx.utils.str_join(hwords) rwords = tx.utils.str_join(rwords) hyp_fn, ref_fn = tx.utils.write_paired_text( - hwords, rwords, fname, mode='s') + hwords, rwords, fname, mode='s', + src_fname_suffix='hyp', tgt_fname_suffix='ref') logger.info('Test output writtn to file: %s', hyp_fn) print('Test output writtn to file: %s' % hyp_fn)