Skip to content

Commit

Permalink
Fix fill_constant_batch_size_like_op when input is LoDTensor. (#10943)
Browse files Browse the repository at this point in the history
  • Loading branch information
qingqing01 committed May 28, 2018
1 parent bf869e4 commit 91bd583
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 0 deletions.
8 changes: 8 additions & 0 deletions paddle/fluid/operators/fill_constant_batch_size_like_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,14 @@ class FillConstantBatchSizeLikeOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* out = ctx.Output<framework::Tensor>("Out");
auto* in = ctx.Input<framework::LoDTensor>("Input");
if (in->lod().size() && ctx.Attr<int>("input_dim_idx") == 0) {
// set the correct batch size for the LoDTensor.
auto odims = out->dims();
int output_dim_idx = ctx.Attr<int>("output_dim_idx");
odims[output_dim_idx] = static_cast<int>(in->lod().back().size()) - 1;
out->mutable_data<T>(odims, ctx.GetPlace());
}
out->mutable_data<T>(ctx.GetPlace());
auto value = ctx.Attr<float>("value");

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,5 +50,27 @@ def test_check_output(self):
self.check_output()


class TestFillConstantBatchSizeLikeWithLoDTensor(OpTest):
def setUp(self):
self.op_type = "fill_constant_batch_size_like"
self.inputs = {
'Input': (np.random.random((31, 28)).astype("float32"),
[[0, 9, 23, 31]])
}
self.attrs = {
'value': 3.5,
'shape': [-1, 16],
'input_dim_idx': 0,
'output_dim_idx': 0
}

out = np.random.random((3, 16)).astype("float32")
out.fill(3.5)
self.outputs = {'Out': out}

def test_check_output(self):
self.check_output()


if __name__ == "__main__":
unittest.main()

0 comments on commit 91bd583

Please sign in to comment.