-
Notifications
You must be signed in to change notification settings - Fork 3.8k
[Relax][PyTorch] Add support for decomposed operators and fix IR of ops tests(2) #18403
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1103,8 +1103,8 @@ def main( | |
| return gv | ||
|
|
||
| example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) | ||
| verify_model(Softmax(), example_args, {}, expected1) | ||
| verify_model(Softmax2(), example_args, {}, expected1) | ||
| verify_model(Softmax(), example_args, {}, expected1, run_ep_decomposition=True) | ||
| verify_model(Softmax2(), example_args, {}, expected1, run_ep_decomposition=True) | ||
|
|
||
|
|
||
| def test_softsign(): | ||
|
|
@@ -1135,8 +1135,8 @@ def main( | |
| return gv | ||
|
|
||
| example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) | ||
| verify_model(Softsign(), example_args, {}, expected_softsign) | ||
| verify_model(Softsign2(), example_args, {}, expected_softsign) | ||
| verify_model(Softsign(), example_args, {}, expected_softsign, run_ep_decomposition=True) | ||
| verify_model(Softsign2(), example_args, {}, expected_softsign, run_ep_decomposition=True) | ||
|
|
||
|
|
||
| def test_softshrink(): | ||
|
|
@@ -1159,32 +1159,24 @@ def main( | |
| input: R.Tensor((1, 3, 10, 10), dtype="float32"), | ||
| ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): | ||
| with R.dataflow(): | ||
| lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.subtract( | ||
| input, R.const(0.5, "float32") | ||
| ) | ||
| lv1: R.Tensor((1, 3, 10, 10), dtype="bool") = R.greater( | ||
| input, R.const(0.5, "float32") | ||
| lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.abs(input) | ||
| lv1: R.Tensor((1, 3, 10, 10), dtype="bool") = R.greater(lv, R.const(0.5, "float32")) | ||
| lv2: R.Tensor((1, 3, 10, 10), dtype="float32") = R.sign(input) | ||
| lv3: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply( | ||
| lv2, R.const(0.5, "float32") | ||
| ) | ||
| lv2: R.Tensor((1, 3, 10, 10), dtype="float32") = R.astype(lv1, "float32") | ||
| lv3: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply(lv, lv2) | ||
|
|
||
| lv4: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add( | ||
| input, R.const(0.5, "float32") | ||
| lv4: R.Tensor((1, 3, 10, 10), dtype="float32") = R.subtract(input, lv3) | ||
| lv5: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply( | ||
| input, R.const(0.0, "float32") | ||
| ) | ||
| lv5: R.Tensor((), dtype="float32") = R.negative(R.const(0.5, "float32")) | ||
| lv6: R.Tensor((1, 3, 10, 10), dtype="bool") = R.less(input, lv5) | ||
| lv7: R.Tensor((1, 3, 10, 10), dtype="float32") = R.astype(lv6, "float32") | ||
| lv8: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply(lv4, lv7) | ||
|
|
||
| lv9: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(lv3, lv8) | ||
|
|
||
| gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv9,) | ||
| lv6: R.Tensor((1, 3, 10, 10), dtype="float32") = R.where(lv1, lv4, lv5) | ||
| gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv6,) | ||
| R.output(gv) | ||
| return gv | ||
|
|
||
| example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) | ||
| verify_model(Softshrink(), example_args, {}, expected_softshrink) | ||
| verify_model(Softshrink2(), example_args, {}, expected_softshrink) | ||
| verify_model(Softshrink(), example_args, {}, expected_softshrink, run_ep_decomposition=True) | ||
| verify_model(Softshrink2(), example_args, {}, expected_softshrink, run_ep_decomposition=True) | ||
|
|
||
|
|
||
| def test_tril_triu(): | ||
|
|
@@ -1198,16 +1190,27 @@ def forward(self, input): | |
| class expected_tril: | ||
| @R.function | ||
| def main( | ||
| input_1: R.Tensor((10, 10), dtype="float32") | ||
| input: R.Tensor((10, 10), dtype="float32") | ||
| ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): | ||
| # block 0 | ||
| with R.dataflow(): | ||
| lv: R.Tensor((10, 10), dtype="float32") = R.tril(input_1, 1) | ||
| gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,) | ||
| lv: R.Tensor((10,), dtype="int64") = R.arange( | ||
| R.prim_value(0), R.prim_value(10), R.prim_value(1), dtype="int64" | ||
| ) | ||
| lv1: R.Tensor((1, 10), dtype="int64") = R.expand_dims(lv, axis=[-2]) | ||
| lv2: R.Tensor((10,), dtype="int64") = R.arange( | ||
| R.prim_value(0), R.prim_value(10), R.prim_value(1), dtype="int64" | ||
| ) | ||
| lv3: R.Tensor((10, 1), dtype="int64") = R.expand_dims(lv2, axis=[-1]) | ||
|
Comment on lines
+1201
to
+1204
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The lv3: R.Tensor((10, 1), dtype="int64") = R.expand_dims(lv, axis=[-1]) |
||
| lv4: R.Tensor((10, 10), dtype="int64") = R.subtract(lv1, lv3) | ||
| lv5: R.Tensor((10, 10), dtype="bool") = R.less_equal(lv4, R.const(1, "int64")) | ||
| lv6: R.Tensor((), dtype="float32") = R.const(0.0, "float32") | ||
| lv7: R.Tensor((10, 10), dtype="float32") = R.where(lv5, input, lv6) | ||
| gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv7,) | ||
| R.output(gv) | ||
| return gv | ||
|
|
||
| verify_model(Tril(), example_args, {}, expected_tril) | ||
| verify_model(Tril(), example_args, {}, expected_tril, run_ep_decomposition=True) | ||
|
|
||
| class Triu(Module): | ||
| def forward(self, input): | ||
|
|
@@ -1217,16 +1220,27 @@ def forward(self, input): | |
| class expected_triu: | ||
| @R.function | ||
| def main( | ||
| input_1: R.Tensor((10, 10), dtype="float32") | ||
| input: R.Tensor((10, 10), dtype="float32") | ||
| ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): | ||
| # block 0 | ||
| with R.dataflow(): | ||
| lv: R.Tensor((10, 10), dtype="float32") = R.triu(input_1, 1) | ||
| gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,) | ||
| lv: R.Tensor((10,), dtype="int64") = R.arange( | ||
| R.prim_value(0), R.prim_value(10), R.prim_value(1), dtype="int64" | ||
| ) | ||
| lv1: R.Tensor((1, 10), dtype="int64") = R.expand_dims(lv, axis=[-2]) | ||
| lv2: R.Tensor((10,), dtype="int64") = R.arange( | ||
| R.prim_value(0), R.prim_value(10), R.prim_value(1), dtype="int64" | ||
| ) | ||
| lv3: R.Tensor((10, 1), dtype="int64") = R.expand_dims(lv2, axis=[-1]) | ||
|
Comment on lines
+1231
to
+1234
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| lv4: R.Tensor((10, 10), dtype="int64") = R.subtract(lv1, lv3) | ||
| lv5: R.Tensor((10, 10), dtype="bool") = R.greater_equal(lv4, R.const(1, "int64")) | ||
| lv6: R.Tensor((), dtype="float32") = R.const(0.0, "float32") | ||
| lv7: R.Tensor((10, 10), dtype="float32") = R.where(lv5, input, lv6) | ||
| gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv7,) | ||
| R.output(gv) | ||
| return gv | ||
|
|
||
| verify_model(Triu(), example_args, {}, expected_triu) | ||
| verify_model(Triu(), example_args, {}, expected_triu, run_ep_decomposition=True) | ||
|
|
||
|
|
||
| operator_binary_1 = [ | ||
|
|
@@ -1501,7 +1515,7 @@ def main( | |
| torch.randn(64, 64, dtype=torch.float32), | ||
| torch.randn(64, dtype=torch.float32), | ||
| ) | ||
| verify_model(DivModel(), example_args, {}, expected_div) | ||
| verify_model(DivModel(), example_args, {}, expected_div, run_ep_decomposition=True) | ||
|
|
||
| # Case 2: Division with trunc rounding | ||
| class DivTruncModel(torch.nn.Module): | ||
|
|
@@ -1521,7 +1535,7 @@ def main( | |
| R.output(gv) | ||
| return gv | ||
|
|
||
| verify_model(DivTruncModel(), example_args, {}, expected_div_trunc) | ||
| verify_model(DivTruncModel(), example_args, {}, expected_div_trunc, run_ep_decomposition=True) | ||
|
|
||
| # Case 3: Division with floor rounding | ||
| class DivFloorModel(torch.nn.Module): | ||
|
|
@@ -1540,7 +1554,7 @@ def main( | |
| R.output(gv) | ||
| return gv | ||
|
|
||
| verify_model(DivFloorModel(), example_args, {}, expected_div_floor) | ||
| verify_model(DivFloorModel(), example_args, {}, expected_div_floor, run_ep_decomposition=True) | ||
|
|
||
|
|
||
| def test_batchnorm2d(): | ||
|
|
@@ -1578,6 +1592,8 @@ def main( | |
| epsilon=1e-05, | ||
| center=True, | ||
| scale=True, | ||
| momentum=1e-05, | ||
| training=False, | ||
| ) | ||
| lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = lv[0] | ||
| gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv1,) | ||
|
|
@@ -1593,7 +1609,7 @@ def main( | |
| "w3": model.bn.running_mean.detach().numpy(), | ||
| "w4": model.bn.running_var.detach().numpy(), | ||
| } | ||
| verify_model(model, example_args, binding, expected1) | ||
| verify_model(model, example_args, binding, expected1, run_ep_decomposition=True) | ||
|
|
||
|
|
||
| def test_adaptive_avgpool1d(): | ||
|
|
@@ -1748,8 +1764,8 @@ def main( | |
| torch.randn(10, 10, dtype=torch.float32), | ||
| ) | ||
|
|
||
| verify_model(Addmm1(), example_args, {}, expected1) | ||
| verify_model(Addmm2(), example_args, {}, expected2) | ||
| verify_model(Addmm1(), example_args, {}, expected1, run_ep_decomposition=True) | ||
| verify_model(Addmm2(), example_args, {}, expected2, run_ep_decomposition=True) | ||
|
|
||
|
|
||
| def test_avg_pool1d(): | ||
|
|
@@ -2054,8 +2070,10 @@ def main( | |
| inp_2: R.Tensor((4, 256, 512), dtype="float32"), | ||
| ) -> R.Tuple(R.Tensor((4, 128, 512), dtype="float32")): | ||
| with R.dataflow(): | ||
| lv: R.Tensor((4, 128, 512), dtype="float32") = R.matmul(inp_1, inp_2) | ||
| lv1: R.Tensor((4, 128, 512), dtype="float32") = R.add(lv, inp_0) | ||
| lv: R.Tensor((4, 128, 512), dtype="float32") = R.matmul( | ||
| inp_1, inp_2, out_dtype="float32" | ||
| ) | ||
| lv1: R.Tensor((4, 128, 512), dtype="float32") = R.add(inp_0, lv) | ||
| gv: R.Tuple(R.Tensor((4, 128, 512), dtype="float32")) = (lv1,) | ||
| R.output(gv) | ||
| return gv | ||
|
|
@@ -2076,7 +2094,9 @@ def main( | |
| inp_2: R.Tensor((4, 256, 512), dtype="float32"), | ||
| ) -> R.Tuple(R.Tensor((4, 128, 512), dtype="float32")): | ||
| with R.dataflow(): | ||
| lv: R.Tensor((4, 128, 512), dtype="float32") = R.matmul(inp_1, inp_2) | ||
| lv: R.Tensor((4, 128, 512), dtype="float32") = R.matmul( | ||
| inp_1, inp_2, out_dtype="float32" | ||
| ) | ||
| lv1: R.Tensor((4, 128, 512), dtype="float32") = R.multiply( | ||
| lv, R.const(2, "float32") | ||
| ) | ||
|
|
@@ -2100,14 +2120,16 @@ def main( | |
| inp_2: R.Tensor((4, 256, 512), dtype="float32"), | ||
| ) -> R.Tuple(R.Tensor((4, 128, 512), dtype="float32")): | ||
| with R.dataflow(): | ||
| lv: R.Tensor((4, 128, 512), dtype="float32") = R.matmul(inp_1, inp_2) | ||
| lv: R.Tensor((4, 128, 512), dtype="float32") = R.matmul( | ||
| inp_1, inp_2, out_dtype="float32" | ||
| ) | ||
| lv1: R.Tensor((4, 128, 512), dtype="float32") = R.multiply( | ||
| lv, R.const(2, "float32") | ||
| ) | ||
| lv2: R.Tensor((4, 128, 512), dtype="float32") = R.multiply( | ||
| inp_0, R.const(3, "float32") | ||
| ) | ||
| lv3: R.Tensor((4, 128, 512), dtype="float32") = R.add(lv1, lv2) | ||
| lv3: R.Tensor((4, 128, 512), dtype="float32") = R.add(lv2, lv1) | ||
| gv: R.Tuple(R.Tensor((4, 128, 512), dtype="float32")) = (lv3,) | ||
| R.output(gv) | ||
| return gv | ||
|
|
@@ -2122,20 +2144,23 @@ def main( | |
| example_args, | ||
| {}, | ||
| Expected1, | ||
| run_ep_decomposition=True, | ||
| ) | ||
|
|
||
| verify_model( | ||
| BAddBMM2(), | ||
| example_args, | ||
| {}, | ||
| Expected2, | ||
| run_ep_decomposition=True, | ||
| ) | ||
|
|
||
| verify_model( | ||
| BAddBMM3(), | ||
| example_args, | ||
| {}, | ||
| Expected3, | ||
| run_ep_decomposition=True, | ||
| ) | ||
|
|
||
|
|
||
|
|
@@ -2172,6 +2197,7 @@ def main( | |
| example_args, | ||
| {}, | ||
| Expected, | ||
| run_ep_decomposition=True, | ||
| ) | ||
|
|
||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using
R.zeros_like(input)would be more idiomatic and clearer to express the intent of creating a zero tensor with the same shape and dtype asinput.