Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Ethos-U][TIR] Handle DeclBuffer in Ethos-U inputs #15098

Merged
merged 1 commit into from
Jun 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
12 changes: 11 additions & 1 deletion python/tvm/relay/backend/contrib/ethosu/tir/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,13 +193,23 @@ def _replace_pointers(stmt):
)
return None

def _remove_buffer_decl(stmt):
if isinstance(stmt, tvm.tir.DeclBuffer):
if stmt.buffer.data in replace_output_pointer:
return stmt.body

def _post_transform(stmt):
# Replace operators with call_externs
result = _replace_operator(stmt)
# Remove operators that don't need compiling
result = result or _remove_no_compile(stmt)
# Replace necessary pointers that were removed in the previous step
return result or _replace_pointers(stmt)
result = result or _replace_pointers(stmt)
# Replace BufferDecl, since only the tir.Var data pointer is
# still used, and not the tir.Buffer
result = result or _remove_buffer_decl(stmt)

return result

def _ftransform(f, mod, ctx):
tvm.tir.stmt_functor.post_order_visit(f.body, _find_pointer_to_extent)
Expand Down
67 changes: 54 additions & 13 deletions src/tir/contrib/ethosu/passes.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
#include <unordered_map>
#include <unordered_set>

#include "../../transforms/ir_utils.h"

namespace tvm {

/*!
Expand All @@ -54,8 +56,47 @@ namespace ethosu {

namespace {

struct FlattenUnwrapResult {
std::vector<Stmt> seq;
std::vector<Stmt> rewrap_nest;
};

/*! \brief Utility function to flatten SeqStmt
*
* An AttrStmt or DeclBuffer may internally contain SeqStmt nodes that
* we want to flatten. Unlike SeqStmt::Flatten, this function unwraps
* these node types when encountered.
*
* \param stmt The tir::Stmt to be flattened.
* \return The flattened statements
*/
FlattenUnwrapResult FlattenUnwrap(const Stmt& stmt) {
std::vector<Stmt> seq_stmt;
std::vector<Stmt> rewrap_nest;
std::function<void(const Stmt&)> flatten_unwrap = [&](const Stmt& stmt) {
if (auto* ptr = stmt.as<DeclBufferNode>()) {
rewrap_nest.push_back(DeclBuffer(ptr->buffer, Evaluate(0)));
flatten_unwrap(ptr->body);
} else if (auto* ptr = stmt.as<SeqStmtNode>()) {
for (const auto& sub_stmt : ptr->seq) {
flatten_unwrap(sub_stmt);
}
} else if (auto* ptr = stmt.as<EvaluateNode>(); ptr && ptr->value.as<IntImmNode>()) {
// Skip
} else {
seq_stmt.push_back(stmt);
}
};
flatten_unwrap(stmt);
return FlattenUnwrapResult{seq_stmt, rewrap_nest};
}

/*! Returns the arguments of the given statement */
Array<PrimExpr> GetStmtArgs(const Stmt& stmt) {
Array<PrimExpr> GetStmtArgs(Stmt stmt) {
while (auto* ptr = stmt.as<DeclBufferNode>()) {
stmt = ptr->body;
}

auto attr{stmt.as<AttrStmtNode>()};
Stmt eval_stmt{attr ? attr->body : stmt};
auto eval{eval_stmt.as<EvaluateNode>()};
Expand Down Expand Up @@ -215,13 +256,13 @@ class CopyComputeReorderingMutator : public StmtExprMutator {
};

Stmt VisitStmt_(const SeqStmtNode* op) override {
if (op->size() <= 1) {
auto [seq, rewrap_nest] = FlattenUnwrap(GetRef<Stmt>(op));

if (seq.size() <= 1) {
return StmtExprMutator::VisitStmt_(op);
}

auto seq_stmt{GetRef<SeqStmt>(op)};
std::vector<Stmt> new_seq(seq_stmt->size());
std::copy(seq_stmt->seq.begin(), seq_stmt->seq.end(), new_seq.begin());
std::vector<Stmt> new_seq(seq.begin(), seq.end());

// Reorder the copies and computes based on the cycle count
if (_reorder_by_cycles) {
Expand Down Expand Up @@ -324,9 +365,7 @@ class CopyComputeReorderingMutator : public StmtExprMutator {
}
}

auto seq_stmt_node{CopyOnWrite(op)};
seq_stmt_node->seq = std::move(new_seq);
return Stmt{seq_stmt_node};
return MergeNest(rewrap_nest, SeqStmt::Flatten(new_seq));
}

bool stmt_is_global_copy(const Stmt& stmt) { return GetStmtType(stmt) == StmtType::global_copy; }
Expand Down Expand Up @@ -433,12 +472,13 @@ class MergeConstantsInfoExtractor : public StmtExprVisitor {
}

void VisitStmt_(const SeqStmtNode* op) override {
if (op->size() <= 1) {
std::vector<Stmt> seq_stmt = FlattenUnwrap(GetRef<Stmt>(op)).seq;

if (seq_stmt.size() <= 1) {
StmtExprVisitor::VisitStmt_(op);
return;
}

auto seq_stmt{GetRef<SeqStmt>(op)};
for (size_t i = 0; i < seq_stmt.size(); ++i) {
Stmt stmt{seq_stmt[i]};
switch (GetStmtType(stmt)) {
Expand Down Expand Up @@ -593,12 +633,13 @@ class MergeConstantsMutator : public StmtExprMutator {
}

Stmt VisitStmt_(const SeqStmtNode* op) override {
if (op->size() <= 1) {
std::vector<Stmt> seq_stmt = FlattenUnwrap(GetRef<Stmt>(op)).seq;

if (seq_stmt.size() <= 1) {
return StmtExprMutator::VisitStmt_(op);
}

Array<Stmt> new_seq{};
SeqStmt seq_stmt{GetRef<SeqStmt>(op)};
for (size_t i{0}; i < seq_stmt.size(); ++i) {
Stmt stmt{seq_stmt[i]};

Expand Down Expand Up @@ -628,7 +669,7 @@ class MergeConstantsMutator : public StmtExprMutator {
}
}
}
return SeqStmt(new_seq, op->span);
return SeqStmt::Flatten(new_seq);
}

/*! Returns the variables of the buffers written by copies */
Expand Down