-
Notifications
You must be signed in to change notification settings - Fork 656
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add autotest tan tanh floor arctanh * add autotest tan floor tan * Add autotest for log1p * Code format * delete no use import * add eye op alignment * amend pytorch test * delete merge error branch * Delete log1p.py delete merge error branch * amend code because the list of master change * amend eye test code * amend eye docsting * amend eye docsting * amend eye docstring * autotest test_eye * auto format by CI * amend param of eye Co-authored-by: Zhenhua <huangzhenhua@zhejianglab.com> Co-authored-by: oneflow-ci-bot <ci-bot@oneflow.org> Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
- Loading branch information
1 parent
5b63e76
commit e81dafc
Showing
4 changed files
with
168 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,7 +12,8 @@ oneflow | |
clip, | ||
diag, | ||
enable_eager_execution, | ||
expand, | ||
expand, | ||
eye, | ||
flatten, | ||
function_config, | ||
gather, | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
""" | ||
Copyright 2020 The OneFlow Authors. All rights reserved. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
""" | ||
from typing import Union | ||
|
||
import oneflow as flow | ||
from oneflow.nn.module import Module | ||
from oneflow.framework.tensor import register_tensor_op | ||
|
||
|
||
class Eye(Module): | ||
def __init__( | ||
self, | ||
n: int, | ||
m: int = None, | ||
device: Union[str, flow.device] = "cpu", | ||
requires_grad: bool = False, | ||
) -> None: | ||
super().__init__() | ||
self.n = n | ||
self.m = m | ||
self.device = device | ||
self.requires_grad = requires_grad | ||
|
||
def forward(self): | ||
n = self.n | ||
m = self.m | ||
if m is None: | ||
m = n | ||
|
||
if m == n: | ||
res = flow.diag(flow.ones(n)) | ||
elif m < n: | ||
tmp = flow.ones(m) | ||
input1 = flow.zeros((n - m, m)) | ||
input2 = flow.diag(tmp) | ||
res = flow.cat([input2, input1], dim=0) | ||
else: | ||
tmp = flow.ones(n) | ||
input1 = flow.zeros((n, m - n)) | ||
input2 = flow.diag(tmp) | ||
res = flow.cat([input2, input1], dim=1) | ||
|
||
res.requires_grad = self.requires_grad | ||
if isinstance(self.device, str): | ||
device = flow.device(self.device) | ||
else: | ||
device = self.device | ||
re = res.to(device) | ||
return re | ||
|
||
|
||
def eye_op( | ||
n, m=None, device: Union[str, flow.device] = "cpu", requires_grad: bool = False, | ||
): | ||
"""This operator creates a 2-D Tensor with ones on the diagonal and zeros elsewhere. | ||
Args: | ||
n (int): the number of rows. | ||
m (Optional[int], optional): the number of colums with default being n. Defaults to None. | ||
Keyword args: | ||
device(flow.device, optional): the desired device of returned tensor. Default: if None, uses the current device for the default tensor. | ||
requires_grad(bool, optional): If autograd should record operations on the returned tensor. Default: `False`. | ||
Returns: | ||
oneflow.Tensor: The result Blob with ones on the diagonal and zeros elsewhere. | ||
For example: | ||
.. code-block:: python | ||
>>> import oneflow as flow | ||
>>> out = flow.eye(3, 3) | ||
>>> out | ||
tensor([[1., 0., 0.], | ||
[0., 1., 0.], | ||
[0., 0., 1.]], dtype=oneflow.float32) | ||
""" | ||
return Eye(n, m, device, requires_grad)() | ||
|
||
|
||
if __name__ == "__main__": | ||
import doctest | ||
|
||
doctest.testmod(raise_on_error=True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
""" | ||
Copyright 2020 The OneFlow Authors. All rights reserved. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
""" | ||
import unittest | ||
from collections import OrderedDict | ||
|
||
import numpy as np | ||
from automated_test_util import * | ||
from test_util import GenArgList | ||
|
||
import oneflow as flow | ||
|
||
|
||
def _test_eye_forward(test_case, device, n, m): | ||
output = flow.eye(n, m, device=device) | ||
np_out = np.eye(n, m) | ||
test_case.assertTrue(np.array_equal(output.numpy(), np_out)) | ||
|
||
|
||
def _test_eye_backward(test_case, device, n, m): | ||
x = flow.eye(n, m, device=device) | ||
x.requires_grad = True | ||
y = x.sum() | ||
y.backward() | ||
test_case.assertTrue(np.array_equal(x.grad.numpy(), np.ones([n, m]))) | ||
|
||
|
||
@flow.unittest.skip_unless_1n1d() | ||
class TestEye(flow.unittest.TestCase): | ||
def test_eye(test_case): | ||
arg_dict = OrderedDict() | ||
arg_dict["test_fun"] = [ | ||
_test_eye_forward, | ||
_test_eye_backward, | ||
] | ||
arg_dict["device"] = ["cpu", "cuda"] | ||
arg_dict["n"] = [4, 3, 2] | ||
arg_dict["m"] = [4, 3, 2] | ||
for arg in GenArgList(arg_dict): | ||
arg[0](test_case, *arg[1:]) | ||
|
||
@autotest() | ||
def test_eye_with_random_data(test_case): | ||
n = random().to(int) | ||
m = random().to(int) | ||
x = torch.eye(n=n, m=m) | ||
device = random_device() | ||
x.to(device) | ||
x = random_pytorch_tensor().to(device) | ||
return x | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |