diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp index a2911386f12af..7417b2a22c7b5 100644 --- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp +++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp @@ -185,6 +185,10 @@ struct CppEmitter { /// Return the existing or a new name for a Value. StringRef getOrCreateName(Value val); + /// Return the existing or a new name for a loop induction variable of an + /// emitc::ForOp. + StringRef getOrCreateName(emitc::ForOp forOp); + // Returns the textual representation of a subscript operation. std::string getSubscriptName(emitc::SubscriptOp op); @@ -200,23 +204,39 @@ struct CppEmitter { /// Whether to map an mlir integer to a unsigned integer in C++. bool shouldMapToUnsigned(IntegerType::SignednessSemantics val); - /// RAII helper function to manage entering/exiting C++ scopes. + /// Abstract RAII helper function to manage entering/exiting C++ scopes. struct Scope { + ~Scope() { emitter.labelInScopeCount.pop(); } + + private: + llvm::ScopedHashTableScope valueMapperScope; + llvm::ScopedHashTableScope blockMapperScope; + + protected: Scope(CppEmitter &emitter) : valueMapperScope(emitter.valueMapper), blockMapperScope(emitter.blockMapper), emitter(emitter) { - emitter.valueInScopeCount.push(emitter.valueInScopeCount.top()); emitter.labelInScopeCount.push(emitter.labelInScopeCount.top()); } - ~Scope() { - emitter.valueInScopeCount.pop(); - emitter.labelInScopeCount.pop(); + CppEmitter &emitter; + }; + + /// RAII helper function to manage entering/exiting functions, while re-using + /// value names. + struct FunctionScope : Scope { + FunctionScope(CppEmitter &emitter) : Scope(emitter) { + // Re-use value names + emitter.resetValueCounter(); } + }; - private: - llvm::ScopedHashTableScope valueMapperScope; - llvm::ScopedHashTableScope blockMapperScope; - CppEmitter &emitter; + /// RAII helper function to manage entering/exiting emitc::forOp loops and + /// handle induction variable naming. + struct LoopScope : Scope { + LoopScope(CppEmitter &emitter) : Scope(emitter) { + emitter.increaseLoopNestingLevel(); + } + ~LoopScope() { emitter.decreaseLoopNestingLevel(); } }; /// Returns wether the Value is assigned to a C++ variable in the scope. @@ -264,6 +284,15 @@ struct CppEmitter { /// This emitter will only emit translation units whos id matches this value. StringRef willOnlyEmitTu() { return onlyTu; } + // Resets the value counter to 0 + void resetValueCounter(); + + // Increases the loop nesting level by 1 + void increaseLoopNestingLevel(); + + // Decreases the loop nesting level by 1 + void decreaseLoopNestingLevel(); + private: using ValueMapper = llvm::ScopedHashTable; using BlockMapper = llvm::ScopedHashTable; @@ -288,11 +317,19 @@ struct CppEmitter { /// Map from block to name of C++ label. BlockMapper blockMapper; - /// The number of values in the current scope. This is used to declare the - /// names of values in a scope. - std::stack valueInScopeCount; + /// Default values representing outermost scope + llvm::ScopedHashTableScope defaultValueMapperScope; + llvm::ScopedHashTableScope defaultBlockMapperScope; + std::stack labelInScopeCount; + /// Keeps track of the amount of nested loops the emitter currently operates + /// in. + uint64_t loopNestingLevel{0}; + + /// Emitter-level count of created values to enable unique identifiers. + unsigned int valueCount{0}; + /// State of the current expression being emitted. ExpressionOp emittedExpression; SmallVector emittedExpressionPrecedence; @@ -911,7 +948,6 @@ static LogicalResult printOperation(CppEmitter &emitter, } static LogicalResult printOperation(CppEmitter &emitter, emitc::ForOp forOp) { - raw_indented_ostream &os = emitter.ostream(); // Utility function to determine whether a value is an expression that will be @@ -930,12 +966,12 @@ static LogicalResult printOperation(CppEmitter &emitter, emitc::ForOp forOp) { emitter.emitType(forOp.getLoc(), forOp.getInductionVar().getType()))) return failure(); os << " "; - os << emitter.getOrCreateName(forOp.getInductionVar()); + os << emitter.getOrCreateName(forOp); os << " = "; if (failed(emitter.emitOperand(forOp.getLowerBound()))) return failure(); os << "; "; - os << emitter.getOrCreateName(forOp.getInductionVar()); + os << emitter.getOrCreateName(forOp); os << " < "; Value upperBound = forOp.getUpperBound(); bool upperBoundRequiresParentheses = requiresParentheses(upperBound); @@ -946,13 +982,15 @@ static LogicalResult printOperation(CppEmitter &emitter, emitc::ForOp forOp) { if (upperBoundRequiresParentheses) os << ")"; os << "; "; - os << emitter.getOrCreateName(forOp.getInductionVar()); + os << emitter.getOrCreateName(forOp); os << " += "; if (failed(emitter.emitOperand(forOp.getStep()))) return failure(); os << ") {\n"; os.indent(); + CppEmitter::LoopScope lScope(emitter); + Region &forRegion = forOp.getRegion(); auto regionOps = forRegion.getOps(); @@ -1039,8 +1077,6 @@ static LogicalResult printOperation(CppEmitter &emitter, } static LogicalResult printOperation(CppEmitter &emitter, ModuleOp moduleOp) { - CppEmitter::Scope scope(emitter); - for (Operation &op : moduleOp) { if (failed(emitter.emitOperation(op, /*trailingSemicolon=*/false))) return failure(); @@ -1052,8 +1088,6 @@ static LogicalResult printOperation(CppEmitter &emitter, TranslationUnitOp tu) { if (!emitter.shouldEmitTu(tu)) return success(); - CppEmitter::Scope scope(emitter); - for (Operation &op : tu) { if (failed(emitter.emitOperation(op, /*trailingSemicolon=*/false))) return failure(); @@ -1216,7 +1250,7 @@ static LogicalResult printOperation(CppEmitter &emitter, return functionOp.emitOpError() << "cannot emit array type as result type"; } - CppEmitter::Scope scope(emitter); + CppEmitter::FunctionScope scope(emitter); raw_indented_ostream &os = emitter.ostream(); if (failed(emitter.emitTypes(functionOp.getLoc(), functionOp.getFunctionType().getResults()))) @@ -1244,7 +1278,7 @@ static LogicalResult printOperation(CppEmitter &emitter, "with multiple blocks needs variables declared at top"); } - CppEmitter::Scope scope(emitter); + CppEmitter::FunctionScope scope(emitter); raw_indented_ostream &os = emitter.ostream(); if (functionOp.getSpecifiers()) { for (Attribute specifier : functionOp.getSpecifiersAttr()) { @@ -1278,7 +1312,6 @@ static LogicalResult printOperation(CppEmitter &emitter, static LogicalResult printOperation(CppEmitter &emitter, DeclareFuncOp declareFuncOp) { - CppEmitter::Scope scope(emitter); raw_indented_ostream &os = emitter.ostream(); auto functionOp = SymbolTable::lookupNearestSymbolFrom( @@ -1310,8 +1343,9 @@ static LogicalResult printOperation(CppEmitter &emitter, CppEmitter::CppEmitter(raw_ostream &os, bool declareVariablesAtTop, StringRef onlyTu, bool constantsAsVariables) : os(os), declareVariablesAtTop(declareVariablesAtTop), - onlyTu(onlyTu.str()), constantsAsVariables(constantsAsVariables) { - valueInScopeCount.push(0); + onlyTu(onlyTu.str()), constantsAsVariables(constantsAsVariables), + defaultValueMapperScope(valueMapper), + defaultBlockMapperScope(blockMapper) { labelInScopeCount.push(0); } @@ -1352,7 +1386,29 @@ StringRef CppEmitter::getOrCreateName(Value val) { assert(!hasDeferredEmission(val.getDefiningOp()) && "cacheDeferredOpResult should have been called on this value, " "update the emitOperation function."); - valueMapper.insert(val, formatv("v{0}", ++valueInScopeCount.top())); + + valueMapper.insert(val, formatv("v{0}", ++valueCount)); + } + return *valueMapper.begin(val); +} + +/// Return the existing or a new name for a loop induction variable Value. +/// Loop induction variables follow natural naming: i, j, k,... +StringRef CppEmitter::getOrCreateName(emitc::ForOp forOp) { + Value val = forOp.getInductionVar(); + + if (!valueMapper.count(val)) { + + int64_t identifier = 'i' + loopNestingLevel; + + if (identifier >= 'i' && identifier <= 'z') { + valueMapper.insert(val, + formatv("{0}_{1}", (char)identifier, ++valueCount)); + } else { + // If running out of letters, continue with zX + valueMapper.insert( + val, formatv("z{0}_{1}", identifier - 'z' - 1, ++valueCount)); + } } return *valueMapper.begin(val); } @@ -1946,6 +2002,12 @@ LogicalResult CppEmitter::emitTupleType(Location loc, ArrayRef types) { return success(); } +void CppEmitter::resetValueCounter() { valueCount = 0; } + +void CppEmitter::increaseLoopNestingLevel() { loopNestingLevel++; } + +void CppEmitter::decreaseLoopNestingLevel() { loopNestingLevel--; } + LogicalResult emitc::translateToCpp(Operation *op, raw_ostream &os, bool declareVariablesAtTop, StringRef onlyTu, diff --git a/mlir/test/Target/Cpp/emitc-constants-as-variables.mlir b/mlir/test/Target/Cpp/emitc-constants-as-variables.mlir index f4f02cec96421..f7d09362f4f0f 100644 --- a/mlir/test/Target/Cpp/emitc-constants-as-variables.mlir +++ b/mlir/test/Target/Cpp/emitc-constants-as-variables.mlir @@ -12,7 +12,7 @@ func.func @test() { return } // CPP-DEFAULT-LABEL: void test() { -// CPP-DEFAULT-NEXT: for (size_t v1 = (size_t) 0; v1 < (size_t) 10; v1 += (size_t) 1) { +// CPP-DEFAULT-NEXT: for (size_t [[V1:[^ ]*]] = (size_t) 0; [[V1]] < (size_t) 10; [[V1]] += (size_t) 1) { // CPP-DEFAULT-NEXT: } // CPP-DEFAULT-NEXT: return; // CPP-DEFAULT-NEXT: } diff --git a/mlir/test/Target/Cpp/for_loop_induction_vars.mlir b/mlir/test/Target/Cpp/for_loop_induction_vars.mlir new file mode 100644 index 0000000000000..c27c93b14583e --- /dev/null +++ b/mlir/test/Target/Cpp/for_loop_induction_vars.mlir @@ -0,0 +1,83 @@ +// RUN: mlir-translate -mlir-to-cpp %s | FileCheck %s + +// CHECK-LABEL: test_for_siblings +func.func @test_for_siblings() { + %start = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t + %stop = "emitc.constant"() <{value = 10 : index}> : () -> !emitc.size_t + %step = "emitc.constant"() <{value = 1 : index}> : () -> !emitc.size_t + + %var1 = "emitc.variable"() <{value = 0 : index}> : () -> !emitc.lvalue + %var2 = "emitc.variable"() <{value = 0 : index}> : () -> !emitc.lvalue + + // CHECK: for (size_t [[ITER0:i_[0-9]*]] = {{.*}}; [[ITER0]] < {{.*}}; [[ITER0]] += {{.*}}) { + emitc.for %i0 = %start to %stop step %step { + // CHECK: for (size_t [[ITER1:j_[0-9]*]] = {{.*}}; [[ITER1]] < {{.*}}; [[ITER1]] += {{.*}}) { + emitc.for %i1 = %start to %stop step %step { + // CHECK: {{.*}} = [[ITER0]]; + "emitc.assign"(%var1,%i0) : (!emitc.lvalue, !emitc.size_t) -> () + // CHECK: {{.*}} = [[ITER1]]; + "emitc.assign"(%var2,%i1) : (!emitc.lvalue, !emitc.size_t) -> () + } + } + // CHECK: for (size_t [[ITER2:i_[0-9]*]] = {{.*}}; [[ITER2]] < {{.*}}; [[ITER2]] += {{.*}}) + emitc.for %ki2 = %start to %stop step %step { + // CHECK: for (size_t [[ITER3:j_[0-9]*]] = {{.*}}; [[ITER3]] < {{.*}}; [[ITER3]] += {{.*}}) + emitc.for %i3 = %start to %stop step %step { + %1 = emitc.call_opaque "f"() : () -> i32 + } + } + return +} + +// CHECK-LABEL: test_for_nesting +func.func @test_for_nesting() { + %start = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t + %stop = "emitc.constant"() <{value = 10 : index}> : () -> !emitc.size_t + %step = "emitc.constant"() <{value = 1 : index}> : () -> !emitc.size_t + + // CHECK-COUNT-18: for (size_t [[ITER:[i-z]_[0-9]*]] = {{.*}}; [[ITER]] < {{.*}}; [[ITER]] += {{.*}}) { + emitc.for %i0 = %start to %stop step %step { + emitc.for %i1 = %start to %stop step %step { + emitc.for %i2 = %start to %stop step %step { + emitc.for %i3 = %start to %stop step %step { + emitc.for %i4 = %start to %stop step %step { + emitc.for %i5 = %start to %stop step %step { + emitc.for %i6 = %start to %stop step %step { + emitc.for %i7 = %start to %stop step %step { + emitc.for %i8 = %start to %stop step %step { + emitc.for %i9 = %start to %stop step %step { + emitc.for %i10 = %start to %stop step %step { + emitc.for %i11 = %start to %stop step %step { + emitc.for %i12 = %start to %stop step %step { + emitc.for %i13 = %start to %stop step %step { + emitc.for %i14 = %start to %stop step %step { + emitc.for %i15 = %start to %stop step %step { + emitc.for %i16 = %start to %stop step %step { + emitc.for %i17 = %start to %stop step %step { + // CHECK: for (size_t [[ITERz0:z0_[0-9]*]] = {{.*}}; [[ITERz0]] < {{.*}}; [[ITERz0]] += {{.*}}) { + emitc.for %i18 = %start to %stop step %step { + // CHECK: for (size_t [[ITERz1:z1_[0-9]*]] = {{.*}}; [[ITERz1]] < {{.*}}; [[ITERz1]] += {{.*}}) { + emitc.for %i19 = %start to %stop step %step { + %0 = emitc.call_opaque "f"() : () -> i32 + } + } + } + } + } + } + } + } + } + } + } + } + } + } + } + } + } + } + } + } + return +}