Skip to content

Commit

Permalink
nn.graph: add cpp NNGraph; export and call NNGraph
Browse files Browse the repository at this point in the history
  • Loading branch information
strint committed Jul 5, 2021
1 parent 05a8dca commit bbf0ace
Show file tree
Hide file tree
Showing 7 changed files with 103 additions and 24 deletions.
30 changes: 30 additions & 0 deletions oneflow/api/python/framework/nn_graph.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
/*
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <pybind11/pybind11.h>
#include <string>
#include "oneflow/api/python/of_api_registry.h"
#include "oneflow/core/framework/nn_graph_if.h"

namespace py = pybind11;

namespace oneflow {
ONEFLOW_API_PYBIND11_MODULE("", m) {
using namespace oneflow;
py::class_<NNGraph, std::shared_ptr<NNGraph>>(m, "NNGraph")
.def(py::init<const std::string&>())
.def_property_readonly("name", &NNGraph::job_name);
}
} // namespace oneflow
9 changes: 8 additions & 1 deletion oneflow/core/framework/nn_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,11 @@ limitations under the License.
*/
#include "oneflow/core/framework/nn_graph_if.h"

namespace oneflow {}
#include "oneflow/core/common/util.h"
namespace oneflow {

const std::vector<std::string>& NNGraph::inputs_op_names() const { UNIMPLEMENTED(); }

const std::vector<std::string>& NNGraph::outputs_op_names() const { UNIMPLEMENTED(); }

} // namespace oneflow
13 changes: 13 additions & 0 deletions oneflow/core/framework/nn_graph_if.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,19 @@ class NNGraphIf {
NNGraphIf() = default;
};

class NNGraph final : public NNGraphIf {
public:
NNGraph() = delete;
explicit NNGraph(const std::string& name) : name_(name) {}

const std::string& job_name() const { return name_; }
const std::vector<std::string>& inputs_op_names() const;
const std::vector<std::string>& outputs_op_names() const;

private:
std::string name_;
};

} // namespace oneflow

#endif // ONEFLOW_CORE_FRAMEWORK_NN_GRAPH_IF_H_
3 changes: 3 additions & 0 deletions oneflow/python/framework/session_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@

_is_multi_client_mode = True


class Session(object):
def __init__(self, sess_id):
self.job_name2function_desc_ = {}
Expand Down Expand Up @@ -521,6 +522,7 @@ def _GetDefaultConfigProto():
config_proto.session_id = session_ctx.GetDefaultSession().id
return config_proto


class MultiClientSession(Session):
def __init__(self, sess_id):
self.status_ = SessionStatus.OPEN
Expand Down Expand Up @@ -548,6 +550,7 @@ def Close(self):
self.resource_ = None
oneflow._oneflow_internal.ClearSessionById(self.id)


if _is_multi_client_mode:
new_session = MultiClientSession(oneflow._oneflow_internal.NewSessionId())
else:
Expand Down
37 changes: 24 additions & 13 deletions oneflow/python/nn/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
class Graph(object):
def __init__(self):
self._name = id_util.UniqueStr(self.__class__.__name__ + "_")
self._c_nn_graph = oneflow._oneflow_internal.NNGraph(self._name)

This comment has been minimized.

Copy link
@strint

strint Jul 5, 2021

Author Contributor

导出和调用c++ NNGraph

self.config = GraphConfig()
self._blocks = OrderedDict()
self._optimizers = OrderedDict()
Expand All @@ -54,7 +55,7 @@ def training(self):

def build(self, *args):
raise NotImplementedError()

def add_optimizer(
self,
name: str,
Expand All @@ -73,7 +74,6 @@ def train(self, mode: bool = True):
assert block.type == BlockType.MODULE
block.origin.train(mode)


def _named_state(self):
for _, b in self._blocks.items():
prefix = b.name + "."
Expand All @@ -83,15 +83,19 @@ def _named_state(self):
b_gen = b.origin.named_buffers()
for n, b in b_gen:
yield prefix + n, b

def _compile(self):
print("try to compile")
assert not self._is_compiled, "nn.Graph " + self_name + " has already been compiled."
self._state_tensortuple = tensor_tuple_util.convert_to_tensor_tuple(tuple(t for _, t in self._named_state()))
assert not self._is_compiled, (
"nn.Graph " + self_name + " has already been compiled."
)
self._state_tensortuple = tensor_tuple_util.convert_to_tensor_tuple(
tuple(t for _, t in self._named_state())
)
sess = session_ctx.GetDefaultSession()
sess.TryInit()
# do job compile

self._is_compiled = True

def __call__(self, *args):
Expand Down Expand Up @@ -156,18 +160,24 @@ def __repr__(self):
main_str += "\n " + "\n ".join(lines) + "\n"
main_str += ")"
return main_str


class BlockType:
NONE = "NONE"
MODULE = "MODULE"
PARAMETER = "PARAMETER"
BUFFER = "BUFFER"


@oneflow_export("nn.graph.Block")
@experimental_api
class Block(object):
def __init__(self, prefix: str = "", name: str = "" , value: Union[Module, Parameter, Tensor] = None):
def __init__(
self,
prefix: str = "",
name: str = "",
value: Union[Module, Parameter, Tensor] = None,
):
assert not isinstance(value, Block)
self._name = name
self._name_prefix = prefix
Expand All @@ -181,11 +191,11 @@ def __init__(self, prefix: str = "", name: str = "" , value: Union[Module, Param
self._parameters = OrderedDict()
self._buffers = OrderedDict()
for n, m in list(value.named_children()):
self.__setattr__(n, Block(self._name_prefix + self._name + ".", n, m))
self.__setattr__(n, Block(self._name_prefix + self._name + ".", n, m))
for n, p in list(value.named_parameters("", False)):
self.__setattr__(n, Block(self._name_prefix + self._name + "." , n, p))
self.__setattr__(n, Block(self._name_prefix + self._name + ".", n, p))
for n, b in list(value.named_buffers("", False)):
self.__setattr__(n, Block(self._name_prefix + self._name + "." , n, b))
self.__setattr__(n, Block(self._name_prefix + self._name + ".", n, b))
elif isinstance(value, Parameter):
self._type = BlockType.PARAMETER
elif isinstance(value, Tensor):
Expand All @@ -203,7 +213,7 @@ def __init__(self, prefix: str = "", name: str = "" , value: Union[Module, Param
@property
def name(self):
return self._name

@property
def name_prefix(self):
return self._name_prefix
Expand Down Expand Up @@ -314,6 +324,7 @@ def _append_child(d):
main_str += ")"
return main_str


@oneflow_export("nn.graph.GraphConfig")
@experimental_api
class GraphConfig(FunctionConfig):
Expand Down Expand Up @@ -343,7 +354,7 @@ def _train(self, mode: bool = True):
@experimental_api
class BlockConfig(object):
def __init__(self):
# TODO(): implement config for block
# TODO(): implement config for block
pass


Expand Down
17 changes: 7 additions & 10 deletions oneflow/python/test/graph/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def build(self, x):

# Graph init
g = CustomGraph()
# g.m is Block
# g.m is Block
test_case.assertTrue(isinstance(g.m, flow.nn.graph.Block))
# g.m.name is "m"
test_case.assertEqual(g.m.name, "m")
Expand All @@ -83,19 +83,15 @@ def build(self, x):
test_case.assertTrue(
isinstance(g.m._buffers["dummy_buff"], flow.nn.graph.Block)
)
# conv1 is Block
test_case.assertTrue(
isinstance(g.m.layer.conv1, flow.nn.graph.Block)
)
# conv1 is Block
test_case.assertTrue(isinstance(g.m.layer.conv1, flow.nn.graph.Block))
# conv1.name is "conv1"
test_case.assertEqual(g.m.layer.conv1.name, "conv1")
# conv1.weight is Tensor, Graph.build(...) need weight to be Tensor
test_case.assertTrue(isinstance(g.m.layer.conv1.weight, flow.Tensor))
# conv1._parameters["weight"] is Block
test_case.assertTrue(
isinstance(
g.m.layer.conv1._parameters["weight"], flow.nn.graph.Block
)
isinstance(g.m.layer.conv1._parameters["weight"], flow.nn.graph.Block)
)
# conv1.kernel_size is original data in original module
test_case.assertEqual(g.m.layer.conv1.kernel_size, (5, 5))
Expand Down Expand Up @@ -158,8 +154,9 @@ def build(self, x):
return x

g = CustomGraph()
g._compile()

test_case.assertEqual(g.name, g._c_nn_graph.name)
# g._compile()

# TODO(): test_add_optimizer


Expand Down
18 changes: 18 additions & 0 deletions oneflow/python/test/graph/test_multi_client_session.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,23 @@
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest

import oneflow as flow


@flow.unittest.skip_unless_1n1d()
@unittest.skipIf(
not flow.unittest.env.eager_execution_enabled(),
Expand All @@ -10,9 +26,11 @@
class TestMultiClientSession(flow.unittest.TestCase):
def test_multi_client_sessioin(test_case):
import oneflow.python.framework.session_context as sc

s = sc.GetDefaultSession()
print("default session id ", s.id)
print("default session type ", type(s))


if __name__ == "__main__":
unittest.main()

0 comments on commit bbf0ace

Please sign in to comment.