Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

UserKernel remove job_desc #6144

Merged
merged 9 commits into from
Sep 3, 2021
2 changes: 1 addition & 1 deletion oneflow/core/eager/opkernel_object.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ void OpKernelObject::NewPartialInitializedKernel(
const ParallelDesc* parallel_desc) {
KernelConf kernel_conf;
op.GenKernelConf(BlobDesc4BnInOp, parallel_ctx, &kernel_conf);
kernel_.reset(new EagerKernel(job_desc_.get(), kernel_conf));
kernel_.reset(new EagerKernel(kernel_conf));
}

Maybe<void> SystemOpKernelObject::ResetKernel(
Expand Down
1 change: 0 additions & 1 deletion oneflow/core/framework/op_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,6 @@ class KernelComputeContext {
int32_t index) const = 0;
virtual DeviceType device_type() const = 0;
virtual const ParallelContext& parallel_ctx() const = 0;
virtual const JobDesc& job_desc() const = 0;

virtual const std::vector<std::pair<std::string, int32_t>>& inputs() const = 0;
virtual const std::vector<std::pair<std::string, int32_t>>& outputs() const = 0;
Expand Down
7 changes: 4 additions & 3 deletions oneflow/core/framework/op_kernel_infer_cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ namespace oneflow {

namespace user_op {

OpKernelInferCache::OpKernelInferCache(const KernelConf& kernel_conf, const JobDesc& job_desc) {
OpKernelInferCache::OpKernelInferCache(const KernelConf& kernel_conf, const void* scope) {
const OperatorConf& op_conf = kernel_conf.op_attribute().op_conf();
std::shared_ptr<Operator> op = CHECK_JUST(ConstructOp(op_conf));
cache_key_.scope = &job_desc;
cache_key_.scope = scope;
cache_key_.op_conf_sym = op->GetOpConfWithoutOpNameAndLbn();
cache_key_.ibn_idx2shape_sym.resize(op->input_bns().size());
cache_key_.dtype_signature_sym = SymbolOf(kernel_conf.dtype_signature());
Expand Down Expand Up @@ -57,7 +57,8 @@ void OpKernelInferCache::UpdateCacheKey(KernelInferContext* ctx) {
}

void OpKernelInferCache::UpdateCacheValue(KernelInferContext* ctx) {
if (cached_key2value_.size() >= max_size_) { Reset(); }
// TODO: make max size configurable
if (cached_key2value_.size() >= kReleaseInIndependentThreadThreshold) { Reset(); }
auto* cache_value = new OpInferCacheValue();
cache_value->obn_idx2shape_sym.resize(ctx->outputs().size());
FOR_RANGE(int, i, 0, ctx->outputs().size()) {
Expand Down
3 changes: 1 addition & 2 deletions oneflow/core/framework/op_kernel_infer_cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class OpKernelInferCache final {
using KeyStorage = std::list<std::unique_ptr<KeyType>>;
static constexpr size_t kReleaseInIndependentThreadThreshold = 4096;

OpKernelInferCache(const KernelConf& kernel_conf, const JobDesc& job_desc);
OpKernelInferCache(const KernelConf& kernel_conf, const void* scope);
~OpKernelInferCache() = default;

bool IsCacheHit() const;
Expand All @@ -47,7 +47,6 @@ class OpKernelInferCache final {
KeyType cache_key_;
HashMap cached_key2value_;
KeyStorage key_storage_;
size_t max_size_;
};

} // namespace user_op
Expand Down
3 changes: 1 addition & 2 deletions oneflow/core/kernel/eager_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ namespace oneflow {
class EagerKernel final : public Kernel {
public:
OF_DISALLOW_COPY_AND_MOVE(EagerKernel);
EagerKernel(const JobDesc* job_desc, const KernelConf& kernel_conf);
explicit EagerKernel(const KernelConf& kernel_conf);
~EagerKernel() = default;

void Infer(std::function<Blob*(const std::string&)> BnInOp2Blob) const;
Expand All @@ -37,7 +37,6 @@ class EagerKernel final : public Kernel {
void InitOpKernel(const KernelConf& kernel_conf);
void ForwardDataContent(const KernelContext* kernel_ctx) const override { UNIMPLEMENTED(); }
std::unique_ptr<const user_op::OpKernel> kernel_;
const JobDesc* job_desc_;
};

} // namespace oneflow
Expand Down
4 changes: 2 additions & 2 deletions oneflow/core/kernel/kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,15 @@ Kernel::Kernel() = default;

Kernel::~Kernel() = default;

void Kernel::InitBase(const JobDesc* job_desc, const KernelConf& kernel_conf) {
void Kernel::InitBase(const KernelConf& kernel_conf) {
if (shape_infer_helper_) { return; }
kernel_conf_ = kernel_conf;
shape_infer_helper_.reset(
new RuntimeBlobShapeInferHelper(this->op_conf(), this->kernel_conf(), this));
}

void Kernel::Init(const KernelConf& kernel_conf, KernelContext* ctx) {
InitBase(ctx->job_desc(), kernel_conf);
InitBase(kernel_conf);
VirtualKernelInit(ctx);
}

Expand Down
2 changes: 1 addition & 1 deletion oneflow/core/kernel/kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class Kernel {

protected:
Kernel();
void InitBase(const JobDesc* job_desc, const KernelConf&);
void InitBase(const KernelConf&);
virtual void VirtualKernelInit(KernelContext* ctx) {}

virtual void ForwardHeader(const KernelContext* ctx) const;
Expand Down
62 changes: 23 additions & 39 deletions oneflow/core/kernel/user_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,7 @@ void FillTensorDescWithBlob(const Blob* blob, user_op::NaiveTensorDesc* tensor_d

class UserKernelBaseContext {
public:
UserKernelBaseContext(const KernelConf& kernel_conf, const JobDesc& job_desc)
: job_desc_(job_desc) {
explicit UserKernelBaseContext(const KernelConf& kernel_conf) {
CHECK(kernel_conf.has_user_conf());
CHECK(kernel_conf.op_attribute().op_conf().has_user_conf());

Expand All @@ -94,7 +93,6 @@ class UserKernelBaseContext {
DeviceType device_type() const { return device_type_; }
const std::string& device_tag() const { return device_tag_; }
const ParallelContext& parallel_ctx() const { return parallel_ctx_; }
const JobDesc& job_desc() const { return job_desc_; }
const user_op::TensorDesc* TensorDesc4ArgNameAndIndex(const std::string& arg_name,
int32_t index) const {
auto it = arg2tensor_desc_.find(std::make_pair(arg_name, index));
Expand All @@ -112,7 +110,6 @@ class UserKernelBaseContext {
std::string device_tag_;
ParallelContext parallel_ctx_;
HashMap<std::pair<std::string, int32_t>, user_op::NaiveTensorDesc> arg2tensor_desc_;
const JobDesc& job_desc_;
};

class KernelCreateContext final : public user_op::KernelCreateContext {
Expand All @@ -132,11 +129,10 @@ class KernelCreateContext final : public user_op::KernelCreateContext {

class UserKernelInitContext final : public user_op::KernelInitContext {
public:
explicit UserKernelInitContext(DeviceCtx* device_ctx, const KernelConf& kernel_conf,
const JobDesc& job_desc)
explicit UserKernelInitContext(DeviceCtx* device_ctx, const KernelConf& kernel_conf)
: user_op_conf_(kernel_conf.op_attribute().op_conf()),
device_ctx_(device_ctx),
base_ctx_(UserKernelBaseContext(kernel_conf, job_desc)),
base_ctx_(UserKernelBaseContext(kernel_conf)),
parallel_desc_(kernel_conf.op_attribute().parallel_conf_signature().op_parallel_conf()) {
nd_sbp_signature_ = new cfg::NdSbpSignature(kernel_conf.op_attribute().nd_sbp_signature());
if (kernel_conf.op_attribute().has_sbp_signature()) {
Expand Down Expand Up @@ -209,9 +205,8 @@ class UserKernelInitContext final : public user_op::KernelInitContext {

class UserKernelOpInferContext : public user_op::InferContext {
public:
UserKernelOpInferContext(const KernelConf& kernel_conf, const JobDesc* job_desc)
explicit UserKernelOpInferContext(const KernelConf& kernel_conf)
: user_op_conf_(kernel_conf.op_attribute().op_conf()),
job_desc_(job_desc),
parallel_ctx_(kernel_conf.parallel_ctx()),
nd_sbp_signature_(kernel_conf.op_attribute().nd_sbp_signature()),
parallel_desc_(kernel_conf.op_attribute().parallel_conf_signature().op_parallel_conf()) {
Expand Down Expand Up @@ -291,10 +286,6 @@ class UserKernelOpInferContext : public user_op::InferContext {

const ArgVec& inputs() const override { return inputs_; }
const ArgVec& outputs() const override { return outputs_; }
const JobDesc* job_desc() const override {
CHECK_NOTNULL(job_desc_);
return job_desc_;
}
const ParallelContext& parallel_ctx() const override { return parallel_ctx_; };
const ParallelDesc& parallel_desc() const override { return parallel_desc_; }
const cfg::SbpParallel& SbpParallel4ArgNameAndIndex(const std::string& arg_name,
Expand Down Expand Up @@ -361,7 +352,6 @@ class UserKernelOpInferContext : public user_op::InferContext {
}

user_op::UserOpConfWrapper user_op_conf_;
const JobDesc* job_desc_;
ArgVec inputs_;
ArgVec outputs_;
ParallelContext parallel_ctx_;
Expand All @@ -375,12 +365,11 @@ class UserKernelOpInferContext : public user_op::InferContext {

class UserKernelInferContext final : public user_op::KernelInferContext {
public:
explicit UserKernelInferContext(DeviceCtx* device_ctx, const KernelConf& kernel_conf,
const JobDesc& job_desc)
explicit UserKernelInferContext(DeviceCtx* device_ctx, const KernelConf& kernel_conf)
: user_op_conf_(kernel_conf.op_attribute().op_conf()),
device_ctx_(device_ctx),
base_ctx_(UserKernelBaseContext(kernel_conf, job_desc)),
op_infer_ctx_(kernel_conf, &job_desc) {
base_ctx_(UserKernelBaseContext(kernel_conf)),
op_infer_ctx_(kernel_conf) {
auto InitArg2Blob = [this](const PbMap<std::string, UserOpConf::ListString>& arg_map) {
for (auto it = arg_map.begin(); it != arg_map.end(); ++it) {
const std::string& arg_name = it->first;
Expand Down Expand Up @@ -490,11 +479,10 @@ BnTensorPair MakeBnTensorPair(const std::string& bn,

class UserKernelComputeContext final : public user_op::KernelComputeContext {
public:
explicit UserKernelComputeContext(DeviceCtx* device_ctx, const KernelConf& kernel_conf,
const JobDesc& job_desc)
explicit UserKernelComputeContext(DeviceCtx* device_ctx, const KernelConf& kernel_conf)
: user_op_conf_(kernel_conf.op_attribute().op_conf()),
device_ctx_(device_ctx),
base_ctx_(kernel_conf, job_desc) {
base_ctx_(kernel_conf) {
auto InitInOrOut = [&](const PbMap<std::string, UserOpConf::ListString>& arg_map) {
for (const auto& it : arg_map) {
const std::string& arg_name = it.first;
Expand Down Expand Up @@ -552,7 +540,6 @@ class UserKernelComputeContext final : public user_op::KernelComputeContext {

DeviceType device_type() const override { return base_ctx_.device_type(); }
const ParallelContext& parallel_ctx() const override { return base_ctx_.parallel_ctx(); }
const JobDesc& job_desc() const override { return base_ctx_.job_desc(); }

const ArgVec& inputs() const override { return base_ctx_.inputs(); }
const ArgVec& outputs() const override { return base_ctx_.outputs(); }
Expand All @@ -573,9 +560,9 @@ class UserKernelComputeContext final : public user_op::KernelComputeContext {

class UserKernelRegContext final : public user_op::KernelRegContext {
public:
explicit UserKernelRegContext(const KernelConf& kernel_conf, const JobDesc& job_desc)
explicit UserKernelRegContext(const KernelConf& kernel_conf)
: user_op_conf_(kernel_conf.op_attribute().op_conf()),
base_ctx_(UserKernelBaseContext(kernel_conf, job_desc)) {}
base_ctx_(UserKernelBaseContext(kernel_conf)) {}
~UserKernelRegContext() = default;

DeviceType device_type() const override { return base_ctx_.device_type(); }
Expand Down Expand Up @@ -603,23 +590,23 @@ class UserKernelRegContext final : public user_op::KernelRegContext {
UserKernel::~UserKernel() = default;

void UserKernel::InitUserKernel(DeviceCtx* device_ctx) {
ctx_.reset(new UserKernelComputeContext(device_ctx, kernel_conf(), job_desc()));
infer_ctx_.reset(new UserKernelInferContext(device_ctx, kernel_conf(), job_desc()));
infer_cache_.reset(new user_op::OpKernelInferCache(kernel_conf(), job_desc()));
ctx_.reset(new UserKernelComputeContext(device_ctx, kernel_conf()));
infer_ctx_.reset(new UserKernelInferContext(device_ctx, kernel_conf()));
infer_cache_.reset(new user_op::OpKernelInferCache(kernel_conf(), this));
{
const std::string& op_type_name =
kernel_conf().op_attribute().op_conf().user_conf().op_type_name();
const user_op::OpKernelRegistryResult* kernel_reg_val =
CHECK_JUST(user_op::UserOpRegistryMgr::Get().GetOpKernelRegistryResult(
op_type_name, UserKernelRegContext(kernel_conf(), job_desc())));
op_type_name, UserKernelRegContext(kernel_conf())));
CHECK_NOTNULL(kernel_reg_val);
KernelCreateContext create_ctx(kernel_conf());
kernel_.reset(kernel_reg_val->create_fn(&create_ctx));
}

#ifdef WITH_CUDA_GRAPHS
if (ParseBooleanFromEnv("ONEFLOW_KERNEL_ENABLE_CUDA_GRAPH", false)) {
UserKernelInitContext init_ctx(device_ctx, kernel_conf(), job_desc());
UserKernelInitContext init_ctx(device_ctx, kernel_conf());
CudaDeviceCtx* cuda_device_ctx = dynamic_cast<CudaDeviceCtx*>(device_ctx);
const auto* cuda_graph_support = dynamic_cast<const user_op::CudaGraphSupport*>(kernel_.get());
if (cuda_device_ctx != nullptr) {
Expand All @@ -637,7 +624,7 @@ void UserKernel::InitUserKernel(DeviceCtx* device_ctx) {
}

std::shared_ptr<user_op::OpKernelState> UserKernel::CreateOpKernelState(DeviceCtx* device_ctx) {
UserKernelInitContext init_ctx(device_ctx, kernel_conf(), job_desc());
UserKernelInitContext init_ctx(device_ctx, kernel_conf());
return kernel_->CreateOpKernelState(&init_ctx);
}

Expand Down Expand Up @@ -682,7 +669,6 @@ bool UserKernel::IsCudaGraphSupported() const {
}

void UserKernel::VirtualKernelInit(KernelContext* ctx) {
job_desc_ = ctx->job_desc();
InitUserKernel(ctx->device_ctx());
CHECK(opkernel_state_.get() == nullptr);
opkernel_state_ = CreateOpKernelState(ctx->device_ctx());
Expand Down Expand Up @@ -731,24 +717,23 @@ NEW_REGISTER_KERNEL(OperatorConf::kUserConf, UserKernel).SetIsMatchedPred([](con
return true;
});

EagerKernel::EagerKernel(const JobDesc* job_desc, const KernelConf& kernel_conf)
: job_desc_(job_desc) {
InitBase(job_desc, kernel_conf);
EagerKernel::EagerKernel(const KernelConf& kernel_conf) {
InitBase(kernel_conf);
InitOpKernel(kernel_conf);
}

void EagerKernel::InitOpKernel(const KernelConf& kernel_conf) {
const std::string& op_type_name = kernel_conf.op_attribute().op_conf().user_conf().op_type_name();
auto kernel_reg_val = CHECK_JUST(user_op::UserOpRegistryMgr::Get().GetOpKernelRegistryResult(
op_type_name, UserKernelRegContext(kernel_conf, *job_desc_)));
op_type_name, UserKernelRegContext(kernel_conf)));
CHECK_NOTNULL(kernel_reg_val);
KernelCreateContext create_ctx(kernel_conf);
kernel_.reset(kernel_reg_val->create_fn(&create_ctx));
}

void EagerKernel::Infer(std::function<Blob*(const std::string&)> BnInOp2Blob) const {
if (kernel_conf().all_blobs_are_static()) { return; }
UserKernelInferContext infer_ctx(nullptr, kernel_conf(), *job_desc_);
UserKernelInferContext infer_ctx(nullptr, kernel_conf());
infer_ctx.UpdateArg2Tensor(BnInOp2Blob);
auto* op_infer_ctx = dynamic_cast<UserKernelOpInferContext*>(infer_ctx.MutOpInferContext());
if (op_infer_ctx) { op_infer_ctx->UpdateArg2TensorDesc(BnInOp2Blob); }
Expand All @@ -762,8 +747,7 @@ std::shared_ptr<user_op::OpKernelState> EagerKernel::EagerForward(
if (old_opkernel_state) {
new_opkernel_state = old_opkernel_state;
} else {
CHECK_NOTNULL(job_desc_);
UserKernelInitContext init_ctx(device_ctx, kernel_conf(), *job_desc_);
UserKernelInitContext init_ctx(device_ctx, kernel_conf());
new_opkernel_state = kernel_->CreateOpKernelState(&init_ctx);
}

Expand All @@ -773,7 +757,7 @@ std::shared_ptr<user_op::OpKernelState> EagerKernel::EagerForward(
}

// TODO(lixinqi): refactor to a lightweight KernelComputeContext
UserKernelComputeContext compute_ctx(device_ctx, kernel_conf(), *job_desc_);
UserKernelComputeContext compute_ctx(device_ctx, kernel_conf());
compute_ctx.UpdateTensorWithCorrBlob(BnInOp2Blob);
kernel_->Compute(&compute_ctx, new_opkernel_state.get());
return new_opkernel_state;
Expand Down
3 changes: 0 additions & 3 deletions oneflow/core/kernel/user_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,6 @@ class UserKernel final : public Kernel {

bool IsStateless() const override;

const JobDesc& job_desc() const { return *job_desc_; }

std::shared_ptr<user_op::OpKernelState> opkernel_state_;
std::unique_ptr<const user_op::OpKernel> kernel_;
std::unique_ptr<UserKernelComputeContext> ctx_;
Expand All @@ -60,7 +58,6 @@ class UserKernel final : public Kernel {
#ifdef WITH_CUDA_GRAPHS
std::unique_ptr<CudaGraphContext> cuda_graph_ctx_;
#endif // WITH_CUDA_GRAPHS
const JobDesc* job_desc_;
};

} // namespace oneflow
1 change: 0 additions & 1 deletion oneflow/user/kernels/conv_cudnn_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ struct CudnnConvArgsAndAlgo final {
CudnnConvArgs args;
PerfT algo_perf;

// TODO(hanbinbin): remove arg job_desc and set cudnn_conv config as args of CudnnConvArgsAndAlgo
CudnnConvArgsAndAlgo(const user_op::Tensor* x, const user_op::Tensor* w, const user_op::Tensor* y,
user_op::Tensor* buf, const user_op::KernelComputeContext* ctx,
DeviceCtx* device_ctx, bool has_forced_algo, int32_t forced_algo)
Expand Down
1 change: 0 additions & 1 deletion oneflow/user/kernels/deconv_cudnn_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ struct CudnnDeConvArgsAndAlgo final {
CudnnConvArgs args;
PerfT algo_perf;

// TODO(hanbinbin): remove arg job_desc and set cudnn_conv config as args of
// CudnnDeConvArgsAndAlgo
CudnnDeConvArgsAndAlgo(const user_op::Tensor* x, const user_op::Tensor* w,
const user_op::Tensor* y, user_op::Tensor* buf,
Expand Down
4 changes: 0 additions & 4 deletions oneflow/user/kernels/stateful_local_opkernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -356,10 +356,6 @@ class LocalUserKernelComputeContext final : public user_op::KernelComputeContext
UNIMPLEMENTED();
return *(const ParallelContext*)nullptr;
}
const JobDesc& job_desc() const override {
UNIMPLEMENTED();
return *(const JobDesc*)nullptr;
}

const ArgVec& inputs() const override { return base_ctx_.inputs(); };
const ArgVec& outputs() const override { return base_ctx_.outputs(); };
Expand Down