Skip to content

Commit

Permalink
[mlir][sparse] support sparsifying sparse kernels to sparse-iterator-…
Browse files Browse the repository at this point in the history
…based loop (llvm#95858)
  • Loading branch information
PeimingLiu committed Jun 17, 2024
1 parent c67ecf3 commit a02010b
Show file tree
Hide file tree
Showing 40 changed files with 745 additions and 474 deletions.
2 changes: 1 addition & 1 deletion mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ class LevelSet {
assert(i < 64);
return (bits & (1 << i)) != 0;
}

unsigned max() const { return 64 - llvm::countl_zero(bits); }
unsigned count() const { return llvm::popcount(bits); }
bool empty() const { return bits == 0; }
};
Expand Down
28 changes: 24 additions & 4 deletions mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1493,6 +1493,10 @@ def ExtractIterSpaceOp : SparseTensor_Op<"extract_iteration_space",
```
}];

let arguments = (ins AnySparseTensor:$tensor,
Optional<AnySparseIterator>:$parentIter,
LevelAttr:$loLvl, LevelAttr:$hiLvl);
let results = (outs AnySparseIterSpace:$extractedSpace);

let extraClassDeclaration = [{
std::pair<Level, Level> getLvlRange() {
Expand All @@ -1506,10 +1510,20 @@ def ExtractIterSpaceOp : SparseTensor_Op<"extract_iteration_space",
}
}];

let arguments = (ins AnySparseTensor:$tensor,
Optional<AnySparseIterator>:$parentIter,
LevelAttr:$loLvl, LevelAttr:$hiLvl);
let results = (outs AnySparseIterSpace:$extractedSpace);
let builders = [
// Construct a 1-D iteration space.
OpBuilder<(ins "Value":$tensor, "Value":$parentIter,
"sparse_tensor::Level":$loLvl),
[{
build($_builder, $_state, tensor, parentIter, loLvl, loLvl + 1);
}]>,
// Construct a 1-D root iteration space
OpBuilder<(ins "Value":$tensor),
[{
build($_builder, $_state, tensor, nullptr, 0);
}]>
];

let assemblyFormat = "$tensor (`at` $parentIter^)? `lvls` `=` custom<LevelRange>($loLvl, $hiLvl) "
" attr-dict `:` type($tensor) (`,` type($parentIter)^)? "
"`->` qualified(type($extractedSpace))";
Expand Down Expand Up @@ -1594,6 +1608,12 @@ def IterateOp : SparseTensor_Op<"iterate",
let results = (outs Variadic<AnyType>:$results);
let regions = (region SizedRegion<1>:$region);

let skipDefaultBuilders = 1;
let builders = [
OpBuilder<(ins "Value":$iterSpace, "ValueRange":$initArgs)>,
OpBuilder<(ins "Value":$iterSpace, "ValueRange":$initArgs, "LevelSet" :$crdUsedLvls)>
];

let extraClassDeclaration = [{
unsigned getSpaceDim() {
return getIterSpace().getType().getSpaceDim();
Expand Down
17 changes: 16 additions & 1 deletion mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,20 @@ struct SparsifierOptions : public PassPipelineOptions<SparsifierOptions> {
mlir::SparseParallelizationStrategy::kAnyStorageAnyLoop,
"any-storage-any-loop",
"Enable sparse parallelization for any storage and loop."))};
PassOptions::Option<mlir::SparseEmitStrategy> emitStrategy{
*this, "sparse-emit-strategy",
::llvm::cl::desc(
"Emit functional code or interfaces (to debug) for sparse loops"),
::llvm::cl::init(mlir::SparseEmitStrategy::kFunctional),
llvm::cl::values(
clEnumValN(mlir::SparseEmitStrategy::kFunctional, "functional",
"Emit functional code (with scf.for/while)."),
clEnumValN(mlir::SparseEmitStrategy::kSparseIterator,
"sparse-iterator",
"Emit (experimental) loops (with sparse.iterate)."),
clEnumValN(
mlir::SparseEmitStrategy::kDebugInterface, "debug-interface",
"Emit non-functional but easy-to-read interfaces to debug."))};

PassOptions::Option<bool> enableRuntimeLibrary{
*this, "enable-runtime-library",
Expand Down Expand Up @@ -143,7 +157,8 @@ struct SparsifierOptions : public PassPipelineOptions<SparsifierOptions> {

/// Projects out the options for `createSparsificationPass`.
SparsificationOptions sparsificationOptions() const {
return SparsificationOptions(parallelization, enableRuntimeLibrary);
return SparsificationOptions(parallelization, emitStrategy,
enableRuntimeLibrary);
}

/// Projects out the options for `createConvertVectorToLLVMPass`.
Expand Down
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ enum class ReinterpretMapScope {
/// Defines a scope for reinterpret map pass.
enum class SparseEmitStrategy {
kFunctional, // generate fully inlined (and functional) sparse iteration
kSparseIterator, // generate (experimental) loop using sparse iterator.
kDebugInterface, // generate only place-holder for sparse iteration
};

Expand Down
4 changes: 3 additions & 1 deletion mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,9 @@ def SparsificationPass : Pass<"sparsification", "ModuleOp"> {
"mlir::SparseEmitStrategy::kFunctional",
"Emit functional code or interfaces (to debug) for sparse loops", [{llvm::cl::values(
clEnumValN(mlir::SparseEmitStrategy::kFunctional, "functional",
"Emit functional code."),
"Emit functional code (with scf.for/while)."),
clEnumValN(mlir::SparseEmitStrategy::kSparseIterator, "sparse-iterator",
"Emit (experimental) loops (with sparse.iterate)."),
clEnumValN(mlir::SparseEmitStrategy::kDebugInterface, "debug-interface",
"Emit non-functional but easy-to-read interfaces to debug."))}]>,
Option<"enableRuntimeLibrary", "enable-runtime-library", "bool",
Expand Down
38 changes: 38 additions & 0 deletions mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2300,6 +2300,41 @@ void IterateOp::getCanonicalizationPatterns(mlir::RewritePatternSet &results,
results.add<RemoveUnusedLvlCrds>(context);
}

void IterateOp::build(OpBuilder &builder, OperationState &odsState,
Value iterSpace, ValueRange initArgs) {
unsigned rank = llvm::cast<IterSpaceType>(iterSpace.getType()).getSpaceDim();
// All ones.
LevelSet set((1 << rank) - 1);
return build(builder, odsState, iterSpace, initArgs, set);
}

void IterateOp::build(OpBuilder &builder, OperationState &odsState,
Value iterSpace, ValueRange initArgs,
LevelSet crdUsedLvls) {
OpBuilder::InsertionGuard guard(builder);

odsState.addOperands(iterSpace);
odsState.addOperands(initArgs);
odsState.getOrAddProperties<Properties>().crdUsedLvls =
builder.getIntegerAttr(builder.getIntegerType(64), crdUsedLvls);
Region *bodyRegion = odsState.addRegion();
odsState.addTypes(initArgs.getTypes());
Block *bodyBlock = builder.createBlock(bodyRegion);

// First argument, sparse iterator
bodyBlock->addArgument(
llvm::cast<IterSpaceType>(iterSpace.getType()).getIteratorType(),
odsState.location);

// Followed by a list of used coordinates.
for (unsigned i = 0, e = crdUsedLvls.count(); i < e; i++)
bodyBlock->addArgument(builder.getIndexType(), odsState.location);

// Followed by a list of user-provided loop arguments.
for (Value v : initArgs)
bodyBlock->addArgument(v.getType(), v.getLoc());
}

ParseResult IterateOp::parse(OpAsmParser &parser, OperationState &result) {
OpAsmParser::Argument iterator;
OpAsmParser::UnresolvedOperand iterSpace;
Expand Down Expand Up @@ -2384,6 +2419,9 @@ LogicalResult IterateOp::verify() {
return emitOpError(
"mismatch in number of loop-carried values and defined values");
}
if (getCrdUsedLvls().max() > getSpaceDim())
return emitOpError("required out-of-bound coordinates");

return success();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,6 @@ class SparseIterateOpConverter : public OneToNOpConversionPattern<IterateOp> {
// replace sparse_tensor.yield with scf.yield.
rewriter.eraseOp(yieldOp);
rewriter.create<scf::YieldOp>(loc, yields);

const OneToNTypeMapping &resultMapping = adaptor.getResultMapping();
rewriter.replaceOp(
op, whileOp.getResults().drop_front(it->getCursor().size()),
Expand Down Expand Up @@ -192,6 +191,8 @@ mlir::SparseIterationTypeConverter::SparseIterationTypeConverter() {

void mlir::populateLowerSparseIterationToSCFPatterns(
TypeConverter &converter, RewritePatternSet &patterns) {

IterateOp::getCanonicalizationPatterns(patterns, patterns.getContext());
patterns.add<ExtractIterSpaceConverter, SparseIterateOpConverter>(
converter, patterns.getContext());
}
6 changes: 5 additions & 1 deletion mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1071,7 +1071,11 @@ static bool getAllTidLvlsInLatPoints(
}
// If we just need to one loop conditions and the conditions is not imposed on
// non-unique level, the loop can be generated by a for loop.
return numloopCond == 1 && !hasNonUnique;
// Or, if we are generating sparse-iterator-based loops, we always generate
// `sparse_tensor.iterate` regardless whether the level is unique or not.
return numloopCond == 1 &&
(!hasNonUnique || env.options().sparseEmitStrategy ==
SparseEmitStrategy::kSparseIterator);
}

/// Starts a loop sequence at given level. Returns true if
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,12 @@ class SparsificationAndBufferizationPass
pm.addPass(createSparseGPUCodegenPass(0, enableRuntimeLibrary));
pm.addPass(createSparseReinterpretMapPass(ReinterpretMapScope::kAll));
pm.addPass(createSparsificationPass(sparsificationOptions));
if (sparsificationOptions.sparseEmitStrategy ==
SparseEmitStrategy::kSparseIterator) {
pm.addNestedPass<func::FuncOp>(createSparseSpaceCollapsePass());
pm.addNestedPass<func::FuncOp>(createLowerSparseIterationToSCFPass());
}

pm.addNestedPass<func::FuncOp>(createStageSparseOperationsPass());
pm.addPass(createLowerSparseOpsToForeachPass(enableRuntimeLibrary,
/*enableConvert=*/true));
Expand Down
Loading

0 comments on commit a02010b

Please sign in to comment.