Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Edit onnx parser to infer values in post order #5755

Merged
merged 2 commits into from
Jun 12, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 116 additions & 3 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,29 @@
from .. import function as _function
from .. import op as _op
from .. import vision as _vision

from ..function import Function
from ..expr import Call, Let
from ..expr import If, Tuple, TupleGetItem
from ..expr import RefCreate, RefRead, RefWrite
from ..expr_functor import ExprFunctor
from ..adt import Match, Clause

from .common import AttrCvt, Renamer
from .common import get_relay_op, new_var, infer_shape, infer_channels
from .common import infer_type, infer_value, infer_value_simulated, get_name
from .common import infer_type, get_name
from .common import infer_value as _infer_value
from .common import infer_value_simulated as _infer_value_simulated

__all__ = ['from_onnx']

g = None

def infer_value(input_val, params, mod=None):
return g.infer_value(input_val, params, mod)

def infer_value_simulated(input_val, params):
return g.infer_value_simulated(input_val, params)

class onnx_input():
""" Dual purpose list or dictionary access object."""
Expand Down Expand Up @@ -1879,8 +1896,7 @@ def _get_convert_map(opset):
'NonZero': NonZero.get_converter(opset),
}


class GraphProto(object):
class GraphProto(ExprFunctor):
"""A helper class for handling Relay expression copying from pb2.GraphProto.
Definition: https://github.com/onnx/onnx/blob/master/onnx/onnx.proto

Expand All @@ -1902,6 +1918,101 @@ def __init__(self, shape, dtype):
self._shape = shape if shape else {}
self._dtype = dtype

#For infering Values
self._tmp_params = {}
self._infer_simulated = True
self._mod = None
super(GraphProto, self).__init__()

def infer_value(self, input_val, params, mod=None):
self._tmp_params = params
self._infer_simulated = False
self._mod = mod
return self.visit(input_val).data
#return _infer_value(input_val, params, mod)

def infer_value_simulated(self, input_val, params):
self._tmp_params = params
self._infer_simulated = True
return self.visit(input_val).data
#return _infer_value_simulated(input_val, params)

def infer(self, expr):
if self._infer_simulated:
out = _infer_value_simulated(expr, self._tmp_params)
else:
out = _infer_value(expr, self._tmp_params)
return _expr.const(out.asnumpy())

def visit_function(self, fn):
new_params = [self.visit(x) for x in fn.params]
new_body = self.visit(fn.body)
return self.infer(Function(
list(new_params),
new_body,
fn.ret_type,
fn.type_params,
fn.attrs))

def visit_let(self, let):
newvar = self.visit(let.var)
newval = self.visit(let.value)
newbody = self.visit(let.body)
return self.infer(Let(newvar, newval, newbody))

def visit_call(self, call):
new_fn = self.visit(call.op)
new_args = [self.visit(arg) for arg in call.args]
return self.infer(Call(new_fn, new_args, call.attrs))

def visit_var(self, var):
return self.infer(var)

def visit_global_id(self, global_var):
return self.infer(global_var)

def visit_if(self, ite):
return self.infer(If(
self.visit(ite.cond),
self.visit(ite.true_branch),
self.visit(ite.false_branch)))

def visit_tuple(self, tup):
return Tuple([self.visit(field) for field in tup.fields])

def visit_tuple_getitem(self, op):
tuple_value = self.visit(op.tuple_value)
if not tuple_value.same_as(op.tuple_value):
return self.infer(TupleGetItem(tuple_value, op.index))
return self.infer(op)

def visit_global_var(self, gvar):
return self.infer(gvar)

def visit_op(self, op):
return op

def visit_constant(self, const):
return const

def visit_constructor(self, con):
return con

def visit_match(self, m):
return self.infer(Match(
self.visit(m.data),
[Clause(c.lhs, self.visit(c.rhs)) for c in m.clauses],
complete=m.complete))

def visit_ref_create(self, r):
return RefCreate(self.visit(r.value))

def visit_ref_write(self, r):
return RefWrite(self.visit(r.ref), self.visit(r.value))

def visit_ref_read(self, r):
return RefRead(self.visit(r.ref))

def from_onnx(self, graph, opset):
"""Construct Relay expression from ONNX graph.

Expand Down Expand Up @@ -2160,6 +2271,7 @@ def from_onnx(model,
warnings.warn(str(e))
except ImportError:
pass
global g
g = GraphProto(shape, dtype)
graph = model.graph
if opset is None:
Expand All @@ -2168,4 +2280,5 @@ def from_onnx(model,
except AttributeError:
opset = 1
mod, params = g.from_onnx(graph, opset)
g = None
return mod, params