Skip to content

Commit

Permalink
Dev eye (#5583)
Browse files Browse the repository at this point in the history
* add autotest tan tanh floor arctanh

* add autotest tan floor tan

* Add autotest  for log1p

* Code format

* delete no use import

* add eye op alignment

* amend pytorch test

* delete merge error branch

* Delete log1p.py

delete merge error branch

* amend code because the list of master change

* amend eye test code

* amend eye docsting

* amend eye docsting

* amend eye docstring

* autotest test_eye

* auto format by CI

* amend param of eye

Co-authored-by: Zhenhua <huangzhenhua@zhejianglab.com>
Co-authored-by: oneflow-ci-bot <ci-bot@oneflow.org>
Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
  • Loading branch information
4 people committed Jul 30, 2021
1 parent 5b63e76 commit e81dafc
Show file tree
Hide file tree
Showing 4 changed files with 168 additions and 1 deletion.
3 changes: 2 additions & 1 deletion docs/source/oneflow.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ oneflow
clip,
diag,
enable_eager_execution,
expand,
expand,
eye,
flatten,
function_config,
gather,
Expand Down
1 change: 1 addition & 0 deletions python/oneflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,7 @@ def Sync():
from oneflow.nn.modules.squeeze import squeeze_op as squeeze
from oneflow.nn.modules.stack import stack
from oneflow.nn.modules.tan import tan_op as tan
from oneflow.nn.modules.eye import eye_op as eye
from oneflow.nn.modules.tensor_buffer import gen_tensor_buffer
from oneflow.nn.modules.tensor_buffer import (
tensor_buffer_to_tensor_op as tensor_buffer_to_tensor,
Expand Down
99 changes: 99 additions & 0 deletions python/oneflow/nn/modules/eye.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from typing import Union

import oneflow as flow
from oneflow.nn.module import Module
from oneflow.framework.tensor import register_tensor_op


class Eye(Module):
def __init__(
self,
n: int,
m: int = None,
device: Union[str, flow.device] = "cpu",
requires_grad: bool = False,
) -> None:
super().__init__()
self.n = n
self.m = m
self.device = device
self.requires_grad = requires_grad

def forward(self):
n = self.n
m = self.m
if m is None:
m = n

if m == n:
res = flow.diag(flow.ones(n))
elif m < n:
tmp = flow.ones(m)
input1 = flow.zeros((n - m, m))
input2 = flow.diag(tmp)
res = flow.cat([input2, input1], dim=0)
else:
tmp = flow.ones(n)
input1 = flow.zeros((n, m - n))
input2 = flow.diag(tmp)
res = flow.cat([input2, input1], dim=1)

res.requires_grad = self.requires_grad
if isinstance(self.device, str):
device = flow.device(self.device)
else:
device = self.device
re = res.to(device)
return re


def eye_op(
n, m=None, device: Union[str, flow.device] = "cpu", requires_grad: bool = False,
):
"""This operator creates a 2-D Tensor with ones on the diagonal and zeros elsewhere.
Args:
n (int): the number of rows.
m (Optional[int], optional): the number of colums with default being n. Defaults to None.
Keyword args:
device(flow.device, optional): the desired device of returned tensor. Default: if None, uses the current device for the default tensor.
requires_grad(bool, optional): If autograd should record operations on the returned tensor. Default: `False`.
Returns:
oneflow.Tensor: The result Blob with ones on the diagonal and zeros elsewhere.
For example:
.. code-block:: python
>>> import oneflow as flow
>>> out = flow.eye(3, 3)
>>> out
tensor([[1., 0., 0.],
[0., 1., 0.],
[0., 0., 1.]], dtype=oneflow.float32)
"""
return Eye(n, m, device, requires_grad)()


if __name__ == "__main__":
import doctest

doctest.testmod(raise_on_error=True)
66 changes: 66 additions & 0 deletions python/oneflow/test/modules/test_eye.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
from collections import OrderedDict

import numpy as np
from automated_test_util import *
from test_util import GenArgList

import oneflow as flow


def _test_eye_forward(test_case, device, n, m):
output = flow.eye(n, m, device=device)
np_out = np.eye(n, m)
test_case.assertTrue(np.array_equal(output.numpy(), np_out))


def _test_eye_backward(test_case, device, n, m):
x = flow.eye(n, m, device=device)
x.requires_grad = True
y = x.sum()
y.backward()
test_case.assertTrue(np.array_equal(x.grad.numpy(), np.ones([n, m])))


@flow.unittest.skip_unless_1n1d()
class TestEye(flow.unittest.TestCase):
def test_eye(test_case):
arg_dict = OrderedDict()
arg_dict["test_fun"] = [
_test_eye_forward,
_test_eye_backward,
]
arg_dict["device"] = ["cpu", "cuda"]
arg_dict["n"] = [4, 3, 2]
arg_dict["m"] = [4, 3, 2]
for arg in GenArgList(arg_dict):
arg[0](test_case, *arg[1:])

@autotest()
def test_eye_with_random_data(test_case):
n = random().to(int)
m = random().to(int)
x = torch.eye(n=n, m=m)
device = random_device()
x.to(device)
x = random_pytorch_tensor().to(device)
return x


if __name__ == "__main__":
unittest.main()

0 comments on commit e81dafc

Please sign in to comment.