From e9283811964c95d9a49e2814bf565566b7a19871 Mon Sep 17 00:00:00 2001 From: wyg1997 Date: Tue, 27 Jul 2021 12:56:39 +0800 Subject: [PATCH 01/32] feat(Tensor): support 0shape tensor --- oneflow/core/functional/impl/array_functor.cpp | 2 -- oneflow/user/kernels/concat_kernel.cpp | 1 + oneflow/user/kernels/empty_kernel.cpp | 4 ---- python/oneflow/framework/tensor.py | 6 ++++-- 4 files changed, 5 insertions(+), 8 deletions(-) diff --git a/oneflow/core/functional/impl/array_functor.cpp b/oneflow/core/functional/impl/array_functor.cpp index c1394ca3da1..e5085d9ed1f 100644 --- a/oneflow/core/functional/impl/array_functor.cpp +++ b/oneflow/core/functional/impl/array_functor.cpp @@ -863,8 +863,6 @@ class TensorGetItemFunctor { JUST(PrepareSliceIndices(index, *(x->shape()), &slice_indices, &tensor_indices, &target_dims)); CHECK_EQ_OR_RETURN(slice_indices.size(), ndims) << "Failed to prepare slice indices."; Shape target_shape(DimVector(target_dims.begin(), target_dims.end())); - CHECK_GT_OR_RETURN(target_shape.Count(0), 0) - << "Target shape is zero shape which was not supported yet."; std::vector start(ndims), end(ndims), step(ndims); for (int i = 0; i < ndims; ++i) { diff --git a/oneflow/user/kernels/concat_kernel.cpp b/oneflow/user/kernels/concat_kernel.cpp index 3619dba8cbe..2fe564317a9 100644 --- a/oneflow/user/kernels/concat_kernel.cpp +++ b/oneflow/user/kernels/concat_kernel.cpp @@ -59,6 +59,7 @@ class ConcatKernel final : public user_op::OpKernel { for (const auto& in_arg_pair : ctx->inputs()) { const user_op::Tensor* in_tensor = ctx->Tensor4ArgNameAndIndex(in_arg_pair.first, in_arg_pair.second); + if (in_tensor->shape().elem_cnt() == 0) { continue; } const int64_t in_cols = in_tensor->shape().Count(axis); CHECK_EQ(in_tensor->shape().elem_cnt(), rows * in_cols); if (in_cols > 0) { diff --git a/oneflow/user/kernels/empty_kernel.cpp b/oneflow/user/kernels/empty_kernel.cpp index 732d6d2d481..e056c83a7ba 100644 --- a/oneflow/user/kernels/empty_kernel.cpp +++ b/oneflow/user/kernels/empty_kernel.cpp @@ -27,10 +27,6 @@ class EmptyKernel final : public OpKernel { private: void Compute(user_op::KernelComputeContext* ctx) const override { - Tensor* out_tensor = ctx->Tensor4ArgNameAndIndex("out", 0); - const int64_t elem_cnt = out_tensor->shape().elem_cnt(); - CHECK_GT(elem_cnt, 0); - // Do nothing } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } diff --git a/python/oneflow/framework/tensor.py b/python/oneflow/framework/tensor.py index 29500ebeef4..4f5c391af4f 100644 --- a/python/oneflow/framework/tensor.py +++ b/python/oneflow/framework/tensor.py @@ -62,7 +62,8 @@ def _local_tensor_numpy(eager_local_tensor): tuple(eager_local_tensor.shape), dtype=flow.convert_oneflow_dtype_to_numpy_dtype(eager_local_tensor.dtype), ) - copy_to_numpy(ndarray) + if ndarray.size != 0: + copy_to_numpy(ndarray) return ndarray @@ -77,7 +78,8 @@ def _copy_from_numpy_to_eager_local_tensor(eager_local_tensor, np_arr): assert tuple(eager_local_tensor.shape) == (1,) else: assert np_arr.shape == tuple(eager_local_tensor.shape) - copy_from_numpy(np_arr) + if np_arr.size != 0: + copy_from_numpy(np_arr) @register_local_tensor_method("_init_by_initializer_conf") From 06b893a5470109cef9bcb53ccac665ae76606dd9 Mon Sep 17 00:00:00 2001 From: liufengwei <2472937968@qq.com> Date: Tue, 27 Jul 2021 17:39:33 +0800 Subject: [PATCH 02/32] math binary broadcast support emoty tensor input --- oneflow/core/ndarray/ndarray_apply_broadcast_binary.h | 4 +++- oneflow/user/ops/math_binary_broadcast_ops.cpp | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/oneflow/core/ndarray/ndarray_apply_broadcast_binary.h b/oneflow/core/ndarray/ndarray_apply_broadcast_binary.h index eaf5a76a001..2250e059be9 100644 --- a/oneflow/core/ndarray/ndarray_apply_broadcast_binary.h +++ b/oneflow/core/ndarray/ndarray_apply_broadcast_binary.h @@ -99,7 +99,9 @@ struct NdarrayApplyBroadcastBinary< CHECK_EQ(y.shape().NumAxes(), a.shape().NumAxes()); CHECK_EQ(y.shape().NumAxes(), b.shape().NumAxes()); for (int i = 0; i < y.shape().NumAxes(); ++i) { - CHECK_EQ(y.shape().At(i), std::max(a.shape().At(i), b.shape().At(i))); + CHECK_EQ(y.shape().At(i), (a.shape().At(i) == 0 || b.shape().At(i) == 0) + ? 0 + : std::max(a.shape().At(i), b.shape().At(i))); if (a.shape().At(i) != b.shape().At(i)) { CHECK(a.shape().At(i) == 1 || b.shape().At(i) == 1); } diff --git a/oneflow/user/ops/math_binary_broadcast_ops.cpp b/oneflow/user/ops/math_binary_broadcast_ops.cpp index 98008f9c877..641781191c9 100644 --- a/oneflow/user/ops/math_binary_broadcast_ops.cpp +++ b/oneflow/user/ops/math_binary_broadcast_ops.cpp @@ -47,7 +47,9 @@ Maybe InferTensorDescBinaryBroadcastNormal(user_op::InferContext* ctx) { CHECK_OR_RETURN(x_shape.At(i) == 1 || y_shape.At(i) == 1 || x_shape.At(i) == y_shape.At(i)) << "op: " << ctx->op_name() << ", type: " << ctx->op_type_name() << ", i: " << i << ", x_shape: " << x_shape << ", y_shape: " << y_shape; - out_shape.Set(i, std::max(x_shape.At(i), y_shape.At(i))); + out_shape.Set(i, (x_shape.At(i) == 0 || y_shape.At(i) == 0) + ? 0 + : std::max(x_shape.At(i), y_shape.At(i))); } *tensor_z->mut_shape() = out_shape; } From 5d04f27d0fbea366ea675e26a892ac2a1ec69ac9 Mon Sep 17 00:00:00 2001 From: liufengwei <2472937968@qq.com> Date: Tue, 27 Jul 2021 18:25:53 +0800 Subject: [PATCH 03/32] slice support empty tensor input and output --- oneflow/user/ops/slice_op.cpp | 4 ++-- python/oneflow/ops/array_ops.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/oneflow/user/ops/slice_op.cpp b/oneflow/user/ops/slice_op.cpp index e409f7058b3..2538a7e21ec 100644 --- a/oneflow/user/ops/slice_op.cpp +++ b/oneflow/user/ops/slice_op.cpp @@ -49,10 +49,10 @@ Maybe InferSliceOpTensorDesc(user_op::InferContext* ctx) { int64_t start = RegulateSliceStart(start_vec.at(i), dim_size); int64_t stop = RegulateSliceStop(stop_vec.at(i), dim_size); if (step > 0) { - CHECK_LT_OR_RETURN(start, stop) << "slice start must be less than stop when step > 0" + CHECK_LE_OR_RETURN(start, stop) << "slice start must be less than stop when step > 0" ", otherwise empty result will be outputted."; } else { - CHECK_GT_OR_RETURN(start, stop) << "slice start must be more than stop when step < 0" + CHECK_GE_OR_RETURN(start, stop) << "slice start must be more than stop when step < 0" ", otherwise empty result will be outputted."; } const int64_t diff = (step > 0) ? (stop - start - 1) : (stop - start + 1); diff --git a/python/oneflow/ops/array_ops.py b/python/oneflow/ops/array_ops.py index f32c64bf5cc..10dbb0e0adc 100644 --- a/python/oneflow/ops/array_ops.py +++ b/python/oneflow/ops/array_ops.py @@ -37,14 +37,14 @@ def check_slice_tup_list(slice_tup_list, shape): if not all((isinstance(idx, int) or idx is None for idx in slice_tup)): raise ValueError("element of slice tuple must int or None") (start, stop, step) = slice_tup - if step is None: + if step is None or start == stop: step = 1 if step == 0: raise ValueError("slice step can't be 0") if start is None: start = 0 if step > 0 else np.iinfo(np.int64).max elif start < -dim_size or start >= dim_size: - raise ValueError("slice start must be in range [-size, size)") + start, stop, step = 0, 0, 1 if stop is None: stop = np.iinfo(np.int64).max if step > 0 else np.iinfo(np.int64).min elif stop < -dim_size - 1 or stop > dim_size: From acf3fcad73a05dcb128cf5eb578a2cf1ab738429 Mon Sep 17 00:00:00 2001 From: liufengwei <2472937968@qq.com> Date: Tue, 27 Jul 2021 18:33:23 +0800 Subject: [PATCH 04/32] fix check in slice --- oneflow/user/ops/slice_op.cpp | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/oneflow/user/ops/slice_op.cpp b/oneflow/user/ops/slice_op.cpp index 2538a7e21ec..48ca98b5b70 100644 --- a/oneflow/user/ops/slice_op.cpp +++ b/oneflow/user/ops/slice_op.cpp @@ -49,11 +49,13 @@ Maybe InferSliceOpTensorDesc(user_op::InferContext* ctx) { int64_t start = RegulateSliceStart(start_vec.at(i), dim_size); int64_t stop = RegulateSliceStop(stop_vec.at(i), dim_size); if (step > 0) { - CHECK_LE_OR_RETURN(start, stop) << "slice start must be less than stop when step > 0" - ", otherwise empty result will be outputted."; + CHECK_LE_OR_RETURN(start, stop) + << "slice start must be less than or equal to stop when step > 0" + ", otherwise empty result will be outputted."; } else { - CHECK_GE_OR_RETURN(start, stop) << "slice start must be more than stop when step < 0" - ", otherwise empty result will be outputted."; + CHECK_GE_OR_RETURN(start, stop) + << "slice start must be more than or equal to stop when step < 0" + ", otherwise empty result will be outputted."; } const int64_t diff = (step > 0) ? (stop - start - 1) : (stop - start + 1); dim_vec[i] = diff / step + 1; From 04d29a02d4b7fe2f2144c743c20f6ec226c861df Mon Sep 17 00:00:00 2001 From: wyg1997 Date: Tue, 27 Jul 2021 19:50:43 +0800 Subject: [PATCH 05/32] test(Cat): add 0shape cat module test --- python/oneflow/test/modules/test_concat.py | 9 +++++++++ .../torch_flow_dual_object.py | 18 +++++++++++++++--- 2 files changed, 24 insertions(+), 3 deletions(-) diff --git a/python/oneflow/test/modules/test_concat.py b/python/oneflow/test/modules/test_concat.py index e5dea24b23a..8e19203a7e7 100644 --- a/python/oneflow/test/modules/test_concat.py +++ b/python/oneflow/test/modules/test_concat.py @@ -22,6 +22,7 @@ import oneflow as flow import oneflow.unittest +from automated_test_util import * def _test_concat_origin(test_case, device): @@ -132,6 +133,14 @@ def test_concat(test_case): for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) + @autotest(n=10, auto_backward=False) + def test_0shape_concat(test_case): + device = random_device() + x = random_pytorch_tensor(4, 2, 3, 2 ,4).to(device) + y = random_pytorch_tensor(4, 2, 3, random(0, 3) ,4).to(device) + z = torch.cat((x, y), dim=2) + return z + if __name__ == "__main__": unittest.main() diff --git a/python/oneflow/test_utils/automated_test_util/torch_flow_dual_object.py b/python/oneflow/test_utils/automated_test_util/torch_flow_dual_object.py index 232ea133e03..e2066d59220 100644 --- a/python/oneflow/test_utils/automated_test_util/torch_flow_dual_object.py +++ b/python/oneflow/test_utils/automated_test_util/torch_flow_dual_object.py @@ -78,9 +78,20 @@ def get_generator_value(x): return x for arg in args: - arg = get_generator_value(arg) - pytorch_args.append(get_pytorch_value(arg)) - oneflow_args.append(get_oneflow_value(arg)) + # TODO: refine codes + if isinstance(arg, tuple): + pytorch_tuple_args = [] + oneflow_tuple_args = [] + for t in arg: + t = get_generator_value(t) + pytorch_tuple_args.append(get_pytorch_value(t)) + oneflow_tuple_args.append(get_oneflow_value(t)) + pytorch_args.append(tuple(pytorch_tuple_args)) + oneflow_args.append(tuple(oneflow_tuple_args)) + else: + arg = get_generator_value(arg) + pytorch_args.append(get_pytorch_value(arg)) + oneflow_args.append(get_oneflow_value(arg)) for (key, value) in kwargs.items(): value = get_generator_value(value) if isinstance(value, Nothing): @@ -252,6 +263,7 @@ def new_f(test_case): except PyTorchDoesNotSupportError as e: if verbose: print(e) + n -= 1 continue if res is not None: if not isinstance(res, collections.abc.Sequence): From acefe4a9bfbda3227dc07eef20ade67c503bf4ec Mon Sep 17 00:00:00 2001 From: daquexian Date: Thu, 29 Jul 2021 17:14:58 +0800 Subject: [PATCH 06/32] fix return type error on gcc 4.8.5 Signed-off-by: daquexian --- oneflow/core/framework/py_remote_blob.cpp | 1 + oneflow/core/job_rewriter/pass_util.h | 1 + .../quantization_aware_training.cpp | 89 ++++++++++--------- 3 files changed, 47 insertions(+), 44 deletions(-) diff --git a/oneflow/core/framework/py_remote_blob.cpp b/oneflow/core/framework/py_remote_blob.cpp index 7da4754837e..12892d1f72e 100644 --- a/oneflow/core/framework/py_remote_blob.cpp +++ b/oneflow/core/framework/py_remote_blob.cpp @@ -197,6 +197,7 @@ int64_t EagerBlobTrait::split_axis() const { return INVALID_SPLIT_AXIS; } else { UNIMPLEMENTED(); + return 0; } } diff --git a/oneflow/core/job_rewriter/pass_util.h b/oneflow/core/job_rewriter/pass_util.h index 2b1ea58add2..23620b7b208 100644 --- a/oneflow/core/job_rewriter/pass_util.h +++ b/oneflow/core/job_rewriter/pass_util.h @@ -20,6 +20,7 @@ limitations under the License. namespace oneflow { #define INSERT_CHECK(expr) CHECK(expr.second) +#define INSERT_CHECK_OR_RETURN(expr) CHECK_OR_RETURN(expr.second) template bool IsKeyFound(const MapT& m, const KeyT& k) { diff --git a/oneflow/core/job_rewriter/quantization_aware_training.cpp b/oneflow/core/job_rewriter/quantization_aware_training.cpp index c888ed71218..b6d69af0a15 100644 --- a/oneflow/core/job_rewriter/quantization_aware_training.cpp +++ b/oneflow/core/job_rewriter/quantization_aware_training.cpp @@ -40,11 +40,12 @@ const std::string MUL_BIAS_SUFFIX = "-fake-quant-mul-bias"; const std::string OBSERVER_SUFFIX = "-fake-quant-observer"; const std::string TRAIN_STEP_SUFFIX = "-fake-train-step"; -void VerifyQATList(const OpTypeSet& op_list) { +Maybe VerifyQATList(const OpTypeSet& op_list) { for (const auto& op_type : op_list) { - CHECK(user_op::UserOpRegistryMgr::Get().GetOpRegistryResult(op_type) != nullptr) + CHECK_OR_RETURN(user_op::UserOpRegistryMgr::Get().GetOpRegistryResult(op_type) != nullptr) << "Cannot find " << op_type << " of QuantAwareTraining list in OpRegistry."; } + return Maybe::Ok(); } HashMap scale_map; @@ -168,47 +169,47 @@ std::string QuantizationSchemeAttr4QatConfig(const QatConfig& qat_config) { } // TODO: refactor the following 4 methods by registration -std::string QuantizationFormulaAttr4QatConfig(const QatConfig& qat_config) { +Maybe QuantizationFormulaAttr4QatConfig(const QatConfig& qat_config) { const auto target_backend = qat_config.target_backend(); if (target_backend == "" || target_backend == "tensorrt") { - return "google"; + return std::string("google"); } else if (target_backend == "cambricon") { - return "cambricon"; + return std::string("cambricon"); } else { - UNIMPLEMENTED(); + UNIMPLEMENTED_THEN_RETURN(); } } -OpTypeSet Int8List4QatConfig(const QatConfig& qat_config) { +Maybe Int8List4QatConfig(const QatConfig& qat_config) { const auto target_backend = qat_config.target_backend(); if (target_backend == "") { - return {"add_n", "matmul", "batch_matmul", "conv2d", "avg_pool_2d", "max_pool_2d"}; + return OpTypeSet{"add_n", "matmul", "batch_matmul", "conv2d", "avg_pool_2d", "max_pool_2d"}; } else if (target_backend == "cambricon" || target_backend == "tensorrt") { - return {"conv2d", "matmul"}; + return OpTypeSet{"conv2d", "matmul"}; } else { - UNIMPLEMENTED(); + UNIMPLEMENTED_THEN_RETURN(); } } -OpTypeSet TransparentList4QatConfig(const QatConfig& qat_config) { +Maybe TransparentList4QatConfig(const QatConfig& qat_config) { const auto target_backend = qat_config.target_backend(); if (target_backend == "" || target_backend == "tensorrt") { - return {"reshape"}; + return OpTypeSet{"reshape"}; } else if (target_backend == "cambricon") { - return {}; + return OpTypeSet{}; } else { - UNIMPLEMENTED(); + UNIMPLEMENTED_THEN_RETURN(); } } -bool InsertQuantOpAfterInt8Ops4QatConfig(const QatConfig& qat_config) { +Maybe InsertQuantOpAfterInt8Ops4QatConfig(const QatConfig& qat_config) { const auto target_backend = qat_config.target_backend(); if (target_backend == "" || target_backend == "tensorrt") { return true; } else if (target_backend == "cambricon") { return false; } else { - UNIMPLEMENTED(); + UNIMPLEMENTED_THEN_RETURN(); } } @@ -226,7 +227,7 @@ user_op::UserOpConfWrapper MultiplyOp(const std::string& name, const std::string return op_wrapper; } -user_op::UserOpConfWrapper MinMaxObserver(const std::string& name, const std::string& input, +Maybe MinMaxObserver(const std::string& name, const std::string& input, const QatConfig& qat_config, const int64_t scope_symbol_id, OpConfMap* inserted_ops) { const auto op_wrapper = @@ -235,7 +236,7 @@ user_op::UserOpConfWrapper MinMaxObserver(const std::string& name, const std::st .Input("in", input) .Output("scale") .Output("zero_point") - .Attr("quantization_formula", QuantizationFormulaAttr4QatConfig(qat_config)) + .Attr("quantization_formula", *JUST(QuantizationFormulaAttr4QatConfig(qat_config))) .Attr("quantization_scheme", QuantizationSchemeAttr4QatConfig(qat_config)) .Attr("per_layer_quantization", PerLayerQuantizationAttr4Config(qat_config)) .ScopeSymbolId(scope_symbol_id) @@ -244,7 +245,7 @@ user_op::UserOpConfWrapper MinMaxObserver(const std::string& name, const std::st return op_wrapper; } -user_op::UserOpConfWrapper MovingMinMaxObserver(const std::string& name, const std::string& input, +Maybe MovingMinMaxObserver(const std::string& name, const std::string& input, const std::string& train_step_lbn, const QatConfig& qat_config, const int64_t scope_symbol_id, @@ -276,7 +277,7 @@ user_op::UserOpConfWrapper MovingMinMaxObserver(const std::string& name, const s .Output("zero_point") .Attr("training", GlobalJobDesc().IsTrain()) .Attr("stop_update_after_iters", qat_config.moving_min_max_stop_update_after_iters()) - .Attr("quantization_formula", QuantizationFormulaAttr4QatConfig(qat_config)) + .Attr("quantization_formula", *JUST(QuantizationFormulaAttr4QatConfig(qat_config))) .Attr("quantization_scheme", QuantizationSchemeAttr4QatConfig(qat_config)) .Attr("momentum", qat_config.moving_min_max_momentum()) .ScopeSymbolId(scope_symbol_id) @@ -285,7 +286,7 @@ user_op::UserOpConfWrapper MovingMinMaxObserver(const std::string& name, const s return op_wrapper; } -user_op::UserOpConfWrapper FakeQuantOp(const std::string& name, const std::string& input, +Maybe FakeQuantOp(const std::string& name, const std::string& input, const std::string& scale, const std::string& zero_point, const QatConfig& qat_config, const int64_t scope_symbol_id, OpConfMap* inserted_ops) { @@ -295,7 +296,7 @@ user_op::UserOpConfWrapper FakeQuantOp(const std::string& name, const std::strin .Input("in", input) .Input("scale", scale) .Input("zero_point", zero_point) - .Attr("quantization_formula", QuantizationFormulaAttr4QatConfig(qat_config)) + .Attr("quantization_formula", *JUST(QuantizationFormulaAttr4QatConfig(qat_config))) .Attr("quantization_scheme", QuantizationSchemeAttr4QatConfig(qat_config)) .Output("out") .ScopeSymbolId(scope_symbol_id) @@ -329,15 +330,15 @@ Maybe GetScaleAndZeroPointLbn4Edge(OpEdge* edge, const std::string train_s const std::string observer_op_name = ReplaceSlashToDash4Lbn(lbn) + OBSERVER_SUFFIX; if (IsWeightEdge(edge)) { const auto observer_op = - MinMaxObserver(observer_op_name, lbn, qat_config, scope_symbol_id, inserted_ops); - *scale = observer_op.output("scale", 0); - *zero_point = observer_op.output("zero_point", 0); + JUST(MinMaxObserver(observer_op_name, lbn, qat_config, scope_symbol_id, inserted_ops)); + *scale = observer_op->output("scale", 0); + *zero_point = observer_op->output("zero_point", 0); } else { CHECK_OR_RETURN(qat_config.has_moving_min_max_stop_update_after_iters()); - const auto observer_op = MovingMinMaxObserver(observer_op_name, lbn, train_step_lbn, - qat_config, scope_symbol_id, inserted_ops); - *scale = observer_op.output("scale", 0); - *zero_point = observer_op.output("zero_point", 0); + const auto observer_op = JUST(MovingMinMaxObserver(observer_op_name, lbn, train_step_lbn, + qat_config, scope_symbol_id, inserted_ops)); + *scale = observer_op->output("scale", 0); + *zero_point = observer_op->output("zero_point", 0); } } return Maybe::Ok(); @@ -374,9 +375,9 @@ class QuantAwareTraining final : public JobPass { HashSet downstream_white, Job* job) const; }; -bool IsNodeQuantizationEnabled(const OpNode& node) { +Maybe IsNodeQuantizationEnabled(const OpNode& node) { int64_t scope_symbol_id = node.op().op_conf().scope_symbol_id(); - CHECK(Global>::Get()->Has(scope_symbol_id)); + CHECK_OR_RETURN(Global>::Get()->Has(scope_symbol_id)); const Scope& scope = Global>::Get()->Get(scope_symbol_id); return scope.Bool("quantization_aware_training"); } @@ -384,20 +385,20 @@ bool IsNodeQuantizationEnabled(const OpNode& node) { Maybe QuantAwareTraining::Apply(Job* job, JobPassCtx* ctx) const { if (!IsEnabled(*ctx)) { return Maybe::Ok(); } const OpGraph op_graph(*job); - CHECK(GlobalJobDesc().DefaultDataType() == DataType::kFloat); + CHECK_OR_RETURN(GlobalJobDesc().DefaultDataType() == DataType::kFloat); const auto qat_config = ctx->job_desc().job_conf().qat_config(); - OpTypeSet int8_list = Int8List4QatConfig(qat_config); - OpTypeSet transparent_list = TransparentList4QatConfig(qat_config); + OpTypeSet int8_list = *JUST(Int8List4QatConfig(qat_config)); + OpTypeSet transparent_list = *JUST(TransparentList4QatConfig(qat_config)); // if `insert_quant_op_after_int8_ops` is false, // always insert quant op before int8 ops. // if `insert_quant_op_after_int8_ops` is true, // always insert quant op after int8 ops - bool insert_quant_op_after_int8_ops = InsertQuantOpAfterInt8Ops4QatConfig(qat_config); + bool insert_quant_op_after_int8_ops = JUST(InsertQuantOpAfterInt8Ops4QatConfig(qat_config)); - VerifyQATList(int8_list); - VerifyQATList(transparent_list); + JUST(VerifyQATList(int8_list)); + JUST(VerifyQATList(transparent_list)); std::function OpName4Node = [](OpNode* const& node) { return node->op().op_name(); @@ -456,7 +457,7 @@ Maybe QuantAwareTraining::InsertFakeQuantOp(const QatConfig& qat_config, const std::string lbn = GenLogicalBlobName(edge->lbis().front()); scale_map[lbn] = ReplaceSlashToDash4Lbn(lbn) + OBSERVER_SUFFIX + "/scale_0"; VLOG(3) << "set " << lbn << " to " << scale_map[lbn]; - INSERT_CHECK(white_set_edges.insert(edge)); + INSERT_CHECK_OR_RETURN(white_set_edges.insert(edge)); return Maybe::Ok(); }; auto PropagateScale = [](OpNode* node) -> Maybe { @@ -478,16 +479,16 @@ Maybe QuantAwareTraining::InsertFakeQuantOp(const QatConfig& qat_config, if (IsKeyFound(white_set, node)) { for (OpEdge* edge : node->in_edges()) { if (IsKeyFound(white_set, edge->src_node())) { continue; } - if (IsNodeQuantizationEnabled(*edge->dst_node())) { JUST(AddWhiteSetEdge(edge)); } + if (JUST(IsNodeQuantizationEnabled(*edge->dst_node()))) { JUST(AddWhiteSetEdge(edge)); } } if (IsNodeInList(int8_list, node)) { if (insert_quant_op_after_int8_ops) { OpNode* inference_node = JUST(GetInferenceOutputNode(op_graph, node)); - if (IsNodeQuantizationEnabled(*inference_node)) { + if (JUST(IsNodeQuantizationEnabled(*inference_node))) { for (OpEdge* edge : inference_node->out_edges()) { JUST(AddWhiteSetEdge(edge)); } } } else { - if (IsNodeQuantizationEnabled(*node)) { + if (JUST(IsNodeQuantizationEnabled(*node))) { for (OpEdge* edge : node->in_edges()) { if (white_set_edges.find(edge) == white_set_edges.end()) { JUST(AddWhiteSetEdge(edge)); @@ -535,10 +536,10 @@ Maybe QuantAwareTraining::InsertFakeQuantOp(const QatConfig& qat_config, JUST(GetScaleAndZeroPointLbn4Edge(edge, job->job_conf().train_conf().train_step_lbn(), &scale, &zero_point, qat_config, scope_symbol_id, &inserted_ops)); const std::string fake_quant_op_name = ReplaceSlashToDash4Lbn(lbn) + FAKE_QUANT_SUFFIX; - const auto fake_quant_op = FakeQuantOp(fake_quant_op_name, lbn, scale, zero_point, qat_config, - scope_symbol_id, &inserted_ops); + const auto fake_quant_op = JUST(FakeQuantOp(fake_quant_op_name, lbn, scale, zero_point, qat_config, + scope_symbol_id, &inserted_ops)); - const std::string fake_quant_op_output_name = fake_quant_op.output("out", 0); + const std::string fake_quant_op_output_name = fake_quant_op->output("out", 0); JUST(ReplaceInputLbn4DstNodeOfEdge(edge, fake_quant_op_output_name, &op_conf_cache)); } From b027fe8947b471ac1f26e3ac747f9fdf7dc7be6a Mon Sep 17 00:00:00 2001 From: oneflow-ci-bot Date: Thu, 29 Jul 2021 10:09:52 +0000 Subject: [PATCH 07/32] auto format by CI --- .../quantization_aware_training.cpp | 38 ++++++++++--------- 1 file changed, 21 insertions(+), 17 deletions(-) diff --git a/oneflow/core/job_rewriter/quantization_aware_training.cpp b/oneflow/core/job_rewriter/quantization_aware_training.cpp index b6d69af0a15..97ffaa44729 100644 --- a/oneflow/core/job_rewriter/quantization_aware_training.cpp +++ b/oneflow/core/job_rewriter/quantization_aware_training.cpp @@ -228,15 +228,17 @@ user_op::UserOpConfWrapper MultiplyOp(const std::string& name, const std::string } Maybe MinMaxObserver(const std::string& name, const std::string& input, - const QatConfig& qat_config, - const int64_t scope_symbol_id, OpConfMap* inserted_ops) { + const QatConfig& qat_config, + const int64_t scope_symbol_id, + OpConfMap* inserted_ops) { const auto op_wrapper = user_op::UserOpConfWrapperBuilder(name) .Op("min_max_observer") .Input("in", input) .Output("scale") .Output("zero_point") - .Attr("quantization_formula", *JUST(QuantizationFormulaAttr4QatConfig(qat_config))) + .Attr("quantization_formula", + *JUST(QuantizationFormulaAttr4QatConfig(qat_config))) .Attr("quantization_scheme", QuantizationSchemeAttr4QatConfig(qat_config)) .Attr("per_layer_quantization", PerLayerQuantizationAttr4Config(qat_config)) .ScopeSymbolId(scope_symbol_id) @@ -245,11 +247,9 @@ Maybe MinMaxObserver(const std::string& name, const return op_wrapper; } -Maybe MovingMinMaxObserver(const std::string& name, const std::string& input, - const std::string& train_step_lbn, - const QatConfig& qat_config, - const int64_t scope_symbol_id, - OpConfMap* inserted_ops) { +Maybe MovingMinMaxObserver( + const std::string& name, const std::string& input, const std::string& train_step_lbn, + const QatConfig& qat_config, const int64_t scope_symbol_id, OpConfMap* inserted_ops) { const std::string moving_max_name = name + MOVING_MAX_SUFFIX; const std::string moving_min_name = name + MOVING_MIN_SUFFIX; const auto moving_max_var = @@ -277,7 +277,8 @@ Maybe MovingMinMaxObserver(const std::string& name, .Output("zero_point") .Attr("training", GlobalJobDesc().IsTrain()) .Attr("stop_update_after_iters", qat_config.moving_min_max_stop_update_after_iters()) - .Attr("quantization_formula", *JUST(QuantizationFormulaAttr4QatConfig(qat_config))) + .Attr("quantization_formula", + *JUST(QuantizationFormulaAttr4QatConfig(qat_config))) .Attr("quantization_scheme", QuantizationSchemeAttr4QatConfig(qat_config)) .Attr("momentum", qat_config.moving_min_max_momentum()) .ScopeSymbolId(scope_symbol_id) @@ -287,16 +288,19 @@ Maybe MovingMinMaxObserver(const std::string& name, } Maybe FakeQuantOp(const std::string& name, const std::string& input, - const std::string& scale, const std::string& zero_point, - const QatConfig& qat_config, const int64_t scope_symbol_id, - OpConfMap* inserted_ops) { + const std::string& scale, + const std::string& zero_point, + const QatConfig& qat_config, + const int64_t scope_symbol_id, + OpConfMap* inserted_ops) { const auto op_wrapper = user_op::UserOpConfWrapperBuilder(name) .Op("fake_quantization") .Input("in", input) .Input("scale", scale) .Input("zero_point", zero_point) - .Attr("quantization_formula", *JUST(QuantizationFormulaAttr4QatConfig(qat_config))) + .Attr("quantization_formula", + *JUST(QuantizationFormulaAttr4QatConfig(qat_config))) .Attr("quantization_scheme", QuantizationSchemeAttr4QatConfig(qat_config)) .Output("out") .ScopeSymbolId(scope_symbol_id) @@ -335,8 +339,8 @@ Maybe GetScaleAndZeroPointLbn4Edge(OpEdge* edge, const std::string train_s *zero_point = observer_op->output("zero_point", 0); } else { CHECK_OR_RETURN(qat_config.has_moving_min_max_stop_update_after_iters()); - const auto observer_op = JUST(MovingMinMaxObserver(observer_op_name, lbn, train_step_lbn, - qat_config, scope_symbol_id, inserted_ops)); + const auto observer_op = JUST(MovingMinMaxObserver( + observer_op_name, lbn, train_step_lbn, qat_config, scope_symbol_id, inserted_ops)); *scale = observer_op->output("scale", 0); *zero_point = observer_op->output("zero_point", 0); } @@ -536,8 +540,8 @@ Maybe QuantAwareTraining::InsertFakeQuantOp(const QatConfig& qat_config, JUST(GetScaleAndZeroPointLbn4Edge(edge, job->job_conf().train_conf().train_step_lbn(), &scale, &zero_point, qat_config, scope_symbol_id, &inserted_ops)); const std::string fake_quant_op_name = ReplaceSlashToDash4Lbn(lbn) + FAKE_QUANT_SUFFIX; - const auto fake_quant_op = JUST(FakeQuantOp(fake_quant_op_name, lbn, scale, zero_point, qat_config, - scope_symbol_id, &inserted_ops)); + const auto fake_quant_op = JUST(FakeQuantOp(fake_quant_op_name, lbn, scale, zero_point, + qat_config, scope_symbol_id, &inserted_ops)); const std::string fake_quant_op_output_name = fake_quant_op->output("out", 0); From 55f3daf82404671d4ff5429650f5eb0980d32f49 Mon Sep 17 00:00:00 2001 From: liufengwei <2472937968@qq.com> Date: Thu, 29 Jul 2021 22:13:50 +0800 Subject: [PATCH 08/32] add module op test for empty tensor, cuda kernel support empty tensor --- oneflow/core/kernel/kernel_util.cu | 30 +++++++----- .../kernel/util/cuda_arithemetic_interface.cu | 8 ++-- .../core/kernel/util/cuda_dnn_interface.cu | 14 +++--- .../core/ndarray/ndarray_apply_binary_core.cu | 12 +++-- .../ndarray_apply_broadcast_binary_core.cu | 18 ++++--- .../ndarray_apply_broadcast_unary_core.cu | 2 +- .../core/ndarray/ndarray_apply_unary_core.cu | 4 +- oneflow/core/ndarray/ndarray_assign_core.cu | 2 +- oneflow/user/kernels/add_n_kernel.cu | 8 ++-- .../kernels/math_binary_elementwise_kernel.cu | 9 ++-- .../kernels/math_unary_elementwise_kernel.cu | 8 ++-- oneflow/user/kernels/slice_util.cu | 8 ++-- python/oneflow/test/modules/test_abs.py | 8 ++++ python/oneflow/test/modules/test_acos.py | 8 ++++ python/oneflow/test/modules/test_acosh.py | 8 ++++ .../oneflow/test/modules/test_activation.py | 48 +++++++++++++++++++ python/oneflow/test/modules/test_add.py | 12 +++++ python/oneflow/test/modules/test_argwhere.py | 2 +- python/oneflow/test/modules/test_atan.py | 8 ++++ python/oneflow/test/modules/test_atan2.py | 8 ++++ python/oneflow/test/modules/test_atanh.py | 8 ++++ python/oneflow/test/modules/test_cast.py | 11 +++++ python/oneflow/test/modules/test_ceil.py | 7 +++ python/oneflow/test/modules/test_clamp.py | 7 +++ python/oneflow/test/modules/test_concat.py | 6 +-- python/oneflow/test/modules/test_div.py | 8 ++++ python/oneflow/test/modules/test_eq.py | 9 ++++ python/oneflow/test/modules/test_expm1.py | 7 +++ python/oneflow/test/modules/test_fmod.py | 7 +++ python/oneflow/test/modules/test_greater.py | 9 ++++ python/oneflow/test/modules/test_ne.py | 12 +++++ python/oneflow/test/modules/test_negative.py | 11 +++++ python/oneflow/test/modules/test_pow.py | 2 +- python/oneflow/test/modules/test_reshape.py | 9 ++++ python/oneflow/test/modules/test_sign.py | 8 ++++ python/oneflow/test/modules/test_squeeze.py | 7 +++ python/oneflow/test/modules/test_sub.py | 12 +++++ python/oneflow/test/modules/test_tan.py | 8 ++++ python/oneflow/test/modules/test_transpose.py | 7 +++ python/oneflow/test/modules/test_triu.py | 8 ++++ python/oneflow/test/modules/test_unsqueeze.py | 7 +++ 41 files changed, 343 insertions(+), 52 deletions(-) diff --git a/oneflow/core/kernel/kernel_util.cu b/oneflow/core/kernel/kernel_util.cu index bb877089418..8960a3effbd 100644 --- a/oneflow/core/kernel/kernel_util.cu +++ b/oneflow/core/kernel/kernel_util.cu @@ -655,29 +655,35 @@ __global__ void CastOnGpu(const half* in, float* out, int64_t elem_ template void CopyElemOnGpu(DeviceCtx* ctx, const T* in_dptr, U* out_dptr, int64_t elem_num) { - if (std::is_same::value) { - Memcpy(ctx, out_dptr, in_dptr, elem_num * sizeof(T)); - } else { - CastOnGpu - <<cuda_stream()>>>( - in_dptr, out_dptr, elem_num); + if (elem_num > 0) { + if (std::is_same::value) { + Memcpy(ctx, out_dptr, in_dptr, elem_num * sizeof(T)); + } else { + CastOnGpu + <<cuda_stream()>>>( + in_dptr, out_dptr, elem_num); + } } } template<> void CopyElemOnGpu(DeviceCtx* ctx, const float* in_dptr, float16* out_dptr, int64_t elem_num) { - CastOnGpu - <<cuda_stream()>>>(in_dptr, reinterpret_cast(out_dptr), elem_num); + if (RoundUp(elem_num, 2) > 0) { + CastOnGpu + <<cuda_stream()>>>(in_dptr, reinterpret_cast(out_dptr), elem_num); + } } template<> void CopyElemOnGpu(DeviceCtx* ctx, const float16* in_dptr, float* out_dptr, int64_t elem_num) { - CastOnGpu - <<cuda_stream()>>>(reinterpret_cast(in_dptr), out_dptr, elem_num); + if (RoundUp(elem_num, 2) > 0) { + CastOnGpu + <<cuda_stream()>>>(reinterpret_cast(in_dptr), out_dptr, elem_num); + } } #define INSTANTIATE_COPY_ELEM_ON_GPU(T, U) \ diff --git a/oneflow/core/kernel/util/cuda_arithemetic_interface.cu b/oneflow/core/kernel/util/cuda_arithemetic_interface.cu index 286bf6c3036..be19b41bbd3 100644 --- a/oneflow/core/kernel/util/cuda_arithemetic_interface.cu +++ b/oneflow/core/kernel/util/cuda_arithemetic_interface.cu @@ -69,9 +69,11 @@ void LaunchTransposeGpu(DeviceCtx* ctx, const ShapeView& x_shape, const ShapeVie cur_stride *= x_shape.At(i); } for (int32_t i = 0; i < NDIMS; ++i) { x_strides.val[i] = buff[permutation[i]]; } - TransposeGpu - <<cuda_stream()>>>( - y_shape_struct, x_strides, elem_cnt, x, y); + if (elem_cnt > 0) { + TransposeGpu + <<cuda_stream()>>>( + y_shape_struct, x_strides, elem_cnt, x, y); + } } template diff --git a/oneflow/core/kernel/util/cuda_dnn_interface.cu b/oneflow/core/kernel/util/cuda_dnn_interface.cu index 97e7b9da779..a65e89edb07 100644 --- a/oneflow/core/kernel/util/cuda_dnn_interface.cu +++ b/oneflow/core/kernel/util/cuda_dnn_interface.cu @@ -132,12 +132,14 @@ template struct ReluHelper final { static void ReluForward(DeviceCtx* ctx, const int64_t n, const T* x, T* y) { CHECK_LE(n, GetMaxVal() / 2); - if (x == y) { - InplaceReluForwardGpu - <<cuda_stream()>>>(n, y); - } else { - ReluForwardGpu - <<cuda_stream()>>>(n, x, y); + if (n > 0) { + if (x == y) { + InplaceReluForwardGpu + <<cuda_stream()>>>(n, y); + } else { + ReluForwardGpu + <<cuda_stream()>>>(n, x, y); + } } } diff --git a/oneflow/core/ndarray/ndarray_apply_binary_core.cu b/oneflow/core/ndarray/ndarray_apply_binary_core.cu index c7dea072171..1a0da7e9dfa 100644 --- a/oneflow/core/ndarray/ndarray_apply_binary_core.cu +++ b/oneflow/core/ndarray/ndarray_apply_binary_core.cu @@ -40,14 +40,18 @@ struct NdarrayApplyBinaryCoreWrapper final { const XpuVarNdarray::return_type>& y, const XpuVarNdarray& a, const XpuVarNdarray& b) { size_t n = y.host_shape().HostElemNum(); - RUN_CUDA_KERNEL((NdarrayApplyBinaryApplyGpu), ctx, n, n, y.host_ptr(), - a.host_ptr(), b.host_ptr()); + if (n > 0) { + RUN_CUDA_KERNEL((NdarrayApplyBinaryApplyGpu), ctx, n, n, y.host_ptr(), + a.host_ptr(), b.host_ptr()); + } } static void InplaceApply(DeviceCtx* ctx, const XpuVarNdarray& y, const XpuVarNdarray& x) { size_t n = y.host_shape().HostElemNum(); - RUN_CUDA_KERNEL((NdarrayApplyBinaryInplaceApplyGpu), ctx, n, n, y.host_ptr(), - x.host_ptr()); + if (n > 0) { + RUN_CUDA_KERNEL((NdarrayApplyBinaryInplaceApplyGpu), ctx, n, n, y.host_ptr(), + x.host_ptr()); + } } }; diff --git a/oneflow/core/ndarray/ndarray_apply_broadcast_binary_core.cu b/oneflow/core/ndarray/ndarray_apply_broadcast_binary_core.cu index 9335521a5c0..6e7c683dda5 100644 --- a/oneflow/core/ndarray/ndarray_apply_broadcast_binary_core.cu +++ b/oneflow/core/ndarray/ndarray_apply_broadcast_binary_core.cu @@ -89,9 +89,11 @@ struct NdarrayApplyBroadcastBinaryCoreWrapper::return_type>& y, const XpuVarNdarray& a, const XpuVarNdarray& b) { size_t n = y.host_shape().HostElemNum(); - if (IsKernelSafeInt32(n) && PartialBroadcast(ctx, y, a, b)) { return; } - if (!IsKernelSafeInt32(n) && PartialBroadcast(ctx, y, a, b)) { return; } - RUN_CUDA_KERNEL((GpuBroadcastBinaryFunc), ctx, n, y, a, b); + if (n > 0) { + if (IsKernelSafeInt32(n) && PartialBroadcast(ctx, y, a, b)) { return; } + if (!IsKernelSafeInt32(n) && PartialBroadcast(ctx, y, a, b)) { return; } + RUN_CUDA_KERNEL((GpuBroadcastBinaryFunc), ctx, n, y, a, b); + } } template @@ -151,9 +153,13 @@ struct NdarrayApplyBroadcastInplaceBinaryCoreWrapper a(y.host_shape(), y.host_ptr()); using NBB = NdarrayApplyBroadcastBinaryCoreWrapper; - if (IsKernelSafeInt32(n) && NBB::template PartialBroadcast(ctx, y, a, x)) { return; } - if (!IsKernelSafeInt32(n) && NBB::template PartialBroadcast(ctx, y, a, x)) { return; } - RUN_CUDA_KERNEL((GpuInplaceBroadcastBinaryFunc), ctx, n, y, x); + if (n > 0) { + if (IsKernelSafeInt32(n) && NBB::template PartialBroadcast(ctx, y, a, x)) { return; } + if (!IsKernelSafeInt32(n) && NBB::template PartialBroadcast(ctx, y, a, x)) { + return; + } + RUN_CUDA_KERNEL((GpuInplaceBroadcastBinaryFunc), ctx, n, y, x); + } } }; diff --git a/oneflow/core/ndarray/ndarray_apply_broadcast_unary_core.cu b/oneflow/core/ndarray/ndarray_apply_broadcast_unary_core.cu index 31e9df6c5b6..d45340b6cac 100644 --- a/oneflow/core/ndarray/ndarray_apply_broadcast_unary_core.cu +++ b/oneflow/core/ndarray/ndarray_apply_broadcast_unary_core.cu @@ -30,7 +30,7 @@ template class unary_func> struct NdarrayApplyBroadcastUnaryCoreWrapper final { static void Apply(DeviceCtx* ctx, const XpuVarNdarray& y, const XpuVarNdarray& x) { size_t n = y.host_shape().HostElemNum(); - RUN_CUDA_KERNEL((GpuBroadcastUnaryFunc), ctx, n, y, x); + if (n > 0) { RUN_CUDA_KERNEL((GpuBroadcastUnaryFunc), ctx, n, y, x); } } }; diff --git a/oneflow/core/ndarray/ndarray_apply_unary_core.cu b/oneflow/core/ndarray/ndarray_apply_unary_core.cu index 2b6963b66e1..aa832fa81bd 100644 --- a/oneflow/core/ndarray/ndarray_apply_unary_core.cu +++ b/oneflow/core/ndarray/ndarray_apply_unary_core.cu @@ -31,7 +31,9 @@ template class unary_func> struct NdarrayApplyUnaryCoreWrapper final { static void InplaceApply(DeviceCtx* ctx, const XpuVarNdarray& y) { size_t n = y.host_shape().HostElemNum(); - RUN_CUDA_KERNEL((NdarrayApplyUnaryInplaceApplyGpu), ctx, n, y.host_ptr(), n); + if (n > 0) { + RUN_CUDA_KERNEL((NdarrayApplyUnaryInplaceApplyGpu), ctx, n, y.host_ptr(), n); + } } }; diff --git a/oneflow/core/ndarray/ndarray_assign_core.cu b/oneflow/core/ndarray/ndarray_assign_core.cu index 3ee26a91aad..7d11e919625 100644 --- a/oneflow/core/ndarray/ndarray_assign_core.cu +++ b/oneflow/core/ndarray/ndarray_assign_core.cu @@ -33,7 +33,7 @@ struct NdarrayAssignCoreWrapper final { static void Assign(DeviceCtx* ctx, const XpuVarNdarray& y, const XpuReducedNdarray& reduced) { size_t n = y.host_shape().HostElemNum(); - RUN_CUDA_KERNEL((NdarrayAssignGpu), ctx, n, y, reduced); + if (n > 0) { RUN_CUDA_KERNEL((NdarrayAssignGpu), ctx, n, y, reduced); } } }; diff --git a/oneflow/user/kernels/add_n_kernel.cu b/oneflow/user/kernels/add_n_kernel.cu index 0ae89377395..8c3fd9fbf51 100644 --- a/oneflow/user/kernels/add_n_kernel.cu +++ b/oneflow/user/kernels/add_n_kernel.cu @@ -56,10 +56,10 @@ struct GpuAddCaller { for (int32_t i = 0; i < N; ++i) { para.in[i] = ctx->Tensor4ArgNameAndIndex("in", i)->dptr(); } - - gpu_add - <<device_ctx()->cuda_stream()>>>( - n, para); + if (n > 0) { + gpu_add<<device_ctx()->cuda_stream()>>>(n, para); + } } }; diff --git a/oneflow/user/kernels/math_binary_elementwise_kernel.cu b/oneflow/user/kernels/math_binary_elementwise_kernel.cu index 6dc1e8cea97..f997cf5255c 100644 --- a/oneflow/user/kernels/math_binary_elementwise_kernel.cu +++ b/oneflow/user/kernels/math_binary_elementwise_kernel.cu @@ -52,9 +52,12 @@ class MathBinaryElementwiseGpuKernel final : public user_op::OpKernel { user_op::Tensor* tensor_z = ctx->Tensor4ArgNameAndIndex("z", 0); int64_t n = tensor_x->shape().elem_cnt(); CHECK_LE(n, GetMaxVal() / 2); - MathBinaryElementwiseForwardGpu - <<device_ctx()->cuda_stream()>>>( - n, tensor_x->dptr(), tensor_y->dptr(), tensor_z->mut_dptr()); + if (n > 0) { + MathBinaryElementwiseForwardGpu + <<device_ctx()->cuda_stream()>>>(n, tensor_x->dptr(), tensor_y->dptr(), + tensor_z->mut_dptr()); + } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; diff --git a/oneflow/user/kernels/math_unary_elementwise_kernel.cu b/oneflow/user/kernels/math_unary_elementwise_kernel.cu index 32144bffe32..7ae25112772 100644 --- a/oneflow/user/kernels/math_unary_elementwise_kernel.cu +++ b/oneflow/user/kernels/math_unary_elementwise_kernel.cu @@ -46,9 +46,11 @@ class MathUnaryElementwiseGpuKernel final : public user_op::OpKernel { T* y = tensor_y->mut_dptr(); int64_t n = tensor_x->shape().elem_cnt(); CHECK_LE(n, GetMaxVal() / 2); - MathUnaryElementwiseForwardGpu - <<device_ctx()->cuda_stream()>>>( - n, x, y); + if (n > 0) { + MathUnaryElementwiseForwardGpu + <<device_ctx()->cuda_stream()>>>(n, x, y); + } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; diff --git a/oneflow/user/kernels/slice_util.cu b/oneflow/user/kernels/slice_util.cu index 4ed08a84e18..a8f16549dbd 100644 --- a/oneflow/user/kernels/slice_util.cu +++ b/oneflow/user/kernels/slice_util.cu @@ -48,9 +48,11 @@ void LaunchSliceForward(DeviceCtx* ctx, const SliceParams& params, const T* enti int64_t elem_cnt = params.elem_cnt(); SliceIndexHelper entire_idx_cvtr(params.dims); SliceIndexHelper sliced_idx_cvtr(params.size); - SliceForwardGpu - <<cuda_stream()>>>( - elem_cnt, params, entire_idx_cvtr, sliced_idx_cvtr, entire, sliced); + if (elem_cnt > 0) { + SliceForwardGpu + <<cuda_stream()>>>( + elem_cnt, params, entire_idx_cvtr, sliced_idx_cvtr, entire, sliced); + } } template diff --git a/python/oneflow/test/modules/test_abs.py b/python/oneflow/test/modules/test_abs.py index 43ec34ae345..105beed194d 100644 --- a/python/oneflow/test/modules/test_abs.py +++ b/python/oneflow/test/modules/test_abs.py @@ -23,6 +23,7 @@ import oneflow as flow import oneflow.unittest +from automated_test_util import * def _test_abs_forward(test_case, device): @@ -81,6 +82,13 @@ def test_flow_tensor_abs_with_random_data(test_case): for device in ["cpu", "cuda"]: test_tensor_against_pytorch(test_case, "abs", device=device) + @autotest(auto_backward=False) + def test_abs_with_0shape_data(test_case): + device = random_device() + x = random_pytorch_tensor(4, 2, 1, 0, 3).to(device) + y = torch.abs(x) + return y + if __name__ == "__main__": unittest.main() diff --git a/python/oneflow/test/modules/test_acos.py b/python/oneflow/test/modules/test_acos.py index ad244a0a132..1047ed2b62b 100644 --- a/python/oneflow/test/modules/test_acos.py +++ b/python/oneflow/test/modules/test_acos.py @@ -22,6 +22,7 @@ import oneflow as flow import oneflow.unittest +from automated_test_util import * def _test_acos_impl(test_case, shape, device): @@ -50,6 +51,13 @@ def test_acos(test_case): for arg in GenArgList(arg_dict): _test_acos_impl(test_case, *arg) + @autotest(auto_backward=False) + def test_acos_with_0shape_data(test_case): + device = random_device() + x = random_pytorch_tensor(4, 2, 1, 0, 3).to(device) + y = torch.acos(x) + return y + if __name__ == "__main__": unittest.main() diff --git a/python/oneflow/test/modules/test_acosh.py b/python/oneflow/test/modules/test_acosh.py index 193424c6113..1d77e9d9f6a 100644 --- a/python/oneflow/test/modules/test_acosh.py +++ b/python/oneflow/test/modules/test_acosh.py @@ -22,6 +22,7 @@ import oneflow as flow import oneflow.unittest +from automated_test_util import * def _test_acosh_impl(test_case, shape, device): @@ -51,6 +52,13 @@ def test_acosh(test_case): for arg in GenArgList(arg_dict): _test_acosh_impl(test_case, *arg) + @autotest(auto_backward=False) + def test_acosh_with_0shape_data(test_case): + device = random_device() + x = random_pytorch_tensor(4, 2, 1, 0, 3).to(device) + y = torch.acosh(x) + return y + if __name__ == "__main__": unittest.main() diff --git a/python/oneflow/test/modules/test_activation.py b/python/oneflow/test/modules/test_activation.py index a54d81c5fb7..a3e4d28a430 100644 --- a/python/oneflow/test/modules/test_activation.py +++ b/python/oneflow/test/modules/test_activation.py @@ -70,6 +70,16 @@ def test_relu_module_with_random_data(test_case): y = m(x) return y + @autotest(auto_backward=False) + def test_relu_module_with_0shape_data(test_case): + m = torch.nn.ReLU() + m.train(random()) + device = random_device() + m.to(device) + x = random_pytorch_tensor(4, 2, 3, 0, 3).to(device) + y = m(x) + return y + def _test_relu6_impl(test_case, shape, device): np_input = np.random.randn(*shape) @@ -111,6 +121,16 @@ def test_relu6_module_with_random_data(test_case): y = m(x) return y + @autotest(auto_backward=False) + def test_relu6_module_with_0shape_data(test_case): + m = torch.nn.ReLU6() + m.train(random()) + device = random_device() + m.to(device) + x = random_pytorch_tensor(4, 2, 3, 0, 3).to(device) + y = m(x) + return y + def _test_tanh_nn_impl(test_case, shape, device): np_input = np.random.randn(*shape) @@ -163,6 +183,16 @@ def test_tanh_module_with_random_data(test_case): y = m(x) return y + @autotest(auto_backward=False) + def test_tanh_module_with_0shapedata(test_case): + m = torch.nn.Tanh() + m.train(random()) + device = random_device() + m.to(device) + x = random_pytorch_tensor(4, 2, 3, 0, 3).to(device) + y = m(x) + return y + @autotest() def test_flow_tanh_with_random_data(test_case): device = random_device() @@ -170,6 +200,14 @@ def test_flow_tanh_with_random_data(test_case): y = torch.tanh(x) return y + @unittest.skip("reshape has bug or auto test has bug") + @autotest(auto_backward=False) + def test_flow_tanh_with_0shape_data(test_case): + device = random_device() + x = random_pytorch_tensor(4, 2, 3, 0, 3).to(device) + y = flow.tanh(x) + return y + def _test_elu_function_impl(test_case, shape, device): m = flow.nn.ELU() @@ -209,6 +247,16 @@ def test_elu_module_with_random_data(test_case): y = m(x) return y + @autotest(auto_backward=False) + def test_elu_module_with_0shape_data(test_case): + m = torch.nn.ELU(alpha=random() | nothing()) + m.train(random()) + device = random_device() + m.to(device) + x = random_pytorch_tensor(4, 2, 3, 0, 3).to(device) + y = m(x) + return y + def _np_gelu(x): return 0.5 * x * (1 + special.erf(x / np.sqrt(2))) diff --git a/python/oneflow/test/modules/test_add.py b/python/oneflow/test/modules/test_add.py index cbc7a2a51f6..6dc2f29dcc6 100644 --- a/python/oneflow/test/modules/test_add.py +++ b/python/oneflow/test/modules/test_add.py @@ -22,6 +22,7 @@ import oneflow as flow import oneflow.unittest +from automated_test_util import * def _test_add_forward(test_case, shape, device): @@ -151,6 +152,17 @@ def test_add(test_case): for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) + @autotest(auto_backward=False) + def test_0shape_add(test_case): + device = random_device() + x = random_pytorch_tensor(2, 0, 3).to(device) + y = random_pytorch_tensor(2, 1, 3).to(device) + out1 = x + y + out2 = x + 2 + out3 = 2 + x + out4 = torch.add(x, y) + return out1, out2, out3 + if __name__ == "__main__": unittest.main() diff --git a/python/oneflow/test/modules/test_argwhere.py b/python/oneflow/test/modules/test_argwhere.py index 375c1f52660..ed9b5ae5527 100644 --- a/python/oneflow/test/modules/test_argwhere.py +++ b/python/oneflow/test/modules/test_argwhere.py @@ -38,7 +38,7 @@ class TestArgwhere(flow.unittest.TestCase): def test_argwhere(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [_test_argwhere] - arg_dict["shape"] = [(2, 3), (2, 3, 4), (2, 4, 5, 6)] + arg_dict["shape"] = [(2, 3), (2, 3, 4), (2, 4, 5, 6), (2, 3, 0, 4)] arg_dict["device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) diff --git a/python/oneflow/test/modules/test_atan.py b/python/oneflow/test/modules/test_atan.py index 91c602b7e84..396b018b95e 100644 --- a/python/oneflow/test/modules/test_atan.py +++ b/python/oneflow/test/modules/test_atan.py @@ -22,6 +22,7 @@ import oneflow as flow import oneflow.unittest +from automated_test_util import * def _test_atan(test_case, shape, device): @@ -70,6 +71,13 @@ def test_atan(test_case): for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) + @autotest(auto_backward=False) + def test_atan_with_0shape_data(test_case): + device = random_device() + x = random_pytorch_tensor(4, 2, 1, 0, 3).to(device) + y = torch.atan(x) + return y + if __name__ == "__main__": unittest.main() diff --git a/python/oneflow/test/modules/test_atan2.py b/python/oneflow/test/modules/test_atan2.py index d8cba30fd05..861ad8dd24b 100644 --- a/python/oneflow/test/modules/test_atan2.py +++ b/python/oneflow/test/modules/test_atan2.py @@ -23,6 +23,7 @@ import oneflow as flow import oneflow.unittest +from automated_test_util import * def _test_atan2_forward(test_case, shape, scalar, device): @@ -132,6 +133,13 @@ def test_flow_atan2_with_random_data(test_case): device=device, ) + @autotest(auto_backward=False) + def test_atan2_with_0shape_data(test_case): + device = random_device() + x = random_pytorch_tensor(4, 2, 1, 0, 3).to(device) + y = torch.atan2(x) + return y + if __name__ == "__main__": unittest.main() diff --git a/python/oneflow/test/modules/test_atanh.py b/python/oneflow/test/modules/test_atanh.py index 966e8b029d8..af0e18bc5ac 100644 --- a/python/oneflow/test/modules/test_atanh.py +++ b/python/oneflow/test/modules/test_atanh.py @@ -22,6 +22,7 @@ import oneflow as flow import oneflow.unittest +from automated_test_util import * def _test_atanh_impl(test_case, shape, device): @@ -70,6 +71,13 @@ def test_atanh(test_case): _test_atanh_impl(test_case, *arg) _test_arctanh_impl(test_case, *arg) + @autotest(auto_backward=False) + def test_atanh_0shape_data(test_case): + device = random_device() + x = random_pytorch_tensor(4, 2, 1, 0, 3).to(device) + y = torch.atanh(x) + return y + if __name__ == "__main__": unittest.main() diff --git a/python/oneflow/test/modules/test_cast.py b/python/oneflow/test/modules/test_cast.py index 2d21a21429e..10b00e819fe 100644 --- a/python/oneflow/test/modules/test_cast.py +++ b/python/oneflow/test/modules/test_cast.py @@ -66,6 +66,17 @@ def test_cast(test_case): for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) + def test_cast_with_0shape_data(test_case): + arg_dict = OrderedDict() + arg_dict["test_fun"] = [ + _test_cast_float2int, + _test_cast_int2float, + ] + arg_dict["device"] = ["cpu", "cuda"] + arg_dict["shape"] = [(2, 3, 0, 5)] + for arg in GenArgList(arg_dict): + arg[0](test_case, *arg[1:]) + if __name__ == "__main__": unittest.main() diff --git a/python/oneflow/test/modules/test_ceil.py b/python/oneflow/test/modules/test_ceil.py index 002251d6662..b721ae25b44 100644 --- a/python/oneflow/test/modules/test_ceil.py +++ b/python/oneflow/test/modules/test_ceil.py @@ -54,6 +54,13 @@ def test_ceil_flow_with_random_data(test_case): y = torch.ceil(input) return y + @autotest(auto_backward=False) + def test_ceil_with_0shape_data(test_case): + device = random_device() + x = random_pytorch_tensor(4, 2, 1, 0, 3).to(device) + y = torch.ceil(x) + return y + if __name__ == "__main__": unittest.main() diff --git a/python/oneflow/test/modules/test_clamp.py b/python/oneflow/test/modules/test_clamp.py index 640408dd1d6..c22c39a6396 100644 --- a/python/oneflow/test/modules/test_clamp.py +++ b/python/oneflow/test/modules/test_clamp.py @@ -153,6 +153,13 @@ def test_clip_max_none_flow_with_random_data(test_case): ) return y + @autotest(auto_backward=False) + def test_clamp_with_0shape_data(test_case): + device = random_device() + x = random_pytorch_tensor(4, 2, 1, 0, 3).to(device) + y = torch.clamp(x) + return y + if __name__ == "__main__": unittest.main() diff --git a/python/oneflow/test/modules/test_concat.py b/python/oneflow/test/modules/test_concat.py index 8e19203a7e7..2097f04e00d 100644 --- a/python/oneflow/test/modules/test_concat.py +++ b/python/oneflow/test/modules/test_concat.py @@ -134,10 +134,10 @@ def test_concat(test_case): arg[0](test_case, *arg[1:]) @autotest(n=10, auto_backward=False) - def test_0shape_concat(test_case): + def test_concat_with_0shape_data(test_case): device = random_device() - x = random_pytorch_tensor(4, 2, 3, 2 ,4).to(device) - y = random_pytorch_tensor(4, 2, 3, random(0, 3) ,4).to(device) + x = random_pytorch_tensor(4, 2, 3, 2, 4).to(device) + y = random_pytorch_tensor(4, 2, 3, random(0, 3), 4).to(device) z = torch.cat((x, y), dim=2) return z diff --git a/python/oneflow/test/modules/test_div.py b/python/oneflow/test/modules/test_div.py index daba1e6f08f..dcf886a4887 100644 --- a/python/oneflow/test/modules/test_div.py +++ b/python/oneflow/test/modules/test_div.py @@ -23,6 +23,7 @@ import oneflow as flow import oneflow.unittest +from automated_test_util import * def _test_div_impl(test_case, shape, device): @@ -100,6 +101,13 @@ def test_sub_against_pytorch(test_case): device=arg[1], ) + @autotest(auto_backward=False) + def test_0shape_div(test_case): + device = random_device() + x = random_pytorch_tensor(4, 2, 1, 0, 3).to(device) + y = torch.div(x) + return y + if __name__ == "__main__": unittest.main() diff --git a/python/oneflow/test/modules/test_eq.py b/python/oneflow/test/modules/test_eq.py index c1feca9bbbe..71ba7caf4bf 100644 --- a/python/oneflow/test/modules/test_eq.py +++ b/python/oneflow/test/modules/test_eq.py @@ -22,6 +22,7 @@ import oneflow as flow import oneflow.unittest +from automated_test_util import * def _test_eq(test_case, shape, device): @@ -99,6 +100,14 @@ def test_eq(test_case): for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) + @autotest(auto_backward=False) + def test_eq_with_0shape_data(test_case): + device = random_device() + x = random_pytorch_tensor(4, 2, 1, 0, 3).to(device) + y = random_pytorch_tensor(4, 2, 1, 0, 3).to(device) + z = torch.eq(y) + return z + if __name__ == "__main__": unittest.main() diff --git a/python/oneflow/test/modules/test_expm1.py b/python/oneflow/test/modules/test_expm1.py index 454b5daf629..04821a2d7c2 100644 --- a/python/oneflow/test/modules/test_expm1.py +++ b/python/oneflow/test/modules/test_expm1.py @@ -54,6 +54,13 @@ def test_expm1_flow_with_random_data(test_case): y = torch.expm1(input) return y + @autotest(auto_backward=False) + def test_expm1_with_0shape_data(test_case): + device = random_device() + x = random_pytorch_tensor(4, 2, 1, 0, 3).to(device) + y = torch.expm1(x) + return y + if __name__ == "__main__": unittest.main() diff --git a/python/oneflow/test/modules/test_fmod.py b/python/oneflow/test/modules/test_fmod.py index e0ab4da4d81..d20381fd34f 100644 --- a/python/oneflow/test/modules/test_fmod.py +++ b/python/oneflow/test/modules/test_fmod.py @@ -91,6 +91,13 @@ def test_flow_fmod_with_random_data(test_case): other = random_pytorch_tensor().to(device) return torch.fmod(input, other) + @autotest(auto_backward=False) + def test_fmod_with_0shape_data(test_case): + device = random_device() + x = random_pytorch_tensor(4, 2, 1, 0, 3).to(device) + y = torch.fmod(x, 2) + return y + if __name__ == "__main__": unittest.main() diff --git a/python/oneflow/test/modules/test_greater.py b/python/oneflow/test/modules/test_greater.py index a7990b6374b..e977621dc8a 100644 --- a/python/oneflow/test/modules/test_greater.py +++ b/python/oneflow/test/modules/test_greater.py @@ -118,6 +118,15 @@ def test_tensor_greater_with_random_data(test_case): y2 = x1 > x2 return (y1, y2) + @autotest(auto_backward=False) + def test_greater_with_0shape_data(test_case): + device = random_device() + x1 = random_pytorch_tensor(4, 2, 3, 0, 5).to(device) + x2 = random_pytorch_tensor(4, 2, 3, 0, 5).to(device) + y1 = torch.gt(x1, x2) + y2 = x1 > x2 + return (y1, y2) + if __name__ == "__main__": unittest.main() diff --git a/python/oneflow/test/modules/test_ne.py b/python/oneflow/test/modules/test_ne.py index 229ec5a8f79..02d7171c24e 100644 --- a/python/oneflow/test/modules/test_ne.py +++ b/python/oneflow/test/modules/test_ne.py @@ -22,6 +22,7 @@ import oneflow as flow import oneflow.unittest +from automated_test_util import * def _test_ne(test_case, shape, device): @@ -99,6 +100,17 @@ def test_ne(test_case): for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) + @autotest(auto_backward=False) + def test_ne_with_0shape_data(test_case): + device = random_device() + x1 = random_pytorch_tensor(4, 2, 3, 0, 5).to(device) + x2 = random_pytorch_tensor(4, 2, 3, 0, 5).to(device) + y1 = torch.ne(x1, x2) + y2 = torch.ne(x1, 2) + y3 = torch.ne(x1, 2.0) + + return (y1, y2, y3) + if __name__ == "__main__": unittest.main() diff --git a/python/oneflow/test/modules/test_negative.py b/python/oneflow/test/modules/test_negative.py index 7f8e7e45a16..534310bb38d 100644 --- a/python/oneflow/test/modules/test_negative.py +++ b/python/oneflow/test/modules/test_negative.py @@ -22,6 +22,7 @@ import oneflow as flow import oneflow.unittest +from automated_test_util import * def _test_negtive(test_case, shape, device): @@ -77,6 +78,16 @@ def test_negative(test_case): for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) + @autotest(auto_backward=False) + def test_ne_with_0shape_data(test_case): + device = random_device() + x = random_pytorch_tensor(4, 2, 3, 0, 5).to(device) + y1 = torch.negative(x) + y2 = torch.neg(x) + y3 = -x + + return (y1, y2, y3) + if __name__ == "__main__": unittest.main() diff --git a/python/oneflow/test/modules/test_pow.py b/python/oneflow/test/modules/test_pow.py index 1a87b26ffad..2a177e4a10a 100644 --- a/python/oneflow/test/modules/test_pow.py +++ b/python/oneflow/test/modules/test_pow.py @@ -96,7 +96,7 @@ def test_x_grad_scalar(): class TestPow(flow.unittest.TestCase): def test_pow_forward(test_case): arg_dict = OrderedDict() - arg_dict["shape"] = [(2, 3), (2, 3, 4, 5)] + arg_dict["shape"] = [(2, 3), (2, 3, 4, 5), (2, 3, 0, 5)] arg_dict["scalar"] = [2.1, 0.8] arg_dict["device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): diff --git a/python/oneflow/test/modules/test_reshape.py b/python/oneflow/test/modules/test_reshape.py index bf5cf34dab2..d268c1e920f 100644 --- a/python/oneflow/test/modules/test_reshape.py +++ b/python/oneflow/test/modules/test_reshape.py @@ -23,6 +23,7 @@ import oneflow as flow import oneflow.unittest +from automated_test_util import * def _test_reshape(test_case, device): @@ -83,6 +84,14 @@ def test_reshape_flow_with_random_data(test_case): y = torch.reshape(x, shape=(-1,)) return y + @unittest.skip("reshape has bug") + @autotest(auto_backward=False) + def test_reshape_with_0shape_data(test_case): + device = random_device() + x = random_pytorch_tensor(4, 2, 1, 0, 3).to(device) + y = torch.reshape(x, shape=[0]) + return y + if __name__ == "__main__": unittest.main() diff --git a/python/oneflow/test/modules/test_sign.py b/python/oneflow/test/modules/test_sign.py index f46a6165b40..50bf57501b7 100644 --- a/python/oneflow/test/modules/test_sign.py +++ b/python/oneflow/test/modules/test_sign.py @@ -22,6 +22,7 @@ import oneflow as flow import oneflow.unittest +from automated_test_util import * def _test_sign_impl(test_case, shape, device): @@ -47,6 +48,13 @@ def test_sign(test_case): for arg in GenArgList(arg_dict): _test_sign_impl(test_case, *arg) + @autotest(auto_backward=False) + def test_sign_with_0shape_data(test_case): + device = random_device() + x = random_pytorch_tensor(4, 2, 3, 0, 4).to(device) + y = torch.sign(x) + return y + if __name__ == "__main__": unittest.main() diff --git a/python/oneflow/test/modules/test_squeeze.py b/python/oneflow/test/modules/test_squeeze.py index a69f92cc3b1..7b158ee2983 100644 --- a/python/oneflow/test/modules/test_squeeze.py +++ b/python/oneflow/test/modules/test_squeeze.py @@ -108,6 +108,13 @@ def test_flow_squeeze_with_random_data(test_case): y = torch.squeeze(x, random(1, 3).to(int)) return y + @autotest(auto_backward=False) + def test_squeeze_with_0shape_data(test_case): + device = random_device() + x = random_pytorch_tensor(3, 2, 1, 0).to(device) + y = torch.squeeze(x) + return y + if __name__ == "__main__": unittest.main() diff --git a/python/oneflow/test/modules/test_sub.py b/python/oneflow/test/modules/test_sub.py index 623d33d91db..f01144595a9 100644 --- a/python/oneflow/test/modules/test_sub.py +++ b/python/oneflow/test/modules/test_sub.py @@ -23,6 +23,7 @@ import oneflow as flow import oneflow.unittest +from automated_test_util import * def _test_sub_impl(test_case, shape, device): @@ -110,6 +111,17 @@ def test_sub_against_pytorch(test_case): device=arg[1], ) + @autotest(auto_backward=False) + def test_sub_with_0shape_data(test_case): + device = random_device() + x = random_pytorch_tensor(2, 0, 3).to(device) + y = random_pytorch_tensor(2, 1, 3).to(device) + out1 = x - y + out2 = x - 2 + out3 = 2 - x + out4 = torch.sub(x - y) + return out1, out2, out3, out4 + if __name__ == "__main__": unittest.main() diff --git a/python/oneflow/test/modules/test_tan.py b/python/oneflow/test/modules/test_tan.py index bad582d05c7..03504cc2f89 100644 --- a/python/oneflow/test/modules/test_tan.py +++ b/python/oneflow/test/modules/test_tan.py @@ -22,6 +22,7 @@ import oneflow as flow import oneflow.unittest +from automated_test_util import * def _test_tan_impl(test_case, shape, device): @@ -51,6 +52,13 @@ def test_tan(test_case): for arg in GenArgList(arg_dict): _test_tan_impl(test_case, *arg) + @autotest(auto_backward=False) + def test_tan_with_0shape_data(test_case): + device = random_device() + x = random_pytorch_tensor(4, 2, 1, 0, 3).to(device) + y = torch.tan(x) + return y + if __name__ == "__main__": unittest.main() diff --git a/python/oneflow/test/modules/test_transpose.py b/python/oneflow/test/modules/test_transpose.py index 8ba4b13305f..675253df15f 100644 --- a/python/oneflow/test/modules/test_transpose.py +++ b/python/oneflow/test/modules/test_transpose.py @@ -102,6 +102,13 @@ def test_transpose_flow_with_random_data(test_case): y = torch.transpose(x, dim0=random(1, 3).to(int), dim1=random(1, 3).to(int)) return y + @autotest(auto_backward=False) + def test_transpose_with_0shape_data(test_case): + device = random_device() + x = random_pytorch_tensor(4, 2, 3, 0, 4).to(device) + y = torch.transpose(x, dim0=random(1, 3).to(int), dim1=random(1, 3).to(int)) + return y + if __name__ == "__main__": unittest.main() diff --git a/python/oneflow/test/modules/test_triu.py b/python/oneflow/test/modules/test_triu.py index 7c85f679600..b1561032e50 100644 --- a/python/oneflow/test/modules/test_triu.py +++ b/python/oneflow/test/modules/test_triu.py @@ -23,6 +23,7 @@ import oneflow as flow import oneflow.nn as nn import oneflow.unittest +from automated_test_util import * def _test_triu(test_case, diagonal, device): @@ -50,6 +51,13 @@ def test_triu(test_case): for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) + @autotest(auto_backward=False) + def test_triu_with_0shape_data(test_case): + device = random_device() + x = random_pytorch_tensor(4, 2, 1, 0, 3).to(device) + y = torch.triu(2) + return y + if __name__ == "__main__": unittest.main() diff --git a/python/oneflow/test/modules/test_unsqueeze.py b/python/oneflow/test/modules/test_unsqueeze.py index 1baf9cbae3b..953cbf217dc 100644 --- a/python/oneflow/test/modules/test_unsqueeze.py +++ b/python/oneflow/test/modules/test_unsqueeze.py @@ -81,6 +81,13 @@ def test_flow_unsqueeze_with_random_data(test_case): y = torch.unsqueeze(x, random(1, 3).to(int)) return y + @autotest(auto_backward=False) + def test_unsqueeze_with_0shape_data(test_case): + device = random_device() + x = random_pytorch_tensor(3, 2, 1, 0).to(device) + y = torch.unsqueeze(x, random(0, 2).to(int)) + return y + if __name__ == "__main__": unittest.main() From b287eb0357aed7cf1b98d3263076e81f38fe5ec4 Mon Sep 17 00:00:00 2001 From: liufengwei <2472937968@qq.com> Date: Fri, 30 Jul 2021 08:44:39 +0800 Subject: [PATCH 09/32] format --- python/oneflow/test/modules/test_ne.py | 1 - python/oneflow/test/modules/test_negative.py | 1 - 2 files changed, 2 deletions(-) diff --git a/python/oneflow/test/modules/test_ne.py b/python/oneflow/test/modules/test_ne.py index 02d7171c24e..48f80d87363 100644 --- a/python/oneflow/test/modules/test_ne.py +++ b/python/oneflow/test/modules/test_ne.py @@ -108,7 +108,6 @@ def test_ne_with_0shape_data(test_case): y1 = torch.ne(x1, x2) y2 = torch.ne(x1, 2) y3 = torch.ne(x1, 2.0) - return (y1, y2, y3) diff --git a/python/oneflow/test/modules/test_negative.py b/python/oneflow/test/modules/test_negative.py index 534310bb38d..4534545a224 100644 --- a/python/oneflow/test/modules/test_negative.py +++ b/python/oneflow/test/modules/test_negative.py @@ -85,7 +85,6 @@ def test_ne_with_0shape_data(test_case): y1 = torch.negative(x) y2 = torch.neg(x) y3 = -x - return (y1, y2, y3) From 5f984f180776a4573f29aa21767c82b82ebe2abb Mon Sep 17 00:00:00 2001 From: wyg1997 Date: Fri, 30 Jul 2021 12:32:59 +0800 Subject: [PATCH 10/32] feat(ReduceOp): reduce op kernels support 0shape tensor --- oneflow/user/kernels/reduce_kernel.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/oneflow/user/kernels/reduce_kernel.cpp b/oneflow/user/kernels/reduce_kernel.cpp index 1d0a6083d32..22d7d7426ba 100644 --- a/oneflow/user/kernels/reduce_kernel.cpp +++ b/oneflow/user/kernels/reduce_kernel.cpp @@ -33,6 +33,8 @@ class ReduceKernel final : public user_op::OpKernel { user_op::Tensor* output_tensor = ctx->Tensor4ArgNameAndIndex("output_tensor", 0); user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); const auto& axis = ctx->Attr>("axis"); + + if (input_tensor->shape().elem_cnt() == 0) { return; } const Shape& reduced_shape = CreateReducedShape(input_tensor->shape(), {axis.begin(), axis.end()}); NdarrayReduce::Reduce( From d1b3867792d7c38cc55836b1879d6aeffafa8a61 Mon Sep 17 00:00:00 2001 From: liufengwei <2472937968@qq.com> Date: Fri, 30 Jul 2021 13:57:13 +0800 Subject: [PATCH 11/32] delete files added by mistake --- python/oneflow/test/modules/test_acos.py | 63 ---------- python/oneflow/test/modules/test_acosh.py | 64 ---------- python/oneflow/test/modules/test_atan.py | 83 ------------- python/oneflow/test/modules/test_atan2.py | 145 ---------------------- python/oneflow/test/modules/test_atanh.py | 83 ------------- python/oneflow/test/modules/test_tan.py | 64 ---------- 6 files changed, 502 deletions(-) delete mode 100644 python/oneflow/test/modules/test_acos.py delete mode 100644 python/oneflow/test/modules/test_acosh.py delete mode 100644 python/oneflow/test/modules/test_atan.py delete mode 100644 python/oneflow/test/modules/test_atan2.py delete mode 100644 python/oneflow/test/modules/test_atanh.py delete mode 100644 python/oneflow/test/modules/test_tan.py diff --git a/python/oneflow/test/modules/test_acos.py b/python/oneflow/test/modules/test_acos.py deleted file mode 100644 index 1047ed2b62b..00000000000 --- a/python/oneflow/test/modules/test_acos.py +++ /dev/null @@ -1,63 +0,0 @@ -""" -Copyright 2020 The OneFlow Authors. All rights reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -""" - -import unittest -from collections import OrderedDict - -import numpy as np -from test_util import GenArgList - -import oneflow as flow -import oneflow.unittest -from automated_test_util import * - - -def _test_acos_impl(test_case, shape, device): - input = flow.Tensor( - np.random.rand(*shape) - 0.5, device=flow.device(device), requires_grad=True - ) - of_out = flow.acos(input) - np_out = np.arccos(input.numpy()) - test_case.assertTrue( - np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05, equal_nan=True) - ) - of_out = of_out.sum() - of_out.backward() - np_grad = -1.0 / np.sqrt(1 - np.square(input.numpy())) - test_case.assertTrue( - np.allclose(input.grad.numpy(), np_grad, 0.0001, 0.0001, equal_nan=True) - ) - - -@flow.unittest.skip_unless_1n1d() -class TestAcos(flow.unittest.TestCase): - def test_acos(test_case): - arg_dict = OrderedDict() - arg_dict["shape"] = [(2,), (2, 3), (2, 3, 4), (2, 4, 5, 6)] - arg_dict["device"] = ["cpu", "cuda"] - for arg in GenArgList(arg_dict): - _test_acos_impl(test_case, *arg) - - @autotest(auto_backward=False) - def test_acos_with_0shape_data(test_case): - device = random_device() - x = random_pytorch_tensor(4, 2, 1, 0, 3).to(device) - y = torch.acos(x) - return y - - -if __name__ == "__main__": - unittest.main() diff --git a/python/oneflow/test/modules/test_acosh.py b/python/oneflow/test/modules/test_acosh.py deleted file mode 100644 index 1d77e9d9f6a..00000000000 --- a/python/oneflow/test/modules/test_acosh.py +++ /dev/null @@ -1,64 +0,0 @@ -""" -Copyright 2020 The OneFlow Authors. All rights reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -""" - -import unittest -from collections import OrderedDict - -import numpy as np -from test_util import GenArgList - -import oneflow as flow -import oneflow.unittest -from automated_test_util import * - - -def _test_acosh_impl(test_case, shape, device): - np_input = np.random.rand(*shape) + 2.0 - of_input = flow.Tensor( - np_input, dtype=flow.float32, device=flow.device(device), requires_grad=True - ) - of_out = flow.acosh(of_input) - np_out = np.arccosh(np_input) - test_case.assertTrue( - np.allclose(of_out.numpy(), np_out, 0.0001, 0.0001, equal_nan=True) - ) - of_out = of_out.sum() - of_out.backward() - np_grad = 1.0 / np.sqrt(np.square(np_input) - 1) - test_case.assertTrue( - np.allclose(of_input.grad.numpy(), np_grad, 0.0001, 0.0001, equal_nan=True) - ) - - -@flow.unittest.skip_unless_1n1d() -class TestAcosh(flow.unittest.TestCase): - def test_acosh(test_case): - arg_dict = OrderedDict() - arg_dict["shape"] = [(2, 3), (2, 3, 4), (2, 4, 5, 6)] - arg_dict["device"] = ["cpu", "cuda"] - for arg in GenArgList(arg_dict): - _test_acosh_impl(test_case, *arg) - - @autotest(auto_backward=False) - def test_acosh_with_0shape_data(test_case): - device = random_device() - x = random_pytorch_tensor(4, 2, 1, 0, 3).to(device) - y = torch.acosh(x) - return y - - -if __name__ == "__main__": - unittest.main() diff --git a/python/oneflow/test/modules/test_atan.py b/python/oneflow/test/modules/test_atan.py deleted file mode 100644 index 396b018b95e..00000000000 --- a/python/oneflow/test/modules/test_atan.py +++ /dev/null @@ -1,83 +0,0 @@ -""" -Copyright 2020 The OneFlow Authors. All rights reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -""" - -import unittest -from collections import OrderedDict - -import numpy as np -from test_util import GenArgList - -import oneflow as flow -import oneflow.unittest -from automated_test_util import * - - -def _test_atan(test_case, shape, device): - np_input = np.random.randn(*shape) - of_input = flow.Tensor( - np_input, dtype=flow.float32, device=flow.device(device), requires_grad=True - ) - of_out = flow.atan(of_input) - np_out = np.arctan(np_input) - test_case.assertTrue( - np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05, equal_nan=True) - ) - of_out = of_out.sum() - of_out.backward() - np_out_grad = 1 / (1 + np_input ** 2) - test_case.assertTrue( - np.allclose(of_input.grad.numpy(), np_out_grad, 1e-05, 1e-05, equal_nan=True) - ) - - -def _test_arctan(test_case, shape, device): - np_input = np.random.randn(*shape) - of_input = flow.Tensor( - np_input, dtype=flow.float32, device=flow.device(device), requires_grad=True - ) - of_out = flow.arctan(of_input) - np_out = np.arctan(np_input) - test_case.assertTrue( - np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05, equal_nan=True) - ) - of_out = of_out.sum() - of_out.backward() - np_out_grad = 1 / (1 + np_input ** 2) - test_case.assertTrue( - np.allclose(of_input.grad.numpy(), np_out_grad, 1e-05, 1e-05, equal_nan=True) - ) - - -@flow.unittest.skip_unless_1n1d() -class TestAtan(flow.unittest.TestCase): - def test_atan(test_case): - arg_dict = OrderedDict() - arg_dict["test_fun"] = [_test_atan, _test_arctan] - arg_dict["shape"] = [(2,), (2, 3), (2, 3, 4), (2, 4, 5, 6)] - arg_dict["device"] = ["cpu", "cuda"] - for arg in GenArgList(arg_dict): - arg[0](test_case, *arg[1:]) - - @autotest(auto_backward=False) - def test_atan_with_0shape_data(test_case): - device = random_device() - x = random_pytorch_tensor(4, 2, 1, 0, 3).to(device) - y = torch.atan(x) - return y - - -if __name__ == "__main__": - unittest.main() diff --git a/python/oneflow/test/modules/test_atan2.py b/python/oneflow/test/modules/test_atan2.py deleted file mode 100644 index 861ad8dd24b..00000000000 --- a/python/oneflow/test/modules/test_atan2.py +++ /dev/null @@ -1,145 +0,0 @@ -""" -Copyright 2020 The OneFlow Authors. All rights reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -""" - -import unittest -from collections import OrderedDict - -import numpy as np -from automated_test_util import * -from test_util import GenArgList - -import oneflow as flow -import oneflow.unittest -from automated_test_util import * - - -def _test_atan2_forward(test_case, shape, scalar, device): - np_input_x = 10 * np.random.rand(*shape) - np_input_y = 10 * np.random.randn(*shape) - of_input_x = flow.Tensor(np_input_x, dtype=flow.float32, device=flow.device(device)) - of_input_y = flow.Tensor(np_input_y, dtype=flow.float32, device=flow.device(device)) - of_out = flow.atan2(of_input_x, of_input_y) - np_out = np.arctan2(np_input_x, np_input_y) - test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05)) - - -def _test_atan2_backward(test_case, device): - np_input_x = np.random.rand(2, 3) - np_input_y = np.random.rand(2, 3) - np_y_grad = -1 * np_input_x / (np_input_x * np_input_x + np_input_y * np_input_y) - np_x_grad = np_input_y / (np_input_x * np_input_x + np_input_y * np_input_y) - - def test_x_y_grad(): - of_input_x = flow.Tensor( - np_input_x, - dtype=flow.float32, - device=flow.device(device), - requires_grad=True, - ) - of_input_y = flow.Tensor( - np_input_y, - dtype=flow.float32, - device=flow.device(device), - requires_grad=True, - ) - of_out = flow.atan2(of_input_x, of_input_y) - of_out_sum = of_out.sum() - of_out_sum.backward() - test_case.assertTrue( - np.allclose(of_input_x.grad.numpy(), np_x_grad, 0.0001, 0.0001) - ) - test_case.assertTrue( - np.allclose(of_input_y.grad.numpy(), np_y_grad, 0.0001, 0.0001) - ) - - def test_x_grad(): - of_input_x = flow.Tensor( - np_input_x, - dtype=flow.float32, - device=flow.device(device), - requires_grad=True, - ) - of_input_y = flow.Tensor( - np_input_y, dtype=flow.float32, device=flow.device(device) - ) - of_out = flow.atan2(of_input_x, of_input_y) - of_out_sum = of_out.sum() - of_out_sum.backward() - test_case.assertTrue( - np.allclose(of_input_x.grad.numpy(), np_x_grad, 0.0001, 0.0001) - ) - - def test_y_grad(): - of_input_x = flow.Tensor( - np_input_x, dtype=flow.float32, device=flow.device(device) - ) - of_input_y = flow.Tensor( - np_input_y, - dtype=flow.float32, - device=flow.device(device), - requires_grad=True, - ) - of_out = flow.atan2(of_input_x, of_input_y) - of_out_sum = of_out.sum() - of_out_sum.backward() - test_case.assertTrue( - np.allclose(of_input_y.grad.numpy(), np_y_grad, 0.0001, 0.0001) - ) - - test_x_y_grad() - test_x_grad() - test_y_grad() - - -@flow.unittest.skip_unless_1n1d() -class TestAtan2(flow.unittest.TestCase): - def test_atan2_forward(test_case): - arg_dict = OrderedDict() - arg_dict["shape"] = [(2,), (2, 3), (2, 3, 4), (2, 3, 4, 5)] - arg_dict["scalar"] = [2.1, 0.8] - arg_dict["device"] = ["cpu", "cuda"] - for arg in GenArgList(arg_dict): - _test_atan2_forward(test_case, *arg) - - def test_atan2_backward(test_case): - arg_dict = OrderedDict() - arg_dict["device"] = ["cpu", "cuda"] - for arg in GenArgList(arg_dict): - _test_atan2_backward(test_case, *arg) - - def test_flow_atan2_with_random_data(test_case): - for device in ["cpu", "cuda"]: - test_flow_against_pytorch( - test_case, - "atan2", - extra_annotations={"other": flow.Tensor}, - extra_generators={ - "input": random_tensor(ndim=1, dim1=1), - "other": random_tensor(ndim=1, dim1=1), - }, - device=device, - ) - - @autotest(auto_backward=False) - def test_atan2_with_0shape_data(test_case): - device = random_device() - x = random_pytorch_tensor(4, 2, 1, 0, 3).to(device) - y = torch.atan2(x) - return y - - -if __name__ == "__main__": - unittest.main() diff --git a/python/oneflow/test/modules/test_atanh.py b/python/oneflow/test/modules/test_atanh.py deleted file mode 100644 index af0e18bc5ac..00000000000 --- a/python/oneflow/test/modules/test_atanh.py +++ /dev/null @@ -1,83 +0,0 @@ -""" -Copyright 2020 The OneFlow Authors. All rights reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -""" - -import unittest -from collections import OrderedDict - -import numpy as np -from test_util import GenArgList - -import oneflow as flow -import oneflow.unittest -from automated_test_util import * - - -def _test_atanh_impl(test_case, shape, device): - np_input = np.random.random(shape) - 0.5 - of_input = flow.Tensor( - np_input, dtype=flow.float32, device=flow.device(device), requires_grad=True - ) - of_out = flow.atanh(of_input) - np_out = np.arctanh(np_input) - test_case.assertTrue( - np.allclose(of_out.numpy(), np_out, 0.0001, 0.0001, equal_nan=True) - ) - of_out = of_out.sum() - of_out.backward() - np_out_grad = 1.0 / (1.0 - np.square(np_input)) - test_case.assertTrue( - np.allclose(of_input.grad.numpy(), np_out_grad, 0.0001, 0.0001, equal_nan=True) - ) - - -def _test_arctanh_impl(test_case, shape, device): - np_input = np.random.random(shape) - 0.5 - of_input = flow.Tensor( - np_input, dtype=flow.float32, device=flow.device(device), requires_grad=True - ) - of_out = flow.arctanh(of_input) - np_out = np.arctanh(np_input) - test_case.assertTrue( - np.allclose(of_out.numpy(), np_out, 0.0001, 0.0001, equal_nan=True) - ) - of_out = of_out.sum() - of_out.backward() - np_out_grad = 1.0 / (1.0 - np.square(np_input)) - test_case.assertTrue( - np.allclose(of_input.grad.numpy(), np_out_grad, 0.0001, 0.0001, equal_nan=True) - ) - - -@flow.unittest.skip_unless_1n1d() -class TestAtanh(flow.unittest.TestCase): - def test_atanh(test_case): - arg_dict = OrderedDict() - arg_dict["shape"] = [(2,), (2, 3), (2, 3, 4), (2, 4, 5, 6)] - arg_dict["device"] = ["cpu", "cuda"] - for arg in GenArgList(arg_dict): - _test_atanh_impl(test_case, *arg) - _test_arctanh_impl(test_case, *arg) - - @autotest(auto_backward=False) - def test_atanh_0shape_data(test_case): - device = random_device() - x = random_pytorch_tensor(4, 2, 1, 0, 3).to(device) - y = torch.atanh(x) - return y - - -if __name__ == "__main__": - unittest.main() diff --git a/python/oneflow/test/modules/test_tan.py b/python/oneflow/test/modules/test_tan.py deleted file mode 100644 index 03504cc2f89..00000000000 --- a/python/oneflow/test/modules/test_tan.py +++ /dev/null @@ -1,64 +0,0 @@ -""" -Copyright 2020 The OneFlow Authors. All rights reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -""" - -import unittest -from collections import OrderedDict - -import numpy as np -from test_util import GenArgList - -import oneflow as flow -import oneflow.unittest -from automated_test_util import * - - -def _test_tan_impl(test_case, shape, device): - np_input = np.random.random(shape) - 0.5 - of_input = flow.Tensor( - np_input, dtype=flow.float32, device=flow.device(device), requires_grad=True - ) - of_out = flow.tan(of_input) - np_out = np.tan(np_input) - test_case.assertTrue( - np.allclose(of_out.numpy(), np_out, 0.0001, 0.0001, equal_nan=True) - ) - of_out = of_out.sum() - of_out.backward() - np_out_grad = 1 + np.square(np_out) - test_case.assertTrue( - np.allclose(of_input.grad.numpy(), np_out_grad, 0.0001, 0.0001, equal_nan=True) - ) - - -@flow.unittest.skip_unless_1n1d() -class TestTan(flow.unittest.TestCase): - def test_tan(test_case): - arg_dict = OrderedDict() - arg_dict["shape"] = [(2,), (2, 3), (2, 3, 4), (2, 4, 5, 6)] - arg_dict["device"] = ["cpu", "cuda"] - for arg in GenArgList(arg_dict): - _test_tan_impl(test_case, *arg) - - @autotest(auto_backward=False) - def test_tan_with_0shape_data(test_case): - device = random_device() - x = random_pytorch_tensor(4, 2, 1, 0, 3).to(device) - y = torch.tan(x) - return y - - -if __name__ == "__main__": - unittest.main() From 0f0f127456ff6bd9703e9b736a93ea6cf596f120 Mon Sep 17 00:00:00 2001 From: liufengwei <2472937968@qq.com> Date: Fri, 30 Jul 2021 14:58:13 +0800 Subject: [PATCH 12/32] refine if --- oneflow/core/kernel/kernel_util.cu | 33 +++++++++---------- .../kernel/util/cuda_arithemetic_interface.cu | 9 +++-- .../core/kernel/util/cuda_dnn_interface.cu | 15 ++++----- oneflow/user/kernels/add_n_kernel.cu | 8 ++--- .../kernels/math_binary_elementwise_kernel.cu | 10 +++--- .../kernels/math_unary_elementwise_kernel.cu | 9 +++-- oneflow/user/kernels/slice_util.cu | 9 +++-- 7 files changed, 42 insertions(+), 51 deletions(-) diff --git a/oneflow/core/kernel/kernel_util.cu b/oneflow/core/kernel/kernel_util.cu index 8960a3effbd..8c517e4d599 100644 --- a/oneflow/core/kernel/kernel_util.cu +++ b/oneflow/core/kernel/kernel_util.cu @@ -655,35 +655,32 @@ __global__ void CastOnGpu(const half* in, float* out, int64_t elem_ template void CopyElemOnGpu(DeviceCtx* ctx, const T* in_dptr, U* out_dptr, int64_t elem_num) { - if (elem_num > 0) { - if (std::is_same::value) { - Memcpy(ctx, out_dptr, in_dptr, elem_num * sizeof(T)); - } else { - CastOnGpu - <<cuda_stream()>>>( - in_dptr, out_dptr, elem_num); - } + if (elem_num == 0) { return; } + if (std::is_same::value) { + Memcpy(ctx, out_dptr, in_dptr, elem_num * sizeof(T)); + } else { + CastOnGpu + <<cuda_stream()>>>( + in_dptr, out_dptr, elem_num); } } template<> void CopyElemOnGpu(DeviceCtx* ctx, const float* in_dptr, float16* out_dptr, int64_t elem_num) { - if (RoundUp(elem_num, 2) > 0) { - CastOnGpu - <<cuda_stream()>>>(in_dptr, reinterpret_cast(out_dptr), elem_num); - } + if (RoundUp(elem_num, 2) == 0) { return; } + CastOnGpu + <<cuda_stream()>>>(in_dptr, reinterpret_cast(out_dptr), elem_num); } template<> void CopyElemOnGpu(DeviceCtx* ctx, const float16* in_dptr, float* out_dptr, int64_t elem_num) { - if (RoundUp(elem_num, 2) > 0) { - CastOnGpu - <<cuda_stream()>>>(reinterpret_cast(in_dptr), out_dptr, elem_num); - } + if (RoundUp(elem_num, 2) == 0) { return; } + CastOnGpu + <<cuda_stream()>>>(reinterpret_cast(in_dptr), out_dptr, elem_num); } #define INSTANTIATE_COPY_ELEM_ON_GPU(T, U) \ diff --git a/oneflow/core/kernel/util/cuda_arithemetic_interface.cu b/oneflow/core/kernel/util/cuda_arithemetic_interface.cu index be19b41bbd3..2707d624a1c 100644 --- a/oneflow/core/kernel/util/cuda_arithemetic_interface.cu +++ b/oneflow/core/kernel/util/cuda_arithemetic_interface.cu @@ -69,11 +69,10 @@ void LaunchTransposeGpu(DeviceCtx* ctx, const ShapeView& x_shape, const ShapeVie cur_stride *= x_shape.At(i); } for (int32_t i = 0; i < NDIMS; ++i) { x_strides.val[i] = buff[permutation[i]]; } - if (elem_cnt > 0) { - TransposeGpu - <<cuda_stream()>>>( - y_shape_struct, x_strides, elem_cnt, x, y); - } + if (elem_cnt == 0) { return; } + TransposeGpu + <<cuda_stream()>>>( + y_shape_struct, x_strides, elem_cnt, x, y); } template diff --git a/oneflow/core/kernel/util/cuda_dnn_interface.cu b/oneflow/core/kernel/util/cuda_dnn_interface.cu index a65e89edb07..85c755a6147 100644 --- a/oneflow/core/kernel/util/cuda_dnn_interface.cu +++ b/oneflow/core/kernel/util/cuda_dnn_interface.cu @@ -132,14 +132,13 @@ template struct ReluHelper final { static void ReluForward(DeviceCtx* ctx, const int64_t n, const T* x, T* y) { CHECK_LE(n, GetMaxVal() / 2); - if (n > 0) { - if (x == y) { - InplaceReluForwardGpu - <<cuda_stream()>>>(n, y); - } else { - ReluForwardGpu - <<cuda_stream()>>>(n, x, y); - } + if (n == 0) { return; } + if (x == y) { + InplaceReluForwardGpu + <<cuda_stream()>>>(n, y); + } else { + ReluForwardGpu + <<cuda_stream()>>>(n, x, y); } } diff --git a/oneflow/user/kernels/add_n_kernel.cu b/oneflow/user/kernels/add_n_kernel.cu index 8c3fd9fbf51..cb9d1561176 100644 --- a/oneflow/user/kernels/add_n_kernel.cu +++ b/oneflow/user/kernels/add_n_kernel.cu @@ -56,10 +56,10 @@ struct GpuAddCaller { for (int32_t i = 0; i < N; ++i) { para.in[i] = ctx->Tensor4ArgNameAndIndex("in", i)->dptr(); } - if (n > 0) { - gpu_add<<device_ctx()->cuda_stream()>>>(n, para); - } + if (n == 0) { return; } + gpu_add + <<device_ctx()->cuda_stream()>>>( + n, para); } }; diff --git a/oneflow/user/kernels/math_binary_elementwise_kernel.cu b/oneflow/user/kernels/math_binary_elementwise_kernel.cu index f997cf5255c..a8d5cfe250c 100644 --- a/oneflow/user/kernels/math_binary_elementwise_kernel.cu +++ b/oneflow/user/kernels/math_binary_elementwise_kernel.cu @@ -52,12 +52,10 @@ class MathBinaryElementwiseGpuKernel final : public user_op::OpKernel { user_op::Tensor* tensor_z = ctx->Tensor4ArgNameAndIndex("z", 0); int64_t n = tensor_x->shape().elem_cnt(); CHECK_LE(n, GetMaxVal() / 2); - if (n > 0) { - MathBinaryElementwiseForwardGpu - <<device_ctx()->cuda_stream()>>>(n, tensor_x->dptr(), tensor_y->dptr(), - tensor_z->mut_dptr()); - } + if (n == 0) { return; } + MathBinaryElementwiseForwardGpu + <<device_ctx()->cuda_stream()>>>( + n, tensor_x->dptr(), tensor_y->dptr(), tensor_z->mut_dptr()); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; diff --git a/oneflow/user/kernels/math_unary_elementwise_kernel.cu b/oneflow/user/kernels/math_unary_elementwise_kernel.cu index 7ae25112772..5beb2360c30 100644 --- a/oneflow/user/kernels/math_unary_elementwise_kernel.cu +++ b/oneflow/user/kernels/math_unary_elementwise_kernel.cu @@ -46,11 +46,10 @@ class MathUnaryElementwiseGpuKernel final : public user_op::OpKernel { T* y = tensor_y->mut_dptr(); int64_t n = tensor_x->shape().elem_cnt(); CHECK_LE(n, GetMaxVal() / 2); - if (n > 0) { - MathUnaryElementwiseForwardGpu - <<device_ctx()->cuda_stream()>>>(n, x, y); - } + if (n == 0) { return; } + MathUnaryElementwiseForwardGpu + <<device_ctx()->cuda_stream()>>>( + n, x, y); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; diff --git a/oneflow/user/kernels/slice_util.cu b/oneflow/user/kernels/slice_util.cu index a8f16549dbd..7edb16af416 100644 --- a/oneflow/user/kernels/slice_util.cu +++ b/oneflow/user/kernels/slice_util.cu @@ -48,11 +48,10 @@ void LaunchSliceForward(DeviceCtx* ctx, const SliceParams& params, const T* enti int64_t elem_cnt = params.elem_cnt(); SliceIndexHelper entire_idx_cvtr(params.dims); SliceIndexHelper sliced_idx_cvtr(params.size); - if (elem_cnt > 0) { - SliceForwardGpu - <<cuda_stream()>>>( - elem_cnt, params, entire_idx_cvtr, sliced_idx_cvtr, entire, sliced); - } + if (elem_cnt == 0) { return; } + SliceForwardGpu + <<cuda_stream()>>>( + elem_cnt, params, entire_idx_cvtr, sliced_idx_cvtr, entire, sliced); } template From 0eb7cff92d298e2e2fc294790a6a485d77126682 Mon Sep 17 00:00:00 2001 From: liufengwei <2472937968@qq.com> Date: Fri, 30 Jul 2021 15:10:28 +0800 Subject: [PATCH 13/32] refine if --- .../core/ndarray/ndarray_apply_binary_core.cu | 14 ++++++------- .../ndarray_apply_broadcast_binary_core.cu | 20 ++++++++----------- .../ndarray_apply_broadcast_unary_core.cu | 3 ++- .../core/ndarray/ndarray_apply_unary_core.cu | 5 ++--- oneflow/core/ndarray/ndarray_assign_core.cu | 3 ++- 5 files changed, 20 insertions(+), 25 deletions(-) diff --git a/oneflow/core/ndarray/ndarray_apply_binary_core.cu b/oneflow/core/ndarray/ndarray_apply_binary_core.cu index 1a0da7e9dfa..5ea4c577930 100644 --- a/oneflow/core/ndarray/ndarray_apply_binary_core.cu +++ b/oneflow/core/ndarray/ndarray_apply_binary_core.cu @@ -40,18 +40,16 @@ struct NdarrayApplyBinaryCoreWrapper final { const XpuVarNdarray::return_type>& y, const XpuVarNdarray& a, const XpuVarNdarray& b) { size_t n = y.host_shape().HostElemNum(); - if (n > 0) { - RUN_CUDA_KERNEL((NdarrayApplyBinaryApplyGpu), ctx, n, n, y.host_ptr(), - a.host_ptr(), b.host_ptr()); - } + if (n == 0) { return; } + RUN_CUDA_KERNEL((NdarrayApplyBinaryApplyGpu), ctx, n, n, y.host_ptr(), + a.host_ptr(), b.host_ptr()); } static void InplaceApply(DeviceCtx* ctx, const XpuVarNdarray& y, const XpuVarNdarray& x) { size_t n = y.host_shape().HostElemNum(); - if (n > 0) { - RUN_CUDA_KERNEL((NdarrayApplyBinaryInplaceApplyGpu), ctx, n, n, y.host_ptr(), - x.host_ptr()); - } + if (n == 0) { return; } + RUN_CUDA_KERNEL((NdarrayApplyBinaryInplaceApplyGpu), ctx, n, n, y.host_ptr(), + x.host_ptr()); } }; diff --git a/oneflow/core/ndarray/ndarray_apply_broadcast_binary_core.cu b/oneflow/core/ndarray/ndarray_apply_broadcast_binary_core.cu index 6e7c683dda5..c1ed93595cf 100644 --- a/oneflow/core/ndarray/ndarray_apply_broadcast_binary_core.cu +++ b/oneflow/core/ndarray/ndarray_apply_broadcast_binary_core.cu @@ -89,11 +89,10 @@ struct NdarrayApplyBroadcastBinaryCoreWrapper::return_type>& y, const XpuVarNdarray& a, const XpuVarNdarray& b) { size_t n = y.host_shape().HostElemNum(); - if (n > 0) { - if (IsKernelSafeInt32(n) && PartialBroadcast(ctx, y, a, b)) { return; } - if (!IsKernelSafeInt32(n) && PartialBroadcast(ctx, y, a, b)) { return; } - RUN_CUDA_KERNEL((GpuBroadcastBinaryFunc), ctx, n, y, a, b); - } + if (n == 0) { return; } + if (IsKernelSafeInt32(n) && PartialBroadcast(ctx, y, a, b)) { return; } + if (!IsKernelSafeInt32(n) && PartialBroadcast(ctx, y, a, b)) { return; } + RUN_CUDA_KERNEL((GpuBroadcastBinaryFunc), ctx, n, y, a, b); } template @@ -153,13 +152,10 @@ struct NdarrayApplyBroadcastInplaceBinaryCoreWrapper a(y.host_shape(), y.host_ptr()); using NBB = NdarrayApplyBroadcastBinaryCoreWrapper; - if (n > 0) { - if (IsKernelSafeInt32(n) && NBB::template PartialBroadcast(ctx, y, a, x)) { return; } - if (!IsKernelSafeInt32(n) && NBB::template PartialBroadcast(ctx, y, a, x)) { - return; - } - RUN_CUDA_KERNEL((GpuInplaceBroadcastBinaryFunc), ctx, n, y, x); - } + if (n == 0) { return; } + if (IsKernelSafeInt32(n) && NBB::template PartialBroadcast(ctx, y, a, x)) { return; } + if (!IsKernelSafeInt32(n) && NBB::template PartialBroadcast(ctx, y, a, x)) { return; } + RUN_CUDA_KERNEL((GpuInplaceBroadcastBinaryFunc), ctx, n, y, x); } }; diff --git a/oneflow/core/ndarray/ndarray_apply_broadcast_unary_core.cu b/oneflow/core/ndarray/ndarray_apply_broadcast_unary_core.cu index d45340b6cac..be6c625ad59 100644 --- a/oneflow/core/ndarray/ndarray_apply_broadcast_unary_core.cu +++ b/oneflow/core/ndarray/ndarray_apply_broadcast_unary_core.cu @@ -30,7 +30,8 @@ template class unary_func> struct NdarrayApplyBroadcastUnaryCoreWrapper final { static void Apply(DeviceCtx* ctx, const XpuVarNdarray& y, const XpuVarNdarray& x) { size_t n = y.host_shape().HostElemNum(); - if (n > 0) { RUN_CUDA_KERNEL((GpuBroadcastUnaryFunc), ctx, n, y, x); } + if (n == 0) { return; } + RUN_CUDA_KERNEL((GpuBroadcastUnaryFunc), ctx, n, y, x); } }; diff --git a/oneflow/core/ndarray/ndarray_apply_unary_core.cu b/oneflow/core/ndarray/ndarray_apply_unary_core.cu index aa832fa81bd..1b9192ba9e6 100644 --- a/oneflow/core/ndarray/ndarray_apply_unary_core.cu +++ b/oneflow/core/ndarray/ndarray_apply_unary_core.cu @@ -31,9 +31,8 @@ template class unary_func> struct NdarrayApplyUnaryCoreWrapper final { static void InplaceApply(DeviceCtx* ctx, const XpuVarNdarray& y) { size_t n = y.host_shape().HostElemNum(); - if (n > 0) { - RUN_CUDA_KERNEL((NdarrayApplyUnaryInplaceApplyGpu), ctx, n, y.host_ptr(), n); - } + if (n == 0) { return; } + RUN_CUDA_KERNEL((NdarrayApplyUnaryInplaceApplyGpu), ctx, n, y.host_ptr(), n); } }; diff --git a/oneflow/core/ndarray/ndarray_assign_core.cu b/oneflow/core/ndarray/ndarray_assign_core.cu index 7d11e919625..17d79f64c0f 100644 --- a/oneflow/core/ndarray/ndarray_assign_core.cu +++ b/oneflow/core/ndarray/ndarray_assign_core.cu @@ -33,7 +33,8 @@ struct NdarrayAssignCoreWrapper final { static void Assign(DeviceCtx* ctx, const XpuVarNdarray& y, const XpuReducedNdarray& reduced) { size_t n = y.host_shape().HostElemNum(); - if (n > 0) { RUN_CUDA_KERNEL((NdarrayAssignGpu), ctx, n, y, reduced); } + if (n == 0) { return; } + RUN_CUDA_KERNEL((NdarrayAssignGpu), ctx, n, y, reduced); } }; From 65c38d4a25794a7c3d38f3a06ac8dfb3536a577a Mon Sep 17 00:00:00 2001 From: wyg1997 Date: Fri, 30 Jul 2021 15:28:39 +0800 Subject: [PATCH 14/32] feat(ConstantOp): constant ops support 0shape tensor --- oneflow/user/kernels/constant_kernel.cpp | 3 ++- python/oneflow/test/modules/test_constant.py | 8 ++++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/oneflow/user/kernels/constant_kernel.cpp b/oneflow/user/kernels/constant_kernel.cpp index ff6f18a3c50..969f07ce920 100644 --- a/oneflow/user/kernels/constant_kernel.cpp +++ b/oneflow/user/kernels/constant_kernel.cpp @@ -30,7 +30,8 @@ class ConstantKernel final : public OpKernel { Tensor* out_tensor = ctx->Tensor4ArgNameAndIndex("out", 0); bool is_floating_value = ctx->Attr("is_floating_value"); const int64_t elem_cnt = out_tensor->shape().elem_cnt(); - CHECK_GT(elem_cnt, 0); + CHECK_GE(elem_cnt, 0); + if (elem_cnt == 0) { return; } NewKernelUtil::Fill(ctx->device_ctx(), elem_cnt, is_floating_value ? static_cast(ctx->Attr("floating_value")) diff --git a/python/oneflow/test/modules/test_constant.py b/python/oneflow/test/modules/test_constant.py index deaeb060668..800b36feda6 100644 --- a/python/oneflow/test/modules/test_constant.py +++ b/python/oneflow/test/modules/test_constant.py @@ -18,11 +18,11 @@ from collections import OrderedDict import numpy as np -from test_util import GenArgList - import oneflow as flow + import oneflow.unittest -from oneflow.framework.tensor import register_tensor_op +from test_util import GenArgList +from automated_test_util import * def _test_ones(test_case, device, shape): @@ -117,7 +117,7 @@ def test_cast(test_case): _test_new_ones, ] arg_dict["device"] = ["cpu", "cuda"] - arg_dict["shape"] = [(2, 3), (2, 3, 4), (2, 3, 4, 5)] + arg_dict["shape"] = [(2, 3), (2, 3, 4), (2, 3, 4, 5), (2, 0, 4)] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) From 068d9fb7fcc98ddd1d7dc021e9d103fc53eac2f9 Mon Sep 17 00:00:00 2001 From: wyg1997 Date: Fri, 30 Jul 2021 15:55:52 +0800 Subject: [PATCH 15/32] feat(ReshapeOp): reshape kernel support 0shape tensor --- oneflow/user/ops/reshape_op.cpp | 2 +- python/oneflow/test/modules/test_reshape.py | 7 ++++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/oneflow/user/ops/reshape_op.cpp b/oneflow/user/ops/reshape_op.cpp index 573028fb2e2..de4331180e9 100644 --- a/oneflow/user/ops/reshape_op.cpp +++ b/oneflow/user/ops/reshape_op.cpp @@ -48,7 +48,7 @@ Maybe LogicalTensorDescInferFn(user_op::InferContext* ctx) { *out_tensor_desc = in_tensor_desc; CHECK_GE_OR_RETURN(shape.NumAxes(), 1); DimVector dim_vec = {shape.dim_vec().begin(), shape.dim_vec().end()}; - FOR_RANGE(int32_t, i, 0, dim_vec.size()) { CHECK_GT_OR_RETURN(dim_vec.at(i), 0); } + FOR_RANGE(int32_t, i, 0, dim_vec.size()) { CHECK_GE_OR_RETURN(dim_vec.at(i), 0); } *out_shape = Shape(dim_vec); CHECK_EQ_OR_RETURN(out_shape->elem_cnt(), in_shape.elem_cnt()); return Maybe::Ok(); diff --git a/python/oneflow/test/modules/test_reshape.py b/python/oneflow/test/modules/test_reshape.py index d268c1e920f..f8e844c7f6f 100644 --- a/python/oneflow/test/modules/test_reshape.py +++ b/python/oneflow/test/modules/test_reshape.py @@ -84,12 +84,13 @@ def test_reshape_flow_with_random_data(test_case): y = torch.reshape(x, shape=(-1,)) return y - @unittest.skip("reshape has bug") @autotest(auto_backward=False) def test_reshape_with_0shape_data(test_case): device = random_device() - x = random_pytorch_tensor(4, 2, 1, 0, 3).to(device) - y = torch.reshape(x, shape=[0]) + x = random_pytorch_tensor(4, 2, 0, 3).to(device) + y = torch.reshape( + x, shape=(random(0, 5).to(int).value(), 0, random(0, 5).to(int).value()) + ) return y From 70bb447ec34d5eaa8f54f5201beb5268e61a7ccc Mon Sep 17 00:00:00 2001 From: liufengwei <2472937968@qq.com> Date: Fri, 30 Jul 2021 16:35:03 +0800 Subject: [PATCH 16/32] math binary and unary backward skip when elem equal to zeros --- oneflow/user/kernels/math_binary_elementwise_kernel.cu | 5 +++++ oneflow/user/kernels/math_unary_elementwise_kernel.cu | 3 +++ 2 files changed, 8 insertions(+) diff --git a/oneflow/user/kernels/math_binary_elementwise_kernel.cu b/oneflow/user/kernels/math_binary_elementwise_kernel.cu index a8d5cfe250c..29c62fdbad7 100644 --- a/oneflow/user/kernels/math_binary_elementwise_kernel.cu +++ b/oneflow/user/kernels/math_binary_elementwise_kernel.cu @@ -74,6 +74,7 @@ class MathBinaryElementwiseXGradGpuKernel final : public user_op::OpKernel { user_op::Tensor* tensor_dx = ctx->Tensor4ArgNameAndIndex("dx", 0); int64_t n = tensor_x->shape().elem_cnt(); CHECK_LE(n, GetMaxVal() / 2); + if (n == 0) { return; } MathBinaryElementwiseBackwardXGradGpu <<device_ctx()->cuda_stream()>>>( n, tensor_x->dptr(), tensor_y->dptr(), tensor_dz->dptr(), @@ -96,6 +97,7 @@ class MathBinaryElementwiseYGradGpuKernel final : public user_op::OpKernel { user_op::Tensor* tensor_dy = ctx->Tensor4ArgNameAndIndex("dy", 0); int64_t n = tensor_x->shape().elem_cnt(); CHECK_LE(n, GetMaxVal() / 2); + if (n == 0) { return; } MathBinaryElementwiseBackwardYGradGpu <<device_ctx()->cuda_stream()>>>( n, tensor_x->dptr(), tensor_y->dptr(), tensor_dz->dptr(), @@ -144,6 +146,7 @@ class MathBinaryElementwiseGpuHalfKernel final : public user_op::OpKernel { half* z = reinterpret_cast(tensor_z->mut_dptr()); int64_t n = tensor_x->shape().elem_cnt(); CHECK_LE(n, GetMaxVal() / 2); + if (n == 0) { return; } MathBinaryElementwiseForwardGpu <<device_ctx()->cuda_stream()>>>( n, x, y, z); @@ -170,6 +173,7 @@ class MathBinaryElementwiseXGradGpuHalfKernel final : public user_op::OpKernel { half* dx = reinterpret_cast(tensor_dx->mut_dptr()); int64_t n = tensor_x->shape().elem_cnt(); CHECK_LE(n, GetMaxVal() / 2); + if (n == 0) { return; } MathBinaryElementwiseBackwardXGradGpu <<device_ctx()->cuda_stream()>>>( n, x, y, dz, dx); @@ -196,6 +200,7 @@ class MathBinaryElementwiseYGradGpuHalfKernel final : public user_op::OpKernel { half* dy = reinterpret_cast(tensor_dy->mut_dptr()); int64_t n = tensor_x->shape().elem_cnt(); CHECK_LE(n, GetMaxVal() / 2); + if (n == 0) { return; } MathBinaryElementwiseBackwardYGradGpu <<device_ctx()->cuda_stream()>>>( n, x, y, dz, dy); diff --git a/oneflow/user/kernels/math_unary_elementwise_kernel.cu b/oneflow/user/kernels/math_unary_elementwise_kernel.cu index 5beb2360c30..7daeb0bb4e2 100644 --- a/oneflow/user/kernels/math_unary_elementwise_kernel.cu +++ b/oneflow/user/kernels/math_unary_elementwise_kernel.cu @@ -71,6 +71,7 @@ class MathUnaryElementwiseGradGpuKernel final : public user_op::OpKernel { T* dx = tensor_dx->mut_dptr(); int64_t n = tensor_x->shape().elem_cnt(); CHECK_LE(n, GetMaxVal() / 2); + if (n == 0) { return; } MathUnaryElementwiseBackwardGpu <<device_ctx()->cuda_stream()>>>( n, x, dy, dx); @@ -111,6 +112,7 @@ class MathUnaryElementwiseGpuHalfKernel final : public user_op::OpKernel { half* y = reinterpret_cast(tensor_y->mut_dptr()); int64_t n = tensor_x->shape().elem_cnt(); CHECK_LE(n, GetMaxVal() / 2); + if (n == 0) { return; } MathUnaryElementwiseForwardGpu <<device_ctx()->cuda_stream()>>>( n, x, y); @@ -135,6 +137,7 @@ class MathUnaryElementwiseGradGpuHalfKernel final : public user_op::OpKernel { half* dx = reinterpret_cast(tensor_dx->mut_dptr()); int64_t n = tensor_x->shape().elem_cnt(); CHECK_LE(n, GetMaxVal() / 2); + if (n == 0) { return; } MathUnaryElementwiseBackwardGpu <<device_ctx()->cuda_stream()>>>( n, x, dy, dx); From 102db91686073c1fd7289ea129f2ea7664868ec4 Mon Sep 17 00:00:00 2001 From: wyg1997 Date: Fri, 30 Jul 2021 18:26:54 +0800 Subject: [PATCH 17/32] fix(ReduceOp): fix reduce not memset bug --- oneflow/user/kernels/reduce_kernel.cpp | 11 ++++++++++- python/oneflow/test/modules/test_sum.py | 17 ++++++++++++----- .../torch_flow_dual_object.py | 1 - 3 files changed, 22 insertions(+), 7 deletions(-) diff --git a/oneflow/user/kernels/reduce_kernel.cpp b/oneflow/user/kernels/reduce_kernel.cpp index 22d7d7426ba..b6c235bb5e2 100644 --- a/oneflow/user/kernels/reduce_kernel.cpp +++ b/oneflow/user/kernels/reduce_kernel.cpp @@ -16,6 +16,7 @@ limitations under the License. #include "oneflow/core/framework/framework.h" #include "oneflow/core/ndarray/ndarray_util.h" #include "oneflow/core/ndarray/xpu_var_ndarray.h" +#include "oneflow/core/kernel/kernel_util.h" namespace oneflow { @@ -34,7 +35,15 @@ class ReduceKernel final : public user_op::OpKernel { user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); const auto& axis = ctx->Attr>("axis"); - if (input_tensor->shape().elem_cnt() == 0) { return; } + if (input_tensor->shape().elem_cnt() == 0) { + if (output_tensor->shape().elem_cnt() != 0) { + AutoMemset( + ctx->device_ctx(), output_tensor->mut_dptr(), 0, + output_tensor->shape().elem_cnt() * GetSizeOfDataType(output_tensor->data_type()), + output_tensor->mem_case()); + } + return; + } const Shape& reduced_shape = CreateReducedShape(input_tensor->shape(), {axis.begin(), axis.end()}); NdarrayReduce::Reduce( diff --git a/python/oneflow/test/modules/test_sum.py b/python/oneflow/test/modules/test_sum.py index e9927eeda88..90a457dd639 100644 --- a/python/oneflow/test/modules/test_sum.py +++ b/python/oneflow/test/modules/test_sum.py @@ -69,12 +69,19 @@ def test_sum(test_case): for arg in GenArgList(arg_dict): _test_sum_impl(test_case, *arg) + @autotest() def test_sum_against_pytorch(test_case): - arg_dict = OrderedDict() - arg_dict["test_type"] = [test_flow_against_pytorch, test_tensor_against_pytorch] - arg_dict["device"] = ["cpu", "cuda"] - for arg in GenArgList(arg_dict): - arg[0](test_case, "sum", device=arg[1]) + device = random_device() + x = random_pytorch_tensor(4, random(0, 5), 2).to(device) + y = torch.sum(x) + return y + + @autotest(auto_backward=False) + def test_sum_with_0shape_tensor(test_case): + device = random_device() + x = random_pytorch_tensor(4, 0, 2).to("cuda:0") + y = torch.sum(x) + return y if __name__ == "__main__": diff --git a/python/oneflow/test_utils/automated_test_util/torch_flow_dual_object.py b/python/oneflow/test_utils/automated_test_util/torch_flow_dual_object.py index b39b535e7c6..9c55aaba78f 100644 --- a/python/oneflow/test_utils/automated_test_util/torch_flow_dual_object.py +++ b/python/oneflow/test_utils/automated_test_util/torch_flow_dual_object.py @@ -277,7 +277,6 @@ def new_f(test_case): except PyTorchDoesNotSupportError as e: if verbose: print(e) - n -= 1 loop += 1 continue if res is not None: From a1d9b42a61d679d89cf58694e6ee653ce1f979ff Mon Sep 17 00:00:00 2001 From: liufengwei <2472937968@qq.com> Date: Fri, 30 Jul 2021 22:32:35 +0800 Subject: [PATCH 18/32] support getitem output empty tensor --- oneflow/api/python/functional/indexing.cpp | 2 -- oneflow/core/functional/tensor_index.cpp | 2 -- oneflow/user/ops/slice_op.cpp | 20 ++++++++++---------- python/oneflow/ops/array_ops.py | 2 +- python/oneflow/test/modules/test_sum.py | 2 +- 5 files changed, 12 insertions(+), 16 deletions(-) diff --git a/oneflow/api/python/functional/indexing.cpp b/oneflow/api/python/functional/indexing.cpp index 0457fcaca8e..b8e53d84437 100644 --- a/oneflow/api/python/functional/indexing.cpp +++ b/oneflow/api/python/functional/indexing.cpp @@ -50,8 +50,6 @@ Maybe PySliceUnpack(PyObject* object, Py_ssize_t* start, Py_ssize_t* stop, CHECK_OR_RETURN(_PyEval_SliceIndex(obj->stop, stop)) << "Invalid slice " << PyStringAsString(PyObject_Repr(object)); } - CHECK_LT_OR_RETURN(*start, *stop) - << "Slice stop must be greater than start since 0 size shape is not allowed currently."; return Maybe::Ok(); } diff --git a/oneflow/core/functional/tensor_index.cpp b/oneflow/core/functional/tensor_index.cpp index 251635eee27..dcc3f33dff1 100644 --- a/oneflow/core/functional/tensor_index.cpp +++ b/oneflow/core/functional/tensor_index.cpp @@ -65,10 +65,8 @@ Maybe PrepareSliceIndices(const TensorIndex& index, const Shape& shape, } CHECK_LT_OR_RETURN(dim, ndims) << "Invalid index for tensor of dimension " << ndims; if (index_item.IsSlice()) { - CHECK_GT_OR_RETURN(shape.At(dim), 0) << "Slice cannot be applied to a 0-dim tensor."; const auto& slice = index_item.slice(); int64_t step = std::min(slice.step(), shape.At(dim)); - CHECK_GT_OR_RETURN(step, 0) << "Step must be greater than zero."; int64_t end = std::min(slice.end(), shape.At(dim)); int64_t start = std::min(slice.start(), shape.At(dim)); if (start < 0) { start += shape.At(dim); } diff --git a/oneflow/user/ops/slice_op.cpp b/oneflow/user/ops/slice_op.cpp index 48ca98b5b70..c3ebf065114 100644 --- a/oneflow/user/ops/slice_op.cpp +++ b/oneflow/user/ops/slice_op.cpp @@ -40,22 +40,22 @@ Maybe InferSliceOpTensorDesc(user_op::InferContext* ctx) { DimVector dim_vec(ndim); FOR_RANGE(size_t, i, 0, dim_vec.size()) { const int64_t dim_size = x_shape.At(i); - if (dim_size == 0) { + const int64_t step = step_vec.at(i); + int64_t start = start_vec.at(i); + int64_t stop = stop_vec.at(i); + if (dim_size == 0 || start == stop) { dim_vec[i] = 0; continue; } - const int64_t step = step_vec.at(i); CHECK_NE_OR_RETURN(step, 0) << "slice step cannot be 0"; - int64_t start = RegulateSliceStart(start_vec.at(i), dim_size); - int64_t stop = RegulateSliceStop(stop_vec.at(i), dim_size); + start = RegulateSliceStart(start, dim_size); + stop = RegulateSliceStop(stop, dim_size); if (step > 0) { - CHECK_LE_OR_RETURN(start, stop) - << "slice start must be less than or equal to stop when step > 0" - ", otherwise empty result will be outputted."; + CHECK_LT_OR_RETURN(start, stop) << "slice start must be less than when step > 0" + ", otherwise empty result will be outputted."; } else { - CHECK_GE_OR_RETURN(start, stop) - << "slice start must be more than or equal to stop when step < 0" - ", otherwise empty result will be outputted."; + CHECK_GT_OR_RETURN(start, stop) << "slice start must be more than when step < 0" + ", otherwise empty result will be outputted."; } const int64_t diff = (step > 0) ? (stop - start - 1) : (stop - start + 1); dim_vec[i] = diff / step + 1; diff --git a/python/oneflow/ops/array_ops.py b/python/oneflow/ops/array_ops.py index 10dbb0e0adc..1c7eb578965 100644 --- a/python/oneflow/ops/array_ops.py +++ b/python/oneflow/ops/array_ops.py @@ -37,7 +37,7 @@ def check_slice_tup_list(slice_tup_list, shape): if not all((isinstance(idx, int) or idx is None for idx in slice_tup)): raise ValueError("element of slice tuple must int or None") (start, stop, step) = slice_tup - if step is None or start == stop: + if step is None: step = 1 if step == 0: raise ValueError("slice step can't be 0") diff --git a/python/oneflow/test/modules/test_sum.py b/python/oneflow/test/modules/test_sum.py index 90a457dd639..63030c54c7d 100644 --- a/python/oneflow/test/modules/test_sum.py +++ b/python/oneflow/test/modules/test_sum.py @@ -79,7 +79,7 @@ def test_sum_against_pytorch(test_case): @autotest(auto_backward=False) def test_sum_with_0shape_tensor(test_case): device = random_device() - x = random_pytorch_tensor(4, 0, 2).to("cuda:0") + x = random_pytorch_tensor(4, 0, 2).to(device) y = torch.sum(x) return y From f5f389ef560397813e16003abb201cff6b2fd6cc Mon Sep 17 00:00:00 2001 From: liufengwei <2472937968@qq.com> Date: Fri, 30 Jul 2021 22:35:11 +0800 Subject: [PATCH 19/32] fix comment --- oneflow/user/ops/slice_op.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/oneflow/user/ops/slice_op.cpp b/oneflow/user/ops/slice_op.cpp index c3ebf065114..a39d973ca59 100644 --- a/oneflow/user/ops/slice_op.cpp +++ b/oneflow/user/ops/slice_op.cpp @@ -51,10 +51,10 @@ Maybe InferSliceOpTensorDesc(user_op::InferContext* ctx) { start = RegulateSliceStart(start, dim_size); stop = RegulateSliceStop(stop, dim_size); if (step > 0) { - CHECK_LT_OR_RETURN(start, stop) << "slice start must be less than when step > 0" + CHECK_LT_OR_RETURN(start, stop) << "slice start must be less than stop when step > 0" ", otherwise empty result will be outputted."; } else { - CHECK_GT_OR_RETURN(start, stop) << "slice start must be more than when step < 0" + CHECK_GT_OR_RETURN(start, stop) << "slice start must be more than stop when step < 0" ", otherwise empty result will be outputted."; } const int64_t diff = (step > 0) ? (stop - start - 1) : (stop - start + 1); From 8e50de3c77272b9eeffd05a78ab33522bd8b0f56 Mon Sep 17 00:00:00 2001 From: liufengwei <2472937968@qq.com> Date: Sat, 31 Jul 2021 08:57:07 +0800 Subject: [PATCH 20/32] getitem support input is empty --- oneflow/core/functional/tensor_index.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/oneflow/core/functional/tensor_index.cpp b/oneflow/core/functional/tensor_index.cpp index dcc3f33dff1..0fa495fe745 100644 --- a/oneflow/core/functional/tensor_index.cpp +++ b/oneflow/core/functional/tensor_index.cpp @@ -73,8 +73,9 @@ Maybe PrepareSliceIndices(const TensorIndex& index, const Shape& shape, if (start < 0) { start = 0; } if (end < 0) { end += shape.At(dim); } if (end < start) { end = start; } + if (start == end) { step = 1; } slice_indices->emplace_back(start, end, step); - int64_t length = (end - start + step - 1) / step; + int64_t length = start == end ? 0 : (end - start + step - 1) / step; target_dims->emplace_back(length); dim++; } else if (index_item.IsInteger()) { From 9ccf9293823644e2b30e07e088a87d49ee18de71 Mon Sep 17 00:00:00 2001 From: liufengwei <2472937968@qq.com> Date: Sat, 31 Jul 2021 10:20:23 +0800 Subject: [PATCH 21/32] reduce_like kernel support empty --- oneflow/user/kernels/reduce_like_kernels.cpp | 8 ++++++++ python/oneflow/test/modules/test_abs.py | 2 +- python/oneflow/test/modules/test_add.py | 9 +++------ 3 files changed, 12 insertions(+), 7 deletions(-) diff --git a/oneflow/user/kernels/reduce_like_kernels.cpp b/oneflow/user/kernels/reduce_like_kernels.cpp index d456857fcbd..5eff8bd47f1 100644 --- a/oneflow/user/kernels/reduce_like_kernels.cpp +++ b/oneflow/user/kernels/reduce_like_kernels.cpp @@ -39,6 +39,14 @@ class ReduceSumLikeOpKernel final : public user_op::OpKernel { user_op::Tensor* tensor_x = ctx->Tensor4ArgNameAndIndex("x", 0); user_op::Tensor* tensor_y = ctx->Tensor4ArgNameAndIndex("y", 0); const auto& axis = ctx->Attr>("axis"); + if (tensor_x->shape().elem_cnt() == 0) { + if (tensor_y->shape().elem_cnt() != 0) { + AutoMemset(ctx->device_ctx(), tensor_y->mut_dptr(), 0, + tensor_y->shape().elem_cnt() * GetSizeOfDataType(tensor_y->data_type()), + tensor_y->mem_case()); + } + return; + } if (axis.empty()) { CHECK_EQ(tensor_x->shape(), tensor_y->shape()); Memcpy(ctx->device_ctx(), tensor_y->mut_dptr(), tensor_x->dptr(), diff --git a/python/oneflow/test/modules/test_abs.py b/python/oneflow/test/modules/test_abs.py index 105beed194d..1682f51bbe2 100644 --- a/python/oneflow/test/modules/test_abs.py +++ b/python/oneflow/test/modules/test_abs.py @@ -82,7 +82,7 @@ def test_flow_tensor_abs_with_random_data(test_case): for device in ["cpu", "cuda"]: test_tensor_against_pytorch(test_case, "abs", device=device) - @autotest(auto_backward=False) + @autotest() def test_abs_with_0shape_data(test_case): device = random_device() x = random_pytorch_tensor(4, 2, 1, 0, 3).to(device) diff --git a/python/oneflow/test/modules/test_add.py b/python/oneflow/test/modules/test_add.py index 6dc2f29dcc6..1a9c2c95f4a 100644 --- a/python/oneflow/test/modules/test_add.py +++ b/python/oneflow/test/modules/test_add.py @@ -152,16 +152,13 @@ def test_add(test_case): for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) - @autotest(auto_backward=False) + @autotest() def test_0shape_add(test_case): device = random_device() x = random_pytorch_tensor(2, 0, 3).to(device) y = random_pytorch_tensor(2, 1, 3).to(device) - out1 = x + y - out2 = x + 2 - out3 = 2 + x - out4 = torch.add(x, y) - return out1, out2, out3 + out = x + y + return out if __name__ == "__main__": From a787aa1f420ab1b5b0b4032ee14d7be52ac9eaa0 Mon Sep 17 00:00:00 2001 From: liufengwei <2472937968@qq.com> Date: Sat, 31 Jul 2021 13:42:07 +0800 Subject: [PATCH 22/32] fix op test bug --- python/oneflow/test/modules/test_activation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/oneflow/test/modules/test_activation.py b/python/oneflow/test/modules/test_activation.py index cad3a489637..2c9fe612bbe 100644 --- a/python/oneflow/test/modules/test_activation.py +++ b/python/oneflow/test/modules/test_activation.py @@ -200,12 +200,12 @@ def test_flow_tanh_with_random_data(test_case): y = torch.tanh(x) return y - @unittest.skip("reshape has bug or auto test has bug") + @autotest(auto_backward=False) def test_flow_tanh_with_0shape_data(test_case): device = random_device() x = random_pytorch_tensor(4, 2, 3, 0, 3).to(device) - y = flow.tanh(x) + y = torch.tanh(x) return y From 42326970d73e0c68c44a881eec79b636fa25164d Mon Sep 17 00:00:00 2001 From: wyg1997 Date: Sat, 31 Jul 2021 16:39:37 +0800 Subject: [PATCH 23/32] feat(ReduceOp): refine reduce ops initialize value --- oneflow/user/kernels/reduce_kernel.cpp | 8 ++++---- oneflow/user/kernels/reduce_like_kernels.cpp | 8 +++++--- python/oneflow/nn/modules/reduce_ops.py | 6 ++++++ 3 files changed, 15 insertions(+), 7 deletions(-) diff --git a/oneflow/user/kernels/reduce_kernel.cpp b/oneflow/user/kernels/reduce_kernel.cpp index b6c235bb5e2..97413e8fcb1 100644 --- a/oneflow/user/kernels/reduce_kernel.cpp +++ b/oneflow/user/kernels/reduce_kernel.cpp @@ -37,10 +37,10 @@ class ReduceKernel final : public user_op::OpKernel { if (input_tensor->shape().elem_cnt() == 0) { if (output_tensor->shape().elem_cnt() != 0) { - AutoMemset( - ctx->device_ctx(), output_tensor->mut_dptr(), 0, - output_tensor->shape().elem_cnt() * GetSizeOfDataType(output_tensor->data_type()), - output_tensor->mem_case()); + for(int32_t dim: axis) { + CHECK_EQ(output_tensor->shape().At(dim), 1); + } + KernelUtil::Set(ctx->device_ctx(), UnitOfBinaryFunc::Val(), output_tensor->mut_dptr()); } return; } diff --git a/oneflow/user/kernels/reduce_like_kernels.cpp b/oneflow/user/kernels/reduce_like_kernels.cpp index 5eff8bd47f1..50b1649f759 100644 --- a/oneflow/user/kernels/reduce_like_kernels.cpp +++ b/oneflow/user/kernels/reduce_like_kernels.cpp @@ -14,6 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/ndarray/binary_func.h" #include "oneflow/core/ndarray/ndarray_util.h" namespace oneflow { @@ -41,9 +42,10 @@ class ReduceSumLikeOpKernel final : public user_op::OpKernel { const auto& axis = ctx->Attr>("axis"); if (tensor_x->shape().elem_cnt() == 0) { if (tensor_y->shape().elem_cnt() != 0) { - AutoMemset(ctx->device_ctx(), tensor_y->mut_dptr(), 0, - tensor_y->shape().elem_cnt() * GetSizeOfDataType(tensor_y->data_type()), - tensor_y->mem_case()); + for(int32_t dim: axis) { + CHECK_EQ(tensor_y->shape().At(dim), 1); + } + KernelUtil::Set(ctx->device_ctx(), 0, tensor_y->mut_dptr()); } return; } diff --git a/python/oneflow/nn/modules/reduce_ops.py b/python/oneflow/nn/modules/reduce_ops.py index 867ef0323d0..f43fbb02e43 100644 --- a/python/oneflow/nn/modules/reduce_ops.py +++ b/python/oneflow/nn/modules/reduce_ops.py @@ -114,6 +114,9 @@ def __init__( self._op = _build_reduce_op("reduce_min", keepdims) def forward(self, input): + # TODO: moves this check in functor + if input.shape.numel() == 0: + raise RuntimeError("operation does not have an identity.") axis_checked = _check_axis(self.axis, input.shape) if len(axis_checked) == 0: return input @@ -151,6 +154,9 @@ def __init__( self._op = _build_reduce_op("reduce_max", keepdims) def forward(self, input): + # TODO: moves this check in functor + if input.shape.numel() == 0: + raise RuntimeError("operation does not have an identity.") axis_checked = _check_axis(self.axis, input.shape) if len(axis_checked) == 0: return input From ebde36505273f56e62032332a4f1543ca4cd3691 Mon Sep 17 00:00:00 2001 From: wyg1997 Date: Sat, 31 Jul 2021 16:41:14 +0800 Subject: [PATCH 24/32] format code --- oneflow/user/kernels/reduce_kernel.cpp | 7 +++---- oneflow/user/kernels/reduce_like_kernels.cpp | 6 ++---- python/oneflow/test/modules/test_activation.py | 1 - 3 files changed, 5 insertions(+), 9 deletions(-) diff --git a/oneflow/user/kernels/reduce_kernel.cpp b/oneflow/user/kernels/reduce_kernel.cpp index 97413e8fcb1..7989d6e20bc 100644 --- a/oneflow/user/kernels/reduce_kernel.cpp +++ b/oneflow/user/kernels/reduce_kernel.cpp @@ -37,10 +37,9 @@ class ReduceKernel final : public user_op::OpKernel { if (input_tensor->shape().elem_cnt() == 0) { if (output_tensor->shape().elem_cnt() != 0) { - for(int32_t dim: axis) { - CHECK_EQ(output_tensor->shape().At(dim), 1); - } - KernelUtil::Set(ctx->device_ctx(), UnitOfBinaryFunc::Val(), output_tensor->mut_dptr()); + for (int32_t dim : axis) { CHECK_EQ(output_tensor->shape().At(dim), 1); } + KernelUtil::Set(ctx->device_ctx(), UnitOfBinaryFunc::Val(), + output_tensor->mut_dptr()); } return; } diff --git a/oneflow/user/kernels/reduce_like_kernels.cpp b/oneflow/user/kernels/reduce_like_kernels.cpp index 50b1649f759..42492319754 100644 --- a/oneflow/user/kernels/reduce_like_kernels.cpp +++ b/oneflow/user/kernels/reduce_like_kernels.cpp @@ -42,10 +42,8 @@ class ReduceSumLikeOpKernel final : public user_op::OpKernel { const auto& axis = ctx->Attr>("axis"); if (tensor_x->shape().elem_cnt() == 0) { if (tensor_y->shape().elem_cnt() != 0) { - for(int32_t dim: axis) { - CHECK_EQ(tensor_y->shape().At(dim), 1); - } - KernelUtil::Set(ctx->device_ctx(), 0, tensor_y->mut_dptr()); + for (int32_t dim : axis) { CHECK_EQ(tensor_y->shape().At(dim), 1); } + KernelUtil::Set(ctx->device_ctx(), 0, tensor_y->mut_dptr()); } return; } diff --git a/python/oneflow/test/modules/test_activation.py b/python/oneflow/test/modules/test_activation.py index 2c9fe612bbe..6bff8656e07 100644 --- a/python/oneflow/test/modules/test_activation.py +++ b/python/oneflow/test/modules/test_activation.py @@ -200,7 +200,6 @@ def test_flow_tanh_with_random_data(test_case): y = torch.tanh(x) return y - @autotest(auto_backward=False) def test_flow_tanh_with_0shape_data(test_case): device = random_device() From 2611d194e0f3aa1315fa5dc8bf0e56236f970ba2 Mon Sep 17 00:00:00 2001 From: liufengwei <2472937968@qq.com> Date: Sat, 31 Jul 2021 18:37:39 +0800 Subject: [PATCH 25/32] fix triu bug when input is empty --- oneflow/user/kernels/triu_kernel.cu | 1 + python/oneflow/test/modules/test_activation.py | 1 - python/oneflow/test/modules/test_triu.py | 2 +- 3 files changed, 2 insertions(+), 2 deletions(-) diff --git a/oneflow/user/kernels/triu_kernel.cu b/oneflow/user/kernels/triu_kernel.cu index bf2173d4d40..2dd732982d6 100644 --- a/oneflow/user/kernels/triu_kernel.cu +++ b/oneflow/user/kernels/triu_kernel.cu @@ -90,6 +90,7 @@ class GpuTriuKernel final : public user_op::OpKernel { const int64_t num_cols = shape.At(shape.NumAxes() - 1); user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("out", 0); const int32_t elem_cnt = shape.elem_cnt(); + if (elem_cnt == 0) { return; } if (num_cols % (kCudaWarpSize * 2) == 0) { const int64_t total_rows = elem_cnt / num_cols; TriuWarpProcessRowGpu<< Date: Sat, 31 Jul 2021 18:48:20 +0800 Subject: [PATCH 26/32] test(AbsOp): fix test bug --- oneflow/user/kernels/reduce_kernel.cpp | 1 - oneflow/user/kernels/reduce_like_kernels.cpp | 1 - python/oneflow/test/modules/test_eq.py | 6 +++--- 3 files changed, 3 insertions(+), 5 deletions(-) diff --git a/oneflow/user/kernels/reduce_kernel.cpp b/oneflow/user/kernels/reduce_kernel.cpp index 7989d6e20bc..afdb1606e51 100644 --- a/oneflow/user/kernels/reduce_kernel.cpp +++ b/oneflow/user/kernels/reduce_kernel.cpp @@ -37,7 +37,6 @@ class ReduceKernel final : public user_op::OpKernel { if (input_tensor->shape().elem_cnt() == 0) { if (output_tensor->shape().elem_cnt() != 0) { - for (int32_t dim : axis) { CHECK_EQ(output_tensor->shape().At(dim), 1); } KernelUtil::Set(ctx->device_ctx(), UnitOfBinaryFunc::Val(), output_tensor->mut_dptr()); } diff --git a/oneflow/user/kernels/reduce_like_kernels.cpp b/oneflow/user/kernels/reduce_like_kernels.cpp index 42492319754..25fcb961569 100644 --- a/oneflow/user/kernels/reduce_like_kernels.cpp +++ b/oneflow/user/kernels/reduce_like_kernels.cpp @@ -42,7 +42,6 @@ class ReduceSumLikeOpKernel final : public user_op::OpKernel { const auto& axis = ctx->Attr>("axis"); if (tensor_x->shape().elem_cnt() == 0) { if (tensor_y->shape().elem_cnt() != 0) { - for (int32_t dim : axis) { CHECK_EQ(tensor_y->shape().At(dim), 1); } KernelUtil::Set(ctx->device_ctx(), 0, tensor_y->mut_dptr()); } return; diff --git a/python/oneflow/test/modules/test_eq.py b/python/oneflow/test/modules/test_eq.py index 71ba7caf4bf..b838397c2cd 100644 --- a/python/oneflow/test/modules/test_eq.py +++ b/python/oneflow/test/modules/test_eq.py @@ -103,9 +103,9 @@ def test_eq(test_case): @autotest(auto_backward=False) def test_eq_with_0shape_data(test_case): device = random_device() - x = random_pytorch_tensor(4, 2, 1, 0, 3).to(device) - y = random_pytorch_tensor(4, 2, 1, 0, 3).to(device) - z = torch.eq(y) + x = random_pytorch_tensor(3, 2, 0, 3).to(device) + y = random_pytorch_tensor(3, 2, 0, 3).to(device) + z = torch.eq(x, y) return z From 3c124a93271652fc2b91e614506cd66acaa19b3c Mon Sep 17 00:00:00 2001 From: wyg1997 Date: Sat, 31 Jul 2021 19:27:15 +0800 Subject: [PATCH 27/32] test(DivOp): fix test bug --- python/oneflow/test/modules/test_div.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/oneflow/test/modules/test_div.py b/python/oneflow/test/modules/test_div.py index dcf886a4887..40fddb3e73e 100644 --- a/python/oneflow/test/modules/test_div.py +++ b/python/oneflow/test/modules/test_div.py @@ -105,8 +105,9 @@ def test_sub_against_pytorch(test_case): def test_0shape_div(test_case): device = random_device() x = random_pytorch_tensor(4, 2, 1, 0, 3).to(device) - y = torch.div(x) - return y + y = random_pytorch_tensor(4, 2, 1, 0, 3).to(device) + z = x / y + return z if __name__ == "__main__": From 891095ba24b73ca1d24346deead5ee7135531d26 Mon Sep 17 00:00:00 2001 From: liufengwei <2472937968@qq.com> Date: Sat, 31 Jul 2021 21:27:36 +0800 Subject: [PATCH 28/32] fix clamp bug --- oneflow/user/kernels/clip_by_value_kernel.cu | 2 ++ python/oneflow/test/modules/test_clamp.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/oneflow/user/kernels/clip_by_value_kernel.cu b/oneflow/user/kernels/clip_by_value_kernel.cu index df9f9528982..5a5acffd84d 100644 --- a/oneflow/user/kernels/clip_by_value_kernel.cu +++ b/oneflow/user/kernels/clip_by_value_kernel.cu @@ -36,12 +36,14 @@ template struct ClipKernelUtil { template static void Forward(DeviceCtx* ctx, F clip_func, const int64_t n, const T* x, T* y) { + if (n == 0) { return; } RUN_CUDA_KERNEL((CudaClipForward), ctx, n, clip_func, n, x, y); } template static void Backward(DeviceCtx* ctx, F clip_func, const int64_t n, const T* x, const T* dy, T* dx) { + if (n == 0) { return; } RUN_CUDA_KERNEL((CudaClipBackward), ctx, n, clip_func, n, x, dy, dx); } }; diff --git a/python/oneflow/test/modules/test_clamp.py b/python/oneflow/test/modules/test_clamp.py index c22c39a6396..afcc68cd430 100644 --- a/python/oneflow/test/modules/test_clamp.py +++ b/python/oneflow/test/modules/test_clamp.py @@ -157,7 +157,7 @@ def test_clip_max_none_flow_with_random_data(test_case): def test_clamp_with_0shape_data(test_case): device = random_device() x = random_pytorch_tensor(4, 2, 1, 0, 3).to(device) - y = torch.clamp(x) + y = torch.clamp(x, min=random().to(float), max=random().to(float)) return y From 472ff2bd9d6b03148c77bc5c04dc945385579961 Mon Sep 17 00:00:00 2001 From: liufengwei <2472937968@qq.com> Date: Sun, 1 Aug 2021 08:11:53 +0800 Subject: [PATCH 29/32] fix test_sub bug --- python/oneflow/test/modules/test_sub.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/oneflow/test/modules/test_sub.py b/python/oneflow/test/modules/test_sub.py index f01144595a9..db9f9a9e025 100644 --- a/python/oneflow/test/modules/test_sub.py +++ b/python/oneflow/test/modules/test_sub.py @@ -119,7 +119,7 @@ def test_sub_with_0shape_data(test_case): out1 = x - y out2 = x - 2 out3 = 2 - x - out4 = torch.sub(x - y) + out4 = torch.sub(x, y) return out1, out2, out3, out4 From 1558ace3264776c87de8ee6246c7d3bc74689519 Mon Sep 17 00:00:00 2001 From: wyg1997 Date: Sun, 1 Aug 2021 12:17:39 +0800 Subject: [PATCH 30/32] fix(ReduceOp): fix reduce op memset bug --- oneflow/user/kernels/reduce_kernel.cpp | 6 ++++-- oneflow/user/kernels/reduce_like_kernels.cpp | 6 +++++- python/oneflow/test/modules/test_sum.py | 4 ++-- 3 files changed, 11 insertions(+), 5 deletions(-) diff --git a/oneflow/user/kernels/reduce_kernel.cpp b/oneflow/user/kernels/reduce_kernel.cpp index afdb1606e51..e7f724ae56e 100644 --- a/oneflow/user/kernels/reduce_kernel.cpp +++ b/oneflow/user/kernels/reduce_kernel.cpp @@ -37,8 +37,10 @@ class ReduceKernel final : public user_op::OpKernel { if (input_tensor->shape().elem_cnt() == 0) { if (output_tensor->shape().elem_cnt() != 0) { - KernelUtil::Set(ctx->device_ctx(), UnitOfBinaryFunc::Val(), - output_tensor->mut_dptr()); + AutoMemset( + ctx->device_ctx(), output_tensor->mut_dptr(), 0, + output_tensor->shape().elem_cnt() * GetSizeOfDataType(output_tensor->data_type()), + output_tensor->mem_case()); } return; } diff --git a/oneflow/user/kernels/reduce_like_kernels.cpp b/oneflow/user/kernels/reduce_like_kernels.cpp index 25fcb961569..9a3f0267a25 100644 --- a/oneflow/user/kernels/reduce_like_kernels.cpp +++ b/oneflow/user/kernels/reduce_like_kernels.cpp @@ -16,6 +16,7 @@ limitations under the License. #include "oneflow/core/framework/framework.h" #include "oneflow/core/ndarray/binary_func.h" #include "oneflow/core/ndarray/ndarray_util.h" +#include "oneflow/core/kernel/kernel_util.h" namespace oneflow { @@ -42,7 +43,10 @@ class ReduceSumLikeOpKernel final : public user_op::OpKernel { const auto& axis = ctx->Attr>("axis"); if (tensor_x->shape().elem_cnt() == 0) { if (tensor_y->shape().elem_cnt() != 0) { - KernelUtil::Set(ctx->device_ctx(), 0, tensor_y->mut_dptr()); + AutoMemset( + ctx->device_ctx(), tensor_y->mut_dptr(), 0, + tensor_y->shape().elem_cnt() * GetSizeOfDataType(tensor_y->data_type()), + tensor_y->mem_case()); } return; } diff --git a/python/oneflow/test/modules/test_sum.py b/python/oneflow/test/modules/test_sum.py index 63030c54c7d..a28f179413a 100644 --- a/python/oneflow/test/modules/test_sum.py +++ b/python/oneflow/test/modules/test_sum.py @@ -79,8 +79,8 @@ def test_sum_against_pytorch(test_case): @autotest(auto_backward=False) def test_sum_with_0shape_tensor(test_case): device = random_device() - x = random_pytorch_tensor(4, 0, 2).to(device) - y = torch.sum(x) + x = random_pytorch_tensor(4, 4, 3, 0, 2).to(device) + y = torch.sum(x, dim=random(0, 3)) return y From 7a0629276c51a1f8da3ac7a731e19aff1370911a Mon Sep 17 00:00:00 2001 From: oneflow-ci-bot Date: Sun, 1 Aug 2021 04:26:59 +0000 Subject: [PATCH 31/32] auto format by CI --- oneflow/user/kernels/reduce_kernel.cpp | 2 +- oneflow/user/kernels/reduce_like_kernels.cpp | 7 +++---- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/oneflow/user/kernels/reduce_kernel.cpp b/oneflow/user/kernels/reduce_kernel.cpp index e7f724ae56e..b6c235bb5e2 100644 --- a/oneflow/user/kernels/reduce_kernel.cpp +++ b/oneflow/user/kernels/reduce_kernel.cpp @@ -37,7 +37,7 @@ class ReduceKernel final : public user_op::OpKernel { if (input_tensor->shape().elem_cnt() == 0) { if (output_tensor->shape().elem_cnt() != 0) { - AutoMemset( + AutoMemset( ctx->device_ctx(), output_tensor->mut_dptr(), 0, output_tensor->shape().elem_cnt() * GetSizeOfDataType(output_tensor->data_type()), output_tensor->mem_case()); diff --git a/oneflow/user/kernels/reduce_like_kernels.cpp b/oneflow/user/kernels/reduce_like_kernels.cpp index 9a3f0267a25..1cab5da469f 100644 --- a/oneflow/user/kernels/reduce_like_kernels.cpp +++ b/oneflow/user/kernels/reduce_like_kernels.cpp @@ -43,10 +43,9 @@ class ReduceSumLikeOpKernel final : public user_op::OpKernel { const auto& axis = ctx->Attr>("axis"); if (tensor_x->shape().elem_cnt() == 0) { if (tensor_y->shape().elem_cnt() != 0) { - AutoMemset( - ctx->device_ctx(), tensor_y->mut_dptr(), 0, - tensor_y->shape().elem_cnt() * GetSizeOfDataType(tensor_y->data_type()), - tensor_y->mem_case()); + AutoMemset(ctx->device_ctx(), tensor_y->mut_dptr(), 0, + tensor_y->shape().elem_cnt() * GetSizeOfDataType(tensor_y->data_type()), + tensor_y->mem_case()); } return; } From bccf28d90fabd421d8ec735ea6c67068a93fc057 Mon Sep 17 00:00:00 2001 From: liufengwei <2472937968@qq.com> Date: Sun, 1 Aug 2021 12:50:46 +0800 Subject: [PATCH 32/32] fix random --- oneflow/user/kernels/reduce_kernel.cpp | 2 +- oneflow/user/kernels/reduce_like_kernels.cpp | 7 +++---- python/oneflow/test/modules/test_sum.py | 2 +- 3 files changed, 5 insertions(+), 6 deletions(-) diff --git a/oneflow/user/kernels/reduce_kernel.cpp b/oneflow/user/kernels/reduce_kernel.cpp index e7f724ae56e..b6c235bb5e2 100644 --- a/oneflow/user/kernels/reduce_kernel.cpp +++ b/oneflow/user/kernels/reduce_kernel.cpp @@ -37,7 +37,7 @@ class ReduceKernel final : public user_op::OpKernel { if (input_tensor->shape().elem_cnt() == 0) { if (output_tensor->shape().elem_cnt() != 0) { - AutoMemset( + AutoMemset( ctx->device_ctx(), output_tensor->mut_dptr(), 0, output_tensor->shape().elem_cnt() * GetSizeOfDataType(output_tensor->data_type()), output_tensor->mem_case()); diff --git a/oneflow/user/kernels/reduce_like_kernels.cpp b/oneflow/user/kernels/reduce_like_kernels.cpp index 9a3f0267a25..1cab5da469f 100644 --- a/oneflow/user/kernels/reduce_like_kernels.cpp +++ b/oneflow/user/kernels/reduce_like_kernels.cpp @@ -43,10 +43,9 @@ class ReduceSumLikeOpKernel final : public user_op::OpKernel { const auto& axis = ctx->Attr>("axis"); if (tensor_x->shape().elem_cnt() == 0) { if (tensor_y->shape().elem_cnt() != 0) { - AutoMemset( - ctx->device_ctx(), tensor_y->mut_dptr(), 0, - tensor_y->shape().elem_cnt() * GetSizeOfDataType(tensor_y->data_type()), - tensor_y->mem_case()); + AutoMemset(ctx->device_ctx(), tensor_y->mut_dptr(), 0, + tensor_y->shape().elem_cnt() * GetSizeOfDataType(tensor_y->data_type()), + tensor_y->mem_case()); } return; } diff --git a/python/oneflow/test/modules/test_sum.py b/python/oneflow/test/modules/test_sum.py index a28f179413a..87462f56d40 100644 --- a/python/oneflow/test/modules/test_sum.py +++ b/python/oneflow/test/modules/test_sum.py @@ -80,7 +80,7 @@ def test_sum_against_pytorch(test_case): def test_sum_with_0shape_tensor(test_case): device = random_device() x = random_pytorch_tensor(4, 4, 3, 0, 2).to(device) - y = torch.sum(x, dim=random(0, 3)) + y = torch.sum(x, dim=np.random.randint(0, 3)) return y