In [None]:
import tvm

import tvm.te as te

from tvm.script import ir as I
from tvm.script import tir as T
from tvm.script import relax as R


def showmod(mod: tvm.ir.module.IRModule):
    mod.show(
        black_format=True,
        show_meta=False,
        verbose_expr=True,
        show_object_address=False,
        show_all_struct_info=True,
    )


def createandshowmod(ops):
    te_func = te.create_prim_func(ops).with_attrs({"global_symbol": "test"})
    mod = tvm.IRModule({"test": te_func})
    showmod(mod)


from tvm.topi.nn.bitserial_conv2d import *

#### bitserial_conv2d_nchw

In [None]:
data: te.Tensor = te.placeholder(
    (16, 8, 224, 224), dtype="uint8", name="data"
)  # 8 bits
kernel: te.Tensor = te.placeholder(
    (64, 8, 4, 4), dtype="uint8", name="kernel"
)  # 8 bits
stride = (3, 3)
padding = (1, 1)
activation_bits = 4
weight_bits = 4
out_dtype = "uint8"
pack_dtype = "uint8"
unipolar = True

output = bitserial_conv2d_nchw(
    data,
    kernel,
    stride,
    padding,
    activation_bits,
    weight_bits,
    pack_dtype,
    out_dtype,
    unipolar,
)

createandshowmod([data, kernel, output])

#### bitserial_conv2d_nhwc

In [None]:
data: te.Tensor = te.placeholder((16, 224, 224, 8), dtype="uint8", name="data")
kernel: te.Tensor = te.placeholder((4, 4, 8, 64), dtype="uint8", name="kernel")
stride = (3, 3)
padding = (1, 1)
activation_bits = 4
weight_bits = 4
out_dtype = "uint8"
pack_dtype = "uint8"
unipolar = True

output = bitserial_conv2d_nhwc(
    data,
    kernel,
    stride,
    padding,
    activation_bits,
    weight_bits,
    pack_dtype,
    out_dtype,
    unipolar,
)

createandshowmod([data, kernel, output])