Skip to content

Commit

Permalink
Refactor: remove module not required, call function directly (#5754)
Browse files Browse the repository at this point in the history
* gather mean, sum

* abs, acos, acosh, slice

* abs, acosh, argwhere, bmm, chunk, concat, slice

* constant, diag, eq, exp, expand, eye, flip, floor, gathr_nd, >=

* greater, less, less_equal, maksed_fill, masked_select

* asin, asinh, sin, cos, atan

* std and related ops

* matmul, ne, negative, permute, repeat, ceil, expm1

* reshape, round, sign, sinh, squeeze, stack, tile etc.

unsqueeze, where, transpose, triu

* fix argwhere default arg

* fix squeeze x not found bug

* remove Atanh()(input)

* refine modules(bmm, cast, diag, expm1)

Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
  • Loading branch information
doombeaker and oneflow-ci-bot committed Aug 8, 2021
1 parent 9f8522f commit fc8d535
Show file tree
Hide file tree
Showing 45 changed files with 637 additions and 1,299 deletions.
13 changes: 2 additions & 11 deletions python/oneflow/nn/modules/abs.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,10 @@
"""
import oneflow as flow
from oneflow.framework.tensor import register_tensor_op
from oneflow.nn.module import Module


class Abs(Module):
def __init__(self):
super().__init__()

def forward(self, x):
return flow.F.abs(x)


@register_tensor_op("abs")
def abs_op(x):
def abs_op(input):
"""Return the absolute value of each element in input tensor:math:`y = |x|` element-wise.
Args:
Expand All @@ -45,7 +36,7 @@ def abs_op(x):
tensor([1., 2., 3., 4.], dtype=oneflow.float32)
"""
return Abs()(x)
return flow.F.abs(input)


if __name__ == "__main__":
Expand Down
13 changes: 2 additions & 11 deletions python/oneflow/nn/modules/acos.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,10 @@
"""
import oneflow as flow
from oneflow.framework.tensor import register_tensor_op
from oneflow.nn.module import Module


class Acos(Module):
def __init__(self) -> None:
super().__init__()

def forward(self, x):
return flow.F.acos(x)


@register_tensor_op("acos")
def acos_op(tensor):
def acos_op(input):
"""
Returns a new tensor with the inverse cosine of the elements of :attr:`input`.
Expand All @@ -51,7 +42,7 @@ def acos_op(tensor):
tensor([1.0472, 0.9273, 0.7954], dtype=oneflow.float32)
"""
return Acos()(tensor)
return flow.F.acos(input)


if __name__ == "__main__":
Expand Down
25 changes: 8 additions & 17 deletions python/oneflow/nn/modules/acosh.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,9 @@
"""
import oneflow as flow
from oneflow.framework.tensor import register_tensor_op
from oneflow.nn.module import Module


class Acosh(Module):
def __init__(self):
super().__init__()

def forward(self, x):
return flow.F.acosh(x)


def acosh_op(x):
def acosh_op(input):
"""Returns a new tensor with the inverse hyperbolic cosine of the elements of :attr:`input`.
.. math::
Expand All @@ -52,40 +43,40 @@ def acosh_op(x):
tensor([0.9624, 1.6094, 1.9827], device='cuda:0', dtype=oneflow.float32)
"""
return Acosh()(x)
return flow.F.acosh(input)


@register_tensor_op("acosh")
def acosh_op_tensor(x):
def acosh_op_tensor(input):
"""
acosh() -> Tensor
See :func:`oneflow.acosh`
"""
return Acosh()(x)
return flow.F.acosh(input)


def arccosh_op(x):
def arccosh_op(input):
"""
See :func:`oneflow.acosh`
"""
return Acosh()(x)
return flow.F.acosh(input)


@register_tensor_op("arccosh")
def arccosh_op_tensor(x):
def arccosh_op_tensor(input):
"""
arccosh() -> Tensor
See :func:`oneflow.acosh`
"""
return Acosh()(x)
return flow.F.acosh(input)


if __name__ == "__main__":
Expand Down
28 changes: 9 additions & 19 deletions python/oneflow/nn/modules/argwhere.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,26 +22,13 @@
from oneflow.nn.module import Module


class Argwhere(Module):
def __init__(self, dtype) -> None:
super().__init__()
if dtype == None:
dtype = flow.int32
self.dtype = dtype

def forward(self, x):
(res, size) = flow.F.argwhere(x, dtype=self.dtype)
slice_tup_list = [[0, int(size.numpy()), 1]]
return flow.slice(res, slice_tup_list=slice_tup_list)


def argwhere_op(x, dtype: Optional[flow.dtype] = None):
"""This operator finds the indices of input Tensor `x` elements that are non-zero.
def argwhere_op(input, dtype: Optional[flow.dtype] = flow.int32):
"""This operator finds the indices of input Tensor `input` elements that are non-zero.
It returns a list in which each element is a coordinate that points to a non-zero element in the condition.
Args:
x (oneflow.Tensor): The input Tensor.
input (oneflow.Tensor): The input Tensor.
dtype (Optional[flow.dtype], optional): The data type of output. Defaults to None.
Returns:
Expand All @@ -64,19 +51,22 @@ def argwhere_op(x, dtype: Optional[flow.dtype] = None):
[1, 2]], dtype=oneflow.int32)
"""
return Argwhere(dtype=dtype)(x)

(res, size) = flow.F.argwhere(input, dtype=dtype)
slice_tup_list = [[0, int(size.numpy()), 1]]
return flow.slice(res, slice_tup_list=slice_tup_list)


@register_tensor_op("argwhere")
def argwhere_tebsor_op(x, dtype: Optional[flow.dtype] = None):
def argwhere_tensor_op(input, dtype: Optional[flow.dtype] = flow.int32):
"""
argwhere() -> Tensor
See :func:`oneflow.argwhere`
"""
return Argwhere(dtype=dtype)(x)
return argwhere_op(input, dtype)


if __name__ == "__main__":
Expand Down
19 changes: 5 additions & 14 deletions python/oneflow/nn/modules/atanh.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,6 @@
"""
import oneflow as flow
from oneflow.framework.tensor import register_tensor_op
from oneflow.nn.module import Module


class Atanh(Module):
def __init__(self):
super().__init__()

def forward(self, x):
return flow.F.atanh(x)


def atanh_op(input):
Expand All @@ -48,25 +39,25 @@ def atanh_op(input):
tensor([0.5493, 0.6931, 0.8673], dtype=oneflow.float32)
"""
return Atanh()(input)
return flow.F.atanh(input)


@register_tensor_op("atanh")
def atanh_op_tensor(x):
def atanh_op_tensor(input):
"""
atanh() -> Tensor
See :func:`oneflow.atanh`
"""
return Atanh()(x)
return flow.F.atanh(input)


def arctanh_op(input):
"""
Alias for :func:`oneflow.atanh`
"""
return Atanh()(input)
return flow.F.atanh(input)


@register_tensor_op("arctanh")
Expand All @@ -75,7 +66,7 @@ def arctanh_op_tensor(input):
Alias for :func:`oneflow.atanh`
"""
return Atanh()(input)
return flow.F.atanh(input)


if __name__ == "__main__":
Expand Down
23 changes: 7 additions & 16 deletions python/oneflow/nn/modules/bmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,9 @@
"""
import oneflow as flow
from oneflow.framework.tensor import register_tensor_op
from oneflow.nn.module import Module


class BMM(Module):
def __init__(self) -> None:
super().__init__()

def forward(self, input, mat2):
assert (
input.shape[0] == mat2.shape[0] and input.shape[2] == mat2.shape[1]
), f"batch dim or matmul dim not match, please check input!"
return flow.F.batch_matmul(input, mat2)


def bmm_op(x, y):
def bmm_op(input, mat2):
"""
Performs a batch matrix-matrix product of matrices stored in input and mat2.
Expand All @@ -53,19 +41,22 @@ def bmm_op(x, y):
>>> of_out.shape
flow.Size([10, 3, 5])
"""
return BMM()(x, y)
assert (
input.shape[0] == mat2.shape[0] and input.shape[2] == mat2.shape[1]
), f"batch dim or matmul dim not match, please check input!"
return flow.F.batch_matmul(input, mat2)


@register_tensor_op("bmm")
def bmm_op_tensor(x, y):
def bmm_op_tensor(input, mat2):
"""
bmm() -> Tensor
See :func:`oneflow.bmm`
"""
return BMM()(x, y)
return flow.F.batch_matmul(input, mat2)


if __name__ == "__main__":
Expand Down
12 changes: 1 addition & 11 deletions python/oneflow/nn/modules/cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,6 @@
"""
import oneflow as flow
from oneflow.framework.tensor import register_tensor_op
from oneflow.nn.module import Module


class Cast(Module):
def __init__(self, dtype: flow.dtype) -> None:
super().__init__()
self.dtype = dtype

def forward(self, x):
return flow.F.cast(x, dtype=self.dtype)


@register_tensor_op("cast")
Expand All @@ -51,7 +41,7 @@ def cast_op(x, dtype):
True
"""
return Cast(dtype)(x)
return flow.F.cast(x, dtype)


if __name__ == "__main__":
Expand Down

0 comments on commit fc8d535

Please sign in to comment.