diff --git a/deep_speech_2/README.md b/deep_speech_2/README.md new file mode 100644 index 0000000000..48ee9f9a1b --- /dev/null +++ b/deep_speech_2/README.md @@ -0,0 +1,58 @@ +# Deep Speech 2 on PaddlePaddle + +## Quick Start + +### Installation + +Please replace `$PADDLE_INSTALL_DIR` with your paddle installation directory. + +``` +pip install -r requirements.txt +export LD_LIBRARY_PATH=$PADDLE_INSTALL_DIR/Paddle/third_party/install/warpctc/lib:$LD_LIBRARY_PATH +``` + +For some machines, we also need to install libsndfile1. Details to be added. + +### Preparing Dataset(s) + +``` +python librispeech.py +``` + +More help for arguments: + +``` +python librispeech.py --help +``` + +### Traininig + +For GPU Training: + +``` +CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py --trainer_count 4 +``` + +For CPU Training: + +``` +python train.py --trainer_count 8 --use_gpu False +``` + +More help for arguments: + +``` +python train.py --help +``` + +### Inferencing + +``` +python infer.py +``` + +More help for arguments: + +``` +python infer.py --help +``` diff --git a/deep_speech_2/audio_data_utils.py b/deep_speech_2/audio_data_utils.py new file mode 100644 index 0000000000..c717bcf182 --- /dev/null +++ b/deep_speech_2/audio_data_utils.py @@ -0,0 +1,383 @@ +""" + Providing basic audio data preprocessing pipeline, and offering + both instance-level and batch-level data reader interfaces. +""" +import paddle.v2 as paddle +import logging +import json +import random +import soundfile +import numpy as np +import os + +RANDOM_SEED = 0 +logger = logging.getLogger(__name__) + + +class DataGenerator(object): + """ + DataGenerator provides basic audio data preprocessing pipeline, and offers + both instance-level and batch-level data reader interfaces. + Normalized FFT are used as audio features here. + + :param vocab_filepath: Vocabulary file path for indexing tokenized + transcriptions. + :type vocab_filepath: basestring + :param normalizer_manifest_path: Manifest filepath for collecting feature + normalization statistics, e.g. mean, std. + :type normalizer_manifest_path: basestring + :param normalizer_num_samples: Number of instances sampled for collecting + feature normalization statistics. + Default is 100. + :type normalizer_num_samples: int + :param max_duration: Audio clips with duration (in seconds) greater than + this will be discarded. Default is 20.0. + :type max_duration: float + :param min_duration: Audio clips with duration (in seconds) smaller than + this will be discarded. Default is 0.0. + :type min_duration: float + :param stride_ms: Striding size (in milliseconds) for generating frames. + Default is 10.0. + :type stride_ms: float + :param window_ms: Window size (in milliseconds) for frames. Default is 20.0. + :type window_ms: float + :param max_frequency: Maximun frequency for FFT features. FFT features of + frequency larger than this will be discarded. + If set None, all features will be kept. + Default is None. + :type max_frequency: float + """ + + def __init__(self, + vocab_filepath, + normalizer_manifest_path, + normalizer_num_samples=100, + max_duration=20.0, + min_duration=0.0, + stride_ms=10.0, + window_ms=20.0, + max_frequency=None): + self.__max_duration__ = max_duration + self.__min_duration__ = min_duration + self.__stride_ms__ = stride_ms + self.__window_ms__ = window_ms + self.__max_frequency__ = max_frequency + self.__random__ = random.Random(RANDOM_SEED) + # load vocabulary (dictionary) + self.__vocab_dict__, self.__vocab_list__ = \ + self.__load_vocabulary_from_file__(vocab_filepath) + # collect normalizer statistics + self.__mean__, self.__std__ = self.__collect_normalizer_statistics__( + manifest_path=normalizer_manifest_path, + num_samples=normalizer_num_samples) + + def __audio_featurize__(self, audio_filename): + """ + Preprocess audio data, including feature extraction, normalization etc.. + """ + features = self.__audio_basic_featurize__(audio_filename) + return self.__normalize__(features) + + def __text_featurize__(self, text): + """ + Preprocess text data, including tokenizing and token indexing etc.. + """ + return self.__convert_text_to_char_index__( + text=text, vocabulary=self.__vocab_dict__) + + def __audio_basic_featurize__(self, audio_filename): + """ + Compute basic (without normalization etc.) features for audio data. + """ + return self.__spectrogram_from_file__( + filename=audio_filename, + stride_ms=self.__stride_ms__, + window_ms=self.__window_ms__, + max_freq=self.__max_frequency__) + + def __collect_normalizer_statistics__(self, manifest_path, num_samples=100): + """ + Compute feature normalization statistics, i.e. mean and stddev. + """ + # read manifest + manifest = self.__read_manifest__( + manifest_path=manifest_path, + max_duration=self.__max_duration__, + min_duration=self.__min_duration__) + # sample for statistics + sampled_manifest = self.__random__.sample(manifest, num_samples) + # extract spectrogram feature + features = [] + for instance in sampled_manifest: + spectrogram = self.__audio_basic_featurize__( + instance["audio_filepath"]) + features.append(spectrogram) + features = np.hstack(features) + mean = np.mean(features, axis=1).reshape([-1, 1]) + std = np.std(features, axis=1).reshape([-1, 1]) + return mean, std + + def __normalize__(self, features, eps=1e-14): + """ + Normalize features to be of zero mean and unit stddev. + """ + return (features - self.__mean__) / (self.__std__ + eps) + + def __spectrogram_from_file__(self, + filename, + stride_ms=10.0, + window_ms=20.0, + max_freq=None, + eps=1e-14): + """ + Laod audio data and calculate the log of spectrogram by FFT. + Refer to utils.py in https://github.com/baidu-research/ba-dls-deepspeech + """ + audio, sample_rate = soundfile.read(filename) + if audio.ndim >= 2: + audio = np.mean(audio, 1) + if max_freq is None: + max_freq = sample_rate / 2 + if max_freq > sample_rate / 2: + raise ValueError("max_freq must be greater than half of " + "sample rate.") + if stride_ms > window_ms: + raise ValueError("Stride size must not be greater than " + "window size.") + stride_size = int(0.001 * sample_rate * stride_ms) + window_size = int(0.001 * sample_rate * window_ms) + spectrogram, freqs = self.__extract_spectrogram__( + audio, + window_size=window_size, + stride_size=stride_size, + sample_rate=sample_rate) + ind = np.where(freqs <= max_freq)[0][-1] + 1 + return np.log(spectrogram[:ind, :] + eps) + + def __extract_spectrogram__(self, samples, window_size, stride_size, + sample_rate): + """ + Compute the spectrogram by FFT for a discrete real signal. + Refer to utils.py in https://github.com/baidu-research/ba-dls-deepspeech + """ + # extract strided windows + truncate_size = (len(samples) - window_size) % stride_size + samples = samples[:len(samples) - truncate_size] + nshape = (window_size, (len(samples) - window_size) // stride_size + 1) + nstrides = (samples.strides[0], samples.strides[0] * stride_size) + windows = np.lib.stride_tricks.as_strided( + samples, shape=nshape, strides=nstrides) + assert np.all( + windows[:, 1] == samples[stride_size:(stride_size + window_size)]) + # window weighting, squared Fast Fourier Transform (fft), scaling + weighting = np.hanning(window_size)[:, None] + fft = np.fft.rfft(windows * weighting, axis=0) + fft = np.absolute(fft)**2 + scale = np.sum(weighting**2) * sample_rate + fft[1:-1, :] *= (2.0 / scale) + fft[(0, -1), :] /= scale + # prepare fft frequency list + freqs = float(sample_rate) / window_size * np.arange(fft.shape[0]) + return fft, freqs + + def __load_vocabulary_from_file__(self, vocabulary_path): + """ + Load vocabulary from file. + """ + if not os.path.exists(vocabulary_path): + raise ValueError("Vocabulary file %s not found.", vocabulary_path) + vocab_lines = [] + with open(vocabulary_path, 'r') as file: + vocab_lines.extend(file.readlines()) + vocab_list = [line[:-1] for line in vocab_lines] + vocab_dict = dict( + [(token, id) for (id, token) in enumerate(vocab_list)]) + return vocab_dict, vocab_list + + def __convert_text_to_char_index__(self, text, vocabulary): + """ + Convert text string to a list of character index integers. + """ + return [vocabulary[w] for w in text] + + def __read_manifest__(self, manifest_path, max_duration, min_duration): + """ + Load and parse manifest file. + """ + manifest = [] + for json_line in open(manifest_path): + try: + json_data = json.loads(json_line) + except Exception as e: + raise ValueError("Error reading manifest: %s" % str(e)) + if (json_data["duration"] <= max_duration and + json_data["duration"] >= min_duration): + manifest.append(json_data) + return manifest + + def __padding_batch__(self, batch, padding_to=-1, flatten=False): + """ + Padding audio part of features (only in the time axis -- column axis) + with zeros, to make each instance in the batch share the same + audio feature shape. + + If `padding_to` is set -1, the maximun column numbers in the batch will + be used as the target size. Otherwise, `padding_to` will be the target + size. Default is -1. + + If `flatten` is set True, audio data will be flatten to be a 1-dim + ndarray. Default is False. + """ + new_batch = [] + # get target shape + max_length = max([audio.shape[1] for audio, text in batch]) + if padding_to != -1: + if padding_to < max_length: + raise ValueError("If padding_to is not -1, it should be greater" + " or equal to the original instance length.") + max_length = padding_to + # padding + for audio, text in batch: + padded_audio = np.zeros([audio.shape[0], max_length]) + padded_audio[:, :audio.shape[1]] = audio + if flatten: + padded_audio = padded_audio.flatten() + new_batch.append((padded_audio, text)) + return new_batch + + def instance_reader_creator(self, + manifest_path, + sort_by_duration=True, + shuffle=False): + """ + Instance reader creator for audio data. Creat a callable function to + produce instances of data. + + Instance: a tuple of a numpy ndarray of audio spectrogram and a list of + tokenized and indexed transcription text. + + :param manifest_path: Filepath of manifest for audio clip files. + :type manifest_path: basestring + :param sort_by_duration: Sort the audio clips by duration if set True + (for SortaGrad). + :type sort_by_duration: bool + :param shuffle: Shuffle the audio clips if set True. + :type shuffle: bool + :return: Data reader function. + :rtype: callable + """ + if sort_by_duration and shuffle: + sort_by_duration = False + logger.warn("When shuffle set to true, " + "sort_by_duration is forced to set False.") + + def reader(): + # read manifest + manifest = self.__read_manifest__( + manifest_path=manifest_path, + max_duration=self.__max_duration__, + min_duration=self.__min_duration__) + # sort (by duration) or shuffle manifest + if sort_by_duration: + manifest.sort(key=lambda x: x["duration"]) + if shuffle: + self.__random__.shuffle(manifest) + # extract spectrogram feature + for instance in manifest: + spectrogram = self.__audio_featurize__( + instance["audio_filepath"]) + transcript = self.__text_featurize__(instance["text"]) + yield (spectrogram, transcript) + + return reader + + def batch_reader_creator(self, + manifest_path, + batch_size, + padding_to=-1, + flatten=False, + sort_by_duration=True, + shuffle=False): + """ + Batch data reader creator for audio data. Creat a callable function to + produce batches of data. + + Audio features will be padded with zeros to make each instance in the + batch to share the same audio feature shape. + + :param manifest_path: Filepath of manifest for audio clip files. + :type manifest_path: basestring + :param batch_size: Instance number in a batch. + :type batch_size: int + :param padding_to: If set -1, the maximun column numbers in the batch + will be used as the target size for padding. + Otherwise, `padding_to` will be the target size. + Default is -1. + :type padding_to: int + :param flatten: If set True, audio data will be flatten to be a 1-dim + ndarray. Otherwise, 2-dim ndarray. Default is False. + :type flatten: bool + :param sort_by_duration: Sort the audio clips by duration if set True + (for SortaGrad). + :type sort_by_duration: bool + :param shuffle: Shuffle the audio clips if set True. + :type shuffle: bool + :return: Batch reader function, producing batches of data when called. + :rtype: callable + """ + + def batch_reader(): + instance_reader = self.instance_reader_creator( + manifest_path=manifest_path, + sort_by_duration=sort_by_duration, + shuffle=shuffle) + batch = [] + for instance in instance_reader(): + batch.append(instance) + if len(batch) == batch_size: + yield self.__padding_batch__(batch, padding_to, flatten) + batch = [] + if len(batch) > 0: + yield self.__padding_batch__(batch, padding_to, flatten) + + return batch_reader + + def vocabulary_size(self): + """ + Get vocabulary size. + + :return: Vocabulary size. + :rtype: int + """ + return len(self.__vocab_list__) + + def vocabulary_dict(self): + """ + Get vocabulary in dict. + + :return: Vocabulary in dict. + :rtype: dict + """ + return self.__vocab_dict__ + + def vocabulary_list(self): + """ + Get vocabulary in list. + + :return: Vocabulary in list + :rtype: list + """ + return self.__vocab_list__ + + def data_name_feeding(self): + """ + Get feeddings (data field name and corresponding field id). + + :return: Feeding dict. + :rtype: dict + """ + feeding = { + "audio_spectrogram": 0, + "transcript_text": 1, + } + return feeding diff --git a/deep_speech_2/eng_vocab.txt b/deep_speech_2/eng_vocab.txt new file mode 100644 index 0000000000..8268f3f330 --- /dev/null +++ b/deep_speech_2/eng_vocab.txt @@ -0,0 +1,28 @@ +' + +a +b +c +d +e +f +g +h +i +j +k +l +m +n +o +p +q +r +s +t +u +v +w +x +y +z diff --git a/deep_speech_2/infer.py b/deep_speech_2/infer.py new file mode 100644 index 0000000000..1c52c98fd8 --- /dev/null +++ b/deep_speech_2/infer.py @@ -0,0 +1,144 @@ +""" + Inference for a simplifed version of Baidu DeepSpeech2 model. +""" + +import paddle.v2 as paddle +from itertools import groupby +import distutils.util +import argparse +import gzip +from audio_data_utils import DataGenerator +from model import deep_speech2 + +parser = argparse.ArgumentParser( + description='Simplified version of DeepSpeech2 inference.') +parser.add_argument( + "--num_samples", + default=10, + type=int, + help="Number of samples for inference. (default: %(default)s)") +parser.add_argument( + "--num_conv_layers", + default=2, + type=int, + help="Convolution layer number. (default: %(default)s)") +parser.add_argument( + "--num_rnn_layers", + default=3, + type=int, + help="RNN layer number. (default: %(default)s)") +parser.add_argument( + "--rnn_layer_size", + default=512, + type=int, + help="RNN layer cell number. (default: %(default)s)") +parser.add_argument( + "--use_gpu", + default=True, + type=distutils.util.strtobool, + help="Use gpu or not. (default: %(default)s)") +parser.add_argument( + "--normalizer_manifest_path", + default='./manifest.libri.train-clean-100', + type=str, + help="Manifest path for normalizer. (default: %(default)s)") +parser.add_argument( + "--decode_manifest_path", + default='./manifest.libri.test-clean', + type=str, + help="Manifest path for decoding. (default: %(default)s)") +parser.add_argument( + "--model_filepath", + default='./params.tar.gz', + type=str, + help="Model filepath. (default: %(default)s)") +args = parser.parse_args() + + +def remove_duplicate_and_blank(id_list, blank_id): + """ + Postprocessing for max-ctc-decoder. + - remove consecutive duplicate tokens. + - remove blanks. + """ + # remove consecutive duplicate tokens + id_list = [x[0] for x in groupby(id_list)] + # remove blanks + return [id for id in id_list if id != blank_id] + + +def best_path_decode(): + """ + Max-ctc-decoding for DeepSpeech2. + """ + # initialize data generator + data_generator = DataGenerator( + vocab_filepath='eng_vocab.txt', + normalizer_manifest_path=args.normalizer_manifest_path, + normalizer_num_samples=200, + max_duration=20.0, + min_duration=0.0, + stride_ms=10, + window_ms=20) + # create network config + dict_size = data_generator.vocabulary_size() + vocab_list = data_generator.vocabulary_list() + audio_data = paddle.layer.data( + name="audio_spectrogram", + height=161, + width=2000, + type=paddle.data_type.dense_vector(322000)) + text_data = paddle.layer.data( + name="transcript_text", + type=paddle.data_type.integer_value_sequence(dict_size)) + _, max_id = deep_speech2( + audio_data=audio_data, + text_data=text_data, + dict_size=dict_size, + num_conv_layers=args.num_conv_layers, + num_rnn_layers=args.num_rnn_layers, + rnn_size=args.rnn_layer_size) + + # load parameters + parameters = paddle.parameters.Parameters.from_tar( + gzip.open(args.model_filepath)) + + # prepare infer data + feeding = data_generator.data_name_feeding() + test_batch_reader = data_generator.batch_reader_creator( + manifest_path=args.decode_manifest_path, + batch_size=args.num_samples, + padding_to=2000, + flatten=True, + sort_by_duration=False, + shuffle=False) + infer_data = test_batch_reader().next() + + # run max-ctc-decoding + max_id_results = paddle.infer( + output_layer=max_id, + parameters=parameters, + input=infer_data, + field=['id']) + + # postprocess + instance_length = len(max_id_results) / args.num_samples + instance_list = [ + max_id_results[i * instance_length:(i + 1) * instance_length] + for i in xrange(0, args.num_samples) + ] + for i, instance in enumerate(instance_list): + id_list = remove_duplicate_and_blank(instance, dict_size) + output_transcript = ''.join([vocab_list[id] for id in id_list]) + target_transcript = ''.join([vocab_list[id] for id in infer_data[i][1]]) + print("Target Transcript: %s \nOutput Transcript: %s \n" % + (target_transcript, output_transcript)) + + +def main(): + paddle.init(use_gpu=args.use_gpu, trainer_count=1) + best_path_decode() + + +if __name__ == '__main__': + main() diff --git a/deep_speech_2/librispeech.py b/deep_speech_2/librispeech.py new file mode 100644 index 0000000000..676bbec5ce --- /dev/null +++ b/deep_speech_2/librispeech.py @@ -0,0 +1,138 @@ +""" + Download, unpack and create manifest for Librespeech dataset. + + Manifest is a json file with each line containing one audio clip filepath, + its transcription text string, and its duration. It servers as a unified + interfance to organize different data sets. +""" + +import paddle.v2 as paddle +from paddle.v2.dataset.common import md5file +import os +import wget +import tarfile +import argparse +import soundfile +import json + +DATA_HOME = os.path.expanduser('~/.cache2/paddle/dataset/speech') + +URL_ROOT = "http://www.openslr.org/resources/12" +URL_TEST_CLEAN = URL_ROOT + "/test-clean.tar.gz" +URL_TEST_OTHER = URL_ROOT + "/test-other.tar.gz" +URL_DEV_CLEAN = URL_ROOT + "/dev-clean.tar.gz" +URL_DEV_OTHER = URL_ROOT + "/dev-other.tar.gz" +URL_TRAIN_CLEAN_100 = URL_ROOT + "/train-clean-100.tar.gz" +URL_TRAIN_CLEAN_360 = URL_ROOT + "/train-clean-360.tar.gz" +URL_TRAIN_OTHER_500 = URL_ROOT + "/train-other-500.tar.gz" + +MD5_TEST_CLEAN = "32fa31d27d2e1cad72775fee3f4849a9" +MD5_DEV_CLEAN = "42e2234ba48799c1f50f24a7926300a1" +MD5_TRAIN_CLEAN_100 = "2a93770f6d5c6c964bc36631d331a522" +MD5_TRAIN_CLEAN_360 = "c0e676e450a7ff2f54aeade5171606fa" +MD5_TRAIN_CLEAN_500 = "d1a0fd59409feb2c614ce4d30c387708" + +parser = argparse.ArgumentParser( + description='Downloads and prepare LibriSpeech dataset.') +parser.add_argument( + "--target_dir", + default=DATA_HOME + "/Libri", + type=str, + help="Directory to save the dataset. (default: %(default)s)") +parser.add_argument( + "--manifest_prefix", + default="manifest.libri", + type=str, + help="Filepath prefix for output manifests. (default: %(default)s)") +args = parser.parse_args() + + +def download(url, md5sum, target_dir): + """ + Download file from url to target_dir, and check md5sum. + """ + if not os.path.exists(target_dir): os.makedirs(target_dir) + filepath = os.path.join(target_dir, url.split("/")[-1]) + if not (os.path.exists(filepath) and md5file(filepath) == md5sum): + print("Downloading %s ..." % url) + wget.download(url, target_dir) + print("\nMD5 Chesksum %s ..." % filepath) + assert md5file(filepath) == md5sum, "MD5 checksum failed." + return filepath + + +def unpack(filepath, target_dir): + """ + Unpack the file to the target_dir. + """ + print("Unpacking %s ..." % filepath) + tar = tarfile.open(filepath) + tar.extractall(target_dir) + tar.close() + return target_dir + + +def create_manifest(data_dir, manifest_path): + """ + Create a manifest file summarizing the dataset (list of filepath and meta + data). + + Each line of the manifest contains one audio clip filepath, its + transcription text string, and its duration. Manifest file servers as a + unified interfance to organize data sets. + """ + print("Creating manifest %s ..." % manifest_path) + json_lines = [] + for subfolder, _, filelist in os.walk(data_dir): + text_filelist = [ + filename for filename in filelist if filename.endswith('trans.txt') + ] + if len(text_filelist) > 0: + text_filepath = os.path.join(data_dir, subfolder, text_filelist[0]) + for line in open(text_filepath): + segments = line.strip().split() + text = ' '.join(segments[1:]).lower() + audio_filepath = os.path.join(data_dir, subfolder, + segments[0] + '.flac') + audio_data, samplerate = soundfile.read(audio_filepath) + duration = float(len(audio_data)) / samplerate + json_lines.append( + json.dumps({ + 'audio_filepath': audio_filepath, + 'duration': duration, + 'text': text + })) + with open(manifest_path, 'w') as out_file: + for line in json_lines: + out_file.write(line + '\n') + + +def prepare_dataset(url, md5sum, target_dir, manifest_path): + """ + Download, unpack and create summmary manifest file. + """ + filepath = download(url, md5sum, target_dir) + unpacked_dir = unpack(filepath, target_dir) + create_manifest(unpacked_dir, manifest_path) + + +def main(): + prepare_dataset( + url=URL_TEST_CLEAN, + md5sum=MD5_TEST_CLEAN, + target_dir=os.path.join(args.target_dir, "test-clean"), + manifest_path=args.manifest_prefix + ".test-clean") + prepare_dataset( + url=URL_DEV_CLEAN, + md5sum=MD5_DEV_CLEAN, + target_dir=os.path.join(args.target_dir, "dev-clean"), + manifest_path=args.manifest_prefix + ".dev-clean") + prepare_dataset( + url=URL_TRAIN_CLEAN_100, + md5sum=MD5_TRAIN_CLEAN_100, + target_dir=os.path.join(args.target_dir, "train-clean-100"), + manifest_path=args.manifest_prefix + ".train-clean-100") + + +if __name__ == '__main__': + main() diff --git a/deep_speech_2/model.py b/deep_speech_2/model.py new file mode 100644 index 0000000000..6b396900e6 --- /dev/null +++ b/deep_speech_2/model.py @@ -0,0 +1,136 @@ +""" + A simplifed version of Baidu DeepSpeech2 model. +""" + +import paddle.v2 as paddle + +#TODO: add bidirectional rnn. + + +def conv_bn_layer(input, filter_size, num_channels_in, num_channels_out, stride, + padding, act): + """ + Convolution layer with batch normalization. + """ + conv_layer = paddle.layer.img_conv( + input=input, + filter_size=filter_size, + num_channels=num_channels_in, + num_filters=num_channels_out, + stride=stride, + padding=padding, + act=paddle.activation.Linear(), + bias_attr=False) + return paddle.layer.batch_norm(input=conv_layer, act=act) + + +def bidirectional_simple_rnn_bn_layer(name, input, size, act): + """ + Bidirectonal simple rnn layer with sequence-wise batch normalization. + The batch normalization is only performed on input-state weights. + """ + # input-hidden weights shared across bi-direcitonal rnn. + input_proj = paddle.layer.fc( + input=input, size=size, act=paddle.activation.Linear(), bias_attr=False) + # batch norm is only performed on input-state projection + input_proj_bn = paddle.layer.batch_norm( + input=input_proj, act=paddle.activation.Linear()) + # forward and backward in time + forward_simple_rnn = paddle.layer.recurrent( + input=input_proj_bn, act=act, reverse=False) + backward_simple_rnn = paddle.layer.recurrent( + input=input_proj_bn, act=act, reverse=True) + return paddle.layer.concat(input=[forward_simple_rnn, backward_simple_rnn]) + + +def conv_group(input, num_stacks): + """ + Convolution group with several stacking convolution layers. + """ + conv = conv_bn_layer( + input=input, + filter_size=(11, 41), + num_channels_in=1, + num_channels_out=32, + stride=(3, 2), + padding=(5, 20), + act=paddle.activation.BRelu()) + for i in xrange(num_stacks - 1): + conv = conv_bn_layer( + input=conv, + filter_size=(11, 21), + num_channels_in=32, + num_channels_out=32, + stride=(1, 2), + padding=(5, 10), + act=paddle.activation.BRelu()) + output_num_channels = 32 + output_height = 160 // pow(2, num_stacks) + 1 + return conv, output_num_channels, output_height + + +def rnn_group(input, size, num_stacks): + """ + RNN group with several stacking RNN layers. + """ + output = input + for i in xrange(num_stacks): + output = bidirectional_simple_rnn_bn_layer( + name=str(i), input=output, size=size, act=paddle.activation.BRelu()) + return output + + +def deep_speech2(audio_data, + text_data, + dict_size, + num_conv_layers=2, + num_rnn_layers=3, + rnn_size=256): + """ + The whole DeepSpeech2 model structure (a simplified version). + + :param audio_data: Audio spectrogram data layer. + :type audio_data: LayerOutput + :param text_data: Transcription text data layer. + :type text_data: LayerOutput + :param dict_size: Dictionary size for tokenized transcription. + :type dict_size: int + :param num_conv_layers: Number of stacking convolution layers. + :type num_conv_layers: int + :param num_rnn_layers: Number of stacking RNN layers. + :type num_rnn_layers: int + :param rnn_size: RNN layer size (number of RNN cells). + :type rnn_size: int + :return: Tuple of the cost layer and the max_id decoder layer. + :rtype: tuple of LayerOutput + """ + # convolution group + conv_group_output, conv_group_num_channels, conv_group_height = conv_group( + input=audio_data, num_stacks=num_conv_layers) + # convert data form convolution feature map to sequence of vectors + conv2seq = paddle.layer.block_expand( + input=conv_group_output, + num_channels=conv_group_num_channels, + stride_x=1, + stride_y=1, + block_x=1, + block_y=conv_group_height) + # rnn group + rnn_group_output = rnn_group( + input=conv2seq, size=rnn_size, num_stacks=num_rnn_layers) + # output token distribution + fc = paddle.layer.fc( + input=rnn_group_output, + size=dict_size + 1, + act=paddle.activation.Linear(), + bias_attr=True) + # ctc cost + cost = paddle.layer.warp_ctc( + input=fc, + label=text_data, + size=dict_size + 1, + blank=dict_size, + norm_by_times=True) + # max decoder + max_id = paddle.layer.max_id(input=fc) + return cost, max_id diff --git a/deep_speech_2/requirements.txt b/deep_speech_2/requirements.txt new file mode 100644 index 0000000000..58a93debe4 --- /dev/null +++ b/deep_speech_2/requirements.txt @@ -0,0 +1,2 @@ +SoundFile==0.9.0.post1 +wget==3.2 diff --git a/deep_speech_2/train.py b/deep_speech_2/train.py new file mode 100644 index 0000000000..ad6e5ffd1b --- /dev/null +++ b/deep_speech_2/train.py @@ -0,0 +1,190 @@ +""" + Trainer for a simplifed version of Baidu DeepSpeech2 model. +""" + +import paddle.v2 as paddle +import distutils.util +import argparse +import gzip +import time +import sys +from model import deep_speech2 +from audio_data_utils import DataGenerator +import numpy as np + +#TODO: add WER metric + +parser = argparse.ArgumentParser( + description='Simplified version of DeepSpeech2 trainer.') +parser.add_argument( + "--batch_size", default=32, type=int, help="Minibatch size.") +parser.add_argument( + "--num_passes", + default=20, + type=int, + help="Training pass number. (default: %(default)s)") +parser.add_argument( + "--num_conv_layers", + default=2, + type=int, + help="Convolution layer number. (default: %(default)s)") +parser.add_argument( + "--num_rnn_layers", + default=3, + type=int, + help="RNN layer number. (default: %(default)s)") +parser.add_argument( + "--rnn_layer_size", + default=512, + type=int, + help="RNN layer cell number. (default: %(default)s)") +parser.add_argument( + "--adam_learning_rate", + default=5e-4, + type=float, + help="Learning rate for ADAM Optimizer. (default: %(default)s)") +parser.add_argument( + "--use_gpu", + default=True, + type=distutils.util.strtobool, + help="Use gpu or not. (default: %(default)s)") +parser.add_argument( + "--use_sortagrad", + default=False, + type=distutils.util.strtobool, + help="Use sortagrad or not. (default: %(default)s)") +parser.add_argument( + "--trainer_count", + default=4, + type=int, + help="Trainer number. (default: %(default)s)") +parser.add_argument( + "--normalizer_manifest_path", + default='./manifest.libri.train-clean-100', + type=str, + help="Manifest path for normalizer. (default: %(default)s)") +parser.add_argument( + "--train_manifest_path", + default='./manifest.libri.train-clean-100', + type=str, + help="Manifest path for training. (default: %(default)s)") +parser.add_argument( + "--dev_manifest_path", + default='./manifest.libri.dev-clean', + type=str, + help="Manifest path for validation. (default: %(default)s)") +args = parser.parse_args() + + +def train(): + """ + DeepSpeech2 training. + """ + # initialize data generator + data_generator = DataGenerator( + vocab_filepath='eng_vocab.txt', + normalizer_manifest_path=args.normalizer_manifest_path, + normalizer_num_samples=200, + max_duration=20.0, + min_duration=0.0, + stride_ms=10, + window_ms=20) + + # create network config + dict_size = data_generator.vocabulary_size() + audio_data = paddle.layer.data( + name="audio_spectrogram", + height=161, + width=2000, + type=paddle.data_type.dense_vector(322000)) + text_data = paddle.layer.data( + name="transcript_text", + type=paddle.data_type.integer_value_sequence(dict_size)) + cost, _ = deep_speech2( + audio_data=audio_data, + text_data=text_data, + dict_size=dict_size, + num_conv_layers=args.num_conv_layers, + num_rnn_layers=args.num_rnn_layers, + rnn_size=args.rnn_layer_size) + + # create parameters and optimizer + parameters = paddle.parameters.create(cost) + optimizer = paddle.optimizer.Adam( + learning_rate=args.adam_learning_rate, gradient_clipping_threshold=400) + trainer = paddle.trainer.SGD( + cost=cost, parameters=parameters, update_equation=optimizer) + + # prepare data reader + train_batch_reader_sortagrad = data_generator.batch_reader_creator( + manifest_path=args.train_manifest_path, + batch_size=args.batch_size // args.trainer_count, + padding_to=2000, + flatten=True, + sort_by_duration=True, + shuffle=False) + train_batch_reader_nosortagrad = data_generator.batch_reader_creator( + manifest_path=args.train_manifest_path, + batch_size=args.batch_size // args.trainer_count, + padding_to=2000, + flatten=True, + sort_by_duration=False, + shuffle=True) + test_batch_reader = data_generator.batch_reader_creator( + manifest_path=args.dev_manifest_path, + batch_size=args.batch_size // args.trainer_count, + padding_to=2000, + flatten=True, + sort_by_duration=False, + shuffle=False) + feeding = data_generator.data_name_feeding() + + # create event handler + def event_handler(event): + global start_time + global cost_sum + global cost_counter + if isinstance(event, paddle.event.EndIteration): + cost_sum += event.cost + cost_counter += 1 + if event.batch_id % 50 == 0: + print "\nPass: %d, Batch: %d, TrainCost: %f" % ( + event.pass_id, event.batch_id, cost_sum / cost_counter) + cost_sum, cost_counter = 0.0, 0 + with gzip.open("params.tar.gz", 'w') as f: + parameters.to_tar(f) + else: + sys.stdout.write('.') + sys.stdout.flush() + if isinstance(event, paddle.event.BeginPass): + start_time = time.time() + cost_sum, cost_counter = 0.0, 0 + if isinstance(event, paddle.event.EndPass): + result = trainer.test(reader=test_batch_reader, feeding=feeding) + print "\n------- Time: %d sec, Pass: %d, ValidationCost: %s" % ( + time.time() - start_time, event.pass_id, result.cost) + + # run train + # first pass with sortagrad + if args.use_sortagrad: + trainer.train( + reader=train_batch_reader_sortagrad, + event_handler=event_handler, + num_passes=1, + feeding=feeding) + args.num_passes -= 1 + # other passes without sortagrad + trainer.train( + reader=train_batch_reader_nosortagrad, + event_handler=event_handler, + num_passes=args.num_passes, + feeding=feeding) + + +def main(): + paddle.init(use_gpu=args.use_gpu, trainer_count=args.trainer_count) + train() + + +if __name__ == '__main__': + main()