diff --git a/example/speech_recognition/README.md b/example/speech_recognition/README.md index 69961b1bdc5c..12e415cf6ad7 100644 --- a/example/speech_recognition/README.md +++ b/example/speech_recognition/README.md @@ -123,3 +123,16 @@ The new file should implement two functions, prepare_data() and arch(), for buil Run the following line after preparing the files.
python main.py --configfile custom.cfg --archfile arch_custom
+ +*** +## **Further more** +You can prepare full LibriSpeech dataset by following the instruction on https://github.com/baidu-research/ba-dls-deepspeech +```bash +git clone https://github.com/baidu-research/ba-dls-deepspeech +cd ba-dls-deepspeech +./download.sh +./flac_to_wav.sh +python create_desc_json.py /path/to/ba-dls-deepspeech/LibriSpeech/train-clean-100 train_corpus.json +python create_desc_json.py /path/to/ba-dls-deepspeech/LibriSpeech/dev-clean validation_corpus.json +python create_desc_json.py /path/to/ba-dls-deepspeech/LibriSpeech/test-clean test_corpus.json +``` diff --git a/example/speech_recognition/arch_deepspeech.py b/example/speech_recognition/arch_deepspeech.py index 92f1002a2f01..99d96c6c281b 100644 --- a/example/speech_recognition/arch_deepspeech.py +++ b/example/speech_recognition/arch_deepspeech.py @@ -1,6 +1,12 @@ +# pylint: disable=C0111, too-many-statements, too-many-locals +# pylint: too-many-arguments,too-many-instance-attributes,too-many-locals,redefined-outer-name,fixme +# pylint: disable=superfluous-parens, no-member, invalid-name +""" +architecture file for deep speech 2 model +""" import json import math - +import argparse import mxnet as mx from stt_layer_batchnorm import batchnorm @@ -13,6 +19,9 @@ def prepare_data(args): + """ + set atual shape of data + """ rnn_type = args.config.get("arch", "rnn_type") num_rnn_layer = args.config.getint("arch", "num_rnn_layer") num_hidden_rnn_list = json.loads(args.config.get("arch", "num_hidden_rnn_list")) @@ -20,26 +29,29 @@ def prepare_data(args): batch_size = args.config.getint("common", "batch_size") if rnn_type == 'lstm': - init_c = [('l%d_init_c' % l, (batch_size, num_hidden_rnn_list[l])) for l in range(num_rnn_layer)] - init_h = [('l%d_init_h' % l, (batch_size, num_hidden_rnn_list[l])) for l in range(num_rnn_layer)] + init_c = [('l%d_init_c' % l, (batch_size, num_hidden_rnn_list[l])) + for l in range(num_rnn_layer)] + init_h = [('l%d_init_h' % l, (batch_size, num_hidden_rnn_list[l])) + for l in range(num_rnn_layer)] elif rnn_type == 'bilstm': - forward_init_c = [('forward_l%d_init_c' % l, (batch_size, num_hidden_rnn_list[l])) for l in - range(num_rnn_layer)] - backward_init_c = [('backward_l%d_init_c' % l, (batch_size, num_hidden_rnn_list[l])) for l in - range(num_rnn_layer)] + forward_init_c = [('forward_l%d_init_c' % l, (batch_size, num_hidden_rnn_list[l])) + for l in range(num_rnn_layer)] + backward_init_c = [('backward_l%d_init_c' % l, (batch_size, num_hidden_rnn_list[l])) + for l in range(num_rnn_layer)] init_c = forward_init_c + backward_init_c - forward_init_h = [('forward_l%d_init_h' % l, (batch_size, num_hidden_rnn_list[l])) for l in - range(num_rnn_layer)] - backward_init_h = [('backward_l%d_init_h' % l, (batch_size, num_hidden_rnn_list[l])) for l in - range(num_rnn_layer)] + forward_init_h = [('forward_l%d_init_h' % l, (batch_size, num_hidden_rnn_list[l])) + for l in range(num_rnn_layer)] + backward_init_h = [('backward_l%d_init_h' % l, (batch_size, num_hidden_rnn_list[l])) + for l in range(num_rnn_layer)] init_h = forward_init_h + backward_init_h elif rnn_type == 'gru': - init_h = [('l%d_init_h' % l, (batch_size, num_hidden_rnn_list[l])) for l in range(num_rnn_layer)] + init_h = [('l%d_init_h' % l, (batch_size, num_hidden_rnn_list[l])) + for l in range(num_rnn_layer)] elif rnn_type == 'bigru': - forward_init_h = [('forward_l%d_init_h' % l, (batch_size, num_hidden_rnn_list[l])) for l in - range(num_rnn_layer)] - backward_init_h = [('backward_l%d_init_h' % l, (batch_size, num_hidden_rnn_list[l])) for l in - range(num_rnn_layer)] + forward_init_h = [('forward_l%d_init_h' % l, (batch_size, num_hidden_rnn_list[l])) + for l in range(num_rnn_layer)] + backward_init_h = [('backward_l%d_init_h' % l, (batch_size, num_hidden_rnn_list[l])) + for l in range(num_rnn_layer)] init_h = forward_init_h + backward_init_h else: raise Exception('network type should be one of the lstm,bilstm,gru,bigru') @@ -51,115 +63,146 @@ def prepare_data(args): return init_states -def arch(args): - mode = args.config.get("common", "mode") - if mode == "train": - channel_num = args.config.getint("arch", "channel_num") - conv_layer1_filter_dim = tuple(json.loads(args.config.get("arch", "conv_layer1_filter_dim"))) - conv_layer1_stride = tuple(json.loads(args.config.get("arch", "conv_layer1_stride"))) - conv_layer2_filter_dim = tuple(json.loads(args.config.get("arch", "conv_layer2_filter_dim"))) - conv_layer2_stride = tuple(json.loads(args.config.get("arch", "conv_layer2_stride"))) - - rnn_type = args.config.get("arch", "rnn_type") - num_rnn_layer = args.config.getint("arch", "num_rnn_layer") - num_hidden_rnn_list = json.loads(args.config.get("arch", "num_hidden_rnn_list")) - - is_batchnorm = args.config.getboolean("arch", "is_batchnorm") - - seq_len = args.config.getint('arch', 'max_t_count') - num_label = args.config.getint('arch', 'max_label_length') - - num_rear_fc_layers = args.config.getint("arch", "num_rear_fc_layers") - num_hidden_rear_fc_list = json.loads(args.config.get("arch", "num_hidden_rear_fc_list")) - act_type_rear_fc_list = json.loads(args.config.get("arch", "act_type_rear_fc_list")) - # model symbol generation - # input preparation - data = mx.sym.Variable('data') - label = mx.sym.Variable('label') - - net = mx.sym.Reshape(data=data, shape=(-4, -1, 1, 0, 0)) - net = conv(net=net, - channels=channel_num, - filter_dimension=conv_layer1_filter_dim, - stride=conv_layer1_stride, - no_bias=is_batchnorm - ) - if is_batchnorm: - # batch norm normalizes axis 1 - net = batchnorm(net) - - net = conv(net=net, - channels=channel_num, - filter_dimension=conv_layer2_filter_dim, - stride=conv_layer2_stride, - no_bias=is_batchnorm - ) - if is_batchnorm: - # batch norm normalizes axis 1 - net = batchnorm(net) - net = mx.sym.transpose(data=net, axes=(0, 2, 1, 3)) - net = mx.sym.Reshape(data=net, shape=(0, 0, -3)) - seq_len_after_conv_layer1 = int( - math.floor((seq_len - conv_layer1_filter_dim[0]) / conv_layer1_stride[0])) + 1 - seq_len_after_conv_layer2 = int( - math.floor((seq_len_after_conv_layer1 - conv_layer2_filter_dim[0]) / conv_layer2_stride[0])) + 1 - net = slice_symbol_to_seq_symobls(net=net, seq_len=seq_len_after_conv_layer2, axis=1) - if rnn_type == "bilstm": - net = bi_lstm_unroll(net=net, +def arch(args, seq_len=None): + """ + define deep speech 2 network + """ + if isinstance(args, argparse.Namespace): + mode = args.config.get("common", "mode") + if mode == "train": + channel_num = args.config.getint("arch", "channel_num") + conv_layer1_filter_dim = \ + tuple(json.loads(args.config.get("arch", "conv_layer1_filter_dim"))) + conv_layer1_stride = tuple(json.loads(args.config.get("arch", "conv_layer1_stride"))) + conv_layer2_filter_dim = \ + tuple(json.loads(args.config.get("arch", "conv_layer2_filter_dim"))) + conv_layer2_stride = tuple(json.loads(args.config.get("arch", "conv_layer2_stride"))) + + rnn_type = args.config.get("arch", "rnn_type") + num_rnn_layer = args.config.getint("arch", "num_rnn_layer") + num_hidden_rnn_list = json.loads(args.config.get("arch", "num_hidden_rnn_list")) + + is_batchnorm = args.config.getboolean("arch", "is_batchnorm") + is_bucketing = args.config.getboolean("arch", "is_bucketing") + + if seq_len is None: + seq_len = args.config.getint('arch', 'max_t_count') + + num_label = args.config.getint('arch', 'max_label_length') + + num_rear_fc_layers = args.config.getint("arch", "num_rear_fc_layers") + num_hidden_rear_fc_list = json.loads(args.config.get("arch", "num_hidden_rear_fc_list")) + act_type_rear_fc_list = json.loads(args.config.get("arch", "act_type_rear_fc_list")) + # model symbol generation + # input preparation + data = mx.sym.Variable('data') + label = mx.sym.Variable('label') + + net = mx.sym.Reshape(data=data, shape=(-4, -1, 1, 0, 0)) + net = conv(net=net, + channels=channel_num, + filter_dimension=conv_layer1_filter_dim, + stride=conv_layer1_stride, + no_bias=is_batchnorm, + name='conv1') + if is_batchnorm: + # batch norm normalizes axis 1 + net = batchnorm(net, name="conv1_batchnorm") + + net = conv(net=net, + channels=channel_num, + filter_dimension=conv_layer2_filter_dim, + stride=conv_layer2_stride, + no_bias=is_batchnorm, + name='conv2') + # if is_batchnorm: + # # batch norm normalizes axis 1 + # net = batchnorm(net, name="conv2_batchnorm") + + net = mx.sym.transpose(data=net, axes=(0, 2, 1, 3)) + net = mx.sym.Reshape(data=net, shape=(0, 0, -3)) + seq_len_after_conv_layer1 = int( + math.floor((seq_len - conv_layer1_filter_dim[0]) / conv_layer1_stride[0])) + 1 + seq_len_after_conv_layer2 = int( + math.floor((seq_len_after_conv_layer1 - conv_layer2_filter_dim[0]) + / conv_layer2_stride[0])) + 1 + net = slice_symbol_to_seq_symobls(net=net, seq_len=seq_len_after_conv_layer2, axis=1) + if rnn_type == "bilstm": + net = bi_lstm_unroll(net=net, + seq_len=seq_len_after_conv_layer2, + num_hidden_lstm_list=num_hidden_rnn_list, + num_lstm_layer=num_rnn_layer, + dropout=0., + is_batchnorm=is_batchnorm, + is_bucketing=is_bucketing) + elif rnn_type == "gru": + net = gru_unroll(net=net, seq_len=seq_len_after_conv_layer2, - num_hidden_lstm_list=num_hidden_rnn_list, - num_lstm_layer=num_rnn_layer, + num_hidden_gru_list=num_hidden_rnn_list, + num_gru_layer=num_rnn_layer, dropout=0., - is_batchnorm=is_batchnorm) - elif rnn_type == "gru": - net = gru_unroll(net=net, - seq_len=seq_len_after_conv_layer2, - num_hidden_gru_list=num_hidden_rnn_list, - num_gru_layer=num_rnn_layer, - dropout=0., - is_batchnorm=is_batchnorm) - elif rnn_type == "bigru": - net = bi_gru_unroll(net=net, + is_batchnorm=is_batchnorm, + is_bucketing=is_bucketing) + elif rnn_type == "bigru": + net = bi_gru_unroll(net=net, + seq_len=seq_len_after_conv_layer2, + num_hidden_gru_list=num_hidden_rnn_list, + num_gru_layer=num_rnn_layer, + dropout=0., + is_batchnorm=is_batchnorm, + is_bucketing=is_bucketing) + else: + raise Exception('rnn_type should be one of the followings, bilstm,gru,bigru') + + # rear fc layers + net = sequence_fc(net=net, seq_len=seq_len_after_conv_layer2, + num_layer=num_rear_fc_layers, prefix="rear", + num_hidden_list=num_hidden_rear_fc_list, + act_type_list=act_type_rear_fc_list, + is_batchnorm=is_batchnorm) + # warpctc layer + net = warpctc_layer(net=net, seq_len=seq_len_after_conv_layer2, - num_hidden_gru_list=num_hidden_rnn_list, - num_gru_layer=num_rnn_layer, - dropout=0., - is_batchnorm=is_batchnorm) + label=label, + num_label=num_label, + character_classes_count= + (args.config.getint('arch', 'n_classes') + 1)) + args.config.set('arch', 'max_t_count', str(seq_len_after_conv_layer2)) + return net + elif mode == 'load' or mode == 'predict': + conv_layer1_filter_dim = \ + tuple(json.loads(args.config.get("arch", "conv_layer1_filter_dim"))) + conv_layer1_stride = tuple(json.loads(args.config.get("arch", "conv_layer1_stride"))) + conv_layer2_filter_dim = \ + tuple(json.loads(args.config.get("arch", "conv_layer2_filter_dim"))) + conv_layer2_stride = tuple(json.loads(args.config.get("arch", "conv_layer2_stride"))) + if seq_len is None: + seq_len = args.config.getint('arch', 'max_t_count') + seq_len_after_conv_layer1 = int( + math.floor((seq_len - conv_layer1_filter_dim[0]) / conv_layer1_stride[0])) + 1 + seq_len_after_conv_layer2 = int( + math.floor((seq_len_after_conv_layer1 - conv_layer2_filter_dim[0]) + / conv_layer2_stride[0])) + 1 + + args.config.set('arch', 'max_t_count', str(seq_len_after_conv_layer2)) else: - raise Exception('rnn_type should be one of the followings, bilstm,gru,bigru') - - # rear fc layers - net = sequence_fc(net=net, seq_len=seq_len_after_conv_layer2, num_layer=num_rear_fc_layers, prefix="rear", - num_hidden_list=num_hidden_rear_fc_list, act_type_list=act_type_rear_fc_list, - is_batchnorm=is_batchnorm) - if is_batchnorm: - hidden_all = [] - # batch norm normalizes axis 1 - for seq_index in range(seq_len_after_conv_layer2): - hidden = net[seq_index] - hidden = batchnorm(hidden) - hidden_all.append(hidden) - net = hidden_all - - # warpctc layer - net = warpctc_layer(net=net, - seq_len=seq_len_after_conv_layer2, - label=label, - num_label=num_label, - character_classes_count=(args.config.getint('arch', 'n_classes') + 1) - ) - args.config.set('arch', 'max_t_count', str(seq_len_after_conv_layer2)) - return net + raise Exception('mode must be the one of the followings - train,predict,load') else: - conv_layer1_filter_dim = tuple(json.loads(args.config.get("arch", "conv_layer1_filter_dim"))) - conv_layer1_stride = tuple(json.loads(args.config.get("arch", "conv_layer1_stride"))) - conv_layer2_filter_dim = tuple(json.loads(args.config.get("arch", "conv_layer2_filter_dim"))) - conv_layer2_stride = tuple(json.loads(args.config.get("arch", "conv_layer2_stride"))) - seq_len = args.config.getint('arch', 'max_t_count') - seq_len_after_conv_layer1 = int( - math.floor((seq_len - conv_layer1_filter_dim[0]) / conv_layer1_stride[0])) + 1 - seq_len_after_conv_layer2 = int( - math.floor((seq_len_after_conv_layer1 - conv_layer2_filter_dim[0]) / conv_layer2_stride[0])) + 1 - args.config.set('arch', 'max_t_count', str(seq_len_after_conv_layer2)) + raise Exception('type of args should be one of the argparse.' + + 'Namespace for fixed length model or integer for variable length model') + + +class BucketingArch(object): + def __init__(self, args): + self.args = args + def sym_gen(self, seq_len): + args = self.args + net = arch(args, seq_len) + init_states = prepare_data(args) + init_state_names = [x[0] for x in init_states] + init_state_names.insert(0, 'data') + return net, init_state_names, ('label',) + def get_sym_gen(self): + return self.sym_gen diff --git a/example/speech_recognition/deepspeech.cfg b/example/speech_recognition/deepspeech.cfg index 13cf578c679a..224cb7bed9ce 100644 --- a/example/speech_recognition/deepspeech.cfg +++ b/example/speech_recognition/deepspeech.cfg @@ -4,12 +4,12 @@ mode = train #ex: gpu0,gpu1,gpu2,gpu3 context = gpu0,gpu1,gpu2 # checkpoint prefix, check point will be saved under checkpoints folder with prefix -prefix = deep +prefix = deep_bucket # when mode is load or predict, model will be loaded from the file name with model_file under checkpoints -model_file = deepspeechn_epoch1n_batch-0009 +model_file = deep_bucket-0001 batch_size = 12 # log will be saved by the log_filename -log_filename = deep.log +log_filename = deep_bucket.log # checkpoint set n to save checkpoints after n epoch save_checkpoint_every_n_epoch = 1 save_checkpoint_every_n_batch = 1000 @@ -18,6 +18,7 @@ tensorboard_log_dir = tblog/deep # if random_seed is -1 then it gets random seed from timestamp mx_random_seed = -1 random_seed = -1 +kvstore_option = device [data] train_json = ./train_corpus_all.json @@ -50,22 +51,18 @@ rnn_type = bigru #vanilla_lstm or fc_lstm (no effect when network_type is gru, bigru) lstm_type = fc_lstm is_batchnorm = True +is_bucketing = True +#[[0,2.3],[10,5.8],[10.8,25],[13.8,50],[15.1,75][15.8,90][29.7,100] +buckets = [200, 300, 400, 500, 600, 700, 800, 900, 1599] [train] num_epoch = 70 learning_rate = 0.0003 # constant learning rate annealing by factor learning_rate_annealing = 1.1 -# supports only sgd and adam -optimizer = sgd -# for sgd -momentum = 0.9 -# set to 0 to disable gradient clipping -clip_gradient = 0 initializer = Xavier init_scale = 2 factor_type = in -weight_decay = 0. # show progress every how nth batches show_every = 100 save_optimizer_states = True @@ -78,3 +75,23 @@ enable_logging_validation_metric = True [load] load_optimizer_states = True is_start_from_batch = True + +[optimizer] +optimizer = sgd +# define parameters for optimizer +# optimizer_params_dictionary should use " not ' as string wrapper +# sgd/nag +optimizer_params_dictionary={"momentum":0.9} +# dcasgd +# optimizer_params_dictionary={"momentum":0.9, "lamda":1.0} +# adam +# optimizer_params_dictionary={"beta1":0.9,"beta2":0.999} +# adagrad +# optimizer_params_dictionary={"eps":1e-08} +# rmsprop +# optimizer_params_dictionary={"gamma1":0.9, "gamma2":0.9,"epsilon":1e-08} +# adadelta +# optimizer_params_dictionary={"rho":0.95, "epsilon":1e-08} +# set to 0 to disable gradient clipping +clip_gradient = 100 +weight_decay = 0. diff --git a/example/speech_recognition/default.cfg b/example/speech_recognition/default.cfg index 853a04aebbdd..96adfa766ef5 100644 --- a/example/speech_recognition/default.cfg +++ b/example/speech_recognition/default.cfg @@ -18,6 +18,7 @@ tensorboard_log_dir = tblog/libri_sample # if random_seed is -1 then it gets random seed from timestamp mx_random_seed = -1 random_seed = -1 +kvstore_option = device [data] train_json = ./Libri_sample.json @@ -50,24 +51,17 @@ rnn_type = bigru #vanilla_lstm or fc_lstm (no effect when network_type is gru, bigru) lstm_type = fc_lstm is_batchnorm = True +is_bucketing = False +buckets = [] [train] num_epoch = 70 - learning_rate = 0.005 # constant learning rate annealing by factor learning_rate_annealing = 1.1 -# supports only sgd and adam -optimizer = adam -# for sgd -momentum = 0.9 -# set to 0 to disable gradient clipping -clip_gradient = 0 - initializer = Xavier init_scale = 2 factor_type = in -weight_decay = 0.00001 # show progress every nth batches show_every = 1 save_optimizer_states = True @@ -80,3 +74,23 @@ enable_logging_validation_metric = True [load] load_optimizer_states = True is_start_from_batch = False + +[optimizer] +optimizer = adam +# define parameters for optimizer +# optimizer_params_dictionary should use " not ' as string wrapper +# sgd/nag +# optimizer_params_dictionary={"momentum":0.9} +# dcasgd +# optimizer_params_dictionary={"momentum":0.9, "lamda":1.0} +# adam +optimizer_params_dictionary={"beta1":0.9,"beta2":0.999} +# adagrad +# optimizer_params_dictionary={"eps":1e-08} +# rmsprop +# optimizer_params_dictionary={"gamma1":0.9, "gamma2":0.9,"epsilon":1e-08} +# adadelta +# optimizer_params_dictionary={"rho":0.95, "epsilon":1e-08} +# set to 0 to disable gradient clipping +clip_gradient = 0 +weight_decay = 0. diff --git a/example/speech_recognition/main.py b/example/speech_recognition/main.py index 398a8a537e01..589cd74aa40f 100644 --- a/example/speech_recognition/main.py +++ b/example/speech_recognition/main.py @@ -1,33 +1,34 @@ +# pylint: disable=C0111, too-many-statements, too-many-locals, too-few-public-methods, too-many-branches +# pylint: too-many-arguments,too-many-instance-attributes,too-many-locals,redefined-outer-name,fixme +# pylint: disable=superfluous-parens, no-member, invalid-name, anomalous-backslash-in-string, redefined-outer-name +import json +import os import sys - -sys.path.insert(0, "../../python") +from collections import namedtuple +from datetime import datetime from config_util import parse_args, parse_contexts, generate_file_path from train import do_training import mxnet as mx from stt_io_iter import STTIter from label_util import LabelUtil from log_util import LogUtil - import numpy as np from stt_datagenerator import DataGenerator from stt_metric import STTMetric -from datetime import datetime from stt_bi_graphemes_util import generate_bi_graphemes_dictionary -######################################## -########## FOR JUPYTER NOTEBOOK -import os +from stt_bucketing_module import STTBucketingModule +from stt_io_bucketingiter import BucketSTTIter +sys.path.insert(0, "../../python") + + + + # os.environ['MXNET_ENGINE_TYPE'] = "NaiveEngine" os.environ['MXNET_ENGINE_TYPE'] = "ThreadedEnginePerDevice" os.environ['MXNET_ENABLE_GPU_P2P'] = "0" - -class WHCS: - width = 0 - height = 0 - channel = 0 - stride = 0 - +WHCS = namedtuple("WHCS", ["width", "height", "channel", "stride"]) class ConfigLogger(object): def __init__(self, log): @@ -45,6 +46,8 @@ def write(self, data): def load_data(args): mode = args.config.get('common', 'mode') + if mode not in ['train', 'predict', 'load']: + raise Exception('mode must be the one of the followings - train,predict,load') batch_size = args.config.getint('common', 'batch_size') whcs = WHCS() @@ -57,7 +60,6 @@ def load_data(args): is_bi_graphemes = args.config.getboolean('common', 'is_bi_graphemes') overwrite_meta_files = args.config.getboolean('train', 'overwrite_meta_files') language = args.config.get('data', 'language') - is_bi_graphemes = args.config.getboolean('common', 'is_bi_graphemes') labelUtil = LabelUtil.getInstance() if language == "en": @@ -65,25 +67,20 @@ def load_data(args): try: labelUtil.load_unicode_set("resources/unicodemap_en_baidu_bi_graphemes.csv") except: - raise Exception("There is no resources/unicodemap_en_baidu_bi_graphemes.csv. Please set overwrite_meta_files at train section True") + raise Exception("There is no resources/unicodemap_en_baidu_bi_graphemes.csv." + + " Please set overwrite_meta_files at train section True") else: labelUtil.load_unicode_set("resources/unicodemap_en_baidu.csv") else: raise Exception("Error: Language Type: %s" % language) args.config.set('arch', 'n_classes', str(labelUtil.get_count())) - if mode == 'predict': - test_json = args.config.get('data', 'test_json') - datagen = DataGenerator(save_dir=save_dir, model_name=model_name) - datagen.load_train_data(test_json) - datagen.get_meta_from_file(np.loadtxt(generate_file_path(save_dir, model_name, 'feats_mean')), - np.loadtxt(generate_file_path(save_dir, model_name, 'feats_std'))) - elif mode =="train" or mode == "load": + if mode == "train" or mode == "load": data_json = args.config.get('data', 'train_json') val_json = args.config.get('data', 'val_json') datagen = DataGenerator(save_dir=save_dir, model_name=model_name) datagen.load_train_data(data_json) - #test bigramphems + # test bigramphems if overwrite_meta_files and is_bi_graphemes: generate_bi_graphemes_dictionary(datagen.train_texts) @@ -95,62 +92,50 @@ def load_data(args): normalize_target_k = args.config.getint('train', 'normalize_target_k') datagen.sample_normalize(normalize_target_k, True) else: - datagen.get_meta_from_file(np.loadtxt(generate_file_path(save_dir, model_name, 'feats_mean')), - np.loadtxt(generate_file_path(save_dir, model_name, 'feats_std'))) + datagen.get_meta_from_file( + np.loadtxt(generate_file_path(save_dir, model_name, 'feats_mean')), + np.loadtxt(generate_file_path(save_dir, model_name, 'feats_std'))) datagen.load_validation_data(val_json) elif mode == "load": # get feat_mean and feat_std to normalize dataset - datagen.get_meta_from_file(np.loadtxt(generate_file_path(save_dir, model_name, 'feats_mean')), - np.loadtxt(generate_file_path(save_dir, model_name, 'feats_std'))) + datagen.get_meta_from_file( + np.loadtxt(generate_file_path(save_dir, model_name, 'feats_mean')), + np.loadtxt(generate_file_path(save_dir, model_name, 'feats_std'))) datagen.load_validation_data(val_json) - else: - raise Exception( - 'Define mode in the cfg file first. train or predict or load can be the candidate for the mode.') + elif mode == 'predict': + test_json = args.config.get('data', 'test_json') + datagen = DataGenerator(save_dir=save_dir, model_name=model_name) + datagen.load_train_data(test_json) + datagen.get_meta_from_file( + np.loadtxt(generate_file_path(save_dir, model_name, 'feats_mean')), + np.loadtxt(generate_file_path(save_dir, model_name, 'feats_std'))) is_batchnorm = args.config.getboolean('arch', 'is_batchnorm') - if batch_size == 1 and is_batchnorm: + if batch_size == 1 and is_batchnorm and (mode == 'train' or mode == 'load'): raise Warning('batch size 1 is too small for is_batchnorm') # sort file paths by its duration in ascending order to implement sortaGrad - if mode == "train" or mode == "load": max_t_count = datagen.get_max_seq_length(partition="train") - max_label_length = datagen.get_max_label_length(partition="train",is_bi_graphemes=is_bi_graphemes) + max_label_length = \ + datagen.get_max_label_length(partition="train", is_bi_graphemes=is_bi_graphemes) elif mode == "predict": max_t_count = datagen.get_max_seq_length(partition="test") - max_label_length = datagen.get_max_label_length(partition="test",is_bi_graphemes=is_bi_graphemes) - else: - raise Exception( - 'Define mode in the cfg file first. train or predict or load can be the candidate for the mode.') + max_label_length = \ + datagen.get_max_label_length(partition="test", is_bi_graphemes=is_bi_graphemes) args.config.set('arch', 'max_t_count', str(max_t_count)) args.config.set('arch', 'max_label_length', str(max_label_length)) from importlib import import_module prepare_data_template = import_module(args.config.get('arch', 'arch_file')) init_states = prepare_data_template.prepare_data(args) - if mode == "train": - sort_by_duration = True - else: - sort_by_duration = False - - data_loaded = STTIter(partition="train", - count=datagen.count, - datagen=datagen, - batch_size=batch_size, - num_label=max_label_length, - init_states=init_states, - seq_length=max_t_count, - width=whcs.width, - height=whcs.height, - sort_by_duration=sort_by_duration, - is_bi_graphemes=is_bi_graphemes) - - if mode == 'predict': - return data_loaded, args - else: - validation_loaded = STTIter(partition="validation", - count=datagen.val_count, + sort_by_duration = (mode == "train") + is_bucketing = args.config.getboolean('arch', 'is_bucketing') + if is_bucketing: + buckets = json.loads(args.config.get('arch', 'buckets')) + data_loaded = BucketSTTIter(partition="train", + count=datagen.count, datagen=datagen, batch_size=batch_size, num_label=max_label_length, @@ -158,37 +143,87 @@ def load_data(args): seq_length=max_t_count, width=whcs.width, height=whcs.height, - sort_by_duration=False, - is_bi_graphemes=is_bi_graphemes) + sort_by_duration=sort_by_duration, + is_bi_graphemes=is_bi_graphemes, + buckets=buckets) + else: + data_loaded = STTIter(partition="train", + count=datagen.count, + datagen=datagen, + batch_size=batch_size, + num_label=max_label_length, + init_states=init_states, + seq_length=max_t_count, + width=whcs.width, + height=whcs.height, + sort_by_duration=sort_by_duration, + is_bi_graphemes=is_bi_graphemes) + + if mode == 'train' or mode == 'load': + if is_bucketing: + validation_loaded = BucketSTTIter(partition="validation", + count=datagen.val_count, + datagen=datagen, + batch_size=batch_size, + num_label=max_label_length, + init_states=init_states, + seq_length=max_t_count, + width=whcs.width, + height=whcs.height, + sort_by_duration=False, + is_bi_graphemes=is_bi_graphemes, + buckets=buckets) + else: + validation_loaded = STTIter(partition="validation", + count=datagen.val_count, + datagen=datagen, + batch_size=batch_size, + num_label=max_label_length, + init_states=init_states, + seq_length=max_t_count, + width=whcs.width, + height=whcs.height, + sort_by_duration=False, + is_bi_graphemes=is_bi_graphemes) return data_loaded, validation_loaded, args + elif mode == 'predict': + return data_loaded, args def load_model(args, contexts, data_train): # load model from model_name prefix and epoch of model_num_epoch with gpu contexts of contexts mode = args.config.get('common', 'mode') load_optimizer_states = args.config.getboolean('load', 'load_optimizer_states') - is_start_from_batch = args.config.getboolean('load','is_start_from_batch') + is_start_from_batch = args.config.getboolean('load', 'is_start_from_batch') from importlib import import_module symbol_template = import_module(args.config.get('arch', 'arch_file')) - model_loaded = symbol_template.arch(args) + is_bucketing = args.config.getboolean('arch', 'is_bucketing') if mode == 'train': + if is_bucketing: + bucketing_arch = symbol_template.BucketingArch(args) + model_loaded = bucketing_arch.get_sym_gen() + else: + model_loaded = symbol_template.arch(args) model_num_epoch = None - else: + elif mode == 'load' or mode == 'predict': model_file = args.config.get('common', 'model_file') model_name = os.path.splitext(model_file)[0] - model_num_epoch = int(model_name[-4:]) + if is_bucketing: + bucketing_arch = symbol_template.BucketingArch(args) + model_loaded = bucketing_arch.get_sym_gen() + else: + model_path = 'checkpoints/' + str(model_name[:-5]) - model_path = 'checkpoints/' + str(model_name[:-5]) - - data_names = [x[0] for x in data_train.provide_data] - label_names = [x[0] for x in data_train.provide_label] + data_names = [x[0] for x in data_train.provide_data] + label_names = [x[0] for x in data_train.provide_label] - model_loaded = mx.module.Module.load(prefix=model_path, epoch=model_num_epoch, context=contexts, - data_names=data_names, label_names=label_names, - load_optimizer_states=load_optimizer_states) + model_loaded = mx.module.Module.load( + prefix=model_path, epoch=model_num_epoch, context=contexts, + data_names=data_names, label_names=label_names, + load_optimizer_states=load_optimizer_states) if is_start_from_batch: import re model_num_epoch = int(re.findall('\d+', model_file)[0]) @@ -198,7 +233,8 @@ def load_model(args, contexts, data_train): if __name__ == '__main__': if len(sys.argv) <= 1: - raise Exception('cfg file path must be provided. ex)python main.py --configfile examplecfg.cfg') + raise Exception('cfg file path must be provided. ' + + 'ex)python main.py --configfile examplecfg.cfg') args = parse_args(sys.argv[1]) # set parameters from cfg file # give random seed @@ -206,9 +242,9 @@ def load_model(args, contexts, data_train): mx_random_seed = args.config.getint('common', 'mx_random_seed') # random seed for shuffling data list if random_seed != -1: - random.seed(random_seed) + np.random.seed(random_seed) # set mx.random.seed to give seed for parameter initialization - if mx_random_seed !=-1: + if mx_random_seed != -1: mx.random.seed(mx_random_seed) else: mx.random.seed(hash(datetime.now())) @@ -220,22 +256,23 @@ def load_model(args, contexts, data_train): mode = args.config.get('common', 'mode') if mode not in ['train', 'predict', 'load']: raise Exception( - 'Define mode in the cfg file first. train or predict or load can be the candidate for the mode.') + 'Define mode in the cfg file first. ' + + 'train or predict or load can be the candidate for the mode.') # get meta file where character to number conversions are defined contexts = parse_contexts(args) num_gpu = len(contexts) batch_size = args.config.getint('common', 'batch_size') - # check the number of gpus is positive divisor of the batch size for data parallel if batch_size % num_gpu != 0: raise Exception('num_gpu should be positive divisor of batch_size') - - if mode == "predict": - data_train, args = load_data(args) - elif mode == "train" or mode == "load": + if mode == "train" or mode == "load": data_train, data_val, args = load_data(args) + elif mode == "predict": + data_train, args = load_data(args) + is_batchnorm = args.config.getboolean('arch', 'is_batchnorm') + is_bucketing = args.config.getboolean('arch', 'is_bucketing') # log current config config_logger = ConfigLogger(log) @@ -246,25 +283,56 @@ def load_model(args, contexts, data_train): # if mode is 'train', it trains the model if mode == 'train': - data_names = [x[0] for x in data_train.provide_data] - label_names = [x[0] for x in data_train.provide_label] - module = mx.mod.Module(model_loaded, context=contexts, data_names=data_names, label_names=label_names) + if is_bucketing: + module = STTBucketingModule( + sym_gen=model_loaded, + default_bucket_key=data_train.default_bucket_key, + context=contexts) + else: + data_names = [x[0] for x in data_train.provide_data] + label_names = [x[0] for x in data_train.provide_label] + module = mx.mod.Module(model_loaded, context=contexts, + data_names=data_names, label_names=label_names) do_training(args=args, module=module, data_train=data_train, data_val=data_val) # if mode is 'load', it loads model from the checkpoint and continues the training. elif mode == 'load': - do_training(args=args, module=model_loaded, data_train=data_train, data_val=data_val, begin_epoch=model_num_epoch+1) + do_training(args=args, module=model_loaded, data_train=data_train, data_val=data_val, + begin_epoch=model_num_epoch + 1) # if mode is 'predict', it predict label from the input by the input model elif mode == 'predict': # predict through data - model_loaded.bind(for_training=False, data_shapes=data_train.provide_data, - label_shapes=data_train.provide_label) + if is_bucketing: + max_t_count = args.config.getint('arch', 'max_t_count') + load_optimizer_states = args.config.getboolean('load', 'load_optimizer_states') + model_file = args.config.get('common', 'model_file') + model_name = os.path.splitext(model_file)[0] + model_num_epoch = int(model_name[-4:]) + + model_path = 'checkpoints/' + str(model_name[:-5]) + model = STTBucketingModule( + sym_gen=model_loaded, + default_bucket_key=data_train.default_bucket_key, + context=contexts) + + model.bind(data_shapes=data_train.provide_data, + label_shapes=data_train.provide_label, + for_training=True) + _, arg_params, aux_params = mx.model.load_checkpoint(model_path, model_num_epoch) + model.set_params(arg_params, aux_params) + model_loaded = model + else: + model_loaded.bind(for_training=False, data_shapes=data_train.provide_data, + label_shapes=data_train.provide_label) max_t_count = args.config.getint('arch', 'max_t_count') - eval_metric = STTMetric(batch_size=batch_size, num_gpu=num_gpu, seq_length=max_t_count) - is_batchnorm = args.config.getboolean('arch', 'is_batchnorm') - if is_batchnorm : + eval_metric = STTMetric(batch_size=batch_size, num_gpu=num_gpu) + if is_batchnorm: for nbatch, data_batch in enumerate(data_train): - # when is_train = False it leads to high cer when batch_norm - model_loaded.forward(data_batch, is_train=True) + model_loaded.forward(data_batch, is_train=False) model_loaded.update_metric(eval_metric, data_batch.label) - else : - model_loaded.score(eval_data=data_train, num_batch=None, eval_metric=eval_metric, reset=True) + else: + model_loaded.score(eval_data=data_train, num_batch=None, + eval_metric=eval_metric, reset=True) + else: + raise Exception( + 'Define mode in the cfg file first. ' + + 'train or predict or load can be the candidate for the mode') diff --git a/example/speech_recognition/stt_bucketing_module.py b/example/speech_recognition/stt_bucketing_module.py new file mode 100644 index 000000000000..00ba8df19bd5 --- /dev/null +++ b/example/speech_recognition/stt_bucketing_module.py @@ -0,0 +1,13 @@ +import mxnet as mx + + +class STTBucketingModule(mx.mod.BucketingModule): + + def save_checkpoint(self, prefix, epoch, save_optimizer_states=False): + symbol, data_names, label_names = self._sym_gen(self._default_bucket_key) + symbol.save('%s-symbol.json' % prefix) + param_name = '%s-%04d.params' % (prefix, epoch) + self.save_params(param_name) + if save_optimizer_states: + state_name = '%s-%04d.states' % (prefix, epoch) + self._curr_module.save_optimizer_states(state_name) \ No newline at end of file diff --git a/example/speech_recognition/stt_datagenerator.py b/example/speech_recognition/stt_datagenerator.py index 390de432e751..a13490fc3435 100644 --- a/example/speech_recognition/stt_datagenerator.py +++ b/example/speech_recognition/stt_datagenerator.py @@ -32,7 +32,7 @@ def __init__(self, save_dir, model_name, step=10, window=20, max_freq=8000, desc # 1d 161 length of array filled with 1s self.feats_std = np.ones((self.feat_dim,)) self.max_input_length = 0 - self.max_length_list_in_batch =[] + self.max_length_list_in_batch = [] # 1d 161 length of array filled with random value #[0.0, 1.0) self.rng = random.Random() @@ -146,10 +146,11 @@ def get_max_seq_length(self, partition): "Must be train/validation/test") max_duration_indexes = durations.index(max(durations)) max_seq_length = self.featurize(audio_paths[max_duration_indexes]).shape[0] - self.max_seq_length=max_seq_length + self.max_seq_length = max_seq_length return max_seq_length - def prepare_minibatch(self, audio_paths, texts, overwrite=False, is_bi_graphemes=False): + def prepare_minibatch(self, audio_paths, texts, overwrite=False, + is_bi_graphemes=False, seq_length=-1): """ Featurize a minibatch of audio, zero pad them and return a dictionary Params: audio_paths (list(str)): List of paths to audio files @@ -167,7 +168,10 @@ def prepare_minibatch(self, audio_paths, texts, overwrite=False, is_bi_graphemes feature_dim = features[0].shape[1] mb_size = len(features) # Pad all the inputs so that they are all the same length - x = np.zeros((mb_size, self.max_seq_length, feature_dim)) + if seq_length == -1: + x = np.zeros((mb_size, self.max_seq_length, feature_dim)) + else: + x = np.zeros((mb_size, seq_length, feature_dim)) y = np.zeros((mb_size, self.max_label_length)) labelUtil = LabelUtil.getInstance() label_lengths = [] @@ -223,10 +227,13 @@ def sample_normalize(self, k_samples=1000, overwrite=False): next_feat_squared = np.square(next_feat) feat_vertically_stacked = np.concatenate((feat, next_feat)).reshape(-1, dim) feat = np.sum(feat_vertically_stacked, axis=0, keepdims=True) - feat_squared_vertically_stacked = np.concatenate((feat_squared, next_feat_squared)).reshape(-1, dim) + feat_squared_vertically_stacked = np.concatenate( + (feat_squared, next_feat_squared)).reshape(-1, dim) feat_squared = np.sum(feat_squared_vertically_stacked, axis=0, keepdims=True) - count = count + float(next_feat.shape[0]) + count += float(next_feat.shape[0]) self.feats_mean = feat / float(count) self.feats_std = np.sqrt(feat_squared / float(count) - np.square(self.feats_mean)) - np.savetxt(generate_file_path(self.save_dir, self.model_name, 'feats_mean'), self.feats_mean) - np.savetxt(generate_file_path(self.save_dir, self.model_name, 'feats_std'), self.feats_std) + np.savetxt( + generate_file_path(self.save_dir, self.model_name, 'feats_mean'), self.feats_mean) + np.savetxt( + generate_file_path(self.save_dir, self.model_name, 'feats_std'), self.feats_std) diff --git a/example/speech_recognition/stt_io_bucketingiter.py b/example/speech_recognition/stt_io_bucketingiter.py new file mode 100644 index 000000000000..e50715d60861 --- /dev/null +++ b/example/speech_recognition/stt_io_bucketingiter.py @@ -0,0 +1,145 @@ +from __future__ import print_function +import mxnet as mx +import sys +sys.path.insert(0, "../../python") + +import bisect +import random +import numpy as np + +BATCH_SIZE = 1 +SEQ_LENGTH = 0 +NUM_GPU = 1 + + +def get_label(buf, num_lable): + ret = np.zeros(num_lable) + for i in range(len(buf)): + ret[i] = int(buf[i]) + return ret + + +class BucketSTTIter(mx.io.DataIter): + def __init__(self, count, datagen, batch_size, num_label, init_states, seq_length, width, height, + sort_by_duration=True, + is_bi_graphemes=False, + partition="train", + buckets=[] + ): + super(BucketSTTIter, self).__init__() + + self.maxLabelLength = num_label + # global param + self.batch_size = batch_size + self.count = count + self.num_label = num_label + self.init_states = init_states + self.init_state_arrays = [mx.nd.zeros(x[1]) for x in init_states] + self.width = width + self.height = height + self.datagen = datagen + self.label = None + self.is_bi_graphemes = is_bi_graphemes + # self.partition = datagen.partition + if partition == 'train': + durations = datagen.train_durations + audio_paths = datagen.train_audio_paths + texts = datagen.train_texts + elif partition == 'validation': + durations = datagen.val_durations + audio_paths = datagen.val_audio_paths + texts = datagen.val_texts + elif partition == 'test': + durations = datagen.test_durations + audio_paths = datagen.test_audio_paths + texts = datagen.test_texts + else: + raise Exception("Invalid partition to load metadata. " + "Must be train/validation/test") + # if sortagrad + if sort_by_duration: + durations, audio_paths, texts = datagen.sort_by_duration(durations, + audio_paths, + texts) + else: + durations = durations + audio_paths = audio_paths + texts = texts + self.trainDataList = zip(durations, audio_paths, texts) + + self.trainDataIter = iter(self.trainDataList) + self.is_first_epoch = True + + data_lengths = [int(d*100) for d in durations] + if len(buckets) == 0: + buckets = [i for i, j in enumerate(np.bincount(data_lengths)) + if j >= batch_size] + if len(buckets) == 0: + raise Exception('There is no valid buckets. It may occured by large batch_size for each buckets. max bincount:%d batch_size:%d' % (max(np.bincount(data_lengths)), batch_size)) + buckets.sort() + ndiscard = 0 + self.data = [[] for _ in buckets] + for i, sent in enumerate(data_lengths): + buck = bisect.bisect_left(buckets, sent) + if buck == len(buckets): + ndiscard += 1 + continue + self.data[buck].append(self.trainDataList[i]) + if ndiscard != 0: + print("WARNING: discarded %d sentences longer than the largest bucket."% ndiscard) + + self.buckets = buckets + self.nddata = [] + self.ndlabel = [] + self.default_bucket_key = max(buckets) + + self.idx = [] + for i, buck in enumerate(self.data): + self.idx.extend([(i, j) for j in range(0, len(buck) - batch_size + 1, batch_size)]) + self.curr_idx = 0 + + self.provide_data = [('data', (self.batch_size, self.default_bucket_key , width * height))] + init_states + self.provide_label = [('label', (self.batch_size, self.maxLabelLength))] + + #self.reset() + + def reset(self): + """Resets the iterator to the beginning of the data.""" + self.curr_idx = 0 + random.shuffle(self.idx) + for buck in self.data: + np.random.shuffle(buck) + + def next(self): + """Returns the next batch of data.""" + if self.curr_idx == len(self.idx): + raise StopIteration + i, j = self.idx[self.curr_idx] + self.curr_idx += 1 + + audio_paths = [] + texts = [] + for duration, audio_path, text in self.data[i][j:j+self.batch_size]: + audio_paths.append(audio_path) + texts.append(text) + + if self.is_first_epoch: + data_set = self.datagen.prepare_minibatch(audio_paths, texts, overwrite=True, + is_bi_graphemes=self.is_bi_graphemes, + seq_length=self.buckets[i] + ) + else: + data_set = self.datagen.prepare_minibatch(audio_paths, texts, overwrite=False, + is_bi_graphemes=self.is_bi_graphemes, + seq_length=self.buckets[i]) + + data_all = [mx.nd.array(data_set['x'])] + self.init_state_arrays + label_all = [mx.nd.array(data_set['y'])] + + self.label = label_all + provide_data = [('data', (self.batch_size, self.buckets[i], self.width * self.height))] + self.init_states + + return mx.io.DataBatch(data_all, label_all, pad=0, + bucket_key=self.buckets[i], + provide_data=provide_data, + provide_label=self.provide_label) diff --git a/example/speech_recognition/stt_io_iter.py b/example/speech_recognition/stt_io_iter.py index 70c31ce92dde..51029167bc4e 100644 --- a/example/speech_recognition/stt_io_iter.py +++ b/example/speech_recognition/stt_io_iter.py @@ -103,7 +103,6 @@ def __iter__(self): data_batch = SimpleBatch(data_names, data_all, label_names, label_all) yield data_batch - self.is_first_epoch = False def reset(self): pass diff --git a/example/speech_recognition/stt_layer_batchnorm.py b/example/speech_recognition/stt_layer_batchnorm.py index 86e75aa49557..5b73f4f9f890 100644 --- a/example/speech_recognition/stt_layer_batchnorm.py +++ b/example/speech_recognition/stt_layer_batchnorm.py @@ -6,7 +6,7 @@ def batchnorm(net, beta=None, eps=0.001, momentum=0.9, - fix_gamma=True, + fix_gamma=False, use_global_stats=False, output_mean_var=False, name=None): @@ -18,7 +18,8 @@ def batchnorm(net, momentum=momentum, fix_gamma=fix_gamma, use_global_stats=use_global_stats, - output_mean_var=output_mean_var + output_mean_var=output_mean_var, + name=name ) else: net = mx.sym.BatchNorm(data=net, @@ -26,6 +27,7 @@ def batchnorm(net, momentum=momentum, fix_gamma=fix_gamma, use_global_stats=use_global_stats, - output_mean_var=output_mean_var + output_mean_var=output_mean_var, + name=name ) return net diff --git a/example/speech_recognition/stt_layer_conv.py b/example/speech_recognition/stt_layer_conv.py index 5ec292557f04..ab0035e4803b 100644 --- a/example/speech_recognition/stt_layer_conv.py +++ b/example/speech_recognition/stt_layer_conv.py @@ -8,20 +8,22 @@ def conv(net, weight=None, bias=None, act_type="relu", - no_bias=False + no_bias=False, + name=None ): # 2d convolution's input should have the shape of 4D (batch_size,1,seq_len,feat_dim) if weight is None or bias is None: # ex) filter_dimension = (41,11) , stride=(2,2) - net = mx.sym.Convolution(data=net, num_filter=channels, kernel=filter_dimension, stride=stride, no_bias=no_bias) + net = mx.sym.Convolution(data=net, num_filter=channels, kernel=filter_dimension, stride=stride, no_bias=no_bias, + name=name) elif weight is None or bias is not None: net = mx.sym.Convolution(data=net, num_filter=channels, kernel=filter_dimension, stride=stride, bias=bias, - no_bias=no_bias) + no_bias=no_bias, name=name) elif weight is not None or bias is None: net = mx.sym.Convolution(data=net, num_filter=channels, kernel=filter_dimension, stride=stride, weight=weight, - no_bias=no_bias) + no_bias=no_bias, name=name) else: net = mx.sym.Convolution(data=net, num_filter=channels, kernel=filter_dimension, stride=stride, weight=weight, - bias=bias, no_bias=no_bias) + bias=bias, no_bias=no_bias, name=name) net = mx.sym.Activation(data=net, act_type=act_type) return net diff --git a/example/speech_recognition/stt_layer_fc.py b/example/speech_recognition/stt_layer_fc.py index b3db2034a3ad..f435922426c5 100644 --- a/example/speech_recognition/stt_layer_fc.py +++ b/example/speech_recognition/stt_layer_fc.py @@ -8,29 +8,30 @@ def fc(net, act_type, weight=None, bias=None, - no_bias=False + no_bias=False, + name=None ): # when weight and bias doesn't have specific name if weight is None and bias is None: - net = mx.sym.FullyConnected(data=net, num_hidden=num_hidden, no_bias=no_bias) + net = mx.sym.FullyConnected(data=net, num_hidden=num_hidden, no_bias=no_bias, name=name) # when weight doesn't have specific name but bias has elif weight is None and bias is not None: if no_bias: - net = mx.sym.FullyConnected(data=net, num_hidden=num_hidden, no_bias=no_bias) + net = mx.sym.FullyConnected(data=net, num_hidden=num_hidden, no_bias=no_bias, name=name) else: - net = mx.sym.FullyConnected(data=net, num_hidden=num_hidden, bias=bias, no_bias=no_bias) + net = mx.sym.FullyConnected(data=net, num_hidden=num_hidden, bias=bias, no_bias=no_bias, name=name) # when bias doesn't have specific name but weight has elif weight is not None and bias is None: - net = mx.sym.FullyConnected(data=net, num_hidden=num_hidden, weight=weight, no_bias=no_bias) + net = mx.sym.FullyConnected(data=net, num_hidden=num_hidden, weight=weight, no_bias=no_bias, name=name) # when weight and bias specific name else: if no_bias: - net = mx.sym.FullyConnected(data=net, num_hidden=num_hidden, weight=weight, no_bias=no_bias) + net = mx.sym.FullyConnected(data=net, num_hidden=num_hidden, weight=weight, no_bias=no_bias, name=name) else: - net = mx.sym.FullyConnected(data=net, num_hidden=num_hidden, weight=weight, bias=bias, no_bias=no_bias) + net = mx.sym.FullyConnected(data=net, num_hidden=num_hidden, weight=weight, bias=bias, no_bias=no_bias, name=name) # activation if act_type is not None: - net = mx.sym.Activation(data=net, act_type=act_type) + net = mx.sym.Activation(data=net, act_type=act_type, name="%s_activation" % name) return net @@ -41,7 +42,7 @@ def sequence_fc(net, num_hidden_list=[], act_type_list=[], is_batchnorm=False, - dropout_rate=0 + dropout_rate=0, ): if num_layer == len(num_hidden_list) == len(act_type_list): if num_layer > 0: @@ -81,13 +82,16 @@ def sequence_fc(net, num_hidden=num_hidden_list[layer_index], act_type=None, weight=weight_list[layer_index], - no_bias=is_batchnorm + no_bias=is_batchnorm, + name="%s_t%d_l%d_fc" % (prefix, seq_index, layer_index) ) # last layer doesn't have batchnorm hidden = batchnorm(net=hidden, gamma=gamma_list[layer_index], - beta=beta_list[layer_index]) - hidden = mx.sym.Activation(data=hidden, act_type=act_type_list[layer_index]) + beta=beta_list[layer_index], + name="%s_t%d_l%d_batchnorm" % (prefix, seq_index, layer_index)) + hidden = mx.sym.Activation(data=hidden, act_type=act_type_list[layer_index], + name="%s_t%d_l%d_activation" % (prefix, seq_index, layer_index)) else: hidden = fc(net=hidden, num_hidden=num_hidden_list[layer_index], diff --git a/example/speech_recognition/stt_layer_gru.py b/example/speech_recognition/stt_layer_gru.py index 8b044746dfcf..89af1c72216d 100644 --- a/example/speech_recognition/stt_layer_gru.py +++ b/example/speech_recognition/stt_layer_gru.py @@ -15,7 +15,7 @@ "param_blocks"]) -def gru(num_hidden, indata, prev_state, param, seqidx, layeridx, dropout=0., is_batchnorm=False, gamma=None, beta=None): +def gru(num_hidden, indata, prev_state, param, seqidx, layeridx, dropout=0., is_batchnorm=False, gamma=None, beta=None, name=None): """ GRU Cell symbol Reference: @@ -31,7 +31,10 @@ def gru(num_hidden, indata, prev_state, param, seqidx, layeridx, dropout=0., is_ name="t%d_l%d_gates_i2h" % (seqidx, layeridx)) if is_batchnorm: - i2h = batchnorm(net=i2h, gamma=gamma, beta=beta) + if name is not None: + i2h = batchnorm(net=i2h, gamma=gamma, beta=beta, name="%s_batchnorm" % name) + else: + i2h = batchnorm(net=i2h, gamma=gamma, beta=beta) h2h = mx.sym.FullyConnected(data=prev_state.h, weight=param.gates_h2h_weight, bias=param.gates_h2h_bias, @@ -53,15 +56,15 @@ def gru(num_hidden, indata, prev_state, param, seqidx, layeridx, dropout=0., is_ weight=param.trans_h2h_weight, bias=param.trans_h2h_bias, num_hidden=num_hidden, - name="t%d_l%d_trans_i2h" % (seqidx, layeridx)) + name="t%d_l%d_trans_h2h" % (seqidx, layeridx)) h_trans = htrans_i2h + htrans_h2h h_trans_active = mx.sym.Activation(h_trans, act_type="tanh") next_h = prev_state.h + update_gate * (h_trans_active - prev_state.h) return GRUState(h=next_h) -def gru_unroll(net, num_gru_layer, seq_len, num_hidden_gru_list, dropout=0., is_batchnorm=False, prefix="", - direction="forward"): +def gru_unroll(net, num_gru_layer, seq_len, num_hidden_gru_list, dropout=0., is_batchnorm=False, prefix="", + direction="forward", is_bucketing=False): if num_gru_layer > 0: param_cells = [] last_states = [] @@ -81,9 +84,14 @@ def gru_unroll(net, num_gru_layer, seq_len, num_hidden_gru_list, dropout=0., is_ if is_batchnorm: batchnorm_gamma = [] batchnorm_beta = [] - for seqidx in range(seq_len): - batchnorm_gamma.append(mx.sym.Variable(prefix + "t%d_i2h_gamma" % seqidx)) - batchnorm_beta.append(mx.sym.Variable(prefix + "t%d_i2h_beta" % seqidx)) + if is_bucketing: + for l in range(num_gru_layer): + batchnorm_gamma.append(mx.sym.Variable(prefix + "l%d_i2h_gamma" % l)) + batchnorm_beta.append(mx.sym.Variable(prefix + "l%d_i2h_beta" % l)) + else: + for seqidx in range(seq_len): + batchnorm_gamma.append(mx.sym.Variable(prefix + "t%d_i2h_gamma" % seqidx)) + batchnorm_beta.append(mx.sym.Variable(prefix + "t%d_i2h_beta" % seqidx)) hidden_all = [] for seqidx in range(seq_len): @@ -103,19 +111,33 @@ def gru_unroll(net, num_gru_layer, seq_len, num_hidden_gru_list, dropout=0., is_ else: dp_ratio = dropout if is_batchnorm: - next_state = gru(num_hidden_gru_list[i], indata=hidden, - prev_state=last_states[i], - param=param_cells[i], - seqidx=k, layeridx=i, dropout=dp_ratio, - is_batchnorm=is_batchnorm, - gamma=batchnorm_gamma[k], - beta=batchnorm_beta[k]) + if is_bucketing: + next_state = gru(num_hidden_gru_list[i], indata=hidden, + prev_state=last_states[i], + param=param_cells[i], + seqidx=k, layeridx=i, dropout=dp_ratio, + is_batchnorm=is_batchnorm, + gamma=batchnorm_gamma[i], + beta=batchnorm_beta[i], + name=prefix + ("t%d_l%d" % (seqidx, i)) + ) + else: + next_state = gru(num_hidden_gru_list[i], indata=hidden, + prev_state=last_states[i], + param=param_cells[i], + seqidx=k, layeridx=i, dropout=dp_ratio, + is_batchnorm=is_batchnorm, + gamma=batchnorm_gamma[k], + beta=batchnorm_beta[k], + name=prefix + ("t%d_l%d" % (seqidx, i)) + ) else: next_state = gru(num_hidden_gru_list[i], indata=hidden, prev_state=last_states[i], param=param_cells[i], seqidx=k, layeridx=i, dropout=dp_ratio, - is_batchnorm=is_batchnorm) + is_batchnorm=is_batchnorm, + name=prefix) hidden = next_state.h last_states[i] = next_state # decoder @@ -133,7 +155,7 @@ def gru_unroll(net, num_gru_layer, seq_len, num_hidden_gru_list, dropout=0., is_ return net -def bi_gru_unroll(net, num_gru_layer, seq_len, num_hidden_gru_list, dropout=0., is_batchnorm=False): +def bi_gru_unroll(net, num_gru_layer, seq_len, num_hidden_gru_list, dropout=0., is_batchnorm=False, is_bucketing=False): if num_gru_layer > 0: net_forward = gru_unroll(net=net, num_gru_layer=num_gru_layer, @@ -142,7 +164,8 @@ def bi_gru_unroll(net, num_gru_layer, seq_len, num_hidden_gru_list, dropout=0., dropout=dropout, is_batchnorm=is_batchnorm, prefix="forward_", - direction="forward") + direction="forward", + is_bucketing=is_bucketing) net_backward = gru_unroll(net=net, num_gru_layer=num_gru_layer, seq_len=seq_len, @@ -150,7 +173,8 @@ def bi_gru_unroll(net, num_gru_layer, seq_len, num_hidden_gru_list, dropout=0., dropout=dropout, is_batchnorm=is_batchnorm, prefix="backward_", - direction="backward") + direction="backward", + is_bucketing=is_bucketing) hidden_all = [] for i in range(seq_len): hidden_all.append(mx.sym.Concat(*[net_forward[i], net_backward[i]], dim=1)) @@ -159,7 +183,7 @@ def bi_gru_unroll(net, num_gru_layer, seq_len, num_hidden_gru_list, dropout=0., def bi_gru_unroll_two_input_two_output(net1, net2, num_gru_layer, seq_len, num_hidden_gru_list, dropout=0., - is_batchnorm=False): + is_batchnorm=False, is_bucketing=False): if num_gru_layer > 0: net_forward = gru_unroll(net=net1, num_gru_layer=num_gru_layer, @@ -168,7 +192,8 @@ def bi_gru_unroll_two_input_two_output(net1, net2, num_gru_layer, seq_len, num_h dropout=dropout, is_batchnorm=is_batchnorm, prefix="forward_", - direction="forward") + direction="forward", + is_bucketing=is_bucketing) net_backward = gru_unroll(net=net2, num_gru_layer=num_gru_layer, seq_len=seq_len, @@ -176,7 +201,8 @@ def bi_gru_unroll_two_input_two_output(net1, net2, num_gru_layer, seq_len, num_h dropout=dropout, is_batchnorm=is_batchnorm, prefix="backward_", - direction="backward") + direction="backward", + is_bucketing=is_bucketing) return net_forward, net_backward else: return net1, net2 diff --git a/example/speech_recognition/stt_layer_lstm.py b/example/speech_recognition/stt_layer_lstm.py index 19e37369b1b0..93b4ca09b908 100644 --- a/example/speech_recognition/stt_layer_lstm.py +++ b/example/speech_recognition/stt_layer_lstm.py @@ -16,7 +16,7 @@ "param_blocks"]) -def vanilla_lstm(num_hidden, indata, prev_state, param, seqidx, layeridx, is_batchnorm=False, gamma=None, beta=None): +def vanilla_lstm(num_hidden, indata, prev_state, param, seqidx, layeridx, is_batchnorm=False, gamma=None, beta=None, name=None): """LSTM Cell symbol""" i2h = mx.sym.FullyConnected(data=indata, weight=param.i2h_weight, @@ -24,7 +24,10 @@ def vanilla_lstm(num_hidden, indata, prev_state, param, seqidx, layeridx, is_bat num_hidden=num_hidden * 4, name="t%d_l%d_i2h" % (seqidx, layeridx)) if is_batchnorm: - i2h = batchnorm(net=i2h, gamma=gamma, beta=beta) + if name is not None: + i2h = batchnorm(net=i2h, gamma=gamma, beta=beta, name="%s_batchnorm" % name) + else: + i2h = batchnorm(net=i2h, gamma=gamma, beta=beta) h2h = mx.sym.FullyConnected(data=prev_state.h, weight=param.h2h_weight, bias=param.h2h_bias, @@ -43,7 +46,7 @@ def vanilla_lstm(num_hidden, indata, prev_state, param, seqidx, layeridx, is_bat def lstm(num_hidden, indata, prev_state, param, seqidx, layeridx, dropout=0., num_hidden_proj=0, is_batchnorm=False, - gamma=None, beta=None): + gamma=None, beta=None, name=None): """LSTM Cell symbol""" # dropout input if dropout > 0.: @@ -55,7 +58,10 @@ def lstm(num_hidden, indata, prev_state, param, seqidx, layeridx, dropout=0., nu num_hidden=num_hidden * 4, name="t%d_l%d_i2h" % (seqidx, layeridx)) if is_batchnorm: - i2h = batchnorm(net=i2h, gamma=gamma, beta=beta) + if name is not None: + i2h = batchnorm(net=i2h, gamma=gamma, beta=beta, name="%s_batchnorm" % name) + else: + i2h = batchnorm(net=i2h, gamma=gamma, beta=beta) h2h = mx.sym.FullyConnected(data=prev_state.h, weight=param.h2h_weight, @@ -96,7 +102,7 @@ def lstm(num_hidden, indata, prev_state, param, seqidx, layeridx, dropout=0., nu def lstm_unroll(net, num_lstm_layer, seq_len, num_hidden_lstm_list, dropout=0., num_hidden_proj=0, - lstm_type='fc_lstm', is_batchnorm=False, prefix="", direction="forward"): + lstm_type='fc_lstm', is_batchnorm=False, prefix="", direction="forward", is_bucketing=False): if num_lstm_layer > 0: param_cells = [] last_states = [] @@ -121,9 +127,14 @@ def lstm_unroll(net, num_lstm_layer, seq_len, num_hidden_lstm_list, dropout=0., if is_batchnorm: batchnorm_gamma = [] batchnorm_beta = [] - for seqidx in range(seq_len): - batchnorm_gamma.append(mx.sym.Variable(prefix + "t%d_i2h_gamma" % seqidx)) - batchnorm_beta.append(mx.sym.Variable(prefix + "t%d_i2h_beta" % seqidx)) + if is_bucketing: + for l in range(num_lstm_layer): + batchnorm_gamma.append(mx.sym.Variable(prefix + "l%d_i2h_gamma" % l)) + batchnorm_beta.append(mx.sym.Variable(prefix + "l%d_i2h_beta" % l)) + else: + for seqidx in range(seq_len): + batchnorm_gamma.append(mx.sym.Variable(prefix + "t%d_i2h_gamma" % seqidx)) + batchnorm_beta.append(mx.sym.Variable(prefix + "t%d_i2h_beta" % seqidx)) hidden_all = [] for seqidx in range(seq_len): @@ -145,18 +156,20 @@ def lstm_unroll(net, num_lstm_layer, seq_len, num_hidden_lstm_list, dropout=0., if lstm_type == 'fc_lstm': if is_batchnorm: - next_state = lstm(num_hidden_lstm_list[i], - indata=hidden, - prev_state=last_states[i], - param=param_cells[i], - seqidx=k, - layeridx=i, - dropout=dp, - num_hidden_proj=num_hidden_proj, - is_batchnorm=is_batchnorm, - gamma=batchnorm_gamma[k], - beta=batchnorm_beta[k] - ) + if is_bucketing: + next_state = lstm(num_hidden_lstm_list[i], + indata=hidden, + prev_state=last_states[i], + param=param_cells[i], + seqidx=k, + layeridx=i, + dropout=dp, + num_hidden_proj=num_hidden_proj, + is_batchnorm=is_batchnorm, + gamma=batchnorm_gamma[i], + beta=batchnorm_beta[i], + name=prefix + ("t%d_l%d" % (seqidx, i)) + ) else: next_state = lstm(num_hidden_lstm_list[i], indata=hidden, @@ -166,7 +179,8 @@ def lstm_unroll(net, num_lstm_layer, seq_len, num_hidden_lstm_list, dropout=0., layeridx=i, dropout=dp, num_hidden_proj=num_hidden_proj, - is_batchnorm=is_batchnorm + is_batchnorm=is_batchnorm, + name=prefix + ("t%d_l%d" % (seqidx, i)) ) elif lstm_type == 'vanilla_lstm': if is_batchnorm: @@ -175,15 +189,17 @@ def lstm_unroll(net, num_lstm_layer, seq_len, num_hidden_lstm_list, dropout=0., param=param_cells[i], seqidx=k, layeridx=i, is_batchnorm=is_batchnorm, - gamma=batchnorm_gamma[k], - beta=batchnorm_beta[k] + gamma=batchnorm_gamma[i], + beta=batchnorm_beta[i], + name=prefix + ("t%d_l%d" % (seqidx, i)) ) else: next_state = vanilla_lstm(num_hidden_lstm_list[i], indata=hidden, prev_state=last_states[i], param=param_cells[i], seqidx=k, layeridx=i, - is_batchnorm=is_batchnorm + is_batchnorm=is_batchnorm, + name=prefix + ("t%d_l%d" % (seqidx, i)) ) else: raise Exception("lstm type %s error" % lstm_type) @@ -206,7 +222,7 @@ def lstm_unroll(net, num_lstm_layer, seq_len, num_hidden_lstm_list, dropout=0., def bi_lstm_unroll(net, num_lstm_layer, seq_len, num_hidden_lstm_list, dropout=0., num_hidden_proj=0, - lstm_type='fc_lstm', is_batchnorm=False): + lstm_type='fc_lstm', is_batchnorm=False, is_bucketing=False): if num_lstm_layer > 0: net_forward = lstm_unroll(net=net, num_lstm_layer=num_lstm_layer, @@ -217,7 +233,8 @@ def bi_lstm_unroll(net, num_lstm_layer, seq_len, num_hidden_lstm_list, dropout=0 lstm_type=lstm_type, is_batchnorm=is_batchnorm, prefix="forward_", - direction="forward") + direction="forward", + is_bucketing=is_bucketing) net_backward = lstm_unroll(net=net, num_lstm_layer=num_lstm_layer, @@ -228,7 +245,8 @@ def bi_lstm_unroll(net, num_lstm_layer, seq_len, num_hidden_lstm_list, dropout=0 lstm_type=lstm_type, is_batchnorm=is_batchnorm, prefix="backward_", - direction="backward") + direction="backward", + is_bucketing=is_bucketing) hidden_all = [] for i in range(seq_len): hidden_all.append(mx.sym.Concat(*[net_forward[i], net_backward[i]], dim=1)) @@ -239,7 +257,9 @@ def bi_lstm_unroll(net, num_lstm_layer, seq_len, num_hidden_lstm_list, dropout=0 # bilistm_2to1 def bi_lstm_unroll_two_input_two_output(net1, net2, num_lstm_layer, seq_len, num_hidden_lstm_list, dropout=0., num_hidden_proj=0, - lstm_type='fc_lstm', is_batchnorm=False): + lstm_type='fc_lstm', + is_batchnorm=False, + is_bucketing=False): if num_lstm_layer > 0: net_forward = lstm_unroll(net=net1, num_lstm_layer=num_lstm_layer, @@ -250,7 +270,8 @@ def bi_lstm_unroll_two_input_two_output(net1, net2, num_lstm_layer, seq_len, num lstm_type=lstm_type, is_batchnorm=is_batchnorm, prefix="forward_", - direction="forward") + direction="forward", + is_bucketing=is_bucketing) net_backward = lstm_unroll(net=net2, num_lstm_layer=num_lstm_layer, @@ -261,7 +282,8 @@ def bi_lstm_unroll_two_input_two_output(net1, net2, num_lstm_layer, seq_len, num lstm_type=lstm_type, is_batchnorm=is_batchnorm, prefix="backward_", - direction="backward") + direction="backward", + is_bucketing=is_bucketing) return net_forward, net_backward else: return net1, net2 diff --git a/example/speech_recognition/stt_metric.py b/example/speech_recognition/stt_metric.py index 0fc2bd11d906..1c5f4408a60e 100644 --- a/example/speech_recognition/stt_metric.py +++ b/example/speech_recognition/stt_metric.py @@ -19,12 +19,11 @@ def check_label_shapes(labels, preds, shape=0): class STTMetric(mx.metric.EvalMetric): - def __init__(self, batch_size, num_gpu, seq_length, is_epoch_end=False, is_logging=True): + def __init__(self, batch_size, num_gpu, is_epoch_end=False, is_logging=True): super(STTMetric, self).__init__('STTMetric') self.batch_size = batch_size self.num_gpu = num_gpu - self.seq_length = seq_length self.total_n_label = 0 self.total_l_dist = 0 self.is_epoch_end = is_epoch_end @@ -37,15 +36,17 @@ def update(self, labels, preds): log = LogUtil().getlogger() labelUtil = LabelUtil.getInstance() self.batch_loss = 0. + for label, pred in zip(labels, preds): label = label.asnumpy() pred = pred.asnumpy() - for i in range(int(int(self.batch_size) / int(self.num_gpu))): + seq_length = len(pred) / int(int(self.batch_size) / int(self.num_gpu)) + for i in range(int(int(self.batch_size) / int(self.num_gpu))): l = remove_blank(label[i]) p = [] - for k in range(int(self.seq_length)): + for k in range(int(seq_length)): p.append(np.argmax(pred[k * int(int(self.batch_size) / int(self.num_gpu)) + i])) p = pred_best(p) @@ -60,7 +61,7 @@ def update(self, labels, preds): self.num_inst += 1 self.sum_metric += this_cer if self.is_epoch_end: - loss = ctc_loss(l, pred, i, int(self.seq_length), int(self.batch_size), int(self.num_gpu)) + loss = ctc_loss(l, pred, i, int(seq_length), int(self.batch_size), int(self.num_gpu)) self.batch_loss += loss if self.is_logging: log.info("loss: %f " % loss) diff --git a/example/speech_recognition/train.py b/example/speech_recognition/train.py index 37f00fc4dd90..708e3e03acf0 100644 --- a/example/speech_recognition/train.py +++ b/example/speech_recognition/train.py @@ -7,7 +7,8 @@ from stt_metric import STTMetric #tensorboard setting from tensorboard import SummaryWriter -import numpy as np +import json + def get_initializer(args): @@ -28,6 +29,7 @@ def __init__(self, learning_rate=0.001): def __call__(self, num_update): return self.learning_rate + def do_training(args, module, data_train, data_val, begin_epoch=0): from distutils.dir_util import mkpath from log_util import LogUtil @@ -35,7 +37,7 @@ def do_training(args, module, data_train, data_val, begin_epoch=0): log = LogUtil().getlogger() mkpath(os.path.dirname(get_checkpoint_path(args))) - seq_len = args.config.get('arch', 'max_t_count') + #seq_len = args.config.get('arch', 'max_t_count') batch_size = args.config.getint('common', 'batch_size') save_checkpoint_every_n_epoch = args.config.getint('common', 'save_checkpoint_every_n_epoch') save_checkpoint_every_n_batch = args.config.getint('common', 'save_checkpoint_every_n_batch') @@ -44,21 +46,22 @@ def do_training(args, module, data_train, data_val, begin_epoch=0): contexts = parse_contexts(args) num_gpu = len(contexts) - eval_metric = STTMetric(batch_size=batch_size, num_gpu=num_gpu, seq_length=seq_len,is_logging=enable_logging_validation_metric,is_epoch_end=True) + eval_metric = STTMetric(batch_size=batch_size, num_gpu=num_gpu, is_logging=enable_logging_validation_metric,is_epoch_end=True) # tensorboard setting - loss_metric = STTMetric(batch_size=batch_size, num_gpu=num_gpu, seq_length=seq_len,is_logging=enable_logging_train_metric,is_epoch_end=False) + loss_metric = STTMetric(batch_size=batch_size, num_gpu=num_gpu, is_logging=enable_logging_train_metric,is_epoch_end=False) - optimizer = args.config.get('train', 'optimizer') - momentum = args.config.getfloat('train', 'momentum') + optimizer = args.config.get('optimizer', 'optimizer') learning_rate = args.config.getfloat('train', 'learning_rate') learning_rate_annealing = args.config.getfloat('train', 'learning_rate_annealing') mode = args.config.get('common', 'mode') num_epoch = args.config.getint('train', 'num_epoch') - clip_gradient = args.config.getfloat('train', 'clip_gradient') - weight_decay = args.config.getfloat('train', 'weight_decay') + clip_gradient = args.config.getfloat('optimizer', 'clip_gradient') + weight_decay = args.config.getfloat('optimizer', 'weight_decay') save_optimizer_states = args.config.getboolean('train', 'save_optimizer_states') show_every = args.config.getint('train', 'show_every') + optimizer_params_dictionary = json.loads(args.config.get('optimizer', 'optimizer_params_dictionary')) + kvstore_option = args.config.get('common', 'kvstore_option') n_epoch=begin_epoch if clip_gradient == 0: @@ -75,24 +78,14 @@ def do_training(args, module, data_train, data_val, begin_epoch=0): lr_scheduler = SimpleLRScheduler(learning_rate=learning_rate) def reset_optimizer(force_init=False): - if optimizer == "sgd": - module.init_optimizer(kvstore='device', - optimizer=optimizer, - optimizer_params={'lr_scheduler': lr_scheduler, - 'momentum': momentum, - 'clip_gradient': clip_gradient, - 'wd': weight_decay}, - force_init=force_init) - elif optimizer == "adam": - module.init_optimizer(kvstore='device', - optimizer=optimizer, - optimizer_params={'lr_scheduler': lr_scheduler, - #'momentum': momentum, - 'clip_gradient': clip_gradient, - 'wd': weight_decay}, - force_init=force_init) - else: - raise Exception('Supported optimizers are sgd and adam. If you want to implement others define them in train.py') + optimizer_params = {'lr_scheduler': lr_scheduler, + 'clip_gradient': clip_gradient, + 'wd': weight_decay} + optimizer_params.update(optimizer_params_dictionary) + module.init_optimizer(kvstore=kvstore_option, + optimizer=optimizer, + optimizer_params=optimizer_params, + force_init=force_init) if mode == "train": reset_optimizer(force_init=True) else: @@ -101,15 +94,23 @@ def reset_optimizer(force_init=False): #tensorboard setting tblog_dir = args.config.get('common', 'tensorboard_log_dir') summary_writer = SummaryWriter(tblog_dir) + + + if mode == "train": + sort_by_duration = True + else: + sort_by_duration = False + + if not sort_by_duration: + data_train.reset() + while True: if n_epoch >= num_epoch: break - loss_metric.reset() log.info('---------train---------') for nbatch, data_batch in enumerate(data_train): - module.forward_backward(data_batch) module.update() # tensorboard setting @@ -136,6 +137,7 @@ def reset_optimizer(force_init=False): assert curr_acc is not None, 'cannot find Acc_exclude_padding in eval metric' data_train.reset() + data_train.is_first_epoch = False # tensorboard setting train_cer, train_n_label, train_l_dist, train_ctc_loss = loss_metric.get_name_value()