-
Notifications
You must be signed in to change notification settings - Fork 660
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
tensorsplit_op #7258
tensorsplit_op #7258
Conversation
lcylcy
commented
Jan 14, 2022
•
edited
Loading
edited
TensorSplitVecFunctor() = default; | ||
Maybe<TensorTuple> operator()(const std::shared_ptr<one::Tensor>& input, | ||
const std::vector<int32_t>& indices_or_sections, | ||
const int32_t& dim) const { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
参考上个pr的comment:#7275 (comment)
std::vector<int64_t> stop(ndim); | ||
std::vector<int64_t> step(ndim, 1); | ||
for(int32_t i=0; i<ndim; i++){ | ||
stop[i] = input->shape()->At(i); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
one::Tensor有个更简单的接口:input->dim(i)
reference:https://github.com/Oneflow-Inc/oneflow/blob/master/oneflow/core/framework/tensor.h#L49
output[i] = JUST(Slice(input, start, stop, step)); | ||
start[pos_dim] = end_idx; | ||
} | ||
stop[pos_dim] = input->shape()->At(ndim-1); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
const int32_t& indices_or_sections, | ||
const int32_t& dim) const { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
std::vector<int64_t> stop(ndim); | ||
std::vector<int64_t> step(ndim, 1); | ||
for(int32_t i=0; i<ndim; i++){ | ||
stop[i] = input->shape()->At(i); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
HsplitIntFunctor() = default; | ||
Maybe<TensorTuple> operator()(const std::shared_ptr<one::Tensor>& input, | ||
const int32_t& indices_or_sections) const { | ||
int32_t ndim = input->shape()->NumAxes(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
public: | ||
HsplitIntFunctor() = default; | ||
Maybe<TensorTuple> operator()(const std::shared_ptr<one::Tensor>& input, | ||
const int32_t& indices_or_sections) const { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
HsplitVecFunctor() = default; | ||
Maybe<TensorTuple> operator()(const std::shared_ptr<one::Tensor>& input, | ||
const std::vector<int32_t>& indices_or_sections) const { | ||
int32_t ndim = input->shape()->NumAxes(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
public: | ||
VsplitIntFunctor() = default; | ||
Maybe<TensorTuple> operator()(const std::shared_ptr<one::Tensor>& input, | ||
const int32_t& indices_or_sections) const { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
int32_t ndim = input->shape()->NumAxes(); | ||
CHECK_OR_RETURN(ndim>=2)<<"torch.vsplit requires a tensor with at least 2 dimension, but got a tensor with "<<ndim <<" dimensions!"; | ||
CHECK_OR_RETURN(indices_or_sections>0) << "indices_or_sections must greater than 0"; | ||
CHECK_OR_RETURN(input->shape()->At(0)% indices_or_sections == 0) << "torch.vsplit attempted to split along dimension " << 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
consider input->dim()
VsplitVecFunctor() = default; | ||
Maybe<TensorTuple> operator()(const std::shared_ptr<one::Tensor>& input, | ||
const std::vector<int32_t>& indices_or_sections) const { | ||
int32_t ndim = input->shape()->NumAxes(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
consider input->ndim()
class TestHsplitVec(flow.unittest.TestCase): | ||
@autotest(check_graph=False) | ||
def test_flow_hsplit_vec(test_case): | ||
device = random_device() | ||
x = random_pytorch_tensor( | ||
ndim=4, | ||
dim1=random(3, 6), | ||
dim2=random(3, 6), | ||
dim3=random(3, 6), | ||
dim4=random(3, 6), | ||
).to(device) | ||
z = torch.hsplit(x, (1,2)) | ||
return z[0] | ||
|
||
class TestHsplitInt(flow.unittest.TestCase): | ||
@autotest(check_graph=False) | ||
def test_flow_hsplit_int(test_case): | ||
device = random_device() | ||
x = random_pytorch_tensor( | ||
ndim=4, | ||
dim1=random(3, 6), | ||
dim2=random(3, 6), | ||
dim3=random(3, 6), | ||
dim4=random(3, 6), | ||
).to(device) | ||
split = random(1, 3).to(int) | ||
z = torch.hsplit(x, split) | ||
return z[0] | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
参考一下这个:#7275 (comment)
class TestTorchSplitVec(flow.unittest.TestCase): | ||
@autotest(check_graph=False) | ||
def test_flow_tensor_split_vec(test_case): | ||
device = random_device() | ||
x = random_pytorch_tensor( | ||
ndim=4, | ||
dim1=random(3, 6), | ||
dim2=random(3, 6), | ||
dim3=random(3, 6), | ||
dim4=random(3, 6), | ||
).to(device) | ||
dim = random(-3, 3).to(int) | ||
z = torch.tensor_split(x, (1,2),dim) | ||
return z[0] | ||
|
||
class TestTorchSplitInt(flow.unittest.TestCase): | ||
@autotest(check_graph=False) | ||
def test_flow_tensor_split_int(test_case): | ||
device = random_device() | ||
x = random_pytorch_tensor( | ||
ndim=4, | ||
dim1=random(3, 6), | ||
dim2=random(3, 6), | ||
dim3=random(3, 6), | ||
dim4=random(3, 6), | ||
).to(device) | ||
split = random(-3, 3).to(int) | ||
dim = random(-3, 3).to(int) | ||
z = torch.tensor_split(x, split,dim) | ||
return z[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
class TestVsplitVec(flow.unittest.TestCase): | ||
@autotest(check_graph=False) | ||
def test_flow_vsplit_vec(test_case): | ||
device = random_device() | ||
x = random_pytorch_tensor( | ||
ndim=4, | ||
dim1=random(3, 6), | ||
dim2=random(3, 6), | ||
dim3=random(3, 6), | ||
dim4=random(3, 6), | ||
).to(device) | ||
z = torch.vsplit(x, (1,2)) | ||
return z[0] | ||
|
||
class TestVsplitInt(flow.unittest.TestCase): | ||
@autotest(check_graph=False) | ||
def test_flow_vsplit_int(test_case): | ||
device = random_device() | ||
x = random_pytorch_tensor( | ||
ndim=4, | ||
dim1=random(3, 6), | ||
dim2=random(3, 6), | ||
dim3=random(3, 6), | ||
dim4=random(3, 6), | ||
).to(device) | ||
split = random(1, 3).to(int) | ||
z = torch.vsplit(x, split) | ||
return z[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
review done,写的太棒了,我写了一些comments,大多都和上个pr(implement as strided)的雷同,你看情况自己酌情修改哈,我直接给你approve了,这样可以尽可能提高效率,你自己做足测试保证正确性就好~ @lcylcy
好的 |
Code got formatted by CI. Please request CI again if you still want to have this PR merged. If the PR is from a forked repo, please download the patch files from the GitHub Actions web page and apply them locally. |
CI failed when running job: cuda-module. PR label automerge has been removed |
7c016a3
to
b41b8b2
Compare
Speed stats:
|