Skip to content

Commit

Permalink
[PIR] [DynamicShape] Add infer_symbolic and unit test for Conv2dOp (#…
Browse files Browse the repository at this point in the history
…62798)

* conv2d

* fix build bugs
  • Loading branch information
zhangbopd committed Mar 22, 2024
1 parent 38bbcf8 commit 65126fa
Show file tree
Hide file tree
Showing 2 changed files with 155 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,137 @@
#include "paddle/common/ddim.h"
#include "paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/infer_sym_utils.h"

namespace {

inline void UpdatePaddingAndDilation(
std::vector<symbol::DimExpr> *paddings,
std::vector<symbol::DimExpr> *dilation,
const std::string padding_algorithm,
const std::vector<symbol::DimExpr> data_dims,
const std::vector<int> &strides,
const std::vector<symbol::DimExpr> &ksize) {
// set padding size == data_dims.size() * 2
if (paddings->size() == data_dims.size()) {
for (size_t i = 0; i < data_dims.size(); ++i) {
symbol::DimExpr copy_pad = *(paddings->begin() + 2 * i);
paddings->insert(paddings->begin() + 2 * i + 1, copy_pad);
}
}

// when padding_algorithm is "VALID" or "SAME"
symbol::DimExpr zero{0};
symbol::DimExpr one{1};
symbol::DimExpr two{2};
if (padding_algorithm == "SAME") {
symbol::DimExprBuilder builder{nullptr};
for (size_t i = 0; i < data_dims.size(); ++i) {
symbol::DimExpr out_size = (data_dims[i] + strides[i] - 1) / strides[i];
symbol::DimExpr pad_sum = builder.Max(
(out_size - one) * strides[i] + ksize[i] - data_dims[i], zero);

symbol::DimExpr pad_0 = pad_sum / two;
symbol::DimExpr pad_1 = pad_sum - pad_0;

*(paddings->begin() + i * 2) = pad_0;
*(paddings->begin() + i * 2 + 1) = pad_1;

// dilation
*(dilation->begin() + i) = one;
}

} else if (padding_algorithm == "VALID") {
for (auto it = paddings->begin(); it != paddings->end(); it++) {
*it = zero;
}
}
}

} // namespace
namespace paddle::dialect {

bool Conv2dOpInferSymbolicShape(
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) {
PADDLE_THROW(phi::errors::Unimplemented(
op->name() + " 's InferSymbolicShape interface is NOT implemented now."));
const std::vector<int> strides =
paddle::dialect::details::GetVectorAttr<int>(op, "strides");

std::vector<int> paddings =
paddle::dialect::details::GetVectorAttr<int>(op, "paddings");

std::vector<int> dilations =
paddle::dialect::details::GetVectorAttr<int>(op, "dilations");

const auto &attributes = op->attributes();
const std::string data_format =
attributes.at("data_format").dyn_cast<pir::StrAttribute>().AsString();

const std::string padding_algorithm = attributes.at("padding_algorithm")
.dyn_cast<pir::StrAttribute>()
.AsString();

const auto in_s_or_d =
shape_analysis->GetShapeOrDataForValue(op->operand_source(0));
const auto filter_s_or_d =
shape_analysis->GetShapeOrDataForValue(op->operand_source(1));

const bool channel_last = (data_format == "NHWC" || data_format == "NDHWC");

std::vector<symbol::DimExpr> in_data_dims =
channel_last ? std::vector<symbol::DimExpr>(in_s_or_d.shape().begin() + 1,
in_s_or_d.shape().end() - 1)
: std::vector<symbol::DimExpr>(in_s_or_d.shape().begin() + 2,
in_s_or_d.shape().end());

std::vector<symbol::DimExpr> filter_data_dims = std::vector<symbol::DimExpr>(
filter_s_or_d.shape().begin() + 2, filter_s_or_d.shape().end());

std::vector<symbol::DimExpr> ksize = filter_data_dims;

std::vector<symbol::DimExpr> new_paddings;
for (const auto &i : paddings) {
new_paddings.push_back(symbol::DimExpr{i});
}
std::vector<symbol::DimExpr> new_dilations;
for (const auto &i : dilations) {
new_dilations.push_back(symbol::DimExpr{i});
}

UpdatePaddingAndDilation(&new_paddings,
&new_dilations,
padding_algorithm,
in_data_dims,
strides,
ksize);

const symbol::ShapeOrDataDimExprs &shape_data = [&] {
std::vector<symbol::DimExpr> out_s_or_d({in_s_or_d.shape()[0]});
if (!channel_last) {
out_s_or_d.push_back(filter_s_or_d.shape()[0]);
}

for (size_t i = 0; i < in_data_dims.size(); ++i) {
if (!in_data_dims[i].isa<int64_t>() ||
!filter_s_or_d.shape()[i + 2].isa<int64_t>()) {
out_s_or_d.push_back(shape_analysis->GetNextSymName());
} else {
const symbol::DimExpr dkernel =
new_dilations[i] * (filter_data_dims[i] - 1) + 1;
symbol::DimExpr output_size = (in_data_dims[i] + new_paddings[2 * i] +
new_paddings[2 * i + 1] - dkernel) /
strides[i] +
1;
out_s_or_d.push_back(output_size);
}
}
if (channel_last) {
out_s_or_d.push_back(filter_s_or_d.shape()[0]);
}

return symbol::ShapeOrDataDimExprs{
symbol::TensorShapeOrDataDimExprs(out_s_or_d)};
}();

shape_analysis->SetShapeOrDataForValue(op->result(0), shape_data);

return true;
}

Expand Down
28 changes: 28 additions & 0 deletions test/ir/pir/cinn/symbolic/test_infer_sym_shape_binary_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,5 +172,33 @@ def test_eval_symbolic(self):
return True


class Conv2dNet(paddle.nn.Layer):
def __init__(self):
super().__init__()
self.conv = paddle.nn.Conv2D(4, 6, (3, 3))

def forward(self, x):
z = paddle.empty(shape=[2, 4, 8, 8])
out = self.conv(z)
return out


class Conv2dOpInferSymbolicShapeTest(TestBase):
def prepare_data(self):
self.expected = ['shape[2, 6, 6, 6], data[NULL]']

def test_eval_symbolic(self):
net = Conv2dNet()

x_spec = InputSpec(shape=[None, None, None], dtype='float32')

input_spec = [x_spec]
net = apply_to_static(net, False, input_spec)
net.eval()
check_infer_results(net, input_spec, 'pd_op.conv2d', self.expected)

return True


if __name__ == '__main__':
unittest.main()

0 comments on commit 65126fa

Please sign in to comment.