In [259]:
import tvm
from tvm.topi.math import *

from tvm.script import ir as I
from tvm.script import tir as T
from tvm.script import relax as R


def showmod(mod: tvm.ir.module.IRModule):
    """Print the IRModule.

    Args:
        mod (tvm.ir.module.IRModule): The IRModule to print.
    """
    mod.show(
        black_format=True,
        show_meta=False,
        verbose_expr=True,
        show_object_address=False,
        show_all_struct_info=True,
    )


def createandshowmod(ops):
    """Create and show a module

    Args:
        ops (List[Union[_tensor.Tensor, tvm.tir.Var]]): The tensor expression
            to create a TensorIR PrimFunc.
    """
    te_func = te.create_prim_func(ops).with_attrs({"global_symbol": "test"})
    mod = tvm.IRModule({"test": te_func})
    showmod(mod)


A: te.Tensor = te.placeholder(shape=(128, 128), dtype="int32", name="A")
B: te.Tensor = te.placeholder(shape=(128, 128), dtype="int32", name="B")

## identity

In [260]:
createandshowmod([A, identity(A)])

## negative

In [261]:
createandshowmod([A, negative(A)])

## exp

In [262]:
createandshowmod([A, exp(A)])

## erf

In [263]:
createandshowmod([A, erf(A)])

## tanh

In [264]:
createandshowmod([A, tanh(A)])

## tan

## 

In [265]:
createandshowmod([A, tan(A)])

## cos

In [266]:
createandshowmod([A, cos(A)])

## cosh

In [267]:
createandshowmod([A, cosh(A)])

## sin

In [268]:
createandshowmod([A, sin(A)])

## sinh

In [269]:
createandshowmod([A, sinh(A)])

## acos

In [270]:
createandshowmod([A, acos(A)])

## acosh

In [271]:
createandshowmod([A, acosh(A)])

## asin

In [272]:
createandshowmod([A, asin(A)])

## asinh

In [273]:
createandshowmod([A, asinh(A)])

## atan

In [274]:
createandshowmod([A, atan(A)])

## atanh

In [275]:
createandshowmod([A, atanh(A)])

## floor

In [276]:
createandshowmod([A, floor(A)])

## ceil

In [277]:
createandshowmod([A, ceil(A)])

## sign

In [278]:
createandshowmod([A, ceil(A)])

## trunc

In [279]:
createandshowmod([A, trunc(A)])

## abs

In [280]:
createandshowmod([A, abs(A)])

## isnan

In [281]:
createandshowmod([A, isnan(A)])

## isfinite

In [282]:
createandshowmod([A, isfinite(A)])

## isinf

In [283]:
createandshowmod([A, isinf(A)])

## round

In [284]:
createandshowmod([A, round(A)])

## log

In [285]:
createandshowmod([A, log(A)])

## log2

In [286]:
createandshowmod([A, log2(A)])

## log10

In [287]:
createandshowmod([A, log10(A)])

## sqrt

In [288]:
createandshowmod([A, sqrt(A)])

## rsqrt

In [289]:
createandshowmod([A, rsqrt(A)])

## sigmoid

In [290]:
createandshowmod([A, sigmoid(A)])

## left_shift

In [291]:
createandshowmod([A, left_shift(A, n=4)])

## right_shift

In [292]:
createandshowmod([A, right_shift(A, n=4)])

## clip

In [293]:
a_max: tvm.tir.PrimExpr = 4
a_min: tvm.tir.PrimExpr = 2
createandshowmod([A, clip(A, a_min, a_max)])

## fixed_point_multiply

In [294]:
"""Fixed point multiplication between data and a fixed point
constant expressed as multiplier * 2^(-shift), where multiplier
is a Q-number with 31 fractional bits
"""

multiplier: int = 3
shift: int = 2  # 3 * 2^(-2) = 3/4
createandshowmod([A, fixed_point_multiply(A, multiplier, shift)])

## fixed_point_multiply_per_axis

In [295]:
# TODO

## cast

In [296]:
createandshowmod([A, cast(A, "float32")])

## reinterpret

In [297]:
createandshowmod([A, reinterpret(A, "float32")])

## fast_exp

In [298]:
createandshowmod([A, fast_exp(A)])

## fast_tanh

In [299]:
createandshowmod([A, fast_tanh(A)])

## fast_erf

In [300]:
createandshowmod([A, fast_erf(A)])

## ceil_log2

In [301]:
# TODO