Skip to content

Commit 1aed6ad

Browse files
jsjodinskatrak
andauthored
[MLIR][OpenMP] Enable multiple variables for target teams reductions (llvm#134903)
This patch enables multiple reductions to be used in a reduction clause inside target regions for GPU offloading. --------- Co-authored-by: Sergio Afonso <safonsof@amd.com>
1 parent 40f9bb9 commit 1aed6ad

File tree

2 files changed

+137
-6
lines changed

2 files changed

+137
-6
lines changed

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4690,12 +4690,18 @@ static uint64_t getTypeByteSize(mlir::Type type, const DataLayout &dl) {
46904690
template <typename OpTy>
46914691
static uint64_t getReductionDataSize(OpTy &op) {
46924692
if (op.getNumReductionVars() > 0) {
4693-
assert(op.getNumReductionVars() == 1 &&
4694-
"Only 1 reduction variable currently supported");
4695-
mlir::Type reductionVarTy = op.getReductionVars()[0].getType();
4693+
SmallVector<omp::DeclareReductionOp> reductions;
4694+
collectReductionDecls(op, reductions);
4695+
4696+
llvm::SmallVector<mlir::Type> members;
4697+
members.reserve(reductions.size());
4698+
for (omp::DeclareReductionOp &red : reductions)
4699+
members.push_back(red.getType());
46964700
Operation *opp = op.getOperation();
4701+
auto structType = mlir::LLVM::LLVMStructType::getLiteral(
4702+
opp->getContext(), members, /*isPacked=*/false);
46974703
DataLayout dl = DataLayout(opp->getParentOfType<ModuleOp>());
4698-
return getTypeByteSize(reductionVarTy, dl);
4704+
return getTypeByteSize(structType, dl);
46994705
}
47004706
return 0;
47014707
}
@@ -4791,8 +4797,6 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
47914797
(maxThreadsVal >= 0 && maxThreadsVal < combinedMaxThreadsVal))
47924798
combinedMaxThreadsVal = maxThreadsVal;
47934799

4794-
// Calculate reduction data size, limited to single reduction variable for
4795-
// now.
47964800
int32_t reductionDataSize = 0;
47974801
if (isGPU && capturedOp) {
47984802
if (auto teamsOp = castOrGetParentOfType<omp::TeamsOp>(capturedOp))
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
2+
3+
// Only check the overall shape of the code and the presence of relevant
4+
// runtime calls. Actual IR checking is done at the OpenMPIRBuilder level.
5+
6+
module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memory_space", 5 : ui32>>, llvm.data_layout = "e-p:64:64-p1:64:64-p2:32:32-p3:32:32-p4:64:64-p5:32:32-p6:32:32-p7:160:256:256:32-p8:128:128-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-v2048:2048-n32:64-S32-A5-G1-ni:7:8", llvm.target_triple = "amdgcn-amd-amdhsa", omp.is_gpu = true, omp.is_target_device = true } {
7+
omp.private {type = private} @_QFEj_private_i32 : i32
8+
omp.declare_reduction @add_reduction_f32 : f32 init {
9+
^bb0(%arg0: f32):
10+
%0 = llvm.mlir.constant(0.000000e+00 : f32) : f32
11+
omp.yield(%0 : f32)
12+
} combiner {
13+
^bb0(%arg0: f32, %arg1: f32):
14+
%0 = llvm.fadd %arg0, %arg1 {fastmathFlags = #llvm.fastmath<contract>} : f32
15+
omp.yield(%0 : f32)
16+
}
17+
omp.declare_reduction @add_reduction_f64 : f64 init {
18+
^bb0(%arg0: f64):
19+
%0 = llvm.mlir.constant(0.000000e+00 : f64) : f64
20+
omp.yield(%0 : f64)
21+
} combiner {
22+
^bb0(%arg0: f64, %arg1: f64):
23+
%0 = llvm.fadd %arg0, %arg1 {fastmathFlags = #llvm.fastmath<contract>} : f64
24+
omp.yield(%0 : f64)
25+
}
26+
llvm.func @_QQmain() attributes {fir.bindc_name = "reduction", frame_pointer = #llvm.framePointerKind<all>, omp.declare_target = #omp.declaretarget<device_type = (host), capture_clause = (to)>, target_cpu = "gfx1030", target_features = #llvm.target_features<["+16-bit-insts", "+ci-insts", "+dl-insts", "+dot1-insts", "+dot10-insts", "+dot2-insts", "+dot5-insts", "+dot6-insts", "+dot7-insts", "+dpp", "+gfx10-3-insts", "+gfx10-insts", "+gfx8-insts", "+gfx9-insts", "+gws", "+image-insts", "+s-memrealtime", "+s-memtime-inst", "+vmem-to-lds-load-insts", "+wavefrontsize32"]>} {
27+
%0 = llvm.mlir.constant(1 : i64) : i64
28+
%1 = llvm.alloca %0 x i32 {bindc_name = "k"} : (i64) -> !llvm.ptr<5>
29+
%2 = llvm.addrspacecast %1 : !llvm.ptr<5> to !llvm.ptr
30+
%3 = llvm.mlir.constant(1 : i64) : i64
31+
%4 = llvm.alloca %3 x i32 {bindc_name = "j"} : (i64) -> !llvm.ptr<5>
32+
%5 = llvm.addrspacecast %4 : !llvm.ptr<5> to !llvm.ptr
33+
%6 = llvm.mlir.constant(1 : i64) : i64
34+
%7 = llvm.alloca %6 x i32 {bindc_name = "i"} : (i64) -> !llvm.ptr<5>
35+
%8 = llvm.addrspacecast %7 : !llvm.ptr<5> to !llvm.ptr
36+
%9 = llvm.mlir.constant(1 : i64) : i64
37+
%10 = llvm.alloca %9 x f32 {bindc_name = "ce4"} : (i64) -> !llvm.ptr<5>
38+
%11 = llvm.addrspacecast %10 : !llvm.ptr<5> to !llvm.ptr
39+
%12 = llvm.mlir.constant(1 : i64) : i64
40+
%13 = llvm.alloca %12 x f32 {bindc_name = "ce3"} : (i64) -> !llvm.ptr<5>
41+
%14 = llvm.addrspacecast %13 : !llvm.ptr<5> to !llvm.ptr
42+
%15 = llvm.mlir.constant(1 : i64) : i64
43+
%16 = llvm.alloca %15 x f64 {bindc_name = "ce2"} : (i64) -> !llvm.ptr<5>
44+
%17 = llvm.addrspacecast %16 : !llvm.ptr<5> to !llvm.ptr
45+
%18 = llvm.mlir.constant(1 : i64) : i64
46+
%19 = llvm.alloca %18 x f64 {bindc_name = "ce1"} : (i64) -> !llvm.ptr<5>
47+
%20 = llvm.addrspacecast %19 : !llvm.ptr<5> to !llvm.ptr
48+
%21 = llvm.mlir.constant(0.000000e+00 : f32) : f32
49+
%22 = llvm.mlir.constant(0.000000e+00 : f64) : f64
50+
%23 = llvm.mlir.constant(1 : i64) : i64
51+
%24 = llvm.mlir.constant(1 : i64) : i64
52+
%25 = llvm.mlir.constant(1 : i64) : i64
53+
%26 = llvm.mlir.constant(1 : i64) : i64
54+
%27 = llvm.mlir.constant(1 : i64) : i64
55+
%28 = llvm.mlir.constant(1 : i64) : i64
56+
%29 = llvm.mlir.constant(1 : i64) : i64
57+
llvm.store %22, %20 : f64, !llvm.ptr
58+
llvm.store %22, %17 : f64, !llvm.ptr
59+
llvm.store %21, %14 : f32, !llvm.ptr
60+
llvm.store %21, %11 : f32, !llvm.ptr
61+
%30 = omp.map.info var_ptr(%20 : !llvm.ptr, f64) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = "ce1"}
62+
%31 = omp.map.info var_ptr(%17 : !llvm.ptr, f64) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = "ce2"}
63+
%32 = omp.map.info var_ptr(%14 : !llvm.ptr, f32) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = "ce3"}
64+
%33 = omp.map.info var_ptr(%11 : !llvm.ptr, f32) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = "ce4"}
65+
%34 = omp.map.info var_ptr(%5 : !llvm.ptr, i32) map_clauses(implicit, exit_release_or_enter_alloc) capture(ByCopy) -> !llvm.ptr {name = "j"}
66+
omp.target map_entries(%30 -> %arg0, %31 -> %arg1, %32 -> %arg2, %33 -> %arg3, %34 -> %arg4 : !llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr) {
67+
%35 = llvm.mlir.constant(1.000000e+00 : f32) : f32
68+
%36 = llvm.mlir.constant(1.000000e+00 : f64) : f64
69+
%37 = llvm.mlir.constant(1000 : i32) : i32
70+
%38 = llvm.mlir.constant(1 : i32) : i32
71+
omp.teams reduction(@add_reduction_f64 %arg0 -> %arg5, @add_reduction_f64 %arg1 -> %arg6, @add_reduction_f32 %arg2 -> %arg7, @add_reduction_f32 %arg3 -> %arg8 : !llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr) {
72+
omp.parallel {
73+
omp.distribute {
74+
omp.wsloop reduction(@add_reduction_f64 %arg5 -> %arg9, @add_reduction_f64 %arg6 -> %arg10, @add_reduction_f32 %arg7 -> %arg11, @add_reduction_f32 %arg8 -> %arg12 : !llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr) {
75+
omp.simd private(@_QFEj_private_i32 %arg4 -> %arg13 : !llvm.ptr) reduction(@add_reduction_f64 %arg9 -> %arg14, @add_reduction_f64 %arg10 -> %arg15, @add_reduction_f32 %arg11 -> %arg16, @add_reduction_f32 %arg12 -> %arg17 : !llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr) {
76+
omp.loop_nest (%arg18) : i32 = (%38) to (%37) inclusive step (%38) {
77+
llvm.store %arg18, %arg13 : i32, !llvm.ptr
78+
%39 = llvm.load %arg14 : !llvm.ptr -> f64
79+
%40 = llvm.fadd %39, %36 {fastmathFlags = #llvm.fastmath<contract>} : f64
80+
llvm.store %40, %arg14 : f64, !llvm.ptr
81+
%41 = llvm.load %arg15 : !llvm.ptr -> f64
82+
%42 = llvm.fadd %41, %36 {fastmathFlags = #llvm.fastmath<contract>} : f64
83+
llvm.store %42, %arg15 : f64, !llvm.ptr
84+
%43 = llvm.load %arg16 : !llvm.ptr -> f32
85+
%44 = llvm.fadd %43, %35 {fastmathFlags = #llvm.fastmath<contract>} : f32
86+
llvm.store %44, %arg16 : f32, !llvm.ptr
87+
%45 = llvm.load %arg17 : !llvm.ptr -> f32
88+
%46 = llvm.fadd %45, %35 {fastmathFlags = #llvm.fastmath<contract>} : f32
89+
llvm.store %46, %arg17 : f32, !llvm.ptr
90+
omp.yield
91+
}
92+
} {omp.composite}
93+
} {omp.composite}
94+
} {omp.composite}
95+
omp.terminator
96+
} {omp.composite}
97+
omp.terminator
98+
}
99+
omp.terminator
100+
}
101+
llvm.return
102+
}
103+
}
104+
105+
// CHECK: kernel_environment =
106+
// CHECK-SAME: i32 24, i32 1024
107+
// CHECK: call void @[[OUTLINED:__omp_offloading_[A-Za-z0-9_.]*]]
108+
// CHECK: %[[MASTER:.+]] = call i32 @__kmpc_nvptx_teams_reduce_nowait_v2
109+
// CHECK: icmp eq i32 %[[MASTER]], 1
110+
// CHECK: i1 %{{.+}}, label %[[THEN:[A-Za-z0-9_.]*]], label %[[DONE:[A-Za-z0-9_.]*]]
111+
// CHECK: [[THEN]]:
112+
// CHECK-NEXT: %[[FINAL_RHS0:[A-Za-z0-9_.]*]] = load double
113+
// CHECK-NEXT: %[[FINAL_LHS0:[A-Za-z0-9_.]*]] = load double
114+
// CHECK-NEXT: %[[FINAL_RESULT0:[A-Za-z0-9_.]*]] = fadd contract double %[[FINAL_LHS0]], %[[FINAL_RHS0]]
115+
// CHECK-NEXT: store double %[[FINAL_RESULT0]]
116+
// CHECK-NEXT: %[[FINAL_RHS1:[A-Za-z0-9_.]*]] = load double
117+
// CHECK-NEXT: %[[FINAL_LHS1:[A-Za-z0-9_.]*]] = load double
118+
// CHECK-NEXT: %[[FINAL_RESULT1:[A-Za-z0-9_.]*]] = fadd contract double %[[FINAL_LHS1]], %[[FINAL_RHS1]]
119+
// CHECK-NEXT: store double %[[FINAL_RESULT1]]
120+
// CHECK-NEXT: %[[FINAL_RHS2:[A-Za-z0-9_.]*]] = load float
121+
// CHECK-NEXT: %[[FINAL_LHS2:[A-Za-z0-9_.]*]] = load float
122+
// CHECK-NEXT: %[[FINAL_RESULT2:[A-Za-z0-9_.]*]] = fadd contract float %[[FINAL_LHS2]], %[[FINAL_RHS2]]
123+
// CHECK-NEXT: store float %[[FINAL_RESULT2]]
124+
// CHECK-NEXT: %[[FINAL_RHS3:[A-Za-z0-9_.]*]] = load float
125+
// CHECK-NEXT: %[[FINAL_LHS3:[A-Za-z0-9_.]*]] = load float
126+
// CHECK-NEXT: %[[FINAL_RESULT3:[A-Za-z0-9_.]*]] = fadd contract float %[[FINAL_LHS3]], %[[FINAL_RHS3]]
127+
// CHECK-NEXT: store float %[[FINAL_RESULT3]]

0 commit comments

Comments
 (0)