Skip to content

Commit

Permalink
Revert no_grad changes and add new implementation (#26902)
Browse files Browse the repository at this point in the history
  • Loading branch information
willthefrog committed Sep 2, 2020
1 parent dd28cad commit 89ef291
Show file tree
Hide file tree
Showing 9 changed files with 138 additions and 39 deletions.
2 changes: 1 addition & 1 deletion python/paddle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@
from .fluid.dygraph.base import enable_dygraph as disable_static #DEFINE_ALIAS
from .fluid.dygraph.base import disable_dygraph as enable_static #DEFINE_ALIAS
from .fluid.framework import in_dygraph_mode as in_dynamic_mode #DEFINE_ALIAS
from .fluid.dygraph.base import no_grad #DEFINE_ALIAS
from .fluid.dygraph.base import no_grad_ as no_grad #DEFINE_ALIAS

from . import jit
from . import static
Expand Down
8 changes: 4 additions & 4 deletions python/paddle/fluid/clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def __init__(self, need_clip=None):
def __str__(self):
raise NotImplementedError()

@imperative_base.no_grad()
@imperative_base.no_grad
def _dygraph_clip(self, params_grads):
raise NotImplementedError

Expand Down Expand Up @@ -258,7 +258,7 @@ def __init__(self, max, min=None, need_clip=None):
def __str__(self):
return "Gradient Clip By Value, min = %f, max=%f" % (self.min, self.max)

@imperative_base.no_grad()
@imperative_base.no_grad
def _dygraph_clip(self, params_grads):
params_and_grads = []
for p, g in params_grads:
Expand Down Expand Up @@ -413,7 +413,7 @@ def __init__(self, clip_norm, need_clip=None):
def __str__(self):
return "Gradient Clip By Norm, clip_norm=%f" % self.clip_norm

@imperative_base.no_grad()
@imperative_base.no_grad
def _dygraph_clip(self, params_grads):
params_and_grads = []
for p, g in params_grads:
Expand Down Expand Up @@ -565,7 +565,7 @@ def __init__(self, clip_norm, group_name="default_group", need_clip=None):
def __str__(self):
return "Gradient Clip By GlobalNorm, global_norm=%f" % (self.clip_norm)

@imperative_base.no_grad()
@imperative_base.no_grad
def _dygraph_clip(self, params_grads):
params_and_grads = []
sum_square_list = []
Expand Down
82 changes: 78 additions & 4 deletions python/paddle/fluid/dygraph/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from ..wrapped_decorator import signature_safe_contextmanager, wrap_decorator
import inspect
import decorator
import contextlib
import functools
import inspect
import sys
import numpy as np
from paddle.fluid import core
Expand All @@ -26,8 +27,8 @@
from ..data_feeder import convert_dtype

__all__ = [
'no_grad', 'grad', 'guard', 'enable_dygraph', 'disable_dygraph', 'enabled',
'to_variable'
'no_grad', 'no_grad_', 'grad', 'guard', 'enable_dygraph', 'disable_dygraph',
'enabled', 'to_variable'
]


Expand Down Expand Up @@ -167,7 +168,80 @@ def disable_dygraph():
_functional_dygraph_context_manager = None


class no_grad:
@signature_safe_contextmanager
def _switch_tracer_mode_guard_(is_train=True):
tracer = framework._dygraph_tracer()
if tracer:
mode = tracer._train_mode
tracer._train_mode = is_train
try:
yield
finally:
tracer._train_mode = mode
else:
yield


def no_grad(func=None):
"""
:api_attr: imperative
Create a context which disables dygraph gradient calculation.
In this mode, the result of every computation will have `stop_gradient=True`.
Also functions as a decorator. (Make sure to instantiate without parenthesis.)
Examples:
.. code-block:: python
import numpy as np
import paddle.fluid as fluid
# use as generator
data = np.array([[2, 3], [4, 5]]).astype('float32')
with fluid.dygraph.guard():
l0 = fluid.Linear(2, 2) # l0.weight.gradient() is None
l1 = fluid.Linear(2, 2)
with fluid.dygraph.no_grad():
# l1.weight.stop_gradient is False
tmp = l1.weight * 2 # tmp.stop_gradient is True
x = fluid.dygraph.to_variable(data)
y = l0(x) + tmp
o = l1(y)
o.backward()
print(tmp.gradient() is None) # True
print(l0.weight.gradient() is None) # False
# use as decorator
@fluid.dygraph.no_grad
def test_layer():
with fluid.dygraph.guard():
inp = np.ones([3, 1024], dtype='float32')
t = fluid.dygraph.base.to_variable(inp)
linear1 = fluid.Linear(1024, 4, bias_attr=False)
linear2 = fluid.Linear(4, 4)
ret = linear1(t)
dy_ret = linear2(ret)
test_layer()
"""
if func is None:
return _switch_tracer_mode_guard_(is_train=False)
else:

@decorator.decorator
def __impl__(func, *args, **kwargs):
with _switch_tracer_mode_guard_(is_train=False):
return func(*args, **kwargs)

return __impl__(func)


class no_grad_:
"""
:api_attr: imperative
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/fluid/dygraph/math_op_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def monkey_patch_math_varbase():
The difference is, in dygraph mode, use auto-generated op functions for better performance.
"""

@no_grad()
@no_grad
def create_tensor(value, dtype, shape):
out = _varbase_creator(dtype=dtype)
out = core.ops.fill_constant(out, 'dtype', dtype, 'shape', shape,
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/fluid/dygraph/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,7 @@ def _split_tensors(self, coalesced_grads_and_grad_vars):
self._reshape_inplace(x=g_var, shape=g_shape)
assert g_var.shape == g_shape

@no_grad()
@no_grad
def apply_collective_grads(self):
"""
AllReduce the Parameters' gradient.
Expand Down
8 changes: 4 additions & 4 deletions python/paddle/fluid/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class Optimizer(object):
but need to use one of it's implementation.
"""

@imperative_base.no_grad()
@imperative_base.no_grad
def __init__(self,
learning_rate,
parameter_list=None,
Expand Down Expand Up @@ -897,7 +897,7 @@ def clear_gradients(self):
if p.trainable:
p.clear_gradient()

@imperative_base.no_grad()
@imperative_base.no_grad
def minimize(self,
loss,
startup_program=None,
Expand Down Expand Up @@ -1015,7 +1015,7 @@ def __init__(self,
name=name)
self.type = "sgd"

@no_grad()
@no_grad
def _append_optimize_op(self, block, param_and_grad):
lr = self._create_param_lr(param_and_grad)
if framework.in_dygraph_mode():
Expand Down Expand Up @@ -1552,7 +1552,7 @@ def _dgc_op(self, param_var, clip_var, grad_var, u_var, v_var, k_var,
dgc_op._set_attr(op_maker.kOpRoleVarAttrName(),
[param_var.name, grad_var.name])

@imperative_base.no_grad()
@imperative_base.no_grad
def apply_gradients(self, params_grads):
params_grads = sorted(params_grads, key=lambda x: x[0].name)
params_grads, table_param_and_grad, table_optimize_op = \
Expand Down
67 changes: 46 additions & 21 deletions python/paddle/fluid/tests/unittests/test_imperative_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def setUp(self):
def get_tracer_mode(self):
assert fluid.in_dygraph_mode(), "Dygraph mode must be enabled"

@paddle.no_grad()
@fluid.dygraph.no_grad
def no_grad_func(self, a):
self.assertEqual(self.tracer._train_mode, False)
return a
Expand Down Expand Up @@ -56,35 +56,17 @@ def test_main(self):
def need_no_grad_func(a, b=1):
return a + b

decorated_func = paddle.no_grad()(need_no_grad_func)
decorated_func = fluid.dygraph.no_grad(need_no_grad_func)
self.assertTrue(
str(inspect.getargspec(decorated_func)) ==
str(inspect.getargspec(need_no_grad_func)))

self.assertEqual(self.tracer._train_mode, self.init_mode)

def test_gen():
for i in range(3):
yield i

a = 0
for i in test_gen():
a += i

@paddle.no_grad()
def test_wrapped_gen():
for i in range(3):
yield i

b = 0
for i in test_wrapped_gen():
b += i

self.assertEqual(a, b)

with fluid.dygraph.guard():
self.check_not_support_rlt(False)

paddle.enable_static()
with new_program_scope():
self.check_not_support_rlt(True)

Expand All @@ -94,5 +76,48 @@ def setUp(self):
self.init_mode = False


class TestNoGradClass(unittest.TestCase):
@paddle.no_grad()
def no_grad_func(self, a):
self.assertEqual(self.tracer._train_mode, False)
return a

def test_main(self):
paddle.disable_static()

self.tracer = framework._dygraph_tracer()
self.tracer._train_mode = True

self.assertEqual(self.no_grad_func(1), 1)
self.assertEqual(self.no_grad_func.__name__, "no_grad_func")

def need_no_grad_func(a, b=1):
return a + b

decorated_func = paddle.no_grad()(need_no_grad_func)
self.assertEqual(
str(inspect.getargspec(decorated_func)),
str(inspect.getargspec(need_no_grad_func)))

def test_gen():
for i in range(3):
yield i

a = 0
for i in test_gen():
a += i

@paddle.no_grad()
def test_wrapped_gen():
for i in range(3):
yield i

b = 0
for i in test_wrapped_gen():
b += i

self.assertEqual(a, b)


if __name__ == '__main__':
unittest.main()
4 changes: 2 additions & 2 deletions python/paddle/optimizer/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ class Optimizer(object):
"""

@imperative_base.no_grad()
@imperative_base.no_grad
def __init__(self,
learning_rate,
parameters=None,
Expand Down Expand Up @@ -812,7 +812,7 @@ def clear_grad(self):
if p.trainable:
p.clear_gradient()

@imperative_base.no_grad()
@imperative_base.no_grad
def minimize(self,
loss,
startup_program=None,
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/optimizer/sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def __init__(self,
name=name)
self.type = "sgd"

@no_grad()
@no_grad
def _append_optimize_op(self, block, param_and_grad):
lr = self._create_param_lr(param_and_grad)
if framework.in_dygraph_mode():
Expand Down

0 comments on commit 89ef291

Please sign in to comment.