Skip to content

Commit

Permalink
[TVMScript][Relax] Preserve tir.SizeVar through TVMScript round-trip (#…
Browse files Browse the repository at this point in the history
…17083)

* [TVMScript][Relax] Preserve tir.SizeVar through TVMScript round-trip

Prior to this commit, all symbolic variables were printed identically,
regardless of whether the underlying variable was a `tir.Var` or
`tir.SizeVar`.  As a result, numeric simplifications that rely on a
`tir.SizeVar` being non-negative may be skipped after a round-trip
through TVMScript.

This commit updates the TVMScript printing and parsing of Relax
functions to use `var = T.int64(is_size_var=True)` for `tir.SizeVar`,
matching how `tir.SizeVar` is parsed for TIR functions.  As an added
benefit, this also allows Relax functions `R.Prim` arguments other
than `int64` to be benefit.  This may be useful in the future, such as
to specify the fill value for `R.full`.

* Remove strict=True argument, not available until python 3.10

* lint fix

* Fix breakage in unit tests
  • Loading branch information
Lunderberg committed Jun 13, 2024
1 parent 0fb5365 commit 5618628
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 6 deletions.
46 changes: 41 additions & 5 deletions python/tvm/script/parser/relax/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,14 @@ def bind_assign_value(
"Expected the same dtype for TIR vars "
f"but got {value.dtype} vs {prev_value.dtype}",
)
return prev_value
if not isinstance(value, type(prev_value)):
self.report_error(
node,
f"Expected the same IR type for TIR vars "
f"but existing value {type(value)} is mismatched "
f"to previous {type(prev_value)}",
)
value = prev_value
IRBuilder.name(var_name, value)
return value

Expand Down Expand Up @@ -144,18 +151,47 @@ def is_recursive(node: doc.FunctionDef) -> bool:
return False


def collect_symbolic_var_from_prelude(
self: Parser, node: doc.FunctionDef, symbolic_vars: Dict[str, tir.Var]
) -> Dict[str, tir.Var]:
prelude_vars = {}
for stmt in node.body:
if isinstance(stmt, doc.Assign) and all(
isinstance(target, doc.Name) and target.id in symbolic_vars for target in stmt.targets
):
values = self.eval_expr(stmt.value)

try:
iter(values)
except TypeError:
values = [values]

assert len(stmt.targets) == len(values)
for target, value in zip(stmt.targets, values):
name = target.id
prelude_vars[name] = value

return {**symbolic_vars, **prelude_vars}


def collect_symbolic_var_from_params(self: Parser, node: doc.FunctionDef) -> None:
# Collect symbolic vars from parameters
symbolic_vars = set()
symbolic_vars = {}
for arg in node.args.args:
if arg.annotation is None:
self.report_error(arg, "Type annotation is required for function parameters.")
param_sinfo_proxy = eval_struct_info_proxy(self, arg.annotation)
symbolic_vars.update(param_sinfo_proxy.get_symbolic_vars())

for var_name in param_sinfo_proxy.get_symbolic_vars():
if var_name not in symbolic_vars:
symbolic_vars[var_name] = tir.Var(var_name, "int64")

# Update symbolic vars based on
symbolic_vars = collect_symbolic_var_from_prelude(self, node, symbolic_vars)

# Define symbolic vars to the current var_table frame
for var_name in symbolic_vars:
self.var_table.add(var_name, tir.Var(var_name, "int64"), allow_shadowing=False)
for var_name, var in symbolic_vars.items():
self.var_table.add(var_name, var, allow_shadowing=False)


@dispatch.register(token="relax", type_name="FunctionDef")
Expand Down
3 changes: 2 additions & 1 deletion src/script/printer/relax/tir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
*/
#include <tvm/ir/expr.h>

#include "../tir/utils.h"
#include "./utils.h"

namespace tvm {
Expand Down Expand Up @@ -59,7 +60,7 @@ Doc PrintTIRVar(tir::Var n, ObjectPath n_p, IRDocsifier d) {
}
IdDoc var = d->Define(n, GetRef<Frame>(f), n->name_hint.empty() ? "v" : n->name_hint);
var->source_paths.push_back(n_p);
f->stmts.push_back(AssignDoc(var, TIR(d, DType2Str(n->dtype))->Call({}), NullOpt));
f->stmts.push_back(AssignDoc(var, PrintVarCreation(n, n_p, d), NullOpt));
}
if (Optional<ExprDoc> doc = d->GetVarDoc(n)) {
return doc.value();
Expand Down
28 changes: 28 additions & 0 deletions tests/python/tvmscript/test_tvmscript_roundtrip.py
Original file line number Diff line number Diff line change
Expand Up @@ -4088,6 +4088,32 @@ def func(A: R.Object):
yield make_ir_generator(subclass)


def relax_symbolic_size_var():
"""Relax symbolic variables may be SizeVar"""
N = tvm.tir.SizeVar("N", "int64")

@R.function
def func(A: R.Tensor([N], "float16")):
B: R.Tensor([N], "float16") = A
return B

return func


def relax_float_symbolic_var():
"""Relax symbolic variables may hold any dtype"""

@R.function
def func(A: R.Tensor(["N"], "float16"), _: R.Prim(value="threshold")):
N = T.int64()
threshold = T.float16()

B = A >= R.prim_value(threshold / T.cast(N, "float16"))
return B

return func


ir_generator = tvm.testing.parameter(
launch_env_thread,
opt_gemm_normalize,
Expand Down Expand Up @@ -4174,6 +4200,8 @@ def func(A: R.Object):
return_zero_private_with_attr,
*op_of_literal(),
*relax_match_cast_struct_info_proxy(),
relax_symbolic_size_var,
relax_float_symbolic_var,
)

relax_ir_generator = tvm.testing.parameter(
Expand Down

0 comments on commit 5618628

Please sign in to comment.