Skip to content

[MLIR][NVVM] Add prefetch Ops #141737

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,16 @@ constexpr int kSharedMemoryAlignmentBit = 128;

/// NVVM memory space identifiers.
enum NVVMMemorySpace {
/// Generic memory space identifier.
kGenericMemorySpace = 0,
/// Global memory space identifier.
kGlobalMemorySpace = 1,
/// Shared memory space identifier.
kSharedMemorySpace = 3,
/// Constant memory space identifier.
kConstantMemorySpace = 4,
/// Local memory space identifier.
kLocalMemorySpace = 5,
/// Tensor memory space identifier.
/// Tensor memory is available only in arch-accelerated
/// variants from sm100 onwards.
Expand Down
74 changes: 74 additions & 0 deletions mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ include "mlir/Dialect/LLVMIR/LLVMTypes.td"
def LLVM_PointerGeneric : LLVM_PointerInAddressSpace<0>;
def LLVM_PointerGlobal : LLVM_PointerInAddressSpace<1>;
def LLVM_PointerShared : LLVM_PointerInAddressSpace<3>;
def LLVM_PointerLocal : LLVM_PointerInAddressSpace<5>;
def LLVM_PointerTensor : LLVM_PointerInAddressSpace<6>;
def LLVM_PointerSharedCluster : LLVM_PointerInAddressSpace<7>;

Expand Down Expand Up @@ -118,6 +119,25 @@ class NVVM_Attr<string attrName, string attrMnemonic, list<Trait> traits = []>
let mnemonic = attrMnemonic;
}

// Cache Eviction Priority enum definitions
def EvictNormal : I32EnumCase<"EvictNormal", 0, "evict_normal">;
def EvictFirst : I32EnumCase<"EvictFirst", 1, "evict_first">;
def EvictLast : I32EnumCase<"EvictLast", 2, "evict_last">;
def EvictUnchanged : I32EnumCase<"EvictUnchanged", 3, "evict_unchanged">;
def NoAllocate : I32EnumCase<"NoAllocate", 4, "no_allocate">;

def CacheEvictionPriority : I32Enum<"CacheEvictionPriority",
"NVVM Cache Eviction Priority",
[EvictNormal, EvictFirst, EvictLast,
EvictUnchanged, NoAllocate]> {
let cppNamespace = "::mlir::NVVM";
}

def CacheEvictionPriorityAttr : EnumAttr<NVVM_Dialect, CacheEvictionPriority,
"cache_eviction_priority"> {
let assemblyFormat = "$value";
}

//===----------------------------------------------------------------------===//
// NVVM intrinsic operations
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -2333,6 +2353,60 @@ def NVVM_CpAsyncBulkTensorSharedCTAToGlobalOp :
let hasVerifier = 1;
}

//===----------------------------------------------------------------------===//
// NVVM Prefetch Op
//===----------------------------------------------------------------------===//

def PrefetchCacheLevelL1 : I32EnumCase<"L1", 0, "L1">;
def PrefetchCacheLevelL2 : I32EnumCase<"L2", 1, "L2">;

def PrefetchCacheLevel : I32Enum<"PrefetchCacheLevel",
"NVVM Prefetch Cache Level",
[PrefetchCacheLevelL1, PrefetchCacheLevelL2]> {
let cppNamespace = "::mlir::NVVM";
}

def PrefetchCacheLevelAttr : EnumAttr<NVVM_Dialect, PrefetchCacheLevel, "prefetch_cache_level"> {
let assemblyFormat = "$value";
}

def NVVM_PrefetchOp : NVVM_Op<"prefetch"> {
let summary = "Brings the cache line containing an address into the specified cache level";
let description = [{
Operand `addr` can be a global, local or generic address pointer. No
operation is performed if `addr` maps to a `shared` memory location.

The `cacheLevel` attribute specifies the cache level to which the cache line
containing the specified address is brought.

`uniform` can be specified after the `cacheLevel` to indicate that the
prefetch is performed to the specified uniform cache level. If `uniform` is
specified, `addr` must be a generic address pointer and no operation is
performed if `addr` maps to a `const`, `local`, or `shared` memory location.

The `evictPriority` attribute is optional and specifies the cache eviction
priority when `cacheLevel` is L2.

[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-prefetch-prefetchu)
}];
let arguments = (ins PrefetchCacheLevelAttr:$cacheLevel,
UnitAttr:$uniform,
AnyTypeOf<[LLVM_PointerGlobal,
LLVM_PointerLocal,
LLVM_PointerGeneric]>:$addr,
OptionalAttr<CacheEvictionPriorityAttr>:$evictPriority);
let assemblyFormat = "`level` `=` $cacheLevel (`uniform` $uniform^)? `,` $addr (`,` `evict_priority` `=` $evictPriority^)? attr-dict `:` type($addr)";
let hasVerifier = 1;

let extraClassDeclaration = [{
static llvm::Intrinsic::ID getIntrinsicID(NVVM::PrefetchOp &op);
}];
let llvmBuilder = [{
auto intId = NVVM::PrefetchOp::getIntrinsicID(op);
createIntrinsicCall(builder, intId, $addr);
}];
}

def NVVM_PrefetchTensorMapOp : NVVM_Op<"prefetch.tensormap",
[DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,
Arguments<(ins LLVM_AnyPointer:$tmaDescriptor, PtxPredicate:$predicate)> {
Expand Down
78 changes: 78 additions & 0 deletions mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1205,6 +1205,42 @@ LogicalResult NVVM::VoteSyncOp::verify() {
return success();
}

LogicalResult NVVM::PrefetchOp::verify() {
using MemSpace = NVVM::NVVMMemorySpace;
using CacheLevel = NVVM::PrefetchCacheLevel;

unsigned addressSpace =
llvm::cast<LLVM::LLVMPointerType>(getAddr().getType()).getAddressSpace();
std::optional<NVVM::CacheEvictionPriority> evictPriority = getEvictPriority();

if (getUniform()) {
if (getCacheLevel() != CacheLevel::L1)
return emitOpError("unsupported cache level, the only supported uniform "
"cache level is L1");

if (addressSpace != MemSpace::kGenericMemorySpace)
return emitOpError(
"prefetch to uniform cache requires a generic pointer");
}

if (evictPriority) {
if (getCacheLevel() != CacheLevel::L2)
return emitOpError(
"cache eviction priority supported only for cache level L2");

if (addressSpace != MemSpace::kGlobalMemorySpace)
return emitOpError("cache eviction priority requires a global pointer");

if (*evictPriority != NVVM::CacheEvictionPriority::EvictNormal &&
*evictPriority != NVVM::CacheEvictionPriority::EvictLast)
return emitOpError(
"unsupported cache eviction priority, only evict_last and "
"evict_normal are supported");
}

return success();
}

/// Packs the given `field` into the `result`.
/// The `result` is 64-bits and each `field` can be 32-bits or narrower.
static llvm::Value *
Expand Down Expand Up @@ -1734,6 +1770,48 @@ NVVM::IDArgPair DotAccumulate2WayOp::getIntrinsicIDAndArgs(
return {ids[type], args};
}

llvm::Intrinsic::ID PrefetchOp::getIntrinsicID(NVVM::PrefetchOp &op) {
using MemSpace = NVVM::NVVMMemorySpace;
using CacheLevel = NVVM::PrefetchCacheLevel;

NVVM::PrefetchCacheLevel cacheLevel = op.getCacheLevel();
std::optional<NVVM::CacheEvictionPriority> evictPriority =
op.getEvictPriority();
unsigned addressSpace =
llvm::cast<LLVM::LLVMPointerType>(op.getAddr().getType())
.getAddressSpace();

if (op.getUniform() && cacheLevel == CacheLevel::L1)
return llvm::Intrinsic::nvvm_prefetchu_L1;

if (evictPriority && cacheLevel == CacheLevel::L2) {
switch (*evictPriority) {
case NVVM::CacheEvictionPriority::EvictLast:
return llvm::Intrinsic::nvvm_prefetch_global_L2_evict_last;
case NVVM::CacheEvictionPriority::EvictNormal:
return llvm::Intrinsic::nvvm_prefetch_global_L2_evict_normal;
default:
llvm_unreachable("Invalid cache eviction priority");
}
}

switch (addressSpace) {
case MemSpace::kGenericMemorySpace:
return cacheLevel == CacheLevel::L1 ? llvm::Intrinsic::nvvm_prefetch_L1
: llvm::Intrinsic::nvvm_prefetch_L2;
case MemSpace::kGlobalMemorySpace:
return cacheLevel == CacheLevel::L1
? llvm::Intrinsic::nvvm_prefetch_global_L1
: llvm::Intrinsic::nvvm_prefetch_global_L2;
case MemSpace::kLocalMemorySpace:
return cacheLevel == CacheLevel::L1
? llvm::Intrinsic::nvvm_prefetch_local_L1
: llvm::Intrinsic::nvvm_prefetch_local_L2;
default:
llvm_unreachable("Invalid pointer address space");
}
}

//===----------------------------------------------------------------------===//
// NVVMDialect initialization, type parsing, and registration.
//===----------------------------------------------------------------------===//
Expand Down
23 changes: 23 additions & 0 deletions mlir/test/Dialect/LLVMIR/nvvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -596,6 +596,29 @@ func.func @dot_accumulate_2way(%a_vec: vector<2xi16>, %b_vec: vector<4xi8>, %c:
return
}

// CHECK-LABEL: @prefetch
func.func @prefetch(%gen_ptr: !llvm.ptr, %local_ptr: !llvm.ptr<5>, %global_ptr: !llvm.ptr<1>) {
// CHECK: nvvm.prefetch level = L1, %{{.*}}
nvvm.prefetch level = L1, %gen_ptr : !llvm.ptr<0>
// CHECK: nvvm.prefetch level = L1, %{{.*}}
nvvm.prefetch level = L1, %local_ptr : !llvm.ptr<5>
// CHECK: nvvm.prefetch level = L1, %{{.*}}
nvvm.prefetch level = L1, %global_ptr : !llvm.ptr<1>
// CHECK: nvvm.prefetch level = L2, %{{.*}}
nvvm.prefetch level = L2, %gen_ptr : !llvm.ptr<0>
// CHECK: nvvm.prefetch level = L2, %{{.*}}
nvvm.prefetch level = L2, %local_ptr : !llvm.ptr<5>
// CHECK: nvvm.prefetch level = L2, %{{.*}}
nvvm.prefetch level = L2, %global_ptr : !llvm.ptr<1>
// CHECK: nvvm.prefetch level = L2, %{{.*}}
nvvm.prefetch level = L2, %global_ptr, evict_priority = evict_last : !llvm.ptr<1>
// CHECK: nvvm.prefetch level = L2, %{{.*}}
nvvm.prefetch level = L2, %global_ptr, evict_priority = evict_normal : !llvm.ptr<1>
// CHECK: nvvm.prefetch level = L1 uniform, %{{.*}}
nvvm.prefetch level = L1 uniform, %gen_ptr : !llvm.ptr
return
}

// -----

// Just check these don't emit errors.
Expand Down
47 changes: 47 additions & 0 deletions mlir/test/Target/LLVMIR/nvvm/prefetch.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s

llvm.func @prefetch_L1(%gen_ptr: !llvm.ptr, %local_ptr: !llvm.ptr<5>, %global_ptr: !llvm.ptr<1>) {
// CHECK-LABEL: define void @prefetch_L1(ptr %0, ptr addrspace(5) %1, ptr addrspace(1) %2) {
// CHECK-NEXT: call void @llvm.nvvm.prefetch.L1(ptr %0)
// CHECK-NEXT: call void @llvm.nvvm.prefetch.local.L1(ptr addrspace(5) %1)
// CHECK-NEXT: call void @llvm.nvvm.prefetch.global.L1(ptr addrspace(1) %2)
// CHECK-NEXT: ret void
// CHECK-NEXT: }
nvvm.prefetch level = L1, %gen_ptr : !llvm.ptr<0>
nvvm.prefetch level = L1, %local_ptr : !llvm.ptr<5>
nvvm.prefetch level = L1, %global_ptr : !llvm.ptr<1>
llvm.return
}

llvm.func @prefetch_L2(%gen_ptr: !llvm.ptr, %local_ptr: !llvm.ptr<5>, %global_ptr: !llvm.ptr<1>) {
// CHECK-LABEL: define void @prefetch_L2(ptr %0, ptr addrspace(5) %1, ptr addrspace(1) %2) {
// CHECK-NEXT: call void @llvm.nvvm.prefetch.L2(ptr %0)
// CHECK-NEXT: call void @llvm.nvvm.prefetch.local.L2(ptr addrspace(5) %1)
// CHECK-NEXT: call void @llvm.nvvm.prefetch.global.L2(ptr addrspace(1) %2)
// CHECK-NEXT: ret void
// CHECK-NEXT: }
nvvm.prefetch level = L2, %gen_ptr : !llvm.ptr<0>
nvvm.prefetch level = L2, %local_ptr : !llvm.ptr<5>
nvvm.prefetch level = L2, %global_ptr : !llvm.ptr<1>
llvm.return
}

llvm.func @prefetch_L2_eviction_priority(%global_ptr: !llvm.ptr<1>) {
// CHECK-LABEL: define void @prefetch_L2_eviction_priority(ptr addrspace(1) %0) {
// CHECK-NEXT: call void @llvm.nvvm.prefetch.global.L2.evict.last(ptr addrspace(1) %0)
// CHECK-NEXT: call void @llvm.nvvm.prefetch.global.L2.evict.normal(ptr addrspace(1) %0)
// CHECK-NEXT: ret void
// CHECK-NEXT: }
nvvm.prefetch level = L2, %global_ptr, evict_priority = evict_last : !llvm.ptr<1>
nvvm.prefetch level = L2, %global_ptr, evict_priority = evict_normal : !llvm.ptr<1>
llvm.return
}

llvm.func @prefetch_L1_uniform(%gen_ptr: !llvm.ptr) {
// CHECK-LABEL: define void @prefetch_L1_uniform(ptr %0) {
// CHECK-NEXT: call void @llvm.nvvm.prefetchu.L1(ptr %0)
// CHECK-NEXT: ret void
// CHECK-NEXT: }
nvvm.prefetch level = L1 uniform, %gen_ptr : !llvm.ptr
llvm.return
}
64 changes: 64 additions & 0 deletions mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -248,3 +248,67 @@ llvm.func @nvvm_cvt_bf16x2_to_f8x2_invalid_rounding(%src : vector<2xbf16>) {
%res = nvvm.convert.bf16x2.to.f8x2 <ue8m0> %src {rnd = #nvvm.fp_rnd_mode<rn>} : vector<2xbf16> -> i16
llvm.return
}

// -----

llvm.func @nvvm_prefetch_L1_with_evict_priority(%global_ptr: !llvm.ptr<1>) {
// expected-error @below {{cache eviction priority supported only for cache level L2}}
nvvm.prefetch level = L1, %global_ptr, evict_priority = evict_last : !llvm.ptr<1>
llvm.return
}

// -----

llvm.func @nvvm_prefetch_L2_with_evict_last_invalid_addr_space(%local_ptr: !llvm.ptr<5>) {
// expected-error @below {{cache eviction priority requires a global pointer}}
nvvm.prefetch level = L2, %local_ptr, evict_priority = evict_last : !llvm.ptr<5>
llvm.return
}

// -----

llvm.func @nvvm_prefetch_L2_with_evict_normal_invalid_addr_space(%local_ptr: !llvm.ptr<5>) {
// expected-error @below {{cache eviction priority requires a global pointer}}
nvvm.prefetch level = L2, %local_ptr, evict_priority = evict_normal : !llvm.ptr<5>
llvm.return
}

// -----

llvm.func @nvvm_prefetch_L2_with_invalid_evict_first(%global_ptr: !llvm.ptr<1>) {
// expected-error @below {{unsupported cache eviction priority, only evict_last and evict_normal are supported}}
nvvm.prefetch level = L2, %global_ptr, evict_priority = evict_first : !llvm.ptr<1>
llvm.return
}

// -----

llvm.func @nvvm_prefetch_L2_with_invalid_evict_unchanged(%global_ptr: !llvm.ptr<1>) {
// expected-error @below {{unsupported cache eviction priority, only evict_last and evict_normal are supported}}
nvvm.prefetch level = L2, %global_ptr, evict_priority = evict_unchanged : !llvm.ptr<1>
llvm.return
}

// -----

llvm.func @nvvm_prefetch_L2_with_invalid_no_allocate(%global_ptr: !llvm.ptr<1>) {
// expected-error @below {{unsupported cache eviction priority, only evict_last and evict_normal are supported}}
nvvm.prefetch level = L2, %global_ptr, evict_priority = no_allocate : !llvm.ptr<1>
llvm.return
}

// -----

llvm.func @nvvm_prefetch_uniform_with_L2(%gen_ptr: !llvm.ptr) {
// expected-error @below {{unsupported cache level, the only supported uniform cache level is L1}}
nvvm.prefetch level = L2 uniform, %gen_ptr : !llvm.ptr
llvm.return
}

// -----

llvm.func @nvvm_prefetch_uniform_with_invalid_addr_space(%global_ptr: !llvm.ptr<1>) {
// expected-error @below {{prefetch to uniform cache requires a generic pointer}}
nvvm.prefetch level = L1 uniform, %global_ptr : !llvm.ptr<1>
llvm.return
}