Skip to content

Commit

Permalink
Fea/nn graph/graph name (#5413)
Browse files Browse the repository at this point in the history
* graph api

* add graph dummy test

* add test

* add recursive module mode

* graph.build test pass

* add detail check on graph inner node

* support config and train

* add repr for debug

* test buffer

* test buffer add

* refine test

* add comment

* refine test

* refactor Node to Block

* add named_state

* refine Graph.named_state()

* add state_tensortuple

* graph._compile()

* add mc session 0

* nn.graph: state tuple to private var; add BlockType; add simple multi client session

* NNGraphIf

* rm old graph.cpp

* nn.graph: add cpp NNGraph; export and call NNGraph

* add comment

* nn.Graph: rm prototype MultiClientSession

* nn.Graph: rm prototype MultiClientSession test

* nn.Graph: add TODO

* nn.Graph: hack to get Graph object name

* nn.Graph: get obj name

* nn.Graph: get obj name 2

* nn.Graph: format for review

* nn.Graph: format

* nn.Graph: format

* nn.Graph: pass flake8 check

* Update graph.py

* name with init count

* name with init count 2

Co-authored-by: Xinqi Li <lixinqi0703106@163.com>
Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
  • Loading branch information
3 people authored Jul 9, 2021
1 parent 1b8bca0 commit 041e441
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 9 deletions.
12 changes: 10 additions & 2 deletions oneflow/python/nn/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from typing import Union

import oneflow._oneflow_internal
import oneflow.python.framework.id_util as id_util
import oneflow.python.framework.tensor_tuple_util as tensor_tuple_util
from oneflow.python.oneflow_export import oneflow_export, experimental_api
from oneflow.python.nn.module import Module
Expand All @@ -31,9 +30,11 @@
@oneflow_export("nn.Graph", "nn.graph.Graph")
@experimental_api
class Graph(object):
_child_init_cnt = dict()

def __init__(self):
self.config = GraphConfig()
self._name = id_util.UniqueStr(self.__class__.__name__ + "_")
self._generate_name()
self._c_nn_graph = oneflow._oneflow_internal.NNGraph(self._name)
self._blocks = OrderedDict()
self._optimizers = OrderedDict()
Expand Down Expand Up @@ -63,6 +64,13 @@ def add_optimizer(
optimizer, lr_scheduler, grad_clipping_conf, weight_decay_conf
)

def _generate_name(self):
child_name = self.__class__.__name__
if Graph._child_init_cnt.get(child_name) is None:
Graph._child_init_cnt[child_name] = 0
self._name = child_name + "_" + str(Graph._child_init_cnt[child_name])
Graph._child_init_cnt[child_name] += 1

def _named_state(self):
for _, b in self._blocks.items():
prefix = b.name + "."
Expand Down
36 changes: 29 additions & 7 deletions oneflow/python/test/graph/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ def build(self, x):

# Graph init
g = CustomGraph()
# check _c_nn_graph init
test_case.assertEqual(g.name, g._c_nn_graph.name)
# g.m is Block
test_case.assertTrue(isinstance(g.m, flow.nn.graph.Block))
# g.m.name is "m"
Expand Down Expand Up @@ -128,19 +130,39 @@ def build(self, x):
# print repr of nn.Graph
print(repr(g))

def test_graph_compile(test_case):
class CustomGraph(flow.nn.Graph):
def test_graph_name(test_case):
class ACustomGraph(flow.nn.Graph):
def __init__(self):
super().__init__()
self.m = CustomModule()
self.config.enable_auto_mixed_precision(True)

def build(self, x):
x = self.m(x)
return x

g = CustomGraph()
test_case.assertEqual(g.name, g._c_nn_graph.name)
class BCustomGraph(flow.nn.Graph):
def __init__(self):
super().__init__()

def build(self, x):
return x

class CBCustomGraph(BCustomGraph):
def __init__(self):
super().__init__()

def create_graph(cnt):
a = ACustomGraph()
test_case.assertEqual(a.name, "ACustomGraph_" + str(cnt))
b = BCustomGraph()
test_case.assertEqual(b.name, "BCustomGraph_" + str(cnt))
cb = CBCustomGraph()
test_case.assertEqual(cb.name, "CBCustomGraph_" + str(cnt))

flow.nn.Graph._child_init_cnt.clear()
for i in range(0, 3):
create_graph(i)
flow.nn.Graph._child_init_cnt.clear()
for i in range(0, 3):
create_graph(i)


if __name__ == "__main__":
Expand Down

0 comments on commit 041e441

Please sign in to comment.