Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dygraph Recompute #32516

Merged
merged 3 commits into from Apr 25, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions python/paddle/distributed/fleet/utils/__init__.py
Expand Up @@ -14,3 +14,4 @@

from .fs import LocalFS, HDFSClient
from .ps_util import DistributedInfer
from .recompute import recompute
177 changes: 177 additions & 0 deletions python/paddle/distributed/fleet/utils/recompute.py
@@ -0,0 +1,177 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle
from paddle.fluid import core
from paddle.autograd import PyLayer
from paddle.fluid import framework
import contextlib

import logging
logging.basicConfig(
format='%(asctime)s %(levelname)-8s %(message)s',
datefmt='%Y-%m-%d %H:%M:%S')


def detach_variable(inputs):
out = []
for inp in inputs:
if not isinstance(inp, core.VarBase):
out.append(inp)
continue

x = inp.detach()
x.stop_gradient = inp.stop_gradient
out.append(x)
return tuple(out)


def check_recompute_necessary(inputs):
if not any(input_.stop_gradient == False for input_ in inputs
if isinstance(input_, paddle.Tensor)):
logging.warn(
"[Recompute]: None of the inputs to current recompute block need grad, "
"therefore there is NO need to recompute this block in backward !")


@contextlib.contextmanager
def swith_rng_state(rng_state):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

swith_rng_state -> switch_rng_state

orig_cuda_rng_state = paddle.get_cuda_rng_state()
paddle.set_cuda_rng_state(rng_state)
try:
yield
finally:
paddle.set_cuda_rng_state(orig_cuda_rng_state)


class RecomputeFunction(PyLayer):
@staticmethod
def forward(ctx, run_function, preserve_rng_state, *args):
check_recompute_necessary(args)

# store for recomputing
ctx.run_function = run_function
ctx.preserve_rng_state = preserve_rng_state

# NOTE the number of outputs of backward() should be equal to the number of tensors in forward()'s input
# the order of tensors in backward()'s output should be the same as tensors in forward()'s input
# None tensor inputs will be filtered in backward inputs.

# save input for backward
ctx.inputs = []
ctx.tensor_indices = []
tensor_inputs = []
for i, arg in enumerate(args):
if paddle.is_tensor(arg):
tensor_inputs.append(arg)
ctx.tensor_indices.append(i)
ctx.inputs.append(None)
else:
ctx.inputs.append(arg)
ctx.save_for_backward(*tensor_inputs)

# NOTE recompute with restore RNG only support one senario where one process for one cuda gpu.
# one process with multiple gpu and mix-gpu-cpu senarios are not support
if ctx.preserve_rng_state:
cur_device = paddle.get_device()
if 'gpu:' not in cur_device:
raise RuntimeError(
"Recompute with RNG perserve is not support current device: {}.".
format(cur_device))
ctx.fw_cuda_rng_state = paddle.get_cuda_rng_state()

# TODO support AMP

with paddle.no_grad():
outputs = run_function(*args)

return outputs

@staticmethod
def backward(ctx, *args):
with paddle.fluid.dygraph.guard():
# TODO need to check the recompute calling is vaild or not

# Restore inputs
inputs = list(ctx.inputs)
tensor_indices = ctx.tensor_indices
tensors = ctx.saved_tensor()
for i, idx in enumerate(tensor_indices):
inputs[idx] = tensors[i]

# paddle.enable_grad()
tracer = framework._dygraph_tracer()
tracer._has_grad = True

# TODO support AMP

if ctx.preserve_rng_state:
with swith_rng_state(ctx.fw_cuda_rng_state):
detached_inputs = detach_variable(tuple(inputs))
outputs = ctx.run_function(*detached_inputs)
else:
detached_inputs = detach_variable(tuple(inputs))
outputs = ctx.run_function(*detached_inputs)

if isinstance(outputs, core.VarBase):
outputs = (outputs, )
assert len(outputs) == len(args)

# run backward() with only tensor that requires grad
forward_outputs_with_grad = []
backward_inputs = list(args)
for i in range(len(outputs)):
if isinstance(outputs[i],
core.VarBase) and not outputs[i].stop_gradient:
forward_outputs_with_grad.append(outputs[i])
if len(forward_outputs_with_grad) == 0:
raise RuntimeError(
"none of output has requires_grad=True, this recompute() is not necessary"
)

assert len(backward_inputs) == len(
forward_outputs_with_grad
), "number of forward outputs is [{}], but the backward got [{}] inputs".format(
len(forward_outputs_with_grad), len(backward_inputs))

# actually backward
paddle.autograd.backward(forward_outputs_with_grad, backward_inputs)

grads = list(inp._grad_ivar() for inp in detached_inputs
if isinstance(inp, core.VarBase))

return grads


def recompute(function, *args, **kwargs):
"""
recompute intermediate activations to save then memory.

Args:
function: layer of sequence of layers that describes part of forward pass of the model whose
intermediate activations will be released to save memory in forward stage and will be recomputed
in backward stage for gradient calculation.
preserve_rng_state(bool, optional): if preserve the RNG state of forward and restore it in backward.
args: inputs to the function

Returns:
Output of function on args
"""
# Hack to mix *args with **kwargs in a python 2.7-compliant way
preserve = kwargs.pop('preserve_rng_state', True)
if kwargs:
raise ValueError("Unexpected keyword arguments: " + ",".join(
arg for arg in kwargs))

return RecomputeFunction.apply(function, preserve, *args)
1 change: 1 addition & 0 deletions python/paddle/fluid/tests/unittests/CMakeLists.txt
Expand Up @@ -174,6 +174,7 @@ if ((NOT WITH_GPU) AND (NOT WITH_ROCM))
list(REMOVE_ITEM TEST_OPS test_parallel_dygraph_dataparallel)
LIST(REMOVE_ITEM TEST_OPS test_imperative_auto_mixed_precision)
LIST(REMOVE_ITEM TEST_OPS test_fleet_base_single)
LIST(REMOVE_ITEM TEST_OPS test_dygraph_recompute)
elseif(WITH_GPU)
if (${CUDNN_VERSION} VERSION_LESS 7100)
LIST(REMOVE_ITEM TEST_OPS test_conv2d_fusion_op)
Expand Down
176 changes: 176 additions & 0 deletions python/paddle/fluid/tests/unittests/test_dygraph_recompute.py
@@ -0,0 +1,176 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import unittest
import numpy as np

import paddle
from paddle.autograd import PyLayer
from paddle.distributed.fleet.utils import recompute
import random

import paddle.fluid.layers as layers


def get_fc_block(block_idx, input_size, is_last=False):
block_name = "block_" + str(block_idx)
block = paddle.nn.Sequential(
(block_name + "_fc_0", paddle.nn.Linear(
input_size, input_size, bias_attr=False)),
(block_name + "_dropout", paddle.nn.Dropout(p=0.5)),
(block_name + "_relu_1", paddle.nn.ReLU()),
(block_name + "_fc_1", paddle.nn.Linear(
input_size, input_size, bias_attr=False)),
(block_name + "_relu_2", paddle.nn.ReLU()), )
if is_last:
block.add_sublayer(
block_name + "_fc_2",
paddle.nn.Linear(
input_size, 1, bias_attr=False)) # add sublayer
else:
block.add_sublayer(
block_name + "_fc_2",
paddle.nn.Linear(
input_size, input_size, bias_attr=False)) # add sublayer
return block


class Naive_fc_net(paddle.nn.Layer):
def __init__(self,
input_size=10,
recompute_blocks=[1, 3],
recompute_kwargs={}):
super(Naive_fc_net, self).__init__()
self.recompute_blocks = recompute_blocks
self.recompute_kwargs = recompute_kwargs
self.runfunc0 = get_fc_block(0, input_size, is_last=False)
self.runfunc1 = get_fc_block(1, input_size, is_last=False)
self.runfunc2 = get_fc_block(2, input_size, is_last=False)
self.runfunc3 = get_fc_block(3, input_size, is_last=False)
self.runfunc4 = get_fc_block(4, input_size, is_last=True)

def forward(self, inputs):

if 0 in self.recompute_blocks:
inputs = recompute(self.runfunc0, inputs)
else:
inputs = self.runfunc0(inputs)

if 1 in self.recompute_blocks:
inputs = recompute(self.runfunc1, inputs)
else:
inputs = self.runfunc1(inputs)

if 2 in self.recompute_blocks:
inputs = recompute(self.runfunc2, inputs, **self.recompute_kwargs)
else:
inputs = self.runfunc2(inputs)

if 3 in self.recompute_blocks:
inputs = recompute(self.runfunc3, inputs)
else:
inputs = self.runfunc3(inputs)

if 4 in self.recompute_blocks:
inputs = recompute(self.runfunc4, inputs)
else:
inputs = self.runfunc4(inputs)

return inputs


def run_model(cuda_state, recompute_block=[], recompute_kwargs={}):
gen = paddle.seed(10)
gen.manual_seed(10)
np.random.seed(10)
random.seed(10)

if cuda_state:
paddle.set_cuda_rng_state(cuda_state)

batch_size, input_size = 1, 10
model = Naive_fc_net(
input_size,
recompute_blocks=recompute_block,
recompute_kwargs=recompute_kwargs)
loss_fn = paddle.nn.MSELoss(reduction='mean')
optimizer = paddle.optimizer.SGD(learning_rate=0.01,
parameters=model.parameters())

loss_ = []
param_ = []
grad_ = []
for step in range(10):
x_data = np.random.randn(batch_size, input_size).astype(np.float32)
x = paddle.to_tensor(x_data)
# x.stop_gradient = False
y_pred = model(x)
loss = y_pred.mean()

loss_.append(np.asarray(loss).tolist())
loss.backward()
optimizer.step()

param_.append(np.asarray(model.parameters()[9]).tolist())
grad_.append(np.asarray(model.parameters()[3]._grad_ivar()).tolist())

optimizer.clear_grad()
return loss_, param_, grad_


class TestPyLayer(unittest.TestCase):
def test_fc_net_with_dropout(self):
def check_identical(loss_ref, param_ref, grad_ref, loss, param, grad):
self.assertEqual(loss_ref, loss)
self.assertEqual(param_ref, param)
self.assertEqual(grad_ref, grad)

cuda_state = paddle.get_cuda_rng_state()
# without recompute
loss_ref, param_ref, grad_ref = run_model(
cuda_state, recompute_block=[])

# recompute second block
loss, param, grad = run_model(cuda_state, recompute_block=[1, 3])
check_identical(loss_ref, param_ref, grad_ref, loss, param, grad)

# recompute fourth block
loss, param, grad = run_model(cuda_state, recompute_block=[3])
check_identical(loss_ref, param_ref, grad_ref, loss, param, grad)

# recompute second to fourth block
loss, param, grad = run_model(cuda_state, recompute_block=[1, 2, 3])
check_identical(loss_ref, param_ref, grad_ref, loss, param, grad)

# recompute second & fourth block
loss, param, grad = run_model(cuda_state, recompute_block=[1, 3])
check_identical(loss_ref, param_ref, grad_ref, loss, param, grad)

def test_recompute_kwargs(self):
paddle.set_device("gpu")
kwargs = {"is_test": False}
with self.assertRaises(ValueError):
loss_ref, param_ref, grad_ref = run_model(
None, recompute_block=[2], recompute_kwargs=kwargs)

def test_recompute_cpu_rng(self):
paddle.set_device("cpu")
with self.assertRaises(RuntimeError):
loss_ref, param_ref, grad_ref = run_model(None, recompute_block=[2])


if __name__ == '__main__':
unittest.main()