Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stream wait #8571

Merged
merged 134 commits into from Aug 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
134 commits
Select commit Hold shift + click to select a range
6e8e9c9
ThreadLocalGuard
lixinqi May 12, 2022
08e9178
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi May 14, 2022
f59d17d
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi May 18, 2022
3eb809a
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi May 18, 2022
55c163c
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi May 20, 2022
8aa2e8f
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi Jun 1, 2022
7612597
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi Jun 6, 2022
de5f971
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi Jun 8, 2022
8e86949
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi Jun 9, 2022
2ca0707
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi Jun 16, 2022
8537b7e
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi Jun 16, 2022
55c5160
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi Jun 17, 2022
e643eb1
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi Jun 18, 2022
eccdfe6
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi Jun 20, 2022
043accc
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi Jun 20, 2022
97b0eef
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi Jun 20, 2022
1591853
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi Jun 23, 2022
ba6f2d7
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi Jun 23, 2022
5e1a86a
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi Jun 23, 2022
1ee004c
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi Jun 24, 2022
e853c71
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi Jun 24, 2022
c5afe82
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi Jun 24, 2022
14226d6
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi Jun 26, 2022
754d6a7
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi Jun 27, 2022
acb7c98
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi Jun 28, 2022
5916848
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi Jun 28, 2022
913f6f5
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi Jul 1, 2022
fa3867e
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi Jul 2, 2022
61bee99
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi Jul 3, 2022
7eb2d72
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi Jul 4, 2022
9554f41
stream_wait
lixinqi Jul 5, 2022
7557ace
Merge branch 'master' into stream_wait
lixinqi Jul 5, 2022
797974a
Instruction::Prescheduleable
lixinqi Jul 5, 2022
e978387
env var ONEFLOW_VM_ENABLE_STREAM_WAIT
lixinqi Jul 6, 2022
5862a95
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi Jul 6, 2022
99fa4ad
fix static check error
clackhan Jul 6, 2022
60175fa
Merge branch 'master' into stream_wait
clackhan Jul 6, 2022
29ad00c
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi Jul 6, 2022
7297192
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi Jul 6, 2022
0a54078
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi Jul 8, 2022
cec8a1d
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi Jul 11, 2022
b50e236
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi Jul 11, 2022
4962c17
Merge branch 'master' into stream_wait
lixinqi Jul 13, 2022
b4f9b31
fix conflicts
lixinqi Jul 13, 2022
a6c5d07
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi Jul 15, 2022
b6b73a2
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi Jul 16, 2022
43197bb
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi Jul 16, 2022
4453c58
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi Jul 17, 2022
582e11f
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi Jul 18, 2022
b46e762
Merge branch 'master' into stream_wait
lixinqi Jul 18, 2022
d5f032f
enable StreamWait
lixinqi Jul 18, 2022
1b3b6c7
Merge branch 'master' into stream_wait
lixinqi Jul 18, 2022
4001637
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi Jul 19, 2022
1dbac7c
merge master
lixinqi Jul 19, 2022
1c6bd69
Merge branch 'stream_wait' of github.com:Oneflow-Inc/oneflow into str…
lixinqi Jul 19, 2022
7fdc675
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi Jul 19, 2022
df91ad2
Merge branch 'master' into stream_wait
lixinqi Jul 19, 2022
c2a47ac
do not use an object after std::move
lixinqi Jul 19, 2022
1555f70
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi Jul 20, 2022
58df59b
Merge branch 'master' into stream_wait
lixinqi Jul 20, 2022
cea5d58
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi Jul 22, 2022
29fd659
Merge branch 'master' into stream_wait
lixinqi Jul 22, 2022
ccbddef
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi Jul 22, 2022
2562e16
Merge branch 'master' into stream_wait
lixinqi Jul 22, 2022
cc31bd2
refactor Instruction::Done
lixinqi Jul 24, 2022
deb692b
Fix typo in oneflow/core/framework/instructions_builder.cpp
daquexian Jul 25, 2022
45208ed
support stream_wait in AccesBlobByCallback
lixinqi Jul 25, 2022
c914f2f
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi Jul 25, 2022
98e05ef
put flow._C.stream_touch(buffers) into post_forward_hook
lixinqi Jul 25, 2022
6b7885f
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi Jul 25, 2022
09489b2
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi Jul 26, 2022
ad38490
merge master
lixinqi Jul 26, 2022
ee14204
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi Jul 27, 2022
4720413
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi Jul 28, 2022
c939c3d
merge master
lixinqi Jul 28, 2022
97b697d
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi Jul 28, 2022
2cccecb
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi Jul 28, 2022
755199c
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi Jul 29, 2022
31a5022
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi Jul 29, 2022
49ce984
Merge branch 'stream_wait' of github.com:Oneflow-Inc/oneflow into str…
lixinqi Jul 29, 2022
f441aa7
merge master
lixinqi Jul 29, 2022
c0ef53b
no event query for StreamWait
lixinqi Jul 29, 2022
d690538
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi Jul 29, 2022
7ed2811
Merge branch 'master' into stream_wait
lixinqi Jul 29, 2022
a3a6056
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi Jul 29, 2022
2cda60c
Update oneflow/core/framework/instructions_builder.cpp
lixinqi Jul 30, 2022
364911e
Merge branch 'master' into stream_wait
lixinqi Jul 30, 2022
eb15c55
auto format by CI
oneflow-ci-bot Jul 30, 2022
b4fd1bd
Merge branch 'master' into stream_wait
lixinqi Jul 30, 2022
dcaacc6
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi Jul 30, 2022
e3161dd
Merge branch 'master' into stream_wait
lixinqi Jul 30, 2022
b2aacf5
Merge branch 'master' into stream_wait
mergify[bot] Jul 30, 2022
9c1d602
merge master
lixinqi Jul 31, 2022
b1d640c
Merge branch 'stream_wait' of github.com:Oneflow-Inc/oneflow into str…
lixinqi Jul 31, 2022
40e4312
Merge branch 'master' into stream_wait
lixinqi Jul 31, 2022
541e570
Merge branch 'master' into stream_wait
mergify[bot] Jul 31, 2022
f6efbb7
include cuda_runtime_api.h
lixinqi Jul 31, 2022
a350960
Merge branch 'stream_wait' of github.com:Oneflow-Inc/oneflow into str…
lixinqi Jul 31, 2022
0d37dfa
Merge branch 'master' into stream_wait
mergify[bot] Jul 31, 2022
4d0194d
replace cuda_stream_api.h with cuda_stream.h
lixinqi Jul 31, 2022
b3b3a1e
Merge branch 'stream_wait' of github.com:Oneflow-Inc/oneflow into str…
lixinqi Jul 31, 2022
20d591b
using default flags for cudaStreamWaitEvent
lixinqi Jul 31, 2022
700c39a
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi Jul 31, 2022
314e037
Merge branch 'master' into stream_wait
lixinqi Jul 31, 2022
1d73197
passing zero to 3rd argument of cudaStreamWaitEvent
lixinqi Aug 1, 2022
b6ed0fe
Merge branch 'master' into stream_wait
lixinqi Aug 1, 2022
1c6f65f
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi Aug 1, 2022
9247cba
Merge branch 'master' into stream_wait
mergify[bot] Aug 1, 2022
b1ae914
Merge branch 'master' into stream_wait
mergify[bot] Aug 1, 2022
ef3da78
Merge branch 'master' into stream_wait
mergify[bot] Aug 1, 2022
1d3c62f
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi Aug 2, 2022
612b397
merge master
lixinqi Aug 2, 2022
11c2021
fix complier complaints
lixinqi Aug 2, 2022
50bc3ed
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi Aug 2, 2022
7af07cf
merge master
lixinqi Aug 2, 2022
486c43a
Merge branch 'master' into stream_wait
lixinqi Aug 2, 2022
3aee226
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi Aug 2, 2022
1cf3aba
Merge branch 'master' into stream_wait
lixinqi Aug 2, 2022
f44ad7e
Merge branch 'master' into stream_wait
mergify[bot] Aug 2, 2022
d540ee0
Merge branch 'master' into stream_wait
lixinqi Aug 2, 2022
f828447
Merge branch 'master' into stream_wait
ouyangyu Aug 3, 2022
4de4d3b
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi Aug 3, 2022
6f4156e
Merge branch 'master' into stream_wait
lixinqi Aug 3, 2022
d5548bc
fix bug in StreamWaitInstructionPolicy::InitInstructionStatus
lixinqi Aug 3, 2022
ff27216
Merge branch 'stream_wait' of github.com:Oneflow-Inc/oneflow into str…
lixinqi Aug 3, 2022
3b4dc75
Merge branch 'master' into stream_wait
mergify[bot] Aug 3, 2022
8691ced
Merge branch 'master' into stream_wait
mergify[bot] Aug 3, 2022
8164635
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi Aug 3, 2022
627f363
merge master
lixinqi Aug 3, 2022
b142ea1
Merge branch 'master' into stream_wait
mergify[bot] Aug 3, 2022
a53003e
Merge branch 'master' into stream_wait
lixinqi Aug 4, 2022
3effc13
Merge branch 'master' into stream_wait
ouyangyu Aug 4, 2022
f8ad010
Merge branch 'master' into stream_wait
ouyangyu Aug 4, 2022
39527a5
Merge branch 'master' into stream_wait
mergify[bot] Aug 4, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions oneflow/core/common/env_var/vm.h
Expand Up @@ -21,6 +21,7 @@ limitations under the License.
namespace oneflow {

DEFINE_THREAD_LOCAL_ENV_BOOL(ONEFLOW_VM_COMPUTE_ON_WORKER_THREAD, true);
DEFINE_THREAD_LOCAL_ENV_BOOL(ONEFLOW_VM_ENABLE_STREAM_WAIT, true);
DEFINE_THREAD_LOCAL_ENV_INTEGER(ONEFLOW_VM_PENDING_HANDLE_WINDOW_SIZE, 10)
DEFINE_THREAD_LOCAL_ENV_BOOL(ONEFLOW_VM_ENABLE_SCHEDULE_YIELD, true)

Expand Down
57 changes: 50 additions & 7 deletions oneflow/core/framework/instructions_builder.cpp
Expand Up @@ -27,6 +27,7 @@ limitations under the License.
#include "oneflow/core/common/decorator.h"
#include "oneflow/core/common/blocking_counter.h"
#include "oneflow/core/common/singleton_ptr.h"
#include "oneflow/core/common/env_var/vm.h"
#include "oneflow/core/rpc/include/global_process_ctx.h"
#include "oneflow/core/vm/access_blob_arg_cb_instruction_policy.h"
#include "oneflow/core/vm/ep_record_event_instruction_policy.h"
Expand All @@ -37,6 +38,7 @@ limitations under the License.
#include "oneflow/core/vm/lazy_job_instruction_policy.h"
#include "oneflow/core/vm/global_sync_instruction_policy.h"
#include "oneflow/core/vm/op_call_instruction_policy.h"
#include "oneflow/core/vm/stream_wait_instruction_policy.h"
#include "oneflow/core/vm/touch_tensors_instruction_policy.h"
#include "oneflow/core/vm/virtual_machine.h"
#include "oneflow/core/vm/vm_util.h"
Expand All @@ -47,6 +49,8 @@ limitations under the License.
#include "oneflow/core/framework/stream.h"
#include "oneflow/core/framework/stream_need_soft_sync.h"
#include "oneflow/core/framework/stream_is_comm_net_stream.h"
#include "oneflow/core/framework/stream_support_stream_wait.h"
#include "oneflow/core/framework/stream_on_independent_thread.h"
#include "oneflow/core/job/env_desc.h"
#include "oneflow/core/profiler/profiler.h"
#include "oneflow/core/platform/include/pthread_fork.h"
Expand Down Expand Up @@ -379,8 +383,7 @@ Maybe<void> InstructionsBuilder::ReleaseTensor(
return Maybe<void>::Ok();
}
if (last_used_stream != producer_stream) {
JUST(SoftSyncStream({JUST(eager_blob_object->compute_local_dep_object())}, "mut",
last_used_stream));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

原来用SoftSyncStream 为何改成 RecordEvent

JUST(RecordEvent({JUST(eager_blob_object->compute_local_dep_object())}, last_used_stream));
}
Optional<Symbol<Stream>> stream{};
if (*one::CurrentDevVmDepObjectConsumeMode() == one::DevVmDepObjectConsumeMode::NONE) {
Expand Down Expand Up @@ -486,23 +489,62 @@ Maybe<void> InstructionsBuilder::SoftSyncStream(const vm::EagerBlobObjectList& e
JUST(ForEachEagerBlobObjectsNeedingSoftSync(
eager_blob_objects, stream,
[&](Symbol<Stream> last_used_stream, auto&& dep_objects) -> Maybe<void> {
return SoftSyncStream(std::move(dep_objects), "mut", last_used_stream);
return SoftSyncStreamBetween(std::move(dep_objects), last_used_stream, stream);
}));
for (const auto& eager_blob_object : eager_blob_objects) {
eager_blob_object->set_last_used_stream(stream);
}
return Maybe<void>::Ok();
}

Maybe<void> InstructionsBuilder::SoftSyncStream(
namespace {

bool SupportingStreamWait(Symbol<Stream> from_stream, Symbol<Stream> to_stream) {
if (unlikely(!ThreadLocalEnvBool<ONEFLOW_VM_ENABLE_STREAM_WAIT>())) { return false; }
DeviceType from_device_type = from_stream->device()->enum_type();
DeviceType to_device_type = from_stream->device()->enum_type();
return from_stream->device() == to_stream->device() && from_device_type == DeviceType::kCUDA
&& StreamSupportStreamWait::Visit(from_stream->stream_type(), from_device_type)
&& StreamSupportStreamWait::Visit(to_stream->stream_type(), to_device_type)
&& !StreamOnIndependentThread::Visit(from_stream->stream_type())
&& !StreamOnIndependentThread::Visit(to_stream->stream_type());
}

} // namespace

Maybe<void> InstructionsBuilder::SoftSyncStreamBetween(
small_vector<intrusive::shared_ptr<LocalDepObject>, kOpArgsReservedSize>&& dependences,
Symbol<Stream> from_stream, Symbol<Stream> to_stream) {
CHECK(from_stream != to_stream) << "synchronization is unnecessary";
if (SupportingStreamWait(from_stream, to_stream)) {
JUST(StreamWait(std::move(dependences), from_stream, to_stream));
} else {
JUST(RecordEvent(std::move(dependences), from_stream));
}
return Maybe<void>::Ok();
}

Maybe<void> InstructionsBuilder::StreamWait(
small_vector<intrusive::shared_ptr<LocalDepObject>, kOpArgsReservedSize>&& dependences,
Symbol<Stream> from_stream, Symbol<Stream> to_stream) {
auto* from_vm_stream = JUST(Singleton<VirtualMachine>::Get()->GetVmStream(from_stream));
auto* to_vm_stream = JUST(Singleton<VirtualMachine>::Get()->GetVmStream(to_stream));
auto instruction = intrusive::make_shared<vm::Instruction>(
to_vm_stream, std::make_unique<vm::StreamWaitInstructionPolicy>(
std::move(dependences), from_vm_stream, to_vm_stream));
instruction_list_->EmplaceBack(std::move(instruction));
return Maybe<void>::Ok();
}

Maybe<void> InstructionsBuilder::RecordEvent(
small_vector<intrusive::shared_ptr<LocalDepObject>, kOpArgsReservedSize>&&
compute_local_dep_objects,
const std::string& modifier, Symbol<Stream> last_used_stream) {
Symbol<Stream> last_used_stream) {
DeviceType device_type = last_used_stream->device()->enum_type();
if (!NeedSoftSync::Visit(last_used_stream->stream_type(), device_type)) {
return Maybe<void>::Ok();
}
OF_PROFILER_RANGE_GUARD("SoftStream");
std::string modifier = "mut";
StreamType stream_type = last_used_stream->stream_type();
auto instruction = intrusive::make_shared<vm::Instruction>(
JUST(Singleton<VirtualMachine>::Get()->GetVmStream(last_used_stream)),
Expand Down Expand Up @@ -588,7 +630,6 @@ Maybe<void> InstructionsBuilder::AccessBlobByCallback(
const std::string& modifier) {
const std::shared_ptr<vm::EagerBlobObject>& eager_blob_object = JUST(tensor->eager_blob_object());
Symbol<Device> device = JUST(GetDevice(tensor));
Symbol<Stream> stream = JUST(GetDefaultStreamByDevice(device));
// Do not use producer_stream or last_used_stream.
// Bug case when using producer_stream or last_used_stream:
//
Expand All @@ -599,6 +640,8 @@ Maybe<void> InstructionsBuilder::AccessBlobByCallback(
// ```
// `ndarray` may not be ones because instruction AccessBlobByCallback is prescheduled before
// oneflow.ones actually finished.
Symbol<Stream> stream = JUST(GetDefaultStreamByDevice(device));
JUST(SoftSyncStream({eager_blob_object}, stream));
auto instruction = intrusive::make_shared<vm::Instruction>(
// Never replace `stream` with producer_stream or last_used_stream.
JUST(Singleton<VirtualMachine>::Get()->GetVmStream(stream)),
Expand Down
15 changes: 11 additions & 4 deletions oneflow/core/framework/instructions_builder.h
Expand Up @@ -140,11 +140,18 @@ class InstructionsBuilder : public std::enable_shared_from_this<InstructionsBuil
private:
Maybe<void> SoftSyncStream(const vm::EagerBlobObjectList& eager_blob_objects,
Symbol<Stream> stream);
Maybe<void> SoftSyncStream(small_vector<intrusive::shared_ptr<LocalDepObject>,
kOpArgsReservedSize>&& compute_local_dep_objects,
const std::string& modifier, Symbol<Stream> stream);
Maybe<void> SoftSyncStreamBetween(
small_vector<intrusive::shared_ptr<LocalDepObject>, kOpArgsReservedSize>&& dependences,
Symbol<Stream> from_stream, Symbol<Stream> to_stream);

Maybe<void> StreamWait(
small_vector<intrusive::shared_ptr<LocalDepObject>, kOpArgsReservedSize>&& dependences,
Symbol<Stream> from_stream, Symbol<Stream> to_stream);

Maybe<void> RecordEvent(small_vector<intrusive::shared_ptr<LocalDepObject>, kOpArgsReservedSize>&&
compute_local_dep_objects,
Symbol<Stream> stream);

private:
vm::InstructionList* instruction_list_;
};

Expand Down
45 changes: 45 additions & 0 deletions oneflow/core/framework/stream_support_stream_wait.h
@@ -0,0 +1,45 @@
/*
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_FRAMEWORK_STREAM_SUPPORT_STREAM_WAIT_H_
#define ONEFLOW_CORE_FRAMEWORK_STREAM_SUPPORT_STREAM_WAIT_H_

#include <glog/logging.h>
#include "oneflow/core/common/stream_type.h"

namespace oneflow {

struct StreamSupportStreamWait : public StreamTypeVisitor<StreamSupportStreamWait> {
static bool VisitCompute(DeviceType device_type) { return Supported(device_type); }
static bool VisitHost2Device(DeviceType device_type) { return false; }
static bool VisitDevice2Host(DeviceType device_type) { return false; }
static bool VisitAsyncedDevice2Host(DeviceType device_type) {
return VisitDevice2Host(device_type);
}
static bool VisitSyncedLaunchedCommNet(DeviceType device_type) { return Supported(device_type); }
static bool VisitAsyncedLaunchedCommNet(DeviceType device_type) { return Supported(device_type); }
static bool VisitBarrier(DeviceType device_type) { return false; }
static bool VisitCriticalSection(DeviceType device_type) { return false; }
static bool VisitLazyJobLauncher(DeviceType device_type) { return false; }
static bool VisitPinnedCompute(DeviceType device_type) { return VisitCompute(device_type); }
static bool VisitTmpCompute(DeviceType device_type) { return VisitCompute(device_type); }

private:
static bool Supported(DeviceType device_type) { return device_type == kCUDA; }
};

} // namespace oneflow

#endif // ONEFLOW_CORE_FRAMEWORK_STREAM_SUPPORT_STREAM_WAIT_H_
3 changes: 2 additions & 1 deletion oneflow/core/vm/instruction.cpp
Expand Up @@ -49,7 +49,8 @@ void Instruction::DeleteStatusAndClearEdges() {
}

bool Instruction::Done() const {
return stream_policy().QueryInstructionStatusDone(stream(), status_buffer());
return stream_policy().QueryInstructionStatusDone(stream(), status_buffer())
&& in_edges().empty();
}

StreamPolicy* Instruction::mut_stream_policy() { return mut_stream()->mut_stream_policy(); }
Expand Down
6 changes: 6 additions & 0 deletions oneflow/core/vm/instruction_policy.h
Expand Up @@ -29,11 +29,17 @@ namespace oneflow {
namespace vm {

class EagerBlobObject;
class Stream;

class InstructionPolicy {
public:
virtual ~InstructionPolicy() = default;

// Same stream.
virtual bool Prescheduleable(const vm::Stream* src, const vm::Stream* dst) const {
return src == dst;
}

virtual const DependenceVector& input_dependences() const = 0;
virtual const DependenceVector& output_dependences() const = 0;
virtual Dependence* stream_sequential_dependence() const { return stream_sequential_dependence_; }
Expand Down
93 changes: 93 additions & 0 deletions oneflow/core/vm/stream_wait_instruction_policy.cpp
@@ -0,0 +1,93 @@
/*
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/vm/stream_wait_instruction_policy.h"
#include "oneflow/core/vm/ep_event.h"
#include "oneflow/core/vm/instruction.h"
#include "oneflow/core/vm/stream.h"
#include "oneflow/core/ep/cuda/cuda_event.h"
#include "oneflow/core/ep/cuda/cuda_stream.h"
#include "oneflow/core/ep/cuda/cuda_device.h"
#include "oneflow/core/vm/ep_stream_policy_base.h"
#include "oneflow/core/vm/ep_optional_event_record_status_querier.h"

namespace oneflow {
namespace vm {

StreamWaitInstructionPolicy::StreamWaitInstructionPolicy(
small_vector<intrusive::shared_ptr<LocalDepObject>, kOpArgsReservedSize>&& dependences,
vm::Stream* from_vm_stream, vm::Stream* to_vm_stream)
: dependences_(std::move(dependences)),
input_dependences_(),
output_dependences_(),
from_vm_stream_(from_vm_stream) {
for (const auto& dep : dependences_) { output_dependences_.push_back(dep.get()); }
stream_sequential_dependence_ = to_vm_stream->schedule_local_dep_object().get();
}

bool StreamWaitInstructionPolicy::Prescheduleable(const Stream* src, const Stream* dst) const {
return &src->thread_ctx() == &dst->thread_ctx();
}

void StreamWaitInstructionPolicy::InitInstructionStatus(Instruction* instruction) {
auto* stream = mut_from_vm_stream();
auto* ep_stream_policy_base =
CHECK_NOTNULL(dynamic_cast<EpStreamPolicyBase*>(instruction->mut_stream_policy()));
ep_stream_policy_base->InitInstructionStatus(*stream, instruction->mut_status_buffer());
auto* ep_event_provider = ep_stream_policy_base->ep_event_provider();
const auto& ep_event = CHECK_NOTNULL(ep_event_provider)->GetReusedEpEvent();
mut_ep_event() = ep_event;
}

void StreamWaitInstructionPolicy::DeleteInstructionStatus(Instruction* instruction) {
auto* stream = mut_from_vm_stream();
instruction->stream_policy().DeleteInstructionStatus(*stream, instruction->mut_status_buffer());
mut_ep_event().reset();
}

void StreamWaitInstructionPolicy::Compute(vm::Instruction* instruction) {
const auto& ep_event = mut_ep_event();
{
// Record event.
auto* from_naive_stream_policy =
dynamic_cast<EpStreamPolicyBase*>(mut_from_vm_stream()->mut_stream_policy());
CHECK_NOTNULL(from_naive_stream_policy);
auto* from_stream = from_naive_stream_policy->stream();
from_stream->RecordEvent(ep_event->mut_event());
}
{
// Wait event.
auto* to_ep_stream_policy_base =
dynamic_cast<EpStreamPolicyBase*>(instruction->mut_stream()->mut_stream_policy());
CHECK_NOTNULL(to_ep_stream_policy_base);
auto* to_ep_stream = to_ep_stream_policy_base->stream();
CHECK_EQ(ep_event->mut_device(), to_ep_stream->device())
<< "only support waiting events from same device";
ep_event->mut_device()->SetAsActiveDevice();
#ifdef WITH_CUDA

auto* ep_cuda_event = CHECK_NOTNULL(dynamic_cast<ep::CudaEvent*>(ep_event->mut_event()));
auto* ep_cuda_stream = CHECK_NOTNULL(dynamic_cast<ep::CudaStream*>(to_ep_stream));

OF_CUDA_CHECK(
cudaStreamWaitEvent(ep_cuda_stream->cuda_stream(), ep_cuda_event->cuda_event(), 0));
#else
UNIMPLEMENTED();
#endif // WITH_CUDA
}
}

} // namespace vm
} // namespace oneflow
66 changes: 66 additions & 0 deletions oneflow/core/vm/stream_wait_instruction_policy.h
@@ -0,0 +1,66 @@
/*
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

#ifndef ONEFLOW_CORE_VM_STREAM_WAIT_INSTRUCTION_POLICY_H_
#define ONEFLOW_CORE_VM_STREAM_WAIT_INSTRUCTION_POLICY_H_

#include <functional>
#include "oneflow/core/eager/local_dep_object.h"
#include "oneflow/core/vm/instruction_policy.h"
#include "oneflow/core/common/op_args_reserved_size.h"
#include "oneflow/core/common/small_vector.h"

namespace oneflow {
class EpEvent;
namespace vm {

class Stream;

class StreamWaitInstructionPolicy final : public vm::InstructionPolicy {
public:
StreamWaitInstructionPolicy(
small_vector<intrusive::shared_ptr<LocalDepObject>, kOpArgsReservedSize>&& dependences,
vm::Stream* from_vm_stream, vm::Stream* to_vm_stream);
~StreamWaitInstructionPolicy() = default;

std::string DebugName(const vm::Instruction&) const override { return "StreamWait"; }

bool Prescheduleable(const Stream* src, const Stream* dst) const override;
void InitInstructionStatus(Instruction* instruction) override;
void DeleteInstructionStatus(Instruction* instruction) override;
Maybe<void> Prepare(vm::Instruction* instruction) override { return Maybe<void>::Ok(); }
void Compute(vm::Instruction* instruction) override;

const DependenceVector& input_dependences() const override { return input_dependences_; }
const DependenceVector& output_dependences() const override { return output_dependences_; }

void ForEachInputEagerBlobObjects(void (*DoEach)(EagerBlobObject*)) const override {}

private:
vm::Stream* mut_from_vm_stream() { return from_vm_stream_; }
std::shared_ptr<EpEvent>& mut_ep_event() { return ep_event_; }

small_vector<intrusive::shared_ptr<LocalDepObject>, kOpArgsReservedSize> dependences_;
DependenceVector input_dependences_;
DependenceVector output_dependences_;
vm::Stream* from_vm_stream_;
std::shared_ptr<EpEvent> ep_event_;
};

} // namespace vm
} // namespace oneflow

#endif // ONEFLOW_CORE_VM_STREAM_WAIT_INSTRUCTION_POLICY_H_
2 changes: 1 addition & 1 deletion oneflow/core/vm/virtual_machine_engine.cpp
Expand Up @@ -252,7 +252,7 @@ void VirtualMachineEngine::ConsumeDependences(Instruction* instruction) {
}

bool VirtualMachineEngine::EdgeDispatchable(const Instruction* src, const Instruction* dst) const {
return (&src->stream() == &dst->stream()) /* same stream*/
return dst->instruction_policy().Prescheduleable(&src->stream(), &dst->stream())
&& !src->dispatched_instruction_hook().empty() /* dispatched */;
}

Expand Down