Skip to content

Commit 158beb1

Browse files
Wolfram70rorth
authored andcommitted
[MLIR][NVVM] Add prefetch Ops (llvm#141737)
This change adds `prefetch` and `prefetch.uniform` Ops to the NVVM dialect for the `prefetch` and `prefetchu` group of instructions. PTX Spec Reference: https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-prefetch-prefetchu
1 parent dd7008a commit 158beb1

File tree

6 files changed

+290
-0
lines changed

6 files changed

+290
-0
lines changed

mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,16 @@ constexpr int kSharedMemoryAlignmentBit = 128;
3636

3737
/// NVVM memory space identifiers.
3838
enum NVVMMemorySpace {
39+
/// Generic memory space identifier.
40+
kGenericMemorySpace = 0,
3941
/// Global memory space identifier.
4042
kGlobalMemorySpace = 1,
4143
/// Shared memory space identifier.
4244
kSharedMemorySpace = 3,
4345
/// Constant memory space identifier.
4446
kConstantMemorySpace = 4,
47+
/// Local memory space identifier.
48+
kLocalMemorySpace = 5,
4549
/// Tensor memory space identifier.
4650
/// Tensor memory is available only in arch-accelerated
4751
/// variants from sm100 onwards.

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ include "mlir/Dialect/LLVMIR/LLVMTypes.td"
2525
def LLVM_PointerGeneric : LLVM_PointerInAddressSpace<0>;
2626
def LLVM_PointerGlobal : LLVM_PointerInAddressSpace<1>;
2727
def LLVM_PointerShared : LLVM_PointerInAddressSpace<3>;
28+
def LLVM_PointerLocal : LLVM_PointerInAddressSpace<5>;
2829
def LLVM_PointerTensor : LLVM_PointerInAddressSpace<6>;
2930
def LLVM_PointerSharedCluster : LLVM_PointerInAddressSpace<7>;
3031

@@ -118,6 +119,25 @@ class NVVM_Attr<string attrName, string attrMnemonic, list<Trait> traits = []>
118119
let mnemonic = attrMnemonic;
119120
}
120121

122+
// Cache Eviction Priority enum definitions
123+
def EvictNormal : I32EnumCase<"EvictNormal", 0, "evict_normal">;
124+
def EvictFirst : I32EnumCase<"EvictFirst", 1, "evict_first">;
125+
def EvictLast : I32EnumCase<"EvictLast", 2, "evict_last">;
126+
def EvictUnchanged : I32EnumCase<"EvictUnchanged", 3, "evict_unchanged">;
127+
def NoAllocate : I32EnumCase<"NoAllocate", 4, "no_allocate">;
128+
129+
def CacheEvictionPriority : I32Enum<"CacheEvictionPriority",
130+
"NVVM Cache Eviction Priority",
131+
[EvictNormal, EvictFirst, EvictLast,
132+
EvictUnchanged, NoAllocate]> {
133+
let cppNamespace = "::mlir::NVVM";
134+
}
135+
136+
def CacheEvictionPriorityAttr : EnumAttr<NVVM_Dialect, CacheEvictionPriority,
137+
"cache_eviction_priority"> {
138+
let assemblyFormat = "$value";
139+
}
140+
121141
//===----------------------------------------------------------------------===//
122142
// NVVM intrinsic operations
123143
//===----------------------------------------------------------------------===//
@@ -2333,6 +2353,60 @@ def NVVM_CpAsyncBulkTensorSharedCTAToGlobalOp :
23332353
let hasVerifier = 1;
23342354
}
23352355

2356+
//===----------------------------------------------------------------------===//
2357+
// NVVM Prefetch Op
2358+
//===----------------------------------------------------------------------===//
2359+
2360+
def PrefetchCacheLevelL1 : I32EnumCase<"L1", 0, "L1">;
2361+
def PrefetchCacheLevelL2 : I32EnumCase<"L2", 1, "L2">;
2362+
2363+
def PrefetchCacheLevel : I32Enum<"PrefetchCacheLevel",
2364+
"NVVM Prefetch Cache Level",
2365+
[PrefetchCacheLevelL1, PrefetchCacheLevelL2]> {
2366+
let cppNamespace = "::mlir::NVVM";
2367+
}
2368+
2369+
def PrefetchCacheLevelAttr : EnumAttr<NVVM_Dialect, PrefetchCacheLevel, "prefetch_cache_level"> {
2370+
let assemblyFormat = "$value";
2371+
}
2372+
2373+
def NVVM_PrefetchOp : NVVM_Op<"prefetch"> {
2374+
let summary = "Brings the cache line containing an address into the specified cache level";
2375+
let description = [{
2376+
Operand `addr` can be a global, local or generic address pointer. No
2377+
operation is performed if `addr` maps to a `shared` memory location.
2378+
2379+
The `cacheLevel` attribute specifies the cache level to which the cache line
2380+
containing the specified address is brought.
2381+
2382+
`uniform` can be specified after the `cacheLevel` to indicate that the
2383+
prefetch is performed to the specified uniform cache level. If `uniform` is
2384+
specified, `addr` must be a generic address pointer and no operation is
2385+
performed if `addr` maps to a `const`, `local`, or `shared` memory location.
2386+
2387+
The `evictPriority` attribute is optional and specifies the cache eviction
2388+
priority when `cacheLevel` is L2.
2389+
2390+
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-prefetch-prefetchu)
2391+
}];
2392+
let arguments = (ins PrefetchCacheLevelAttr:$cacheLevel,
2393+
UnitAttr:$uniform,
2394+
AnyTypeOf<[LLVM_PointerGlobal,
2395+
LLVM_PointerLocal,
2396+
LLVM_PointerGeneric]>:$addr,
2397+
OptionalAttr<CacheEvictionPriorityAttr>:$evictPriority);
2398+
let assemblyFormat = "`level` `=` $cacheLevel (`uniform` $uniform^)? `,` $addr (`,` `evict_priority` `=` $evictPriority^)? attr-dict `:` type($addr)";
2399+
let hasVerifier = 1;
2400+
2401+
let extraClassDeclaration = [{
2402+
static llvm::Intrinsic::ID getIntrinsicID(NVVM::PrefetchOp &op);
2403+
}];
2404+
let llvmBuilder = [{
2405+
auto intId = NVVM::PrefetchOp::getIntrinsicID(op);
2406+
createIntrinsicCall(builder, intId, $addr);
2407+
}];
2408+
}
2409+
23362410
def NVVM_PrefetchTensorMapOp : NVVM_Op<"prefetch.tensormap",
23372411
[DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,
23382412
Arguments<(ins LLVM_AnyPointer:$tmaDescriptor, PtxPredicate:$predicate)> {

mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1205,6 +1205,42 @@ LogicalResult NVVM::VoteSyncOp::verify() {
12051205
return success();
12061206
}
12071207

1208+
LogicalResult NVVM::PrefetchOp::verify() {
1209+
using MemSpace = NVVM::NVVMMemorySpace;
1210+
using CacheLevel = NVVM::PrefetchCacheLevel;
1211+
1212+
unsigned addressSpace =
1213+
llvm::cast<LLVM::LLVMPointerType>(getAddr().getType()).getAddressSpace();
1214+
std::optional<NVVM::CacheEvictionPriority> evictPriority = getEvictPriority();
1215+
1216+
if (getUniform()) {
1217+
if (getCacheLevel() != CacheLevel::L1)
1218+
return emitOpError("unsupported cache level, the only supported uniform "
1219+
"cache level is L1");
1220+
1221+
if (addressSpace != MemSpace::kGenericMemorySpace)
1222+
return emitOpError(
1223+
"prefetch to uniform cache requires a generic pointer");
1224+
}
1225+
1226+
if (evictPriority) {
1227+
if (getCacheLevel() != CacheLevel::L2)
1228+
return emitOpError(
1229+
"cache eviction priority supported only for cache level L2");
1230+
1231+
if (addressSpace != MemSpace::kGlobalMemorySpace)
1232+
return emitOpError("cache eviction priority requires a global pointer");
1233+
1234+
if (*evictPriority != NVVM::CacheEvictionPriority::EvictNormal &&
1235+
*evictPriority != NVVM::CacheEvictionPriority::EvictLast)
1236+
return emitOpError(
1237+
"unsupported cache eviction priority, only evict_last and "
1238+
"evict_normal are supported");
1239+
}
1240+
1241+
return success();
1242+
}
1243+
12081244
/// Packs the given `field` into the `result`.
12091245
/// The `result` is 64-bits and each `field` can be 32-bits or narrower.
12101246
static llvm::Value *
@@ -1734,6 +1770,48 @@ NVVM::IDArgPair DotAccumulate2WayOp::getIntrinsicIDAndArgs(
17341770
return {ids[type], args};
17351771
}
17361772

1773+
llvm::Intrinsic::ID PrefetchOp::getIntrinsicID(NVVM::PrefetchOp &op) {
1774+
using MemSpace = NVVM::NVVMMemorySpace;
1775+
using CacheLevel = NVVM::PrefetchCacheLevel;
1776+
1777+
NVVM::PrefetchCacheLevel cacheLevel = op.getCacheLevel();
1778+
std::optional<NVVM::CacheEvictionPriority> evictPriority =
1779+
op.getEvictPriority();
1780+
unsigned addressSpace =
1781+
llvm::cast<LLVM::LLVMPointerType>(op.getAddr().getType())
1782+
.getAddressSpace();
1783+
1784+
if (op.getUniform() && cacheLevel == CacheLevel::L1)
1785+
return llvm::Intrinsic::nvvm_prefetchu_L1;
1786+
1787+
if (evictPriority && cacheLevel == CacheLevel::L2) {
1788+
switch (*evictPriority) {
1789+
case NVVM::CacheEvictionPriority::EvictLast:
1790+
return llvm::Intrinsic::nvvm_prefetch_global_L2_evict_last;
1791+
case NVVM::CacheEvictionPriority::EvictNormal:
1792+
return llvm::Intrinsic::nvvm_prefetch_global_L2_evict_normal;
1793+
default:
1794+
llvm_unreachable("Invalid cache eviction priority");
1795+
}
1796+
}
1797+
1798+
switch (addressSpace) {
1799+
case MemSpace::kGenericMemorySpace:
1800+
return cacheLevel == CacheLevel::L1 ? llvm::Intrinsic::nvvm_prefetch_L1
1801+
: llvm::Intrinsic::nvvm_prefetch_L2;
1802+
case MemSpace::kGlobalMemorySpace:
1803+
return cacheLevel == CacheLevel::L1
1804+
? llvm::Intrinsic::nvvm_prefetch_global_L1
1805+
: llvm::Intrinsic::nvvm_prefetch_global_L2;
1806+
case MemSpace::kLocalMemorySpace:
1807+
return cacheLevel == CacheLevel::L1
1808+
? llvm::Intrinsic::nvvm_prefetch_local_L1
1809+
: llvm::Intrinsic::nvvm_prefetch_local_L2;
1810+
default:
1811+
llvm_unreachable("Invalid pointer address space");
1812+
}
1813+
}
1814+
17371815
//===----------------------------------------------------------------------===//
17381816
// NVVMDialect initialization, type parsing, and registration.
17391817
//===----------------------------------------------------------------------===//

mlir/test/Dialect/LLVMIR/nvvm.mlir

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -596,6 +596,29 @@ func.func @dot_accumulate_2way(%a_vec: vector<2xi16>, %b_vec: vector<4xi8>, %c:
596596
return
597597
}
598598

599+
// CHECK-LABEL: @prefetch
600+
func.func @prefetch(%gen_ptr: !llvm.ptr, %local_ptr: !llvm.ptr<5>, %global_ptr: !llvm.ptr<1>) {
601+
// CHECK: nvvm.prefetch level = L1, %{{.*}}
602+
nvvm.prefetch level = L1, %gen_ptr : !llvm.ptr<0>
603+
// CHECK: nvvm.prefetch level = L1, %{{.*}}
604+
nvvm.prefetch level = L1, %local_ptr : !llvm.ptr<5>
605+
// CHECK: nvvm.prefetch level = L1, %{{.*}}
606+
nvvm.prefetch level = L1, %global_ptr : !llvm.ptr<1>
607+
// CHECK: nvvm.prefetch level = L2, %{{.*}}
608+
nvvm.prefetch level = L2, %gen_ptr : !llvm.ptr<0>
609+
// CHECK: nvvm.prefetch level = L2, %{{.*}}
610+
nvvm.prefetch level = L2, %local_ptr : !llvm.ptr<5>
611+
// CHECK: nvvm.prefetch level = L2, %{{.*}}
612+
nvvm.prefetch level = L2, %global_ptr : !llvm.ptr<1>
613+
// CHECK: nvvm.prefetch level = L2, %{{.*}}
614+
nvvm.prefetch level = L2, %global_ptr, evict_priority = evict_last : !llvm.ptr<1>
615+
// CHECK: nvvm.prefetch level = L2, %{{.*}}
616+
nvvm.prefetch level = L2, %global_ptr, evict_priority = evict_normal : !llvm.ptr<1>
617+
// CHECK: nvvm.prefetch level = L1 uniform, %{{.*}}
618+
nvvm.prefetch level = L1 uniform, %gen_ptr : !llvm.ptr
619+
return
620+
}
621+
599622
// -----
600623

601624
// Just check these don't emit errors.
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
2+
3+
llvm.func @prefetch_L1(%gen_ptr: !llvm.ptr, %local_ptr: !llvm.ptr<5>, %global_ptr: !llvm.ptr<1>) {
4+
// CHECK-LABEL: define void @prefetch_L1(ptr %0, ptr addrspace(5) %1, ptr addrspace(1) %2) {
5+
// CHECK-NEXT: call void @llvm.nvvm.prefetch.L1(ptr %0)
6+
// CHECK-NEXT: call void @llvm.nvvm.prefetch.local.L1(ptr addrspace(5) %1)
7+
// CHECK-NEXT: call void @llvm.nvvm.prefetch.global.L1(ptr addrspace(1) %2)
8+
// CHECK-NEXT: ret void
9+
// CHECK-NEXT: }
10+
nvvm.prefetch level = L1, %gen_ptr : !llvm.ptr<0>
11+
nvvm.prefetch level = L1, %local_ptr : !llvm.ptr<5>
12+
nvvm.prefetch level = L1, %global_ptr : !llvm.ptr<1>
13+
llvm.return
14+
}
15+
16+
llvm.func @prefetch_L2(%gen_ptr: !llvm.ptr, %local_ptr: !llvm.ptr<5>, %global_ptr: !llvm.ptr<1>) {
17+
// CHECK-LABEL: define void @prefetch_L2(ptr %0, ptr addrspace(5) %1, ptr addrspace(1) %2) {
18+
// CHECK-NEXT: call void @llvm.nvvm.prefetch.L2(ptr %0)
19+
// CHECK-NEXT: call void @llvm.nvvm.prefetch.local.L2(ptr addrspace(5) %1)
20+
// CHECK-NEXT: call void @llvm.nvvm.prefetch.global.L2(ptr addrspace(1) %2)
21+
// CHECK-NEXT: ret void
22+
// CHECK-NEXT: }
23+
nvvm.prefetch level = L2, %gen_ptr : !llvm.ptr<0>
24+
nvvm.prefetch level = L2, %local_ptr : !llvm.ptr<5>
25+
nvvm.prefetch level = L2, %global_ptr : !llvm.ptr<1>
26+
llvm.return
27+
}
28+
29+
llvm.func @prefetch_L2_eviction_priority(%global_ptr: !llvm.ptr<1>) {
30+
// CHECK-LABEL: define void @prefetch_L2_eviction_priority(ptr addrspace(1) %0) {
31+
// CHECK-NEXT: call void @llvm.nvvm.prefetch.global.L2.evict.last(ptr addrspace(1) %0)
32+
// CHECK-NEXT: call void @llvm.nvvm.prefetch.global.L2.evict.normal(ptr addrspace(1) %0)
33+
// CHECK-NEXT: ret void
34+
// CHECK-NEXT: }
35+
nvvm.prefetch level = L2, %global_ptr, evict_priority = evict_last : !llvm.ptr<1>
36+
nvvm.prefetch level = L2, %global_ptr, evict_priority = evict_normal : !llvm.ptr<1>
37+
llvm.return
38+
}
39+
40+
llvm.func @prefetch_L1_uniform(%gen_ptr: !llvm.ptr) {
41+
// CHECK-LABEL: define void @prefetch_L1_uniform(ptr %0) {
42+
// CHECK-NEXT: call void @llvm.nvvm.prefetchu.L1(ptr %0)
43+
// CHECK-NEXT: ret void
44+
// CHECK-NEXT: }
45+
nvvm.prefetch level = L1 uniform, %gen_ptr : !llvm.ptr
46+
llvm.return
47+
}

mlir/test/Target/LLVMIR/nvvmir-invalid.mlir

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,3 +248,67 @@ llvm.func @nvvm_cvt_bf16x2_to_f8x2_invalid_rounding(%src : vector<2xbf16>) {
248248
%res = nvvm.convert.bf16x2.to.f8x2 <ue8m0> %src {rnd = #nvvm.fp_rnd_mode<rn>} : vector<2xbf16> -> i16
249249
llvm.return
250250
}
251+
252+
// -----
253+
254+
llvm.func @nvvm_prefetch_L1_with_evict_priority(%global_ptr: !llvm.ptr<1>) {
255+
// expected-error @below {{cache eviction priority supported only for cache level L2}}
256+
nvvm.prefetch level = L1, %global_ptr, evict_priority = evict_last : !llvm.ptr<1>
257+
llvm.return
258+
}
259+
260+
// -----
261+
262+
llvm.func @nvvm_prefetch_L2_with_evict_last_invalid_addr_space(%local_ptr: !llvm.ptr<5>) {
263+
// expected-error @below {{cache eviction priority requires a global pointer}}
264+
nvvm.prefetch level = L2, %local_ptr, evict_priority = evict_last : !llvm.ptr<5>
265+
llvm.return
266+
}
267+
268+
// -----
269+
270+
llvm.func @nvvm_prefetch_L2_with_evict_normal_invalid_addr_space(%local_ptr: !llvm.ptr<5>) {
271+
// expected-error @below {{cache eviction priority requires a global pointer}}
272+
nvvm.prefetch level = L2, %local_ptr, evict_priority = evict_normal : !llvm.ptr<5>
273+
llvm.return
274+
}
275+
276+
// -----
277+
278+
llvm.func @nvvm_prefetch_L2_with_invalid_evict_first(%global_ptr: !llvm.ptr<1>) {
279+
// expected-error @below {{unsupported cache eviction priority, only evict_last and evict_normal are supported}}
280+
nvvm.prefetch level = L2, %global_ptr, evict_priority = evict_first : !llvm.ptr<1>
281+
llvm.return
282+
}
283+
284+
// -----
285+
286+
llvm.func @nvvm_prefetch_L2_with_invalid_evict_unchanged(%global_ptr: !llvm.ptr<1>) {
287+
// expected-error @below {{unsupported cache eviction priority, only evict_last and evict_normal are supported}}
288+
nvvm.prefetch level = L2, %global_ptr, evict_priority = evict_unchanged : !llvm.ptr<1>
289+
llvm.return
290+
}
291+
292+
// -----
293+
294+
llvm.func @nvvm_prefetch_L2_with_invalid_no_allocate(%global_ptr: !llvm.ptr<1>) {
295+
// expected-error @below {{unsupported cache eviction priority, only evict_last and evict_normal are supported}}
296+
nvvm.prefetch level = L2, %global_ptr, evict_priority = no_allocate : !llvm.ptr<1>
297+
llvm.return
298+
}
299+
300+
// -----
301+
302+
llvm.func @nvvm_prefetch_uniform_with_L2(%gen_ptr: !llvm.ptr) {
303+
// expected-error @below {{unsupported cache level, the only supported uniform cache level is L1}}
304+
nvvm.prefetch level = L2 uniform, %gen_ptr : !llvm.ptr
305+
llvm.return
306+
}
307+
308+
// -----
309+
310+
llvm.func @nvvm_prefetch_uniform_with_invalid_addr_space(%global_ptr: !llvm.ptr<1>) {
311+
// expected-error @below {{prefetch to uniform cache requires a generic pointer}}
312+
nvvm.prefetch level = L1 uniform, %global_ptr : !llvm.ptr<1>
313+
llvm.return
314+
}

0 commit comments

Comments
 (0)