In [1]:
import onnx
onnx_model = onnx.load("sparse_resnet18_best_onnx/resnet18_GL_16_PR_0.5_ckpt_best.onnx")

In [2]:
from tvm import relay
mod, params = relay.frontend.from_onnx(onnx_model, {'data': (1, 3, 224, 224)})

In [3]:
import tvm

const_main = relay.build_module.bind_params_by_name(mod['main'], params)
const_mod = tvm.ir.IRModule({'main': const_main})

In [4]:
desired_layouts = {'nn.conv2d': ['NHWC', 'default']}
seq = tvm.transform.Sequential([relay.transform.RemoveUnusedFunctions(),
                                relay.transform.ConvertLayout(desired_layouts),
                                relay.transform.FoldConstant()])
with tvm.transform.PassContext(opt_level=3):
    const_mod2 = seq(const_mod)

In [15]:
const_mod2['main'].body.astext().splitlines()

['#[version = "0.0.5"]',
 'free_var %data: Tensor[(1, 3, 224, 224), float32];',
 '%0 = layout_transform(%data, src_layout="NCHW", dst_layout="NHWC") /* ty=Tensor[(1, 224, 224, 3), float32] */;',
 '%1 = nn.conv2d(%0, meta[relay.Constant][0] /* ty=Tensor[(7, 7, 3, 64), float32] */, strides=[2, 2], padding=[3, 3, 3, 3], kernel_size=[7, 7], data_layout="NHWC", kernel_layout="HWIO") /* ty=Tensor[(1, 112, 112, 64), float32] */;',
 '%2 = add(%1, meta[relay.Constant][1] /* ty=Tensor[(1, 1, 1, 64), float32] */) /* ty=Tensor[(1, 112, 112, 64), float32] */;',
 '%3 = nn.relu(%2) /* ty=Tensor[(1, 112, 112, 64), float32] */;',
 '%4 = nn.max_pool2d(%3, pool_size=[3, 3], strides=[2, 2], padding=[1, 1, 1, 1], layout="NHWC") /* ty=Tensor[(1, 56, 56, 64), float32] */;',
 '%5 = nn.conv2d(%4, meta[relay.Constant][2] /* ty=Tensor[(3, 3, 64, 64), float32] */, padding=[1, 1, 1, 1], kernel_size=[3, 3], data_layout="NHWC", kernel_layout="HWIO") /* ty=Tensor[(1, 56, 56, 64), float32] */;',
 '%6 = add(%5, meta[re

In [7]:
newfunc = relay.data_dep_optimization.utils._run_opt_pass(
    const_mod2['main'],
    relay.transform._ffi_api.Conv2dToSparse2("NHWC", 3, 16, 1, 0.4)
)

In [8]:
newfunc.body.astext().splitlines()

['#[version = "0.0.5"]',
 'free_var %data: Tensor[(1, 3, 224, 224), float32];',
 '%0 = layout_transform(%data, src_layout="NCHW", dst_layout="NHWC") /* ty=Tensor[(1, 224, 224, 3), float32] */;',
 '%1 = nn.conv2d(%0, meta[relay.Constant][0] /* ty=Tensor[(7, 7, 3, 64), float32] */, strides=[2, 2], padding=[3, 3, 3, 3], kernel_size=[7, 7], data_layout="NHWC", kernel_layout="HWIO") /* ty=Tensor[(1, 112, 112, 64), float32] */;',
 '%2 = add(%1, meta[relay.Constant][1] /* ty=Tensor[(1, 1, 1, 64), float32] */) /* ty=Tensor[(1, 112, 112, 64), float32] */;',
 '%3 = nn.relu(%2) /* ty=Tensor[(1, 112, 112, 64), float32] */;',
 '%4 = nn.max_pool2d(%3, pool_size=[3, 3], strides=[2, 2], padding=[1, 1, 1, 1], layout="NHWC") /* ty=Tensor[(1, 56, 56, 64), float32] */;',
 '%5 = nn.sparse_conv2d(%4, meta[relay.Constant][2] /* ty=Tensor[(1152, 16, 1), float32] */, meta[relay.Constant][3] /* ty=Tensor[(1152), int64] */, meta[relay.Constant][4] /* ty=Tensor[(5), int64] */, kernel_size=3) /* ty=Tensor[(1, 56, 

In [13]:
from scipy import sparse
tot = []
def fvisit(e):
    if isinstance(e, relay.Call) and e.op.name == 'nn.sparse_conv2d':
        weight = tuple(i.data.numpy() for i in e.args[1:])
        weight2 = sparse.bsr_matrix(weight)
        tot.append(weight2)
relay.analysis.post_order_visit(newfunc, fvisit)

In [14]:
tot

[<64x576 sparse matrix of type '<class 'numpy.float32'>'
 	with 18432 stored elements (blocksize = 16x1) in Block Sparse Row format>,
 <64x576 sparse matrix of type '<class 'numpy.float32'>'
 	with 18432 stored elements (blocksize = 16x1) in Block Sparse Row format>,
 <64x576 sparse matrix of type '<class 'numpy.float32'>'
 	with 18432 stored elements (blocksize = 16x1) in Block Sparse Row format>,
 <64x576 sparse matrix of type '<class 'numpy.float32'>'
 	with 18432 stored elements (blocksize = 16x1) in Block Sparse Row format>,
 <128x1152 sparse matrix of type '<class 'numpy.float32'>'
 	with 73728 stored elements (blocksize = 16x1) in Block Sparse Row format>,
 <128x1152 sparse matrix of type '<class 'numpy.float32'>'
 	with 73728 stored elements (blocksize = 16x1) in Block Sparse Row format>,
 <128x1152 sparse matrix of type '<class 'numpy.float32'>'
 	with 73728 stored elements (blocksize = 16x1) in Block Sparse Row format>,
 <256x2304 sparse matrix of type '<class 'numpy.float32'