Skip to content

Commit

Permalink
Refine StreamContext (#6191)
Browse files Browse the repository at this point in the history
Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
  • Loading branch information
liujuncheng and oneflow-ci-bot committed Sep 8, 2021
1 parent 9f30b5f commit a0ba3b0
Show file tree
Hide file tree
Showing 8 changed files with 76 additions and 33 deletions.
5 changes: 4 additions & 1 deletion oneflow/core/actor/actor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,10 @@ void Actor::IncreaseReadingCnt4ProducedRegst(Regst* regst, int64_t val) {
produced_regst2reading_cnt_.at(regst) += val;
}

void Actor::InitDeviceCtx(StreamContext* stream_ctx) { device_ctx_ = stream_ctx->device_ctx(); }
void Actor::InitDeviceCtx(StreamContext* stream_ctx) {
auto* provider = CHECK_NOTNULL(dynamic_cast<DeviceCtxProvider*>(stream_ctx));
device_ctx_ = provider->GetDeviceCtx();
}

void Actor::ForEachCurNaiveReadableDataRegst(std::function<void(const Regst*)> func) const {
naive_consumed_rs_.ForEachFrontRegst([func](int64_t regst_desc_id, Regst* regst) {
Expand Down
3 changes: 2 additions & 1 deletion oneflow/core/actor/light_actor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -592,7 +592,8 @@ class LightActor : public ActorBase, public KernelContext {

std::shared_ptr<DeviceCtx> NewDefaultDeviceCtx(const TaskProto& task_proto,
StreamContext* stream_ctx) {
return stream_ctx->device_ctx();
auto* provider = CHECK_NOTNULL(dynamic_cast<DeviceCtxProvider*>(stream_ctx));
return provider->GetDeviceCtx();
}

template<int kernel_exec, int inplace, typename IndexType, typename RegstIndex,
Expand Down
9 changes: 9 additions & 0 deletions oneflow/core/device/device_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,15 @@ class DeviceCtx {
private:
};

class DeviceCtxProvider {
public:
OF_DISALLOW_COPY_AND_MOVE(DeviceCtxProvider);
DeviceCtxProvider() = default;
virtual ~DeviceCtxProvider() = default;

virtual std::shared_ptr<DeviceCtx> GetDeviceCtx() = 0;
};

#define REGISTER_DEVICE_CONTEXT(device, creator) \
REGISTER_CLASS_CREATOR(int, device, DeviceCtx, creator, const ThreadCtx&)

Expand Down
15 changes: 5 additions & 10 deletions oneflow/core/stream/cpu_stream_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,17 @@ namespace oneflow {

class CpuStreamContext;

class CpuStreamContext : public StreamContext, public KernelObserverProvider {
class CpuStreamContext : public StreamContext,
public KernelObserverProvider,
public DeviceCtxProvider {
public:
OF_DISALLOW_COPY_AND_MOVE(CpuStreamContext);
explicit CpuStreamContext();
virtual ~CpuStreamContext();

Maybe<void> OnActorThreadSetup() override;
Maybe<void> OnActorThreadTeardown() override;

Maybe<void> AddCallback(std::function<void()> callback) override;
Maybe<void> Sync() override;
std::shared_ptr<DeviceCtx> device_ctx() override;
std::shared_ptr<DeviceCtx> GetDeviceCtx() override;
KernelObserver* GetKernelObserver() override;

private:
Expand Down Expand Up @@ -81,18 +80,14 @@ CpuStreamContext::CpuStreamContext() {

CpuStreamContext::~CpuStreamContext() = default;

Maybe<void> CpuStreamContext::OnActorThreadSetup() { return Maybe<void>::Ok(); }

Maybe<void> CpuStreamContext::OnActorThreadTeardown() { return Maybe<void>::Ok(); }

Maybe<void> CpuStreamContext::AddCallback(std::function<void()> callback) {
callback();
return Maybe<void>::Ok();
}

Maybe<void> CpuStreamContext::Sync() { return Maybe<void>::Ok(); }

std::shared_ptr<DeviceCtx> CpuStreamContext::device_ctx() { return device_ctx_; }
std::shared_ptr<DeviceCtx> CpuStreamContext::GetDeviceCtx() { return device_ctx_; }

KernelObserver* CpuStreamContext::GetKernelObserver() { return kernel_observer_.get(); }

Expand Down
31 changes: 16 additions & 15 deletions oneflow/core/stream/cuda_stream_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ limitations under the License.
*/
#include "oneflow/core/stream/cuda_stream_context.h"
#include "oneflow/core/stream/cuda_graph_context.h"
#include "oneflow/core/stream/execution_context_hook.h"
#include "oneflow/core/profiler/profiler.h"
#include "oneflow/core/job/global_for.h"
#include "oneflow/core/job/resource_desc.h"
Expand Down Expand Up @@ -45,13 +46,15 @@ void SetAffinityByDevice(int64_t dev_id) {
}

#ifdef WITH_CUDA_GRAPHS
#define CUDA_STREAM_CONTEXT_IMPL_BASE \
public \
CudaStreamContext, public CudaGraphContext, public KernelObserverProvider
#define CUDA_STREAM_CONTEXT_IMPL_BASE \
public \
CudaStreamContext, public CudaGraphContext, public KernelObserverProvider, \
public ExecutionContextHook, public DeviceCtxProvider
#else
#define CUDA_STREAM_CONTEXT_IMPL_BASE \
public \
CudaStreamContext, public KernelObserverProvider
#define CUDA_STREAM_CONTEXT_IMPL_BASE \
public \
CudaStreamContext, public KernelObserverProvider, public ExecutionContextHook, \
public DeviceCtxProvider
#endif

class CudaStreamContextImpl : CUDA_STREAM_CONTEXT_IMPL_BASE {
Expand All @@ -60,12 +63,12 @@ class CudaStreamContextImpl : CUDA_STREAM_CONTEXT_IMPL_BASE {
explicit CudaStreamContextImpl(const StreamId& stream_id);
virtual ~CudaStreamContextImpl();

Maybe<void> OnActorThreadSetup() override;
Maybe<void> OnActorThreadTeardown() override;
Maybe<void> OnExecutionContextSetup() override;
Maybe<void> OnExecutionContextTeardown() override;

Maybe<void> AddCallback(std::function<void()> callback) override;
Maybe<void> Sync() override;
std::shared_ptr<DeviceCtx> device_ctx() override;
std::shared_ptr<DeviceCtx> GetDeviceCtx() override;
KernelObserver* GetKernelObserver() override;

cudaStream_t cuda_stream() const override;
Expand Down Expand Up @@ -100,7 +103,6 @@ class CudaStreamContextImpl : CUDA_STREAM_CONTEXT_IMPL_BASE {
std::thread poller_thread_;
StreamId stream_id_;
std::shared_ptr<DeviceCtx> device_ctx_;
bool is_graph_capturing_;
std::unique_ptr<KernelObserver> kernel_observer_;
#ifdef WITH_CUDA_GRAPHS
std::unique_ptr<GenericCudaGraphContext> cuda_graph_ctx_impl_;
Expand Down Expand Up @@ -135,8 +137,7 @@ class DeviceCtxImpl : public DeviceCtx, public StreamContextProvider {

} // namespace

CudaStreamContextImpl::CudaStreamContextImpl(const StreamId& stream_id)
: stream_id_(stream_id), is_graph_capturing_(false) {
CudaStreamContextImpl::CudaStreamContextImpl(const StreamId& stream_id) : stream_id_(stream_id) {
CudaCurrentDeviceGuard guard(stream_id_.device_id().device_index());
CHECK_EQ(stream_id.device_id().device_type(), DeviceType::kGPU);
cuda_event_flags_ = cudaEventDisableTiming;
Expand Down Expand Up @@ -227,13 +228,13 @@ CudaStreamContextImpl::~CudaStreamContextImpl() {
OF_CUDA_CHECK(cudaStreamDestroy(cuda_stream_));
}

Maybe<void> CudaStreamContextImpl::OnActorThreadSetup() {
Maybe<void> CudaStreamContextImpl::OnExecutionContextSetup() {
SetAffinityByDevice(stream_id_.device_id().device_index());
OF_CUDA_CHECK(cudaSetDevice(stream_id_.device_id().device_index()));
return Maybe<void>::Ok();
}

Maybe<void> CudaStreamContextImpl::OnActorThreadTeardown() { return Maybe<void>::Ok(); }
Maybe<void> CudaStreamContextImpl::OnExecutionContextTeardown() { return Maybe<void>::Ok(); }

Maybe<void> CudaStreamContextImpl::AddCallback(std::function<void()> callback) {
cudaEvent_t cuda_event = GetEvent();
Expand Down Expand Up @@ -285,7 +286,7 @@ Maybe<void> CudaStreamContextImpl::Sync() {
}
}

std::shared_ptr<DeviceCtx> CudaStreamContextImpl::device_ctx() { return device_ctx_; }
std::shared_ptr<DeviceCtx> CudaStreamContextImpl::GetDeviceCtx() { return device_ctx_; }

KernelObserver* CudaStreamContextImpl::GetKernelObserver() { return kernel_observer_.get(); }

Expand Down
36 changes: 36 additions & 0 deletions oneflow/core/stream/execution_context_hook.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/*
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_STREAM_EXECUTION_CONTEXT_HOOK_H_
#define ONEFLOW_CORE_STREAM_EXECUTION_CONTEXT_HOOK_H_

#include "oneflow/core/common/util.h"
#include "oneflow/core/common/auto_registration_factory.h"

namespace oneflow {

class ExecutionContextHook {
public:
OF_DISALLOW_COPY_AND_MOVE(ExecutionContextHook);
ExecutionContextHook() = default;
virtual ~ExecutionContextHook() = default;

virtual Maybe<void> OnExecutionContextSetup() = 0;
virtual Maybe<void> OnExecutionContextTeardown() = 0;
};

} // namespace oneflow

#endif // ONEFLOW_CORE_STREAM_EXECUTION_CONTEXT_HOOK_H_
4 changes: 0 additions & 4 deletions oneflow/core/stream/stream_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,8 @@ class StreamContext {
StreamContext() = default;
virtual ~StreamContext() = default;

virtual Maybe<void> OnActorThreadSetup() = 0;
virtual Maybe<void> OnActorThreadTeardown() = 0;

virtual Maybe<void> AddCallback(std::function<void()> callback) = 0;
virtual Maybe<void> Sync() = 0;
virtual std::shared_ptr<DeviceCtx> device_ctx() = 0;
};

class StreamContextProvider {
Expand Down
6 changes: 4 additions & 2 deletions oneflow/core/thread/thread.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ limitations under the License.
#include "oneflow/core/actor/actor.h"
#include "oneflow/core/actor/light_actor.h"
#include "oneflow/core/stream/stream_context.h"
#include "oneflow/core/stream/execution_context_hook.h"
#include "oneflow/core/graph/id_serialization.h"

namespace oneflow {
Expand All @@ -31,9 +32,10 @@ Thread::Thread(const StreamId& stream_id) : thrd_id_(SerializeStreamIdToInt64(st
NewObj<int, StreamContext, const StreamId&>(stream_id.device_id().device_type(), stream_id);
stream_ctx_.reset(stream_ctx);
actor_thread_ = std::thread([this]() {
CHECK_JUST(stream_ctx_->OnActorThreadSetup());
auto* hook = dynamic_cast<ExecutionContextHook*>(stream_ctx_.get());
if (hook != nullptr) { CHECK_JUST(hook->OnExecutionContextSetup()); }
PollMsgChannel();
CHECK_JUST(stream_ctx_->OnActorThreadTeardown());
if (hook != nullptr) { CHECK_JUST(hook->OnExecutionContextTeardown()); }
});
}

Expand Down

0 comments on commit a0ba3b0

Please sign in to comment.