Skip to content

Commit

Permalink
[StableHLO][NFC] Add e2e scatter tests (iree-org#13781)
Browse files Browse the repository at this point in the history
These are based on the equivalent IREE tests for MHLO.

Issue: iree-org#12678
  • Loading branch information
kuhar authored and NatashaKnk committed Jul 6, 2023
1 parent 15583ef commit 540d05e
Show file tree
Hide file tree
Showing 3 changed files with 250 additions and 0 deletions.
6 changes: 6 additions & 0 deletions tests/e2e/stablehlo_ops/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ iree_check_single_backend_test_suite(
"rng_uniform.mlir",
"round.mlir",
"rsqrt.mlir",
"scatter.mlir",
"scatter_dynamic.mlir",
"select.mlir",
"sine.mlir",
Expand Down Expand Up @@ -135,6 +136,7 @@ iree_check_single_backend_test_suite(
"rng_uniform.mlir",
"round.mlir",
"rsqrt.mlir",
"scatter.mlir",
"scatter_dynamic.mlir",
"select.mlir",
"sine.mlir",
Expand Down Expand Up @@ -202,6 +204,7 @@ iree_check_single_backend_test_suite(
"rng_uniform.mlir",
"round.mlir",
"rsqrt.mlir",
"scatter.mlir",
"scatter_dynamic.mlir",
"select.mlir",
"sine.mlir",
Expand Down Expand Up @@ -273,6 +276,7 @@ iree_check_single_backend_test_suite(
"rng_uniform.mlir",
"round.mlir",
"rsqrt.mlir",
"scatter.mlir",
"scatter_dynamic.mlir",
"select.mlir",
"sine.mlir",
Expand Down Expand Up @@ -360,6 +364,7 @@ iree_check_single_backend_test_suite(
"rng_uniform.mlir",
"round.mlir",
"rsqrt.mlir",
"scatter.mlir",
"scatter_dynamic.mlir",
"select.mlir",
"sine.mlir",
Expand Down Expand Up @@ -438,6 +443,7 @@ iree_check_single_backend_test_suite(
"rng_uniform.mlir",
"round.mlir",
"rsqrt.mlir",
"scatter.mlir",
"scatter_dynamic.mlir",
"select.mlir",
"sine.mlir",
Expand Down
6 changes: 6 additions & 0 deletions tests/e2e/stablehlo_ops/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ iree_check_single_backend_test_suite(
"rng_uniform.mlir"
"round.mlir"
"rsqrt.mlir"
"scatter.mlir"
"scatter_dynamic.mlir"
"select.mlir"
"sine.mlir"
Expand Down Expand Up @@ -122,6 +123,7 @@ iree_check_single_backend_test_suite(
"rng_uniform.mlir"
"round.mlir"
"rsqrt.mlir"
"scatter.mlir"
"scatter_dynamic.mlir"
"select.mlir"
"sine.mlir"
Expand Down Expand Up @@ -187,6 +189,7 @@ iree_check_single_backend_test_suite(
"rng_uniform.mlir"
"round.mlir"
"rsqrt.mlir"
"scatter.mlir"
"scatter_dynamic.mlir"
"select.mlir"
"sine.mlir"
Expand Down Expand Up @@ -253,6 +256,7 @@ iree_check_single_backend_test_suite(
"rng_uniform.mlir"
"round.mlir"
"rsqrt.mlir"
"scatter.mlir"
"scatter_dynamic.mlir"
"select.mlir"
"sine.mlir"
Expand Down Expand Up @@ -323,6 +327,7 @@ iree_check_single_backend_test_suite(
"rng_uniform.mlir"
"round.mlir"
"rsqrt.mlir"
"scatter.mlir"
"scatter_dynamic.mlir"
"select.mlir"
"sine.mlir"
Expand Down Expand Up @@ -397,6 +402,7 @@ iree_check_single_backend_test_suite(
"rng_uniform.mlir"
"round.mlir"
"rsqrt.mlir"
"scatter.mlir"
"scatter_dynamic.mlir"
"select.mlir"
"sine.mlir"
Expand Down
238 changes: 238 additions & 0 deletions tests/e2e/stablehlo_ops/scatter.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,238 @@
func.func @scatter_update_scalar_1D() {
%arg0 = util.unfoldable_constant dense<0> : tensor<8xi32>
%arg1 = util.unfoldable_constant dense<[[1], [3], [4], [7]]> : tensor<4x1xi32>
%arg2 = util.unfoldable_constant dense<[9, 10, 11, 12]> : tensor<4xi32>
%0 = "stablehlo.scatter"(%arg0, %arg1, %arg2) ( {
^bb0(%arg3: tensor<i32>, %arg4: tensor<i32>): // no predecessors
"stablehlo.return"(%arg4) : (tensor<i32>) -> ()
}) {
indices_are_sorted = false,
scatter_dimension_numbers = #stablehlo.scatter<
inserted_window_dims = [0],
scatter_dims_to_operand_dims = [0],
index_vector_dim = 1,
>,
unique_indices = true
} : (tensor<8xi32>, tensor<4x1xi32>, tensor<4xi32>) -> tensor<8xi32>
check.expect_eq_const(%0, dense<[0, 9, 0, 10, 11, 0, 0, 12]> : tensor<8xi32>) : tensor<8xi32>
return
}

func.func @scatter_repeated_update_scalar_1D() {
%arg0 = util.unfoldable_constant dense<0> : tensor<8xi32>
%arg1 = util.unfoldable_constant dense<[[1], [1], [7], [7]]> : tensor<4x1xi32>
%arg2 = util.unfoldable_constant dense<[9, 10, 11, 12]> : tensor<4xi32>
%0 = "stablehlo.scatter"(%arg0, %arg1, %arg2) ( {
^bb0(%arg3: tensor<i32>, %arg4: tensor<i32>): // no predecessors
"stablehlo.return"(%arg4) : (tensor<i32>) -> ()
}) {
indices_are_sorted = false,
scatter_dimension_numbers = #stablehlo.scatter<
inserted_window_dims = [0],
scatter_dims_to_operand_dims = [0],
index_vector_dim = 1,
>,
unique_indices = false
} : (tensor<8xi32>, tensor<4x1xi32>, tensor<4xi32>) -> tensor<8xi32>
check.expect_eq_const(%0, dense<[0, 10, 0, 0, 0, 0, 0, 12]> : tensor<8xi32>) : tensor<8xi32>
return
}

func.func @scatter_update_scalar_2D() {
%arg0 = util.unfoldable_constant dense<0> : tensor<4x3xi32>
%arg1 = util.unfoldable_constant dense<[[0, 0], [1, 1], [2, 2]]> : tensor<3x2xi32>
%arg2 = util.unfoldable_constant dense<[1, 2, 3]> : tensor<3xi32>
%0 = "stablehlo.scatter"(%arg0, %arg1, %arg2) ( {
^bb0(%arg3: tensor<i32>, %arg4: tensor<i32>): // no predecessors
"stablehlo.return"(%arg4) : (tensor<i32>) -> ()
}) {indices_are_sorted = false,
scatter_dimension_numbers = #stablehlo.scatter<
inserted_window_dims = [0, 1],
scatter_dims_to_operand_dims = [0, 1],
index_vector_dim = 1
>,
unique_indices = true
} : (tensor<4x3xi32>, tensor<3x2xi32>, tensor<3xi32>) -> tensor<4x3xi32>
check.expect_eq_const(%0, dense<[[1, 0, 0],
[0, 2, 0],
[0, 0, 3],
[0, 0, 0]]> : tensor<4x3xi32>) : tensor<4x3xi32>
return
}

func.func @scatter_update_slice_2D() {
%arg0 = util.unfoldable_constant dense<0> : tensor<6x3xi32>
%arg1 = util.unfoldable_constant dense<[[2], [4]]> : tensor<2x1xi32>
%arg2 = util.unfoldable_constant dense<[[1, 2, 3],
[4, 5, 6]]> : tensor<2x3xi32>
%0 = "stablehlo.scatter"(%arg0, %arg1, %arg2) ( {
^bb0(%arg3: tensor<i32>, %arg4: tensor<i32>): // no predecessors
"stablehlo.return"(%arg4) : (tensor<i32>) -> ()
}) {
indices_are_sorted = false,
scatter_dimension_numbers = #stablehlo.scatter<
update_window_dims = [1],
inserted_window_dims = [0],
scatter_dims_to_operand_dims = [0],
index_vector_dim = 1,
>,
unique_indices = true
} : (tensor<6x3xi32>, tensor<2x1xi32>, tensor<2x3xi32>) -> tensor<6x3xi32>
check.expect_eq_const(%0, dense<[[0, 0, 0],
[0, 0, 0],
[1, 2, 3],
[0, 0, 0],
[4, 5, 6],
[0, 0, 0]]> : tensor<6x3xi32>) : tensor<6x3xi32>
return
}

func.func @scatter_update_slice_partial_2D() {
%arg0 = util.unfoldable_constant dense<0> : tensor<6x3xi32>
%arg1 = util.unfoldable_constant dense<[[2], [4]]> : tensor<2x1xi32>
%arg2 = util.unfoldable_constant dense<[[1, 2],
[4, 5]]> : tensor<2x2xi32>
%0 = "stablehlo.scatter"(%arg0, %arg1, %arg2) ( {
^bb0(%arg3: tensor<i32>, %arg4: tensor<i32>): // no predecessors
"stablehlo.return"(%arg4) : (tensor<i32>) -> ()
}) {
indices_are_sorted = false,
scatter_dimension_numbers = #stablehlo.scatter<
update_window_dims = [1],
inserted_window_dims = [0],
scatter_dims_to_operand_dims = [0],
index_vector_dim = 1,
>,
unique_indices = true
} : (tensor<6x3xi32>, tensor<2x1xi32>, tensor<2x2xi32>) -> tensor<6x3xi32>
check.expect_eq_const(%0, dense<[[0, 0, 0],
[0, 0, 0],
[1, 2, 0],
[0, 0, 0],
[4, 5, 0],
[0, 0, 0]]> : tensor<6x3xi32>) : tensor<6x3xi32>
return
}

func.func @scatter_add_slice_2D() {
%arg0 = util.unfoldable_constant dense<1> : tensor<6x3xi32>
%arg1 = util.unfoldable_constant dense<[[2], [4]]> : tensor<2x1xi32>
%arg2 = util.unfoldable_constant dense<[[1, 2, 3],
[4, 5, 6]]> : tensor<2x3xi32>
%0 = "stablehlo.scatter"(%arg0, %arg1, %arg2) ( {
^bb0(%arg3: tensor<i32>, %arg4: tensor<i32>): // no predecessors
%1 = stablehlo.add %arg3, %arg4 : tensor<i32>
"stablehlo.return"(%1) : (tensor<i32>) -> ()
}) {
indices_are_sorted = false,
scatter_dimension_numbers = #stablehlo.scatter<
update_window_dims = [1],
inserted_window_dims = [0],
scatter_dims_to_operand_dims = [0],
index_vector_dim = 1,
>,
unique_indices = true
} : (tensor<6x3xi32>, tensor<2x1xi32>, tensor<2x3xi32>) -> tensor<6x3xi32>
check.expect_eq_const(%0, dense<[[1, 1, 1],
[1, 1, 1],
[2, 3, 4],
[1, 1, 1],
[5, 6, 7],
[1, 1, 1]]> : tensor<6x3xi32>) : tensor<6x3xi32>
return
}

func.func @scatter_1D_large() {
%original = util.unfoldable_constant dense<1> : tensor<1400xi32>
%update = util.unfoldable_constant dense<2> : tensor<1400xi32>
%init = tensor.empty() : tensor<1400xi32>
%indices = linalg.generic {
indexing_maps = [affine_map<(d0) -> (d0)>],
iterator_types = ["parallel"]}
outs(%init : tensor<1400xi32>) {
^bb0(%arg0: i32):
%0 = linalg.index 0 : index
%1 = arith.index_cast %0 : index to i32
linalg.yield %1 : i32
} -> tensor<1400xi32>
%indices_reshaped = tensor.expand_shape %indices [[0, 1]] :
tensor<1400xi32> into tensor<1400x1xi32>
%result = "stablehlo.scatter"(%original, %indices_reshaped, %update)({
^bb0(%arg3 : tensor<i32>, %arg4 : tensor<i32>):
"stablehlo.return"(%arg4) : (tensor<i32>) -> ()
}) {
indices_are_sorted = false,
scatter_dimension_numbers = #stablehlo.scatter<
inserted_window_dims = [0],
scatter_dims_to_operand_dims = [0],
index_vector_dim = 1,
>,
unique_indices = true
} : (tensor<1400xi32>, tensor<1400x1xi32>, tensor<1400xi32>) -> tensor<1400xi32>
check.expect_eq_const(%result, dense<2> : tensor<1400xi32>) : tensor<1400xi32>
return
}

func.func @scatter_2D_large() {
%original = util.unfoldable_constant dense<1> : tensor<200x300xi32>
%update = util.unfoldable_constant dense<2> : tensor<200x300xi32>
%init = tensor.empty() : tensor<200xi32>
%indices = linalg.generic {
indexing_maps = [affine_map<(d0) -> (d0)>],
iterator_types = ["parallel"]}
outs(%init : tensor<200xi32>) {
^bb0(%arg0: i32):
%0 = linalg.index 0 : index
%1 = arith.index_cast %0 : index to i32
linalg.yield %1 : i32
} -> tensor<200xi32>
%indices_reshaped = tensor.expand_shape %indices [[0, 1]] :
tensor<200xi32> into tensor<200x1xi32>
%result = "stablehlo.scatter"(%original, %indices_reshaped, %update)({
^bb0(%arg3 : tensor<i32>, %arg4 : tensor<i32>):
"stablehlo.return"(%arg4) : (tensor<i32>) -> ()
}) {
indices_are_sorted = false,
scatter_dimension_numbers = #stablehlo.scatter<
update_window_dims = [1],
inserted_window_dims = [0],
scatter_dims_to_operand_dims = [0],
index_vector_dim = 1,
>,
unique_indices = true
} : (tensor<200x300xi32>, tensor<200x1xi32>, tensor<200x300xi32>) -> tensor<200x300xi32>
check.expect_eq_const(%result, dense<2> : tensor<200x300xi32>) : tensor<200x300xi32>
return
}

func.func @scatter_2D_large_permuted() {
%original = util.unfoldable_constant dense<1> : tensor<200x300xi32>
%update = util.unfoldable_constant dense<2> : tensor<300x200xi32>
%init = tensor.empty() : tensor<300xi32>
%indices = linalg.generic {
indexing_maps = [affine_map<(d0) -> (d0)>],
iterator_types = ["parallel"]}
outs(%init : tensor<300xi32>) {
^bb0(%arg0: i32):
%0 = linalg.index 0 : index
%1 = arith.index_cast %0 : index to i32
linalg.yield %1 : i32
} -> tensor<300xi32>
%indices_reshaped = tensor.expand_shape %indices [[0, 1]] :
tensor<300xi32> into tensor<300x1xi32>
%result = "stablehlo.scatter"(%original, %indices_reshaped, %update)({
^bb0(%arg3 : tensor<i32>, %arg4 : tensor<i32>):
"stablehlo.return"(%arg4) : (tensor<i32>) -> ()
}) {
indices_are_sorted = false,
scatter_dimension_numbers = #stablehlo.scatter<
update_window_dims = [1],
inserted_window_dims = [1],
scatter_dims_to_operand_dims = [1],
index_vector_dim = 1,
>,
unique_indices = true
} : (tensor<200x300xi32>, tensor<300x1xi32>, tensor<300x200xi32>) -> tensor<200x300xi32>
check.expect_eq_const(%result, dense<2> : tensor<200x300xi32>) : tensor<200x300xi32>
return
}

0 comments on commit 540d05e

Please sign in to comment.