Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support 0shape tensor #5620

Merged
merged 42 commits into from
Aug 1, 2021
Merged
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
e928381
feat(Tensor): support 0shape tensor
wyg1997 Jul 27, 2021
06b893a
math binary broadcast support emoty tensor input
liufengwei0103 Jul 27, 2021
5d04f27
slice support empty tensor input and output
liufengwei0103 Jul 27, 2021
acf3fca
fix check in slice
liufengwei0103 Jul 27, 2021
04d29a0
test(Cat): add 0shape cat module test
wyg1997 Jul 27, 2021
acefe4a
fix return type error on gcc 4.8.5
daquexian Jul 29, 2021
61fa201
Merge branch 'master' into fix_no_return
oneflow-ci-bot Jul 29, 2021
b027fe8
auto format by CI
oneflow-ci-bot Jul 29, 2021
a07615e
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow into …
liufengwei0103 Jul 29, 2021
9dc2e33
Merge branch 'fix_no_return' of https://github.com/Oneflow-Inc/oneflo…
liufengwei0103 Jul 29, 2021
55f3daf
add module op test for empty tensor, cuda kernel support empty tensor
liufengwei0103 Jul 29, 2021
b287eb0
format
liufengwei0103 Jul 30, 2021
5f984f1
feat(ReduceOp): reduce op kernels support 0shape tensor
wyg1997 Jul 30, 2021
08431da
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow into …
liufengwei0103 Jul 30, 2021
d1b3867
delete files added by mistake
liufengwei0103 Jul 30, 2021
0f0f127
refine if
liufengwei0103 Jul 30, 2021
0eb7cff
refine if
liufengwei0103 Jul 30, 2021
65c38d4
feat(ConstantOp): constant ops support 0shape tensor
wyg1997 Jul 30, 2021
068d9fb
feat(ReshapeOp): reshape kernel support 0shape tensor
wyg1997 Jul 30, 2021
70bb447
math binary and unary backward skip when elem equal to zeros
liufengwei0103 Jul 30, 2021
102db91
fix(ReduceOp): fix reduce not memset bug
wyg1997 Jul 30, 2021
46b3a3c
Merge remote-tracking branch 'origin/feat-0shape_tensor' into feat-0s…
wyg1997 Jul 30, 2021
a1d9b42
support getitem output empty tensor
liufengwei0103 Jul 30, 2021
f5f389e
fix comment
liufengwei0103 Jul 30, 2021
8e50de3
getitem support input is empty
liufengwei0103 Jul 31, 2021
9ccf929
reduce_like kernel support empty
liufengwei0103 Jul 31, 2021
a787aa1
fix op test bug
liufengwei0103 Jul 31, 2021
4232697
feat(ReduceOp): refine reduce ops initialize value
wyg1997 Jul 31, 2021
ebde365
format code
wyg1997 Jul 31, 2021
f4fa2f5
Merge branch 'master' into feat-0shape_tensor
liufengwei0103 Jul 31, 2021
682c30c
Merge branch 'master' into feat-0shape_tensor
oneflow-ci-bot Jul 31, 2021
2611d19
fix triu bug when input is empty
liufengwei0103 Jul 31, 2021
194c4a0
Merge branch 'feat-0shape_tensor' of https://github.com/Oneflow-Inc/o…
liufengwei0103 Jul 31, 2021
effe78c
test(AbsOp): fix test bug
wyg1997 Jul 31, 2021
3c124a9
test(DivOp): fix test bug
wyg1997 Jul 31, 2021
891095b
fix clamp bug
liufengwei0103 Jul 31, 2021
472ff2b
fix test_sub bug
liufengwei0103 Aug 1, 2021
1558ace
fix(ReduceOp): fix reduce op memset bug
wyg1997 Aug 1, 2021
7a06292
auto format by CI
oneflow-ci-bot Aug 1, 2021
bef483e
Merge branch 'master' into feat-0shape_tensor
oneflow-ci-bot Aug 1, 2021
bccf28d
fix random
liufengwei0103 Aug 1, 2021
ab5b043
Merge branch 'feat-0shape_tensor' of https://github.com/Oneflow-Inc/o…
liufengwei0103 Aug 1, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 0 additions & 2 deletions oneflow/api/python/functional/indexing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,6 @@ Maybe<void> PySliceUnpack(PyObject* object, Py_ssize_t* start, Py_ssize_t* stop,
CHECK_OR_RETURN(_PyEval_SliceIndex(obj->stop, stop))
<< "Invalid slice " << PyStringAsString(PyObject_Repr(object));
}
CHECK_LT_OR_RETURN(*start, *stop)
<< "Slice stop must be greater than start since 0 size shape is not allowed currently.";
return Maybe<void>::Ok();
}

Expand Down
2 changes: 0 additions & 2 deletions oneflow/core/functional/impl/array_functor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -863,8 +863,6 @@ class TensorGetItemFunctor {
JUST(PrepareSliceIndices(index, *(x->shape()), &slice_indices, &tensor_indices, &target_dims));
CHECK_EQ_OR_RETURN(slice_indices.size(), ndims) << "Failed to prepare slice indices.";
Shape target_shape(DimVector(target_dims.begin(), target_dims.end()));
CHECK_GT_OR_RETURN(target_shape.Count(0), 0)
<< "Target shape is zero shape which was not supported yet.";

std::vector<int64_t> start(ndims), end(ndims), step(ndims);
for (int i = 0; i < ndims; ++i) {
Expand Down
5 changes: 2 additions & 3 deletions oneflow/core/functional/tensor_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,18 +65,17 @@ Maybe<void> PrepareSliceIndices(const TensorIndex& index, const Shape& shape,
}
CHECK_LT_OR_RETURN(dim, ndims) << "Invalid index for tensor of dimension " << ndims;
if (index_item.IsSlice()) {
CHECK_GT_OR_RETURN(shape.At(dim), 0) << "Slice cannot be applied to a 0-dim tensor.";
wyg1997 marked this conversation as resolved.
Show resolved Hide resolved
const auto& slice = index_item.slice();
int64_t step = std::min(slice.step(), shape.At(dim));
CHECK_GT_OR_RETURN(step, 0) << "Step must be greater than zero.";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个不应该删吧,我们就是不支持负数step

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个不应该删吧,我们就是不支持负数step

已修改

int64_t end = std::min(slice.end(), shape.At(dim));
int64_t start = std::min(slice.start(), shape.At(dim));
if (start < 0) { start += shape.At(dim); }
if (start < 0) { start = 0; }
if (end < 0) { end += shape.At(dim); }
if (end < start) { end = start; }
if (start == end) { step = 1; }
slice_indices->emplace_back(start, end, step);
int64_t length = (end - start + step - 1) / step;
int64_t length = start == end ? 0 : (end - start + step - 1) / step;
target_dims->emplace_back(length);
dim++;
} else if (index_item.IsInteger()) {
Expand Down
3 changes: 3 additions & 0 deletions oneflow/core/kernel/kernel_util.cu
Original file line number Diff line number Diff line change
Expand Up @@ -655,6 +655,7 @@ __global__ void CastOnGpu<half, float>(const half* in, float* out, int64_t elem_

template<typename T, typename U>
void CopyElemOnGpu(DeviceCtx* ctx, const T* in_dptr, U* out_dptr, int64_t elem_num) {
if (elem_num == 0) { return; }
if (std::is_same<T, U>::value) {
Memcpy<DeviceType::kGPU>(ctx, out_dptr, in_dptr, elem_num * sizeof(T));
} else {
Expand All @@ -667,6 +668,7 @@ void CopyElemOnGpu(DeviceCtx* ctx, const T* in_dptr, U* out_dptr, int64_t elem_n
template<>
void CopyElemOnGpu<float, float16>(DeviceCtx* ctx, const float* in_dptr, float16* out_dptr,
int64_t elem_num) {
if (RoundUp(elem_num, 2) == 0) { return; }
CastOnGpu<float, half>
<<<BlocksNum4ThreadsNum(RoundUp(elem_num, 2) / 2), kCudaThreadsNumPerBlock, 0,
ctx->cuda_stream()>>>(in_dptr, reinterpret_cast<half*>(out_dptr), elem_num);
Expand All @@ -675,6 +677,7 @@ void CopyElemOnGpu<float, float16>(DeviceCtx* ctx, const float* in_dptr, float16
template<>
void CopyElemOnGpu<float16, float>(DeviceCtx* ctx, const float16* in_dptr, float* out_dptr,
int64_t elem_num) {
if (RoundUp(elem_num, 2) == 0) { return; }
CastOnGpu<half, float>
<<<BlocksNum4ThreadsNum(RoundUp(elem_num, 2) / 2), kCudaThreadsNumPerBlock, 0,
ctx->cuda_stream()>>>(reinterpret_cast<const half*>(in_dptr), out_dptr, elem_num);
Expand Down
1 change: 1 addition & 0 deletions oneflow/core/kernel/util/cuda_arithemetic_interface.cu
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ void LaunchTransposeGpu(DeviceCtx* ctx, const ShapeView& x_shape, const ShapeVie
cur_stride *= x_shape.At(i);
}
for (int32_t i = 0; i < NDIMS; ++i) { x_strides.val[i] = buff[permutation[i]]; }
if (elem_cnt == 0) { return; }
TransposeGpu<NDIMS, T>
<<<SMBlocksNum4ThreadsNum(elem_cnt), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>(
y_shape_struct, x_strides, elem_cnt, x, y);
Expand Down
1 change: 1 addition & 0 deletions oneflow/core/kernel/util/cuda_dnn_interface.cu
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ template<typename T>
struct ReluHelper final {
static void ReluForward(DeviceCtx* ctx, const int64_t n, const T* x, T* y) {
CHECK_LE(n, GetMaxVal<int32_t>() / 2);
if (n == 0) { return; }
if (x == y) {
InplaceReluForwardGpu<T>
<<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>(n, y);
Expand Down
2 changes: 2 additions & 0 deletions oneflow/core/ndarray/ndarray_apply_binary_core.cu
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,14 @@ struct NdarrayApplyBinaryCoreWrapper<DeviceType::kGPU, T, binary_func> final {
const XpuVarNdarray<typename BinaryFuncTrait<binary_func, T>::return_type>& y,
const XpuVarNdarray<const T>& a, const XpuVarNdarray<const T>& b) {
size_t n = y.host_shape().HostElemNum();
if (n == 0) { return; }
RUN_CUDA_KERNEL((NdarrayApplyBinaryApplyGpu<T, binary_func>), ctx, n, n, y.host_ptr(),
a.host_ptr(), b.host_ptr());
}
static void InplaceApply(DeviceCtx* ctx, const XpuVarNdarray<T>& y,
const XpuVarNdarray<const T>& x) {
size_t n = y.host_shape().HostElemNum();
if (n == 0) { return; }
RUN_CUDA_KERNEL((NdarrayApplyBinaryInplaceApplyGpu<T, binary_func>), ctx, n, n, y.host_ptr(),
x.host_ptr());
}
Expand Down
4 changes: 3 additions & 1 deletion oneflow/core/ndarray/ndarray_apply_broadcast_binary.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,9 @@ struct NdarrayApplyBroadcastBinary<
CHECK_EQ(y.shape().NumAxes(), a.shape().NumAxes());
CHECK_EQ(y.shape().NumAxes(), b.shape().NumAxes());
for (int i = 0; i < y.shape().NumAxes(); ++i) {
CHECK_EQ(y.shape().At(i), std::max(a.shape().At(i), b.shape().At(i)));
CHECK_EQ(y.shape().At(i), (a.shape().At(i) == 0 || b.shape().At(i) == 0)
? 0
: std::max(a.shape().At(i), b.shape().At(i)));
if (a.shape().At(i) != b.shape().At(i)) {
CHECK(a.shape().At(i) == 1 || b.shape().At(i) == 1);
}
Expand Down
2 changes: 2 additions & 0 deletions oneflow/core/ndarray/ndarray_apply_broadcast_binary_core.cu
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ struct NdarrayApplyBroadcastBinaryCoreWrapper<DeviceType::kGPU, T, NDIMS, binary
const XpuVarNdarray<typename BinaryFuncTrait<binary_func, T>::return_type>& y,
const XpuVarNdarray<const T>& a, const XpuVarNdarray<const T>& b) {
size_t n = y.host_shape().HostElemNum();
if (n == 0) { return; }
if (IsKernelSafeInt32(n) && PartialBroadcast<int32_t>(ctx, y, a, b)) { return; }
if (!IsKernelSafeInt32(n) && PartialBroadcast<int64_t>(ctx, y, a, b)) { return; }
RUN_CUDA_KERNEL((GpuBroadcastBinaryFunc<T, NDIMS, binary_func>), ctx, n, y, a, b);
Expand Down Expand Up @@ -151,6 +152,7 @@ struct NdarrayApplyBroadcastInplaceBinaryCoreWrapper<DeviceType::kGPU, T, NDIMS,
size_t n = y.host_shape().HostElemNum();
XpuVarNdarray<const T> a(y.host_shape(), y.host_ptr());
using NBB = NdarrayApplyBroadcastBinaryCoreWrapper<DeviceType::kGPU, T, NDIMS, binary_func>;
if (n == 0) { return; }
if (IsKernelSafeInt32(n) && NBB::template PartialBroadcast<int32_t>(ctx, y, a, x)) { return; }
if (!IsKernelSafeInt32(n) && NBB::template PartialBroadcast<int64_t>(ctx, y, a, x)) { return; }
RUN_CUDA_KERNEL((GpuInplaceBroadcastBinaryFunc<T, NDIMS, binary_func>), ctx, n, y, x);
Expand Down
1 change: 1 addition & 0 deletions oneflow/core/ndarray/ndarray_apply_broadcast_unary_core.cu
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ template<typename T, int NDIMS, template<typename> class unary_func>
struct NdarrayApplyBroadcastUnaryCoreWrapper<DeviceType::kGPU, T, NDIMS, unary_func> final {
static void Apply(DeviceCtx* ctx, const XpuVarNdarray<T>& y, const XpuVarNdarray<const T>& x) {
size_t n = y.host_shape().HostElemNum();
if (n == 0) { return; }
RUN_CUDA_KERNEL((GpuBroadcastUnaryFunc<T, NDIMS, unary_func>), ctx, n, y, x);
}
};
Expand Down
1 change: 1 addition & 0 deletions oneflow/core/ndarray/ndarray_apply_unary_core.cu
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ template<typename T, template<typename> class unary_func>
struct NdarrayApplyUnaryCoreWrapper<DeviceType::kGPU, T, unary_func> final {
static void InplaceApply(DeviceCtx* ctx, const XpuVarNdarray<T>& y) {
size_t n = y.host_shape().HostElemNum();
if (n == 0) { return; }
RUN_CUDA_KERNEL((NdarrayApplyUnaryInplaceApplyGpu<T, unary_func>), ctx, n, y.host_ptr(), n);
}
};
Expand Down
1 change: 1 addition & 0 deletions oneflow/core/ndarray/ndarray_assign_core.cu
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ struct NdarrayAssignCoreWrapper<DeviceType::kGPU, T, NDIMS> final {
static void Assign(DeviceCtx* ctx, const XpuVarNdarray<T>& y,
const XpuReducedNdarray<T, NDIMS>& reduced) {
size_t n = y.host_shape().HostElemNum();
if (n == 0) { return; }
RUN_CUDA_KERNEL((NdarrayAssignGpu<T, NDIMS>), ctx, n, y, reduced);
}
};
Expand Down
2 changes: 1 addition & 1 deletion oneflow/user/kernels/add_n_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ struct GpuAddCaller {
for (int32_t i = 0; i < N; ++i) {
para.in[i] = ctx->Tensor4ArgNameAndIndex("in", i)->dptr<T>();
}

if (n == 0) { return; }
gpu_add<T, N>
<<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0, ctx->device_ctx()->cuda_stream()>>>(
n, para);
Expand Down
1 change: 1 addition & 0 deletions oneflow/user/kernels/concat_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ class ConcatKernel final : public user_op::OpKernel {
for (const auto& in_arg_pair : ctx->inputs()) {
const user_op::Tensor* in_tensor =
ctx->Tensor4ArgNameAndIndex(in_arg_pair.first, in_arg_pair.second);
if (in_tensor->shape().elem_cnt() == 0) { continue; }
const int64_t in_cols = in_tensor->shape().Count(axis);
CHECK_EQ(in_tensor->shape().elem_cnt(), rows * in_cols);
if (in_cols > 0) {
Expand Down
3 changes: 2 additions & 1 deletion oneflow/user/kernels/constant_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ class ConstantKernel final : public OpKernel {
Tensor* out_tensor = ctx->Tensor4ArgNameAndIndex("out", 0);
bool is_floating_value = ctx->Attr<bool>("is_floating_value");
const int64_t elem_cnt = out_tensor->shape().elem_cnt();
CHECK_GT(elem_cnt, 0);
CHECK_GE(elem_cnt, 0);
if (elem_cnt == 0) { return; }
NewKernelUtil<device_type>::Fill(ctx->device_ctx(), elem_cnt,
is_floating_value
? static_cast<T>(ctx->Attr<double>("floating_value"))
Expand Down
4 changes: 0 additions & 4 deletions oneflow/user/kernels/empty_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,6 @@ class EmptyKernel final : public OpKernel {

private:
void Compute(user_op::KernelComputeContext* ctx) const override {
Tensor* out_tensor = ctx->Tensor4ArgNameAndIndex("out", 0);
const int64_t elem_cnt = out_tensor->shape().elem_cnt();
CHECK_GT(elem_cnt, 0);

// Do nothing
}
bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
Expand Down
6 changes: 6 additions & 0 deletions oneflow/user/kernels/math_binary_elementwise_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ class MathBinaryElementwiseGpuKernel final : public user_op::OpKernel {
user_op::Tensor* tensor_z = ctx->Tensor4ArgNameAndIndex("z", 0);
int64_t n = tensor_x->shape().elem_cnt();
CHECK_LE(n, GetMaxVal<int32_t>() / 2);
if (n == 0) { return; }
MathBinaryElementwiseForwardGpu<BinaryFunctor, T>
<<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0, ctx->device_ctx()->cuda_stream()>>>(
n, tensor_x->dptr<T>(), tensor_y->dptr<T>(), tensor_z->mut_dptr<T>());
Expand All @@ -73,6 +74,7 @@ class MathBinaryElementwiseXGradGpuKernel final : public user_op::OpKernel {
user_op::Tensor* tensor_dx = ctx->Tensor4ArgNameAndIndex("dx", 0);
int64_t n = tensor_x->shape().elem_cnt();
CHECK_LE(n, GetMaxVal<int32_t>() / 2);
if (n == 0) { return; }
MathBinaryElementwiseBackwardXGradGpu<BinaryFunctor, T>
<<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0, ctx->device_ctx()->cuda_stream()>>>(
n, tensor_x->dptr<T>(), tensor_y->dptr<T>(), tensor_dz->dptr<T>(),
Expand All @@ -95,6 +97,7 @@ class MathBinaryElementwiseYGradGpuKernel final : public user_op::OpKernel {
user_op::Tensor* tensor_dy = ctx->Tensor4ArgNameAndIndex("dy", 0);
int64_t n = tensor_x->shape().elem_cnt();
CHECK_LE(n, GetMaxVal<int32_t>() / 2);
if (n == 0) { return; }
MathBinaryElementwiseBackwardYGradGpu<BinaryFunctor, T>
<<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0, ctx->device_ctx()->cuda_stream()>>>(
n, tensor_x->dptr<T>(), tensor_y->dptr<T>(), tensor_dz->dptr<T>(),
Expand Down Expand Up @@ -143,6 +146,7 @@ class MathBinaryElementwiseGpuHalfKernel final : public user_op::OpKernel {
half* z = reinterpret_cast<half*>(tensor_z->mut_dptr<float16>());
int64_t n = tensor_x->shape().elem_cnt();
CHECK_LE(n, GetMaxVal<int32_t>() / 2);
if (n == 0) { return; }
MathBinaryElementwiseForwardGpu<BinaryFunctor, half>
<<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0, ctx->device_ctx()->cuda_stream()>>>(
n, x, y, z);
Expand All @@ -169,6 +173,7 @@ class MathBinaryElementwiseXGradGpuHalfKernel final : public user_op::OpKernel {
half* dx = reinterpret_cast<half*>(tensor_dx->mut_dptr<float16>());
int64_t n = tensor_x->shape().elem_cnt();
CHECK_LE(n, GetMaxVal<int32_t>() / 2);
if (n == 0) { return; }
MathBinaryElementwiseBackwardXGradGpu<BinaryFunctor, half>
<<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0, ctx->device_ctx()->cuda_stream()>>>(
n, x, y, dz, dx);
Expand All @@ -195,6 +200,7 @@ class MathBinaryElementwiseYGradGpuHalfKernel final : public user_op::OpKernel {
half* dy = reinterpret_cast<half*>(tensor_dy->mut_dptr<float16>());
int64_t n = tensor_x->shape().elem_cnt();
CHECK_LE(n, GetMaxVal<int32_t>() / 2);
if (n == 0) { return; }
MathBinaryElementwiseBackwardYGradGpu<BinaryFunctor, half>
<<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0, ctx->device_ctx()->cuda_stream()>>>(
n, x, y, dz, dy);
Expand Down
4 changes: 4 additions & 0 deletions oneflow/user/kernels/math_unary_elementwise_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class MathUnaryElementwiseGpuKernel final : public user_op::OpKernel {
T* y = tensor_y->mut_dptr<T>();
int64_t n = tensor_x->shape().elem_cnt();
CHECK_LE(n, GetMaxVal<int32_t>() / 2);
if (n == 0) { return; }
MathUnaryElementwiseForwardGpu<UnaryFunctor, T>
<<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0, ctx->device_ctx()->cuda_stream()>>>(
n, x, y);
Expand All @@ -70,6 +71,7 @@ class MathUnaryElementwiseGradGpuKernel final : public user_op::OpKernel {
T* dx = tensor_dx->mut_dptr<T>();
int64_t n = tensor_x->shape().elem_cnt();
CHECK_LE(n, GetMaxVal<int32_t>() / 2);
if (n == 0) { return; }
MathUnaryElementwiseBackwardGpu<UnaryFunctor, T>
<<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0, ctx->device_ctx()->cuda_stream()>>>(
n, x, dy, dx);
Expand Down Expand Up @@ -110,6 +112,7 @@ class MathUnaryElementwiseGpuHalfKernel final : public user_op::OpKernel {
half* y = reinterpret_cast<half*>(tensor_y->mut_dptr<float16>());
int64_t n = tensor_x->shape().elem_cnt();
CHECK_LE(n, GetMaxVal<int32_t>() / 2);
if (n == 0) { return; }
MathUnaryElementwiseForwardGpu<UnaryFunctor, half>
<<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0, ctx->device_ctx()->cuda_stream()>>>(
n, x, y);
Expand All @@ -134,6 +137,7 @@ class MathUnaryElementwiseGradGpuHalfKernel final : public user_op::OpKernel {
half* dx = reinterpret_cast<half*>(tensor_dx->mut_dptr<float16>());
int64_t n = tensor_x->shape().elem_cnt();
CHECK_LE(n, GetMaxVal<int32_t>() / 2);
if (n == 0) { return; }
MathUnaryElementwiseBackwardGpu<UnaryFunctor, half>
<<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0, ctx->device_ctx()->cuda_stream()>>>(
n, x, dy, dx);
Expand Down
11 changes: 11 additions & 0 deletions oneflow/user/kernels/reduce_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ limitations under the License.
#include "oneflow/core/framework/framework.h"
#include "oneflow/core/ndarray/ndarray_util.h"
#include "oneflow/core/ndarray/xpu_var_ndarray.h"
#include "oneflow/core/kernel/kernel_util.h"

namespace oneflow {

Expand All @@ -33,6 +34,16 @@ class ReduceKernel final : public user_op::OpKernel {
user_op::Tensor* output_tensor = ctx->Tensor4ArgNameAndIndex("output_tensor", 0);
user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0);
const auto& axis = ctx->Attr<std::vector<int32_t>>("axis");

if (input_tensor->shape().elem_cnt() == 0) {
if (output_tensor->shape().elem_cnt() != 0) {
AutoMemset(
ctx->device_ctx(), output_tensor->mut_dptr<T>(), 0,
output_tensor->shape().elem_cnt() * GetSizeOfDataType(output_tensor->data_type()),
output_tensor->mem_case());
}
wyg1997 marked this conversation as resolved.
Show resolved Hide resolved
return;
}
const Shape& reduced_shape =
CreateReducedShape(input_tensor->shape(), {axis.begin(), axis.end()});
NdarrayReduce<device_type, T, BinaryFunc>::Reduce(
Expand Down
8 changes: 8 additions & 0 deletions oneflow/user/kernels/reduce_like_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,14 @@ class ReduceSumLikeOpKernel final : public user_op::OpKernel {
user_op::Tensor* tensor_x = ctx->Tensor4ArgNameAndIndex("x", 0);
user_op::Tensor* tensor_y = ctx->Tensor4ArgNameAndIndex("y", 0);
const auto& axis = ctx->Attr<std::vector<int32_t>>("axis");
if (tensor_x->shape().elem_cnt() == 0) {
if (tensor_y->shape().elem_cnt() != 0) {
AutoMemset(ctx->device_ctx(), tensor_y->mut_dptr<T>(), 0,
tensor_y->shape().elem_cnt() * GetSizeOfDataType(tensor_y->data_type()),
tensor_y->mem_case());
}
return;
}
wyg1997 marked this conversation as resolved.
Show resolved Hide resolved
if (axis.empty()) {
CHECK_EQ(tensor_x->shape(), tensor_y->shape());
Memcpy<device_type>(ctx->device_ctx(), tensor_y->mut_dptr(), tensor_x->dptr(),
Expand Down
1 change: 1 addition & 0 deletions oneflow/user/kernels/slice_util.cu
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ void LaunchSliceForward(DeviceCtx* ctx, const SliceParams& params, const T* enti
int64_t elem_cnt = params.elem_cnt();
SliceIndexHelper<NDIM> entire_idx_cvtr(params.dims);
SliceIndexHelper<NDIM> sliced_idx_cvtr(params.size);
if (elem_cnt == 0) { return; }
SliceForwardGpu<T, NDIM>
<<<BlocksNum4ThreadsNum(elem_cnt), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>(
elem_cnt, params, entire_idx_cvtr, sliced_idx_cvtr, entire, sliced);
Expand Down
4 changes: 3 additions & 1 deletion oneflow/user/ops/math_binary_broadcast_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,9 @@ Maybe<void> InferTensorDescBinaryBroadcastNormal(user_op::InferContext* ctx) {
CHECK_OR_RETURN(x_shape.At(i) == 1 || y_shape.At(i) == 1 || x_shape.At(i) == y_shape.At(i))
<< "op: " << ctx->op_name() << ", type: " << ctx->op_type_name() << ", i: " << i
<< ", x_shape: " << x_shape << ", y_shape: " << y_shape;
out_shape.Set(i, std::max(x_shape.At(i), y_shape.At(i)));
out_shape.Set(i, (x_shape.At(i) == 0 || y_shape.At(i) == 0)
? 0
: std::max(x_shape.At(i), y_shape.At(i)));
}
*tensor_z->mut_shape() = out_shape;
}
Expand Down
2 changes: 1 addition & 1 deletion oneflow/user/ops/reshape_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ Maybe<void> LogicalTensorDescInferFn(user_op::InferContext* ctx) {
*out_tensor_desc = in_tensor_desc;
CHECK_GE_OR_RETURN(shape.NumAxes(), 1);
DimVector dim_vec = {shape.dim_vec().begin(), shape.dim_vec().end()};
FOR_RANGE(int32_t, i, 0, dim_vec.size()) { CHECK_GT_OR_RETURN(dim_vec.at(i), 0); }
FOR_RANGE(int32_t, i, 0, dim_vec.size()) { CHECK_GE_OR_RETURN(dim_vec.at(i), 0); }
*out_shape = Shape(dim_vec);
CHECK_EQ_OR_RETURN(out_shape->elem_cnt(), in_shape.elem_cnt());
return Maybe<void>::Ok();
Expand Down
10 changes: 6 additions & 4 deletions oneflow/user/ops/slice_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,16 @@ Maybe<void> InferSliceOpTensorDesc(user_op::InferContext* ctx) {
DimVector dim_vec(ndim);
FOR_RANGE(size_t, i, 0, dim_vec.size()) {
const int64_t dim_size = x_shape.At(i);
if (dim_size == 0) {
const int64_t step = step_vec.at(i);
int64_t start = start_vec.at(i);
int64_t stop = stop_vec.at(i);
if (dim_size == 0 || start == stop) {
dim_vec[i] = 0;
continue;
}
const int64_t step = step_vec.at(i);
CHECK_NE_OR_RETURN(step, 0) << "slice step cannot be 0";
int64_t start = RegulateSliceStart(start_vec.at(i), dim_size);
int64_t stop = RegulateSliceStop(stop_vec.at(i), dim_size);
start = RegulateSliceStart(start, dim_size);
stop = RegulateSliceStop(stop, dim_size);
if (step > 0) {
CHECK_LT_OR_RETURN(start, stop) << "slice start must be less than stop when step > 0"
", otherwise empty result will be outputted.";
Expand Down
3 changes: 2 additions & 1 deletion python/oneflow/framework/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ def _tensor_numpy(eager_local_tensor):
tuple(eager_local_tensor.shape),
dtype=flow.convert_oneflow_dtype_to_numpy_dtype(eager_local_tensor.dtype),
)
copy_to_numpy(ndarray)
if ndarray.size != 0:
copy_to_numpy(ndarray)
return ndarray


Expand Down
2 changes: 1 addition & 1 deletion python/oneflow/ops/array_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def check_slice_tup_list(slice_tup_list, shape):
if start is None:
start = 0 if step > 0 else np.iinfo(np.int64).max
elif start < -dim_size or start >= dim_size:
raise ValueError("slice start must be in range [-size, size)")
start, stop, step = 0, 0, 1
if stop is None:
stop = np.iinfo(np.int64).max if step > 0 else np.iinfo(np.int64).min
elif stop < -dim_size - 1 or stop > dim_size:
Expand Down