Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bug of Multi-Client src tick output order #6221

Merged
merged 10 commits into from
Sep 11, 2021
1 change: 1 addition & 0 deletions oneflow/core/framework/nn_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ Maybe<void> NNGraph::CompileAndInitRuntime() {
<< " , compile time: " << (GetCurTime() - start) / 1000000000.0 << " seconds.\n";
if (Global<ResourceDesc, ForSession>::Get()->enable_debug_mode()) {
TeePersistentLogStream::Create("job_" + name_ + "_plan")->Write(plan_);
PlanUtil::ToDotFile(plan_, "job_" + name_ + "_plan.dot");
}
// TODO(chengcheng): test collective boxing for multi-job.
PlanUtil::GenCollectiveBoxingPlan(&job_, &plan_);
Expand Down
43 changes: 43 additions & 0 deletions oneflow/core/graph/task_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ limitations under the License.
#include "oneflow/core/job/scope.h"
#include "oneflow/core/vm/symbol_storage.h"
#include "oneflow/core/job_rewriter/calculation_pass.h"
#include "oneflow/core/job/env_desc.h"
#include "oneflow/core/graph/boxing/sub_task_graph_builder_util.h"
#include "oneflow/core/graph/boxing/hierarchical_sub_task_graph_builder_impl.h"
#include "oneflow/core/graph/stream_index_getter_registry_manager.h"
Expand Down Expand Up @@ -542,6 +543,48 @@ void TaskGraph::ConnectCtrlEdges(const std::vector<CompTaskNode*>& src_task_node
}
}

void TaskGraph::AddCtrlEdgeBetweenSrcDstTickAndInputOutputInSameRank() {
if (!CHECK_JUST(GlobalMultiClientEnv())) { return; }
HashMap<int64_t, TaskNode*> rank_id2src_tick;
HashMap<int64_t, TaskNode*> rank_id2dst_tick;
HashMap<int64_t, HashSet<TaskNode*>> rank_id2input_output_nodes;

ForEachNode([&](TaskNode* node) {
if (node->GetTaskType() == TaskType::kSrcSubsetTick) {
CHECK(rank_id2src_tick.emplace(node->machine_id(), node).second);
} else if (node->GetTaskType() == TaskType::kDstSubsetTick) {
CHECK(rank_id2dst_tick.emplace(node->machine_id(), node).second);
} else if (node->GetTaskType() == TaskType::kNormalForward) {
auto* forward_node = reinterpret_cast<NormalForwardCompTaskNode*>(node);
CHECK(forward_node);
if (forward_node->op()->op_conf().has_input_conf()
|| forward_node->op()->op_conf().has_output_conf()) {
CHECK(rank_id2input_output_nodes[node->machine_id()].insert(node).second);
}
}
});

auto AddCtrlEdge = [&](TaskNode* src, TaskNode* dst) {
std::string ctrl_regst_name;
src->BuildCtrlRegstDesc(dst, &ctrl_regst_name);
TaskEdge* edge = NewEdge();
Connect<TaskNode>(src, edge, dst);
src->BindEdgeWithProducedRegst(edge, ctrl_regst_name);
};

for (auto& pair : rank_id2src_tick) {
int64_t rank_id = pair.first;
TaskNode* src = pair.second;
for (TaskNode* io_task : rank_id2input_output_nodes[rank_id]) { AddCtrlEdge(src, io_task); }
}

for (auto& pair : rank_id2dst_tick) {
int64_t rank_id = pair.first;
TaskNode* dst = pair.second;
for (TaskNode* io_task : rank_id2input_output_nodes[rank_id]) { AddCtrlEdge(io_task, dst); }
}
}

void TaskGraph::RemoveEmptyRegsts() {
ForEachNode([&](TaskNode* node) { node->EraseUninitializedShapeProducedBlob(); });
ForEachNode([&](TaskNode* node) { node->EraseZeroSizeConsumedRegst(); });
Expand Down
1 change: 1 addition & 0 deletions oneflow/core/graph/task_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class TaskGraph final : public Graph<TaskNode, TaskEdge> {

const char* TypeName() const override { return "TaskGraph"; }
void RemoveEmptyRegsts();
void AddCtrlEdgeBetweenSrcDstTickAndInputOutputInSameRank();
void MergeChainAndAddOrderingCtrlEdgeInSameChain();

void EnableInplaceMemSharing(const std::function<bool(const std::string&, const std::string&)>&
Expand Down
4 changes: 4 additions & 0 deletions oneflow/core/job/compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ void Compiler::Compile(Job* job, Plan* plan, bool need_job_complete) const {
task_gph->ForEachNode(std::bind(&TaskNode::PinConsumedRegst, _1));
task_gph->TopoForEachNode(&TaskNode::Build);
task_gph->RemoveEmptyRegsts();
// NOTE(chengcheng):
// In Multi-Client, each rank has its own src_tick/dst_tick and input/output with callback,
// which need to be forced sequenced.
task_gph->AddCtrlEdgeBetweenSrcDstTickAndInputOutputInSameRank();
task_gph->MergeChainAndAddOrderingCtrlEdgeInSameChain();
auto IsReachable = Global<OpGraph>::Get()->MakePredicatorIsOpNameDataOrCtrlReachable();
if (job_desc.enable_inplace()) { task_gph->EnableInplaceMemSharing(IsReachable); }
Expand Down
117 changes: 117 additions & 0 deletions python/oneflow/test/graph/test_graph_pipeline_delay.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
"""
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import os
import time
import unittest
import numpy as np

import oneflow as flow
import oneflow.unittest


def _test_graph_pipeline_delay_output(test_case):
class StageLayerModule(flow.nn.Module):
def __init__(self):
super().__init__()
self.linear1 = flow.nn.Linear(10, 8, False)
self.linear2 = flow.nn.Linear(8, 10)
flow.nn.init.constant_(self.linear1.weight, 0.023)
flow.nn.init.constant_(self.linear2.weight, 1.23)

def forward(self, x):
out0 = self.linear1(x)
out0 = out0 + 1.0
out0 = out0 * 2.0
out1 = self.linear2(out0)
return out1

P0 = flow.placement("cuda", {0: [0]})
P1 = flow.placement("cuda", {0: [1]})
B = flow.sbp.broadcast

class PipelineModule(flow.nn.Module):
def __init__(self):
super().__init__()
self.layer_0 = StageLayerModule()
self.layer_1 = StageLayerModule()
self.layer_0.to_consistent(P0, B)
self.layer_1.to_consistent(P1, B)

def forward(self, x):
# stage 0
in0 = x.to_consistent(P0, B)
out0 = self.layer_0(in0)
# stage 1
in1 = out0.to_consistent(P1, B)
out1 = self.layer_1(in1)
return out1

pp_m = PipelineModule()
pp_m.train()
of_sgd = flow.optim.SGD(pp_m.parameters(), lr=0.001)

class PipelineGraph(flow.nn.Graph):
def __init__(self):
super().__init__()
self.pp_m = pp_m
self.pp_m.layer_0.config.stage_id = 0
self.pp_m.layer_1.config.stage_id = 1
self.config.set_gradient_accumulation_steps(4)
self.add_optimizer(of_sgd)

def build(self, x, y):
pp_out = self.pp_m(x)
loss = pp_out.mean()
loss.backward()
y = x + y
free_out = y.to_consistent(P1, B)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个没有上下文的估计看不出在做什么,这里加注释说明下应该会更清楚

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的

return loss, free_out

pp_g = PipelineGraph()
rank = flow.env.get_rank()
for i in range(3):
x = flow.randn(16, 10)
y = flow.randn(16, 10)
x = x.to_consistent(P0, B)
y = y.to_consistent(P0, B)
if rank == 1:
time.sleep(2)
loss_pack_4, free_out = pp_g(x, y)
if rank == 1:
time.sleep(2)
chengtbf marked this conversation as resolved.
Show resolved Hide resolved
print(
"rank: ",
rank,
"packed loss with 4 micro-batch = ",
loss_pack_4.to_local(),
)
print(
"rank: ",
rank,
"packed image with 4 micro-batch = ",
free_out.to_local(),
)


@unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases")
@flow.unittest.skip_unless_1n2d()
class TestGraphPipelineDelayOutput(oneflow.unittest.TestCase):
def test_graph_pipeline_delay_output(test_case):
_test_graph_pipeline_delay_output(test_case)


if __name__ == "__main__":
unittest.main()