Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support variable input batch and SortaGrad. #74

Merged
merged 4 commits into from
Jun 12, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 63 additions & 35 deletions deep_speech_2/audio_data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import random
import soundfile
import numpy as np
import itertools
import os

RANDOM_SEED = 0
Expand Down Expand Up @@ -62,6 +63,7 @@ def __init__(self,
self.__stride_ms__ = stride_ms
self.__window_ms__ = window_ms
self.__max_frequency__ = max_frequency
self.__epoc__ = 0
self.__random__ = random.Random(RANDOM_SEED)
# load vocabulary (dictionary)
self.__vocab_dict__, self.__vocab_list__ = \
Expand Down Expand Up @@ -245,43 +247,56 @@ def __padding_batch__(self, batch, padding_to=-1, flatten=False):
new_batch.append((padded_audio, text))
return new_batch

def instance_reader_creator(self,
manifest_path,
sort_by_duration=True,
shuffle=False):
def __batch_shuffle__(self, manifest, batch_size):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里确实需要batch_size,因此没有改成batch_shuffle_size

"""
The instances have different lengths and they cannot be
combined into a single matrix multiplication. It usually
sorts the training examples by length and combines only
similarly-sized instances into minibatches, pads with
silence when necessary so that all instances in a batch
have the same length. This batch shuffle fuction is used
to make similarly-sized instances into minibatches and
make a batch-wise shuffle.

1. Sort the audio clips by duration.
2. Generate a random number `k`, k in [0, batch_size).
3. Randomly remove `k` instances in order to make different mini-batches,
then make minibatches and each minibatch size is batch_size.
4. Shuffle the minibatches.

:param manifest: manifest file.
:type manifest: list
:param batch_size: Batch size. This size is also used for generate
a random number for batch shuffle.
:type batch_size: int
:return: batch shuffled mainifest.
:rtype: list
"""
manifest.sort(key=lambda x: x["duration"])
shift_len = self.__random__.randint(0, batch_size - 1)
batch_manifest = zip(*[iter(manifest[shift_len:])] * batch_size)
self.__random__.shuffle(batch_manifest)
batch_manifest = list(sum(batch_manifest, ()))
res_len = len(manifest) - shift_len - len(batch_manifest)
batch_manifest.extend(manifest[-res_len:])
batch_manifest.extend(manifest[0:shift_len])
return batch_manifest

def instance_reader_creator(self, manifest):
"""
Instance reader creator for audio data. Creat a callable function to
produce instances of data.

Instance: a tuple of a numpy ndarray of audio spectrogram and a list of
tokenized and indexed transcription text.

:param manifest_path: Filepath of manifest for audio clip files.
:type manifest_path: basestring
:param sort_by_duration: Sort the audio clips by duration if set True
(for SortaGrad).
:type sort_by_duration: bool
:param shuffle: Shuffle the audio clips if set True.
:type shuffle: bool
:param manifest: Filepath of manifest for audio clip files.
:type manifest: basestring
:return: Data reader function.
:rtype: callable
"""
if sort_by_duration and shuffle:
sort_by_duration = False
logger.warn("When shuffle set to true, "
"sort_by_duration is forced to set False.")

def reader():
# read manifest
manifest = self.__read_manifest__(
manifest_path=manifest_path,
max_duration=self.__max_duration__,
min_duration=self.__min_duration__)
# sort (by duration) or shuffle manifest
if sort_by_duration:
manifest.sort(key=lambda x: x["duration"])
if shuffle:
self.__random__.shuffle(manifest)
# extract spectrogram feature
for instance in manifest:
spectrogram = self.__audio_featurize__(
Expand All @@ -296,8 +311,8 @@ def batch_reader_creator(self,
batch_size,
padding_to=-1,
flatten=False,
sort_by_duration=True,
shuffle=False):
sortagrad=False,
batch_shuffle=False):
"""
Batch data reader creator for audio data. Creat a callable function to
produce batches of data.
Expand All @@ -317,20 +332,32 @@ def batch_reader_creator(self,
:param flatten: If set True, audio data will be flatten to be a 1-dim
ndarray. Otherwise, 2-dim ndarray. Default is False.
:type flatten: bool
:param sort_by_duration: Sort the audio clips by duration if set True
(for SortaGrad).
:type sort_by_duration: bool
:param shuffle: Shuffle the audio clips if set True.
:type shuffle: bool
:param sortagrad: Sort the audio clips by duration in the first epoc
if set True.
:type sortagrad: bool
:param batch_shuffle: Shuffle the audio clips if set True. It is
not a thorough instance-wise shuffle, but a
specific batch-wise shuffle. For more details,
please see `__batch_shuffle__` function.
:type batch_shuffle: bool
:return: Batch reader function, producing batches of data when called.
:rtype: callable
"""

def batch_reader():
instance_reader = self.instance_reader_creator(
# read manifest
manifest = self.__read_manifest__(
manifest_path=manifest_path,
sort_by_duration=sort_by_duration,
shuffle=shuffle)
max_duration=self.__max_duration__,
min_duration=self.__min_duration__)

# sort (by duration) or shuffle manifest
if self.__epoc__ == 0 and sortagrad:
manifest.sort(key=lambda x: x["duration"])
elif batch_shuffle:
manifest = self.__batch_shuffle__(manifest, batch_size)

instance_reader = self.instance_reader_creator(manifest)
batch = []
for instance in instance_reader():
batch.append(instance)
Expand All @@ -339,6 +366,7 @@ def batch_reader():
batch = []
if len(batch) > 0:
yield self.__padding_batch__(batch, padding_to, flatten)
self.__epoc__ += 1

return batch_reader

Expand Down
65 changes: 24 additions & 41 deletions deep_speech_2/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,23 +93,27 @@ def train():
"""
DeepSpeech2 training.
"""

# initialize data generator
data_generator = DataGenerator(
vocab_filepath=args.vocab_filepath,
normalizer_manifest_path=args.normalizer_manifest_path,
normalizer_num_samples=200,
max_duration=20.0,
min_duration=0.0,
stride_ms=10,
window_ms=20)
def data_generator():
return DataGenerator(
vocab_filepath=args.vocab_filepath,
normalizer_manifest_path=args.normalizer_manifest_path,
normalizer_num_samples=200,
max_duration=20.0,
min_duration=0.0,
stride_ms=10,
window_ms=20)

train_generator = data_generator()
test_generator = data_generator()
# create network config
dict_size = data_generator.vocabulary_size()
dict_size = train_generator.vocabulary_size()
# paddle.data_type.dense_array is used for variable batch input.
# the size 161 * 161 is only an placeholder value and the real shape
# of input batch data will be set at each batch.
audio_data = paddle.layer.data(
name="audio_spectrogram",
height=161,
width=2000,
type=paddle.data_type.dense_vector(322000))
name="audio_spectrogram", type=paddle.data_type.dense_array(161 * 161))
text_data = paddle.layer.data(
name="transcript_text",
type=paddle.data_type.integer_value_sequence(dict_size))
Expand All @@ -136,28 +140,16 @@ def train():
cost=cost, parameters=parameters, update_equation=optimizer)

# prepare data reader
train_batch_reader_sortagrad = data_generator.batch_reader_creator(
manifest_path=args.train_manifest_path,
batch_size=args.batch_size,
padding_to=2000,
flatten=True,
sort_by_duration=True,
shuffle=False)
train_batch_reader_nosortagrad = data_generator.batch_reader_creator(
train_batch_reader = train_generator.batch_reader_creator(
manifest_path=args.train_manifest_path,
batch_size=args.batch_size,
padding_to=2000,
flatten=True,
sort_by_duration=False,
shuffle=True)
test_batch_reader = data_generator.batch_reader_creator(
sortagrad=True if args.init_model_path is None else False,
batch_shuffle=True)
test_batch_reader = test_generator.batch_reader_creator(
manifest_path=args.dev_manifest_path,
batch_size=args.batch_size,
padding_to=2000,
flatten=True,
sort_by_duration=False,
shuffle=False)
feeding = data_generator.data_name_feeding()
batch_shuffle=False)
feeding = train_generator.data_name_feeding()

# create event handler
def event_handler(event):
Expand All @@ -183,17 +175,8 @@ def event_handler(event):
time.time() - start_time, event.pass_id, result.cost)

# run train
# first pass with sortagrad
if args.use_sortagrad:
trainer.train(
reader=train_batch_reader_sortagrad,
event_handler=event_handler,
num_passes=1,
feeding=feeding)
args.num_passes -= 1
# other passes without sortagrad
trainer.train(
reader=train_batch_reader_nosortagrad,
reader=train_batch_reader,
event_handler=event_handler,
num_passes=args.num_passes,
feeding=feeding)
Expand Down