diff --git a/oneflow/core/kernel/cuda_graph_support.h b/oneflow/core/kernel/cuda_graph_support.h index 402bac7dce7..975dd08680c 100644 --- a/oneflow/core/kernel/cuda_graph_support.h +++ b/oneflow/core/kernel/cuda_graph_support.h @@ -19,13 +19,16 @@ namespace oneflow { namespace user_op { class KernelInitContext; +class OpKernelState; class CudaGraphSupport { public: CudaGraphSupport() = default; virtual ~CudaGraphSupport() = default; - virtual bool IsCudaGraphSupported(KernelInitContext* ctx) const { return true; } + virtual bool IsCudaGraphSupported(KernelInitContext* ctx, OpKernelState* state) const { + return true; + } }; } // namespace user_op diff --git a/oneflow/core/kernel/user_kernel.cpp b/oneflow/core/kernel/user_kernel.cpp index 7aa0422acff..1f24ba3e6e1 100644 --- a/oneflow/core/kernel/user_kernel.cpp +++ b/oneflow/core/kernel/user_kernel.cpp @@ -603,24 +603,6 @@ void UserKernel::InitUserKernel(StreamContext* stream_ctx, DeviceCtx* device_ctx KernelCreateContext create_ctx(kernel_conf()); kernel_.reset(kernel_reg_val->create_fn(&create_ctx)); } - -#ifdef WITH_CUDA_GRAPHS - if (ParseBooleanFromEnv("ONEFLOW_KERNEL_ENABLE_CUDA_GRAPH", false)) { - UserKernelInitContext init_ctx(device_ctx, kernel_conf()); - cuda_graph_ctx_ = dynamic_cast(stream_ctx); - const auto* cuda_graph_support = dynamic_cast(kernel_.get()); - if (cuda_graph_ctx_ != nullptr) { - if (cuda_graph_support != nullptr && cuda_graph_support->IsCudaGraphSupported(&init_ctx)) { - cuda_graph_exec_.reset(new CudaGraphExecutable()); - LOG(INFO) << "CUDA Graphs Kernel: " << op_conf().name() << " (" - << op_conf().user_conf().op_type_name() << ")"; - } else { - LOG(INFO) << "CUDA Graphs not supported: " << op_conf().name() << " (" - << op_conf().user_conf().op_type_name() << ")"; - } - } - } -#endif // WITH_CUDA_GRAPHS } std::shared_ptr UserKernel::CreateOpKernelState(DeviceCtx* device_ctx) { @@ -672,6 +654,24 @@ void UserKernel::VirtualKernelInit(KernelContext* ctx) { InitUserKernel(ctx->stream_ctx(), ctx->device_ctx()); CHECK(opkernel_state_.get() == nullptr); opkernel_state_ = CreateOpKernelState(ctx->device_ctx()); +#ifdef WITH_CUDA_GRAPHS + if (ParseBooleanFromEnv("ONEFLOW_KERNEL_ENABLE_CUDA_GRAPH", false)) { + UserKernelInitContext init_ctx(ctx->device_ctx(), kernel_conf()); + cuda_graph_ctx_ = dynamic_cast(ctx->stream_ctx()); + const auto* cuda_graph_support = dynamic_cast(kernel_.get()); + if (cuda_graph_ctx_ != nullptr) { + if (cuda_graph_support != nullptr + && cuda_graph_support->IsCudaGraphSupported(&init_ctx, opkernel_state_.get())) { + cuda_graph_exec_.reset(new CudaGraphExecutable()); + LOG(INFO) << "CUDA Graphs Kernel: " << op_conf().name() << " (" + << op_conf().user_conf().op_type_name() << ")"; + } else { + LOG(INFO) << "CUDA Graphs not supported: " << op_conf().name() << " (" + << op_conf().user_conf().op_type_name() << ")"; + } + } + } +#endif // WITH_CUDA_GRAPHS } void UserKernel::ForwardDataContent(KernelContext* ctx) const { diff --git a/oneflow/user/kernels/conv_cudnn_kernels.cpp b/oneflow/user/kernels/conv_cudnn_kernels.cpp index 7afd04e7410..9172464b987 100644 --- a/oneflow/user/kernels/conv_cudnn_kernels.cpp +++ b/oneflow/user/kernels/conv_cudnn_kernels.cpp @@ -193,7 +193,8 @@ class ConvGpuKernel final : public user_op::OpKernel, public user_op::CudaGraphS } } - bool IsCudaGraphSupported(user_op::KernelInitContext* ctx) const override { + bool IsCudaGraphSupported(user_op::KernelInitContext* ctx, + user_op::OpKernelState* state) const override { return Global::Get() ->resource() .cudnn_conf() @@ -270,7 +271,8 @@ class ConvDataGradGpuKernel final : public user_op::OpKernel, public user_op::Cu args.params.max_ws_size, beta, args.xdesc.Get(), dx->mut_dptr())); } - bool IsCudaGraphSupported(user_op::KernelInitContext* ctx) const override { + bool IsCudaGraphSupported(user_op::KernelInitContext* ctx, + user_op::OpKernelState* state) const override { return Global::Get() ->resource() .cudnn_conf() @@ -334,7 +336,8 @@ class ConvFilterGradGpuKernel final : public user_op::OpKernel, public user_op:: args.params.max_ws_size, CudnnSPZeroPtr(), args.wdesc.Get(), filter_diff->mut_dptr())); } - bool IsCudaGraphSupported(user_op::KernelInitContext* ctx) const override { + bool IsCudaGraphSupported(user_op::KernelInitContext* ctx, + user_op::OpKernelState* state) const override { return Global::Get() ->resource() .cudnn_conf()