diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp index 1b4ec9eae9367..d4dadc12d41de 100644 --- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp +++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp @@ -96,6 +96,10 @@ struct CppEmitter { LogicalResult emitVariableDeclaration(OpResult result, bool trailingSemicolon); + /// Emits a declaration of a variable with the given type and name. + LogicalResult emitVariableDeclaration(Location loc, Type type, + StringRef name); + /// Emits the variable declaration and assignment prefix for 'op'. /// - emits separate variable followed by std::tie for multi-valued operation; /// - emits single type followed by variable for single result; @@ -623,14 +627,12 @@ static LogicalResult printOperation(CppEmitter &emitter, os << " " << functionOp.getName(); os << "("; - if (failed(interleaveCommaWithError( - functionOp.getArguments(), os, - [&](BlockArgument arg) -> LogicalResult { - if (failed(emitter.emitType(functionOp.getLoc(), arg.getType()))) - return failure(); - os << " " << emitter.getOrCreateName(arg); - return success(); - }))) + if (failed(interleaveCommaWithError(functionOp.getArguments(), os, + [&](BlockArgument arg) -> LogicalResult { + return emitter.emitVariableDeclaration( + functionOp.getLoc(), arg.getType(), + emitter.getOrCreateName(arg)); + }))) return failure(); os << ") {\n"; os.indent(); @@ -893,9 +895,10 @@ LogicalResult CppEmitter::emitVariableDeclaration(OpResult result, return result.getDefiningOp()->emitError( "result variable for the operation already declared"); } - if (failed(emitType(result.getOwner()->getLoc(), result.getType()))) + if (failed(emitVariableDeclaration(result.getOwner()->getLoc(), + result.getType(), + getOrCreateName(result)))) return failure(); - os << " " << getOrCreateName(result); if (trailingSemicolon) os << ";\n"; return success(); @@ -977,6 +980,23 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) { return success(); } +LogicalResult CppEmitter::emitVariableDeclaration(Location loc, Type type, + StringRef name) { + if (auto arrType = dyn_cast(type)) { + if (failed(emitType(loc, arrType.getElementType()))) + return failure(); + os << " " << name; + for (auto dim : arrType.getShape()) { + os << "[" << dim << "]"; + } + return success(); + } + if (failed(emitType(loc, type))) + return failure(); + os << " " << name; + return success(); +} + LogicalResult CppEmitter::emitType(Location loc, Type type) { if (auto iType = dyn_cast(type)) { switch (iType.getWidth()) { diff --git a/mlir/test/Target/Cpp/common-cpp.mlir b/mlir/test/Target/Cpp/common-cpp.mlir index b537e7098deb5..a87b33a10844d 100644 --- a/mlir/test/Target/Cpp/common-cpp.mlir +++ b/mlir/test/Target/Cpp/common-cpp.mlir @@ -89,3 +89,8 @@ func.func @apply(%arg0: i32) -> !emitc.ptr { %1 = emitc.apply "*"(%0) : (!emitc.ptr) -> (i32) return %0 : !emitc.ptr } + +// CHECK: void array_type(int32_t v1[3], float v2[10][20]) +func.func @array_type(%arg0: !emitc.array<3xi32>, %arg1: !emitc.array<10x20xf32>) { + return +} diff --git a/mlir/test/Target/Cpp/variable.mlir b/mlir/test/Target/Cpp/variable.mlir index 77a060a32f9d4..5d061a6c87505 100644 --- a/mlir/test/Target/Cpp/variable.mlir +++ b/mlir/test/Target/Cpp/variable.mlir @@ -9,6 +9,7 @@ func.func @emitc_variable() { %c4 = "emitc.variable"(){value = 255 : ui8} : () -> ui8 %c5 = "emitc.variable"(){value = #emitc.opaque<"">} : () -> !emitc.ptr %c6 = "emitc.variable"(){value = #emitc.opaque<"NULL">} : () -> !emitc.ptr + %c7 = "emitc.variable"(){value = #emitc.opaque<"">} : () -> !emitc.array<3x7xi32> return } // CPP-DEFAULT: void emitc_variable() { @@ -19,6 +20,7 @@ func.func @emitc_variable() { // CPP-DEFAULT-NEXT: uint8_t [[V4:[^ ]*]] = 255; // CPP-DEFAULT-NEXT: int32_t* [[V5:[^ ]*]]; // CPP-DEFAULT-NEXT: int32_t* [[V6:[^ ]*]] = NULL; +// CPP-DEFAULT-NEXT: int32_t [[V7:[^ ]*]][3][7]; // CPP-DECLTOP: void emitc_variable() { // CPP-DECLTOP-NEXT: int32_t [[V0:[^ ]*]]; @@ -28,6 +30,7 @@ func.func @emitc_variable() { // CPP-DECLTOP-NEXT: uint8_t [[V4:[^ ]*]]; // CPP-DECLTOP-NEXT: int32_t* [[V5:[^ ]*]]; // CPP-DECLTOP-NEXT: int32_t* [[V6:[^ ]*]]; +// CPP-DECLTOP-NEXT: int32_t [[V7:[^ ]*]][3][7]; // CPP-DECLTOP-NEXT: ; // CPP-DECLTOP-NEXT: [[V1]] = 42; // CPP-DECLTOP-NEXT: [[V2]] = -1;