Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tensorsplit_op #7258

Merged
merged 8 commits into from
Jan 27, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/source/oneflow.rst
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ oneflow
diag,
diagonal,
movedim,
tensor_split,
hsplit,
vsplit,
div,
dot,
eq,
Expand Down
21 changes: 21 additions & 0 deletions oneflow/core/functional/functional_api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -946,6 +946,27 @@
]
bind_python: True

- name: "tensor_split"
signature: [
"TensorTuple (Tensor input, Int32 indices_or_sections, Int32 dim=0) => TensorSplitInt",
"TensorTuple (Tensor input, Int32List indices_or_sections, Int32 dim=0) => TensorSplitVec",
]
bind_python: True

- name: "hsplit"
signature: [
"TensorTuple (Tensor input, Int32 indices_or_sections) => HsplitInt",
"TensorTuple (Tensor input, Int32List indices_or_sections) => HsplitVec",
]
bind_python: True

- name: "vsplit"
signature: [
"TensorTuple (Tensor input, Int32 indices_or_sections) => VsplitInt",
"TensorTuple (Tensor input, Int32List indices_or_sections) => VsplitVec",
]
bind_python: True

- name: "negative"
signature: "Tensor (Tensor x) => Negative"
bind_python: True
Expand Down
129 changes: 129 additions & 0 deletions oneflow/core/functional/impl/math_functor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1754,6 +1754,129 @@ class MovedimIntFunctor {
}
};

class TensorSplitVecFunctor {
public:
TensorSplitVecFunctor() = default;
Maybe<TensorTuple> operator()(const std::shared_ptr<one::Tensor>& input,
const std::vector<int32_t>& indices_or_sections,
const int32_t& dim) const {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

参考上个pr的comment:#7275 (comment)

int32_t ndim = input->shape()->NumAxes();
CHECK_OR_RETURN((dim>=-ndim)&&(dim<ndim))<< "Dimension out of range (expected to be in range of ["
<<-ndim<<","<< ndim-1 <<"], but got "<<dim<<")";
int32_t pos_dim = dim>=0?dim:dim+ndim;

std::vector<int64_t> start(ndim, 0);
std::vector<int64_t> stop(ndim);
std::vector<int64_t> step(ndim, 1);
for(int32_t i=0; i<ndim; i++){
stop[i] = input->shape()->At(i);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one::Tensor有个更简单的接口:input->dim(i)

reference:https://github.com/Oneflow-Inc/oneflow/blob/master/oneflow/core/framework/tensor.h#L49

}

int32_t num_indices = indices_or_sections.size();
TensorTuple output(num_indices+1);
for(int32_t i = 0; i < num_indices; i++){
int32_t end_idx = indices_or_sections[i];
stop[pos_dim] = end_idx;
output[i] = JUST(Slice(input, start, stop, step));
start[pos_dim] = end_idx;
}
stop[pos_dim] = input->shape()->At(ndim-1);
Copy link
Contributor

@wyushun wyushun Jan 24, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

output[num_indices] = JUST(Slice(input, start, stop, step));

return output;
}
};

class TensorSplitIntFunctor {
public:
TensorSplitIntFunctor() = default;
Maybe<TensorTuple> operator()(const std::shared_ptr<one::Tensor>& input,
const int32_t& indices_or_sections,
const int32_t& dim) const {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

int32_t ndim = input->shape()->NumAxes();
CHECK_OR_RETURN((dim>=-ndim)&&(dim<ndim))<< "Dimension out of range (expected to be in range of ["
<<-ndim<<","<< ndim-1 <<"], but got "<<dim<<")";
CHECK_OR_RETURN(indices_or_sections > 0) <<"number of sections must be larger than 0, got ,"<< indices_or_sections <<");";
int32_t pos_dim = dim>=0?dim:dim+ndim;

const auto dim_size = input->shape()->At(pos_dim);
int64_t min_split_size = dim_size / indices_or_sections;
int64_t num_splits_one_extra = dim_size % indices_or_sections;

std::vector<int64_t> start(ndim, 0);
std::vector<int64_t> stop(ndim);
std::vector<int64_t> step(ndim, 1);
for(int32_t i=0; i<ndim; i++){
stop[i] = input->shape()->At(i);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

}
stop[pos_dim] = 0;

TensorTuple output(indices_or_sections);
for(int32_t i = 0; i < indices_or_sections; i++){
int64_t split_size = (i < num_splits_one_extra) ? (min_split_size + 1) : min_split_size;
stop[pos_dim] += split_size;
output[i] = JUST(Slice(input, start, stop, step));
start[pos_dim] += split_size;
}

return output;
}
};

class HsplitIntFunctor {
public:
HsplitIntFunctor() = default;
Maybe<TensorTuple> operator()(const std::shared_ptr<one::Tensor>& input,
const int32_t& indices_or_sections) const {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

int32_t ndim = input->shape()->NumAxes();
Copy link
Contributor

@wyushun wyushun Jan 24, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CHECK_OR_RETURN(ndim>=1)<<"torch.hsplit requires a tensor with at least 1 dimension, but got a tensor with "<<ndim <<" dimensions!";
CHECK_OR_RETURN(indices_or_sections>0) << "indices_or_sections must greater than 0";
int32_t dim = (ndim == 1) ? 0 : 1;
CHECK_OR_RETURN(input->shape()->At(dim)% indices_or_sections == 0) << "torch.hsplit attempted to split along dimension " << dim
<<", but the size of the dimension " << input->shape()->At(dim) <<
" is not divisible by the split_size " <<indices_or_sections<< "!";
return TensorSplitInt(input, indices_or_sections, dim);
}
};

class HsplitVecFunctor {
public:
HsplitVecFunctor() = default;
Maybe<TensorTuple> operator()(const std::shared_ptr<one::Tensor>& input,
const std::vector<int32_t>& indices_or_sections) const {
int32_t ndim = input->shape()->NumAxes();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CHECK_OR_RETURN(ndim>=1)<<"torch.hsplit requires a tensor with at least 1 dimension, but got a tensor with "<<ndim <<" dimensions!";
int32_t dim = (ndim == 1) ? 0 : 1;
return TensorSplitVec(input, indices_or_sections, dim);
}
};

class VsplitIntFunctor {
public:
VsplitIntFunctor() = default;
Maybe<TensorTuple> operator()(const std::shared_ptr<one::Tensor>& input,
const int32_t& indices_or_sections) const {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

int32_t ndim = input->shape()->NumAxes();
CHECK_OR_RETURN(ndim>=2)<<"torch.vsplit requires a tensor with at least 2 dimension, but got a tensor with "<<ndim <<" dimensions!";
CHECK_OR_RETURN(indices_or_sections>0) << "indices_or_sections must greater than 0";
CHECK_OR_RETURN(input->shape()->At(0)% indices_or_sections == 0) << "torch.vsplit attempted to split along dimension " << 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

consider input->dim()

<<", but the size of the dimension " << input->shape()->At(0) <<
" is not divisible by the split_size " <<indices_or_sections<< "!";
return TensorSplitInt(input, indices_or_sections, 0);
}
};

class VsplitVecFunctor {
public:
VsplitVecFunctor() = default;
Maybe<TensorTuple> operator()(const std::shared_ptr<one::Tensor>& input,
const std::vector<int32_t>& indices_or_sections) const {
int32_t ndim = input->shape()->NumAxes();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

consider input->ndim()

CHECK_OR_RETURN(ndim>=2)<<"torch.vsplit requires a tensor with at least 1 dimension, but got a tensor with "<<ndim <<" dimensions!";
return TensorSplitVec(input, indices_or_sections, 0);
}
};

class ErfinvFunctor {
public:
ErfinvFunctor() { op_ = CHECK_JUST(one::OpBuilder("erfinv").Input("x").Output("y").Build()); }
Expand Down Expand Up @@ -1886,6 +2009,12 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
m.add_functor<DotFunctor>("Dot");
m.add_functor<MovedimVecFunctor>("MovedimVec");
m.add_functor<MovedimIntFunctor>("MovedimInt");
m.add_functor<TensorSplitVecFunctor>("TensorSplitVec");
m.add_functor<TensorSplitIntFunctor>("TensorSplitInt");
m.add_functor<HsplitIntFunctor>("HsplitInt");
m.add_functor<HsplitVecFunctor>("HsplitVec");
m.add_functor<VsplitIntFunctor>("VsplitInt");
m.add_functor<VsplitVecFunctor>("VsplitVec");
m.add_functor<ErfinvFunctor>("Erfinv");
m.add_functor<ErfinvInplaceFunctor>("ErfinvInplace");
m.add_functor<CumsumFunctor>("Cumsum");
Expand Down
3 changes: 3 additions & 0 deletions python/oneflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,9 @@ def is_deprecated(func_or_class):
from oneflow._C import narrow
from oneflow._C import unsqueeze
from oneflow._C import permute
from oneflow._C import tensor_split
from oneflow._C import hsplit
from oneflow._C import vsplit
from oneflow._C import concat
from oneflow._C import concat as cat
from oneflow._C import to
Expand Down
124 changes: 124 additions & 0 deletions python/oneflow/framework/docstr/math_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1296,6 +1296,130 @@
""",
)

add_docstr(
oneflow.tensor_split,
r"""
Splits a tensor into multiple sub-tensors, all of which are views of input, along dimension
dim according to the indices or number of sections specified by indices_or_sections .
The documentation is referenced from:
https://pytorch.org/docs/stable/generated/torch.tensor_split.html#torch.tensor_split

Args:
input (Tensor): the input tensor.
indices_or_sections (int or a list): If indices_or_sections is an integer n , input is split into n sections
along dimension dim.If input is divisible by n along dimension dim, each section will be of equal size,
input.size (dim) / n. If input is not divisible by n, the sizes of the first int(input.size(dim) % n).
sections will have size int(input.size(dim) / n) + 1, and the rest will have size int(input.size(dim) / n).
If indices_or_sections is a list or tuple of ints, then input is split along dimension dim at each of the indices in
the list, tuple or tensor. For instance, indices_or_sections=[2, 3] and dim=0 would result in the tensors
input[:2], input[2:3], and input[3:].If indices_or_sections is a tensor, it must be a zero-dimensional or
one-dimensional long tensor on the CPU.
dim (int): dimension along which to split the tensor.

Returns:
oneflow.TensorTuple: the output Tensor.

For example:

.. code-block:: python

>>> import oneflow as flow

>>> input = flow.rand(3,4,5)
>>> output = flow.tensor_split(input,(2,3),2)
>>> output[0].size()
oneflow.Size([3, 4, 2])
>>> output[1].size()
oneflow.Size([3, 4, 1])
>>> output[2].size()
oneflow.Size([3, 4, 2])
""",
)

add_docstr(
oneflow.hsplit,
r"""
Splits input, a tensor with one or more dimensions, into multiple tensors horizontally according to indices_or_sections.
Each split is a view of input.
If input is one dimensional this is equivalent to calling torch.tensor_split(input, indices_or_sections, dim=0)
(the split dimension is zero), and if input has two or more dimensions it’s equivalent to calling
torch.tensor_split(input, indices_or_sections, dim=1) (the split dimension is 1), except that if indices_or_sections
is an integer it must evenly divide the split dimension or a runtime error will be thrown.

Args:
input (Tensor): the input tensor.
indices_or_sections (int or a list): If indices_or_sections is an integer n , input is split into n sections
along dimension dim.If input is divisible by n along dimension dim, each section will be of equal size,
input.size (dim) / n. If input is not divisible by n, the sizes of the first int(input.size(dim) % n).
sections will have size int(input.size(dim) / n) + 1, and the rest will have size int(input.size(dim) / n).
If indices_or_sections is a list or tuple of ints, then input is split along dimension dim at each of the indices in
the list, tuple or tensor. For instance, indices_or_sections=[2, 3] and dim=0 would result in the tensors
input[:2], input[2:3], and input[3:].If indices_or_sections is a tensor, it must be a zero-dimensional or
one-dimensional long tensor on the CPU.

Returns:
oneflow.TensorTuple: the output Tensor.

For example:

.. code-block:: python

>>> import oneflow as flow

>>> input = flow.rand(3,4,5,6)
>>> output = flow.hsplit(input,(1,3))
>>> output[0].size()
oneflow.Size([3, 1, 5, 6])
>>> output[1].size()
oneflow.Size([3, 2, 5, 6])
>>> output[2].size()
oneflow.Size([3, 1, 5, 6])
>>> output[3].size()
""",
)

add_docstr(
oneflow.vsplit,
r"""
Splits input, a tensor with two or more dimensions, into multiple tensors vertically according to indices_or_sections.
Each split is a view of input.
This is equivalent to calling torch.tensor_split(input, indices_or_sections, dim=0) (the split dimension is 0),
except that if indices_or_sections is an integer it must evenly divide the split dimension or a runtime error will be thrown.

Args:
input (Tensor): the input tensor.
indices_or_sections (int or a list): If indices_or_sections is an integer n , input is split into n sections
along dimension dim.If input is divisible by n along dimension dim, each section will be of equal size,
input.size (dim) / n. If input is not divisible by n, the sizes of the first int(input.size(dim) % n).
sections will have size int(input.size(dim) / n) + 1, and the rest will have size int(input.size(dim) / n).
If indices_or_sections is a list or tuple of ints, then input is split along dimension dim at each of the indices in
the list, tuple or tensor. For instance, indices_or_sections=[2, 3] and dim=0 would result in the tensors
input[:2], input[2:3], and input[3:].If indices_or_sections is a tensor, it must be a zero-dimensional or
one-dimensional long tensor on the CPU.

Returns:
oneflow.TensorTuple: the output Tensor.

For example:

.. code-block:: python

>>> import oneflow as flow

>>> input = flow.rand(3,4,5,6)
>>> output = flow.vsplit(input,(1,3))
>>> output[0].size()
oneflow.Size([1, 4, 5, 6])
>>> output[1].size()
oneflow.Size([2, 4, 5, 6])
>>> output[2].size()
oneflow.Size([1, 4, 5, 6])
>>> output[3].size()
""",
)



add_docstr(
oneflow.eye,
"""oneflow.eye(n, m, *, device=None, requires_grad=False, placement=None, sbp) -> Tensor
Expand Down
53 changes: 53 additions & 0 deletions python/oneflow/test/modules/test_hsplit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
"""
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
from oneflow.test_utils.automated_test_util import *
import oneflow as flow
import oneflow.unittest


class TestHsplitVec(flow.unittest.TestCase):
@autotest(check_graph=False)
def test_flow_hsplit_vec(test_case):
device = random_device()
x = random_pytorch_tensor(
ndim=4,
dim1=random(3, 6),
dim2=random(3, 6),
dim3=random(3, 6),
dim4=random(3, 6),
).to(device)
z = torch.hsplit(x, (1,2))
return z[0]

class TestHsplitInt(flow.unittest.TestCase):
@autotest(check_graph=False)
def test_flow_hsplit_int(test_case):
device = random_device()
x = random_pytorch_tensor(
ndim=4,
dim1=random(3, 6),
dim2=random(3, 6),
dim3=random(3, 6),
dim4=random(3, 6),
).to(device)
split = random(1, 3).to(int)
z = torch.hsplit(x, split)
return z[0]


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

参考一下这个:#7275 (comment)

if __name__ == "__main__":
unittest.main()
Loading