From 58b599cc57995682c900c1b953857a41f7f42a41 Mon Sep 17 00:00:00 2001 From: QingshuChen Date: Fri, 22 Jul 2022 16:41:59 +0800 Subject: [PATCH] add xpu lars_momentum/pow2_decay (#44448) *test=kunlun --- cmake/external/xpu.cmake | 4 +- .../optimizers/lars_momentum_op_xpu.cc | 115 ++++++++++++++++ .../pow2_decay_with_linear_warmup_op_xpu.cc | 84 ++++++++++++ .../fluid/platform/device/xpu/xpu2_op_list.h | 6 + .../unittests/xpu/get_test_cover_info.py | 1 + .../xpu/test_coalesce_tensor_op_xpu.py | 128 ++++++++++++++++++ ...st_pow2_decay_with_linear_warmup_op_xpu.py | 96 +++++++++++++ 7 files changed, 432 insertions(+), 2 deletions(-) create mode 100644 paddle/fluid/operators/optimizers/lars_momentum_op_xpu.cc create mode 100644 paddle/fluid/operators/optimizers/pow2_decay_with_linear_warmup_op_xpu.cc create mode 100644 python/paddle/fluid/tests/unittests/xpu/test_coalesce_tensor_op_xpu.py create mode 100644 python/paddle/fluid/tests/unittests/xpu/test_pow2_decay_with_linear_warmup_op_xpu.py diff --git a/cmake/external/xpu.cmake b/cmake/external/xpu.cmake index 81128ccf3b6a0..ad4471071e1fc 100644 --- a/cmake/external/xpu.cmake +++ b/cmake/external/xpu.cmake @@ -10,7 +10,7 @@ set(XPU_RT_LIB_NAME "libxpurt.so") if(NOT DEFINED XPU_BASE_URL) set(XPU_BASE_URL_WITHOUT_DATE "https://baidu-kunlun-product.cdn.bcebos.com/KL-SDK/klsdk-dev") - set(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20220718") + set(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20220719") else() set(XPU_BASE_URL "${XPU_BASE_URL}") endif() @@ -19,7 +19,7 @@ endif() if(NOT DEFINED XPU_XDNN_BASE_URL) set(XPU_XDNN_BASE_URL_WITHOUT_DATE "https://klx-sdk-release-public.su.bcebos.com/xdnn/dev") - set(XPU_XDNN_BASE_URL "${XPU_XDNN_BASE_URL_WITHOUT_DATE}/20220718") + set(XPU_XDNN_BASE_URL "${XPU_XDNN_BASE_URL_WITHOUT_DATE}/20220719") else() set(XPU_XDNN_BASE_URL "${XPU_XDNN_BASE_URL}") endif() diff --git a/paddle/fluid/operators/optimizers/lars_momentum_op_xpu.cc b/paddle/fluid/operators/optimizers/lars_momentum_op_xpu.cc new file mode 100644 index 0000000000000..626e071c20c13 --- /dev/null +++ b/paddle/fluid/operators/optimizers/lars_momentum_op_xpu.cc @@ -0,0 +1,115 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef PADDLE_WITH_XPU +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/optimizers/lars_momentum_op.h" +#include "paddle/phi/backends/xpu/enforce_xpu.h" + +namespace paddle { +namespace operators { + +template +class LarsMomentumOpXPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + bool multi_precision = ctx.Attr("multi_precision"); + auto param_out = ctx.MultiOutput("ParamOut"); + auto velocity_out = ctx.MultiOutput("VelocityOut"); + auto param = ctx.MultiInput("Param"); + auto velocity = ctx.MultiInput("Velocity"); + auto learning_rate = ctx.MultiInput("LearningRate"); + auto grad = ctx.MultiInput("Grad"); + auto weight_decay_arr = ctx.Attr>("lars_weight_decay"); + auto master_param = ctx.MultiInput("MasterParam"); + auto master_param_out = + ctx.MultiOutput("MasterParamOut"); + T mu = static_cast(ctx.Attr("mu")); + T lars_coeff = ctx.Attr("lars_coeff"); + T epsilon = ctx.Attr("epsilon"); + T rescale_grad = ctx.Attr("rescale_grad"); + + std::vector param_list; + std::vector grad_list; + std::vector param_out_list; + std::vector velocity_list; + std::vector velocity_out_list; + std::vector lrs; + std::vector param_sizes; + + std::vector master_param_list; + std::vector master_param_out_list; + int op_num = param.size(); + for (int i = 0; i < op_num; ++i) { + param_list.push_back(const_cast(param[i]->data())); + grad_list.push_back(const_cast(grad[i]->data())); + param_out_list.push_back(param_out[i]->mutable_data(ctx.GetPlace())); + velocity_list.push_back(const_cast(velocity[i]->data())); + velocity_out_list.push_back( + velocity_out[i]->mutable_data(ctx.GetPlace())); + lrs.push_back(const_cast(learning_rate[i]->data())); + param_sizes.push_back(param[i]->numel()); + + PADDLE_ENFORCE_EQ( + param_list[i], + param_out_list[i], + platform::errors::InvalidArgument( + "Input(Param) and Output(ParamOut) must be the same Tensors.")); + PADDLE_ENFORCE_EQ(velocity_list[i], + velocity_out_list[i], + platform::errors::InvalidArgument( + "Input(Velocity) and Output(VelocityOut) must be " + "the same Tensors.")); + if (multi_precision) { + master_param_list.push_back( + const_cast(master_param[i]->data())); + master_param_out_list.push_back( + master_param_out[i]->mutable_data(ctx.GetPlace())); + PADDLE_ENFORCE_EQ(master_param_list[i], + master_param_out_list[i], + platform::errors::InvalidArgument( + "Input(MasterParam) and Output(MasterParamOut) " + "must be the same Tensors.")); + } else { + master_param_list.push_back(nullptr); + master_param_out_list.push_back(nullptr); + } + } + + auto& dev_ctx = ctx.template device_context(); + int r = lars_momentum(dev_ctx.x_context(), + param_list, + grad_list, + velocity_list, + lrs, + master_param_list, + param_out_list, + velocity_out_list, + master_param_out_list, + weight_decay_arr, + param_sizes, + mu, + lars_coeff, + epsilon, + rescale_grad); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "lars_momentum"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_XPU_KERNEL(lars_momentum, ops::LarsMomentumOpXPUKernel); +#endif diff --git a/paddle/fluid/operators/optimizers/pow2_decay_with_linear_warmup_op_xpu.cc b/paddle/fluid/operators/optimizers/pow2_decay_with_linear_warmup_op_xpu.cc new file mode 100644 index 0000000000000..4a13e226df8ce --- /dev/null +++ b/paddle/fluid/operators/optimizers/pow2_decay_with_linear_warmup_op_xpu.cc @@ -0,0 +1,84 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifdef PADDLE_WITH_XPU +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/operators/optimizers/pow2_decay_with_linear_warmup_op.h" +#include "paddle/fluid/platform/macros.h" +#include "paddle/phi/backends/xpu/enforce_xpu.h" + +namespace paddle { +namespace operators { + +template +class Pow2DecayWithLinearWarmupXPUOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const { + const auto *lr = ctx.Input("LearningRate"); + const auto *step = ctx.Input("Step"); + auto *lr_out = ctx.Output("LearningRateOut"); + auto *step_out = ctx.Output("StepOut"); + PADDLE_ENFORCE_EQ( + lr, + lr_out, + platform::errors::InvalidArgument("Input(LearningRate) and " + "Output(LearningRateOut) " + "must be the same.")); + PADDLE_ENFORCE_NOT_NULL(lr, + platform::errors::InvalidArgument( + "Input(LearingRate) should not be nullptr.")); + PADDLE_ENFORCE_EQ(step, + step_out, + platform::errors::InvalidArgument( + "Input(Step) and Output(StepOut) must be the same.")); + PADDLE_ENFORCE_NOT_NULL(step, + platform::errors::InvalidArgument( + "Input(Step) should not be nullptr.")); + PADDLE_ENFORCE_EQ( + step->IsInitialized(), + true, + platform::errors::InvalidArgument("Input(Step) must be initialized.")); + + auto warmup_steps = static_cast(ctx.Attr("warmup_steps")); + auto total_steps = static_cast(ctx.Attr("total_steps")); + PADDLE_ENFORCE_LE(warmup_steps, + total_steps, + platform::errors::InvalidArgument( + "warmup_steps must not be larger than total_steps.")); + auto base_lr = ctx.Attr("base_lr"); + auto end_lr = ctx.Attr("end_lr"); + + auto *lr_data = lr_out->data(); + auto *step_data = step_out->data(); + auto &dev_ctx = ctx.template device_context(); + int r = xpu::pow2_decay_with_linear_warmup(dev_ctx.x_context(), + lr_data, + step_data, + warmup_steps, + total_steps, + base_lr, + end_lr); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "pow2_decay_with_linear_warmup"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_XPU_KERNEL(pow2_decay_with_linear_warmup, + ops::Pow2DecayWithLinearWarmupXPUOpKernel); +#endif diff --git a/paddle/fluid/platform/device/xpu/xpu2_op_list.h b/paddle/fluid/platform/device/xpu/xpu2_op_list.h index 8cae8cfe534ef..1d4a5bf74b8df 100644 --- a/paddle/fluid/platform/device/xpu/xpu2_op_list.h +++ b/paddle/fluid/platform/device/xpu/xpu2_op_list.h @@ -71,6 +71,8 @@ XPUOpMap& get_kl2_ops() { XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), pOpKernelType(vartype::FP16, XPUPlace())})}, {"clip", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"coalesce_tensor", + XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"concat_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), pOpKernelType(vartype::FP16, XPUPlace())})}, @@ -255,6 +257,8 @@ XPUOpMap& get_kl2_ops() { XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"label_smooth", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"lars_momentum", + XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"layer_norm_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"layer_norm_grad", @@ -334,6 +338,8 @@ XPUOpMap& get_kl2_ops() { pOpKernelType(vartype::FP16, XPUPlace())})}, {"pow", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"pow_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"pow2_decay_with_linear_warmup", + XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"prior_box", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"range", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), diff --git a/python/paddle/fluid/tests/unittests/xpu/get_test_cover_info.py b/python/paddle/fluid/tests/unittests/xpu/get_test_cover_info.py index bcaa8055b25cd..b58cc6a6bb31f 100644 --- a/python/paddle/fluid/tests/unittests/xpu/get_test_cover_info.py +++ b/python/paddle/fluid/tests/unittests/xpu/get_test_cover_info.py @@ -88,6 +88,7 @@ 'dropout_float16', 'dropout_grad_float16', "grad_add_float32", # no api for grad_add, skip + "lars_momentum_float32", "resnet_unit", "resnet_unit_grad" ] diff --git a/python/paddle/fluid/tests/unittests/xpu/test_coalesce_tensor_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_coalesce_tensor_op_xpu.py new file mode 100644 index 0000000000000..2f2fb2c83e4d4 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/xpu/test_coalesce_tensor_op_xpu.py @@ -0,0 +1,128 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from paddle.fluid import core +import sys + +sys.path.append("..") +from op_test import OpTest + +alignment = 256 +import paddle +from op_test_xpu import XPUOpTest +from xpu.get_test_cover_info import create_test_class, get_xpu_op_support_types, XPUOpTestWrapper + +paddle.enable_static() + + +class XPUTestCoalesceTensorOp(XPUOpTestWrapper): + + def __init__(self): + self.op_name = 'coalesce_tensor' + self.use_dynamic_create_class = False + + class TestAllocContinuousSpace(XPUOpTest): + + def setUp(self): + self.op_type = "coalesce_tensor" + self.use_xpu = True + self.dtype, self.fluid_dtype = self.init_dtype() + attrs = self.init_attr() + self.copy_data = attrs["copy_data"] + self.constant = attrs["constant"] + self.set_constant = attrs["set_constant"] + self.Inputs = self.init_input() + self.Outputs, self.FusedOutput = self.init_output( + self.Inputs, self.set_constant, self.constant) + self.inputs = {'Input': self.Inputs} + self.attrs = attrs + self.outputs = { + 'Output': self.Outputs, + 'FusedOutput': self.FusedOutput + } + + def init_dtype(self): + return np.float32, int(core.VarDesc.VarType.FP32) + + def init_input(self): + inputs = [] + inputs.append(("x1", np.random.random([20, 3]).astype(self.dtype))) + inputs.append(("x2", np.random.random([20]).astype(self.dtype))) + inputs.append(("x3", np.random.random([1]).astype(self.dtype))) + inputs.append(("x4", np.random.random([200, + 30]).astype(self.dtype))) + inputs.append(("x5", np.random.random([30]).astype(self.dtype))) + inputs.append(("x6", np.random.random([1]).astype(self.dtype))) + return inputs + + def init_attr(self): + return { + "copy_data": True, + "set_constant": False, + "constant": 0.0, + "dtype": self.fluid_dtype + } + + def init_output(self, input_list, set_constant, constant): + inputs = [] + outputs = input_list + + for input in input_list: + length = len(input[1].flatten()) + aligned_len = (length + alignment) / alignment * alignment + out = np.zeros(int(aligned_len)) + out[0:length] = input[1].flatten() + inputs.append(out) + + coalesce_tensor_var = np.concatenate([input for input in inputs]) + if set_constant: + coalesce_tensor_var = np.ones( + (len(coalesce_tensor_var))) * constant + outputs = [(out[0], + np.ones(out[1].shape).astype(self.dtype) * constant) + for out in outputs] + return outputs, coalesce_tensor_var + + def test_check_output(self): + self.check_output_with_place(place=core.XPUPlace(0), + no_check_set=["FusedOutput"], + atol=1e-5) + + class TestAllocContinuousSpace2(TestAllocContinuousSpace): + + def init_attr(self): + return { + "copy_data": False, + "set_constant": True, + "constant": 0.5, + "dtype": self.fluid_dtype, + "user_defined_size_of_dtype": 2 + } + + def test_check_output(self): + self.check_output_with_place(place=core.XPUPlace(0), + no_check_set=["FusedOutput"], + atol=1e-5) + + +support_types = get_xpu_op_support_types('coalesce_tensor') +for stype in support_types: + create_test_class(globals(), XPUTestCoalesceTensorOp, stype) + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/xpu/test_pow2_decay_with_linear_warmup_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_pow2_decay_with_linear_warmup_op_xpu.py new file mode 100644 index 0000000000000..e1bba6ed1bd57 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/xpu/test_pow2_decay_with_linear_warmup_op_xpu.py @@ -0,0 +1,96 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +from paddle.fluid.contrib.layers.nn import pow2_decay_with_linear_warmup +from paddle.optimizer.lr import LinearWarmup +from paddle.optimizer.lr import PolynomialDecay +import unittest +import sys + +sys.path.append("..") + +from op_test import OpTest +from op_test_xpu import XPUOpTest +from xpu.get_test_cover_info import record_op_test + + +def gen_pow2_warmup_op_lr(warmup_steps, total_steps, base_lr, end_lr, place): + main = paddle.static.Program() + startup = paddle.static.Program() + with paddle.static.program_guard(main, startup): + lr = pow2_decay_with_linear_warmup(warmup_steps, total_steps, base_lr, + end_lr) + exe = paddle.static.Executor(place) + with paddle.static.scope_guard(paddle.static.Scope()): + exe.run(startup) + while True: + lr_np = exe.run(main, fetch_list=[lr])[0] + yield lr_np[0] + + +class Pow2Warmup(LinearWarmup): + + def __init__(self, warmup_steps, total_steps, base_lr, end_lr): + assert total_steps > warmup_steps + lr_sch = PolynomialDecay(learning_rate=base_lr, + decay_steps=total_steps - warmup_steps, + end_lr=end_lr, + power=2) + + super(Pow2Warmup, self).__init__(learning_rate=lr_sch, + warmup_steps=warmup_steps, + start_lr=0.0, + end_lr=base_lr) + + +def gen_pow2_warmup_py_lr(warmup_steps, total_steps, base_lr, end_lr, place): + lr_sch = Pow2Warmup(warmup_steps, total_steps, base_lr, end_lr) + lr_sch.step() + while True: + yield lr_sch() + lr_sch.step() + + +class TestPowWarmup(unittest.TestCase): + + def setUp(self): + paddle.enable_static() + self.op_type = 'pow2_decay_with_linear_warmup' + self.params = { + 'warmup_steps': 30, + 'total_steps': 100, + 'base_lr': 0.02, + 'end_lr': 0.001, + } + self.step_num = 1000 + + def check_with_place(self, place): + kwargs = dict(self.params) + kwargs['place'] = place + lr_sch_op = gen_pow2_warmup_op_lr(**kwargs) + lr_sch_py = gen_pow2_warmup_py_lr(**kwargs) + for i, (lr_op, lr_py) in enumerate(zip(lr_sch_op, lr_sch_py)): + self.assertLess(abs(lr_op - lr_py), 1e-6) + if i > self.step_num: + break + + def test_main(self): + self.check_with_place(paddle.XPUPlace(0)) + + +record_op_test("pow2_decay_with_linear_warmup", "float32") + +if __name__ == "__main__": + unittest.main()