New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Frontend][ONNX] LSTM Support #4825
Changes from all commits
62e5abb
5f3f658
4f0c2ab
26acb09
2f892b8
475852a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -32,6 +32,55 @@ | |
__all__ = ['from_onnx'] | ||
|
||
|
||
class onnx_input(): | ||
""" Dual purpose list or dictionary access object.""" | ||
|
||
def __init__(self): | ||
self.input_keys = [] | ||
self.input_dict = {} | ||
|
||
def __getitem__(self, item): | ||
if isinstance(item, int): | ||
return self.input_dict[self.input_keys[item]] | ||
if isinstance(item, str): | ||
if item not in self.input_keys: | ||
return None | ||
return self.input_dict[item] | ||
if isinstance(item, slice): | ||
keys = self.input_keys[item] | ||
return [self.input_dict[key] for key in keys] | ||
|
||
raise ValueError("Only integer, string, and slice accesses allowed.") | ||
|
||
def __setitem__(self, item, value): | ||
if isinstance(item, int): | ||
self.input_dict[self.input_keys[item]] = value | ||
elif isinstance(item, str): | ||
if item not in self.input_dict: | ||
self.input_keys.append(item) | ||
self.input_dict[item] = value | ||
else: | ||
raise ValueError("Only integer and string indexed writes allowed.") | ||
|
||
def keys(self): | ||
return self.input_keys | ||
|
||
def __len__(self): | ||
return len(self.input_keys) | ||
|
||
def __iter__(self): | ||
self.n = 0 | ||
return self | ||
|
||
def __next__(self): | ||
if self.n < len(self.input_keys): | ||
output = self.input_dict[self.input_keys[self.n]] | ||
self.n += 1 | ||
return output | ||
|
||
raise StopIteration | ||
|
||
|
||
def get_numpy(tensor_proto): | ||
"""Grab data in TensorProto and convert to numpy array.""" | ||
try: | ||
|
@@ -664,13 +713,24 @@ def _impl_v1(cls, inputs, attr, params): | |
return inputs[len(inputs) - 1] | ||
|
||
|
||
class Affine(OnnxOpConverter): | ||
""" Operator converter for Affine transformation. | ||
""" | ||
|
||
@classmethod | ||
def _impl_v1(cls, inputs, attr, params): | ||
alpha = _expr.const(attr.get('alpha', 1.0)) | ||
beta = _expr.const(attr.get('beta', 0.0)) | ||
return (alpha * inputs[0]) + beta | ||
|
||
|
||
class ThresholdedRelu(OnnxOpConverter): | ||
""" Operator converter for ThresholdedRelu. | ||
""" | ||
|
||
@classmethod | ||
def _impl_v1(cls, inputs, attr, params): | ||
alpha = float(attr.get('alpha', 0.0)) | ||
alpha = float(attr.get('alpha', 1.0)) | ||
alpha_tensor = _op.full_like(inputs[0], fill_value=_expr.const(alpha)) | ||
mask = _op.greater(inputs[0], alpha_tensor).astype("float32") | ||
return inputs[0] * mask | ||
|
@@ -893,7 +953,7 @@ class Maximum(OnnxOpConverter): | |
""" | ||
@classmethod | ||
def _impl_v1(cls, inputs, attr, params): | ||
if not isinstance(inputs, list) or len(inputs) < 2: | ||
if not isinstance(inputs, (list, onnx_input)) or len(inputs) < 2: | ||
raise ValueError("Expect minimum 2 inputs") | ||
_max = inputs[0] | ||
for i in range(1, len(inputs)): | ||
|
@@ -905,7 +965,7 @@ class Minimum(OnnxOpConverter): | |
""" | ||
@classmethod | ||
def _impl_v1(cls, inputs, attr, params): | ||
if not isinstance(inputs, list) or len(inputs) < 2: | ||
if not isinstance(inputs, (list, onnx_input)) or len(inputs) < 2: | ||
raise ValueError("Expect minimum 2 inputs") | ||
_min = inputs[0] | ||
for i in range(1, len(inputs)): | ||
|
@@ -917,7 +977,7 @@ class Mean(OnnxOpConverter): | |
""" | ||
@classmethod | ||
def _impl_v1(cls, inputs, attr, params): | ||
if not isinstance(inputs, list) or len(inputs) < 2: | ||
if not isinstance(inputs, (list, onnx_input)) or len(inputs) < 2: | ||
raise ValueError("Expect minimum 2 inputs") | ||
# avoid overflow | ||
concat = _op.concatenate([_op.expand_dims(x, axis=0) for x in inputs], axis=0) | ||
|
@@ -1190,6 +1250,151 @@ def expand_shape(in_shape, shape): | |
return _op.broadcast_to(inputs[0], shape=tuple(shape)) | ||
|
||
|
||
class LSTM(OnnxOpConverter): | ||
""" Operator converter for LSTM. | ||
""" | ||
|
||
@classmethod | ||
def _activation_helper(cls, activation, alpha, beta): | ||
convert_map = _get_convert_map(1) | ||
attrs = {} | ||
if alpha is not None: | ||
attrs['alpha'] = alpha | ||
if beta is not None: | ||
attrs['beta'] = beta | ||
return lambda x: convert_map[activation.decode("utf-8")]([x], attrs, {}) | ||
|
||
@classmethod | ||
def _activation_needs_alpha(cls, activation): | ||
needs_alpha = [ | ||
"Affine", | ||
"LeakyRelu", | ||
"ThresholdedRelu", | ||
"ScaledTanh", | ||
"HardSigmoid", | ||
"Elu", | ||
] | ||
return activation.decode("utf-8") in needs_alpha | ||
|
||
@classmethod | ||
def _activation_needs_beta(cls, activation): | ||
needs_beta = [ | ||
"Affine", | ||
"ScaledTanh", | ||
"HardSigmoid", | ||
] | ||
return activation.decode("utf-8") in needs_beta | ||
|
||
@classmethod | ||
def _impl_v7(cls, inputs, attr, params): | ||
# Unpack inputs, note that if optional and not provided then value will be None. | ||
X = inputs[0] | ||
W = inputs[1] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there any case when the weights won’t be constant? If they’re constant, we can remove some operations from the graph and compute them here (like squeeze). By constant, I mean we can call There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think in almost all cases it'd be safe to assume weights are constant. However, the fold constant pass in relay will eliminate all operations on the weights anyway. Since treating the weights as a non-constant is slightly more flexible I prefer it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sounds good. Thanks a lot for the updates. |
||
R = inputs[2] | ||
B = inputs['B'] | ||
# Sequence length currently unused as it can be inferred from shapes. | ||
#sequence_lens = inputs['sequence_lens'] | ||
h_0 = inputs['initial_h'] | ||
c_0 = inputs['initial_c'] | ||
P = inputs['P'] | ||
|
||
num_directions = infer_shape(W)[0] | ||
W_dtype = infer_type(W).type_annotation.dtype | ||
|
||
if num_directions != 1: | ||
raise NotImplementedError("Bidirectional LSTMs not yet supported.") | ||
# Remove num_directions axis from weights. | ||
W = _op.squeeze(W, axis=[0]) | ||
R = _op.squeeze(R, axis=[0]) | ||
if B is not None: | ||
B = _op.squeeze(B, axis=[0]) | ||
|
||
X_shape = infer_shape(X) | ||
hidden_size = infer_shape(R)[-1] | ||
batch_size = X_shape[1] | ||
|
||
# Initialize state if not provided. | ||
# Otherwise remove bidirectional axis. | ||
if h_0 is None: | ||
h_0 = _op.zeros((batch_size, hidden_size), W_dtype) | ||
else: | ||
h_0 = _op.squeeze(h_0, axis=[0]) | ||
if c_0 is None: | ||
c_0 = _op.zeros((batch_size, hidden_size), W_dtype) | ||
else: | ||
c_0 = _op.squeeze(c_0, axis=[0]) | ||
|
||
if P is not None: | ||
P = _op.squeeze(P, axis=[0]) | ||
p_i, p_o, p_f = _op.split(P, 3) | ||
H_t = h_0 | ||
C_t = c_0 | ||
h_list = [] | ||
|
||
if 'activations' in attr: | ||
activations = attr['activations'] | ||
if len(activations) != 3: | ||
raise NotImplementedError("LSTM assumes 3 activation functions are provided") | ||
alpha_loc = 0 | ||
alphas = attr.get('activation_alpha', []) | ||
if isinstance(alphas, float): | ||
alphas = [alphas] | ||
beta_loc = 0 | ||
betas = attr.get('activation_beta', []) | ||
if isinstance(betas, float): | ||
betas = [betas] | ||
acts = [] | ||
for i in range(3): | ||
alpha = None | ||
beta = None | ||
activation = activations[i] | ||
if cls._activation_needs_alpha(activation) and len(alphas) > alpha_loc: | ||
alpha = alphas[alpha_loc] | ||
alpha_loc += 1 | ||
if cls._activation_needs_beta(activation) and len(betas) > beta_loc: | ||
beta = betas[beta_loc] | ||
beta_loc += 1 | ||
acts.append(cls._activation_helper(activation, alpha, beta)) | ||
f_act, g_act, h_act = acts | ||
else: | ||
f_act = _op.sigmoid | ||
g_act = _op.tanh | ||
h_act = _op.tanh | ||
|
||
X_steps = _op.split(X, indices_or_sections=X_shape[0], axis=0) | ||
for step in X_steps: | ||
step = _op.squeeze(step, axis=[0]) | ||
gates = _op.nn.dense(step, W) + _op.nn.dense(H_t, R) | ||
if B is not None: | ||
WB, RB = _op.split(B, 2) | ||
gates += WB + RB | ||
i, o, f, c = _op.split(gates, 4, axis=-1) | ||
if P is not None: | ||
i = f_act(i + p_i * C_t) | ||
f = f_act(f + p_f * C_t) | ||
|
||
else: | ||
i = f_act(i) | ||
f = f_act(f) | ||
c = g_act(c) | ||
C = f * C_t + i * c | ||
if P is not None: | ||
o = f_act(o + p_o * C) | ||
else: | ||
o = f_act(o) | ||
H = o * h_act(C) | ||
H_t = H | ||
C_t = C | ||
h_list.append(_op.expand_dims(H, axis=0)) | ||
# Concatenate outputs and add back in direction axis. | ||
concatenated = _op.concatenate(h_list, 0) | ||
output = _op.expand_dims(concatenated, axis=1) | ||
H_t = _op.expand_dims(H_t, axis=0) | ||
C_t = _op.expand_dims(C_t, axis=0) | ||
|
||
return _expr.TupleWrapper(_expr.Tuple((output, H_t, C_t)), 3) | ||
|
||
|
||
# compatible operators that do NOT require any conversion. | ||
_identity_list = [] | ||
|
||
|
@@ -1203,7 +1408,7 @@ def _get_convert_map(opset): | |
return { | ||
# defs/experimental | ||
'Identity': Renamer('copy'), | ||
# 'Affine' | ||
'Affine': Affine.get_converter(opset), | ||
'ThresholdedRelu': ThresholdedRelu.get_converter(opset), | ||
'ScaledTanh': ScaledTanh.get_converter(opset), | ||
'ParametricSoftplus': ParametricSoftPlus.get_converter(opset), | ||
|
@@ -1281,6 +1486,8 @@ def _get_convert_map(opset): | |
'Dropout': AttrCvt('dropout', {'ratio': 'rate'}, ignores=['is_test']), | ||
'Flatten': Flatten.get_converter(opset), | ||
'LRN': LRN.get_converter(opset), | ||
# Recurrent Layers | ||
'LSTM': LSTM.get_converter(opset), | ||
|
||
# defs/reduction | ||
'ReduceMax': ReduceMax.get_converter(opset), | ||
|
@@ -1414,7 +1621,11 @@ def from_onnx(self, graph, opset): | |
for node in graph.node: | ||
op_name = node.op_type | ||
attr = self._parse_attr(node.attribute) | ||
inputs = [self._nodes[self._renames.get(i, i)] for i in node.input] | ||
# Create and populate onnx input object. | ||
inputs = onnx_input() | ||
for i in node.input: | ||
if i != '': | ||
inputs[i] = self._nodes[self._renames.get(i, i)] | ||
if op_name == "Constant": | ||
t_proto = self._parse_attr(node.attribute)["value"] | ||
self._num_param += 1 | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I really like this design - very sleek.