Skip to content

Commit

Permalink
[Relax] Additional unit tests for RemoveUnusedParameters (#16574)
Browse files Browse the repository at this point in the history
* [Relax] Additional unit tests for RemoveUnusedParameters

Verifying behavior for subroutines that receive `R.Prim` or `R.Shape`
parameters, if the symbolic variables defined by those parameters are
already defined by another parameter.

* Typo fix
  • Loading branch information
Lunderberg committed Feb 23, 2024
1 parent 33a6f75 commit faa6628
Showing 1 changed file with 106 additions and 3 deletions.
109 changes: 106 additions & 3 deletions tests/python/relax/test_transform_remove_unused_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,14 @@ class BaseCompare(tvm.testing.CompareBeforeAfter):
transform = tvm.relax.transform.RemoveUnusedParameters()


class TestSimple(BaseCompare):
class TestRemoveUnusedRelaxParameter(BaseCompare):
"""A relax parameter may be removed
This is only allowed for internal function calls, where all
callsites can be updated. For externally-exposed functions, the
signature may not be modified.
"""

@I.ir_module
class Before:
@R.function
Expand All @@ -46,7 +53,15 @@ def func(A: R.Tensor) -> R.Tensor:
return A


class TestSymbolicVariables(BaseCompare):
class TestReplaceSymbolicVariables(BaseCompare):
"""If a parameter is only required for its symbolic variables, provide them directly
The relax parameter `A` isn't used by the subroutine. However,
its shape defines the symbolic variables `m` and `n`. When
removing the `R.Tensor` argument, we may need to provide
additional parameters to define the symbolic variables.
"""

@I.ir_module
class Before:
@R.function
Expand Down Expand Up @@ -78,7 +93,12 @@ def func(


class TestNoExtraSymbolicVariables(BaseCompare):
"""Don't add symbolic variables if they can be inferred."""
"""Don't add symbolic variables if they can be inferred.
Even though some cases require adding new parameters to provide
symbolic variables, not every symbolic variable requires a
distinct parameter.
"""

@I.ir_module
class Before:
Expand All @@ -97,5 +117,88 @@ def func(A: R.Tensor(["m", "n"], "float32")) -> R.Tensor(["m", "n"], "float32"):
Expected = Before


class TestRemoveExtraPrimVariables(BaseCompare):
"""Remove parameters that only serve to define existing symbolic variables
If a `R.Prim` parameter provies a definition of a symbolic
variable, but that symbolic variable can be determined from a
different parameter, then the `R.Prim` parameter can be removed.
"""

@I.ir_module
class Before:
@R.function
def main(A: R.Tensor(["m", "n"], "float32")) -> R.Tensor(["m", "n"], "float32"):
m = T.int64()
n = T.int64()
return Before.func(A, R.prim_value(m), R.prim_value(n))

@R.function(private=True)
def func(
A: R.Tensor(["m", "n"], "float32"), _m: R.Prim(value="m"), _n: R.Prim(value="n")
) -> R.Tensor(["m", "n"], "float32"):
m = T.int64()
n = T.int64()
zeros = R.zeros(R.shape([m, n]), dtype="float32")
out = R.add(A, zeros)
return out

@I.ir_module
class Expected:
@R.function
def main(A: R.Tensor(["m", "n"], "float32")) -> R.Tensor(["m", "n"], "float32"):
return Expected.func(A)

@R.function(private=True)
def func(A: R.Tensor(["m", "n"], "float32")) -> R.Tensor(["m", "n"], "float32"):
m = T.int64()
n = T.int64()
zeros = R.zeros(R.shape([m, n]), dtype="float32")
out = R.add(A, zeros)
return out


class TestRemoveExtraShapeVariables(BaseCompare):
"""Remove parameters that only serve to define existing symbolic variables
If a `R.Shape` parameter provides a definition of a symbolic
variable, but that symbolic variable can be determined from a
different parameter, then the `R.Shape` parameter can be removed.
"""

@I.ir_module
class Before:
@R.function
def main(A: R.Tensor(["m", "n"], "float32")) -> R.Tensor(["m", "n"], "float32"):
m = T.int64()
n = T.int64()
return Before.func(A, R.shape([m, n]))

@R.function(private=True)
def func(
A: R.Tensor(["m", "n"], "float32"),
_: R.Shape(["m", "n"]),
) -> R.Tensor(["m", "n"], "float32"):
m = T.int64()
n = T.int64()
zeros = R.zeros(R.shape([m, n]), dtype="float32")
out = R.add(A, zeros)
return out

@I.ir_module
class Expected:
@R.function
def main(A: R.Tensor(["m", "n"], "float32")) -> R.Tensor(["m", "n"], "float32"):
return Expected.func(A)

@R.function(private=True)
def func(A: R.Tensor(["m", "n"], "float32")) -> R.Tensor(["m", "n"], "float32"):
m = T.int64()
n = T.int64()
zeros = R.zeros(R.shape([m, n]), dtype="float32")
out = R.add(A, zeros)
return out


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit faa6628

Please sign in to comment.