Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[metal] fix_elementwise #7467

Merged
merged 2 commits into from
Nov 3, 2021

Conversation

xiaoxiaohehe001
Copy link
Collaborator

针对metalelementwise系列代码包括进行了重写和整合,其中包括:
add、sub、mul、div全面支持mps框架,支持broadcast
对elementwise代码进行了整合,目前通过ElementwiseImageCompute单个类的调用即可

@zhangjun zhangjun changed the title [Paddle-Lite Metal] fix_elementwise [metal] fix_elementwise Oct 28, 2021
Comment on lines 92 to 93
auto op_type = KernelBase::op_type();
op_ = op_type;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

直接取值, ele_type_ = KernelBase::op_type()

std::shared_ptr<MetalBuffer> params_buffer_;
DDim last_input_dims_{};

std::string op_;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

改成ele_type_

id<MTLComputePipelineState> pipline_;
std::string function_name_;
MetalContext* metal_context_;

int op_num;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

使用准确的名字,operation_type_

Comment on lines 30 to 50
bool InputsValid(const MetalImage* input_x_, const MetalImage* input_y_) {
auto x_dims = input_x_->dim_;
auto y_dims = input_y_->dim_;

// check data layout
if (input_x_->transpose_ != input_y_->transpose_) return false;
// check data dims equal
if (x_dims == y_dims) return true;

if (x_dims[0] == y_dims[0] && x_dims[3] == y_dims[3]) {
//[1 32 1 3]
if (x_dims[1] == y_dims[1] && (x_dims[2] == 1 || y_dims[2] == 1)) return true;
//[1 1 32 3]
if (x_dims[2] == y_dims[2] && (x_dims[1] == 1 || y_dims[1] == 1)) return true;
//[1 1 1 3]
if ((x_dims[1] == 1 && x_dims[2] == 1) || (y_dims[1] == 1 && y_dims[2] == 1)) return true;
}
return false;
}

void ElementwiseImageCompute::PrepareForRun() {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

代码格式化一下


params_buffer_ =
std::make_shared<MetalBuffer>(metal_context_, sizeof(element_params), &element_params);

function_name_ = fuse_flag_ ? "elementwise_add_relu" : "elementwise_add";

function_name_ = fuse_flag_ ? "elementwise_relu" : "elementwise_";
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

elementwise_改成elementwise

@zhangjun zhangjun merged commit 15d14bc into PaddlePaddle:develop Nov 3, 2021
zhangjun pushed a commit to zhangjun/Paddle-Lite that referenced this pull request Nov 27, 2021
* fix_elementwise

* fix_elementwise
zhangjun added a commit that referenced this pull request Nov 29, 2021
* [metal] fix_elementwise (#7467)

* fix_elementwise

* fix_elementwise

* [metal] fix_elementwise (#7737)

* fix_elementwise

* fix_element

Co-authored-by: xiaoxiaohehe001 <49090790+xiaoxiaohehe001@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants