Skip to content

Commit 0babff9

Browse files
authored
[flang] Lower REDUCE intrinsic with no DIM argument and rank 1 (llvm#94652)
This patch lowers the `REDUCE` intrinsic call to the runtime equivalent for scalar results. Call with array result will follow.
1 parent fade04f commit 0babff9

File tree

9 files changed

+877
-22
lines changed

9 files changed

+877
-22
lines changed

Diff for: flang/include/flang/Optimizer/Builder/IntrinsicCall.h

+2
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,8 @@ struct IntrinsicLibrary {
328328
void genRandomNumber(llvm::ArrayRef<fir::ExtendedValue>);
329329
void genRandomSeed(llvm::ArrayRef<fir::ExtendedValue>);
330330
fir::ExtendedValue genReduce(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
331+
fir::ExtendedValue genReduceDim(mlir::Type,
332+
llvm::ArrayRef<fir::ExtendedValue>);
331333
fir::ExtendedValue genRepeat(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
332334
fir::ExtendedValue genReshape(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
333335
mlir::Value genRRSpacing(mlir::Type resultType,

Diff for: flang/include/flang/Optimizer/Builder/Runtime/RTBuilder.h

+180-5
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "flang/Optimizer/Builder/FIRBuilder.h"
2323
#include "flang/Optimizer/Dialect/FIRDialect.h"
2424
#include "flang/Optimizer/Dialect/FIRType.h"
25+
#include "flang/Runtime/reduce.h"
2526
#include "mlir/IR/BuiltinTypes.h"
2627
#include "mlir/IR/MLIRContext.h"
2728
#include "llvm/ADT/SmallVector.h"
@@ -52,6 +53,34 @@ namespace fir::runtime {
5253
using TypeBuilderFunc = mlir::Type (*)(mlir::MLIRContext *);
5354
using FuncTypeBuilderFunc = mlir::FunctionType (*)(mlir::MLIRContext *);
5455

56+
#define REDUCTION_OPERATION_MODEL(T) \
57+
template <> \
58+
constexpr TypeBuilderFunc \
59+
getModel<Fortran::runtime::ReductionOperation<T>>() { \
60+
return [](mlir::MLIRContext *context) -> mlir::Type { \
61+
TypeBuilderFunc f{getModel<T>()}; \
62+
auto refTy = fir::ReferenceType::get(f(context)); \
63+
return mlir::FunctionType::get(context, {refTy, refTy}, refTy); \
64+
}; \
65+
}
66+
67+
#define REDUCTION_CHAR_OPERATION_MODEL(T) \
68+
template <> \
69+
constexpr TypeBuilderFunc \
70+
getModel<Fortran::runtime::ReductionCharOperation<T>>() { \
71+
return [](mlir::MLIRContext *context) -> mlir::Type { \
72+
TypeBuilderFunc f{getModel<T>()}; \
73+
auto voidTy = fir::LLVMPointerType::get( \
74+
context, mlir::IntegerType::get(context, 8)); \
75+
auto size_tTy = \
76+
mlir::IntegerType::get(context, 8 * sizeof(std::size_t)); \
77+
auto refTy = fir::ReferenceType::get(f(context)); \
78+
return mlir::FunctionType::get( \
79+
context, {refTy, size_tTy, refTy, refTy, size_tTy, size_tTy}, \
80+
voidTy); \
81+
}; \
82+
}
83+
5584
//===----------------------------------------------------------------------===//
5685
// Type builder models
5786
//===----------------------------------------------------------------------===//
@@ -75,14 +104,24 @@ constexpr TypeBuilderFunc getModel<unsigned int>() {
75104
return mlir::IntegerType::get(context, 8 * sizeof(unsigned int));
76105
};
77106
}
78-
79107
template <>
80108
constexpr TypeBuilderFunc getModel<short int>() {
81109
return [](mlir::MLIRContext *context) -> mlir::Type {
82110
return mlir::IntegerType::get(context, 8 * sizeof(short int));
83111
};
84112
}
85113
template <>
114+
constexpr TypeBuilderFunc getModel<short int *>() {
115+
return [](mlir::MLIRContext *context) -> mlir::Type {
116+
TypeBuilderFunc f{getModel<short int>()};
117+
return fir::ReferenceType::get(f(context));
118+
};
119+
}
120+
template <>
121+
constexpr TypeBuilderFunc getModel<const short int *>() {
122+
return getModel<short int *>();
123+
}
124+
template <>
86125
constexpr TypeBuilderFunc getModel<int>() {
87126
return [](mlir::MLIRContext *context) -> mlir::Type {
88127
return mlir::IntegerType::get(context, 8 * sizeof(int));
@@ -96,6 +135,17 @@ constexpr TypeBuilderFunc getModel<int &>() {
96135
};
97136
}
98137
template <>
138+
constexpr TypeBuilderFunc getModel<int *>() {
139+
return getModel<int &>();
140+
}
141+
template <>
142+
constexpr TypeBuilderFunc getModel<const int *>() {
143+
return [](mlir::MLIRContext *context) -> mlir::Type {
144+
TypeBuilderFunc f{getModel<int>()};
145+
return fir::ReferenceType::get(f(context));
146+
};
147+
}
148+
template <>
99149
constexpr TypeBuilderFunc getModel<char *>() {
100150
return [](mlir::MLIRContext *context) -> mlir::Type {
101151
return fir::ReferenceType::get(mlir::IntegerType::get(context, 8));
@@ -130,6 +180,43 @@ constexpr TypeBuilderFunc getModel<signed char>() {
130180
};
131181
}
132182
template <>
183+
constexpr TypeBuilderFunc getModel<signed char *>() {
184+
return [](mlir::MLIRContext *context) -> mlir::Type {
185+
TypeBuilderFunc f{getModel<signed char>()};
186+
return fir::ReferenceType::get(f(context));
187+
};
188+
}
189+
template <>
190+
constexpr TypeBuilderFunc getModel<const signed char *>() {
191+
return getModel<signed char *>();
192+
}
193+
template <>
194+
constexpr TypeBuilderFunc getModel<char16_t>() {
195+
return [](mlir::MLIRContext *context) -> mlir::Type {
196+
return mlir::IntegerType::get(context, 8 * sizeof(char16_t));
197+
};
198+
}
199+
template <>
200+
constexpr TypeBuilderFunc getModel<char16_t *>() {
201+
return [](mlir::MLIRContext *context) -> mlir::Type {
202+
TypeBuilderFunc f{getModel<char16_t>()};
203+
return fir::ReferenceType::get(f(context));
204+
};
205+
}
206+
template <>
207+
constexpr TypeBuilderFunc getModel<char32_t>() {
208+
return [](mlir::MLIRContext *context) -> mlir::Type {
209+
return mlir::IntegerType::get(context, 8 * sizeof(char32_t));
210+
};
211+
}
212+
template <>
213+
constexpr TypeBuilderFunc getModel<char32_t *>() {
214+
return [](mlir::MLIRContext *context) -> mlir::Type {
215+
TypeBuilderFunc f{getModel<char32_t>()};
216+
return fir::ReferenceType::get(f(context));
217+
};
218+
}
219+
template <>
133220
constexpr TypeBuilderFunc getModel<unsigned char>() {
134221
return [](mlir::MLIRContext *context) -> mlir::Type {
135222
return mlir::IntegerType::get(context, 8 * sizeof(unsigned char));
@@ -175,6 +262,10 @@ constexpr TypeBuilderFunc getModel<long *>() {
175262
return getModel<long &>();
176263
}
177264
template <>
265+
constexpr TypeBuilderFunc getModel<const long *>() {
266+
return getModel<long *>();
267+
}
268+
template <>
178269
constexpr TypeBuilderFunc getModel<long long>() {
179270
return [](mlir::MLIRContext *context) -> mlir::Type {
180271
return mlir::IntegerType::get(context, 8 * sizeof(long long));
@@ -199,6 +290,10 @@ constexpr TypeBuilderFunc getModel<long long *>() {
199290
return getModel<long long &>();
200291
}
201292
template <>
293+
constexpr TypeBuilderFunc getModel<const long long *>() {
294+
return getModel<long long *>();
295+
}
296+
template <>
202297
constexpr TypeBuilderFunc getModel<unsigned long>() {
203298
return [](mlir::MLIRContext *context) -> mlir::Type {
204299
return mlir::IntegerType::get(context, 8 * sizeof(unsigned long));
@@ -228,6 +323,27 @@ constexpr TypeBuilderFunc getModel<double *>() {
228323
return getModel<double &>();
229324
}
230325
template <>
326+
constexpr TypeBuilderFunc getModel<const double *>() {
327+
return getModel<double *>();
328+
}
329+
template <>
330+
constexpr TypeBuilderFunc getModel<long double>() {
331+
return [](mlir::MLIRContext *context) -> mlir::Type {
332+
return mlir::FloatType::getF80(context);
333+
};
334+
}
335+
template <>
336+
constexpr TypeBuilderFunc getModel<long double *>() {
337+
return [](mlir::MLIRContext *context) -> mlir::Type {
338+
TypeBuilderFunc f{getModel<long double>()};
339+
return fir::ReferenceType::get(f(context));
340+
};
341+
}
342+
template <>
343+
constexpr TypeBuilderFunc getModel<const long double *>() {
344+
return getModel<long double *>();
345+
}
346+
template <>
231347
constexpr TypeBuilderFunc getModel<float>() {
232348
return [](mlir::MLIRContext *context) -> mlir::Type {
233349
return mlir::FloatType::getF32(context);
@@ -245,6 +361,10 @@ constexpr TypeBuilderFunc getModel<float *>() {
245361
return getModel<float &>();
246362
}
247363
template <>
364+
constexpr TypeBuilderFunc getModel<const float *>() {
365+
return getModel<float *>();
366+
}
367+
template <>
248368
constexpr TypeBuilderFunc getModel<bool>() {
249369
return [](mlir::MLIRContext *context) -> mlir::Type {
250370
return mlir::IntegerType::get(context, 1);
@@ -258,20 +378,48 @@ constexpr TypeBuilderFunc getModel<bool &>() {
258378
};
259379
}
260380
template <>
381+
constexpr TypeBuilderFunc getModel<std::complex<float>>() {
382+
return [](mlir::MLIRContext *context) -> mlir::Type {
383+
return mlir::ComplexType::get(mlir::FloatType::getF32(context));
384+
};
385+
}
386+
template <>
261387
constexpr TypeBuilderFunc getModel<std::complex<float> &>() {
262388
return [](mlir::MLIRContext *context) -> mlir::Type {
263-
auto ty = mlir::ComplexType::get(mlir::FloatType::getF32(context));
264-
return fir::ReferenceType::get(ty);
389+
TypeBuilderFunc f{getModel<std::complex<float>>()};
390+
return fir::ReferenceType::get(f(context));
391+
};
392+
}
393+
template <>
394+
constexpr TypeBuilderFunc getModel<std::complex<float> *>() {
395+
return getModel<std::complex<float> &>();
396+
}
397+
template <>
398+
constexpr TypeBuilderFunc getModel<const std::complex<float> *>() {
399+
return getModel<std::complex<float> *>();
400+
}
401+
template <>
402+
constexpr TypeBuilderFunc getModel<std::complex<double>>() {
403+
return [](mlir::MLIRContext *context) -> mlir::Type {
404+
return mlir::ComplexType::get(mlir::FloatType::getF64(context));
265405
};
266406
}
267407
template <>
268408
constexpr TypeBuilderFunc getModel<std::complex<double> &>() {
269409
return [](mlir::MLIRContext *context) -> mlir::Type {
270-
auto ty = mlir::ComplexType::get(mlir::FloatType::getF64(context));
271-
return fir::ReferenceType::get(ty);
410+
TypeBuilderFunc f{getModel<std::complex<double>>()};
411+
return fir::ReferenceType::get(f(context));
272412
};
273413
}
274414
template <>
415+
constexpr TypeBuilderFunc getModel<std::complex<double> *>() {
416+
return getModel<std::complex<double> &>();
417+
}
418+
template <>
419+
constexpr TypeBuilderFunc getModel<const std::complex<double> *>() {
420+
return getModel<std::complex<double> *>();
421+
}
422+
template <>
275423
constexpr TypeBuilderFunc getModel<c_float_complex_t>() {
276424
return [](mlir::MLIRContext *context) -> mlir::Type {
277425
return fir::ComplexType::get(context, sizeof(float));
@@ -332,6 +480,33 @@ constexpr TypeBuilderFunc getModel<void>() {
332480
};
333481
}
334482

483+
REDUCTION_OPERATION_MODEL(std::int8_t)
484+
REDUCTION_OPERATION_MODEL(std::int16_t)
485+
REDUCTION_OPERATION_MODEL(std::int32_t)
486+
REDUCTION_OPERATION_MODEL(std::int64_t)
487+
REDUCTION_OPERATION_MODEL(Fortran::common::int128_t)
488+
489+
REDUCTION_OPERATION_MODEL(float)
490+
REDUCTION_OPERATION_MODEL(double)
491+
REDUCTION_OPERATION_MODEL(long double)
492+
493+
REDUCTION_OPERATION_MODEL(std::complex<float>)
494+
REDUCTION_OPERATION_MODEL(std::complex<double>)
495+
496+
REDUCTION_CHAR_OPERATION_MODEL(char)
497+
REDUCTION_CHAR_OPERATION_MODEL(char16_t)
498+
REDUCTION_CHAR_OPERATION_MODEL(char32_t)
499+
500+
template <>
501+
constexpr TypeBuilderFunc
502+
getModel<Fortran::runtime::ReductionDerivedTypeOperation>() {
503+
return [](mlir::MLIRContext *context) -> mlir::Type {
504+
auto voidTy =
505+
fir::LLVMPointerType::get(context, mlir::IntegerType::get(context, 8));
506+
return mlir::FunctionType::get(context, {voidTy, voidTy, voidTy}, voidTy);
507+
};
508+
}
509+
335510
template <typename...>
336511
struct RuntimeTableKey;
337512
template <typename RT, typename... ATs>

Diff for: flang/include/flang/Optimizer/Builder/Runtime/Reduction.h

+16
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,22 @@ void genIParityDim(fir::FirOpBuilder &builder, mlir::Location loc,
224224
mlir::Value resultBox, mlir::Value arrayBox, mlir::Value dim,
225225
mlir::Value maskBox);
226226

227+
/// Generate call to `Reduce` intrinsic runtime routine. This is the version
228+
/// that does not take a dim argument and store the result in the provided
229+
/// result value. This is used for COMPLEX, CHARACTER and DERIVED TYPES.
230+
void genReduce(fir::FirOpBuilder &builder, mlir::Location loc,
231+
mlir::Value arrayBox, mlir::Value operation, mlir::Value maskBox,
232+
mlir::Value identity, mlir::Value ordered,
233+
mlir::Value resultBox);
234+
235+
/// Generate call to `Reduce` intrinsic runtime routine. This is the version
236+
/// that does not take a dim argument and return a scalare result. This is used
237+
/// for REAL, INTEGER and LOGICAL TYPES.
238+
mlir::Value genReduce(fir::FirOpBuilder &builder, mlir::Location loc,
239+
mlir::Value arrayBox, mlir::Value operation,
240+
mlir::Value maskBox, mlir::Value identity,
241+
mlir::Value ordered);
242+
227243
} // namespace fir::runtime
228244

229245
#endif // FORTRAN_OPTIMIZER_BUILDER_RUNTIME_REDUCTION_H

Diff for: flang/lib/Optimizer/Builder/IntrinsicCall.cpp

+57-3
Original file line numberDiff line numberDiff line change
@@ -526,8 +526,8 @@ static constexpr IntrinsicHandler handlers[]{
526526
{"operation", asAddr},
527527
{"dim", asValue},
528528
{"mask", asBox, handleDynamicOptional},
529-
{"identity", asValue},
530-
{"ordered", asValue}}},
529+
{"identity", asAddr, handleDynamicOptional},
530+
{"ordered", asValue, handleDynamicOptional}}},
531531
/*isElemental=*/false},
532532
{"repeat",
533533
&I::genRepeat,
@@ -5736,7 +5736,61 @@ void IntrinsicLibrary::genRandomSeed(llvm::ArrayRef<fir::ExtendedValue> args) {
57365736
fir::ExtendedValue
57375737
IntrinsicLibrary::genReduce(mlir::Type resultType,
57385738
llvm::ArrayRef<fir::ExtendedValue> args) {
5739-
TODO(loc, "intrinsic: reduce");
5739+
assert(args.size() == 6);
5740+
5741+
fir::BoxValue arrayTmp = builder.createBox(loc, args[0]);
5742+
mlir::Value array = fir::getBase(arrayTmp);
5743+
mlir::Value operation = fir::getBase(args[1]);
5744+
int rank = arrayTmp.rank();
5745+
assert(rank >= 1);
5746+
5747+
mlir::Type ty = array.getType();
5748+
mlir::Type arrTy = fir::dyn_cast_ptrOrBoxEleTy(ty);
5749+
mlir::Type eleTy = mlir::cast<fir::SequenceType>(arrTy).getEleTy();
5750+
5751+
// Handle optional arguments
5752+
bool absentDim = isStaticallyAbsent(args[2]);
5753+
5754+
auto mask = isStaticallyAbsent(args[3])
5755+
? builder.create<fir::AbsentOp>(
5756+
loc, fir::BoxType::get(builder.getI1Type()))
5757+
: builder.createBox(loc, args[3]);
5758+
5759+
mlir::Value identity =
5760+
isStaticallyAbsent(args[4])
5761+
? builder.create<fir::AbsentOp>(loc, fir::ReferenceType::get(eleTy))
5762+
: fir::getBase(args[4]);
5763+
5764+
mlir::Value ordered = isStaticallyAbsent(args[5])
5765+
? builder.createBool(loc, false)
5766+
: fir::getBase(args[5]);
5767+
5768+
// We call the type specific versions because the result is scalar
5769+
// in the case below.
5770+
if (absentDim || rank == 1) {
5771+
if (fir::isa_complex(eleTy) || fir::isa_derived(eleTy)) {
5772+
mlir::Value result = builder.createTemporary(loc, eleTy);
5773+
fir::runtime::genReduce(builder, loc, array, operation, mask, identity,
5774+
ordered, result);
5775+
if (fir::isa_derived(eleTy))
5776+
return result;
5777+
return builder.create<fir::LoadOp>(loc, result);
5778+
}
5779+
if (fir::isa_char(eleTy)) {
5780+
// Create mutable fir.box to be passed to the runtime for the result.
5781+
fir::MutableBoxValue resultMutableBox =
5782+
fir::factory::createTempMutableBox(builder, loc, eleTy);
5783+
mlir::Value resultIrBox =
5784+
fir::factory::getMutableIRBox(builder, loc, resultMutableBox);
5785+
fir::runtime::genReduce(builder, loc, array, operation, mask, identity,
5786+
ordered, resultIrBox);
5787+
// Handle cleanup of allocatable result descriptor and return
5788+
return readAndAddCleanUp(resultMutableBox, resultType, "REDUCE");
5789+
}
5790+
return fir::runtime::genReduce(builder, loc, array, operation, mask,
5791+
identity, ordered);
5792+
}
5793+
TODO(loc, "reduce with array result");
57405794
}
57415795

57425796
// REPEAT

0 commit comments

Comments
 (0)