Skip to content

[mlir][Vector] Add vector.to_elements op #141457

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 70 additions & 15 deletions mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -790,40 +790,95 @@ def Vector_FMAOp :
}];
}

def Vector_ToElementsOp : Vector_Op<"to_elements", [
Pure,
TypesMatchWith<"operand element type matches result types",
"source", "elements", "SmallVector<Type>("
"::llvm::cast<VectorType>($_self).getNumElements(), "
"::llvm::cast<VectorType>($_self).getElementType())">]> {
let summary = "operation that decomposes a vector into all its scalar elements";
let description = [{
This operation decomposes all the scalar elements from a vector. The
decomposed scalar elements are returned in row-major order. The number of
scalar results must match the number of elements in the input vector type.
All the result elements have the same result type, which must match the
element type of the input vector. Scalable vectors are not supported.
Comment on lines +801 to +805
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it important that it decomposes into all elements? This op could be really useful for unrolling a dimension if we could do it dimwise. Something like:

%0:16 = vector.to_elements %v : vector<16x4xf32> -> vector<4xf32>

This should have the exact same semantics as vector.extract, just doing multiple extracts at once.

I would much rather have this form of the operation, it is much closer to vector.extract and works for N-D vectors much better.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that keeping the symmetry with from_elements is valuable. I'm not sure I follow the suggestion, but is it doing something that chaining extract / extract_strided_slice / shape_cast / to_elements cannot achieve?


Examples:

```mlir
// Decompose a 0-D vector.
%0 = vector.to_elements %v0 : vector<f32>
// %0 = %v0[0]

// Decompose a 1-D vector.
%0:2 = vector.to_elements %v1 : vector<2xf32>
// %0#0 = %v1[0]
// %0#1 = %v1[1]

// Decompose a 2-D.
%0:6 = vector.to_elements %v2 : vector<2x3xf32>
// %0#0 = %v2[0, 0]
// %0#1 = %v2[0, 1]
// %0#2 = %v2[0, 2]
// %0#3 = %v2[1, 0]
// %0#4 = %v2[1, 1]
// %0#5 = %v2[1, 2]

// Decompose a 3-D vector.
%0:6 = vector.to_elements %v3 : vector<3x1x2xf32>
// %0#0 = %v3[0, 0, 0]
// %0#1 = %v3[0, 0, 1]
// %0#2 = %v3[1, 0, 0]
// %0#3 = %v3[1, 0, 1]
// %0#4 = %v3[2, 0, 0]
// %0#5 = %v3[2, 0, 1]
```
}];

let arguments = (ins AnyVectorOfAnyRank:$source);
let results = (outs Variadic<AnyType>:$elements);
let assemblyFormat = "$source attr-dict `:` type($source)";
}

def Vector_FromElementsOp : Vector_Op<"from_elements", [
Pure,
TypesMatchWith<"operand types match result element type",
"result", "elements", "SmallVector<Type>("
"dest", "elements", "SmallVector<Type>("
"::llvm::cast<VectorType>($_self).getNumElements(), "
"::llvm::cast<VectorType>($_self).getElementType())">]> {
let summary = "operation that defines a vector from scalar elements";
let description = [{
This operation defines a vector from one or multiple scalar elements. The
number of elements must match the number of elements in the result type.
All elements must have the same type, which must match the element type of
the result vector type.

`elements` are a flattened version of the result vector in row-major order.
scalar elements are arranged in row-major within the vector. The number of
elements must match the number of elements in the result type. All elements
must have the same type, which must match the element type of the result
vector type. Scalable vectors are not supported.

Example:
Examples:

```mlir
// %f1
// Define a 0-D vector.
%0 = vector.from_elements %f1 : vector<f32>
// [%f1, %f2]
// [%f1]

// Define a 1-D vector.
%1 = vector.from_elements %f1, %f2 : vector<2xf32>
// [[%f1, %f2, %f3], [%f4, %f5, %f6]]
// [%f1, %f2]

// Define a 2-D vector.
%2 = vector.from_elements %f1, %f2, %f3, %f4, %f5, %f6 : vector<2x3xf32>
// [[[%f1, %f2]], [[%f3, %f4]], [[%f5, %f6]]]
// [[%f1, %f2, %f3], [%f4, %f5, %f6]]

// Define a 3-D vector.
%3 = vector.from_elements %f1, %f2, %f3, %f4, %f5, %f6 : vector<3x1x2xf32>
// [[[%f1, %f2]], [[%f3, %f4]], [[%f5, %f6]]]
```

Note, scalable vectors are not supported.
}];

let arguments = (ins Variadic<AnyType>:$elements);
let results = (outs AnyFixedVectorOfAnyRank:$result);
let assemblyFormat = "$elements attr-dict `:` type($result)";
let results = (outs AnyFixedVectorOfAnyRank:$dest);
let assemblyFormat = "$elements attr-dict `:` type($dest)";
let hasCanonicalizer = 1;
}

Expand Down
24 changes: 20 additions & 4 deletions mlir/test/Dialect/Vector/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1896,7 +1896,24 @@ func.func @deinterleave_scalable_rank_fail(%vec : vector<2x[4]xf32>) {

// -----

func.func @invalid_from_elements(%a: f32) {
func.func @to_elements_wrong_num_results(%a: vector<1x1x2xf32>) {
// expected-error @+1 {{operation defines 2 results but was provided 4 to bind}}
%0:4 = vector.to_elements %a : vector<1x1x2xf32>
return
}

// -----

func.func @to_elements_wrong_result_type(%a: vector<2xf32>) -> i32 {
// expected-error @+3 {{use of value '%0' expects different type than prior uses: 'i32'}}
// expected-note @+1 {{prior use here}}
%0:2 = vector.to_elements %a : vector<2xf32>
return %0#0 : i32
}

// -----

func.func @from_elements_wrong_num_operands(%a: f32) {
// expected-error @+1 {{'vector.from_elements' number of operands and types do not match: got 1 operands and 2 types}}
vector.from_elements %a : vector<2xf32>
return
Expand All @@ -1905,16 +1922,15 @@ func.func @invalid_from_elements(%a: f32) {
// -----

// expected-note @+1 {{prior use here}}
func.func @invalid_from_elements(%a: f32, %b: i32) {
func.func @from_elements_wrong_operand_type(%a: f32, %b: i32) {
// expected-error @+1 {{use of value '%b' expects different type than prior uses: 'f32' vs 'i32'}}
vector.from_elements %a, %b : vector<2xf32>
return
}

// -----

func.func @invalid_from_elements_scalable(%a: f32, %b: i32) {
// expected-error @+1 {{'result' must be fixed-length vector of any type values, but got 'vector<[2]xf32>'}}
// expected-error @+1 {{'dest' must be fixed-length vector of any type values, but got 'vector<[2]xf32>'}}
vector.from_elements %a, %b : vector<[2]xf32>
return
}
Expand Down
19 changes: 19 additions & 0 deletions mlir/test/Dialect/Vector/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1175,6 +1175,25 @@ func.func @deinterleave_nd_scalable(%arg:vector<2x3x4x[6]xf32>) -> (vector<2x3x4
return %0, %1 : vector<2x3x4x[3]xf32>, vector<2x3x4x[3]xf32>
}

// CHECK-LABEL: func @to_elements(
// CHECK-SAME: %[[A_VEC:.*]]: vector<f32>, %[[B_VEC:.*]]: vector<4xf32>,
// CHECK-SAME: %[[C_VEC:.*]]: vector<1xf32>, %[[D_VEC:.*]]: vector<2x2xf32>)
func.func @to_elements(%a_vec : vector<f32>, %b_vec : vector<4xf32>, %c_vec : vector<1xf32>, %d_vec : vector<2x2xf32>)
-> (f32, f32, f32, f32, f32, f32, f32, f32, f32, f32) {
// CHECK: %[[A_ELEMS:.*]] = vector.to_elements %[[A_VEC]] : vector<f32>
%0 = vector.to_elements %a_vec : vector<f32>
// CHECK: %[[B_ELEMS:.*]]:4 = vector.to_elements %[[B_VEC]] : vector<4xf32>
%1:4 = vector.to_elements %b_vec : vector<4xf32>
// CHECK: %[[C_ELEMS:.*]] = vector.to_elements %[[C_VEC]] : vector<1xf32>
%2 = vector.to_elements %c_vec : vector<1xf32>
// CHECK: %[[D_ELEMS:.*]]:4 = vector.to_elements %[[D_VEC]] : vector<2x2xf32>
%3:4 = vector.to_elements %d_vec : vector<2x2xf32>
// CHECK: return %[[A_ELEMS]], %[[B_ELEMS]]#0, %[[B_ELEMS]]#1, %[[B_ELEMS]]#2,
// CHECK-SAME: %[[B_ELEMS]]#3, %[[C_ELEMS]], %[[D_ELEMS]]#0, %[[D_ELEMS]]#1,
// CHECK-SAME: %[[D_ELEMS]]#2, %[[D_ELEMS]]#3
return %0, %1#0, %1#1, %1#2, %1#3, %2, %3#0, %3#1, %3#2, %3#3 : f32, f32, f32, f32, f32, f32, f32, f32, f32, f32
}

// CHECK-LABEL: func @from_elements(
// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: f32)
func.func @from_elements(%a: f32, %b: f32) -> (vector<f32>, vector<1xf32>, vector<1x2xf32>, vector<2x2xf32>) {
Expand Down
Loading