Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for GridSample operator #2909

Merged
merged 55 commits into from
Jun 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
75f9589
Basic GridSample parsing
gyulaz-htec Mar 12, 2024
a87d759
Fix where operator compute
gyulaz-htec Mar 18, 2024
2658929
Fix attribute parsing
gyulaz-htec Mar 18, 2024
a977331
Refactoring
gyulaz-htec Mar 18, 2024
0d26f8c
python formatting
gyulaz-htec Mar 18, 2024
bdaf8f2
GridSample verify tests
gyulaz-htec Mar 19, 2024
d4eee76
nearest_sampler
gyulaz-htec Mar 19, 2024
f6e2d20
grid_sampler
gyulaz-htec Mar 19, 2024
2bce181
linear_sampler
gyulaz-htec Mar 20, 2024
34a785d
Clean-up
gyulaz-htec Mar 20, 2024
6f282b3
GridSample border padding test
gyulaz-htec Mar 21, 2024
3d2e3a6
Fix border padding test
gyulaz-htec Mar 21, 2024
63e9d1e
Tidy fixes
gyulaz-htec Mar 22, 2024
d47721e
More GridSample tests
gyulaz-htec Mar 22, 2024
174a7c3
Fix license
gyulaz-htec Mar 22, 2024
d5c20be
GridSample tests for wrong inputs
gyulaz-htec Mar 22, 2024
466e858
Remove not needed throw
gyulaz-htec Mar 25, 2024
e70ba64
Fix cppcheck
gyulaz-htec Mar 25, 2024
ac00e20
Add reflection padding
gyulaz-htec Mar 26, 2024
65ea085
Fix cppcheck
gyulaz-htec Mar 27, 2024
b9f3160
GridSample handle other output types other than float32
gyulaz-htec Mar 27, 2024
bc5f546
Fix wrong shape type in GridSample parsing
gyulaz-htec Mar 28, 2024
a70f78e
Fix float to int conversion in reflect_coordinates
gyulaz-htec Mar 28, 2024
d6652cc
Exclude logical_and and where from pointwise fusion
gyulaz-htec Mar 28, 2024
bf35db4
Fix fuse_pointwise
gyulaz-htec Mar 28, 2024
cd390b9
Test disable pointwise fusion
gyulaz-htec Apr 8, 2024
c4e87fe
Fix codecov
gyulaz-htec Apr 10, 2024
d9f4745
Address review changes
gyulaz-htec Apr 11, 2024
8f7585b
Fix license
gyulaz-htec Apr 11, 2024
74ed9f6
Update onnx operator list
gyulaz-htec Apr 12, 2024
a4b6492
Eliminate gathernd from loops
gyulaz-htec Apr 15, 2024
66b7cf9
Merge branch 'develop' into grid_sample
gyulaz-htec Apr 17, 2024
cf79e31
Merge branch 'develop' into grid_sample
gyulaz-htec Apr 18, 2024
713d852
Coding style fixes
gyulaz-htec Apr 19, 2024
b8fbc94
Merge branch 'develop' into grid_sample
gyulaz-htec Apr 19, 2024
6ec8bb7
Merge branch 'develop' into grid_sample
gyulaz-htec Apr 22, 2024
8f35c36
Merge branch 'develop' into grid_sample
gyulaz-htec Apr 23, 2024
b867046
Merge branch 'develop' into grid_sample
gyulaz-htec May 6, 2024
33667b5
Revert pointwise fusion changes made for GridSample
gyulaz-htec May 7, 2024
244e1f3
Merge branch 'develop' into grid_sample
gyulaz-htec May 8, 2024
44b895d
Merge branch 'develop' into grid_sample
gyulaz-htec May 10, 2024
cfffabc
Merge branch 'develop' into grid_sample
gyulaz-htec May 13, 2024
d55c08b
Merge branch 'develop' into grid_sample
gyulaz-htec May 15, 2024
b9fef94
Merge branch 'develop' into grid_sample
gyulaz-htec May 17, 2024
7c4859e
Merge branch 'develop' into grid_sample
gyulaz-htec May 20, 2024
26e887e
Merge branch 'develop' into grid_sample
gyulaz-htec May 21, 2024
2202dd9
Fix gridsample test onnx loading
gyulaz-htec May 21, 2024
747c4e5
Merge branch 'develop' into grid_sample
gyulaz-htec May 21, 2024
a1cf519
Fix windows build issue
gyulaz-htec May 23, 2024
8bb3d24
Merge branch 'develop' into grid_sample
gyulaz-htec May 24, 2024
cd82d55
Merge branch 'develop' into grid_sample
gyulaz-htec May 28, 2024
feffd2f
Merge branch 'develop' into grid_sample
gyulaz-htec May 29, 2024
3652b30
Merge branch 'develop' into grid_sample
gyulaz-htec Jun 5, 2024
ec18285
Merge branch 'develop' into grid_sample
gyulaz-htec Jun 6, 2024
b6172d3
Merge branch 'develop' into grid_sample
TedThemistokleous Jun 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion docs/dev/onnx_operators.rst
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,10 @@ Operator Support Matrix
| | | FP8, FP16, | |
| | | FP32, FP64 | |
+--------------------------+-----------+-----------------+------------------------------+
| GridSample | 👷 | 👷 | |
| GridSample | ✅ | UINT32, UINT64, | ``bicubic`` |
| | | INT32, INT64, | mode not supported, |
| | | FP16, FP32 | `5-D inputs` |
| | | FP64 | not supported |
+--------------------------+-----------+-----------------+------------------------------+
| GroupNormalization | ✅ | FP8, FP16, | ``stash_type`` |
| | | FP32, FP64 | not supported |
Expand Down
448 changes: 448 additions & 0 deletions src/onnx/parse_gridsample.cpp

Large diffs are not rendered by default.

317 changes: 317 additions & 0 deletions test/onnx/gen_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -3832,6 +3832,323 @@ def greaterorequal_test():
return ([node], [x1, x2], [y])


@onnx_test()
def gridsample_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [1, 1, 4, 4])
grid = helper.make_tensor_value_info('grid', TensorProto.FLOAT,
[1, 6, 6, 2])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [1, 1, 6, 6])

node = onnx.helper.make_node(
"GridSample",
inputs=["x", "grid"],
outputs=["y"],
mode="linear",
padding_mode="zeros",
align_corners=0,
)

return ([node], [x, grid], [y])


@onnx_test()
def gridsample_half_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT16, [1, 1, 4, 4])
grid = helper.make_tensor_value_info('grid', TensorProto.FLOAT,
[1, 6, 6, 2])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT16, [1, 1, 6, 6])

node = onnx.helper.make_node(
"GridSample",
inputs=["x", "grid"],
outputs=["y"],
mode="linear",
padding_mode="zeros",
align_corners=0,
)

return ([node], [x, grid], [y])


@onnx_test()
def gridsample_int_test():
x = helper.make_tensor_value_info('x', TensorProto.INT32, [1, 1, 4, 4])
grid = helper.make_tensor_value_info('grid', TensorProto.FLOAT,
[1, 6, 6, 2])
y = helper.make_tensor_value_info('y', TensorProto.INT32, [1, 1, 6, 6])

node = onnx.helper.make_node(
"GridSample",
inputs=["x", "grid"],
outputs=["y"],
mode="linear",
padding_mode="zeros",
align_corners=0,
)

return ([node], [x, grid], [y])


@onnx_test()
def gridsample_simple_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [1, 1, 2, 2])
grid = helper.make_tensor_value_info('grid', TensorProto.FLOAT,
[1, 1, 1, 2])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [1, 1, 1, 1])

node = onnx.helper.make_node(
"GridSample",
inputs=["x", "grid"],
outputs=["y"],
mode="linear",
)

return ([node], [x, grid], [y])


@onnx_test()
def gridsample_bilinear_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [1, 1, 3, 2])
grid = helper.make_tensor_value_info('grid', TensorProto.FLOAT,
[1, 2, 4, 2])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [1, 1, 2, 4])

node = onnx.helper.make_node(
"GridSample",
inputs=["x", "grid"],
outputs=["y"],
align_corners=0,
mode="linear",
)

return ([node], [x, grid], [y])


@onnx_test()
def gridsample_aligncorners_true_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [1, 1, 3, 2])
grid = helper.make_tensor_value_info('grid', TensorProto.FLOAT,
[1, 2, 4, 2])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [1, 1, 2, 4])

node = onnx.helper.make_node(
"GridSample",
inputs=["x", "grid"],
outputs=["y"],
mode="linear",
align_corners=1,
)

return ([node], [x, grid], [y])


@onnx_test()
def gridsample_nearest_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [1, 1, 3, 2])
grid = helper.make_tensor_value_info('grid', TensorProto.FLOAT,
[1, 2, 4, 2])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [1, 1, 2, 4])

node = onnx.helper.make_node(
"GridSample",
inputs=["x", "grid"],
outputs=["y"],
mode="nearest",
)

return ([node], [x, grid], [y])


@onnx_test()
def gridsample_bicubic_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [1, 1, 3, 2])
grid = helper.make_tensor_value_info('grid', TensorProto.FLOAT,
[1, 2, 4, 2])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [1, 1, 2, 4])

node = onnx.helper.make_node(
"GridSample",
inputs=["x", "grid"],
outputs=["y"],
mode="cubic",
)

return ([node], [x, grid], [y])


@onnx_test()
def gridsample_nearest_align_corners_0_additional_1_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [1, 1, 3, 2])
grid = helper.make_tensor_value_info('grid', TensorProto.FLOAT,
[1, 2, 4, 2])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [1, 1, 2, 4])

node = onnx.helper.make_node(
"GridSample",
inputs=["x", "grid"],
outputs=["y"],
mode="nearest",
align_corners=0,
)

return ([node], [x, grid], [y])


@onnx_test()
def gridsample_nearest_align_corners_1_additional_1_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [1, 1, 3, 2])
grid = helper.make_tensor_value_info('grid', TensorProto.FLOAT,
[1, 2, 4, 2])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [1, 1, 2, 4])

node = onnx.helper.make_node(
"GridSample",
inputs=["x", "grid"],
outputs=["y"],
mode="nearest",
align_corners=1,
)

return ([node], [x, grid], [y])


@onnx_test()
def gridsample_bilinear_align_corners_0_additional_1_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [1, 1, 3, 2])
grid = helper.make_tensor_value_info('grid', TensorProto.FLOAT,
[1, 2, 4, 2])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [1, 1, 2, 4])

node = onnx.helper.make_node(
"GridSample",
inputs=["x", "grid"],
outputs=["y"],
mode="linear",
align_corners=0,
)

return ([node], [x, grid], [y])


@onnx_test()
def gridsample_bilinear_align_corners_1_additional_1_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [1, 1, 3, 2])
grid = helper.make_tensor_value_info('grid', TensorProto.FLOAT,
[1, 2, 4, 2])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [1, 1, 2, 4])

node = onnx.helper.make_node(
"GridSample",
inputs=["x", "grid"],
outputs=["y"],
mode="linear",
align_corners=1,
)

return ([node], [x, grid], [y])


@onnx_test()
def gridsample_zeros_padding_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [1, 1, 3, 2])
grid = helper.make_tensor_value_info('grid', TensorProto.FLOAT,
[1, 2, 4, 2])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [1, 1, 2, 4])

node = onnx.helper.make_node(
"GridSample",
inputs=["x", "grid"],
outputs=["y"],
padding_mode="zeros",
)

return ([node], [x, grid], [y])


@onnx_test()
def gridsample_border_padding_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [1, 1, 3, 2])
grid = helper.make_tensor_value_info('grid', TensorProto.FLOAT,
[1, 2, 4, 2])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [1, 1, 2, 4])

node = onnx.helper.make_node(
"GridSample",
inputs=["x", "grid"],
outputs=["y"],
padding_mode="border",
)

return ([node], [x, grid], [y])


@onnx_test()
def gridsample_reflection_padding_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [1, 1, 3, 2])
grid = helper.make_tensor_value_info('grid', TensorProto.FLOAT,
[1, 2, 4, 2])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [1, 1, 2, 4])

node = onnx.helper.make_node(
"GridSample",
inputs=["x", "grid"],
outputs=["y"],
padding_mode="reflection",
)

return ([node], [x, grid], [y])


@onnx_test()
def gridsample_volumetric_nearest_align_corners_0_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [1, 1, 3, 2, 2])
grid = helper.make_tensor_value_info('grid', TensorProto.FLOAT,
[1, 2, 4, 2, 3])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [1, 1, 2, 4, 2])

node = onnx.helper.make_node(
"GridSample",
inputs=["x", "grid"],
outputs=["y"],
mode="nearest",
align_corners=0,
)

return ([node], [x, grid], [y])


@onnx_test()
def gridsample_wrong_grid_type_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [1, 1, 3, 2])
grid = helper.make_tensor_value_info('grid', TensorProto.INT32,
[1, 2, 4, 2])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [1, 1, 2, 4])

node = onnx.helper.make_node(
"GridSample",
inputs=["x", "grid"],
outputs=["y"],
)

return ([node], [x, grid], [y])


@onnx_test()
def gridsample_mismatching_dims_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [1, 1, 3, 2, 2])
grid = helper.make_tensor_value_info('grid', TensorProto.FLOAT,
[1, 2, 4, 2])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [1, 1, 2, 4])

node = onnx.helper.make_node(
"GridSample",
inputs=["x", "grid"],
outputs=["y"],
)

return ([node], [x, grid], [y])


@onnx_test()
def group_conv_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 4, 16, 16])
Expand Down
25 changes: 25 additions & 0 deletions test/onnx/gridsample_aligncorners_true_test.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
 !gridsample_aligncorners_true_test:À
A
x
gridy"
GridSample*
align_corners *
mode"linear !gridsample_aligncorners_true_testZ
x




Z
grid




b
y




B
Expand Down
24 changes: 24 additions & 0 deletions test/onnx/gridsample_bicubic_test.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
 gridsample_bicubic_test:Ÿ
*
x
gridy"
GridSample*
mode"cubic gridsample_bicubic_testZ
x




Z
grid




b
y




B
Binary file not shown.
Loading
Loading