New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Dygraph Recompute #32516
Merged
Merged
Dygraph Recompute #32516
Changes from all commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,3 +14,4 @@ | |
|
||
from .fs import LocalFS, HDFSClient | ||
from .ps_util import DistributedInfer | ||
from .recompute import recompute |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,177 @@ | ||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import paddle | ||
from paddle.fluid import core | ||
from paddle.autograd import PyLayer | ||
from paddle.fluid import framework | ||
import contextlib | ||
|
||
import logging | ||
logging.basicConfig( | ||
format='%(asctime)s %(levelname)-8s %(message)s', | ||
datefmt='%Y-%m-%d %H:%M:%S') | ||
|
||
|
||
def detach_variable(inputs): | ||
out = [] | ||
for inp in inputs: | ||
if not isinstance(inp, core.VarBase): | ||
out.append(inp) | ||
continue | ||
|
||
x = inp.detach() | ||
x.stop_gradient = inp.stop_gradient | ||
out.append(x) | ||
return tuple(out) | ||
|
||
|
||
def check_recompute_necessary(inputs): | ||
if not any(input_.stop_gradient == False for input_ in inputs | ||
if isinstance(input_, paddle.Tensor)): | ||
logging.warn( | ||
"[Recompute]: None of the inputs to current recompute block need grad, " | ||
"therefore there is NO need to recompute this block in backward !") | ||
|
||
|
||
@contextlib.contextmanager | ||
def swith_rng_state(rng_state): | ||
orig_cuda_rng_state = paddle.get_cuda_rng_state() | ||
paddle.set_cuda_rng_state(rng_state) | ||
try: | ||
yield | ||
finally: | ||
paddle.set_cuda_rng_state(orig_cuda_rng_state) | ||
|
||
|
||
class RecomputeFunction(PyLayer): | ||
@staticmethod | ||
def forward(ctx, run_function, preserve_rng_state, *args): | ||
check_recompute_necessary(args) | ||
|
||
# store for recomputing | ||
ctx.run_function = run_function | ||
ctx.preserve_rng_state = preserve_rng_state | ||
|
||
# NOTE the number of outputs of backward() should be equal to the number of tensors in forward()'s input | ||
# the order of tensors in backward()'s output should be the same as tensors in forward()'s input | ||
# None tensor inputs will be filtered in backward inputs. | ||
|
||
# save input for backward | ||
ctx.inputs = [] | ||
ctx.tensor_indices = [] | ||
tensor_inputs = [] | ||
for i, arg in enumerate(args): | ||
if paddle.is_tensor(arg): | ||
tensor_inputs.append(arg) | ||
ctx.tensor_indices.append(i) | ||
ctx.inputs.append(None) | ||
else: | ||
ctx.inputs.append(arg) | ||
ctx.save_for_backward(*tensor_inputs) | ||
|
||
# NOTE recompute with restore RNG only support one senario where one process for one cuda gpu. | ||
# one process with multiple gpu and mix-gpu-cpu senarios are not support | ||
if ctx.preserve_rng_state: | ||
cur_device = paddle.get_device() | ||
if 'gpu:' not in cur_device: | ||
raise RuntimeError( | ||
"Recompute with RNG perserve is not support current device: {}.". | ||
format(cur_device)) | ||
ctx.fw_cuda_rng_state = paddle.get_cuda_rng_state() | ||
|
||
# TODO support AMP | ||
|
||
with paddle.no_grad(): | ||
outputs = run_function(*args) | ||
|
||
return outputs | ||
|
||
@staticmethod | ||
def backward(ctx, *args): | ||
with paddle.fluid.dygraph.guard(): | ||
# TODO need to check the recompute calling is vaild or not | ||
|
||
# Restore inputs | ||
inputs = list(ctx.inputs) | ||
tensor_indices = ctx.tensor_indices | ||
tensors = ctx.saved_tensor() | ||
for i, idx in enumerate(tensor_indices): | ||
inputs[idx] = tensors[i] | ||
|
||
# paddle.enable_grad() | ||
tracer = framework._dygraph_tracer() | ||
tracer._has_grad = True | ||
|
||
# TODO support AMP | ||
|
||
if ctx.preserve_rng_state: | ||
with swith_rng_state(ctx.fw_cuda_rng_state): | ||
detached_inputs = detach_variable(tuple(inputs)) | ||
outputs = ctx.run_function(*detached_inputs) | ||
else: | ||
detached_inputs = detach_variable(tuple(inputs)) | ||
outputs = ctx.run_function(*detached_inputs) | ||
|
||
if isinstance(outputs, core.VarBase): | ||
outputs = (outputs, ) | ||
assert len(outputs) == len(args) | ||
|
||
# run backward() with only tensor that requires grad | ||
forward_outputs_with_grad = [] | ||
backward_inputs = list(args) | ||
for i in range(len(outputs)): | ||
if isinstance(outputs[i], | ||
core.VarBase) and not outputs[i].stop_gradient: | ||
forward_outputs_with_grad.append(outputs[i]) | ||
if len(forward_outputs_with_grad) == 0: | ||
raise RuntimeError( | ||
"none of output has requires_grad=True, this recompute() is not necessary" | ||
) | ||
|
||
assert len(backward_inputs) == len( | ||
forward_outputs_with_grad | ||
), "number of forward outputs is [{}], but the backward got [{}] inputs".format( | ||
len(forward_outputs_with_grad), len(backward_inputs)) | ||
|
||
# actually backward | ||
paddle.autograd.backward(forward_outputs_with_grad, backward_inputs) | ||
|
||
grads = list(inp._grad_ivar() for inp in detached_inputs | ||
if isinstance(inp, core.VarBase)) | ||
|
||
return grads | ||
|
||
|
||
def recompute(function, *args, **kwargs): | ||
""" | ||
recompute intermediate activations to save then memory. | ||
|
||
Args: | ||
function: layer of sequence of layers that describes part of forward pass of the model whose | ||
intermediate activations will be released to save memory in forward stage and will be recomputed | ||
in backward stage for gradient calculation. | ||
preserve_rng_state(bool, optional): if preserve the RNG state of forward and restore it in backward. | ||
args: inputs to the function | ||
|
||
Returns: | ||
Output of function on args | ||
""" | ||
# Hack to mix *args with **kwargs in a python 2.7-compliant way | ||
preserve = kwargs.pop('preserve_rng_state', True) | ||
if kwargs: | ||
raise ValueError("Unexpected keyword arguments: " + ",".join( | ||
arg for arg in kwargs)) | ||
|
||
return RecomputeFunction.apply(function, preserve, *args) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
176 changes: 176 additions & 0 deletions
176
python/paddle/fluid/tests/unittests/test_dygraph_recompute.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,176 @@ | ||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from __future__ import print_function | ||
|
||
import unittest | ||
import numpy as np | ||
|
||
import paddle | ||
from paddle.autograd import PyLayer | ||
from paddle.distributed.fleet.utils import recompute | ||
import random | ||
|
||
import paddle.fluid.layers as layers | ||
|
||
|
||
def get_fc_block(block_idx, input_size, is_last=False): | ||
block_name = "block_" + str(block_idx) | ||
block = paddle.nn.Sequential( | ||
(block_name + "_fc_0", paddle.nn.Linear( | ||
input_size, input_size, bias_attr=False)), | ||
(block_name + "_dropout", paddle.nn.Dropout(p=0.5)), | ||
(block_name + "_relu_1", paddle.nn.ReLU()), | ||
(block_name + "_fc_1", paddle.nn.Linear( | ||
input_size, input_size, bias_attr=False)), | ||
(block_name + "_relu_2", paddle.nn.ReLU()), ) | ||
if is_last: | ||
block.add_sublayer( | ||
block_name + "_fc_2", | ||
paddle.nn.Linear( | ||
input_size, 1, bias_attr=False)) # add sublayer | ||
else: | ||
block.add_sublayer( | ||
block_name + "_fc_2", | ||
paddle.nn.Linear( | ||
input_size, input_size, bias_attr=False)) # add sublayer | ||
return block | ||
|
||
|
||
class Naive_fc_net(paddle.nn.Layer): | ||
def __init__(self, | ||
input_size=10, | ||
recompute_blocks=[1, 3], | ||
recompute_kwargs={}): | ||
super(Naive_fc_net, self).__init__() | ||
self.recompute_blocks = recompute_blocks | ||
self.recompute_kwargs = recompute_kwargs | ||
self.runfunc0 = get_fc_block(0, input_size, is_last=False) | ||
self.runfunc1 = get_fc_block(1, input_size, is_last=False) | ||
self.runfunc2 = get_fc_block(2, input_size, is_last=False) | ||
self.runfunc3 = get_fc_block(3, input_size, is_last=False) | ||
self.runfunc4 = get_fc_block(4, input_size, is_last=True) | ||
|
||
def forward(self, inputs): | ||
|
||
if 0 in self.recompute_blocks: | ||
inputs = recompute(self.runfunc0, inputs) | ||
else: | ||
inputs = self.runfunc0(inputs) | ||
|
||
if 1 in self.recompute_blocks: | ||
inputs = recompute(self.runfunc1, inputs) | ||
else: | ||
inputs = self.runfunc1(inputs) | ||
|
||
if 2 in self.recompute_blocks: | ||
inputs = recompute(self.runfunc2, inputs, **self.recompute_kwargs) | ||
else: | ||
inputs = self.runfunc2(inputs) | ||
|
||
if 3 in self.recompute_blocks: | ||
inputs = recompute(self.runfunc3, inputs) | ||
else: | ||
inputs = self.runfunc3(inputs) | ||
|
||
if 4 in self.recompute_blocks: | ||
inputs = recompute(self.runfunc4, inputs) | ||
else: | ||
inputs = self.runfunc4(inputs) | ||
|
||
return inputs | ||
|
||
|
||
def run_model(cuda_state, recompute_block=[], recompute_kwargs={}): | ||
gen = paddle.seed(10) | ||
gen.manual_seed(10) | ||
np.random.seed(10) | ||
random.seed(10) | ||
|
||
if cuda_state: | ||
paddle.set_cuda_rng_state(cuda_state) | ||
|
||
batch_size, input_size = 1, 10 | ||
model = Naive_fc_net( | ||
input_size, | ||
recompute_blocks=recompute_block, | ||
recompute_kwargs=recompute_kwargs) | ||
loss_fn = paddle.nn.MSELoss(reduction='mean') | ||
optimizer = paddle.optimizer.SGD(learning_rate=0.01, | ||
parameters=model.parameters()) | ||
|
||
loss_ = [] | ||
param_ = [] | ||
grad_ = [] | ||
for step in range(10): | ||
x_data = np.random.randn(batch_size, input_size).astype(np.float32) | ||
x = paddle.to_tensor(x_data) | ||
# x.stop_gradient = False | ||
y_pred = model(x) | ||
loss = y_pred.mean() | ||
|
||
loss_.append(np.asarray(loss).tolist()) | ||
loss.backward() | ||
optimizer.step() | ||
|
||
param_.append(np.asarray(model.parameters()[9]).tolist()) | ||
grad_.append(np.asarray(model.parameters()[3]._grad_ivar()).tolist()) | ||
|
||
optimizer.clear_grad() | ||
return loss_, param_, grad_ | ||
|
||
|
||
class TestPyLayer(unittest.TestCase): | ||
def test_fc_net_with_dropout(self): | ||
def check_identical(loss_ref, param_ref, grad_ref, loss, param, grad): | ||
self.assertEqual(loss_ref, loss) | ||
self.assertEqual(param_ref, param) | ||
self.assertEqual(grad_ref, grad) | ||
|
||
cuda_state = paddle.get_cuda_rng_state() | ||
# without recompute | ||
loss_ref, param_ref, grad_ref = run_model( | ||
cuda_state, recompute_block=[]) | ||
|
||
# recompute second block | ||
loss, param, grad = run_model(cuda_state, recompute_block=[1, 3]) | ||
check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) | ||
|
||
# recompute fourth block | ||
loss, param, grad = run_model(cuda_state, recompute_block=[3]) | ||
check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) | ||
|
||
# recompute second to fourth block | ||
loss, param, grad = run_model(cuda_state, recompute_block=[1, 2, 3]) | ||
check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) | ||
|
||
# recompute second & fourth block | ||
loss, param, grad = run_model(cuda_state, recompute_block=[1, 3]) | ||
check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) | ||
|
||
def test_recompute_kwargs(self): | ||
paddle.set_device("gpu") | ||
kwargs = {"is_test": False} | ||
with self.assertRaises(ValueError): | ||
loss_ref, param_ref, grad_ref = run_model( | ||
None, recompute_block=[2], recompute_kwargs=kwargs) | ||
|
||
def test_recompute_cpu_rng(self): | ||
paddle.set_device("cpu") | ||
with self.assertRaises(RuntimeError): | ||
loss_ref, param_ref, grad_ref = run_model(None, recompute_block=[2]) | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
swith_rng_state -> switch_rng_state