From ad8ceda26ed078cf310bddcc86cfd40fd920d4d0 Mon Sep 17 00:00:00 2001 From: tiancaishaonvjituizi <452565578@qq.com> Date: Tue, 26 Apr 2022 14:20:57 +0800 Subject: [PATCH 01/11] fix false positive warning in gcc>=9 --- paddle/utils/variant.h | 30 ++++++++++++++++++++++-------- 1 file changed, 22 insertions(+), 8 deletions(-) diff --git a/paddle/utils/variant.h b/paddle/utils/variant.h index a7546d094c2ff..ed579edf18772 100644 --- a/paddle/utils/variant.h +++ b/paddle/utils/variant.h @@ -1823,6 +1823,20 @@ struct dtor { template class destructor; +// gcc >= 9 has a bug that creates a false positive warning +// Reference: +// https://gcc.gnu.org/bugzilla/show_bug.cgi?id=92145 +// https://gcc.gnu.org/bugzilla/show_bug.cgi?id=89381 +#if defined(__GNUC__) && !defined(__clang__) && __GNUC__ >= 9 +#define USING_SUPER_ASSIGNMENT_OPERATOR \ +_Pragma("GCC diagnostic push") \ +_Pragma("GCC diagnostic ignored \"-Wdeprecated-copy\"") \ +using super::operator=; \ +_Pragma("GCC diagnostic pop") +#else +#define USING_SUPER_ASSIGNMENT_OPERATOR using super::operator=; +#endif + #define MPARK_VARIANT_DESTRUCTOR(destructible_trait, definition, destroy) \ template \ class destructor, destructible_trait> \ @@ -1831,7 +1845,7 @@ class destructor; \ public: \ MPARK_INHERITING_CTOR(destructor, super) \ - using super::operator=; \ + USING_SUPER_ASSIGNMENT_OPERATOR \ \ destructor(const destructor &) = default; \ destructor(destructor &&) = default; \ @@ -1867,7 +1881,7 @@ class constructor : public destructor { public: MPARK_INHERITING_CTOR(constructor, super) - using super::operator=; + USING_SUPER_ASSIGNMENT_OPERATOR protected: #ifndef MPARK_GENERIC_LAMBDAS @@ -1919,7 +1933,7 @@ class move_constructor; \ public: \ MPARK_INHERITING_CTOR(move_constructor, super) \ - using super::operator=; \ + USING_SUPER_ASSIGNMENT_OPERATOR \ \ move_constructor(const move_constructor &) = default; \ definition ~move_constructor() = default; \ @@ -1955,7 +1969,7 @@ class copy_constructor; \ public: \ MPARK_INHERITING_CTOR(copy_constructor, super) \ - using super::operator=; \ + USING_SUPER_ASSIGNMENT_OPERATOR \ \ definition copy_constructor(copy_constructor &&) = default; \ ~copy_constructor() = default; \ @@ -1984,7 +1998,7 @@ class assignment : public copy_constructor { public: MPARK_INHERITING_CTOR(assignment, super) - using super::operator=; + USING_SUPER_ASSIGNMENT_OPERATOR template inline /* auto & */ auto emplace(Args &&... args) @@ -2071,7 +2085,7 @@ class move_assignment; \ public: \ MPARK_INHERITING_CTOR(move_assignment, super) \ - using super::operator=; \ + USING_SUPER_ASSIGNMENT_OPERATOR \ \ move_assignment(const move_assignment &) = default; \ move_assignment(move_assignment &&) = default; \ @@ -2111,7 +2125,7 @@ class copy_assignment; \ public: \ MPARK_INHERITING_CTOR(copy_assignment, super) \ - using super::operator=; \ + USING_SUPER_ASSIGNMENT_OPERATOR \ \ copy_assignment(const copy_assignment &) = default; \ copy_assignment(copy_assignment &&) = default; \ @@ -2141,7 +2155,7 @@ class impl : public copy_assignment> { public: MPARK_INHERITING_CTOR(impl, super) - using super::operator=; + USING_SUPER_ASSIGNMENT_OPERATOR template inline void assign(Arg &&arg) { From 178dd270cfc0599bd0df5dab286c4915903325aa Mon Sep 17 00:00:00 2001 From: tiancaishaonvjituizi <452565578@qq.com> Date: Tue, 26 Apr 2022 15:00:38 +0800 Subject: [PATCH 02/11] use more aggressive way --- paddle/utils/variant.h | 39 +++++++++++++++++---------------------- 1 file changed, 17 insertions(+), 22 deletions(-) diff --git a/paddle/utils/variant.h b/paddle/utils/variant.h index ed579edf18772..7b11ae1bee88c 100644 --- a/paddle/utils/variant.h +++ b/paddle/utils/variant.h @@ -13,6 +13,11 @@ #pragma once +#if defined(__GNUC__) && !defined(__clang__) && __GNUC__ >= 9 +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wdeprecated-copy" +#endif + /* variant synopsis @@ -1823,20 +1828,6 @@ struct dtor { template class destructor; -// gcc >= 9 has a bug that creates a false positive warning -// Reference: -// https://gcc.gnu.org/bugzilla/show_bug.cgi?id=92145 -// https://gcc.gnu.org/bugzilla/show_bug.cgi?id=89381 -#if defined(__GNUC__) && !defined(__clang__) && __GNUC__ >= 9 -#define USING_SUPER_ASSIGNMENT_OPERATOR \ -_Pragma("GCC diagnostic push") \ -_Pragma("GCC diagnostic ignored \"-Wdeprecated-copy\"") \ -using super::operator=; \ -_Pragma("GCC diagnostic pop") -#else -#define USING_SUPER_ASSIGNMENT_OPERATOR using super::operator=; -#endif - #define MPARK_VARIANT_DESTRUCTOR(destructible_trait, definition, destroy) \ template \ class destructor, destructible_trait> \ @@ -1845,7 +1836,7 @@ _Pragma("GCC diagnostic pop") \ public: \ MPARK_INHERITING_CTOR(destructor, super) \ - USING_SUPER_ASSIGNMENT_OPERATOR \ + using super::operator=; \ \ destructor(const destructor &) = default; \ destructor(destructor &&) = default; \ @@ -1881,7 +1872,7 @@ class constructor : public destructor { public: MPARK_INHERITING_CTOR(constructor, super) - USING_SUPER_ASSIGNMENT_OPERATOR + using super::operator=; protected: #ifndef MPARK_GENERIC_LAMBDAS @@ -1933,7 +1924,7 @@ class move_constructor; \ public: \ MPARK_INHERITING_CTOR(move_constructor, super) \ - USING_SUPER_ASSIGNMENT_OPERATOR \ + using super::operator=; \ \ move_constructor(const move_constructor &) = default; \ definition ~move_constructor() = default; \ @@ -1969,7 +1960,7 @@ class copy_constructor; \ public: \ MPARK_INHERITING_CTOR(copy_constructor, super) \ - USING_SUPER_ASSIGNMENT_OPERATOR \ + using super::operator=; \ \ definition copy_constructor(copy_constructor &&) = default; \ ~copy_constructor() = default; \ @@ -1998,7 +1989,7 @@ class assignment : public copy_constructor { public: MPARK_INHERITING_CTOR(assignment, super) - USING_SUPER_ASSIGNMENT_OPERATOR + using super::operator=; template inline /* auto & */ auto emplace(Args &&... args) @@ -2085,7 +2076,7 @@ class move_assignment; \ public: \ MPARK_INHERITING_CTOR(move_assignment, super) \ - USING_SUPER_ASSIGNMENT_OPERATOR \ + using super::operator=; \ \ move_assignment(const move_assignment &) = default; \ move_assignment(move_assignment &&) = default; \ @@ -2125,7 +2116,7 @@ class copy_assignment; \ public: \ MPARK_INHERITING_CTOR(copy_assignment, super) \ - USING_SUPER_ASSIGNMENT_OPERATOR \ + using super::operator=; \ \ copy_assignment(const copy_assignment &) = default; \ copy_assignment(copy_assignment &&) = default; \ @@ -2155,7 +2146,7 @@ class impl : public copy_assignment> { public: MPARK_INHERITING_CTOR(impl, super) - USING_SUPER_ASSIGNMENT_OPERATOR + using super::operator=; template inline void assign(Arg &&arg) { @@ -2842,3 +2833,7 @@ struct hash { }; } // namespace std + +#if defined(__GNUC__) && !defined(__clang__) && __GNUC__ >= 9 +#pragma GCC diagnostic pop +#endif From 6cfd8ec327170e32c960015d24ca4e54498e9421 Mon Sep 17 00:00:00 2001 From: tiancaishaonvjituizi <452565578@qq.com> Date: Tue, 26 Apr 2022 19:45:26 +0800 Subject: [PATCH 03/11] stash --- paddle/fluid/operators/optimizers/asgd_op.cc | 71 +++++++++++++++ paddle/phi/kernels/asgd_kernel.h | 33 +++++++ paddle/phi/kernels/cpu/asgd_kernel.cc | 21 +++++ paddle/phi/kernels/gpu/asgd_kernel.cu | 21 +++++ paddle/phi/kernels/impl/asgd_kernel_impl.h | 70 ++++++++++++++ .../fluid/tests/unittests/model.pdparams | Bin 0 -> 299 bytes python/paddle/fluid/tests/unittests/opt.pdopt | Bin 0 -> 613 bytes .../fluid/tests/unittests/test_optimizer.py | 18 ++++ python/paddle/optimizer/__init__.py | 2 + python/paddle/optimizer/asgd.py | 86 ++++++++++++++++++ 10 files changed, 322 insertions(+) create mode 100644 paddle/fluid/operators/optimizers/asgd_op.cc create mode 100644 paddle/phi/kernels/asgd_kernel.h create mode 100644 paddle/phi/kernels/cpu/asgd_kernel.cc create mode 100644 paddle/phi/kernels/gpu/asgd_kernel.cu create mode 100644 paddle/phi/kernels/impl/asgd_kernel_impl.h create mode 100644 python/paddle/fluid/tests/unittests/model.pdparams create mode 100644 python/paddle/fluid/tests/unittests/opt.pdopt create mode 100644 python/paddle/optimizer/asgd.py diff --git a/paddle/fluid/operators/optimizers/asgd_op.cc b/paddle/fluid/operators/optimizers/asgd_op.cc new file mode 100644 index 0000000000000..ee6f13ba72551 --- /dev/null +++ b/paddle/fluid/operators/optimizers/asgd_op.cc @@ -0,0 +1,71 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/multiary.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +class AsgdOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Param"), ctx.GetPlace()); + } +}; + +class AsgdOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("Param", "(Tensor) Input parameter"); + AddInput("LearningRate", "(Tensor) Learning rate of SGD"); + AddInput("Grad", "(Tensor) Input gradient"); + AddInput("AvgParam", + "(Tensor) Average of parameter"); + AddInput("CurrentStep", + "(Tensor) Current step"); + AddInput("t0", + "(Tensor) point at which to start averaging"); + AddOutput("ParamOut", + "(Tensor, same with Param) " + "Output parameter, should share the same memory with Param"); + AddOutput("AvgParamOut", + "(Tensor, same with AvgParam) Average of parameter"); + AddOutput("CurrentStepOut", + "(Tensor) Increased step"); + + AddComment(R"DOC( +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +DECLARE_INFER_SHAPE_FUNCTOR(asgd, AsgdInferMetaFunctor, + PD_INFER_META(phi::SgdInferMeta)); +REGISTER_OPERATOR( + asgd, ops::AsgdOp, ops::AsgdOpMaker, + paddle::framework::EmptyGradOpMaker, + paddle::framework::EmptyGradOpMaker, + AsgdInferMetaFunctor); diff --git a/paddle/phi/kernels/asgd_kernel.h b/paddle/phi/kernels/asgd_kernel.h new file mode 100644 index 0000000000000..1b29bd2906f72 --- /dev/null +++ b/paddle/phi/kernels/asgd_kernel.h @@ -0,0 +1,33 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void AsgdKernel(const Context& dev_ctx, + const DenseTensor& param, + const DenseTensor& learning_rate, + const DenseTensor& grad, + const DenseTensor& avg_param, + const DenseTensor& current_step, + const DenseTensor& t0, + DenseTensor* param_out, + DenseTensor* avg_param_out, + DenseTensor* current_step_out); + +} // namespace phi diff --git a/paddle/phi/kernels/cpu/asgd_kernel.cc b/paddle/phi/kernels/cpu/asgd_kernel.cc new file mode 100644 index 0000000000000..b71e2b67d4852 --- /dev/null +++ b/paddle/phi/kernels/cpu/asgd_kernel.cc @@ -0,0 +1,21 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/asgd_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/asgd_kernel_impl.h" + +PD_REGISTER_KERNEL(asgd, CPU, ALL_LAYOUT, phi::AsgdKernel, float, double) {} diff --git a/paddle/phi/kernels/gpu/asgd_kernel.cu b/paddle/phi/kernels/gpu/asgd_kernel.cu new file mode 100644 index 0000000000000..2b9e71988b523 --- /dev/null +++ b/paddle/phi/kernels/gpu/asgd_kernel.cu @@ -0,0 +1,21 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/asgd_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/asgd_kernel_impl.h" + +PD_REGISTER_KERNEL(asgd, GPU, ALL_LAYOUT, phi::AsgdKernel, float, double) {} diff --git a/paddle/phi/kernels/impl/asgd_kernel_impl.h b/paddle/phi/kernels/impl/asgd_kernel_impl.h new file mode 100644 index 0000000000000..33ba3eaa54410 --- /dev/null +++ b/paddle/phi/kernels/impl/asgd_kernel_impl.h @@ -0,0 +1,70 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/platform/place.h" +#include "paddle/phi/kernels/adadelta_kernel.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" +#include "paddle/phi/kernels/funcs/eigen/eigen_function.h" + +namespace phi { + +template +void AsgdKernel(const Context& dev_ctx, + const DenseTensor& param, + const DenseTensor& learning_rate, + const DenseTensor& grad, + const DenseTensor& avg_param, + const DenseTensor& current_step, + const DenseTensor& t0, + DenseTensor* param_out, + DenseTensor* avg_param_out, + DenseTensor* current_step_out) { + dev_ctx.template Alloc(param_out); + dev_ctx.template Alloc(avg_param_out); + + auto eigen_param = EigenVector::Flatten(param); + auto eigen_grad = EigenVector::Flatten(grad); + auto eigen_avg_param = EigenVector::Flatten(avg_param); + auto eigen_current_step = EigenVector::Flatten(current_step); + auto eigen_t0 = EigenVector::Flatten(t0); + auto eigen_param_out = EigenVector::Flatten(*param_out); + auto eigen_avg_param_out = EigenVector::Flatten(*avg_param_out); + auto eigen_current_step_out = EigenVector::Flatten(*current_step_out); + auto& place = *dev_ctx.eigen_device(); + + if (paddle::platform::is_cpu_place(dev_ctx.GetPlace())) { + auto lr = learning_rate.data()[0]; + eigen_param_out.device(place) = eigen_param - lr * eigen_grad; + } else { + Eigen::DSizes dsize(param_out->numel()); + auto eigen_lr = EigenVector::Flatten(learning_rate); + eigen_param_out.device(place) = + eigen_param - eigen_lr.broadcast(dsize) * eigen_grad; + } + + if (eigen_current_step < eigen_t0) { + eigen_avg_param_out.device(place) = eigen_param_out; + } else { + // const auto mu = eigen_current_step - eigen_t0 + 1; + eigen_avg_param_out.device(place) = + eigen_avg_param + (eigen_param_out - eigen_avg_param) / eigen_current_step; + } + + // eigen_current_step_out = eigen_current_step + 1; + eigen_current_step++; +} + +} // namespace phi diff --git a/python/paddle/fluid/tests/unittests/model.pdparams b/python/paddle/fluid/tests/unittests/model.pdparams new file mode 100644 index 0000000000000000000000000000000000000000..aa9eb2496a41a2f51f382b8e1e8602e00b4b6984 GIT binary patch literal 299 zcmZo*nX16Z00y;FGgAT^lw>9r z6(v?q>EVelN=?qsD=sN2O)i--c}fo}SmBf&_Pi8`lF3uN8Cs_}GbT-GpAs}h!<*5Y z$(yNdN(O5WYf4FFK`KZSQ<~9~&JKt?a}T4<6hA*dum3;*CcGI+rX+Pbb1c|8bB~0@ z?VVGA+#WuJ8G9;WUzxQ%HV`s1LO#HIE1c59o|ghqGI@$OL+cc0#-u6jQ-Y>wcr$u4c{8<5$>8i^O)04? z0K0-I&16bv2SlE^htXz=pP!%Ce;@!8-V7yEk~*CQ*6Z2KdE37mg3R)vfE-lpq0Pt|_{r~^~ literal 0 HcmV?d00001 diff --git a/python/paddle/fluid/tests/unittests/test_optimizer.py b/python/paddle/fluid/tests/unittests/test_optimizer.py index ba1e9be815de6..c736753ff4fdd 100644 --- a/python/paddle/fluid/tests/unittests/test_optimizer.py +++ b/python/paddle/fluid/tests/unittests/test_optimizer.py @@ -69,6 +69,24 @@ def check_sgd_optimizer(optimizer_attr): self.assertEqual(len(opts), 1) self.assertEqual([op.type for op in opts], ["sgd"]) + def test_asgd_optimizer(self): + class MyLayer(paddle.nn.Layer): + def __init__(self): + super(MyLayer, self).__init__() + self._w = self.create_parameter([2, 3]) + self._b = self.create_parameter([2, 3]) + + def forward(self, x): + return x * self._w + self._b + + with paddle.fluid.dygraph.guard(): + model = MyLayer() + x = paddle.rand([10, 2, 3]) + loss = model(x) + adam = paddle.optimizer.Adam(parameters=model.parameters()) + loss.backward() + adam.step() + class TestOptimizerBackwardApplygrad(unittest.TestCase): def test_sgd_optimizer(self): diff --git a/python/paddle/optimizer/__init__.py b/python/paddle/optimizer/__init__.py index 07d2935bc7646..a813ec046e025 100644 --- a/python/paddle/optimizer/__init__.py +++ b/python/paddle/optimizer/__init__.py @@ -17,6 +17,7 @@ from .adam import Adam # noqa: F401 from .adamw import AdamW # noqa: F401 from .adamax import Adamax # noqa: F401 +from .asgd import ASGD # noqa: F401 from .rmsprop import RMSProp # noqa: F401 from .adadelta import Adadelta # noqa: F401 from .sgd import SGD # noqa: F401 @@ -30,6 +31,7 @@ 'Adam', 'AdamW', 'Adamax', + 'ASGD', 'RMSProp', 'Adadelta', 'SGD', diff --git a/python/paddle/optimizer/asgd.py b/python/paddle/optimizer/asgd.py new file mode 100644 index 0000000000000..ae2022f16a9b3 --- /dev/null +++ b/python/paddle/optimizer/asgd.py @@ -0,0 +1,86 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .optimizer import Optimizer +from ..fluid import core +from ..fluid import framework +from ..fluid.framework import Variable + +__all__ = [] + + +class ASGD(Optimizer): + r""" + """ + _avg_parameter_acc_str = "_avg_parameter" + + def __init__(self, + learning_rate, + t0=1e6, + parameters=None, + weight_decay=None, + grad_clip=None, + name=None): + assert learning_rate is not None + super(ASGD, self).__init__( + learning_rate=learning_rate, + parameters=parameters, + weight_decay=weight_decay, + grad_clip=grad_clip, + name=name) + self.type = "asgd" + self._t0 = t0 + self._default_dict = { + 't0': t0, + } + + def _create_accumulators(self, block, parameters): + if not isinstance(block, framework.Block): + raise TypeError("block is not instance of framework.Block.") + if isinstance(parameters, dict): + parameters = parameters.get('params') + + for p in parameters: + self._add_accumulator(self._avg_parameter_acc_str, p) + + def _append_optimize_op(self, block, param_and_grad): + assert isinstance(block, framework.Block) + + if isinstance(param_and_grad, dict): + param_and_grad = self._update_param_group(param_and_grad) + + avg_parameter = self._get_accumulator( + self._avg_parameter_acc_str, param_and_grad[0]) + + # Create the adagrad optimizer op + asgd_op = block.append_op( + type=self.type, + inputs={ + "Param": param_and_grad[0], + "Grad": param_and_grad[1], + "LearningRate": self._create_param_lr(param_and_grad), + "AvgParam": avg_parameter + }, + outputs={"ParamOut": param_and_grad[0], + "AvgParamOut": avg_parameter + }, + attrs={'t0': self._t0}, + stop_gradient=True) + + return asgd_op + + def _update_param_group(self, parameters): + self._t0 = parameters.get('t0', self._default_dict['t0']) + parameters = parameters.get('params') + return parameters From c3b49d88f321fb55a657167443e3b116cf336280 Mon Sep 17 00:00:00 2001 From: tiancaishaonvjituizi <452565578@qq.com> Date: Mon, 2 May 2022 11:30:20 +0800 Subject: [PATCH 04/11] implement asgd optimizer --- paddle/fluid/operators/optimizers/asgd_op.cc | 9 +-- paddle/phi/infermeta/multiary.cc | 46 +++++++++++++ paddle/phi/infermeta/multiary.h | 10 +++ paddle/phi/kernels/asgd_kernel.h | 2 +- paddle/phi/kernels/cpu/asgd_kernel.cc | 46 ++++++++++++- paddle/phi/kernels/gpu/asgd_kernel.cu | 67 ++++++++++++++++++- paddle/phi/kernels/impl/asgd_kernel_impl.h | 34 +++++++--- .../fluid/tests/unittests/test_optimizer.py | 28 ++++++-- python/paddle/optimizer/asgd.py | 20 ++++-- 9 files changed, 235 insertions(+), 27 deletions(-) diff --git a/paddle/fluid/operators/optimizers/asgd_op.cc b/paddle/fluid/operators/optimizers/asgd_op.cc index ee6f13ba72551..938b07a270b3f 100644 --- a/paddle/fluid/operators/optimizers/asgd_op.cc +++ b/paddle/fluid/operators/optimizers/asgd_op.cc @@ -37,14 +37,12 @@ class AsgdOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { AddInput("Param", "(Tensor) Input parameter"); - AddInput("LearningRate", "(Tensor) Learning rate of SGD"); AddInput("Grad", "(Tensor) Input gradient"); + AddInput("LearningRate", "(Tensor) Learning rate of SGD"); AddInput("AvgParam", "(Tensor) Average of parameter"); AddInput("CurrentStep", "(Tensor) Current step"); - AddInput("t0", - "(Tensor) point at which to start averaging"); AddOutput("ParamOut", "(Tensor, same with Param) " "Output parameter, should share the same memory with Param"); @@ -53,6 +51,9 @@ class AsgdOpMaker : public framework::OpProtoAndCheckerMaker { AddOutput("CurrentStepOut", "(Tensor) Increased step"); + AddAttr("t0", + "(float, default 1e6) point at which to start averaging") + .SetDefault(0.95f); AddComment(R"DOC( )DOC"); } @@ -63,7 +64,7 @@ class AsgdOpMaker : public framework::OpProtoAndCheckerMaker { namespace ops = paddle::operators; DECLARE_INFER_SHAPE_FUNCTOR(asgd, AsgdInferMetaFunctor, - PD_INFER_META(phi::SgdInferMeta)); + PD_INFER_META(phi::AsgdInferMeta)); REGISTER_OPERATOR( asgd, ops::AsgdOp, ops::AsgdOpMaker, paddle::framework::EmptyGradOpMaker, diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index 519d21b323fc2..55e16bfa49b49 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -1926,6 +1926,52 @@ void RnnInferMeta(const MetaTensor& x, } } +void AsgdInferMeta(const MetaTensor& param, + const MetaTensor& grad, + const MetaTensor& learning_rate, + const MetaTensor& avg_param, + const MetaTensor& current_step, + float t0, + MetaTensor* param_out, + MetaTensor* avg_param_out, + MetaTensor* current_step_out) { + PADDLE_ENFORCE_NOT_NULL(param_out, + phi::errors::InvalidArgument( + "Output(ParamOut) of SGDOp should not be null.")); + + auto lr_dims = learning_rate.dims(); + PADDLE_ENFORCE_EQ(phi::product(lr_dims), + 1, + phi::errors::InvalidArgument( + "Learning rate should have 1 element. But received " + "LearningRate dims [%s]", + phi::product(lr_dims))); + + auto current_step_dims = current_step.dims(); + PADDLE_ENFORCE_EQ(phi::product(current_step_dims), + 1, + phi::errors::InvalidArgument( + "Current step should have 1 element. But received " + "dims [%s]", + phi::product(current_step_dims))); + + auto param_dims = param.dims(); + auto avg_param_dims = avg_param.dims(); + PADDLE_ENFORCE_EQ(param_dims, + avg_param_dims, + phi::errors::InvalidArgument( + "Param and AvgParam should have the same dims. But received " + "[%s] and [%s]", + param_dims, avg_param_dims)); + + param_out->set_dims(param.dims()); + param_out->set_dtype(param.dtype()); + avg_param_out->set_dims(param.dims()); + avg_param_out->set_dtype(param.dtype()); + current_step_out->set_dims(current_step.dims()); + current_step_out->set_dtype(current_step.dtype()); +} + void SgdInferMeta(const MetaTensor& param, const MetaTensor& learning_rate, const MetaTensor& grad, diff --git a/paddle/phi/infermeta/multiary.h b/paddle/phi/infermeta/multiary.h index 65b5819b602ba..a56f4855476ea 100644 --- a/paddle/phi/infermeta/multiary.h +++ b/paddle/phi/infermeta/multiary.h @@ -122,6 +122,16 @@ void AddNInferMeta(const std::vector& x, MetaTensor* out, MetaConfig config = MetaConfig()); +void AsgdInferMeta(const MetaTensor& param, + const MetaTensor& grad, + const MetaTensor& learning_rate, + const MetaTensor& avg_param, + const MetaTensor& current_step, + float t0, + MetaTensor* param_out, + MetaTensor* avg_param_out, + MetaTensor* current_step_out); + void AucInferMeta(const MetaTensor& input, const MetaTensor& label, const MetaTensor& stat_pos, diff --git a/paddle/phi/kernels/asgd_kernel.h b/paddle/phi/kernels/asgd_kernel.h index 1b29bd2906f72..9881f7a3dfabc 100644 --- a/paddle/phi/kernels/asgd_kernel.h +++ b/paddle/phi/kernels/asgd_kernel.h @@ -25,7 +25,7 @@ void AsgdKernel(const Context& dev_ctx, const DenseTensor& grad, const DenseTensor& avg_param, const DenseTensor& current_step, - const DenseTensor& t0, + float t0, DenseTensor* param_out, DenseTensor* avg_param_out, DenseTensor* current_step_out); diff --git a/paddle/phi/kernels/cpu/asgd_kernel.cc b/paddle/phi/kernels/cpu/asgd_kernel.cc index b71e2b67d4852..4a2697a218c29 100644 --- a/paddle/phi/kernels/cpu/asgd_kernel.cc +++ b/paddle/phi/kernels/cpu/asgd_kernel.cc @@ -16,6 +16,50 @@ #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/impl/asgd_kernel_impl.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" +#include "paddle/phi/kernels/funcs/eigen/eigen_function.h" + +namespace phi { + +template +void AsgdKernel(const Context& dev_ctx, + const DenseTensor& param, + const DenseTensor& grad, + const DenseTensor& learning_rate, + const DenseTensor& avg_param, + const DenseTensor& current_step, + float t0, + DenseTensor* param_out, + DenseTensor* avg_param_out, + DenseTensor* current_step_out) { + dev_ctx.template Alloc(param_out); + dev_ctx.template Alloc(avg_param_out); + dev_ctx.template Alloc(current_step_out); + + auto eigen_param = EigenVector::Flatten(param); + auto eigen_grad = EigenVector::Flatten(grad); + auto eigen_avg_param = EigenVector::Flatten(avg_param); + auto eigen_param_out = EigenVector::Flatten(*param_out); + auto eigen_avg_param_out = EigenVector::Flatten(*avg_param_out); + auto& place = *dev_ctx.eigen_device(); + + auto lr = learning_rate.data()[0]; + eigen_param_out.device(place) = eigen_param - lr * eigen_grad; + + T current_step_data = current_step.data()[0]; + + if (current_step_data <= t0) { + eigen_avg_param_out.device(place) = eigen_param_out; + } else { + const auto mu1 = 1 / (current_step_data - t0); + const auto mu2 = 1 - mu1; + eigen_avg_param_out.device(place) = + mu2 * eigen_avg_param + mu1 * eigen_param_out; + } + *current_step_out->mutable_data(dev_ctx.GetPlace()) = + current_step_data + 1; +} + +} // namespace phi PD_REGISTER_KERNEL(asgd, CPU, ALL_LAYOUT, phi::AsgdKernel, float, double) {} diff --git a/paddle/phi/kernels/gpu/asgd_kernel.cu b/paddle/phi/kernels/gpu/asgd_kernel.cu index 2b9e71988b523..8db5889af4c7a 100644 --- a/paddle/phi/kernels/gpu/asgd_kernel.cu +++ b/paddle/phi/kernels/gpu/asgd_kernel.cu @@ -12,10 +12,71 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/phi/kernels/asgd_kernel.h" - #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/impl/asgd_kernel_impl.h" +#include "paddle/phi/kernels/asgd_kernel.h" + +namespace phi { + +template +__global__ void ASGDKernel(const T* param, + const T* grad, + const T* learning_rate, + const T* avg_param, + const T* current_step, + float t0, + size_t num, + T* param_out, + T* avg_param_out) { + T lr = learning_rate[0]; + CUDA_KERNEL_LOOP(i, num) { param_out[i] = param[i] - lr * grad[i]; } + T current_step_data = current_step[0]; + if (current_step_data <= t0) { + memcpy(avg_param_out, param, num * sizeof(T)); + } else { + const auto mu1 = 1 / (current_step_data - t0); + const auto mu2 = 1 - mu1; + CUDA_KERNEL_LOOP(i, num) { + avg_param_out[i] = mu2 * avg_param[i] + mu1 * param_out[i]; + } + } +} + +template +__global__ void IncreaseStep(const T* step, T* step_out) { + *step_out = *step + 1; +} + +template +void AsgdKernel(const Context& dev_ctx, + const DenseTensor& param, + const DenseTensor& learning_rate, + const DenseTensor& grad, + const DenseTensor& avg_param, + const DenseTensor& current_step, + float t0, + DenseTensor* param_out, + DenseTensor* avg_param_out, + DenseTensor* current_step_out) { + int block = 512; + int grid = (param.numel() + block - 1) / block; + + ASGDKernel<<>>( + param.data(), + grad.data(), + learning_rate.data(), + avg_param.data(), + current_step.data(), + t0, + param.numel(), + param_out->mutable_data(dev_ctx.GetPlace()), + avg_param_out->mutable_data(dev_ctx.GetPlace())); + + IncreaseStep<<<1, 1, 0, dev_ctx.stream()>>>( + current_step.data(), + current_step_out->mutable_data(dev_ctx.GetPlace())); +} + +} // namespace phi PD_REGISTER_KERNEL(asgd, GPU, ALL_LAYOUT, phi::AsgdKernel, float, double) {} diff --git a/paddle/phi/kernels/impl/asgd_kernel_impl.h b/paddle/phi/kernels/impl/asgd_kernel_impl.h index 33ba3eaa54410..38d44a1a50681 100644 --- a/paddle/phi/kernels/impl/asgd_kernel_impl.h +++ b/paddle/phi/kernels/impl/asgd_kernel_impl.h @@ -13,6 +13,7 @@ // limitations under the License. #pragma once +#error "should not include this" #include "paddle/fluid/platform/place.h" #include "paddle/phi/kernels/adadelta_kernel.h" @@ -21,6 +22,25 @@ namespace phi { +// inline HOSTDEVICE bool is_less(const double* current_step, float t0) { +// printf("%d", __LINE__); +// printf("%lf", *current_step); +// printf("%f", t0); +// return current_step[0] < t0; +// } +// +// inline HOSTDEVICE bool is_less(const float* current_step, float t0) { +// printf("%d", __LINE__); +// #ifdef __CUDA_ARCH__ +// printf("cuda"); +// #else +// printf("cpu"); +// #endif +// printf("%f", *current_step); +// printf("%f", t0); +// return current_step[0] < t0; +// } +// template void AsgdKernel(const Context& dev_ctx, const DenseTensor& param, @@ -28,18 +48,18 @@ void AsgdKernel(const Context& dev_ctx, const DenseTensor& grad, const DenseTensor& avg_param, const DenseTensor& current_step, - const DenseTensor& t0, + float t0, DenseTensor* param_out, DenseTensor* avg_param_out, DenseTensor* current_step_out) { dev_ctx.template Alloc(param_out); dev_ctx.template Alloc(avg_param_out); + dev_ctx.template Alloc(current_step_out); auto eigen_param = EigenVector::Flatten(param); auto eigen_grad = EigenVector::Flatten(grad); auto eigen_avg_param = EigenVector::Flatten(avg_param); auto eigen_current_step = EigenVector::Flatten(current_step); - auto eigen_t0 = EigenVector::Flatten(t0); auto eigen_param_out = EigenVector::Flatten(*param_out); auto eigen_avg_param_out = EigenVector::Flatten(*avg_param_out); auto eigen_current_step_out = EigenVector::Flatten(*current_step_out); @@ -55,16 +75,14 @@ void AsgdKernel(const Context& dev_ctx, eigen_param - eigen_lr.broadcast(dsize) * eigen_grad; } - if (eigen_current_step < eigen_t0) { + if (current_step.data()[0] < t0) { eigen_avg_param_out.device(place) = eigen_param_out; } else { - // const auto mu = eigen_current_step - eigen_t0 + 1; + const auto mu = (1 - t0) + eigen_current_step; eigen_avg_param_out.device(place) = - eigen_avg_param + (eigen_param_out - eigen_avg_param) / eigen_current_step; + eigen_avg_param + mu * (eigen_param_out - eigen_avg_param); } - - // eigen_current_step_out = eigen_current_step + 1; - eigen_current_step++; + eigen_current_step_out.device(place) = 1 + eigen_current_step; } } // namespace phi diff --git a/python/paddle/fluid/tests/unittests/test_optimizer.py b/python/paddle/fluid/tests/unittests/test_optimizer.py index c736753ff4fdd..8a2eb10252057 100644 --- a/python/paddle/fluid/tests/unittests/test_optimizer.py +++ b/python/paddle/fluid/tests/unittests/test_optimizer.py @@ -70,22 +70,38 @@ def check_sgd_optimizer(optimizer_attr): self.assertEqual([op.type for op in opts], ["sgd"]) def test_asgd_optimizer(self): + w_shape = [3, 4] class MyLayer(paddle.nn.Layer): def __init__(self): super(MyLayer, self).__init__() - self._w = self.create_parameter([2, 3]) - self._b = self.create_parameter([2, 3]) + self._w = self.create_parameter(w_shape, default_initializer=paddle.fluid.initializer.ConstantInitializer()) def forward(self, x): - return x * self._w + self._b + return x * self._w with paddle.fluid.dygraph.guard(): model = MyLayer() - x = paddle.rand([10, 2, 3]) + x = paddle.ones([1, 3, 4]) loss = model(x) - adam = paddle.optimizer.Adam(parameters=model.parameters()) + asgd = paddle.optimizer.ASGD(learning_rate=1., parameters=model.parameters(), t0=1) loss.backward() - adam.step() + + np_neg_ones = np.ones(w_shape) * -1 + asgd.step() + assert np.array_equal(model._w.numpy(), np_neg_ones) + assert np.array_equal(asgd.averaged_parameters()[0].numpy(), np_neg_ones) + asgd.step() + assert np.array_equal(model._w.numpy(), np_neg_ones * 2) + assert np.array_equal(asgd.averaged_parameters()[0].numpy(), np_neg_ones * 2) + asgd.step() + assert np.array_equal(model._w.numpy(), np_neg_ones * 3) + assert np.array_equal(asgd.averaged_parameters()[0].numpy(), np_neg_ones * 3) + asgd.step() + assert np.array_equal(model._w.numpy(), np_neg_ones * 4) + assert np.array_equal(asgd.averaged_parameters()[0].numpy(), np_neg_ones * 3.5) + asgd.step() + assert np.array_equal(model._w.numpy(), np_neg_ones * 5) + assert np.array_equal(asgd.averaged_parameters()[0].numpy(), np_neg_ones * 4) class TestOptimizerBackwardApplygrad(unittest.TestCase): diff --git a/python/paddle/optimizer/asgd.py b/python/paddle/optimizer/asgd.py index ae2022f16a9b3..9fa2a001d44cf 100644 --- a/python/paddle/optimizer/asgd.py +++ b/python/paddle/optimizer/asgd.py @@ -24,6 +24,7 @@ class ASGD(Optimizer): r""" """ _avg_parameter_acc_str = "_avg_parameter" + _current_step_acc_str = "_current_step" def __init__(self, learning_rate, @@ -51,8 +52,11 @@ def _create_accumulators(self, block, parameters): if isinstance(parameters, dict): parameters = parameters.get('params') + self._averaged_parameters = [] for p in parameters: - self._add_accumulator(self._avg_parameter_acc_str, p) + avg_param = self._add_accumulator(self._avg_parameter_acc_str, p) + self._averaged_parameters.append(avg_param) + self._add_accumulator(self._current_step_acc_str, p, shape=[1], device='cpu') def _append_optimize_op(self, block, param_and_grad): assert isinstance(block, framework.Block) @@ -63,17 +67,21 @@ def _append_optimize_op(self, block, param_and_grad): avg_parameter = self._get_accumulator( self._avg_parameter_acc_str, param_and_grad[0]) - # Create the adagrad optimizer op + current_step = self._get_accumulator( + self._current_step_acc_str, param_and_grad[0]) + asgd_op = block.append_op( type=self.type, inputs={ "Param": param_and_grad[0], "Grad": param_and_grad[1], "LearningRate": self._create_param_lr(param_and_grad), - "AvgParam": avg_parameter + "AvgParam": avg_parameter, + "CurrentStep": current_step, }, outputs={"ParamOut": param_and_grad[0], - "AvgParamOut": avg_parameter + "AvgParamOut": avg_parameter, + "CurrentStepOut": current_step, }, attrs={'t0': self._t0}, stop_gradient=True) @@ -84,3 +92,7 @@ def _update_param_group(self, parameters): self._t0 = parameters.get('t0', self._default_dict['t0']) parameters = parameters.get('params') return parameters + + def averaged_parameters(self): + return self._averaged_parameters + From ef95457818a07d9debc12113d488dbc7575b5e31 Mon Sep 17 00:00:00 2001 From: tiancaishaonvjituizi <452565578@qq.com> Date: Mon, 2 May 2022 11:35:38 +0800 Subject: [PATCH 05/11] refine --- paddle/phi/kernels/gpu/asgd_kernel.cu | 3 +- paddle/phi/kernels/impl/asgd_kernel_impl.h | 88 ---------------------- 2 files changed, 2 insertions(+), 89 deletions(-) delete mode 100644 paddle/phi/kernels/impl/asgd_kernel_impl.h diff --git a/paddle/phi/kernels/gpu/asgd_kernel.cu b/paddle/phi/kernels/gpu/asgd_kernel.cu index 8db5889af4c7a..2eefd1831a40e 100644 --- a/paddle/phi/kernels/gpu/asgd_kernel.cu +++ b/paddle/phi/kernels/gpu/asgd_kernel.cu @@ -12,9 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "paddle/phi/kernels/asgd_kernel.h" + #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/asgd_kernel.h" namespace phi { diff --git a/paddle/phi/kernels/impl/asgd_kernel_impl.h b/paddle/phi/kernels/impl/asgd_kernel_impl.h deleted file mode 100644 index 38d44a1a50681..0000000000000 --- a/paddle/phi/kernels/impl/asgd_kernel_impl.h +++ /dev/null @@ -1,88 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once -#error "should not include this" - -#include "paddle/fluid/platform/place.h" -#include "paddle/phi/kernels/adadelta_kernel.h" -#include "paddle/phi/kernels/funcs/eigen/common.h" -#include "paddle/phi/kernels/funcs/eigen/eigen_function.h" - -namespace phi { - -// inline HOSTDEVICE bool is_less(const double* current_step, float t0) { -// printf("%d", __LINE__); -// printf("%lf", *current_step); -// printf("%f", t0); -// return current_step[0] < t0; -// } -// -// inline HOSTDEVICE bool is_less(const float* current_step, float t0) { -// printf("%d", __LINE__); -// #ifdef __CUDA_ARCH__ -// printf("cuda"); -// #else -// printf("cpu"); -// #endif -// printf("%f", *current_step); -// printf("%f", t0); -// return current_step[0] < t0; -// } -// -template -void AsgdKernel(const Context& dev_ctx, - const DenseTensor& param, - const DenseTensor& learning_rate, - const DenseTensor& grad, - const DenseTensor& avg_param, - const DenseTensor& current_step, - float t0, - DenseTensor* param_out, - DenseTensor* avg_param_out, - DenseTensor* current_step_out) { - dev_ctx.template Alloc(param_out); - dev_ctx.template Alloc(avg_param_out); - dev_ctx.template Alloc(current_step_out); - - auto eigen_param = EigenVector::Flatten(param); - auto eigen_grad = EigenVector::Flatten(grad); - auto eigen_avg_param = EigenVector::Flatten(avg_param); - auto eigen_current_step = EigenVector::Flatten(current_step); - auto eigen_param_out = EigenVector::Flatten(*param_out); - auto eigen_avg_param_out = EigenVector::Flatten(*avg_param_out); - auto eigen_current_step_out = EigenVector::Flatten(*current_step_out); - auto& place = *dev_ctx.eigen_device(); - - if (paddle::platform::is_cpu_place(dev_ctx.GetPlace())) { - auto lr = learning_rate.data()[0]; - eigen_param_out.device(place) = eigen_param - lr * eigen_grad; - } else { - Eigen::DSizes dsize(param_out->numel()); - auto eigen_lr = EigenVector::Flatten(learning_rate); - eigen_param_out.device(place) = - eigen_param - eigen_lr.broadcast(dsize) * eigen_grad; - } - - if (current_step.data()[0] < t0) { - eigen_avg_param_out.device(place) = eigen_param_out; - } else { - const auto mu = (1 - t0) + eigen_current_step; - eigen_avg_param_out.device(place) = - eigen_avg_param + mu * (eigen_param_out - eigen_avg_param); - } - eigen_current_step_out.device(place) = 1 + eigen_current_step; -} - -} // namespace phi From 2c6f29d845c5c4b576be95b770fc19548fe2578b Mon Sep 17 00:00:00 2001 From: tiancaishaonvjituizi <452565578@qq.com> Date: Tue, 3 May 2022 11:43:46 +0800 Subject: [PATCH 06/11] replace array_equal with allclose --- .../fluid/tests/unittests/test_optimizer.py | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_optimizer.py b/python/paddle/fluid/tests/unittests/test_optimizer.py index 8a2eb10252057..141967adb33e6 100644 --- a/python/paddle/fluid/tests/unittests/test_optimizer.py +++ b/python/paddle/fluid/tests/unittests/test_optimizer.py @@ -88,20 +88,20 @@ def forward(self, x): np_neg_ones = np.ones(w_shape) * -1 asgd.step() - assert np.array_equal(model._w.numpy(), np_neg_ones) - assert np.array_equal(asgd.averaged_parameters()[0].numpy(), np_neg_ones) + assert np.allclose(model._w.numpy(), np_neg_ones) + assert np.allclose(asgd.averaged_parameters()[0].numpy(), np_neg_ones) asgd.step() - assert np.array_equal(model._w.numpy(), np_neg_ones * 2) - assert np.array_equal(asgd.averaged_parameters()[0].numpy(), np_neg_ones * 2) + assert np.allclose(model._w.numpy(), np_neg_ones * 2) + assert np.allclose(asgd.averaged_parameters()[0].numpy(), np_neg_ones * 2) asgd.step() - assert np.array_equal(model._w.numpy(), np_neg_ones * 3) - assert np.array_equal(asgd.averaged_parameters()[0].numpy(), np_neg_ones * 3) + assert np.allclose(model._w.numpy(), np_neg_ones * 3) + assert np.allclose(asgd.averaged_parameters()[0].numpy(), np_neg_ones * 3) asgd.step() - assert np.array_equal(model._w.numpy(), np_neg_ones * 4) - assert np.array_equal(asgd.averaged_parameters()[0].numpy(), np_neg_ones * 3.5) + assert np.allclose(model._w.numpy(), np_neg_ones * 4) + assert np.allclose(asgd.averaged_parameters()[0].numpy(), np_neg_ones * 3.5) asgd.step() - assert np.array_equal(model._w.numpy(), np_neg_ones * 5) - assert np.array_equal(asgd.averaged_parameters()[0].numpy(), np_neg_ones * 4) + assert np.allclose(model._w.numpy(), np_neg_ones * 5) + assert np.allclose(asgd.averaged_parameters()[0].numpy(), np_neg_ones * 4) class TestOptimizerBackwardApplygrad(unittest.TestCase): From 8a48bfb048aafe91ffa0198d352da3777953f433 Mon Sep 17 00:00:00 2001 From: tiancaishaonvjituizi <452565578@qq.com> Date: Tue, 3 May 2022 16:14:02 +0800 Subject: [PATCH 07/11] add print to debug --- python/paddle/fluid/tests/unittests/test_optimizer.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/paddle/fluid/tests/unittests/test_optimizer.py b/python/paddle/fluid/tests/unittests/test_optimizer.py index 141967adb33e6..6816317ce0282 100644 --- a/python/paddle/fluid/tests/unittests/test_optimizer.py +++ b/python/paddle/fluid/tests/unittests/test_optimizer.py @@ -91,6 +91,9 @@ def forward(self, x): assert np.allclose(model._w.numpy(), np_neg_ones) assert np.allclose(asgd.averaged_parameters()[0].numpy(), np_neg_ones) asgd.step() + print(f'w: {model._w.numpy()}') + print(f'-2: {np_neg_ones * 2}') + print(f'avg: {asgd.averaged_parameters()[0].numpy()}') assert np.allclose(model._w.numpy(), np_neg_ones * 2) assert np.allclose(asgd.averaged_parameters()[0].numpy(), np_neg_ones * 2) asgd.step() From 506d05b9954f664d857e3f30ceb71368b4e55531 Mon Sep 17 00:00:00 2001 From: tiancaishaonvjituizi <452565578@qq.com> Date: Tue, 3 May 2022 19:47:26 +0800 Subject: [PATCH 08/11] add print to debug --- python/paddle/fluid/tests/unittests/test_optimizer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_optimizer.py b/python/paddle/fluid/tests/unittests/test_optimizer.py index 6816317ce0282..c139d235b5ff5 100644 --- a/python/paddle/fluid/tests/unittests/test_optimizer.py +++ b/python/paddle/fluid/tests/unittests/test_optimizer.py @@ -90,10 +90,10 @@ def forward(self, x): asgd.step() assert np.allclose(model._w.numpy(), np_neg_ones) assert np.allclose(asgd.averaged_parameters()[0].numpy(), np_neg_ones) + print(f'grad: {model._w.grad}') + print(f'w before step: {model._w.numpy()}') asgd.step() - print(f'w: {model._w.numpy()}') - print(f'-2: {np_neg_ones * 2}') - print(f'avg: {asgd.averaged_parameters()[0].numpy()}') + print(f'w after step: {model._w.numpy()}') assert np.allclose(model._w.numpy(), np_neg_ones * 2) assert np.allclose(asgd.averaged_parameters()[0].numpy(), np_neg_ones * 2) asgd.step() From 0403b084627ca3909998cc38dd3b88b71dffbf69 Mon Sep 17 00:00:00 2001 From: tiancaishaonvjituizi <452565578@qq.com> Date: Tue, 3 May 2022 22:02:14 +0800 Subject: [PATCH 09/11] add print to debug --- python/paddle/fluid/tests/unittests/test_optimizer.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/python/paddle/fluid/tests/unittests/test_optimizer.py b/python/paddle/fluid/tests/unittests/test_optimizer.py index c139d235b5ff5..dc13157229002 100644 --- a/python/paddle/fluid/tests/unittests/test_optimizer.py +++ b/python/paddle/fluid/tests/unittests/test_optimizer.py @@ -87,7 +87,11 @@ def forward(self, x): loss.backward() np_neg_ones = np.ones(w_shape) * -1 + print(f'grad before step: {model._w.grad}') + print(f'w before step: {model._w.numpy()}') asgd.step() + print(f'w after step: {model._w.numpy()}') + print(f'grad after step: {model._w.grad}') assert np.allclose(model._w.numpy(), np_neg_ones) assert np.allclose(asgd.averaged_parameters()[0].numpy(), np_neg_ones) print(f'grad: {model._w.grad}') From f3538735674d302c683d643dde0308beb96caef1 Mon Sep 17 00:00:00 2001 From: tiancaishaonvjituizi <452565578@qq.com> Date: Wed, 4 May 2022 15:45:38 +0800 Subject: [PATCH 10/11] update test --- .../fluid/tests/unittests/test_optimizer.py | 30 ++++++++++++------- 1 file changed, 20 insertions(+), 10 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_optimizer.py b/python/paddle/fluid/tests/unittests/test_optimizer.py index dc13157229002..e2ab12c9beb3c 100644 --- a/python/paddle/fluid/tests/unittests/test_optimizer.py +++ b/python/paddle/fluid/tests/unittests/test_optimizer.py @@ -80,32 +80,42 @@ def forward(self, x): return x * self._w with paddle.fluid.dygraph.guard(): + np_neg_ones = np.ones(w_shape) * -1 + model = MyLayer() x = paddle.ones([1, 3, 4]) - loss = model(x) asgd = paddle.optimizer.ASGD(learning_rate=1., parameters=model.parameters(), t0=1) - loss.backward() - np_neg_ones = np.ones(w_shape) * -1 - print(f'grad before step: {model._w.grad}') - print(f'w before step: {model._w.numpy()}') + loss = model(x) + loss.backward() asgd.step() - print(f'w after step: {model._w.numpy()}') - print(f'grad after step: {model._w.grad}') assert np.allclose(model._w.numpy(), np_neg_ones) assert np.allclose(asgd.averaged_parameters()[0].numpy(), np_neg_ones) - print(f'grad: {model._w.grad}') - print(f'w before step: {model._w.numpy()}') + asgd.clear_grad() + + loss = model(x) + loss.backward() asgd.step() - print(f'w after step: {model._w.numpy()}') assert np.allclose(model._w.numpy(), np_neg_ones * 2) assert np.allclose(asgd.averaged_parameters()[0].numpy(), np_neg_ones * 2) + asgd.clear_grad() + + loss = model(x) + loss.backward() asgd.step() assert np.allclose(model._w.numpy(), np_neg_ones * 3) assert np.allclose(asgd.averaged_parameters()[0].numpy(), np_neg_ones * 3) + asgd.clear_grad() + + loss = model(x) + loss.backward() asgd.step() assert np.allclose(model._w.numpy(), np_neg_ones * 4) assert np.allclose(asgd.averaged_parameters()[0].numpy(), np_neg_ones * 3.5) + asgd.clear_grad() + + loss = model(x) + loss.backward() asgd.step() assert np.allclose(model._w.numpy(), np_neg_ones * 5) assert np.allclose(asgd.averaged_parameters()[0].numpy(), np_neg_ones * 4) From 3a21267a1f4f6e73704340c3b52f720765a3d4c0 Mon Sep 17 00:00:00 2001 From: tiancaishaonvjituizi <452565578@qq.com> Date: Wed, 4 May 2022 16:48:20 +0800 Subject: [PATCH 11/11] add print to debug --- python/paddle/fluid/tests/unittests/test_optimizer.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/python/paddle/fluid/tests/unittests/test_optimizer.py b/python/paddle/fluid/tests/unittests/test_optimizer.py index e2ab12c9beb3c..bd4e014e2b6b4 100644 --- a/python/paddle/fluid/tests/unittests/test_optimizer.py +++ b/python/paddle/fluid/tests/unittests/test_optimizer.py @@ -87,14 +87,18 @@ def forward(self, x): asgd = paddle.optimizer.ASGD(learning_rate=1., parameters=model.parameters(), t0=1) loss = model(x) + print(f'1: w grad before bw: {model._w.grad}') loss.backward() + print(f'1: w grad: {model._w.grad}') asgd.step() assert np.allclose(model._w.numpy(), np_neg_ones) assert np.allclose(asgd.averaged_parameters()[0].numpy(), np_neg_ones) asgd.clear_grad() loss = model(x) + print(f'2: w grad before bw: {model._w.grad}') loss.backward() + print(f'2: w grad: {model._w.grad}') asgd.step() assert np.allclose(model._w.numpy(), np_neg_ones * 2) assert np.allclose(asgd.averaged_parameters()[0].numpy(), np_neg_ones * 2)