Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend][ONNX] LSTM Support #4825

Merged
merged 6 commits into from Feb 7, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
223 changes: 217 additions & 6 deletions python/tvm/relay/frontend/onnx.py
Expand Up @@ -32,6 +32,55 @@
__all__ = ['from_onnx']


class onnx_input():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I really like this design - very sleek.

""" Dual purpose list or dictionary access object."""

def __init__(self):
self.input_keys = []
self.input_dict = {}

def __getitem__(self, item):
if isinstance(item, int):
return self.input_dict[self.input_keys[item]]
if isinstance(item, str):
if item not in self.input_keys:
return None
return self.input_dict[item]
if isinstance(item, slice):
keys = self.input_keys[item]
return [self.input_dict[key] for key in keys]

raise ValueError("Only integer, string, and slice accesses allowed.")

def __setitem__(self, item, value):
if isinstance(item, int):
self.input_dict[self.input_keys[item]] = value
elif isinstance(item, str):
if item not in self.input_dict:
self.input_keys.append(item)
self.input_dict[item] = value
else:
raise ValueError("Only integer and string indexed writes allowed.")

def keys(self):
return self.input_keys

def __len__(self):
return len(self.input_keys)

def __iter__(self):
self.n = 0
return self

def __next__(self):
if self.n < len(self.input_keys):
output = self.input_dict[self.input_keys[self.n]]
self.n += 1
return output

raise StopIteration


def get_numpy(tensor_proto):
"""Grab data in TensorProto and convert to numpy array."""
try:
Expand Down Expand Up @@ -664,13 +713,24 @@ def _impl_v1(cls, inputs, attr, params):
return inputs[len(inputs) - 1]


class Affine(OnnxOpConverter):
""" Operator converter for Affine transformation.
"""

@classmethod
def _impl_v1(cls, inputs, attr, params):
alpha = _expr.const(attr.get('alpha', 1.0))
beta = _expr.const(attr.get('beta', 0.0))
return (alpha * inputs[0]) + beta


class ThresholdedRelu(OnnxOpConverter):
""" Operator converter for ThresholdedRelu.
"""

@classmethod
def _impl_v1(cls, inputs, attr, params):
alpha = float(attr.get('alpha', 0.0))
alpha = float(attr.get('alpha', 1.0))
alpha_tensor = _op.full_like(inputs[0], fill_value=_expr.const(alpha))
mask = _op.greater(inputs[0], alpha_tensor).astype("float32")
return inputs[0] * mask
Expand Down Expand Up @@ -893,7 +953,7 @@ class Maximum(OnnxOpConverter):
"""
@classmethod
def _impl_v1(cls, inputs, attr, params):
if not isinstance(inputs, list) or len(inputs) < 2:
if not isinstance(inputs, (list, onnx_input)) or len(inputs) < 2:
raise ValueError("Expect minimum 2 inputs")
_max = inputs[0]
for i in range(1, len(inputs)):
Expand All @@ -905,7 +965,7 @@ class Minimum(OnnxOpConverter):
"""
@classmethod
def _impl_v1(cls, inputs, attr, params):
if not isinstance(inputs, list) or len(inputs) < 2:
if not isinstance(inputs, (list, onnx_input)) or len(inputs) < 2:
raise ValueError("Expect minimum 2 inputs")
_min = inputs[0]
for i in range(1, len(inputs)):
Expand All @@ -917,7 +977,7 @@ class Mean(OnnxOpConverter):
"""
@classmethod
def _impl_v1(cls, inputs, attr, params):
if not isinstance(inputs, list) or len(inputs) < 2:
if not isinstance(inputs, (list, onnx_input)) or len(inputs) < 2:
raise ValueError("Expect minimum 2 inputs")
# avoid overflow
concat = _op.concatenate([_op.expand_dims(x, axis=0) for x in inputs], axis=0)
Expand Down Expand Up @@ -1190,6 +1250,151 @@ def expand_shape(in_shape, shape):
return _op.broadcast_to(inputs[0], shape=tuple(shape))


class LSTM(OnnxOpConverter):
""" Operator converter for LSTM.
"""

@classmethod
def _activation_helper(cls, activation, alpha, beta):
convert_map = _get_convert_map(1)
attrs = {}
if alpha is not None:
attrs['alpha'] = alpha
if beta is not None:
attrs['beta'] = beta
return lambda x: convert_map[activation.decode("utf-8")]([x], attrs, {})

@classmethod
def _activation_needs_alpha(cls, activation):
needs_alpha = [
"Affine",
"LeakyRelu",
"ThresholdedRelu",
"ScaledTanh",
"HardSigmoid",
"Elu",
]
return activation.decode("utf-8") in needs_alpha

@classmethod
def _activation_needs_beta(cls, activation):
needs_beta = [
"Affine",
"ScaledTanh",
"HardSigmoid",
]
return activation.decode("utf-8") in needs_beta

@classmethod
def _impl_v7(cls, inputs, attr, params):
# Unpack inputs, note that if optional and not provided then value will be None.
X = inputs[0]
W = inputs[1]
Copy link
Contributor

@soiferj soiferj Feb 6, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any case when the weights won’t be constant? If they’re constant, we can remove some operations from the graph and compute them here (like squeeze).

By constant, I mean we can call infer_value on it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think in almost all cases it'd be safe to assume weights are constant. However, the fold constant pass in relay will eliminate all operations on the weights anyway. Since treating the weights as a non-constant is slightly more flexible I prefer it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. Thanks a lot for the updates.

R = inputs[2]
B = inputs['B']
# Sequence length currently unused as it can be inferred from shapes.
#sequence_lens = inputs['sequence_lens']
h_0 = inputs['initial_h']
c_0 = inputs['initial_c']
P = inputs['P']

num_directions = infer_shape(W)[0]
W_dtype = infer_type(W).type_annotation.dtype

if num_directions != 1:
raise NotImplementedError("Bidirectional LSTMs not yet supported.")
# Remove num_directions axis from weights.
W = _op.squeeze(W, axis=[0])
R = _op.squeeze(R, axis=[0])
if B is not None:
B = _op.squeeze(B, axis=[0])

X_shape = infer_shape(X)
hidden_size = infer_shape(R)[-1]
batch_size = X_shape[1]

# Initialize state if not provided.
# Otherwise remove bidirectional axis.
if h_0 is None:
h_0 = _op.zeros((batch_size, hidden_size), W_dtype)
else:
h_0 = _op.squeeze(h_0, axis=[0])
if c_0 is None:
c_0 = _op.zeros((batch_size, hidden_size), W_dtype)
else:
c_0 = _op.squeeze(c_0, axis=[0])

if P is not None:
P = _op.squeeze(P, axis=[0])
p_i, p_o, p_f = _op.split(P, 3)
H_t = h_0
C_t = c_0
h_list = []

if 'activations' in attr:
activations = attr['activations']
if len(activations) != 3:
raise NotImplementedError("LSTM assumes 3 activation functions are provided")
alpha_loc = 0
alphas = attr.get('activation_alpha', [])
if isinstance(alphas, float):
alphas = [alphas]
beta_loc = 0
betas = attr.get('activation_beta', [])
if isinstance(betas, float):
betas = [betas]
acts = []
for i in range(3):
alpha = None
beta = None
activation = activations[i]
if cls._activation_needs_alpha(activation) and len(alphas) > alpha_loc:
alpha = alphas[alpha_loc]
alpha_loc += 1
if cls._activation_needs_beta(activation) and len(betas) > beta_loc:
beta = betas[beta_loc]
beta_loc += 1
acts.append(cls._activation_helper(activation, alpha, beta))
f_act, g_act, h_act = acts
else:
f_act = _op.sigmoid
g_act = _op.tanh
h_act = _op.tanh

X_steps = _op.split(X, indices_or_sections=X_shape[0], axis=0)
for step in X_steps:
step = _op.squeeze(step, axis=[0])
gates = _op.nn.dense(step, W) + _op.nn.dense(H_t, R)
if B is not None:
WB, RB = _op.split(B, 2)
gates += WB + RB
i, o, f, c = _op.split(gates, 4, axis=-1)
if P is not None:
i = f_act(i + p_i * C_t)
f = f_act(f + p_f * C_t)

else:
i = f_act(i)
f = f_act(f)
c = g_act(c)
C = f * C_t + i * c
if P is not None:
o = f_act(o + p_o * C)
else:
o = f_act(o)
H = o * h_act(C)
H_t = H
C_t = C
h_list.append(_op.expand_dims(H, axis=0))
# Concatenate outputs and add back in direction axis.
concatenated = _op.concatenate(h_list, 0)
output = _op.expand_dims(concatenated, axis=1)
H_t = _op.expand_dims(H_t, axis=0)
C_t = _op.expand_dims(C_t, axis=0)

return _expr.TupleWrapper(_expr.Tuple((output, H_t, C_t)), 3)


# compatible operators that do NOT require any conversion.
_identity_list = []

Expand All @@ -1203,7 +1408,7 @@ def _get_convert_map(opset):
return {
# defs/experimental
'Identity': Renamer('copy'),
# 'Affine'
'Affine': Affine.get_converter(opset),
'ThresholdedRelu': ThresholdedRelu.get_converter(opset),
'ScaledTanh': ScaledTanh.get_converter(opset),
'ParametricSoftplus': ParametricSoftPlus.get_converter(opset),
Expand Down Expand Up @@ -1281,6 +1486,8 @@ def _get_convert_map(opset):
'Dropout': AttrCvt('dropout', {'ratio': 'rate'}, ignores=['is_test']),
'Flatten': Flatten.get_converter(opset),
'LRN': LRN.get_converter(opset),
# Recurrent Layers
'LSTM': LSTM.get_converter(opset),

# defs/reduction
'ReduceMax': ReduceMax.get_converter(opset),
Expand Down Expand Up @@ -1414,7 +1621,11 @@ def from_onnx(self, graph, opset):
for node in graph.node:
op_name = node.op_type
attr = self._parse_attr(node.attribute)
inputs = [self._nodes[self._renames.get(i, i)] for i in node.input]
# Create and populate onnx input object.
inputs = onnx_input()
for i in node.input:
if i != '':
inputs[i] = self._nodes[self._renames.get(i, i)]
if op_name == "Constant":
t_proto = self._parse_attr(node.attribute)["value"]
self._num_param += 1
Expand Down