Skip to content

Commit 9e4d64b

Browse files
committed
[MLIR][NVVM] Add prefetch Ops
This change adds `prefetch.L1`, `prefetch.L2`, and `prefetch.L1.uniform` Ops to the NVVM dialect for the `prefetch` and `prefetchu` group of instructions. PTX Spec Reference: https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-prefetch-prefetchu
1 parent 663aea2 commit 9e4d64b

File tree

5 files changed

+208
-0
lines changed

5 files changed

+208
-0
lines changed

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ include "mlir/Dialect/LLVMIR/LLVMTypes.td"
2525
def LLVM_PointerGeneric : LLVM_PointerInAddressSpace<0>;
2626
def LLVM_PointerGlobal : LLVM_PointerInAddressSpace<1>;
2727
def LLVM_PointerShared : LLVM_PointerInAddressSpace<3>;
28+
def LLVM_PointerLocal : LLVM_PointerInAddressSpace<5>;
2829
def LLVM_PointerTensor : LLVM_PointerInAddressSpace<6>;
2930
def LLVM_PointerSharedCluster : LLVM_PointerInAddressSpace<7>;
3031

@@ -2333,6 +2334,90 @@ def NVVM_CpAsyncBulkTensorSharedCTAToGlobalOp :
23332334
let hasVerifier = 1;
23342335
}
23352336

2337+
//===----------------------------------------------------------------------===//
2338+
// NVVM Prefetch Ops
2339+
//===----------------------------------------------------------------------===//
2340+
2341+
def NVVM_PrefetchL1Op : NVVM_Op<"prefetch.L1"> {
2342+
let description = [{
2343+
Brings the cache line containing the specified address into L1 cache.
2344+
2345+
Operand `ptr` can be a global, local or generic address pointer.
2346+
No operation is performed if `ptr` maps to a `shared` memory location.
2347+
2348+
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-prefetch-prefetchu)
2349+
}];
2350+
let arguments = (ins AnyTypeOf<[LLVM_PointerGlobal,
2351+
LLVM_PointerLocal,
2352+
LLVM_PointerGeneric]>:$ptr);
2353+
let assemblyFormat = "$ptr attr-dict `:` type($ptr)";
2354+
2355+
let extraClassDeclaration = [{
2356+
static llvm::Intrinsic::ID getIntrinsicID(llvm::Type *ptrType);
2357+
}];
2358+
let llvmBuilder = [{
2359+
auto intId = NVVM::PrefetchL1Op::getIntrinsicID($ptr->getType());
2360+
createIntrinsicCall(builder, intId, $ptr);
2361+
}];
2362+
}
2363+
2364+
def EvictLast : I32EnumAttrCase<"EvictLast", 0, "evict_last">;
2365+
def EvictNormal : I32EnumAttrCase<"EvictNormal", 1, "evict_normal">;
2366+
2367+
def EvictionPriority : I32EnumAttr<"EvictionPriority", "NVVM Eviction Priority",
2368+
[EvictLast, EvictNormal]> {
2369+
let genSpecializedAttr = 0;
2370+
let cppNamespace = "::mlir::NVVM";
2371+
}
2372+
2373+
def EvictionPriorityAttr : EnumAttr<NVVM_Dialect, EvictionPriority, "eviction_priority"> {
2374+
let assemblyFormat = "$value";
2375+
}
2376+
2377+
def NVVM_PrefetchL2Op : NVVM_Op<"prefetch.L2"> {
2378+
let description = [{
2379+
Brings the cache line containing the specified address into L2 cache.
2380+
2381+
Operand `ptr` can be a global, local or generic address pointer.
2382+
No operation is performed if `ptr` maps to a `shared` memory location.
2383+
2384+
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-prefetch-prefetchu)
2385+
}];
2386+
let arguments = (ins AnyTypeOf<[LLVM_PointerGlobal,
2387+
LLVM_PointerLocal,
2388+
LLVM_PointerGeneric]>:$ptr,
2389+
OptionalAttr<EvictionPriorityAttr>:$evictionPriority);
2390+
let assemblyFormat = "$ptr (`,` `evict_priority` `=` $evictionPriority^)? attr-dict `:` type($ptr)";
2391+
let hasVerifier = 1;
2392+
2393+
let extraClassDeclaration = [{
2394+
static llvm::Intrinsic::ID getIntrinsicID(llvm::Type *ptrType, std::optional<NVVM::EvictionPriority> evictionPriority);
2395+
}];
2396+
let llvmBuilder = [{
2397+
auto intId = NVVM::PrefetchL2Op::getIntrinsicID($ptr->getType(), $evictionPriority);
2398+
createIntrinsicCall(builder, intId, $ptr);
2399+
}];
2400+
}
2401+
2402+
def NVVM_PrefetchL1UniformOp : NVVM_Op<"prefetch.L1.uniform"> {
2403+
let description = [{
2404+
Brings the cache line containing the specified address into L1 uniform
2405+
cache.
2406+
2407+
Operand `ptr` is a generic address pointer.
2408+
No operation is performed if `ptr` maps to a `const`, `local`, or `shared`
2409+
memory location.
2410+
2411+
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-prefetch-prefetchu)
2412+
}];
2413+
let arguments = (ins LLVM_PointerGeneric:$ptr);
2414+
let assemblyFormat = "$ptr attr-dict `:` type($ptr)";
2415+
2416+
let llvmBuilder = [{
2417+
createIntrinsicCall(builder, llvm::Intrinsic::nvvm_prefetchu_L1, $ptr);
2418+
}];
2419+
}
2420+
23362421
def NVVM_PrefetchTensorMapOp : NVVM_Op<"prefetch.tensormap",
23372422
[DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,
23382423
Arguments<(ins LLVM_AnyPointer:$tmaDescriptor, PtxPredicate:$predicate)> {

mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1205,6 +1205,15 @@ LogicalResult NVVM::VoteSyncOp::verify() {
12051205
return success();
12061206
}
12071207

1208+
LogicalResult NVVM::PrefetchL2Op::verify() {
1209+
if (getEvictionPriority() &&
1210+
(llvm::cast<LLVM::LLVMPointerType>(getPtr().getType())
1211+
.getAddressSpace() != 1))
1212+
return emitOpError(
1213+
"prefetch with eviction priority requires a global pointer");
1214+
return success();
1215+
}
1216+
12081217
/// Packs the given `field` into the `result`.
12091218
/// The `result` is 64-bits and each `field` can be 32-bits or narrower.
12101219
static llvm::Value *
@@ -1712,6 +1721,42 @@ NVVM::IDArgPair DotAccumulate4WayOp::getIntrinsicIDAndArgs(
17121721
return {ids[type], args};
17131722
}
17141723

1724+
llvm::Intrinsic::ID PrefetchL1Op::getIntrinsicID(llvm::Type *ptrType) {
1725+
switch (ptrType->getPointerAddressSpace()) {
1726+
case 0:
1727+
return llvm::Intrinsic::nvvm_prefetch_L1;
1728+
case 1:
1729+
return llvm::Intrinsic::nvvm_prefetch_global_L1;
1730+
case 5:
1731+
return llvm::Intrinsic::nvvm_prefetch_local_L1;
1732+
default:
1733+
llvm_unreachable("Invalid pointer address space");
1734+
}
1735+
}
1736+
1737+
llvm::Intrinsic::ID PrefetchL2Op::getIntrinsicID(
1738+
llvm::Type *ptrType,
1739+
std::optional<NVVM::EvictionPriority> evictionPriority) {
1740+
switch (ptrType->getPointerAddressSpace()) {
1741+
case 0:
1742+
return llvm::Intrinsic::nvvm_prefetch_L2;
1743+
case 1:
1744+
if (evictionPriority) {
1745+
if (*evictionPriority == NVVM::EvictionPriority::EvictLast)
1746+
return llvm::Intrinsic::nvvm_prefetch_global_L2_evict_last;
1747+
else if (*evictionPriority == NVVM::EvictionPriority::EvictNormal)
1748+
return llvm::Intrinsic::nvvm_prefetch_global_L2_evict_normal;
1749+
else
1750+
llvm_unreachable("Invalid eviction priority");
1751+
}
1752+
return llvm::Intrinsic::nvvm_prefetch_global_L2;
1753+
case 5:
1754+
return llvm::Intrinsic::nvvm_prefetch_local_L2;
1755+
default:
1756+
llvm_unreachable("Invalid pointer address space");
1757+
}
1758+
}
1759+
17151760
//===----------------------------------------------------------------------===//
17161761
// NVVMDialect initialization, type parsing, and registration.
17171762
//===----------------------------------------------------------------------===//

mlir/test/Dialect/LLVMIR/nvvm.mlir

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -587,6 +587,29 @@ func.func @dot_accumulate_4way(%a_vec: vector<4xi8>, %b_vec: vector<4xi8>, %c: i
587587
return
588588
}
589589

590+
// CHECK-LABEL: @prefetch
591+
func.func @prefetch(%gen_ptr: !llvm.ptr, %local_ptr: !llvm.ptr<5>, %global_ptr: !llvm.ptr<1>) {
592+
// CHECK: nvvm.prefetch.L1 %{{.*}}
593+
nvvm.prefetch.L1 %gen_ptr : !llvm.ptr<0>
594+
// CHECK: nvvm.prefetch.L1 %{{.*}}
595+
nvvm.prefetch.L1 %local_ptr : !llvm.ptr<5>
596+
// CHECK: nvvm.prefetch.L1 %{{.*}}
597+
nvvm.prefetch.L1 %global_ptr : !llvm.ptr<1>
598+
// CHECK: nvvm.prefetch.L2 %{{.*}}
599+
nvvm.prefetch.L2 %gen_ptr : !llvm.ptr<0>
600+
// CHECK: nvvm.prefetch.L2 %{{.*}}
601+
nvvm.prefetch.L2 %local_ptr : !llvm.ptr<5>
602+
// CHECK: nvvm.prefetch.L2 %{{.*}}
603+
nvvm.prefetch.L2 %global_ptr : !llvm.ptr<1>
604+
// CHECK: nvvm.prefetch.L2 %{{.*}}
605+
nvvm.prefetch.L2 %global_ptr, evict_priority = evict_last : !llvm.ptr<1>
606+
// CHECK: nvvm.prefetch.L2 %{{.*}}
607+
nvvm.prefetch.L2 %global_ptr, evict_priority = evict_normal : !llvm.ptr<1>
608+
// CHECK: nvvm.prefetch.L1.uniform %{{.*}}
609+
nvvm.prefetch.L1.uniform %gen_ptr : !llvm.ptr
610+
return
611+
}
612+
590613
// -----
591614

592615
// Just check these don't emit errors.
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
2+
3+
llvm.func @prefetch_L1(%gen_ptr: !llvm.ptr, %local_ptr: !llvm.ptr<5>, %global_ptr: !llvm.ptr<1>) {
4+
// CHECK-LABEL: define void @prefetch_L1(ptr %0, ptr addrspace(5) %1, ptr addrspace(1) %2) {
5+
// CHECK-NEXT: call void @llvm.nvvm.prefetch.L1(ptr %0)
6+
// CHECK-NEXT: call void @llvm.nvvm.prefetch.local.L1(ptr addrspace(5) %1)
7+
// CHECK-NEXT: call void @llvm.nvvm.prefetch.global.L1(ptr addrspace(1) %2)
8+
// CHECK-NEXT: ret void
9+
// CHECK-NEXT: }
10+
nvvm.prefetch.L1 %gen_ptr : !llvm.ptr<0>
11+
nvvm.prefetch.L1 %local_ptr : !llvm.ptr<5>
12+
nvvm.prefetch.L1 %global_ptr : !llvm.ptr<1>
13+
llvm.return
14+
}
15+
16+
llvm.func @prefetch_L2(%gen_ptr: !llvm.ptr, %local_ptr: !llvm.ptr<5>, %global_ptr: !llvm.ptr<1>) {
17+
// CHECK-LABEL: define void @prefetch_L2(ptr %0, ptr addrspace(5) %1, ptr addrspace(1) %2) {
18+
// CHECK-NEXT: call void @llvm.nvvm.prefetch.L2(ptr %0)
19+
// CHECK-NEXT: call void @llvm.nvvm.prefetch.local.L2(ptr addrspace(5) %1)
20+
// CHECK-NEXT: call void @llvm.nvvm.prefetch.global.L2(ptr addrspace(1) %2)
21+
// CHECK-NEXT: ret void
22+
// CHECK-NEXT: }
23+
nvvm.prefetch.L2 %gen_ptr : !llvm.ptr<0>
24+
nvvm.prefetch.L2 %local_ptr : !llvm.ptr<5>
25+
nvvm.prefetch.L2 %global_ptr : !llvm.ptr<1>
26+
llvm.return
27+
}
28+
29+
llvm.func @prefetch_L2_eviction_priority(%global_ptr: !llvm.ptr<1>) {
30+
// CHECK-LABEL: define void @prefetch_L2_eviction_priority(ptr addrspace(1) %0) {
31+
// CHECK-NEXT: call void @llvm.nvvm.prefetch.global.L2.evict.last(ptr addrspace(1) %0)
32+
// CHECK-NEXT: call void @llvm.nvvm.prefetch.global.L2.evict.normal(ptr addrspace(1) %0)
33+
// CHECK-NEXT: ret void
34+
// CHECK-NEXT: }
35+
nvvm.prefetch.L2 %global_ptr, evict_priority = evict_last : !llvm.ptr<1>
36+
nvvm.prefetch.L2 %global_ptr, evict_priority = evict_normal : !llvm.ptr<1>
37+
llvm.return
38+
}
39+
40+
llvm.func @prefetch_L1_uniform(%gen_ptr: !llvm.ptr) {
41+
// CHECK-LABEL: define void @prefetch_L1_uniform(ptr %0) {
42+
// CHECK-NEXT: call void @llvm.nvvm.prefetchu.L1(ptr %0)
43+
// CHECK-NEXT: ret void
44+
// CHECK-NEXT: }
45+
nvvm.prefetch.L1.uniform %gen_ptr : !llvm.ptr
46+
llvm.return
47+
}

mlir/test/Target/LLVMIR/nvvmir-invalid.mlir

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,3 +248,11 @@ llvm.func @nvvm_cvt_bf16x2_to_f8x2_invalid_rounding(%src : vector<2xbf16>) {
248248
%res = nvvm.convert.bf16x2.to.f8x2 <ue8m0> %src {rnd = #nvvm.fp_rnd_mode<rn>} : vector<2xbf16> -> i16
249249
llvm.return
250250
}
251+
252+
// -----
253+
254+
llvm.func @nvvm_prefetch_L2_with_evict_invalid_addr_space(%local_ptr: !llvm.ptr<5>) {
255+
// expected-error @below {{prefetch with eviction priority requires a global pointer}}
256+
nvvm.prefetch.L2 %local_ptr, evict_priority = evict_last : !llvm.ptr<5>
257+
llvm.return
258+
}

0 commit comments

Comments
 (0)