Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions include/tvm/relax/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,21 @@ TVM_DLL StructInfo StructInfoLCA(const StructInfo& lhs, const StructInfo& rhs,
*/
TVM_DLL Array<tir::Var> TIRVarsInStructInfo(const StructInfo& sinfo);

/*!
* \brief Get the TIR variables that defined in the input function.
* The returned list is deduplicated - each TIR variable will appear at most once.
* \param func The function object to be analyzed.
* \return The list of TIR variables that are defined in the input function.
*/
TVM_DLL Array<tir::Var> DefinedSymbolicVars(const Function& func);

/*!
* \brief Get the TIR variables that are used but not defined in the input function.
* The returned list is deduplicated - each TIR variable will appear at most once.
* \param func The function object to be analyzed.
* \return The list of TIR variables that are used but not defined in the input function.
*/
TVM_DLL Array<tir::Var> FreeSymbolicVars(const Function& func);
//-----------------------------------
// General IR analysis
//-----------------------------------
Expand Down
34 changes: 34 additions & 0 deletions python/tvm/relax/analysis/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,40 @@ def tir_vars_in_struct_info(sinfo: StructInfo) -> List[tir.Var]:
return _ffi_api.TIRVarsInStructInfo(sinfo) # type: ignore


def defined_symbolic_vars(func: Function) -> List[Var]:
"""Get the TIR variables that defined in the input function.
The returned list is deduplicated - each TIR variable will appear at most once.

Parameters
----------
func : Function
The function object to be analyzed.

Returns
-------
ret : List[Var]
The list of symbolic variables that are defined in the input function.
"""
return _ffi_api.DefinedSymbolicVars(func) # type: ignore


def free_symbolic_vars(func: Function) -> List[Var]:
"""Get the TIR variables that are used but not defined in the input function.
The returned list is deduplicated - each TIR variable will appear at most once.

Parameters
----------
func : Function
The function object to be analyzed.

Returns
-------
ret : List[Var]
The list of symbolic variables that are used but not defined in the input function.
"""
return _ffi_api.FreeSymbolicVars(func) # type: ignore


def bound_vars(expr: Expr) -> List[Var]:
"""
Return all bound variables from expression expr.
Expand Down
119 changes: 119 additions & 0 deletions src/relax/analysis/struct_info_analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -904,5 +904,124 @@ Array<tir::Var> TIRVarsInStructInfo(const StructInfo& sinfo) {
TVM_REGISTER_GLOBAL("relax.analysis.TIRVarsInStructInfo")
.set_body_typed([](const StructInfo& sinfo) { return TIRVarsInStructInfo(sinfo); });

class SymbolicVarCollector : public relax::ExprVisitor,
public relax::StructInfoVisitor,
public tir::ExprVisitor {
public:
static Array<tir::Var> Free(const Function& func) {
SymbolicVarCollector collector;
collector.VisitExpr(func);
Array<tir::Var> ret{collector.free_symbolic_var_.begin(), collector.free_symbolic_var_.end()};
return ret;
}

static Array<tir::Var> Defined(const Function& func) {
SymbolicVarCollector collector;
collector.VisitExpr(func);
Array<tir::Var> ret{collector.defined_symbolic_var_.begin(),
collector.defined_symbolic_var_.end()};
return ret;
}

private:
using relax::ExprVisitor::VisitExpr;
using relax::ExprVisitor::VisitExpr_;
using tir::ExprVisitor::VisitExpr;
using tir::ExprVisitor::VisitExpr_;

// Possible mode of visitor
enum class VisitMode {
/*! \brief Check all vars are well-defined. */
kDefault,
/*! \brief Match define the vars on first occurrence. */
kMatchVarDef,
};

void VisitExpr_(const FunctionNode* op) final {
WithMode(VisitMode::kMatchVarDef, [&]() {
ICHECK(mode_ == VisitMode::kMatchVarDef);
for (Var param : op->params) {
relax::StructInfoVisitor::VisitStructInfo(GetStructInfo(param));
}
});

relax::ExprVisitor::VisitExpr_(op);
}

void VisitBinding_(const MatchCastNode* binding) final {
WithMode(VisitMode::kMatchVarDef, [&]() { this->VisitStructInfo(binding->struct_info); });

relax::ExprVisitor::VisitBinding_(binding);
}

void VisitExprDepStructInfoField(const StructInfo& struct_info) {
return this->VisitStructInfo(struct_info);
}

void VisitStructInfo_(const FuncStructInfoNode* op) final {
if (op->params.defined()) {
WithMode(VisitMode::kMatchVarDef, [&]() {
ICHECK(mode_ == VisitMode::kMatchVarDef);
for (StructInfo param : op->params.value()) {
this->VisitStructInfo(param);
}
});
}
this->VisitStructInfo(op->ret);
}

void VisitStructInfoExprField(const Expr& expr) final {
relax::ExprVisitor::VisitExpr(expr);
if (auto* shape = expr.as<relax::ShapeExprNode>()) {
for (const auto& val : shape->values) {
this->VisitStructInfoExprField(val);
}
}
}

void VisitStructInfoExprField(const PrimExpr& expr) final {
if (mode_ == VisitMode::kMatchVarDef && expr->IsInstance<tir::VarNode>()) {
// populate symbolic var in first occurrence
const auto& var = Downcast<tir::Var>(expr);
if (defined_symbolic_var_.count(var) == 0) {
defined_symbolic_var_.insert(var);
}
}
tir::ExprVisitor::VisitExpr(expr);
}

void VisitExpr_(const tir::VarNode* op) final {
tir::Var var = GetRef<tir::Var>(op);
// default mode, check defined.
if (defined_symbolic_var_.count(var) == 0) {
free_symbolic_var_.insert(var);
}
}

// Run callback with mode.
template <typename FType>
void WithMode(VisitMode mode, FType callback) {
std::swap(mode_, mode);
callback();
std::swap(mode_, mode);
}

/*! \brief The current visit mode. */
VisitMode mode_ = VisitMode::kDefault;
/*! \brief The set of defined symbolic vars. */
std::unordered_set<tir::Var, ObjectPtrHash, ObjectPtrEqual> defined_symbolic_var_;
/*! \brief The set of free/undefined symbolic vars. */
std::unordered_set<tir::Var, ObjectPtrHash, ObjectPtrEqual> free_symbolic_var_;
};

Array<tir::Var> DefinedSymbolicVars(const Function& func) {
return SymbolicVarCollector::Defined(func);
}
Array<tir::Var> FreeSymbolicVars(const Function& func) { return SymbolicVarCollector::Free(func); }

TVM_REGISTER_GLOBAL("relax.analysis.DefinedSymbolicVars").set_body_typed(DefinedSymbolicVars);

TVM_REGISTER_GLOBAL("relax.analysis.FreeSymbolicVars").set_body_typed(FreeSymbolicVars);

} // namespace relax
} // namespace tvm
19 changes: 15 additions & 4 deletions src/relax/transform/fuse_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -506,10 +506,21 @@ class FunctionCreator : public ExprMutator {
body = builder_->Normalize(body);
body = builder_->Normalize(SeqExpr({new_block}, body));
group_attrs.Set(tvm::relax::attr::kPrimitive, Integer(1));
function_ = SymbolicVarRenewMutator::Renew(Function(/*params=*/params_, //
/*body=*/body, //
/*ret_struct_info=*/NullOpt, //
/*attrs=*/DictAttrs(group_attrs)));
Function function = Function(/*params=*/params_, //
/*body=*/body, //
/*ret_struct_info=*/NullOpt, //
/*attrs=*/DictAttrs(group_attrs));
Array<PrimExpr> free_vars =
FreeSymbolicVars(function).Map([](const tir::Var& var) -> PrimExpr { return var; });
if (!free_vars.empty()) {
params_.push_back(Var("tir_vars", ShapeStructInfo(free_vars)));
arguments_.push_back(ShapeExpr(free_vars));
function = Function(/*params=*/params_, //
/*body=*/body, //
/*ret_struct_info=*/NullOpt, //
/*attrs=*/DictAttrs(group_attrs));
}
function_ = SymbolicVarRenewMutator::Renew(function);
}
}

Expand Down
13 changes: 13 additions & 0 deletions tests/python/relax/test_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,5 +421,18 @@ def reduction(A: T.Buffer((4, 4), "float32"), B: T.Buffer((4,), "float32")):
assert not has_reshape_pattern(reduction)


def test_reshape_pattern_reject_reduction():
@T.prim_func
def reduction(A: T.Buffer((4, 4), "float32"), B: T.Buffer((4,), "float32")):
for i0, i1 in T.grid(4, 4):
with T.block("identity"):
vi0, vi1 = T.axis.remap("SR", [i0, i1])
with T.init():
B[vi0] = T.float32(0)
B[vi0] = B[vi0] + A[vi0, vi1]

assert not has_reshape_pattern(reduction)


if __name__ == "__main__":
tvm.testing.main()
26 changes: 25 additions & 1 deletion tests/python/relax/test_analysis_struct_info_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@
"""Tests analysis functions of struct info"""

import pytest

import tvm
import tvm.testing
from tvm import relax as rx, TVMError
from tvm import TVMError
from tvm import relax as rx
from tvm import tir


Expand Down Expand Up @@ -574,5 +576,27 @@ def test_tir_vars_in_struct_info():
tvm.ir.assert_structural_equal(rx.analysis.tir_vars_in_struct_info(func), [n, m])


def test_symbolic_var_collector():
n, m, k, q, p = (
tir.Var("n", "int64"),
tir.Var("m", "int64"),
tir.Var("k", "int64"),
tir.Var("q", "int64"),
tir.Var("p", "int64"),
)
bb = rx.BlockBuilder()
x = rx.Var("x", rx.TensorStructInfo([m, m + n], "float32"))
with bb.function("main", [x]):
v0 = bb.match_cast(x, rx.TensorStructInfo([m, k], "float32"))
v1 = bb.emit(rx.call_dps_packed("test", x, rx.TensorStructInfo([p, q], "float32")))
bb.emit_func_output(rx.const(1))
func = bb.get()["main"]

defined_vars = set(rx.analysis.defined_symbolic_vars(func))
free_vars = set(rx.analysis.free_symbolic_vars(func))
assert defined_vars == {m, k}
assert free_vars == {n, p, q}


if __name__ == "__main__":
tvm.testing.main()
42 changes: 42 additions & 0 deletions tests/python/relax/test_transform_fuse_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1291,5 +1291,47 @@ def main(x: R.Tensor(["n", "m"], "float32")) -> R.Tensor(["n", "m"], dtype="floa
_check(Before, Expected)


def test_symbolic_shape_aware_fuse_2():
@I.ir_module
class Before:
@R.function
def main(s: R.Shape(["n"])):
n = T.int64()
with R.dataflow():
lv0 = R.emit_te(topi.full, [n, n], "float32", 0)
lv1 = R.emit_te(topi.trilu, lv0, tvm.tir.const(1, "int32"), upper=True)
gv = R.emit_te(topi.broadcast_to, lv1, [1, 1, n, n])
R.output(gv)
return gv

@I.ir_module
class Expected:
@R.function
def fused_full_trilu_broadcast_to(
s: R.Shape(["n"]),
) -> R.Tensor([1, 1, "n", "n"], "float32"):
R.func_attr({"Primitive": 1})
n = T.int64()
with R.dataflow():
lv0 = R.emit_te(topi.full, [n, n], "float32", 0)
lv1 = R.emit_te(topi.trilu, lv0, tvm.tir.const(1, "int32"), upper=True)
gv = R.emit_te(topi.broadcast_to, lv1, [1, 1, n, n])
R.output(gv)
return gv

@R.function
def main(s: R.Shape(["n"])) -> R.Tensor((1, 1, "n", "n"), dtype="float32"):
cls = Expected
n = T.int64()
with R.dataflow():
gv: R.Tensor([1, 1, n, n], "float32") = cls.fused_full_trilu_broadcast_to(
R.shape([n])
)
R.output(gv)
return gv

_check(Before, Expected)


if __name__ == "__main__":
tvm.testing.main()