Skip to content

Commit 1b232fa

Browse files
author
Jeff Niu
authored
[mlir][ods] Allow sharding of op definitions (llvm#89423)
Adds an option to `mlir-tblgen -gen-op-defs` `op-shard-count=N` that divides the op class definitions and op list into N segments, e.g. ``` // mlir-tblgen -gen-op-defs -op-shard-count=2 void FooDialect::initialize() { addOperations< >(); addOperations< >(); } ``` When split across multiple source files, this can help significantly improve dialect compile time for dialects with a large opset.
1 parent c3def59 commit 1b232fa

File tree

14 files changed

+519
-49
lines changed

14 files changed

+519
-49
lines changed

mlir/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,10 +185,13 @@ include_directories( ${MLIR_INCLUDE_DIR})
185185
add_subdirectory(tools/mlir-linalg-ods-gen)
186186
add_subdirectory(tools/mlir-pdll)
187187
add_subdirectory(tools/mlir-tblgen)
188+
add_subdirectory(tools/mlir-src-sharder)
188189
set(MLIR_TABLEGEN_EXE "${MLIR_TABLEGEN_EXE}" CACHE INTERNAL "")
189190
set(MLIR_TABLEGEN_TARGET "${MLIR_TABLEGEN_TARGET}" CACHE INTERNAL "")
190191
set(MLIR_PDLL_TABLEGEN_EXE "${MLIR_PDLL_TABLEGEN_EXE}" CACHE INTERNAL "")
191192
set(MLIR_PDLL_TABLEGEN_TARGET "${MLIR_PDLL_TABLEGEN_TARGET}" CACHE INTERNAL "")
193+
set(MLIR_SRC_SHARDER_TABLEGEN_EXE "${MLIR_SRC_SHARDER_TABLEGEN_EXE}" CACHE INTERNAL "")
194+
set(MLIR_SRC_SHARDER_TABLEGEN_TARGET "${MLIR_SRC_SHARDER_TABLEGEN_TARGET}" CACHE INTERNAL "")
192195

193196
add_subdirectory(include/mlir)
194197
add_subdirectory(lib)

mlir/cmake/modules/AddMLIR.cmake

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,28 @@ function(mlir_tablegen ofn)
55
tablegen(MLIR ${ARGV})
66
set(TABLEGEN_OUTPUT ${TABLEGEN_OUTPUT} ${CMAKE_CURRENT_BINARY_DIR}/${ofn}
77
PARENT_SCOPE)
8+
9+
# Get the current set of include paths for this td file.
10+
cmake_parse_arguments(ARG "" "" "DEPENDS;EXTRA_INCLUDES" ${ARGN})
11+
get_directory_property(tblgen_includes INCLUDE_DIRECTORIES)
12+
list(APPEND tblgen_includes ${ARG_EXTRA_INCLUDES})
13+
# Filter out any empty include items.
14+
list(REMOVE_ITEM tblgen_includes "")
15+
16+
# Build the absolute path for the current input file.
17+
if (IS_ABSOLUTE ${LLVM_TARGET_DEFINITIONS})
18+
set(LLVM_TARGET_DEFINITIONS_ABSOLUTE ${LLVM_TARGET_DEFINITIONS})
19+
else()
20+
set(LLVM_TARGET_DEFINITIONS_ABSOLUTE ${CMAKE_CURRENT_SOURCE_DIR}/${LLVM_TARGET_DEFINITIONS})
21+
endif()
22+
23+
# Append the includes used for this file to the tablegen_compile_commands
24+
# file.
25+
file(APPEND ${CMAKE_BINARY_DIR}/tablegen_compile_commands.yml
26+
"--- !FileInfo:\n"
27+
" filepath: \"${LLVM_TARGET_DEFINITIONS_ABSOLUTE}\"\n"
28+
" includes: \"${CMAKE_CURRENT_SOURCE_DIR};${tblgen_includes}\"\n"
29+
)
830
endfunction()
931

1032
# Clear out any pre-existing compile_commands file before processing. This
@@ -149,6 +171,22 @@ function(add_mlir_dialect dialect dialect_namespace)
149171
add_dependencies(mlir-headers MLIR${dialect}IncGen)
150172
endfunction()
151173

174+
# Declare sharded dialect operation declarations and definitions
175+
function(add_sharded_ops ops_target shard_count)
176+
set(LLVM_TARGET_DEFINITIONS ${ops_target}.td)
177+
mlir_tablegen(${ops_target}.h.inc -gen-op-decls -op-shard-count=${shard_count})
178+
mlir_tablegen(${ops_target}.cpp.inc -gen-op-defs -op-shard-count=${shard_count})
179+
set(LLVM_TARGET_DEFINITIONS ${ops_target}.cpp)
180+
foreach(index RANGE ${shard_count})
181+
set(SHARDED_SRC ${ops_target}.${index}.cpp)
182+
list(APPEND SHARDED_SRCS ${SHARDED_SRC})
183+
tablegen(MLIR_SRC_SHARDER ${SHARDED_SRC} -op-shard-index=${index})
184+
set(TABLEGEN_OUTPUT ${TABLEGEN_OUTPUT} ${CMAKE_CURRENT_BINARY_DIR}/${SHARDED_SRC})
185+
endforeach()
186+
add_public_tablegen_target(MLIR${ops_target}ShardGen)
187+
set(SHARDED_SRCS ${SHARDED_SRCS} PARENT_SCOPE)
188+
endfunction()
189+
152190
# Declare a dialect in the include directory
153191
function(add_mlir_interface interface)
154192
set(LLVM_TARGET_DEFINITIONS ${interface}.td)

mlir/cmake/modules/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ set(MLIR_CONFIG_INCLUDE_DIRS
3939
# Refer to the best host mlir-tbgen, which might be a host-optimized version
4040
set(MLIR_CONFIG_TABLEGEN_EXE "${MLIR_TABLEGEN_EXE}")
4141
set(MLIR_CONFIG_PDLL_TABLEGEN_EXE "${MLIR_PDLL_TABLEGEN_EXE}")
42+
set(MLIR_CONFIG_SRC_SHARDER_TABLEGEN_EXE "${MLIR_SRC_SHARDER_TABLEGEN_EXE}")
4243

4344
configure_file(
4445
${CMAKE_CURRENT_SOURCE_DIR}/MLIRConfig.cmake.in
@@ -77,6 +78,7 @@ set(MLIR_CONFIG_INCLUDE_DIRS
7778
# if we're building with a host-optimized mlir-tblgen (with LLVM_OPTIMIZED_TABLEGEN).
7879
set(MLIR_CONFIG_TABLEGEN_EXE mlir-tblgen)
7980
set(MLIR_CONFIG_PDLL_TABLEGEN_EXE mlir-pdll)
81+
set(MLIR_CONFIG_SRC_SHARDER_TABLEGEN_EXE mlir-src-sharder)
8082

8183
configure_file(
8284
${CMAKE_CURRENT_SOURCE_DIR}/MLIRConfig.cmake.in

mlir/cmake/modules/MLIRConfig.cmake.in

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ set(MLIR_CMAKE_DIR "@MLIR_CONFIG_CMAKE_DIR@")
1111
set(MLIR_INCLUDE_DIRS "@MLIR_CONFIG_INCLUDE_DIRS@")
1212
set(MLIR_TABLEGEN_EXE "@MLIR_CONFIG_TABLEGEN_EXE@")
1313
set(MLIR_PDLL_TABLEGEN_EXE "@MLIR_CONFIG_PDLL_TABLEGEN_EXE@")
14+
set(MLIR_SRC_SHARDER_TABLEGEN_EXE "@MLIR_CONFIG_SRC_SHARDER_TABLEGEN_EXE@")
1415
set(MLIR_INSTALL_AGGREGATE_OBJECTS "@MLIR_INSTALL_AGGREGATE_OBJECTS@")
1516
set(MLIR_ENABLE_BINDINGS_PYTHON "@MLIR_ENABLE_BINDINGS_PYTHON@")
1617
set(MLIR_ENABLE_EXECUTION_ENGINE "@MLIR_ENABLE_EXECUTION_ENGINE@")

mlir/include/mlir/TableGen/CodeGenHelpers.h

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -99,16 +99,22 @@ class NamespaceEmitter {
9999
///
100100
class StaticVerifierFunctionEmitter {
101101
public:
102+
/// Create a constraint uniquer with a unique prefix derived from the record
103+
/// keeper with an optional tag.
102104
StaticVerifierFunctionEmitter(raw_ostream &os,
103-
const llvm::RecordKeeper &records);
105+
const llvm::RecordKeeper &records,
106+
StringRef tag = "");
107+
108+
/// Collect and unique all the constraints used by operations.
109+
void collectOpConstraints(ArrayRef<llvm::Record *> opDefs);
104110

105111
/// Collect and unique all compatible type, attribute, successor, and region
106112
/// constraints from the operations in the file and emit them at the top of
107113
/// the generated file.
108114
///
109115
/// Constraints that do not meet the restriction that they can only reference
110116
/// `$_self` and `$_op` are not uniqued.
111-
void emitOpConstraints(ArrayRef<llvm::Record *> opDefs, bool emitDecl);
117+
void emitOpConstraints(ArrayRef<llvm::Record *> opDefs);
112118

113119
/// Unique all compatible type and attribute constraints from a pattern file
114120
/// and emit them at the top of the generated file.
@@ -177,8 +183,6 @@ class StaticVerifierFunctionEmitter {
177183
/// Emit pattern constraints.
178184
void emitPatternConstraints();
179185

180-
/// Collect and unique all the constraints used by operations.
181-
void collectOpConstraints(ArrayRef<llvm::Record *> opDefs);
182186
/// Collect and unique all pattern constraints.
183187
void collectPatternConstraints(ArrayRef<DagLeaf> constraints);
184188

mlir/lib/TableGen/CodeGenHelpers.cpp

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@ using namespace mlir::tblgen;
2424

2525
/// Generate a unique label based on the current file name to prevent name
2626
/// collisions if multiple generated files are included at once.
27-
static std::string getUniqueOutputLabel(const llvm::RecordKeeper &records) {
27+
static std::string getUniqueOutputLabel(const llvm::RecordKeeper &records,
28+
StringRef tag) {
2829
// Use the input file name when generating a unique name.
2930
std::string inputFilename = records.getInputFilename();
3031

@@ -33,7 +34,7 @@ static std::string getUniqueOutputLabel(const llvm::RecordKeeper &records) {
3334
nameRef.consume_back(".td");
3435

3536
// Sanitize any invalid characters.
36-
std::string uniqueName;
37+
std::string uniqueName(tag);
3738
for (char c : nameRef) {
3839
if (llvm::isAlnum(c) || c == '_')
3940
uniqueName.push_back(c);
@@ -44,15 +45,11 @@ static std::string getUniqueOutputLabel(const llvm::RecordKeeper &records) {
4445
}
4546

4647
StaticVerifierFunctionEmitter::StaticVerifierFunctionEmitter(
47-
raw_ostream &os, const llvm::RecordKeeper &records)
48-
: os(os), uniqueOutputLabel(getUniqueOutputLabel(records)) {}
48+
raw_ostream &os, const llvm::RecordKeeper &records, StringRef tag)
49+
: os(os), uniqueOutputLabel(getUniqueOutputLabel(records, tag)) {}
4950

5051
void StaticVerifierFunctionEmitter::emitOpConstraints(
51-
ArrayRef<llvm::Record *> opDefs, bool emitDecl) {
52-
collectOpConstraints(opDefs);
53-
if (emitDecl)
54-
return;
55-
52+
ArrayRef<llvm::Record *> opDefs) {
5653
NamespaceEmitter namespaceEmitter(os, Operator(*opDefs[0]).getCppNamespace());
5754
emitTypeConstraints();
5855
emitAttrConstraints();
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
// RUN: mlir-tblgen -gen-op-defs -op-shard-count=2 -I %S/../../include %s | FileCheck %s --check-prefix=DEFS
2+
// RUN: mlir-tblgen -gen-op-decls -op-shard-count=2 -I %S/../../include %s | FileCheck %s --check-prefix=DECLS
3+
4+
include "mlir/IR/OpBase.td"
5+
6+
def Test_Dialect : Dialect {
7+
let name = "test";
8+
let cppNamespace = "test";
9+
}
10+
11+
class Test_Op<string mnemonic, list<Trait> traits = []>
12+
: Op<Test_Dialect, mnemonic, traits>;
13+
14+
def OpA : Test_Op<"a">;
15+
def OpB : Test_Op<"b">;
16+
def OpC : Test_Op<"c">;
17+
18+
// DECLS: OpA
19+
// DECLS: OpB
20+
// DECLS: OpC
21+
// DECLS: registerTestDialectOperations(
22+
// DECLS: registerTestDialectOperations0(
23+
// DECLS: registerTestDialectOperations1(
24+
25+
// DEFS-LABEL: GET_OP_DEFS_0
26+
// DEFS: void test::registerTestDialectOperations(
27+
// DEFS: void test::registerTestDialectOperations0(
28+
// DEFS: OpAAdaptor
29+
// DEFS: OpBAdaptor
30+
31+
// DEFS-LABEL: GET_OP_DEFS_1
32+
// DEFS: void test::registerTestDialectOperations1(
33+
// DEFS: OpCAdaptor
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
set(LLVM_LINK_COMPONENTS Support)
2+
set(LIBS MLIRSupport)
3+
4+
add_tablegen(mlir-src-sharder MLIR_SRC_SHARDER
5+
mlir-src-sharder.cpp
6+
7+
DEPENDS
8+
${LIBS}
9+
)
10+
11+
set_target_properties(mlir-src-sharder PROPERTIES FOLDER "Tablegenning")
12+
target_link_libraries(mlir-src-sharder PRIVATE ${LIBS})
13+
14+
mlir_check_all_link_libraries(mlir-src-sharder)
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
//===- mlir-src-sharder.cpp - A tool for sharder generated source files ---===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Support/FileUtilities.h"
10+
#include "mlir/Support/LogicalResult.h"
11+
#include "llvm/Support/CommandLine.h"
12+
#include "llvm/Support/InitLLVM.h"
13+
#include "llvm/Support/MemoryBuffer.h"
14+
#include "llvm/Support/ToolOutputFile.h"
15+
16+
using namespace mlir;
17+
18+
/// Create a dependency file for `-d` option.
19+
///
20+
/// This functionality is generally only for the benefit of the build system,
21+
/// and is modeled after the same option in TableGen.
22+
static LogicalResult createDependencyFile(StringRef outputFilename,
23+
StringRef dependencyFile) {
24+
if (outputFilename == "-") {
25+
llvm::errs() << "error: the option -d must be used together with -o\n";
26+
return failure();
27+
}
28+
29+
std::string errorMessage;
30+
std::unique_ptr<llvm::ToolOutputFile> outputFile =
31+
openOutputFile(dependencyFile, &errorMessage);
32+
if (!outputFile) {
33+
llvm::errs() << errorMessage << "\n";
34+
return failure();
35+
}
36+
37+
outputFile->os() << outputFilename << ":\n";
38+
outputFile->keep();
39+
return success();
40+
}
41+
42+
int main(int argc, char **argv) {
43+
// FIXME: This is necessary because we link in TableGen, which defines its
44+
// options as static variables.. some of which overlap with our options.
45+
llvm::cl::ResetCommandLineParser();
46+
47+
llvm::cl::opt<unsigned> opShardIndex(
48+
"op-shard-index", llvm::cl::desc("The current shard index"));
49+
llvm::cl::opt<std::string> inputFilename(llvm::cl::Positional,
50+
llvm::cl::desc("<input file>"),
51+
llvm::cl::init("-"));
52+
llvm::cl::opt<std::string> outputFilename(
53+
"o", llvm::cl::desc("Output filename"), llvm::cl::value_desc("filename"),
54+
llvm::cl::init("-"));
55+
llvm::cl::list<std::string> includeDirs(
56+
"I", llvm::cl::desc("Directory of include files"),
57+
llvm::cl::value_desc("directory"), llvm::cl::Prefix);
58+
llvm::cl::opt<std::string> dependencyFilename(
59+
"d", llvm::cl::desc("Dependency filename"),
60+
llvm::cl::value_desc("filename"), llvm::cl::init(""));
61+
llvm::cl::opt<bool> writeIfChanged(
62+
"write-if-changed",
63+
llvm::cl::desc("Only write to the output file if it changed"));
64+
65+
llvm::InitLLVM y(argc, argv);
66+
llvm::cl::ParseCommandLineOptions(argc, argv);
67+
68+
// Open the input file.
69+
std::string errorMessage;
70+
std::unique_ptr<llvm::MemoryBuffer> inputFile =
71+
openInputFile(inputFilename, &errorMessage);
72+
if (!inputFile) {
73+
llvm::errs() << errorMessage << "\n";
74+
return 1;
75+
}
76+
77+
// Write the output to a buffer.
78+
std::string outputStr;
79+
llvm::raw_string_ostream os(outputStr);
80+
os << "#define GET_OP_DEFS_" << opShardIndex << "\n"
81+
<< inputFile->getBuffer();
82+
83+
// Determine whether we need to write the output file.
84+
bool shouldWriteOutput = true;
85+
if (writeIfChanged) {
86+
// Only update the real output file if there are any differences. This
87+
// prevents recompilation of all the files depending on it if there aren't
88+
// any.
89+
if (auto existingOrErr =
90+
llvm::MemoryBuffer::getFile(outputFilename, /*IsText=*/true))
91+
if (std::move(existingOrErr.get())->getBuffer() == os.str())
92+
shouldWriteOutput = false;
93+
}
94+
95+
// Populate the output file if necessary.
96+
if (shouldWriteOutput) {
97+
std::unique_ptr<llvm::ToolOutputFile> outputFile =
98+
openOutputFile(outputFilename, &errorMessage);
99+
if (!outputFile) {
100+
llvm::errs() << errorMessage << "\n";
101+
return 1;
102+
}
103+
outputFile->os() << os.str();
104+
outputFile->keep();
105+
}
106+
107+
// Always write the depfile, even if the main output hasn't changed. If it's
108+
// missing, Ninja considers the output dirty.
109+
if (!dependencyFilename.empty())
110+
if (failed(createDependencyFile(outputFilename, dependencyFilename)))
111+
return 1;
112+
113+
return 0;
114+
}

0 commit comments

Comments
 (0)