Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: support no_grad decorator #5947

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/oneflow.rst
Original file line number Diff line number Diff line change
Expand Up @@ -122,5 +122,6 @@ oneflow
zeros,
zeros_like,
is_nonzero,
no_grad,

.. autofunction:: oneflow.data.load_mnist(train_batch_size=100, test_batch_size=100, data_format='NCHW')
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,12 @@ namespace oneflow {
namespace autograd {

ONEFLOW_API_PYBIND11_MODULE("autograd", m) {
py::class_<NoGradGuard, std::shared_ptr<NoGradGuard>>(m, "no_grad")
py::class_<NoGradGuard, std::shared_ptr<NoGradGuard>>(m, "NoGradGuard")
.def(py::init([]() { return std::make_shared<NoGradGuard>(); }))
.def("__enter__", [](const NoGradGuard& no_grad_obj) {})
.def("__exit__", [](const NoGradGuard& no_grad_obj, const py::object& type,
const py::object& value, const py::object& traceback) {});
m.def("autograd_mode", &GradMode::is_enabled);
}

} // namespace autograd
Expand Down
3 changes: 2 additions & 1 deletion python/oneflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
Size = oneflow._oneflow_internal.Size
device = oneflow._oneflow_internal.device
placement = oneflow._oneflow_internal.placement
no_grad = oneflow._oneflow_internal.autograd.no_grad
autograd_mode = oneflow._oneflow_internal.autograd.autograd_mode
locals()["dtype"] = oneflow._oneflow_internal.dtype
locals()["char"] = oneflow._oneflow_internal.char
locals()["float16"] = oneflow._oneflow_internal.float16
Expand Down Expand Up @@ -114,6 +114,7 @@ def _SyncOnMasterFn():
register_docstr()
del register_docstr
del docstr
from oneflow.autograd import no_grad
import oneflow.nn.image
import oneflow.nn.modules.acosh
import oneflow.nn.modules.activation
Expand Down
1 change: 1 addition & 0 deletions python/oneflow/autograd/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@
"""

from oneflow.autograd.autograd import backward, grad
from oneflow.autograd.autograd_mode import no_grad
62 changes: 62 additions & 0 deletions python/oneflow/autograd/autograd_mode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
"""
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

from oneflow._oneflow_internal.autograd import NoGradGuard


class no_grad(NoGradGuard):
wyg1997 marked this conversation as resolved.
Show resolved Hide resolved
r"""
Context-manager that disabled gradient calculation.

Disabling gradient calculation is useful for inference, when you are sure that
you will not call Tensor.backward(). It will reduce memory consumption for computations
that would otherwise have requires_grad=True.

In this mode, the result of every computation will have requires_grad=False, even when
the inputs have requires_grad=True.

This context manager is thread local; it will not affect computation in other threads.

Also functions as a decorator. (Make sure to instantiate with parenthesis.)

.. code-block:: python

>>> import oneflow as flow
>>> x = flow.ones(2, 3, requires_grad=True)
>>> with flow.no_grad():
... y = x * x
>>> y.requires_grad
False
>>> @flow.no_grad()
... def no_grad_func(x):
... return x * x
>>> y = no_grad_func(x)
>>> y.requires_grad
False
"""

def __call__(self, func):
def warpper(*args, **kwargs):
with NoGradGuard():
return func(*args, **kwargs)

return warpper


if __name__ == "__main__":
import doctest

doctest.testmod(raise_on_error=True)