Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions include/tvm/relax/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#include <tvm/tir/index_map.h>

#include <functional>
#include <set>
#include <utility>

namespace tvm {
Expand Down Expand Up @@ -494,6 +495,17 @@ struct VarUsageInfo {
*/
VarUsageInfo CollectVarUsage(const Expr& expr);

/*!
* \brief Get the used variables in an expression.
*
* This function collects all variables that are referenced within the given expression.
*
* \param expr The expression to analyze
*
* \return A set of variable nodes that are used in the expression
*/
TVM_DLL std::set<const VarNode*> GetUsedVars(const Expr& expr);

/*!
* \brief Remove unused statements inside DataflowBlocks.
*
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relax/analysis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
free_symbolic_vars,
free_vars,
get_static_type,
used_vars,
get_var2val,
has_reshape_pattern,
name_to_binding,
Expand Down
20 changes: 20 additions & 0 deletions python/tvm/relax/analysis/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,26 @@ def all_vars(expr: Expr) -> List[Var]:
return _ffi_api.all_vars(expr)


def used_vars(expr: Expr) -> List[Var]:
"""
Return all variables used in an expression.

This function collects all variable references within the given expression,
which is useful for analyzing variable dependencies.

Parameters
----------
expr: Expr
The expression to analyze.

Returns
-------
ret: List[Var]
List of variables used in the expression.
"""
return _ffi_api.used_vars(expr) # type: ignore


def all_global_vars(expr: Expr) -> List[GlobalVar]:
"""
Return all global variables from expression expr.
Expand Down
21 changes: 20 additions & 1 deletion src/relax/analysis/udchain.cc
Original file line number Diff line number Diff line change
Expand Up @@ -121,10 +121,29 @@ ffi::Map<Var, ffi::Array<Var>> DataflowBlockUseDef(const DataflowBlock& dfb) {

TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("relax.analysis.udchain", DataflowBlockUseDef);
refl::GlobalDef()
.def("relax.analysis.udchain", DataflowBlockUseDef)
.def("relax.analysis.used_vars", [](const Expr& expr) {
auto used_vars = GetUsedVars(expr);
ffi::Array<Var> result;
for (const VarNode* var_node : used_vars) {
result.push_back(ffi::GetRef<Var>(var_node));
}
return result;
Comment on lines +128 to +132
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For improved performance, especially when used_vars is large, consider building a std::vector with pre-reserved capacity and then constructing the ffi::Array from it. This avoids potential reallocations that can occur with repeated push_back calls on ffi::Array.

        std::vector<Var> result_vec;
        result_vec.reserve(used_vars.size());
        for (const VarNode* var_node : used_vars) {
          result_vec.push_back(ffi::GetRef<Var>(var_node));
        }
        return ffi::Array<Var>(std::move(result_vec));

});
}

VarUsageInfo CollectVarUsage(const Expr& expr) { return UDChain::Collect(expr); }

std::set<const VarNode*> GetUsedVars(const Expr& expr) {
class UsedVars : public ExprVisitor {
public:
std::set<const VarNode*> used_vars;
void VisitExpr_(const VarNode* op) override { used_vars.insert(op); }
} visitor;
visitor.VisitExpr(expr);
return std::move(visitor.used_vars);
}

} // namespace relax
} // namespace tvm
12 changes: 1 addition & 11 deletions src/relax/ir/binding_rewrite.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
*/

#include <tvm/ffi/reflection/registry.h>
#include <tvm/relax/analysis.h>
#include <tvm/relax/binding_rewrite.h>
#include <tvm/relax/block_builder.h>
#include <tvm/relax/expr.h>
Expand Down Expand Up @@ -134,17 +135,6 @@ class UpdateDFB : public ExprMutator {
}
};

// TODO(masahi): Consider moving this to analysis
std::set<const VarNode*> GetUsedVars(Expr val) {
class UsedVars : public ExprVisitor {
public:
std::set<const VarNode*> used_vars;
void VisitExpr_(const VarNode* op) override { used_vars.insert(op); }
} uvar{};
uvar.VisitExpr(val);
return std::move(uvar.used_vars);
}

void DataflowBlockRewriteNode::Add(Binding binding) {
auto [var, val] = [binding] {
if (auto vb = binding.as<VarBindingNode>()) {
Expand Down
23 changes: 23 additions & 0 deletions tests/python/relax/test_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from typing import List, Set, Union

import pytest
import tvm
import tvm.testing
from tvm import relax as rx
Expand All @@ -26,6 +27,7 @@
all_vars,
bound_vars,
free_vars,
used_vars,
has_reshape_pattern,
name_to_binding,
remove_all_unused,
Expand Down Expand Up @@ -61,6 +63,27 @@ def test_use_def():
assert set(udc[gv0]) == set()


@pytest.mark.parametrize(
"expr_fn, expected_var_names",
[
(lambda x, y, z: rx.op.add(x, y), {"x", "y"}),
(lambda x, y, z: rx.op.multiply(x, x), {"x"}),
(lambda x, y, z: rx.Tuple([x, y, z]), {"x", "y", "z"}),
],
ids=["binary_op", "self_reference", "tuple"],
)
def test_used_vars(expr_fn, expected_var_names):
m = tir.Var("m", "int64")
n = tir.Var("n", "int64")
x = rx.Var("x", R.Tensor([m, n], "float16"))
y = rx.Var("y", R.Tensor([n], "float16"))
z = rx.Var("z", R.Tensor([m], "float16"))

expr = expr_fn(x, y, z)
result = used_vars(expr)
assert var_name_set(result) == expected_var_names


def test_chained_remove_all_unused():
@tvm.script.ir_module
class IdentityUnused:
Expand Down
Loading