Skip to content

Commit

Permalink
[Ethos-U][TIR] Handle DeclBuffer in Ethos-U inputs (#15098)
Browse files Browse the repository at this point in the history
This is a subset of changes, being split out from
#14778 into independent portions.
  • Loading branch information
Lunderberg committed Jun 16, 2023
1 parent fa8a9f7 commit 7c4c913
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 14 deletions.
12 changes: 11 additions & 1 deletion python/tvm/relay/backend/contrib/ethosu/tir/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,13 +193,23 @@ def _replace_pointers(stmt):
)
return None

def _remove_buffer_decl(stmt):
if isinstance(stmt, tvm.tir.DeclBuffer):
if stmt.buffer.data in replace_output_pointer:
return stmt.body

def _post_transform(stmt):
# Replace operators with call_externs
result = _replace_operator(stmt)
# Remove operators that don't need compiling
result = result or _remove_no_compile(stmt)
# Replace necessary pointers that were removed in the previous step
return result or _replace_pointers(stmt)
result = result or _replace_pointers(stmt)
# Replace BufferDecl, since only the tir.Var data pointer is
# still used, and not the tir.Buffer
result = result or _remove_buffer_decl(stmt)

return result

def _ftransform(f, mod, ctx):
tvm.tir.stmt_functor.post_order_visit(f.body, _find_pointer_to_extent)
Expand Down
67 changes: 54 additions & 13 deletions src/tir/contrib/ethosu/passes.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
#include <unordered_map>
#include <unordered_set>

#include "../../transforms/ir_utils.h"

namespace tvm {

/*!
Expand All @@ -54,8 +56,47 @@ namespace ethosu {

namespace {

struct FlattenUnwrapResult {
std::vector<Stmt> seq;
std::vector<Stmt> rewrap_nest;
};

/*! \brief Utility function to flatten SeqStmt
*
* An AttrStmt or DeclBuffer may internally contain SeqStmt nodes that
* we want to flatten. Unlike SeqStmt::Flatten, this function unwraps
* these node types when encountered.
*
* \param stmt The tir::Stmt to be flattened.
* \return The flattened statements
*/
FlattenUnwrapResult FlattenUnwrap(const Stmt& stmt) {
std::vector<Stmt> seq_stmt;
std::vector<Stmt> rewrap_nest;
std::function<void(const Stmt&)> flatten_unwrap = [&](const Stmt& stmt) {
if (auto* ptr = stmt.as<DeclBufferNode>()) {
rewrap_nest.push_back(DeclBuffer(ptr->buffer, Evaluate(0)));
flatten_unwrap(ptr->body);
} else if (auto* ptr = stmt.as<SeqStmtNode>()) {
for (const auto& sub_stmt : ptr->seq) {
flatten_unwrap(sub_stmt);
}
} else if (auto* ptr = stmt.as<EvaluateNode>(); ptr && ptr->value.as<IntImmNode>()) {
// Skip
} else {
seq_stmt.push_back(stmt);
}
};
flatten_unwrap(stmt);
return FlattenUnwrapResult{seq_stmt, rewrap_nest};
}

/*! Returns the arguments of the given statement */
Array<PrimExpr> GetStmtArgs(const Stmt& stmt) {
Array<PrimExpr> GetStmtArgs(Stmt stmt) {
while (auto* ptr = stmt.as<DeclBufferNode>()) {
stmt = ptr->body;
}

auto attr{stmt.as<AttrStmtNode>()};
Stmt eval_stmt{attr ? attr->body : stmt};
auto eval{eval_stmt.as<EvaluateNode>()};
Expand Down Expand Up @@ -215,13 +256,13 @@ class CopyComputeReorderingMutator : public StmtExprMutator {
};

Stmt VisitStmt_(const SeqStmtNode* op) override {
if (op->size() <= 1) {
auto [seq, rewrap_nest] = FlattenUnwrap(GetRef<Stmt>(op));

if (seq.size() <= 1) {
return StmtExprMutator::VisitStmt_(op);
}

auto seq_stmt{GetRef<SeqStmt>(op)};
std::vector<Stmt> new_seq(seq_stmt->size());
std::copy(seq_stmt->seq.begin(), seq_stmt->seq.end(), new_seq.begin());
std::vector<Stmt> new_seq(seq.begin(), seq.end());

// Reorder the copies and computes based on the cycle count
if (_reorder_by_cycles) {
Expand Down Expand Up @@ -324,9 +365,7 @@ class CopyComputeReorderingMutator : public StmtExprMutator {
}
}

auto seq_stmt_node{CopyOnWrite(op)};
seq_stmt_node->seq = std::move(new_seq);
return Stmt{seq_stmt_node};
return MergeNest(rewrap_nest, SeqStmt::Flatten(new_seq));
}

bool stmt_is_global_copy(const Stmt& stmt) { return GetStmtType(stmt) == StmtType::global_copy; }
Expand Down Expand Up @@ -433,12 +472,13 @@ class MergeConstantsInfoExtractor : public StmtExprVisitor {
}

void VisitStmt_(const SeqStmtNode* op) override {
if (op->size() <= 1) {
std::vector<Stmt> seq_stmt = FlattenUnwrap(GetRef<Stmt>(op)).seq;

if (seq_stmt.size() <= 1) {
StmtExprVisitor::VisitStmt_(op);
return;
}

auto seq_stmt{GetRef<SeqStmt>(op)};
for (size_t i = 0; i < seq_stmt.size(); ++i) {
Stmt stmt{seq_stmt[i]};
switch (GetStmtType(stmt)) {
Expand Down Expand Up @@ -593,12 +633,13 @@ class MergeConstantsMutator : public StmtExprMutator {
}

Stmt VisitStmt_(const SeqStmtNode* op) override {
if (op->size() <= 1) {
std::vector<Stmt> seq_stmt = FlattenUnwrap(GetRef<Stmt>(op)).seq;

if (seq_stmt.size() <= 1) {
return StmtExprMutator::VisitStmt_(op);
}

Array<Stmt> new_seq{};
SeqStmt seq_stmt{GetRef<SeqStmt>(op)};
for (size_t i{0}; i < seq_stmt.size(); ++i) {
Stmt stmt{seq_stmt[i]};

Expand Down Expand Up @@ -628,7 +669,7 @@ class MergeConstantsMutator : public StmtExprMutator {
}
}
}
return SeqStmt(new_seq, op->span);
return SeqStmt::Flatten(new_seq);
}

/*! Returns the variables of the buffers written by copies */
Expand Down

0 comments on commit 7c4c913

Please sign in to comment.