From 817c6f00e74488bd218b1769edc8efa5d1795429 Mon Sep 17 00:00:00 2001 From: Huy Le Nguyen Date: Sun, 8 Nov 2020 19:50:20 +0700 Subject: [PATCH] :zap: Fixed tflite conformer --- tensorflow_asr/models/conformer.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tensorflow_asr/models/conformer.py b/tensorflow_asr/models/conformer.py index 5c923da65a..5a02fc2b87 100755 --- a/tensorflow_asr/models/conformer.py +++ b/tensorflow_asr/models/conformer.py @@ -19,6 +19,7 @@ from .layers.subsampling import VggSubsampling, Conv2dSubsampling from .layers.positional_encoding import PositionalEncoding, PositionalEncodingConcat from .layers.multihead_attention import MultiHeadAttention, RelPositionMultiHeadAttention +from ..utils.utils import shape_list L2 = tf.keras.regularizers.l2(1e-6) @@ -180,14 +181,15 @@ def __init__(self, def call(self, inputs, training=False, **kwargs): outputs = self.ln(inputs, training=training) - outputs = tf.expand_dims(outputs, axis=2) + B, T, E = shape_list(outputs) + outputs = tf.reshape(outputs, [B, T, 1, E]) outputs = self.pw_conv_1(outputs, training=training) outputs = self.glu(outputs) outputs = self.dw_conv(outputs, training=training) outputs = self.bn(outputs, training=training) outputs = self.swish(outputs) outputs = self.pw_conv_2(outputs, training=training) - outputs = tf.squeeze(outputs, axis=2) + outputs = tf.reshape(outputs, [B, T, E]) outputs = self.do(outputs, training=training) outputs = self.res_add([inputs, outputs]) return outputs