Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@

setuptools.setup(
name="TensorFlowASR",
version="0.6.2",
version="0.6.3",
author="Huy Le Nguyen",
author_email="nlhuy.cs.16@gmail.com",
description="Almost State-of-the-art Automatic Speech Recognition using Tensorflow 2",
Expand Down
8 changes: 4 additions & 4 deletions tensorflow_asr/datasets/asr_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,9 +238,9 @@ def create(self, batch_size):


class ASRTFRecordTestDataset(ASRTFRecordDataset):
def preprocess(self, path, transcript):
def preprocess(self, path, audio, transcript):
with tf.device("/CPU:0"):
signal = read_raw_audio(path.decode("utf-8"), self.speech_featurizer.sample_rate)
signal = read_raw_audio(audio, self.speech_featurizer.sample_rate)

features = self.speech_featurizer.extract(signal)
features = tf.convert_to_tensor(features, tf.float32)
Expand All @@ -262,8 +262,8 @@ def parse(self, record):

return tf.numpy_function(
self.preprocess,
inp=[example["audio"], example["transcript"]],
Tout=(tf.string, tf.float32, tf.int32, tf.int32)
inp=[example["path"], example["audio"], example["transcript"]],
Tout=[tf.string, tf.float32, tf.int32, tf.int32]
)

def process(self, dataset, batch_size):
Expand Down
5 changes: 2 additions & 3 deletions tensorflow_asr/featurizers/text_featurizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,12 +270,11 @@ def iextract(self, indices: tf.Tensor) -> tf.Tensor:
clear_after_read=False, element_shape=tf.TensorShape([])
)

def cond(batch, total, transcripts): return tf.less(batch, total)
def cond(batch, total, _): return tf.less(batch, total)

def body(batch, total, transcripts):
upoints = self.indices2upoints(indices[batch])
_transcript = tf.strings.unicode_encode(upoints, "UTF-8")
transcripts = transcripts.write(batch, _transcript)
transcripts = transcripts.write(batch, tf.strings.unicode_encode(upoints, "UTF-8"))
return batch + 1, total, transcripts

_, _, transcripts = tf.while_loop(cond, body, loop_vars=[batch, total, transcripts])
Expand Down
6 changes: 2 additions & 4 deletions tensorflow_asr/models/conformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,7 @@ def __init__(self,
kernel_regularizer=kernel_regularizer,
bias_regularizer=bias_regularizer
)
self.swish = tf.keras.layers.Activation(
tf.keras.activations.swish, name=f"{name}_swish_activation")
self.swish = tf.keras.layers.Activation(tf.nn.swish, name=f"{name}_swish_activation")
self.do1 = tf.keras.layers.Dropout(dropout, name=f"{name}_dropout_1")
self.ffn2 = tf.keras.layers.Dense(
input_dim, name=f"{name}_dense_2",
Expand Down Expand Up @@ -168,8 +167,7 @@ def __init__(self,
gamma_regularizer=kernel_regularizer,
beta_regularizer=bias_regularizer
)
self.swish = tf.keras.layers.Activation(
tf.keras.activations.swish, name=f"{name}_swish_activation")
self.swish = tf.keras.layers.Activation(tf.nn.swish, name=f"{name}_swish_activation")
self.pw_conv_2 = tf.keras.layers.Conv2D(
filters=input_dim, kernel_size=1, strides=1,
padding="valid", name=f"{name}_pw_conv_2",
Expand Down
2 changes: 1 addition & 1 deletion tensorflow_asr/models/contextnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

def get_activation(activation: str = "silu"):
activation = activation.lower()
if activation in ["silu", "swish"]: return tf.nn.silu
if activation in ["silu", "swish"]: return tf.nn.swish
elif activation == "relu": return tf.nn.relu
elif activation == "linear": return tf.keras.activations.linear
else: raise ValueError("activation must be either 'silu', 'swish', 'relu' or 'linear'")
Expand Down
41 changes: 35 additions & 6 deletions tensorflow_asr/models/streaming_transducer.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,36 @@ def recognize_tflite(self, signal, predicted, encoder_states, prediction_states)
hypothesis.states
)

def recognize_tflite_with_timestamp(self, signal, predicted, encoder_states, prediction_states):
features = self.speech_featurizer.tf_extract(signal)
encoded, new_encoder_states = self.encoder_inference(features, encoder_states)
hypothesis = self._perform_greedy(encoded, tf.shape(encoded)[0], predicted, prediction_states)
indices = self.text_featurizer.normalize_indices(hypothesis.prediction)
upoints = tf.gather_nd(self.text_featurizer.upoints, tf.expand_dims(indices, axis=-1)) # [None, max_subword_length]

num_samples = tf.cast(tf.shape(signal)[0], dtype=tf.float32)
total_time_reduction_factor = self.time_reduction_factor * self.speech_featurizer.frame_step

stime = tf.range(0, num_samples, delta=total_time_reduction_factor, dtype=tf.float32)
stime /= tf.cast(self.speech_featurizer.sample_rate, dtype=tf.float32)

etime = tf.range(total_time_reduction_factor, num_samples, delta=total_time_reduction_factor, dtype=tf.float32)
etime /= tf.cast(self.speech_featurizer.sample_rate, dtype=tf.float32)

non_blank = tf.where(tf.not_equal(upoints, 0))
non_blank_transcript = tf.gather_nd(upoints, non_blank)
non_blank_stime = tf.gather_nd(tf.repeat(tf.expand_dims(stime, axis=-1), tf.shape(upoints)[-1], axis=-1), non_blank)
non_blank_etime = tf.gather_nd(tf.repeat(tf.expand_dims(etime, axis=-1), tf.shape(upoints)[-1], axis=-1), non_blank)

return (
non_blank_transcript,
non_blank_stime,
non_blank_etime,
hypothesis.prediction,
new_encoder_states,
hypothesis.states
)

# -------------------------------- BEAM SEARCH -------------------------------------

@tf.function
Expand Down Expand Up @@ -325,15 +355,14 @@ def recognize_beam(self,

# -------------------------------- TFLITE -------------------------------------

def make_tflite_function(self, greedy: bool = True):
def make_tflite_function(self, timestamp: bool = True):
tflite_func = self.recognize_tflite_with_timestamp if timestamp else self.recognize_tflite
return tf.function(
self.recognize_tflite,
tflite_func,
input_signature=[
tf.TensorSpec([None], dtype=tf.float32),
tf.TensorSpec([], dtype=tf.int32),
tf.TensorSpec(self.encoder.get_initial_state().get_shape(),
dtype=tf.float32),
tf.TensorSpec(self.predict_net.get_initial_state().get_shape(),
dtype=tf.float32)
tf.TensorSpec(self.encoder.get_initial_state().get_shape(), dtype=tf.float32),
tf.TensorSpec(self.predict_net.get_initial_state().get_shape(), dtype=tf.float32)
]
)
Loading