diff --git a/example/speech_recognition/README.md b/example/speech_recognition/README.md index 69961b1bdc5c..12e415cf6ad7 100644 --- a/example/speech_recognition/README.md +++ b/example/speech_recognition/README.md @@ -123,3 +123,16 @@ The new file should implement two functions, prepare_data() and arch(), for buil Run the following line after preparing the files.
python main.py --configfile custom.cfg --archfile arch_custom
+
+***
+## **Further more**
+You can prepare full LibriSpeech dataset by following the instruction on https://github.com/baidu-research/ba-dls-deepspeech
+```bash
+git clone https://github.com/baidu-research/ba-dls-deepspeech
+cd ba-dls-deepspeech
+./download.sh
+./flac_to_wav.sh
+python create_desc_json.py /path/to/ba-dls-deepspeech/LibriSpeech/train-clean-100 train_corpus.json
+python create_desc_json.py /path/to/ba-dls-deepspeech/LibriSpeech/dev-clean validation_corpus.json
+python create_desc_json.py /path/to/ba-dls-deepspeech/LibriSpeech/test-clean test_corpus.json
+```
diff --git a/example/speech_recognition/arch_deepspeech.py b/example/speech_recognition/arch_deepspeech.py
index 92f1002a2f01..99d96c6c281b 100644
--- a/example/speech_recognition/arch_deepspeech.py
+++ b/example/speech_recognition/arch_deepspeech.py
@@ -1,6 +1,12 @@
+# pylint: disable=C0111, too-many-statements, too-many-locals
+# pylint: too-many-arguments,too-many-instance-attributes,too-many-locals,redefined-outer-name,fixme
+# pylint: disable=superfluous-parens, no-member, invalid-name
+"""
+architecture file for deep speech 2 model
+"""
import json
import math
-
+import argparse
import mxnet as mx
from stt_layer_batchnorm import batchnorm
@@ -13,6 +19,9 @@
def prepare_data(args):
+ """
+ set atual shape of data
+ """
rnn_type = args.config.get("arch", "rnn_type")
num_rnn_layer = args.config.getint("arch", "num_rnn_layer")
num_hidden_rnn_list = json.loads(args.config.get("arch", "num_hidden_rnn_list"))
@@ -20,26 +29,29 @@ def prepare_data(args):
batch_size = args.config.getint("common", "batch_size")
if rnn_type == 'lstm':
- init_c = [('l%d_init_c' % l, (batch_size, num_hidden_rnn_list[l])) for l in range(num_rnn_layer)]
- init_h = [('l%d_init_h' % l, (batch_size, num_hidden_rnn_list[l])) for l in range(num_rnn_layer)]
+ init_c = [('l%d_init_c' % l, (batch_size, num_hidden_rnn_list[l]))
+ for l in range(num_rnn_layer)]
+ init_h = [('l%d_init_h' % l, (batch_size, num_hidden_rnn_list[l]))
+ for l in range(num_rnn_layer)]
elif rnn_type == 'bilstm':
- forward_init_c = [('forward_l%d_init_c' % l, (batch_size, num_hidden_rnn_list[l])) for l in
- range(num_rnn_layer)]
- backward_init_c = [('backward_l%d_init_c' % l, (batch_size, num_hidden_rnn_list[l])) for l in
- range(num_rnn_layer)]
+ forward_init_c = [('forward_l%d_init_c' % l, (batch_size, num_hidden_rnn_list[l]))
+ for l in range(num_rnn_layer)]
+ backward_init_c = [('backward_l%d_init_c' % l, (batch_size, num_hidden_rnn_list[l]))
+ for l in range(num_rnn_layer)]
init_c = forward_init_c + backward_init_c
- forward_init_h = [('forward_l%d_init_h' % l, (batch_size, num_hidden_rnn_list[l])) for l in
- range(num_rnn_layer)]
- backward_init_h = [('backward_l%d_init_h' % l, (batch_size, num_hidden_rnn_list[l])) for l in
- range(num_rnn_layer)]
+ forward_init_h = [('forward_l%d_init_h' % l, (batch_size, num_hidden_rnn_list[l]))
+ for l in range(num_rnn_layer)]
+ backward_init_h = [('backward_l%d_init_h' % l, (batch_size, num_hidden_rnn_list[l]))
+ for l in range(num_rnn_layer)]
init_h = forward_init_h + backward_init_h
elif rnn_type == 'gru':
- init_h = [('l%d_init_h' % l, (batch_size, num_hidden_rnn_list[l])) for l in range(num_rnn_layer)]
+ init_h = [('l%d_init_h' % l, (batch_size, num_hidden_rnn_list[l]))
+ for l in range(num_rnn_layer)]
elif rnn_type == 'bigru':
- forward_init_h = [('forward_l%d_init_h' % l, (batch_size, num_hidden_rnn_list[l])) for l in
- range(num_rnn_layer)]
- backward_init_h = [('backward_l%d_init_h' % l, (batch_size, num_hidden_rnn_list[l])) for l in
- range(num_rnn_layer)]
+ forward_init_h = [('forward_l%d_init_h' % l, (batch_size, num_hidden_rnn_list[l]))
+ for l in range(num_rnn_layer)]
+ backward_init_h = [('backward_l%d_init_h' % l, (batch_size, num_hidden_rnn_list[l]))
+ for l in range(num_rnn_layer)]
init_h = forward_init_h + backward_init_h
else:
raise Exception('network type should be one of the lstm,bilstm,gru,bigru')
@@ -51,115 +63,146 @@ def prepare_data(args):
return init_states
-def arch(args):
- mode = args.config.get("common", "mode")
- if mode == "train":
- channel_num = args.config.getint("arch", "channel_num")
- conv_layer1_filter_dim = tuple(json.loads(args.config.get("arch", "conv_layer1_filter_dim")))
- conv_layer1_stride = tuple(json.loads(args.config.get("arch", "conv_layer1_stride")))
- conv_layer2_filter_dim = tuple(json.loads(args.config.get("arch", "conv_layer2_filter_dim")))
- conv_layer2_stride = tuple(json.loads(args.config.get("arch", "conv_layer2_stride")))
-
- rnn_type = args.config.get("arch", "rnn_type")
- num_rnn_layer = args.config.getint("arch", "num_rnn_layer")
- num_hidden_rnn_list = json.loads(args.config.get("arch", "num_hidden_rnn_list"))
-
- is_batchnorm = args.config.getboolean("arch", "is_batchnorm")
-
- seq_len = args.config.getint('arch', 'max_t_count')
- num_label = args.config.getint('arch', 'max_label_length')
-
- num_rear_fc_layers = args.config.getint("arch", "num_rear_fc_layers")
- num_hidden_rear_fc_list = json.loads(args.config.get("arch", "num_hidden_rear_fc_list"))
- act_type_rear_fc_list = json.loads(args.config.get("arch", "act_type_rear_fc_list"))
- # model symbol generation
- # input preparation
- data = mx.sym.Variable('data')
- label = mx.sym.Variable('label')
-
- net = mx.sym.Reshape(data=data, shape=(-4, -1, 1, 0, 0))
- net = conv(net=net,
- channels=channel_num,
- filter_dimension=conv_layer1_filter_dim,
- stride=conv_layer1_stride,
- no_bias=is_batchnorm
- )
- if is_batchnorm:
- # batch norm normalizes axis 1
- net = batchnorm(net)
-
- net = conv(net=net,
- channels=channel_num,
- filter_dimension=conv_layer2_filter_dim,
- stride=conv_layer2_stride,
- no_bias=is_batchnorm
- )
- if is_batchnorm:
- # batch norm normalizes axis 1
- net = batchnorm(net)
- net = mx.sym.transpose(data=net, axes=(0, 2, 1, 3))
- net = mx.sym.Reshape(data=net, shape=(0, 0, -3))
- seq_len_after_conv_layer1 = int(
- math.floor((seq_len - conv_layer1_filter_dim[0]) / conv_layer1_stride[0])) + 1
- seq_len_after_conv_layer2 = int(
- math.floor((seq_len_after_conv_layer1 - conv_layer2_filter_dim[0]) / conv_layer2_stride[0])) + 1
- net = slice_symbol_to_seq_symobls(net=net, seq_len=seq_len_after_conv_layer2, axis=1)
- if rnn_type == "bilstm":
- net = bi_lstm_unroll(net=net,
+def arch(args, seq_len=None):
+ """
+ define deep speech 2 network
+ """
+ if isinstance(args, argparse.Namespace):
+ mode = args.config.get("common", "mode")
+ if mode == "train":
+ channel_num = args.config.getint("arch", "channel_num")
+ conv_layer1_filter_dim = \
+ tuple(json.loads(args.config.get("arch", "conv_layer1_filter_dim")))
+ conv_layer1_stride = tuple(json.loads(args.config.get("arch", "conv_layer1_stride")))
+ conv_layer2_filter_dim = \
+ tuple(json.loads(args.config.get("arch", "conv_layer2_filter_dim")))
+ conv_layer2_stride = tuple(json.loads(args.config.get("arch", "conv_layer2_stride")))
+
+ rnn_type = args.config.get("arch", "rnn_type")
+ num_rnn_layer = args.config.getint("arch", "num_rnn_layer")
+ num_hidden_rnn_list = json.loads(args.config.get("arch", "num_hidden_rnn_list"))
+
+ is_batchnorm = args.config.getboolean("arch", "is_batchnorm")
+ is_bucketing = args.config.getboolean("arch", "is_bucketing")
+
+ if seq_len is None:
+ seq_len = args.config.getint('arch', 'max_t_count')
+
+ num_label = args.config.getint('arch', 'max_label_length')
+
+ num_rear_fc_layers = args.config.getint("arch", "num_rear_fc_layers")
+ num_hidden_rear_fc_list = json.loads(args.config.get("arch", "num_hidden_rear_fc_list"))
+ act_type_rear_fc_list = json.loads(args.config.get("arch", "act_type_rear_fc_list"))
+ # model symbol generation
+ # input preparation
+ data = mx.sym.Variable('data')
+ label = mx.sym.Variable('label')
+
+ net = mx.sym.Reshape(data=data, shape=(-4, -1, 1, 0, 0))
+ net = conv(net=net,
+ channels=channel_num,
+ filter_dimension=conv_layer1_filter_dim,
+ stride=conv_layer1_stride,
+ no_bias=is_batchnorm,
+ name='conv1')
+ if is_batchnorm:
+ # batch norm normalizes axis 1
+ net = batchnorm(net, name="conv1_batchnorm")
+
+ net = conv(net=net,
+ channels=channel_num,
+ filter_dimension=conv_layer2_filter_dim,
+ stride=conv_layer2_stride,
+ no_bias=is_batchnorm,
+ name='conv2')
+ # if is_batchnorm:
+ # # batch norm normalizes axis 1
+ # net = batchnorm(net, name="conv2_batchnorm")
+
+ net = mx.sym.transpose(data=net, axes=(0, 2, 1, 3))
+ net = mx.sym.Reshape(data=net, shape=(0, 0, -3))
+ seq_len_after_conv_layer1 = int(
+ math.floor((seq_len - conv_layer1_filter_dim[0]) / conv_layer1_stride[0])) + 1
+ seq_len_after_conv_layer2 = int(
+ math.floor((seq_len_after_conv_layer1 - conv_layer2_filter_dim[0])
+ / conv_layer2_stride[0])) + 1
+ net = slice_symbol_to_seq_symobls(net=net, seq_len=seq_len_after_conv_layer2, axis=1)
+ if rnn_type == "bilstm":
+ net = bi_lstm_unroll(net=net,
+ seq_len=seq_len_after_conv_layer2,
+ num_hidden_lstm_list=num_hidden_rnn_list,
+ num_lstm_layer=num_rnn_layer,
+ dropout=0.,
+ is_batchnorm=is_batchnorm,
+ is_bucketing=is_bucketing)
+ elif rnn_type == "gru":
+ net = gru_unroll(net=net,
seq_len=seq_len_after_conv_layer2,
- num_hidden_lstm_list=num_hidden_rnn_list,
- num_lstm_layer=num_rnn_layer,
+ num_hidden_gru_list=num_hidden_rnn_list,
+ num_gru_layer=num_rnn_layer,
dropout=0.,
- is_batchnorm=is_batchnorm)
- elif rnn_type == "gru":
- net = gru_unroll(net=net,
- seq_len=seq_len_after_conv_layer2,
- num_hidden_gru_list=num_hidden_rnn_list,
- num_gru_layer=num_rnn_layer,
- dropout=0.,
- is_batchnorm=is_batchnorm)
- elif rnn_type == "bigru":
- net = bi_gru_unroll(net=net,
+ is_batchnorm=is_batchnorm,
+ is_bucketing=is_bucketing)
+ elif rnn_type == "bigru":
+ net = bi_gru_unroll(net=net,
+ seq_len=seq_len_after_conv_layer2,
+ num_hidden_gru_list=num_hidden_rnn_list,
+ num_gru_layer=num_rnn_layer,
+ dropout=0.,
+ is_batchnorm=is_batchnorm,
+ is_bucketing=is_bucketing)
+ else:
+ raise Exception('rnn_type should be one of the followings, bilstm,gru,bigru')
+
+ # rear fc layers
+ net = sequence_fc(net=net, seq_len=seq_len_after_conv_layer2,
+ num_layer=num_rear_fc_layers, prefix="rear",
+ num_hidden_list=num_hidden_rear_fc_list,
+ act_type_list=act_type_rear_fc_list,
+ is_batchnorm=is_batchnorm)
+ # warpctc layer
+ net = warpctc_layer(net=net,
seq_len=seq_len_after_conv_layer2,
- num_hidden_gru_list=num_hidden_rnn_list,
- num_gru_layer=num_rnn_layer,
- dropout=0.,
- is_batchnorm=is_batchnorm)
+ label=label,
+ num_label=num_label,
+ character_classes_count=
+ (args.config.getint('arch', 'n_classes') + 1))
+ args.config.set('arch', 'max_t_count', str(seq_len_after_conv_layer2))
+ return net
+ elif mode == 'load' or mode == 'predict':
+ conv_layer1_filter_dim = \
+ tuple(json.loads(args.config.get("arch", "conv_layer1_filter_dim")))
+ conv_layer1_stride = tuple(json.loads(args.config.get("arch", "conv_layer1_stride")))
+ conv_layer2_filter_dim = \
+ tuple(json.loads(args.config.get("arch", "conv_layer2_filter_dim")))
+ conv_layer2_stride = tuple(json.loads(args.config.get("arch", "conv_layer2_stride")))
+ if seq_len is None:
+ seq_len = args.config.getint('arch', 'max_t_count')
+ seq_len_after_conv_layer1 = int(
+ math.floor((seq_len - conv_layer1_filter_dim[0]) / conv_layer1_stride[0])) + 1
+ seq_len_after_conv_layer2 = int(
+ math.floor((seq_len_after_conv_layer1 - conv_layer2_filter_dim[0])
+ / conv_layer2_stride[0])) + 1
+
+ args.config.set('arch', 'max_t_count', str(seq_len_after_conv_layer2))
else:
- raise Exception('rnn_type should be one of the followings, bilstm,gru,bigru')
-
- # rear fc layers
- net = sequence_fc(net=net, seq_len=seq_len_after_conv_layer2, num_layer=num_rear_fc_layers, prefix="rear",
- num_hidden_list=num_hidden_rear_fc_list, act_type_list=act_type_rear_fc_list,
- is_batchnorm=is_batchnorm)
- if is_batchnorm:
- hidden_all = []
- # batch norm normalizes axis 1
- for seq_index in range(seq_len_after_conv_layer2):
- hidden = net[seq_index]
- hidden = batchnorm(hidden)
- hidden_all.append(hidden)
- net = hidden_all
-
- # warpctc layer
- net = warpctc_layer(net=net,
- seq_len=seq_len_after_conv_layer2,
- label=label,
- num_label=num_label,
- character_classes_count=(args.config.getint('arch', 'n_classes') + 1)
- )
- args.config.set('arch', 'max_t_count', str(seq_len_after_conv_layer2))
- return net
+ raise Exception('mode must be the one of the followings - train,predict,load')
else:
- conv_layer1_filter_dim = tuple(json.loads(args.config.get("arch", "conv_layer1_filter_dim")))
- conv_layer1_stride = tuple(json.loads(args.config.get("arch", "conv_layer1_stride")))
- conv_layer2_filter_dim = tuple(json.loads(args.config.get("arch", "conv_layer2_filter_dim")))
- conv_layer2_stride = tuple(json.loads(args.config.get("arch", "conv_layer2_stride")))
- seq_len = args.config.getint('arch', 'max_t_count')
- seq_len_after_conv_layer1 = int(
- math.floor((seq_len - conv_layer1_filter_dim[0]) / conv_layer1_stride[0])) + 1
- seq_len_after_conv_layer2 = int(
- math.floor((seq_len_after_conv_layer1 - conv_layer2_filter_dim[0]) / conv_layer2_stride[0])) + 1
- args.config.set('arch', 'max_t_count', str(seq_len_after_conv_layer2))
+ raise Exception('type of args should be one of the argparse.' +
+ 'Namespace for fixed length model or integer for variable length model')
+
+
+class BucketingArch(object):
+ def __init__(self, args):
+ self.args = args
+ def sym_gen(self, seq_len):
+ args = self.args
+ net = arch(args, seq_len)
+ init_states = prepare_data(args)
+ init_state_names = [x[0] for x in init_states]
+ init_state_names.insert(0, 'data')
+ return net, init_state_names, ('label',)
+ def get_sym_gen(self):
+ return self.sym_gen
diff --git a/example/speech_recognition/deepspeech.cfg b/example/speech_recognition/deepspeech.cfg
index 13cf578c679a..224cb7bed9ce 100644
--- a/example/speech_recognition/deepspeech.cfg
+++ b/example/speech_recognition/deepspeech.cfg
@@ -4,12 +4,12 @@ mode = train
#ex: gpu0,gpu1,gpu2,gpu3
context = gpu0,gpu1,gpu2
# checkpoint prefix, check point will be saved under checkpoints folder with prefix
-prefix = deep
+prefix = deep_bucket
# when mode is load or predict, model will be loaded from the file name with model_file under checkpoints
-model_file = deepspeechn_epoch1n_batch-0009
+model_file = deep_bucket-0001
batch_size = 12
# log will be saved by the log_filename
-log_filename = deep.log
+log_filename = deep_bucket.log
# checkpoint set n to save checkpoints after n epoch
save_checkpoint_every_n_epoch = 1
save_checkpoint_every_n_batch = 1000
@@ -18,6 +18,7 @@ tensorboard_log_dir = tblog/deep
# if random_seed is -1 then it gets random seed from timestamp
mx_random_seed = -1
random_seed = -1
+kvstore_option = device
[data]
train_json = ./train_corpus_all.json
@@ -50,22 +51,18 @@ rnn_type = bigru
#vanilla_lstm or fc_lstm (no effect when network_type is gru, bigru)
lstm_type = fc_lstm
is_batchnorm = True
+is_bucketing = True
+#[[0,2.3],[10,5.8],[10.8,25],[13.8,50],[15.1,75][15.8,90][29.7,100]
+buckets = [200, 300, 400, 500, 600, 700, 800, 900, 1599]
[train]
num_epoch = 70
learning_rate = 0.0003
# constant learning rate annealing by factor
learning_rate_annealing = 1.1
-# supports only sgd and adam
-optimizer = sgd
-# for sgd
-momentum = 0.9
-# set to 0 to disable gradient clipping
-clip_gradient = 0
initializer = Xavier
init_scale = 2
factor_type = in
-weight_decay = 0.
# show progress every how nth batches
show_every = 100
save_optimizer_states = True
@@ -78,3 +75,23 @@ enable_logging_validation_metric = True
[load]
load_optimizer_states = True
is_start_from_batch = True
+
+[optimizer]
+optimizer = sgd
+# define parameters for optimizer
+# optimizer_params_dictionary should use " not ' as string wrapper
+# sgd/nag
+optimizer_params_dictionary={"momentum":0.9}
+# dcasgd
+# optimizer_params_dictionary={"momentum":0.9, "lamda":1.0}
+# adam
+# optimizer_params_dictionary={"beta1":0.9,"beta2":0.999}
+# adagrad
+# optimizer_params_dictionary={"eps":1e-08}
+# rmsprop
+# optimizer_params_dictionary={"gamma1":0.9, "gamma2":0.9,"epsilon":1e-08}
+# adadelta
+# optimizer_params_dictionary={"rho":0.95, "epsilon":1e-08}
+# set to 0 to disable gradient clipping
+clip_gradient = 100
+weight_decay = 0.
diff --git a/example/speech_recognition/default.cfg b/example/speech_recognition/default.cfg
index 853a04aebbdd..96adfa766ef5 100644
--- a/example/speech_recognition/default.cfg
+++ b/example/speech_recognition/default.cfg
@@ -18,6 +18,7 @@ tensorboard_log_dir = tblog/libri_sample
# if random_seed is -1 then it gets random seed from timestamp
mx_random_seed = -1
random_seed = -1
+kvstore_option = device
[data]
train_json = ./Libri_sample.json
@@ -50,24 +51,17 @@ rnn_type = bigru
#vanilla_lstm or fc_lstm (no effect when network_type is gru, bigru)
lstm_type = fc_lstm
is_batchnorm = True
+is_bucketing = False
+buckets = []
[train]
num_epoch = 70
-
learning_rate = 0.005
# constant learning rate annealing by factor
learning_rate_annealing = 1.1
-# supports only sgd and adam
-optimizer = adam
-# for sgd
-momentum = 0.9
-# set to 0 to disable gradient clipping
-clip_gradient = 0
-
initializer = Xavier
init_scale = 2
factor_type = in
-weight_decay = 0.00001
# show progress every nth batches
show_every = 1
save_optimizer_states = True
@@ -80,3 +74,23 @@ enable_logging_validation_metric = True
[load]
load_optimizer_states = True
is_start_from_batch = False
+
+[optimizer]
+optimizer = adam
+# define parameters for optimizer
+# optimizer_params_dictionary should use " not ' as string wrapper
+# sgd/nag
+# optimizer_params_dictionary={"momentum":0.9}
+# dcasgd
+# optimizer_params_dictionary={"momentum":0.9, "lamda":1.0}
+# adam
+optimizer_params_dictionary={"beta1":0.9,"beta2":0.999}
+# adagrad
+# optimizer_params_dictionary={"eps":1e-08}
+# rmsprop
+# optimizer_params_dictionary={"gamma1":0.9, "gamma2":0.9,"epsilon":1e-08}
+# adadelta
+# optimizer_params_dictionary={"rho":0.95, "epsilon":1e-08}
+# set to 0 to disable gradient clipping
+clip_gradient = 0
+weight_decay = 0.
diff --git a/example/speech_recognition/main.py b/example/speech_recognition/main.py
index 398a8a537e01..589cd74aa40f 100644
--- a/example/speech_recognition/main.py
+++ b/example/speech_recognition/main.py
@@ -1,33 +1,34 @@
+# pylint: disable=C0111, too-many-statements, too-many-locals, too-few-public-methods, too-many-branches
+# pylint: too-many-arguments,too-many-instance-attributes,too-many-locals,redefined-outer-name,fixme
+# pylint: disable=superfluous-parens, no-member, invalid-name, anomalous-backslash-in-string, redefined-outer-name
+import json
+import os
import sys
-
-sys.path.insert(0, "../../python")
+from collections import namedtuple
+from datetime import datetime
from config_util import parse_args, parse_contexts, generate_file_path
from train import do_training
import mxnet as mx
from stt_io_iter import STTIter
from label_util import LabelUtil
from log_util import LogUtil
-
import numpy as np
from stt_datagenerator import DataGenerator
from stt_metric import STTMetric
-from datetime import datetime
from stt_bi_graphemes_util import generate_bi_graphemes_dictionary
-########################################
-########## FOR JUPYTER NOTEBOOK
-import os
+from stt_bucketing_module import STTBucketingModule
+from stt_io_bucketingiter import BucketSTTIter
+sys.path.insert(0, "../../python")
+
+
+
+
# os.environ['MXNET_ENGINE_TYPE'] = "NaiveEngine"
os.environ['MXNET_ENGINE_TYPE'] = "ThreadedEnginePerDevice"
os.environ['MXNET_ENABLE_GPU_P2P'] = "0"
-
-class WHCS:
- width = 0
- height = 0
- channel = 0
- stride = 0
-
+WHCS = namedtuple("WHCS", ["width", "height", "channel", "stride"])
class ConfigLogger(object):
def __init__(self, log):
@@ -45,6 +46,8 @@ def write(self, data):
def load_data(args):
mode = args.config.get('common', 'mode')
+ if mode not in ['train', 'predict', 'load']:
+ raise Exception('mode must be the one of the followings - train,predict,load')
batch_size = args.config.getint('common', 'batch_size')
whcs = WHCS()
@@ -57,7 +60,6 @@ def load_data(args):
is_bi_graphemes = args.config.getboolean('common', 'is_bi_graphemes')
overwrite_meta_files = args.config.getboolean('train', 'overwrite_meta_files')
language = args.config.get('data', 'language')
- is_bi_graphemes = args.config.getboolean('common', 'is_bi_graphemes')
labelUtil = LabelUtil.getInstance()
if language == "en":
@@ -65,25 +67,20 @@ def load_data(args):
try:
labelUtil.load_unicode_set("resources/unicodemap_en_baidu_bi_graphemes.csv")
except:
- raise Exception("There is no resources/unicodemap_en_baidu_bi_graphemes.csv. Please set overwrite_meta_files at train section True")
+ raise Exception("There is no resources/unicodemap_en_baidu_bi_graphemes.csv." +
+ " Please set overwrite_meta_files at train section True")
else:
labelUtil.load_unicode_set("resources/unicodemap_en_baidu.csv")
else:
raise Exception("Error: Language Type: %s" % language)
args.config.set('arch', 'n_classes', str(labelUtil.get_count()))
- if mode == 'predict':
- test_json = args.config.get('data', 'test_json')
- datagen = DataGenerator(save_dir=save_dir, model_name=model_name)
- datagen.load_train_data(test_json)
- datagen.get_meta_from_file(np.loadtxt(generate_file_path(save_dir, model_name, 'feats_mean')),
- np.loadtxt(generate_file_path(save_dir, model_name, 'feats_std')))
- elif mode =="train" or mode == "load":
+ if mode == "train" or mode == "load":
data_json = args.config.get('data', 'train_json')
val_json = args.config.get('data', 'val_json')
datagen = DataGenerator(save_dir=save_dir, model_name=model_name)
datagen.load_train_data(data_json)
- #test bigramphems
+ # test bigramphems
if overwrite_meta_files and is_bi_graphemes:
generate_bi_graphemes_dictionary(datagen.train_texts)
@@ -95,62 +92,50 @@ def load_data(args):
normalize_target_k = args.config.getint('train', 'normalize_target_k')
datagen.sample_normalize(normalize_target_k, True)
else:
- datagen.get_meta_from_file(np.loadtxt(generate_file_path(save_dir, model_name, 'feats_mean')),
- np.loadtxt(generate_file_path(save_dir, model_name, 'feats_std')))
+ datagen.get_meta_from_file(
+ np.loadtxt(generate_file_path(save_dir, model_name, 'feats_mean')),
+ np.loadtxt(generate_file_path(save_dir, model_name, 'feats_std')))
datagen.load_validation_data(val_json)
elif mode == "load":
# get feat_mean and feat_std to normalize dataset
- datagen.get_meta_from_file(np.loadtxt(generate_file_path(save_dir, model_name, 'feats_mean')),
- np.loadtxt(generate_file_path(save_dir, model_name, 'feats_std')))
+ datagen.get_meta_from_file(
+ np.loadtxt(generate_file_path(save_dir, model_name, 'feats_mean')),
+ np.loadtxt(generate_file_path(save_dir, model_name, 'feats_std')))
datagen.load_validation_data(val_json)
- else:
- raise Exception(
- 'Define mode in the cfg file first. train or predict or load can be the candidate for the mode.')
+ elif mode == 'predict':
+ test_json = args.config.get('data', 'test_json')
+ datagen = DataGenerator(save_dir=save_dir, model_name=model_name)
+ datagen.load_train_data(test_json)
+ datagen.get_meta_from_file(
+ np.loadtxt(generate_file_path(save_dir, model_name, 'feats_mean')),
+ np.loadtxt(generate_file_path(save_dir, model_name, 'feats_std')))
is_batchnorm = args.config.getboolean('arch', 'is_batchnorm')
- if batch_size == 1 and is_batchnorm:
+ if batch_size == 1 and is_batchnorm and (mode == 'train' or mode == 'load'):
raise Warning('batch size 1 is too small for is_batchnorm')
# sort file paths by its duration in ascending order to implement sortaGrad
-
if mode == "train" or mode == "load":
max_t_count = datagen.get_max_seq_length(partition="train")
- max_label_length = datagen.get_max_label_length(partition="train",is_bi_graphemes=is_bi_graphemes)
+ max_label_length = \
+ datagen.get_max_label_length(partition="train", is_bi_graphemes=is_bi_graphemes)
elif mode == "predict":
max_t_count = datagen.get_max_seq_length(partition="test")
- max_label_length = datagen.get_max_label_length(partition="test",is_bi_graphemes=is_bi_graphemes)
- else:
- raise Exception(
- 'Define mode in the cfg file first. train or predict or load can be the candidate for the mode.')
+ max_label_length = \
+ datagen.get_max_label_length(partition="test", is_bi_graphemes=is_bi_graphemes)
args.config.set('arch', 'max_t_count', str(max_t_count))
args.config.set('arch', 'max_label_length', str(max_label_length))
from importlib import import_module
prepare_data_template = import_module(args.config.get('arch', 'arch_file'))
init_states = prepare_data_template.prepare_data(args)
- if mode == "train":
- sort_by_duration = True
- else:
- sort_by_duration = False
-
- data_loaded = STTIter(partition="train",
- count=datagen.count,
- datagen=datagen,
- batch_size=batch_size,
- num_label=max_label_length,
- init_states=init_states,
- seq_length=max_t_count,
- width=whcs.width,
- height=whcs.height,
- sort_by_duration=sort_by_duration,
- is_bi_graphemes=is_bi_graphemes)
-
- if mode == 'predict':
- return data_loaded, args
- else:
- validation_loaded = STTIter(partition="validation",
- count=datagen.val_count,
+ sort_by_duration = (mode == "train")
+ is_bucketing = args.config.getboolean('arch', 'is_bucketing')
+ if is_bucketing:
+ buckets = json.loads(args.config.get('arch', 'buckets'))
+ data_loaded = BucketSTTIter(partition="train",
+ count=datagen.count,
datagen=datagen,
batch_size=batch_size,
num_label=max_label_length,
@@ -158,37 +143,87 @@ def load_data(args):
seq_length=max_t_count,
width=whcs.width,
height=whcs.height,
- sort_by_duration=False,
- is_bi_graphemes=is_bi_graphemes)
+ sort_by_duration=sort_by_duration,
+ is_bi_graphemes=is_bi_graphemes,
+ buckets=buckets)
+ else:
+ data_loaded = STTIter(partition="train",
+ count=datagen.count,
+ datagen=datagen,
+ batch_size=batch_size,
+ num_label=max_label_length,
+ init_states=init_states,
+ seq_length=max_t_count,
+ width=whcs.width,
+ height=whcs.height,
+ sort_by_duration=sort_by_duration,
+ is_bi_graphemes=is_bi_graphemes)
+
+ if mode == 'train' or mode == 'load':
+ if is_bucketing:
+ validation_loaded = BucketSTTIter(partition="validation",
+ count=datagen.val_count,
+ datagen=datagen,
+ batch_size=batch_size,
+ num_label=max_label_length,
+ init_states=init_states,
+ seq_length=max_t_count,
+ width=whcs.width,
+ height=whcs.height,
+ sort_by_duration=False,
+ is_bi_graphemes=is_bi_graphemes,
+ buckets=buckets)
+ else:
+ validation_loaded = STTIter(partition="validation",
+ count=datagen.val_count,
+ datagen=datagen,
+ batch_size=batch_size,
+ num_label=max_label_length,
+ init_states=init_states,
+ seq_length=max_t_count,
+ width=whcs.width,
+ height=whcs.height,
+ sort_by_duration=False,
+ is_bi_graphemes=is_bi_graphemes)
return data_loaded, validation_loaded, args
+ elif mode == 'predict':
+ return data_loaded, args
def load_model(args, contexts, data_train):
# load model from model_name prefix and epoch of model_num_epoch with gpu contexts of contexts
mode = args.config.get('common', 'mode')
load_optimizer_states = args.config.getboolean('load', 'load_optimizer_states')
- is_start_from_batch = args.config.getboolean('load','is_start_from_batch')
+ is_start_from_batch = args.config.getboolean('load', 'is_start_from_batch')
from importlib import import_module
symbol_template = import_module(args.config.get('arch', 'arch_file'))
- model_loaded = symbol_template.arch(args)
+ is_bucketing = args.config.getboolean('arch', 'is_bucketing')
if mode == 'train':
+ if is_bucketing:
+ bucketing_arch = symbol_template.BucketingArch(args)
+ model_loaded = bucketing_arch.get_sym_gen()
+ else:
+ model_loaded = symbol_template.arch(args)
model_num_epoch = None
- else:
+ elif mode == 'load' or mode == 'predict':
model_file = args.config.get('common', 'model_file')
model_name = os.path.splitext(model_file)[0]
-
model_num_epoch = int(model_name[-4:])
+ if is_bucketing:
+ bucketing_arch = symbol_template.BucketingArch(args)
+ model_loaded = bucketing_arch.get_sym_gen()
+ else:
+ model_path = 'checkpoints/' + str(model_name[:-5])
- model_path = 'checkpoints/' + str(model_name[:-5])
-
- data_names = [x[0] for x in data_train.provide_data]
- label_names = [x[0] for x in data_train.provide_label]
+ data_names = [x[0] for x in data_train.provide_data]
+ label_names = [x[0] for x in data_train.provide_label]
- model_loaded = mx.module.Module.load(prefix=model_path, epoch=model_num_epoch, context=contexts,
- data_names=data_names, label_names=label_names,
- load_optimizer_states=load_optimizer_states)
+ model_loaded = mx.module.Module.load(
+ prefix=model_path, epoch=model_num_epoch, context=contexts,
+ data_names=data_names, label_names=label_names,
+ load_optimizer_states=load_optimizer_states)
if is_start_from_batch:
import re
model_num_epoch = int(re.findall('\d+', model_file)[0])
@@ -198,7 +233,8 @@ def load_model(args, contexts, data_train):
if __name__ == '__main__':
if len(sys.argv) <= 1:
- raise Exception('cfg file path must be provided. ex)python main.py --configfile examplecfg.cfg')
+ raise Exception('cfg file path must be provided. ' +
+ 'ex)python main.py --configfile examplecfg.cfg')
args = parse_args(sys.argv[1])
# set parameters from cfg file
# give random seed
@@ -206,9 +242,9 @@ def load_model(args, contexts, data_train):
mx_random_seed = args.config.getint('common', 'mx_random_seed')
# random seed for shuffling data list
if random_seed != -1:
- random.seed(random_seed)
+ np.random.seed(random_seed)
# set mx.random.seed to give seed for parameter initialization
- if mx_random_seed !=-1:
+ if mx_random_seed != -1:
mx.random.seed(mx_random_seed)
else:
mx.random.seed(hash(datetime.now()))
@@ -220,22 +256,23 @@ def load_model(args, contexts, data_train):
mode = args.config.get('common', 'mode')
if mode not in ['train', 'predict', 'load']:
raise Exception(
- 'Define mode in the cfg file first. train or predict or load can be the candidate for the mode.')
+ 'Define mode in the cfg file first. ' +
+ 'train or predict or load can be the candidate for the mode.')
# get meta file where character to number conversions are defined
contexts = parse_contexts(args)
num_gpu = len(contexts)
batch_size = args.config.getint('common', 'batch_size')
-
# check the number of gpus is positive divisor of the batch size for data parallel
if batch_size % num_gpu != 0:
raise Exception('num_gpu should be positive divisor of batch_size')
-
- if mode == "predict":
- data_train, args = load_data(args)
- elif mode == "train" or mode == "load":
+ if mode == "train" or mode == "load":
data_train, data_val, args = load_data(args)
+ elif mode == "predict":
+ data_train, args = load_data(args)
+ is_batchnorm = args.config.getboolean('arch', 'is_batchnorm')
+ is_bucketing = args.config.getboolean('arch', 'is_bucketing')
# log current config
config_logger = ConfigLogger(log)
@@ -246,25 +283,56 @@ def load_model(args, contexts, data_train):
# if mode is 'train', it trains the model
if mode == 'train':
- data_names = [x[0] for x in data_train.provide_data]
- label_names = [x[0] for x in data_train.provide_label]
- module = mx.mod.Module(model_loaded, context=contexts, data_names=data_names, label_names=label_names)
+ if is_bucketing:
+ module = STTBucketingModule(
+ sym_gen=model_loaded,
+ default_bucket_key=data_train.default_bucket_key,
+ context=contexts)
+ else:
+ data_names = [x[0] for x in data_train.provide_data]
+ label_names = [x[0] for x in data_train.provide_label]
+ module = mx.mod.Module(model_loaded, context=contexts,
+ data_names=data_names, label_names=label_names)
do_training(args=args, module=module, data_train=data_train, data_val=data_val)
# if mode is 'load', it loads model from the checkpoint and continues the training.
elif mode == 'load':
- do_training(args=args, module=model_loaded, data_train=data_train, data_val=data_val, begin_epoch=model_num_epoch+1)
+ do_training(args=args, module=model_loaded, data_train=data_train, data_val=data_val,
+ begin_epoch=model_num_epoch + 1)
# if mode is 'predict', it predict label from the input by the input model
elif mode == 'predict':
# predict through data
- model_loaded.bind(for_training=False, data_shapes=data_train.provide_data,
- label_shapes=data_train.provide_label)
+ if is_bucketing:
+ max_t_count = args.config.getint('arch', 'max_t_count')
+ load_optimizer_states = args.config.getboolean('load', 'load_optimizer_states')
+ model_file = args.config.get('common', 'model_file')
+ model_name = os.path.splitext(model_file)[0]
+ model_num_epoch = int(model_name[-4:])
+
+ model_path = 'checkpoints/' + str(model_name[:-5])
+ model = STTBucketingModule(
+ sym_gen=model_loaded,
+ default_bucket_key=data_train.default_bucket_key,
+ context=contexts)
+
+ model.bind(data_shapes=data_train.provide_data,
+ label_shapes=data_train.provide_label,
+ for_training=True)
+ _, arg_params, aux_params = mx.model.load_checkpoint(model_path, model_num_epoch)
+ model.set_params(arg_params, aux_params)
+ model_loaded = model
+ else:
+ model_loaded.bind(for_training=False, data_shapes=data_train.provide_data,
+ label_shapes=data_train.provide_label)
max_t_count = args.config.getint('arch', 'max_t_count')
- eval_metric = STTMetric(batch_size=batch_size, num_gpu=num_gpu, seq_length=max_t_count)
- is_batchnorm = args.config.getboolean('arch', 'is_batchnorm')
- if is_batchnorm :
+ eval_metric = STTMetric(batch_size=batch_size, num_gpu=num_gpu)
+ if is_batchnorm:
for nbatch, data_batch in enumerate(data_train):
- # when is_train = False it leads to high cer when batch_norm
- model_loaded.forward(data_batch, is_train=True)
+ model_loaded.forward(data_batch, is_train=False)
model_loaded.update_metric(eval_metric, data_batch.label)
- else :
- model_loaded.score(eval_data=data_train, num_batch=None, eval_metric=eval_metric, reset=True)
+ else:
+ model_loaded.score(eval_data=data_train, num_batch=None,
+ eval_metric=eval_metric, reset=True)
+ else:
+ raise Exception(
+ 'Define mode in the cfg file first. ' +
+ 'train or predict or load can be the candidate for the mode')
diff --git a/example/speech_recognition/stt_bucketing_module.py b/example/speech_recognition/stt_bucketing_module.py
new file mode 100644
index 000000000000..00ba8df19bd5
--- /dev/null
+++ b/example/speech_recognition/stt_bucketing_module.py
@@ -0,0 +1,13 @@
+import mxnet as mx
+
+
+class STTBucketingModule(mx.mod.BucketingModule):
+
+ def save_checkpoint(self, prefix, epoch, save_optimizer_states=False):
+ symbol, data_names, label_names = self._sym_gen(self._default_bucket_key)
+ symbol.save('%s-symbol.json' % prefix)
+ param_name = '%s-%04d.params' % (prefix, epoch)
+ self.save_params(param_name)
+ if save_optimizer_states:
+ state_name = '%s-%04d.states' % (prefix, epoch)
+ self._curr_module.save_optimizer_states(state_name)
\ No newline at end of file
diff --git a/example/speech_recognition/stt_datagenerator.py b/example/speech_recognition/stt_datagenerator.py
index 390de432e751..a13490fc3435 100644
--- a/example/speech_recognition/stt_datagenerator.py
+++ b/example/speech_recognition/stt_datagenerator.py
@@ -32,7 +32,7 @@ def __init__(self, save_dir, model_name, step=10, window=20, max_freq=8000, desc
# 1d 161 length of array filled with 1s
self.feats_std = np.ones((self.feat_dim,))
self.max_input_length = 0
- self.max_length_list_in_batch =[]
+ self.max_length_list_in_batch = []
# 1d 161 length of array filled with random value
#[0.0, 1.0)
self.rng = random.Random()
@@ -146,10 +146,11 @@ def get_max_seq_length(self, partition):
"Must be train/validation/test")
max_duration_indexes = durations.index(max(durations))
max_seq_length = self.featurize(audio_paths[max_duration_indexes]).shape[0]
- self.max_seq_length=max_seq_length
+ self.max_seq_length = max_seq_length
return max_seq_length
- def prepare_minibatch(self, audio_paths, texts, overwrite=False, is_bi_graphemes=False):
+ def prepare_minibatch(self, audio_paths, texts, overwrite=False,
+ is_bi_graphemes=False, seq_length=-1):
""" Featurize a minibatch of audio, zero pad them and return a dictionary
Params:
audio_paths (list(str)): List of paths to audio files
@@ -167,7 +168,10 @@ def prepare_minibatch(self, audio_paths, texts, overwrite=False, is_bi_graphemes
feature_dim = features[0].shape[1]
mb_size = len(features)
# Pad all the inputs so that they are all the same length
- x = np.zeros((mb_size, self.max_seq_length, feature_dim))
+ if seq_length == -1:
+ x = np.zeros((mb_size, self.max_seq_length, feature_dim))
+ else:
+ x = np.zeros((mb_size, seq_length, feature_dim))
y = np.zeros((mb_size, self.max_label_length))
labelUtil = LabelUtil.getInstance()
label_lengths = []
@@ -223,10 +227,13 @@ def sample_normalize(self, k_samples=1000, overwrite=False):
next_feat_squared = np.square(next_feat)
feat_vertically_stacked = np.concatenate((feat, next_feat)).reshape(-1, dim)
feat = np.sum(feat_vertically_stacked, axis=0, keepdims=True)
- feat_squared_vertically_stacked = np.concatenate((feat_squared, next_feat_squared)).reshape(-1, dim)
+ feat_squared_vertically_stacked = np.concatenate(
+ (feat_squared, next_feat_squared)).reshape(-1, dim)
feat_squared = np.sum(feat_squared_vertically_stacked, axis=0, keepdims=True)
- count = count + float(next_feat.shape[0])
+ count += float(next_feat.shape[0])
self.feats_mean = feat / float(count)
self.feats_std = np.sqrt(feat_squared / float(count) - np.square(self.feats_mean))
- np.savetxt(generate_file_path(self.save_dir, self.model_name, 'feats_mean'), self.feats_mean)
- np.savetxt(generate_file_path(self.save_dir, self.model_name, 'feats_std'), self.feats_std)
+ np.savetxt(
+ generate_file_path(self.save_dir, self.model_name, 'feats_mean'), self.feats_mean)
+ np.savetxt(
+ generate_file_path(self.save_dir, self.model_name, 'feats_std'), self.feats_std)
diff --git a/example/speech_recognition/stt_io_bucketingiter.py b/example/speech_recognition/stt_io_bucketingiter.py
new file mode 100644
index 000000000000..e50715d60861
--- /dev/null
+++ b/example/speech_recognition/stt_io_bucketingiter.py
@@ -0,0 +1,145 @@
+from __future__ import print_function
+import mxnet as mx
+import sys
+sys.path.insert(0, "../../python")
+
+import bisect
+import random
+import numpy as np
+
+BATCH_SIZE = 1
+SEQ_LENGTH = 0
+NUM_GPU = 1
+
+
+def get_label(buf, num_lable):
+ ret = np.zeros(num_lable)
+ for i in range(len(buf)):
+ ret[i] = int(buf[i])
+ return ret
+
+
+class BucketSTTIter(mx.io.DataIter):
+ def __init__(self, count, datagen, batch_size, num_label, init_states, seq_length, width, height,
+ sort_by_duration=True,
+ is_bi_graphemes=False,
+ partition="train",
+ buckets=[]
+ ):
+ super(BucketSTTIter, self).__init__()
+
+ self.maxLabelLength = num_label
+ # global param
+ self.batch_size = batch_size
+ self.count = count
+ self.num_label = num_label
+ self.init_states = init_states
+ self.init_state_arrays = [mx.nd.zeros(x[1]) for x in init_states]
+ self.width = width
+ self.height = height
+ self.datagen = datagen
+ self.label = None
+ self.is_bi_graphemes = is_bi_graphemes
+ # self.partition = datagen.partition
+ if partition == 'train':
+ durations = datagen.train_durations
+ audio_paths = datagen.train_audio_paths
+ texts = datagen.train_texts
+ elif partition == 'validation':
+ durations = datagen.val_durations
+ audio_paths = datagen.val_audio_paths
+ texts = datagen.val_texts
+ elif partition == 'test':
+ durations = datagen.test_durations
+ audio_paths = datagen.test_audio_paths
+ texts = datagen.test_texts
+ else:
+ raise Exception("Invalid partition to load metadata. "
+ "Must be train/validation/test")
+ # if sortagrad
+ if sort_by_duration:
+ durations, audio_paths, texts = datagen.sort_by_duration(durations,
+ audio_paths,
+ texts)
+ else:
+ durations = durations
+ audio_paths = audio_paths
+ texts = texts
+ self.trainDataList = zip(durations, audio_paths, texts)
+
+ self.trainDataIter = iter(self.trainDataList)
+ self.is_first_epoch = True
+
+ data_lengths = [int(d*100) for d in durations]
+ if len(buckets) == 0:
+ buckets = [i for i, j in enumerate(np.bincount(data_lengths))
+ if j >= batch_size]
+ if len(buckets) == 0:
+ raise Exception('There is no valid buckets. It may occured by large batch_size for each buckets. max bincount:%d batch_size:%d' % (max(np.bincount(data_lengths)), batch_size))
+ buckets.sort()
+ ndiscard = 0
+ self.data = [[] for _ in buckets]
+ for i, sent in enumerate(data_lengths):
+ buck = bisect.bisect_left(buckets, sent)
+ if buck == len(buckets):
+ ndiscard += 1
+ continue
+ self.data[buck].append(self.trainDataList[i])
+ if ndiscard != 0:
+ print("WARNING: discarded %d sentences longer than the largest bucket."% ndiscard)
+
+ self.buckets = buckets
+ self.nddata = []
+ self.ndlabel = []
+ self.default_bucket_key = max(buckets)
+
+ self.idx = []
+ for i, buck in enumerate(self.data):
+ self.idx.extend([(i, j) for j in range(0, len(buck) - batch_size + 1, batch_size)])
+ self.curr_idx = 0
+
+ self.provide_data = [('data', (self.batch_size, self.default_bucket_key , width * height))] + init_states
+ self.provide_label = [('label', (self.batch_size, self.maxLabelLength))]
+
+ #self.reset()
+
+ def reset(self):
+ """Resets the iterator to the beginning of the data."""
+ self.curr_idx = 0
+ random.shuffle(self.idx)
+ for buck in self.data:
+ np.random.shuffle(buck)
+
+ def next(self):
+ """Returns the next batch of data."""
+ if self.curr_idx == len(self.idx):
+ raise StopIteration
+ i, j = self.idx[self.curr_idx]
+ self.curr_idx += 1
+
+ audio_paths = []
+ texts = []
+ for duration, audio_path, text in self.data[i][j:j+self.batch_size]:
+ audio_paths.append(audio_path)
+ texts.append(text)
+
+ if self.is_first_epoch:
+ data_set = self.datagen.prepare_minibatch(audio_paths, texts, overwrite=True,
+ is_bi_graphemes=self.is_bi_graphemes,
+ seq_length=self.buckets[i]
+ )
+ else:
+ data_set = self.datagen.prepare_minibatch(audio_paths, texts, overwrite=False,
+ is_bi_graphemes=self.is_bi_graphemes,
+ seq_length=self.buckets[i])
+
+ data_all = [mx.nd.array(data_set['x'])] + self.init_state_arrays
+ label_all = [mx.nd.array(data_set['y'])]
+
+ self.label = label_all
+ provide_data = [('data', (self.batch_size, self.buckets[i], self.width * self.height))] + self.init_states
+
+ return mx.io.DataBatch(data_all, label_all, pad=0,
+ bucket_key=self.buckets[i],
+ provide_data=provide_data,
+ provide_label=self.provide_label)
diff --git a/example/speech_recognition/stt_io_iter.py b/example/speech_recognition/stt_io_iter.py
index 70c31ce92dde..51029167bc4e 100644
--- a/example/speech_recognition/stt_io_iter.py
+++ b/example/speech_recognition/stt_io_iter.py
@@ -103,7 +103,6 @@ def __iter__(self):
data_batch = SimpleBatch(data_names, data_all, label_names, label_all)
yield data_batch
- self.is_first_epoch = False
def reset(self):
pass
diff --git a/example/speech_recognition/stt_layer_batchnorm.py b/example/speech_recognition/stt_layer_batchnorm.py
index 86e75aa49557..5b73f4f9f890 100644
--- a/example/speech_recognition/stt_layer_batchnorm.py
+++ b/example/speech_recognition/stt_layer_batchnorm.py
@@ -6,7 +6,7 @@ def batchnorm(net,
beta=None,
eps=0.001,
momentum=0.9,
- fix_gamma=True,
+ fix_gamma=False,
use_global_stats=False,
output_mean_var=False,
name=None):
@@ -18,7 +18,8 @@ def batchnorm(net,
momentum=momentum,
fix_gamma=fix_gamma,
use_global_stats=use_global_stats,
- output_mean_var=output_mean_var
+ output_mean_var=output_mean_var,
+ name=name
)
else:
net = mx.sym.BatchNorm(data=net,
@@ -26,6 +27,7 @@ def batchnorm(net,
momentum=momentum,
fix_gamma=fix_gamma,
use_global_stats=use_global_stats,
- output_mean_var=output_mean_var
+ output_mean_var=output_mean_var,
+ name=name
)
return net
diff --git a/example/speech_recognition/stt_layer_conv.py b/example/speech_recognition/stt_layer_conv.py
index 5ec292557f04..ab0035e4803b 100644
--- a/example/speech_recognition/stt_layer_conv.py
+++ b/example/speech_recognition/stt_layer_conv.py
@@ -8,20 +8,22 @@ def conv(net,
weight=None,
bias=None,
act_type="relu",
- no_bias=False
+ no_bias=False,
+ name=None
):
# 2d convolution's input should have the shape of 4D (batch_size,1,seq_len,feat_dim)
if weight is None or bias is None:
# ex) filter_dimension = (41,11) , stride=(2,2)
- net = mx.sym.Convolution(data=net, num_filter=channels, kernel=filter_dimension, stride=stride, no_bias=no_bias)
+ net = mx.sym.Convolution(data=net, num_filter=channels, kernel=filter_dimension, stride=stride, no_bias=no_bias,
+ name=name)
elif weight is None or bias is not None:
net = mx.sym.Convolution(data=net, num_filter=channels, kernel=filter_dimension, stride=stride, bias=bias,
- no_bias=no_bias)
+ no_bias=no_bias, name=name)
elif weight is not None or bias is None:
net = mx.sym.Convolution(data=net, num_filter=channels, kernel=filter_dimension, stride=stride, weight=weight,
- no_bias=no_bias)
+ no_bias=no_bias, name=name)
else:
net = mx.sym.Convolution(data=net, num_filter=channels, kernel=filter_dimension, stride=stride, weight=weight,
- bias=bias, no_bias=no_bias)
+ bias=bias, no_bias=no_bias, name=name)
net = mx.sym.Activation(data=net, act_type=act_type)
return net
diff --git a/example/speech_recognition/stt_layer_fc.py b/example/speech_recognition/stt_layer_fc.py
index b3db2034a3ad..f435922426c5 100644
--- a/example/speech_recognition/stt_layer_fc.py
+++ b/example/speech_recognition/stt_layer_fc.py
@@ -8,29 +8,30 @@ def fc(net,
act_type,
weight=None,
bias=None,
- no_bias=False
+ no_bias=False,
+ name=None
):
# when weight and bias doesn't have specific name
if weight is None and bias is None:
- net = mx.sym.FullyConnected(data=net, num_hidden=num_hidden, no_bias=no_bias)
+ net = mx.sym.FullyConnected(data=net, num_hidden=num_hidden, no_bias=no_bias, name=name)
# when weight doesn't have specific name but bias has
elif weight is None and bias is not None:
if no_bias:
- net = mx.sym.FullyConnected(data=net, num_hidden=num_hidden, no_bias=no_bias)
+ net = mx.sym.FullyConnected(data=net, num_hidden=num_hidden, no_bias=no_bias, name=name)
else:
- net = mx.sym.FullyConnected(data=net, num_hidden=num_hidden, bias=bias, no_bias=no_bias)
+ net = mx.sym.FullyConnected(data=net, num_hidden=num_hidden, bias=bias, no_bias=no_bias, name=name)
# when bias doesn't have specific name but weight has
elif weight is not None and bias is None:
- net = mx.sym.FullyConnected(data=net, num_hidden=num_hidden, weight=weight, no_bias=no_bias)
+ net = mx.sym.FullyConnected(data=net, num_hidden=num_hidden, weight=weight, no_bias=no_bias, name=name)
# when weight and bias specific name
else:
if no_bias:
- net = mx.sym.FullyConnected(data=net, num_hidden=num_hidden, weight=weight, no_bias=no_bias)
+ net = mx.sym.FullyConnected(data=net, num_hidden=num_hidden, weight=weight, no_bias=no_bias, name=name)
else:
- net = mx.sym.FullyConnected(data=net, num_hidden=num_hidden, weight=weight, bias=bias, no_bias=no_bias)
+ net = mx.sym.FullyConnected(data=net, num_hidden=num_hidden, weight=weight, bias=bias, no_bias=no_bias, name=name)
# activation
if act_type is not None:
- net = mx.sym.Activation(data=net, act_type=act_type)
+ net = mx.sym.Activation(data=net, act_type=act_type, name="%s_activation" % name)
return net
@@ -41,7 +42,7 @@ def sequence_fc(net,
num_hidden_list=[],
act_type_list=[],
is_batchnorm=False,
- dropout_rate=0
+ dropout_rate=0,
):
if num_layer == len(num_hidden_list) == len(act_type_list):
if num_layer > 0:
@@ -81,13 +82,16 @@ def sequence_fc(net,
num_hidden=num_hidden_list[layer_index],
act_type=None,
weight=weight_list[layer_index],
- no_bias=is_batchnorm
+ no_bias=is_batchnorm,
+ name="%s_t%d_l%d_fc" % (prefix, seq_index, layer_index)
)
# last layer doesn't have batchnorm
hidden = batchnorm(net=hidden,
gamma=gamma_list[layer_index],
- beta=beta_list[layer_index])
- hidden = mx.sym.Activation(data=hidden, act_type=act_type_list[layer_index])
+ beta=beta_list[layer_index],
+ name="%s_t%d_l%d_batchnorm" % (prefix, seq_index, layer_index))
+ hidden = mx.sym.Activation(data=hidden, act_type=act_type_list[layer_index],
+ name="%s_t%d_l%d_activation" % (prefix, seq_index, layer_index))
else:
hidden = fc(net=hidden,
num_hidden=num_hidden_list[layer_index],
diff --git a/example/speech_recognition/stt_layer_gru.py b/example/speech_recognition/stt_layer_gru.py
index 8b044746dfcf..89af1c72216d 100644
--- a/example/speech_recognition/stt_layer_gru.py
+++ b/example/speech_recognition/stt_layer_gru.py
@@ -15,7 +15,7 @@
"param_blocks"])
-def gru(num_hidden, indata, prev_state, param, seqidx, layeridx, dropout=0., is_batchnorm=False, gamma=None, beta=None):
+def gru(num_hidden, indata, prev_state, param, seqidx, layeridx, dropout=0., is_batchnorm=False, gamma=None, beta=None, name=None):
"""
GRU Cell symbol
Reference:
@@ -31,7 +31,10 @@ def gru(num_hidden, indata, prev_state, param, seqidx, layeridx, dropout=0., is_
name="t%d_l%d_gates_i2h" % (seqidx, layeridx))
if is_batchnorm:
- i2h = batchnorm(net=i2h, gamma=gamma, beta=beta)
+ if name is not None:
+ i2h = batchnorm(net=i2h, gamma=gamma, beta=beta, name="%s_batchnorm" % name)
+ else:
+ i2h = batchnorm(net=i2h, gamma=gamma, beta=beta)
h2h = mx.sym.FullyConnected(data=prev_state.h,
weight=param.gates_h2h_weight,
bias=param.gates_h2h_bias,
@@ -53,15 +56,15 @@ def gru(num_hidden, indata, prev_state, param, seqidx, layeridx, dropout=0., is_
weight=param.trans_h2h_weight,
bias=param.trans_h2h_bias,
num_hidden=num_hidden,
- name="t%d_l%d_trans_i2h" % (seqidx, layeridx))
+ name="t%d_l%d_trans_h2h" % (seqidx, layeridx))
h_trans = htrans_i2h + htrans_h2h
h_trans_active = mx.sym.Activation(h_trans, act_type="tanh")
next_h = prev_state.h + update_gate * (h_trans_active - prev_state.h)
return GRUState(h=next_h)
-def gru_unroll(net, num_gru_layer, seq_len, num_hidden_gru_list, dropout=0., is_batchnorm=False, prefix="",
- direction="forward"):
+def gru_unroll(net, num_gru_layer, seq_len, num_hidden_gru_list, dropout=0., is_batchnorm=False, prefix="",
+ direction="forward", is_bucketing=False):
if num_gru_layer > 0:
param_cells = []
last_states = []
@@ -81,9 +84,14 @@ def gru_unroll(net, num_gru_layer, seq_len, num_hidden_gru_list, dropout=0., is_
if is_batchnorm:
batchnorm_gamma = []
batchnorm_beta = []
- for seqidx in range(seq_len):
- batchnorm_gamma.append(mx.sym.Variable(prefix + "t%d_i2h_gamma" % seqidx))
- batchnorm_beta.append(mx.sym.Variable(prefix + "t%d_i2h_beta" % seqidx))
+ if is_bucketing:
+ for l in range(num_gru_layer):
+ batchnorm_gamma.append(mx.sym.Variable(prefix + "l%d_i2h_gamma" % l))
+ batchnorm_beta.append(mx.sym.Variable(prefix + "l%d_i2h_beta" % l))
+ else:
+ for seqidx in range(seq_len):
+ batchnorm_gamma.append(mx.sym.Variable(prefix + "t%d_i2h_gamma" % seqidx))
+ batchnorm_beta.append(mx.sym.Variable(prefix + "t%d_i2h_beta" % seqidx))
hidden_all = []
for seqidx in range(seq_len):
@@ -103,19 +111,33 @@ def gru_unroll(net, num_gru_layer, seq_len, num_hidden_gru_list, dropout=0., is_
else:
dp_ratio = dropout
if is_batchnorm:
- next_state = gru(num_hidden_gru_list[i], indata=hidden,
- prev_state=last_states[i],
- param=param_cells[i],
- seqidx=k, layeridx=i, dropout=dp_ratio,
- is_batchnorm=is_batchnorm,
- gamma=batchnorm_gamma[k],
- beta=batchnorm_beta[k])
+ if is_bucketing:
+ next_state = gru(num_hidden_gru_list[i], indata=hidden,
+ prev_state=last_states[i],
+ param=param_cells[i],
+ seqidx=k, layeridx=i, dropout=dp_ratio,
+ is_batchnorm=is_batchnorm,
+ gamma=batchnorm_gamma[i],
+ beta=batchnorm_beta[i],
+ name=prefix + ("t%d_l%d" % (seqidx, i))
+ )
+ else:
+ next_state = gru(num_hidden_gru_list[i], indata=hidden,
+ prev_state=last_states[i],
+ param=param_cells[i],
+ seqidx=k, layeridx=i, dropout=dp_ratio,
+ is_batchnorm=is_batchnorm,
+ gamma=batchnorm_gamma[k],
+ beta=batchnorm_beta[k],
+ name=prefix + ("t%d_l%d" % (seqidx, i))
+ )
else:
next_state = gru(num_hidden_gru_list[i], indata=hidden,
prev_state=last_states[i],
param=param_cells[i],
seqidx=k, layeridx=i, dropout=dp_ratio,
- is_batchnorm=is_batchnorm)
+ is_batchnorm=is_batchnorm,
+ name=prefix)
hidden = next_state.h
last_states[i] = next_state
# decoder
@@ -133,7 +155,7 @@ def gru_unroll(net, num_gru_layer, seq_len, num_hidden_gru_list, dropout=0., is_
return net
-def bi_gru_unroll(net, num_gru_layer, seq_len, num_hidden_gru_list, dropout=0., is_batchnorm=False):
+def bi_gru_unroll(net, num_gru_layer, seq_len, num_hidden_gru_list, dropout=0., is_batchnorm=False, is_bucketing=False):
if num_gru_layer > 0:
net_forward = gru_unroll(net=net,
num_gru_layer=num_gru_layer,
@@ -142,7 +164,8 @@ def bi_gru_unroll(net, num_gru_layer, seq_len, num_hidden_gru_list, dropout=0.,
dropout=dropout,
is_batchnorm=is_batchnorm,
prefix="forward_",
- direction="forward")
+ direction="forward",
+ is_bucketing=is_bucketing)
net_backward = gru_unroll(net=net,
num_gru_layer=num_gru_layer,
seq_len=seq_len,
@@ -150,7 +173,8 @@ def bi_gru_unroll(net, num_gru_layer, seq_len, num_hidden_gru_list, dropout=0.,
dropout=dropout,
is_batchnorm=is_batchnorm,
prefix="backward_",
- direction="backward")
+ direction="backward",
+ is_bucketing=is_bucketing)
hidden_all = []
for i in range(seq_len):
hidden_all.append(mx.sym.Concat(*[net_forward[i], net_backward[i]], dim=1))
@@ -159,7 +183,7 @@ def bi_gru_unroll(net, num_gru_layer, seq_len, num_hidden_gru_list, dropout=0.,
def bi_gru_unroll_two_input_two_output(net1, net2, num_gru_layer, seq_len, num_hidden_gru_list, dropout=0.,
- is_batchnorm=False):
+ is_batchnorm=False, is_bucketing=False):
if num_gru_layer > 0:
net_forward = gru_unroll(net=net1,
num_gru_layer=num_gru_layer,
@@ -168,7 +192,8 @@ def bi_gru_unroll_two_input_two_output(net1, net2, num_gru_layer, seq_len, num_h
dropout=dropout,
is_batchnorm=is_batchnorm,
prefix="forward_",
- direction="forward")
+ direction="forward",
+ is_bucketing=is_bucketing)
net_backward = gru_unroll(net=net2,
num_gru_layer=num_gru_layer,
seq_len=seq_len,
@@ -176,7 +201,8 @@ def bi_gru_unroll_two_input_two_output(net1, net2, num_gru_layer, seq_len, num_h
dropout=dropout,
is_batchnorm=is_batchnorm,
prefix="backward_",
- direction="backward")
+ direction="backward",
+ is_bucketing=is_bucketing)
return net_forward, net_backward
else:
return net1, net2
diff --git a/example/speech_recognition/stt_layer_lstm.py b/example/speech_recognition/stt_layer_lstm.py
index 19e37369b1b0..93b4ca09b908 100644
--- a/example/speech_recognition/stt_layer_lstm.py
+++ b/example/speech_recognition/stt_layer_lstm.py
@@ -16,7 +16,7 @@
"param_blocks"])
-def vanilla_lstm(num_hidden, indata, prev_state, param, seqidx, layeridx, is_batchnorm=False, gamma=None, beta=None):
+def vanilla_lstm(num_hidden, indata, prev_state, param, seqidx, layeridx, is_batchnorm=False, gamma=None, beta=None, name=None):
"""LSTM Cell symbol"""
i2h = mx.sym.FullyConnected(data=indata,
weight=param.i2h_weight,
@@ -24,7 +24,10 @@ def vanilla_lstm(num_hidden, indata, prev_state, param, seqidx, layeridx, is_bat
num_hidden=num_hidden * 4,
name="t%d_l%d_i2h" % (seqidx, layeridx))
if is_batchnorm:
- i2h = batchnorm(net=i2h, gamma=gamma, beta=beta)
+ if name is not None:
+ i2h = batchnorm(net=i2h, gamma=gamma, beta=beta, name="%s_batchnorm" % name)
+ else:
+ i2h = batchnorm(net=i2h, gamma=gamma, beta=beta)
h2h = mx.sym.FullyConnected(data=prev_state.h,
weight=param.h2h_weight,
bias=param.h2h_bias,
@@ -43,7 +46,7 @@ def vanilla_lstm(num_hidden, indata, prev_state, param, seqidx, layeridx, is_bat
def lstm(num_hidden, indata, prev_state, param, seqidx, layeridx, dropout=0., num_hidden_proj=0, is_batchnorm=False,
- gamma=None, beta=None):
+ gamma=None, beta=None, name=None):
"""LSTM Cell symbol"""
# dropout input
if dropout > 0.:
@@ -55,7 +58,10 @@ def lstm(num_hidden, indata, prev_state, param, seqidx, layeridx, dropout=0., nu
num_hidden=num_hidden * 4,
name="t%d_l%d_i2h" % (seqidx, layeridx))
if is_batchnorm:
- i2h = batchnorm(net=i2h, gamma=gamma, beta=beta)
+ if name is not None:
+ i2h = batchnorm(net=i2h, gamma=gamma, beta=beta, name="%s_batchnorm" % name)
+ else:
+ i2h = batchnorm(net=i2h, gamma=gamma, beta=beta)
h2h = mx.sym.FullyConnected(data=prev_state.h,
weight=param.h2h_weight,
@@ -96,7 +102,7 @@ def lstm(num_hidden, indata, prev_state, param, seqidx, layeridx, dropout=0., nu
def lstm_unroll(net, num_lstm_layer, seq_len, num_hidden_lstm_list, dropout=0., num_hidden_proj=0,
- lstm_type='fc_lstm', is_batchnorm=False, prefix="", direction="forward"):
+ lstm_type='fc_lstm', is_batchnorm=False, prefix="", direction="forward", is_bucketing=False):
if num_lstm_layer > 0:
param_cells = []
last_states = []
@@ -121,9 +127,14 @@ def lstm_unroll(net, num_lstm_layer, seq_len, num_hidden_lstm_list, dropout=0.,
if is_batchnorm:
batchnorm_gamma = []
batchnorm_beta = []
- for seqidx in range(seq_len):
- batchnorm_gamma.append(mx.sym.Variable(prefix + "t%d_i2h_gamma" % seqidx))
- batchnorm_beta.append(mx.sym.Variable(prefix + "t%d_i2h_beta" % seqidx))
+ if is_bucketing:
+ for l in range(num_lstm_layer):
+ batchnorm_gamma.append(mx.sym.Variable(prefix + "l%d_i2h_gamma" % l))
+ batchnorm_beta.append(mx.sym.Variable(prefix + "l%d_i2h_beta" % l))
+ else:
+ for seqidx in range(seq_len):
+ batchnorm_gamma.append(mx.sym.Variable(prefix + "t%d_i2h_gamma" % seqidx))
+ batchnorm_beta.append(mx.sym.Variable(prefix + "t%d_i2h_beta" % seqidx))
hidden_all = []
for seqidx in range(seq_len):
@@ -145,18 +156,20 @@ def lstm_unroll(net, num_lstm_layer, seq_len, num_hidden_lstm_list, dropout=0.,
if lstm_type == 'fc_lstm':
if is_batchnorm:
- next_state = lstm(num_hidden_lstm_list[i],
- indata=hidden,
- prev_state=last_states[i],
- param=param_cells[i],
- seqidx=k,
- layeridx=i,
- dropout=dp,
- num_hidden_proj=num_hidden_proj,
- is_batchnorm=is_batchnorm,
- gamma=batchnorm_gamma[k],
- beta=batchnorm_beta[k]
- )
+ if is_bucketing:
+ next_state = lstm(num_hidden_lstm_list[i],
+ indata=hidden,
+ prev_state=last_states[i],
+ param=param_cells[i],
+ seqidx=k,
+ layeridx=i,
+ dropout=dp,
+ num_hidden_proj=num_hidden_proj,
+ is_batchnorm=is_batchnorm,
+ gamma=batchnorm_gamma[i],
+ beta=batchnorm_beta[i],
+ name=prefix + ("t%d_l%d" % (seqidx, i))
+ )
else:
next_state = lstm(num_hidden_lstm_list[i],
indata=hidden,
@@ -166,7 +179,8 @@ def lstm_unroll(net, num_lstm_layer, seq_len, num_hidden_lstm_list, dropout=0.,
layeridx=i,
dropout=dp,
num_hidden_proj=num_hidden_proj,
- is_batchnorm=is_batchnorm
+ is_batchnorm=is_batchnorm,
+ name=prefix + ("t%d_l%d" % (seqidx, i))
)
elif lstm_type == 'vanilla_lstm':
if is_batchnorm:
@@ -175,15 +189,17 @@ def lstm_unroll(net, num_lstm_layer, seq_len, num_hidden_lstm_list, dropout=0.,
param=param_cells[i],
seqidx=k, layeridx=i,
is_batchnorm=is_batchnorm,
- gamma=batchnorm_gamma[k],
- beta=batchnorm_beta[k]
+ gamma=batchnorm_gamma[i],
+ beta=batchnorm_beta[i],
+ name=prefix + ("t%d_l%d" % (seqidx, i))
)
else:
next_state = vanilla_lstm(num_hidden_lstm_list[i], indata=hidden,
prev_state=last_states[i],
param=param_cells[i],
seqidx=k, layeridx=i,
- is_batchnorm=is_batchnorm
+ is_batchnorm=is_batchnorm,
+ name=prefix + ("t%d_l%d" % (seqidx, i))
)
else:
raise Exception("lstm type %s error" % lstm_type)
@@ -206,7 +222,7 @@ def lstm_unroll(net, num_lstm_layer, seq_len, num_hidden_lstm_list, dropout=0.,
def bi_lstm_unroll(net, num_lstm_layer, seq_len, num_hidden_lstm_list, dropout=0., num_hidden_proj=0,
- lstm_type='fc_lstm', is_batchnorm=False):
+ lstm_type='fc_lstm', is_batchnorm=False, is_bucketing=False):
if num_lstm_layer > 0:
net_forward = lstm_unroll(net=net,
num_lstm_layer=num_lstm_layer,
@@ -217,7 +233,8 @@ def bi_lstm_unroll(net, num_lstm_layer, seq_len, num_hidden_lstm_list, dropout=0
lstm_type=lstm_type,
is_batchnorm=is_batchnorm,
prefix="forward_",
- direction="forward")
+ direction="forward",
+ is_bucketing=is_bucketing)
net_backward = lstm_unroll(net=net,
num_lstm_layer=num_lstm_layer,
@@ -228,7 +245,8 @@ def bi_lstm_unroll(net, num_lstm_layer, seq_len, num_hidden_lstm_list, dropout=0
lstm_type=lstm_type,
is_batchnorm=is_batchnorm,
prefix="backward_",
- direction="backward")
+ direction="backward",
+ is_bucketing=is_bucketing)
hidden_all = []
for i in range(seq_len):
hidden_all.append(mx.sym.Concat(*[net_forward[i], net_backward[i]], dim=1))
@@ -239,7 +257,9 @@ def bi_lstm_unroll(net, num_lstm_layer, seq_len, num_hidden_lstm_list, dropout=0
# bilistm_2to1
def bi_lstm_unroll_two_input_two_output(net1, net2, num_lstm_layer, seq_len, num_hidden_lstm_list, dropout=0.,
num_hidden_proj=0,
- lstm_type='fc_lstm', is_batchnorm=False):
+ lstm_type='fc_lstm',
+ is_batchnorm=False,
+ is_bucketing=False):
if num_lstm_layer > 0:
net_forward = lstm_unroll(net=net1,
num_lstm_layer=num_lstm_layer,
@@ -250,7 +270,8 @@ def bi_lstm_unroll_two_input_two_output(net1, net2, num_lstm_layer, seq_len, num
lstm_type=lstm_type,
is_batchnorm=is_batchnorm,
prefix="forward_",
- direction="forward")
+ direction="forward",
+ is_bucketing=is_bucketing)
net_backward = lstm_unroll(net=net2,
num_lstm_layer=num_lstm_layer,
@@ -261,7 +282,8 @@ def bi_lstm_unroll_two_input_two_output(net1, net2, num_lstm_layer, seq_len, num
lstm_type=lstm_type,
is_batchnorm=is_batchnorm,
prefix="backward_",
- direction="backward")
+ direction="backward",
+ is_bucketing=is_bucketing)
return net_forward, net_backward
else:
return net1, net2
diff --git a/example/speech_recognition/stt_metric.py b/example/speech_recognition/stt_metric.py
index 0fc2bd11d906..1c5f4408a60e 100644
--- a/example/speech_recognition/stt_metric.py
+++ b/example/speech_recognition/stt_metric.py
@@ -19,12 +19,11 @@ def check_label_shapes(labels, preds, shape=0):
class STTMetric(mx.metric.EvalMetric):
- def __init__(self, batch_size, num_gpu, seq_length, is_epoch_end=False, is_logging=True):
+ def __init__(self, batch_size, num_gpu, is_epoch_end=False, is_logging=True):
super(STTMetric, self).__init__('STTMetric')
self.batch_size = batch_size
self.num_gpu = num_gpu
- self.seq_length = seq_length
self.total_n_label = 0
self.total_l_dist = 0
self.is_epoch_end = is_epoch_end
@@ -37,15 +36,17 @@ def update(self, labels, preds):
log = LogUtil().getlogger()
labelUtil = LabelUtil.getInstance()
self.batch_loss = 0.
+
for label, pred in zip(labels, preds):
label = label.asnumpy()
pred = pred.asnumpy()
- for i in range(int(int(self.batch_size) / int(self.num_gpu))):
+ seq_length = len(pred) / int(int(self.batch_size) / int(self.num_gpu))
+ for i in range(int(int(self.batch_size) / int(self.num_gpu))):
l = remove_blank(label[i])
p = []
- for k in range(int(self.seq_length)):
+ for k in range(int(seq_length)):
p.append(np.argmax(pred[k * int(int(self.batch_size) / int(self.num_gpu)) + i]))
p = pred_best(p)
@@ -60,7 +61,7 @@ def update(self, labels, preds):
self.num_inst += 1
self.sum_metric += this_cer
if self.is_epoch_end:
- loss = ctc_loss(l, pred, i, int(self.seq_length), int(self.batch_size), int(self.num_gpu))
+ loss = ctc_loss(l, pred, i, int(seq_length), int(self.batch_size), int(self.num_gpu))
self.batch_loss += loss
if self.is_logging:
log.info("loss: %f " % loss)
diff --git a/example/speech_recognition/train.py b/example/speech_recognition/train.py
index 37f00fc4dd90..708e3e03acf0 100644
--- a/example/speech_recognition/train.py
+++ b/example/speech_recognition/train.py
@@ -7,7 +7,8 @@
from stt_metric import STTMetric
#tensorboard setting
from tensorboard import SummaryWriter
-import numpy as np
+import json
+
def get_initializer(args):
@@ -28,6 +29,7 @@ def __init__(self, learning_rate=0.001):
def __call__(self, num_update):
return self.learning_rate
+
def do_training(args, module, data_train, data_val, begin_epoch=0):
from distutils.dir_util import mkpath
from log_util import LogUtil
@@ -35,7 +37,7 @@ def do_training(args, module, data_train, data_val, begin_epoch=0):
log = LogUtil().getlogger()
mkpath(os.path.dirname(get_checkpoint_path(args)))
- seq_len = args.config.get('arch', 'max_t_count')
+ #seq_len = args.config.get('arch', 'max_t_count')
batch_size = args.config.getint('common', 'batch_size')
save_checkpoint_every_n_epoch = args.config.getint('common', 'save_checkpoint_every_n_epoch')
save_checkpoint_every_n_batch = args.config.getint('common', 'save_checkpoint_every_n_batch')
@@ -44,21 +46,22 @@ def do_training(args, module, data_train, data_val, begin_epoch=0):
contexts = parse_contexts(args)
num_gpu = len(contexts)
- eval_metric = STTMetric(batch_size=batch_size, num_gpu=num_gpu, seq_length=seq_len,is_logging=enable_logging_validation_metric,is_epoch_end=True)
+ eval_metric = STTMetric(batch_size=batch_size, num_gpu=num_gpu, is_logging=enable_logging_validation_metric,is_epoch_end=True)
# tensorboard setting
- loss_metric = STTMetric(batch_size=batch_size, num_gpu=num_gpu, seq_length=seq_len,is_logging=enable_logging_train_metric,is_epoch_end=False)
+ loss_metric = STTMetric(batch_size=batch_size, num_gpu=num_gpu, is_logging=enable_logging_train_metric,is_epoch_end=False)
- optimizer = args.config.get('train', 'optimizer')
- momentum = args.config.getfloat('train', 'momentum')
+ optimizer = args.config.get('optimizer', 'optimizer')
learning_rate = args.config.getfloat('train', 'learning_rate')
learning_rate_annealing = args.config.getfloat('train', 'learning_rate_annealing')
mode = args.config.get('common', 'mode')
num_epoch = args.config.getint('train', 'num_epoch')
- clip_gradient = args.config.getfloat('train', 'clip_gradient')
- weight_decay = args.config.getfloat('train', 'weight_decay')
+ clip_gradient = args.config.getfloat('optimizer', 'clip_gradient')
+ weight_decay = args.config.getfloat('optimizer', 'weight_decay')
save_optimizer_states = args.config.getboolean('train', 'save_optimizer_states')
show_every = args.config.getint('train', 'show_every')
+ optimizer_params_dictionary = json.loads(args.config.get('optimizer', 'optimizer_params_dictionary'))
+ kvstore_option = args.config.get('common', 'kvstore_option')
n_epoch=begin_epoch
if clip_gradient == 0:
@@ -75,24 +78,14 @@ def do_training(args, module, data_train, data_val, begin_epoch=0):
lr_scheduler = SimpleLRScheduler(learning_rate=learning_rate)
def reset_optimizer(force_init=False):
- if optimizer == "sgd":
- module.init_optimizer(kvstore='device',
- optimizer=optimizer,
- optimizer_params={'lr_scheduler': lr_scheduler,
- 'momentum': momentum,
- 'clip_gradient': clip_gradient,
- 'wd': weight_decay},
- force_init=force_init)
- elif optimizer == "adam":
- module.init_optimizer(kvstore='device',
- optimizer=optimizer,
- optimizer_params={'lr_scheduler': lr_scheduler,
- #'momentum': momentum,
- 'clip_gradient': clip_gradient,
- 'wd': weight_decay},
- force_init=force_init)
- else:
- raise Exception('Supported optimizers are sgd and adam. If you want to implement others define them in train.py')
+ optimizer_params = {'lr_scheduler': lr_scheduler,
+ 'clip_gradient': clip_gradient,
+ 'wd': weight_decay}
+ optimizer_params.update(optimizer_params_dictionary)
+ module.init_optimizer(kvstore=kvstore_option,
+ optimizer=optimizer,
+ optimizer_params=optimizer_params,
+ force_init=force_init)
if mode == "train":
reset_optimizer(force_init=True)
else:
@@ -101,15 +94,23 @@ def reset_optimizer(force_init=False):
#tensorboard setting
tblog_dir = args.config.get('common', 'tensorboard_log_dir')
summary_writer = SummaryWriter(tblog_dir)
+
+
+ if mode == "train":
+ sort_by_duration = True
+ else:
+ sort_by_duration = False
+
+ if not sort_by_duration:
+ data_train.reset()
+
while True:
if n_epoch >= num_epoch:
break
-
loss_metric.reset()
log.info('---------train---------')
for nbatch, data_batch in enumerate(data_train):
-
module.forward_backward(data_batch)
module.update()
# tensorboard setting
@@ -136,6 +137,7 @@ def reset_optimizer(force_init=False):
assert curr_acc is not None, 'cannot find Acc_exclude_padding in eval metric'
data_train.reset()
+ data_train.is_first_epoch = False
# tensorboard setting
train_cer, train_n_label, train_l_dist, train_ctc_loss = loss_metric.get_name_value()