Skip to content

Commit

Permalink
polish error_message, test=kunlun
Browse files Browse the repository at this point in the history
  • Loading branch information
tink2123 committed Oct 12, 2020
1 parent e8f44f5 commit 8107d6a
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 20 deletions.
24 changes: 14 additions & 10 deletions paddle/fluid/operators/conv_op_xpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,14 @@ class GemmConvXPUKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE(dilations[0] == 1 && dilations[1] == 1,
"XPU only support dilation == 1.");
auto& dev_ctx = context.template device_context<DeviceContext>();
PADDLE_ENFORCE(
PADDLE_ENFORCE_EQ(
xpu::findmax(dev_ctx.x_context(), input->data<T>(), input->numel(),
max_input->data<T>()) == xpu::Error_t::SUCCESS,
"XPU kernel error!");
PADDLE_ENFORCE(
true, platform::errors::InvalidArgument("XPU kernel error!"));
PADDLE_ENFORCE_EQ(
xpu::findmax(dev_ctx.x_context(), filter.data<T>(), filter.numel(),
max_filter->data<T>()) == xpu::Error_t::SUCCESS,
"XPU kernel error!");
true, platform::errors::InvalidArgument("XPU kernel error!"));
if (groups == 1) {
int r = xpu::conv2d_forward_int16<float, float, float, float>(
dev_ctx.x_context(), batch_size, img_c, img_h, img_w, f, win_h, win_w,
Expand All @@ -62,15 +62,17 @@ class GemmConvXPUKernel : public framework::OpKernel<T> {
output->data<float>(), nullptr, nullptr, xpu::Activation_t::LINEAR,
// nullptr, nullptr);
max_input->data<float>(), max_filter->data<float>());
PADDLE_ENFORCE(r == xpu::Error_t::SUCCESS, "XPU kernel error!");
PADDLE_ENFORCE_EQ(r == xpu::Error_t::SUCCESS, true,
platform::errors::InvalidArgument("XPU kernel error!"));
} else {
int r = xpu::conv2d_int16_with_group<float, float, float>(
dev_ctx.x_context(), input->data<float>(), filter.data<float>(),
output->data<float>(), batch_size, img_c, img_h, img_w, f, win_h,
win_w, groups, strides[0], strides[1], paddings[0], paddings[1],
// nullptr, nullptr);
max_input->data<float>(), max_filter->data<float>());
PADDLE_ENFORCE(r == xpu::Error_t::SUCCESS, "XPU kernel error!");
PADDLE_ENFORCE_EQ(r == xpu::Error_t::SUCCESS, true,
platform::errors::InvalidArgument("XPU kernel error!"));
}
}
};
Expand Down Expand Up @@ -116,11 +118,11 @@ class GemmConvGradXPUKernel : public framework::OpKernel<T> {
auto& dev_ctx = context.template device_context<DeviceContext>();
max_output_grad->Resize({4});
max_output_grad->mutable_data<T>(context.GetPlace());
PADDLE_ENFORCE(
PADDLE_ENFORCE_EQ(
xpu::findmax(dev_ctx.x_context(), output_grad->data<T>(),
output_grad->numel(),
max_output_grad->data<T>()) == xpu::Error_t::SUCCESS,
"XPU kernel error!");
true, platform::errors::InvalidArgument("XPU kernel error!"));
if (input_grad) {
int r = xpu::conv2d_backward_int16(
dev_ctx.x_context(), batch_size, img_c, img_h, img_w, f, win_h, win_w,
Expand All @@ -129,7 +131,8 @@ class GemmConvGradXPUKernel : public framework::OpKernel<T> {
filter.data<float>(), input_grad->data<float>(),
// nullptr, nullptr,
max_output_grad->data<float>(), max_filter->data<float>());
PADDLE_ENFORCE(r == xpu::Error_t::SUCCESS, "XPU kernel error!");
PADDLE_ENFORCE_EQ(r == xpu::Error_t::SUCCESS, true,
platform::errors::InvalidArgument("XPU kernel error!"));
}
if (filter_grad) {
int r = xpu::conv2d_backward_weight_int16(
Expand All @@ -139,7 +142,8 @@ class GemmConvGradXPUKernel : public framework::OpKernel<T> {
input->data<float>(), filter_grad->data<float>(),
// nullptr, nullptr,
max_output_grad->data<float>(), max_input->data<float>());
PADDLE_ENFORCE(r == xpu::Error_t::SUCCESS, "XPU kernel error!");
PADDLE_ENFORCE_EQ(r == xpu::Error_t::SUCCESS, true,
platform::errors::InvalidArgument("XPU kernel error!"));
}
}
};
Expand Down
28 changes: 18 additions & 10 deletions python/paddle/fluid/tests/unittests/xpu/test_conv2d_op_xpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,9 @@ def test_check_output(self):
paddle.enable_static()
place = paddle.XPUPlace(0)
self.check_output_with_place(
place, atol=1e-5, check_dygraph=(self.use_mkldnn == False))
place,
#atol=1e-5,
check_dygraph=(self.use_mkldnn == False))

def test_check_grad(self):
if self.dtype == np.float16 or (hasattr(self, "no_need_check_grad") and
Expand All @@ -233,9 +235,10 @@ def test_check_grad(self):
paddle.enable_static()
place = paddle.XPUPlace(0)
self.check_grad_with_place(
place, {'Input', 'Filter'},
place,
{'Input', 'Filter'},
'Output',
max_relative_error=0.02,
#max_relative_error=0.02,
check_dygraph=(self.use_mkldnn == False))

def test_check_grad_no_filter(self):
Expand All @@ -246,9 +249,10 @@ def test_check_grad_no_filter(self):
paddle.enable_static()
place = paddle.XPUPlace(0)
self.check_grad_with_place(
place, ['Input'],
place,
['Input'],
'Output',
max_relative_error=0.02,
#max_relative_error=0.02,
no_grad_set=set(['Filter']),
check_dygraph=(self.use_mkldnn == False))

Expand Down Expand Up @@ -436,7 +440,9 @@ def test_check_output(self):
paddle.enable_static()
place = paddle.XPUPlace(0)
self.check_output_with_place(
place, atol=1e-5, check_dygraph=(self.use_mkldnn == False))
place,
#atol=1e-5,
check_dygraph=(self.use_mkldnn == False))

def test_check_grad(self):
# TODO(wangzhongpu): support mkldnn op in dygraph mode
Expand All @@ -446,9 +452,10 @@ def test_check_grad(self):
paddle.enable_static()
place = paddle.XPUPlace(0)
self.check_grad_with_place(
place, {'Input', 'Filter'},
place,
{'Input', 'Filter'},
'Output',
max_relative_error=0.02,
#max_relative_error=0.02,
check_dygraph=(self.use_mkldnn == False))

def test_check_grad_no_filter(self):
Expand All @@ -459,9 +466,10 @@ def test_check_grad_no_filter(self):
paddle.enable_static()
place = paddle.XPUPlace(0)
self.check_grad_with_place(
place, ['Input'],
place,
['Input'],
'Output',
max_relative_error=0.02,
#max_relative_error=0.02,
no_grad_set=set(['Filter']),
check_dygraph=(self.use_mkldnn == False))

Expand Down

0 comments on commit 8107d6a

Please sign in to comment.