Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[MLIR][SPIRV] Start module combiner.
This commit adds a new library that merges/combines a number of spv modules into a combined one. The library has a single entry point: combine(...). To combine a number of MLIR spv modules, we move all the module-level ops from all the input modules into one big combined module. To that end, the combination process can proceed in 2 phases: (1) resolving conflicts between pairs of ops from different modules (2) deduplicate equivalent ops/sub-ops in the merged module. (TODO) This patch implements only the first phase. Reviewed By: antiagainst Differential Revision: https://reviews.llvm.org/D90477
- Loading branch information
1 parent
30e130c
commit 90a8260
Showing
10 changed files
with
1,047 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
//===- ModuleCombiner.h - MLIR SPIR-V Module Combiner -----------*- C++ -*-===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
// | ||
// This file declares the entry point to the SPIR-V module combiner library. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#ifndef MLIR_DIALECT_SPIRV_MODULECOMBINER_H_ | ||
#define MLIR_DIALECT_SPIRV_MODULECOMBINER_H_ | ||
|
||
#include "mlir/Dialect/SPIRV/SPIRVModule.h" | ||
#include "llvm/ADT/ArrayRef.h" | ||
#include "llvm/ADT/SmallVector.h" | ||
|
||
namespace mlir { | ||
class OpBuilder; | ||
|
||
namespace spirv { | ||
class ModuleOp; | ||
|
||
/// To combine a number of MLIR SPIR-V modules, we move all the module-level ops | ||
/// from all the input modules into one big combined module. To that end, the | ||
/// combination process proceeds in 2 phases: | ||
/// | ||
/// (1) resolve conflicts between pairs of ops from different modules | ||
/// (2) deduplicate equivalent ops/sub-ops in the merged module. (TODO) | ||
/// | ||
/// For the conflict resolution phase, the following rules are employed to | ||
/// resolve such conflicts: | ||
/// | ||
/// - If 2 spv.func's have the same symbol name, then rename one of the | ||
/// functions. | ||
/// - If an spv.func and another op have the same symbol name, then rename the | ||
/// other symbol. | ||
/// - If none of the 2 conflicting ops are spv.func, then rename either. | ||
/// | ||
/// In all cases, the references to the updated symbol are also updated to | ||
/// reflect the change. | ||
/// | ||
/// \param modules the list of modules to combine. Input modules are not | ||
/// modified. | ||
/// \param combinedMdouleBuilder an OpBuilder to be used for | ||
/// building up the combined module. | ||
/// \param symbRenameListener a listener that gets called everytime a symbol in | ||
/// one of the input modules is renamed. The arguments | ||
/// passed to the listener are: the input | ||
/// spirv::ModuleOp that contains the renamed symbol, | ||
/// a StringRef to the old symbol name, and a | ||
/// StringRef to the new symbol name. Note that it is | ||
/// the responsibility of the caller to properly | ||
/// retain the storage underlying the passed | ||
/// StringRefs if the listener callback outlives this | ||
/// function call. | ||
/// | ||
/// \return the combined module. | ||
OwningSPIRVModuleRef | ||
combine(llvm::MutableArrayRef<ModuleOp> modules, | ||
OpBuilder &combinedModuleBuilder, | ||
llvm::function_ref<void(ModuleOp, StringRef, StringRef)> | ||
symbRenameListener); | ||
} // namespace spirv | ||
} // namespace mlir | ||
|
||
#endif // MLIR_DIALECT_SPIRV_MODULECOMBINER_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
add_subdirectory(ModuleCombiner) |
11 changes: 11 additions & 0 deletions
11
mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/CMakeLists.txt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
add_mlir_dialect_library(MLIRSPIRVModuleCombiner | ||
ModuleCombiner.cpp | ||
|
||
ADDITIONAL_HEADER_DIRS | ||
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SPIRV | ||
|
||
LINK_LIBS PUBLIC | ||
MLIRIR | ||
MLIRSPIRV | ||
MLIRSupport | ||
) |
181 changes: 181 additions & 0 deletions
181
mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,181 @@ | ||
//===- ModuleCombiner.cpp - MLIR SPIR-V Module Combiner ---------*- C++ -*-===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
// | ||
// This file implements the the SPIR-V module combiner library. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include "mlir/Dialect/SPIRV/ModuleCombiner.h" | ||
|
||
#include "mlir/Dialect/SPIRV/SPIRVOps.h" | ||
#include "mlir/IR/Builders.h" | ||
#include "mlir/IR/SymbolTable.h" | ||
#include "llvm/ADT/ArrayRef.h" | ||
#include "llvm/ADT/StringExtras.h" | ||
|
||
using namespace mlir; | ||
|
||
static constexpr unsigned maxFreeID = 1 << 20; | ||
|
||
static SmallString<64> renameSymbol(StringRef oldSymName, unsigned &lastUsedID, | ||
spirv::ModuleOp combinedModule) { | ||
SmallString<64> newSymName(oldSymName); | ||
newSymName.push_back('_'); | ||
|
||
while (lastUsedID < maxFreeID) { | ||
std::string possible = (newSymName + llvm::utostr(++lastUsedID)).str(); | ||
|
||
if (!SymbolTable::lookupSymbolIn(combinedModule, possible)) { | ||
newSymName += llvm::utostr(lastUsedID); | ||
break; | ||
} | ||
} | ||
|
||
return newSymName; | ||
} | ||
|
||
/// Check if a symbol with the same name as op already exists in source. If so, | ||
/// rename op and update all its references in target. | ||
static LogicalResult updateSymbolAndAllUses(SymbolOpInterface op, | ||
spirv::ModuleOp target, | ||
spirv::ModuleOp source, | ||
unsigned &lastUsedID) { | ||
if (!SymbolTable::lookupSymbolIn(source, op.getName())) | ||
return success(); | ||
|
||
StringRef oldSymName = op.getName(); | ||
SmallString<64> newSymName = renameSymbol(oldSymName, lastUsedID, target); | ||
|
||
if (failed(SymbolTable::replaceAllSymbolUses(op, newSymName, target))) | ||
return op.emitError("unable to update all symbol uses for ") | ||
<< oldSymName << " to " << newSymName; | ||
|
||
SymbolTable::setSymbolName(op, newSymName); | ||
return success(); | ||
} | ||
|
||
namespace mlir { | ||
namespace spirv { | ||
|
||
// TODO Properly test symbol rename listener mechanism. | ||
|
||
OwningSPIRVModuleRef | ||
combine(llvm::MutableArrayRef<spirv::ModuleOp> modules, | ||
OpBuilder &combinedModuleBuilder, | ||
llvm::function_ref<void(ModuleOp, StringRef, StringRef)> | ||
symRenameListener) { | ||
unsigned lastUsedID = 0; | ||
|
||
if (modules.empty()) | ||
return nullptr; | ||
|
||
auto addressingModel = modules[0].addressing_model(); | ||
auto memoryModel = modules[0].memory_model(); | ||
|
||
auto combinedModule = combinedModuleBuilder.create<spirv::ModuleOp>( | ||
modules[0].getLoc(), addressingModel, memoryModel); | ||
combinedModuleBuilder.setInsertionPointToStart(&*combinedModule.getBody()); | ||
|
||
// In some cases, a symbol in the (current state of the) combined module is | ||
// renamed in order to maintain the conflicting symbol in the input module | ||
// being merged. For example, if the conflict is between a global variable in | ||
// the current combined module and a function in the input module, the global | ||
// varaible is renamed. In order to notify listeners of the symbol updates in | ||
// such cases, we need to keep track of the module from which the renamed | ||
// symbol in the combined module originated. This map keeps such information. | ||
DenseMap<StringRef, spirv::ModuleOp> symNameToModuleMap; | ||
|
||
for (auto module : modules) { | ||
if (module.addressing_model() != addressingModel || | ||
module.memory_model() != memoryModel) { | ||
module.emitError( | ||
"input modules differ in addressing model and/or memory model"); | ||
return nullptr; | ||
} | ||
|
||
spirv::ModuleOp moduleClone = module.clone(); | ||
|
||
// In the combined module, rename all symbols that conflict with symbols | ||
// from the current input module. This renmaing applies to all ops except | ||
// for spv.funcs. This way, if the conflicting op in the input module is | ||
// non-spv.func, we rename that symbol instead and maintain the spv.func in | ||
// the combined module name as it is. | ||
for (auto &op : combinedModule.getBlock().without_terminator()) { | ||
if (auto symbolOp = dyn_cast<SymbolOpInterface>(op)) { | ||
StringRef oldSymName = symbolOp.getName(); | ||
|
||
if (!isa<FuncOp>(op) && | ||
failed(updateSymbolAndAllUses(symbolOp, combinedModule, moduleClone, | ||
lastUsedID))) | ||
return nullptr; | ||
|
||
StringRef newSymName = symbolOp.getName(); | ||
|
||
if (symRenameListener && oldSymName != newSymName) { | ||
spirv::ModuleOp originalModule = | ||
symNameToModuleMap.lookup(oldSymName); | ||
|
||
if (!originalModule) { | ||
module.emitError("unable to find original ModuleOp for symbol ") | ||
<< oldSymName; | ||
return nullptr; | ||
} | ||
|
||
symRenameListener(originalModule, oldSymName, newSymName); | ||
|
||
// Since the symbol name is updated, there is no need to maintain the | ||
// entry that assocaites the old symbol name with the original module. | ||
symNameToModuleMap.erase(oldSymName); | ||
// Instead, add a new entry to map the new symbol name to the original | ||
// module in case it gets renamed again later. | ||
symNameToModuleMap[newSymName] = originalModule; | ||
} | ||
} | ||
} | ||
|
||
// In the current input module, rename all symbols that conflict with | ||
// symbols from the combined module. This includes renaming spv.funcs. | ||
for (auto &op : moduleClone.getBlock().without_terminator()) { | ||
if (auto symbolOp = dyn_cast<SymbolOpInterface>(op)) { | ||
StringRef oldSymName = symbolOp.getName(); | ||
|
||
if (failed(updateSymbolAndAllUses(symbolOp, moduleClone, combinedModule, | ||
lastUsedID))) | ||
return nullptr; | ||
|
||
StringRef newSymName = symbolOp.getName(); | ||
|
||
if (symRenameListener && oldSymName != newSymName) { | ||
symRenameListener(module, oldSymName, newSymName); | ||
|
||
// Insert the module associated with the symbol name. | ||
auto emplaceResult = | ||
symNameToModuleMap.try_emplace(symbolOp.getName(), module); | ||
|
||
// If an entry with the same symbol name is already present, this must | ||
// be a problem with the implementation, specially clean-up of the map | ||
// while iterating over the combined module above. | ||
if (!emplaceResult.second) { | ||
module.emitError("did not expect to find an entry for symbol ") | ||
<< symbolOp.getName(); | ||
return nullptr; | ||
} | ||
} | ||
} | ||
} | ||
|
||
// Clone all the module's ops to the combined module. | ||
for (auto &op : moduleClone.getBlock().without_terminator()) | ||
combinedModuleBuilder.insert(op.clone()); | ||
} | ||
|
||
return combinedModule; | ||
} | ||
|
||
} // namespace spirv | ||
} // namespace mlir |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
// RUN: mlir-opt -test-spirv-module-combiner -split-input-file -verify-diagnostics %s | FileCheck %s | ||
|
||
// CHECK: module { | ||
// CHECK-NEXT: spv.module Logical GLSL450 { | ||
// CHECK-NEXT: spv.specConstant @m1_sc | ||
// CHECK-NEXT: spv.specConstant @m2_sc | ||
// CHECK-NEXT: spv.func @variable_init_spec_constant | ||
// CHECK-NEXT: spv._reference_of @m2_sc | ||
// CHECK-NEXT: spv.Variable init | ||
// CHECK-NEXT: spv.Return | ||
// CHECK-NEXT: } | ||
// CHECK-NEXT: } | ||
// CHECK-NEXT: } | ||
|
||
module { | ||
spv.module Logical GLSL450 { | ||
spv.specConstant @m1_sc = 42.42 : f32 | ||
} | ||
|
||
spv.module Logical GLSL450 { | ||
spv.specConstant @m2_sc = 42 : i32 | ||
spv.func @variable_init_spec_constant() -> () "None" { | ||
%0 = spv._reference_of @m2_sc : i32 | ||
%1 = spv.Variable init(%0) : !spv.ptr<i32, Function> | ||
spv.Return | ||
} | ||
} | ||
} | ||
|
||
// ----- | ||
|
||
module { | ||
spv.module Physical64 GLSL450 { | ||
} | ||
|
||
// expected-error @+1 {{input modules differ in addressing model and/or memory model}} | ||
spv.module Logical GLSL450 { | ||
} | ||
} | ||
|
||
// ----- | ||
|
||
module { | ||
spv.module Logical Simple { | ||
} | ||
|
||
// expected-error @+1 {{input modules differ in addressing model and/or memory model}} | ||
spv.module Logical GLSL450 { | ||
} | ||
} |
Oops, something went wrong.