diff --git a/python/paddle/distributed/fleet/recompute/recompute.py b/python/paddle/distributed/fleet/recompute/recompute.py index 6929ca52cb013..e1d2db328d179 100755 --- a/python/paddle/distributed/fleet/recompute/recompute.py +++ b/python/paddle/distributed/fleet/recompute/recompute.py @@ -41,16 +41,23 @@ def detach_variable(inputs): def check_recompute_necessary(inputs): - if not any(input_.stop_gradient == False for input_ in inputs - if isinstance(input_, (core.eager.Tensor, paddle.Tensor))): + if not any( + input_.stop_gradient == False + for input_ in inputs + if isinstance(input_, (core.eager.Tensor, paddle.Tensor)) + ): logger.warning( "[Recompute]: None of the inputs to current recompute block need grad, " - "therefore there is NO need to recompute this block in backward !") + "therefore there is NO need to recompute this block in backward !" + ) @contextlib.contextmanager def swith_rng_state_tracker(rng_state, tracker): - from paddle.distributed.fleet.meta_parallel.parallel_layers.random import get_rng_state_tracker + from paddle.distributed.fleet.meta_parallel.parallel_layers.random import ( + get_rng_state_tracker, + ) + orig_cuda_rng_state = paddle.get_cuda_rng_state() orig_cuda_rng_tracker = get_rng_state_tracker().get_states_tracker() @@ -64,10 +71,11 @@ def swith_rng_state_tracker(rng_state, tracker): class LegacyRecomputeFunction(LegacyPyLayer): - @staticmethod def forward(ctx, run_function, preserve_rng_state, *args): - from paddle.distributed.fleet.meta_parallel.parallel_layers.random import get_rng_state_tracker + from paddle.distributed.fleet.meta_parallel.parallel_layers.random import ( + get_rng_state_tracker, + ) # store for recomputing ctx.run_function = run_function @@ -96,30 +104,37 @@ def forward(ctx, run_function, preserve_rng_state, *args): cur_device = paddle.get_device() if 'gpu:' not in cur_device: raise RuntimeError( - "Recompute with RNG perserve is not support current device: {}." - .format(cur_device)) + "Recompute with RNG perserve is not support current device: {}.".format( + cur_device + ) + ) ctx.fw_cuda_rng_state = paddle.get_cuda_rng_state() - ctx.fwd_cuda_rng_state_tracker = get_rng_state_tracker( - ).get_states_tracker() + ctx.fwd_cuda_rng_state_tracker = ( + get_rng_state_tracker().get_states_tracker() + ) # TODO support AMP tracer = framework._dygraph_tracer() - ctx.is_fw_autocast = False if tracer._amp_level == core.AmpLevel.O0 else True + ctx.is_fw_autocast = ( + False if tracer._amp_level == core.AmpLevel.O0 else True + ) if tracer._amp_level == core.AmpLevel.O2: ctx.amp_level = 'O2' elif tracer._amp_level in (core.AmpLevel.O1, core.AmpLevel.O0): ctx.amp_level = 'O1' else: - raise ValueError("unsupported amp level: {}".format( - tracer._amp_level)) + raise ValueError( + "unsupported amp level: {}".format(tracer._amp_level) + ) if tracer._amp_dtype == 'float16': ctx.amp_dtype = 'float16' elif tracer._amp_dtype in ('bfloat16', 'float32'): ctx.amp_dtype = 'bfloat16' else: - raise ValueError("unsupported amp dtype: {}".format( - tracer._amp_dtype)) + raise ValueError( + "unsupported amp dtype: {}".format(tracer._amp_dtype) + ) ctx.amp_white_list, ctx.amp_black_list = tracer._get_amp_op_list() @@ -129,7 +144,10 @@ def forward(ctx, run_function, preserve_rng_state, *args): @staticmethod def backward(ctx, *args): - from paddle.distributed.fleet.meta_parallel.parallel_layers.random import get_rng_state_tracker + from paddle.distributed.fleet.meta_parallel.parallel_layers.random import ( + get_rng_state_tracker, + ) + with paddle.fluid.dygraph.guard(): # TODO need to check the recompute calling is vaild or not @@ -147,27 +165,31 @@ def backward(ctx, *args): # NOTE support AMP # need restore auto_cast state as well as w/b list if ctx.preserve_rng_state: - with swith_rng_state_tracker(ctx.fw_cuda_rng_state, - ctx.fwd_cuda_rng_state_tracker): + with swith_rng_state_tracker( + ctx.fw_cuda_rng_state, ctx.fwd_cuda_rng_state_tracker + ): with paddle.amp.auto_cast( - enable=ctx.is_fw_autocast, - custom_white_list=ctx.amp_white_list, - custom_black_list=ctx.amp_black_list, - level=ctx.amp_level, - dtype=ctx.amp_dtype): + enable=ctx.is_fw_autocast, + custom_white_list=ctx.amp_white_list, + custom_black_list=ctx.amp_black_list, + level=ctx.amp_level, + dtype=ctx.amp_dtype, + ): detached_inputs = detach_variable(tuple(inputs)) outputs = ctx.run_function(*detached_inputs) else: - with paddle.amp.auto_cast(enable=ctx.is_fw_autocast, - custom_white_list=ctx.amp_white_list, - custom_black_list=ctx.amp_black_list, - level=ctx.amp_level, - dtype=ctx.amp_dtype): + with paddle.amp.auto_cast( + enable=ctx.is_fw_autocast, + custom_white_list=ctx.amp_white_list, + custom_black_list=ctx.amp_black_list, + level=ctx.amp_level, + dtype=ctx.amp_dtype, + ): detached_inputs = detach_variable(tuple(inputs)) outputs = ctx.run_function(*detached_inputs) if isinstance(outputs, core.VarBase): - outputs = (outputs, ) + outputs = (outputs,) assert len(outputs) == len(args) # run backward() with only tensor that requires grad @@ -178,8 +200,10 @@ def backward(ctx, *args): # the following backward_inputs_with_grad is used to avoid this case. backward_inputs_with_grad = [] for i in range(len(outputs)): - if isinstance(outputs[i], - core.VarBase) and not outputs[i].stop_gradient: + if ( + isinstance(outputs[i], core.VarBase) + and not outputs[i].stop_gradient + ): forward_outputs_with_grad.append(outputs[i]) backward_inputs_with_grad.append(args[i]) @@ -190,19 +214,24 @@ def backward(ctx, *args): # actually backward with paddle.amp.auto_cast(enable=False): - paddle.autograd.backward(forward_outputs_with_grad, - backward_inputs_with_grad) + paddle.autograd.backward( + forward_outputs_with_grad, backward_inputs_with_grad + ) - grads = list(inp._grad_ivar() for inp in detached_inputs - if isinstance(inp, core.VarBase)) + grads = list( + inp._grad_ivar() + for inp in detached_inputs + if isinstance(inp, core.VarBase) + ) return grads class RecomputeFunction(PyLayer): - @staticmethod def forward(ctx, run_function, preserve_rng_state, *args, **kwargs): - from paddle.distributed.fleet.meta_parallel.parallel_layers.random import get_rng_state_tracker + from paddle.distributed.fleet.meta_parallel.parallel_layers.random import ( + get_rng_state_tracker, + ) # store for recomputing ctx.run_function = run_function @@ -232,30 +261,37 @@ def forward(ctx, run_function, preserve_rng_state, *args, **kwargs): cur_device = paddle.get_device() if 'gpu:' not in cur_device: raise RuntimeError( - "Recompute with RNG perserve is not support current device: {}." - .format(cur_device)) + "Recompute with RNG perserve is not support current device: {}.".format( + cur_device + ) + ) ctx.fw_cuda_rng_state = paddle.get_cuda_rng_state() - ctx.fwd_cuda_rng_state_tracker = get_rng_state_tracker( - ).get_states_tracker() + ctx.fwd_cuda_rng_state_tracker = ( + get_rng_state_tracker().get_states_tracker() + ) # TODO support AMP tracer = framework._dygraph_tracer() - ctx.is_fw_autocast = False if tracer._amp_level == core.AmpLevel.O0 else True + ctx.is_fw_autocast = ( + False if tracer._amp_level == core.AmpLevel.O0 else True + ) if tracer._amp_level == core.AmpLevel.O2: ctx.amp_level = 'O2' elif tracer._amp_level in (core.AmpLevel.O1, core.AmpLevel.O0): ctx.amp_level = 'O1' else: - raise ValueError("unsupported amp level: {}".format( - tracer._amp_level)) + raise ValueError( + "unsupported amp level: {}".format(tracer._amp_level) + ) if tracer._amp_dtype == 'float16': ctx.amp_dtype = 'float16' elif tracer._amp_dtype in ('bfloat16', 'float32'): ctx.amp_dtype = 'bfloat16' else: - raise ValueError("unsupported amp dtype: {}".format( - tracer._amp_dtype)) + raise ValueError( + "unsupported amp dtype: {}".format(tracer._amp_dtype) + ) ctx.amp_white_list, ctx.amp_black_list = tracer._get_amp_op_list() @@ -265,7 +301,10 @@ def forward(ctx, run_function, preserve_rng_state, *args, **kwargs): @staticmethod def backward(ctx, *args): - from paddle.distributed.fleet.meta_parallel.parallel_layers.random import get_rng_state_tracker + from paddle.distributed.fleet.meta_parallel.parallel_layers.random import ( + get_rng_state_tracker, + ) + with paddle.fluid.dygraph.guard(): # TODO need to check the recompute calling is vaild or not @@ -283,28 +322,33 @@ def backward(ctx, *args): # NOTE support AMP # need restore auto_cast state as well as w/b list if ctx.preserve_rng_state: - with swith_rng_state_tracker(ctx.fw_cuda_rng_state, - ctx.fwd_cuda_rng_state_tracker): + with swith_rng_state_tracker( + ctx.fw_cuda_rng_state, ctx.fwd_cuda_rng_state_tracker + ): with paddle.amp.auto_cast( - enable=ctx.is_fw_autocast, - custom_white_list=ctx.amp_white_list, - custom_black_list=ctx.amp_black_list, - level=ctx.amp_level, - dtype=ctx.amp_dtype): + enable=ctx.is_fw_autocast, + custom_white_list=ctx.amp_white_list, + custom_black_list=ctx.amp_black_list, + level=ctx.amp_level, + dtype=ctx.amp_dtype, + ): detached_inputs = detach_variable(tuple(inputs)) - outputs = ctx.run_function(*detached_inputs, - **ctx.kwargs) + outputs = ctx.run_function( + *detached_inputs, **ctx.kwargs + ) else: - with paddle.amp.auto_cast(enable=ctx.is_fw_autocast, - custom_white_list=ctx.amp_white_list, - custom_black_list=ctx.amp_black_list, - level=ctx.amp_level, - dtype=ctx.amp_dtype): + with paddle.amp.auto_cast( + enable=ctx.is_fw_autocast, + custom_white_list=ctx.amp_white_list, + custom_black_list=ctx.amp_black_list, + level=ctx.amp_level, + dtype=ctx.amp_dtype, + ): detached_inputs = detach_variable(tuple(inputs)) outputs = ctx.run_function(*detached_inputs, **ctx.kwargs) if isinstance(outputs, (core.VarBase, core.eager.Tensor)): - outputs = (outputs, ) + outputs = (outputs,) assert len(outputs) == len(args) # run backward() with only tensor that requires grad @@ -315,10 +359,10 @@ def backward(ctx, *args): # the following backward_inputs_with_grad is used to avoid this case. backward_inputs_with_grad = [] for i in range(len(outputs)): - if isinstance( - outputs[i], - (core.VarBase, - core.eager.Tensor)) and not outputs[i].stop_gradient: + if ( + isinstance(outputs[i], (core.VarBase, core.eager.Tensor)) + and not outputs[i].stop_gradient + ): forward_outputs_with_grad.append(outputs[i]) backward_inputs_with_grad.append(args[i]) @@ -329,17 +373,22 @@ def backward(ctx, *args): # actually backward with paddle.amp.auto_cast(enable=False): - paddle.autograd.backward(forward_outputs_with_grad, - backward_inputs_with_grad) + paddle.autograd.backward( + forward_outputs_with_grad, backward_inputs_with_grad + ) if in_dygraph_mode(): grads = tuple( - inp._grad_ivar() for inp in detached_inputs - if isinstance(inp, (core.VarBase, core.eager.Tensor))) + inp._grad_ivar() + for inp in detached_inputs + if isinstance(inp, (core.VarBase, core.eager.Tensor)) + ) else: grads = list( - inp._grad_ivar() for inp in detached_inputs - if isinstance(inp, (core.VarBase, core.eager.Tensor))) + inp._grad_ivar() + for inp in detached_inputs + if isinstance(inp, (core.VarBase, core.eager.Tensor)) + ) return grads @@ -363,13 +412,10 @@ def recompute(function, *args, **kwargs): Examples: .. code-block:: python - import numpy as np import paddle from paddle.distributed.fleet.utils import recompute import random - # required: gpu - def get_fc_block(block_idx, input_size, is_last=False): block_name = "block_" + str(block_idx) block = paddle.nn.Sequential( @@ -391,10 +437,7 @@ def get_fc_block(block_idx, input_size, is_last=False): block_name + "_fc_2", paddle.nn.Linear(input_size, input_size, bias_attr=False) ) - return block - - class Naive_fc_net(paddle.nn.Layer): def __init__(self, input_size=10, recompute_blocks=[1, 3], @@ -408,7 +451,6 @@ def __init__(self, input_size=10, self.runfunc3 = get_fc_block(3, input_size, is_last=False) self.runfunc4 = get_fc_block(4, input_size, is_last=True) self.total_func = [self.runfunc0, self.runfunc1, self.runfunc2, self.runfunc3, self.runfunc4] - def forward(self, inputs): nums = len(self.total_func) for i in range(nums): @@ -417,15 +459,12 @@ def forward(self, inputs): else: inputs = self.total_func[i](inputs) return inputs - def run_model(cuda_state, recompute_block=[], recompute_kwargs={}): gen = paddle.seed(10) gen.manual_seed(10) - np.random.seed(10) random.seed(10) if cuda_state: paddle.set_cuda_rng_state(cuda_state) - batch_size, input_size = 1, 10 model = Naive_fc_net( input_size, @@ -436,29 +475,24 @@ def run_model(cuda_state, recompute_block=[], recompute_kwargs={}): param_ = [] grad_ = [] for _ in range(5): - x_data = np.random.randn(batch_size, input_size).astype(np.float32) - x = paddle.to_tensor(x_data) + x = paddle.rand(shape=[batch_size, input_size], dtype="float32") y_pred = model(x) loss = y_pred.mean() - loss_.append(np.asarray(loss).tolist()) + loss_.append(loss.item()) loss.backward() optimizer.step() - param_.append(np.asarray(model.parameters()[9]).tolist()) - grad_.append(np.asarray(model.parameters()[3]._grad_ivar()).tolist()) + param_.append(model.parameters()[9]) + grad_.append(model.parameters()[3]._grad_ivar()) optimizer.clear_grad() - return loss_, param_, grad_ - cuda_state = paddle.get_cuda_rng_state() # without recompute loss_ref, param_ref, grad_ref = run_model( cuda_state, recompute_block=[] ) - loss, param, grad = run_model(cuda_state, recompute_block=[1, 2]) print("normal_loss: {}, recompute_loss: {}".format(loss_ref, loss)) # The result of the recompute_loss should be the same as the normal_loss. - """ # Hack to mix *args with **kwargs in a python 2.7-compliant way preserve = kwargs.pop('preserve_rng_state', True) @@ -497,7 +531,6 @@ def recompute_sequential(ctx, functions, *args, **kwargs): preserve_rng_state = ctx.get('preserve_rng_state', True) def _run_func(begin, end, funcs): - def do_run(input): for i in range(begin, end + 1): input = funcs[i](input) @@ -513,8 +546,10 @@ def do_run(input): end = -1 for begin in range(0, segment_size * (segments - 1), segment_size): end = begin + segment_size - 1 - args = recompute(_run_func(begin, end, functions), - *args, - preserve_rng_state=preserve_rng_state, - **kwargs) + args = recompute( + _run_func(begin, end, functions), + *args, + preserve_rng_state=preserve_rng_state, + **kwargs + ) return _run_func(end + 1, len(functions) - 1, functions)(args) diff --git a/python/paddle/distributed/fleet/utils/__init__.py b/python/paddle/distributed/fleet/utils/__init__.py index 93fc890d05af5..30afae2b432e5 100644 --- a/python/paddle/distributed/fleet/utils/__init__.py +++ b/python/paddle/distributed/fleet/utils/__init__.py @@ -22,14 +22,108 @@ from . import log_util # noqa: F401 from . import hybrid_parallel_util # noqa: F401 -__all__ = [ #noqa - "LocalFS", "recompute", "DistributedInfer", "HDFSClient" -] +__all__ = ["LocalFS", "recompute", "DistributedInfer", "HDFSClient"] # noqa -@deprecated(since="2.4.0", - update_to="paddle.distributed.fleet.recompute", - level=1, - reason="Please use new recompute API(fleet.recompute) ") def recompute(function, *args, **kwargs): + """ + recompute intermediate activations to save then memory. + Parameters: + function(paddle.nn.Layer): layer of sequence of layers that describes part of forward pass of the model + whose intermediate activations will be released to save memory in forward stage and will be recomputed + in backward stage for gradient calculation. + *args(Tensor): inputs to the function. + **kwargs(Dict): Kwargs should only contain the key-value pair of preserve_rng_state, which is used to + indicate whether to save the forward rng. If it is True, then the last forward rng value will be + restored when the forward recalculation of backpropagation is performed. The default + preserve_rng_state is True. + Returns: + Output of function on args. + + Examples: + .. code-block:: python + + import paddle + from paddle.distributed.fleet.utils import recompute + import random + # required: gpu + def get_fc_block(block_idx, input_size, is_last=False): + block_name = "block_" + str(block_idx) + block = paddle.nn.Sequential( + (block_name + "_fc_0", paddle.nn.Linear(input_size, input_size, bias_attr=False)), + (block_name + "_dropout", paddle.nn.Dropout(p=0.5)), + (block_name + "_relu_1", paddle.nn.ReLU()), + (block_name + "_fc_1", paddle.nn.Linear(input_size, input_size, bias_attr=False)), + (block_name + "_relu_2", paddle.nn.ReLU()), + ) + if is_last: + block.add_sublayer( + block_name + "_fc_2", + paddle.nn.Linear( + input_size, 1, bias_attr=False + ) + ) + else: + block.add_sublayer( + block_name + "_fc_2", + paddle.nn.Linear(input_size, input_size, bias_attr=False) + ) + return block + class Naive_fc_net(paddle.nn.Layer): + def __init__(self, input_size=10, + recompute_blocks=[1, 3], + recompute_kwargs={}): + super(Naive_fc_net, self).__init__() + self.recompute_blocks = recompute_blocks + self.recompute_kwargs = recompute_kwargs + self.runfunc0 = get_fc_block(0, input_size, is_last=False) + self.runfunc1 = get_fc_block(1, input_size, is_last=False) + self.runfunc2 = get_fc_block(2, input_size, is_last=False) + self.runfunc3 = get_fc_block(3, input_size, is_last=False) + self.runfunc4 = get_fc_block(4, input_size, is_last=True) + self.total_func = [self.runfunc0, self.runfunc1, self.runfunc2, self.runfunc3, self.runfunc4] + def forward(self, inputs): + nums = len(self.total_func) + for i in range(nums): + if i in self.recompute_blocks: + inputs = recompute(self.total_func[i], inputs, **{"preserve_rng_state": True}) + else: + inputs = self.total_func[i](inputs) + return inputs + def run_model(cuda_state, recompute_block=[], recompute_kwargs={}): + gen = paddle.seed(10) + gen.manual_seed(10) + random.seed(10) + if cuda_state: + paddle.set_cuda_rng_state(cuda_state) + batch_size, input_size = 1, 10 + model = Naive_fc_net( + input_size, + recompute_blocks=recompute_block, + recompute_kwargs=recompute_kwargs) + optimizer = paddle.optimizer.SGD(learning_rate=0.01, parameters=model.parameters()) + loss_ = [] + param_ = [] + grad_ = [] + for _ in range(5): + x = paddle.rand(shape=[batch_size, input_size], dtype="float32") + y_pred = model(x) + loss = y_pred.mean() + loss_.append(loss.item()) + loss.backward() + optimizer.step() + param_.append(model.parameters()[9]) + grad_.append(model.parameters()[3]._grad_ivar()) + optimizer.clear_grad() + return loss_, param_, grad_ + cuda_state = paddle.get_cuda_rng_state() + # without recompute + loss_ref, param_ref, grad_ref = run_model( + cuda_state, recompute_block=[] + ) + loss, param, grad = run_model(cuda_state, recompute_block=[1, 2]) + print("normal_loss: {}, recompute_loss: {}".format(loss_ref, loss)) + # The result of the recompute_loss should be the same as the normal_loss. + """ + return fleet.recompute.recompute(function, *args, **kwargs)