-
Notifications
You must be signed in to change notification settings - Fork 324
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
End-to-end testing with various model configurations (#85)
* End-to-end testing with various model configurations * Reorganize unit and integration tests * More reorganization, add system tests pytest coverage 57% -> 80%
- Loading branch information
Michael Denkowski
committed
Jul 26, 2017
1 parent
6c22eb2
commit 7da864e
Showing
34 changed files
with
398 additions
and
95 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,2 @@ | ||
[pytest] | ||
addopts = --cov sockeye test -v | ||
addopts = --cov sockeye test/unit test/integration -v |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,159 @@ | ||
# Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"). You may not | ||
# use this file except in compliance with the License. A copy of the License | ||
# is located at | ||
# | ||
# http://aws.amazon.com/apache2.0/ | ||
# | ||
# or in the "license" file accompanying this file. This file is distributed on | ||
# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either | ||
# express or implied. See the License for the specific language governing | ||
# permissions and limitations under the License. | ||
|
||
import os | ||
import random | ||
import sys | ||
from tempfile import TemporaryDirectory | ||
from typing import Optional, Tuple | ||
from unittest.mock import patch | ||
|
||
import mxnet as mx | ||
import numpy as np | ||
|
||
import sockeye.bleu | ||
import sockeye.constants as C | ||
import sockeye.train | ||
import sockeye.translate | ||
import sockeye.utils | ||
|
||
|
||
def gaussian_vector(shape, return_symbol=False): | ||
""" | ||
Generates random normal tensors (diagonal covariance) | ||
:param shape: shape of the tensor. | ||
:param return_symbol: True if the result should be a Symbol, False if it should be an Numpy array. | ||
:return: A gaussian tensor. | ||
""" | ||
return mx.sym.random_normal(shape=shape) if return_symbol else np.random.normal(size=shape) | ||
|
||
|
||
def integer_vector(shape, max_value, return_symbol=False): | ||
""" | ||
Generates a random positive integer tensor | ||
:param shape: shape of the tensor. | ||
:param max_value: maximum integer value. | ||
:param return_symbol: True if the result should be a Symbol, False if it should be an Numpy array. | ||
:return: A random integer tensor. | ||
""" | ||
return mx.sym.round(mx.sym.random_uniform(shape=shape) * max_value) if return_symbol \ | ||
else np.round(np.random.uniform(size=shape) * max_value) | ||
|
||
|
||
def uniform_vector(shape, min_value=0, max_value=1, return_symbol=False): | ||
""" | ||
Generates a uniformly random tensor | ||
:param shape: shape of the tensor | ||
:param min_value: minimum possible value | ||
:param max_value: maximum possible value (exclusive) | ||
:param return_symbol: True if the result should be a mx.sym.Symbol, False if it should be a Numpy array | ||
:return: | ||
""" | ||
return mx.sym.random_uniform(low=min_value, high=max_value, shape=shape) if return_symbol \ | ||
else np.random.uniform(low=min_value, high=max_value, size=shape) | ||
|
||
|
||
def generate_random_sentence(vocab_size, max_len): | ||
""" | ||
Generates a random "sentence" as a list of integers. | ||
:param vocab_size: Number of words in the "vocabulary". Note that due to | ||
the inclusion of special words (BOS, EOS, UNK) this does *not* | ||
correspond to the maximum possible value. | ||
:param max_len: maximum sentence length. | ||
""" | ||
length = random.randint(1, max_len) | ||
# Due to the special words, the actual words start at index 3 and go up to vocab_size+2 | ||
return [random.randint(3, vocab_size + 2) for _ in range(length)] | ||
|
||
|
||
_DIGITS = "0123456789" | ||
|
||
|
||
def generate_digits_file(source_path: str, | ||
target_path: str, | ||
line_count: int = 100, | ||
line_length: int = 9, | ||
sort_target: bool = False): | ||
with open(source_path, "w") as source_out, open(target_path, "w") as target_out: | ||
for _ in range(line_count): | ||
digits = [random.choice(_DIGITS) for _ in range(random.randint(1, line_length))] | ||
print(" ".join(digits), file=source_out) | ||
if sort_target: | ||
digits.sort() | ||
print(" ".join(digits), file=target_out) | ||
|
||
|
||
_TRAIN_PARAMS_COMMON = "--use-cpu --max-seq-len {max_len} --source {train_source} --target {train_target}" \ | ||
" --validation-source {dev_source} --validation-target {dev_target} --output {model}" | ||
|
||
|
||
_TRANSLATE_PARAMS_COMMON = "--use-cpu --models {model} --input {input} --output {output}" | ||
|
||
|
||
def run_train_translate(train_params: str, | ||
translate_params: str, | ||
train_source_path: str, | ||
train_target_path: str, | ||
dev_source_path: str, | ||
dev_target_path: str, | ||
max_seq_len: int = 10, | ||
work_dir: Optional[str] = None) -> Tuple[float, float]: | ||
""" | ||
Train a model and translate a dev set. Report perplexity and BLEU. | ||
:param train_params: Command line args for model training. | ||
:param translate_params: Command line args for translation. | ||
:param perplexity_thresh: Maximum perplexity for success | ||
:param bleu_thresh: Minimum BLEU score for success | ||
:return: (perplexity, bleu) | ||
""" | ||
with TemporaryDirectory(dir=work_dir, prefix="test_train_translate.") as work_dir: | ||
|
||
# Train model | ||
model_path = os.path.join(work_dir, "model") | ||
params = "{} {} {}".format(sockeye.train.__file__, | ||
_TRAIN_PARAMS_COMMON.format(train_source=train_source_path, | ||
train_target=train_target_path, | ||
dev_source=dev_source_path, | ||
dev_target=dev_target_path, | ||
model=model_path, | ||
max_len=max_seq_len), | ||
train_params) | ||
with patch.object(sys, "argv", params.split()): | ||
sockeye.train.main() | ||
|
||
# Translate corpus | ||
out_path = os.path.join(work_dir, "out.txt") | ||
params = "{} {} {}".format(sockeye.translate.__file__, | ||
_TRANSLATE_PARAMS_COMMON.format(model=model_path, | ||
input=dev_source_path, | ||
output=out_path), | ||
translate_params) | ||
with patch.object(sys, "argv", params.split()): | ||
sockeye.translate.main() | ||
|
||
# Measure perplexity | ||
checkpoints = sockeye.utils.read_metrics_points(path=os.path.join(model_path, C.METRICS_NAME), | ||
model_path=model_path, | ||
metric=C.PERPLEXITY) | ||
perplexity = checkpoints[-1][0] | ||
|
||
# Measure BLEU | ||
bleu = sockeye.bleu.corpus_bleu(open(out_path, "r").readlines(), | ||
open(dev_target_path, "r").readlines()) | ||
|
||
return perplexity, bleu |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
# Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"). You may not | ||
# use this file except in compliance with the License. A copy of the License | ||
# is located at | ||
# | ||
# http://aws.amazon.com/apache2.0/ | ||
# | ||
# or in the "license" file accompanying this file. This file is distributed on | ||
# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either | ||
# express or implied. See the License for the specific language governing | ||
# permissions and limitations under the License. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
# Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"). You may not | ||
# use this file except in compliance with the License. A copy of the License | ||
# is located at | ||
# | ||
# http://aws.amazon.com/apache2.0/ | ||
# | ||
# or in the "license" file accompanying this file. This file is distributed on | ||
# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either | ||
# express or implied. See the License for the specific language governing | ||
# permissions and limitations under the License. | ||
|
||
import os | ||
from tempfile import TemporaryDirectory | ||
|
||
import pytest | ||
|
||
from test.common import generate_digits_file, run_train_translate | ||
|
||
_TRAIN_LINE_COUNT = 100 | ||
_DEV_LINE_COUNT = 10 | ||
_LINE_MAX_LENGTH = 9 | ||
|
||
@pytest.mark.parametrize("train_params, translate_params", [ | ||
# "Vanilla" LSTM encoder-decoder with attention | ||
("--encoder rnn --rnn-num-layers 1 --rnn-cell-type lstm --rnn-num-hidden 16 --num-embed 8 --attention-type mlp" | ||
" --attention-num-hidden 16 --batch-size 8 --loss cross-entropy --optimized-metric perplexity --max-updates 10" | ||
" --checkpoint-frequency 10 --optimizer adam --initial-learning-rate 0.01", | ||
"--beam-size 2"), | ||
# "Kitchen sink" LSTM encoder-decoder with attention | ||
("--encoder rnn --rnn-num-layers 4 --rnn-cell-type lstm --rnn-num-hidden 16 --rnn-residual-connections" | ||
" --num-embed 16 --attention-type coverage --attention-num-hidden 16 --weight-tying --attention-use-prev-word" | ||
" --context-gating --layer-normalization --batch-size 8 --loss smoothed-cross-entropy" | ||
" --smoothed-cross-entropy-alpha 0.1 --normalize-loss --optimized-metric perplexity --max-updates 10" | ||
" --checkpoint-frequency 10 --dropout 0.1 --optimizer adam --initial-learning-rate 0.01", | ||
"--beam-size 2"), | ||
# Convolutional embedding encoder + LSTM encoder-decoder with attention | ||
("--encoder rnn-with-conv-embed --conv-embed-max-filter-width 3 --conv-embed-num-filters 4 4 8" | ||
" --conv-embed-pool-stride 2 --conv-embed-num-highway-layers 1 --rnn-num-layers 1 --rnn-cell-type lstm" | ||
" --rnn-num-hidden 16 --num-embed 8 --attention-num-hidden 16 --batch-size 8 --loss cross-entropy" | ||
" --optimized-metric perplexity --max-updates 10 --checkpoint-frequency 10 --optimizer adam" | ||
" --initial-learning-rate 0.01", | ||
"--beam-size 2"), | ||
]) | ||
|
||
def test_seq_copy(train_params, translate_params): | ||
"""Task: copy short sequences of digits""" | ||
with TemporaryDirectory(prefix="test_seq_copy") as work_dir: | ||
# Simple digits files for train/dev data | ||
train_source_path = os.path.join(work_dir, "train.src") | ||
train_target_path = os.path.join(work_dir, "train.tgt") | ||
dev_source_path = os.path.join(work_dir, "dev.src") | ||
dev_target_path = os.path.join(work_dir, "dev.tgt") | ||
generate_digits_file(train_source_path, train_target_path, _TRAIN_LINE_COUNT, _LINE_MAX_LENGTH) | ||
generate_digits_file(dev_source_path, dev_target_path, _DEV_LINE_COUNT, _LINE_MAX_LENGTH) | ||
# Test model configuration | ||
# Ignore return values (perplexity and BLEU) for integration test | ||
run_train_translate(train_params, | ||
translate_params, | ||
train_source_path, | ||
train_target_path, | ||
dev_source_path, | ||
dev_target_path, | ||
max_seq_len=_LINE_MAX_LENGTH + 1, | ||
work_dir=work_dir) |
Oops, something went wrong.