Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion examples/conformer/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ model_config:

learning_config:
augmentations:
use_tf: True
after:
time_masking:
num_masks: 10
Expand All @@ -77,7 +78,7 @@ learning_config:
- /mnt/Miscellanea/Datasets/Speech/LibriSpeech/dev-other/transcripts.tsv
test_paths:
- /mnt/Miscellanea/Datasets/Speech/LibriSpeech/test-clean/transcripts.tsv
tfrecords_dir: /mnt/Miscellanea/Datasets/Speech/LibriSpeech/tfrecords
tfrecords_dir: /mnt/Miscellanea/Datasets/Speech/LibriSpeech/tfrecords-test

optimizer_config:
warmup_steps: 40000
Expand Down
16 changes: 15 additions & 1 deletion scripts/create_tfrecords.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,29 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import argparse
from tensorflow_asr.configs.config import Config
from tensorflow_asr.utils.utils import preprocess_paths
from tensorflow_asr.datasets.asr_dataset import ASRTFRecordDataset
from tensorflow_asr.featurizers.text_featurizers import SubwordFeaturizer

modes = ["train", "eval", "test"]

parser = argparse.ArgumentParser(prog="TFRecords Creation")

parser.add_argument("--mode", "-m", type=str, default=None, help=f"Mode in {modes}")

parser.add_argument("--config", type=str, default=None, help="The file path of model configuration file")

parser.add_argument("--tfrecords_dir", type=str, default=None, help="Directory to tfrecords")

parser.add_argument("--tfrecords_shards", type=int, default=16, help="Number of tfrecords shards")

parser.add_argument("--shuffle", default=False, action="store_true", help="Shuffle data or not")

parser.add_argument("--subwords", type=str, default=None, help="Path to file that stores generated subwords")

parser.add_argument("transcripts", nargs="+", type=str, default=None, help="Paths to transcript files")

args = parser.parse_args()
Expand All @@ -37,8 +44,15 @@
transcripts = preprocess_paths(args.transcripts)
tfrecords_dir = preprocess_paths(args.tfrecords_dir)

config = Config(args.config)
if args.subwords and os.path.exists(args.subwords):
print("Loading subwords ...")
text_featurizer = SubwordFeaturizer.load_from_file(config.decoder_config, args.subwords)
else:
raise ValueError("subwords must be set")

ASRTFRecordDataset(
data_paths=transcripts, tfrecords_dir=tfrecords_dir,
speech_featurizer=None, text_featurizer=None,
speech_featurizer=None, text_featurizer=text_featurizer,
stage=args.mode, shuffle=args.shuffle, tfrecords_shards=args.tfrecords_shards
).create_tfrecords()
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
requirements = [
"tensorflow-datasets>=3.2.1,<4.0.0",
"tensorflow-addons>=0.10.0",
"tensorflow-io>=0.17.0",
"setuptools>=47.1.1",
"librosa>=0.8.0",
"soundfile>=0.10.3",
Expand All @@ -35,7 +36,7 @@

setuptools.setup(
name="TensorFlowASR",
version="0.7.0",
version="0.7.1",
author="Huy Le Nguyen",
author_email="nlhuy.cs.16@gmail.com",
description="Almost State-of-the-art Automatic Speech Recognition using Tensorflow 2",
Expand Down
41 changes: 38 additions & 3 deletions tensorflow_asr/augmentations/augments.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import tensorflow as tf
import nlpaug.flow as naf

from .signal_augment import SignalCropping, SignalLoudness, SignalMask, SignalNoise, \
SignalPitch, SignalShift, SignalSpeed, SignalVtlp
from .spec_augment import FreqMasking, TimeMasking
from .spec_augment import FreqMasking, TimeMasking, TFFreqMasking, TFTimeMasking


AUGMENTATIONS = {
Expand All @@ -32,12 +33,34 @@
"vtlp": SignalVtlp
}

TFAUGMENTATIONS = {
"freq_masking": TFFreqMasking,
"time_masking": TFTimeMasking,
}


class TFAugmentationExecutor:
def __init__(self, augmentations: list):
self.augmentations = augmentations

@tf.function
def augment(self, inputs):
outputs = inputs
for au in self.augmentations:
outputs = au.augment(outputs)
return outputs


class Augmentation:
def __init__(self, config: dict = None):
if not config: config = {}
self.before = self.parse(config.get("before", {}))
self.after = self.parse(config.get("after", {}))
self.use_tf = config.pop("use_tf", False)
if self.use_tf:
self.before = self.tf_parse(config.pop("before", {}))
self.after = self.tf_parse(config.pop("after", {}))
else:
self.before = self.parse(config.pop("before", {}))
self.after = self.parse(config.pop("after", {}))

@staticmethod
def parse(config: dict) -> list:
Expand All @@ -50,3 +73,15 @@ def parse(config: dict) -> list:
aug = au(**value) if value is not None else au()
augmentations.append(aug)
return naf.Sometimes(augmentations)

@staticmethod
def tf_parse(config: dict) -> list:
augmentations = []
for key, value in config.items():
au = TFAUGMENTATIONS.get(key, None)
if au is None:
raise KeyError(f"No tf augmentation named: {key}\n"
f"Available tf augmentations: {TFAUGMENTATIONS.keys()}")
aug = au(**value) if value is not None else au()
augmentations.append(aug)
return TFAugmentationExecutor(augmentations)
66 changes: 64 additions & 2 deletions tensorflow_asr/augmentations/spec_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

""" Augmentation on spectrogram: http://arxiv.org/abs/1904.08779 """

import numpy as np
import tensorflow as tf

from nlpaug.flow import Sequential
from nlpaug.util import Action
from nlpaug.model.spectrogram import Spectrogram
from nlpaug.augmenter.spectrogram import SpectrogramAugmenter

from ..utils.utils import shape_list

# ---------------------------- FREQ MASKING ----------------------------


Expand Down Expand Up @@ -75,6 +80,35 @@ def __init__(self,
def substitute(self, data):
return self.flow.augment(data)


class TFFreqMasking:
def __init__(self, num_masks: int = 1, mask_factor: float = 27):
self.num_masks = num_masks
self.mask_factor = mask_factor

@tf.function
def augment(self, spectrogram: tf.Tensor):
"""
Masking the frequency channels (shape[1])
Args:
spectrogram: shape (T, num_feature_bins, V)
Returns:
frequency masked spectrogram
"""
T, F, V = shape_list(spectrogram, out_type=tf.int32)
for _ in range(self.num_masks):
f = tf.random.uniform([], minval=0, maxval=self.mask_factor, dtype=tf.int32)
f = tf.minimum(f, F)
f0 = tf.random.uniform([], minval=0, maxval=(F - f), dtype=tf.int32)
mask = tf.concat([
tf.ones([T, f0, V], dtype=spectrogram.dtype),
tf.zeros([T, f, V], dtype=spectrogram.dtype),
tf.ones([T, F - f0 - f, V], dtype=spectrogram.dtype)
], axis=1)
spectrogram = spectrogram * mask
return spectrogram


# ---------------------------- TIME MASKING ----------------------------


Expand All @@ -101,9 +135,8 @@ def mask(self, data: np.ndarray) -> np.ndarray:
"""
spectrogram = data.copy()
time = np.random.randint(0, self.mask_factor + 1)
time = min(time, spectrogram.shape[0])
time0 = np.random.randint(0, spectrogram.shape[0] - time + 1)
time = min(time, int(self.p_upperbound * spectrogram.shape[0]))
time0 = np.random.randint(0, spectrogram.shape[0] - time + 1)
spectrogram[time0:time0 + time, :, :] = 0
return spectrogram

Expand Down Expand Up @@ -139,3 +172,32 @@ def __init__(self,

def substitute(self, data):
return self.flow.augment(data)


class TFTimeMasking:
def __init__(self, num_masks: int = 1, mask_factor: float = 100, p_upperbound: float = 1.0):
self.num_masks = num_masks
self.mask_factor = mask_factor
self.p_upperbound = p_upperbound

@tf.function
def augment(self, spectrogram: tf.Tensor):
"""
Masking the time channel (shape[0])
Args:
spectrogram: shape (T, num_feature_bins, V)
Returns:
frequency masked spectrogram
"""
T, F, V = shape_list(spectrogram, out_type=tf.int32)
for _ in range(self.num_masks):
t = tf.random.uniform([], minval=0, maxval=self.mask_factor, dtype=tf.int32)
t = tf.minimum(t, tf.cast(tf.cast(T, dtype=tf.float32) * self.p_upperbound, dtype=tf.int32))
t0 = tf.random.uniform([], minval=0, maxval=(T - t), dtype=tf.int32)
mask = tf.concat([
tf.ones([t0, F, V], dtype=spectrogram.dtype),
tf.zeros([t, F, V], dtype=spectrogram.dtype),
tf.ones([T - t0 - t, F, V], dtype=spectrogram.dtype)
], axis=0)
spectrogram = spectrogram * mask
return spectrogram
1 change: 1 addition & 0 deletions tensorflow_asr/configs/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def __init__(self, config: dict = None):
self.eval_paths = preprocess_paths(config.pop("eval_paths", None))
self.test_paths = preprocess_paths(config.pop("test_paths", None))
self.tfrecords_dir = preprocess_paths(config.pop("tfrecords_dir", None))
self.use_tf = config.pop("use_tf", False)
for k, v in config.items(): setattr(self, k, v)


Expand Down
4 changes: 2 additions & 2 deletions tensorflow_asr/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,5 @@
# limitations under the License.

from .base_dataset import BaseDataset
from .asr_dataset import ASRTFRecordDataset, ASRSliceDataset, ASRTFRecordTestDataset, ASRSliceTestDataset
__all__ = ['BaseDataset', 'ASRTFRecordDataset', 'ASRSliceDataset', 'ASRTFRecordTestDataset', 'ASRSliceTestDataset']
from .asr_dataset import ASRDataset, ASRTFRecordDataset, ASRSliceDataset
__all__ = ['BaseDataset', 'ASRDataset', 'ASRTFRecordDataset', 'ASRSliceDataset']
Loading