diff --git a/oneflow/core/actor/actor.cpp b/oneflow/core/actor/actor.cpp index e9a5666c9c1..f6bc19416d5 100644 --- a/oneflow/core/actor/actor.cpp +++ b/oneflow/core/actor/actor.cpp @@ -320,7 +320,10 @@ void Actor::IncreaseReadingCnt4ProducedRegst(Regst* regst, int64_t val) { produced_regst2reading_cnt_.at(regst) += val; } -void Actor::InitDeviceCtx(StreamContext* stream_ctx) { device_ctx_ = stream_ctx->device_ctx(); } +void Actor::InitDeviceCtx(StreamContext* stream_ctx) { + auto* provider = CHECK_NOTNULL(dynamic_cast(stream_ctx)); + device_ctx_ = provider->GetDeviceCtx(); +} void Actor::ForEachCurNaiveReadableDataRegst(std::function func) const { naive_consumed_rs_.ForEachFrontRegst([func](int64_t regst_desc_id, Regst* regst) { diff --git a/oneflow/core/actor/light_actor.cpp b/oneflow/core/actor/light_actor.cpp index cdd8d4249ce..b14aa557d65 100644 --- a/oneflow/core/actor/light_actor.cpp +++ b/oneflow/core/actor/light_actor.cpp @@ -592,7 +592,8 @@ class LightActor : public ActorBase, public KernelContext { std::shared_ptr NewDefaultDeviceCtx(const TaskProto& task_proto, StreamContext* stream_ctx) { - return stream_ctx->device_ctx(); + auto* provider = CHECK_NOTNULL(dynamic_cast(stream_ctx)); + return provider->GetDeviceCtx(); } template GetDeviceCtx() = 0; +}; + #define REGISTER_DEVICE_CONTEXT(device, creator) \ REGISTER_CLASS_CREATOR(int, device, DeviceCtx, creator, const ThreadCtx&) diff --git a/oneflow/core/stream/cpu_stream_context.cpp b/oneflow/core/stream/cpu_stream_context.cpp index 1fb725ccbc9..43570d751d6 100644 --- a/oneflow/core/stream/cpu_stream_context.cpp +++ b/oneflow/core/stream/cpu_stream_context.cpp @@ -26,18 +26,17 @@ namespace oneflow { class CpuStreamContext; -class CpuStreamContext : public StreamContext, public KernelObserverProvider { +class CpuStreamContext : public StreamContext, + public KernelObserverProvider, + public DeviceCtxProvider { public: OF_DISALLOW_COPY_AND_MOVE(CpuStreamContext); explicit CpuStreamContext(); virtual ~CpuStreamContext(); - Maybe OnActorThreadSetup() override; - Maybe OnActorThreadTeardown() override; - Maybe AddCallback(std::function callback) override; Maybe Sync() override; - std::shared_ptr device_ctx() override; + std::shared_ptr GetDeviceCtx() override; KernelObserver* GetKernelObserver() override; private: @@ -81,10 +80,6 @@ CpuStreamContext::CpuStreamContext() { CpuStreamContext::~CpuStreamContext() = default; -Maybe CpuStreamContext::OnActorThreadSetup() { return Maybe::Ok(); } - -Maybe CpuStreamContext::OnActorThreadTeardown() { return Maybe::Ok(); } - Maybe CpuStreamContext::AddCallback(std::function callback) { callback(); return Maybe::Ok(); @@ -92,7 +87,7 @@ Maybe CpuStreamContext::AddCallback(std::function callback) { Maybe CpuStreamContext::Sync() { return Maybe::Ok(); } -std::shared_ptr CpuStreamContext::device_ctx() { return device_ctx_; } +std::shared_ptr CpuStreamContext::GetDeviceCtx() { return device_ctx_; } KernelObserver* CpuStreamContext::GetKernelObserver() { return kernel_observer_.get(); } diff --git a/oneflow/core/stream/cuda_stream_context.cpp b/oneflow/core/stream/cuda_stream_context.cpp index 9fd1f6805da..555cf72ca57 100644 --- a/oneflow/core/stream/cuda_stream_context.cpp +++ b/oneflow/core/stream/cuda_stream_context.cpp @@ -15,6 +15,7 @@ limitations under the License. */ #include "oneflow/core/stream/cuda_stream_context.h" #include "oneflow/core/stream/cuda_graph_context.h" +#include "oneflow/core/stream/execution_context_hook.h" #include "oneflow/core/profiler/profiler.h" #include "oneflow/core/job/global_for.h" #include "oneflow/core/job/resource_desc.h" @@ -45,13 +46,15 @@ void SetAffinityByDevice(int64_t dev_id) { } #ifdef WITH_CUDA_GRAPHS -#define CUDA_STREAM_CONTEXT_IMPL_BASE \ - public \ - CudaStreamContext, public CudaGraphContext, public KernelObserverProvider +#define CUDA_STREAM_CONTEXT_IMPL_BASE \ + public \ + CudaStreamContext, public CudaGraphContext, public KernelObserverProvider, \ + public ExecutionContextHook, public DeviceCtxProvider #else -#define CUDA_STREAM_CONTEXT_IMPL_BASE \ - public \ - CudaStreamContext, public KernelObserverProvider +#define CUDA_STREAM_CONTEXT_IMPL_BASE \ + public \ + CudaStreamContext, public KernelObserverProvider, public ExecutionContextHook, \ + public DeviceCtxProvider #endif class CudaStreamContextImpl : CUDA_STREAM_CONTEXT_IMPL_BASE { @@ -60,12 +63,12 @@ class CudaStreamContextImpl : CUDA_STREAM_CONTEXT_IMPL_BASE { explicit CudaStreamContextImpl(const StreamId& stream_id); virtual ~CudaStreamContextImpl(); - Maybe OnActorThreadSetup() override; - Maybe OnActorThreadTeardown() override; + Maybe OnExecutionContextSetup() override; + Maybe OnExecutionContextTeardown() override; Maybe AddCallback(std::function callback) override; Maybe Sync() override; - std::shared_ptr device_ctx() override; + std::shared_ptr GetDeviceCtx() override; KernelObserver* GetKernelObserver() override; cudaStream_t cuda_stream() const override; @@ -100,7 +103,6 @@ class CudaStreamContextImpl : CUDA_STREAM_CONTEXT_IMPL_BASE { std::thread poller_thread_; StreamId stream_id_; std::shared_ptr device_ctx_; - bool is_graph_capturing_; std::unique_ptr kernel_observer_; #ifdef WITH_CUDA_GRAPHS std::unique_ptr cuda_graph_ctx_impl_; @@ -135,8 +137,7 @@ class DeviceCtxImpl : public DeviceCtx, public StreamContextProvider { } // namespace -CudaStreamContextImpl::CudaStreamContextImpl(const StreamId& stream_id) - : stream_id_(stream_id), is_graph_capturing_(false) { +CudaStreamContextImpl::CudaStreamContextImpl(const StreamId& stream_id) : stream_id_(stream_id) { CudaCurrentDeviceGuard guard(stream_id_.device_id().device_index()); CHECK_EQ(stream_id.device_id().device_type(), DeviceType::kGPU); cuda_event_flags_ = cudaEventDisableTiming; @@ -227,13 +228,13 @@ CudaStreamContextImpl::~CudaStreamContextImpl() { OF_CUDA_CHECK(cudaStreamDestroy(cuda_stream_)); } -Maybe CudaStreamContextImpl::OnActorThreadSetup() { +Maybe CudaStreamContextImpl::OnExecutionContextSetup() { SetAffinityByDevice(stream_id_.device_id().device_index()); OF_CUDA_CHECK(cudaSetDevice(stream_id_.device_id().device_index())); return Maybe::Ok(); } -Maybe CudaStreamContextImpl::OnActorThreadTeardown() { return Maybe::Ok(); } +Maybe CudaStreamContextImpl::OnExecutionContextTeardown() { return Maybe::Ok(); } Maybe CudaStreamContextImpl::AddCallback(std::function callback) { cudaEvent_t cuda_event = GetEvent(); @@ -285,7 +286,7 @@ Maybe CudaStreamContextImpl::Sync() { } } -std::shared_ptr CudaStreamContextImpl::device_ctx() { return device_ctx_; } +std::shared_ptr CudaStreamContextImpl::GetDeviceCtx() { return device_ctx_; } KernelObserver* CudaStreamContextImpl::GetKernelObserver() { return kernel_observer_.get(); } diff --git a/oneflow/core/stream/execution_context_hook.h b/oneflow/core/stream/execution_context_hook.h new file mode 100644 index 00000000000..3b1ce462d93 --- /dev/null +++ b/oneflow/core/stream/execution_context_hook.h @@ -0,0 +1,36 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#ifndef ONEFLOW_CORE_STREAM_EXECUTION_CONTEXT_HOOK_H_ +#define ONEFLOW_CORE_STREAM_EXECUTION_CONTEXT_HOOK_H_ + +#include "oneflow/core/common/util.h" +#include "oneflow/core/common/auto_registration_factory.h" + +namespace oneflow { + +class ExecutionContextHook { + public: + OF_DISALLOW_COPY_AND_MOVE(ExecutionContextHook); + ExecutionContextHook() = default; + virtual ~ExecutionContextHook() = default; + + virtual Maybe OnExecutionContextSetup() = 0; + virtual Maybe OnExecutionContextTeardown() = 0; +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_STREAM_EXECUTION_CONTEXT_HOOK_H_ diff --git a/oneflow/core/stream/stream_context.h b/oneflow/core/stream/stream_context.h index a597d81c849..81472e89bc9 100644 --- a/oneflow/core/stream/stream_context.h +++ b/oneflow/core/stream/stream_context.h @@ -29,12 +29,8 @@ class StreamContext { StreamContext() = default; virtual ~StreamContext() = default; - virtual Maybe OnActorThreadSetup() = 0; - virtual Maybe OnActorThreadTeardown() = 0; - virtual Maybe AddCallback(std::function callback) = 0; virtual Maybe Sync() = 0; - virtual std::shared_ptr device_ctx() = 0; }; class StreamContextProvider { diff --git a/oneflow/core/thread/thread.cpp b/oneflow/core/thread/thread.cpp index 9c0ca227d41..addac76abc1 100644 --- a/oneflow/core/thread/thread.cpp +++ b/oneflow/core/thread/thread.cpp @@ -19,6 +19,7 @@ limitations under the License. #include "oneflow/core/actor/actor.h" #include "oneflow/core/actor/light_actor.h" #include "oneflow/core/stream/stream_context.h" +#include "oneflow/core/stream/execution_context_hook.h" #include "oneflow/core/graph/id_serialization.h" namespace oneflow { @@ -31,9 +32,10 @@ Thread::Thread(const StreamId& stream_id) : thrd_id_(SerializeStreamIdToInt64(st NewObj(stream_id.device_id().device_type(), stream_id); stream_ctx_.reset(stream_ctx); actor_thread_ = std::thread([this]() { - CHECK_JUST(stream_ctx_->OnActorThreadSetup()); + auto* hook = dynamic_cast(stream_ctx_.get()); + if (hook != nullptr) { CHECK_JUST(hook->OnExecutionContextSetup()); } PollMsgChannel(); - CHECK_JUST(stream_ctx_->OnActorThreadTeardown()); + if (hook != nullptr) { CHECK_JUST(hook->OnExecutionContextTeardown()); } }); }