Skip to content

Commit

Permalink
Resolve issues #154 and #155 (#156)
Browse files Browse the repository at this point in the history
- Fix incorrecting padding in Transformer example (#154)
- Fix unnecessary shuffle for dev & test in Transformer example (#155)
  • Loading branch information
huzecong authored and gpengzhi committed Aug 15, 2019
1 parent a678cf3 commit fd39580
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 23 deletions.
13 changes: 9 additions & 4 deletions examples/transformer/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ python bleu_main.py --reference=data/en_de/test.de --translation=temp/test.outpu

### Results

* On **IWSLT'15**, the implementation achieves around `BLEU_cased=28.44` and `BLEU_uncased=29.21` (reported by
* On **IWSLT'15**, the implementation achieves around `BLEU_cased=29.00` and `BLEU_uncased=29.82` (reported by
[bleu_main.py](./bleu_main.py)), which are comparable to the base_single_gpu results by the
[official implementation](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/models/transformer.py)
(`28.12` and `28.97`, respectively, as reported [here](https://github.com/tensorflow/tensor2tensor/pull/611)).
Expand All @@ -118,9 +118,14 @@ python bleu_main.py --reference=data/en_de/test.de --translation=temp/test.outpu
### Example training log

```
12:57:06,611:INFO:step: 500, loss: 7.4818
12:58:24,629:INFO:step: 1000, loss: 6.8003
12:59:42,661:INFO:step: 1500, loss: 6.3096
2019-08-14 16:37:48,346:INFO:Begin running with train_and_evaluate mode
2019-08-14 16:39:10,780:INFO:step: 500, loss: 7.4967
2019-08-14 16:40:34,075:INFO:step: 1000, loss: 6.7844
2019-08-14 16:41:57,523:INFO:step: 1500, loss: 6.3648
2019-08-14 16:43:21,424:INFO:step: 2000, loss: 5.8466
2019-08-14 16:48:31,190:INFO:epoch: 0, eval_bleu 2.0754
2019-08-14 16:48:31,191:INFO:epoch: 0, best bleu: 2.0754
2019-08-14 16:48:31,191:INFO:Saving model to ./outputs/best-model.ckpt
```
Using an NVIDIA GTX 1080Ti, the model usually converges within 5 hours (~15 epochs) on **IWSLT'15**.

Expand Down
2 changes: 1 addition & 1 deletion examples/transformer/preprocess_data.sh
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#!/usr/bin/env bash
# Copyright 2018 The Texar Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -11,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#!/usr/bin/env bash

###########################################################################

Expand Down
8 changes: 5 additions & 3 deletions examples/transformer/transformer_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,9 @@ def main():
# Load data
vocab = tx.data.Vocab(config_data.vocab_file)
data_hparams = {
# "batch_size" is ignored for train since we use dynamic batching
# "batch_size" is ignored for train since we use dynamic batching.
"batch_size": config_data.test_batch_size,
"pad_id": vocab.pad_token_id,
"bos_id": vocab.bos_token_id,
"eos_id": vocab.eos_token_id,
}
Expand All @@ -78,8 +79,9 @@ def main():
config_data.input_dir,
f"{config_data.filename_prefix}{split}.npy"
),
hparams=data_hparams,
device=device
# Only shuffle during training.
hparams={**data_hparams, "shuffle": split == "train"},
device=device,
) for split in ["train", "valid", "test"]
}
print(f"Training data size: {len(datasets['train'])}")
Expand Down
25 changes: 10 additions & 15 deletions examples/transformer/utils/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,29 +81,24 @@ def __init__(self, filename: str, hparams=None,
def default_hparams():
return {
**tx.data.DataBase.default_hparams(),
"pad_id": 0,
"bos_id": 1,
"eos_id": 2,
}

def collate(self, examples: List[Example]) -> tx.data.Batch:
src_seqs = [ex[0] for ex in examples]
tgt_seqs = [ex[1] for ex in examples]
max_src_len = max(map(len, src_seqs))
max_tgt_len = max(map(len, tgt_seqs))
# Add EOS token by setting pad_length to max length + 1.
# Add EOS tokens.
src_seqs = [ex[0].tolist() + [self._hparams.eos_id] for ex in examples]
tgt_seqs = [ex[1].tolist() + [self._hparams.eos_id] for ex in examples]
# Pad sentences to equal length.
source, _ = tx.data.padded_batch(
src_seqs, pad_length=(max_src_len + 1),
pad_value=self._hparams.eos_id,
)
src_seqs, pad_value=self._hparams.pad_id)
target_output, _ = tx.data.padded_batch(
tgt_seqs, pad_length=(max_tgt_len + 1),
pad_value=self._hparams.eos_id,
)
tgt_seqs, pad_value=self._hparams.pad_id)
# Add BOS token to the target inputs.
target_input = np.pad(
target_output[:, :max_tgt_len], ((0, 0), (1, 0)),
"constant", constant_values=self._hparams.bos_id,
)
target_output[:, :-1], ((0, 0), (1, 0)),
mode="constant", constant_values=self._hparams.bos_id)
source, target_input, target_output = [
torch.from_numpy(x).to(device=self.device)
for x in [source, target_input, target_output]
Expand All @@ -112,5 +107,5 @@ def collate(self, examples: List[Example]) -> tx.data.Batch:
len(examples),
source=source,
target_input=target_input,
target_output=target_output
target_output=target_output,
)

0 comments on commit fd39580

Please sign in to comment.