Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions examples/conformer/masking/masking.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import tensorflow as tf
from tensorflow_asr.utils.utils import shape_list
from tensorflow_asr.utils.utils import shape_list, get_reduced_length


def create_padding_mask(features, input_length, time_reduction_factor):
Expand All @@ -14,10 +14,10 @@ def create_padding_mask(features, input_length, time_reduction_factor):
[tf.Tensor]: with shape [B, Tquery, Tkey]
"""
batch_size, padded_time, _, _ = shape_list(features)
reduced_padded_time = tf.math.ceil(padded_time / time_reduction_factor)
reduced_padded_time = get_reduced_length(padded_time, time_reduction_factor)

def create_mask(length):
reduced_length = tf.math.ceil(length / time_reduction_factor)
reduced_length = get_reduced_length(length, time_reduction_factor)
mask = tf.ones([reduced_length, reduced_length], dtype=tf.float32)
return tf.pad(
mask,
Expand Down
5 changes: 3 additions & 2 deletions examples/conformer/masking/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from masking import create_padding_mask
from tensorflow_asr.runners.transducer_runners import TransducerTrainer, TransducerTrainerGA
from tensorflow_asr.losses.rnnt_losses import rnnt_loss
from tensorflow_asr.utils.utils import get_reduced_length


class TrainerWithMasking(TransducerTrainer):
Expand All @@ -17,7 +18,7 @@ def _train_step(self, batch):
tape.watch(logits)
per_train_loss = rnnt_loss(
logits=logits, labels=labels, label_length=label_length,
logit_length=tf.cast(tf.math.ceil(input_length / self.model.time_reduction_factor), dtype=tf.int32),
logit_length=get_reduced_length(input_length, self.model.time_reduction_factor),
blank=self.text_featurizer.blank
)
train_loss = tf.nn.compute_average_loss(per_train_loss,
Expand All @@ -41,7 +42,7 @@ def _train_step(self, batch):
tape.watch(logits)
per_train_loss = rnnt_loss(
logits=logits, labels=labels, label_length=label_length,
logit_length=tf.cast(tf.math.ceil(input_length / self.model.time_reduction_factor), dtype=tf.int32),
logit_length=get_reduced_length(input_length, self.model.time_reduction_factor),
blank=self.text_featurizer.blank
)
train_loss = tf.nn.compute_average_loss(
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@

setuptools.setup(
name="TensorFlowASR",
version="0.5.0",
version="0.5.1",
author="Huy Le Nguyen",
author_email="nlhuy.cs.16@gmail.com",
description="Almost State-of-the-art Automatic Speech Recognition using Tensorflow 2",
Expand Down
2 changes: 1 addition & 1 deletion tensorflow_asr/datasets/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,5 +53,5 @@ Where `prediction` and `prediction_length` are the label prepanded by blank and
**Outputs when iterating in test step**

```python
(path, signals, labels)
(path, features, input_lengths, labels)
```
32 changes: 24 additions & 8 deletions tensorflow_asr/datasets/asr_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,8 +241,15 @@ class ASRTFRecordTestDataset(ASRTFRecordDataset):
def preprocess(self, path, transcript):
with tf.device("/CPU:0"):
signal = read_raw_audio(path.decode("utf-8"), self.speech_featurizer.sample_rate)

features = self.speech_featurizer.extract(signal)
features = tf.convert_to_tensor(features, tf.float32)
input_length = tf.cast(tf.shape(features)[0], tf.int32)

label = self.text_featurizer.extract(transcript.decode("utf-8"))
return path, signal, tf.convert_to_tensor(label, dtype=tf.int32)
label = tf.convert_to_tensor(label, dtype=tf.int32)

return path, features, input_length, label

@tf.function
def parse(self, record):
Expand All @@ -256,7 +263,7 @@ def parse(self, record):
return tf.numpy_function(
self.preprocess,
inp=[example["audio"], example["transcript"]],
Tout=(tf.string, tf.float32, tf.int32)
Tout=(tf.string, tf.float32, tf.int32, tf.int32)
)

def process(self, dataset, batch_size):
Expand All @@ -273,10 +280,11 @@ def process(self, dataset, batch_size):
batch_size=batch_size,
padded_shapes=(
tf.TensorShape([]),
tf.TensorShape([None]),
tf.TensorShape(self.speech_featurizer.shape),
tf.TensorShape([]),
tf.TensorShape([None]),
),
padding_values=("", 0.0, self.text_featurizer.blank),
padding_values=("", 0.0, 0, self.text_featurizer.blank),
drop_remainder=True
)

Expand Down Expand Up @@ -304,15 +312,22 @@ class ASRSliceTestDataset(ASRDataset):
def preprocess(self, path, transcript):
with tf.device("/CPU:0"):
signal = read_raw_audio(path.decode("utf-8"), self.speech_featurizer.sample_rate)

features = self.speech_featurizer.extract(signal)
features = tf.convert_to_tensor(features, tf.float32)
input_length = tf.cast(tf.shape(features)[0], tf.int32)

label = self.text_featurizer.extract(transcript.decode("utf-8"))
return path, signal, tf.convert_to_tensor(label, dtype=tf.int32)
label = tf.convert_to_tensor(label, dtype=tf.int32)

return path, features, input_length, label

@tf.function
def parse(self, record):
return tf.numpy_function(
self.preprocess,
inp=[record[0], record[1]],
Tout=[tf.string, tf.float32, tf.int32]
Tout=[tf.string, tf.float32, tf.int32, tf.int32]
)

def process(self, dataset, batch_size):
Expand All @@ -329,10 +344,11 @@ def process(self, dataset, batch_size):
batch_size=batch_size,
padded_shapes=(
tf.TensorShape([]),
tf.TensorShape([None]),
tf.TensorShape(self.speech_featurizer.shape),
tf.TensorShape([]),
tf.TensorShape([None]),
),
padding_values=("", 0.0, self.text_featurizer.blank),
padding_values=("", 0.0, 0, self.text_featurizer.blank),
drop_remainder=True
)

Expand Down
8 changes: 8 additions & 0 deletions tensorflow_asr/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,11 @@ def _build(self, *args, **kwargs):
@abc.abstractmethod
def call(self, inputs, training=False, **kwargs):
raise NotImplementedError()

@abc.abstractmethod
def recognize(self, features, input_lengths, **kwargs):
pass

@abc.abstractmethod
def recognize_beam(self, features, input_lengths, **kwargs):
pass
11 changes: 7 additions & 4 deletions tensorflow_asr/models/contextnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
""" Ref: https://github.com/iankur/ContextNet """

from typing import List
from typing import List, Optional
import tensorflow as tf
from .transducer import Transducer
from ..utils.utils import merge_two_last_dims, get_reduced_length
Expand Down Expand Up @@ -234,8 +234,7 @@ def __init__(self,
)
self.dmodel = self.encoder.blocks[-1].dmodel
self.time_reduction_factor = 1
for block in self.encoder.blocks:
self.time_reduction_factor *= block.time_reduction_factor
for block in self.encoder.blocks: self.time_reduction_factor *= block.time_reduction_factor

def call(self, inputs, training=False, **kwargs):
features, input_length, prediction, prediction_length = inputs
Expand All @@ -244,8 +243,12 @@ def call(self, inputs, training=False, **kwargs):
outputs = self.joint_net([enc, pred], training=training, **kwargs)
return outputs

def encoder_inference(self, features):
def encoder_inference(self,
features: tf.Tensor,
input_length: Optional[tf.Tensor] = None,
with_batch: bool = False):
with tf.name_scope(f"{self.name}_encoder"):
if with_batch: return self.encoder([features, input_length], training=False)
input_length = tf.expand_dims(tf.shape(features)[0], axis=0)
outputs = tf.expand_dims(features, axis=0)
outputs = self.encoder([outputs, input_length], training=False)
Expand Down
33 changes: 12 additions & 21 deletions tensorflow_asr/models/ctc.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional
import numpy as np
import tensorflow as tf

from . import Model
from ..featurizers.speech_featurizers import TFSpeechFeaturizer
from ..featurizers.text_featurizers import TextFeaturizer
from ..utils.utils import shape_list
from ..utils.utils import shape_list, get_reduced_length


class CtcModel(Model):
Expand All @@ -41,20 +42,15 @@ def call(self, inputs, training=False, **kwargs):
# -------------------------------- GREEDY -------------------------------------

@tf.function
def recognize(self, signals):

def extract_fn(signal): return self.speech_featurizer.tf_extract(signal)

features = tf.map_fn(extract_fn, signals,
fn_output_signature=tf.TensorSpec(self.speech_featurizer.shape, dtype=tf.float32))
def recognize(self, features: tf.Tensor, input_length: Optional[tf.Tensor]):
logits = self(features, training=False)
probs = tf.nn.softmax(logits)

def map_fn(prob): return tf.numpy_function(self.perform_greedy, inp=[prob], Tout=tf.string)
def map_fn(prob): return tf.numpy_function(self.__perform_greedy, inp=[prob], Tout=tf.string)

return tf.map_fn(map_fn, probs, fn_output_signature=tf.TensorSpec([], dtype=tf.string))

def perform_greedy(self, probs: np.ndarray):
def __perform_greedy(self, probs: np.ndarray):
from ctc_decoders import ctc_greedy_decoder
decoded = ctc_greedy_decoder(probs, vocabulary=self.text_featurizer.vocab_array)
return tf.convert_to_tensor(decoded, dtype=tf.string)
Expand All @@ -71,7 +67,7 @@ def recognize_tflite(self, signal):
features = self.speech_featurizer.tf_extract(signal)
features = tf.expand_dims(features, axis=0)
input_length = shape_list(features)[1]
input_length = input_length // self.base_model.time_reduction_factor
input_length = get_reduced_length(input_length, self.base_model.time_reduction_factor)
input_length = tf.expand_dims(input_length, axis=0)
logits = self(features, training=False)
probs = tf.nn.softmax(logits)
Expand All @@ -85,25 +81,20 @@ def recognize_tflite(self, signal):
# -------------------------------- BEAM SEARCH -------------------------------------

@tf.function
def recognize_beam(self, signals, lm=False):

def extract_fn(signal): return self.speech_featurizer.tf_extract(signal)

features = tf.map_fn(extract_fn, signals,
fn_output_signature=tf.TensorSpec(self.speech_featurizer.shape, dtype=tf.float32))
def recognize_beam(self, features: tf.Tensor, input_length: Optional[tf.Tensor], lm: bool = False):
logits = self(features, training=False)
probs = tf.nn.softmax(logits)

def map_fn(prob): return tf.numpy_function(self.perform_beam_search, inp=[prob, lm], Tout=tf.string)
def map_fn(prob): return tf.numpy_function(self.__perform_beam_search, inp=[prob, lm], Tout=tf.string)

return tf.map_fn(map_fn, probs, dtype=tf.string)

def perform_beam_search(self, probs: np.ndarray, lm: bool = False):
def __perform_beam_search(self, probs: np.ndarray, lm: bool = False):
from ctc_decoders import ctc_beam_search_decoder
decoded = ctc_beam_search_decoder(
probs_seq=probs,
vocabulary=self.text_featurizer.vocab_array,
beam_size=self.text_featurizer.decoder_config["beam_width"],
beam_size=self.text_featurizer.decoder_config.beam_width,
ext_scoring_func=self.text_featurizer.scorer if lm else None
)
decoded = decoded[0][-1]
Expand All @@ -122,13 +113,13 @@ def recognize_beam_tflite(self, signal):
features = self.speech_featurizer.tf_extract(signal)
features = tf.expand_dims(features, axis=0)
input_length = shape_list(features)[1]
input_length = input_length // self.base_model.time_reduction_factor
input_length = get_reduced_length(input_length, self.base_model.time_reduction_factor)
input_length = tf.expand_dims(input_length, axis=0)
logits = self(features, training=False)
probs = tf.nn.softmax(logits)
decoded = tf.keras.backend.ctc_decode(
y_pred=probs, input_length=input_length, greedy=False,
beam_width=self.text_featurizer.decoder_config["beam_width"]
beam_width=self.text_featurizer.decoder_config.beam_width
)
decoded = tf.cast(decoded[0][0][0], dtype=tf.int32)
transcript = self.text_featurizer.indices2upoints(decoded)
Expand Down
Loading