Skip to content

Commit

Permalink
[ONNX] Add imports for BERT contrib operators (#10949)
Browse files Browse the repository at this point in the history
* EmbedLayerNormalization, Attention

* fix Attention

* SkipLayerNormalization

* fix dtype bug in Gelu

Co-authored-by: An Wang <anwang2009@gmail.com>

* missing parameterize_targets

* lint

* lint

* comments

* fix small thing

* factor out layer norm computation

* layernorm func

* add optional args to test

* upgrade onnxrt version

* no upgrade onnx

* fix tests

* int32

* fix tests

Co-authored-by: An Wang <anwang2009@gmail.com>
  • Loading branch information
altanh and anwang2009 committed Apr 13, 2022
1 parent 814e856 commit 11b8cd3
Show file tree
Hide file tree
Showing 2 changed files with 440 additions and 3 deletions.
224 changes: 221 additions & 3 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,22 @@ def flatten_to_nd(x, x_shape, nd=3):
return _op.nn.dense(inputs[0], input_1_t, out_dtype=out_dtype)


def layer_norm(x, eps, gamma, beta):
"""Common function to handle layer norm"""
eps_dtype = infer_type(x).checked_type.dtype

u, s = _op.mean_variance(x, axis=-1, keepdims=True)
output = _op.divide(
_op.subtract(x, u),
_op.sqrt(_op.add(s, _op.const(eps, dtype=eps_dtype))),
)
output = _op.multiply(output, gamma)
if beta is not None:
output = _op.add(output, beta)

return output


class OnnxOpConverter(object):
"""A helper class for holding onnx op converters."""

Expand Down Expand Up @@ -807,9 +823,10 @@ def _impl_v1(cls, inputs, attr, params):
x = inputs[0]

# Declare consts
half = _expr.const(0.5)
one = _expr.const(1.0)
sqrt2 = _expr.const(math.sqrt(2))
const_dtype = infer_type(x).checked_type.dtype
half = _expr.const(0.5, dtype=const_dtype)
one = _expr.const(1.0, dtype=const_dtype)
sqrt2 = _expr.const(math.sqrt(2), dtype=const_dtype)

# Compute gelu
term1 = _op.multiply(half, x)
Expand All @@ -836,6 +853,201 @@ def _impl_v1(cls, inputs, attr, params):
return Gelu._impl_v1([inp], attr, params)


class EmbedLayerNormalization(OnnxOpConverter):
"""Operator converter for EmbedLayerNormalization from Microsoft onnxruntime contrib opset.
This layer embeds the input tokens, sums them, and applies layer normalization.
"""

@classmethod
def _impl_v1(cls, inputs, attr, params):
input_ids = inputs[0]
segment_ids = inputs[1]
word_emb = inputs[2]
pos_emb = inputs[3]
segment_emb = inputs[4]
gamma = inputs[5]
beta = inputs[6]

mask = inputs[7]
pos_ids = inputs[8]

eps = attr.get("epsilon", 1e-12)

(batch_size, seq_len) = infer_shape(input_ids)

if segment_ids:
assert segment_emb

if pos_ids is None:
pos_ids = _op.const([list(range(seq_len))] * seq_len, dtype="int32")

word_vec = _op.take(word_emb, input_ids, axis=0)
segment_vec = _op.take(segment_emb, segment_ids, axis=0)
pos_vec = _op.take(pos_emb, pos_ids, axis=0)

vec_sum = _op.add(word_vec, pos_vec)
if segment_ids:
vec_sum = _op.add(vec_sum, segment_vec)

ln = layer_norm(vec_sum, eps, gamma, beta)

mask_index = _op.const(np.zeros((batch_size,), dtype="int32"))
if mask:
# calculate number of words per sentence
mask_index = _op.sum(mask, axis=1)

# TODO(@anwang2009): onnxruntime v1.10.0 requires a third output of vec_sum
return _expr.TupleWrapper(_expr.Tuple([ln, mask_index]), 2)


class SkipLayerNormalization(OnnxOpConverter):
"""Operator converter for SkipLayerNormalization from Microsoft onnxruntime contrib opset.
This layer sums the two input tensors (along with optional bias), and applies layer
normalization.
"""

@classmethod
def _impl_v1(cls, inputs, attr, params):
data = inputs[0]
skip = inputs[1]
gamma = inputs[2]
beta = inputs[3]
bias = inputs[4]

assert (
beta is not None and bias is not None
), "SkipLayerNormalization import currently only supports required beta and bias"

eps = attr.get("epsilon", 1e-12)

x = _op.add(data, skip)
if bias is not None:
x = _op.add(x, bias)

output = layer_norm(x, eps, gamma, beta)

# onnxruntime doesn't compute the other outputs, despite the documentation
placeholder = _op.const(0, dtype="float32")

return _expr.TupleWrapper(_expr.Tuple([output, placeholder, placeholder]), 3)


class Attention(OnnxOpConverter):
"""Operator converter for Attention from Microsoft onnxruntime contrib opset.
This is the self-attention mechanism used in transformer models.
"""

@classmethod
def _impl_v1(cls, inputs, attr, params):
num_heads = attr["num_heads"]
assert (
"qkv_hidden_sizes" not in attr
), "different hidden sizes for Q, K, V are not currently supported"
assert "unidirectional" not in attr, "unidirectional attention not current supported"

# (batch, seq, in_hidden)
input_emb = inputs[0]

# (in_hidden, 3 * out_hidden), where out_hidden = num_heads * head_size
weight = inputs[1]

# (3 * out_hidden,)
bias = inputs[2]

# 1. ( batch, 1, max_seq, max_seq)
# 2. ( batch, past_seq + seq,)
# 3. ( batch, seq, past_seq + seq,)
# 4. ( batch,)
# 5. (2 * batch,)
# For now, we only support case 2.
mask_index = inputs[3]

# (2, batch, num_heads, past_seq, head_size)
past = inputs[4]

# (batch, num_heads, seq, seq)
extra_add = inputs[5]

(batch_size, seq_len, _) = infer_shape(input_emb)
(out_hidden_x3,) = infer_shape(bias)
assert out_hidden_x3 % 3 == 0, "bias shape should be divisible by 3"
out_hidden = out_hidden_x3 // 3
assert (
out_hidden % num_heads == 0
), "output hidden size should be divisible by number of attention heads"
head_size = out_hidden // num_heads

assert (
mask_index is not None
), "Attention import currently only supports required mask_index"
mask_index_shape = infer_shape(mask_index)
assert (
len(mask_index_shape) == 2
and mask_index_shape[0] == batch_size
and mask_index_shape[1] == seq_len
), "currently only support (batch_size, sequence_length) mask index"

assert past is None, "past K, V state is not currently supported"
assert extra_add is None, "extra add to QxK not currently supported"

# split weight and biases and do the matmuls
w_Q, w_K, w_V = _op.split(weight, 3, axis=1)
b_Q, b_K, b_V = _op.split(bias, 3, axis=0)
# need to merge batch dimensions since TVM matmul is 2D
input_emb = _op.reverse_reshape(input_emb, (-1, 0))
Q = _op.add(_op.nn.matmul(input_emb, w_Q), b_Q)
K = _op.add(_op.nn.matmul(input_emb, w_K), b_K)
V = _op.add(_op.nn.matmul(input_emb, w_V), b_V)

# massage tensors in preparation for batched matmul
def massage(tensor):
tensor = _op.reshape(tensor, (batch_size, seq_len, num_heads, head_size))

# (batch_size, num_heads, seq_len, head_size)
tensor = _op.transpose(tensor, axes=[0, 2, 1, 3])

# (batch_size * num_heads, seq_len, head_size)
return _op.reverse_reshape(tensor, (-1, 0, 0))

Q = massage(Q)
K = massage(K)
V = massage(V)

K_present = _op.reshape(K, (batch_size, num_heads, seq_len, head_size))
V_present = _op.reshape(V, (batch_size, num_heads, seq_len, head_size))
present = _op.stack([K_present, V_present], axis=0)

att_scores = _op.nn.batch_matmul(Q, K, transpose_a=False, transpose_b=True)
score_dtype = infer_type(att_scores).checked_type.dtype
att_scores = _op.divide(
att_scores,
_op.const(np.sqrt(head_size), dtype=infer_type(att_scores).checked_type.dtype),
)
att_scores = _op.reshape(att_scores, (batch_size, num_heads, seq_len, seq_len))

# build the attention mask
att_mask = _op.cast(mask_index, score_dtype)
att_mask = _op.expand_dims(att_mask, 1, num_newaxis=2)
att_mask = _op.subtract(_op.const(1, dtype=score_dtype), att_mask)
att_mask = _op.multiply(att_mask, _op.const(-10000, dtype=score_dtype))

# apply the mask
att_scores = _op.add(att_scores, att_mask)
att_scores = _op.reshape(att_scores, (batch_size * num_heads, seq_len, seq_len))

att_probs = _op.nn.softmax(att_scores, axis=-1)

output = _op.nn.batch_matmul(att_probs, V, transpose_a=False, transpose_b=False)
output = _op.reverse_reshape(output, (-1, num_heads, 0, 0))
output = _op.transpose(output, axes=[0, 2, 1, 3])
output = _op.reshape(output, (0, 0, out_hidden))

return _expr.TupleWrapper(_expr.Tuple([output, present]), 2)


class Gemm(OnnxOpConverter):
"""Operator converter for Gemm."""

Expand Down Expand Up @@ -4808,6 +5020,12 @@ def _get_convert_map(opset):
"Elu": Elu.get_converter(opset),
"Gelu": Gelu.get_converter(opset),
"BiasGelu": BiasGelu.get_converter(opset),
# TODO: We need a better way to handle different domains, in case
# of name collisions. EmbedLayerNormalization, SkipLayerNormalization, and Attention
# are in the `com.microsoft` domain.
"EmbedLayerNormalization": EmbedLayerNormalization.get_converter(opset),
"SkipLayerNormalization": SkipLayerNormalization.get_converter(opset),
"Attention": Attention.get_converter(opset),
"Exp": Renamer("exp"),
"Greater": Renamer("greater"),
"GreaterOrEqual": Renamer("greater_equal"),
Expand Down
Loading

0 comments on commit 11b8cd3

Please sign in to comment.