diff --git a/paddle/fluid/operators/average_accumulates_op.cc b/paddle/fluid/operators/average_accumulates_op.cc new file mode 100644 index 0000000000000..c95077fcbdb6b --- /dev/null +++ b/paddle/fluid/operators/average_accumulates_op.cc @@ -0,0 +1,216 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/average_accumulates_op.h" + +namespace paddle { +namespace operators { + +template <> +void GetAccumulators( + const framework::ExecutionContext& ctx, int64_t& num_updates_, + int64_t& num_accumulates_, int64_t& old_num_accumulates_) { + auto* in_old_num_accumulates = ctx.Input("in_old_num_accumulates"); + auto* in_num_accumulates = ctx.Input("in_num_accumulates"); + auto* in_num_updates = ctx.Input("in_num_updates"); + + old_num_accumulates_ = in_old_num_accumulates->data()[0]; + num_accumulates_ = in_num_accumulates->data()[0]; + num_updates_ = in_num_updates->data()[0]; +} + +template <> +void SetAccumulators( + const framework::ExecutionContext& ctx, int64_t num_updates_, + int64_t num_accumulates_, int64_t old_num_accumulates_) { + auto* out_old_num_accumulates = ctx.Output("out_old_num_accumulates"); + auto* out_num_accumulates = ctx.Output("out_num_accumulates"); + auto* out_num_updates = ctx.Output("out_num_updates"); + + out_old_num_accumulates->data()[0] = old_num_accumulates_; + out_num_accumulates->data()[0] = num_accumulates_; + out_num_updates->data()[0] = num_updates_; +} + +class AverageAccumulatesOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE( + ctx->HasInput("param"), + "Input (param) of average_accumulates op should not be null."); + PADDLE_ENFORCE( + ctx->HasInput("in_sum_1"), + "Input (sum_1) of average_accumulates op should not be null."); + PADDLE_ENFORCE( + ctx->HasInput("in_sum_2"), + "Input (sum_2) of average_accumulates op should not be null."); + PADDLE_ENFORCE( + ctx->HasInput("in_sum_3"), + "Input (sum_3) of average_accumulates op should not be null."); + PADDLE_ENFORCE( + ctx->HasInput("in_num_accumulates"), + "Input (in_num_accumulates) of average_accumulates op should " + "not be null."); + PADDLE_ENFORCE(ctx->HasInput("in_old_num_accumulates"), + "Input (old_num_accumulates) of average_accumulates op " + "should not be null."); + PADDLE_ENFORCE( + ctx->HasInput("in_num_updates"), + "Input (num_updates) of average_accumulates op should not be null."); + + PADDLE_ENFORCE( + ctx->HasOutput("out_sum_1"), + "Output (sum_1) of average_accumulates op should not be null."); + PADDLE_ENFORCE( + ctx->HasOutput("out_sum_2"), + "Output (sum_2) of average_accumulates op should not be null."); + PADDLE_ENFORCE( + ctx->HasOutput("out_sum_3"), + "Output (sum_3) of average_accumulates op should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("out_num_accumulates"), + "Output (num_accumulates) of average_accumulates op should " + "not be null."); + PADDLE_ENFORCE(ctx->HasOutput("out_old_num_accumulates"), + "Output (old_num_accumulates) of average_accumulates op " + "should not be null."); + PADDLE_ENFORCE( + ctx->HasOutput("out_num_updates"), + "Output (num_updates) of average_accumulates op should not be null."); + + auto in_dim = ctx->GetInputDim("param"); + + ctx->SetOutputDim("out_sum_1", in_dim); + ctx->SetOutputDim("out_sum_2", in_dim); + ctx->SetOutputDim("out_sum_3", in_dim); + ctx->SetOutputDim("out_num_accumulates", {1}); + ctx->SetOutputDim("out_old_num_accumulates", {1}); + ctx->SetOutputDim("out_num_updates", {1}); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("param")->type()), + ctx.GetPlace()); + } +}; + +class AverageAccumulatesOpMaker : public framework::OpProtoAndCheckerMaker { + public: + AverageAccumulatesOpMaker(OpProto* proto, OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("param", "(Tensor), The parameter to be accumulated."); + AddInput("in_sum_1", + "(Tensor), A tensor used to store the parameter " + "sums with the same shape as input(param)."); + AddInput("in_sum_2", + "(Tensor), A auxiliary tensor to help " + "accumulating sums of parameter values with the same shape as " + "input(param). It is used to avoid loss of precision due to too " + "many sums."); + AddInput("in_sum_3", + "(Tensor), A auxiliary tensor to help " + "accumulating sums of parameter values with the same shape as " + "input(param)."); + AddInput("in_num_accumulates", + "(Tensor), The accumulating times of current window with " + "shape [1]."); + AddInput( + "in_old_num_accumulates", + "(Tensor), The accumulating times of previous window with " + "shape [1]."); + AddInput("in_num_updates", + "(Tensor), The total number of batches used by trainning " + "before this batch with shape [1]."); + + AddOutput("out_sum_1", + "(Tensor), A tensor used to store the " + "parameter sums with the same shape as input(param)."); + AddOutput("out_sum_2", + "(Tensor), A auxiliary tensor to help " + "accumulating sums of parameter values with the same shape as " + "input(param). It is used to avoid loss of precision due to too " + "many sums."); + AddOutput("out_sum_3", + "(Tensor), A auxiliary tensor to help " + "accumulating sums of parameter values with the same shape as " + "input(param)."); + AddOutput( + "out_num_accumulates", + "(Tensor), The accumulating times of current window with " + "shape [1]."); + AddOutput( + "out_old_num_accumulates", + "(Tensor) The accumulating times of previous window with " + "shape [1]."); + AddOutput( + "out_num_updates", + "(Tensor), The total number of batches used by trainning " + "before this batch with shape [1]."); + + AddAttr("average_window", + "(float, default 0) " + "The rate of average window size relative to num_updates.") + .SetDefault(0); + AddAttr("max_average_window", + "(int64_t) " + "Maximum size of average window. It suggests that the " + "number of mini-batches " + "in one pass is appropriate value to set."); + AddAttr("min_average_window", + "(int64_t, default 10000L) " + "Minimu size of average window.") + .SetDefault(10000L); + + AddComment(R"DOC( +AverageAccumulates Operator. +Accumulate the sum of parameter whtin sliding window. The size of sliding window is +determined by 'average_window', 'max_average_window' and 'min_average_window'. +Memory was shared by Input(in_sum_1) and Output(out_sum_1) which acts as an accumulator 'sum_1'. +'sum_2', 'sum_3', 'num_accumulates', 'old_num_accumulates' and 'num_updates' were the same as 'sum_1'. + +All the accumulators were inited to zero before training. + +And for a mini-batch in training, accumulators were computed as below steps: + num_updates += 1 + num_accumulates += 1 + sum_1 += param + if num_updates % kMaxNumAccumulates == 0: + sum_2 += sum_1 + sum_1 = 0 + if num_accumulates >= min_average_window && num_accumulates >= min(max_average_window, num_updates * average_window): + sum_3 = sum_1 + sum_2 + sum_1 = 0 + sum_2 = 0 + old_num_accumulates = num_accumulates + num_accumulates = 0 + +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(average_accumulates, ops::AverageAccumulatesOp, + ops::AverageAccumulatesOpMaker, + paddle::framework::EmptyGradOpMaker); +REGISTER_OP_CPU_KERNEL( + average_accumulates, + ops::AverageAccumulatesKernel, + ops::AverageAccumulatesKernel); diff --git a/paddle/fluid/operators/average_accumulates_op.cu b/paddle/fluid/operators/average_accumulates_op.cu new file mode 100644 index 0000000000000..270c46984465e --- /dev/null +++ b/paddle/fluid/operators/average_accumulates_op.cu @@ -0,0 +1,63 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/average_accumulates_op.h" +#include "paddle/fluid/platform/gpu_info.h" + +namespace paddle { +namespace operators { +template <> +void GetAccumulators( + const framework::ExecutionContext& ctx, int64_t& num_updates_, + int64_t& num_accumulates_, int64_t& old_num_accumulates_) { + auto* in_old_num_accumulates = ctx.Input("in_old_num_accumulates"); + auto* in_num_accumulates = ctx.Input("in_num_accumulates"); + auto* in_num_updates = ctx.Input("in_num_updates"); + auto stream = ctx.cuda_device_context().stream(); + memory::Copy(platform::CPUPlace(), &old_num_accumulates_, + platform::CUDAPlace(), in_old_num_accumulates->data(), + sizeof(int64_t), stream); + memory::Copy(platform::CPUPlace(), &num_accumulates_, platform::CUDAPlace(), + in_num_accumulates->data(), sizeof(int64_t), stream); + memory::Copy(platform::CPUPlace(), &num_updates_, platform::CUDAPlace(), + in_num_updates->data(), sizeof(int64_t), stream); +} + +template <> +void SetAccumulators( + const framework::ExecutionContext& ctx, int64_t num_updates_, + int64_t num_accumulates_, int64_t old_num_accumulates_) { + auto stream = ctx.cuda_device_context().stream(); + auto* out_old_num_accumulates = ctx.Output("out_old_num_accumulates"); + auto* out_num_accumulates = ctx.Output("out_num_accumulates"); + auto* out_num_updates = ctx.Output("out_num_updates"); + + memory::Copy(platform::CUDAPlace(), out_old_num_accumulates->data(), + platform::CPUPlace(), &old_num_accumulates_, sizeof(int64_t), + stream); + memory::Copy(platform::CUDAPlace(), out_num_accumulates->data(), + platform::CPUPlace(), &num_accumulates_, sizeof(int64_t), + stream); + memory::Copy(platform::CUDAPlace(), out_num_updates->data(), + platform::CPUPlace(), &num_updates_, sizeof(int64_t), stream); +} + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL( + average_accumulates, + ops::AverageAccumulatesKernel, + ops::AverageAccumulatesKernel); diff --git a/paddle/fluid/operators/average_accumulates_op.h b/paddle/fluid/operators/average_accumulates_op.h new file mode 100644 index 0000000000000..f858109d1428d --- /dev/null +++ b/paddle/fluid/operators/average_accumulates_op.h @@ -0,0 +1,113 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/math_function.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +using EigenVector = framework::EigenVector; + +template +void GetAccumulators(const framework::ExecutionContext& ctx, + int64_t& num_updates, int64_t& num_accumulates, + int64_t& old_num_accumulates); + +template +void SetAccumulators(const framework::ExecutionContext& ctx, + int64_t num_updates, int64_t num_accumulates, + int64_t old_num_accumulates); + +template +class AverageAccumulatesKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + // It is used to avoid loss of precision + static const int64_t kMaxNumAccumulates = 16384; + // Get accumulators from input + int64_t num_updates = 0; + int64_t num_accumulates = 0; + int64_t old_num_accumulates = 0; + GetAccumulators(ctx, num_updates, num_accumulates, + old_num_accumulates); + + // Get attrs + float average_window = ctx.Attr("average_window"); + int64_t max_average_window = ctx.Attr("max_average_window"); + int64_t min_average_window = ctx.Attr("min_average_window"); + min_average_window = + std::min(min_average_window, max_average_window); + + // Get inputs + auto* param = ctx.Input("param"); + auto* in_sum_1 = ctx.Input("in_sum_1"); + auto* in_sum_2 = ctx.Input("in_sum_2"); + auto* in_sum_3 = ctx.Input("in_sum_3"); + auto param_tensor = EigenVector::Flatten(*param); + auto in_sum_1_tensor = EigenVector::Flatten(*in_sum_1); + auto in_sum_2_tensor = EigenVector::Flatten(*in_sum_2); + auto in_sum_3_tensor = EigenVector::Flatten(*in_sum_3); + + // Get outputs + auto* out_sum_1 = ctx.Output("out_sum_1"); + auto* out_sum_2 = ctx.Output("out_sum_2"); + auto* out_sum_3 = ctx.Output("out_sum_3"); + auto out_sum_1_tensor = EigenVector::Flatten(*out_sum_1); + auto out_sum_2_tensor = EigenVector::Flatten(*out_sum_2); + auto out_sum_3_tensor = EigenVector::Flatten(*out_sum_3); + + // Compute + auto& place = *ctx.template device_context().eigen_device(); + math::SetConstant constant_functor; + ++num_updates; + ++num_accumulates; + out_sum_1_tensor.device(place) = in_sum_1_tensor + param_tensor; + out_sum_2_tensor.device(place) = in_sum_2_tensor; + out_sum_3_tensor.device(place) = in_sum_3_tensor; + if (num_updates % kMaxNumAccumulates == 0) { + // Move the sum to a different buffer to avoid loss of precision due to + // too many sums. + out_sum_2_tensor.device(place) = in_sum_2_tensor + in_sum_1_tensor; + constant_functor(ctx.template device_context(), out_sum_1, + 0.0); + } + if (num_accumulates >= min_average_window && + num_accumulates >= std::min(max_average_window, + num_updates * average_window)) { + // Now the average window is too long, discard the old sum. + out_sum_3_tensor.device(place) = in_sum_1_tensor + in_sum_2_tensor; + constant_functor(ctx.template device_context(), out_sum_1, + 0.0); + constant_functor(ctx.template device_context(), out_sum_2, + 0.0); + old_num_accumulates = num_accumulates; + num_accumulates = 0; + } + + // Set accumulators to output + SetAccumulators(ctx, num_updates, num_accumulates, + old_num_accumulates); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 70ecffd910a46..3e78788f47055 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -918,6 +918,24 @@ def copy_param_info_from(self, other): name=v.name) self.vars[new_p.name] = new_p + def clone_variable(self, var): + """ + Clone a variable into current block. + Args: + var: the variable to be cloned. + + Returns: + The new variable cloned from 'var' in current block. + """ + assert isinstance(var, Variable) + return self.create_var( + name=var.name, + shape=var.shape, + dtype=var.dtype, + type=var.type, + lod_level=var.lod_level, + persistable=True) + class Program(object): def __init__(self): @@ -960,14 +978,14 @@ def clone(self, for_test=False): """Clone the Program object Set for_test to False when we want to clone the program for training. - Set for_test to True when we want to clone the program for testing. + Set for_test to True when we want to clone the program for testing. Args: for_test(bool): Some operators, such as batch_norm and drop_out ops, behave differently in training and testing. If for_test is True, the is_test attributes in these operators will be set to True for - testing purposes, otherwise, they remain unchanged. - + testing purposes, otherwise, they remain unchanged. + Returns(Program): The cloned Program object. """ diff --git a/python/paddle/fluid/layers/ops.py b/python/paddle/fluid/layers/ops.py index d7bad221c5fa7..f5c6b47d243dc 100644 --- a/python/paddle/fluid/layers/ops.py +++ b/python/paddle/fluid/layers/ops.py @@ -69,6 +69,7 @@ 'gaussian_random_batch_size_like', 'cumsum', 'scatter', + 'sum', ] + __activations__ for _OP in set(__all__): diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index a33760a528f66..180575c35dc6e 100644 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -13,7 +13,7 @@ # limitations under the License. from collections import defaultdict - +from paddle.fluid.framework import Program import framework import layers from backward import append_backward @@ -23,9 +23,11 @@ from layer_helper import LayerHelper from regularizer import append_regularization_ops from clip import append_gradient_clip_ops, error_clip_callback +from contextlib import contextmanager __all__ = [ - 'SGD', 'Momentum', 'Adagrad', 'Adam', 'Adamax', 'DecayedAdagrad', 'Adadelta' + 'SGD', 'Momentum', 'Adagrad', 'Adam', 'Adamax', 'DecayedAdagrad', + 'Adadelta', 'ModelAverage' ] @@ -121,7 +123,12 @@ def _finish_update(self, block): """ pass - def _add_accumulator(self, name, param, dtype=None, fill_value=0.0): + def _add_accumulator(self, + name, + param, + dtype=None, + fill_value=0.0, + shape=None): """Utility function to add an accumulator for a parameter Args: @@ -135,17 +142,19 @@ def _add_accumulator(self, name, param, dtype=None, fill_value=0.0): param.name in self._accumulators[name]): raise Exception("Accumulator {} already exists for parameter {}". format(name, param.name)) - + if shape == None: + shape = param.shape assert isinstance(self.helper, LayerHelper) var = self.helper.create_global_variable( name=unique_name.generate(name), persistable=True, dtype=dtype or param.dtype, type=param.type, - shape=param.shape) + shape=shape) self.helper.set_variable_initializer( var, initializer=Constant(value=float(fill_value))) self._accumulators[name][param.name] = var + return var def _get_accumulator(self, name, param): """Utility function to fetch an accumulator for a parameter @@ -797,3 +806,143 @@ def _append_optimize_op(self, block, param_and_grad): DecayedAdagrad = DecayedAdagradOptimizer Adadelta = AdadeltaOptimizer RMSProp = RMSPropOptimizer + + +class ModelAverage(Optimizer): + """Accumulate the average of parameters whtin sliding window. The average + result will be saved in temporary variables which can be applied to + parameter variables of current model by calling 'apply()' method. And the + 'restore()' method is used to restored the parameter values of current model. + + The size of average window is determined by average_window_rate, + min_average_window, max_average_window and current update times. + + Args: + params_grads: A list of parameter-grad variable pairs. + average_window_rate: The rate of average window. + min_average_window: The minimum size of average window. + max_average_window: The maximum size of average window. + + Examples: + ... + optimizer = fluid.optimizer.Momentum() + _, params_grads = optimizer.minimize(cost) + model_average = fluid.optimizer.ModelAverage(params_grads, 0.15, + min_average_window=10000, + max_average_window=20000) + for pass_id in range(args.pass_num): + for data in train_reader(): + exe.run(fluid.default_main_program()...) + + with model_average.apply(exe): + for data in test_reader(): + exe.run(inference_program...) + """ + + def __init__(self, + params_grads, + average_window_rate, + min_average_window=10000, + max_average_window=10000, + **kwargs): + super(ModelAverage, self).__init__(0.0, **kwargs) + self.average_window = average_window_rate + self.min_average_window = min_average_window + self.max_average_window = max_average_window + self.params_grads = params_grads + for param, grad in self.params_grads: + if grad is not None: + self._append_average_accumulate_op(param) + + self.apply_program = Program() + block = self.apply_program.global_block() + with program_guard(main_program=self.apply_program): + for param_grad in self.params_grads: + if param_grad[1] is not None: + self._add_average_apply_op(block, param_grad) + + self.restore_program = Program() + block = self.restore_program.global_block() + with program_guard(main_program=self.restore_program): + for param_grad in self.params_grads: + if param_grad[1] is not None: + self._add_average_restore_op(block, param_grad) + + def _add_average_apply_op(self, block, param_grad): + param = block.clone_variable(param_grad[0]) + grad = block.clone_variable(param_grad[1]) + sum_1 = block.clone_variable(self._get_accumulator('sum_1', param)) + sum_2 = block.clone_variable(self._get_accumulator('sum_2', param)) + sum_3 = block.clone_variable(self._get_accumulator('sum_3', param)) + num_accumulates = block.clone_variable( + self._get_accumulator('num_accumulates', param)) + old_num_accumulates = block.clone_variable( + self._get_accumulator('old_num_accumulates', param)) + num_updates = block.clone_variable( + self._get_accumulator('num_updates', param)) + # backup param value to grad + layers.assign(input=param, output=grad) + # param = (sum_1 + sum_2 + sum_3) / (num_accumulates + old_num_accumulates) + tmp = layers.sum(x=[num_accumulates, old_num_accumulates]) + sum = layers.sum(x=[sum_1, sum_2, sum_3]) + tmp = layers.cast(x=tmp, dtype='float32') + sum = layers.cast(x=sum, dtype='float32') + layers.elementwise_div(x=sum, y=tmp, out=param) + + def _add_average_restore_op(self, block, param_grad): + param = block.clone_variable(param_grad[0]) + grad = block.clone_variable(param_grad[1]) + layers.assign(input=grad, output=param) + + def _append_average_accumulate_op(self, param): + self.helper = LayerHelper("average_accumulate") + sum_1 = self._add_accumulator('sum_1', param) + sum_2 = self._add_accumulator('sum_2', param) + sum_3 = self._add_accumulator('sum_3', param) + num_accumulates = self._add_accumulator( + 'num_accumulates', param, dtype='int64', shape=[1]) + old_num_accumulates = self._add_accumulator( + 'old_num_accumulates', param, dtype='int64', shape=[1]) + num_updates = self._add_accumulator( + 'num_updates', param, dtype='int64', shape=[1]) + + self.helper.append_op( + type='average_accumulates', + inputs={ + "param": param, + "in_sum_1": sum_1, + "in_sum_2": sum_2, + "in_sum_3": sum_3, + "in_num_accumulates": num_accumulates, + "in_old_num_accumulates": old_num_accumulates, + "in_num_updates": num_updates + }, + outputs={ + "out_sum_1": sum_1, + "out_sum_2": sum_2, + "out_sum_3": sum_3, + "out_num_accumulates": num_accumulates, + "out_old_num_accumulates": old_num_accumulates, + "out_num_updates": num_updates, + }, + attrs={ + "average_window": self.average_window, + "min_average_window": self.min_average_window, + "max_average_window": self.max_average_window, + }) + + @contextmanager + def apply(self, executor, need_restore=True): + """Apply average values to parameters of current model. + """ + executor.run(self.apply_program) + try: + yield + finally: + if need_restore: + self.restore(executor) + + def restore(self, executor): + """Restore parameter values of current model. + """ + executor.run(self.restore_program)