Skip to content

Commit

Permalink
[0D-Tensor] CINN supports fill_constant, fix infershape and pass (#…
Browse files Browse the repository at this point in the history
…55563)

* [0D-Tensor] CINN supports fill_constant, fix infershape and pass

* fix infershape of fill_constant

* add back fill_constant to zero_tensor_trick_pass
  • Loading branch information
jiahy0825 committed Jul 26, 2023
1 parent 97ec1d8 commit f5830c0
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 1 deletion.
17 changes: 17 additions & 0 deletions paddle/cinn/frontend/pass/expand_zero_dim_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ class ExpandZeroDimPass : public ProgramPass {
if (instr->op_type == "transpose") {
builder.AppendInstruction(HandleTranspose(instr));
continue;
} else if (instr->op_type == "fill_constant") {
builder.AppendInstruction(HandleFillConstant(instr));
continue;
}
for (auto& input : instr->inputs) {
if (input->shape.empty()) {
Expand Down Expand Up @@ -101,6 +104,20 @@ class ExpandZeroDimPass : public ProgramPass {
}
return new_instr;
}

// Before: out-0D = fill_constant([], 123.456, "out", "float32")
// After: out-1D = fill_constant([1], 123.456, "out", "float32")
Instruction HandleFillConstant(const Instruction& instr) {
Instruction new_instr = instr;
std::vector<int32_t> shape =
new_instr.GetAttrs<std::vector<int32_t>>("shape");
if (shape.empty()) {
shape.push_back(1);
VLOG(4) << "Change fill_constant's attribute shape from [] to [1]";
}
new_instr.SetAttr<std::vector<int32_t>>("shape", shape);
return new_instr;
}
};

} // namespace pass
Expand Down
1 change: 0 additions & 1 deletion paddle/cinn/hlir/op/elementwise.cc
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,6 @@ std::vector<shape_t> InferShapeForFillConstant(
const framework::AttrMapType &attrs) {
CHECK(attrs.count("shape"));
auto shape = absl::get<std::vector<int>>(attrs.at("shape"));
CHECK(!shape.empty()) << "shape attr is empty!";
return {shape};
}

Expand Down
27 changes: 27 additions & 0 deletions test/cinn/ops/test_zero_dim_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1318,6 +1318,33 @@ def init_input(self):
self.target_shape = []


@OpTestTool.skip_if(
not is_compiled_with_cuda(), "x86 test will be skipped due to timeout."
)
class TestFillConstantOp(OpTest):
def setUp(self):
np.random.seed(2023)
self.target_shape = ()

def build_paddle_program(self, target):
out = paddle.full([], 123.456, "float32")

self.paddle_outputs = [out]

def build_cinn_program(self, target):
builder = NetBuilder("fill_constant_op")
out = builder.fill_constant([], 123.456, "out", "float32")

prog = builder.build()
res = self.get_cinn_output(prog, target, [], [], [out])

self.cinn_outputs = res
self.assertEqual(res[0].shape, self.target_shape)

def test_check_results(self):
self.check_outputs_and_grads()


@OpTestTool.skip_if(
not is_compiled_with_cuda(), "x86 test will be skipped due to timeout."
)
Expand Down

0 comments on commit f5830c0

Please sign in to comment.