Skip to content

Commit

Permalink
UserKernel remove job_desc (#6144)
Browse files Browse the repository at this point in the history
* UserKernel remove job_desc

* fix

* fix

Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
  • Loading branch information
liujuncheng and oneflow-ci-bot committed Sep 3, 2021
1 parent ae7abaa commit 8d48b0e
Show file tree
Hide file tree
Showing 12 changed files with 33 additions and 60 deletions.
2 changes: 1 addition & 1 deletion oneflow/core/eager/opkernel_object.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ void OpKernelObject::NewPartialInitializedKernel(
const ParallelDesc* parallel_desc) {
KernelConf kernel_conf;
op.GenKernelConf(BlobDesc4BnInOp, parallel_ctx, &kernel_conf);
kernel_.reset(new EagerKernel(job_desc_.get(), kernel_conf));
kernel_.reset(new EagerKernel(kernel_conf));
}

Maybe<void> SystemOpKernelObject::ResetKernel(
Expand Down
1 change: 0 additions & 1 deletion oneflow/core/framework/op_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,6 @@ class KernelComputeContext {
int32_t index) const = 0;
virtual DeviceType device_type() const = 0;
virtual const ParallelContext& parallel_ctx() const = 0;
virtual const JobDesc& job_desc() const = 0;

virtual const std::vector<std::pair<std::string, int32_t>>& inputs() const = 0;
virtual const std::vector<std::pair<std::string, int32_t>>& outputs() const = 0;
Expand Down
7 changes: 4 additions & 3 deletions oneflow/core/framework/op_kernel_infer_cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ namespace oneflow {

namespace user_op {

OpKernelInferCache::OpKernelInferCache(const KernelConf& kernel_conf, const JobDesc& job_desc) {
OpKernelInferCache::OpKernelInferCache(const KernelConf& kernel_conf, const void* scope) {
const OperatorConf& op_conf = kernel_conf.op_attribute().op_conf();
std::shared_ptr<Operator> op = CHECK_JUST(ConstructOp(op_conf));
cache_key_.scope = &job_desc;
cache_key_.scope = scope;
cache_key_.op_conf_sym = op->GetOpConfWithoutOpNameAndLbn();
cache_key_.ibn_idx2shape_sym.resize(op->input_bns().size());
cache_key_.dtype_signature_sym = SymbolOf(kernel_conf.dtype_signature());
Expand Down Expand Up @@ -57,7 +57,8 @@ void OpKernelInferCache::UpdateCacheKey(KernelInferContext* ctx) {
}

void OpKernelInferCache::UpdateCacheValue(KernelInferContext* ctx) {
if (cached_key2value_.size() >= max_size_) { Reset(); }
// TODO: make max size configurable
if (cached_key2value_.size() >= kReleaseInIndependentThreadThreshold) { Reset(); }
auto* cache_value = new OpInferCacheValue();
cache_value->obn_idx2shape_sym.resize(ctx->outputs().size());
FOR_RANGE(int, i, 0, ctx->outputs().size()) {
Expand Down
3 changes: 1 addition & 2 deletions oneflow/core/framework/op_kernel_infer_cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class OpKernelInferCache final {
using KeyStorage = std::list<std::unique_ptr<KeyType>>;
static constexpr size_t kReleaseInIndependentThreadThreshold = 4096;

OpKernelInferCache(const KernelConf& kernel_conf, const JobDesc& job_desc);
OpKernelInferCache(const KernelConf& kernel_conf, const void* scope);
~OpKernelInferCache() = default;

bool IsCacheHit() const;
Expand All @@ -47,7 +47,6 @@ class OpKernelInferCache final {
KeyType cache_key_;
HashMap cached_key2value_;
KeyStorage key_storage_;
size_t max_size_;
};

} // namespace user_op
Expand Down
3 changes: 1 addition & 2 deletions oneflow/core/kernel/eager_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ namespace oneflow {
class EagerKernel final : public Kernel {
public:
OF_DISALLOW_COPY_AND_MOVE(EagerKernel);
EagerKernel(const JobDesc* job_desc, const KernelConf& kernel_conf);
explicit EagerKernel(const KernelConf& kernel_conf);
~EagerKernel() = default;

void Infer(std::function<Blob*(const std::string&)> BnInOp2Blob) const;
Expand All @@ -37,7 +37,6 @@ class EagerKernel final : public Kernel {
void InitOpKernel(const KernelConf& kernel_conf);
void ForwardDataContent(const KernelContext* kernel_ctx) const override { UNIMPLEMENTED(); }
std::unique_ptr<const user_op::OpKernel> kernel_;
const JobDesc* job_desc_;
};

} // namespace oneflow
Expand Down
4 changes: 2 additions & 2 deletions oneflow/core/kernel/kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,15 @@ Kernel::Kernel() = default;

Kernel::~Kernel() = default;

void Kernel::InitBase(const JobDesc* job_desc, const KernelConf& kernel_conf) {
void Kernel::InitBase(const KernelConf& kernel_conf) {
if (shape_infer_helper_) { return; }
kernel_conf_ = kernel_conf;
shape_infer_helper_.reset(
new RuntimeBlobShapeInferHelper(this->op_conf(), this->kernel_conf(), this));
}

void Kernel::Init(const KernelConf& kernel_conf, KernelContext* ctx) {
InitBase(ctx->job_desc(), kernel_conf);
InitBase(kernel_conf);
VirtualKernelInit(ctx);
}

Expand Down
2 changes: 1 addition & 1 deletion oneflow/core/kernel/kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class Kernel {

protected:
Kernel();
void InitBase(const JobDesc* job_desc, const KernelConf&);
void InitBase(const KernelConf&);
virtual void VirtualKernelInit(KernelContext* ctx) {}

virtual void ForwardHeader(const KernelContext* ctx) const;
Expand Down
62 changes: 23 additions & 39 deletions oneflow/core/kernel/user_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,7 @@ void FillTensorDescWithBlob(const Blob* blob, user_op::NaiveTensorDesc* tensor_d

class UserKernelBaseContext {
public:
UserKernelBaseContext(const KernelConf& kernel_conf, const JobDesc& job_desc)
: job_desc_(job_desc) {
explicit UserKernelBaseContext(const KernelConf& kernel_conf) {
CHECK(kernel_conf.has_user_conf());
CHECK(kernel_conf.op_attribute().op_conf().has_user_conf());

Expand All @@ -94,7 +93,6 @@ class UserKernelBaseContext {
DeviceType device_type() const { return device_type_; }
const std::string& device_tag() const { return device_tag_; }
const ParallelContext& parallel_ctx() const { return parallel_ctx_; }
const JobDesc& job_desc() const { return job_desc_; }
const user_op::TensorDesc* TensorDesc4ArgNameAndIndex(const std::string& arg_name,
int32_t index) const {
auto it = arg2tensor_desc_.find(std::make_pair(arg_name, index));
Expand All @@ -112,7 +110,6 @@ class UserKernelBaseContext {
std::string device_tag_;
ParallelContext parallel_ctx_;
HashMap<std::pair<std::string, int32_t>, user_op::NaiveTensorDesc> arg2tensor_desc_;
const JobDesc& job_desc_;
};

class KernelCreateContext final : public user_op::KernelCreateContext {
Expand All @@ -132,11 +129,10 @@ class KernelCreateContext final : public user_op::KernelCreateContext {

class UserKernelInitContext final : public user_op::KernelInitContext {
public:
explicit UserKernelInitContext(DeviceCtx* device_ctx, const KernelConf& kernel_conf,
const JobDesc& job_desc)
explicit UserKernelInitContext(DeviceCtx* device_ctx, const KernelConf& kernel_conf)
: user_op_conf_(kernel_conf.op_attribute().op_conf()),
device_ctx_(device_ctx),
base_ctx_(UserKernelBaseContext(kernel_conf, job_desc)),
base_ctx_(UserKernelBaseContext(kernel_conf)),
parallel_desc_(kernel_conf.op_attribute().parallel_conf_signature().op_parallel_conf()) {
nd_sbp_signature_ = new cfg::NdSbpSignature(kernel_conf.op_attribute().nd_sbp_signature());
if (kernel_conf.op_attribute().has_sbp_signature()) {
Expand Down Expand Up @@ -209,9 +205,8 @@ class UserKernelInitContext final : public user_op::KernelInitContext {

class UserKernelOpInferContext : public user_op::InferContext {
public:
UserKernelOpInferContext(const KernelConf& kernel_conf, const JobDesc* job_desc)
explicit UserKernelOpInferContext(const KernelConf& kernel_conf)
: user_op_conf_(kernel_conf.op_attribute().op_conf()),
job_desc_(job_desc),
parallel_ctx_(kernel_conf.parallel_ctx()),
nd_sbp_signature_(kernel_conf.op_attribute().nd_sbp_signature()),
parallel_desc_(kernel_conf.op_attribute().parallel_conf_signature().op_parallel_conf()) {
Expand Down Expand Up @@ -291,10 +286,6 @@ class UserKernelOpInferContext : public user_op::InferContext {

const ArgVec& inputs() const override { return inputs_; }
const ArgVec& outputs() const override { return outputs_; }
const JobDesc* job_desc() const override {
CHECK_NOTNULL(job_desc_);
return job_desc_;
}
const ParallelContext& parallel_ctx() const override { return parallel_ctx_; };
const ParallelDesc& parallel_desc() const override { return parallel_desc_; }
const cfg::SbpParallel& SbpParallel4ArgNameAndIndex(const std::string& arg_name,
Expand Down Expand Up @@ -361,7 +352,6 @@ class UserKernelOpInferContext : public user_op::InferContext {
}

user_op::UserOpConfWrapper user_op_conf_;
const JobDesc* job_desc_;
ArgVec inputs_;
ArgVec outputs_;
ParallelContext parallel_ctx_;
Expand All @@ -375,12 +365,11 @@ class UserKernelOpInferContext : public user_op::InferContext {

class UserKernelInferContext final : public user_op::KernelInferContext {
public:
explicit UserKernelInferContext(DeviceCtx* device_ctx, const KernelConf& kernel_conf,
const JobDesc& job_desc)
explicit UserKernelInferContext(DeviceCtx* device_ctx, const KernelConf& kernel_conf)
: user_op_conf_(kernel_conf.op_attribute().op_conf()),
device_ctx_(device_ctx),
base_ctx_(UserKernelBaseContext(kernel_conf, job_desc)),
op_infer_ctx_(kernel_conf, &job_desc) {
base_ctx_(UserKernelBaseContext(kernel_conf)),
op_infer_ctx_(kernel_conf) {
auto InitArg2Blob = [this](const PbMap<std::string, UserOpConf::ListString>& arg_map) {
for (auto it = arg_map.begin(); it != arg_map.end(); ++it) {
const std::string& arg_name = it->first;
Expand Down Expand Up @@ -490,11 +479,10 @@ BnTensorPair MakeBnTensorPair(const std::string& bn,

class UserKernelComputeContext final : public user_op::KernelComputeContext {
public:
explicit UserKernelComputeContext(DeviceCtx* device_ctx, const KernelConf& kernel_conf,
const JobDesc& job_desc)
explicit UserKernelComputeContext(DeviceCtx* device_ctx, const KernelConf& kernel_conf)
: user_op_conf_(kernel_conf.op_attribute().op_conf()),
device_ctx_(device_ctx),
base_ctx_(kernel_conf, job_desc) {
base_ctx_(kernel_conf) {
auto InitInOrOut = [&](const PbMap<std::string, UserOpConf::ListString>& arg_map) {
for (const auto& it : arg_map) {
const std::string& arg_name = it.first;
Expand Down Expand Up @@ -552,7 +540,6 @@ class UserKernelComputeContext final : public user_op::KernelComputeContext {

DeviceType device_type() const override { return base_ctx_.device_type(); }
const ParallelContext& parallel_ctx() const override { return base_ctx_.parallel_ctx(); }
const JobDesc& job_desc() const override { return base_ctx_.job_desc(); }

const ArgVec& inputs() const override { return base_ctx_.inputs(); }
const ArgVec& outputs() const override { return base_ctx_.outputs(); }
Expand All @@ -573,9 +560,9 @@ class UserKernelComputeContext final : public user_op::KernelComputeContext {

class UserKernelRegContext final : public user_op::KernelRegContext {
public:
explicit UserKernelRegContext(const KernelConf& kernel_conf, const JobDesc& job_desc)
explicit UserKernelRegContext(const KernelConf& kernel_conf)
: user_op_conf_(kernel_conf.op_attribute().op_conf()),
base_ctx_(UserKernelBaseContext(kernel_conf, job_desc)) {}
base_ctx_(UserKernelBaseContext(kernel_conf)) {}
~UserKernelRegContext() = default;

DeviceType device_type() const override { return base_ctx_.device_type(); }
Expand Down Expand Up @@ -603,23 +590,23 @@ class UserKernelRegContext final : public user_op::KernelRegContext {
UserKernel::~UserKernel() = default;

void UserKernel::InitUserKernel(DeviceCtx* device_ctx) {
ctx_.reset(new UserKernelComputeContext(device_ctx, kernel_conf(), job_desc()));
infer_ctx_.reset(new UserKernelInferContext(device_ctx, kernel_conf(), job_desc()));
infer_cache_.reset(new user_op::OpKernelInferCache(kernel_conf(), job_desc()));
ctx_.reset(new UserKernelComputeContext(device_ctx, kernel_conf()));
infer_ctx_.reset(new UserKernelInferContext(device_ctx, kernel_conf()));
infer_cache_.reset(new user_op::OpKernelInferCache(kernel_conf(), this));
{
const std::string& op_type_name =
kernel_conf().op_attribute().op_conf().user_conf().op_type_name();
const user_op::OpKernelRegistryResult* kernel_reg_val =
CHECK_JUST(user_op::UserOpRegistryMgr::Get().GetOpKernelRegistryResult(
op_type_name, UserKernelRegContext(kernel_conf(), job_desc())));
op_type_name, UserKernelRegContext(kernel_conf())));
CHECK_NOTNULL(kernel_reg_val);
KernelCreateContext create_ctx(kernel_conf());
kernel_.reset(kernel_reg_val->create_fn(&create_ctx));
}

#ifdef WITH_CUDA_GRAPHS
if (ParseBooleanFromEnv("ONEFLOW_KERNEL_ENABLE_CUDA_GRAPH", false)) {
UserKernelInitContext init_ctx(device_ctx, kernel_conf(), job_desc());
UserKernelInitContext init_ctx(device_ctx, kernel_conf());
CudaDeviceCtx* cuda_device_ctx = dynamic_cast<CudaDeviceCtx*>(device_ctx);
const auto* cuda_graph_support = dynamic_cast<const user_op::CudaGraphSupport*>(kernel_.get());
if (cuda_device_ctx != nullptr) {
Expand All @@ -637,7 +624,7 @@ void UserKernel::InitUserKernel(DeviceCtx* device_ctx) {
}

std::shared_ptr<user_op::OpKernelState> UserKernel::CreateOpKernelState(DeviceCtx* device_ctx) {
UserKernelInitContext init_ctx(device_ctx, kernel_conf(), job_desc());
UserKernelInitContext init_ctx(device_ctx, kernel_conf());
return kernel_->CreateOpKernelState(&init_ctx);
}

Expand Down Expand Up @@ -682,7 +669,6 @@ bool UserKernel::IsCudaGraphSupported() const {
}

void UserKernel::VirtualKernelInit(KernelContext* ctx) {
job_desc_ = ctx->job_desc();
InitUserKernel(ctx->device_ctx());
CHECK(opkernel_state_.get() == nullptr);
opkernel_state_ = CreateOpKernelState(ctx->device_ctx());
Expand Down Expand Up @@ -731,24 +717,23 @@ NEW_REGISTER_KERNEL(OperatorConf::kUserConf, UserKernel).SetIsMatchedPred([](con
return true;
});

EagerKernel::EagerKernel(const JobDesc* job_desc, const KernelConf& kernel_conf)
: job_desc_(job_desc) {
InitBase(job_desc, kernel_conf);
EagerKernel::EagerKernel(const KernelConf& kernel_conf) {
InitBase(kernel_conf);
InitOpKernel(kernel_conf);
}

void EagerKernel::InitOpKernel(const KernelConf& kernel_conf) {
const std::string& op_type_name = kernel_conf.op_attribute().op_conf().user_conf().op_type_name();
auto kernel_reg_val = CHECK_JUST(user_op::UserOpRegistryMgr::Get().GetOpKernelRegistryResult(
op_type_name, UserKernelRegContext(kernel_conf, *job_desc_)));
op_type_name, UserKernelRegContext(kernel_conf)));
CHECK_NOTNULL(kernel_reg_val);
KernelCreateContext create_ctx(kernel_conf);
kernel_.reset(kernel_reg_val->create_fn(&create_ctx));
}

void EagerKernel::Infer(std::function<Blob*(const std::string&)> BnInOp2Blob) const {
if (kernel_conf().all_blobs_are_static()) { return; }
UserKernelInferContext infer_ctx(nullptr, kernel_conf(), *job_desc_);
UserKernelInferContext infer_ctx(nullptr, kernel_conf());
infer_ctx.UpdateArg2Tensor(BnInOp2Blob);
auto* op_infer_ctx = dynamic_cast<UserKernelOpInferContext*>(infer_ctx.MutOpInferContext());
if (op_infer_ctx) { op_infer_ctx->UpdateArg2TensorDesc(BnInOp2Blob); }
Expand All @@ -762,8 +747,7 @@ std::shared_ptr<user_op::OpKernelState> EagerKernel::EagerForward(
if (old_opkernel_state) {
new_opkernel_state = old_opkernel_state;
} else {
CHECK_NOTNULL(job_desc_);
UserKernelInitContext init_ctx(device_ctx, kernel_conf(), *job_desc_);
UserKernelInitContext init_ctx(device_ctx, kernel_conf());
new_opkernel_state = kernel_->CreateOpKernelState(&init_ctx);
}

Expand All @@ -773,7 +757,7 @@ std::shared_ptr<user_op::OpKernelState> EagerKernel::EagerForward(
}

// TODO(lixinqi): refactor to a lightweight KernelComputeContext
UserKernelComputeContext compute_ctx(device_ctx, kernel_conf(), *job_desc_);
UserKernelComputeContext compute_ctx(device_ctx, kernel_conf());
compute_ctx.UpdateTensorWithCorrBlob(BnInOp2Blob);
kernel_->Compute(&compute_ctx, new_opkernel_state.get());
return new_opkernel_state;
Expand Down
3 changes: 0 additions & 3 deletions oneflow/core/kernel/user_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,6 @@ class UserKernel final : public Kernel {

bool IsStateless() const override;

const JobDesc& job_desc() const { return *job_desc_; }

std::shared_ptr<user_op::OpKernelState> opkernel_state_;
std::unique_ptr<const user_op::OpKernel> kernel_;
std::unique_ptr<UserKernelComputeContext> ctx_;
Expand All @@ -60,7 +58,6 @@ class UserKernel final : public Kernel {
#ifdef WITH_CUDA_GRAPHS
std::unique_ptr<CudaGraphContext> cuda_graph_ctx_;
#endif // WITH_CUDA_GRAPHS
const JobDesc* job_desc_;
};

} // namespace oneflow
1 change: 0 additions & 1 deletion oneflow/user/kernels/conv_cudnn_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ struct CudnnConvArgsAndAlgo final {
CudnnConvArgs args;
PerfT algo_perf;

// TODO(hanbinbin): remove arg job_desc and set cudnn_conv config as args of CudnnConvArgsAndAlgo
CudnnConvArgsAndAlgo(const user_op::Tensor* x, const user_op::Tensor* w, const user_op::Tensor* y,
user_op::Tensor* buf, const user_op::KernelComputeContext* ctx,
DeviceCtx* device_ctx, bool has_forced_algo, int32_t forced_algo)
Expand Down
1 change: 0 additions & 1 deletion oneflow/user/kernels/deconv_cudnn_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ struct CudnnDeConvArgsAndAlgo final {
CudnnConvArgs args;
PerfT algo_perf;

// TODO(hanbinbin): remove arg job_desc and set cudnn_conv config as args of
// CudnnDeConvArgsAndAlgo
CudnnDeConvArgsAndAlgo(const user_op::Tensor* x, const user_op::Tensor* w,
const user_op::Tensor* y, user_op::Tensor* buf,
Expand Down
4 changes: 0 additions & 4 deletions oneflow/user/kernels/stateful_local_opkernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -356,10 +356,6 @@ class LocalUserKernelComputeContext final : public user_op::KernelComputeContext
UNIMPLEMENTED();
return *(const ParallelContext*)nullptr;
}
const JobDesc& job_desc() const override {
UNIMPLEMENTED();
return *(const JobDesc*)nullptr;
}

const ArgVec& inputs() const override { return base_ctx_.inputs(); };
const ArgVec& outputs() const override { return base_ctx_.outputs(); };
Expand Down

0 comments on commit 8d48b0e

Please sign in to comment.