-
Notifications
You must be signed in to change notification settings - Fork 14k
[MLIR] Integration tests for lowering vector.contract to SVE FEAT_I8MM #140573
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MLIR] Integration tests for lowering vector.contract to SVE FEAT_I8MM #140573
Conversation
@llvm/pr-subscribers-mlir-vector @llvm/pr-subscribers-mlir-sve Author: Momchil Velikov (momchil-velikov) ChangesPatch is 30.80 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/140573.diff 5 Files Affected:
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-4x8x4.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-4x8x4.mlir
new file mode 100644
index 0000000000000..88534dd2aab1e
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-4x8x4.mlir
@@ -0,0 +1,117 @@
+// REQUIRES: arm-emulator
+
+// DEFINE: %{compile} = mlir-opt %s \
+// DEFINE: --convert-vector-to-scf --convert-scf-to-cf --convert-vector-to-llvm='enable-arm-sve enable-arm-i8mm' \
+// DEFINE: --expand-strided-metadata --convert-to-llvm --finalize-memref-to-llvm --reconcile-unrealized-casts \
+// DEFINE: -o %t
+
+// DEFINE: %{entry_point} = main
+
+// DEFINE: %{run} = %mcr_aarch64_cmd %t -e %{entry_point} -entry-point-result=void --march=aarch64 --mattr="+sve,+i8mm" \
+// DEFINE: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%native_mlir_arm_runner_utils
+
+// RUN: rm -f %t && %{compile} && %{run} | FileCheck %s
+
+#packed_maps = [
+ affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+]
+
+func.func private @setArmVLBits(%bits : i32)
+
+func.func @main() {
+ %c128 = arith.constant 128 : i32
+ func.call @setArmVLBits(%c128) : (i32) -> ()
+
+ %c0 = arith.constant 0 : index
+ %c0_i32 = arith.constant 0 : i32
+ %c0_i8 = arith.constant 0 : i8
+
+// Accumulator test data
+ %acc_cst = arith.constant dense<[[-44, 20, 44, -46],
+ [ -8, 25, -34, 26],
+ [-20, -36, -3, 39],
+ [-48, -31, -25, -21]]> : vector<4x4xi32>
+ %acc_m = memref.alloca() : memref<4x4xi32>
+ vector.transfer_write %acc_cst, %acc_m[%c0, %c0] : vector<4x4xi32>, memref<4x4xi32>
+
+ %acc_m1 = memref.collapse_shape %acc_m [[0, 1]] : memref<4x4xi32> into memref<16xi32>
+ %acc_flat = vector.transfer_read %acc_m1[%c0], %c0_i32 {in_bounds = [true]} : memref<16xi32>, vector<[16]xi32>
+ %acc = vector.shape_cast %acc_flat : vector<[16]xi32> to vector<4x[4]xi32>
+
+ vector.print str "ACC:\n"
+ %acc0 = vector.extract %acc[0] : vector<[4]xi32> from vector<4x[4]xi32>
+ %acc1 = vector.extract %acc[1] : vector<[4]xi32> from vector<4x[4]xi32>
+ %acc2 = vector.extract %acc[2] : vector<[4]xi32> from vector<4x[4]xi32>
+ %acc3 = vector.extract %acc[3] : vector<[4]xi32> from vector<4x[4]xi32>
+ vector.print %acc0 : vector<[4]xi32>
+ vector.print %acc1 : vector<[4]xi32>
+ vector.print %acc2 : vector<[4]xi32>
+ vector.print %acc3 : vector<[4]xi32>
+
+ // LHS test data
+ %lhs_cst = arith.constant dense<[[-35, -27, -36, -31, 23, -34, -8, -33],
+ [-20, 17, -32, -47, 37, 22, -7, -21],
+ [ -7, -35, 20, -4, 39, 46, -23, 40],
+ [ 40, 27, 37, 43, 38, -6, 37, 49]]> : vector<4x8xi8>
+
+ %lhs_m = memref.alloca() : memref<4x8xi8>
+ vector.transfer_write %lhs_cst, %lhs_m[%c0, %c0] : vector<4x8xi8>, memref<4x8xi8>
+ %lhs = vector.transfer_read %lhs_m[%c0, %c0], %c0_i8 : memref<4x8xi8>, vector<4x8xi8>
+
+ vector.print str "LHS:\n"
+ %lhs0 = vector.extract %lhs[0] : vector<8xi8> from vector<4x8xi8>
+ %lhs1 = vector.extract %lhs[1] : vector<8xi8> from vector<4x8xi8>
+ %lhs2 = vector.extract %lhs[2] : vector<8xi8> from vector<4x8xi8>
+ %lhs3 = vector.extract %lhs[3] : vector<8xi8> from vector<4x8xi8>
+ vector.print %lhs0 : vector<8xi8>
+ vector.print %lhs1 : vector<8xi8>
+ vector.print %lhs2 : vector<8xi8>
+ vector.print %lhs3 : vector<8xi8>
+
+ // RHS test data
+ %rhs_cst = arith.constant dense<[[-17, -50, -1, 48, -13, 22, 39, 33],
+ [-35, -24, 37, -32, 33, 30, -11, -17],
+ [-28, 31, 3, -44, -15, -27, 22, 35],
+ [-23, 39, 48, 26, -23, 32, -39, -38]]> : vector<4x8xi8>
+
+ %rhs_m = memref.alloca() : memref<4x8xi8>
+ vector.transfer_write %rhs_cst, %rhs_m[%c0, %c0] : vector<4x8xi8>, memref<4x8xi8>
+
+ %rhs_m1 = memref.collapse_shape %rhs_m [[0, 1]] : memref<4x8xi8> into memref<32xi8>
+ %rhs_flat = vector.transfer_read %rhs_m1[%c0], %c0_i8 {in_bounds = [true]} : memref<32xi8>, vector<[32]xi8>
+
+ vector.print str "RHS:\n"
+ %rhs0 = vector.scalable.extract %rhs_flat[0] : vector<[16]xi8> from vector<[32]xi8>
+ %rhs1 = vector.scalable.extract %rhs_flat[16] : vector<[16]xi8> from vector<[32]xi8>
+ vector.print %rhs0 : vector<[16]xi8>
+ vector.print %rhs1 : vector<[16]xi8>
+
+ %rhs = vector.shape_cast %rhs_flat : vector<[32]xi8> to vector<[4]x8xi8>
+
+ // Matrix multiplication
+ %0 = arith.extsi %lhs : vector<4x8xi8> to vector<4x8xi32>
+ %1 = arith.extsi %rhs : vector<[4]x8xi8> to vector<[4]x8xi32>
+ %2 = vector.contract {indexing_maps = #packed_maps,
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>} %0, %1, %acc
+ : vector<4x8xi32>, vector<[4]x8xi32> into vector<4x[4]xi32>
+
+ // Display the result of the multiplication
+ vector.print str "Result:\n"
+ %u0 = vector.extract %2[0] : vector<[4]xi32> from vector<4x[4]xi32>
+ %u1 = vector.extract %2[1] : vector<[4]xi32> from vector<4x[4]xi32>
+ %u2 = vector.extract %2[2] : vector<[4]xi32> from vector<4x[4]xi32>
+ %u3 = vector.extract %2[3] : vector<[4]xi32> from vector<4x[4]xi32>
+ vector.print %u0 : vector<[4]xi32>
+ vector.print %u1 : vector<[4]xi32>
+ vector.print %u2 : vector<[4]xi32>
+ vector.print %u3 : vector<[4]xi32>
+
+// CHECK: ( -1999, 1941, 685, -2879 )
+// CHECK: ( -3705, 2952, 987, -685 )
+// CHECK: ( 2565, 4157, -1589, -357 )
+// CHECK: ( 2383, -2252, 32, -1365 )
+ return
+}
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-8x8x8-vs2.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-8x8x8-vs2.mlir
new file mode 100644
index 0000000000000..ce57be91fa540
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-8x8x8-vs2.mlir
@@ -0,0 +1,159 @@
+// REQUIRES: arm-emulator
+
+// DEFINE: %{compile} = mlir-opt %s \
+// DEFINE: --convert-vector-to-scf --convert-scf-to-cf --convert-vector-to-llvm='enable-arm-sve enable-arm-i8mm' \
+// DEFINE: --expand-strided-metadata --convert-to-llvm --finalize-memref-to-llvm --reconcile-unrealized-casts \
+// DEFINE: -o %t
+
+// DEFINE: %{entry_point} = main
+
+// DEFINE: %{run} = %mcr_aarch64_cmd %t -e %{entry_point} -entry-point-result=void --march=aarch64 --mattr="+sve,+i8mm" \
+// DEFINE: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%native_mlir_arm_runner_utils
+
+// RUN: rm -f %t && %{compile} && %{run} | FileCheck %s
+
+#packed_maps = [
+ affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+]
+
+func.func private @setArmVLBits(%bits : i32)
+
+func.func @main() {
+ %c256 = arith.constant 256 : i32
+ func.call @setArmVLBits(%c256) : (i32) -> ()
+
+ %c0 = arith.constant 0 : index
+ %c0_i32 = arith.constant 0 : i32
+ %c0_i8 = arith.constant 0 : i8
+
+
+ // Accumulator test data
+ %acc_cst = arith.constant dense<[[-44, 20, 44, -46, -8, 25, -34, 26],
+ [-20, -36, -3, 39, -48, -31, -25, -21],
+ [-35, -27, -36, -31, 23, -34, -8, -33],
+ [-20, 17, -32, -47, 37, 22, -7, -21],
+ [ -7, -35, 20, -4, 39, 46, -23, 40],
+ [ 40, 27, 37, 43, 38, -6, 37, 49],
+ [-17, -50, -1, 48, -13, 22, 39, 33],
+ [-35, -24, 37, -32, 33, 30, -11, -17]]> : vector<8x8xi32>
+ %acc_m = memref.alloca() : memref<8x8xi32>
+ vector.transfer_write %acc_cst, %acc_m[%c0, %c0] : vector<8x8xi32>, memref<8x8xi32>
+
+ %acc_m1 = memref.collapse_shape %acc_m [[0, 1]] : memref<8x8xi32> into memref<64xi32>
+ %acc_flat = vector.transfer_read %acc_m1[%c0], %c0_i32 {in_bounds = [true]} : memref<64xi32>, vector<[32]xi32>
+ %acc = vector.shape_cast %acc_flat : vector<[32]xi32> to vector<8x[4]xi32>
+
+ vector.print str "ACC:\n"
+ %acc0 = vector.extract %acc[0] : vector<[4]xi32> from vector<8x[4]xi32>
+ %acc1 = vector.extract %acc[1] : vector<[4]xi32> from vector<8x[4]xi32>
+ %acc2 = vector.extract %acc[2] : vector<[4]xi32> from vector<8x[4]xi32>
+ %acc3 = vector.extract %acc[3] : vector<[4]xi32> from vector<8x[4]xi32>
+ %acc4 = vector.extract %acc[4] : vector<[4]xi32> from vector<8x[4]xi32>
+ %acc5 = vector.extract %acc[5] : vector<[4]xi32> from vector<8x[4]xi32>
+ %acc6 = vector.extract %acc[6] : vector<[4]xi32> from vector<8x[4]xi32>
+ %acc7 = vector.extract %acc[7] : vector<[4]xi32> from vector<8x[4]xi32>
+ vector.print %acc0 : vector<[4]xi32>
+ vector.print %acc1 : vector<[4]xi32>
+ vector.print %acc2 : vector<[4]xi32>
+ vector.print %acc3 : vector<[4]xi32>
+ vector.print %acc4 : vector<[4]xi32>
+ vector.print %acc5 : vector<[4]xi32>
+ vector.print %acc6 : vector<[4]xi32>
+ vector.print %acc7 : vector<[4]xi32>
+
+ // LHS test data
+ %lhs_cst = arith.constant dense<[[-28, 31, 3, -44, -15, -27, 22, 35],
+ [-23, 39, 48, 26, -23, 32, -39, -38],
+ [ -3, 9, 43, -30, -32, 39, 41, -39],
+ [-13, -21, -25, 27, 47, -36, -11, -11],
+ [ -4, -20, 36, 11, 13, -23, 24, -13],
+ [-20, 30, -5, 1, 42, -37, -22, 35],
+ [-22, 38, -4, 44, 25, -31, 23, -39],
+ [-45, -4, -31, -24, 14, -41, -47, 22]]> : vector<8x8xi8>
+
+ %lhs_m = memref.alloca() : memref<8x8xi8>
+ vector.transfer_write %lhs_cst, %lhs_m[%c0, %c0] : vector<8x8xi8>, memref<8x8xi8>
+ %lhs = vector.transfer_read %lhs_m[%c0, %c0], %c0_i8 : memref<8x8xi8>, vector<8x8xi8>
+
+ vector.print str "LHS:\n"
+ %lhs0 = vector.extract %lhs[0] : vector<8xi8> from vector<8x8xi8>
+ %lhs1 = vector.extract %lhs[1] : vector<8xi8> from vector<8x8xi8>
+ %lhs2 = vector.extract %lhs[2] : vector<8xi8> from vector<8x8xi8>
+ %lhs3 = vector.extract %lhs[3] : vector<8xi8> from vector<8x8xi8>
+ %lhs4 = vector.extract %lhs[4] : vector<8xi8> from vector<8x8xi8>
+ %lhs5 = vector.extract %lhs[5] : vector<8xi8> from vector<8x8xi8>
+ %lhs6 = vector.extract %lhs[6] : vector<8xi8> from vector<8x8xi8>
+ %lhs7 = vector.extract %lhs[7] : vector<8xi8> from vector<8x8xi8>
+ vector.print %lhs0 : vector<8xi8>
+ vector.print %lhs1 : vector<8xi8>
+ vector.print %lhs2 : vector<8xi8>
+ vector.print %lhs3 : vector<8xi8>
+ vector.print %lhs4 : vector<8xi8>
+ vector.print %lhs5 : vector<8xi8>
+ vector.print %lhs6 : vector<8xi8>
+ vector.print %lhs7 : vector<8xi8>
+
+ // RHS test data
+ %rhs_cst = arith.constant dense<[[-40, -11, -36, 36, -1, 20, 14, -32],
+ [ 46, -45, -48, -46, -24, 31, -36, 22],
+ [ 2, 36, 45, -29, -37, -49, -20, -35],
+ [ -6, 23, 23, 15, 20, 4, -8, -2],
+ [-35, -6, 16, 49, -50, 9, -44, 13],
+ [ 24, 1, -4, -44, 41, 15, -43, 44],
+ [ 44, 0, -10, 41, 22, 44, -40, 0],
+ [-33, 19, 27, 22, 38, -17, 23, -9]]> : vector<8x8xi8>
+
+ %rhs_m = memref.alloca() : memref<8x8xi8>
+ vector.transfer_write %rhs_cst, %rhs_m[%c0, %c0] : vector<8x8xi8>, memref<8x8xi8>
+
+ %rhs_m1 = memref.collapse_shape %rhs_m [[0, 1]] : memref<8x8xi8> into memref<64xi8>
+ %rhs_flat = vector.transfer_read %rhs_m1[%c0], %c0_i8 {in_bounds = [true]} : memref<64xi8>, vector<[32]xi8>
+
+ vector.print str "RHS:\n"
+ %rhs0 = vector.scalable.extract %rhs_flat[ 0] : vector<[16]xi8> from vector<[32]xi8>
+ %rhs1 = vector.scalable.extract %rhs_flat[16] : vector<[16]xi8> from vector<[32]xi8>
+ vector.print %rhs0 : vector<[16]xi8>
+ vector.print %rhs1 : vector<[16]xi8>
+
+ %rhs = vector.shape_cast %rhs_flat : vector<[32]xi8> to vector<[4]x8xi8>
+
+ // Matrix multiplication
+ %0 = arith.extsi %lhs : vector<8x8xi8> to vector<8x8xi32>
+ %1 = arith.extsi %rhs : vector<[4]x8xi8> to vector<[4]x8xi32>
+ %2 = vector.contract {indexing_maps = #packed_maps,
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>} %0, %1, %acc
+ : vector<8x8xi32>, vector<[4]x8xi32> into vector<8x[4]xi32>
+
+ // Display the result of the multilication
+ vector.print str "Result:\n"
+ %u0 = vector.extract %2[0] : vector<[4]xi32> from vector<8x[4]xi32>
+ %u1 = vector.extract %2[1] : vector<[4]xi32> from vector<8x[4]xi32>
+ %u2 = vector.extract %2[2] : vector<[4]xi32> from vector<8x[4]xi32>
+ %u3 = vector.extract %2[3] : vector<[4]xi32> from vector<8x[4]xi32>
+ %u4 = vector.extract %2[4] : vector<[4]xi32> from vector<8x[4]xi32>
+ %u5 = vector.extract %2[5] : vector<[4]xi32> from vector<8x[4]xi32>
+ %u6 = vector.extract %2[6] : vector<[4]xi32> from vector<8x[4]xi32>
+ %u7 = vector.extract %2[7] : vector<[4]xi32> from vector<8x[4]xi32>
+ vector.print %u0 : vector<[4]xi32>
+ vector.print %u1 : vector<[4]xi32>
+ vector.print %u2 : vector<[4]xi32>
+ vector.print %u3 : vector<[4]xi32>
+ vector.print %u4 : vector<[4]xi32>
+ vector.print %u5 : vector<[4]xi32>
+ vector.print %u6 : vector<[4]xi32>
+ vector.print %u7 : vector<[4]xi32>
+
+
+// CHECK: ( -2294, -1282, 2728, -410, -1328, 882, -5498, 732 )
+// CHECK: ( 1012, -4237, 4154, 2624, 5225, -2338, 2011, 1374 )
+// CHECK: ( -8, -1611, 2905, -1, -1068, -3155, -2428, 153 )
+// CHECK: ( 2034, -1768, -2092, 284, -792, -23, 668, 2172 )
+// CHECK: ( -248, -3728, 1214, 555, -668, -2114, -1794, 2560 )
+// CHECK: ( -1484, -2642, 297, 1551, -483, 3173, -576, 2570 )
+// CHECK: ( 3098, -7851, 1366, 1892, -427, -4533, -819, 4698 )
+// CHECK: ( -135, 1247, 765, -479, 1245, 3074, -2281, -23 )
+ return
+}
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-summla-4x8x4.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-summla-4x8x4.mlir
new file mode 100644
index 0000000000000..f1f311ddb0c18
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-summla-4x8x4.mlir
@@ -0,0 +1,118 @@
+// REQUIRES: arm-emulator
+
+// DEFINE: %{compile} = mlir-opt %s \
+// DEFINE: --convert-vector-to-scf --convert-scf-to-cf --convert-vector-to-llvm='enable-arm-sve enable-arm-i8mm' \
+// DEFINE: --expand-strided-metadata --convert-to-llvm --finalize-memref-to-llvm --reconcile-unrealized-casts \
+// DEFINE: -o %t
+
+// DEFINE: %{entry_point} = main
+
+// DEFINE: %{run} = %mcr_aarch64_cmd %t -e %{entry_point} -entry-point-result=void --march=aarch64 --mattr="+sve,+i8mm" \
+// DEFINE: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%native_mlir_arm_runner_utils
+
+// RUN: rm -f %t && %{compile} && %{run} | FileCheck %s
+
+#packed_maps = [
+ affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+]
+
+func.func private @setArmVLBits(%bits : i32)
+
+func.func @main() {
+ %c128 = arith.constant 128 : i32
+ func.call @setArmVLBits(%c128) : (i32) -> ()
+
+ %c0 = arith.constant 0 : index
+ %c0_i32 = arith.constant 0 : i32
+ %c0_i8 = arith.constant 0 : i8
+
+// Accumulator test data
+ %acc_cst = arith.constant dense<[[-44, 20, 44, -46],
+ [ -8, 25, -34, 26],
+ [-20, -36, -3, 39],
+ [-48, -31, -25, -21]]> : vector<4x4xi32>
+ %acc_m = memref.alloca() : memref<4x4xi32>
+ vector.transfer_write %acc_cst, %acc_m[%c0, %c0] : vector<4x4xi32>, memref<4x4xi32>
+
+ %acc_m1 = memref.collapse_shape %acc_m [[0, 1]] : memref<4x4xi32> into memref<16xi32>
+ %acc_flat = vector.transfer_read %acc_m1[%c0], %c0_i32 {in_bounds = [true]} : memref<16xi32>, vector<[16]xi32>
+ %acc = vector.shape_cast %acc_flat : vector<[16]xi32> to vector<4x[4]xi32>
+
+ vector.print str "ACC:\n"
+ %acc0 = vector.extract %acc[0] : vector<[4]xi32> from vector<4x[4]xi32>
+ %acc1 = vector.extract %acc[1] : vector<[4]xi32> from vector<4x[4]xi32>
+ %acc2 = vector.extract %acc[2] : vector<[4]xi32> from vector<4x[4]xi32>
+ %acc3 = vector.extract %acc[3] : vector<[4]xi32> from vector<4x[4]xi32>
+ vector.print %acc0 : vector<[4]xi32>
+ vector.print %acc1 : vector<[4]xi32>
+ vector.print %acc2 : vector<[4]xi32>
+ vector.print %acc3 : vector<[4]xi32>
+
+ // LHS test data
+ %lhs_cst = arith.constant dense<[[-35, -27, -36, -31, 23, -34, -8, -33],
+ [-20, 17, -32, -47, 37, 22, -7, -21],
+ [ -7, -35, 20, -4, 39, 46, -23, 40],
+ [ 40, 27, 37, 43, 38, -6, 37, 49]]> : vector<4x8xi8>
+
+ %lhs_m = memref.alloca() : memref<4x8xi8>
+ vector.transfer_write %lhs_cst, %lhs_m[%c0, %c0] : vector<4x8xi8>, memref<4x8xi8>
+ %lhs = vector.transfer_read %lhs_m[%c0, %c0], %c0_i8 : memref<4x8xi8>, vector<4x8xi8>
+
+ vector.print str "LHS:\n"
+ %lhs0 = vector.extract %lhs[0] : vector<8xi8> from vector<4x8xi8>
+ %lhs1 = vector.extract %lhs[1] : vector<8xi8> from vector<4x8xi8>
+ %lhs2 = vector.extract %lhs[2] : vector<8xi8> from vector<4x8xi8>
+ %lhs3 = vector.extract %lhs[3] : vector<8xi8> from vector<4x8xi8>
+ vector.print %lhs0 : vector<8xi8>
+ vector.print %lhs1 : vector<8xi8>
+ vector.print %lhs2 : vector<8xi8>
+ vector.print %lhs3 : vector<8xi8>
+
+ // RHS test data
+ %rhs_cst = arith.constant dense<[[125, 171, 138, 187, 108, 175, 82, 99],
+ [221, 25, 164, 97, 156, 221, 218, 177],
+ [171, 160, 219, 191, 144, 45, 161, 210],
+ [223, 165, 123, 99, 108, 86, 37, 92]]> : vector<4x8xi8>
+
+ %rhs_m = memref.alloca() : memref<4x8xi8>
+ vector.transfer_write %rhs_cst, %rhs_m[%c0, %c0] : vector<4x8xi8>, memref<4x8xi8>
+
+ %rhs_m1 = memref.collapse_shape %rhs_m [[0, 1]] : memref<4x8xi8> into memref<32xi8>
+ %rhs_flat = vector.transfer_read %rhs_m1[%c0], %c0_i8 {in_bounds = [true]} : memref<32xi8>, vector<[32]xi8>
+
+ vector.print str "RHS:\n"
+ %rhs0 = vector.scalable.extract %rhs_flat[0] : vector<[16]xi8> from vector<[32]xi8>
+ %rhs1 = vector.scalable.extract %rhs_flat[16] : vector<[16]xi8> from vector<[32]xi8>
+ vector.print %rhs0 : vector<[16]xi8>
+ vector.print %rhs1 : vector<[16]xi8>
+
+ %rhs = vector.shape_cast %rhs_flat : vector<[32]xi8> to vector<[4]x8xi8>
+
+ // Matrix multiplication
+ %0 = arith.extsi %lhs : vector<4x8xi8> to vector<4x8xi32>
+ %1 = arith.extui %rhs : vector<[4]x8xi8> to vector<[4]x8xi32>
+ %2 = vector.contract {indexing_maps = #packed_maps,
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>} %0, %1, %acc
+ : vector<4x8xi32>, vector<[4]x8xi32> into vector<4x[4]xi32>
+
+ // Display the result of the multiplication
+ vector.print str "Result:\n"
+ %u0 = vector.extract %2[0] : vector<[4]xi32> from vector<4x[4]xi32>
+ %u1 = vector.extract %2[1] : vector<[4]xi32> from vector<4x[4]xi32>
+ %u2 = vector.extract %2[2] : vector<[4]xi32> from vector<4x[4]xi32>
+ %u3 = vector.extract %2[3] : vector<[4]xi32> from vector<4x[4]xi32>
+ vector.print %u0 : vector<[4]xi32>
+ vector.print %u1 : vector<[4]xi32>
+ vector.print %u2 : vector<[4]xi32>
+ vector.print %u3 : vector<[4]xi32>
+
+// CHECK: ( -27190, -28812, -30502, -23575 )
+// CHECK: ( -7613, -8386, -15938, -6521 )
+// CHECK: ( 9468, 18750, 9199, 5764 )
+// CHECK: ( 33655, 41064, 48900, 31627 )
+ return
+}
+
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-ummla-4x8x4.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-ummla-4x8x4.mlir
new file mode 100644
index 0000000000000..7af0b2c3f1054
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-ummla-4x8x4.mlir
@@ -0,0 +1,119 @@
+// REQUIRES: arm-emulator
+
+// DEFINE: %{compile} = mlir-opt %s ...
[truncated]
|
@llvm/pr-subscribers-mlir Author: Momchil Velikov (momchil-velikov) ChangesPatch is 30.80 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/140573.diff 5 Files Affected:
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-4x8x4.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-4x8x4.mlir
new file mode 100644
index 0000000000000..88534dd2aab1e
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-4x8x4.mlir
@@ -0,0 +1,117 @@
+// REQUIRES: arm-emulator
+
+// DEFINE: %{compile} = mlir-opt %s \
+// DEFINE: --convert-vector-to-scf --convert-scf-to-cf --convert-vector-to-llvm='enable-arm-sve enable-arm-i8mm' \
+// DEFINE: --expand-strided-metadata --convert-to-llvm --finalize-memref-to-llvm --reconcile-unrealized-casts \
+// DEFINE: -o %t
+
+// DEFINE: %{entry_point} = main
+
+// DEFINE: %{run} = %mcr_aarch64_cmd %t -e %{entry_point} -entry-point-result=void --march=aarch64 --mattr="+sve,+i8mm" \
+// DEFINE: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%native_mlir_arm_runner_utils
+
+// RUN: rm -f %t && %{compile} && %{run} | FileCheck %s
+
+#packed_maps = [
+ affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+]
+
+func.func private @setArmVLBits(%bits : i32)
+
+func.func @main() {
+ %c128 = arith.constant 128 : i32
+ func.call @setArmVLBits(%c128) : (i32) -> ()
+
+ %c0 = arith.constant 0 : index
+ %c0_i32 = arith.constant 0 : i32
+ %c0_i8 = arith.constant 0 : i8
+
+// Accumulator test data
+ %acc_cst = arith.constant dense<[[-44, 20, 44, -46],
+ [ -8, 25, -34, 26],
+ [-20, -36, -3, 39],
+ [-48, -31, -25, -21]]> : vector<4x4xi32>
+ %acc_m = memref.alloca() : memref<4x4xi32>
+ vector.transfer_write %acc_cst, %acc_m[%c0, %c0] : vector<4x4xi32>, memref<4x4xi32>
+
+ %acc_m1 = memref.collapse_shape %acc_m [[0, 1]] : memref<4x4xi32> into memref<16xi32>
+ %acc_flat = vector.transfer_read %acc_m1[%c0], %c0_i32 {in_bounds = [true]} : memref<16xi32>, vector<[16]xi32>
+ %acc = vector.shape_cast %acc_flat : vector<[16]xi32> to vector<4x[4]xi32>
+
+ vector.print str "ACC:\n"
+ %acc0 = vector.extract %acc[0] : vector<[4]xi32> from vector<4x[4]xi32>
+ %acc1 = vector.extract %acc[1] : vector<[4]xi32> from vector<4x[4]xi32>
+ %acc2 = vector.extract %acc[2] : vector<[4]xi32> from vector<4x[4]xi32>
+ %acc3 = vector.extract %acc[3] : vector<[4]xi32> from vector<4x[4]xi32>
+ vector.print %acc0 : vector<[4]xi32>
+ vector.print %acc1 : vector<[4]xi32>
+ vector.print %acc2 : vector<[4]xi32>
+ vector.print %acc3 : vector<[4]xi32>
+
+ // LHS test data
+ %lhs_cst = arith.constant dense<[[-35, -27, -36, -31, 23, -34, -8, -33],
+ [-20, 17, -32, -47, 37, 22, -7, -21],
+ [ -7, -35, 20, -4, 39, 46, -23, 40],
+ [ 40, 27, 37, 43, 38, -6, 37, 49]]> : vector<4x8xi8>
+
+ %lhs_m = memref.alloca() : memref<4x8xi8>
+ vector.transfer_write %lhs_cst, %lhs_m[%c0, %c0] : vector<4x8xi8>, memref<4x8xi8>
+ %lhs = vector.transfer_read %lhs_m[%c0, %c0], %c0_i8 : memref<4x8xi8>, vector<4x8xi8>
+
+ vector.print str "LHS:\n"
+ %lhs0 = vector.extract %lhs[0] : vector<8xi8> from vector<4x8xi8>
+ %lhs1 = vector.extract %lhs[1] : vector<8xi8> from vector<4x8xi8>
+ %lhs2 = vector.extract %lhs[2] : vector<8xi8> from vector<4x8xi8>
+ %lhs3 = vector.extract %lhs[3] : vector<8xi8> from vector<4x8xi8>
+ vector.print %lhs0 : vector<8xi8>
+ vector.print %lhs1 : vector<8xi8>
+ vector.print %lhs2 : vector<8xi8>
+ vector.print %lhs3 : vector<8xi8>
+
+ // RHS test data
+ %rhs_cst = arith.constant dense<[[-17, -50, -1, 48, -13, 22, 39, 33],
+ [-35, -24, 37, -32, 33, 30, -11, -17],
+ [-28, 31, 3, -44, -15, -27, 22, 35],
+ [-23, 39, 48, 26, -23, 32, -39, -38]]> : vector<4x8xi8>
+
+ %rhs_m = memref.alloca() : memref<4x8xi8>
+ vector.transfer_write %rhs_cst, %rhs_m[%c0, %c0] : vector<4x8xi8>, memref<4x8xi8>
+
+ %rhs_m1 = memref.collapse_shape %rhs_m [[0, 1]] : memref<4x8xi8> into memref<32xi8>
+ %rhs_flat = vector.transfer_read %rhs_m1[%c0], %c0_i8 {in_bounds = [true]} : memref<32xi8>, vector<[32]xi8>
+
+ vector.print str "RHS:\n"
+ %rhs0 = vector.scalable.extract %rhs_flat[0] : vector<[16]xi8> from vector<[32]xi8>
+ %rhs1 = vector.scalable.extract %rhs_flat[16] : vector<[16]xi8> from vector<[32]xi8>
+ vector.print %rhs0 : vector<[16]xi8>
+ vector.print %rhs1 : vector<[16]xi8>
+
+ %rhs = vector.shape_cast %rhs_flat : vector<[32]xi8> to vector<[4]x8xi8>
+
+ // Matrix multiplication
+ %0 = arith.extsi %lhs : vector<4x8xi8> to vector<4x8xi32>
+ %1 = arith.extsi %rhs : vector<[4]x8xi8> to vector<[4]x8xi32>
+ %2 = vector.contract {indexing_maps = #packed_maps,
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>} %0, %1, %acc
+ : vector<4x8xi32>, vector<[4]x8xi32> into vector<4x[4]xi32>
+
+ // Display the result of the multiplication
+ vector.print str "Result:\n"
+ %u0 = vector.extract %2[0] : vector<[4]xi32> from vector<4x[4]xi32>
+ %u1 = vector.extract %2[1] : vector<[4]xi32> from vector<4x[4]xi32>
+ %u2 = vector.extract %2[2] : vector<[4]xi32> from vector<4x[4]xi32>
+ %u3 = vector.extract %2[3] : vector<[4]xi32> from vector<4x[4]xi32>
+ vector.print %u0 : vector<[4]xi32>
+ vector.print %u1 : vector<[4]xi32>
+ vector.print %u2 : vector<[4]xi32>
+ vector.print %u3 : vector<[4]xi32>
+
+// CHECK: ( -1999, 1941, 685, -2879 )
+// CHECK: ( -3705, 2952, 987, -685 )
+// CHECK: ( 2565, 4157, -1589, -357 )
+// CHECK: ( 2383, -2252, 32, -1365 )
+ return
+}
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-8x8x8-vs2.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-8x8x8-vs2.mlir
new file mode 100644
index 0000000000000..ce57be91fa540
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-8x8x8-vs2.mlir
@@ -0,0 +1,159 @@
+// REQUIRES: arm-emulator
+
+// DEFINE: %{compile} = mlir-opt %s \
+// DEFINE: --convert-vector-to-scf --convert-scf-to-cf --convert-vector-to-llvm='enable-arm-sve enable-arm-i8mm' \
+// DEFINE: --expand-strided-metadata --convert-to-llvm --finalize-memref-to-llvm --reconcile-unrealized-casts \
+// DEFINE: -o %t
+
+// DEFINE: %{entry_point} = main
+
+// DEFINE: %{run} = %mcr_aarch64_cmd %t -e %{entry_point} -entry-point-result=void --march=aarch64 --mattr="+sve,+i8mm" \
+// DEFINE: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%native_mlir_arm_runner_utils
+
+// RUN: rm -f %t && %{compile} && %{run} | FileCheck %s
+
+#packed_maps = [
+ affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+]
+
+func.func private @setArmVLBits(%bits : i32)
+
+func.func @main() {
+ %c256 = arith.constant 256 : i32
+ func.call @setArmVLBits(%c256) : (i32) -> ()
+
+ %c0 = arith.constant 0 : index
+ %c0_i32 = arith.constant 0 : i32
+ %c0_i8 = arith.constant 0 : i8
+
+
+ // Accumulator test data
+ %acc_cst = arith.constant dense<[[-44, 20, 44, -46, -8, 25, -34, 26],
+ [-20, -36, -3, 39, -48, -31, -25, -21],
+ [-35, -27, -36, -31, 23, -34, -8, -33],
+ [-20, 17, -32, -47, 37, 22, -7, -21],
+ [ -7, -35, 20, -4, 39, 46, -23, 40],
+ [ 40, 27, 37, 43, 38, -6, 37, 49],
+ [-17, -50, -1, 48, -13, 22, 39, 33],
+ [-35, -24, 37, -32, 33, 30, -11, -17]]> : vector<8x8xi32>
+ %acc_m = memref.alloca() : memref<8x8xi32>
+ vector.transfer_write %acc_cst, %acc_m[%c0, %c0] : vector<8x8xi32>, memref<8x8xi32>
+
+ %acc_m1 = memref.collapse_shape %acc_m [[0, 1]] : memref<8x8xi32> into memref<64xi32>
+ %acc_flat = vector.transfer_read %acc_m1[%c0], %c0_i32 {in_bounds = [true]} : memref<64xi32>, vector<[32]xi32>
+ %acc = vector.shape_cast %acc_flat : vector<[32]xi32> to vector<8x[4]xi32>
+
+ vector.print str "ACC:\n"
+ %acc0 = vector.extract %acc[0] : vector<[4]xi32> from vector<8x[4]xi32>
+ %acc1 = vector.extract %acc[1] : vector<[4]xi32> from vector<8x[4]xi32>
+ %acc2 = vector.extract %acc[2] : vector<[4]xi32> from vector<8x[4]xi32>
+ %acc3 = vector.extract %acc[3] : vector<[4]xi32> from vector<8x[4]xi32>
+ %acc4 = vector.extract %acc[4] : vector<[4]xi32> from vector<8x[4]xi32>
+ %acc5 = vector.extract %acc[5] : vector<[4]xi32> from vector<8x[4]xi32>
+ %acc6 = vector.extract %acc[6] : vector<[4]xi32> from vector<8x[4]xi32>
+ %acc7 = vector.extract %acc[7] : vector<[4]xi32> from vector<8x[4]xi32>
+ vector.print %acc0 : vector<[4]xi32>
+ vector.print %acc1 : vector<[4]xi32>
+ vector.print %acc2 : vector<[4]xi32>
+ vector.print %acc3 : vector<[4]xi32>
+ vector.print %acc4 : vector<[4]xi32>
+ vector.print %acc5 : vector<[4]xi32>
+ vector.print %acc6 : vector<[4]xi32>
+ vector.print %acc7 : vector<[4]xi32>
+
+ // LHS test data
+ %lhs_cst = arith.constant dense<[[-28, 31, 3, -44, -15, -27, 22, 35],
+ [-23, 39, 48, 26, -23, 32, -39, -38],
+ [ -3, 9, 43, -30, -32, 39, 41, -39],
+ [-13, -21, -25, 27, 47, -36, -11, -11],
+ [ -4, -20, 36, 11, 13, -23, 24, -13],
+ [-20, 30, -5, 1, 42, -37, -22, 35],
+ [-22, 38, -4, 44, 25, -31, 23, -39],
+ [-45, -4, -31, -24, 14, -41, -47, 22]]> : vector<8x8xi8>
+
+ %lhs_m = memref.alloca() : memref<8x8xi8>
+ vector.transfer_write %lhs_cst, %lhs_m[%c0, %c0] : vector<8x8xi8>, memref<8x8xi8>
+ %lhs = vector.transfer_read %lhs_m[%c0, %c0], %c0_i8 : memref<8x8xi8>, vector<8x8xi8>
+
+ vector.print str "LHS:\n"
+ %lhs0 = vector.extract %lhs[0] : vector<8xi8> from vector<8x8xi8>
+ %lhs1 = vector.extract %lhs[1] : vector<8xi8> from vector<8x8xi8>
+ %lhs2 = vector.extract %lhs[2] : vector<8xi8> from vector<8x8xi8>
+ %lhs3 = vector.extract %lhs[3] : vector<8xi8> from vector<8x8xi8>
+ %lhs4 = vector.extract %lhs[4] : vector<8xi8> from vector<8x8xi8>
+ %lhs5 = vector.extract %lhs[5] : vector<8xi8> from vector<8x8xi8>
+ %lhs6 = vector.extract %lhs[6] : vector<8xi8> from vector<8x8xi8>
+ %lhs7 = vector.extract %lhs[7] : vector<8xi8> from vector<8x8xi8>
+ vector.print %lhs0 : vector<8xi8>
+ vector.print %lhs1 : vector<8xi8>
+ vector.print %lhs2 : vector<8xi8>
+ vector.print %lhs3 : vector<8xi8>
+ vector.print %lhs4 : vector<8xi8>
+ vector.print %lhs5 : vector<8xi8>
+ vector.print %lhs6 : vector<8xi8>
+ vector.print %lhs7 : vector<8xi8>
+
+ // RHS test data
+ %rhs_cst = arith.constant dense<[[-40, -11, -36, 36, -1, 20, 14, -32],
+ [ 46, -45, -48, -46, -24, 31, -36, 22],
+ [ 2, 36, 45, -29, -37, -49, -20, -35],
+ [ -6, 23, 23, 15, 20, 4, -8, -2],
+ [-35, -6, 16, 49, -50, 9, -44, 13],
+ [ 24, 1, -4, -44, 41, 15, -43, 44],
+ [ 44, 0, -10, 41, 22, 44, -40, 0],
+ [-33, 19, 27, 22, 38, -17, 23, -9]]> : vector<8x8xi8>
+
+ %rhs_m = memref.alloca() : memref<8x8xi8>
+ vector.transfer_write %rhs_cst, %rhs_m[%c0, %c0] : vector<8x8xi8>, memref<8x8xi8>
+
+ %rhs_m1 = memref.collapse_shape %rhs_m [[0, 1]] : memref<8x8xi8> into memref<64xi8>
+ %rhs_flat = vector.transfer_read %rhs_m1[%c0], %c0_i8 {in_bounds = [true]} : memref<64xi8>, vector<[32]xi8>
+
+ vector.print str "RHS:\n"
+ %rhs0 = vector.scalable.extract %rhs_flat[ 0] : vector<[16]xi8> from vector<[32]xi8>
+ %rhs1 = vector.scalable.extract %rhs_flat[16] : vector<[16]xi8> from vector<[32]xi8>
+ vector.print %rhs0 : vector<[16]xi8>
+ vector.print %rhs1 : vector<[16]xi8>
+
+ %rhs = vector.shape_cast %rhs_flat : vector<[32]xi8> to vector<[4]x8xi8>
+
+ // Matrix multiplication
+ %0 = arith.extsi %lhs : vector<8x8xi8> to vector<8x8xi32>
+ %1 = arith.extsi %rhs : vector<[4]x8xi8> to vector<[4]x8xi32>
+ %2 = vector.contract {indexing_maps = #packed_maps,
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>} %0, %1, %acc
+ : vector<8x8xi32>, vector<[4]x8xi32> into vector<8x[4]xi32>
+
+ // Display the result of the multilication
+ vector.print str "Result:\n"
+ %u0 = vector.extract %2[0] : vector<[4]xi32> from vector<8x[4]xi32>
+ %u1 = vector.extract %2[1] : vector<[4]xi32> from vector<8x[4]xi32>
+ %u2 = vector.extract %2[2] : vector<[4]xi32> from vector<8x[4]xi32>
+ %u3 = vector.extract %2[3] : vector<[4]xi32> from vector<8x[4]xi32>
+ %u4 = vector.extract %2[4] : vector<[4]xi32> from vector<8x[4]xi32>
+ %u5 = vector.extract %2[5] : vector<[4]xi32> from vector<8x[4]xi32>
+ %u6 = vector.extract %2[6] : vector<[4]xi32> from vector<8x[4]xi32>
+ %u7 = vector.extract %2[7] : vector<[4]xi32> from vector<8x[4]xi32>
+ vector.print %u0 : vector<[4]xi32>
+ vector.print %u1 : vector<[4]xi32>
+ vector.print %u2 : vector<[4]xi32>
+ vector.print %u3 : vector<[4]xi32>
+ vector.print %u4 : vector<[4]xi32>
+ vector.print %u5 : vector<[4]xi32>
+ vector.print %u6 : vector<[4]xi32>
+ vector.print %u7 : vector<[4]xi32>
+
+
+// CHECK: ( -2294, -1282, 2728, -410, -1328, 882, -5498, 732 )
+// CHECK: ( 1012, -4237, 4154, 2624, 5225, -2338, 2011, 1374 )
+// CHECK: ( -8, -1611, 2905, -1, -1068, -3155, -2428, 153 )
+// CHECK: ( 2034, -1768, -2092, 284, -792, -23, 668, 2172 )
+// CHECK: ( -248, -3728, 1214, 555, -668, -2114, -1794, 2560 )
+// CHECK: ( -1484, -2642, 297, 1551, -483, 3173, -576, 2570 )
+// CHECK: ( 3098, -7851, 1366, 1892, -427, -4533, -819, 4698 )
+// CHECK: ( -135, 1247, 765, -479, 1245, 3074, -2281, -23 )
+ return
+}
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-summla-4x8x4.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-summla-4x8x4.mlir
new file mode 100644
index 0000000000000..f1f311ddb0c18
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-summla-4x8x4.mlir
@@ -0,0 +1,118 @@
+// REQUIRES: arm-emulator
+
+// DEFINE: %{compile} = mlir-opt %s \
+// DEFINE: --convert-vector-to-scf --convert-scf-to-cf --convert-vector-to-llvm='enable-arm-sve enable-arm-i8mm' \
+// DEFINE: --expand-strided-metadata --convert-to-llvm --finalize-memref-to-llvm --reconcile-unrealized-casts \
+// DEFINE: -o %t
+
+// DEFINE: %{entry_point} = main
+
+// DEFINE: %{run} = %mcr_aarch64_cmd %t -e %{entry_point} -entry-point-result=void --march=aarch64 --mattr="+sve,+i8mm" \
+// DEFINE: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%native_mlir_arm_runner_utils
+
+// RUN: rm -f %t && %{compile} && %{run} | FileCheck %s
+
+#packed_maps = [
+ affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+]
+
+func.func private @setArmVLBits(%bits : i32)
+
+func.func @main() {
+ %c128 = arith.constant 128 : i32
+ func.call @setArmVLBits(%c128) : (i32) -> ()
+
+ %c0 = arith.constant 0 : index
+ %c0_i32 = arith.constant 0 : i32
+ %c0_i8 = arith.constant 0 : i8
+
+// Accumulator test data
+ %acc_cst = arith.constant dense<[[-44, 20, 44, -46],
+ [ -8, 25, -34, 26],
+ [-20, -36, -3, 39],
+ [-48, -31, -25, -21]]> : vector<4x4xi32>
+ %acc_m = memref.alloca() : memref<4x4xi32>
+ vector.transfer_write %acc_cst, %acc_m[%c0, %c0] : vector<4x4xi32>, memref<4x4xi32>
+
+ %acc_m1 = memref.collapse_shape %acc_m [[0, 1]] : memref<4x4xi32> into memref<16xi32>
+ %acc_flat = vector.transfer_read %acc_m1[%c0], %c0_i32 {in_bounds = [true]} : memref<16xi32>, vector<[16]xi32>
+ %acc = vector.shape_cast %acc_flat : vector<[16]xi32> to vector<4x[4]xi32>
+
+ vector.print str "ACC:\n"
+ %acc0 = vector.extract %acc[0] : vector<[4]xi32> from vector<4x[4]xi32>
+ %acc1 = vector.extract %acc[1] : vector<[4]xi32> from vector<4x[4]xi32>
+ %acc2 = vector.extract %acc[2] : vector<[4]xi32> from vector<4x[4]xi32>
+ %acc3 = vector.extract %acc[3] : vector<[4]xi32> from vector<4x[4]xi32>
+ vector.print %acc0 : vector<[4]xi32>
+ vector.print %acc1 : vector<[4]xi32>
+ vector.print %acc2 : vector<[4]xi32>
+ vector.print %acc3 : vector<[4]xi32>
+
+ // LHS test data
+ %lhs_cst = arith.constant dense<[[-35, -27, -36, -31, 23, -34, -8, -33],
+ [-20, 17, -32, -47, 37, 22, -7, -21],
+ [ -7, -35, 20, -4, 39, 46, -23, 40],
+ [ 40, 27, 37, 43, 38, -6, 37, 49]]> : vector<4x8xi8>
+
+ %lhs_m = memref.alloca() : memref<4x8xi8>
+ vector.transfer_write %lhs_cst, %lhs_m[%c0, %c0] : vector<4x8xi8>, memref<4x8xi8>
+ %lhs = vector.transfer_read %lhs_m[%c0, %c0], %c0_i8 : memref<4x8xi8>, vector<4x8xi8>
+
+ vector.print str "LHS:\n"
+ %lhs0 = vector.extract %lhs[0] : vector<8xi8> from vector<4x8xi8>
+ %lhs1 = vector.extract %lhs[1] : vector<8xi8> from vector<4x8xi8>
+ %lhs2 = vector.extract %lhs[2] : vector<8xi8> from vector<4x8xi8>
+ %lhs3 = vector.extract %lhs[3] : vector<8xi8> from vector<4x8xi8>
+ vector.print %lhs0 : vector<8xi8>
+ vector.print %lhs1 : vector<8xi8>
+ vector.print %lhs2 : vector<8xi8>
+ vector.print %lhs3 : vector<8xi8>
+
+ // RHS test data
+ %rhs_cst = arith.constant dense<[[125, 171, 138, 187, 108, 175, 82, 99],
+ [221, 25, 164, 97, 156, 221, 218, 177],
+ [171, 160, 219, 191, 144, 45, 161, 210],
+ [223, 165, 123, 99, 108, 86, 37, 92]]> : vector<4x8xi8>
+
+ %rhs_m = memref.alloca() : memref<4x8xi8>
+ vector.transfer_write %rhs_cst, %rhs_m[%c0, %c0] : vector<4x8xi8>, memref<4x8xi8>
+
+ %rhs_m1 = memref.collapse_shape %rhs_m [[0, 1]] : memref<4x8xi8> into memref<32xi8>
+ %rhs_flat = vector.transfer_read %rhs_m1[%c0], %c0_i8 {in_bounds = [true]} : memref<32xi8>, vector<[32]xi8>
+
+ vector.print str "RHS:\n"
+ %rhs0 = vector.scalable.extract %rhs_flat[0] : vector<[16]xi8> from vector<[32]xi8>
+ %rhs1 = vector.scalable.extract %rhs_flat[16] : vector<[16]xi8> from vector<[32]xi8>
+ vector.print %rhs0 : vector<[16]xi8>
+ vector.print %rhs1 : vector<[16]xi8>
+
+ %rhs = vector.shape_cast %rhs_flat : vector<[32]xi8> to vector<[4]x8xi8>
+
+ // Matrix multiplication
+ %0 = arith.extsi %lhs : vector<4x8xi8> to vector<4x8xi32>
+ %1 = arith.extui %rhs : vector<[4]x8xi8> to vector<[4]x8xi32>
+ %2 = vector.contract {indexing_maps = #packed_maps,
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>} %0, %1, %acc
+ : vector<4x8xi32>, vector<[4]x8xi32> into vector<4x[4]xi32>
+
+ // Display the result of the multiplication
+ vector.print str "Result:\n"
+ %u0 = vector.extract %2[0] : vector<[4]xi32> from vector<4x[4]xi32>
+ %u1 = vector.extract %2[1] : vector<[4]xi32> from vector<4x[4]xi32>
+ %u2 = vector.extract %2[2] : vector<[4]xi32> from vector<4x[4]xi32>
+ %u3 = vector.extract %2[3] : vector<[4]xi32> from vector<4x[4]xi32>
+ vector.print %u0 : vector<[4]xi32>
+ vector.print %u1 : vector<[4]xi32>
+ vector.print %u2 : vector<[4]xi32>
+ vector.print %u3 : vector<[4]xi32>
+
+// CHECK: ( -27190, -28812, -30502, -23575 )
+// CHECK: ( -7613, -8386, -15938, -6521 )
+// CHECK: ( 9468, 18750, 9199, 5764 )
+// CHECK: ( 33655, 41064, 48900, 31627 )
+ return
+}
+
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-ummla-4x8x4.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-ummla-4x8x4.mlir
new file mode 100644
index 0000000000000..7af0b2c3f1054
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-ummla-4x8x4.mlir
@@ -0,0 +1,119 @@
+// REQUIRES: arm-emulator
+
+// DEFINE: %{compile} = mlir-opt %s ...
[truncated]
|
194c1c7
to
87f2964
Compare
251f93e
to
4ab7765
Compare
Thanks - great to finally be reaching this stage! I have a few high-level questions and suggestions: 1. Why is the scalable dimension always [4]? From the current tests, it looks like the scalable dim is always 2. Reduce duplication in the 4x8x4 tests The current tests differ only in terms of input/output and For example: func.func @main_smmla() {
// Init LHS, RHS, ACC
// CHECK-LINES for LHS
print(lhs);
// CHECK-LINES for RHS
print(rhs);
arith.extsi (lhs)
arith.extsi (rhs)
vector.contract
// CHECK-LINES for ACC
print(acc);
} This would keep the test logic focused and easier to maintain. 3. Add checks for generated IR (LLVM dialect) It would be good to verify that the lowered IR includes the correct SME MMLA intrinsics. For example: // CHECK-COUNT-4: llvm.intr.smmla This would help confirm both correctness and that the expected number of operations are emitted. 4. Consider toggling VL within tests From what I can tell, this would only work if the inputs are generated inside a loop, similar to this example: llvm-project/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/load-vertical.mlir Lines 19 to 37 in 88f61f2
That might be a nice validation of the "scalability" aspect. |
4ab7765
to
2d2bea9
Compare
87f2964
to
a34bec5
Compare
2d2bea9
to
d177af4
Compare
a34bec5
to
915fbf9
Compare
Thanks for all the updates, @momchil-velikov 🙏🏻 Just for context: Momchil and I discussed this offline - in its current form, the test crashes (after producing valid results). I’ve created a minimal repro here: Given that these tests (i.e., SVE e2e tests that require emulation) are only run by clang-aarch64-full-2stage - which we closely monitor - I'm inclined to land this with a temporary workaround (on top of other updates suggested below). Looking at the current version of the tests, I do think there’s room to improve code reuse and to make the test structure more VLA-friendly (Vector-Length Agnostic). For instance, the function below mixes fixed and scalable shapes: func.func private @prepareAccTestData(%in: vector<4x4xi32>) -> vector<4x[4]xi32> {
// ...
} It also returns a scalable vector from a function - something we haven’t independently tested and should probably avoid until validated. Since this comment is getting long and I’m effectively suggesting a refactor (after many solid improvements already), I’ve gone ahead and rewritten the tests myself. Below is a single test file that combines:
Summary of changes:
To me, this format accomplishes 3 key goals:
If this makes sense to you, please feel free to reuse the file. Alternatively, I’d be happy to push it to your branch directly. –Andrzej NEW TEST FILE // REQUIRES: arm-emulator
// DEFINE: %{compile} = mlir-opt %s \
// DEFINE: --convert-vector-to-scf --convert-scf-to-cf --convert-vector-to-llvm='enable-arm-sve enable-arm-i8mm' \
// DEFINE: --expand-strided-metadata --convert-to-llvm --finalize-memref-to-llvm --reconcile-unrealized-casts \
// DEFINE: -o %t
// DEFINE: %{entry_point} = main
// DEFINE: %{run} = %mcr_aarch64_cmd %t -e %{entry_point} -entry-point-result=void --march=aarch64 --mattr="+sve,+i8mm" \
// DEFINE: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%native_mlir_arm_runner_utils
// RUN: rm -f %t && %{compile} && FileCheck %s --input-file=%t -check-prefix CHECK-IR && %{run}
#packed_maps = [
affine_map<(d0, d1, d2) -> (d0, d2)>,
affine_map<(d0, d1, d2) -> (d1, d2)>,
affine_map<(d0, d1, d2) -> (d0, d1)>
]
//=============================================================================
// Helper methods to allocate+initialise test data
//=============================================================================
// Allolocate and initialise a memref of 16 x vscale elements of type: i8. This
// matches the requirments for the accumulator for i8mm, it is precisely
// * 4 x [4] elements.
func.func private @getFlatMemRef_i32() -> memref<?xi32> {
%c0 = arith.constant 0 : index
%c16 = arith.constant 16 : index
%vscale = vector.vscale
%c16_vscale = arith.muli %vscale, %c16 : index
%flat_mem = memref.alloc(%c16_vscale) : memref<?xi32>
%vector_i32 = llvm.intr.stepvector : vector<[16]xi32>
vector.transfer_write %vector_i32, %flat_mem[%c0] : vector<[16]xi32>, memref<?xi32>
return %flat_mem : memref<?xi32>
}
// Allolocate and initialise a memref of 32 x vscale elements of type: i8. This
// matches the requirments for the RHS for i8mm, it is precisely
// * [4] x 8 elements.
func.func private @getFlatMemRef_i8_scalable() -> memref<?xi8> {
%c0 = arith.constant 0 : index
%c32 = arith.constant 32 : index
%vscale = vector.vscale
%vscale_times_32 = arith.muli %vscale, %c32 : index
%flat_mem = memref.alloc(%vscale_times_32) : memref<?xi8>
%vector_i32 = llvm.intr.stepvector : vector<[32]xi8>
vector.transfer_write %vector_i32, %flat_mem[%c0] : vector<[32]xi8>, memref<?xi8>
return %flat_mem : memref<?xi8>
}
// Allolocate and initialise a memref of 32 elements of type: i8. This
// matches the requirments for the RHS for i8mm, it is precisely:
// * 4 x 8 elements.
func.func private @getFlatMemRef_i8_fixed() -> memref<?xi8> {
%c0 = arith.constant 0 : index
%c32 = arith.constant 32 : index
%flat_mem = memref.alloc(%c32) : memref<?xi8>
%vector_i32 = llvm.intr.stepvector : vector<32xi8>
vector.transfer_write %vector_i32, %flat_mem[%c0] : vector<32xi8>, memref<?xi8>
return %flat_mem : memref<?xi8>
}
//=============================================================================
// Main entry point for test.
//=============================================================================
func.func @main() {
// NOTE: Update this value to some other valid value of VL (i.e. supported by
// SVE) to see the impact of "scalability".
// FIXME: https://github.com/llvm/llvm-project/issues/143670
%c128 = arith.constant 128 : i32
func.call @setArmVLBits(%c128) : (i32) -> ()
%c0_idx = arith.constant 0 : index
%c0_i32 = arith.constant 0 : i32
%c0_i8 = arith.constant 0 : i8
//---------------------------------------------------------------------------
// 1. GENERATE TEST DATA
//---------------------------------------------------------------------------
// 1.1. Accumulator test data
%acc_flat = func.call @getFlatMemRef_i32() : () -> memref<?xi32>
%flat_vec = vector.transfer_read %acc_flat[%c0_idx], %c0_i32 {in_bounds = [true]} : memref<?xi32>, vector<[16]xi32>
%acc = vector.shape_cast %flat_vec : vector<[16]xi32> to vector<4x[4]xi32>
// 1.2. LHS test data
%lhs_flat = func.call @getFlatMemRef_i8_fixed() : () -> memref<?xi8>
%lhs_flat_vec = vector.transfer_read %lhs_flat[%c0_idx], %c0_i8 {in_bounds = [true]} : memref<?xi8>, vector<32xi8>
%lhs = vector.shape_cast %lhs_flat_vec : vector<32xi8> to vector<4x8xi8>
// 1.3. RHS test data
%rhs_flat = func.call @getFlatMemRef_i8_scalable() : () -> memref<?xi8>
%rhs_flat_vec = vector.transfer_read %rhs_flat[%c0_idx], %c0_i8 {in_bounds = [true]} : memref<?xi8>, vector<[32]xi8>
%rhs = vector.shape_cast %rhs_flat_vec : vector<[32]xi8> to vector<[4]x8xi8>
//---------------------------------------------------------------------------
// 2. "EXTEND" THE RHS + LHS VECTORS
// This is what i8mm expects.
//---------------------------------------------------------------------------
%lhs_si = arith.extsi %lhs : vector<4x8xi8> to vector<4x8xi32>
%lhs_ui = arith.extui %lhs : vector<4x8xi8> to vector<4x8xi32>
%rhs_si = arith.extsi %rhs : vector<[4]x8xi8> to vector<[4]x8xi32>
%rhs_ui = arith.extui %rhs : vector<[4]x8xi8> to vector<[4]x8xi32>
//---------------------------------------------------------------------------
// 3. MATRIX MULTIPLICATION
//---------------------------------------------------------------------------
// 3.1. SMMLA
// CHECK-IR-COUNT-4: arm_sve.intr.smmla
%res_smmla = vector.contract {indexing_maps = #packed_maps,
iterator_types = ["parallel", "parallel", "reduction"],
kind = #vector.kind<add>} %lhs_si, %rhs_si, %acc
: vector<4x8xi32>, vector<[4]x8xi32> into vector<4x[4]xi32>
// 3.2. UMMLA
// CHECK-IR-COUNT-4: arm_sve.intr.ummla
%res_ummla = vector.contract {indexing_maps = #packed_maps,
iterator_types = ["parallel", "parallel", "reduction"],
kind = #vector.kind<add>} %lhs_ui, %rhs_ui, %acc
: vector<4x8xi32>, vector<[4]x8xi32> into vector<4x[4]xi32>
// 3.3. USMMLA
// CHECK-IR-COUNT-4: arm_sve.intr.usmmla
%res_usmmla = vector.contract {indexing_maps = #packed_maps,
iterator_types = ["parallel", "parallel", "reduction"],
kind = #vector.kind<add>} %lhs_ui, %rhs_si, %acc
: vector<4x8xi32>, vector<[4]x8xi32> into vector<4x[4]xi32>
// 3.4. SUMMLA
// CHECK-IR-COUNT-4: arm_sve.intr.usmmla
%res_summla = vector.contract {indexing_maps = #packed_maps,
iterator_types = ["parallel", "parallel", "reduction"],
kind = #vector.kind<add>} %lhs_si, %rhs_ui, %acc
: vector<4x8xi32>, vector<[4]x8xi32> into vector<4x[4]xi32>
//---------------------------------------------------------------------------
// 4. DISPLAY THE RESULTS OF THE MULTIPLICATION
// TODO: Implement this and use instead:
// * vector.print %2 : vector<4x[4]xi32>
//---------------------------------------------------------------------------
vector.print str "RESULT (smmla):\n"
%s0 = vector.extract %res_smmla[0] : vector<[4]xi32> from vector<4x[4]xi32>
%s1 = vector.extract %res_smmla[1] : vector<[4]xi32> from vector<4x[4]xi32>
%s2 = vector.extract %res_smmla[2] : vector<[4]xi32> from vector<4x[4]xi32>
%s3 = vector.extract %res_smmla[3] : vector<[4]xi32> from vector<4x[4]xi32>
vector.print %s0 : vector<[4]xi32>
vector.print %s1 : vector<[4]xi32>
vector.print %s2 : vector<[4]xi32>
vector.print %s3 : vector<[4]xi32>
vector.print str "RESULT (ummla):\n"
%u0 = vector.extract %res_ummla[0] : vector<[4]xi32> from vector<4x[4]xi32>
%u1 = vector.extract %res_ummla[1] : vector<[4]xi32> from vector<4x[4]xi32>
%u2 = vector.extract %res_ummla[2] : vector<[4]xi32> from vector<4x[4]xi32>
%u3 = vector.extract %res_ummla[3] : vector<[4]xi32> from vector<4x[4]xi32>
vector.print %u0 : vector<[4]xi32>
vector.print %u1 : vector<[4]xi32>
vector.print %u2 : vector<[4]xi32>
vector.print %u3 : vector<[4]xi32>
vector.print str "RESULT (usmmla):\n"
%us0 = vector.extract %res_usmmla[0] : vector<[4]xi32> from vector<4x[4]xi32>
%us1 = vector.extract %res_usmmla[1] : vector<[4]xi32> from vector<4x[4]xi32>
%us2 = vector.extract %res_usmmla[2] : vector<[4]xi32> from vector<4x[4]xi32>
%us3 = vector.extract %res_usmmla[3] : vector<[4]xi32> from vector<4x[4]xi32>
vector.print %us0 : vector<[4]xi32>
vector.print %us1 : vector<[4]xi32>
vector.print %us2 : vector<[4]xi32>
vector.print %us3 : vector<[4]xi32>
vector.print str "RESULT (summla):\n"
%su0 = vector.extract %res_summla[0] : vector<[4]xi32> from vector<4x[4]xi32>
%su1 = vector.extract %res_summla[1] : vector<[4]xi32> from vector<4x[4]xi32>
%su2 = vector.extract %res_summla[2] : vector<[4]xi32> from vector<4x[4]xi32>
%su3 = vector.extract %res_summla[3] : vector<[4]xi32> from vector<4x[4]xi32>
vector.print %su0 : vector<[4]xi32>
vector.print %su1 : vector<[4]xi32>
vector.print %su2 : vector<[4]xi32>
vector.print %su3 : vector<[4]xi32>
// With all inputs positive, the results are identical for types of extensions.
// TOOD: Use negative numbers to demonstrate the run-time difference between e.g. UMMLA and SMMLA.
//CHECK-4: ( 140, 365, 590, 815, 1040, 1265, 1490, 1715 )
//CHECK-4: ( 372, 1109, 1846, 2583, 3320, 4057, 4794, 5531 )
//CHECK-4: ( 604, 1853, 3102, 4351, 5600, 6849, 8098, 9347 )
//CHECK-4: ( 836, 2597, 4358, 6119, 7880, 9641, 11402, 13163 )
//---------------------------------------------------------------------------
// 5. WORKAROUND
// This extra printing should not be required, but the test crashes without it.
// FIXME: https://github.com/llvm/llvm-project/issues/143670
//---------------------------------------------------------------------------
%res_smmla_flat = vector.shape_cast %res_smmla : vector<4x[4]xi32> to vector<[16]xi32>
vector.transfer_write %res_smmla_flat, %acc_flat[%c0_idx] : vector<[16]xi32>, memref<?xi32>
%acc_cast = memref.cast %acc_flat : memref<?xi32> to memref<*xi32>
call @printMemrefI32(%acc_cast) : (memref<*xi32>) -> ()
//---------------------------------------------------------------------------
// 6. BUFFER DEALLOCATION
//---------------------------------------------------------------------------
memref.dealloc %acc_flat : memref<?xi32>
memref.dealloc %rhs_flat : memref<?xi8>
memref.dealloc %lhs_flat : memref<?xi8>
return
}
func.func private @printMemrefI32(%ptr : memref<*xi32>)
func.func private @setArmVLBits(%bits : i32) |
Thanks for the workaround, I applied it and finished the refactoring.
Initially, no reason other then keep test data from eating too much screen space. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great updates, thank you! I really like how this is checking both VL=128
and VL=256
.
The overall logic LGTM, but since the underlying concepts are pretty complex (scalable vectors + i8mm), it deserves considerably more documentation. Suggestions inline.
mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/vector-contract-i8mm.mlir
Outdated
Show resolved
Hide resolved
mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/vector-contract-i8mm.mlir
Show resolved
Hide resolved
mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/vector-contract-i8mm.mlir
Show resolved
Hide resolved
mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/vector-contract-i8mm.mlir
Show resolved
Hide resolved
mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/vector-contract-i8mm.mlir
Show resolved
Hide resolved
mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/vector-contract-i8mm.mlir
Show resolved
Hide resolved
mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/vector-contract-i8mm.mlir
Outdated
Show resolved
Hide resolved
mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/vector-contract-i8mm.mlir
Outdated
Show resolved
Hide resolved
mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/vector-contract-i8mm.mlir
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM % some clarifications in the docs.
I appreciate that my comment is rather long, but that's just to clarify my view-point/rationale. The actual request is small. Also, to avoid confusion - I posted that in a thread that I "unresolved" (GitHub can be "funny" how it displays things).
[nit] Use m
, n
and k
in affine maps as we do here https://mlir.llvm.org/docs/Dialects/Linalg/#linalgmatmul-linalgmatmulop (and in a few other places).
No description provided.