Skip to content

Commit 857ac4c

Browse files
authored
[MLIR][OpenMP] Lowering nontemporal clause to LLVM IR for SIMD directive (llvm#118751)
This patch, - Added a new attribute `nontemporal` to fir.load and fir.store operation in the FIR dialect. - Added a pass `lower-nontemporal` which is called before FIRToLLVM conversion pass and adds the nontemporal attribute to loads and stores on the list items specified in the nontemporal clause of the SIMD directive. - Set the `UnitAttr:$nontemporal` to llvm.load and llvm.store operations during FIR to LLVM dialect conversion, if the corresponding fir.load or fir.store operations have the nontemporal attribute. - Attached the `nontemporal metadata` to load and store instructions that have the nontemporal attribute, during LLVM dialect to LLVM IR translation.
1 parent e4332e4 commit 857ac4c

File tree

12 files changed

+418
-23
lines changed

12 files changed

+418
-23
lines changed

flang/include/flang/Optimizer/Dialect/FIROps.td

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,7 @@ def fir_LoadOp : fir_OneResultOp<"load", [FirAliasTagOpInterface,
305305
}];
306306

307307
let arguments = (ins AnyReferenceLike:$memref,
308-
OptionalAttr<LLVM_TBAATagArrayAttr>:$tbaa);
308+
OptionalAttr<LLVM_TBAATagArrayAttr>:$tbaa, UnitAttr:$nontemporal);
309309

310310
let builders = [OpBuilder<(ins "mlir::Value":$refVal)>,
311311
OpBuilder<(ins "mlir::Type":$resTy, "mlir::Value":$refVal)>];
@@ -337,9 +337,8 @@ def fir_StoreOp : fir_Op<"store", [FirAliasTagOpInterface,
337337
`%p`, is undefined or null.
338338
}];
339339

340-
let arguments = (ins AnyType:$value,
341-
AnyReferenceLike:$memref,
342-
OptionalAttr<LLVM_TBAATagArrayAttr>:$tbaa);
340+
let arguments = (ins AnyType:$value, AnyReferenceLike:$memref,
341+
OptionalAttr<LLVM_TBAATagArrayAttr>:$tbaa, UnitAttr:$nontemporal);
343342

344343
let builders = [OpBuilder<(ins "mlir::Value":$value, "mlir::Value":$memref)>];
345344

flang/include/flang/Optimizer/OpenMP/Passes.td

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,13 @@ def DoConcurrentConversionPass : Pass<"omp-do-concurrent-conversion", "mlir::fun
8181
];
8282
}
8383

84+
def LowerNontemporalPass : Pass<"lower-nontemporal", "mlir::func::FuncOp"> {
85+
let summary =
86+
"Adds nontemporal attribute to loads and stores performed on "
87+
"the list items specified in the nontemporal clause of omp.simd.";
88+
let dependentDialects = ["mlir::omp::OpenMPDialect"];
89+
}
90+
8491
// Needs to be scheduled on Module as we create functions in it
8592
def LowerWorkshare : Pass<"lower-workshare", "::mlir::ModuleOp"> {
8693
let summary = "Lower workshare construct";

flang/lib/Optimizer/CodeGen/CodeGen.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3569,8 +3569,13 @@ struct StoreOpConversion : public fir::FIROpConversion<fir::StoreOp> {
35693569
} else {
35703570
mlir::LLVM::StoreOp storeOp =
35713571
rewriter.create<mlir::LLVM::StoreOp>(loc, llvmValue, llvmMemref);
3572+
35723573
if (isVolatile)
35733574
storeOp.setVolatile_(true);
3575+
3576+
if (store.getNontemporal())
3577+
storeOp.setNontemporal(true);
3578+
35743579
newOp = storeOp;
35753580
}
35763581
if (std::optional<mlir::ArrayAttr> optionalTag = store.getTbaa())

flang/lib/Optimizer/OpenMP/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ add_flang_library(FlangOpenMPTransforms
88
MapInfoFinalization.cpp
99
MarkDeclareTarget.cpp
1010
LowerWorkshare.cpp
11+
LowerNontemporal.cpp
1112

1213
DEPENDS
1314
FIRDialect
@@ -17,7 +18,7 @@ add_flang_library(FlangOpenMPTransforms
1718
LINK_LIBS
1819
FIRAnalysis
1920
FIRBuilder
20-
FIRCodeGen
21+
FIRCodeGenDialect
2122
FIRDialect
2223
FIRDialectSupport
2324
FIRSupport
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
//===- LowerNontemporal.cpp -------------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// Add nontemporal attributes to load and stores of variables marked as
10+
// nontemporal.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#include "flang/Optimizer/Dialect/FIRCG/CGOps.h"
15+
#include "flang/Optimizer/Dialect/FIROpsSupport.h"
16+
#include "flang/Optimizer/OpenMP/Passes.h"
17+
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
18+
#include "llvm/ADT/TypeSwitch.h"
19+
20+
using namespace mlir;
21+
22+
namespace flangomp {
23+
#define GEN_PASS_DEF_LOWERNONTEMPORALPASS
24+
#include "flang/Optimizer/OpenMP/Passes.h.inc"
25+
} // namespace flangomp
26+
27+
namespace {
28+
class LowerNontemporalPass
29+
: public flangomp::impl::LowerNontemporalPassBase<LowerNontemporalPass> {
30+
void addNonTemporalAttr(omp::SimdOp simdOp) {
31+
if (simdOp.getNontemporalVars().empty())
32+
return;
33+
34+
std::function<mlir::Value(mlir::Value)> getBaseOperand =
35+
[&](mlir::Value operand) -> mlir::Value {
36+
auto *defOp = operand.getDefiningOp();
37+
while (defOp) {
38+
llvm::TypeSwitch<Operation *>(defOp)
39+
.Case<fir::ArrayCoorOp, fir::cg::XArrayCoorOp, fir::LoadOp>(
40+
[&](auto op) {
41+
operand = op.getMemref();
42+
defOp = operand.getDefiningOp();
43+
})
44+
.Case<fir::BoxAddrOp>([&](auto op) {
45+
operand = op.getVal();
46+
defOp = operand.getDefiningOp();
47+
})
48+
.Default([&](auto op) { defOp = nullptr; });
49+
}
50+
return operand;
51+
};
52+
53+
// walk through the operations and mark the load and store as nontemporal
54+
simdOp->walk([&](Operation *op) {
55+
mlir::Value operand = nullptr;
56+
57+
if (auto loadOp = llvm::dyn_cast<fir::LoadOp>(op))
58+
operand = loadOp.getMemref();
59+
else if (auto storeOp = llvm::dyn_cast<fir::StoreOp>(op))
60+
operand = storeOp.getMemref();
61+
62+
// Skip load and store operations involving boxes (allocatable or pointer
63+
// types).
64+
if (operand && !(fir::isAllocatableType(operand.getType()) ||
65+
fir::isPointerType((operand.getType())))) {
66+
operand = getBaseOperand(operand);
67+
68+
// TODO : Handling of nontemporal clause inside atomic construct
69+
if (llvm::is_contained(simdOp.getNontemporalVars(), operand)) {
70+
if (auto loadOp = llvm::dyn_cast<fir::LoadOp>(op))
71+
loadOp.setNontemporal(true);
72+
else if (auto storeOp = llvm::dyn_cast<fir::StoreOp>(op))
73+
storeOp.setNontemporal(true);
74+
}
75+
}
76+
});
77+
}
78+
79+
void runOnOperation() override {
80+
Operation *op = getOperation();
81+
op->walk([&](omp::SimdOp simdOp) { addNonTemporalAttr(simdOp); });
82+
}
83+
};
84+
} // namespace

flang/lib/Optimizer/Passes/Pipelines.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,11 @@ void createDefaultFIRCodeGenPassPipeline(mlir::PassManager &pm,
353353
config.ApproxFuncFPMath, config.NoSignedZerosFPMath, config.UnsafeFPMath,
354354
""}));
355355

356+
if (config.EnableOpenMP) {
357+
pm.addNestedPass<mlir::func::FuncOp>(
358+
flangomp::createLowerNontemporalPass());
359+
}
360+
356361
fir::addFIRToLLVMPass(pm, config);
357362
}
358363

flang/test/Fir/basic-program.fir

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ func.func @_QQmain() {
149149
// PASSES-NEXT: CompilerGeneratedNamesConversion
150150
// PASSES-NEXT: 'func.func' Pipeline
151151
// PASSES-NEXT: FunctionAttr
152+
// PASSES-NEXT: LowerNontemporalPass
152153
// PASSES-NEXT: FIRToLLVMLowering
153154
// PASSES-NEXT: ReconcileUnrealizedCasts
154155
// PASSES-NEXT: LLVMIRLoweringPass
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
// Test lower-nontemporal pass
2+
// RUN: fir-opt --fir-to-llvm-ir %s | FileCheck %s --check-prefixes=CHECK-LABEL,CHECK
3+
4+
// CHECK-LABEL: llvm.func @_QPtest()
5+
// CHECK: %[[CONST_VAL:.*]] = llvm.mlir.constant(1 : i64) : i64
6+
// CHECK: %[[VAL1:.*]] = llvm.alloca %[[CONST_VAL]] x i32 {bindc_name = "n"} : (i64) -> !llvm.ptr
7+
// CHECK: %[[CONST_VAL1:.*]] = llvm.mlir.constant(1 : i64) : i64
8+
// CHECK: %[[VAL2:.*]] = llvm.alloca %[[CONST_VAL1]] x i32 {bindc_name = "i"} : (i64) -> !llvm.ptr
9+
// CHECK: %[[CONST_VAL2:.*]] = llvm.mlir.constant(1 : i64) : i64
10+
// CHECK: %[[VAL3:.*]] = llvm.alloca %[[CONST_VAL2]] x i32 {bindc_name = "c"} : (i64) -> !llvm.ptr
11+
// CHECK: %[[CONST_VAL3:.*]] = llvm.mlir.constant(1 : i64) : i64
12+
// CHECK: %[[VAL4:.*]] = llvm.alloca %[[CONST_VAL3]] x i32 {bindc_name = "b"} : (i64) -> !llvm.ptr
13+
// CHECK: %[[CONST_VAL4:.*]] = llvm.mlir.constant(1 : i64) : i64
14+
// CHECK: %[[VAL5:.*]] = llvm.alloca %[[CONST_VAL4]] x i32 {bindc_name = "a"} : (i64) -> !llvm.ptr
15+
// CHECK: %[[CONST_VAL5:.*]] = llvm.mlir.constant(1 : i32) : i32
16+
// CHECK: %[[VAL6:.*]] = llvm.load %[[VAL1]] : !llvm.ptr -> i32
17+
// CHECK: omp.simd nontemporal(%[[VAL5]], %[[VAL3]] : !llvm.ptr, !llvm.ptr) private(@_QFtestEi_private_i32 %[[VAL2]] -> %arg0 : !llvm.ptr) {
18+
// CHECK: omp.loop_nest (%{{.*}}) : i32 = (%[[CONST_VAL5]]) to (%[[VAL6]]) inclusive step (%[[CONST_VAL5]]) {
19+
// CHECK: llvm.store %{{.*}}, %{{.*}} : i32, !llvm.ptr
20+
// CHECK: %[[VAL8:.*]] = llvm.load %[[VAL5]] {nontemporal} : !llvm.ptr -> i32
21+
// CHECK: %[[VAL9:.*]] = llvm.load %[[VAL4]] : !llvm.ptr -> i32
22+
// CHECK: %[[VAL10:.*]] = llvm.add %[[VAL8]], %[[VAL9]] : i32
23+
// CHECK: llvm.store %[[VAL10]], %[[VAL3]] {nontemporal} : i32, !llvm.ptr
24+
// CHECK: omp.yield
25+
// CHECK: }
26+
// CHECK: }
27+
28+
func.func @_QPtest() {
29+
%c1_i32 = arith.constant 1 : i32
30+
%0 = fir.alloca i32 {bindc_name = "a", uniq_name = "_QFtestEa"}
31+
%1 = fir.alloca i32 {bindc_name = "b", uniq_name = "_QFtestEb"}
32+
%2 = fir.alloca i32 {bindc_name = "c", uniq_name = "_QFtestEc"}
33+
%3 = fir.alloca i32 {bindc_name = "i", uniq_name = "_QFtestEi"}
34+
%4 = fir.alloca i32 {bindc_name = "n", uniq_name = "_QFtestEn"}
35+
%5 = fir.load %4 : !fir.ref<i32>
36+
omp.simd nontemporal(%0, %2 : !fir.ref<i32>, !fir.ref<i32>) private(@_QFtestEi_private_i32 %3 -> %arg0 : !fir.ref<i32>) {
37+
omp.loop_nest (%arg1) : i32 = (%c1_i32) to (%5) inclusive step (%c1_i32) {
38+
fir.store %arg1 to %arg0 : !fir.ref<i32>
39+
%6 = fir.load %0 {nontemporal}: !fir.ref<i32>
40+
%7 = fir.load %1 : !fir.ref<i32>
41+
%8 = arith.addi %6, %7 : i32
42+
fir.store %8 to %2 {nontemporal} : !fir.ref<i32>
43+
omp.yield
44+
}
45+
}
46+
return
47+
}
48+
49+
// CHECK-LABEL: llvm.func @_QPsimd_nontemporal_allocatable
50+
// CHECK: %[[CONST_VAL:.*]] = llvm.mlir.constant(1 : i64) : i64
51+
// CHECK: %[[ALLOCA2:.*]] = llvm.alloca %[[CONST_VAL]] x i32 {bindc_name = "i"} : (i64) -> !llvm.ptr
52+
// CHECK: %[[IDX_VAL:.*]] = llvm.mlir.constant(1 : i32) : i32
53+
// CHECK: %[[CONST_VAL1:.*]] = llvm.mlir.constant(0 : index) : i64
54+
// CHECK: %[[END_IDX:.*]] = llvm.mlir.constant(100 : i32) : i32
55+
// CHECK: omp.simd nontemporal(%[[ARG0:.*]] : !llvm.ptr) private(@_QFsimd_nontemporal_allocatableEi_private_i32 %[[ALLOCA2]] -> %[[ARG2:.*]] : !llvm.ptr) {
56+
// CHECK: omp.loop_nest (%[[ARG3:.*]]) : i32 = (%[[IDX_VAL]]) to (%[[END_IDX]]) inclusive step (%[[IDX_VAL]]) {
57+
// CHECK: llvm.store %[[ARG3]], %[[ARG2]] : i32, !llvm.ptr
58+
// CHECK: %[[CONST_VAL2:.*]] = llvm.mlir.constant(48 : i32) : i32
59+
// CHECK: "llvm.intr.memcpy"(%[[ALLOCA1:.*]], %[[ARG0]], %[[CONST_VAL2]]) <{isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> ()
60+
// CHECK: %[[VAL1:.*]] = llvm.load %[[ARG2]] : !llvm.ptr -> i32
61+
// CHECK: %[[VAL2:.*]] = llvm.sext %[[VAL1]] : i32 to i64
62+
// CHECK: %[[VAL3:.*]] = llvm.getelementptr %[[ALLOCA1]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>
63+
// CHECK: %[[VAL4:.*]] = llvm.load %[[VAL3]] : !llvm.ptr -> !llvm.ptr
64+
// CHECK: %[[VAL5:.*]] = llvm.getelementptr %[[ALLOCA1]][0, 7, %[[CONST_VAL1]], 0] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>
65+
// CHECK: %[[VAL6:.*]] = llvm.load %[[VAL5]] : !llvm.ptr -> i64
66+
// CHECK: %[[VAL7:.*]] = llvm.getelementptr %[[ALLOCA1]][0, 7, %[[CONST_VAL1]], 1] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>
67+
// CHECK: %[[VAL8:.*]] = llvm.load %[[VAL7]] : !llvm.ptr -> i64
68+
// CHECK: %[[VAL10:.*]] = llvm.mlir.constant(1 : i64) : i64
69+
// CHECK: %[[VAL11:.*]] = llvm.mlir.constant(0 : i64) : i64
70+
// CHECK: %[[VAL12:.*]] = llvm.sub %[[VAL2]], %[[VAL6]] overflow<nsw> : i64
71+
// CHECK: %[[VAL13:.*]] = llvm.mul %[[VAL12]], %[[VAL10]] overflow<nsw> : i64
72+
// CHECK: %[[VAL14:.*]] = llvm.mul %[[VAL13]], %[[VAL10]] overflow<nsw> : i64
73+
// CHECK: %[[VAL15:.*]] = llvm.add %[[VAL14]], %[[VAL11]] overflow<nsw> : i64
74+
// CHECK: %[[VAL16:.*]] = llvm.mul %[[VAL10]], %[[VAL8]] overflow<nsw> : i64
75+
// CHECK: %[[VAL17:.*]] = llvm.getelementptr %[[VAL4]][%[[VAL15]]] : (!llvm.ptr, i64) -> !llvm.ptr, i32
76+
// CHECK: %[[VAL18:.*]] = llvm.load %[[VAL17]] {nontemporal} : !llvm.ptr -> i32
77+
// CHECK: %[[VAL19:.*]] = llvm.load %{{.*}} : !llvm.ptr -> i32
78+
// CHECK: %[[VAL20:.*]] = llvm.add %[[VAL18]], %[[VAL19]] : i32
79+
// CHECK: llvm.store %[[VAL20]], %[[VAL17]] {nontemporal} : i32, !llvm.ptr
80+
// CHECK: omp.yield
81+
// CHECK: }
82+
// CHECK: }
83+
// CHECK: llvm.return
84+
85+
func.func @_QPsimd_nontemporal_allocatable(%arg0: !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>> {fir.bindc_name = "x"}, %arg1: !fir.ref<i32> {fir.bindc_name = "y"}) {
86+
%c100 = arith.constant 100 : index
87+
%c1_i32 = arith.constant 1 : i32
88+
%c0 = arith.constant 0 : index
89+
%c100_i32 = arith.constant 100 : i32
90+
%0 = fir.alloca i32 {bindc_name = "i", uniq_name = "_QFsimd_nontemporal_allocatableEi"}
91+
%1 = fir.allocmem !fir.array<?xi32>, %c100 {fir.must_be_heap = true, uniq_name = "_QFsimd_nontemporal_allocatableEx.alloc"}
92+
%2 = fircg.ext_embox %1(%c100) : (!fir.heap<!fir.array<?xi32>>, index) -> !fir.box<!fir.heap<!fir.array<?xi32>>>
93+
fir.store %2 to %arg0 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
94+
omp.simd nontemporal(%arg0 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) private(@_QFsimd_nontemporal_allocatableEi_private_i32 %0 -> %arg2 : !fir.ref<i32>) {
95+
omp.loop_nest (%arg3) : i32 = (%c1_i32) to (%c100_i32) inclusive step (%c1_i32) {
96+
fir.store %arg3 to %arg2 : !fir.ref<i32>
97+
%7 = fir.load %arg0 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
98+
%8 = fir.load %arg2 : !fir.ref<i32>
99+
%9 = fir.convert %8 : (i32) -> i64
100+
%10 = fir.box_addr %7 : (!fir.box<!fir.heap<!fir.array<?xi32>>>) -> !fir.heap<!fir.array<?xi32>>
101+
%11:3 = fir.box_dims %7, %c0 : (!fir.box<!fir.heap<!fir.array<?xi32>>>, index) -> (index, index, index)
102+
%12 = fircg.ext_array_coor %10(%11#1) origin %11#0<%9> : (!fir.heap<!fir.array<?xi32>>, index, index, i64) -> !fir.ref<i32>
103+
%13 = fir.load %12 {nontemporal} : !fir.ref<i32>
104+
%14 = fir.load %arg1 : !fir.ref<i32>
105+
%15 = arith.addi %13, %14 : i32
106+
fir.store %15 to %12 {nontemporal} : !fir.ref<i32>
107+
omp.yield
108+
}
109+
}
110+
return
111+
}

0 commit comments

Comments
 (0)