In [354]:
import numpy as np
import tvm
from tvm.script import tir #stands for tensor intermediate representation
from tvm.ir.module import IRModule

In [355]:
@tvm.script.ir_module #this tells us that this class is an ir_module
class TensorModule():

    @tir.prim_func #this tells us that this function is a primitive function
    def TensorFunction(array: tir.Buffer((4), "float32")) : #a buffer is a tir type representing a tensor of a designated data type and size

        #provides metadata about the function
        #the global symbol is the name that identifies the function, default is function name
        #noalias states whether or not the function has aliasing over memory buffers
        tir.func_attr({"global_symbol": "TensorFunction", "tir.noalias": True}) 

        #increments each element in the buffer by 1
        for i in range(4) :
            array[i] += 1
        

In [356]:
#a whole module must be built before a function can be called, and this build target is the low level virtual machine
builtModule = tvm.build(TensorModule, target="llvm")

#this gets a specific packed function from the module, which can be run
functionToRun = builtModule["TensorFunction"]

In [357]:
#create a buffer 
array = tvm.nd.array(np.array([1,2,3,4], dtype="f"))

#pass array in, usually we only pass "handles" (pointers) of buffers to a built function, since memory allocation is left to the OS
functionToRun(array)
print(array)

[2. 3. 4. 5.]


In [358]:
@tvm.script.ir_module
class SuperTensorModule():

    @tir.prim_func
    def matmul(A: tir.Buffer((5,5), "float32"), 
               B: tir.Buffer((5,5), "float32"), 
               C: tir.Buffer((5,5), "float32"))  :
        tir.func_attr({"global_symbol": "matmul", "tir.noalias": True})

        #shorthand for nested loops
        for i, j, k in tir.grid(5,5,5) :
            with tir.block("C") : #block is basic unit of tensorIR computation, needed for axis mapping
                vi, vj, vk = tir.axis.remap("SSR", [i, j, k]) #remaps indicies i, j, and k to spatial, spatial, and reduce axes respectively
                with tir.init() : #runs when the block is first instatiated
                    C[vi, vj] = 0 
                C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] #compute each element of matrix

    @tir.prim_func
    def relu(A: tir.Buffer((5,5), "float32")) :
        tir.func_attr({"global_symbol": "relu", "tir.noalias": True})

        for i, j in tir.grid(5,5) :
            with tir.block("A") :
                vi, vj = tir.axis.remap("SS", [i, j])
                if(A[vi, vj] < 0) :
                    A[vi, vj] = 0
        

In [359]:
builtModule = tvm.build(SuperTensorModule, target="llvm")

In [360]:
A = tvm.nd.array(np.arange(25, dtype="f").reshape((5,5)))
B = tvm.nd.array(np.arange(25, dtype="f").reshape((5,5)))
C = tvm.nd.empty((5,5))
D = tvm.nd.array(np.arange(-12, 13, dtype="f").reshape(5,5))

In [361]:
builtModule["matmul"](A,B,C)
print(C)
builtModule["relu"](D)
print(D)

[[ 150.  160.  170.  180.  190.]
 [ 400.  435.  470.  505.  540.]
 [ 650.  710.  770.  830.  890.]
 [ 900.  985. 1070. 1155. 1240.]
 [1150. 1260. 1370. 1480. 1590.]]
[[ 0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.]
 [ 0.  0.  0.  1.  2.]
 [ 3.  4.  5.  6.  7.]
 [ 8.  9. 10. 11. 12.]]


In [362]:
#print out the tensor module script
print(SuperTensorModule.script())

# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func
    def matmul(A: T.Buffer((5, 5), "float32"), B: T.Buffer((5, 5), "float32"), C: T.Buffer((5, 5), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for i, j, k in T.grid(5, 5, 5):
            with T.block("C"):
                vi, vj, vk = T.axis.remap("SSR", [i, j, k])
                T.reads(A[vi, vk], B[vk, vj])
                T.writes(C[vi, vj])
                with T.init():
                    C[vi, vj] = T.float32(0.0)
                C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]

    @T.prim_func
    def relu(A: T.Buffer((5, 5), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for i, j in T.grid(5, 5):
            with T.block("A"):
                vi, vj = T.axis.remap("SS", [i, j])
                T.reads(A[vi, vj])
                T.writes(A[vi, vj])
     

In [363]:
#create a schedule, which is a wrapper for your module that allows for tensor function transformations
schedule = tvm.tir.Schedule(SuperTensorModule)

#get a block from the module, our basic unit of computation, which we can transform
block_C = schedule.get_block("C", func_name="matmul")

#gets the iterators of the block C
i, j, k = schedule.get_loops(block_C)

#split the iterator j into j0 and j1, to reduce cache stride
j0, j1 = schedule.split(j, factors=[None, 2])

In [364]:
print(schedule.mod.script())

# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func
    def matmul(A: T.Buffer((5, 5), "float32"), B: T.Buffer((5, 5), "float32"), C: T.Buffer((5, 5), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for i, j_0, j_1, k in T.grid(5, 3, 2, 5):
            with T.block("C"):
                vi = T.axis.spatial(5, i)
                vj = T.axis.spatial(5, j_0 * 2 + j_1)
                vk = T.axis.reduce(5, k)
                T.where(j_0 * 2 + j_1 < 5)
                T.reads(A[vi, vk], B[vk, vj])
                T.writes(C[vi, vj])
                with T.init():
                    C[vi, vj] = T.float32(0.0)
                C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]

    @T.prim_func
    def relu(A: T.Buffer((5, 5), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for i, j in T.grid(5, 5):
            with T.block("A")

In [365]:
#alters the order of the loops, once again for caching purposes
schedule.reorder(i, j0, k, j1)
print(schedule.mod.script())

# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func
    def matmul(A: T.Buffer((5, 5), "float32"), B: T.Buffer((5, 5), "float32"), C: T.Buffer((5, 5), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for i, j_0, k, j_1 in T.grid(5, 3, 5, 2):
            with T.block("C"):
                vi = T.axis.spatial(5, i)
                vj = T.axis.spatial(5, j_0 * 2 + j_1)
                vk = T.axis.reduce(5, k)
                T.where(j_0 * 2 + j_1 < 5)
                T.reads(A[vi, vk], B[vk, vj])
                T.writes(C[vi, vj])
                with T.init():
                    C[vi, vj] = T.float32(0.0)
                C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]

    @T.prim_func
    def relu(A: T.Buffer((5, 5), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for i, j in T.grid(5, 5):
            with T.block("A")

In [366]:
@tvm.script.ir_module
class TensorModuleII():
    @tir.prim_func
    def mm_relu(A: tir.Buffer((5,5), "float32"), 
                B: tir.Buffer((5,5), "float32"), 
                X: tir.Buffer((5,5), "float32"))  :
        tir.func_attr({"global_symbol": "mm_relu", "tir.noalias": True})

        #allocates a buffer (array) 
        Y = tir.alloc_buffer((5,5), "float32")
        #mm part
        for i, j, k in tir.grid(5,5,5):
            with tir.block("Y"):
                vi = tir.axis.spatial(5, i) #another way of mapping an index
                vj = tir.axis.spatial(5, j)
                vk = tir.axis.reduce(5, k) #spatial means indicies are independent, reduce means they take multiple elements and combine them to one (dependencies)
                with tir.init():
                    Y[vi, vj] = 0
                Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]

        #relu part
        for i, j in tir.grid(5,5):
            with tir.block("X"):
                vi = tir.axis.spatial(5,i)
                vj = tir.axis.spatial(5,j)
                X[vi, vj] = tir.max(Y[vi, vj], tir.float32(0.0))
                

In [367]:
newSchedule = tvm.tir.Schedule(TensorModuleII)
block_Y = newSchedule.get_block("Y", "mm_relu")
iY, jY, kY = newSchedule.get_loops(block_Y)
block_X = newSchedule.get_block("X", "mm_relu")
iX, jX = newSchedule.get_loops(block_X)

#computes the block X at the index jY
newSchedule.reverse_compute_at(block_X, jY)
print(newSchedule.mod.script())

# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func
    def mm_relu(A: T.Buffer((5, 5), "float32"), B: T.Buffer((5, 5), "float32"), X: T.Buffer((5, 5), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        Y = T.alloc_buffer((5, 5))
        for i, j in T.grid(5, 5):
            for k in range(5):
                with T.block("Y"):
                    vi, vj, vk = T.axis.remap("SSR", [i, j, k])
                    T.reads(A[vi, vk], B[vk, vj])
                    T.writes(Y[vi, vj])
                    with T.init():
                        Y[vi, vj] = T.float32(0.0)
                    Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
            with T.block("X"):
                vi, vj = T.axis.remap("SS", [i, j])
                T.reads(Y[vi, vj])
                T.writes(X[vi, vj])
                X[vi, vj] = T.max(Y[vi, vj], T.float32(0.0))
