Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor: remove module not required, call function directly #5754

Merged
merged 30 commits into from
Aug 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
ed717b3
gather mean, sum
doombeaker Aug 5, 2021
f33658c
abs, acos, acosh, slice
doombeaker Aug 5, 2021
7843f39
abs, acosh, argwhere, bmm, chunk, concat, slice
doombeaker Aug 5, 2021
0aa2122
constant, diag, eq, exp, expand, eye, flip, floor, gathr_nd, >=
doombeaker Aug 5, 2021
73823bb
greater, less, less_equal, maksed_fill, masked_select
doombeaker Aug 5, 2021
916b3de
asin, asinh, sin, cos, atan
doombeaker Aug 5, 2021
c2ab04c
std and related ops
doombeaker Aug 5, 2021
2ad3bbd
matmul, ne, negative, permute, repeat, ceil, expm1
doombeaker Aug 5, 2021
54e0933
reshape, round, sign, sinh, squeeze, stack, tile etc.
doombeaker Aug 5, 2021
ffd727f
Merge branch 'master' into refactor_moudle2functional
doombeaker Aug 5, 2021
f306147
Merge branch 'master' into refactor_moudle2functional
doombeaker Aug 6, 2021
d7576d9
Merge branch 'master' into refactor_moudle2functional
oneflow-ci-bot Aug 6, 2021
13d51fa
Merge branch 'master' into refactor_moudle2functional
oneflow-ci-bot Aug 6, 2021
f7cba92
fix argwhere default arg
doombeaker Aug 6, 2021
b7b5145
Merge branch 'master' into refactor_moudle2functional
oneflow-ci-bot Aug 6, 2021
eda7bf2
Merge branch 'master' into refactor_moudle2functional
oneflow-ci-bot Aug 6, 2021
e362939
Merge branch 'master' into refactor_moudle2functional
oneflow-ci-bot Aug 6, 2021
93120c2
Merge branch 'master' into refactor_moudle2functional
oneflow-ci-bot Aug 7, 2021
89c7fce
Merge branch 'master' into refactor_moudle2functional
oneflow-ci-bot Aug 7, 2021
97a3db5
fix squeeze x not found bug
doombeaker Aug 7, 2021
1aeedd2
Merge branch 'master' into refactor_moudle2functional
doombeaker Aug 7, 2021
b24ae95
Merge branch 'master' into refactor_moudle2functional
oneflow-ci-bot Aug 7, 2021
85bcd8f
Merge branch 'master' into refactor_moudle2functional
oneflow-ci-bot Aug 7, 2021
544ae0a
Merge branch 'master' into refactor_moudle2functional
doombeaker Aug 8, 2021
0809077
remove Atanh()(input)
doombeaker Aug 8, 2021
7b5341d
refine modules(bmm, cast, diag, expm1)
doombeaker Aug 8, 2021
0b115d2
Merge branch 'master' into refactor_moudle2functional
doombeaker Aug 8, 2021
e7006bc
Merge branch 'master' into refactor_moudle2functional
doombeaker Aug 8, 2021
05e17e6
Merge branch 'master' into refactor_moudle2functional
oneflow-ci-bot Aug 8, 2021
dfc29e0
Merge branch 'master' into refactor_moudle2functional
oneflow-ci-bot Aug 8, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
13 changes: 2 additions & 11 deletions python/oneflow/nn/modules/abs.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,10 @@
"""
import oneflow as flow
from oneflow.framework.tensor import register_tensor_op
from oneflow.nn.module import Module


class Abs(Module):
def __init__(self):
super().__init__()

def forward(self, x):
return flow.F.abs(x)


@register_tensor_op("abs")
def abs_op(x):
def abs_op(input):
"""Return the absolute value of each element in input tensor:math:`y = |x|` element-wise.

Args:
Expand All @@ -45,7 +36,7 @@ def abs_op(x):
tensor([1., 2., 3., 4.], dtype=oneflow.float32)

"""
return Abs()(x)
return flow.F.abs(input)


if __name__ == "__main__":
Expand Down
13 changes: 2 additions & 11 deletions python/oneflow/nn/modules/acos.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,10 @@
"""
import oneflow as flow
from oneflow.framework.tensor import register_tensor_op
from oneflow.nn.module import Module


class Acos(Module):
def __init__(self) -> None:
super().__init__()

def forward(self, x):
return flow.F.acos(x)


@register_tensor_op("acos")
def acos_op(tensor):
def acos_op(input):
"""
Returns a new tensor with the inverse cosine of the elements of :attr:`input`.

Expand All @@ -51,7 +42,7 @@ def acos_op(tensor):
tensor([1.0472, 0.9273, 0.7954], dtype=oneflow.float32)

"""
return Acos()(tensor)
return flow.F.acos(input)


if __name__ == "__main__":
Expand Down
25 changes: 8 additions & 17 deletions python/oneflow/nn/modules/acosh.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,9 @@
"""
import oneflow as flow
from oneflow.framework.tensor import register_tensor_op
from oneflow.nn.module import Module


class Acosh(Module):
def __init__(self):
super().__init__()

def forward(self, x):
return flow.F.acosh(x)


def acosh_op(x):
def acosh_op(input):
"""Returns a new tensor with the inverse hyperbolic cosine of the elements of :attr:`input`.

.. math::
Expand All @@ -52,40 +43,40 @@ def acosh_op(x):
tensor([0.9624, 1.6094, 1.9827], device='cuda:0', dtype=oneflow.float32)

"""
return Acosh()(x)
return flow.F.acosh(input)


@register_tensor_op("acosh")
def acosh_op_tensor(x):
def acosh_op_tensor(input):
"""

acosh() -> Tensor

See :func:`oneflow.acosh`

"""
return Acosh()(x)
return flow.F.acosh(input)


def arccosh_op(x):
def arccosh_op(input):
"""

See :func:`oneflow.acosh`

"""
return Acosh()(x)
return flow.F.acosh(input)


@register_tensor_op("arccosh")
def arccosh_op_tensor(x):
def arccosh_op_tensor(input):
"""

arccosh() -> Tensor

See :func:`oneflow.acosh`

"""
return Acosh()(x)
return flow.F.acosh(input)


if __name__ == "__main__":
Expand Down
28 changes: 9 additions & 19 deletions python/oneflow/nn/modules/argwhere.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,26 +22,13 @@
from oneflow.nn.module import Module


class Argwhere(Module):
def __init__(self, dtype) -> None:
super().__init__()
if dtype == None:
dtype = flow.int32
self.dtype = dtype

def forward(self, x):
(res, size) = flow.F.argwhere(x, dtype=self.dtype)
slice_tup_list = [[0, int(size.numpy()), 1]]
return flow.slice(res, slice_tup_list=slice_tup_list)


def argwhere_op(x, dtype: Optional[flow.dtype] = None):
"""This operator finds the indices of input Tensor `x` elements that are non-zero.
def argwhere_op(input, dtype: Optional[flow.dtype] = flow.int32):
"""This operator finds the indices of input Tensor `input` elements that are non-zero.

It returns a list in which each element is a coordinate that points to a non-zero element in the condition.

Args:
x (oneflow.Tensor): The input Tensor.
input (oneflow.Tensor): The input Tensor.
dtype (Optional[flow.dtype], optional): The data type of output. Defaults to None.

Returns:
Expand All @@ -64,19 +51,22 @@ def argwhere_op(x, dtype: Optional[flow.dtype] = None):
[1, 2]], dtype=oneflow.int32)

"""
return Argwhere(dtype=dtype)(x)

(res, size) = flow.F.argwhere(input, dtype=dtype)
slice_tup_list = [[0, int(size.numpy()), 1]]
return flow.slice(res, slice_tup_list=slice_tup_list)


@register_tensor_op("argwhere")
def argwhere_tebsor_op(x, dtype: Optional[flow.dtype] = None):
def argwhere_tensor_op(input, dtype: Optional[flow.dtype] = flow.int32):
"""

argwhere() -> Tensor

See :func:`oneflow.argwhere`

"""
return Argwhere(dtype=dtype)(x)
return argwhere_op(input, dtype)


if __name__ == "__main__":
Expand Down
19 changes: 5 additions & 14 deletions python/oneflow/nn/modules/atanh.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,6 @@
"""
import oneflow as flow
from oneflow.framework.tensor import register_tensor_op
from oneflow.nn.module import Module


class Atanh(Module):
def __init__(self):
super().__init__()

def forward(self, x):
return flow.F.atanh(x)


def atanh_op(input):
Expand All @@ -48,25 +39,25 @@ def atanh_op(input):
tensor([0.5493, 0.6931, 0.8673], dtype=oneflow.float32)

"""
return Atanh()(input)
return flow.F.atanh(input)


@register_tensor_op("atanh")
def atanh_op_tensor(x):
def atanh_op_tensor(input):
"""
atanh() -> Tensor
See :func:`oneflow.atanh`

"""
return Atanh()(x)
return flow.F.atanh(input)


def arctanh_op(input):
"""

Alias for :func:`oneflow.atanh`
"""
return Atanh()(input)
return flow.F.atanh(input)


@register_tensor_op("arctanh")
Expand All @@ -75,7 +66,7 @@ def arctanh_op_tensor(input):

Alias for :func:`oneflow.atanh`
"""
return Atanh()(input)
return flow.F.atanh(input)


if __name__ == "__main__":
Expand Down
23 changes: 7 additions & 16 deletions python/oneflow/nn/modules/bmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,9 @@
"""
import oneflow as flow
from oneflow.framework.tensor import register_tensor_op
from oneflow.nn.module import Module


class BMM(Module):
def __init__(self) -> None:
super().__init__()

def forward(self, input, mat2):
assert (
input.shape[0] == mat2.shape[0] and input.shape[2] == mat2.shape[1]
), f"batch dim or matmul dim not match, please check input!"
return flow.F.batch_matmul(input, mat2)


def bmm_op(x, y):
def bmm_op(input, mat2):
"""
Performs a batch matrix-matrix product of matrices stored in input and mat2.

Expand All @@ -53,19 +41,22 @@ def bmm_op(x, y):
>>> of_out.shape
flow.Size([10, 3, 5])
"""
return BMM()(x, y)
assert (
input.shape[0] == mat2.shape[0] and input.shape[2] == mat2.shape[1]
), f"batch dim or matmul dim not match, please check input!"
return flow.F.batch_matmul(input, mat2)


@register_tensor_op("bmm")
def bmm_op_tensor(x, y):
def bmm_op_tensor(input, mat2):
"""

bmm() -> Tensor

See :func:`oneflow.bmm`

"""
return BMM()(x, y)
return flow.F.batch_matmul(input, mat2)


if __name__ == "__main__":
Expand Down
12 changes: 1 addition & 11 deletions python/oneflow/nn/modules/cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,6 @@
"""
import oneflow as flow
from oneflow.framework.tensor import register_tensor_op
from oneflow.nn.module import Module


class Cast(Module):
def __init__(self, dtype: flow.dtype) -> None:
super().__init__()
self.dtype = dtype

def forward(self, x):
return flow.F.cast(x, dtype=self.dtype)


@register_tensor_op("cast")
Expand All @@ -51,7 +41,7 @@ def cast_op(x, dtype):
True

"""
return Cast(dtype)(x)
return flow.F.cast(x, dtype)


if __name__ == "__main__":
Expand Down