Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 5 additions & 8 deletions examples/conformer/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,14 @@ speech_config:
normalize_per_feature: False

decoder_config:
vocabulary: ./vocabularies/librispeech_train_4_4076.subwords
vocabulary: ./vocabularies/librispeech/librispeech_train_4_1030.subwords
target_vocab_size: 4096
max_subword_length: 4
blank_at_zero: True
beam_width: 5
norm_score: True
corpus_files:
- /mnt/Miscellanea/Datasets/Speech/LibriSpeech/train-clean-100/transcripts.tsv
- /media/nlhuy/Data/ML/Datasets/ASR/Raw/LibriSpeech/train-clean-100/transcripts.tsv

model_config:
name: conformer
Expand Down Expand Up @@ -74,7 +74,7 @@ learning_config:
num_masks: 1
mask_factor: 27
data_paths:
- /mnt/Miscellanea/Datasets/Speech/LibriSpeech/train-clean-100/transcripts.tsv
- /media/nlhuy/Data/ML/Datasets/ASR/Raw/LibriSpeech/train-clean-100/transcripts.tsv
tfrecords_dir: /mnt/Miscellanea/Datasets/Speech/LibriSpeech/tfrecords
shuffle: True
cache: True
Expand All @@ -84,9 +84,7 @@ learning_config:

eval_dataset_config:
use_tf: True
data_paths:
- /mnt/Miscellanea/Datasets/Speech/LibriSpeech/dev-clean/transcripts.tsv
- /mnt/Miscellanea/Datasets/Speech/LibriSpeech/dev-other/transcripts.tsv
data_paths: null
tfrecords_dir: /mnt/Miscellanea/Datasets/Speech/LibriSpeech/tfrecords
shuffle: False
cache: True
Expand All @@ -96,8 +94,7 @@ learning_config:

test_dataset_config:
use_tf: True
data_paths:
- /mnt/Miscellanea/Datasets/Speech/LibriSpeech/test-clean/transcripts.tsv
data_paths: null
tfrecords_dir: /mnt/Miscellanea/Datasets/Speech/LibriSpeech/tfrecords
shuffle: False
cache: True
Expand Down
5 changes: 5 additions & 0 deletions examples/conformer/train_tpu_keras_subword_conformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@

parser.add_argument("--subwords_corpus", nargs="*", type=str, default=[], help="Transcript files for generating subwords")

parser.add_argument("--saved", type=str, default=None, help="Path to saved model")

args = parser.parse_args()

tf.config.optimizer.set_experimental_options({"auto_mixed_precision": args.mxp})
Expand Down Expand Up @@ -108,6 +110,9 @@
conformer._build(speech_featurizer.shape, prediction_shape=text_featurizer.prepand_shape, batch_size=global_batch_size)
conformer.summary(line_length=120)

if args.saved:
conformer.load_weights(args.saved, by_name=True, skip_mismatch=True)

optimizer = tf.keras.optimizers.Adam(
TransformerSchedule(
d_model=conformer.dmodel,
Expand Down
5 changes: 5 additions & 0 deletions examples/contextnet/train_tpu_keras_subword_contextnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@

parser.add_argument("--subwords_corpus", nargs="*", type=str, default=[], help="Transcript files for generating subwords")

parser.add_argument("--saved", type=str, default=None, help="Path to saved model")

args = parser.parse_args()

tf.config.optimizer.set_experimental_options({"auto_mixed_precision": args.mxp})
Expand Down Expand Up @@ -108,6 +110,9 @@
contextnet._build(speech_featurizer.shape, prediction_shape=text_featurizer.prepand_shape, batch_size=global_batch_size)
contextnet.summary(line_length=120)

if args.saved:
contextnet.load_weights(args.saved, by_name=True, skip_mismatch=True)

optimizer = tf.keras.optimizers.Adam(
TransformerSchedule(
d_model=contextnet.dmodel,
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

setuptools.setup(
name="TensorFlowASR",
version="0.7.8",
version="0.8.0",
author="Huy Le Nguyen",
author_email="nlhuy.cs.16@gmail.com",
description="Almost State-of-the-art Automatic Speech Recognition using Tensorflow 2",
Expand Down
6 changes: 2 additions & 4 deletions tensorflow_asr/datasets/asr_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from .base_dataset import BaseDataset, BUFFER_SIZE, TFRECORD_SHARDS, AUTOTUNE
from ..featurizers.speech_featurizers import load_and_convert_to_wav, read_raw_audio, tf_read_raw_audio, SpeechFeaturizer
from ..featurizers.text_featurizers import TextFeaturizer
from ..utils.utils import bytestring_feature, get_num_batches, preprocess_paths, get_nsamples_from_duration
from ..utils.utils import bytestring_feature, get_num_batches, preprocess_paths


class ASRDataset(BaseDataset):
Expand Down Expand Up @@ -54,9 +54,7 @@ def __init__(self,
def compute_metadata(self):
self.read_entries()
for _, duration, indices in tqdm.tqdm(self.entries, desc=f"Computing metadata for entries in {self.stage} dataset"):
nsamples = get_nsamples_from_duration(duration, sample_rate=self.speech_featurizer.sample_rate)
# https://www.tensorflow.org/api_docs/python/tf/signal/frame
input_length = -(-nsamples // self.speech_featurizer.frame_step)
input_length = self.speech_featurizer.get_length_from_duration(duration)
label = str(indices).split()
label_length = len(label)
self.speech_featurizer.update_length(input_length)
Expand Down
21 changes: 17 additions & 4 deletions tensorflow_asr/featurizers/speech_featurizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import io
import abc
import six
import math
import numpy as np
import librosa
import soundfile as sf
Expand Down Expand Up @@ -219,6 +221,7 @@ def __init__(self, speech_config: dict):
self.normalize_signal = speech_config.get("normalize_signal", True)
self.normalize_feature = speech_config.get("normalize_feature", True)
self.normalize_per_feature = speech_config.get("normalize_per_feature", False)
self.center = speech_config.get("center", True)
# Length
self.max_length = 0

Expand All @@ -232,6 +235,11 @@ def shape(self) -> list:
""" The shape of extracted features """
raise NotImplementedError()

def get_length_from_duration(self, duration):
nsamples = math.ceil(float(duration) * self.sample_rate)
if self.center: nsamples += self.nfft
return 1 + (nsamples - self.nfft) // self.frame_step # https://www.tensorflow.org/api_docs/python/tf/signal/frame

def update_length(self, length: int):
self.max_length = max(self.max_length, length)

Expand Down Expand Up @@ -280,7 +288,7 @@ def shape(self) -> list:
def stft(self, signal):
return np.square(
np.abs(librosa.core.stft(signal, n_fft=self.nfft, hop_length=self.frame_step,
win_length=self.frame_length, center=False, window="hann")))
win_length=self.frame_length, center=self.center, window="hann")))

def power_to_db(self, S, ref=1.0, amin=1e-10, top_db=80.0):
return librosa.power_to_db(S, ref=ref, amin=amin, top_db=top_db)
Expand Down Expand Up @@ -409,9 +417,14 @@ def shape(self) -> list:
return [length, self.num_feature_bins, 1]

def stft(self, signal):
return tf.square(
tf.abs(tf.signal.stft(signal, frame_length=self.frame_length,
frame_step=self.frame_step, fft_length=self.nfft, pad_end=True)))
if self.center: signal = tf.pad(signal, [[self.nfft // 2, self.nfft // 2]], mode="REFLECT")
window = tf.signal.hann_window(self.frame_length, periodic=True)
left_pad = (self.nfft - self.frame_length) // 2
right_pad = self.nfft - self.frame_length - left_pad
window = tf.pad(window, [[left_pad, right_pad]])
framed_signals = tf.signal.frame(signal, frame_length=self.nfft, frame_step=self.frame_step)
framed_signals *= window
return tf.square(tf.abs(tf.signal.rfft(framed_signals, [self.nfft])))

def power_to_db(self, S, ref=1.0, amin=1e-10, top_db=80.0):
if amin <= 0:
Expand Down
4 changes: 2 additions & 2 deletions tensorflow_asr/losses/keras/ctc_losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@


class CtcLoss(tf.keras.losses.Loss):
def __init__(self, blank=0, global_batch_size=None, reduction=tf.keras.losses.Reduction.NONE, name=None):
super(CtcLoss, self).__init__(reduction=reduction, name=name)
def __init__(self, blank=0, global_batch_size=None, name=None):
super(CtcLoss, self).__init__(reduction=tf.keras.losses.Reduction.NONE, name=name)
self.blank = blank
self.global_batch_size = global_batch_size

Expand Down
4 changes: 2 additions & 2 deletions tensorflow_asr/losses/keras/rnnt_losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@


class RnntLoss(tf.keras.losses.Loss):
def __init__(self, blank=0, global_batch_size=None, reduction=tf.keras.losses.Reduction.NONE, name=None):
super(RnntLoss, self).__init__(reduction=reduction, name=name)
def __init__(self, blank=0, global_batch_size=None, name=None):
super(RnntLoss, self).__init__(reduction=tf.keras.losses.Reduction.NONE, name=name)
self.blank = blank
self.global_batch_size = global_batch_size

Expand Down
57 changes: 40 additions & 17 deletions tensorflow_asr/utils/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,12 @@
# limitations under the License.

from typing import Tuple
import numpy as np
import tensorflow as tf
from nltk.metrics import distance
from .utils import bytes_to_string


def wer(decode: np.ndarray, target: np.ndarray) -> Tuple[tf.Tensor, tf.Tensor]:
"""Word Error Rate

Args:
decode (np.ndarray): array of prediction texts
target (np.ndarray): array of groundtruth texts

Returns:
tuple: a tuple of tf.Tensor of (edit distances, number of words) of each text
"""
def _wer(decode, target):
decode = bytes_to_string(decode)
target = bytes_to_string(target)
dis = 0.0
Expand All @@ -45,16 +35,20 @@ def wer(decode: np.ndarray, target: np.ndarray) -> Tuple[tf.Tensor, tf.Tensor]:
return tf.convert_to_tensor(dis, tf.float32), tf.convert_to_tensor(length, tf.float32)


def cer(decode: np.ndarray, target: np.ndarray) -> Tuple[tf.Tensor, tf.Tensor]:
"""Character Error Rate
def wer(_decode: tf.Tensor, _target: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
"""Word Error Rate

Args:
decode (np.ndarray): array of prediction texts
target (np.ndarray): array of groundtruth texts

Returns:
tuple: a tuple of tf.Tensor of (edit distances, number of characters) of each text
tuple: a tuple of tf.Tensor of (edit distances, number of words) of each text
"""
return tf.numpy_function(_wer, inp=[_decode, _target], Tout=[tf.float32, tf.float32])


def _cer(decode, target):
decode = bytes_to_string(decode)
target = bytes_to_string(target)
dis = 0
Expand All @@ -65,6 +59,36 @@ def cer(decode: np.ndarray, target: np.ndarray) -> Tuple[tf.Tensor, tf.Tensor]:
return tf.convert_to_tensor(dis, tf.float32), tf.convert_to_tensor(length, tf.float32)


def cer(_decode: tf.Tensor, _target: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
"""Character Error Rate

Args:
decode (np.ndarray): array of prediction texts
target (np.ndarray): array of groundtruth texts

Returns:
tuple: a tuple of tf.Tensor of (edit distances, number of characters) of each text
"""
return tf.numpy_function(_cer, inp=[_decode, _target], Tout=[tf.float32, tf.float32])


def tf_cer(decode: tf.Tensor, target: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
"""Tensorflwo Charactor Error rate

Args:
decoder (tf.Tensor): tensor shape [B]
target (tf.Tensor): tensor shape [B]

Returns:
tuple: a tuple of tf.Tensor of (edit distances, number of characters) of each text
"""
decode = tf.strings.bytes_split(decode) # [B, N]
target = tf.strings.bytes_split(target) # [B, M]
distances = tf.edit_distance(decode.to_sparse(), target.to_sparse(), normalize=False) # [B]
lengths = tf.cast(target.row_lengths(axis=1), dtype=tf.float32) # [B]
return tf.reduce_sum(distances), tf.reduce_sum(lengths)


class ErrorRate(tf.keras.metrics.Metric):
""" Metric for WER and CER """

Expand All @@ -75,10 +99,9 @@ def __init__(self, func, name="error_rate", **kwargs):
self.func = func

def update_state(self, decode: tf.Tensor, target: tf.Tensor):
n, d = tf.numpy_function(self.func, inp=[decode, target], Tout=[tf.float32, tf.float32])
n, d = self.func(decode, target)
self.numerator.assign_add(n)
self.denominator.assign_add(d)

def result(self):
if self.denominator == 0.0: return 0.0
return (self.numerator / self.denominator) * 100
return tf.math.divide_no_nan(self.numerator, self.denominator) * 100
4 changes: 0 additions & 4 deletions tensorflow_asr/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,3 @@ def body(index, tfarray):

index, tfarray = tf.while_loop(condition, body, loop_vars=[index, tfarray], swap_memory=False)
return tfarray


def get_nsamples_from_duration(duration, sample_rate=16000):
return math.ceil(float(duration) * sample_rate)
16 changes: 10 additions & 6 deletions tests/featurizer/test_speech_featurizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
setup_environment()
import librosa
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import librosa.display
from tensorflow_asr.featurizers.speech_featurizers import read_raw_audio, TFSpeechFeaturizer, NumpySpeechFeaturizer


Expand All @@ -33,26 +35,28 @@ def main(argv):
"normalize_feature": True,
"normalize_per_feature": False,
"num_feature_bins": 80,
"center": True
}
signal = read_raw_audio(speech_file, speech_conf["sample_rate"])

nsf = NumpySpeechFeaturizer(speech_conf)
sf = TFSpeechFeaturizer(speech_conf)
ft = nsf.stft(signal)
nft = nsf.stft(signal)
print(nft.shape, np.mean(nft))
ft = sf.stft(signal).numpy().T
print(ft.shape, np.mean(ft))
ft = sf.stft(signal).numpy()
print(ft.shape, np.mean(ft))
ft = sf.extract(signal)
print(nft == ft)
ft = tf.squeeze(sf.extract(signal)).numpy().T

plt.figure(figsize=(16, 2.5))
ax = plt.gca()
ax.set_title(f"{feature_type}", fontweight="bold")
librosa.display.specshow(ft.T, cmap="magma")
librosa.display.specshow(ft, cmap="magma")
v1 = np.linspace(ft.min(), ft.max(), 8, endpoint=True)
plt.colorbar(pad=0.01, fraction=0.02, ax=ax, format="%.2f", ticks=v1)
plt.tight_layout()
# plt.savefig(argv[3])
# plt.show()
plt.show()
# plt.figure(figsize=(15, 5))
# for i in range(4):
# plt.subplot(2, 2, i + 1)
Expand Down
Loading