-
Notifications
You must be signed in to change notification settings - Fork 756
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add_Tensor.T_and_Tensor.t()_ops #7269
Changes from 1 commit
cb31153
7d40973
a9f32ae
3885c6d
2c4c5f4
e20adb5
4e238f3
bc3e571
110d759
92adf3d
6df17c8
50865f0
83307a0
6b1f075
a247d13
79bd17d
7faf4bf
1de25a3
79e7f8f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -138,6 +138,7 @@ oneflow | |
tile, | ||
to, | ||
transpose, | ||
t, | ||
tril, | ||
unsqueeze, | ||
permute, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -161,6 +161,8 @@ OneFlow Tensor Class | |
tril, | ||
triu, | ||
type_as, | ||
t, | ||
T, | ||
unfold, | ||
uniform_, | ||
unsqueeze, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2523,6 +2523,35 @@ class GenTensorBufferFunctor { | |
std::shared_ptr<OpExpr> op_; | ||
}; | ||
|
||
class TransposeAllDimPropertyFunctor { | ||
public: | ||
TransposeAllDimPropertyFunctor() {} | ||
Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x) const { | ||
const int64_t ndim = x->shape()->NumAxes(); | ||
std::vector<int32_t> permute; | ||
permute.resize(ndim); | ||
std::iota(permute.begin(), permute.end(), 0); | ||
std::reverse(permute.begin(), permute.end()); | ||
return Transpose(x, permute); | ||
} | ||
}; | ||
|
||
class TransposeAllDimFunctionFunctor { | ||
public: | ||
TransposeAllDimFunctionFunctor() {} | ||
Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x) const { | ||
const int64_t ndim = x->shape()->NumAxes(); | ||
CHECK_OR_RETURN(ndim <= 2) | ||
<< "RuntimeError: t() expects a tensor with <= 2 dimensions, but self is " | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
<< ndim << "D"; | ||
std::vector<int32_t> permute; | ||
permute.resize(ndim); | ||
std::iota(permute.begin(), permute.end(), 0); | ||
std::reverse(permute.begin(), permute.end()); | ||
return Transpose(x, permute); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里直接Transpose(x, 0, 1)和这么写有区别吗 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 输入如果是0维或者1维tensor,这样应该不能处理吧 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 我的意思是这个permute不就是一个0一个1两个元素吗,是不是直接传有0和1的vector更快点。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 好的,我改一下 |
||
} | ||
}; | ||
|
||
} // namespace impl | ||
|
||
ONEFLOW_FUNCTION_LIBRARY(m) { | ||
|
@@ -2626,6 +2655,8 @@ ONEFLOW_FUNCTION_LIBRARY(m) { | |
m.add_functor<impl::TensorToTensorBufferFunctor>("TensorToTensorBuffer"); | ||
m.add_functor<impl::TensorBufferToTensorFunctor>("TensorBufferToTensor"); | ||
m.add_functor<impl::GenTensorBufferFunctor>("GenTensorBuffer"); | ||
m.add_functor<impl::TransposeAllDimPropertyFunctor>("TransposeAllDimProperty"); | ||
m.add_functor<impl::TransposeAllDimFunctionFunctor>("TransposeAllDimFunction"); | ||
}; | ||
|
||
} // namespace functional | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -45,3 +45,4 @@ | |
from .clamp import * | ||
from .erfinv import * | ||
from .swapaxes import * | ||
from .tensor_t import * |
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
@@ -0,0 +1,61 @@ | ||||
""" | ||||
Copyright 2020 The OneFlow Authors. All rights reserved. | ||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); | ||||
you may not use this file except in compliance with the License. | ||||
You may obtain a copy of the License at | ||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0 | ||||
|
||||
Unless required by applicable law or agreed to in writing, software | ||||
distributed under the License is distributed on an "AS IS" BASIS, | ||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
See the License for the specific language governing permissions and | ||||
limitations under the License. | ||||
""" | ||||
import oneflow | ||||
from oneflow.framework.docstr.utils import add_docstr | ||||
|
||||
add_docstr( | ||||
oneflow.t, | ||||
""" | ||||
oneflow.t(input) → Tensor. | ||||
|
||||
Expects `input` to be <= 2-D tensor and transposes dimensions 0 and 1. | ||||
|
||||
0-D and 1-D tensors are returned as is. When input is a 2-D tensor this is equivalent to `transpose(input, 0, 1)`. | ||||
|
||||
Args: | ||||
input (oneflow.Tensor): An input tensor. | ||||
|
||||
For example: | ||||
|
||||
.. code-block:: python | ||||
|
||||
>>> import oneflow as flow | ||||
>>> import numpy as np | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||
|
||||
>>> x = flow.randn() | ||||
>>> x | ||||
tensor(-0.2048, dtype=oneflow.float32) | ||||
>>> flow.t(x) | ||||
tensor(-0.2048, dtype=oneflow.float32) | ||||
>>> x = flow.randn(3) | ||||
>>> x | ||||
tensor([ 0.5034, -0.4999, 0.2721], dtype=oneflow.float32) | ||||
>>> flow.t(x) | ||||
tensor([ 0.5034, -0.4999, 0.2721], dtype=oneflow.float32) | ||||
>>> x = flow.randn(2,3) | ||||
>>> x | ||||
tensor([[ 0.1939, 0.6988, 1.0040], | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 你怎么保证换一台机器随机到的数还是这个呢? |
||||
[-0.2530, -1.5002, 0.1415]], dtype=oneflow.float32) | ||||
>>> y = flow.t(x) | ||||
>>> y | ||||
tensor([[ 0.1939, -0.2530], | ||||
[ 0.6988, -1.5002], | ||||
[ 1.0040, 0.1415]], dtype=oneflow.float32) | ||||
>>> y.shape | ||||
oneflow.Size([3, 2]) | ||||
|
||||
""", | ||||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
""" | ||
Copyright 2020 The OneFlow Authors. All rights reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
""" | ||
|
||
import unittest | ||
from collections import OrderedDict | ||
|
||
import numpy as np | ||
|
||
from oneflow.test_utils.automated_test_util import * | ||
from test_util import GenArgList | ||
|
||
import oneflow as flow | ||
import oneflow.unittest | ||
|
||
|
||
@flow.unittest.skip_unless_1n1d() | ||
class TestTransposeAllDimFunction(flow.unittest.TestCase): | ||
@autotest(check_graph=True) | ||
def test_t_flow_with_random_data(test_case): | ||
device = random_device() | ||
x = random_pytorch_tensor(ndim=random(1, 2)).to(device) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 既然只支持2维tensor,应该直接设置ndim为constant dim0和dim1的随机范围可以调大,还要测试下是否支持0shape tensor (random是左闭右开) ndim=constant(2).to(int), dim0=random(0, 64), dim1=random(0, 64) |
||
y = torch.t(x) | ||
return y | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -899,6 +899,20 @@ def test_transpose_tensor_with_random_data(test_case): | |
x = random_pytorch_tensor(ndim=4).to(device) | ||
y = x.transpose(dim0=random(1, 3).to(int), dim1=random(1, 3).to(int)) | ||
return y | ||
|
||
@autotest(check_graph=True) | ||
def test_t_tensor_with_random_data(test_case): | ||
device = random_device() | ||
x = random_pytorch_tensor(ndim=random(1, 2)).to(device) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 同上 |
||
y = x.t() | ||
return y | ||
|
||
@autotest(check_graph=True) | ||
def test_T_tensor_with_random_data(test_case): | ||
device = random_device() | ||
x = random_pytorch_tensor(ndim=random(1, 4)).to(device) | ||
y = x.T | ||
return y | ||
|
||
@flow.unittest.skip_unless_1n1d() | ||
def test_tensor_where(test_case): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
有一个ndim的方法已经帮你写好这部分了,可以直接
可以参考 https://github.com/Oneflow-Inc/oneflow/blob/master/oneflow/core/framework/tensor.h#L51