Skip to content

Commit

Permalink
[MLU]add mlu kernel for sqrt op (#43326)
Browse files Browse the repository at this point in the history
  • Loading branch information
cambriconhsq committed Jun 10, 2022
1 parent 8045fcf commit 6d3a68c
Show file tree
Hide file tree
Showing 2 changed files with 132 additions and 0 deletions.
43 changes: 43 additions & 0 deletions paddle/fluid/operators/activation_op_mlu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,43 @@ class ActivationGradMLUKernelV3 : public framework::OpKernel<T> {
}
};

// For sqrt
template <typename T>
class SqrtMLUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<Tensor>("X");
auto* out = ctx.Output<Tensor>("Out");
auto place = ctx.GetPlace();

out->mutable_data<T>(place);

MLUCnnlTensorDesc input_desc(*x);
MLUCnnlTensorDesc output_desc(*out);

cnnlComputationPreference_t prefer = CNNL_COMPUTATION_FAST;
MLUCnnl::Sqrt(ctx, prefer, input_desc.get(), GetBasePtr(x),
output_desc.get(), GetBasePtr(out));
}
};

template <typename T>
class SqrtGradMLUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* out = ctx.Input<Tensor>("Out");
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto place = ctx.GetPlace();

dx->mutable_data<T>(place);

MLUCnnlTensorDesc data_desc(*out);
MLUCnnl::SqrtGrad(ctx, data_desc.get(), GetBasePtr(out), GetBasePtr(dout),
GetBasePtr(dx));
}
};

} // namespace operators
} // namespace paddle

Expand Down Expand Up @@ -170,3 +207,9 @@ REGISTER_OP_MLU_KERNEL(
ops::ActivationGradMLUKernelV1<CNNL_ACTIVATION_LEAKYRELU, float>,
ops::ActivationGradMLUKernelV1<CNNL_ACTIVATION_LEAKYRELU,
paddle::platform::float16>);

// sqrt
REGISTER_OP_MLU_KERNEL(sqrt, ops::SqrtMLUKernel<float>,
ops::SqrtMLUKernel<paddle::platform::float16>);
REGISTER_OP_MLU_KERNEL(sqrt_grad, ops::SqrtGradMLUKernel<float>,
ops::SqrtGradMLUKernel<paddle::platform::float16>);
89 changes: 89 additions & 0 deletions python/paddle/fluid/tests/unittests/mlu/test_sqrt_op_mlu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import unittest
import numpy as np
import sys

sys.path.append('..')
from op_test import OpTest
import paddle.fluid.core as core
import paddle.fluid as fluid
from paddle.fluid import compiler, Program, program_guard
import paddle
import paddle.nn.functional as F

paddle.enable_static()
np.random.seed(10)


class TestSqrt(OpTest):

def setUp(self):
self.op_type = "sqrt"
self.dtype = 'float32'
self.set_mlu()
self.python_api = paddle.sqrt

np.random.seed(1023)
x = np.random.uniform(0.1, 1, [11, 17]).astype(self.dtype)
out = np.sqrt(x)

self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.outputs = {'Out': out}

def set_mlu(self):
self.__class__.use_mlu = True
self.place = paddle.device.MLUPlace(0)

def test_check_grad(self):
self.check_grad_with_place(self.place, ['X'], 'Out', check_eager=False)

def test_check_output(self):
self.check_output_with_place(self.place)


class TestSqrtHalf(OpTest):

def setUp(self):
self.op_type = "sqrt"
self.dtype = 'float16'
self.set_mlu()
self.python_api = paddle.sqrt

np.random.seed(1023)
x = np.random.uniform(0.1, 1, [11, 17]).astype(self.dtype)
out = np.sqrt(x)

self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.outputs = {'Out': out}

def set_mlu(self):
self.__class__.use_mlu = True
self.place = paddle.device.MLUPlace(0)

def test_check_grad(self):
self.check_grad_with_place(self.place, ['X'],
'Out',
check_eager=False,
max_relative_error=0.85)

def test_check_output(self):
self.check_output_with_place(self.place)


if __name__ == "__main__":
unittest.main()

0 comments on commit 6d3a68c

Please sign in to comment.