Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 30 additions & 10 deletions mlir/lib/Target/Cpp/TranslateToCpp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,10 @@ struct CppEmitter {
LogicalResult emitVariableDeclaration(OpResult result,
bool trailingSemicolon);

/// Emits a declaration of a variable with the given type and name.
LogicalResult emitVariableDeclaration(Location loc, Type type,
StringRef name);

/// Emits the variable declaration and assignment prefix for 'op'.
/// - emits separate variable followed by std::tie for multi-valued operation;
/// - emits single type followed by variable for single result;
Expand Down Expand Up @@ -623,14 +627,12 @@ static LogicalResult printOperation(CppEmitter &emitter,
os << " " << functionOp.getName();

os << "(";
if (failed(interleaveCommaWithError(
functionOp.getArguments(), os,
[&](BlockArgument arg) -> LogicalResult {
if (failed(emitter.emitType(functionOp.getLoc(), arg.getType())))
return failure();
os << " " << emitter.getOrCreateName(arg);
return success();
})))
if (failed(interleaveCommaWithError(functionOp.getArguments(), os,
[&](BlockArgument arg) -> LogicalResult {
return emitter.emitVariableDeclaration(
functionOp.getLoc(), arg.getType(),
emitter.getOrCreateName(arg));
})))
return failure();
os << ") {\n";
os.indent();
Expand Down Expand Up @@ -893,9 +895,10 @@ LogicalResult CppEmitter::emitVariableDeclaration(OpResult result,
return result.getDefiningOp()->emitError(
"result variable for the operation already declared");
}
if (failed(emitType(result.getOwner()->getLoc(), result.getType())))
if (failed(emitVariableDeclaration(result.getOwner()->getLoc(),
result.getType(),
getOrCreateName(result))))
return failure();
os << " " << getOrCreateName(result);
if (trailingSemicolon)
os << ";\n";
return success();
Expand Down Expand Up @@ -977,6 +980,23 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
return success();
}

LogicalResult CppEmitter::emitVariableDeclaration(Location loc, Type type,
StringRef name) {
if (auto arrType = dyn_cast<emitc::ArrayType>(type)) {
if (failed(emitType(loc, arrType.getElementType())))
return failure();
os << " " << name;
for (auto dim : arrType.getShape()) {
os << "[" << dim << "]";
}
return success();
}
if (failed(emitType(loc, type)))
return failure();
os << " " << name;
return success();
}

LogicalResult CppEmitter::emitType(Location loc, Type type) {
if (auto iType = dyn_cast<IntegerType>(type)) {
switch (iType.getWidth()) {
Expand Down
5 changes: 5 additions & 0 deletions mlir/test/Target/Cpp/common-cpp.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,8 @@ func.func @apply(%arg0: i32) -> !emitc.ptr<i32> {
%1 = emitc.apply "*"(%0) : (!emitc.ptr<i32>) -> (i32)
return %0 : !emitc.ptr<i32>
}

// CHECK: void array_type(int32_t v1[3], float v2[10][20])
func.func @array_type(%arg0: !emitc.array<3xi32>, %arg1: !emitc.array<10x20xf32>) {
return
}
3 changes: 3 additions & 0 deletions mlir/test/Target/Cpp/variable.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ func.func @emitc_variable() {
%c4 = "emitc.variable"(){value = 255 : ui8} : () -> ui8
%c5 = "emitc.variable"(){value = #emitc.opaque<"">} : () -> !emitc.ptr<i32>
%c6 = "emitc.variable"(){value = #emitc.opaque<"NULL">} : () -> !emitc.ptr<i32>
%c7 = "emitc.variable"(){value = #emitc.opaque<"">} : () -> !emitc.array<3x7xi32>
return
}
// CPP-DEFAULT: void emitc_variable() {
Expand All @@ -19,6 +20,7 @@ func.func @emitc_variable() {
// CPP-DEFAULT-NEXT: uint8_t [[V4:[^ ]*]] = 255;
// CPP-DEFAULT-NEXT: int32_t* [[V5:[^ ]*]];
// CPP-DEFAULT-NEXT: int32_t* [[V6:[^ ]*]] = NULL;
// CPP-DEFAULT-NEXT: int32_t [[V7:[^ ]*]][3][7];

// CPP-DECLTOP: void emitc_variable() {
// CPP-DECLTOP-NEXT: int32_t [[V0:[^ ]*]];
Expand All @@ -28,6 +30,7 @@ func.func @emitc_variable() {
// CPP-DECLTOP-NEXT: uint8_t [[V4:[^ ]*]];
// CPP-DECLTOP-NEXT: int32_t* [[V5:[^ ]*]];
// CPP-DECLTOP-NEXT: int32_t* [[V6:[^ ]*]];
// CPP-DECLTOP-NEXT: int32_t [[V7:[^ ]*]][3][7];
// CPP-DECLTOP-NEXT: ;
// CPP-DECLTOP-NEXT: [[V1]] = 42;
// CPP-DECLTOP-NEXT: [[V2]] = -1;
Expand Down