Skip to content

Commit

Permalink
[TIR] Added PrettyPrint of ProducerStore/ProducerRealize nodes (#9259)
Browse files Browse the repository at this point in the history
  • Loading branch information
Lunderberg authored Oct 13, 2021
1 parent 3229cb3 commit 617c712
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 0 deletions.
7 changes: 7 additions & 0 deletions src/printer/text_printer.h
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,8 @@ class TIRTextPrinter : public StmtFunctor<Doc(const Stmt&)>,
std::unordered_map<Var, Doc, ObjectPtrHash, ObjectPtrEqual> memo_var_;
/*! \brief Map from Buffer to Doc */
std::unordered_map<Buffer, Doc, ObjectPtrHash, ObjectPtrEqual> memo_buf_;
/*! \brief Map from Buffer to Doc */
std::unordered_map<DataProducer, Doc, ObjectPtrHash, ObjectPtrEqual> memo_producer_;
/*! \brief name allocation map */
std::unordered_map<std::string, int> name_alloc_map_;

Expand Down Expand Up @@ -321,7 +323,9 @@ class TIRTextPrinter : public StmtFunctor<Doc(const Stmt&)>,
Doc VisitStmt_(const AssertStmtNode* op) override;
Doc VisitStmt_(const StoreNode* op) override;
Doc VisitStmt_(const BufferStoreNode* op) override;
Doc VisitStmt_(const ProducerStoreNode* op) override;
Doc VisitStmt_(const BufferRealizeNode* op) override;
Doc VisitStmt_(const ProducerRealizeNode* op) override;
Doc VisitStmt_(const AllocateNode* op) override;
Doc VisitStmt_(const IfThenElseNode* op) override;
Doc VisitStmt_(const SeqStmtNode* op) override;
Expand All @@ -342,7 +346,9 @@ class TIRTextPrinter : public StmtFunctor<Doc(const Stmt&)>,
Doc PrintIterVar(const IterVarNode* op);
Doc PrintRange(const RangeNode* op);
Doc PrintBuffer(const BufferNode* op);
Doc PrintProducer(const DataProducerNode* op);
Doc BufferNode2Doc(const BufferNode* op, Doc doc);
Doc DataProducerNode2Doc(const DataProducerNode* op, Doc doc);
Doc PrintString(const StringObj* op) { return Doc::StrLiteral(op->data); }
Doc PrintBufferRegion(const BufferRegionNode* op);

Expand All @@ -361,6 +367,7 @@ class TIRTextPrinter : public StmtFunctor<Doc(const Stmt&)>,
Doc GetUniqueName(std::string prefix);
Doc AllocVar(const Var& var);
Doc AllocBuf(const Buffer& buffer);
Doc AllocProducer(const DataProducer& buffer);
/*!
* \brief special method to render vectors of docs with a separator
* \param vec vector of docs
Expand Down
47 changes: 47 additions & 0 deletions src/printer/tir_text_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ Doc TIRTextPrinter::Print(const ObjectRef& node) {
return PrintRange(node.as<RangeNode>());
} else if (node->IsInstance<BufferNode>()) {
return PrintBuffer(node.as<BufferNode>());
} else if (node->IsInstance<DataProducerNode>()) {
return PrintProducer(node.as<DataProducerNode>());
} else if (node->IsInstance<StringObj>()) {
return PrintString(node.as<StringObj>());
} else if (node->IsInstance<BufferRegionNode>()) {
Expand Down Expand Up @@ -199,6 +201,19 @@ Doc TIRTextPrinter::PrintBuffer(const BufferNode* op) {
}
}

Doc TIRTextPrinter::PrintProducer(const DataProducerNode* op) {
const DataProducer& prod = GetRef<DataProducer>(op);

if (meta_->InMeta(prod)) {
return meta_->GetMetaNode(prod);
} else if (memo_producer_.count(prod)) {
return memo_producer_[prod];
} else {
memo_producer_[prod] = AllocProducer(prod);
return DataProducerNode2Doc(op, memo_producer_[prod]);
}
}

Doc TIRTextPrinter::BufferNode2Doc(const BufferNode* buf, Doc doc) {
doc << Doc::Text(": Buffer(") << Print(buf->data) << ", " << PrintDType(buf->dtype) << ", "
<< Print(buf->shape) << ", " << Print(buf->strides);
Expand All @@ -220,6 +235,11 @@ Doc TIRTextPrinter::BufferNode2Doc(const BufferNode* buf, Doc doc) {
return doc << ")";
}

Doc TIRTextPrinter::DataProducerNode2Doc(const DataProducerNode* prod, Doc doc) {
return doc << Doc::Text(": DataProducer(") << Print(prod->GetNameHint()) << ", "
<< PrintDType(prod->GetDataType()) << ", " << Print(prod->GetShape()) << ")";
}

Doc TIRTextPrinter::PrintBufferRegion(const BufferRegionNode* op) {
Doc doc;
doc << Print(op->buffer) << "[";
Expand Down Expand Up @@ -439,13 +459,26 @@ Doc TIRTextPrinter::VisitStmt_(const BufferStoreNode* op) {
return doc;
}

Doc TIRTextPrinter::VisitStmt_(const ProducerStoreNode* op) {
Doc doc;
doc << Print(op->producer) << Print(op->indices) << " = " << Print(op->value);
return doc;
}

Doc TIRTextPrinter::VisitStmt_(const BufferRealizeNode* op) {
Doc doc;
doc << "realize(" << Print(op->buffer) << ", " << Print(op->bounds) << ", "
<< Print(op->condition) << PrintBody(op->body) << ")";
return doc;
}

Doc TIRTextPrinter::VisitStmt_(const ProducerRealizeNode* op) {
Doc doc;
doc << "producer_realize(" << Print(op->producer) << ", " << Print(op->bounds) << ", "
<< Print(op->condition) << ", " << PrintBody(op->body) << ")";
return doc;
}

Doc TIRTextPrinter::VisitStmt_(const AllocateNode* op) {
Doc doc;
auto scope = GetPtrStorageScope(op->buffer_var);
Expand Down Expand Up @@ -709,6 +742,20 @@ Doc TIRTextPrinter::AllocBuf(const Buffer& buffer) {
return val;
}

Doc TIRTextPrinter::AllocProducer(const DataProducer& producer) {
const auto& it = memo_producer_.find(producer);
if (it != memo_producer_.end()) {
return it->second;
}
std::string name = producer->GetNameHint();
if (name.length() == 0 || !std::isalpha(name[0])) {
name = "tensor_" + name;
}
Doc val = GetUniqueName(name);
memo_producer_[producer] = val;
return val;
}

Doc TIRTextPrinter::PrintSep(const std::vector<Doc>& vec, const Doc& sep) {
Doc seq;
if (vec.size() != 0) {
Expand Down

0 comments on commit 617c712

Please sign in to comment.