Skip to content

Commit

Permalink
[MLIR][SPIRV] Start module combiner.
Browse files Browse the repository at this point in the history
This commit adds a new library that merges/combines a number of spv
modules into a combined one. The library has a single entry point:
combine(...).

To combine a number of MLIR spv modules, we move all the module-level ops
from all the input modules into one big combined module. To that end, the
combination process can proceed in 2 phases:

  (1) resolving conflicts between pairs of ops from different modules
  (2) deduplicate equivalent ops/sub-ops in the merged module. (TODO)

This patch implements only the first phase.

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D90477
  • Loading branch information
ergawy authored and antiagainst committed Oct 30, 2020
1 parent 30e130c commit 90a8260
Show file tree
Hide file tree
Showing 10 changed files with 1,047 additions and 0 deletions.
69 changes: 69 additions & 0 deletions mlir/include/mlir/Dialect/SPIRV/ModuleCombiner.h
@@ -0,0 +1,69 @@
//===- ModuleCombiner.h - MLIR SPIR-V Module Combiner -----------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file declares the entry point to the SPIR-V module combiner library.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_SPIRV_MODULECOMBINER_H_
#define MLIR_DIALECT_SPIRV_MODULECOMBINER_H_

#include "mlir/Dialect/SPIRV/SPIRVModule.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"

namespace mlir {
class OpBuilder;

namespace spirv {
class ModuleOp;

/// To combine a number of MLIR SPIR-V modules, we move all the module-level ops
/// from all the input modules into one big combined module. To that end, the
/// combination process proceeds in 2 phases:
///
/// (1) resolve conflicts between pairs of ops from different modules
/// (2) deduplicate equivalent ops/sub-ops in the merged module. (TODO)
///
/// For the conflict resolution phase, the following rules are employed to
/// resolve such conflicts:
///
/// - If 2 spv.func's have the same symbol name, then rename one of the
/// functions.
/// - If an spv.func and another op have the same symbol name, then rename the
/// other symbol.
/// - If none of the 2 conflicting ops are spv.func, then rename either.
///
/// In all cases, the references to the updated symbol are also updated to
/// reflect the change.
///
/// \param modules the list of modules to combine. Input modules are not
/// modified.
/// \param combinedMdouleBuilder an OpBuilder to be used for
/// building up the combined module.
/// \param symbRenameListener a listener that gets called everytime a symbol in
/// one of the input modules is renamed. The arguments
/// passed to the listener are: the input
/// spirv::ModuleOp that contains the renamed symbol,
/// a StringRef to the old symbol name, and a
/// StringRef to the new symbol name. Note that it is
/// the responsibility of the caller to properly
/// retain the storage underlying the passed
/// StringRefs if the listener callback outlives this
/// function call.
///
/// \return the combined module.
OwningSPIRVModuleRef
combine(llvm::MutableArrayRef<ModuleOp> modules,
OpBuilder &combinedModuleBuilder,
llvm::function_ref<void(ModuleOp, StringRef, StringRef)>
symbRenameListener);
} // namespace spirv
} // namespace mlir

#endif // MLIR_DIALECT_SPIRV_MODULECOMBINER_H_
1 change: 1 addition & 0 deletions mlir/lib/Dialect/SPIRV/CMakeLists.txt
Expand Up @@ -34,5 +34,6 @@ add_mlir_dialect_library(MLIRSPIRV
MLIRTransforms
)

add_subdirectory(Linking)
add_subdirectory(Serialization)
add_subdirectory(Transforms)
1 change: 1 addition & 0 deletions mlir/lib/Dialect/SPIRV/Linking/CMakeLists.txt
@@ -0,0 +1 @@
add_subdirectory(ModuleCombiner)
11 changes: 11 additions & 0 deletions mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/CMakeLists.txt
@@ -0,0 +1,11 @@
add_mlir_dialect_library(MLIRSPIRVModuleCombiner
ModuleCombiner.cpp

ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SPIRV

LINK_LIBS PUBLIC
MLIRIR
MLIRSPIRV
MLIRSupport
)
181 changes: 181 additions & 0 deletions mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp
@@ -0,0 +1,181 @@
//===- ModuleCombiner.cpp - MLIR SPIR-V Module Combiner ---------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements the the SPIR-V module combiner library.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/SPIRV/ModuleCombiner.h"

#include "mlir/Dialect/SPIRV/SPIRVOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/SymbolTable.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/StringExtras.h"

using namespace mlir;

static constexpr unsigned maxFreeID = 1 << 20;

static SmallString<64> renameSymbol(StringRef oldSymName, unsigned &lastUsedID,
spirv::ModuleOp combinedModule) {
SmallString<64> newSymName(oldSymName);
newSymName.push_back('_');

while (lastUsedID < maxFreeID) {
std::string possible = (newSymName + llvm::utostr(++lastUsedID)).str();

if (!SymbolTable::lookupSymbolIn(combinedModule, possible)) {
newSymName += llvm::utostr(lastUsedID);
break;
}
}

return newSymName;
}

/// Check if a symbol with the same name as op already exists in source. If so,
/// rename op and update all its references in target.
static LogicalResult updateSymbolAndAllUses(SymbolOpInterface op,
spirv::ModuleOp target,
spirv::ModuleOp source,
unsigned &lastUsedID) {
if (!SymbolTable::lookupSymbolIn(source, op.getName()))
return success();

StringRef oldSymName = op.getName();
SmallString<64> newSymName = renameSymbol(oldSymName, lastUsedID, target);

if (failed(SymbolTable::replaceAllSymbolUses(op, newSymName, target)))
return op.emitError("unable to update all symbol uses for ")
<< oldSymName << " to " << newSymName;

SymbolTable::setSymbolName(op, newSymName);
return success();
}

namespace mlir {
namespace spirv {

// TODO Properly test symbol rename listener mechanism.

OwningSPIRVModuleRef
combine(llvm::MutableArrayRef<spirv::ModuleOp> modules,
OpBuilder &combinedModuleBuilder,
llvm::function_ref<void(ModuleOp, StringRef, StringRef)>
symRenameListener) {
unsigned lastUsedID = 0;

if (modules.empty())
return nullptr;

auto addressingModel = modules[0].addressing_model();
auto memoryModel = modules[0].memory_model();

auto combinedModule = combinedModuleBuilder.create<spirv::ModuleOp>(
modules[0].getLoc(), addressingModel, memoryModel);
combinedModuleBuilder.setInsertionPointToStart(&*combinedModule.getBody());

// In some cases, a symbol in the (current state of the) combined module is
// renamed in order to maintain the conflicting symbol in the input module
// being merged. For example, if the conflict is between a global variable in
// the current combined module and a function in the input module, the global
// varaible is renamed. In order to notify listeners of the symbol updates in
// such cases, we need to keep track of the module from which the renamed
// symbol in the combined module originated. This map keeps such information.
DenseMap<StringRef, spirv::ModuleOp> symNameToModuleMap;

for (auto module : modules) {
if (module.addressing_model() != addressingModel ||
module.memory_model() != memoryModel) {
module.emitError(
"input modules differ in addressing model and/or memory model");
return nullptr;
}

spirv::ModuleOp moduleClone = module.clone();

// In the combined module, rename all symbols that conflict with symbols
// from the current input module. This renmaing applies to all ops except
// for spv.funcs. This way, if the conflicting op in the input module is
// non-spv.func, we rename that symbol instead and maintain the spv.func in
// the combined module name as it is.
for (auto &op : combinedModule.getBlock().without_terminator()) {
if (auto symbolOp = dyn_cast<SymbolOpInterface>(op)) {
StringRef oldSymName = symbolOp.getName();

if (!isa<FuncOp>(op) &&
failed(updateSymbolAndAllUses(symbolOp, combinedModule, moduleClone,
lastUsedID)))
return nullptr;

StringRef newSymName = symbolOp.getName();

if (symRenameListener && oldSymName != newSymName) {
spirv::ModuleOp originalModule =
symNameToModuleMap.lookup(oldSymName);

if (!originalModule) {
module.emitError("unable to find original ModuleOp for symbol ")
<< oldSymName;
return nullptr;
}

symRenameListener(originalModule, oldSymName, newSymName);

// Since the symbol name is updated, there is no need to maintain the
// entry that assocaites the old symbol name with the original module.
symNameToModuleMap.erase(oldSymName);
// Instead, add a new entry to map the new symbol name to the original
// module in case it gets renamed again later.
symNameToModuleMap[newSymName] = originalModule;
}
}
}

// In the current input module, rename all symbols that conflict with
// symbols from the combined module. This includes renaming spv.funcs.
for (auto &op : moduleClone.getBlock().without_terminator()) {
if (auto symbolOp = dyn_cast<SymbolOpInterface>(op)) {
StringRef oldSymName = symbolOp.getName();

if (failed(updateSymbolAndAllUses(symbolOp, moduleClone, combinedModule,
lastUsedID)))
return nullptr;

StringRef newSymName = symbolOp.getName();

if (symRenameListener && oldSymName != newSymName) {
symRenameListener(module, oldSymName, newSymName);

// Insert the module associated with the symbol name.
auto emplaceResult =
symNameToModuleMap.try_emplace(symbolOp.getName(), module);

// If an entry with the same symbol name is already present, this must
// be a problem with the implementation, specially clean-up of the map
// while iterating over the combined module above.
if (!emplaceResult.second) {
module.emitError("did not expect to find an entry for symbol ")
<< symbolOp.getName();
return nullptr;
}
}
}
}

// Clone all the module's ops to the combined module.
for (auto &op : moduleClone.getBlock().without_terminator())
combinedModuleBuilder.insert(op.clone());
}

return combinedModule;
}

} // namespace spirv
} // namespace mlir
50 changes: 50 additions & 0 deletions mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/basic.mlir
@@ -0,0 +1,50 @@
// RUN: mlir-opt -test-spirv-module-combiner -split-input-file -verify-diagnostics %s | FileCheck %s

// CHECK: module {
// CHECK-NEXT: spv.module Logical GLSL450 {
// CHECK-NEXT: spv.specConstant @m1_sc
// CHECK-NEXT: spv.specConstant @m2_sc
// CHECK-NEXT: spv.func @variable_init_spec_constant
// CHECK-NEXT: spv._reference_of @m2_sc
// CHECK-NEXT: spv.Variable init
// CHECK-NEXT: spv.Return
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }

module {
spv.module Logical GLSL450 {
spv.specConstant @m1_sc = 42.42 : f32
}

spv.module Logical GLSL450 {
spv.specConstant @m2_sc = 42 : i32
spv.func @variable_init_spec_constant() -> () "None" {
%0 = spv._reference_of @m2_sc : i32
%1 = spv.Variable init(%0) : !spv.ptr<i32, Function>
spv.Return
}
}
}

// -----

module {
spv.module Physical64 GLSL450 {
}

// expected-error @+1 {{input modules differ in addressing model and/or memory model}}
spv.module Logical GLSL450 {
}
}

// -----

module {
spv.module Logical Simple {
}

// expected-error @+1 {{input modules differ in addressing model and/or memory model}}
spv.module Logical GLSL450 {
}
}

0 comments on commit 90a8260

Please sign in to comment.