In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import typing
import tvm
from tvm import relay
from tvm import relax
import tvm.contrib.graph_executor as runtime
import numpy as np

class MyRelaxModel(relax.frontend.nn.Module):
    def __init__(self):
        super(MyRelaxModel, self).__init__()
        self.conv1 = relax.frontend.nn.Conv2D(3, 32, kernel_size=5, stride=2, padding=2, bias=True)
        self.relu1 = relax.frontend.nn.ReLU()
        self.conv2 = relax.frontend.nn.Conv2D(32, 64, kernel_size=5, stride=2, padding=2, bias=True)
        self.relu2 = relax.frontend.nn.ReLU()

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.conv2(x)
        x = self.relu2(x)
        return x



input_shape = (1, 3, 128, 128)
mod, params = MyRelaxModel().export_tvm({"forward": {"x": relax.frontend.nn.spec.Tensor(input_shape, "float32")}})

# Unlike in torch, model parameters are not initialized automatically. Let's create them by hand
# We'll need them at profile time
model_params_tvm = [tvm.nd.array(np.random.randn(*x[1].shape).astype(np.float32)) for x in params]
model_params_tvm_labels = [x[0] for x in params]

mod.show()

# Show the IR

In [3]:
import difflib

class IRDiffer:
    def __init__(self, init_state=None):
        self.last_state = None
        self.current_state = None
        if init_state is not None:
            self.update(init_state)

    def update(self, script):
        if isinstance(script, tvm.ir.IRModule):
            script = script.script()
        elif isinstance(script, str):
            pass
        else:
            raise ValueError(f"Cannot parse type {type(script)}")

        self.last_state = self.current_state
        self.current_state = script

    def show_diff(self) -> None:
        """
        Prints the git-like diff between two relax IR states.
        """
        if not isinstance(self.last_state, str) and isinstance(self.current_state, str):
            raise ValueError(f"Can only compare two str, not {type(self.last_state)} and {type(self.current_state)}")
        line_color = {"+": 32, "-": 31}

        a = self.last_state
        b = self.current_state
        diffs = difflib.ndiff(a.splitlines(keepends=True), b.splitlines(keepends=True))
        diff_list = list(diffs)
        styled: list[str] = []
        for prev, next in zip(diff_list, diff_list[1:] + [""]):
            color = line_color.get(prev[0], 0)
            match prev[0]:
                case " ":
                    styled.append(prev)
                case "+" | "-":
                    index = [i for i, c in enumerate(next) if c == "^"]
                    _prev = list(prev)
                    for idx in index:
                        _prev[idx] = f"\x1b[97;{color+10};1m{_prev[idx]}\x1b[0;{color}m"
                    styled.append(f'\x1b[{color}m{"".join(_prev)}\x1b[0m')
                case "?":
                    continue
        print("".join(styled))

irdiffer = IRDiffer(mod)

In [4]:
# https://tvm.apache.org/docs/reference/api/python/transform.html
transforms = [
    # Phase 1. Passes on high-level operator graph
    # There's not much we can do here with this model we defined
    # relax.transform.FuseTransposeMatmul(),

    # Phase 2. Lowering to TIR, inherited TVM Relax's official "zero" pipeline
    relax.transform.LegalizeOps(),
    # relax.transform.AnnotateTIROpPattern(),
    # relax.transform.FoldConstant(),
    # relax.transform.ConvertToDataflow(),
    # relax.transform.FuseOps(),
    # relax.transform.FuseTIR(),

    # Phase 3. Passes on TIR
    # relax.transform.DeadCodeElimination(),

    # Phase 4. Lowering to VM bytecode
    # relax.transform.RewriteDataflowReshape(),
    # relax.transform.ToNonDataflow(),
    # relax.transform.RemovePurityChecking(),
    # relax.transform.CallTIRRewrite(),
    # relax.transform.StaticPlanBlockMemory(),
    # relax.transform.RewriteCUDAGraph(),
    # relax.transform.LowerAllocTensor(),
    # relax.transform.KillAfterLastUse(),
    # relax.transform.LowerRuntimeBuiltin(),
    # relax.transform.VMShapeLower(),
    # relax.transform.AttachGlobalSymbol(),
]


new_mod = mod
for t in transforms:
    new_mod = t(new_mod)
irdiffer.update(new_mod)

try:
    irdiffer.show_diff()
except ValueError as e:
    print(f"failed, maybe uninitialized irdiffer. {e}")

  # from tvm.script import ir as I
[32m+ # from tvm.script import tir as T
[0m  # from tvm.script import relax as R
  
  @I.ir_module
  class Module:
[32m+     @T.prim_func(private=True)
[0m[32m+     def add(lv: T.Buffer((T.int64(1), T.int64(32), T.int64(64), T.int64(64)), "float32"), lv1: T.Buffer((T.int64(1), T.int64(32), T.int64(1), T.int64(1)), "float32"), T_add: T.Buffer((T.int64(1), T.int64(32), T.int64(64), T.int64(64)), "float32")):
[0m[32m+         T.func_attr({"tir.noalias": T.bool(True)})
[0m[32m+         # with T.block("root"):
[0m[32m+         for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(64), T.int64(64)):
[0m[32m+             with T.block("T_add"):
[0m[32m+                 v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
[0m[32m+                 T.reads(lv[v_ax0, v_ax1, v_ax2, v_ax3], lv1[v_ax0, v_ax1, T.int64(0), T.int64(0)])
[0m[32m+                 T.writes(T_add[v_ax0, v_ax1, v_ax2, v_ax3])
[0m[32m+     

In [5]:
new_mod.show()

# Profiling the optimized module

You might need to activate the lowering passes in your transforms define after "phase 4"

    # relax.transform.ToNonDataflow(),
    # relax.transform.RemovePurityChecking(),
    # relax.transform.CallTIRRewrite(),
    # relax.transform.StaticPlanBlockMemory(),
    # relax.transform.RewriteCUDAGraph(),
    # relax.transform.LowerAllocTensor(),
    # relax.transform.KillAfterLastUse(),
    # relax.transform.LowerRuntimeBuiltin(),
    # relax.transform.VMShapeLower(),
    # relax.transform.AttachGlobalSymbol(),

else it doesn't really know how to lower the IR

In [6]:
dummy_input = tvm.nd.array(np.random.randn(1, 3, 128, 128).astype(np.float32))

# If you change the pipeline, change the name !
pipeline_name = "opt_test"
try:
    tvm.relax.get_pipeline(pipeline_name)
    print(f"Using previously defined pipeline")
except:
    @tvm.relax.register_pipeline(pipeline_name)
    def _pipeline(  # pylint: disable=too-many-arguments
        ext_mods: typing.List[tvm.relax.frontend.nn.ExternModule] = None,
    ):
        ext_mods = ext_mods or []

        @tvm.transform.module_pass(opt_level=1)
        def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.IRModule:
            seq = tvm.transform.Sequential(
                transforms  # <-- Defined in previous cells /!\ you are applying these transforms
            )
            mod = seq(mod)
            return mod

        return _pipeline


In [8]:
ex = relax.build(new_mod, target="llvm", pipeline=relax.get_pipeline(pipeline_name))  # You can try 'default_build'
vm = relax.VirtualMachine(ex, tvm.cpu(), profile=True)
evaluator = vm.time_evaluator("forward", dev=tvm.cpu(), min_repeat_ms=200)(
    dummy_input, *model_params_tvm
)
evaluator  # If error = TVMError: CodeGenVM cannot handle this intrinsic now: Op(relax.call_tir) --> apply phase 4 lowerings

TVMError: Traceback (most recent call last):
  6: _ZN3tvm7runtime13PackedFuncObj
  5: tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::relax::ExecBuilder, tvm::IRModule)>::AssignTypedLambda<tvm::IRModule (*)(tvm::relax::ExecBuilder, tvm::IRModule)>(tvm::IRModule (*)(tvm::relax::ExecBuilder, tvm::IRModule), std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*) const
  4: tvm::relax::relax_vm::VMCodeGen(tvm::relax::ExecBuilder, tvm::IRModule)
  3: tvm::relax::relax_vm::CodeGenVM::Run(tvm::relax::ExecBuilder, tvm::IRModule)
  2: tvm::relax::relax_vm::CodeGenVM::Codegen(tvm::relax::Function const&)
  1: tvm::relax::relax_vm::CodeGenVM::VisitExpr_(tvm::relax::SeqExprNode const*)
  0: tvm::relax::relax_vm::CodeGenVM::VisitExpr_(tvm::relax::CallNode const*)
  File "/home/arthur/Desktop/Projects/fromsource/d2l-tvm/tvm/src/relax/backend/vm/codegen_vm.cc", line 157
TVMError: CodeGenVM cannot handle this intrinsic now:
Op(relax.call_tir)