In [None]:
import tvm

import tvm.te as te

from tvm.script import ir as I
from tvm.script import tir as T
from tvm.script import relax as R


def showmod(mod: tvm.ir.module.IRModule):
    mod.show(
        black_format=True,
        show_meta=False,
        verbose_expr=True,
        show_object_address=False,
        show_all_struct_info=True,
    )


def createandshowmod(ops):
    te_func = te.create_prim_func(ops).with_attrs({"global_symbol": "test"})
    mod = tvm.IRModule({"test": te_func})
    showmod(mod)


from tvm.topi.transform import *

## expand_dims

In [None]:
A: te.Tensor = te.placeholder((128, 128), name="A")

expand_dims_tensor = expand_dims(a=A, axis=0, num_newaxis=2)
createandshowmod([A, expand_dims_tensor])

expand_dims_tensor = expand_dims(a=A, axis=1, num_newaxis=2)
createandshowmod([A, expand_dims_tensor])

expand_dims_tensor = expand_dims(a=A, axis=-1, num_newaxis=2)
createandshowmod([A, expand_dims_tensor])

## expand_like

In [None]:
"""Expand an input array with the shape of second array.
This operation can always be composed of unsqueezing and
expanding dims on those unsqueezed axes.

Parameters
----------
a : tvm.te.Tensor
    The tensor to be expanded.
shape_like : tvm.te.Tensor
    The tensor to with target shape.
axis: list of int
    axis to be expanded on

Returns
-------
ret : tvm.te.Tensor
"""

print(expand_dims_tensor)

"""The shape of output tensor is the same as the shape of `shape_like`.

`axis` determines the axis of `A` that will be expanded. `axis` should
be a list of axis, and len(axis) + len(A.shape) = len(shape_like.shape).
`axis` can be negative, which means counting dimensions from the back.

For example, if A.shape = [128, 128, 256, 256], shape_like.shape =
[128, 128, 128, 128], axis = [1, 2], then the shape of output tensor is
[128, 128, 128, 128]. `axis` means the axis [1, 2] of `A` will be expanded.
And it will read A[0, 3] and write to output.
"""
A = te.placeholder(shape=(256, 256), dtype="int32", name="A")
expand_dims_tensor = te.placeholder(
    shape=(128, 128, 128, 128), dtype="int32", name="expand_dims_tensor"
)

expand_like_tensor = expand_like(a=A, shape_like=expand_dims_tensor, axis=[1, 2])
createandshowmod([A, expand_like_tensor])

expand_like_tensor = expand_like(a=A, shape_like=expand_dims_tensor, axis=[0, 1])
createandshowmod([A, expand_like_tensor])

## transpose

In [None]:
"""Permute the dimensions of an array.

Parameters
----------
a : tvm.te.Tensor
    The tensor to be expanded.

axes: tuple of ints, optional
    By default, reverse the dimensions.

Returns
-------
ret : tvm.te.Tensor
"""

A: te.Tensor = te.placeholder(shape=(256, 256, 256), dtype="float32", name="A")
createandshowmod([A, transpose(a=A, axes=(1, 0, 2))])

## flip

In [None]:
"""Flip/reverse elements of an array in a particular axis.

Parameters
----------
a : tvm.te.Tensor
    The tensor to be expanded.

axis : int, optional
    The axis along which the tensors will be reveresed.

Returns
-------
ret : tvm.te.Tensor
"""

createandshowmod([A, flip(A, axis=0)])
createandshowmod([A, flip(A, axis=1)])

## reverse_sequence

In [None]:
"""Reverse the tensor for variable length slices.
Input is first sliced along batch axis and then elements are reversed along seq axis.

Parameters
----------
a : tvm.te.Tensor
   The tensor to be reversed.

seq_lengths : tvm.te.Tensor
   A 1D Tensor with length a.dims[batch_axis]
   Must be one of the following types: int32, int64
   if seq_lengths[i] > a.dims[seq_axis], it is rounded to a.dims[seq_axis]
   if seq_lengths[i] < 1, it is rounded to 1

seq_axis : int, optional
   The axis along which the elements will be reversed. Default is 1.

batch_axis : int, optional
   The axis along which the tensor will be sliced. Default is 0.

Returns
-------
ret : tvm.te.Tensor
   The computed result of same shape and type as of input.

"""

A: te.Tensor = te.placeholder(shape=(1, 2, 3), dtype="float32", name="A")
# TODO
seq_lengths: te.Tensor = te.placeholder(shape=(1), dtype="int64", name="seq_lengths")
createandshowmod([A, seq_lengths, reverse_sequence(A, seq_lengths, 1, 0)])

## strided_slice

In [None]:
"""Slice of an array.

Parameters
----------
a : tvm.te.Tensor
    The tensor to be sliced.

begin : list of int
    The indices to begin with in the slicing.

end : list of int
    Indices indicating end of the slice.

strides : list of int, optional
    Specifies the stride values, it can be negative
    in that case, the input tensor will be reversed
    in that particular axis.

axes : list of int, optional
    Axes along which slicing is applied. When it is specified, begin, end
    strides, and axes need to a list of integers of the same length.

slice_mode : str, optional
    The slice mode [end, size].
    end - The ending indices for the slice [default].
    size - The input strides will be ignored, input end in this mode indicates
    the sizeof a slice starting at the location specified by begin. If end[i]
    is -1, all remaining elements in that dimension are included in the slice.

assume_inbound: bool, optional
    A flag to indicate if all indices are assumed to be inbound

Returns
-------
ret : tvm.te.Tensor
"""

A: te.Tensor = te.placeholder(shape=(12, 13, 14), dtype="float32", name="A")
begin = [0, 1, 2]
end = [11, 12, 13]
strides = [2, 3, 4]  # ax0, ax1, ax2 = (12 - 0)//2, (13 - 1)//3, (14 - 2)//4
axes = [0, 1, 2]  # must be combination of 0, 1, 2
slice_mode = "end"
assume_inbound = False
ret: te.Tensor = strided_slice(
    a=A,
    begin=begin,
    end=end,
    strides=strides,
    axes=axes,
    slice_mode=slice_mode,
    assume_inbound=assume_inbound,
)
createandshowmod([A, ret])

## dynamic_strided_slice

In [None]:
"""Slice of an array.

Parameters
----------
a : tvm.te.Tensor
    The tensor to be sliced.

begin : tvm.te.Tensor
    The indices to begin with in the slicing.

end : tvm.te.Tensor
    Indices indicating end of the slice.

strides : tvm.te.Tensor
    Specifies the stride values, it can be negative
    in that case, the input tensor will be reversed
    in that particular axis.

output_shape: list of PrimExpr
    Specifies the output shape

Returns
-------
ret : tvm.te.Tensor
"""

A: te.Tensor = te.placeholder(shape=(12, 13, 14), dtype="float32", name="A")
begin = [1, 2, 3]
end = [11, 12, 13]
strides = [2, 3, 4]  # (12 - 1) // 2 = 5, (13 - 2) // 3 = 3, (14 - 3) // 4 = 2
ret = dynamic_strided_slice(
    a=A, begin=begin, end=end, strides=strides, output_shape=(5, 3, 2)
)
"""
for i,j,k in [5, 3, 2]:
  # const_vector stores the strides array
  # const_vector_1 stores the begin array
  ret[i, j, k] = A[begin[0] + i * strides[0], begin[1] + j * strides[1], begin[2] + k * strides[2]]
"""
createandshowmod([A, ret])

## strided_set

In [None]:
"""Set slice of an array.

Parameters
----------
a : tvm.te.Tensor
    The tensor to be sliced.

v : tvm.te.Tensor
    The values to set

begin: tvm.te.Tensor
    The indices to begin with in the slicing.

end: tvm.te.Tensor
    Indices indicating end of the slice.

strides: tvm.te.Tensor, optional
    Specifies the stride values, it can be negative
    in that case, the input tensor will be reversed
    in that particular axis.

Returns
-------
ret : tvm.te.Tensor
"""

import numpy as np

A: te.Tensor = te.placeholder(shape=(12, 13, 14), dtype="float32", name="A")
v = te.placeholder(shape=(5, 3, 2), dtype="float32", name="v")
begin = const_vector(np.array([1, 3, 3], dtype="int32"))
# (11 - 1) // 2 = 5, (12 - 3) // 3 = 3, (11 - 3) // 4 = 2
end = const_vector(np.array([11, 12, 11], dtype="int32"))
strides = const_vector(np.array([2, 3, 4], dtype="int32"))
# const_vector stores the begin array
# const_vector_1 stores the end array
# const_vector_2 stores the strides array
createandshowmod(
    [
        A,
        v,
        begin,
        end,
        strides,
        strided_set(a=A, v=v, begin=begin, end=end, strides=strides),
    ]
)

## reshape

In [None]:
A: te.Tensor = te.placeholder((64, 64), name="A")
createandshowmod([A, reshape(a=A, newshape=(64, 64))])
# TODO

## squeeze

In [None]:
"""Remove single-dimensional entries from the shape of an array.

Parameters
----------
a : tvm.te.Tensor

axis : None or int or tuple of ints, optional
    Selects a subset of the single-dimensional entries in the shape.
    If an axis is selected with shape entry greater than one, an error is raised.

Returns
-------
squeezed : tvm.te.Tensor
"""

A: te.Tensor = te.placeholder((1, 64), name="A")
createandshowmod([A, squeeze(a=A, axis=0)])

## concatenate

In [None]:
"""Join a sequence of arrays along an existing axis.

Parameters
----------
a_tuple : tuple of tvm.te.Tensor
    The arrays to concatenate

axis : int, optional
    The axis along which the arrays will be joined. Default is 0.

Returns
-------
ret : tvm.te.Tensor
"""


A: te.Tensor = te.placeholder((16, 64), name="A")
B: te.Tensor = te.placeholder((32, 64), name="B")
concat_tensor = concatenate((A, B), axis=0)
createandshowmod([A, B, concat_tensor])

concat_tensor = concatenate((A, B), axis=1)
createandshowmod([A, B, concat_tensor])

## stack

In [None]:
"""Join a sequence of tensors along a new axis.

Parameters
----------
tensors : tuple or list of tvm.te.Tensor
    The tensors to be stacked. All tensors must have the same shape.

axis : int, optional
    The axis in the resulting tensor along which the input tensors will be stacked.
    Negative values wrap around. Default is 0.

Returns
-------
ret : tvm.te.Tensor
    The stacked tensor with an additional dimension compared to the input tensors.
"""


A: te.Tensor = te.placeholder((1, 2, 3), name="A")
B: te.Tensor = te.placeholder((1, 2, 3), name="B")
stack_tensor = stack([A, B], axis=1)
createandshowmod([A, B, stack_tensor])

## split

In [None]:
"""Split an array into multiple sub-arrays.

Parameters
----------
ary : tvm.te.Tensor

indices_or_sections : int or 1-D array

axis : int

Returns
-------
ret : tuple of tvm.te.Tensor
"""

A: te.Tensor = te.placeholder((16, 32, 64), name="A")
createandshowmod([A, split(A, indices_or_sections=[2], axis=2)])
# TODO

## take

In [None]:
"""Take elements from an array along an axis.

Parameters
----------
a : tvm.te.Tensor
    The source array.

indices : tvm.te.Tensor
    The indices of the values to extract.

axis : int, optional
    The axis over which to select values. By default,
    the flattened input array is used.

batch_dims : int
    The number of batch dimensions. By default is 0.

mode : str, optional
    Specifies how out-of-bound indices will behave.
    clip - clip to the range (default)
    wrap - wrap around the indices
    fast - no clip or wrap around (user must make sure indices are in-bound)

Returns
-------
ret : tvm.te.Tensor
"""

A: te.Tensor = te.placeholder((16, 32, 64), name="A")
# TODO