Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relax] Additional unit tests for RemoveUnusedParameters #16574

Merged
merged 2 commits into from
Feb 23, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
109 changes: 106 additions & 3 deletions tests/python/relax/test_transform_remove_unused_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,14 @@ class BaseCompare(tvm.testing.CompareBeforeAfter):
transform = tvm.relax.transform.RemoveUnusedParameters()


class TestSimple(BaseCompare):
class TestRemoveUnusedRelaxParameter(BaseCompare):
"""A relax parameter may be removed

This is only allowed for internal function calls, where all
callsites can be updated. For externally-exposed functions, the
signature may not be modified.
"""

@I.ir_module
class Before:
@R.function
Expand All @@ -46,7 +53,15 @@ def func(A: R.Tensor) -> R.Tensor:
return A


class TestSymbolicVariables(BaseCompare):
class TestReplaceSymbolicVariables(BaseCompare):
"""If a parameter is only required for its symbolic variables, provide them directly

The relax parameter `A` isn't used by the subroutine. However,
its shape defines the symbolic variables `m` and `n`. When
removing the `R.Tensor` argument, we may need to provide
additional parameters to define the symbolic variables.
"""

@I.ir_module
class Before:
@R.function
Expand Down Expand Up @@ -78,7 +93,12 @@ def func(


class TestNoExtraSymbolicVariables(BaseCompare):
"""Don't add symbolic variables if they can be inferred."""
"""Don't add symbolic variables if they can be inferred.

Even though some cases require adding new parameters to provide
symbolic variables, not every symbolic variable requires a
distinct parameter.
"""

@I.ir_module
class Before:
Expand All @@ -97,5 +117,88 @@ def func(A: R.Tensor(["m", "n"], "float32")) -> R.Tensor(["m", "n"], "float32"):
Expected = Before


class TestRemoveExtraPrimVariables(BaseCompare):
"""Remove parameters that only serve to define existing symbolic variables

If a `R.Prim` parameter provies a definition of a symbolic
variable, but that symbolic variable can be determined from a
different parameter, then the `R.Prim` parameter can be removed.
"""

@I.ir_module
class Before:
@R.function
def main(A: R.Tensor(["m", "n"], "float32")) -> R.Tensor(["m", "n"], "float32"):
m = T.int64()
n = T.int64()
return Before.func(A, R.prim_value(m), R.prim_value(n))

@R.function(private=True)
def func(
A: R.Tensor(["m", "n"], "float32"), _m: R.Prim(value="m"), _n: R.Prim(value="n")
) -> R.Tensor(["m", "n"], "float32"):
m = T.int64()
n = T.int64()
zeros = R.zeros(R.shape([m, n]), dtype="float32")
out = R.add(A, zeros)
return out

@I.ir_module
class Expected:
@R.function
def main(A: R.Tensor(["m", "n"], "float32")) -> R.Tensor(["m", "n"], "float32"):
return Expected.func(A)

@R.function(private=True)
def func(A: R.Tensor(["m", "n"], "float32")) -> R.Tensor(["m", "n"], "float32"):
m = T.int64()
n = T.int64()
zeros = R.zeros(R.shape([m, n]), dtype="float32")
out = R.add(A, zeros)
return out


class TestRemoveExtraShapeVariables(BaseCompare):
"""Remove parameters that only serve to define existing symbolic variables

If a `R.Shape` parameter provides a definition of a symbolic
variable, but that symbolic variable can be determined from a
different parameter, then the `R.Shape` parameter can be removed.
"""

@I.ir_module
class Before:
@R.function
def main(A: R.Tensor(["m", "n"], "float32")) -> R.Tensor(["m", "n"], "float32"):
m = T.int64()
n = T.int64()
return Before.func(A, R.shape([m, n]))

@R.function(private=True)
def func(
A: R.Tensor(["m", "n"], "float32"),
_: R.Shape(["m", "n"]),
) -> R.Tensor(["m", "n"], "float32"):
m = T.int64()
n = T.int64()
zeros = R.zeros(R.shape([m, n]), dtype="float32")
out = R.add(A, zeros)
return out

@I.ir_module
class Expected:
@R.function
def main(A: R.Tensor(["m", "n"], "float32")) -> R.Tensor(["m", "n"], "float32"):
return Expected.func(A)

@R.function(private=True)
def func(A: R.Tensor(["m", "n"], "float32")) -> R.Tensor(["m", "n"], "float32"):
m = T.int64()
n = T.int64()
zeros = R.zeros(R.shape([m, n]), dtype="float32")
out = R.add(A, zeros)
return out


if __name__ == "__main__":
tvm.testing.main()