Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[XPU]. Fixed the mode error in pad3d. #9506

Merged
merged 2 commits into from
Nov 9, 2022
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 71 additions & 6 deletions lite/kernels/xpu/pad3d_compute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,25 @@ void Pad3dCompute<T>::Run() {
auto* in_data = x->template data<T>();
auto* out = param.Out;
T* out_data = out->template mutable_data<T>(TARGET(kXPU));
bool is_ncdhw;
if (data_format == "NCDHW") {
is_ncdhw = true;
} else if (data_format == "NDHWC") {
is_ncdhw = false;
} else {
LOG(FATAL) << "xpu unsupport data_format: " << data_format;
}
// trans format
std::vector<int> padding(6);
padding[0] = pads[4];
padding[1] = pads[5];
padding[2] = pads[2];
padding[3] = pads[3];
padding[4] = pads[0];
padding[5] = pads[1];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pads可能是一个元素吗?


if (mode == "reflect" || mode == "constant" || mode == "replicate" ||
mode == "circular") {
if (data_format == "NCDHW") {
if (mode == "constant") {
if (is_ncdhw) {
std::vector<int> pad_left = {0, 0, pads[4], pads[2], pads[0]};
std::vector<int> pad_right = {0, 0, pads[5], pads[3], pads[1]};

Expand All @@ -50,7 +65,6 @@ void Pad3dCompute<T>::Run() {
int w_shape = in_dims[4];

std::vector<int> xshape = {n_shape, c_shape, d_shape, h_shape, w_shape};

int r = xdnn::pad<T>(ctx.GetRawContext(),
in_data,
out_data,
Expand All @@ -59,7 +73,7 @@ void Pad3dCompute<T>::Run() {
pad_right,
value);
CHECK_EQ(r, 0);
} else if (data_format == "NDHWC") {
} else {
std::vector<int> pad_left = {0, pads[4], pads[2], pads[0], 0};
std::vector<int> pad_right = {0, pads[5], pads[3], pads[1], 0};

Expand All @@ -79,7 +93,58 @@ void Pad3dCompute<T>::Run() {
value);
CHECK_EQ(r, 0);
}

} else if (mode == "reflect") {
int r = 0;
if (is_ncdhw) {
int r = xdnn::reflection_pad3d<T>(ctx.GetRawContext(),
in_data,
out_data,
in_dims[0],
in_dims[1],
in_dims[2],
in_dims[3],
in_dims[4],
padding,
is_ncdhw);
} else {
int r = xdnn::reflection_pad3d<T>(ctx.GetRawContext(),
in_data,
out_data,
in_dims[0],
in_dims[4],
in_dims[1],
in_dims[2],
in_dims[3],
padding,
is_ncdhw);
}
CHECK_EQ(r, 0);
} else if (mode == "replicate") {
int r = 0;
if (is_ncdhw) {
int r = xdnn::replication_pad3d<T>(ctx.GetRawContext(),
in_data,
out_data,
in_dims[0],
in_dims[1],
in_dims[2],
in_dims[3],
in_dims[4],
padding,
is_ncdhw);
} else {
int r = xdnn::replication_pad3d<T>(ctx.GetRawContext(),
in_data,
out_data,
in_dims[0],
in_dims[4],
in_dims[1],
in_dims[2],
in_dims[3],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

建议使用前面的 n c d h w这些变量作为参数 ---- 理论上不应该存在is_ncdhw为两种值时调用两次api

padding,
is_ncdhw);
}
CHECK_EQ(r, 0);
} else {
LOG(FATAL) << "xpu unsupport mode: " << mode;
}
Expand Down