-
Notifications
You must be signed in to change notification settings - Fork 1.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[OpBugFix]support OpenCL fc shape size >2 and in_num_col_dims > 1 #8465
[OpBugFix]support OpenCL fc shape size >2 and in_num_col_dims > 1 #8465
Conversation
Thanks for your contribution! |
} | ||
|
||
} else { | ||
output.y = 0.0f; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
需要改为支持half的情况,output.y = (CL_DTYPE)0;
output.z = input2.w; | ||
} | ||
} else { | ||
output.z = 0.0f; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
output.w = input3.w; | ||
} | ||
} else { | ||
output.w = 0.0f; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
output.w = input3.w; | ||
} | ||
} else { | ||
output.w = 0.0f; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
} | ||
|
||
} else { | ||
output.y = 0.0f; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
output.z = input2.w; | ||
} | ||
} else { | ||
output.z = 0.0f; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
@@ -53,6 +53,8 @@ class FcImageCompute : public KernelLite<TARGET(kOpenCL), | |||
// convert weights from cpu to gpu | |||
auto w_cpu_t = std::unique_ptr<Tensor>(new Tensor); | |||
w_gpu_t_ = std::unique_ptr<Tensor>(new Tensor); | |||
layout_input_image_ = std::unique_ptr<Tensor>(new Tensor); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
将这两行代码放在后面的if判断中吧,当输入是2维数据时,不用执行。
@@ -170,22 +173,113 @@ class FcImageCompute : public KernelLite<TARGET(kOpenCL), | |||
kernel_key << kernel_func_name_ << build_options_ << time_stamp_; | |||
kernel_ = context.cl_context()->GetKernel(kernel_key.str()); | |||
|
|||
context.cl_context()->AddKernel( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
建议这两个kernel编译也放在if判断中
auto x_dims = param.input->dims(); | ||
auto out_dims = param.output->dims(); | ||
cl::Image2D* x_img_src = DATA_GPU(param.input); | ||
cl::NDRange layout_gws; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
layout_gws 变量放在 Line 200 if 判断内
cl::NDRange layout_gws; | ||
if (x_dims.size() > 2 && x_dims.size() <= 4) { | ||
int in_num_col_dims = param.in_num_col_dims; | ||
std::vector<size_t> new_dims = {1, 1, 1, 1}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
对原始维度扩展为4维,这个函数挺常用, 已经在 image_helfer.h 中定义了,可以直接调用auto new_dims = Broadcast2GpuShape(x_dims);
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
several comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
支持fc中shape size > 2和in_num_col_dims > 1