Skip to content

Commit

Permalink
Fea/checkpoint stop identity (#6216)
Browse files Browse the repository at this point in the history
* Primitive (#6183)

* Add Primitive

* #ifdef WITH_CUDA

Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>

* Disable implicit boxing when parallel num eq one (#6188)

* mv_boxing_folder_to_core

* minor fix

* disable_implicit_boxing_when_parallel_num_eq_one

* Update eager_consistent_op_interpreter.cpp

Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>

* Lazy support Scalar (#6181)

Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>

* Fix LayerNorm check bug (#6196)

* fix(Layernorm): fix check bug

* fix judge whether cpu or not

Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>

* add glu op (#6065)

* add glu op

* del glu_op export,align with torch

* mod glu_op

* mov op logic to C++

* Solve problems

* solve conflict

* delete gradient functor

* add ndim check

* add GLU test

* delete blank line

* format

Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
Co-authored-by: Zhenhua <huangzhenhua@zhejianglab.com>

* Primitive based copy task node (#6195)

Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>

* KernelState (#6198)

Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>

* container_util: fix VectorAt, remove useless MutMapAt (#6172)

* fcontainer_util: fix VectorAt, remove useless MutMapAt

* fcontainer_util: format

* MapAt: add default value version

* format

Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>

* Refine StreamContext (#6191)

Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>

* Cpu symetric s to s (#6153)

* mv_boxing_folder_to_core

* minor fix

* cpu_symetric_s_to_s

* add test case

* auto format by CI

* minor fix

* refine

* Update eager_nccl_kernels.cpp

* minor fix

* fix bug

* minor fix

* Update oneflow/user/kernels/eager_nccl_kernels.cpp

Co-authored-by: daquexian <daquexian566@gmail.com>

* Update eager_nccl_kernels.cpp

* Update eager_nccl_kernels.cpp

* minor fix

* Update eager_nccl_kernels.cpp

Co-authored-by: oneflow-ci-bot <ci-bot@oneflow.org>
Co-authored-by: daquexian <daquexian566@gmail.com>

* fix bug (#6197)

Co-authored-by: Yinggang Wang <wyg19970408@gmail.com>
Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>

* fix consistent tensor zeros (#6202)

Signed-off-by: daquexian <daquexian566@gmail.com>

Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>

* [Feat.] nn.Graph support grad acc with input/output tensor (#6155)

* nn.Graph support grad acc with input/output tensor

* dirty pass grad acc

* revert tensor.backward hack

* fix indent

* default S0 -> B

* Pack op/kernel support scalar input

* nn.Graph output pack support loss scalar

* add test script

* pass test

* Lazy build output eager tensors after job complete

* non scalar output test

Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>

* Dev eliminate gcc warnings (#6199)

* fix gcc warning

* refine

* fix comment

Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>

* StreamContextAdapter (#6205)

* StreamContextAdapter

* refine

Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>

* Autotest generate input tensor (#6206)

* Add tensor yaml, support export tensor functional api.

* refine

* Remove packed functor signature

* remove unused file

* Refine

* refine

* add activation op import

* reinit oneflow init.py

* add oneflow abs and exp

* add oneflow abs and exp

* add acos

* add arccosh

* add more op

* add more ops

* add more op

* add more ops

* add log1p

* add more smaples

* add more ops

* add more ops

* add more ops

* add more ops

* Complete tensor functional apis.

* Fix pybind call

* add more ops

* add ops done

* Add target of_functional_tensor_obj

* Disable throw visibility warnings

* fix target link

* fix

* fix incorrect use of flow.Tensor.

* Fix error merge

* fix

* fix add unittest

* refine

* refine

* fix

* fix

* add tensor doc

* auto format by CI

* refine

* Fix

* Add doc for python function

* refine

* add tensor method docstring

* fix some bug

* fix docs bug

* Fix

* auto format by CI

* Tensor->tensor

* Tensor->tensor

* refine Tensor->tensor

* fix

* fix

* fix

* fix conflict

* fix bug

* fix ci bug

* fix

* delete diag op

* fix conflict

* Fix segment

* fix

* merge

* autotest framework generate input tensor

* autotest framework generate input tensor

* fix bug

* fix impl bug

* refine

* refine

* refine

* fix

* fix

* fix comments

* delete useless

* fix ci error

* fix ci error

Co-authored-by: hjchen2 <chenhoujiangcug@gmail.com>
Co-authored-by: oneflow-ci-bot <ci-bot@oneflow.org>
Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>

* Cleanup KernelUtil (#6212)

Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>

* Rename flow to oneflow in user hint (#6190)

* style(*): rename flow to oneflow in user hint

* fix(*): fix doctest

* auto format by CI

* remove ddp speed test

Signed-off-by: daquexian <daquexian566@gmail.com>

Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
Co-authored-by: oneflow-ci-bot <ci-bot@oneflow.org>
Co-authored-by: daquexian <daquexian566@gmail.com>

* merg and refactor

* refact code

* add io identity for activation checkpointing

Co-authored-by: Juncheng <liujuncheng1022@gmail.com>
Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
Co-authored-by: binbinHan <han_binbin@163.com>
Co-authored-by: cheng cheng <472491134@qq.com>
Co-authored-by: Yinggang Wang <wyg19970408@gmail.com>
Co-authored-by: QiangX-man <87475073+QiangX-man@users.noreply.github.com>
Co-authored-by: Zhenhua <huangzhenhua@zhejianglab.com>
Co-authored-by: Twice <i@twice.moe>
Co-authored-by: oneflow-ci-bot <ci-bot@oneflow.org>
Co-authored-by: daquexian <daquexian566@gmail.com>
Co-authored-by: ZZK <42901638+MARD1NO@users.noreply.github.com>
Co-authored-by: Luyang <flowingsun007@163.com>
Co-authored-by: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com>
Co-authored-by: hjchen2 <chenhoujiangcug@gmail.com>
  • Loading branch information
15 people committed Sep 9, 2021
1 parent 9a2b3b0 commit 716cc69
Show file tree
Hide file tree
Showing 123 changed files with 1,931 additions and 1,126 deletions.
11 changes: 6 additions & 5 deletions ci/test/test_speed_multi_client.sh
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,12 @@ python3 scripts/compare_speed_with_pytorch.py resnet50/models/resnet50.py resnet
python3 scripts/compare_speed_with_pytorch.py resnet50/models/resnet50.py resnet50 2x3x224x224 --no-show-memory --times 50 | check_relative_speed 0.9 | write_to_file_and_print
python3 scripts/compare_speed_with_pytorch.py resnet50/models/resnet50.py resnet50 1x3x224x224 --no-show-memory --times 50 | check_relative_speed 0.9 | write_to_file_and_print

python3 -m oneflow.distributed.launch --nproc_per_node 2 scripts/compare_speed_with_pytorch.py resnet50/models/resnet50.py resnet50 16x3x224x224 --no-show-memory --times 50 --ddp | check_relative_speed 0.8 | write_to_file_and_print
python3 -m oneflow.distributed.launch --nproc_per_node 2 scripts/compare_speed_with_pytorch.py resnet50/models/resnet50.py resnet50 8x3x224x224 --no-show-memory --times 50 --ddp | check_relative_speed 0.8 | write_to_file_and_print
python3 -m oneflow.distributed.launch --nproc_per_node 2 scripts/compare_speed_with_pytorch.py resnet50/models/resnet50.py resnet50 4x3x224x224 --no-show-memory --times 50 --ddp | check_relative_speed 0.8 | write_to_file_and_print
python3 -m oneflow.distributed.launch --nproc_per_node 2 scripts/compare_speed_with_pytorch.py resnet50/models/resnet50.py resnet50 2x3x224x224 --no-show-memory --times 50 --ddp | check_relative_speed 0.72 | write_to_file_and_print
python3 -m oneflow.distributed.launch --nproc_per_node 2 scripts/compare_speed_with_pytorch.py resnet50/models/resnet50.py resnet50 1x3x224x224 --no-show-memory --times 50 --ddp | check_relative_speed 0.72 | write_to_file_and_print
# TODO: restore ddp speed test after allocator bug is fixed
# python3 -m oneflow.distributed.launch --nproc_per_node 2 scripts/compare_speed_with_pytorch.py resnet50/models/resnet50.py resnet50 16x3x224x224 --no-show-memory --times 50 --ddp | check_relative_speed 0.8 | write_to_file_and_print
# python3 -m oneflow.distributed.launch --nproc_per_node 2 scripts/compare_speed_with_pytorch.py resnet50/models/resnet50.py resnet50 8x3x224x224 --no-show-memory --times 50 --ddp | check_relative_speed 0.8 | write_to_file_and_print
# python3 -m oneflow.distributed.launch --nproc_per_node 2 scripts/compare_speed_with_pytorch.py resnet50/models/resnet50.py resnet50 4x3x224x224 --no-show-memory --times 50 --ddp | check_relative_speed 0.8 | write_to_file_and_print
# python3 -m oneflow.distributed.launch --nproc_per_node 2 scripts/compare_speed_with_pytorch.py resnet50/models/resnet50.py resnet50 2x3x224x224 --no-show-memory --times 50 --ddp | check_relative_speed 0.72 | write_to_file_and_print
# python3 -m oneflow.distributed.launch --nproc_per_node 2 scripts/compare_speed_with_pytorch.py resnet50/models/resnet50.py resnet50 1x3x224x224 --no-show-memory --times 50 --ddp | check_relative_speed 0.72 | write_to_file_and_print

result="GPU Name: `nvidia-smi --query-gpu=name --format=csv,noheader -i 0` \n\n `cat result`"
# escape newline for github actions: https://github.community/t/set-output-truncates-multiline-strings/16852/2
Expand Down
1 change: 1 addition & 0 deletions docs/source/functional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ Functional operations for neural networks
.. autofunction:: prelu
.. autofunction:: log_sigmoid
.. autofunction:: gelu
.. autofunction:: glu
.. autofunction:: softsign
.. autofunction:: softmax
.. autofunction:: softplus
Expand Down
1 change: 1 addition & 0 deletions docs/source/nn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ Operators for neural networks
Embedding,
Flatten,
GELU,
GLU,
GroupNorm,
Hardsigmoid,
Hardswish,
Expand Down
2 changes: 1 addition & 1 deletion oneflow/api/python/autograd/autograd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ bool IsScalarTensor(const one::Tensor& tensor) {
// Checks and sets default value for initial gradients based on out_grads
// If output is the tensor whose size is greater than 1, out_grad's shape must be same as output's.
// If output is a scalar tensor, out_grad will also be a scaler or empty(will be inited to
// `flow.ones([1])`).
// `oneflow.ones([1])`).
Maybe<one::TensorTuple> CheckAndInitOutGrads(const one::TensorTuple& outputs,
const one::TensorTuple& out_grads) {
size_t grad_size = out_grads.empty() ? outputs.size() : out_grads.size();
Expand Down
4 changes: 2 additions & 2 deletions oneflow/api/python/framework/shape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ struct ShapeExportUtil final {

static int GetItem(const Shape& shape, int idx) {
const int len = shape.dim_vec().size();
if (idx < -len || idx >= len) { throw py::index_error("flow.Size index out of range"); }
if (idx < -len || idx >= len) { throw py::index_error("oneflow.Size index out of range"); }
if (idx < 0) { idx += len; }
return shape.At(idx);
}
Expand All @@ -65,7 +65,7 @@ struct ShapeExportUtil final {
static std::string ToString(const Shape& shape) {
std::stringstream ss;
int32_t idx = 0;
ss << "flow.Size([";
ss << "oneflow.Size([";
for (int64_t dim : shape.dim_vec()) {
ss << dim;
if (++idx != shape.dim_vec().size()) { ss << ", "; }
Expand Down
11 changes: 8 additions & 3 deletions oneflow/api/python/utils/tensor_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,16 @@ namespace oneflow {
namespace one {

Maybe<void> EagerMirroredTensorZeros(const std::shared_ptr<Tensor>& t) {
const auto& tensor = JUST(t->AsMirroredTensor());
CHECK_OR_RETURN(tensor->is_eager()) << "eager tensors supported only";
std::shared_ptr<MirroredTensor> local_tensor;
if (t->is_local()) {
local_tensor = JUST(t->AsMirroredTensor());
} else {
local_tensor = JUST(t->cur_rank_phy_tensor());
}
CHECK_OR_RETURN(local_tensor->is_eager()) << "eager tensors supported only";
JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe<void> {
JUST(builder->AccessBlobByCallback(
tensor,
local_tensor,
[](uint64_t of_blob_ptr) {
auto* of_blob = reinterpret_cast<OfBlob*>(of_blob_ptr);
of_blob->AsyncAutoMemset(0);
Expand Down
42 changes: 20 additions & 22 deletions oneflow/core/actor/actor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,7 @@ limitations under the License.
*/
#include "oneflow/core/actor/actor.h"
#include "oneflow/core/control/global_process_ctx.h"
#include "oneflow/core/thread/thread_manager.h"
#include "oneflow/core/job/runtime_job_descs.h"
#include "oneflow/core/control/global_process_ctx.h"
#include "oneflow/core/stream/stream_context.h"

namespace oneflow {
Expand All @@ -27,29 +25,27 @@ namespace {
class KernelContextImpl : public KernelContext {
public:
OF_DISALLOW_COPY_AND_MOVE(KernelContextImpl);
explicit KernelContextImpl(DeviceCtx* device_ctx)
: device_ctx_(device_ctx), state_(nullptr), stream_kernel_observer_(nullptr) {
auto* stream_context_provider = dynamic_cast<StreamContextProvider*>(device_ctx);
if (stream_context_provider != nullptr) {
auto* kernel_observer_provider =
dynamic_cast<KernelObserverProvider*>(stream_context_provider->GetStreamContext());
if (kernel_observer_provider != nullptr) {
stream_kernel_observer_ = kernel_observer_provider->GetKernelObserver();
}
explicit KernelContextImpl(StreamContext* stream_ctx, DeviceCtx* device_ctx)
: stream_ctx_(stream_ctx),
device_ctx_(device_ctx),
state_(nullptr),
stream_kernel_observer_(nullptr) {
auto* kernel_observer_provider = dynamic_cast<KernelObserverProvider*>(stream_ctx_);
if (kernel_observer_provider != nullptr) {
stream_kernel_observer_ = kernel_observer_provider->GetKernelObserver();
}
}
~KernelContextImpl() = default;

StreamContext* stream_ctx() const override { return stream_ctx_; }

DeviceCtx* device_ctx() const override { return device_ctx_; }

Blob* BnInOp2Blob(const std::string& bn) const override { return bn_in_op2blob_fn_(bn); }

void* state() const override { return state_; }
const std::shared_ptr<KernelState>& state() const override { return state_; }

void set_state(void* state) override {
CHECK(state_ == nullptr);
state_ = state;
}
void set_state(std::shared_ptr<KernelState> state) override { state_ = std::move(state); }

void WillForward(KernelContext* kernel_ctx, const Kernel* kernel) override;
void DidForward(KernelContext* kernel_ctx, const Kernel* kernel) override;
Expand All @@ -65,9 +61,10 @@ class KernelContextImpl : public KernelContext {
}

private:
StreamContext* stream_ctx_;
DeviceCtx* device_ctx_;
std::function<Blob*(const std::string&)> bn_in_op2blob_fn_;
void* state_;
std::shared_ptr<KernelState> state_;
KernelObserver* stream_kernel_observer_;
};

Expand Down Expand Up @@ -127,9 +124,7 @@ void CheckInplaceRegstDescId(const TaskProto& task_proto) {

} // namespace

Actor::~Actor() {
for (ExecKernel& ek : exec_kernel_vec_) { ek.kernel->DestroyState(ek.kernel_ctx->state()); }
}
Actor::~Actor() = default;

void Actor::Init(const JobDesc* job_desc, const TaskProto& task_proto, StreamContext* stream_ctx) {
job_desc_ = job_desc;
Expand All @@ -142,7 +137,7 @@ void Actor::Init(const JobDesc* job_desc, const TaskProto& task_proto, StreamCon
}
for (const ExecNodeProto& node : task_proto.exec_sequence().exec_node()) {
ExecKernel ek;
ek.kernel_ctx.reset(new KernelContextImpl(device_ctx_.get()));
ek.kernel_ctx.reset(new KernelContextImpl(stream_ctx, device_ctx_.get()));
ek.kernel = ConstructKernel(node.kernel_conf(), ek.kernel_ctx.get());
exec_kernel_vec_.push_back(std::move(ek));
}
Expand Down Expand Up @@ -327,7 +322,10 @@ void Actor::IncreaseReadingCnt4ProducedRegst(Regst* regst, int64_t val) {
produced_regst2reading_cnt_.at(regst) += val;
}

void Actor::InitDeviceCtx(StreamContext* stream_ctx) { device_ctx_ = stream_ctx->device_ctx(); }
void Actor::InitDeviceCtx(StreamContext* stream_ctx) {
auto* provider = CHECK_NOTNULL(dynamic_cast<DeviceCtxProvider*>(stream_ctx));
device_ctx_ = provider->GetDeviceCtx();
}

void Actor::ForEachCurNaiveReadableDataRegst(std::function<void(const Regst*)> func) const {
naive_consumed_rs_.ForEachFrontRegst([func](int64_t regst_desc_id, Regst* regst) {
Expand Down
3 changes: 2 additions & 1 deletion oneflow/core/actor/case_actor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ class CaseActor final : public Actor {

void CaseActor::VirtualActorInit(const TaskProto& task_proto) {
CHECK_EQ(1, exec_kernel_vec().size());
case_status_ = static_cast<CaseStatus*>(exec_kernel_vec().at(0).kernel_ctx->state());
case_status_ =
CHECK_NOTNULL(dynamic_cast<CaseStatus*>(exec_kernel_vec().at(0).kernel_ctx->state().get()));
const int32_t output_bns_size =
task_proto.exec_sequence().exec_node().Get(0).kernel_conf().op_attribute().output_bns_size();
FOR_RANGE(int64_t, i, 0, output_bns_size) {
Expand Down
6 changes: 4 additions & 2 deletions oneflow/core/actor/esac_actor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ limitations under the License.
*/
#include "oneflow/core/actor/actor.h"
#include "oneflow/core/operator/operator.h"
#include "oneflow/core/kernel/esac_kernel.h"

namespace oneflow {

Expand Down Expand Up @@ -49,7 +50,7 @@ class EsacActor final : public Actor {
int64_t GetCurProcessedRegstDescId() const;

RegstSlot consumed_rs_;
int64_t cur_processed_regst_desc_id_;
int64_t cur_processed_regst_desc_id_{};
HashMap<int64_t, int64_t> regst_desc_id2in_bn_id_;
};

Expand Down Expand Up @@ -88,7 +89,8 @@ void EsacActor::Act() {
CHECK(cur_regst);
int64_t in_bn_id = InBnId4RegstDescId(cur_processed_regst_desc_id_);
CHECK_EQ(exec_kernel_vec().size(), 1);
*static_cast<int64_t*>(exec_kernel_vec().at(0).kernel_ctx->state()) = in_bn_id;
CHECK_NOTNULL(dynamic_cast<EsacKernelState*>(exec_kernel_vec().at(0).kernel_ctx->state().get()))
->value = in_bn_id;
AsyncLaunchKernel([&](int64_t regst_desc_id) -> Regst* {
if (cur_processed_regst_desc_id_ != regst_desc_id) { return nullptr; }
return cur_regst;
Expand Down
49 changes: 22 additions & 27 deletions oneflow/core/actor/light_actor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ struct RegstState {
struct KernelInfo {
std::unique_ptr<const Kernel> kernel;
HashMap<std::string, Blob*> bn_in_op2blob;
void* state = nullptr;
std::shared_ptr<KernelState> state;
};

template<typename IndexType, int max_size>
Expand Down Expand Up @@ -189,12 +189,6 @@ size_t GetConsumerCount(const TaskProto& task) {

#ifdef WITH_CUDA_GRAPHS

CudaGraphContext* GetCUDAGraphContext(DeviceCtx* device_ctx) {
auto* provider = dynamic_cast<StreamContextProvider*>(device_ctx);
if (provider == nullptr) { return nullptr; }
return dynamic_cast<CudaGraphContext*>(provider->GetStreamContext());
}

bool IsCUDAGraphSupported(const Kernel* kernel) {
auto* user_kernel = dynamic_cast<const UserKernel*>(kernel);
return (user_kernel != nullptr && user_kernel->IsCudaGraphSupported());
Expand All @@ -207,20 +201,17 @@ template<int exec_kernel, int inplace, typename IndexType, typename RegstIndex,
class LightActor : public ActorBase, public KernelContext {
public:
OF_DISALLOW_COPY_AND_MOVE(LightActor);
explicit LightActor(std::shared_ptr<DeviceCtx> device_ctx)
: thread_(nullptr), device_ctx_(std::move(device_ctx)), stream_kernel_observer_(nullptr) {
auto* stream_context_provider = dynamic_cast<StreamContextProvider*>(device_ctx_.get());
if (stream_context_provider != nullptr) {
auto* kernel_observer_provider =
dynamic_cast<KernelObserverProvider*>(stream_context_provider->GetStreamContext());
if (kernel_observer_provider != nullptr) {
stream_kernel_observer_ = kernel_observer_provider->GetKernelObserver();
}
explicit LightActor(StreamContext* stream_ctx, std::shared_ptr<DeviceCtx> device_ctx)
: thread_(nullptr),
stream_ctx_(stream_ctx),
device_ctx_(std::move(device_ctx)),
stream_kernel_observer_(nullptr) {
auto* kernel_observer_provider = dynamic_cast<KernelObserverProvider*>(stream_ctx_);
if (kernel_observer_provider != nullptr) {
stream_kernel_observer_ = kernel_observer_provider->GetKernelObserver();
}
}
~LightActor() override {
if (exec_kernel) { kernel_info_[0]->kernel->DestroyState(kernel_info_[0]->state); }
}
~LightActor() override = default;

void Init(const JobDesc* job_desc, const TaskProto& task_proto,
StreamContext* stream_ctx) override {
Expand All @@ -231,7 +222,7 @@ class LightActor : public ActorBase, public KernelContext {
const KernelConf& kernel_conf = task_proto.exec_sequence().exec_node(0).kernel_conf();
kernel_info_[0]->kernel = ConstructKernel(kernel_conf, this);
#ifdef WITH_CUDA_GRAPHS
cuda_graph_ctx_[0] = GetCUDAGraphContext(device_ctx_.get());
cuda_graph_ctx_[0] = dynamic_cast<CudaGraphContext*>(stream_ctx);
if (cuda_graph_ctx_[0] != nullptr && kernel_conf.all_blobs_are_static()
&& IsCUDAGraphSupported(kernel_info_[0]->kernel.get())) {
cuda_graph_exec_[0].reset(new CudaGraphExecutable());
Expand Down Expand Up @@ -498,6 +489,8 @@ class LightActor : public ActorBase, public KernelContext {
}
}

StreamContext* stream_ctx() const override { return stream_ctx_; }

DeviceCtx* device_ctx() const override { return device_ctx_.get(); }

Blob* BnInOp2Blob(const std::string& bn) const override {
Expand All @@ -513,18 +506,18 @@ class LightActor : public ActorBase, public KernelContext {
}
}

void* state() const override {
const std::shared_ptr<KernelState>& state() const override {
if (exec_kernel) {
return kernel_info_[0]->state;
} else {
return nullptr;
static const std::shared_ptr<KernelState> null_state;
return null_state;
}
}

void set_state(void* state) override {
void set_state(std::shared_ptr<KernelState> state) override {
CHECK(exec_kernel);
CHECK(kernel_info_[0]->state == nullptr);
kernel_info_[0]->state = state;
kernel_info_[0]->state = std::move(state);
}

void WillForward(KernelContext* kernel_ctx, const Kernel* kernel) override {
Expand Down Expand Up @@ -585,6 +578,7 @@ class LightActor : public ActorBase, public KernelContext {
std::unique_ptr<CudaGraphExecutable> cuda_graph_exec_[exec_kernel];
CudaGraphContext* cuda_graph_ctx_[exec_kernel]{};
#endif
StreamContext* stream_ctx_;
std::shared_ptr<DeviceCtx> device_ctx_;
std::vector<ActorMsg> sync_post_act_msgs_;
std::vector<ActorMsg> async_post_act_msgs_;
Expand All @@ -594,15 +588,16 @@ class LightActor : public ActorBase, public KernelContext {

std::shared_ptr<DeviceCtx> NewDefaultDeviceCtx(const TaskProto& task_proto,
StreamContext* stream_ctx) {
return stream_ctx->device_ctx();
auto* provider = CHECK_NOTNULL(dynamic_cast<DeviceCtxProvider*>(stream_ctx));
return provider->GetDeviceCtx();
}

template<int kernel_exec, int inplace, typename IndexType, typename RegstIndex,
typename StateContainer>
ActorBase* NewLightActor(const TaskProto& task_proto, StreamContext* stream_ctx,
std::shared_ptr<DeviceCtx> device_ctx) {
return new LightActor<kernel_exec, inplace, IndexType, RegstIndex, StateContainer>(
std::move(device_ctx));
stream_ctx, std::move(device_ctx));
}

template<int kernel_exec, int inplace, typename IndexType>
Expand Down
4 changes: 2 additions & 2 deletions oneflow/core/actor/reentrant_lock_actor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ class ReentrantLockActor final : public Actor {

void ReentrantLockActor::VirtualActorInit(const TaskProto& task_proto) {
CHECK_EQ(1, exec_kernel_vec().size());
reentrant_lock_status_ =
static_cast<ReentrantLockStatus*>(exec_kernel_vec().at(0).kernel_ctx->state());
reentrant_lock_status_ = CHECK_NOTNULL(
dynamic_cast<ReentrantLockStatus*>(exec_kernel_vec().at(0).kernel_ctx->state().get()));
act_id_ = 0;
const auto& kernel_conf = task_proto.exec_sequence().exec_node().Get(0).kernel_conf();
const auto& ibns = kernel_conf.op_attribute().input_bns();
Expand Down
4 changes: 2 additions & 2 deletions oneflow/core/actor/wait_and_send_ids_actor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ class WaitAndSendIdsActor final : public Actor {

void WaitAndSendIdsActor::VirtualActorInit(const TaskProto& task_proto) {
CHECK_EQ(exec_kernel_vec().size(), 1);
wait_and_send_ids_status_ =
static_cast<WaitAndSendIdsStatus*>(exec_kernel_vec().at(0).kernel_ctx->state());
wait_and_send_ids_status_ = CHECK_NOTNULL(
dynamic_cast<WaitAndSendIdsStatus*>(exec_kernel_vec().at(0).kernel_ctx->state().get()));
wait_and_send_ids_status_->buffer_status_ = kBufferStatusSuccess;
wait_and_send_ids_status_->in_id_ = 0;
wait_and_send_ids_status_->out_idx_ = 0;
Expand Down

0 comments on commit 716cc69

Please sign in to comment.