-
Notifications
You must be signed in to change notification settings - Fork 964
/
test_keras.py
38 lines (30 loc) · 1.16 KB
/
test_keras.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from __future__ import absolute_import
from __future__ import print_function
import os
import sys
import six
from conversion_imagenet import TestModels
def get_test_table():
if six.PY3:
return None
ONNX = os.environ.get('TEST_ONNX')
if ONNX and ONNX.lower() == 'true':
return {
'keras' : {
'vgg16' : [TestModels.onnx_emit],
'vgg19' : [TestModels.onnx_emit],
# 'nasnet' : [TestModels.onnx_emit],
},
}
else:
return {
'keras' : {
'vgg19' : [TestModels.caffe_emit, TestModels.cntk_emit, TestModels.coreml_emit, TestModels.keras_emit, TestModels.mxnet_emit, TestModels.pytorch_emit, TestModels.tensorflow_emit],
'inception_v3' : [TestModels.caffe_emit, TestModels.cntk_emit, TestModels.coreml_emit, TestModels.keras_emit, TestModels.mxnet_emit, TestModels.pytorch_emit, TestModels.tensorflow_emit],
}}
def test_keras():
test_table = get_test_table()
tester = TestModels(test_table)
tester._test_function('keras', tester.keras_parse)
if __name__ == '__main__':
test_keras()