Skip to content

[Bug] EmbedLayerNormalization cannot work in relax.frontend.onnx #17895

@erdayang

Description

@erdayang

This is the implementation of the EmbedLayerNormalization op in relax.frontend.onnx.

class EmbedLayerNormalization(OnnxOpConverter):
    """Converts a microsoft contrib EmbedLayerNormalization node into a Relax expression."""

    @classmethod
    def _impl_v1(cls, bb, inputs, attr, params):
        input_ids = inputs[0]
        segment_ids = inputs[1]
        word_emb = inputs[2]
        pos_emb = inputs[3]
        segment_emb = inputs[4]
        gamma = inputs[5]
        beta = inputs[6]
        mask = inputs[7]
        pos_ids = inputs[8]

        epsilon = attr.get("epsilon", 1e-12)

        (batch_size, seq_len) = [dim.value for dim in input_ids.struct_info.shape]

        if segment_ids:
            assert segment_emb

        if pos_ids is None:
            pos_ids = relax.const([list(range(seq_len))] * batch_size, dtype="int64")
        # TODO(jwfromm) Replace with relax ops once take has better support.
        word_vec = bb.emit_te(topi.take, word_emb, input_ids, 0)
        if segment_ids:
            segment_vec = bb.emit_te(topi.take, segment_emb, segment_ids, 0)
        pos_vec = bb.emit_te(topi.take, pos_emb, pos_ids, 0)

        vec_sum = relax.op.add(word_vec, pos_vec)
        if segment_ids:
            vec_sum = relax.op.add(vec_sum, segment_vec)

        ln = relax.op.nn.layer_norm(vec_sum, gamma, beta, axes=-1, epsilon=epsilon)

        mask_index = relax.const(_np.zeros((batch_size,), dtype="int64"))
        if mask:
            # Caculate number of words per sentence.
            mask_index = relax.op.sum(mask, axis=1)

        return relax.Tuple([ln, mask_index])

The following line of code throws an error, indicating that dim does not have a value attribute. I printed the type of dim, which shows as SizeVar. After checking the definition of SizeVar, I can confirm that it indeed does not have a value attribute.

(batch_size, seq_len) = [dim.value for dim in input_ids.struct_info.shape]

Additionally, it seems that the following line requires batch_size to be of type int

mask_index = relax.const(_np.zeros((batch_size,), dtype="int64"))

I'm not very clear about TVM's type design. Could you tell me if this is considered a bug?

Same error in Attenion op

(batch_size, seq_len, input_hidden_size) = [
            val.value for val in input_emb.struct_info.shape.values
        ]

Metadata

Metadata

Assignees

No one assigned

    Labels

    needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions