In [1]:
import torch
from torch import nn
from torch import fx
from torch.nn import functional as F

from d2l import torch as d2l

In [3]:
# This is needed for deferring annotation parsing in TVMScript
from __future__ import annotations
import numpy as np
import tvm
from tvm import relax
from tvm.ir.module import IRModule
from tvm.script import relax as R
from tvm.script import tir as T

from tvm import te
from tvm import topi

In [4]:
import IPython

In [5]:
import sys
sys.path.append('..')

定义模型

In [6]:
net = nn.Sequential(
    nn.Conv2d(1, 6, kernel_size=5, padding=2), nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=2, stride=2),
    nn.Conv2d(6, 16, kernel_size=5), nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=2, stride=2),
    nn.Flatten(),
    nn.Linear(16 * 5 * 5, 120), nn.Sigmoid(),
    nn.Linear(120, 84), nn.Sigmoid(),
    nn.Linear(84, 10))

In [7]:
data = torch.rand(256, 1, 28, 28)
with torch.autograd.profiler.profile(use_cuda=False) as prof:
    # for i in range(10):
    net(data)
print(prof)

----------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                        Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
----------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                aten::conv2d         0.70%     180.000us        57.81%      14.848ms      14.848ms             1  
           aten::convolution         1.28%     329.000us        57.11%      14.668ms      14.668ms             1  
          aten::_convolution         0.17%      44.000us        55.83%      14.339ms      14.339ms             1  
    aten::mkldnn_convolution        55.52%      14.260ms        55.66%      14.295ms      14.295ms             1  
                 aten::empty         0.11%      27.000us         0.11%      27.000us      27.000us             1  
           aten::as_strided_         0.03%       8.000us         0.03%       8.0

转换为计算图，并打印一些计算图信息

In [8]:
fx_module = fx.symbolic_trace(net)

In [9]:
# fx_module 包含一个简单的计算图，可以打印成表格便于查看。我们的目标是将此图转换为 IRModule。
fx_module.graph.print_tabular()

opcode       name     target    args        kwargs
-----------  -------  --------  ----------  --------
placeholder  input_1  input     ()          {}
call_module  _0       0         (input_1,)  {}
call_module  _1       1         (_0,)       {}
call_module  _2       2         (_1,)       {}
call_module  _3       3         (_2,)       {}
call_module  _4       4         (_3,)       {}
call_module  _5       5         (_4,)       {}
call_module  _6       6         (_5,)       {}
call_module  _7       7         (_6,)       {}
call_module  _8       8         (_7,)       {}
call_module  _9       9         (_8,)       {}
call_module  _10      10        (_9,)       {}
call_module  _11      11        (_10,)      {}
output       output   output    (_11,)      {}


In [10]:
# 为了展示torch中nn函数的计算图
conv2d = fx.symbolic_trace(net[0])
sigmoid = fx.symbolic_trace(net[1])
pooling = fx.symbolic_trace(net[2])
flatten = fx.symbolic_trace(net[6])
linear  = fx.symbolic_trace(net[7])

In [11]:
conv2d.graph.print_tabular()

opcode         name     target                                                     args                                                kwargs
-------------  -------  ---------------------------------------------------------  --------------------------------------------------  --------
placeholder    input_1  input                                                      ()                                                  {}
get_attr       weight   weight                                                     ()                                                  {}
get_attr       bias     bias                                                       ()                                                  {}
call_function  conv2d   <built-in method conv2d of type object at 0x7fb7924f8780>  (input_1, weight, bias, (1, 1), (2, 2), (1, 1), 1)  {}
output         output   output                                                     (conv2d,)                                           {}


定义torch中函数/Module与tir中函数/Module的映射

In [12]:
from relax_wrapper import map_param, from_fx

In [13]:
def te_avgpool2d(A: te.Tensor, kernel_size, stride) -> te.Tensor:
    # print("avg pool2d: ", A.shape)
    assert(len(A.shape) == 4) # batch_size, channel, height, width
    m = A.shape[-2]
    n = A.shape[-1]
    shape = (A.shape[0], A.shape[1], te.indexdiv(m-kernel_size, stride) + 1, te.indexdiv(n-kernel_size, stride) + 1)
    hk = te.reduce_axis((0, kernel_size), name="hk") # row
    wk = te.reduce_axis((0, kernel_size), name="wk") # col
    return te.compute(shape, lambda b, c, h, w: te.sum(A[b, c, h+hk, w+wk], axis=[hk, wk]), name="avgpool2d")

In [14]:
# 定义映射Module OP, 这里用的三种方式
# 1. 使用TE自定义映射函数
# 2. 使用topi中包含的TE函数
# 3. 使用relax.op中包含的函数

# relax.nn.flatten和relax.nn.conv2d利用现成的pass降级还存在问题, 直接用topi.nn替代了

def map_nn_relu_op(bb: relax.BlockBuilder, node_map, node: fx.node.Node, nn_mod: nn.ReLU):
    A = node_map[node.args[0]]
    return bb.emit(relax.op.relu(A))

def map_nn_sigmoid_op(bb: relax.BlockBuilder, node_map, node: fx.node.Node, nn_mod: nn.Sigmoid):
    A = node_map[node.args[0]]
    return bb.emit_te(topi.sigmoid, A)

def map_nn_flatten_op(bb: relax.BlockBuilder, node_map, node: fx.node.Node, nn_mod: nn.Flatten):
    A = node_map[node.args[0]]
    # return bb.emit(relax.op.flatten(A))
    return bb.emit_te(topi.nn.flatten, A)

def map_nn_linear_op(bb: relax.BlockBuilder, node_map, node: fx.node.Node, nn_mod: nn.Linear):
    x = node_map[node.args[0]]
    w = map_param(nn_mod.weight)
    if nn_mod.bias is not None:
        b = map_param(nn_mod.bias)
    y = bb.emit(relax.op.dense(x, w))
    return bb.emit(relax.op.add(y, b))

def map_nn_conv2d_op(bb: relax.BlockBuilder, node_map, node: fx.node.Node, nn_mod: nn.Conv2d):
    # print("conv2d: out_channels: ", nn_mod.out_channels)
    # print("conv2d: kernel_size: ", nn_mod.kernel_size)
    # print("conv2d: stride: ", nn_mod.stride)
    # print("conv2d: padding: ", nn_mod.padding)

    x = node_map[node.args[0]]
    w = map_param(nn_mod.weight)
    if nn_mod.bias is not None:
        b = map_param(nn_mod.bias)
    # return bb.emit(relax.op.conv2d(x, w, channels=nn_mod.out_channels,
    #             kernel_size=nn_mod.kernel_size,
    #             strides=nn_mod.stride, padding=nn_mod.padding))
    return bb.emit_te(topi.nn.conv2d, x, w, nn_mod.stride, nn_mod.padding, nn_mod.dilation)

def map_nn_avgpool2d_op(bb: relax.BlockBuilder, node_map, node: fx.node.Node, nn_mod: nn.AvgPool2d):
    # print("avgpool2d: kernel_size: ", nn_mod.kernel_size)
    # print("avgpool2d: stride: ", nn_mod.stride)
    A = node_map[node.args[0]]
    return bb.emit_te(te_avgpool2d, A, nn_mod.kernel_size, nn_mod.stride)

转换为tir，并计算当前基线的运行性能

In [24]:
LeNetModule = from_fx(
    fx.symbolic_trace(net),
    input_shapes = [(256, 1, 28, 28)],
    call_function_map={
    },
    call_module_map={
        torch.nn.Linear: map_nn_linear_op,
        torch.nn.Sigmoid: map_nn_sigmoid_op,
        torch.nn.ReLU: map_nn_relu_op,
        torch.nn.Flatten: map_nn_flatten_op,
        torch.nn.Conv2d: map_nn_conv2d_op,
        torch.nn.AvgPool2d: map_nn_avgpool2d_op
    },
)

In [16]:
from tvm.relax.testing import transform
from tvm.relax.transform.tuning_api import Trace

In [18]:
# 降级转为tir
target = tvm.target.Target("llvm")
with target, tvm.transform.PassContext(trace=Trace(LeNetModule), opt_level=0):
    seq = tvm.transform.Sequential(
    [
        transform.LowerWithRelayOpStrategyPass(target)
    ])
    NewLeNetModule = seq(LeNetModule)
# 合并函数
for var, func in LeNetModule.functions.items():
    if not isinstance(func, tvm.tir.function.PrimFunc):
        continue
    if var in NewLeNetModule.functions.keys():
        continue
    NewLeNetModule[var] = func
print(len(NewLeNetModule.functions))

16


In [19]:
assert relax.analysis.well_formed(NewLeNetModule)

a_np = np.random.rand(256, 1, 28, 28).astype("float32")
a_nd = tvm.nd.array(a_np)

exec = relax.vm.build(NewLeNetModule, target=target)
vm = relax.VirtualMachine(exec, tvm.cpu())

f_timer_baseline = vm.time_evaluator("main", tvm.cpu())
print("Time cost of baseline LeNetModule: %f s" % (f_timer_baseline(a_nd).mean))

Time cost of baseline LeNetModule: 0.084377 s


开始优化：基本算子：conv2d：
1. 定义优化后的conv2d算子；
2. 修改relax函数：改变main中的call_tir指向的var为新的conv2d函数
3. 合并函数

In [20]:
def conv2d_opt_function_(func: tvm.tir.PrimFunc, new_symbolname):
    # 1. 定义新的func, 并指定其名称
    sch = tvm.tir.Schedule(IRModule({new_symbolname: func.with_attr({"global_symbol": new_symbolname})}))
    # 2. 优化pad_temp block
    block_pad = sch.get_block("pad_temp", func_name=new_symbolname)
    pad_i, pad_j, pad_m, pad_n = sch.get_loops(block=block_pad)
    sch.parallel(pad_i)
    sch.unroll(pad_j)
    sch.vectorize(pad_n)
    # 3. 优化conv2d_nchw block
    block_nchw = sch.get_block("conv2d_nchw", func_name=new_symbolname)
    # batch/out_channel/pad_h/pad_w/in_channel/kernel_h/kernel_w
    i, j, m, n, k_i, k_h, k_w = sch.get_loops(block=block_nchw)
    sch.reorder(i, k_i, k_h, k_w, j, m, n)
    sch.parallel(i)
    sch.unroll(k_i)
    sch.vectorize(n)
    return sch.mod[new_symbolname]

def sigmoid_opt_function_(func: tvm.tir.PrimFunc, new_symbolname):
    # 1. 定义新的func, 并指定其名称
    sch = tvm.tir.Schedule(IRModule({new_symbolname: func.with_attr({"global_symbol": new_symbolname})}))
    # 2. 并行化、向量化
    block_compute = sch.get_block("compute", func_name=new_symbolname)
    loop_axis = sch.get_loops(block=block_compute)
    sch.parallel(loop_axis[0])
    sch.vectorize(loop_axis[-1]) 
    return sch.mod[new_symbolname]

# [REMOVE]
def tir_func_optimizer_schedule(mod: tvm.ir.module.IRModule):
    # 1. 移除没有使用的Functions
    new_mod = relax.transform.RemoveUnusedFunctions()(mod)
    # 2. 针对不同的op进行优化
    for global_var, func in new_mod.functions.items():
        func_name = func.attrs["global_symbol"]
        opt_func_name = func_name + "_opt"
        if func_name.startswith("conv2d") and not func_name.endswith("opt"):
            new_mod[opt_func_name] = conv2d_opt_function_(func, opt_func_name)
    return new_mod

# 完成对main函数中对tir函数调用的修改
@relax.expr_functor.mutator
class LeNetModuleRewriter(relax.PyExprMutator):
    def __init__(self, mod: IRModule) -> None:
        super().__init__()
        self.mod_ = mod
        self.functions = [call.attrs["global_symbol"] for call in mod.functions.values()]
        self.dense_op = tvm.ir.Op.get("relax.nn.dense")
        self.add_op = tvm.ir.Op.get("relax.add")

    # 处理relax func中的call
    def visit_call_(self, call: relax.expr.Call) -> relax.expr.Expr:
        call: relax.expr.Call = self.visit_expr_post_order(call)
        call_tir_op = tvm.ir.Op.get("relax.call_tir")
        # 完成算子替换
        if call.op == call_tir_op:
            # print("[args len]:", len(call.args), "[args 0]:", call.args[0], type(call.args[0]), len(call.args[1]), call.args[2])
            # print(self.mod_[call.args[0]])
            tir_func: tvm.tir.PrimFunc = self.mod_[call.args[0]]
            func_name = tir_func.attrs["global_symbol"]
            opt_func_name = func_name + "_opt"
            if func_name.endswith("opt"): # opt func
                return call
            if opt_func_name in self.functions: # already opt
                return call
            # print(opt_func_name, func_name)
            if func_name.startswith("conv2d"):
                global_var = self.builder_.add_func(conv2d_opt_function_(tir_func, opt_func_name),
                        opt_func_name)
            elif func_name.startswith("sigmoid"):
                global_var = self.builder_.add_func(sigmoid_opt_function_(tir_func, opt_func_name),
                        opt_func_name)
            else:
                global_var = self.builder_.add_func(tir_func, func_name)
            return relax.call_tir(func=global_var, args=[x for x in call.args[1]], shape=call.args[2], dtype="float32")
        
        return call 

    def transform(self) -> IRModule:
        for global_var, func in self.mod_.functions.items():
            
            if not isinstance(func, relax.Function):
                # 复制全部的tir function
                # self.builder_.add_func(func, func.attrs["global_symbol"])
                continue
            # 处理relax function, 找到call指令并优化函数, 和改变call的对象
            updated_func = self.visit_expr(func)
            updated_func = relax.analysis.remove_all_unused(updated_func)
            self.builder_.update_func(global_var, updated_func)
        return self.builder_.get()

@tvm.ir.transform.module_pass(opt_level=2, name="LeNetModuleRewriter")
class LeNetModuleRewriterPass:
    """The wrapper for the LeNetModuleRewriter pass."""
    def transform_module(self, mod, ctx):
        return LeNetModuleRewriter(mod).transform()

In [21]:
NewLeNetModule_OPT = LeNetModuleRewriterPass()(NewLeNetModule)

In [22]:
assert relax.analysis.well_formed(NewLeNetModule_OPT)

print("total functions number: ", len(NewLeNetModule_OPT.functions))

a_np = np.random.rand(256, 1, 28, 28).astype("float32")
a_nd = tvm.nd.array(a_np)

exec = relax.vm.build(NewLeNetModule_OPT, target=target)
vm = relax.VirtualMachine(exec, tvm.cpu())

f_timer_baseline = vm.time_evaluator("main", tvm.cpu())
print("Time cost of conv2d opt LeNetModule: %f s" % (f_timer_baseline(a_nd).mean))

total functions number:  16
Time cost of conv2d opt LeNetModule: 0.044542 s


开始优化：算子融合

In [23]:
# 完成对main函数中对tir函数调用的修改
@relax.expr_functor.mutator
class LeNetModuleFuseRewriter(relax.PyExprMutator):
    def __init__(self, mod: IRModule) -> None:
        super().__init__()
        self.mod_ = mod
        self.functions = [call.attrs["global_symbol"] for call in mod.functions.values()]
        self.call_tir_op = tvm.ir.Op.get("relax.call_tir")
        self.dense_op = tvm.ir.Op.get("relax.nn.dense")
        self.add_op = tvm.ir.Op.get("relax.add")
        self.counter = {"fused_dense_add": 0}

    def visit_call_(self, call: relax.expr.Call) -> relax.expr.Expr:
        call: relax.expr.Call = self.visit_expr_post_order(call)
        def match_call(node, op):
            if not isinstance(node, relax.Call):
                return False
            return node.op == op
        # 完成算子替换
        if match_call(call, self.call_tir_op):
            # print("[args len]:", len(call.args), "[args 0]:", call.args[0], type(call.args[0]), len(call.args[1]), call.args[2])
            # print(self.mod_[call.args[0]])
            tir_func: tvm.tir.PrimFunc = self.mod_[call.args[0]]
            func_name = tir_func.attrs["global_symbol"]
            opt_func_name = func_name + "_opt"
            if func_name.endswith("opt"): # opt func
                return call
            if opt_func_name in self.functions: # already opt
                return call
            if func_name.startswith("conv2d"):
                global_var = self.builder_.add_func(conv2d_opt_function_(tir_func, opt_func_name),
                        opt_func_name)
            elif func_name.startswith("sigmoid"):
                global_var = self.builder_.add_func(sigmoid_opt_function_(tir_func, opt_func_name),
                        opt_func_name)
            else:
                global_var = self.builder_.add_func(tir_func, func_name)
            return relax.call_tir(func=global_var, args=[x for x in call.args[1]], shape=call.args[2], dtype="float32")
        # 完成算子融合, relax.op
        # pattern match dense => add
        if not match_call(call, self.add_op):
            return call
        # 通过add的第0个参数找到dense
        value = self.lookup_binding(call.args[0])
        if value is None:
            return call
        if not match_call(value, self.dense_op):
            return call
        x = value.args[0]
        w = value.args[1]
        b = call.args[1]
        # 注意，参数绑定的是每个函数, 所需需要为我们的fused创建新的const
        # construct a new fused primitive function
        param_x = relax.Var("x", x.shape_, x._checked_type_)
        param_w = relax.Var("w", w.shape_, w._checked_type_)
        param_b = relax.Var("b", b.shape_, b._checked_type_)

        bb = relax.BlockBuilder()
        fn_name = "fused_dense_add%d" % (self.counter["fused_dense_add"])
        self.counter["fused_dense_add"] += 1
        # [NOTE]注意这里参数绑定的问题
        # [NOTE]注意这里emit_func_output需要指定func的output和input
        fn_output = None
        with bb.function(fn_name, [param_x, param_w, param_b]):
            with bb.dataflow():
                lv0 = bb.emit(relax.op.dense(param_x, param_w))
                lv1 = bb.emit(relax.op.add(lv0, param_b))
                assert fn_output is None
                fn_output = bb.emit_output(lv1)
            bb.emit_func_output(fn_output)

        # Add Primitive attribute to the fused funtions
        fused_fn = bb.get()[fn_name].with_attr("global_symbol", fn_name)
        fused_fn = fused_fn.with_attr("Primitive", 1)
        normalized = self.builder_.normalize(fused_fn)
        global_var = self.builder_.add_func(normalized, fn_name)
        #[NOTE] 
        return relax.Call(global_var, [x, w, b], None, None)

    def transform(self) -> IRModule:
        for global_var, func in self.mod_.functions.items():
            if not isinstance(func, relax.Function):
                continue
            # avoid already fused primitive functions
            if "Primitive" in func.attrs.keys() and func.attrs["Primitive"] != 0:
                continue
            updated_func = self.visit_expr(func)
            updated_func = relax.analysis.remove_all_unused(updated_func)
            self.builder_.update_func(global_var, updated_func)

        return self.builder_.get()

@tvm.ir.transform.module_pass(opt_level=2, name="LeNetModuleFuseRewriter")
class LeNetModuleFuseRewriterPass:
    """The wrapper for the LeNetModuleFuseRewriter pass."""
    def transform_module(self, mod, ctx):
        return LeNetModuleFuseRewriter(mod).transform()

In [25]:
FuseLeNetModule = LeNetModuleFuseRewriterPass()(LeNetModule)

In [28]:
# 降级转为tir
target = tvm.target.Target("llvm")
with target, tvm.transform.PassContext(trace=Trace(LeNetModule), opt_level=0):
    seq = tvm.transform.Sequential(
    [
        transform.LowerWithRelayOpStrategyPass(target)
    ])
    FuseLeNetModuleTIR = seq(FuseLeNetModule)
# 合并函数
for var, func in FuseLeNetModule.functions.items():
    if not isinstance(func, tvm.tir.function.PrimFunc):
        continue
    if var in FuseLeNetModuleTIR.functions.keys():
        continue
    FuseLeNetModuleTIR[var] = func
# 【******】将合成的relax函数转化为tir函数
FuseLeNetModuleFinal = relax.transform.FuseTIR()(FuseLeNetModuleTIR)
assert relax.analysis.well_formed(FuseLeNetModuleFinal)

In [29]:
a_np = np.random.rand(256, 1, 28, 28).astype("float32")
a_nd = tvm.nd.array(a_np)

exec = relax.vm.build(FuseLeNetModuleFinal, target="llvm")
vm = relax.VirtualMachine(exec, tvm.cpu())

f_timer_baseline = vm.time_evaluator("main", tvm.cpu())
print("Time cost of baseline LeNetModule: %f s" % (f_timer_baseline(a_nd).mean))

Time cost of baseline LeNetModule: 0.044521 s
