Skip to content

Commit

Permalink
Merge pull request #772 from kuke/fix_fetch
Browse files Browse the repository at this point in the history
Solve the problem of fetching prediction and make data dim configurable
  • Loading branch information
Yibing Liu committed Mar 27, 2018
2 parents ab01a0b + b6baf32 commit 745591a
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 14 deletions.
14 changes: 12 additions & 2 deletions fluid/DeepASR/infer_by_ckpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ def parse_args():
default=1,
help='The minimum sequence number of a batch data. '
'(default: %(default)d)')
parser.add_argument(
'--frame_dim',
type=int,
default=120 * 11,
help='Frame dimension of feature data. (default: %(default)d)')
parser.add_argument(
'--stacked_num',
type=int,
Expand All @@ -47,6 +52,11 @@ def parse_args():
type=int,
default=1024,
help='Hidden size of lstmp unit. (default: %(default)d)')
parser.add_argument(
'--class_num',
type=int,
default=1749,
help='Number of classes in label. (default: %(default)d)')
parser.add_argument(
'--learning_rate',
type=float,
Expand Down Expand Up @@ -99,10 +109,11 @@ def infer_from_ckpt(args):
raise IOError("Invalid checkpoint!")

prediction, avg_cost, accuracy = stacked_lstmp_model(
frame_dim=args.frame_dim,
hidden_dim=args.hidden_dim,
proj_dim=args.proj_dim,
stacked_num=args.stacked_num,
class_num=1749,
class_num=args.class_num,
parallel=args.parallel)

infer_program = fluid.default_main_program().clone()
Expand Down Expand Up @@ -156,7 +167,6 @@ def infer_from_ckpt(args):
for index, sample in enumerate(infer_batch):
print("Decoding %d: " % (batch_id * args.batch_size + index),
decoder.decode(sample))

print(np.mean(infer_costs), np.mean(infer_accs))


Expand Down
23 changes: 13 additions & 10 deletions fluid/DeepASR/model_utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
import paddle.fluid as fluid


def stacked_lstmp_model(hidden_dim,
def stacked_lstmp_model(frame_dim,
hidden_dim,
proj_dim,
stacked_num,
class_num,
Expand All @@ -20,12 +21,13 @@ def stacked_lstmp_model(hidden_dim,
label data respectively. And in inference, only `feature` is needed.
Args:
hidden_dim(int): The hidden state's dimension of the LSTMP layer.
proj_dim(int): The projection size of the LSTMP layer.
stacked_num(int): The number of stacked LSTMP layers.
parallel(bool): Run in parallel or not, default `False`.
is_train(bool): Run in training phase or not, default `True`.
class_dim(int): The number of output classes.
frame_dim(int): The frame dimension of feature data.
hidden_dim(int): The hidden state's dimension of the LSTMP layer.
proj_dim(int): The projection size of the LSTMP layer.
stacked_num(int): The number of stacked LSTMP layers.
parallel(bool): Run in parallel or not, default `False`.
is_train(bool): Run in training phase or not, default `True`.
class_dim(int): The number of output classes.
"""

# network configuration
Expand Down Expand Up @@ -78,7 +80,7 @@ def _net_conf(feature, label):

# data feeder
feature = fluid.layers.data(
name="feature", shape=[-1, 120 * 11], dtype="float32", lod_level=1)
name="feature", shape=[-1, frame_dim], dtype="float32", lod_level=1)
label = fluid.layers.data(
name="label", shape=[-1, 1], dtype="int64", lod_level=1)

Expand All @@ -92,11 +94,12 @@ def _net_conf(feature, label):
feat_ = pd.read_input(feature)
label_ = pd.read_input(label)
prediction, avg_cost, acc = _net_conf(feat_, label_)
for out in [avg_cost, acc]:
for out in [prediction, avg_cost, acc]:
pd.write_output(out)

# get mean loss and acc through every devices.
avg_cost, acc = pd()
prediction, avg_cost, acc = pd()
prediction.stop_gradient = True
avg_cost = fluid.layers.mean(x=avg_cost)
acc = fluid.layers.mean(x=acc)
else:
Expand Down
13 changes: 12 additions & 1 deletion fluid/DeepASR/tools/profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ def parse_args():
default=1,
help='The minimum sequence number of a batch data. '
'(default: %(default)d)')
parser.add_argument(
'--frame_dim',
type=int,
default=120 * 11,
help='Frame dimension of feature data. (default: %(default)d)')
parser.add_argument(
'--stacked_num',
type=int,
Expand All @@ -46,6 +51,11 @@ def parse_args():
type=int,
default=1024,
help='Hidden size of lstmp unit. (default: %(default)d)')
parser.add_argument(
'--class_num',
type=int,
default=1749,
help='Number of classes in label. (default: %(default)d)')
parser.add_argument(
'--learning_rate',
type=float,
Expand Down Expand Up @@ -119,10 +129,11 @@ def profile(args):
"arg 'first_batches_to_skip' must not be smaller than 0.")

_, avg_cost, accuracy = stacked_lstmp_model(
frame_dim=args.frame_dim,
hidden_dim=args.hidden_dim,
proj_dim=args.proj_dim,
stacked_num=args.stacked_num,
class_num=1749,
class_num=args.class_num,
parallel=args.parallel)

optimizer = fluid.optimizer.Adam(learning_rate=args.learning_rate)
Expand Down
13 changes: 12 additions & 1 deletion fluid/DeepASR/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@ def parse_args():
default=1,
help='The minimum sequence number of a batch data. '
'(default: %(default)d)')
parser.add_argument(
'--frame_dim',
type=int,
default=120 * 11,
help='Frame dimension of feature data. (default: %(default)d)')
parser.add_argument(
'--stacked_num',
type=int,
Expand All @@ -45,6 +50,11 @@ def parse_args():
type=int,
default=1024,
help='Hidden size of lstmp unit. (default: %(default)d)')
parser.add_argument(
'--class_num',
type=int,
default=1749,
help='Number of classes in label. (default: %(default)d)')
parser.add_argument(
'--pass_num',
type=int,
Expand Down Expand Up @@ -137,10 +147,11 @@ def train(args):
os.mkdir(args.infer_models)

prediction, avg_cost, accuracy = stacked_lstmp_model(
frame_dim=args.frame_dim,
hidden_dim=args.hidden_dim,
proj_dim=args.proj_dim,
stacked_num=args.stacked_num,
class_num=1749,
class_num=args.class_num,
parallel=args.parallel)

# program for test
Expand Down

0 comments on commit 745591a

Please sign in to comment.