Skip to content

Commit

Permalink
Update onnx_model_test with tests on cntk pretrained models
Browse files Browse the repository at this point in the history
  • Loading branch information
BowenBao committed Sep 19, 2018
1 parent 6f09c39 commit 0a3eb3b
Showing 1 changed file with 48 additions and 13 deletions.
61 changes: 48 additions & 13 deletions bindings/python/cntk/tests/onnx_model_test.py
Expand Up @@ -15,7 +15,7 @@
# To test models locally, create folder 'onnx_models' and put in model folders.
# For example.
# .
# +-- onnx_models
# +-- onnx_models # models stored in 'model.onnx' onnx format.
# | +-- model1
# | | +-- model.onnx
# | | +-- test_data_set_0
Expand All @@ -28,9 +28,18 @@
# | | | +-- output_0.pb
# | +-- model2
# ...
base_dir = 'onnx_models' if not 'CNTK_EXTERNAL_TESTDATA_SOURCE_DIRECTORY' in os.environ else os.path.join(os.environ['CNTK_EXTERNAL_TESTDATA_SOURCE_DIRECTORY'], 'onnx_models')
model_names = [dir for dir in os.listdir(base_dir)
if os.path.isdir(os.path.join(base_dir, dir))] if os.path.exists(base_dir) else []
# +-- PretrainedModelsV2 # models stored in '.model' CNTKv2 format.
# | +-- model1.model
# | +-- model2.model
# ...
def get_base_dir(base_dir):
return base_dir if not 'CNTK_EXTERNAL_TESTDATA_SOURCE_DIRECTORY' in os.environ else os.path.join(os.environ['CNTK_EXTERNAL_TESTDATA_SOURCE_DIRECTORY'], base_dir)
onnx_base_dir = get_base_dir('onnx_models')
onnx_model_names = [dir for dir in os.listdir(onnx_base_dir)
if os.path.isdir(os.path.join(onnx_base_dir, dir)) and os.path.exists(os.path.join(onnx_base_dir, dir, 'model.onnx'))] if os.path.exists(onnx_base_dir) else []
cntk_base_dir = get_base_dir('PretrainedModelsV2')
cntk_model_names = [dir for dir in os.listdir(cntk_base_dir)
if os.path.isfile(os.path.join(cntk_base_dir, dir)) and dir.rfind('.model') + len('.model') == len(dir)] if os.path.exists(cntk_base_dir) else []
input_filename_pattern = re.compile('input_[0-9]+.pb')
output_filename_pattern = re.compile('output_[0-9]+.pb')

Expand All @@ -46,8 +55,6 @@
'test_lstm_defaults',
'test_lstm_with_initial_bias',
'test_lstm_with_peepholes',
'test_max_example',
'test_min_example',
'test_reduce_log_sum',
'test_reduce_log_sum_asc_axes',
'test_reduce_log_sum_default',
Expand Down Expand Up @@ -95,8 +102,6 @@
'test_lstm_defaults',
'test_lstm_with_initial_bias',
'test_lstm_with_peepholes',
'test_max_example',
'test_min_example',
'test_reduce_log_sum',
'test_reduce_log_sum_asc_axes',
'test_reduce_log_sum_default',
Expand Down Expand Up @@ -134,16 +139,18 @@
'test_upsample_nearest',
]

@pytest.mark.parametrize('model_name, round_trip',
[(model_name, round_trip) for model_name in model_names for round_trip in [False, True]],
ids=['round_trip_' + model_name if round_trip else model_name for model_name in model_names for round_trip in [False, True]])
skip_cntk_model_names = []

@pytest.mark.parametrize('model_name, round_trip',
[(model_name, round_trip) for model_name in onnx_model_names for round_trip in [False, True]],
ids=['round_trip_' + model_name if round_trip else model_name for model_name in onnx_model_names for round_trip in [False, True]])
def test_onnx_model(model_name, round_trip):
if model_name in skip_model_names and not round_trip:
pytest.skip('Skip onnx model test. ')
if model_name in skip_round_trip_model_names and round_trip:
pytest.skip('Skip onnx model round trip test. ')

model_dir = os.path.join(base_dir, model_name)
model_dir = os.path.join(onnx_base_dir, model_name)
model = C.Function.load(os.path.join(model_dir, 'model.onnx'), format=C.ModelFormat.ONNX)

if round_trip:
Expand Down Expand Up @@ -190,4 +197,32 @@ def test_onnx_model(model_name, round_trip):
ref_outputs[i],
outputs[i],
rtol=1e-3,
atol=1e-4)
atol=1e-4)

@pytest.mark.parametrize('model_name',
[model_name for model_name in cntk_model_names],
ids=[model_name for model_name in cntk_model_names])
def test_cntk_model(model_name):
if model_name in skip_cntk_model_names:
pytest.skip('Skip cntk model test. ')
model_dir = os.path.join(cntk_base_dir, model_name)
model = C.Function.load(model_dir, format=C.ModelFormat.CNTKv2)

resave_model_path = 'model_resave.onnx'
model.save(resave_model_path, format=C.ModelFormat.ONNX)
reloaded_model = C.Function.load(resave_model_path, format=C.ModelFormat.ONNX)

np.random.seed(3)
input_shape = (1,) + model.arguments[0].shape
data_x = np.asarray(np.random.uniform(-1, 1, input_shape), dtype=np.float32)
data_y = model.eval({model.arguments[0]:data_x})
data_y_ = reloaded_model.eval({reloaded_model.arguments[0]:data_x})

np.testing.assert_equal(len(data_y), len(data_y_))
for i in range(len(data_y)):
np.testing.assert_equal(data_y[i].dtype, data_y_[i].dtype)
np.testing.assert_allclose(
data_y[i],
data_y_[i],
rtol=1e-3,
atol=1e-4)

0 comments on commit 0a3eb3b

Please sign in to comment.