Skip to content

Commit

Permalink
[x86] fix im2col when datasize > 2^31 (the max value that int can exp…
Browse files Browse the repository at this point in the history
…ress) (#7300)

* [x86] fix im2col when datasize > 2^31 (the max value that int can express)

* fix windows ci
  • Loading branch information
mjp9527 committed Oct 19, 2021
1 parent f83d759 commit 61a6717
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 40 deletions.
72 changes: 40 additions & 32 deletions lite/backends/x86/math/avx/conv_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1141,31 +1141,36 @@ void im2col_s1<float>(const float* data_im,
(width + pad_left + pad_right - (dilation_w * (kernel_w - 1) + 1)) + 1;
const int in_channel_size = height * width;
const int out_channel_size = output_h * output_w;
const int output_plane_size = output_h * output_w * kernel_h * kernel_w;
memset(data_col, 0, output_plane_size * channels * sizeof(float));
const unsigned int output_plane_size =
output_h * output_w * kernel_h * kernel_w;
size_t tmp_size = static_cast<size_t>(output_plane_size);
size_t mem_size = tmp_size * channels * sizeof(float);
memset(data_col, 0, mem_size);
#pragma omp parallel for
for (int c = 0; c < channels; c++) {
int data_im_z = c * in_channel_size;
int data_col_z1 = c * output_plane_size;
unsigned int data_im_z = c * in_channel_size;
unsigned int data_col_z1 = c * output_plane_size;
for (int ky = 0, h_offset = 0; ky < kernel_h;
ky++, h_offset += dilation_h) {
int data_col_z2 = ky * out_channel_size * kernel_w;
unsigned int data_col_z2 = ky * out_channel_size * kernel_w;
for (int kx = 0, w_offset = 0; kx < kernel_w;
kx++, w_offset += dilation_w) {
int data_col_z3 = kx * out_channel_size;
int data_col_z = data_col_z1 + data_col_z2 + data_col_z3;
int oh_begin = std::max(((pad_top - h_offset)), 0);
int oh_end = std::min(((height + pad_bottom - h_offset)), output_h);
unsigned int data_col_z3 = kx * out_channel_size;
unsigned int data_col_z = data_col_z1 + data_col_z2 + data_col_z3;
unsigned int oh_begin = std::max(((pad_top - h_offset)), 0);
unsigned int oh_end =
std::min(((height + pad_bottom - h_offset)), output_h);
oh_end = std::max(oh_begin, oh_end);
int ow_begin = std::max(((pad_left - w_offset)), 0);
int ow_end = std::min(((width + pad_right - w_offset)), output_w);
unsigned int ow_begin = std::max(((pad_left - w_offset)), 0);
unsigned int ow_end =
std::min(((width + pad_right - w_offset)), output_w);
ow_end = std::max(ow_begin, ow_end);
int ih = oh_begin - pad_top + h_offset;
unsigned int ih = oh_begin - pad_top + h_offset;
for (int oh = oh_begin; oh < oh_end; ++oh, ++ih) {
int iw = ow_begin - pad_left + w_offset;
int ow = ow_begin;
int data_im_offset = data_im_z + ih * width;
int data_col_offset = data_col_z + oh * output_w;
unsigned int iw = ow_begin - pad_left + w_offset;
unsigned int ow = ow_begin;
unsigned int data_im_offset = data_im_z + ih * width;
unsigned int data_col_offset = data_col_z + oh * output_w;
const float* data_im_ptr = data_im + data_im_offset;
float* data_col_ptr = data_col + data_col_offset;
#ifdef __AVX__
Expand Down Expand Up @@ -1209,33 +1214,36 @@ void im2col_s2<float>(const float* data_im,
(width + pad_left + pad_right - (dilation_w * (kernel_w - 1) + 1)) / 2 +
1;
const int in_channel_size = height * width;
const int output_plane_size = output_h * output_w * kernel_h * kernel_w;
memset(data_col, 0, output_plane_size * channels * sizeof(float));
const unsigned int output_plane_size =
output_h * output_w * kernel_h * kernel_w;
size_t tmp_size = static_cast<size_t>(output_plane_size);
size_t mem_size = tmp_size * channels * sizeof(float);
memset(data_col, 0, mem_size);
#pragma omp parallel for
for (int c = 0; c < channels; c++) {
int data_im_z = c * in_channel_size;
int data_col_z1 = c * output_plane_size;
unsigned int data_im_z = c * in_channel_size;
unsigned int data_col_z1 = c * output_plane_size;
for (int ky = 0, h_offset = 0; ky < kernel_h;
ky++, h_offset += dilation_h) {
int data_col_z2 = ky * output_h * output_w * kernel_w;
unsigned int data_col_z2 = ky * output_h * output_w * kernel_w;
for (int kx = 0, w_offset = 0; kx < kernel_w;
kx++, w_offset += dilation_w) {
int data_col_z3 = kx * output_h * output_w;
int data_col_z = data_col_z1 + data_col_z2 + data_col_z3;
int oh_begin = std::max(((pad_top - h_offset + 1) / 2), 0);
int oh_end =
unsigned int data_col_z3 = kx * output_h * output_w;
unsigned int data_col_z = data_col_z1 + data_col_z2 + data_col_z3;
unsigned int oh_begin = std::max(((pad_top - h_offset + 1) / 2), 0);
unsigned int oh_end =
std::min(((height + pad_bottom - h_offset + 1) / 2), output_h);
oh_end = std::max(oh_begin, oh_end);
int ow_begin = std::max(((pad_left - w_offset + 1) / 2), 0);
int ow_end =
unsigned int ow_begin = std::max(((pad_left - w_offset + 1) / 2), 0);
unsigned int ow_end =
std::min(((width + pad_right - w_offset + 1) / 2), output_w);
ow_end = std::max(ow_begin, ow_end);
int ih = oh_begin * 2 - pad_top + h_offset;
unsigned int ih = oh_begin * 2 - pad_top + h_offset;
for (int oh = oh_begin; oh < oh_end; ++oh, ih += 2) {
int iw = ow_begin * 2 - pad_left + w_offset;
int ow = ow_begin;
int data_im_offset = data_im_z + ih * width;
int data_col_offset = data_col_z + oh * output_w;
unsigned int iw = ow_begin * 2 - pad_left + w_offset;
unsigned int ow = ow_begin;
unsigned int data_im_offset = data_im_z + ih * width;
unsigned int data_col_offset = data_col_z + oh * output_w;
const float* data_im_ptr = data_im + data_im_offset;
float* data_col_ptr = data_col + data_col_offset;
for (; ow + 3 < ow_end; ow += 4, iw += 8) {
Expand Down
16 changes: 8 additions & 8 deletions lite/kernels/x86/conv_compute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -118,11 +118,11 @@ void Conv2dCompute<PRECISION(kFloat), PRECISION(kFloat)>::Run() {
auto& ctx = ctx_->As<X86Context>();
INIT_PARAM
bool flag_bias = (param.bias != nullptr);
int group_size_out = m * n;
int group_size_weights = m * k;
int group_size_coldata = n * k;
int channel_in_size = chin * hin * win;
int channel_out_size = chout * hout * wout;
unsigned int group_size_out = m * n;
unsigned int group_size_weights = m * k;
unsigned int group_size_coldata = n * k;
unsigned int channel_in_size = chin * hin * win;
unsigned int channel_out_size = chout * hout * wout;
auto paddings = *param.paddings;
auto dilations = *param.dilations;

Expand All @@ -135,9 +135,9 @@ void Conv2dCompute<PRECISION(kFloat), PRECISION(kFloat)>::Run() {
float* col_data = nullptr;

if (!flag_1x1gemm_) {
int col_size = group * group_size_coldata;
col_data = static_cast<float*>(
TargetMalloc(TARGET(kX86), col_size * sizeof(float)));
size_t col_size = group_size_coldata * group;
size_t col_data_size = static_cast<size_t>(col_size * sizeof(float));
col_data = static_cast<float*>(TargetMalloc(TARGET(kX86), col_data_size));
}
auto act_param = param.activation_param;
paddle::lite::x86::math::Blas<lite::TargetType::kX86> matmul(ctx);
Expand Down

0 comments on commit 61a6717

Please sign in to comment.