Skip to content

Commit

Permalink
[TIR] In SplitHostDevice, check for variables in thread extents (#16250)
Browse files Browse the repository at this point in the history
* [TIR] In SplitHostDevice, check for variables in thread extents

Otherwise, they would be undefined after being de-duplicated by
`ConvertSSA`.

* Revert #16236

The buf reported in #16237 can be resolved by tracking variable usage
in a thread extent.

* lint fixes

* Update TIR well-formed checker for env thread SSA requirements

Environment threads must reuse the same `tir::Var` across all
`AttrStmt` instances in a `PrimFunc`, but must not reuse across
separate `PrimFunc`s in an `IRModule`.

* Update ConvertSSA to handle environment threads' SSA requirements

* lint fix

* Updated docstrings for VerifyWellFormed

* Rely on script.Complete for read/writes

Avoids issue in cortexm unit tests resulting from read/write
annotations being present in the root block, followed by application
of BindParams.

* Typo fix

* Added structural equal comparison in unit test
  • Loading branch information
Lunderberg committed Jan 3, 2024
1 parent 8eec0bf commit eb15d04
Show file tree
Hide file tree
Showing 10 changed files with 1,383 additions and 33 deletions.
29 changes: 28 additions & 1 deletion include/tvm/tir/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -307,13 +307,40 @@ TVM_DLL Map<Buffer, Optional<Stmt>> DetectBufferAccessLCA(const PrimFunc& func);

/*!
* \brief Verify if the given TIR is well-formed. The verification includes:
* - Check if expressions not contain vars that is defined outside the block.
*
* - All variables are defined prior to their point of use.
*
* - No variables are used outside of the scope of their definition.
*
* - Each variable has a single point of definition.
*
* - Expressions within a tir::Block may not reference variables
* defined outside the block. For example, for a block with iter
* vars `vi, vj = T.axis.remap('SS', [i,j])`, the statement
* `B[i,j] = A[i,j]` would be ill-formed, because it uses the loop
* variables `i` and `j` instead of the block variables `vi` and
* `vj`.
*
* \param func The PrimFunc to be verified.
* \param assert_mode The indicator if it raises an error when the function is not well-formed.
* \return Whether it is a well-formed TIR function.
*/
TVM_DLL bool VerifyWellFormed(const PrimFunc& func, bool assert_mode = true);

/*!
* \brief Verify if the TIR in the given IRMOdule is well-formed.
*
* In addition to the checks performed for each PrimFunc (see above),
* the following checks are performed:
*
* - The same TIR variable may not be defined in more than one function
*
* \param mod The IRModule to be verified.
* \param assert_mode The indicator if it raises an error when the function is not well-formed.
* \return Whether it is a well-formed TIR module.
*/
TVM_DLL bool VerifyWellFormed(const IRModule& mod, bool assert_mode = true);

/*!
* \brief Find the entry function of the given IRModule, i.e, functions marked by
* `tir::attr::kIsEntryFunc`, whose name is `main` or being the only PrimeFunc.
Expand Down
19 changes: 8 additions & 11 deletions src/te/operation/create_primfunc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -424,15 +424,12 @@ Stmt GenerateStmtFromExternOp(const te::ExternOp& extern_op, CreateFuncInfo* inf
}
}

// Step 3. Collect Access Region
Array<BufferRegion> reads, writes;
for (const te::Tensor& tensor : extern_op->inputs) {
// We have ICHECK before so it is not needed here.
reads.push_back(BufferRegion::FullRegion(info->tensor2buffers[tensor]));
}
for (const Buffer& buffer : extern_op->output_placeholders) {
writes.push_back(BufferRegion::FullRegion(buffer));
}
// The access region does not need to be collected here, as it will
// be generated with the later application of "script.Complete" in
// GenerateAndCompletePrimFunc. Waiting until later also handles
// the case where there is only a single BlockNode, which then
// becomes the root Block of the function, and should not have
// reads/writes filled in.

BufferSubstituter substituter(var_map, input_buffer_map);
Stmt body = substituter(extern_op->body);
Expand All @@ -442,8 +439,8 @@ Stmt GenerateStmtFromExternOp(const te::ExternOp& extern_op, CreateFuncInfo* inf
/*predicate=*/Bool(true),
/*block=*/
Block(/*iter_vars=*/{},
/*reads=*/std::move(reads),
/*writes=*/std::move(writes),
/*reads=*/{},
/*writes=*/{},
/*name_hint=*/info->FreshName(extern_op->name),
/*body=*/std::move(body),
/*init=*/NullOpt,
Expand Down
214 changes: 214 additions & 0 deletions src/tir/analysis/verify_well_formed.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,97 @@
#include <tvm/tir/stmt.h>
#include <tvm/tir/stmt_functor.h>

#include <exception>
#include <optional>
#include <tuple>
#include <variant>

#include "../ir/functor_common.h"
#include "../ir/tir_visitor_with_path.h"
#include "tvm/ir/module.h"

namespace tvm {
namespace tir {

namespace {

template <typename DerivedVerifier>
class Verifier : protected TIRVisitorWithPath {
public:
template <typename TirNodeRef>
static bool Verify(const TirNodeRef& node, bool assert_on_error) {
DerivedVerifier verifier(assert_on_error);
verifier(node);
return !verifier.has_error_;
}

protected:
explicit Verifier(bool assert_on_error) : assert_on_error_(assert_on_error) {}

/* \brief Helper class to handle the bool-or-assert handles
*
* Each verifier can either return a boolean, or assert on failure.
* To avoid needing to duplicate this logic at every step, the
* Verify() method can be used. Similar to `LOG(FATAL)` or
* `LOG(DEBUG)`, it returns an object that can accept streamed
* context information.
*
* If the error should be raised, then the context is collected
* identically to `LOG(FATAL)`. If a boolean is returned, or if the
* condition passes, then the streamed context is discarded.
*
* Usage:
*
* Verify(value == expected_value)
* << "ValueError: " << value
* << " was not the expected value of " << expected_value;
*/
class VerifyStream {
public:
explicit VerifyStream(bool log_fatal) {
if (log_fatal) {
log_.emplace();
}
}

VerifyStream(const VerifyStream&) = delete;
VerifyStream& operator=(const VerifyStream&) = delete;
VerifyStream(VerifyStream&& other) { std::swap(log_, other.log_); }
VerifyStream& operator=(VerifyStream&& other) {
std::swap(log_, other.log_);
return *this;
}

template <typename T>
VerifyStream& operator<<(T&& t) {
if (log_.has_value()) {
log_.value() << std::forward<T>(t);
}
return *this;
}

~VerifyStream() noexcept(false) {
if (log_.has_value()) {
LOG(FATAL) << log_->str();
}
}

std::optional<std::ostringstream> log_{std::nullopt};
};

// TODO(Lunderberg): Add the filename/linenum with
// std::source_location when C++20 is available.
VerifyStream Verify(bool condition) {
has_error_ = has_error_ || !condition;
return VerifyStream(!condition && assert_on_error_);
}

bool assert_on_error_;
bool has_error_{false};
};

} // namespace

/*! \brief Verify all Expr inside the block does not contain:
* 1. loop vars outside the current block.
* 2. block vars of parent blocks.
Expand Down Expand Up @@ -135,10 +220,135 @@ class BlockVarAccessVerifier : public StmtExprVisitor {
bool has_error_{false};
};

class UndefinedVarVerifier : public Verifier<UndefinedVarVerifier> {
public:
// Until templated-this arrives in C++23, the CRTP can't inject a
// constructor into the child class. Therefore, must explicitly add
// the constructor.
using Verifier::Verifier;

private:
void Visit(const PrimFunc& prim_func, ObjectPath path) override {
Verifier::Visit(prim_func, path);
redefine_allowed_within_function_.clear();
}

void EnterDef(const IterVar& iter_var, ObjectPath path) override {
Verifier::EnterDef(iter_var, path);
if (iter_var->iter_type == IterVarType::kThreadIndex) {
redefine_allowed_within_function_.insert(iter_var->var);
}
}

void EnterDef(const Var& var, ObjectPath path) override {
bool redefine_is_allowed = redefine_allowed_within_function_.count(var);
{
auto it = currently_defined_.find(var);
Verify(it == currently_defined_.end() || redefine_is_allowed)
<< "ValueError: "
<< "TIR is ill-formed, "
<< "due to multiple nested definitions of variable " << var
<< ". It was first defined at " << it->second << ", and was re-defined at " << path;
}

{
auto it = previously_defined_.find(var);
Verify(it == previously_defined_.end() || redefine_is_allowed)
<< "ValueError: "
<< "TIR is ill-formed, "
<< "due to multiple definitions of variable " << var << ". It was first defined at "
<< it->second << ", and was later re-defined at " << path;
}

currently_defined_.insert({var, path});
}

void ExitDef(const Var& var, ObjectPath path) override {
auto active_def = currently_defined_.find(var);

currently_defined_.erase(active_def);
previously_defined_.insert({var, path});
}

void VisitExpr_(const VarNode* op, ObjectPath path) override {
auto var = GetRef<Var>(op);

auto active_def = currently_defined_.find(var);
auto verify = Verify(active_def != currently_defined_.end());
verify << "ValueError: "
<< "Invalid use of undefined variable " << var << " at " << path << ".";

// Check if there was a previous definition, and append the
// location to the error message if there was. This is to aid in
// debugging, by distinguishing between a variable that is
// currently out-of-scope, and a variable that never had a
// definition in the first place.
if (auto prev_def = previously_defined_.find(var); prev_def != previously_defined_.end()) {
verify << ". While this variable was previously defined at " << prev_def->second
<< ", this definition is no longer in-scope.";
}
}

// Variables that are defined in the currently-visited scope.
std::unordered_map<Var, ObjectPath, ObjectPtrHash, ObjectPtrEqual> currently_defined_;

// Variables that were previously defined, and are now out of scope.
std::unordered_map<Var, ObjectPath, ObjectPtrHash, ObjectPtrEqual> previously_defined_;

// Special variables that are allowed to be re-defined, so long as
// that re-definition occurs within the same PrimFunc. For example
std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> redefine_allowed_within_function_;
};

/* \brief Verify unique tir::Var for each environment thread
*
* Environment threads, such as CUDA's `threadIdx.x`, are defined in
* TIR using an `AttrStmt` with the key `attr::thread_extent`. A
* `PrimFunc` may contain multiple such attributes for the same
* environment thread. However, all such attributes must use the same
* `tir::Var` for a given thread.
*/
class SingleEnvThreadVerifier : public Verifier<SingleEnvThreadVerifier> {
public:
using Verifier::Verifier;

private:
void Visit(const PrimFunc& prim_func, ObjectPath path) override {
Verifier::Visit(prim_func, path);
env_thread_vars_.clear();
}

void EnterDef(const IterVar& iter_var, ObjectPath path) override {
if (iter_var->iter_type == IterVarType::kThreadIndex) {
if (auto it = env_thread_vars_.find(iter_var->thread_tag); it != env_thread_vars_.end()) {
const auto& [prev_var, prev_path] = it->second;
Verify(prev_var.same_as(iter_var->var))
<< "ValueError: "
<< "PrimFunc uses multiple distinct TIR variables "
<< " for the environment thread \"" << iter_var->thread_tag << "\". "
<< "While multiple tir::AttrStmt may define the same environment thread, "
<< "all definitions within a single PrimFunc must share the same tir::Var. "
<< "Binding of environment thread \"" << iter_var->thread_tag
<< "\" to the TIR variable " << iter_var->var << " at " << path
<< " conflicts with the previous binding to the TIR variable " << prev_var << " at "
<< path;
} else {
env_thread_vars_.insert({iter_var->thread_tag, {iter_var->var, path}});
}
}
}

std::unordered_map<String, std::tuple<Var, ObjectPath>> env_thread_vars_;
};

bool VerifyWellFormed(const PrimFunc& func, bool assert_mode) {
if (!BlockVarAccessVerifier::Verify(func, assert_mode)) {
return false;
}

if (!UndefinedVarVerifier::Verify(func, assert_mode)) return false;
if (!SingleEnvThreadVerifier::Verify(func, assert_mode)) return false;

// TODO(Siyuan): add more checks here.
return true;
}
Expand All @@ -152,6 +362,10 @@ bool VerifyWellFormed(const IRModule& mod, bool assert_mode) {
}
}
}

if (!UndefinedVarVerifier::Verify(mod, assert_mode)) return false;
if (!SingleEnvThreadVerifier::Verify(mod, assert_mode)) return false;

return true;
}

Expand Down

0 comments on commit eb15d04

Please sign in to comment.