In [None]:
import tvm
import torch
import tvm.script
import tvm.script.tir as T
import tvm.script.relax as R
from tvm.relax.frontend.torch import from_exported_program
import tvm.meta_schedule as ms
import tvm.relax as rx

In [None]:
class torchModule(torch.nn.Module):
    def __init__(self, in_features, out_features) -> None:
        super().__init__()
        self.linear = torch.nn.Linear(in_features, out_features)
        self.relu = torch.nn.ReLU()
    
    def forward(self, x):
        return self.relu(self.linear(x))



In [None]:
my_model = torchModule(10, 10)
x = torch.rand([1,10], dtype=torch.float32)

exported_program = torch.export.export(my_model, args=(x,))
irmod = from_exported_program(exported_program)

irmod.show()

In [None]:
# lower to tir
tir_mod = rx.transform.LegalizeOps()(irmod)
tir_mod.show()

In [None]:
database = ms.tune_tir(
    mod=tir_mod,
    target="llvm --num-cores=1",
    max_trials_global=32,
    num_trials_per_iter=32,
    work_dir="./tune_tmp",
    )



In [None]:
sch = ms.tir_integration.compile_tir(database, irmod, "llvm --num-cores=1")
# sch.mod.show()
print(sch)

In [None]:
@tvm.script.ir_module
class MyTirModule:
    @T.prim_func
    def matmul(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")):
        T.func_attr({"global_symbol": "main", "tir.noalias": True})
        for i, j, k in T.grid(128, 128, 128):
            with T.block("C"):
                vi, vj, vk = T.axis.remap("SSR", [i, j, k])
                with T.init():
                    C[vi, vj] = 0.0
                C[vi, vj] += A[vi, vk] * B[vk, vj]


In [None]:
MyTirModule.show()

In [None]:
database = ms.tune_tir(
    mod=MyTirModule,
    target="llvm --num-cores=1",
    max_trials_global=64,
    num_trials_per_iter=64,
    work_dir="./tune_tmp",
    )

In [None]:
sch = ms.tir_integration.compile_tir(database, MyTirModule, "llvm --num-cores=1")

In [None]:
print(sch)
sch.mod.show()

tir.Schedule(0xa3dfff8)


In [None]:
import numpy as np

a_nd = tvm.nd.array(np.random.rand(128, 128).astype("float32"))
b_nd = tvm.nd.array(np.random.rand(128, 128).astype("float32"))
c_nd = tvm.nd.array(np.zeros((128, 128), dtype="float32"))


In [None]:
lib = tvm.build(MyTirModule, target="llvm")
f_timer_before = lib.time_evaluator("main", tvm.cpu())
print("Time cost of MyModule before tuning: %.3f ms" % (f_timer_before(a_nd, b_nd, c_nd).mean * 1000))


Time cost of MyModule before tuning: 2.356 ms


In [None]:
c_nd_2 = tvm.nd.array(np.zeros((128, 128), dtype="float32"))

lib = tvm.build(sch.mod, target="llvm")
f_timer_after = lib.time_evaluator("main", tvm.cpu())
print("Time cost of MyModule after tuning: %.3f ms" % (f_timer_after(a_nd, b_nd, c_nd_2).mean * 1000))

Time cost of MyModule after tuning: 0.127 ms


In [None]:
import tvm.testing
tvm.testing.assert_allclose(c_nd.numpy(), c_nd_2.numpy(), atol=1e-5, rtol=1e-5)