Skip to content

Commit

Permalink
disable concise scoping when the scope stmt is explicitly annotated
Browse files Browse the repository at this point in the history
  • Loading branch information
wrongtest committed Dec 21, 2023
1 parent f36a093 commit 45ec030
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 8 deletions.
22 changes: 14 additions & 8 deletions src/script/printer/tir/stmt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,13 @@ Doc DoConciseScoping(const Optional<ExprDoc>& lhs, const ExprDoc& rhs, Array<Stm
}
}

bool AllowConciseScoping(const IRDocsifier& d) {
bool AllowConciseScoping(const IRDocsifier& d, const ObjectRef& obj) {
if (d->cfg.defined()) {
if (d->cfg->obj_to_annotate.count(obj)) {
// if the object requires annotation, do not fold this frame
return false;
}
}
ICHECK(!d->frames.empty());
if (const auto* f = d->frames.back().as<TIRFrameNode>()) {
return f->allow_concise_scoping;
Expand Down Expand Up @@ -69,7 +75,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)

TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<tir::LetStmt>("", [](tir::LetStmt stmt, ObjectPath p, IRDocsifier d) -> Doc {
bool concise = AllowConciseScoping(d);
bool concise = AllowConciseScoping(d, stmt);
// Step 1. Type annotation
Optional<ExprDoc> type_doc = d->AsDoc<ExprDoc>(stmt->var->type_annotation, //
p->Attr("var")->Attr("type_annotation"));
Expand Down Expand Up @@ -105,7 +111,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<tir::AssertStmt>(
"", [](tir::AssertStmt stmt, ObjectPath p, IRDocsifier d) -> Doc {
bool concise = AllowConciseScoping(d);
bool concise = AllowConciseScoping(d, stmt);
ExprDoc cond = d->AsDoc<ExprDoc>(stmt->condition, p->Attr("condition"));
ExprDoc msg = d->AsDoc<ExprDoc>(stmt->message, p->Attr("message"));
With<TIRFrame> f(d, stmt);
Expand All @@ -129,7 +135,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
namespace {
Doc DeclBufferDoc(tir::DeclBuffer stmt, ObjectPath p, IRDocsifier d,
BufferVarDefinition var_definitions) {
bool concise = AllowConciseScoping(d);
bool concise = AllowConciseScoping(d, stmt);
ExprDoc rhs = BufferDecl(stmt->buffer, "decl_buffer", {}, p->Attr("buffer"), d->frames.back(), d,
var_definitions);
With<TIRFrame> f(d, stmt);
Expand Down Expand Up @@ -203,7 +209,7 @@ bool IsAllocateDeclBufferPattern(const tir::AllocateNode* allocate) {
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<tir::Allocate>( //
"", [](tir::Allocate stmt, ObjectPath stmt_p, IRDocsifier d) -> Doc {
bool concise = AllowConciseScoping(d);
bool concise = AllowConciseScoping(d, stmt_p);
if (d->cfg->syntax_sugar && IsAllocateDeclBufferPattern(stmt.get())) {
return DeclBufferDoc(Downcast<tir::DeclBuffer>(stmt->body), stmt_p->Attr("body"), d,
BufferVarDefinition::DataPointer);
Expand Down Expand Up @@ -261,7 +267,7 @@ ExprDoc PrintNDArray(::tvm::runtime::NDArray arr) {
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<tir::AllocateConst>(
"", [](tir::AllocateConst stmt, ObjectPath stmt_p, IRDocsifier d) -> Doc {
bool concise = AllowConciseScoping(d);
bool concise = AllowConciseScoping(d, stmt);
String storage_scope = tir::GetPtrStorageScope(stmt->buffer_var);
Array<ExprDoc> args;
Array<String> kwargs_keys;
Expand Down Expand Up @@ -379,7 +385,7 @@ ExprDoc DocsifyLaunchThread(const tir::AttrStmt& attr_stmt, const ObjectPath& at
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<tir::BufferRealize>( //
"", [](tir::BufferRealize stmt, ObjectPath p, IRDocsifier d) -> Doc {
bool concise = AllowConciseScoping(d);
bool concise = AllowConciseScoping(d, stmt);
ExprDoc rhs = DocsifyBufferRealize(stmt.get(), NullOpt, p, d);
With<TIRFrame> f(d, stmt);
AsDocBody(stmt->body, p->Attr("body"), f->get(), d);
Expand All @@ -389,7 +395,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<tir::AttrStmt>( //
"", [](tir::AttrStmt stmt, ObjectPath stmt_p, IRDocsifier d) -> Doc {
bool concise = AllowConciseScoping(d);
bool concise = AllowConciseScoping(d, stmt);
Optional<ExprDoc> lhs = NullOpt;
Optional<ExprDoc> rhs = NullOpt;
Optional<tir::Var> define_var = NullOpt;
Expand Down
25 changes: 25 additions & 0 deletions tests/python/tvmscript/test_tvmscript_printer_annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,28 @@ def main():
T.evaluate(6)
T.evaluate(7) # annotation 7"""
)


def test_disable_concise_scoping_when_scope_annotated():
@T.prim_func
def _func():
x = 1
y = x + 1
T.evaluate(y - 1)

result = _func.with_attr("global_symbol", "main").script(
obj_to_annotate={
_func.body.body: "annotation 1",
}
)
assert (
result
== """# from tvm.script import tir as T
@T.prim_func
def main():
x: T.int32 = 1
# annotation 1
with T.LetStmt(x + 1) as y:
T.evaluate(y - 1)"""
)

0 comments on commit 45ec030

Please sign in to comment.