In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Test(nn.Module):
    def __init__(self):
        super(Test, self).__init__()
        self.score4 = nn.Conv2d(23, 22, 1)
        self.score_fr = nn.Conv2d(22, 21, 1, bias=False)
    
    def forward(self, inp):
        h = inp
        h = self.score4(h)
        h = self.score_fr(h)

        return h
        

model = Test()

In [2]:
model.train()
input_shape = [1, 23, 13, 19]
input_data = torch.randn(input_shape)
scripted_model = torch.jit.trace(model, input_data)
scripted_model

Test(
  original_name=Test
  (score4): Conv2d(original_name=Conv2d)
  (score_fr): Conv2d(original_name=Conv2d)
)

In [3]:
import tvm
from tvm import relay
from tvm.contrib import graph_runtime

input_name = 'input0'
shape_list = [(input_name, input_shape)]
mod, params = relay.frontend.from_pytorch(scripted_model, shape_list)
mod

IRModuleNode( {GlobalVar(main): FunctionNode([Var(input0, ty=TensorType([1, 23, 13, 19], float32)), Var(score4.weight, ty=TensorType([22, 23, 1, 1], float32)), Var(score4.bias, ty=TensorType([22], float32)), Var(score_fr.weight, ty=TensorType([21, 22, 1, 1], float32))], TensorType([1, 21, 13, 19], float32), CallNode(Op(nn.conv2d), [CallNode(Op(nn.bias_add), [CallNode(Op(nn.conv2d), [Var(input0, ty=TensorType([1, 23, 13, 19], float32)), Var(score4.weight, ty=TensorType([22, 23, 1, 1], float32))], relay.attrs.Conv2DAttrs(000001F122E72048), [TensorType([1, 23, 13, 19], float32), TensorType([22, 23, 1, 1], float32)]), Var(score4.bias, ty=TensorType([22], float32))], relay.attrs.BiasAddAttrs(000001F122FEF698), [TensorType([1, 22, 13, 19], float32), TensorType([22], float32)]), Var(score_fr.weight, ty=TensorType([21, 22, 1, 1], float32))], relay.attrs.Conv2DAttrs(000001F122E71808), [TensorType([1, 22, 13, 19], float32), TensorType([21, 22, 1, 1], float32)]), [], (nullptr))})

In [4]:
fn = mod['main']
# visualize(fn)

In [5]:
import numpy

def work_on_fn(pass_cls):
    def apply_pass(fn_or_mod):
        if isinstance(fn_or_mod, tvm.IRModule):
            return pass_cls()(fn_or_mod)
        if isinstance(fn_or_mod, tvm.relay.Function):
            return pass_cls()(
                       tvm.IRModule({'main': fn_or_mod}))['main']
        raise NotImplemented("unsupporded type {}".format(type(fn_or_mod)))
    return apply_pass

infer_type = work_on_fn(tvm.relay.transform.InferType)
to_graph_normal_form = work_on_fn(tvm.relay.transform.ToGraphNormalForm)
dead_code_elimination = work_on_fn(tvm.relay.transform.DeadCodeElimination)
eliminate_common_subexpr = work_on_fn(tvm.relay.transform.EliminateCommonSubexpr)

class ShapeConstDedupMutator(tvm.relay.ExprMutator):
    def __init__(self):
        super().__init__()
        self.shape_consts = {}

    def visit_call(self, call):
        if (isinstance(call.op, tvm.ir.Op) 
            and call.op.name in {"reshape", "broadcast_to", "collapse_sum_to"}
            and isinstance(call.args[1], tvm.relay.Constant)):
            # assert list(call.attrs.newshape) == list(call.args[1].data.asnumpy())
            new_fn = self.visit(call.op)
            new_args = [self.visit(arg) for arg in call.args]
            const = new_args[1]
            assert const.data.dtype.startswith('int') and len(const.data.shape)==1
            key = tuple(const.data.asnumpy())
            if key in self.shape_consts:
                new_args[1] = self.shape_consts[key]
            else:
                self.shape_consts[key] = new_args[1]
            return tvm.relay.Call(new_fn, new_args, call.attrs)
        return super().visit_call(call)


class TransposeDedupMutator(tvm.relay.ExprMutator):
    def visit_call(self, call):
        if (isinstance(call.op, tvm.ir.Op) and call.op.name == "transpose"
            and isinstance(call.args[0], tvm.relay.Call) 
            and isinstance(call.args[0].op, tvm.ir.Op) and call.args[0].op.name == "transpose"):
            axes = [call.args[0].attrs.axes[int(i)] for i in call.attrs.axes]
            new_inp = self.visit(call.args[0].args[0])
            if axes == list(range(len(axes))): # neutral permutation, should really do this separately...
                return new_inp
            return tvm.relay.transpose(new_inp, axes)
        return super().visit_call(call)

#@tvm.relay.transform.function_pass(opt_level=1)
#def TransposeDedup(fn, mod, ctx):
#    return TransposeDedupMutator().visit(fn)

class ZeroZapp(tvm.relay.dataflow_pattern.DFPatternCallback):
    def __init__(self):
        self.zeros = tvm.relay.dataflow_pattern.is_op("zeros")(tvm.relay.dataflow_pattern.wildcard())
        self.other_tensor = tvm.relay.dataflow_pattern.wildcard()
        self.pattern = (self.zeros + self.other_tensor) | (self.other_tensor + self.zeros)
        self.require_type = True

    def callback(self, pre, post, node_map):
        rt = node_map[self.pattern][0]
        ot = node_map[self.other_tensor][0]
        if (ot._checked_type_ == rt._checked_type_):
            return ot
        else:
            return tvm.relay.broadcast_to(ot, list(rt._checked_type_.shape))

class ZeroZapp(tvm.relay.dataflow_pattern.DFPatternCallback):
    def __init__(self):
        self.ones = tvm.relay.dataflow_pattern.is_op("zeros")(tvm.relay.dataflow_pattern.wildcard()) | tvm.relay.dataflow_pattern.is_constant()
        self.other_tensor = tvm.relay.dataflow_pattern.wildcard()
        self.pattern = (self.ones + self.other_tensor) | (self.other_tensor + self.ones)
        self.require_type = True

    def callback(self, pre, post, node_map):
        rt = node_map[self.pattern][0]
        ones = node_map[self.ones][0]
        ot = node_map[self.other_tensor][0]
        if isinstance(ones, tvm.relay.Constant):
            val = ones.data.asnumpy()
            if not ((val == 0) if numpy.isscalar(val) else (val == 0).all()):
                return rt
        # I don't know why I don't reliably get checked types here...
        if (((rt._checked_type_ is not None) and (ot._checked_type_ == rt._checked_type_))
            or (rt.type_args[0] == rt.type_args[1])):
            return ot
        elif (rt._checked_type_ is not None):
            return tvm.relay.broadcast_to(ot, list(rt._checked_type_.shape))
        return rt

class OneZapp(tvm.relay.dataflow_pattern.DFPatternCallback):
    def __init__(self):
        self.ones = tvm.relay.dataflow_pattern.is_op("ones")(tvm.relay.dataflow_pattern.wildcard()) | tvm.relay.dataflow_pattern.is_constant()
        self.other_tensor = tvm.relay.dataflow_pattern.wildcard()
        self.pattern = (self.ones * self.other_tensor) | (self.other_tensor * self.ones)
        self.require_type = True

    def callback(self, pre, post, node_map):
        global val
        rt = node_map[self.pattern][0]
        ones = node_map[self.ones][0]
        ot = node_map[self.other_tensor][0]
        if isinstance(ones, tvm.relay.Constant):
            val = ones.data.asnumpy()
            if not ((val == 1) if numpy.isscalar(val) else (val == 1).all()):
                return rt
        if (((rt._checked_type_ is not None) and (ot._checked_type_ == rt._checked_type_))
            or (rt.type_args[0] == rt.type_args[1])):
            return ot
        if (rt._checked_type_ is not None):
            return tvm.relay.broadcast_to(ot, list(rt._checked_type_.shape))
        return rt


class LikeZapp(tvm.relay.dataflow_pattern.DFPatternCallback):
    def __init__(self):
        self.translations_with_dt = {'zeros_like': tvm.relay.zeros,
                                     'ones_like': tvm.relay.ones}
        self.data_tensor = tvm.relay.dataflow_pattern.wildcard()
        self.pattern_tensor = tvm.relay.dataflow_pattern.wildcard()
        self.pattern = ((tvm.relay.dataflow_pattern.is_op("zeros_like")
                        | tvm.relay.dataflow_pattern.is_op("ones_like")
                        )(self.data_tensor)
                        ) | ((
                        tvm.relay.dataflow_pattern.is_op("collapse_sum_like")
                        | tvm.relay.dataflow_pattern.is_op("reshape_like")
                        | tvm.relay.dataflow_pattern.is_op("broadcast_to_like")
                       )(self.data_tensor, self.pattern_tensor))
        self.require_type = True

    def callback(self, pre, post, node_map):
        data = node_map[self.data_tensor][0]
        res = node_map[self.pattern][0]
        if res.op.name in self.translations_with_dt:
            ret = self.translations_with_dt[res.op.name](list(res.type_args[0].shape),
                                                              res.type_args[0].dtype) # which dtype?
            return ret
        if (res.type_args[0] is not None and res.type_args[0] == res.type_args[1]):
            return data
        if res.op.name == 'broadcast_to_like':
            return tvm.relay.broadcast_to(data, list(res.type_args[1].shape))
        if res.op.name == 'reshape_like':
            return tvm.relay.reshape(data, list(res.type_args[1].shape))
        if res.op.name == 'collapse_sum_like':
            return tvm.relay.collapse_sum_to(data, list(res.type_args[1].shape))
        return res

In [6]:
fn = TransposeDedupMutator().visit(fn)
fn = infer_type(fn)

output_type = fn.body.checked_type

if isinstance(output_type, tvm.relay.TensorType):
    gr_out = tvm.relay.var("gr:out", output_type)
    fn_for_gr = tvm.relay.Function(list(fn.params) + [gr_out], tvm.relay.sum(fn.body * gr_out))
else:
    # we can try to handle tuples of tensors, but our nesting patience ends there
    assert (isinstance(output_type, tvm.relay.TupleType) and
            all([isinstance(f, tvm.relay.TensorType) for f in output_type.fields]))
    gr_outs = [tvm.relay.var(f"gr:out:{i}", t) for i, t in enumerate(output_type.fields)]
    prods_with_gr_out = [tvm.relay.sum(tvm.relay.TupleGetItem(fn.body, i) * go_i)
                         for i, go_i in enumerate(gr_outs)]
    s = prods_with_gr_out[0]
    for p in prods_with_gr_out[1:]:
        s = s + p
    fn_for_gr = tvm.relay.Function(list(fn.params) + gr_outs, s)
fn_for_gr = infer_type(fn_for_gr)
# visualize(fn_for_gr)

In [7]:
grfn = tvm.relay.transform.gradient(fn_for_gr, mode='first_order')
grfn = to_graph_normal_form(grfn)

# Now we have (sum(orig_out * grad_out), (grad_inp_1, ..., grad_inp_n, grad_grad_out, gr_dropout ...))
# but we only want orig_out and grad_inp_1, ..., grad_inp_n
def is_aux_input(p):
    return p.name_hint.startswith('dropout:') or p.name_hint.startswith('gr:out:')

# the gr_out and dropout parameters will have gradients computed, but we do not want that
grads_to_keep = tvm.relay.Tuple([g for p, g in zip(grfn.params, grfn.body.fields[1].fields)
                                   if not is_aux_input(p)])

assert grfn.body.fields[0].op.name == 'sum'
assert grfn.body.fields[0].args[0].op.name == 'multiply'
if isinstance(output_type, tvm.relay.TensorType):
    orig_out = grfn.body.fields[0].args[0].args[0]
else:
    assert isinstance(output_type, tvm.relay.TupleType)
    orig_out = grfn.body.fields[0].args[0].args[0].tuple_value
out_and_grad = tvm.relay.Tuple([orig_out, grads_to_keep])
out_and_grad_fn = tvm.relay.Function(grfn.params, out_and_grad)
out_and_grad_fn = infer_type(out_and_grad_fn)
out_and_grad_fn = dead_code_elimination(out_and_grad_fn)
out_and_grad_fn = eliminate_common_subexpr(out_and_grad_fn)
out_and_grad_fn = infer_type(out_and_grad_fn)
# out_and_grad_fn = tvm.relay.dataflow_pattern.rewrite(LikeZapp(), out_and_grad_fn)
# out_and_grad_fn = infer_type(out_and_grad_fn)
# out_and_grad_fn = tvm.relay.dataflow_pattern.rewrite(ZeroZapp(), out_and_grad_fn)
# out_and_grad_fn = infer_type(out_and_grad_fn)
# out_and_grad_fn = tvm.relay.dataflow_pattern.rewrite(OneZapp(), out_and_grad_fn)
# out_and_grad_fn = infer_type(out_and_grad_fn)
# out_and_grad_fn = tvm.relay.dataflow_pattern.rewrite(OneZapp(), out_and_grad_fn)
# out_and_grad_fn = infer_type(out_and_grad_fn)
# out_and_grad_fn = dead_code_elimination(out_and_grad_fn)
# out_and_grad_fn = eliminate_common_subexpr(out_and_grad_fn)

# tvm.relay.analysis.all_type_vars(grfn)

In [8]:
orig_out = out_and_grad_fn.body.fields[0]
grad_ins = out_and_grad_fn.body.fields[1]

color_dict = {}
def color(n, c):
    if n in color_dict:
        return
    color_dict[n] = c
    for a in getattr(n, 'args', []):
        color(a, c)
    for a in getattr(n, 'fields', []):
        color(a, c)
    for nam in ('body', 'tuple_value'):
        b = getattr(n, nam, None)
        if b is not None:
            color(b, c)

color(orig_out, {'color': 'red'})
seen = set()
def color_crossings(n, c):
    if n in seen:
        return
    seen.add(n)
    if n in color_dict:
        color_dict[n] = c
        return
    for a in getattr(n, 'args', []):
        color_crossings(a, c)
    for a in getattr(n, 'fields', []):
        color_crossings(a, c)
    for nam in ('body', 'tuple_value'):
        b = getattr(n, nam, None)
        if b is not None:
            color_crossings(b, c)

color_crossings(grad_ins, {'color': 'blue'})
# visualize(out_and_grad_fn, node_attr_dict=color_dict)
color_dict

{CallNode(Op(nn.conv2d), [CallNode(Op(nn.bias_add), [CallNode(Op(nn.conv2d), [Var(input0, ty=TensorType([1, 23, 13, 19], float32)), Var(score4.weight, ty=TensorType([22, 23, 1, 1], float32))], relay.attrs.Conv2DAttrs(000001F122E72048), [TensorType([1, 23, 13, 19], float32), TensorType([22, 23, 1, 1], float32)]), Var(score4.bias, ty=TensorType([22], float32))], relay.attrs.BiasAddAttrs(000001F122FEF698), [TensorType([1, 22, 13, 19], float32), TensorType([22], float32)]), Var(score_fr.weight, ty=TensorType([21, 22, 1, 1], float32))], relay.attrs.Conv2DAttrs(000001F122E71808), [TensorType([1, 22, 13, 19], float32), TensorType([21, 22, 1, 1], float32)]): {'color': 'blue'},
 CallNode(Op(nn.bias_add), [CallNode(Op(nn.conv2d), [Var(input0, ty=TensorType([1, 23, 13, 19], float32)), Var(score4.weight, ty=TensorType([22, 23, 1, 1], float32))], relay.attrs.Conv2DAttrs(000001F122E72048), [TensorType([1, 23, 13, 19], float32), TensorType([22, 23, 1, 1], float32)]), Var(score4.bias, ty=TensorType([2

In [9]:
nodes_to_capture = [n for n, v in color_dict.items() 
                    if v['color'] == 'blue' and not isinstance(n, (tvm.relay.Constant, tvm.relay.Var))]
capture_tup = tvm.relay.Tuple(nodes_to_capture)
nodes_to_capture_idx = {n:i for i, n in enumerate(nodes_to_capture)}
capture_vars = [tvm.relay.var(f"input:captures:{i}", type_annotation=nodes_to_capture[i].checked_type)
                for i, n in enumerate(nodes_to_capture)]

grads_in = out_and_grad_fn.body.fields[1]

needed_vars = set()
class GradientOnlyMutator(tvm.relay.ExprMutator):
    def __init__(self):
        super().__init__()

    def visit_var(self, var):
        needed_vars.add(var)
        return var

    def visit(self, expr):
        if expr in nodes_to_capture_idx:
            return capture_vars[nodes_to_capture_idx[expr]]
        return super().visit(expr)

grads_in_only = GradientOnlyMutator().visit(grads_in)
gr_only_fn = tvm.relay.Function(sorted(needed_vars) + capture_vars, grads_in_only)
gr_only_fn = infer_type(gr_only_fn)

# TODO: check against output of original
fn_for_gr_input_names = {p.name_hint for p in fn_for_gr.params}
needed_var_names = {v.name_hint for v in needed_vars}

assert needed_var_names <= fn_for_gr_input_names
inputs_to_keep = [n for n in needed_vars if not is_aux_input(n)]

capture_tup = tvm.relay.Tuple([n for n in nodes_to_capture])
fw_and_cap_params = [p for p in out_and_grad_fn.params if not p.name_hint.startswith('gr:out:')]

fw_and_cap_fn = tvm.relay.Function(fw_and_cap_params,
                                   tvm.relay.Tuple((out_and_grad_fn.body.fields[0],) + (capture_tup,)))
# visualize(gr_only_fn)

In [10]:
fw_and_cap_fn_flattened = tvm.relay.Function(fw_and_cap_fn.params, tvm.relay.Tuple([fw_and_cap_fn.body.fields[0]]
                                                + list(fw_and_cap_fn.body.fields[1].fields)))

target = tvm.target.cuda()
target_host = 'llvm'
ctx = tvm.context(str(target), 0)

fw_and_cap_mod = tvm.IRModule({"main": fw_and_cap_fn_flattened})
with tvm.transform.PassContext(opt_level=3):
    graph, lib, params = tvm.relay.build(fw_and_cap_mod,
                                         target=target,
                                         target_host=target_host,
                                         params={})
fw_and_cap_compiled_module = tvm.contrib.graph_runtime.create(graph, lib, ctx)
fw_and_cap_compiled_module.set_input(**params)

gr_only_mod = tvm.IRModule({"main": gr_only_fn})
with tvm.transform.PassContext(opt_level=3):
    graph, lib, params = tvm.relay.build(gr_only_mod,
                                     target=target,
                                     target_host=target_host,
                                     params={})
gr_only_compiled_module = tvm.contrib.graph_runtime.create(graph, lib, ctx)
gr_only_compiled_module.set_input(**params) # we do have funny const tensors from TVM :/

Cannot find config for target=cuda -keys=cuda,gpu -max_num_threads=1024 -model=unknown -thread_warp_size=32, workload=('conv2d_nchw.cuda', ('TENSOR', (1, 23, 13, 19), 'float32'), ('TENSOR', (22, 23, 1, 1), 'float32'), (1, 1), (0, 0, 0, 0), (1, 1), 'float32'). A fallback configuration is used, which may bring great performance regression.
Cannot find config for target=cuda -keys=cuda,gpu -max_num_threads=1024 -model=unknown -thread_warp_size=32, workload=('conv2d_nchw.cuda', ('TENSOR', (1, 22, 13, 19), 'float32'), ('TENSOR', (21, 22, 1, 1), 'float32'), (1, 1), (0, 0, 0, 0), (1, 1), 'float32'). A fallback configuration is used, which may bring great performance regression.
  del sys.path[0]
Cannot find config for target=cuda -keys=cuda,gpu -max_num_threads=1024 -model=unknown -thread_warp_size=32, workload=('group_conv2d_nchw.cuda', ('TENSOR', (1, 23, 13, 19), 'float32'), ('TENSOR', (506, 1, 13, 19), 'float32'), (1, 1), (0, 0, 0, 0), (1, 1), 23, 'float32'). A fallback configuration is us

TVMError: Traceback (most recent call last):
  File "F:\tvm_0.7_vs2019\src\ir\error.cc", line 132
TVMError: 
Error(s) have occurred. The program has been annotated with them:

In `main`: 
#[version = "0.0.5"]
fn (%score4.bias: Tensor[(22), float32], %score_fr.weight: Tensor[(21, 22, 1, 1), float32], %score4.weight: Tensor[(22, 23, 1, 1), float32], %input0: Tensor[(1, 23, 13, 19), float32], %gr:out: Tensor[(1, 21, 13, 19), float32], %input:captures:0: Tensor[(1, 21, 13, 19), float32], %input:captures:1: Tensor[(1, 22, 13, 19), float32], %input:captures:2: Tensor[(1, 22, 13, 19), float32]) -> (Tensor[(1, 23, 13, 19), float32], Tensor[(22, 23, 1, 1), float32], Tensor[(22), float32], Tensor[(21, 22, 1, 1), float32], Tensor[(1, 21, 13, 19), float32]) {
  %0 = zeros_like(%input0);
  %1 = zeros_like(%input:captures:2);
  %2 = zeros_like(%input:captures:1);
  %3 = zeros_like(%input:captures:0);
  %4 = multiply(%input:captures:0, %gr:out);
  %5 = zeros_like(%4);
  %6 = sum(%4);
  %7 = ones_like(%6);
  %8 = expand_dims(%7, axis=0);
  %9 = expand_dims(%8, axis=1);
  %10 = expand_dims(%9, axis=2);
  %11 = expand_dims(%10, axis=3);
  %12 = broadcast_to_like(%11, %4);
  %13 = add(%5, %12);
  %14 = multiply(%13, %gr:out);
  %15 = collapse_sum_like(%14, %input:captures:0);
  %16 = add(%3, %15);
  %17 = nn.conv2d_transpose(%16, %score_fr.weight, padding=[0, 0, 0, 0]);
  %18 = add(%2, %17);
  %19 = collapse_sum_like(%18, %input:captures:2);
  %20 = add(%1, %19);
  %21 = nn.conv2d_transpose(%20, %score4.weight, padding=[0, 0, 0, 0]);
  %22 = add(%0, %21);
  %23 = zeros_like(%score4.weight);
  %24 = reshape(%input0, newshape=[1, -1, 0, 0]);
  %25 = tile(%20, reps=[1, 23, 1, 1]);
  %26 = reshape(%25, newshape=[-1, 1, 0, 0]);
  %27 = nn.conv2d(%24, %26, padding=[0, 0, 0, 0], groups=23);
  %28 = reshape(%27, newshape=[1, 23, 22, 1, 1]);
  %29 = sum(%28, axis=[0]);
  %30 = transpose(%29, axes=[1, 0, 2, 3]);
  %31 = add(%23, %30);
  %32 = zeros_like(%score4.bias);
  %33 = expand_dims(%32, axis=0, num_newaxis=3);
  %34 = layout_transform(%33, src_layout="CHWN", dst_layout="NCHW");
  %35 = sum(%18, axis=[0, 2, 3], exclude=True);
  %36 = add(%34, %35) tensor type `Tensor[(22, 1, 13, 19), float32]` has 4 dimensions, while `Tensor[(22), float32]` has 1 dimensions; unable to unify: `Tensor[(22, 1, 13, 19), float32]` and `Tensor[(22), float32]`; ;
  %37 = zeros_like(%score_fr.weight);
  %38 = reshape(%input:captures:1, newshape=[1, -1, 0, 0]);
  %39 = tile(%16, reps=[1, 22, 1, 1]);
  %40 = reshape(%39, newshape=[-1, 1, 0, 0]);
  %41 = nn.conv2d(%38, %40, padding=[0, 0, 0, 0], groups=22);
  %42 = reshape(%41, newshape=[1, 22, 21, 1, 1]);
  %43 = sum(%42, axis=[0]);
  %44 = transpose(%43, axes=[1, 0, 2, 3]);
  %45 = add(%37, %44);
  %46 = zeros_like(%gr:out);
  %47 = multiply(%13, %input:captures:0);
  %48 = collapse_sum_like(%47, %gr:out);
  %49 = add(%46, %48);
  (%22, %31, %36, %45, %49)
}

In [None]:
from torch.utils import dlpack

def tensor_to_tvm(t):
    return tvm.nd.from_dlpack(dlpack.to_dlpack(t))
def tensor_from_tvm(a):
    return (dlpack.from_dlpack(a.to_dlpack()))

model_params_tvm = {k: tensor_to_tvm(v) for k, v in model.state_dict().items()}

model.cuda()
inp_c = input_data.cuda()
inp_tvm = tensor_to_tvm(inp_c)

torch.manual_seed(12345)

fw_and_cap_compiled_module.set_input(input_name, inp_tvm)
fw_and_cap_compiled_module.set_input(**model_params_tvm)
fw_and_cap_compiled_module.run()
# inp_c

In [None]:
torch.manual_seed(12345)
model.train()
numpy.abs(fw_and_cap_compiled_module.get_output(0).asnumpy()-model(inp_c).detach().cpu().numpy()).max()

In [None]:
gr_out_c = torch.randn([1, 21, 28, 40]).cuda()

num_captures = len(capture_vars)
num_regular_outputs = len(fw_and_cap_fn_flattened.body.fields) - num_captures
captured_values = {v.name_hint: fw_and_cap_compiled_module.get_output(num_regular_outputs + i) for i, v in enumerate(capture_vars)}

gr_only_compiled_module.set_input(**model_params_tvm)
gr_only_compiled_module.set_input(**captured_values)
gr_only_compiled_module.set_input('gr:out', tensor_to_tvm(gr_out_c))
gr_only_compiled_module.set_input(input_name, inp_tvm)
gr_only_compiled_module.run()

torch.manual_seed(12345)
model.train()
inp_c_rq = inp_c.requires_grad_()
for p in model.parameters():
    p.requires_grad_()
res = model(inp_c_rq)
grads_pt = torch.autograd.grad(res, [inp_c_rq] + list(model.parameters()), gr_out_c, allow_unused=True, retain_graph=True)
grads_pt

In [None]:
for i, g_pt in enumerate(grads_pt):
    print(numpy.abs(gr_only_compiled_module.get_output(i).asnumpy() - g_pt.cpu().numpy()).max())

In [None]:
w = gr_only_compiled_module.get_output(1)
w = tensor_from_tvm(w)
w