In [8]:
from tvm import topi
from tvm import te
from tvm.contrib import tedd
from tvm.auto_scheduler import ComputeDAG

def img2col(input: te.Tensor, kernel_size: tuple, stride: int, padding: tuple) -> te.Tensor:
    """
    input layout: NHWC
    """
    n, h, w, c = input.shape
    kernel_h, kernel_w = kernel_size
    assert n == 1
    out_w = kernel_h * kernel_w * c
    out_h = h * w // stride
    pad_h, pad_w = padding

    #return te.compute((out_h, out_w), lambda ii, jj: te.if_then_else(
    #    1,
    #    0,
    #    input[0, , ,jj % c]
    #), name = "img2col")

input = te.placeholder(shape = (1, 64, 56, 56), name = "input")
weights_0 = te.placeholder(shape = (64, 1, 1, 1), name = "weights_0")
weights_1 = te.placeholder(shape = (64, 64, 3, 3), name = "weights_1")
weights_2 = te.placeholder(shape = (256, 64, 1, 1), name = "weights_2")
weights_shortcut = te.placeholder(shape = (256, 1, 1, 1), name = "weights_shortcut")

mean_0 = te.placeholder(shape = (64,), name = "mean_0")
mean_1 = te.placeholder(shape = (64,), name = "mean_1")
mean_2 = te.placeholder(shape = (256,), name = "mean_2")
mean_shortcut = te.placeholder(shape = (256,), name = "mean_shortcut")

var_0 = te.placeholder(shape = (64,), name = "var_0")
var_1 = te.placeholder(shape = (64,), name = "var_1")
var_2 = te.placeholder(shape = (256,), name = "var_2")
var_shortcut = te.placeholder(shape = (256,), name = "var_shortcut")

gamma_0 = te.placeholder(shape = (64,), name = "gamma_0")
gamma_1 = te.placeholder(shape = (64,), name = "gamma_1")
gamma_2 = te.placeholder(shape = (256,), name = "gamma_2")
gamma_shortcut = te.placeholder(shape = (256,), name = "gamma_shortcut")

beta_0 = te.placeholder(shape = (64,), name = "beta_0")
beta_1 = te.placeholder(shape = (64,), name = "beta_1")
beta_2 = te.placeholder(shape = (256,), name = "beta_2")
beta_shortcut = te.placeholder(shape = (256,), name = "beta_shortcut")

#layer0
conv0 = topi.nn.conv2d(input, weights_0, 1, (0, 0), 1)
bn0, _, _ = topi.nn.batch_norm(conv0, gamma_0, beta_0, mean_0, var_0)
relu0 = topi.nn.relu(bn0)

#layer1
conv1 = topi.nn.conv2d(relu0, weights_1, 1, (1, 1), 1)
bn1, _, _ = topi.nn.batch_norm(conv1, gamma_1, beta_1, mean_1, var_1)
relu1 = topi.nn.relu(bn1)

#layer2
conv2 = topi.nn.conv2d(relu1, weights_2, 1, (0, 0), 1)
bn2, _, _ = topi.nn.batch_norm(conv2, gamma_2, beta_2, mean_2, var_2)

#shortcut
conv_shortcut = topi.nn.conv2d(input, weights_shortcut, 1, (0, 0), 1)
bn_shortcut, _, _ = topi.nn.batch_norm(conv_shortcut, gamma_shortcut, beta_shortcut, mean_shortcut, var_shortcut)

#output
out_add = topi.add(bn2, bn_shortcut)
out_relu = topi.nn.relu(out_add)

sch = te.create_schedule(out_relu.op)

'''
print(tvm.lower(sch, [input, weights_0, weights_1, weights_2, weights_shortcut,
                      mean_0, mean_1, mean_2, mean_shortcut,
                      var_0, var_1, var_2, var_shortcut,
                      gamma_0, gamma_1, gamma_2, gamma_shortcut,
                      beta_0, beta_1, beta_2, beta_shortcut, out_relu], simple_mode = True))
'''
#tedd.viz_dataflow_graph(sch,show_svg = True)
dag = ComputeDAG([input, weights_0, weights_1, weights_2, weights_shortcut,
                      mean_0, mean_1, mean_2, mean_shortcut,
                      var_0, var_1, var_2, var_shortcut,
                      gamma_0, gamma_1, gamma_2, gamma_shortcut,
                      beta_0, beta_1, beta_2, beta_shortcut, out_relu])
dag.print_python_code_from_state(dag.get_init_state())

'pad_temp_i0, pad_temp_i1, pad_temp_i2, pad_temp_i3 = tuple(pad_temp.op.axis) + tuple(pad_temp.op.reduce_axis)\nconv2d_nchw_nn, conv2d_nchw_ff, conv2d_nchw_yy, conv2d_nchw_xx, conv2d_nchw_rc, conv2d_nchw_ry, conv2d_nchw_rx = tuple(conv2d_nchw.op.axis) + tuple(conv2d_nchw.op.reduce_axis)\nT_reshape_ax0, T_reshape_ax1, T_reshape_ax2, T_reshape_ax3 = tuple(T_reshape.op.axis) + tuple(T_reshape.op.reduce_axis)\nT_subtract_ax0, T_subtract_ax1, T_subtract_ax2, T_subtract_ax3 = tuple(T_subtract.op.axis) + tuple(T_subtract.op.reduce_axis)\nT_reshape_ax0, T_reshape_ax1, T_reshape_ax2, T_reshape_ax3 = tuple(T_reshape.op.axis) + tuple(T_reshape.op.reduce_axis)\nT_add_ax0, T_add_ax1, T_add_ax2, T_add_ax3 = tuple(T_add.op.axis) + tuple(T_add.op.reduce_axis)\ncompute_i0, compute_i1, compute_i2, compute_i3 = tuple(compute.op.axis) + tuple(compute.op.reduce_axis)\nT_divide_ax0, T_divide_ax1, T_divide_ax2, T_divide_ax3 = tuple(T_divide.op.axis) + tuple(T_divide.op.reduce_axis)\nT_reshape_ax0, T_reshape_