From 9a331fbb72f8f29fad5e36e8708293bce91763f5 Mon Sep 17 00:00:00 2001 From: Bowen Chen Date: Mon, 5 Jul 2021 17:58:41 +0800 Subject: [PATCH] [Random Generator] Part2: Migrate functional dropout (#5378) * add random generator * reformat * refactor: allow auto generator * refactor: remove kAUTO, update python api and test * refactor: use member initializer lists, fix build issue when cpu only * handle exeception given invalid device * add dropout functor; add OpExprInterpContext; refactor random_mask_like based on random_generator * refactor random generator * disable generator's copyandmove constructor * reformat * fix bad merge * refine * fix cpu only build * auto format by CI * refactor * use global generator when no generator specified in functional api * refine Co-authored-by: oneflow-ci-bot Co-authored-by: Houjiang Chen Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> --- oneflow/api/python/functional/python_arg.cpp | 12 ++++ .../autograd/gradient_funcs/adaptive_pool.cpp | 2 +- .../core/autograd/gradient_funcs/bias_add.cpp | 4 +- .../gradient_funcs/broadcast_binary_ops.cpp | 12 ++-- .../core/autograd/gradient_funcs/deconv.cpp | 3 +- .../core/autograd/gradient_funcs/squeeze.cpp | 3 +- .../gradient_funcs/tensor_scalar_binary.cpp | 3 +- .../autograd/gradient_funcs/unsqueeze.cpp | 3 +- oneflow/core/eager/eager_oneflow.cpp | 6 +- .../local_call_opkernel_phy_instr_operand.h | 11 +-- .../core/eager/opkernel_instruction_type.cpp | 4 ++ .../core/framework/instructions_builder.cpp | 5 +- oneflow/core/framework/instructions_builder.h | 2 +- oneflow/core/framework/op_interpreter.h | 44 +++++++++--- .../eager_consistent_op_interpreter.cpp | 25 +++---- .../eager_mirrored_op_interpreter.cpp | 30 +++++---- .../op_interpreter/op_interpreter.cpp | 23 ++++--- .../op_interpreter/op_interpreter_util.cpp | 8 +-- .../op_interpreter/op_interpreter_util.h | 12 +++- oneflow/core/framework/random_generator.cpp | 47 ++++++++++++- oneflow/core/framework/random_generator.h | 14 +++- oneflow/core/functional/functional_api.yaml | 5 ++ oneflow/core/functional/impl/nn_functor.cpp | 42 ++++++++++++ oneflow/core/functional/value_types.h | 7 ++ oneflow/python/nn/modules/dropout.py | 31 ++------- oneflow/user/kernels/dropout_kernel.cpp | 38 ----------- .../user/kernels/random_mask_generator.cpp | 5 +- oneflow/user/kernels/random_mask_generator.cu | 52 +++----------- oneflow/user/kernels/random_mask_generator.h | 15 ++--- .../user/kernels/random_mask_like_kernel.cpp | 32 +++++++++ .../user/kernels/random_mask_like_kernel.h | 67 +++++++++++++++++++ .../user/kernels/stateful_local_opkernel.h | 1 + tools/generate_functional_api.py | 4 ++ 33 files changed, 372 insertions(+), 200 deletions(-) create mode 100644 oneflow/user/kernels/random_mask_like_kernel.cpp create mode 100644 oneflow/user/kernels/random_mask_like_kernel.h diff --git a/oneflow/api/python/functional/python_arg.cpp b/oneflow/api/python/functional/python_arg.cpp index 57d4280f678..f6a2daf7985 100644 --- a/oneflow/api/python/functional/python_arg.cpp +++ b/oneflow/api/python/functional/python_arg.cpp @@ -23,6 +23,7 @@ limitations under the License. #include "oneflow/core/framework/tensor.h" #include "oneflow/core/framework/tensor_tuple.h" #include "oneflow/core/framework/user_op_attr.cfg.h" +#include "oneflow/core/framework/random_generator.h" #include "oneflow/core/functional/scalar.h" namespace py = pybind11; @@ -155,6 +156,17 @@ Maybe PythonArg::ObjectAs() const { } } +template<> +Maybe> PythonArg::ObjectAs>() + const { + return detail::cast>(Borrow()); +} + +template<> +Maybe PythonArg::ObjectAs() const { + return *JUST(detail::cast>(Borrow())); +} + } // namespace functional } // namespace one } // namespace oneflow diff --git a/oneflow/core/autograd/gradient_funcs/adaptive_pool.cpp b/oneflow/core/autograd/gradient_funcs/adaptive_pool.cpp index 96bf197fda9..9512f030c24 100644 --- a/oneflow/core/autograd/gradient_funcs/adaptive_pool.cpp +++ b/oneflow/core/autograd/gradient_funcs/adaptive_pool.cpp @@ -64,7 +64,7 @@ Maybe AdaptivePool::Apply(const AdaptivePoolInterpState* ctx, const Tensor const std::shared_ptr& x = ctx->SavedTensors().at(0); in_grads->resize(1); - in_grads->at(0) = JUST(OpInterpUtil::Dispatch(*grad_op_, {x, out_grads.at(0)}, {})); + in_grads->at(0) = JUST(OpInterpUtil::Dispatch(*grad_op_, {x, out_grads.at(0)})); return Maybe::Ok(); } diff --git a/oneflow/core/autograd/gradient_funcs/bias_add.cpp b/oneflow/core/autograd/gradient_funcs/bias_add.cpp index 6e59f14982a..04378ee553e 100644 --- a/oneflow/core/autograd/gradient_funcs/bias_add.cpp +++ b/oneflow/core/autograd/gradient_funcs/bias_add.cpp @@ -68,8 +68,8 @@ class BiasAdd : public OpExprGradFunction { JUST(OpInterpUtil::Dispatch(*backward_bias_op_, {out_grads.at(0)}, attrs)); } if (ctx->input_requires_grad) { - in_grads->at(0) = JUST( - OpInterpUtil::Dispatch(*backward_input_op_, {out_grads.at(0)}, /*attrs=*/{})); + in_grads->at(0) = + JUST(OpInterpUtil::Dispatch(*backward_input_op_, {out_grads.at(0)})); } return Maybe::Ok(); } diff --git a/oneflow/core/autograd/gradient_funcs/broadcast_binary_ops.cpp b/oneflow/core/autograd/gradient_funcs/broadcast_binary_ops.cpp index 63b50e89cff..24fdc1dc614 100644 --- a/oneflow/core/autograd/gradient_funcs/broadcast_binary_ops.cpp +++ b/oneflow/core/autograd/gradient_funcs/broadcast_binary_ops.cpp @@ -134,8 +134,7 @@ class BroadcastSub : public BroadcastBinaryGrad { in_grads->resize(2); if (x->requires_grad()) { in_grads->at(0) = JUST(x_grad_op_->forward(out_grads.at(0), x)); } if (y->requires_grad()) { - const auto& grad = - JUST(OpInterpUtil::Dispatch(*y_grad_mul_op_, {out_grads.at(0)}, /*attrs=*/{})); + const auto& grad = JUST(OpInterpUtil::Dispatch(*y_grad_mul_op_, {out_grads.at(0)})); in_grads->at(1) = JUST(y_grad_op_->forward(grad, y)); } return Maybe::Ok(); @@ -169,12 +168,12 @@ class BroadcastMul : public BroadcastBinaryGrad { in_grads->resize(2); if (x->requires_grad()) { const auto& x_grad = - JUST(OpInterpUtil::Dispatch(*x_grad_mul_op_, {out_grads.at(0), y}, /*attrs=*/{})); + JUST(OpInterpUtil::Dispatch(*x_grad_mul_op_, {out_grads.at(0), y})); in_grads->at(0) = JUST(x_grad_op_->forward(x_grad, x)); } if (y->requires_grad()) { const auto& y_grad = - JUST(OpInterpUtil::Dispatch(*y_grad_mul_op_, {out_grads.at(0), x}, /*attrs=*/{})); + JUST(OpInterpUtil::Dispatch(*y_grad_mul_op_, {out_grads.at(0), x})); in_grads->at(1) = JUST(y_grad_op_->forward(y_grad, y)); } return Maybe::Ok(); @@ -208,12 +207,11 @@ class BroadcastDiv : public BroadcastBinaryGrad { in_grads->resize(2); if (x->requires_grad()) { const auto& x_grad = - JUST(OpInterpUtil::Dispatch(*x_grad_div_op_, {out_grads.at(0), y}, /*attrs=*/{})); + JUST(OpInterpUtil::Dispatch(*x_grad_div_op_, {out_grads.at(0), y})); in_grads->at(0) = JUST(x_grad_op_->forward(x_grad, x)); } if (y->requires_grad()) { - in_grads->at(1) = - JUST(OpInterpUtil::Dispatch(*y_grad_op_, {out_grads.at(0), z, y}, /*attrs=*/{})); + in_grads->at(1) = JUST(OpInterpUtil::Dispatch(*y_grad_op_, {out_grads.at(0), z, y})); } return Maybe::Ok(); } diff --git a/oneflow/core/autograd/gradient_funcs/deconv.cpp b/oneflow/core/autograd/gradient_funcs/deconv.cpp index dddc48ba283..fa2e92551f4 100644 --- a/oneflow/core/autograd/gradient_funcs/deconv.cpp +++ b/oneflow/core/autograd/gradient_funcs/deconv.cpp @@ -98,8 +98,7 @@ Maybe DeConvolutionNd::Apply(const DeConvolutionNdInterpState* ctx, if (ctx->weight_requires_grad) { int idx = ctx->activation_requires_grad; const auto& x = ctx->SavedTensors().at(idx); - in_grads->at(1) = - JUST(OpInterpUtil::Dispatch(*weight_grad_op_, {x, out_grads.at(0)}, /*attrs=*/{})); + in_grads->at(1) = JUST(OpInterpUtil::Dispatch(*weight_grad_op_, {x, out_grads.at(0)})); } return Maybe::Ok(); } diff --git a/oneflow/core/autograd/gradient_funcs/squeeze.cpp b/oneflow/core/autograd/gradient_funcs/squeeze.cpp index 9a38e9402bb..a69a600394f 100644 --- a/oneflow/core/autograd/gradient_funcs/squeeze.cpp +++ b/oneflow/core/autograd/gradient_funcs/squeeze.cpp @@ -64,8 +64,7 @@ Maybe Squeeze::Apply(const SqueezeInterpState* ctx, const TensorTuple& out const std::shared_ptr& like = ctx->SavedTensors().at(0); in_grads->resize(1); - in_grads->at(0) = - JUST(OpInterpUtil::Dispatch(*grad_op_, {out_grads.at(0), like}, /*attrs*/ {})); + in_grads->at(0) = JUST(OpInterpUtil::Dispatch(*grad_op_, {out_grads.at(0), like})); return Maybe::Ok(); } diff --git a/oneflow/core/autograd/gradient_funcs/tensor_scalar_binary.cpp b/oneflow/core/autograd/gradient_funcs/tensor_scalar_binary.cpp index 90b1908c8cd..fa3a6264aff 100644 --- a/oneflow/core/autograd/gradient_funcs/tensor_scalar_binary.cpp +++ b/oneflow/core/autograd/gradient_funcs/tensor_scalar_binary.cpp @@ -89,8 +89,7 @@ class TensorScalarSub : public TensorScalarAddOrSub { TensorTuple* in_grads) const override { in_grads->resize(2); if (ctx->x_requires_grad) { - in_grads->at(0) = - JUST(OpInterpUtil::Dispatch(*identity_op_, {out_grads.at(0)}, /*attrs=*/{})); + in_grads->at(0) = JUST(OpInterpUtil::Dispatch(*identity_op_, {out_grads.at(0)})); } if (ctx->scalar_requires_grad) { int32_t num_axes = out_grads.at(0)->shape()->NumAxes(); diff --git a/oneflow/core/autograd/gradient_funcs/unsqueeze.cpp b/oneflow/core/autograd/gradient_funcs/unsqueeze.cpp index 793e06b7c25..3faf92e3d5c 100644 --- a/oneflow/core/autograd/gradient_funcs/unsqueeze.cpp +++ b/oneflow/core/autograd/gradient_funcs/unsqueeze.cpp @@ -64,8 +64,7 @@ Maybe Unsqueeze::Apply(const UnsqueezeInterpState* ctx, const TensorTuple& const std::shared_ptr& like = ctx->SavedTensors().at(0); in_grads->resize(1); - in_grads->at(0) = - JUST(OpInterpUtil::Dispatch(*grad_op_, {out_grads.at(0), like}, /*attrs*/ {})); + in_grads->at(0) = JUST(OpInterpUtil::Dispatch(*grad_op_, {out_grads.at(0), like})); return Maybe::Ok(); } diff --git a/oneflow/core/eager/eager_oneflow.cpp b/oneflow/core/eager/eager_oneflow.cpp index cf59c06a8da..ab02e3271d1 100644 --- a/oneflow/core/eager/eager_oneflow.cpp +++ b/oneflow/core/eager/eager_oneflow.cpp @@ -68,11 +68,11 @@ Maybe StorageAdd(const EagerSymbol& symbol) { Maybe EagerOneflow::RunPhysicalInstruction( const std::shared_ptr& cluster_instruction) { vm::InstructionMsgList instruction_list; - const auto& eage_instructions = cluster_instruction->eager_instruction(); - for (const auto& instr_proto : eage_instructions.instruction_list().instruction()) { + const auto& eager_instructions = cluster_instruction->eager_instruction(); + for (const auto& instr_proto : eager_instructions.instruction_list().instruction()) { instruction_list.EmplaceBack(ObjectMsgPtr::New(instr_proto)); } - return RunPhysicalInstruction(&instruction_list, eage_instructions.eager_symbol_list()); + return RunPhysicalInstruction(&instruction_list, eager_instructions.eager_symbol_list()); } Maybe EagerOneflow::RunPhysicalInstruction( diff --git a/oneflow/core/eager/local_call_opkernel_phy_instr_operand.h b/oneflow/core/eager/local_call_opkernel_phy_instr_operand.h index 5ae00c19dcc..a4aa14f734e 100644 --- a/oneflow/core/eager/local_call_opkernel_phy_instr_operand.h +++ b/oneflow/core/eager/local_call_opkernel_phy_instr_operand.h @@ -18,6 +18,7 @@ limitations under the License. #include "oneflow/core/eager/eager_blob_object.h" #include "oneflow/core/framework/attr_map.h" +#include "oneflow/core/framework/op_interpreter.h" #include "oneflow/core/vm/instruction_operand.msg.h" namespace oneflow { @@ -48,13 +49,15 @@ class LocalCallOpKernelPhyInstrOperand final : public vm::PhyInstrOperand { LocalCallOpKernelPhyInstrOperand(const std::shared_ptr& opkernel, const one::EagerBlobObjectListPtr& inputs, - const one::EagerBlobObjectListPtr& outputs, const AttrMap& attrs) - : opkernel_(opkernel), inputs_(inputs), outputs_(outputs), attrs_(attrs) {} + const one::EagerBlobObjectListPtr& outputs, + const one::OpExprInterpContext& op_interp_ctx_) + : opkernel_(opkernel), inputs_(inputs), outputs_(outputs), op_interp_ctx_(op_interp_ctx_) {} const one::StatefulLocalOpKernel& opkernel() const { return *opkernel_; } const one::EagerBlobObjectListPtr& inputs() const { return inputs_; } const one::EagerBlobObjectListPtr& outputs() const { return outputs_; } - const AttrMap& attrs() const { return attrs_; } + const AttrMap& attrs() const { return op_interp_ctx_.attrs; } + const one::OpExprInterpContext& op_interp_ctx() const { return op_interp_ctx_; } one::StatefulLocalOpKernel* mut_opkernel() { return opkernel_.get(); } @@ -84,7 +87,7 @@ class LocalCallOpKernelPhyInstrOperand final : public vm::PhyInstrOperand { std::shared_ptr opkernel_; one::EagerBlobObjectListPtr inputs_; one::EagerBlobObjectListPtr outputs_; - const AttrMap attrs_; + const one::OpExprInterpContext op_interp_ctx_; const user_op::OpKernel* user_opkernel_; }; diff --git a/oneflow/core/eager/opkernel_instruction_type.cpp b/oneflow/core/eager/opkernel_instruction_type.cpp index ad68caceb46..be3d853ea11 100644 --- a/oneflow/core/eager/opkernel_instruction_type.cpp +++ b/oneflow/core/eager/opkernel_instruction_type.cpp @@ -550,6 +550,10 @@ struct LocalCallOpKernelUtil final { static inline void TryInitOpKernelState(LocalCallOpKernelPhyInstrOperand* operand, DeviceCtx* device_ctx, user_op::OpKernelState** state) { + if (operand->op_interp_ctx().state) { + *state = operand->op_interp_ctx().state.get(); + return; + } operand->mut_opkernel()->TryInitOpKernelState(operand->user_opkernel(), device_ctx, operand->inputs(), operand->outputs(), state); } diff --git a/oneflow/core/framework/instructions_builder.cpp b/oneflow/core/framework/instructions_builder.cpp index 08a773e8c90..5ed0c3f9adb 100644 --- a/oneflow/core/framework/instructions_builder.cpp +++ b/oneflow/core/framework/instructions_builder.cpp @@ -653,13 +653,14 @@ Maybe InstructionsBuilder::BuildRecvInstruction( Maybe InstructionsBuilder::LocalCallOpKernel( const std::shared_ptr& opkernel, const one::EagerBlobObjectListPtr& input_eager_blob_objects, - const one::EagerBlobObjectListPtr& output_eager_blob_objects, const AttrMap& attrs, + const one::EagerBlobObjectListPtr& output_eager_blob_objects, + const one::OpExprInterpContext& ctx, const std::shared_ptr& parallel_desc_sym, const std::string& instr_type_name) { ObjectMsgPtr instruction = ObjectMsgPtr::New(instr_type_name); auto phy_instr_operand = std::make_shared( - opkernel, input_eager_blob_objects, output_eager_blob_objects, attrs); + opkernel, input_eager_blob_objects, output_eager_blob_objects, ctx); *instruction->mut_parallel_desc() = parallel_desc_sym; *instruction->mutable_phy_instr_operand() = phy_instr_operand; instruction_list_->EmplaceBack(std::move(instruction)); diff --git a/oneflow/core/framework/instructions_builder.h b/oneflow/core/framework/instructions_builder.h index b614b1db986..5917660ae60 100644 --- a/oneflow/core/framework/instructions_builder.h +++ b/oneflow/core/framework/instructions_builder.h @@ -247,7 +247,7 @@ class InstructionsBuilder : public std::enable_shared_from_this LocalCallOpKernel(const std::shared_ptr& opkernel, const one::EagerBlobObjectListPtr& input_eager_blob_objects, const one::EagerBlobObjectListPtr& output_eager_blob_objects, - const AttrMap& attrs, + const one::OpExprInterpContext& ctx, const std::shared_ptr& parallel_desc_sym, const std::string& instr_type_name); diff --git a/oneflow/core/framework/op_interpreter.h b/oneflow/core/framework/op_interpreter.h index 7c4cbc7307b..2cea0930aea 100644 --- a/oneflow/core/framework/op_interpreter.h +++ b/oneflow/core/framework/op_interpreter.h @@ -20,6 +20,7 @@ limitations under the License. #include "oneflow/core/framework/op_expr.h" #include "oneflow/core/framework/tensor.h" #include "oneflow/core/framework/tensor_tuple.h" +#include "oneflow/core/framework/op_kernel.h" namespace oneflow { namespace one { @@ -41,17 +42,27 @@ class OpExprInterpState { TensorTuple saved_tensors_; }; +struct OpExprInterpContext { + AttrMap attrs; + std::shared_ptr state; +}; + class OpExprInterpreter { public: OpExprInterpreter() = default; virtual ~OpExprInterpreter() = default; - virtual Maybe Apply(const OpExpr& op, const TensorTuple& inputs, TensorTuple* outputs, - const AttrMap& attrs) const = 0; + Maybe Apply(const OpExpr& op, const TensorTuple& inputs, TensorTuple* outputs, + const AttrMap& attrs) const { + return Apply(op, inputs, outputs, OpExprInterpContext{attrs, nullptr}); + } Maybe Apply(const OpExpr& op, const TensorTuple& inputs, TensorTuple* outputs) const { return Apply(op, inputs, outputs, AttrMap{}); } + + virtual Maybe Apply(const OpExpr& op, const TensorTuple& inputs, TensorTuple* outputs, + const OpExprInterpContext& ctx) const = 0; }; #define FOR_EACH_BUILTIN_OPS(_macro) \ @@ -66,13 +77,13 @@ class OpExprInterpreter { #define DECLARE_NORMAL_APPLY_FUNC(op_type) \ virtual Maybe ApplyImpl(const op_type##Expr& op_expr, const TensorTuple& inputs, \ - TensorTuple* outputs, const AttrMap& attrs) const + TensorTuple* outputs, const OpExprInterpContext& ctx) const #define DECLARE_PURE_VIRTUAL_APPLY_FUNC(op_type) DECLARE_NORMAL_APPLY_FUNC(op_type) = 0; #define DECLARE_OVERRIDE_APPLY_FUNC(op_type) \ Maybe ApplyImpl(const op_type##Expr& op_expr, const TensorTuple& inputs, \ - TensorTuple* outputs, const AttrMap& attrs) const override; + TensorTuple* outputs, const OpExprInterpContext& ctx) const override; class LazyInterpreter : public OpExprInterpreter { public: @@ -80,7 +91,12 @@ class LazyInterpreter : public OpExprInterpreter { virtual ~LazyInterpreter() = default; Maybe Apply(const OpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs, - const AttrMap& attrs) const override; + const AttrMap& attrs) const { + return Apply(op_expr, inputs, outputs, OpExprInterpContext{attrs, nullptr}); + } + + Maybe Apply(const OpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs, + const OpExprInterpContext& ctx) const override; private: DECLARE_NORMAL_APPLY_FUNC(BuiltinOp); @@ -93,7 +109,12 @@ class EagerInterpreter : public OpExprInterpreter { virtual ~EagerInterpreter() = default; Maybe Apply(const OpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs, - const AttrMap& attrs) const override; + const AttrMap& attrs) const { + return Apply(op_expr, inputs, outputs, OpExprInterpContext{attrs, nullptr}); + } + + Maybe Apply(const OpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs, + const OpExprInterpContext& ctx) const override; private: FOR_EACH_BUILTIN_OPS(DECLARE_PURE_VIRTUAL_APPLY_FUNC); @@ -131,12 +152,17 @@ class AutogradInterpreter { virtual ~AutogradInterpreter() = default; Maybe Apply(const OpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs, - const AttrMap& attrs) const; + const AttrMap& attrs) const { + return Apply(op_expr, inputs, outputs, OpExprInterpContext{attrs, nullptr}); + } - Maybe Apply(const OpExpr& op, const TensorTuple& inputs, TensorTuple* outputs) const { - return Apply(op, inputs, outputs, AttrMap{}); + Maybe Apply(const OpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs) const { + return Apply(op_expr, inputs, outputs, OpExprInterpContext{}); } + Maybe Apply(const OpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs, + const OpExprInterpContext& ctx) const; + private: std::shared_ptr internal_; }; diff --git a/oneflow/core/framework/op_interpreter/eager_consistent_op_interpreter.cpp b/oneflow/core/framework/op_interpreter/eager_consistent_op_interpreter.cpp index 910b727d10e..66080dd26a2 100644 --- a/oneflow/core/framework/op_interpreter/eager_consistent_op_interpreter.cpp +++ b/oneflow/core/framework/op_interpreter/eager_consistent_op_interpreter.cpp @@ -33,10 +33,11 @@ namespace oneflow { namespace one { Maybe Interpret(const UserOpExpr& user_op_expr, const TensorTuple& inputs, - TensorTuple* outputs, const AttrMap& attrs) { + TensorTuple* outputs, const OpExprInterpContext& ctx) { CHECK_EQ_OR_RETURN(outputs->size(), user_op_expr.output_size()); const auto& placement_scope = JUST(GetCurrentScope())->placement_scope(); - const auto& infer_args = JUST(ConsistentTensorMetaInferArgs::New(inputs, placement_scope, attrs)); + const auto& infer_args = + JUST(ConsistentTensorMetaInferArgs::New(inputs, placement_scope, ctx.attrs)); const auto& result = JUST(user_op_expr.mut_consistent_tensor_infer_cache()->GetOrInfer(*infer_args)); const auto& output_tensor_metas = result->output_tensor_metas(); @@ -71,56 +72,56 @@ Maybe Interpret(const UserOpExpr& user_op_expr, const TensorTuple& inputs, const auto& instr_type_name = JUST(GetLocalCallInstructionName(parallel_desc->device_tag())); JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe { return builder->LocalCallOpKernel(kernel, input_eager_blob_objects, output_eager_blob_objects, - attrs, parallel_desc, instr_type_name); + ctx, parallel_desc, instr_type_name); })); return Maybe::Ok(); } Maybe EagerConsistentInterpreter::ApplyImpl(const UserOpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs, - const AttrMap& attrs) const { - return Interpret(op_expr, inputs, outputs, attrs); + const OpExprInterpContext& ctx) const { + OF_UNIMPLEMENTED(); } Maybe EagerConsistentInterpreter::ApplyImpl(const VariableOpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs, - const AttrMap& attrs) const { + const OpExprInterpContext& ctx) const { OF_UNIMPLEMENTED(); } Maybe EagerConsistentInterpreter::ApplyImpl(const CastToMirroredOpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs, - const AttrMap& attrs) const { + const OpExprInterpContext& ctx) const { OF_UNIMPLEMENTED(); } Maybe EagerConsistentInterpreter::ApplyImpl(const CastFromMirroredOpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs, - const AttrMap& attrs) const { + const OpExprInterpContext& ctx) const { OF_UNIMPLEMENTED(); } Maybe EagerConsistentInterpreter::ApplyImpl(const DistributeSplitOpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs, - const AttrMap& attrs) const { + const OpExprInterpContext& ctx) const { OF_UNIMPLEMENTED(); } Maybe EagerConsistentInterpreter::ApplyImpl(const DistributeCloneOpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs, - const AttrMap& attrs) const { + const OpExprInterpContext& ctx) const { OF_UNIMPLEMENTED(); } Maybe EagerConsistentInterpreter::ApplyImpl(const DistributeConcatOpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs, - const AttrMap& attrs) const { + const OpExprInterpContext& ctx) const { OF_UNIMPLEMENTED(); } Maybe EagerConsistentInterpreter::ApplyImpl(const DistributeAddOpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs, - const AttrMap& attrs) const { + const OpExprInterpContext& ctx) const { OF_UNIMPLEMENTED(); } diff --git a/oneflow/core/framework/op_interpreter/eager_mirrored_op_interpreter.cpp b/oneflow/core/framework/op_interpreter/eager_mirrored_op_interpreter.cpp index e21f94ccf7e..e045100b721 100644 --- a/oneflow/core/framework/op_interpreter/eager_mirrored_op_interpreter.cpp +++ b/oneflow/core/framework/op_interpreter/eager_mirrored_op_interpreter.cpp @@ -48,7 +48,8 @@ Maybe TensorImpl4Tensor(const std::shared_ptr& Maybe NaiveInterpret(const UserOpExpr& user_op_expr, const TensorTuple& inputs, const Symbol& default_device, TensorTuple* outputs, - const AttrMap& attrs) { + const OpExprInterpContext& ctx) { + const auto& attrs = ctx.attrs; std::shared_ptr input_eager_blob_objects = std::make_shared(inputs.size()); for (int i = 0; i < inputs.size(); i++) { @@ -126,7 +127,7 @@ Maybe NaiveInterpret(const UserOpExpr& user_op_expr, const TensorTuple& in } } return builder->LocalCallOpKernel(kernel, input_eager_blob_objects, output_eager_blob_objects, - attrs, op_parallel_desc, instr_type_name); + ctx, op_parallel_desc, instr_type_name); })); return Maybe::Ok(); } @@ -139,12 +140,13 @@ Maybe RunEmptyOp(TensorTuple* outputs) { const auto& device = tensor_impl->device(); const auto empty_expr = JUST(op_expr_helper::EmptyOp(*shape, data_type)); std::shared_ptr inputs = std::make_shared(); - JUST(NaiveInterpret(*empty_expr, *inputs, device, outputs, AttrMap{})); + JUST(NaiveInterpret(*empty_expr, *inputs, device, outputs, + OpExprInterpContext{AttrMap{}, nullptr})); return Maybe::Ok(); } static Maybe NaiveInterpret(const UserOpExpr& user_op_expr, const TensorTuple& inputs, - TensorTuple* outputs, const AttrMap& attrs) { + TensorTuple* outputs, const OpExprInterpContext& ctx) { CHECK_EQ_OR_RETURN(outputs->size(), user_op_expr.output_size()); Symbol default_device; if (inputs.empty()) { @@ -152,18 +154,18 @@ static Maybe NaiveInterpret(const UserOpExpr& user_op_expr, const TensorTu } else { default_device = JUST(inputs.at(0)->device()); } - return NaiveInterpret(user_op_expr, inputs, default_device, outputs, attrs); + return NaiveInterpret(user_op_expr, inputs, default_device, outputs, ctx); } Maybe EagerMirroredInterpreter::ApplyImpl(const UserOpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs, - const AttrMap& attrs) const { - return NaiveInterpret(op_expr, inputs, outputs, attrs); + const OpExprInterpContext& ctx) const { + return NaiveInterpret(op_expr, inputs, outputs, ctx); } Maybe EagerMirroredInterpreter::ApplyImpl(const VariableOpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs, - const AttrMap& attrs) const { + const OpExprInterpContext& ctx) const { OF_UNIMPLEMENTED(); } @@ -176,13 +178,13 @@ static Maybe BuildAndRunMirroredCastInstruction(const BuiltinOpExpr& op_ex Maybe EagerMirroredInterpreter::ApplyImpl(const CastToMirroredOpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs, - const AttrMap& attrs) const { + const OpExprInterpContext& ctx) const { return BuildAndRunMirroredCastInstruction(op_expr, inputs, outputs); } Maybe EagerMirroredInterpreter::ApplyImpl(const CastFromMirroredOpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs, - const AttrMap& attrs) const { + const OpExprInterpContext& ctx) const { return BuildAndRunMirroredCastInstruction(op_expr, inputs, outputs); } @@ -195,13 +197,13 @@ static Maybe BuildAndRunDistributeSplitOrCloneInstruction(const BuiltinOpE Maybe EagerMirroredInterpreter::ApplyImpl(const DistributeSplitOpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs, - const AttrMap& attrs) const { + const OpExprInterpContext& ctx) const { return BuildAndRunDistributeSplitOrCloneInstruction(op_expr, inputs, outputs); } Maybe EagerMirroredInterpreter::ApplyImpl(const DistributeCloneOpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs, - const AttrMap& attrs) const { + const OpExprInterpContext& ctx) const { return BuildAndRunDistributeSplitOrCloneInstruction(op_expr, inputs, outputs); } @@ -214,13 +216,13 @@ static Maybe BuildAndRunDistributeConcatAndAddInstruction(const BuiltinOpE Maybe EagerMirroredInterpreter::ApplyImpl(const DistributeConcatOpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs, - const AttrMap& attrs) const { + const OpExprInterpContext& ctx) const { return BuildAndRunDistributeConcatAndAddInstruction(op_expr, inputs, outputs); } Maybe EagerMirroredInterpreter::ApplyImpl(const DistributeAddOpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs, - const AttrMap& attrs) const { + const OpExprInterpContext& ctx) const { return BuildAndRunDistributeConcatAndAddInstruction(op_expr, inputs, outputs); } diff --git a/oneflow/core/framework/op_interpreter/op_interpreter.cpp b/oneflow/core/framework/op_interpreter/op_interpreter.cpp index 115f6e64f1c..eb4e13dcc64 100644 --- a/oneflow/core/framework/op_interpreter/op_interpreter.cpp +++ b/oneflow/core/framework/op_interpreter/op_interpreter.cpp @@ -34,10 +34,10 @@ namespace oneflow { namespace one { Maybe LazyInterpreter::Apply(const OpExpr& op_expr, const TensorTuple& inputs, - TensorTuple* outputs, const AttrMap& attrs) const { + TensorTuple* outputs, const OpExprInterpContext& ctx) const { #define APPLY_IF(op_type) \ if (const auto* op = dynamic_cast(&op_expr)) { \ - return ApplyImpl(*op, inputs, outputs, attrs); \ + return ApplyImpl(*op, inputs, outputs, ctx); \ } APPLY_IF(FunctionOp); @@ -49,10 +49,10 @@ Maybe LazyInterpreter::Apply(const OpExpr& op_expr, const TensorTuple& inp } Maybe LazyInterpreter::ApplyImpl(const BuiltinOpExpr& op_expr, const TensorTuple& inputs, - TensorTuple* outputs, const AttrMap& attrs) const { + TensorTuple* outputs, const OpExprInterpContext& ctx) const { CHECK_EQ_OR_RETURN(inputs.size(), op_expr.input_size()); const auto& scope = JUST(GetCurrentScope()); - auto op_conf = JUST(OpInterpUtil::GenBuiltinOpConf(op_expr, attrs)); + auto op_conf = JUST(OpInterpUtil::GenBuiltinOpConf(op_expr, ctx.attrs)); int64_t symbol_id = JUST(scope->symbol_id()); op_conf->set_scope_symbol_id(symbol_id); if (!op_conf->has_device_tag()) { @@ -93,17 +93,17 @@ Maybe LazyInterpreter::ApplyImpl(const BuiltinOpExpr& op_expr, const Tenso } Maybe LazyInterpreter::ApplyImpl(const FunctionOpExpr& op_expr, const TensorTuple& inputs, - TensorTuple* outputs, const AttrMap& attrs) const { + TensorTuple* outputs, const OpExprInterpContext& ctx) const { // TODO(hjchen2) UNIMPLEMENTED(); return Maybe::Ok(); } Maybe EagerInterpreter::Apply(const OpExpr& op_expr, const TensorTuple& inputs, - TensorTuple* outputs, const AttrMap& attrs) const { + TensorTuple* outputs, const OpExprInterpContext& ctx) const { #define APPLY_IF(op_type) \ if (const auto* op = dynamic_cast(&op_expr)) { \ - return ApplyImpl(*op, inputs, outputs, attrs); \ + return ApplyImpl(*op, inputs, outputs, ctx); \ } APPLY_IF(UserOp); @@ -122,7 +122,8 @@ Maybe EagerInterpreter::Apply(const OpExpr& op_expr, const TensorTuple& in } Maybe EagerInterpreter::ApplyImpl(const FunctionOpExpr& op_expr, const TensorTuple& inputs, - TensorTuple* outputs, const AttrMap& attrs) const { + TensorTuple* outputs, + const OpExprInterpContext& ctx) const { // TODO(hjchen2) UNIMPLEMENTED(); return Maybe::Ok(); @@ -144,7 +145,7 @@ Maybe DetermineRequiresGrad(TensorTuple* outputs, const bool& requires_gra } // namespace Maybe AutogradInterpreter::Apply(const OpExpr& op_expr, const TensorTuple& inputs, - TensorTuple* outputs, const AttrMap& attrs) const { + TensorTuple* outputs, const OpExprInterpContext& ctx) const { bool requires_grad = false; if (autograd::GradMode::is_enabled() && !JUST(op_expr.IsGradDisabled())) { requires_grad = @@ -153,13 +154,13 @@ Maybe AutogradInterpreter::Apply(const OpExpr& op_expr, const TensorTuple& } { autograd::AutoGradMode mode(false); - JUST(internal_->Apply(op_expr, inputs, outputs, attrs)); + JUST(internal_->Apply(op_expr, inputs, outputs, ctx)); JUST(DetermineIsLeaf(outputs, inputs.size() == 0, requires_grad)); JUST(DetermineRequiresGrad(outputs, requires_grad)); } if (requires_grad) { const auto& grad_closure = JUST(op_expr.GetOrCreateOpGradClosure()); - JUST(grad_closure->Capture(inputs, *outputs, attrs)); + JUST(grad_closure->Capture(inputs, *outputs, ctx.attrs)); auto backward_fn = std::make_shared(const TensorTuple&, TensorTuple*, bool)>>( diff --git a/oneflow/core/framework/op_interpreter/op_interpreter_util.cpp b/oneflow/core/framework/op_interpreter/op_interpreter_util.cpp index c5c8f7c36ab..34c269a8d36 100644 --- a/oneflow/core/framework/op_interpreter/op_interpreter_util.cpp +++ b/oneflow/core/framework/op_interpreter/op_interpreter_util.cpp @@ -66,17 +66,17 @@ std::shared_ptr BuildLazyInterpreter() { template<> /*static*/ Maybe OpInterpUtil::Dispatch(const OpExpr& op_expr, const TensorTuple& inputs, - const AttrMap& attrs) { + const OpExprInterpContext& ctx) { auto outputs = std::make_shared(op_expr.output_size()); - JUST(JUST(GetInterpreter())->Apply(op_expr, inputs, outputs.get(), attrs)); + JUST(JUST(GetInterpreter())->Apply(op_expr, inputs, outputs.get(), ctx)); return outputs; } template<> /*static*/ Maybe OpInterpUtil::Dispatch(const OpExpr& op_expr, const TensorTuple& inputs, - const AttrMap& attrs) { - return JUST(Dispatch(op_expr, inputs, attrs))->at(0); + const OpExprInterpContext& ctx) { + return JUST(Dispatch(op_expr, inputs, ctx))->at(0); } /*static*/ Maybe OpInterpUtil::AddOpAndInferOpAttribute( diff --git a/oneflow/core/framework/op_interpreter/op_interpreter_util.h b/oneflow/core/framework/op_interpreter/op_interpreter_util.h index 6e5c1741578..aeccc89abc5 100644 --- a/oneflow/core/framework/op_interpreter/op_interpreter_util.h +++ b/oneflow/core/framework/op_interpreter/op_interpreter_util.h @@ -34,18 +34,24 @@ class OpInterpUtil { static Maybe GetInterpreter(); template - static Maybe Dispatch(const OpExpr& op_expr, const TensorTuple& inputs, const AttrMap& attrs); + static Maybe Dispatch(const OpExpr& op_expr, const TensorTuple& inputs, const AttrMap& attrs) { + return Dispatch(op_expr, inputs, OpExprInterpContext{attrs, nullptr}); + } template static Maybe Dispatch(const OpExpr& op_expr, const TensorTuple& inputs) { - return Dispatch(op_expr, inputs, AttrMap{}); + return Dispatch(op_expr, inputs, OpExprInterpContext{AttrMap{}, nullptr}); } - static Maybe GenBuiltinOpConf(const BuiltinOpExpr& op_expr, const AttrMap& attrs); + template + static Maybe Dispatch(const OpExpr& op_expr, const TensorTuple& inputs, + const OpExprInterpContext& ctx); static Maybe AddOpAndInferOpAttribute(const OperatorConf& op_conf, const bool is_mirrored_strategy_enabled); + static Maybe GenBuiltinOpConf(const BuiltinOpExpr& op_expr, const AttrMap& attrs); + static Maybe BuildTensor( const std::shared_ptr& blob_attr, const std::shared_ptr& parallel_attr, diff --git a/oneflow/core/framework/random_generator.cpp b/oneflow/core/framework/random_generator.cpp index acf15c9ef03..d14a66bf836 100644 --- a/oneflow/core/framework/random_generator.cpp +++ b/oneflow/core/framework/random_generator.cpp @@ -65,6 +65,24 @@ void ManualSeed(uint64_t seed) { auto_gen->set_current_seed(seed); } +Maybe GetDefaultGenerator(const std::string& device) { + std::shared_ptr gen_impl; + if (device == "cpu") { + gen_impl = GetDefaultDeviceGenerator(); + } +#ifdef WITH_CUDA + else if (device == "cuda") { + gen_impl = GetDefaultDeviceGenerator(); + } +#endif // WITH_CUDA + else if (device == "auto") { + gen_impl = GetDefaultAutoGenerator(); + } else { + UNIMPLEMENTED_THEN_RETURN() << " device unimplemented, device name: " << device; + } + return std::make_shared(gen_impl); +} + std::shared_ptr CreateAutoGenerator(uint64_t seed) { return std::make_shared(seed); } @@ -87,14 +105,37 @@ const std::shared_ptr>& GetDefaultDeviceGenerat template Maybe> TryGetDeviceGenerator( - const std::shared_ptr& generator) { - if (auto auto_gen = std::dynamic_pointer_cast(generator)) { + const std::shared_ptr& gen_impl) { + if (auto auto_gen = std::dynamic_pointer_cast(gen_impl)) { return auto_gen->template GetDeviceGenerator(); } - auto device_gen = std::dynamic_pointer_cast>(generator); + auto device_gen = std::dynamic_pointer_cast>(gen_impl); CHECK_NOTNULL_OR_RETURN(device_gen); return device_gen; } +template Maybe> TryGetDeviceGenerator( + const std::shared_ptr& generator); + +#ifdef WITH_CUDA +template Maybe> TryGetDeviceGenerator( + const std::shared_ptr& generator); +#endif // WITH_CUDA + +template +Maybe> TryGetDeviceGenerator( + const std::shared_ptr& generator) { + CHECK_NOTNULL_OR_RETURN(generator); + return TryGetDeviceGenerator(generator->get_impl()); +} + +template Maybe> TryGetDeviceGenerator( + const std::shared_ptr& generator); + +#ifdef WITH_CUDA +template Maybe> TryGetDeviceGenerator( + const std::shared_ptr& generator); +#endif // WITH_CUDA + } // namespace one } // namespace oneflow diff --git a/oneflow/core/framework/random_generator.h b/oneflow/core/framework/random_generator.h index e2ade46d9e6..da56b202b66 100644 --- a/oneflow/core/framework/random_generator.h +++ b/oneflow/core/framework/random_generator.h @@ -128,7 +128,9 @@ class Generator final { static constexpr uint64_t default_rng_seed_val = 67280421310721; public: + OF_DISALLOW_COPY_AND_MOVE(Generator); Generator() = default; + Generator(std::shared_ptr gen_impl) : gen_impl_(gen_impl) {} Maybe Init(const std::string& device, uint64_t seed); @@ -144,16 +146,22 @@ class Generator final { // Reset current seed by the default seed, and returns it. uint64_t seed(); + const std::shared_ptr& get_impl() const { return gen_impl_; } + private: std::shared_ptr gen_impl_; }; void ManualSeed(uint64_t seed); -const std::shared_ptr& GetDefaultAutoGenerator(); +std::shared_ptr CreateGenerator(const std::string& device, uint64_t seed); + +Maybe GetDefaultGenerator(const std::string& device); std::shared_ptr CreateAutoGenerator(uint64_t seed); +const std::shared_ptr& GetDefaultAutoGenerator(); + template std::shared_ptr> CreateDeviceGenerator(uint64_t seed); @@ -164,6 +172,10 @@ template Maybe> TryGetDeviceGenerator( const std::shared_ptr& generator); +template +Maybe> TryGetDeviceGenerator( + const std::shared_ptr& generator); + } // namespace one } // namespace oneflow diff --git a/oneflow/core/functional/functional_api.yaml b/oneflow/core/functional/functional_api.yaml index 5807e407966..74eb602a24a 100644 --- a/oneflow/core/functional/functional_api.yaml +++ b/oneflow/core/functional/functional_api.yaml @@ -540,6 +540,11 @@ signature: "Tensor ClipByScalarMaxGrad(Tensor dy, Tensor x, *, Scalar max)" bind_python: False +- name: "dropout" + signature: + "Tensor Dropout(Tensor x, *, Float p, Generator generator=None)" + bind_python: True + - name: "pad" signature: "Tensor Pad(Tensor x, *, Int64List pad, String mode=\"constant\", Scalar value=0)" bind_python: True diff --git a/oneflow/core/functional/impl/nn_functor.cpp b/oneflow/core/functional/impl/nn_functor.cpp index 687d7e77eb8..34c10472f32 100644 --- a/oneflow/core/functional/impl/nn_functor.cpp +++ b/oneflow/core/functional/impl/nn_functor.cpp @@ -21,10 +21,13 @@ limitations under the License. #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" #include "oneflow/core/framework/tensor.h" #include "oneflow/core/framework/tensor_tuple.h" +#include "oneflow/core/framework/op_interpreter.h" +#include "oneflow/core/framework/random_generator.h" #include "oneflow/core/functional/function_library.h" #include "oneflow/core/functional/impl/common.h" #include "oneflow/core/functional/impl/unary_functor.h" #include "oneflow/core/functional/scalar.h" +#include "oneflow/user/kernels/random_mask_like_kernel.h" namespace oneflow { namespace one { @@ -350,6 +353,44 @@ class PadFunctor { std::shared_ptr replicate_pad_; }; +class DropoutFunctor { + public: + DropoutFunctor() { + random_mask_like_op_ = + CHECK_JUST(one::OpBuilder("random_mask_like").Input("like").Output("out").Build()); + dropout_op_ = + CHECK_JUST(one::OpBuilder("dropout").Input("in").Input("mask").Output("out").Build()); + } + Maybe operator()(const std::shared_ptr& x, const float& p, + const Optional& generator) const { + MutableAttrMap random_mask_like_attrs; + JUST(random_mask_like_attrs.SetAttr("rate", p)); + + std::shared_ptr gen; + if (!generator) { + gen = JUST(one::GetDefaultGenerator("auto")); + } else { + gen = JUST(generator.value()); + } + + JUST(random_mask_like_attrs.SetAttr("seed", gen->current_seed())); + const auto& random_mask_like_state = std::make_shared(gen); + + const auto& mask = JUST(OpInterpUtil::Dispatch( + *random_mask_like_op_, {x}, + OpExprInterpContext{.attrs = random_mask_like_attrs, .state = random_mask_like_state})); + float scale = 1.0; + if (p != 1.0) { scale = 1.0 / (1.0 - p); } + MutableAttrMap dropout_attrs; + JUST(dropout_attrs.SetAttr("scale", scale)); + return OpInterpUtil::Dispatch(*dropout_op_, {x, mask}, dropout_attrs); + } + + private: + std::shared_ptr random_mask_like_op_; + std::shared_ptr dropout_op_; +}; + } // namespace impl ONEFLOW_FUNCTION_LIBRARY(m) { @@ -366,6 +407,7 @@ ONEFLOW_FUNCTION_LIBRARY(m) { m.add_functor("SparseSoftmaxCrossEntropy"); m.add_functor("Normalization"); m.add_functor("Pad"); + m.add_functor("Dropout"); }; } // namespace functional diff --git a/oneflow/core/functional/value_types.h b/oneflow/core/functional/value_types.h index f47035a8e1d..1270bd28f3e 100644 --- a/oneflow/core/functional/value_types.h +++ b/oneflow/core/functional/value_types.h @@ -34,6 +34,7 @@ class AttrValue; namespace one { class Tensor; class TensorTuple; +class Generator; namespace functional { class Scalar; @@ -76,6 +77,9 @@ enum ValueType { kATTR_MAP, kDTYPE, kSHAPE, + kGENERATOR, + kGENERATOR_REF, + kGENERATOR_MAYBE, }; #define VALUE_TYPE_OF_IMPL(cpp_type, value_type) \ @@ -122,6 +126,9 @@ VALUE_TYPE_OF_IMPL(std::shared_ptr, kATTR_REF); VALUE_TYPE_OF_IMPL(AttrMap, kATTR_MAP); VALUE_TYPE_OF_IMPL(DataType, kDTYPE); VALUE_TYPE_OF_IMPL(Shape, kSHAPE); +VALUE_TYPE_OF_IMPL(one::Generator, kGENERATOR); +VALUE_TYPE_OF_IMPL(std::shared_ptr, kGENERATOR_REF); +VALUE_TYPE_OF_IMPL(Maybe, kGENERATOR_MAYBE); #undef VALUE_TYPE_OF_IMPL diff --git a/oneflow/python/nn/modules/dropout.py b/oneflow/python/nn/modules/dropout.py index b16fc6195fa..3f44a1c8b7c 100644 --- a/oneflow/python/nn/modules/dropout.py +++ b/oneflow/python/nn/modules/dropout.py @@ -90,37 +90,18 @@ class Dropout(_DropoutNd): """ - def __init__(self, p: float = 0.5, inplace: bool = False): + def __init__(self, p: float = 0.5, inplace: bool = False, generator=None): _DropoutNd.__init__(self, p, inplace) - if self.p == 1.0: - scale = 1.0 - else: - scale = float(1.0 / (1.0 - self.p)) - - seed = random.randint(-sys.maxsize, sys.maxsize) - self._op = ( - flow.builtin_op("dropout") - .Input("in") - .Input("mask") - .Output("out") - .Attr("scale", scale) - .Build() - ) - self._mask_op = ( - flow.builtin_op("random_mask_like") - .Input("like") - .Output("out") - .Attr("rate", self.p) - .Attr("seed", seed) - .Build() - ) + self.p = p + if generator is None: + generator = flow.Generator() + self.generator = generator def forward(self, x): if self.p == 0.0 or not self.training: return x - mask = self._mask_op(x)[0] - return self._op(x, mask)[0] + return flow.F.dropout(x, self.p, self.generator) if __name__ == "__main__": diff --git a/oneflow/user/kernels/dropout_kernel.cpp b/oneflow/user/kernels/dropout_kernel.cpp index f5c34e8eac7..a641c55eacc 100644 --- a/oneflow/user/kernels/dropout_kernel.cpp +++ b/oneflow/user/kernels/dropout_kernel.cpp @@ -15,7 +15,6 @@ limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/user/kernels/op_kernel_state_wrapper.h" -#include "oneflow/user/kernels/random_mask_generator.h" #include "oneflow/core/kernel/kernel_util.h" namespace oneflow { @@ -100,42 +99,5 @@ class DropoutGradKernelCPU final : public user_op::OpKernel { REGISTER_DROPOUT_GRAD_KERNEL_CPU(float) REGISTER_DROPOUT_GRAD_KERNEL_CPU(double) -template -class RandomMaskLikeKernel final : public user_op::OpKernel { - public: - RandomMaskLikeKernel() = default; - ~RandomMaskLikeKernel() = default; - - std::shared_ptr CreateOpKernelState( - user_op::KernelInitContext* ctx) const override { - int64_t seed = ctx->Attr("seed"); - return std::make_shared>>(seed); - } - - private: - void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state) const override { - const user_op::Tensor* like = ctx->Tensor4ArgNameAndIndex("like", 0); - user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); - int64_t elem_cnt = like->shape().elem_cnt(); - int8_t* mask = out->mut_dptr(); - auto* random_mask_generator = - dynamic_cast>*>(state); - CHECK_NOTNULL(random_mask_generator); - random_mask_generator->Mutable()->Generate(ctx->device_ctx(), elem_cnt, - ctx->Attr("rate"), mask); - } - bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } -}; - -#define REGISTER_RANDOM_MASK_LIKE_KERNEL(device) \ - REGISTER_USER_KERNEL("random_mask_like") \ - .SetCreateFn>() \ - .SetIsMatchedHob(user_op::HobDeviceTag() == device); - -REGISTER_RANDOM_MASK_LIKE_KERNEL(DeviceType::kCPU) -#ifdef WITH_CUDA -REGISTER_RANDOM_MASK_LIKE_KERNEL(DeviceType::kGPU) -#endif - } // namespace } // namespace oneflow diff --git a/oneflow/user/kernels/random_mask_generator.cpp b/oneflow/user/kernels/random_mask_generator.cpp index 88a70bba285..bf7ae1e3a75 100644 --- a/oneflow/user/kernels/random_mask_generator.cpp +++ b/oneflow/user/kernels/random_mask_generator.cpp @@ -22,7 +22,10 @@ void RandomMaskGenerator::Generate(DeviceCtx* device_ctx, cons CHECK_GE(n, 0); std::uniform_real_distribution random_distribution(GetZeroVal(), GetOneVal()); - for (int64_t i = 0; i < n; ++i) { mask[i] = random_distribution(mt19937_generator_) > rate; } + const auto& cpu_generator = CHECK_JUST(one::TryGetDeviceGenerator(generator_)); + for (int64_t i = 0; i < n; ++i) { + mask[i] = random_distribution(cpu_generator->generator()) > rate; + } } template class RandomMaskGenerator; diff --git a/oneflow/user/kernels/random_mask_generator.cu b/oneflow/user/kernels/random_mask_generator.cu index 588420ad240..715456b9b62 100644 --- a/oneflow/user/kernels/random_mask_generator.cu +++ b/oneflow/user/kernels/random_mask_generator.cu @@ -28,35 +28,10 @@ union Pack { int8_t b_value[sizeof(PackType)]; }; -int GetThreadNum(const cudaDeviceProp& prop) { - switch (prop.major) { - case 3: // Kepler - return 2 * 192; - case 5: // Maxwell - return 2 * 128; - case 6: // Pascal - if ((prop.minor == 1) || (prop.minor == 2)) { - return 2 * 128; - } else { - return 2 * 64; - } - case 7: // Volta and Turing - return 2 * 64; - default: return 2 * 64; - } -} - __device__ int8_t GenMask(curandState* state, const float rate) { return curand_uniform(state) >= rate; } -__global__ void SetupKernel(int64_t seed, curandState* state) { - const int id = blockIdx.x * blockDim.x + threadIdx.x; - size_t local_seed = (static_cast(seed) + 0x9e3779b9U + (static_cast(id) << 6U) - + (static_cast(id) >> 2U)); - curand_init(local_seed, 0, 0, &state[id]); -} - __global__ void GenerateGpu(curandState* state, const int64_t n, const float rate, int8_t* mask) { const int id = blockIdx.x * blockDim.x + threadIdx.x; curandState localState = state[id]; @@ -75,26 +50,17 @@ __global__ void GenerateGpu(curandState* state, const int64_t n, const float rat } // namespace -RandomMaskGenerator::RandomMaskGenerator(int64_t seed) { - cudaDeviceProp prop; - OF_CUDA_CHECK(cudaGetDeviceProperties(&prop, 0)); - block_num_ = prop.multiProcessorCount; - thread_num_ = GetThreadNum(prop); - OF_CUDA_CHECK(cudaMalloc(&curand_states_, block_num_ * thread_num_ * sizeof(curandState))); - SetupKernel<<>>(seed, curand_states_); -} - -RandomMaskGenerator::~RandomMaskGenerator() { - OF_CUDA_CHECK(cudaFree(curand_states_)); -} - void RandomMaskGenerator::Generate(DeviceCtx* device_ctx, const int64_t n, const float rate, int8_t* mask) { - const int32_t elem_cnt_per_block = thread_num_ * sizeof(PackType) * kMinPackPerThread; - const int32_t block_num = - std::min(static_cast((n + elem_cnt_per_block - 1) / elem_cnt_per_block), block_num_); - GenerateGpu<<cuda_stream()>>>(curand_states_, n, rate, - mask); + const auto& cuda_gen = CHECK_JUST(one::TryGetDeviceGenerator(generator_)); + const auto& block_num = cuda_gen->block_num(); + const auto& thread_num = cuda_gen->thread_num(); + auto* curand_states = cuda_gen->curand_states(); + const int32_t elem_cnt_per_block = thread_num * sizeof(PackType) * kMinPackPerThread; + const int32_t block_num_final = + std::min(static_cast((n + elem_cnt_per_block - 1) / elem_cnt_per_block), block_num); + GenerateGpu<<cuda_stream()>>>(curand_states, n, rate, + mask); } template class RandomMaskGenerator; diff --git a/oneflow/user/kernels/random_mask_generator.h b/oneflow/user/kernels/random_mask_generator.h index b1d87a78db7..3e2e56fa4e6 100644 --- a/oneflow/user/kernels/random_mask_generator.h +++ b/oneflow/user/kernels/random_mask_generator.h @@ -18,6 +18,7 @@ limitations under the License. #include "oneflow/core/common/data_type.h" #include "oneflow/core/device/device_context.h" +#include "oneflow/core/framework/random_generator.h" #ifdef WITH_CUDA #include #include @@ -32,13 +33,13 @@ template<> class RandomMaskGenerator final { public: OF_DISALLOW_COPY_AND_MOVE(RandomMaskGenerator); - RandomMaskGenerator(int64_t seed) : mt19937_generator_(seed) {} - ~RandomMaskGenerator() {} + RandomMaskGenerator(const std::shared_ptr& generator) : generator_(generator) {} + ~RandomMaskGenerator() = default; void Generate(DeviceCtx* device_ctx, int64_t n, float rate, int8_t* mask); private: - std::mt19937 mt19937_generator_; + const std::shared_ptr generator_; }; #ifdef WITH_CUDA @@ -46,15 +47,13 @@ template<> class RandomMaskGenerator final { public: OF_DISALLOW_COPY_AND_MOVE(RandomMaskGenerator); - RandomMaskGenerator(int64_t seed); - ~RandomMaskGenerator(); + RandomMaskGenerator(const std::shared_ptr& generator) : generator_(generator) {} + ~RandomMaskGenerator() = default; void Generate(DeviceCtx* device_ctx, int64_t n, float rate, int8_t* mask); private: - curandState* curand_states_; - int32_t block_num_; - int32_t thread_num_; + const std::shared_ptr generator_; }; #endif diff --git a/oneflow/user/kernels/random_mask_like_kernel.cpp b/oneflow/user/kernels/random_mask_like_kernel.cpp new file mode 100644 index 00000000000..d1063385e19 --- /dev/null +++ b/oneflow/user/kernels/random_mask_like_kernel.cpp @@ -0,0 +1,32 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/user/kernels/random_mask_like_kernel.h" + +namespace oneflow { + +namespace { +#define REGISTER_RANDOM_MASK_LIKE_KERNEL(device) \ + REGISTER_USER_KERNEL("random_mask_like") \ + .SetCreateFn>() \ + .SetIsMatchedHob(user_op::HobDeviceTag() == device); + +REGISTER_RANDOM_MASK_LIKE_KERNEL(DeviceType::kCPU) +#ifdef WITH_CUDA +REGISTER_RANDOM_MASK_LIKE_KERNEL(DeviceType::kGPU) +#endif +} // namespace + +} // namespace oneflow diff --git a/oneflow/user/kernels/random_mask_like_kernel.h b/oneflow/user/kernels/random_mask_like_kernel.h new file mode 100644 index 00000000000..bd1b5de6e36 --- /dev/null +++ b/oneflow/user/kernels/random_mask_like_kernel.h @@ -0,0 +1,67 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#ifndef ONEFLOW_USER_KERNEL_RANDOM_MASK_GENERATOR_H_ +#define ONEFLOW_USER_KERNEL_RANDOM_MASK_GENERATOR_H_ +#include "oneflow/user/kernels/random_mask_generator.h" +#include "oneflow/core/framework/framework.h" + +namespace oneflow { + +class RandomMaskLikeKernelState : public user_op::OpKernelState { + public: + explicit RandomMaskLikeKernelState(const std::shared_ptr& generator) + : generator_(generator) {} + + const std::shared_ptr& generator() const { return generator_; } + + private: + std::shared_ptr generator_; +}; + +namespace { + +template +class RandomMaskLikeKernel final : public user_op::OpKernel { + public: + RandomMaskLikeKernel() = default; + ~RandomMaskLikeKernel() = default; + + std::shared_ptr CreateOpKernelState( + user_op::KernelInitContext* ctx) const override { + const auto generator = CHECK_JUST(one::Generator::New("auto", ctx->Attr("seed"))); + return std::make_shared(generator); + } + + private: + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state) const override { + const user_op::Tensor* like = ctx->Tensor4ArgNameAndIndex("like", 0); + user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); + int64_t elem_cnt = like->shape().elem_cnt(); + int8_t* mask = out->mut_dptr(); + auto* random_mask_like_state = dynamic_cast(state); + CHECK_NOTNULL(random_mask_like_state); + const auto& generator = random_mask_like_state->generator(); + CHECK_NOTNULL(generator); + auto random_mask_like_gen = std::make_shared>(generator); + random_mask_like_gen->Generate(ctx->device_ctx(), elem_cnt, ctx->Attr("rate"), mask); + } + bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } +}; + +} // namespace +} // namespace oneflow + +#endif // ONEFLOW_USER_KERNEL_RANDOM_MASK_GENERATOR_H_ diff --git a/oneflow/user/kernels/stateful_local_opkernel.h b/oneflow/user/kernels/stateful_local_opkernel.h index 520b74985f5..60093824be4 100644 --- a/oneflow/user/kernels/stateful_local_opkernel.h +++ b/oneflow/user/kernels/stateful_local_opkernel.h @@ -22,6 +22,7 @@ limitations under the License. #include "oneflow/core/framework/device.h" #include "oneflow/core/framework/user_op_kernel_registry.h" #include "oneflow/core/framework/arg_tuple.h" +#include "oneflow/core/framework/op_interpreter.h" namespace oneflow { diff --git a/tools/generate_functional_api.py b/tools/generate_functional_api.py index 409858934ad..97a7598e84e 100644 --- a/tools/generate_functional_api.py +++ b/tools/generate_functional_api.py @@ -64,6 +64,7 @@ #include "oneflow/core/common/optional.h" #include "oneflow/core/framework/tensor.h" #include "oneflow/core/framework/tensor_tuple.h" +#include "oneflow/core/framework/random_generator.h" #include "oneflow/core/functional/scalar.h" namespace oneflow {{ @@ -147,6 +148,7 @@ "BoolList", "DataType", "Shape", + "Generator", } generic_type_aliases = { @@ -173,6 +175,7 @@ "BoolList": "const std::vector&", "DataType": "const DataType&", "Shape": "const Shape&", + "Generator": "const std::shared_ptr&", **generic_type_aliases, } @@ -191,6 +194,7 @@ "BoolList": "const Optional>&", "DataType": "const Optional&", "Shape": "const Optional&", + "Generator": "const Optional&", **{k: "const Optional<{0}>".format(v) for k, v in generic_type_aliases.items()}, }