In [None]:
import tvm

import tvm.te as te

from tvm.script import ir as I
from tvm.script import tir as T
from tvm.script import relax as R


def showmod(mod: tvm.ir.module.IRModule):
    mod.show(
        black_format=True,
        show_meta=False,
        verbose_expr=True,
        show_object_address=False,
        show_all_struct_info=True,
    )


def createandshowmod(ops):
    te_func = te.create_prim_func(ops).with_attrs({"global_symbol": "test"})
    mod = tvm.IRModule({"test": te_func})
    showmod(mod)


from tvm.topi.nn.batch_norm import *

#### batch_norm
$y = \frac{x-E[x]}{\sqrt{Var[x]+\epsilon}} * \gamma + \beta$

In [None]:
data: tvm.te.Tensor = te.placeholder(shape=(128, 128), dtype="float32", name="data")
gamma: tvm.te.Tensor = te.placeholder(shape=(128, 128), dtype="float32", name="gamma")
beta: tvm.te.Tensor = te.placeholder(shape=(128, 128), dtype="float32", name="beta")
moving_mean: tvm.te.Tensor = te.placeholder(
    shape=(128, 128), dtype="float32", name="moving_mean"
)
moving_var: tvm.te.Tensor = te.placeholder(
    shape=(128, 128), dtype="float32", name="moving_var"
)
axis = 0
epsilon = 1e-5
center = False
scale = True
training = False
momentum = 0.1

output, output_moving_mean, output_moving_var = batch_norm(
    data,
    gamma,
    beta,
    moving_mean,
    moving_var,
    axis,
    epsilon,
    center,
    scale,
    training,
    momentum,
)

createandshowmod(
    [
        data,
        gamma,
        beta,
        moving_mean,
        moving_var,
        output,
        output_moving_mean,
        output_moving_var,
    ]
)