In [None]:
import tvm

import tvm.te as te

from tvm.script import ir as I
from tvm.script import tir as T
from tvm.script import relax as R

from tvm.relax.binding_rewrite import DataflowBlockRewrite
from tvm.relax.analysis import name_to_binding


def showmod(mod: tvm.ir.module.IRModule):
    mod.show(
        black_format=True,
        show_meta=False,
        verbose_expr=True,
        show_object_address=False,
        show_all_struct_info=True,
    )


def createandshowmod(ops):
    te_func = te.create_prim_func(ops).with_attrs({"global_symbol": "test"})
    mod = tvm.IRModule({"test": te_func})
    showmod(mod)

In [None]:
@tvm.script.ir_module
class Identity:
    @R.function
    def main(x: R.Tensor([32, 32], "int32")) -> R.Tensor:
        with R.dataflow():
            lv0 = R.add(x, x)
            R.output(lv0)
        return lv0


showmod(Identity)

In [None]:
root_fn = Identity["main"]
dfb = root_fn.body.blocks[0]

showmod(dfb)

In [None]:
rewrite = DataflowBlockRewrite(dfb, root_fn)
# `is_dfvar = True` means that the variable is a dataflow variable,
# which can only be access in the scope of the dataflow.
rewrite.add(Identity["main"].params[0], "unused", is_dfvar=True)

# Due to the immutable and copy-on-write nature of TVM AST nodes,
# the rewriting is not done in place. Instead, a new DataflowBlock
# is created and returned with mutated_dfb. Similarly, its new root
# Function is created and returned by mutated_root_fn. To apply this
# change for an IRModule, use mutate_irmodule which rewrites the old
# function that registered in the constructor, and returns a new
# IRModule with the mutated function.
showmod(rewrite.mutated_root_fn())
showmod(rewrite.mutate_irmodule(Identity))

In [None]:
@tvm.script.ir_module
class IdentityUnused:
    @R.function
    def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor:
        with R.dataflow():
            lv0 = x
            unused = lv0
            R.output(lv0)
        return lv0


n2binding = name_to_binding(IdentityUnused["main"])
for name, binding in n2binding.items():
    print(name, "---", binding)

root_fn = IdentityUnused["main"]
dfb = root_fn.body.blocks[0]
rewrite = DataflowBlockRewrite(dfb, root_fn)

print("n2binding['unused'][0].var: ", n2binding["unused"][0].var)
rewrite.remove_unused(n2binding["unused"][0].var)
showmod(rewrite.mutated_root_fn())
showmod(rewrite.mutate_irmodule(IdentityUnused))

lv0 --- [x: R.Tensor((32, 32), dtype="float32")
lv0: R.Tensor((32, 32), dtype="float32") = x]
unused --- [lv0: R.Tensor((32, 32), dtype="float32")
unused: R.Tensor((32, 32), dtype="float32") = lv0]
n2binding['unused'][0].var:  unused


In [None]:
@tvm.script.ir_module
class Lv0To1:
    @R.function
    def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor((32, 32), "float32"):
        #    x
        #    | \
        #    |  \
        #   lv0  lv1
        #  /   \
        # lv2  lv3
        #  \   /
        #   lv4
        with R.dataflow():
            lv0: R.Tensor((32, 32), "float32") = R.call_dps_packed(
                "my_relu", (x,), R.Tensor((32, 32), dtype="float32")
            )
            lv1: R.Tensor((32, 32), "float32") = R.call_dps_packed(
                "my_sigmoid", (x,), R.Tensor((32, 32), dtype="float32")
            )
            lv2: R.Tensor((32, 32), "float32") = R.call_dps_packed(
                "my_add", (x, lv0), R.Tensor((32, 32), dtype="float32")
            )
            lv3: R.Tensor((32, 32), "float32") = R.call_dps_packed(
                "my_mul", (x, lv0), R.Tensor((32, 32), dtype="float32")
            )
            lv4: R.Tensor((32, 32), "float32") = R.call_dps_packed(
                "my_whatever", (lv2, lv3), R.Tensor((32, 32), dtype="float32")
            )
            R.output(lv4)
        return lv4


showmod(Lv0To1)

root_fn = Lv0To1["main"]
dfb = root_fn.body.blocks[0]
rewrite = DataflowBlockRewrite(dfb, root_fn)

n2binding = name_to_binding(Lv0To1["main"])
lv0 = n2binding["lv0"][0].var
lv1 = n2binding["lv1"][0].var

# Replace all uses of lv0 with lv1, this will remove lv0 node and use lv1 instead.
#    x                 x
#    | \               |
#    |  \              |
#   lv0  lv1          lv1
#  /   \       =>    /   \
# lv2  lv3          lv2  lv3
#  \   /             \   /
#   lv4               lv4
rewrite.replace_all_uses(lv0, lv1)
rewrite.remove_unused(var=lv0)
# replace_all_uses 会替换变量的所有引用，但不会保留被替换变量的定义。
# remove_all_unused 会删除所有未被引用的节点，包括“临时替换”的变量，比如 lv1 在替换前未被使用也会被一并删除。
# rewrite.remove_all_unused()

showmod(rewrite.mutated_root_fn())
showmod(rewrite.mutate_irmodule(Lv0To1))