Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLIR][AMDGPU] Adding Vector transfer_read to load rewrite pattern #131803

Merged
merged 5 commits into from
Mar 21, 2025

Conversation

jerryyin
Copy link
Member

@jerryyin jerryyin commented Mar 18, 2025

This PR adds the Vector transfer_read to load rewrite pattern. The pattern creates a transfer read op lowering. A vector trasfer read op will be lowered to a combination of vector.load, arith.select and vector.broadcast if:

  • The transfer op is masked.
  • The memref is in buffer address space.
  • Other conditions introduced from TransferReadToVectorLoadLowering

The motivation of this PR is due to the lack of support of masked load from amdgpu backend. llvm.intr.masked.load lower to a series of conditional scalar loads refer to (scalarize-masked-mem-intrin pass). This PR will make it possible for masked transfer_read to be lowered towards buffer load with bounds check, allowing a more optimized global load accessing pattern compared with existing implementation of llvm.intr.masked.load on vectors.

@llvmbot
Copy link
Member

llvmbot commented Mar 18, 2025

@llvm/pr-subscribers-mlir-amdgpu
@llvm/pr-subscribers-backend-amdgpu
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-gpu

Author: Zhuoran Yin (jerryyin)

Changes

This PR adds the Vector -> AMDGPU conversion lowering, including a single lowering pattern. The single lowering pattern creates a transfer read op lowering. A vector trasfer read op will be lowered to a combination of vector.load, arith.select and vector.broadcast if:

  • The transfer op is masked.
  • The memref is in buffer address space.
  • Other conditions introduced from TransferReadToVectorLoadLowering

The motivation of this PR is due to the lack of support of masked load from amdgpu backend. llvm.intr.masked.load lower to a series of conditional scalar loads refer to (scalarize-masked-mem-intrin pass). This PR will allow masked transfer_read to be lowered towards buffer load with bounds check, allowing a more optimized global load accessing pattern compared with existing implementation of llvm.intr.masked.load on vectors.


Full diff: https://github.com/llvm/llvm-project/pull/131803.diff

7 Files Affected:

  • (modified) mlir/include/mlir/Conversion/Passes.h (+1)
  • (modified) mlir/include/mlir/Conversion/Passes.td (+10)
  • (added) mlir/include/mlir/Conversion/VectorToAMDGPU/VectorToAMDGPU.h (+24)
  • (modified) mlir/lib/Conversion/CMakeLists.txt (+1)
  • (added) mlir/lib/Conversion/VectorToAMDGPU/CMakeLists.txt (+18)
  • (added) mlir/lib/Conversion/VectorToAMDGPU/VectorToAMDGPU.cpp (+147)
  • (added) mlir/test/Conversion/VectorToAMDGPU/vector-transfer-read-to-vector-load.mlir (+68)
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index ccd862f67c068..ed5e8de8787f7 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -73,6 +73,7 @@
 #include "mlir/Conversion/TosaToTensor/TosaToTensor.h"
 #include "mlir/Conversion/UBToLLVM/UBToLLVM.h"
 #include "mlir/Conversion/UBToSPIRV/UBToSPIRV.h"
+#include "mlir/Conversion/VectorToAMDGPU/VectorToAMDGPU.h"
 #include "mlir/Conversion/VectorToArmSME/VectorToArmSME.h"
 #include "mlir/Conversion/VectorToGPU/VectorToGPU.h"
 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.h"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index bbba495e613b2..1845d0235183e 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1333,6 +1333,16 @@ def ConvertVectorToArmSMEPass : Pass<"convert-vector-to-arm-sme"> {
   let dependentDialects = ["arm_sme::ArmSMEDialect", "arm_sve::ArmSVEDialect"];
 }
 
+//===----------------------------------------------------------------------===//
+// VectorToAMDGPU
+//===----------------------------------------------------------------------===//
+
+def ConvertVectorToAMDGPUPass : Pass<"convert-vector-to-amdgpu"> {
+  let summary = "Lower the operations from the vector dialect into the AMDGPU "
+                "dialect";
+  let dependentDialects = ["amdgpu::AMDGPUDialect", "vector::VectorDialect"];
+} 
+
 //===----------------------------------------------------------------------===//
 // ArmSMEToSCF
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Conversion/VectorToAMDGPU/VectorToAMDGPU.h b/mlir/include/mlir/Conversion/VectorToAMDGPU/VectorToAMDGPU.h
new file mode 100644
index 0000000000000..be96061a23b08
--- /dev/null
+++ b/mlir/include/mlir/Conversion/VectorToAMDGPU/VectorToAMDGPU.h
@@ -0,0 +1,24 @@
+//===- VectorToAMDGPU.h - Vector to AMDGPU dialect conversion ---*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_VECTORTOAMDGPU_VECTORTOAMDGPU_H
+#define MLIR_CONVERSION_VECTORTOAMDGPU_VECTORTOAMDGPU_H
+
+#include "mlir/IR/PatternMatch.h"
+
+namespace mlir {
+class RewritePatternSet;
+class Pass;
+
+#define GEN_PASS_DECL_CONVERTVECTORTOAMDGPUPASS
+#include "mlir/Conversion/Passes.h.inc"
+
+void populateVectorToAMDGPUConversionPatterns(RewritePatternSet &patterns);
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_VECTORTOAMDGPU_VECTORTOAMDGPU_H
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index b6c21440c571c..1e4cbd2be4c96 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -66,6 +66,7 @@ add_subdirectory(TosaToSCF)
 add_subdirectory(TosaToTensor)
 add_subdirectory(UBToLLVM)
 add_subdirectory(UBToSPIRV)
+add_subdirectory(VectorToAMDGPU)
 add_subdirectory(VectorToArmSME)
 add_subdirectory(VectorToGPU)
 add_subdirectory(VectorToLLVM)
diff --git a/mlir/lib/Conversion/VectorToAMDGPU/CMakeLists.txt b/mlir/lib/Conversion/VectorToAMDGPU/CMakeLists.txt
new file mode 100644
index 0000000000000..2ad46c26d0a57
--- /dev/null
+++ b/mlir/lib/Conversion/VectorToAMDGPU/CMakeLists.txt
@@ -0,0 +1,18 @@
+add_mlir_conversion_library(MLIRVectorToAMDGPU
+  VectorToAMDGPU.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/VectorToAMDGPU
+
+  DEPENDS
+  MLIRConversionPassIncGen
+
+  LINK_COMPONENTS
+  Core
+
+  LINK_LIBS PUBLIC
+  MLIRAMDGPUDialect
+  MLIRVectorDialect
+  MLIRPass
+  MLIRTransforms
+  )
diff --git a/mlir/lib/Conversion/VectorToAMDGPU/VectorToAMDGPU.cpp b/mlir/lib/Conversion/VectorToAMDGPU/VectorToAMDGPU.cpp
new file mode 100644
index 0000000000000..248b84a7fdc98
--- /dev/null
+++ b/mlir/lib/Conversion/VectorToAMDGPU/VectorToAMDGPU.cpp
@@ -0,0 +1,147 @@
+//===- VectorToAMDGPU.cpp - Vector to AMDGPU dialect conversion ---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/VectorToAMDGPU/VectorToAMDGPU.h"
+
+#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_CONVERTVECTORTOAMDGPUPASS
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+/// This pattern supports lowering of:
+/// `vector.transfer_read` to a combination of `vector.load`, `arith.select` and
+/// `vector.broadcast` if all of the following hold:
+/// - The transfer op is masked.
+/// - The memref is in buffer address space.
+/// - Stride of most minor memref dimension must be 1.
+/// - Out-of-bounds masking is not required.
+/// - If the memref's element type is a vector type then it coincides with the
+///   result type.
+/// - The permutation map doesn't perform permutation (broadcasting is allowed).
+/// Note: those conditions mostly come from TransferReadToVectorLoadLowering
+/// pass.
+static LogicalResult
+transferPreconditions(PatternRewriter &rewriter,
+                      VectorTransferOpInterface xferOp,
+                      SmallVector<unsigned> &broadcastedDims,
+                      VectorType &unbroadcastedVectorType) {
+  if (!xferOp.getMask())
+    return rewriter.notifyMatchFailure(xferOp, "Only support masked transfer");
+
+  // Permutations are handled by VectorToSCF or
+  // populateVectorTransferPermutationMapLoweringPatterns.
+  // We let the 0-d corner case pass-through as it is supported.
+  if (!xferOp.getPermutationMap().isMinorIdentityWithBroadcasting(
+          &broadcastedDims))
+    return rewriter.notifyMatchFailure(xferOp, "not minor identity + bcast");
+
+  auto memRefType = dyn_cast<MemRefType>(xferOp.getShapedType());
+  if (!memRefType)
+    return rewriter.notifyMatchFailure(xferOp, "not a memref source");
+
+  Attribute addrSpace = memRefType.getMemorySpace();
+  if (!addrSpace ||
+      llvm::dyn_cast<amdgpu::AddressSpaceAttr>(addrSpace).getValue() !=
+          amdgpu::AddressSpace::FatRawBuffer)
+    return rewriter.notifyMatchFailure(xferOp, "not in buffer address space");
+
+  // Non-unit strides are handled by VectorToSCF.
+  if (!memRefType.isLastDimUnitStride())
+    return rewriter.notifyMatchFailure(xferOp, "!= 1 stride needs VectorToSCF");
+
+  // If there is broadcasting involved then we first load the unbroadcasted
+  // vector, and then broadcast it with `vector.broadcast`.
+  ArrayRef<int64_t> vectorShape = xferOp.getVectorType().getShape();
+  SmallVector<int64_t> unbroadcastedVectorShape(vectorShape);
+  for (unsigned i : broadcastedDims)
+    unbroadcastedVectorShape[i] = 1;
+  unbroadcastedVectorType = xferOp.getVectorType().cloneWith(
+      unbroadcastedVectorShape, xferOp.getVectorType().getElementType());
+
+  // `vector.load` supports vector types as memref's elements only when the
+  // resulting vector type is the same as the element type.
+  auto memrefElTy = memRefType.getElementType();
+  if (isa<VectorType>(memrefElTy) && memrefElTy != unbroadcastedVectorType)
+    return rewriter.notifyMatchFailure(xferOp, "incompatible element type");
+
+  // Otherwise, element types of the memref and the vector must match.
+  if (!isa<VectorType>(memrefElTy) &&
+      memrefElTy != xferOp.getVectorType().getElementType())
+    return rewriter.notifyMatchFailure(xferOp, "non-matching element type");
+
+  // Out-of-bounds dims are handled by MaterializeTransferMask.
+  if (xferOp.hasOutOfBoundsDim())
+    return rewriter.notifyMatchFailure(xferOp, "out-of-bounds needs mask");
+
+  if (xferOp.getVectorType().getRank() != 1)
+    // vector.maskedload operates on 1-D vectors.
+    return rewriter.notifyMatchFailure(
+        xferOp, "vector type is not rank 1, can't create masked load, needs "
+                "VectorToSCF");
+
+  return success();
+}
+
+struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
+  using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
+                                PatternRewriter &rewriter) const override {
+
+    SmallVector<unsigned> broadcastedDims;
+    VectorType unbroadcastedVectorType;
+    if (failed(transferPreconditions(rewriter, readOp, broadcastedDims,
+                                     unbroadcastedVectorType))) {
+      return failure();
+    }
+
+    Value fill = rewriter.create<vector::SplatOp>(
+        readOp.getLoc(), unbroadcastedVectorType, readOp.getPadding());
+    Value load = rewriter.create<vector::LoadOp>(
+        readOp.getLoc(), unbroadcastedVectorType, readOp.getSource(),
+        readOp.getIndices());
+    Value res = rewriter.create<arith::SelectOp>(
+        readOp.getLoc(), unbroadcastedVectorType, readOp.getMask(), load, fill);
+
+    // Insert a broadcasting op if required.
+    if (!broadcastedDims.empty()) {
+      res = rewriter.create<vector::BroadcastOp>(readOp.getLoc(),
+                                                 readOp.getVectorType(), res);
+    }
+
+    rewriter.replaceOp(readOp, res);
+
+    return success();
+  }
+};
+
+void mlir::populateVectorToAMDGPUConversionPatterns(
+    RewritePatternSet &patterns) {
+  patterns.add<TransferReadLowering>(patterns.getContext());
+}
+
+struct ConvertVectorToAMDGPUPass
+    : public impl::ConvertVectorToAMDGPUPassBase<ConvertVectorToAMDGPUPass> {
+  void runOnOperation() override {
+    RewritePatternSet patterns(&getContext());
+    populateVectorToAMDGPUConversionPatterns(patterns);
+    if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
+      return signalPassFailure();
+  }
+};
diff --git a/mlir/test/Conversion/VectorToAMDGPU/vector-transfer-read-to-vector-load.mlir b/mlir/test/Conversion/VectorToAMDGPU/vector-transfer-read-to-vector-load.mlir
new file mode 100644
index 0000000000000..30d9814cc0621
--- /dev/null
+++ b/mlir/test/Conversion/VectorToAMDGPU/vector-transfer-read-to-vector-load.mlir
@@ -0,0 +1,68 @@
+// RUN: mlir-opt %s -convert-vector-to-amdgpu --split-input-file | FileCheck %s
+
+// CHECK-LABEL: func @transfer_to_maskedload_fatrawbuffer(
+// CHECK-SAME: %[[ARG0:.*]]: memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>
+// CHECK-SAME: %[[ARG1:.*]]: index
+// CHECK-SAME: %[[ARG2:.*]]: vector<4xi1>
+func.func @transfer_to_maskedload_fatrawbuffer(%mem : memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>, %idx : index, %mask : vector<4xi1>) -> vector<4xf32> {
+  %cf0 = arith.constant 0.0 : f32
+  %res = vector.transfer_read %mem[%idx, %idx], %cf0, %mask {in_bounds = [true]} : memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>, vector<4xf32>
+  return %res : vector<4xf32>
+}
+// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00>
+// CHECK: %[[LOAD:.*]] = vector.load %arg0[%arg1, %arg1]
+// CHECK: %[[SELECT:.*]] = arith.select %arg2, %[[LOAD]], %[[CST]]
+// CHECK: return %[[SELECT]] : vector<4xf32>
+
+// -----
+
+// CHECK-LABEL: func @transfer_to_maskedload_regular(
+// CHECK-SAME: %[[ARG0:.*]]: memref<8x8xf32>
+// CHECK-SAME: %[[ARG1:.*]]: index
+// CHECK-SAME: %[[ARG2:.*]]: vector<4xi1>
+func.func @transfer_to_maskedload_regular(%mem : memref<8x8xf32>, %idx : index, %mask : vector<4xi1>) -> vector<4xf32> {
+  %cf0 = arith.constant 0.0 : f32
+  %res = vector.transfer_read %mem[%idx, %idx], %cf0, %mask {in_bounds = [true]} : memref<8x8xf32>, vector<4xf32>
+  return %res : vector<4xf32>
+}
+// CHECK: %[[CST:.*]] = arith.constant 0.0
+// CHECK: %[[RES:.*]] = vector.transfer_read %arg0[%arg1, %arg1], %[[CST]], %arg2 {in_bounds = [true]} : memref<8x8xf32>, vector<4xf32>
+// CHECK: return %[[RES]] : vector<4xf32>
+
+// -----
+
+// CHECK-LABEL: func @transfer_broadcasting(
+// CHECK-SAME: %[[ARG0:.*]]: memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>
+// CHECK-SAME: %[[ARG1:.*]]: index
+// CHECK-SAME: %[[ARG2:.*]]: vector<1xi1>
+#broadcast_1d = affine_map<(d0, d1) -> (0)>
+func.func @transfer_broadcasting(%mem : memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>, %idx : index, %mask : vector<1xi1>) -> vector<4xf32> {
+  %cf0 = arith.constant 0.0 : f32
+  %res = vector.transfer_read %mem[%idx, %idx], %cf0, %mask
+    {in_bounds = [true], permutation_map = #broadcast_1d}
+      : memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>, vector<4xf32>
+  return %res : vector<4xf32>
+}
+// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00>
+// CHECK: %[[LOAD:.*]] = vector.load %arg0[%arg1, %arg1]
+// CHECK: %[[SELECT:.*]] = arith.select %arg2, %[[LOAD]], %[[CST]]
+// CHECK: %[[BROADCAST:.*]] = vector.broadcast %[[SELECT]] : vector<1xf32> to vector<4xf32>
+// CHECK: return %[[BROADCAST]] : vector<4xf32>
+
+// -----
+
+// CHECK-LABEL: func @transfer_scalar(
+// CHECK-SAME: %[[ARG0:.*]]: memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>
+// CHECK-SAME: %[[ARG1:.*]]: index
+// CHECK-SAME: %[[ARG2:.*]]: vector<1xi1>
+func.func @transfer_scalar(%mem : memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>, %idx : index, %mask : vector<1xi1>) -> vector<1xf32> {
+  %cf0 = arith.constant 0.0 : f32
+  %res = vector.transfer_read %mem[%idx, %idx], %cf0, %mask
+    {in_bounds = [true]}
+      : memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>, vector<1xf32>
+  return %res : vector<1xf32>
+}
+// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00>
+// CHECK: %[[LOAD:.*]] = vector.load %arg0[%arg1, %arg1]
+// CHECK: %[[SELECT:.*]] = arith.select %arg2, %[[LOAD]], %[[CST]]
+// CHECK: return %[[SELECT]] : vector<1xf32>

@jerryyin jerryyin changed the title Adding Vector to AMDGPU conversion lowering [MLIR][AMDGPU] Adding Vector to AMDGPU conversion lowering Mar 18, 2025
@kuhar kuhar self-requested a review March 18, 2025 14:20
@jerryyin jerryyin force-pushed the users/zyin/create-vector-to-amdgpu-conversion branch 2 times, most recently from db53cdf to f7ca23b Compare March 19, 2025 13:51
@llvmbot llvmbot added the bazel "Peripheral" support tier build system: utils/bazel label Mar 19, 2025
@jerryyin jerryyin requested review from kuhar and krzysz00 March 20, 2025 13:04
@jerryyin jerryyin force-pushed the users/zyin/create-vector-to-amdgpu-conversion branch from 8eae773 to 01beca8 Compare March 20, 2025 21:29
@jerryyin jerryyin changed the title [MLIR][AMDGPU] Adding Vector to AMDGPU conversion lowering [MLIR][AMDGPU] Adding Vector transfer_read to load rewrite pattern Mar 21, 2025
@jerryyin jerryyin merged commit ea03bde into main Mar 21, 2025
11 checks passed
@jerryyin jerryyin deleted the users/zyin/create-vector-to-amdgpu-conversion branch March 21, 2025 12:42
@jerryyin jerryyin removed backend:AMDGPU bazel "Peripheral" support tier build system: utils/bazel labels Mar 21, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants