Skip to content

Commit

Permalink
[ONNX] Export get/set attribute nodes (pytorch#50768)
Browse files Browse the repository at this point in the history
Fix get/set attributes when getting/setting a model parameter.
This PR also fixes inplace ops in If blocks.

ghstack-source-id: 6ed01725c9ea350544853498ddef7c0d8641ccae
Pull Request resolved: pytorch#51517
  • Loading branch information
BowenBao committed Feb 3, 2021
1 parent d7dd12a commit 287482e
Show file tree
Hide file tree
Showing 9 changed files with 835 additions and 158 deletions.
8 changes: 6 additions & 2 deletions scripts/onnx/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ pytest "${args[@]}" \
--ignore "$top_dir/test/onnx/test_pytorch_onnx_caffe2.py" \
--ignore "$top_dir/test/onnx/test_pytorch_onnx_shape_inference.py" \
--ignore "$top_dir/test/onnx/test_pytorch_onnx_onnxruntime_cuda.py" \
--ignore "$top_dir/test/onnx/test_pytorch_onnx_caffe2_quantized.py" \
"${test_paths[@]}"

# onnxruntime only support py3
Expand All @@ -68,10 +69,12 @@ if [[ "$BUILD_ENVIRONMENT" == *ort_test1* ]]; then
"$top_dir/test/onnx/test_pytorch_onnx_onnxruntime.py::TestONNXRuntime_opset7" \
"$top_dir/test/onnx/test_pytorch_onnx_onnxruntime.py::TestONNXRuntime_opset8" \
"$top_dir/test/onnx/test_pytorch_onnx_onnxruntime.py::TestONNXRuntime" \
"$top_dir/test/onnx/test_pytorch_onnx_onnxruntime.py::TestONNXRuntime_IRv4_old_jit_API" \
"$top_dir/test/onnx/test_custom_ops.py" \
"$top_dir/test/onnx/test_models_onnxruntime.py" \
"$top_dir/test/onnx/test_utility_funs.py" \
"$top_dir/test/onnx/test_pytorch_onnx_caffe2.py"
"$top_dir/test/onnx/test_pytorch_onnx_caffe2.py" \
"$top_dir/test/onnx/test_pytorch_onnx_caffe2_quantized.py"
fi
if [[ "$BUILD_ENVIRONMENT" == *ort_test2* ]]; then
# Update the loop for new opsets
Expand All @@ -80,7 +83,8 @@ if [[ "$BUILD_ENVIRONMENT" == *ort_test2* ]]; then
"$top_dir/test/onnx/test_pytorch_onnx_onnxruntime.py::TestONNXRuntime_opset$i"
done
pytest "${args[@]}" \
"$top_dir/test/onnx/test_pytorch_onnx_onnxruntime.py::TestONNXRuntime_opset12_onnx_shape_inference"
"$top_dir/test/onnx/test_pytorch_onnx_onnxruntime.py::TestONNXRuntime_opset12_onnx_shape_inference" \
"$top_dir/test/onnx/test_pytorch_onnx_onnxruntime.py::TestONNXRuntime_opset12_IRv4_old_jit_API"
fi

# Our CI expects both coverage.xml and .coverage to be within test/
Expand Down
12 changes: 12 additions & 0 deletions test/onnx/test_pytorch_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,18 @@ def wrapper(self):
return wrapper
return script_dec


# Disable tests for old jit passes.
def disableOldJitPassesTest():
def script_dec(func):
def wrapper(self):
if not self.use_new_jit_passes:
raise unittest.SkipTest("Skip test for old jit API.")
return func(self)
return wrapper
return script_dec


# skips tests for opset_versions listed in unsupported_opset_versions.
# if the caffe2 test cannot be run for a specific version, add this wrapper
# (for example, an op was modified but the change is not supported in caffe2)
Expand Down

0 comments on commit 287482e

Please sign in to comment.