Skip to content

Commit

Permalink
NNGraph input/output valid by register tensors (#6240)
Browse files Browse the repository at this point in the history
* NNGraph input/output valid by register tensors

* LazyJobInstrutionType skip send Push/PullCB by NNGraph io valid

* Add test script

* Fix bug of static vec empty

* fix bug of clang

Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
  • Loading branch information
chengtbf and oneflow-ci-bot committed Sep 12, 2021
1 parent 1b901ee commit 8f67c6b
Show file tree
Hide file tree
Showing 8 changed files with 211 additions and 38 deletions.
18 changes: 11 additions & 7 deletions oneflow/api/python/framework/nn_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,17 @@ ONEFLOW_API_PYBIND11_MODULE("nn.graph.", m) {
py::class_<NNGraph, std::shared_ptr<NNGraph>>(m, "CNNGraph")
.def(py::init<const std::string&>())
.def_property_readonly("name", &NNGraph::job_name)
.def("register_input_op_names",
[](NNGraph& graph, const std::vector<std::string>& input_op_names) {
return graph.RegisterInputOpNames(input_op_names).GetOrThrow();
})
.def("register_output_op_names",
[](NNGraph& graph, const std::vector<std::string>& output_op_names) {
return graph.RegisterOutputOpNames(output_op_names).GetOrThrow();
.def(
"register_input_op_names_and_tensors",
[](NNGraph& graph, const std::vector<std::string>& input_op_names,
const std::vector<std::shared_ptr<one::Tensor>>& input_tensors) {
return graph.RegisterInputOpNamesAndTensors(input_op_names, input_tensors).GetOrThrow();
})
.def("register_output_op_names_and_tensors",
[](NNGraph& graph, const std::vector<std::string>& output_op_names,
const std::vector<std::shared_ptr<one::Tensor>>& output_tensors) {
return graph.RegisterOutputOpNamesAndTensors(output_op_names, output_tensors)
.GetOrThrow();
})
.def("register_variable_op_names_and_tensors",
[](NNGraph& graph, const std::vector<std::string>& variable_op_names,
Expand Down
54 changes: 32 additions & 22 deletions oneflow/core/eager/lazy_job_instruction_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,11 +100,17 @@ class RunLazyJobInstructionType final : public InstructionType {
OF_PROFILER_RANGE_PUSH("Send all buffers to BufferMgr");
const auto& job_name = job_instance->job_name();
auto* buffer_mgr = Global<BufferMgr<std::shared_ptr<JobInstance>>>::Get();
for (const auto& op_name : cur_nn_graph->inputs_op_names()) {
buffer_mgr->Get(GetInputBufferName(job_name, op_name))->Send(job_instance);
for (int i = 0; i < cur_nn_graph->inputs_op_names().size(); ++i) {
if (cur_nn_graph->inputs_valid().at(i)) {
const std::string& input_op_name = cur_nn_graph->inputs_op_names().at(i);
buffer_mgr->Get(GetInputBufferName(job_name, input_op_name))->Send(job_instance);
}
}
for (const auto& op_name : cur_nn_graph->outputs_op_names()) {
buffer_mgr->Get(GetOutputBufferName(job_name, op_name))->Send(job_instance);
for (int i = 0; i < cur_nn_graph->outputs_op_names().size(); ++i) {
if (cur_nn_graph->outputs_valid().at(i)) {
const std::string& output_op_name = cur_nn_graph->outputs_op_names().at(i);
buffer_mgr->Get(GetOutputBufferName(job_name, output_op_name))->Send(job_instance);
}
}
buffer_mgr->Get(GetCallbackNotifierBufferName(job_name))->Send(job_instance);
buffer_mgr->Get(GetSourceTickBufferName(job_name))->Send(job_instance);
Expand Down Expand Up @@ -138,28 +144,32 @@ class RunLazyJobInstructionType final : public InstructionType {
HashMap<std::string, std::function<void(int64_t)>> push_cbs;
CHECK_EQ(nn_graph->inputs_op_names().size(), phy_instr_operand->inputs()->size());
for (int i = 0; i < nn_graph->inputs_op_names().size(); ++i) {
const auto* blob = &phy_instr_operand->inputs()->at(i)->blob();
if (!blob) { continue; }
const auto& op_name = nn_graph->inputs_op_names().at(i);
const auto& PushCb = [blob](int64_t of_blob_ptr) {
OfBlob* of_blob = reinterpret_cast<OfBlob*>(of_blob_ptr);
of_blob->mut_blob()->CopyHeaderFrom(of_blob->mut_device_ctx(), blob);
of_blob->mut_blob()->CopyDataContentFrom(of_blob->mut_device_ctx(), blob);
};
CHECK(push_cbs.emplace(op_name, PushCb).second);
if (nn_graph->inputs_valid().at(i)) {
const auto* blob = &phy_instr_operand->inputs()->at(i)->blob();
CHECK(blob != nullptr);
const auto& op_name = nn_graph->inputs_op_names().at(i);
const auto& PushCb = [blob](int64_t of_blob_ptr) {
OfBlob* of_blob = reinterpret_cast<OfBlob*>(of_blob_ptr);
of_blob->mut_blob()->CopyHeaderFrom(of_blob->mut_device_ctx(), blob);
of_blob->mut_blob()->CopyDataContentFrom(of_blob->mut_device_ctx(), blob);
};
CHECK(push_cbs.emplace(op_name, PushCb).second);
}
}
HashMap<std::string, std::function<void(int64_t)>> pull_cbs;
CHECK_EQ(nn_graph->outputs_op_names().size(), phy_instr_operand->outputs()->size());
for (int i = 0; i < nn_graph->outputs_op_names().size(); ++i) {
auto* mut_blob = phy_instr_operand->outputs()->at(i)->mut_blob();
if (!mut_blob) { continue; }
const auto& op_name = nn_graph->outputs_op_names().at(i);
const auto& PullCb = [mut_blob](int64_t of_blob_ptr) {
OfBlob* of_blob = reinterpret_cast<OfBlob*>(of_blob_ptr);
mut_blob->CopyHeaderFrom(of_blob->mut_device_ctx(), &of_blob->blob());
mut_blob->CopyDataContentFrom(of_blob->mut_device_ctx(), &of_blob->blob());
};
CHECK(pull_cbs.emplace(op_name, PullCb).second);
if (nn_graph->outputs_valid().at(i)) {
auto* mut_blob = phy_instr_operand->outputs()->at(i)->mut_blob();
CHECK(mut_blob != nullptr);
const auto& op_name = nn_graph->outputs_op_names().at(i);
const auto& PullCb = [mut_blob](int64_t of_blob_ptr) {
OfBlob* of_blob = reinterpret_cast<OfBlob*>(of_blob_ptr);
mut_blob->CopyHeaderFrom(of_blob->mut_device_ctx(), &of_blob->blob());
mut_blob->CopyDataContentFrom(of_blob->mut_device_ctx(), &of_blob->blob());
};
CHECK(pull_cbs.emplace(op_name, PullCb).second);
}
}
const auto& FinishCb = [this, instruction]() {
auto* device_ctx = GetLazyJobDeviceCtx(instruction);
Expand Down
10 changes: 10 additions & 0 deletions oneflow/core/eager/lazy_job_instruction_type_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,16 @@ class NoArgNoRetMockNNGraph : public NNGraphIf {
return empty;
}

const std::vector<bool>& inputs_valid() const override {
static std::vector<bool> empty;
return empty;
}

const std::vector<bool>& outputs_valid() const override {
static std::vector<bool> empty;
return empty;
}

private:
const std::string job_name_;
};
Expand Down
47 changes: 45 additions & 2 deletions oneflow/core/framework/nn_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,23 @@ limitations under the License.

namespace oneflow {

namespace {

Maybe<bool> GetTensorValidInCurRank(const std::shared_ptr<one::Tensor>& tensor) {
if (tensor->is_consistent()) {
const auto& parallel_id = JUST(GetParallelId4CurrentProcessCtx(JUST(tensor->parallel_desc())));
if (parallel_id->has_value()) {
return true;
} else {
return false;
}
} else {
return true;
}
}

} // namespace

NNGraph::~NNGraph() {
VLOG(2) << "graph destructor Try to close c nn graph name " << name_ << "." << std::endl;
CHECK_JUST(Close());
Expand All @@ -57,15 +74,41 @@ const std::vector<std::string>& NNGraph::inputs_op_names() const { return input_

const std::vector<std::string>& NNGraph::outputs_op_names() const { return output_op_names_; }

const std::vector<bool>& NNGraph::inputs_valid() const { return input_tensors_valid_; }

const std::vector<bool>& NNGraph::outputs_valid() const { return output_tensors_valid_; }

int64_t NNGraph::variable_op_size() const { return variable_op_name2eager_blob_.size(); }

Maybe<void> NNGraph::RegisterInputOpNames(const std::vector<std::string>& input_op_names) {
Maybe<void> NNGraph::RegisterInputOpNamesAndTensors(
const std::vector<std::string>& input_op_names,
const std::vector<std::shared_ptr<one::Tensor>>& input_tensors) {
CHECK_EQ_OR_RETURN(input_op_names.size(), input_tensors.size());
CHECK_OR_RETURN(input_op_names_.empty())
<< " The input tensors of nn.Graph " << name_ << " are register repeatedly.";
CHECK_OR_RETURN(input_tensors_valid_.empty());
input_op_names_.assign(input_op_names.begin(), input_op_names.end());
input_tensors_valid_.reserve(input_tensors.size());
for (const auto& input_tensor : input_tensors) {
input_tensors_valid_.push_back(JUST(GetTensorValidInCurRank(input_tensor)));
}
CHECK_EQ_OR_RETURN(input_tensors_valid_.size(), input_tensors.size());
return Maybe<void>::Ok();
}

Maybe<void> NNGraph::RegisterOutputOpNames(const std::vector<std::string>& output_op_names) {
Maybe<void> NNGraph::RegisterOutputOpNamesAndTensors(
const std::vector<std::string>& output_op_names,
const std::vector<std::shared_ptr<one::Tensor>>& output_tensors) {
CHECK_EQ_OR_RETURN(output_op_names.size(), output_tensors.size());
CHECK_OR_RETURN(output_op_names_.empty())
<< " The output tensors of nn.Graph " << name_ << " are register repeatedly.";
CHECK_OR_RETURN(output_tensors_valid_.empty());
output_op_names_.assign(output_op_names.begin(), output_op_names.end());
output_tensors_valid_.reserve(output_tensors.size());
for (const auto& output_tensor : output_tensors) {
output_tensors_valid_.push_back(JUST(GetTensorValidInCurRank(output_tensor)));
}
CHECK_EQ_OR_RETURN(output_tensors_valid_.size(), output_tensors.size());
return Maybe<void>::Ok();
}

Expand Down
18 changes: 13 additions & 5 deletions oneflow/core/framework/nn_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,19 @@ class NNGraph final : public NNGraphIf {
: name_(name), runtime_inited_(false), is_closed_(false) {}
~NNGraph();

const std::string& job_name() const { return name_; }
const std::vector<std::string>& inputs_op_names() const;
const std::vector<std::string>& outputs_op_names() const;
const std::string& job_name() const override { return name_; }
const std::vector<std::string>& inputs_op_names() const override;
const std::vector<std::string>& outputs_op_names() const override;
const std::vector<bool>& inputs_valid() const override;
const std::vector<bool>& outputs_valid() const override;
int64_t variable_op_size() const;

Maybe<void> RegisterInputOpNames(const std::vector<std::string>& input_op_names);
Maybe<void> RegisterOutputOpNames(const std::vector<std::string>& output_op_names);
Maybe<void> RegisterInputOpNamesAndTensors(
const std::vector<std::string>& input_op_names,
const std::vector<std::shared_ptr<one::Tensor>>& input_tensors);
Maybe<void> RegisterOutputOpNamesAndTensors(
const std::vector<std::string>& output_op_names,
const std::vector<std::shared_ptr<one::Tensor>>& output_tensors);
Maybe<void> RegisterVariableOpNamesAndTensors(
const std::vector<std::string>& variable_op_names,
const std::vector<std::shared_ptr<one::Tensor>>& variable_tensors);
Expand All @@ -56,6 +62,8 @@ class NNGraph final : public NNGraphIf {
std::string name_;
std::vector<std::string> input_op_names_;
std::vector<std::string> output_op_names_;
std::vector<bool> input_tensors_valid_;
std::vector<bool> output_tensors_valid_;
HashMap<std::string, Blob*> variable_op_name2eager_blob_;
HashSet<std::string> variable_op_names_;
Job job_;
Expand Down
2 changes: 2 additions & 0 deletions oneflow/core/framework/nn_graph_if.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ class NNGraphIf {
virtual const std::string& job_name() const = 0;
virtual const std::vector<std::string>& inputs_op_names() const = 0;
virtual const std::vector<std::string>& outputs_op_names() const = 0;
virtual const std::vector<bool>& inputs_valid() const = 0;
virtual const std::vector<bool>& outputs_valid() const = 0;

protected:
NNGraphIf() = default;
Expand Down
8 changes: 6 additions & 2 deletions python/oneflow/nn/graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,8 +494,12 @@ def _build_graph(self, *args):
self._rebuild_outputs(out2name)

# Register input/output/variable/buffer to _c_nn_graph
self._c_nn_graph.register_input_op_names(arg_op_names)
self._c_nn_graph.register_output_op_names(output_op_names)
self._c_nn_graph.register_input_op_names_and_tensors(
arg_op_names, convert_to_tensor_tuple(self._flatten_io("input", *args))
)
self._c_nn_graph.register_output_op_names_and_tensors(
output_op_names, self._outputs_tensor_tuple
)
self._c_nn_graph.register_variable_op_names_and_tensors(
state_op_names, self._states_tensor_tuple
)
Expand Down
92 changes: 92 additions & 0 deletions python/oneflow/test/graph/test_graph_buffer_limit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import os
import time
import unittest
import numpy as np

import oneflow as flow
import oneflow.unittest


def _test_graph_buffer_limit(test_case):
class StageLayerModule(flow.nn.Module):
def __init__(self):
super().__init__()
self.linear1 = flow.nn.Linear(10, 8, False)
self.linear2 = flow.nn.Linear(8, 10, False)
flow.nn.init.constant_(self.linear1.weight, 0.023)
flow.nn.init.constant_(self.linear2.weight, 1.23)

def forward(self, x):
out0 = self.linear1(x)
out0 = out0 + 1.0
out0 = out0 * 2.0
out1 = self.linear2(out0)
return out1

P0 = flow.placement("cuda", {0: [0]})
P1 = flow.placement("cuda", {0: [1]})
PT = flow.placement("cuda", {0: [0, 1]})
B = flow.sbp.broadcast

class PipelineModule(flow.nn.Module):
def __init__(self):
super().__init__()
self.layer_0 = StageLayerModule()
self.layer_1 = StageLayerModule()
self.layer_0.to_consistent(P0, B)
self.layer_1.to_consistent(P1, B)

def forward(self, x):
# stage 0
in0 = x.to_consistent(P0, B)
out0 = self.layer_0(in0)
# stage 1
in1 = out0.to_consistent(P1, B)
out1 = self.layer_1(in1)
return out1

pp_m = PipelineModule()
pp_m.eval()

class PipelineGraph(flow.nn.Graph):
def __init__(self):
super().__init__()
self.pp_m = pp_m

def build(self, x):
return self.pp_m(x)

pp_g = PipelineGraph()

for i in range(500):
x = flow.randn(16, 10)
x = x.to_consistent(P0, B)
out = pp_g(x)
# print(out.to_local().mean())


@unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases")
@flow.unittest.skip_unless_1n2d()
class TestGraphPipelineBufferLimit(oneflow.unittest.TestCase):
def test_graph_buffer_limit(test_case):
_test_graph_buffer_limit(test_case)


if __name__ == "__main__":
unittest.main()

0 comments on commit 8f67c6b

Please sign in to comment.