In [ ]:
import tvm
from tvm.topi.math import *

from tvm.script import ir as I
from tvm.script import tir as T
from tvm.script import relax as R


def showmod(mod: tvm.ir.module.IRModule):
    """Print the IRModule.

    Args:
        mod (tvm.ir.module.IRModule): The IRModule to print.
    """
    mod.show(
        black_format=True,
        show_meta=False,
        verbose_expr=True,
        show_object_address=False,
        show_all_struct_info=True,
    )


def createandshowmod(ops):
    """Create and show a module

    Args:
        ops (List[Union[_tensor.Tensor, tvm.tir.Var]]): The tensor expression
            to create a TensorIR PrimFunc.
    """
    te_func = te.create_prim_func(ops).with_attrs({"global_symbol": "test"})
    mod = tvm.IRModule({"test": te_func})
    showmod(mod)


A: te.Tensor = te.placeholder(shape=(128, 128), dtype="int32", name="A")
B: te.Tensor = te.placeholder(shape=(128, 128), dtype="int32", name="B")

## identity

In [ ]:
createandshowmod([A, identity(A)])

## negative

In [ ]:
createandshowmod([A, negative(A)])

## exp

In [ ]:
createandshowmod([A, exp(A)])

## erf

In [ ]:
createandshowmod([A, erf(A)])

## tanh

In [ ]:
createandshowmod([A, tanh(A)])

## tan

## 

In [ ]:
createandshowmod([A, tan(A)])

## cos

In [ ]:
createandshowmod([A, cos(A)])

## cosh

In [ ]:
createandshowmod([A, cosh(A)])

## sin

In [ ]:
createandshowmod([A, sin(A)])

## sinh

In [ ]:
createandshowmod([A, sinh(A)])

## acos

In [ ]:
createandshowmod([A, acos(A)])

## acosh

In [ ]:
createandshowmod([A, acosh(A)])

## asin

In [ ]:
createandshowmod([A, asin(A)])

## asinh

In [ ]:
createandshowmod([A, asinh(A)])

## atan

In [ ]:
createandshowmod([A, atan(A)])

## atanh

In [ ]:
createandshowmod([A, atanh(A)])

## floor

In [ ]:
createandshowmod([A, floor(A)])

## ceil

In [ ]:
createandshowmod([A, ceil(A)])

## sign

In [ ]:
createandshowmod([A, ceil(A)])

## trunc

In [ ]:
createandshowmod([A, trunc(A)])

## abs

In [ ]:
createandshowmod([A, abs(A)])

## isnan

In [ ]:
createandshowmod([A, isnan(A)])

## isfinite

In [ ]:
createandshowmod([A, isfinite(A)])

## isinf

In [ ]:
createandshowmod([A, isinf(A)])

## round

In [ ]:
createandshowmod([A, round(A)])

## log

In [ ]:
createandshowmod([A, log(A)])

## log2

In [ ]:
createandshowmod([A, log2(A)])

## log10

In [ ]:
createandshowmod([A, log10(A)])

## sqrt

In [ ]:
createandshowmod([A, sqrt(A)])

## rsqrt

In [ ]:
createandshowmod([A, rsqrt(A)])

## sigmoid

In [ ]:
createandshowmod([A, sigmoid(A)])

## left_shift

In [ ]:
createandshowmod([A, left_shift(A, n=4)])

## right_shift

In [ ]:
createandshowmod([A, right_shift(A, n=4)])

## clip

In [ ]:
a_max: tvm.tir.PrimExpr = 4
a_min: tvm.tir.PrimExpr = 2
createandshowmod([A, clip(A, a_min, a_max)])

## fixed_point_multiply

In [ ]:
"""Fixed point multiplication between data and a fixed point
constant expressed as multiplier * 2^(-shift), where multiplier
is a Q-number with 31 fractional bits
"""

multiplier: int = 3
shift: int = 2  # 3 * 2^(-2) = 3/4
createandshowmod([A, fixed_point_multiply(A, multiplier, shift)])

## fixed_point_multiply_per_axis

In [ ]:
# TODO

## cast

In [ ]:
createandshowmod([A, cast(A, "float32")])

## reinterpret

In [ ]:
createandshowmod([A, reinterpret(A, "float32")])

## fast_exp

In [ ]:
createandshowmod([A, fast_exp(A)])

## fast_tanh

In [ ]:
createandshowmod([A, fast_tanh(A)])

## fast_erf

In [ ]:
createandshowmod([A, fast_erf(A)])

## ceil_log2

In [ ]:
# TODO