Skip to content

Commit 20a104a

Browse files
authored
[mlir] allow function type cloning to fail (llvm#136300)
`FunctionOpInterface` assumed the fact that the function type (attribute of the operation) can be cloned with arbirary lists of function arguments and results to support argument and result list mutation. This is not always correct, in particular, LLVM dialect functions require exactly one result making it impossible to erase the result. Allow function type cloning to fail and propagate this failure through various APIs that use it. The common assumption is that existing IR has not been modified. Fixes llvm#131142.
1 parent 35e6ca4 commit 20a104a

File tree

10 files changed

+105
-44
lines changed

10 files changed

+105
-44
lines changed

mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,8 @@ def LLVMFunctionType : LLVMType<"LLVMFunction", "func"> {
104104
bool isVarArg() const { return getVarArg(); }
105105

106106
/// Returns a clone of this function type with the given argument
107-
/// and result types.
107+
/// and result types. Returns null if the resulting function type would
108+
/// not verify.
108109
LLVMFunctionType clone(TypeRange inputs, TypeRange results) const;
109110

110111
/// Returns the result type of the function as an ArrayRef, enabling better

mlir/include/mlir/Interfaces/FunctionInterfaces.td

Lines changed: 54 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -255,79 +255,105 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface", [
255255
BlockArgListType getArguments() { return getFunctionBody().getArguments(); }
256256

257257
/// Insert a single argument of type `argType` with attributes `argAttrs` and
258-
/// location `argLoc` at `argIndex`.
259-
void insertArgument(unsigned argIndex, ::mlir::Type argType, ::mlir::DictionaryAttr argAttrs,
260-
::mlir::Location argLoc) {
261-
insertArguments({argIndex}, {argType}, {argAttrs}, {argLoc});
258+
/// location `argLoc` at `argIndex`. Returns failure if the function cannot be
259+
/// updated to have the new signature.
260+
::llvm::LogicalResult insertArgument(
261+
unsigned argIndex, ::mlir::Type argType, ::mlir::DictionaryAttr argAttrs,
262+
::mlir::Location argLoc) {
263+
return insertArguments({argIndex}, {argType}, {argAttrs}, {argLoc});
262264
}
263265

264266
/// Inserts arguments with the listed types, attributes, and locations at the
265267
/// listed indices. `argIndices` must be sorted. Arguments are inserted in the
266268
/// order they are listed, such that arguments with identical index will
267-
/// appear in the same order that they were listed here.
268-
void insertArguments(::llvm::ArrayRef<unsigned> argIndices, ::mlir::TypeRange argTypes,
269-
::llvm::ArrayRef<::mlir::DictionaryAttr> argAttrs,
270-
::llvm::ArrayRef<::mlir::Location> argLocs) {
269+
/// appear in the same order that they were listed here. Returns failure if
270+
/// the function cannot be updated to have the new signature.
271+
::llvm::LogicalResult insertArguments(
272+
::llvm::ArrayRef<unsigned> argIndices, ::mlir::TypeRange argTypes,
273+
::llvm::ArrayRef<::mlir::DictionaryAttr> argAttrs,
274+
::llvm::ArrayRef<::mlir::Location> argLocs) {
271275
unsigned originalNumArgs = $_op.getNumArguments();
272276
::mlir::Type newType = $_op.getTypeWithArgsAndResults(
273277
argIndices, argTypes, /*resultIndices=*/{}, /*resultTypes=*/{});
278+
if (!newType)
279+
return ::llvm::failure();
274280
::mlir::function_interface_impl::insertFunctionArguments(
275281
$_op, argIndices, argTypes, argAttrs, argLocs,
276282
originalNumArgs, newType);
283+
return ::llvm::success();
277284
}
278285

279-
/// Insert a single result of type `resultType` at `resultIndex`.
280-
void insertResult(unsigned resultIndex, ::mlir::Type resultType,
281-
::mlir::DictionaryAttr resultAttrs) {
282-
insertResults({resultIndex}, {resultType}, {resultAttrs});
286+
/// Insert a single result of type `resultType` at `resultIndex`.Returns
287+
/// failure if the function cannot be updated to have the new signature.
288+
::llvm::LogicalResult insertResult(
289+
unsigned resultIndex, ::mlir::Type resultType,
290+
::mlir::DictionaryAttr resultAttrs) {
291+
return insertResults({resultIndex}, {resultType}, {resultAttrs});
283292
}
284293

285294
/// Inserts results with the listed types at the listed indices.
286295
/// `resultIndices` must be sorted. Results are inserted in the order they are
287296
/// listed, such that results with identical index will appear in the same
288-
/// order that they were listed here.
289-
void insertResults(::llvm::ArrayRef<unsigned> resultIndices, ::mlir::TypeRange resultTypes,
290-
::llvm::ArrayRef<::mlir::DictionaryAttr> resultAttrs) {
297+
/// order that they were listed here. Returns failure if the function
298+
/// cannot be updated to have the new signature.
299+
::llvm::LogicalResult insertResults(
300+
::llvm::ArrayRef<unsigned> resultIndices,
301+
::mlir::TypeRange resultTypes,
302+
::llvm::ArrayRef<::mlir::DictionaryAttr> resultAttrs) {
291303
unsigned originalNumResults = $_op.getNumResults();
292304
::mlir::Type newType = $_op.getTypeWithArgsAndResults(
293305
/*argIndices=*/{}, /*argTypes=*/{}, resultIndices, resultTypes);
306+
if (!newType)
307+
return ::llvm::failure();
294308
::mlir::function_interface_impl::insertFunctionResults(
295309
$_op, resultIndices, resultTypes, resultAttrs,
296310
originalNumResults, newType);
311+
return ::llvm::success();
297312
}
298313

299-
/// Erase a single argument at `argIndex`.
300-
void eraseArgument(unsigned argIndex) {
314+
/// Erase a single argument at `argIndex`. Returns failure if the function
315+
/// cannot be updated to have the new signature.
316+
::llvm::LogicalResult eraseArgument(unsigned argIndex) {
301317
::llvm::BitVector argsToErase($_op.getNumArguments());
302318
argsToErase.set(argIndex);
303-
eraseArguments(argsToErase);
319+
return eraseArguments(argsToErase);
304320
}
305321

306-
/// Erases the arguments listed in `argIndices`.
307-
void eraseArguments(const ::llvm::BitVector &argIndices) {
322+
/// Erases the arguments listed in `argIndices`. Returns failure if the
323+
/// function cannot be updated to have the new signature.
324+
::llvm::LogicalResult eraseArguments(const ::llvm::BitVector &argIndices) {
308325
::mlir::Type newType = $_op.getTypeWithoutArgs(argIndices);
326+
if (!newType)
327+
return ::llvm::failure();
309328
::mlir::function_interface_impl::eraseFunctionArguments(
310329
$_op, argIndices, newType);
330+
return ::llvm::success();
311331
}
312332

313-
/// Erase a single result at `resultIndex`.
314-
void eraseResult(unsigned resultIndex) {
333+
/// Erase a single result at `resultIndex`. Returns failure if the function
334+
/// cannot be updated to have the new signature.
335+
LogicalResult eraseResult(unsigned resultIndex) {
315336
::llvm::BitVector resultsToErase($_op.getNumResults());
316337
resultsToErase.set(resultIndex);
317-
eraseResults(resultsToErase);
338+
return eraseResults(resultsToErase);
318339
}
319340

320-
/// Erases the results listed in `resultIndices`.
321-
void eraseResults(const ::llvm::BitVector &resultIndices) {
341+
/// Erases the results listed in `resultIndices`. Returns failure if the
342+
/// function cannot be updated to have the new signature.
343+
::llvm::LogicalResult eraseResults(const ::llvm::BitVector &resultIndices) {
322344
::mlir::Type newType = $_op.getTypeWithoutResults(resultIndices);
345+
if (!newType)
346+
return ::llvm::failure();
323347
::mlir::function_interface_impl::eraseFunctionResults(
324348
$_op, resultIndices, newType);
349+
return ::llvm::success();
325350
}
326351

327352
/// Return the type of this function with the specified arguments and
328353
/// results inserted. This is used to update the function's signature in
329354
/// the `insertArguments` and `insertResults` methods. The arrays must be
330-
/// sorted by increasing index.
355+
/// sorted by increasing index. Return nullptr if the updated type would
356+
/// not be valid.
331357
::mlir::Type getTypeWithArgsAndResults(
332358
::llvm::ArrayRef<unsigned> argIndices, ::mlir::TypeRange argTypes,
333359
::llvm::ArrayRef<unsigned> resultIndices, ::mlir::TypeRange resultTypes) {
@@ -341,7 +367,8 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface", [
341367

342368
/// Return the type of this function without the specified arguments and
343369
/// results. This is used to update the function's signature in the
344-
/// `eraseArguments` and `eraseResults` methods.
370+
/// `eraseArguments` and `eraseResults` methods. Return nullptr if the
371+
/// updated type would not be valid.
345372
::mlir::Type getTypeWithoutArgsAndResults(
346373
const ::llvm::BitVector &argIndices, const ::llvm::BitVector &resultIndices) {
347374
::llvm::SmallVector<::mlir::Type> argStorage, resultStorage;

mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,8 +125,12 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
125125
// Perform signature modification
126126
rewriter.modifyOpInPlace(
127127
gpuFuncOp, [gpuFuncOp, &argIndices, &argTypes, &argAttrs, &argLocs]() {
128-
static_cast<FunctionOpInterface>(gpuFuncOp).insertArguments(
129-
argIndices, argTypes, argAttrs, argLocs);
128+
LogicalResult inserted =
129+
static_cast<FunctionOpInterface>(gpuFuncOp).insertArguments(
130+
argIndices, argTypes, argAttrs, argLocs);
131+
(void)inserted;
132+
assert(succeeded(inserted) &&
133+
"expected GPU funcs to support inserting any argument");
130134
});
131135
} else {
132136
workgroupBuffers.reserve(gpuFuncOp.getNumWorkgroupAttributions());

mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,8 @@ updateFuncOp(func::FuncOp func,
9292
}
9393

9494
// Erase the results.
95-
func.eraseResults(erasedResultIndices);
95+
if (failed(func.eraseResults(erasedResultIndices)))
96+
return failure();
9697

9798
// Add the new arguments to the entry block if the function is not external.
9899
if (func.isExternal())

mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,8 @@ mlir::bufferization::dropEquivalentBufferResults(ModuleOp module) {
113113
}
114114

115115
// Update function.
116-
funcOp.eraseResults(erasedResultIndices);
116+
if (failed(funcOp.eraseResults(erasedResultIndices)))
117+
return failure();
117118
returnOp.getOperandsMutable().assign(newReturnValues);
118119

119120
// Update function calls.

mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,10 @@ LLVMFunctionType::getChecked(function_ref<InFlightDiagnostic()> emitError,
232232

233233
LLVMFunctionType LLVMFunctionType::clone(TypeRange inputs,
234234
TypeRange results) const {
235-
assert(results.size() == 1 && "expected a single result type");
235+
if (results.size() != 1 || !isValidResultType(results[0]))
236+
return {};
237+
if (!llvm::all_of(inputs, isValidArgumentType))
238+
return {};
236239
return get(results[0], llvm::to_vector(inputs), isVarArg());
237240
}
238241

mlir/lib/Query/Query.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,10 +88,11 @@ static Operation *extractFunction(std::vector<Operation *> &ops,
8888
// Remove unused function arguments
8989
size_t currentIndex = 0;
9090
while (currentIndex < funcOp.getNumArguments()) {
91+
// Erase if possible.
9192
if (funcOp.getArgument(currentIndex).use_empty())
92-
funcOp.eraseArgument(currentIndex);
93-
else
94-
++currentIndex;
93+
if (succeeded(funcOp.eraseArgument(currentIndex)))
94+
continue;
95+
++currentIndex;
9596
}
9697

9798
return funcOp;

mlir/lib/Transforms/RemoveDeadValues.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -698,8 +698,11 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) {
698698

699699
// 3. Functions
700700
for (auto &f : list.functions) {
701-
f.funcOp.eraseArguments(f.nonLiveArgs);
702-
f.funcOp.eraseResults(f.nonLiveRets);
701+
// Some functions may not allow erasing arguments or results. These calls
702+
// return failure in such cases without modifying the function, so it's okay
703+
// to proceed.
704+
(void)f.funcOp.eraseArguments(f.nonLiveArgs);
705+
(void)f.funcOp.eraseResults(f.nonLiveRets);
703706
}
704707

705708
// 4. Operands

mlir/test/IR/test-func-erase-result.mlir

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt %s -test-func-erase-result -split-input-file | FileCheck %s
1+
// RUN: mlir-opt %s -test-func-erase-result -split-input-file -verify-diagnostics | FileCheck %s
22

33
// CHECK: func private @f(){{$}}
44
// CHECK-NOT: attributes{{.*}}result
@@ -66,3 +66,8 @@ func.func private @f() -> (
6666
f32 {test.erase_this_result},
6767
tensor<3xf32>
6868
)
69+
70+
// -----
71+
72+
// expected-error @below {{failed to erase results}}
73+
llvm.func @llvm_func(!llvm.ptr, i64)

mlir/test/lib/IR/TestFunc.cpp

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,12 @@ struct TestFuncInsertArg
4545
: unknownLoc);
4646
}
4747
func->removeAttr("test.insert_args");
48-
func.insertArguments(indicesToInsert, typesToInsert, attrsToInsert,
49-
locsToInsert);
48+
if (succeeded(func.insertArguments(indicesToInsert, typesToInsert,
49+
attrsToInsert, locsToInsert)))
50+
continue;
51+
52+
emitError(func->getLoc()) << "failed to insert arguments";
53+
return signalPassFailure();
5054
}
5155
}
5256
};
@@ -79,7 +83,12 @@ struct TestFuncInsertResult
7983
: DictionaryAttr::get(&getContext()));
8084
}
8185
func->removeAttr("test.insert_results");
82-
func.insertResults(indicesToInsert, typesToInsert, attrsToInsert);
86+
if (succeeded(func.insertResults(indicesToInsert, typesToInsert,
87+
attrsToInsert)))
88+
continue;
89+
90+
emitError(func->getLoc()) << "failed to insert results";
91+
return signalPassFailure();
8392
}
8493
}
8594
};
@@ -100,7 +109,10 @@ struct TestFuncEraseArg
100109
for (auto argIndex : llvm::seq<int>(0, func.getNumArguments()))
101110
if (func.getArgAttr(argIndex, "test.erase_this_arg"))
102111
indicesToErase.set(argIndex);
103-
func.eraseArguments(indicesToErase);
112+
if (succeeded(func.eraseArguments(indicesToErase)))
113+
continue;
114+
emitError(func->getLoc()) << "failed to erase arguments";
115+
return signalPassFailure();
104116
}
105117
}
106118
};
@@ -122,7 +134,10 @@ struct TestFuncEraseResult
122134
for (auto resultIndex : llvm::seq<int>(0, func.getNumResults()))
123135
if (func.getResultAttr(resultIndex, "test.erase_this_result"))
124136
indicesToErase.set(resultIndex);
125-
func.eraseResults(indicesToErase);
137+
if (succeeded(func.eraseResults(indicesToErase)))
138+
continue;
139+
emitError(func->getLoc()) << "failed to erase results";
140+
return signalPassFailure();
126141
}
127142
}
128143
};

0 commit comments

Comments
 (0)