diff --git a/tensorflow_asr/models/conformer.py b/tensorflow_asr/models/conformer.py index 5c923da65a..5a02fc2b87 100755 --- a/tensorflow_asr/models/conformer.py +++ b/tensorflow_asr/models/conformer.py @@ -19,6 +19,7 @@ from .layers.subsampling import VggSubsampling, Conv2dSubsampling from .layers.positional_encoding import PositionalEncoding, PositionalEncodingConcat from .layers.multihead_attention import MultiHeadAttention, RelPositionMultiHeadAttention +from ..utils.utils import shape_list L2 = tf.keras.regularizers.l2(1e-6) @@ -180,14 +181,15 @@ def __init__(self, def call(self, inputs, training=False, **kwargs): outputs = self.ln(inputs, training=training) - outputs = tf.expand_dims(outputs, axis=2) + B, T, E = shape_list(outputs) + outputs = tf.reshape(outputs, [B, T, 1, E]) outputs = self.pw_conv_1(outputs, training=training) outputs = self.glu(outputs) outputs = self.dw_conv(outputs, training=training) outputs = self.bn(outputs, training=training) outputs = self.swish(outputs) outputs = self.pw_conv_2(outputs, training=training) - outputs = tf.squeeze(outputs, axis=2) + outputs = tf.reshape(outputs, [B, T, E]) outputs = self.do(outputs, training=training) outputs = self.res_add([inputs, outputs]) return outputs