Skip to content

[MLIR] Integration tests for lowering vector.contract to SVE FEAT_I8MM #140573

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

momchil-velikov
Copy link
Collaborator

No description provided.

@llvmbot
Copy link
Member

llvmbot commented May 19, 2025

@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir-sve

Author: Momchil Velikov (momchil-velikov)

Changes

Patch is 30.80 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/140573.diff

5 Files Affected:

  • (added) mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-4x8x4.mlir (+117)
  • (added) mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-8x8x8-vs2.mlir (+159)
  • (added) mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-summla-4x8x4.mlir (+118)
  • (added) mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-ummla-4x8x4.mlir (+119)
  • (added) mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-usmmla-4x8x4.mlir (+117)
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-4x8x4.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-4x8x4.mlir
new file mode 100644
index 0000000000000..88534dd2aab1e
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-4x8x4.mlir
@@ -0,0 +1,117 @@
+// REQUIRES: arm-emulator
+
+// DEFINE: %{compile} = mlir-opt %s \
+// DEFINE:   --convert-vector-to-scf --convert-scf-to-cf  --convert-vector-to-llvm='enable-arm-sve enable-arm-i8mm' \
+// DEFINE:   --expand-strided-metadata --convert-to-llvm --finalize-memref-to-llvm  --reconcile-unrealized-casts \
+// DEFINE: -o %t
+
+// DEFINE: %{entry_point} = main
+
+// DEFINE: %{run} = %mcr_aarch64_cmd %t -e %{entry_point} -entry-point-result=void  --march=aarch64 --mattr="+sve,+i8mm" \
+// DEFINE:    -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%native_mlir_arm_runner_utils
+
+// RUN: rm -f %t && %{compile} && %{run} | FileCheck %s
+
+#packed_maps = [
+  affine_map<(d0, d1, d2) -> (d0, d2)>,
+  affine_map<(d0, d1, d2) -> (d1, d2)>,
+  affine_map<(d0, d1, d2) -> (d0, d1)>
+]
+
+func.func private @setArmVLBits(%bits : i32)
+
+func.func @main() {
+  %c128 = arith.constant 128 : i32
+  func.call @setArmVLBits(%c128) : (i32) -> ()
+
+  %c0 = arith.constant 0 : index
+  %c0_i32 = arith.constant 0 : i32
+  %c0_i8 = arith.constant 0 : i8
+
+// Accumulator test data
+  %acc_cst = arith.constant dense<[[-44,  20,  44, -46],
+                                   [ -8,  25, -34,  26],
+                                   [-20, -36,  -3,  39],
+                                   [-48, -31, -25, -21]]> : vector<4x4xi32>
+  %acc_m = memref.alloca() : memref<4x4xi32>
+  vector.transfer_write %acc_cst, %acc_m[%c0, %c0] : vector<4x4xi32>, memref<4x4xi32>
+
+  %acc_m1 = memref.collapse_shape %acc_m [[0, 1]] : memref<4x4xi32> into memref<16xi32>
+  %acc_flat = vector.transfer_read %acc_m1[%c0], %c0_i32 {in_bounds = [true]} : memref<16xi32>, vector<[16]xi32>
+  %acc = vector.shape_cast %acc_flat : vector<[16]xi32> to vector<4x[4]xi32>
+
+  vector.print str "ACC:\n"
+  %acc0 = vector.extract %acc[0] : vector<[4]xi32> from vector<4x[4]xi32>
+  %acc1 = vector.extract %acc[1] : vector<[4]xi32> from vector<4x[4]xi32>
+  %acc2 = vector.extract %acc[2] : vector<[4]xi32> from vector<4x[4]xi32>
+  %acc3 = vector.extract %acc[3] : vector<[4]xi32> from vector<4x[4]xi32>
+  vector.print %acc0 : vector<[4]xi32>
+  vector.print %acc1 : vector<[4]xi32>
+  vector.print %acc2 : vector<[4]xi32>
+  vector.print %acc3 : vector<[4]xi32>
+
+  // LHS test data
+  %lhs_cst = arith.constant dense<[[-35, -27, -36, -31,  23, -34,  -8, -33],
+                                   [-20,  17, -32, -47,  37,  22,  -7, -21],
+                                   [ -7, -35,  20,  -4,  39,  46, -23,  40],
+                                   [ 40,  27,  37,  43,  38,  -6,  37,  49]]> : vector<4x8xi8>
+
+  %lhs_m = memref.alloca() : memref<4x8xi8>
+  vector.transfer_write %lhs_cst, %lhs_m[%c0, %c0] : vector<4x8xi8>, memref<4x8xi8>
+  %lhs = vector.transfer_read %lhs_m[%c0, %c0], %c0_i8 : memref<4x8xi8>, vector<4x8xi8>
+
+  vector.print str "LHS:\n"
+  %lhs0 = vector.extract %lhs[0] : vector<8xi8> from vector<4x8xi8>
+  %lhs1 = vector.extract %lhs[1] : vector<8xi8> from vector<4x8xi8>
+  %lhs2 = vector.extract %lhs[2] : vector<8xi8> from vector<4x8xi8>
+  %lhs3 = vector.extract %lhs[3] : vector<8xi8> from vector<4x8xi8>
+  vector.print %lhs0 : vector<8xi8>
+  vector.print %lhs1 : vector<8xi8>
+  vector.print %lhs2 : vector<8xi8>
+  vector.print %lhs3 : vector<8xi8>
+
+  // RHS test data
+  %rhs_cst = arith.constant dense<[[-17, -50,  -1,  48, -13,  22,  39,  33],
+                                   [-35, -24,  37, -32,  33,  30, -11, -17],
+                                   [-28,  31,   3, -44, -15, -27,  22,  35],
+                                   [-23,  39,  48,  26, -23,  32, -39, -38]]> : vector<4x8xi8>
+
+  %rhs_m = memref.alloca() : memref<4x8xi8>
+  vector.transfer_write %rhs_cst, %rhs_m[%c0, %c0] : vector<4x8xi8>, memref<4x8xi8>
+
+  %rhs_m1 = memref.collapse_shape %rhs_m [[0, 1]] : memref<4x8xi8> into memref<32xi8>
+  %rhs_flat = vector.transfer_read %rhs_m1[%c0], %c0_i8 {in_bounds = [true]} : memref<32xi8>, vector<[32]xi8>
+
+  vector.print str "RHS:\n"
+  %rhs0 = vector.scalable.extract %rhs_flat[0] : vector<[16]xi8> from vector<[32]xi8>
+  %rhs1 = vector.scalable.extract %rhs_flat[16] : vector<[16]xi8> from vector<[32]xi8>
+  vector.print %rhs0 : vector<[16]xi8>
+  vector.print %rhs1 : vector<[16]xi8>
+
+  %rhs = vector.shape_cast %rhs_flat : vector<[32]xi8> to vector<[4]x8xi8>
+
+  // Matrix multiplication
+  %0 = arith.extsi %lhs : vector<4x8xi8> to vector<4x8xi32>
+  %1 = arith.extsi %rhs : vector<[4]x8xi8> to vector<[4]x8xi32>
+  %2 = vector.contract {indexing_maps = #packed_maps,
+                        iterator_types = ["parallel", "parallel", "reduction"],
+                        kind = #vector.kind<add>} %0, %1, %acc
+    : vector<4x8xi32>, vector<[4]x8xi32> into vector<4x[4]xi32>
+
+  // Display the result of the multiplication
+  vector.print str "Result:\n"
+  %u0 = vector.extract %2[0] : vector<[4]xi32> from vector<4x[4]xi32>
+  %u1 = vector.extract %2[1] : vector<[4]xi32> from vector<4x[4]xi32>
+  %u2 = vector.extract %2[2] : vector<[4]xi32> from vector<4x[4]xi32>
+  %u3 = vector.extract %2[3] : vector<[4]xi32> from vector<4x[4]xi32>
+  vector.print %u0 : vector<[4]xi32>
+  vector.print %u1 : vector<[4]xi32>
+  vector.print %u2 : vector<[4]xi32>
+  vector.print %u3 : vector<[4]xi32>
+
+// CHECK: ( -1999,  1941,   685, -2879 )
+// CHECK: ( -3705,  2952,   987,  -685 )
+// CHECK: (  2565,  4157, -1589,  -357 )
+// CHECK: (  2383, -2252,    32, -1365 )
+  return
+}
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-8x8x8-vs2.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-8x8x8-vs2.mlir
new file mode 100644
index 0000000000000..ce57be91fa540
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-8x8x8-vs2.mlir
@@ -0,0 +1,159 @@
+// REQUIRES: arm-emulator
+
+// DEFINE: %{compile} = mlir-opt %s \
+// DEFINE:   --convert-vector-to-scf --convert-scf-to-cf  --convert-vector-to-llvm='enable-arm-sve enable-arm-i8mm' \
+// DEFINE:   --expand-strided-metadata --convert-to-llvm --finalize-memref-to-llvm  --reconcile-unrealized-casts \
+// DEFINE: -o %t
+
+// DEFINE: %{entry_point} = main
+
+// DEFINE: %{run} = %mcr_aarch64_cmd %t -e %{entry_point} -entry-point-result=void  --march=aarch64 --mattr="+sve,+i8mm" \
+// DEFINE:    -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%native_mlir_arm_runner_utils
+
+// RUN: rm -f %t && %{compile} && %{run} | FileCheck %s
+
+#packed_maps = [
+  affine_map<(d0, d1, d2) -> (d0, d2)>,
+  affine_map<(d0, d1, d2) -> (d1, d2)>,
+  affine_map<(d0, d1, d2) -> (d0, d1)>
+]
+
+func.func private @setArmVLBits(%bits : i32)
+
+func.func @main() {
+  %c256 = arith.constant 256 : i32
+  func.call @setArmVLBits(%c256) : (i32) -> ()
+
+  %c0 = arith.constant 0 : index
+  %c0_i32 = arith.constant 0 : i32
+  %c0_i8 = arith.constant 0 : i8
+
+
+  // Accumulator test data
+  %acc_cst = arith.constant dense<[[-44,  20,  44, -46,  -8,  25, -34,  26],
+                                   [-20, -36,  -3,  39, -48, -31, -25, -21],
+                                   [-35, -27, -36, -31,  23, -34,  -8, -33],
+                                   [-20,  17, -32, -47,  37,  22,  -7, -21],
+                                   [ -7, -35,  20,  -4,  39,  46, -23,  40],
+                                   [ 40,  27,  37,  43,  38,  -6,  37,  49],
+                                   [-17, -50,  -1,  48, -13,  22,  39,  33],
+                                   [-35, -24,  37, -32,  33,  30, -11, -17]]> : vector<8x8xi32>
+  %acc_m = memref.alloca() : memref<8x8xi32>
+  vector.transfer_write %acc_cst, %acc_m[%c0, %c0] : vector<8x8xi32>, memref<8x8xi32>
+
+  %acc_m1 = memref.collapse_shape %acc_m [[0, 1]] : memref<8x8xi32> into memref<64xi32>
+  %acc_flat = vector.transfer_read %acc_m1[%c0], %c0_i32 {in_bounds = [true]} : memref<64xi32>, vector<[32]xi32>
+  %acc = vector.shape_cast %acc_flat : vector<[32]xi32> to vector<8x[4]xi32>
+
+  vector.print str "ACC:\n"
+  %acc0 = vector.extract %acc[0] : vector<[4]xi32> from vector<8x[4]xi32>
+  %acc1 = vector.extract %acc[1] : vector<[4]xi32> from vector<8x[4]xi32>
+  %acc2 = vector.extract %acc[2] : vector<[4]xi32> from vector<8x[4]xi32>
+  %acc3 = vector.extract %acc[3] : vector<[4]xi32> from vector<8x[4]xi32>
+  %acc4 = vector.extract %acc[4] : vector<[4]xi32> from vector<8x[4]xi32>
+  %acc5 = vector.extract %acc[5] : vector<[4]xi32> from vector<8x[4]xi32>
+  %acc6 = vector.extract %acc[6] : vector<[4]xi32> from vector<8x[4]xi32>
+  %acc7 = vector.extract %acc[7] : vector<[4]xi32> from vector<8x[4]xi32>
+  vector.print %acc0 : vector<[4]xi32>
+  vector.print %acc1 : vector<[4]xi32>
+  vector.print %acc2 : vector<[4]xi32>
+  vector.print %acc3 : vector<[4]xi32>
+  vector.print %acc4 : vector<[4]xi32>
+  vector.print %acc5 : vector<[4]xi32>
+  vector.print %acc6 : vector<[4]xi32>
+  vector.print %acc7 : vector<[4]xi32>
+
+  // LHS test data
+  %lhs_cst = arith.constant dense<[[-28,  31,   3, -44, -15, -27,  22,  35],
+                                   [-23,  39,  48,  26, -23,  32, -39, -38],
+                                   [ -3,   9,  43, -30, -32,  39,  41, -39],
+                                   [-13, -21, -25,  27,  47, -36, -11, -11],
+                                   [ -4, -20,  36,  11,  13, -23,  24, -13],
+                                   [-20,  30,  -5,   1,  42, -37, -22,  35],
+                                   [-22,  38,  -4,  44,  25, -31,  23, -39],
+                                   [-45,  -4, -31, -24,  14, -41, -47,  22]]> : vector<8x8xi8>
+
+  %lhs_m = memref.alloca() : memref<8x8xi8>
+  vector.transfer_write %lhs_cst, %lhs_m[%c0, %c0] : vector<8x8xi8>, memref<8x8xi8>
+  %lhs = vector.transfer_read %lhs_m[%c0, %c0], %c0_i8 : memref<8x8xi8>, vector<8x8xi8>
+
+  vector.print str "LHS:\n"
+  %lhs0 = vector.extract %lhs[0] : vector<8xi8> from vector<8x8xi8>
+  %lhs1 = vector.extract %lhs[1] : vector<8xi8> from vector<8x8xi8>
+  %lhs2 = vector.extract %lhs[2] : vector<8xi8> from vector<8x8xi8>
+  %lhs3 = vector.extract %lhs[3] : vector<8xi8> from vector<8x8xi8>
+  %lhs4 = vector.extract %lhs[4] : vector<8xi8> from vector<8x8xi8>
+  %lhs5 = vector.extract %lhs[5] : vector<8xi8> from vector<8x8xi8>
+  %lhs6 = vector.extract %lhs[6] : vector<8xi8> from vector<8x8xi8>
+  %lhs7 = vector.extract %lhs[7] : vector<8xi8> from vector<8x8xi8>
+  vector.print %lhs0 : vector<8xi8>
+  vector.print %lhs1 : vector<8xi8>
+  vector.print %lhs2 : vector<8xi8>
+  vector.print %lhs3 : vector<8xi8>
+  vector.print %lhs4 : vector<8xi8>
+  vector.print %lhs5 : vector<8xi8>
+  vector.print %lhs6 : vector<8xi8>
+  vector.print %lhs7 : vector<8xi8>
+
+  // RHS test data
+  %rhs_cst = arith.constant dense<[[-40, -11, -36,  36,  -1,  20,  14, -32],
+                                   [ 46, -45, -48, -46, -24,  31, -36,  22],
+                                   [  2,  36,  45, -29, -37, -49, -20, -35],
+                                   [ -6,  23,  23,  15,  20,   4,  -8,  -2],
+                                   [-35,  -6,  16,  49, -50,   9, -44,  13],
+                                   [ 24,   1,  -4, -44,  41,  15, -43,  44],
+                                   [ 44,   0, -10,  41,  22,  44, -40,   0],
+                                   [-33,  19,  27,  22,  38, -17,  23,  -9]]> : vector<8x8xi8>
+
+  %rhs_m = memref.alloca() : memref<8x8xi8>
+  vector.transfer_write %rhs_cst, %rhs_m[%c0, %c0] : vector<8x8xi8>, memref<8x8xi8>
+
+  %rhs_m1 = memref.collapse_shape %rhs_m [[0, 1]] : memref<8x8xi8> into memref<64xi8>
+  %rhs_flat = vector.transfer_read %rhs_m1[%c0], %c0_i8 {in_bounds = [true]} : memref<64xi8>, vector<[32]xi8>
+
+  vector.print str "RHS:\n"
+  %rhs0 = vector.scalable.extract %rhs_flat[ 0] : vector<[16]xi8> from vector<[32]xi8>
+  %rhs1 = vector.scalable.extract %rhs_flat[16] : vector<[16]xi8> from vector<[32]xi8>
+  vector.print %rhs0 : vector<[16]xi8>
+  vector.print %rhs1 : vector<[16]xi8>
+
+  %rhs = vector.shape_cast %rhs_flat : vector<[32]xi8> to vector<[4]x8xi8>
+
+  // Matrix multiplication
+  %0 = arith.extsi %lhs : vector<8x8xi8> to vector<8x8xi32>
+  %1 = arith.extsi %rhs : vector<[4]x8xi8> to vector<[4]x8xi32>
+  %2 = vector.contract {indexing_maps = #packed_maps,
+                        iterator_types = ["parallel", "parallel", "reduction"],
+                        kind = #vector.kind<add>} %0, %1, %acc
+    : vector<8x8xi32>, vector<[4]x8xi32> into vector<8x[4]xi32>
+
+  // Display the result of the multilication
+  vector.print str "Result:\n"
+  %u0 = vector.extract %2[0] : vector<[4]xi32> from vector<8x[4]xi32>
+  %u1 = vector.extract %2[1] : vector<[4]xi32> from vector<8x[4]xi32>
+  %u2 = vector.extract %2[2] : vector<[4]xi32> from vector<8x[4]xi32>
+  %u3 = vector.extract %2[3] : vector<[4]xi32> from vector<8x[4]xi32>
+  %u4 = vector.extract %2[4] : vector<[4]xi32> from vector<8x[4]xi32>
+  %u5 = vector.extract %2[5] : vector<[4]xi32> from vector<8x[4]xi32>
+  %u6 = vector.extract %2[6] : vector<[4]xi32> from vector<8x[4]xi32>
+  %u7 = vector.extract %2[7] : vector<[4]xi32> from vector<8x[4]xi32>
+  vector.print %u0 : vector<[4]xi32>
+  vector.print %u1 : vector<[4]xi32>
+  vector.print %u2 : vector<[4]xi32>
+  vector.print %u3 : vector<[4]xi32>
+  vector.print %u4 : vector<[4]xi32>
+  vector.print %u5 : vector<[4]xi32>
+  vector.print %u6 : vector<[4]xi32>
+  vector.print %u7 : vector<[4]xi32>
+
+
+// CHECK: ( -2294, -1282,  2728,  -410, -1328,   882, -5498,   732 )
+// CHECK: (  1012, -4237,  4154,  2624,  5225, -2338,  2011,  1374 )
+// CHECK: (    -8, -1611,  2905,    -1, -1068, -3155, -2428,   153 )
+// CHECK: (  2034, -1768, -2092,   284,  -792,   -23,   668,  2172 )
+// CHECK: (  -248, -3728,  1214,   555,  -668, -2114, -1794,  2560 )
+// CHECK: ( -1484, -2642,   297,  1551,  -483,  3173,  -576,  2570 )
+// CHECK: (  3098, -7851,  1366,  1892,  -427, -4533,  -819,  4698 )
+// CHECK: (  -135,  1247,   765,  -479,  1245,  3074, -2281,   -23 )
+  return
+}
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-summla-4x8x4.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-summla-4x8x4.mlir
new file mode 100644
index 0000000000000..f1f311ddb0c18
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-summla-4x8x4.mlir
@@ -0,0 +1,118 @@
+// REQUIRES: arm-emulator
+
+// DEFINE: %{compile} = mlir-opt %s \
+// DEFINE:   --convert-vector-to-scf --convert-scf-to-cf  --convert-vector-to-llvm='enable-arm-sve enable-arm-i8mm' \
+// DEFINE:   --expand-strided-metadata --convert-to-llvm --finalize-memref-to-llvm  --reconcile-unrealized-casts \
+// DEFINE: -o %t
+
+// DEFINE: %{entry_point} = main
+
+// DEFINE: %{run} = %mcr_aarch64_cmd %t -e %{entry_point} -entry-point-result=void  --march=aarch64 --mattr="+sve,+i8mm" \
+// DEFINE:    -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%native_mlir_arm_runner_utils
+
+// RUN: rm -f %t && %{compile} && %{run} | FileCheck %s
+
+#packed_maps = [
+  affine_map<(d0, d1, d2) -> (d0, d2)>,
+  affine_map<(d0, d1, d2) -> (d1, d2)>,
+  affine_map<(d0, d1, d2) -> (d0, d1)>
+]
+
+func.func private @setArmVLBits(%bits : i32)
+
+func.func @main() {
+  %c128 = arith.constant 128 : i32
+  func.call @setArmVLBits(%c128) : (i32) -> ()
+
+  %c0 = arith.constant 0 : index
+  %c0_i32 = arith.constant 0 : i32
+  %c0_i8 = arith.constant 0 : i8
+
+// Accumulator test data
+  %acc_cst = arith.constant dense<[[-44,  20,  44, -46],
+                                   [ -8,  25, -34,  26],
+                                   [-20, -36,  -3,  39],
+                                   [-48, -31, -25, -21]]> : vector<4x4xi32>
+  %acc_m = memref.alloca() : memref<4x4xi32>
+  vector.transfer_write %acc_cst, %acc_m[%c0, %c0] : vector<4x4xi32>, memref<4x4xi32>
+
+  %acc_m1 = memref.collapse_shape %acc_m [[0, 1]] : memref<4x4xi32> into memref<16xi32>
+  %acc_flat = vector.transfer_read %acc_m1[%c0], %c0_i32 {in_bounds = [true]} : memref<16xi32>, vector<[16]xi32>
+  %acc = vector.shape_cast %acc_flat : vector<[16]xi32> to vector<4x[4]xi32>
+
+  vector.print str "ACC:\n"
+  %acc0 = vector.extract %acc[0] : vector<[4]xi32> from vector<4x[4]xi32>
+  %acc1 = vector.extract %acc[1] : vector<[4]xi32> from vector<4x[4]xi32>
+  %acc2 = vector.extract %acc[2] : vector<[4]xi32> from vector<4x[4]xi32>
+  %acc3 = vector.extract %acc[3] : vector<[4]xi32> from vector<4x[4]xi32>
+  vector.print %acc0 : vector<[4]xi32>
+  vector.print %acc1 : vector<[4]xi32>
+  vector.print %acc2 : vector<[4]xi32>
+  vector.print %acc3 : vector<[4]xi32>
+
+  // LHS test data
+  %lhs_cst = arith.constant dense<[[-35, -27, -36, -31,  23, -34,  -8, -33],
+                                   [-20,  17, -32, -47,  37,  22,  -7, -21],
+                                   [ -7, -35,  20,  -4,  39,  46, -23,  40],
+                                   [ 40,  27,  37,  43,  38,  -6,  37,  49]]> : vector<4x8xi8>
+
+  %lhs_m = memref.alloca() : memref<4x8xi8>
+  vector.transfer_write %lhs_cst, %lhs_m[%c0, %c0] : vector<4x8xi8>, memref<4x8xi8>
+  %lhs = vector.transfer_read %lhs_m[%c0, %c0], %c0_i8 : memref<4x8xi8>, vector<4x8xi8>
+
+  vector.print str "LHS:\n"
+  %lhs0 = vector.extract %lhs[0] : vector<8xi8> from vector<4x8xi8>
+  %lhs1 = vector.extract %lhs[1] : vector<8xi8> from vector<4x8xi8>
+  %lhs2 = vector.extract %lhs[2] : vector<8xi8> from vector<4x8xi8>
+  %lhs3 = vector.extract %lhs[3] : vector<8xi8> from vector<4x8xi8>
+  vector.print %lhs0 : vector<8xi8>
+  vector.print %lhs1 : vector<8xi8>
+  vector.print %lhs2 : vector<8xi8>
+  vector.print %lhs3 : vector<8xi8>
+
+  // RHS test data
+  %rhs_cst = arith.constant dense<[[125, 171, 138, 187, 108, 175,  82,  99],
+                                   [221,  25, 164,  97, 156, 221, 218, 177],
+                                   [171, 160, 219, 191, 144,  45, 161, 210],
+                                   [223, 165, 123,  99, 108,  86,  37,  92]]> : vector<4x8xi8>
+
+  %rhs_m = memref.alloca() : memref<4x8xi8>
+  vector.transfer_write %rhs_cst, %rhs_m[%c0, %c0] : vector<4x8xi8>, memref<4x8xi8>
+
+  %rhs_m1 = memref.collapse_shape %rhs_m [[0, 1]] : memref<4x8xi8> into memref<32xi8>
+  %rhs_flat = vector.transfer_read %rhs_m1[%c0], %c0_i8 {in_bounds = [true]} : memref<32xi8>, vector<[32]xi8>
+
+  vector.print str "RHS:\n"
+  %rhs0 = vector.scalable.extract %rhs_flat[0] : vector<[16]xi8> from vector<[32]xi8>
+  %rhs1 = vector.scalable.extract %rhs_flat[16] : vector<[16]xi8> from vector<[32]xi8>
+  vector.print %rhs0 : vector<[16]xi8>
+  vector.print %rhs1 : vector<[16]xi8>
+
+  %rhs = vector.shape_cast %rhs_flat : vector<[32]xi8> to vector<[4]x8xi8>
+
+  // Matrix multiplication
+  %0 = arith.extsi %lhs : vector<4x8xi8> to vector<4x8xi32>
+  %1 = arith.extui %rhs : vector<[4]x8xi8> to vector<[4]x8xi32>
+  %2 = vector.contract {indexing_maps = #packed_maps,
+                        iterator_types = ["parallel", "parallel", "reduction"],
+                        kind = #vector.kind<add>} %0, %1, %acc
+    : vector<4x8xi32>, vector<[4]x8xi32> into vector<4x[4]xi32>
+
+  // Display the result of the multiplication
+  vector.print str "Result:\n"
+  %u0 = vector.extract %2[0] : vector<[4]xi32> from vector<4x[4]xi32>
+  %u1 = vector.extract %2[1] : vector<[4]xi32> from vector<4x[4]xi32>
+  %u2 = vector.extract %2[2] : vector<[4]xi32> from vector<4x[4]xi32>
+  %u3 = vector.extract %2[3] : vector<[4]xi32> from vector<4x[4]xi32>
+  vector.print %u0 : vector<[4]xi32>
+  vector.print %u1 : vector<[4]xi32>
+  vector.print %u2 : vector<[4]xi32>
+  vector.print %u3 : vector<[4]xi32>
+
+// CHECK: ( -27190, -28812, -30502, -23575 )
+// CHECK: (  -7613,  -8386, -15938,  -6521 )
+// CHECK: (   9468,  18750,   9199,   5764 )
+// CHECK: (  33655,  41064,  48900,  31627 )
+  return
+}
+
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-ummla-4x8x4.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-ummla-4x8x4.mlir
new file mode 100644
index 0000000000000..7af0b2c3f1054
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-ummla-4x8x4.mlir
@@ -0,0 +1,119 @@
+// REQUIRES: arm-emulator
+
+// DEFINE: %{compile} = mlir-opt %s ...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented May 19, 2025

@llvm/pr-subscribers-mlir

Author: Momchil Velikov (momchil-velikov)

Changes

Patch is 30.80 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/140573.diff

5 Files Affected:

  • (added) mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-4x8x4.mlir (+117)
  • (added) mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-8x8x8-vs2.mlir (+159)
  • (added) mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-summla-4x8x4.mlir (+118)
  • (added) mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-ummla-4x8x4.mlir (+119)
  • (added) mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-usmmla-4x8x4.mlir (+117)
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-4x8x4.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-4x8x4.mlir
new file mode 100644
index 0000000000000..88534dd2aab1e
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-4x8x4.mlir
@@ -0,0 +1,117 @@
+// REQUIRES: arm-emulator
+
+// DEFINE: %{compile} = mlir-opt %s \
+// DEFINE:   --convert-vector-to-scf --convert-scf-to-cf  --convert-vector-to-llvm='enable-arm-sve enable-arm-i8mm' \
+// DEFINE:   --expand-strided-metadata --convert-to-llvm --finalize-memref-to-llvm  --reconcile-unrealized-casts \
+// DEFINE: -o %t
+
+// DEFINE: %{entry_point} = main
+
+// DEFINE: %{run} = %mcr_aarch64_cmd %t -e %{entry_point} -entry-point-result=void  --march=aarch64 --mattr="+sve,+i8mm" \
+// DEFINE:    -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%native_mlir_arm_runner_utils
+
+// RUN: rm -f %t && %{compile} && %{run} | FileCheck %s
+
+#packed_maps = [
+  affine_map<(d0, d1, d2) -> (d0, d2)>,
+  affine_map<(d0, d1, d2) -> (d1, d2)>,
+  affine_map<(d0, d1, d2) -> (d0, d1)>
+]
+
+func.func private @setArmVLBits(%bits : i32)
+
+func.func @main() {
+  %c128 = arith.constant 128 : i32
+  func.call @setArmVLBits(%c128) : (i32) -> ()
+
+  %c0 = arith.constant 0 : index
+  %c0_i32 = arith.constant 0 : i32
+  %c0_i8 = arith.constant 0 : i8
+
+// Accumulator test data
+  %acc_cst = arith.constant dense<[[-44,  20,  44, -46],
+                                   [ -8,  25, -34,  26],
+                                   [-20, -36,  -3,  39],
+                                   [-48, -31, -25, -21]]> : vector<4x4xi32>
+  %acc_m = memref.alloca() : memref<4x4xi32>
+  vector.transfer_write %acc_cst, %acc_m[%c0, %c0] : vector<4x4xi32>, memref<4x4xi32>
+
+  %acc_m1 = memref.collapse_shape %acc_m [[0, 1]] : memref<4x4xi32> into memref<16xi32>
+  %acc_flat = vector.transfer_read %acc_m1[%c0], %c0_i32 {in_bounds = [true]} : memref<16xi32>, vector<[16]xi32>
+  %acc = vector.shape_cast %acc_flat : vector<[16]xi32> to vector<4x[4]xi32>
+
+  vector.print str "ACC:\n"
+  %acc0 = vector.extract %acc[0] : vector<[4]xi32> from vector<4x[4]xi32>
+  %acc1 = vector.extract %acc[1] : vector<[4]xi32> from vector<4x[4]xi32>
+  %acc2 = vector.extract %acc[2] : vector<[4]xi32> from vector<4x[4]xi32>
+  %acc3 = vector.extract %acc[3] : vector<[4]xi32> from vector<4x[4]xi32>
+  vector.print %acc0 : vector<[4]xi32>
+  vector.print %acc1 : vector<[4]xi32>
+  vector.print %acc2 : vector<[4]xi32>
+  vector.print %acc3 : vector<[4]xi32>
+
+  // LHS test data
+  %lhs_cst = arith.constant dense<[[-35, -27, -36, -31,  23, -34,  -8, -33],
+                                   [-20,  17, -32, -47,  37,  22,  -7, -21],
+                                   [ -7, -35,  20,  -4,  39,  46, -23,  40],
+                                   [ 40,  27,  37,  43,  38,  -6,  37,  49]]> : vector<4x8xi8>
+
+  %lhs_m = memref.alloca() : memref<4x8xi8>
+  vector.transfer_write %lhs_cst, %lhs_m[%c0, %c0] : vector<4x8xi8>, memref<4x8xi8>
+  %lhs = vector.transfer_read %lhs_m[%c0, %c0], %c0_i8 : memref<4x8xi8>, vector<4x8xi8>
+
+  vector.print str "LHS:\n"
+  %lhs0 = vector.extract %lhs[0] : vector<8xi8> from vector<4x8xi8>
+  %lhs1 = vector.extract %lhs[1] : vector<8xi8> from vector<4x8xi8>
+  %lhs2 = vector.extract %lhs[2] : vector<8xi8> from vector<4x8xi8>
+  %lhs3 = vector.extract %lhs[3] : vector<8xi8> from vector<4x8xi8>
+  vector.print %lhs0 : vector<8xi8>
+  vector.print %lhs1 : vector<8xi8>
+  vector.print %lhs2 : vector<8xi8>
+  vector.print %lhs3 : vector<8xi8>
+
+  // RHS test data
+  %rhs_cst = arith.constant dense<[[-17, -50,  -1,  48, -13,  22,  39,  33],
+                                   [-35, -24,  37, -32,  33,  30, -11, -17],
+                                   [-28,  31,   3, -44, -15, -27,  22,  35],
+                                   [-23,  39,  48,  26, -23,  32, -39, -38]]> : vector<4x8xi8>
+
+  %rhs_m = memref.alloca() : memref<4x8xi8>
+  vector.transfer_write %rhs_cst, %rhs_m[%c0, %c0] : vector<4x8xi8>, memref<4x8xi8>
+
+  %rhs_m1 = memref.collapse_shape %rhs_m [[0, 1]] : memref<4x8xi8> into memref<32xi8>
+  %rhs_flat = vector.transfer_read %rhs_m1[%c0], %c0_i8 {in_bounds = [true]} : memref<32xi8>, vector<[32]xi8>
+
+  vector.print str "RHS:\n"
+  %rhs0 = vector.scalable.extract %rhs_flat[0] : vector<[16]xi8> from vector<[32]xi8>
+  %rhs1 = vector.scalable.extract %rhs_flat[16] : vector<[16]xi8> from vector<[32]xi8>
+  vector.print %rhs0 : vector<[16]xi8>
+  vector.print %rhs1 : vector<[16]xi8>
+
+  %rhs = vector.shape_cast %rhs_flat : vector<[32]xi8> to vector<[4]x8xi8>
+
+  // Matrix multiplication
+  %0 = arith.extsi %lhs : vector<4x8xi8> to vector<4x8xi32>
+  %1 = arith.extsi %rhs : vector<[4]x8xi8> to vector<[4]x8xi32>
+  %2 = vector.contract {indexing_maps = #packed_maps,
+                        iterator_types = ["parallel", "parallel", "reduction"],
+                        kind = #vector.kind<add>} %0, %1, %acc
+    : vector<4x8xi32>, vector<[4]x8xi32> into vector<4x[4]xi32>
+
+  // Display the result of the multiplication
+  vector.print str "Result:\n"
+  %u0 = vector.extract %2[0] : vector<[4]xi32> from vector<4x[4]xi32>
+  %u1 = vector.extract %2[1] : vector<[4]xi32> from vector<4x[4]xi32>
+  %u2 = vector.extract %2[2] : vector<[4]xi32> from vector<4x[4]xi32>
+  %u3 = vector.extract %2[3] : vector<[4]xi32> from vector<4x[4]xi32>
+  vector.print %u0 : vector<[4]xi32>
+  vector.print %u1 : vector<[4]xi32>
+  vector.print %u2 : vector<[4]xi32>
+  vector.print %u3 : vector<[4]xi32>
+
+// CHECK: ( -1999,  1941,   685, -2879 )
+// CHECK: ( -3705,  2952,   987,  -685 )
+// CHECK: (  2565,  4157, -1589,  -357 )
+// CHECK: (  2383, -2252,    32, -1365 )
+  return
+}
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-8x8x8-vs2.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-8x8x8-vs2.mlir
new file mode 100644
index 0000000000000..ce57be91fa540
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-8x8x8-vs2.mlir
@@ -0,0 +1,159 @@
+// REQUIRES: arm-emulator
+
+// DEFINE: %{compile} = mlir-opt %s \
+// DEFINE:   --convert-vector-to-scf --convert-scf-to-cf  --convert-vector-to-llvm='enable-arm-sve enable-arm-i8mm' \
+// DEFINE:   --expand-strided-metadata --convert-to-llvm --finalize-memref-to-llvm  --reconcile-unrealized-casts \
+// DEFINE: -o %t
+
+// DEFINE: %{entry_point} = main
+
+// DEFINE: %{run} = %mcr_aarch64_cmd %t -e %{entry_point} -entry-point-result=void  --march=aarch64 --mattr="+sve,+i8mm" \
+// DEFINE:    -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%native_mlir_arm_runner_utils
+
+// RUN: rm -f %t && %{compile} && %{run} | FileCheck %s
+
+#packed_maps = [
+  affine_map<(d0, d1, d2) -> (d0, d2)>,
+  affine_map<(d0, d1, d2) -> (d1, d2)>,
+  affine_map<(d0, d1, d2) -> (d0, d1)>
+]
+
+func.func private @setArmVLBits(%bits : i32)
+
+func.func @main() {
+  %c256 = arith.constant 256 : i32
+  func.call @setArmVLBits(%c256) : (i32) -> ()
+
+  %c0 = arith.constant 0 : index
+  %c0_i32 = arith.constant 0 : i32
+  %c0_i8 = arith.constant 0 : i8
+
+
+  // Accumulator test data
+  %acc_cst = arith.constant dense<[[-44,  20,  44, -46,  -8,  25, -34,  26],
+                                   [-20, -36,  -3,  39, -48, -31, -25, -21],
+                                   [-35, -27, -36, -31,  23, -34,  -8, -33],
+                                   [-20,  17, -32, -47,  37,  22,  -7, -21],
+                                   [ -7, -35,  20,  -4,  39,  46, -23,  40],
+                                   [ 40,  27,  37,  43,  38,  -6,  37,  49],
+                                   [-17, -50,  -1,  48, -13,  22,  39,  33],
+                                   [-35, -24,  37, -32,  33,  30, -11, -17]]> : vector<8x8xi32>
+  %acc_m = memref.alloca() : memref<8x8xi32>
+  vector.transfer_write %acc_cst, %acc_m[%c0, %c0] : vector<8x8xi32>, memref<8x8xi32>
+
+  %acc_m1 = memref.collapse_shape %acc_m [[0, 1]] : memref<8x8xi32> into memref<64xi32>
+  %acc_flat = vector.transfer_read %acc_m1[%c0], %c0_i32 {in_bounds = [true]} : memref<64xi32>, vector<[32]xi32>
+  %acc = vector.shape_cast %acc_flat : vector<[32]xi32> to vector<8x[4]xi32>
+
+  vector.print str "ACC:\n"
+  %acc0 = vector.extract %acc[0] : vector<[4]xi32> from vector<8x[4]xi32>
+  %acc1 = vector.extract %acc[1] : vector<[4]xi32> from vector<8x[4]xi32>
+  %acc2 = vector.extract %acc[2] : vector<[4]xi32> from vector<8x[4]xi32>
+  %acc3 = vector.extract %acc[3] : vector<[4]xi32> from vector<8x[4]xi32>
+  %acc4 = vector.extract %acc[4] : vector<[4]xi32> from vector<8x[4]xi32>
+  %acc5 = vector.extract %acc[5] : vector<[4]xi32> from vector<8x[4]xi32>
+  %acc6 = vector.extract %acc[6] : vector<[4]xi32> from vector<8x[4]xi32>
+  %acc7 = vector.extract %acc[7] : vector<[4]xi32> from vector<8x[4]xi32>
+  vector.print %acc0 : vector<[4]xi32>
+  vector.print %acc1 : vector<[4]xi32>
+  vector.print %acc2 : vector<[4]xi32>
+  vector.print %acc3 : vector<[4]xi32>
+  vector.print %acc4 : vector<[4]xi32>
+  vector.print %acc5 : vector<[4]xi32>
+  vector.print %acc6 : vector<[4]xi32>
+  vector.print %acc7 : vector<[4]xi32>
+
+  // LHS test data
+  %lhs_cst = arith.constant dense<[[-28,  31,   3, -44, -15, -27,  22,  35],
+                                   [-23,  39,  48,  26, -23,  32, -39, -38],
+                                   [ -3,   9,  43, -30, -32,  39,  41, -39],
+                                   [-13, -21, -25,  27,  47, -36, -11, -11],
+                                   [ -4, -20,  36,  11,  13, -23,  24, -13],
+                                   [-20,  30,  -5,   1,  42, -37, -22,  35],
+                                   [-22,  38,  -4,  44,  25, -31,  23, -39],
+                                   [-45,  -4, -31, -24,  14, -41, -47,  22]]> : vector<8x8xi8>
+
+  %lhs_m = memref.alloca() : memref<8x8xi8>
+  vector.transfer_write %lhs_cst, %lhs_m[%c0, %c0] : vector<8x8xi8>, memref<8x8xi8>
+  %lhs = vector.transfer_read %lhs_m[%c0, %c0], %c0_i8 : memref<8x8xi8>, vector<8x8xi8>
+
+  vector.print str "LHS:\n"
+  %lhs0 = vector.extract %lhs[0] : vector<8xi8> from vector<8x8xi8>
+  %lhs1 = vector.extract %lhs[1] : vector<8xi8> from vector<8x8xi8>
+  %lhs2 = vector.extract %lhs[2] : vector<8xi8> from vector<8x8xi8>
+  %lhs3 = vector.extract %lhs[3] : vector<8xi8> from vector<8x8xi8>
+  %lhs4 = vector.extract %lhs[4] : vector<8xi8> from vector<8x8xi8>
+  %lhs5 = vector.extract %lhs[5] : vector<8xi8> from vector<8x8xi8>
+  %lhs6 = vector.extract %lhs[6] : vector<8xi8> from vector<8x8xi8>
+  %lhs7 = vector.extract %lhs[7] : vector<8xi8> from vector<8x8xi8>
+  vector.print %lhs0 : vector<8xi8>
+  vector.print %lhs1 : vector<8xi8>
+  vector.print %lhs2 : vector<8xi8>
+  vector.print %lhs3 : vector<8xi8>
+  vector.print %lhs4 : vector<8xi8>
+  vector.print %lhs5 : vector<8xi8>
+  vector.print %lhs6 : vector<8xi8>
+  vector.print %lhs7 : vector<8xi8>
+
+  // RHS test data
+  %rhs_cst = arith.constant dense<[[-40, -11, -36,  36,  -1,  20,  14, -32],
+                                   [ 46, -45, -48, -46, -24,  31, -36,  22],
+                                   [  2,  36,  45, -29, -37, -49, -20, -35],
+                                   [ -6,  23,  23,  15,  20,   4,  -8,  -2],
+                                   [-35,  -6,  16,  49, -50,   9, -44,  13],
+                                   [ 24,   1,  -4, -44,  41,  15, -43,  44],
+                                   [ 44,   0, -10,  41,  22,  44, -40,   0],
+                                   [-33,  19,  27,  22,  38, -17,  23,  -9]]> : vector<8x8xi8>
+
+  %rhs_m = memref.alloca() : memref<8x8xi8>
+  vector.transfer_write %rhs_cst, %rhs_m[%c0, %c0] : vector<8x8xi8>, memref<8x8xi8>
+
+  %rhs_m1 = memref.collapse_shape %rhs_m [[0, 1]] : memref<8x8xi8> into memref<64xi8>
+  %rhs_flat = vector.transfer_read %rhs_m1[%c0], %c0_i8 {in_bounds = [true]} : memref<64xi8>, vector<[32]xi8>
+
+  vector.print str "RHS:\n"
+  %rhs0 = vector.scalable.extract %rhs_flat[ 0] : vector<[16]xi8> from vector<[32]xi8>
+  %rhs1 = vector.scalable.extract %rhs_flat[16] : vector<[16]xi8> from vector<[32]xi8>
+  vector.print %rhs0 : vector<[16]xi8>
+  vector.print %rhs1 : vector<[16]xi8>
+
+  %rhs = vector.shape_cast %rhs_flat : vector<[32]xi8> to vector<[4]x8xi8>
+
+  // Matrix multiplication
+  %0 = arith.extsi %lhs : vector<8x8xi8> to vector<8x8xi32>
+  %1 = arith.extsi %rhs : vector<[4]x8xi8> to vector<[4]x8xi32>
+  %2 = vector.contract {indexing_maps = #packed_maps,
+                        iterator_types = ["parallel", "parallel", "reduction"],
+                        kind = #vector.kind<add>} %0, %1, %acc
+    : vector<8x8xi32>, vector<[4]x8xi32> into vector<8x[4]xi32>
+
+  // Display the result of the multilication
+  vector.print str "Result:\n"
+  %u0 = vector.extract %2[0] : vector<[4]xi32> from vector<8x[4]xi32>
+  %u1 = vector.extract %2[1] : vector<[4]xi32> from vector<8x[4]xi32>
+  %u2 = vector.extract %2[2] : vector<[4]xi32> from vector<8x[4]xi32>
+  %u3 = vector.extract %2[3] : vector<[4]xi32> from vector<8x[4]xi32>
+  %u4 = vector.extract %2[4] : vector<[4]xi32> from vector<8x[4]xi32>
+  %u5 = vector.extract %2[5] : vector<[4]xi32> from vector<8x[4]xi32>
+  %u6 = vector.extract %2[6] : vector<[4]xi32> from vector<8x[4]xi32>
+  %u7 = vector.extract %2[7] : vector<[4]xi32> from vector<8x[4]xi32>
+  vector.print %u0 : vector<[4]xi32>
+  vector.print %u1 : vector<[4]xi32>
+  vector.print %u2 : vector<[4]xi32>
+  vector.print %u3 : vector<[4]xi32>
+  vector.print %u4 : vector<[4]xi32>
+  vector.print %u5 : vector<[4]xi32>
+  vector.print %u6 : vector<[4]xi32>
+  vector.print %u7 : vector<[4]xi32>
+
+
+// CHECK: ( -2294, -1282,  2728,  -410, -1328,   882, -5498,   732 )
+// CHECK: (  1012, -4237,  4154,  2624,  5225, -2338,  2011,  1374 )
+// CHECK: (    -8, -1611,  2905,    -1, -1068, -3155, -2428,   153 )
+// CHECK: (  2034, -1768, -2092,   284,  -792,   -23,   668,  2172 )
+// CHECK: (  -248, -3728,  1214,   555,  -668, -2114, -1794,  2560 )
+// CHECK: ( -1484, -2642,   297,  1551,  -483,  3173,  -576,  2570 )
+// CHECK: (  3098, -7851,  1366,  1892,  -427, -4533,  -819,  4698 )
+// CHECK: (  -135,  1247,   765,  -479,  1245,  3074, -2281,   -23 )
+  return
+}
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-summla-4x8x4.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-summla-4x8x4.mlir
new file mode 100644
index 0000000000000..f1f311ddb0c18
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-summla-4x8x4.mlir
@@ -0,0 +1,118 @@
+// REQUIRES: arm-emulator
+
+// DEFINE: %{compile} = mlir-opt %s \
+// DEFINE:   --convert-vector-to-scf --convert-scf-to-cf  --convert-vector-to-llvm='enable-arm-sve enable-arm-i8mm' \
+// DEFINE:   --expand-strided-metadata --convert-to-llvm --finalize-memref-to-llvm  --reconcile-unrealized-casts \
+// DEFINE: -o %t
+
+// DEFINE: %{entry_point} = main
+
+// DEFINE: %{run} = %mcr_aarch64_cmd %t -e %{entry_point} -entry-point-result=void  --march=aarch64 --mattr="+sve,+i8mm" \
+// DEFINE:    -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%native_mlir_arm_runner_utils
+
+// RUN: rm -f %t && %{compile} && %{run} | FileCheck %s
+
+#packed_maps = [
+  affine_map<(d0, d1, d2) -> (d0, d2)>,
+  affine_map<(d0, d1, d2) -> (d1, d2)>,
+  affine_map<(d0, d1, d2) -> (d0, d1)>
+]
+
+func.func private @setArmVLBits(%bits : i32)
+
+func.func @main() {
+  %c128 = arith.constant 128 : i32
+  func.call @setArmVLBits(%c128) : (i32) -> ()
+
+  %c0 = arith.constant 0 : index
+  %c0_i32 = arith.constant 0 : i32
+  %c0_i8 = arith.constant 0 : i8
+
+// Accumulator test data
+  %acc_cst = arith.constant dense<[[-44,  20,  44, -46],
+                                   [ -8,  25, -34,  26],
+                                   [-20, -36,  -3,  39],
+                                   [-48, -31, -25, -21]]> : vector<4x4xi32>
+  %acc_m = memref.alloca() : memref<4x4xi32>
+  vector.transfer_write %acc_cst, %acc_m[%c0, %c0] : vector<4x4xi32>, memref<4x4xi32>
+
+  %acc_m1 = memref.collapse_shape %acc_m [[0, 1]] : memref<4x4xi32> into memref<16xi32>
+  %acc_flat = vector.transfer_read %acc_m1[%c0], %c0_i32 {in_bounds = [true]} : memref<16xi32>, vector<[16]xi32>
+  %acc = vector.shape_cast %acc_flat : vector<[16]xi32> to vector<4x[4]xi32>
+
+  vector.print str "ACC:\n"
+  %acc0 = vector.extract %acc[0] : vector<[4]xi32> from vector<4x[4]xi32>
+  %acc1 = vector.extract %acc[1] : vector<[4]xi32> from vector<4x[4]xi32>
+  %acc2 = vector.extract %acc[2] : vector<[4]xi32> from vector<4x[4]xi32>
+  %acc3 = vector.extract %acc[3] : vector<[4]xi32> from vector<4x[4]xi32>
+  vector.print %acc0 : vector<[4]xi32>
+  vector.print %acc1 : vector<[4]xi32>
+  vector.print %acc2 : vector<[4]xi32>
+  vector.print %acc3 : vector<[4]xi32>
+
+  // LHS test data
+  %lhs_cst = arith.constant dense<[[-35, -27, -36, -31,  23, -34,  -8, -33],
+                                   [-20,  17, -32, -47,  37,  22,  -7, -21],
+                                   [ -7, -35,  20,  -4,  39,  46, -23,  40],
+                                   [ 40,  27,  37,  43,  38,  -6,  37,  49]]> : vector<4x8xi8>
+
+  %lhs_m = memref.alloca() : memref<4x8xi8>
+  vector.transfer_write %lhs_cst, %lhs_m[%c0, %c0] : vector<4x8xi8>, memref<4x8xi8>
+  %lhs = vector.transfer_read %lhs_m[%c0, %c0], %c0_i8 : memref<4x8xi8>, vector<4x8xi8>
+
+  vector.print str "LHS:\n"
+  %lhs0 = vector.extract %lhs[0] : vector<8xi8> from vector<4x8xi8>
+  %lhs1 = vector.extract %lhs[1] : vector<8xi8> from vector<4x8xi8>
+  %lhs2 = vector.extract %lhs[2] : vector<8xi8> from vector<4x8xi8>
+  %lhs3 = vector.extract %lhs[3] : vector<8xi8> from vector<4x8xi8>
+  vector.print %lhs0 : vector<8xi8>
+  vector.print %lhs1 : vector<8xi8>
+  vector.print %lhs2 : vector<8xi8>
+  vector.print %lhs3 : vector<8xi8>
+
+  // RHS test data
+  %rhs_cst = arith.constant dense<[[125, 171, 138, 187, 108, 175,  82,  99],
+                                   [221,  25, 164,  97, 156, 221, 218, 177],
+                                   [171, 160, 219, 191, 144,  45, 161, 210],
+                                   [223, 165, 123,  99, 108,  86,  37,  92]]> : vector<4x8xi8>
+
+  %rhs_m = memref.alloca() : memref<4x8xi8>
+  vector.transfer_write %rhs_cst, %rhs_m[%c0, %c0] : vector<4x8xi8>, memref<4x8xi8>
+
+  %rhs_m1 = memref.collapse_shape %rhs_m [[0, 1]] : memref<4x8xi8> into memref<32xi8>
+  %rhs_flat = vector.transfer_read %rhs_m1[%c0], %c0_i8 {in_bounds = [true]} : memref<32xi8>, vector<[32]xi8>
+
+  vector.print str "RHS:\n"
+  %rhs0 = vector.scalable.extract %rhs_flat[0] : vector<[16]xi8> from vector<[32]xi8>
+  %rhs1 = vector.scalable.extract %rhs_flat[16] : vector<[16]xi8> from vector<[32]xi8>
+  vector.print %rhs0 : vector<[16]xi8>
+  vector.print %rhs1 : vector<[16]xi8>
+
+  %rhs = vector.shape_cast %rhs_flat : vector<[32]xi8> to vector<[4]x8xi8>
+
+  // Matrix multiplication
+  %0 = arith.extsi %lhs : vector<4x8xi8> to vector<4x8xi32>
+  %1 = arith.extui %rhs : vector<[4]x8xi8> to vector<[4]x8xi32>
+  %2 = vector.contract {indexing_maps = #packed_maps,
+                        iterator_types = ["parallel", "parallel", "reduction"],
+                        kind = #vector.kind<add>} %0, %1, %acc
+    : vector<4x8xi32>, vector<[4]x8xi32> into vector<4x[4]xi32>
+
+  // Display the result of the multiplication
+  vector.print str "Result:\n"
+  %u0 = vector.extract %2[0] : vector<[4]xi32> from vector<4x[4]xi32>
+  %u1 = vector.extract %2[1] : vector<[4]xi32> from vector<4x[4]xi32>
+  %u2 = vector.extract %2[2] : vector<[4]xi32> from vector<4x[4]xi32>
+  %u3 = vector.extract %2[3] : vector<[4]xi32> from vector<4x[4]xi32>
+  vector.print %u0 : vector<[4]xi32>
+  vector.print %u1 : vector<[4]xi32>
+  vector.print %u2 : vector<[4]xi32>
+  vector.print %u3 : vector<[4]xi32>
+
+// CHECK: ( -27190, -28812, -30502, -23575 )
+// CHECK: (  -7613,  -8386, -15938,  -6521 )
+// CHECK: (   9468,  18750,   9199,   5764 )
+// CHECK: (  33655,  41064,  48900,  31627 )
+  return
+}
+
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-ummla-4x8x4.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-ummla-4x8x4.mlir
new file mode 100644
index 0000000000000..7af0b2c3f1054
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-ummla-4x8x4.mlir
@@ -0,0 +1,119 @@
+// REQUIRES: arm-emulator
+
+// DEFINE: %{compile} = mlir-opt %s ...
[truncated]

@momchil-velikov momchil-velikov force-pushed the users/momchil-velikov/vector-contract-i8mm-integration branch from 194c1c7 to 87f2964 Compare May 27, 2025 16:22
@momchil-velikov momchil-velikov force-pushed the users/momchil-velikov/vector-contract-i8mm-transform-dialect branch from 251f93e to 4ab7765 Compare May 27, 2025 16:22
@banach-space
Copy link
Contributor

Thanks - great to finally be reaching this stage! I have a few high-level questions and suggestions:

1. Why is the scalable dimension always [4]?

From the current tests, it looks like the scalable dim is always [4]. Could you remind me why that value is chosen?

2. Reduce duplication in the 4x8x4 tests

The current tests differ only in terms of input/output and extsi vs extui. It should be possible to reduce duplication by extracting shared logic into helpers, and writing 4 separate entry points (set via entry_point) to isolate the differences.

For example:

func.func @main_smmla() {
  // Init LHS, RHS, ACC
  // CHECK-LINES for LHS
  print(lhs);
  // CHECK-LINES for RHS
  print(rhs);
 
  arith.extsi (lhs)
  arith.extsi (rhs)
  vector.contract
  
  // CHECK-LINES for ACC
  print(acc);
}

This would keep the test logic focused and easier to maintain.

3. Add checks for generated IR (LLVM dialect)

It would be good to verify that the lowered IR includes the correct SME MMLA intrinsics. For example:

// CHECK-COUNT-4: llvm.intr.smmla

This would help confirm both correctness and that the expected number of operations are emitted.

4. Consider toggling VL within tests
Have you considered toggling the scalable vector length (VL) within the test? That would allow verifying behaviour for multiple VL values.

From what I can tell, this would only work if the inputs are generated inside a loop, similar to this example:

// Allocate memory.
%mem1 = memref.alloca(%svl_s, %svl_s) : memref<?x?xi32>
// Fill each "row" of "mem1" with row number.
//
// For example, assuming an SVL of 128-bits:
//
// 0, 0, 0, 0
// 1, 1, 1, 1
// 2, 2, 2, 2
// 3, 3, 3, 3
//
%init_0 = arith.constant 0 : i32
scf.for %i = %c0 to %svl_s step %c1 iter_args(%val = %init_0) -> (i32) {
%splat_val = vector.broadcast %val : i32 to vector<[4]xi32>
vector.store %splat_val, %mem1[%i, %c0] : memref<?x?xi32>, vector<[4]xi32>
%val_next = arith.addi %val, %c1_i32 : i32
scf.yield %val_next : i32
}

That might be a nice validation of the "scalability" aspect.

@momchil-velikov momchil-velikov force-pushed the users/momchil-velikov/vector-contract-i8mm-transform-dialect branch from 4ab7765 to 2d2bea9 Compare June 4, 2025 17:07
@momchil-velikov momchil-velikov force-pushed the users/momchil-velikov/vector-contract-i8mm-integration branch from 87f2964 to a34bec5 Compare June 4, 2025 17:07
@momchil-velikov momchil-velikov force-pushed the users/momchil-velikov/vector-contract-i8mm-transform-dialect branch from 2d2bea9 to d177af4 Compare June 6, 2025 09:10
@momchil-velikov momchil-velikov requested a review from ftynse as a code owner June 6, 2025 09:10
Base automatically changed from users/momchil-velikov/vector-contract-i8mm-transform-dialect to main June 6, 2025 09:54
@banach-space
Copy link
Contributor

Thanks for all the updates, @momchil-velikov 🙏🏻

Just for context: Momchil and I discussed this offline - in its current form, the test crashes (after producing valid results). I’ve created a minimal repro here:

Given that these tests (i.e., SVE e2e tests that require emulation) are only run by clang-aarch64-full-2stage - which we closely monitor - I'm inclined to land this with a temporary workaround (on top of other updates suggested below).


Looking at the current version of the tests, I do think there’s room to improve code reuse and to make the test structure more VLA-friendly (Vector-Length Agnostic). For instance, the function below mixes fixed and scalable shapes:

func.func private @prepareAccTestData(%in: vector<4x4xi32>) -> vector<4x[4]xi32> {
 // ...  
}

It also returns a scalable vector from a function - something we haven’t independently tested and should probably avoid until validated.


Since this comment is getting long and I’m effectively suggesting a refactor (after many solid improvements already), I’ve gone ahead and rewritten the tests myself. Below is a single test file that combines:

  • contraction-smmla-4x8x4.mlir,
  • contraction-ummla-4x8x4.mlir,
  • contraction-summla-4x8x4.mlir,
  • contraction-usmmla-4x8x4.mlir.

Summary of changes:

To me, this format accomplishes 3 key goals:

  • Maximizes code reuse
  • Verifies codegen paths for all i8mm instructions
  • Smoke-tests run-time correctness under scalable vector lengths

If this makes sense to you, please feel free to reuse the file. Alternatively, I’d be happy to push it to your branch directly.

–Andrzej


NEW TEST FILE

// REQUIRES: arm-emulator

// DEFINE: %{compile} = mlir-opt %s \
// DEFINE:   --convert-vector-to-scf --convert-scf-to-cf  --convert-vector-to-llvm='enable-arm-sve enable-arm-i8mm' \
// DEFINE:   --expand-strided-metadata --convert-to-llvm --finalize-memref-to-llvm  --reconcile-unrealized-casts \
// DEFINE: -o %t

// DEFINE: %{entry_point} = main

// DEFINE: %{run} = %mcr_aarch64_cmd %t -e %{entry_point} -entry-point-result=void  --march=aarch64 --mattr="+sve,+i8mm" \
// DEFINE:    -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%native_mlir_arm_runner_utils

// RUN: rm -f %t && %{compile} && FileCheck %s --input-file=%t -check-prefix CHECK-IR && %{run}

#packed_maps = [
  affine_map<(d0, d1, d2) -> (d0, d2)>,
  affine_map<(d0, d1, d2) -> (d1, d2)>,
  affine_map<(d0, d1, d2) -> (d0, d1)>
]

//=============================================================================
// Helper methods to allocate+initialise test data
//=============================================================================
// Allolocate and initialise a memref of 16 x vscale elements of type: i8. This
// matches the requirments for the accumulator for i8mm, it is precisely
//   * 4 x [4] elements.
func.func private @getFlatMemRef_i32() -> memref<?xi32> {
  %c0 = arith.constant 0 : index
  %c16 = arith.constant 16 : index
  %vscale = vector.vscale

  %c16_vscale = arith.muli %vscale, %c16 : index
  %flat_mem = memref.alloc(%c16_vscale) : memref<?xi32>
  %vector_i32 = llvm.intr.stepvector : vector<[16]xi32>

  vector.transfer_write %vector_i32, %flat_mem[%c0] : vector<[16]xi32>, memref<?xi32>
  return %flat_mem : memref<?xi32>
}

// Allolocate and initialise a memref of 32 x vscale elements of type: i8. This
// matches the requirments for the RHS for i8mm, it is precisely
//   * [4] x 8 elements.
func.func private @getFlatMemRef_i8_scalable() -> memref<?xi8> {
  %c0 = arith.constant 0 : index
  %c32 = arith.constant 32 : index
  %vscale = vector.vscale

  %vscale_times_32 = arith.muli %vscale, %c32 : index
  %flat_mem = memref.alloc(%vscale_times_32) : memref<?xi8>
  %vector_i32 = llvm.intr.stepvector : vector<[32]xi8>

  vector.transfer_write %vector_i32, %flat_mem[%c0] : vector<[32]xi8>, memref<?xi8>
  return %flat_mem : memref<?xi8>
}

// Allolocate and initialise a memref of 32  elements of type: i8. This
// matches the requirments for the RHS for i8mm, it is precisely:
//   * 4 x 8 elements.
func.func private @getFlatMemRef_i8_fixed() -> memref<?xi8> {
  %c0 = arith.constant 0 : index
  %c32 = arith.constant 32 : index

  %flat_mem = memref.alloc(%c32) : memref<?xi8>
  %vector_i32 = llvm.intr.stepvector : vector<32xi8>

  vector.transfer_write %vector_i32, %flat_mem[%c0] : vector<32xi8>, memref<?xi8>
  return %flat_mem : memref<?xi8>
}

//=============================================================================
// Main entry point for test.
//=============================================================================
func.func @main() {
  // NOTE: Update this value to some other valid value of VL (i.e. supported by
  // SVE) to see the impact of "scalability".
  // FIXME: https://github.com/llvm/llvm-project/issues/143670
  %c128 = arith.constant 128 : i32
  func.call @setArmVLBits(%c128) : (i32) -> ()

  %c0_idx = arith.constant 0 : index
  %c0_i32 = arith.constant 0 : i32
  %c0_i8 = arith.constant 0 : i8

  //---------------------------------------------------------------------------
  // 1. GENERATE TEST DATA
  //---------------------------------------------------------------------------
  // 1.1. Accumulator test data
  %acc_flat = func.call @getFlatMemRef_i32() : () -> memref<?xi32>
  %flat_vec = vector.transfer_read %acc_flat[%c0_idx], %c0_i32 {in_bounds = [true]} : memref<?xi32>, vector<[16]xi32>
  %acc = vector.shape_cast %flat_vec : vector<[16]xi32> to vector<4x[4]xi32>

  // 1.2. LHS test data
  %lhs_flat = func.call @getFlatMemRef_i8_fixed() : () -> memref<?xi8>
  %lhs_flat_vec = vector.transfer_read %lhs_flat[%c0_idx], %c0_i8 {in_bounds = [true]} : memref<?xi8>, vector<32xi8>
  %lhs = vector.shape_cast %lhs_flat_vec : vector<32xi8> to vector<4x8xi8>

  // 1.3. RHS test data
  %rhs_flat = func.call @getFlatMemRef_i8_scalable() : () -> memref<?xi8>
  %rhs_flat_vec = vector.transfer_read %rhs_flat[%c0_idx], %c0_i8 {in_bounds = [true]} : memref<?xi8>, vector<[32]xi8>
  %rhs = vector.shape_cast %rhs_flat_vec : vector<[32]xi8> to vector<[4]x8xi8>

  //---------------------------------------------------------------------------
  // 2. "EXTEND" THE RHS + LHS VECTORS
  // This is what i8mm expects.
  //---------------------------------------------------------------------------
  %lhs_si = arith.extsi %lhs : vector<4x8xi8> to vector<4x8xi32>
  %lhs_ui = arith.extui %lhs : vector<4x8xi8> to vector<4x8xi32>

  %rhs_si = arith.extsi %rhs : vector<[4]x8xi8> to vector<[4]x8xi32>
  %rhs_ui = arith.extui %rhs : vector<[4]x8xi8> to vector<[4]x8xi32>

  //---------------------------------------------------------------------------
  // 3. MATRIX MULTIPLICATION
  //---------------------------------------------------------------------------
  // 3.1. SMMLA
  // CHECK-IR-COUNT-4: arm_sve.intr.smmla
  %res_smmla = vector.contract {indexing_maps = #packed_maps,
                        iterator_types = ["parallel", "parallel", "reduction"],
                        kind = #vector.kind<add>} %lhs_si, %rhs_si, %acc
    : vector<4x8xi32>, vector<[4]x8xi32> into vector<4x[4]xi32>

  // 3.2. UMMLA
  // CHECK-IR-COUNT-4: arm_sve.intr.ummla
  %res_ummla = vector.contract {indexing_maps = #packed_maps,
                        iterator_types = ["parallel", "parallel", "reduction"],
                        kind = #vector.kind<add>} %lhs_ui, %rhs_ui, %acc
    : vector<4x8xi32>, vector<[4]x8xi32> into vector<4x[4]xi32>

  // 3.3. USMMLA
  // CHECK-IR-COUNT-4: arm_sve.intr.usmmla
  %res_usmmla = vector.contract {indexing_maps = #packed_maps,
                        iterator_types = ["parallel", "parallel", "reduction"],
                        kind = #vector.kind<add>} %lhs_ui, %rhs_si, %acc
    : vector<4x8xi32>, vector<[4]x8xi32> into vector<4x[4]xi32>

  // 3.4. SUMMLA
  // CHECK-IR-COUNT-4: arm_sve.intr.usmmla
  %res_summla = vector.contract {indexing_maps = #packed_maps,
                        iterator_types = ["parallel", "parallel", "reduction"],
                        kind = #vector.kind<add>} %lhs_si, %rhs_ui, %acc
    : vector<4x8xi32>, vector<[4]x8xi32> into vector<4x[4]xi32>

  //---------------------------------------------------------------------------
  // 4. DISPLAY THE RESULTS OF THE MULTIPLICATION
  // TODO: Implement this and use instead:
  //    * vector.print %2 : vector<4x[4]xi32>
  //---------------------------------------------------------------------------
  vector.print str "RESULT (smmla):\n"
  %s0 = vector.extract %res_smmla[0] : vector<[4]xi32> from vector<4x[4]xi32>
  %s1 = vector.extract %res_smmla[1] : vector<[4]xi32> from vector<4x[4]xi32>
  %s2 = vector.extract %res_smmla[2] : vector<[4]xi32> from vector<4x[4]xi32>
  %s3 = vector.extract %res_smmla[3] : vector<[4]xi32> from vector<4x[4]xi32>
  vector.print %s0 : vector<[4]xi32>
  vector.print %s1 : vector<[4]xi32>
  vector.print %s2 : vector<[4]xi32>
  vector.print %s3 : vector<[4]xi32>

  vector.print str "RESULT (ummla):\n"
  %u0 = vector.extract %res_ummla[0] : vector<[4]xi32> from vector<4x[4]xi32>
  %u1 = vector.extract %res_ummla[1] : vector<[4]xi32> from vector<4x[4]xi32>
  %u2 = vector.extract %res_ummla[2] : vector<[4]xi32> from vector<4x[4]xi32>
  %u3 = vector.extract %res_ummla[3] : vector<[4]xi32> from vector<4x[4]xi32>
  vector.print %u0 : vector<[4]xi32>
  vector.print %u1 : vector<[4]xi32>
  vector.print %u2 : vector<[4]xi32>
  vector.print %u3 : vector<[4]xi32>

  vector.print str "RESULT (usmmla):\n"
  %us0 = vector.extract %res_usmmla[0] : vector<[4]xi32> from vector<4x[4]xi32>
  %us1 = vector.extract %res_usmmla[1] : vector<[4]xi32> from vector<4x[4]xi32>
  %us2 = vector.extract %res_usmmla[2] : vector<[4]xi32> from vector<4x[4]xi32>
  %us3 = vector.extract %res_usmmla[3] : vector<[4]xi32> from vector<4x[4]xi32>
  vector.print %us0 : vector<[4]xi32>
  vector.print %us1 : vector<[4]xi32>
  vector.print %us2 : vector<[4]xi32>
  vector.print %us3 : vector<[4]xi32>

  vector.print str "RESULT (summla):\n"
  %su0 = vector.extract %res_summla[0] : vector<[4]xi32> from vector<4x[4]xi32>
  %su1 = vector.extract %res_summla[1] : vector<[4]xi32> from vector<4x[4]xi32>
  %su2 = vector.extract %res_summla[2] : vector<[4]xi32> from vector<4x[4]xi32>
  %su3 = vector.extract %res_summla[3] : vector<[4]xi32> from vector<4x[4]xi32>
  vector.print %su0 : vector<[4]xi32>
  vector.print %su1 : vector<[4]xi32>
  vector.print %su2 : vector<[4]xi32>
  vector.print %su3 : vector<[4]xi32>

  // With all inputs positive, the results are identical for types of extensions.
  // TOOD: Use negative numbers to demonstrate the run-time difference between e.g. UMMLA and SMMLA.
  //CHECK-4: ( 140, 365, 590, 815, 1040, 1265, 1490, 1715 )
  //CHECK-4: ( 372, 1109, 1846, 2583, 3320, 4057, 4794, 5531 )
  //CHECK-4: ( 604, 1853, 3102, 4351, 5600, 6849, 8098, 9347 )
  //CHECK-4: ( 836, 2597, 4358, 6119, 7880, 9641, 11402, 13163 )

  //---------------------------------------------------------------------------
  // 5. WORKAROUND
  // This extra printing should not be required, but the test crashes without it.
  // FIXME: https://github.com/llvm/llvm-project/issues/143670
  //---------------------------------------------------------------------------
  %res_smmla_flat = vector.shape_cast %res_smmla : vector<4x[4]xi32> to vector<[16]xi32>
  vector.transfer_write %res_smmla_flat, %acc_flat[%c0_idx] : vector<[16]xi32>, memref<?xi32>
  %acc_cast = memref.cast %acc_flat : memref<?xi32> to memref<*xi32>
  call @printMemrefI32(%acc_cast) : (memref<*xi32>) -> ()

  //---------------------------------------------------------------------------
  // 6. BUFFER DEALLOCATION
  //---------------------------------------------------------------------------
  memref.dealloc %acc_flat : memref<?xi32>
  memref.dealloc %rhs_flat : memref<?xi8>
  memref.dealloc %lhs_flat : memref<?xi8>

  return
}

func.func private @printMemrefI32(%ptr : memref<*xi32>)
func.func private @setArmVLBits(%bits : i32)

@momchil-velikov
Copy link
Collaborator Author

momchil-velikov commented Jun 12, 2025

Thanks for the workaround, I applied it and finished the refactoring.

  1. Why is the scalable dimension always [4]?

Initially, no reason other then keep test data from eating too much screen space.
However, I'm not convinced in the ability of LLVM to handle unnatural vector sizes, e.g. <vscale x 2 x i32> or <vscale x 6 x i32>

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great updates, thank you! I really like how this is checking both VL=128 and VL=256.

The overall logic LGTM, but since the underlying concepts are pretty complex (scalable vectors + i8mm), it deserves considerably more documentation. Suggestions inline.

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM % some clarifications in the docs.

I appreciate that my comment is rather long, but that's just to clarify my view-point/rationale. The actual request is small. Also, to avoid confusion - I posted that in a thread that I "unresolved" (GitHub can be "funny" how it displays things).

[nit] Use m, n and k in affine maps as we do here https://mlir.llvm.org/docs/Dialects/Linalg/#linalgmatmul-linalgmatmulop (and in a few other places).

@momchil-velikov momchil-velikov merged commit 7eda827 into main Jun 17, 2025
7 checks passed
@momchil-velikov momchil-velikov deleted the users/momchil-velikov/vector-contract-i8mm-integration branch June 17, 2025 10:03
ajaden-codes pushed a commit to Jaddyen/llvm-project that referenced this pull request Jun 17, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants