-
Notifications
You must be signed in to change notification settings - Fork 656
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fea/nn graph/block proxy func #5727
Merged
Merged
Changes from 38 commits
Commits
Show all changes
42 commits
Select commit
Hold shift + click to select a range
db98408
pass test on linear with training
strint 85b1cee
Refactor RuntimeCtx for multi-runtime
chengtbf 4d54492
Merge branch 'master' into dev_cc_refactor_runtime_context
chengtbf 6220e12
refactor inplace to support nn graph
strint 55d5305
block support iterator
strint b889a1a
block iter add check
strint b7ff9d5
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow into …
chengtbf ba584bb
Merge branch 'master' into dev_cc_refactor_runtime_context
oneflow-ci-bot 814b403
fix scalar_mul op conf build
strint d4b37a4
Merge branch 'master' into dev_cc_refactor_runtime_context
oneflow-ci-bot 0f16ffc
merge multi runtime
strint c20c926
deal with inplace after merge master
strint 85c8845
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow into …
strint c7dc232
add alexnet graph test
strint b3d4413
add cpu test and format
strint cf61b34
cout to glog
strint bf1a3cb
deal with Job run finish bug
strint 3c50a5b
refactor lazy deal with inplace
strint 3cb56dc
merge master
strint 80ad598
deal with 0D tensor
strint 4fe15cd
update data path
strint ad4b190
address review
strint f6afd94
deal with lazy default attr
strint 7dac683
mv according to ci
strint acf1ee8
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow into …
strint 78d4720
merge master
strint 71518d8
fix for ci
strint a93d0bb
Merge branch 'master' into fea/nn_graph/train
oneflow-ci-bot bad4874
Merge branch 'master' into fea/nn_graph/train
oneflow-ci-bot 4177ea4
Merge branch 'master' into fea/nn_graph/train
oneflow-ci-bot 98a7320
Merge branch 'master' into fea/nn_graph/train
oneflow-ci-bot f3b0628
fix for ci limit
strint 481ed4e
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow into …
strint c7044d7
Merge branch 'fea/nn_graph/train' of https://github.com/Oneflow-Inc/o…
strint 0eef6af
block proxy func
strint 7e5c911
Merge branch 'master' into fea/nn_graph/block_proxy_func
strint 88147f8
support module custom func and refacotr get attr of block
strint fa3c8ab
Merge branch 'fea/nn_graph/block_proxy_func' of https://github.com/On…
strint 68b8d22
Merge branch 'master' into fea/nn_graph/block_proxy_func
strint 0ccb9cd
auto format by CI
oneflow-ci-bot b84373c
Merge branch 'master' into fea/nn_graph/block_proxy_func
oneflow-ci-bot 8317866
Merge branch 'master' into fea/nn_graph/block_proxy_func
oneflow-ci-bot File filter
Filter by extension
Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
""" | ||
Copyright 2020 The OneFlow Authors. All rights reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
""" | ||
import os | ||
import unittest | ||
|
||
import numpy as np | ||
|
||
import oneflow as flow | ||
import oneflow.unittest | ||
|
||
|
||
@unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") | ||
@flow.unittest.skip_unless_1n1d() | ||
class TestGraphBlock(flow.unittest.TestCase): | ||
def test_module_has_custom_func(test_case): | ||
class CustomModuleHasFunc(flow.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.data_mem = 10 | ||
|
||
def forward(self, x): | ||
return self._custom_func(x) | ||
|
||
def _custom_func(self, x): | ||
test_case.assertEqual(self.data_mem, 10) | ||
return x | ||
|
||
class CustomGraphHasFunc(flow.nn.Graph): | ||
def __init__(self): | ||
super().__init__() | ||
self.m = CustomModuleHasFunc() | ||
|
||
def build(self, x): | ||
return self.m(x) | ||
|
||
g = CustomGraphHasFunc() | ||
x = np.ones((10, 10)) | ||
x = flow.tensor(x, dtype=flow.float32) | ||
out = g(x) | ||
test_case.assertTrue(np.array_equal(x.numpy(), out.numpy())) | ||
|
||
def test_block_with_parameter(test_case): | ||
device = "cuda" | ||
linear = flow.nn.Linear(3, 8) | ||
linear = linear.to(device) | ||
flow.nn.init.constant_(linear.weight, 2.068758) | ||
flow.nn.init.constant_(linear.bias, 0.23) | ||
of_sgd = flow.optim.SGD(linear.parameters(), lr=0.001, momentum=0.9) | ||
|
||
x = flow.Tensor( | ||
[ | ||
[-0.94630778, -0.83378579, -0.87060891], | ||
[2.0289922, -0.28708987, -2.18369248], | ||
[0.35217619, -0.67095644, -1.58943879], | ||
[0.08086036, -1.81075924, 1.20752494], | ||
[0.8901075, -0.49976737, -1.07153746], | ||
[-0.44872912, -1.07275683, 0.06256855], | ||
[-0.22556897, 0.74798368, 0.90416439], | ||
[0.48339456, -2.32742195, -0.59321527], | ||
], | ||
device=device, | ||
requires_grad=False, | ||
) | ||
|
||
class CustomModule(flow.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.linear = linear | ||
|
||
def forward(self, x): | ||
return self._forward_impl(x) | ||
|
||
def _forward_impl(self, x): | ||
test_case.assertTrue(isinstance(self.linear, flow.nn.graph.Block)) | ||
return self.linear(x) | ||
|
||
class LinearTrainGraph(flow.nn.Graph): | ||
def __init__(self): | ||
super().__init__() | ||
self.m = CustomModule() | ||
self.add_optimizer("sgd", of_sgd) | ||
|
||
def build(self, x): | ||
out = self.m(x) | ||
out = out.sum() | ||
out.backward() | ||
test_case.assertTrue(self.m.linear.weight.is_lazy) | ||
return out | ||
|
||
linear_t_g = LinearTrainGraph() | ||
|
||
linear_t_g(x) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
主要改动在这里,其它的是代码的整理、简化