<a href="https://colab.research.google.com/github/XueyanZhang/MachineLearningCompilation/blob/master/6_Integration_with_ML_Framework.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Integration with Machine Learning Frameworks

In [1]:
!python3 -m  pip install mlc-ai-nightly -f https://mlc.ai/wheels

import numpy as np
import tvm
from tvm import relax
from tvm.ir.module import IRModule
from tvm.script import relax as R
from tvm.script import tir as T

import torch
import torch.nn as nn
from torch import fx
from torch.nn import functional as F

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in links: https://mlc.ai/wheels
Collecting mlc-ai-nightly
  Downloading https://github.com/mlc-ai/utils/releases/download/v0.9.dev0/mlc_ai_nightly-0.12.dev878%2Bgd65d311af-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (52.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m52.1/52.1 MB[0m [31m8.2 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: mlc-ai-nightly
Successfully installed mlc-ai-nightly-0.12.dev878+gd65d311af


# Build IRModule via Builder

past IRModules was written mannual in TVMScript.

it doesn't meet the demand for large systems.

## Tensor Expression

Tensor Expression for TensorIR Creation

In [2]:
from tvm import te

# create input
f32 = "float32"
A = te.placeholder((128, 128), name="A", dtype="float32")
B = te.placeholder((128, 128), name="B", dtype="float32")
print(type(A))
print(A.shape)

<class 'tvm.te.tensor.Tensor'>
[128, 128]


In [3]:
# define computation
def te_matmul(A: te.Tensor, B: te.Tensor) -> te.Tensor:
    assert A.shape[1] == B.shape[0]
    m = A.shape[0]
    n = B.shape[1]
    k = te.reduce_axis((0, A.shape[1]), name='k')
    return te.compute((m, n), lambda i, j: te.sum(A[i, k] * B[k, j], axis=k), name="matmul")

In [8]:
# create result
C = te_matmul(A, B)
print(C)


Tensor(shape=[128, 128], op.name=matmul)


In [9]:
# create TensorIR function
te.create_prim_func([A, B, C]).show()

To print formatted TVM script, please install the formatter 'Black':
/usr/bin/python3 -m pip install "black==22.3.0" --upgrade --user


In [10]:
# create relu func
def te_relu(A: te.Tensor) -> te.Tensor:
    return te.compute(A.shape, lambda *i: te.max(A(*i), 0), name="relu")

the `*i` represents arbitrary shape index. here are some examples:

In [11]:
# 1D input
X1 = te.placeholder((10, ), dtype=f32, name='X1')
Y1 = te_relu(X1)
te.create_prim_func([X1, Y1])

# from tvm.script import tir as T

@T.prim_func
def main(X1: T.Buffer((10,), "float32"), relu: T.Buffer((10,), "float32")):
    T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)})
    # with T.block("root"):
    for i0 in range(10):
        with T.block("relu"):
            v_i0 = T.axis.spatial(10, i0)
            T.reads(X1[v_i0])
            T.writes(relu[v_i0])
            relu[v_i0] = T.max(X1[v_i0], T.float32(0))

In [12]:
# 2D input
X2 = te.placeholder((10, 10), dtype=f32, name='X2')
Y2 = te_relu(X2)
te.create_prim_func([X2, Y2])

# from tvm.script import tir as T

@T.prim_func
def main(X2: T.Buffer((10, 10), "float32"), relu: T.Buffer((10, 10), "float32")):
    T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)})
    # with T.block("root"):
    for i0, i1 in T.grid(10, 10):
        with T.block("relu"):
            v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
            T.reads(X2[v_i0, v_i1])
            T.writes(relu[v_i0, v_i1])
            relu[v_i0, v_i1] = T.max(X2[v_i0, v_i1], T.float32(0))

## fuse / fusion 

In [13]:
# fuse matmul w/ relu
C = te_matmul(A, B)
D = te_relu(C)

te.create_prim_func([A, B, D]).show()

To print formatted TVM script, please install the formatter 'Black':
/usr/bin/python3 -m pip install "black==22.3.0" --upgrade --user


note only A B D are passed in, and C omitted.

above, we can see prim func create temp buffer for C/matmul.

we can still pass in C. however, the fusion is less advanced.

In [14]:
te.create_prim_func([A, B, C, D]).show()

To print formatted TVM script, please install the formatter 'Black':
/usr/bin/python3 -m pip install "black==22.3.0" --upgrade --user


note how C/matmul is passed in.

# Build IRModule via BlockBuilder 