Skip to content

Commit

Permalink
fix v2 infer interface.
Browse files Browse the repository at this point in the history
  • Loading branch information
lcy-seso committed Aug 31, 2017
1 parent 45ced9d commit 09e903e
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
1 change: 0 additions & 1 deletion paddle/gserver/layers/CrossEntropyOverBeam.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ void CostForOneSequence::calValidExpandStep() {
if (start + beamSize_ == findEnd) return;
goldColIds_[i] = findEnd - start;
}

if (goldColIds_[beams_->expansionCount - 1] != -1) goldAsExtraPath_ = false;
}

Expand Down
7 changes: 5 additions & 2 deletions python/paddle/v2/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def iter_infer_field(self, field, **kwargs):
item = [each_result[each_field] for each_field in field]
yield item

def infer(self, input, field='value', **kwargs):
def infer(self, input, field='value', flatten_result=True, **kwargs):
"""
Infer a data by model.
:param input: input data batch. Should be python iterable object.
Expand All @@ -83,7 +83,10 @@ def infer(self, input, field='value', **kwargs):
retv = [[] for i in xrange(len(result))]
for i, item in enumerate(result):
retv[i].append(item)
retv = [numpy.concatenate(out) for out in retv]

if flatten_result:
retv = [numpy.concatenate(out) for out in retv]

if len(retv) == 1:
return retv[0]
else:
Expand Down

0 comments on commit 09e903e

Please sign in to comment.