Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fea/nn graph/block proxy func #5727

Merged
merged 42 commits into from
Aug 4, 2021
Merged
Show file tree
Hide file tree
Changes from 38 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
db98408
pass test on linear with training
strint Jul 29, 2021
85b1cee
Refactor RuntimeCtx for multi-runtime
chengtbf Jul 29, 2021
4d54492
Merge branch 'master' into dev_cc_refactor_runtime_context
chengtbf Jul 29, 2021
6220e12
refactor inplace to support nn graph
strint Jul 29, 2021
55d5305
block support iterator
strint Jul 29, 2021
b889a1a
block iter add check
strint Jul 29, 2021
b7ff9d5
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow into …
chengtbf Jul 30, 2021
ba584bb
Merge branch 'master' into dev_cc_refactor_runtime_context
oneflow-ci-bot Jul 30, 2021
814b403
fix scalar_mul op conf build
strint Jul 30, 2021
d4b37a4
Merge branch 'master' into dev_cc_refactor_runtime_context
oneflow-ci-bot Jul 30, 2021
0f16ffc
merge multi runtime
strint Jul 30, 2021
c20c926
deal with inplace after merge master
strint Jul 30, 2021
85c8845
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow into …
strint Jul 30, 2021
c7dc232
add alexnet graph test
strint Jul 30, 2021
b3d4413
add cpu test and format
strint Jul 30, 2021
cf61b34
cout to glog
strint Jul 31, 2021
bf1a3cb
deal with Job run finish bug
strint Aug 1, 2021
3c50a5b
refactor lazy deal with inplace
strint Aug 2, 2021
3cb56dc
merge master
strint Aug 3, 2021
80ad598
deal with 0D tensor
strint Aug 3, 2021
4fe15cd
update data path
strint Aug 3, 2021
ad4b190
address review
strint Aug 3, 2021
f6afd94
deal with lazy default attr
strint Aug 3, 2021
7dac683
mv according to ci
strint Aug 3, 2021
acf1ee8
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow into …
strint Aug 3, 2021
78d4720
merge master
strint Aug 3, 2021
71518d8
fix for ci
strint Aug 3, 2021
a93d0bb
Merge branch 'master' into fea/nn_graph/train
oneflow-ci-bot Aug 3, 2021
bad4874
Merge branch 'master' into fea/nn_graph/train
oneflow-ci-bot Aug 3, 2021
4177ea4
Merge branch 'master' into fea/nn_graph/train
oneflow-ci-bot Aug 3, 2021
98a7320
Merge branch 'master' into fea/nn_graph/train
oneflow-ci-bot Aug 3, 2021
f3b0628
fix for ci limit
strint Aug 4, 2021
481ed4e
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow into …
strint Aug 4, 2021
c7044d7
Merge branch 'fea/nn_graph/train' of https://github.com/Oneflow-Inc/o…
strint Aug 4, 2021
0eef6af
block proxy func
strint Aug 4, 2021
7e5c911
Merge branch 'master' into fea/nn_graph/block_proxy_func
strint Aug 4, 2021
88147f8
support module custom func and refacotr get attr of block
strint Aug 4, 2021
fa3c8ab
Merge branch 'fea/nn_graph/block_proxy_func' of https://github.com/On…
strint Aug 4, 2021
68b8d22
Merge branch 'master' into fea/nn_graph/block_proxy_func
strint Aug 4, 2021
0ccb9cd
auto format by CI
oneflow-ci-bot Aug 4, 2021
b84373c
Merge branch 'master' into fea/nn_graph/block_proxy_func
oneflow-ci-bot Aug 4, 2021
8317866
Merge branch 'master' into fea/nn_graph/block_proxy_func
oneflow-ci-bot Aug 4, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
85 changes: 45 additions & 40 deletions python/oneflow/nn/graph_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
limitations under the License.
"""
from collections import OrderedDict
from functools import partial
from typing import Iterator, Optional, Set, Union

import oneflow._oneflow_internal
Expand Down Expand Up @@ -207,56 +208,60 @@ def __getattr__(self, name: str):
if name in self.__dict__:
return self.__dict__[name]
if self._type == BlockType.MODULE:
# support get module
if "_modules" in self.__dict__:
modules = self.__dict__["_modules"]
if name in modules:
return modules[name]
if "_parameters" in self.__dict__:
_parameters = self.__dict__["_parameters"]
if name in _parameters:
p_block = _parameters[name]
if self._is_executing_forward:
if graph_build_util.lazy_mode.is_enabled():
if p_block._lazy_origin is None:
assert p_block._lazy_origin_builder is not None, (
repr(p_block)
+ " has no lazy Tensor creation function."
)
with p_block.scope_context():
p_block._lazy_origin = (
p_block._lazy_origin_builder()
)
return p_block._lazy_origin
else:
return p_block.origin
else:
return p_block
if "_buffers" in self.__dict__:
_buffers = self.__dict__["_buffers"]
if name in _buffers:
b_block = _buffers[name]
if self._is_executing_forward:
if graph_build_util.lazy_mode.is_enabled():
if b_block._lazy_origin is None:
assert b_block._lazy_origin_builder is not None, (
repr(b_block)
+ " has no lazy Tensor creation function."
)
with b_block.scope_context():
b_block._lazy_origin = (
b_block._lazy_origin_builder()
)
return b_block._lazy_origin
else:
return b_block.origin
else:
return b_block
# support get parameter
p_state = self._get_in_states(name, "_parameters")
if p_state is not None:
return p_state
# support get buffer
b_state = self._get_in_states(name, "_buffers")
if b_state is not None:
return b_state
# support get normal attr
if name in self._origin.__dict__:
return self._origin.__dict__[name]
# support get function
if hasattr(self._origin, name):
return partial(getattr(self._origin.__class__, name), self)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

主要改动在这里,其它的是代码的整理、简化

raise AttributeError(
"'{}' object has no attribute '{}'".format(type(self).__name__, name)
)

def _get_in_states(self, name, states_name):
if states_name not in self.__dict__:
return None

_states = self.__dict__[states_name]
if name not in _states:
return None

_s_block = _states[name]
if graph_build_util.lazy_mode.is_enabled():
# lazy
if _s_block._lazy_origin is None:
assert _s_block._lazy_origin_builder is not None, (
repr(_s_block) + " has no lazy Tensor creation function."
)
assert self._is_executing_forward, (
repr(_s_block)
+ "'s first get must happened in it's nn.Module.forward() to generate the right scope."
)
with _s_block.scope_context():
_s_block._lazy_origin = _s_block._lazy_origin_builder()
return _s_block._lazy_origin
elif (
not graph_build_util.lazy_mode.is_enabled()
) and self._is_executing_forward:
# eager and inside nn.Graph.build()
return _s_block.origin
else:
# outside nn.Graph.build()
return _s_block

def __repr__(self):
lines = None
if self._type == BlockType.MODULE:
Expand Down
6 changes: 4 additions & 2 deletions python/oneflow/test/graph/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def forward(self, x):
return x


@unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases")
@flow.unittest.skip_unless_1n1d()
class TestGraph(flow.unittest.TestCase):
def test_add_nested_module(test_case):
Expand Down Expand Up @@ -201,9 +202,10 @@ def forward(self, x):
"pipeline_stage_id_hint"
].at_int64
test_case.assertEqual(stage_int, 0)
out = self.conv1(x)
weight = self.conv1.weight
test_case.assertEqual(type(weight), flow.nn.graph.Block)
return self.conv1(x)
test_case.assertTrue(weight.is_lazy)
return out

class SubModule1(flow.nn.Module):
def __init__(self):
Expand Down
109 changes: 109 additions & 0 deletions python/oneflow/test/graph/test_graph_block.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
"""
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import os
import unittest

import numpy as np

import oneflow as flow
import oneflow.unittest


@unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases")
@flow.unittest.skip_unless_1n1d()
class TestGraphBlock(flow.unittest.TestCase):
def test_module_has_custom_func(test_case):
class CustomModuleHasFunc(flow.nn.Module):
def __init__(self):
super().__init__()
self.data_mem = 10

def forward(self, x):
return self._custom_func(x)

def _custom_func(self, x):
test_case.assertEqual(self.data_mem, 10)
return x

class CustomGraphHasFunc(flow.nn.Graph):
def __init__(self):
super().__init__()
self.m = CustomModuleHasFunc()

def build(self, x):
return self.m(x)

g = CustomGraphHasFunc()
x = np.ones((10, 10))
x = flow.tensor(x, dtype=flow.float32)
out = g(x)
test_case.assertTrue(np.array_equal(x.numpy(), out.numpy()))

def test_block_with_parameter(test_case):
device = "cuda"
linear = flow.nn.Linear(3, 8)
linear = linear.to(device)
flow.nn.init.constant_(linear.weight, 2.068758)
flow.nn.init.constant_(linear.bias, 0.23)
of_sgd = flow.optim.SGD(linear.parameters(), lr=0.001, momentum=0.9)

x = flow.Tensor(
[
[-0.94630778, -0.83378579, -0.87060891],
[2.0289922, -0.28708987, -2.18369248],
[0.35217619, -0.67095644, -1.58943879],
[0.08086036, -1.81075924, 1.20752494],
[0.8901075, -0.49976737, -1.07153746],
[-0.44872912, -1.07275683, 0.06256855],
[-0.22556897, 0.74798368, 0.90416439],
[0.48339456, -2.32742195, -0.59321527],
],
device=device,
requires_grad=False,
)

class CustomModule(flow.nn.Module):
def __init__(self):
super().__init__()
self.linear = linear

def forward(self, x):
return self._forward_impl(x)

def _forward_impl(self, x):
test_case.assertTrue(isinstance(self.linear, flow.nn.graph.Block))
return self.linear(x)

class LinearTrainGraph(flow.nn.Graph):
def __init__(self):
super().__init__()
self.m = CustomModule()
self.add_optimizer("sgd", of_sgd)

def build(self, x):
out = self.m(x)
out = out.sum()
out.backward()
test_case.assertTrue(self.m.linear.weight.is_lazy)
return out

linear_t_g = LinearTrainGraph()

linear_t_g(x)


if __name__ == "__main__":
unittest.main()