Skip to content

Commit

Permalink
add more vision model compatibility tests (#7942)
Browse files Browse the repository at this point in the history
* add more vision model compatibility tests

* add more model tests

* auto format by CI

* Update python/oneflow/framework/docstr/tensor.py

Co-authored-by: Yao Chi <later@usopp.net>

* refine

* skip efficientnet

* fix ci

* fix ci

* Fix nn function conv3d ci bug (#7961)

* fix reduce_sum scalar check bug

* fix ci oom bug

Co-authored-by: oneflow-ci-bot <ci-bot@oneflow.org>
Co-authored-by: Yao Chi <later@usopp.net>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
Co-authored-by: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com>
  • Loading branch information
5 people committed Apr 15, 2022
1 parent c8d9657 commit 39cd242
Show file tree
Hide file tree
Showing 33 changed files with 8,155 additions and 137 deletions.
16 changes: 16 additions & 0 deletions python/oneflow/framework/docstr/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1569,6 +1569,22 @@
""",
)

add_docstr(
oneflow.Tensor.reshape_as,
"""
Tensor.reshape_as(other) -> Tensor
Returns this tensor as the same shape as other.
self.reshape_as(other) is equivalent to self.reshape(other.sizes()).
This method returns a view if other.sizes() is compatible with the current shape.
See :func:`oneflow.Tensor.view` on when it is possible to return a view.
Please see reshape() for more information about reshape. See :func:`oneflow.reshape`
Parameters
other (oneflow.Tensor) – The result tensor has the same shape as other.
""",
)

add_docstr(
oneflow.Tensor.view,
"""
Expand Down
5 changes: 5 additions & 0 deletions python/oneflow/framework/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -963,6 +963,10 @@ def _reshape(self, *shape):
return flow._C.reshape(self, new_shape)


def _reshape_as(self, other):
return _reshape(self, other.size())


def _view(self, *shape):
if len(shape) == 1:
new_shape = shape[0]
Expand Down Expand Up @@ -1237,6 +1241,7 @@ def RegisterMethods():
Tensor.le = _le
Tensor.to_local = _to_local
Tensor.reshape = _reshape
Tensor.reshape_as = _reshape_as
Tensor.view = _view
Tensor.sort = _sort
Tensor.type_as = _type_as
Expand Down
2 changes: 1 addition & 1 deletion python/oneflow/nn/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
from oneflow._C import threshold
from oneflow._C import silu
from oneflow._C import mish
from oneflow._C import layer_norm
from oneflow.nn.modules.normalization import layer_norm
from oneflow._C import dropout
from oneflow._C import smooth_l1_loss
from oneflow._C import pad
Expand Down
47 changes: 41 additions & 6 deletions python/oneflow/nn/modules/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
from oneflow.nn.module import Module
from oneflow.nn.modules.utils import _pair, _single, _triple

from typing import Union


def slice(x, begin, size):
ndim = len(x.shape)
Expand Down Expand Up @@ -71,6 +73,27 @@ def split(cls, x, axis, split_num):
return result_list


def get_padding(padding, kernel_size, dilation, stride):
valid_padding_strings = {"same", "valid"}
if isinstance(padding, str):
if padding not in valid_padding_strings:
raise ValueError(
"Invalid padding string {!r}, should be one of {}".format(
padding, valid_padding_strings
)
)
if padding == "same" and any(s != 1 for s in list(stride)):
raise ValueError("padding='same' is not supported for strided convolutions")

out_padding = [0] * len(kernel_size)
if padding == "same":
for d, k, i in zip(dilation, kernel_size, range(len(kernel_size) - 1, -1, -1)):
total_padding = d * (k - 1)
left_pad = total_padding // 2
out_padding[i] = left_pad
return out_padding


class Conv1d(Module):
"""The interface is consistent with PyTorch.
The documentation is referenced from: https://pytorch.org/docs/master/generated/torch.nn.Conv1d.html#conv1d
Expand Down Expand Up @@ -168,7 +191,7 @@ def __init__(
out_channels: int,
kernel_size: _size_1_t,
stride: _size_1_t = 1,
padding: _size_1_t = 0,
padding: Union[str, _size_1_t] = 0,
dilation: _size_1_t = 1,
groups: int = 1,
bias: bool = True,
Expand All @@ -179,8 +202,12 @@ def __init__(
self.padding_mode = padding_mode
self.kernel_size = _single(kernel_size)
self.stride = _single(stride)
self.padding = _single(padding)
self.dilation = _single(dilation)
self.padding = (
get_padding(padding, self.kernel_size, self.dilation, self.stride)
if isinstance(padding, str)
else _single(padding)
)
self.groups = groups
self.channel_pos = "channels_first"
assert in_channels % groups == 0
Expand Down Expand Up @@ -353,7 +380,7 @@ def __init__(
out_channels: int,
kernel_size: _size_2_t,
stride: _size_2_t = 1,
padding: _size_2_t = 0,
padding: Union[str, _size_2_t] = 0,
dilation: _size_2_t = 1,
groups: int = 1,
bias: bool = True,
Expand All @@ -364,8 +391,12 @@ def __init__(
self.padding_mode = padding_mode
self.kernel_size = _pair(kernel_size)
self.stride = _pair(stride)
self.padding = _pair(padding)
self.dilation = _pair(dilation)
self.padding = (
get_padding(padding, self.kernel_size, self.dilation, self.stride)
if isinstance(padding, str)
else _pair(padding)
)
self.groups = groups

if os.getenv("ONEFLOW_ENABLE_NHWC") == "1":
Expand Down Expand Up @@ -535,7 +566,7 @@ def __init__(
out_channels: int,
kernel_size: _size_3_t,
stride: _size_3_t = 1,
padding: _size_3_t = 0,
padding: Union[str, _size_3_t] = 0,
dilation: _size_3_t = 1,
groups: int = 1,
bias: bool = True,
Expand All @@ -547,8 +578,12 @@ def __init__(
self.padding_mode = padding_mode
self.kernel_size = _triple(kernel_size)
self.stride = _triple(stride)
self.padding = _triple(padding)
self.dilation = _triple(dilation)
self.padding = (
get_padding(padding, self.kernel_size, self.dilation, self.stride)
if isinstance(padding, str)
else _triple(padding)
)
self.groups = groups
self.channel_pos = "channels_first"
assert in_channels % groups == 0, "in_channels must be divisible by groups"
Expand Down
129 changes: 64 additions & 65 deletions python/oneflow/nn/modules/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,69 @@ def extra_repr(self) -> str:
)


def layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-05):
assert len(input.shape) > len(
normalized_shape
), "Input tensor dim must greater than normalized dim!"
begin_norm_axis = len(input.shape) - len(normalized_shape)
begin_params_axis = len(input.shape) - len(normalized_shape)

elementwise_affine = True if (weight is not None and bias is not None) else False

for i in range(0, len(normalized_shape)):
if input.shape[i + begin_params_axis] != normalized_shape[i]:
raise RuntimeError(
f"Given normalized_shape={normalized_shape}, expected input with shape [*, {str(normalized_shape)[1:-1]}], but got input of size {x.shape}"
)

if not input.is_cuda:
reduce_axis = []
for dim in range(len(input.shape)):
if dim >= begin_norm_axis:
reduce_axis.append(dim)
mean = input.mean(dim=reduce_axis, keepdim=True)
variance = input.var(dim=reduce_axis, unbiased=False, keepdim=True)
params_shape = input.shape[begin_params_axis:]
if len(mean.shape) == 1:
nd_params_shape = [1] * len(input.shape)
nd_params_shape[begin_norm_axis] = params_shape[0]
mean = flow.reshape(mean, shape=nd_params_shape)
variance = flow.reshape(variance, nd_params_shape)
if weight is not None and params_shape[0] == weight.nelement():
weight = flow.reshape(weight, shape=nd_params_shape)
if bias is not None and params_shape[0] == bias.nelement():
bias = flow.reshape(bias, shape=nd_params_shape)
elif len(mean.shape) == len(input.shape):
pass
else:
raise ValueError(
"shape of mean and variance should be 1D or has number of axes and x's"
)
variance += eps
normalized = (input - mean) * variance.rsqrt()
if elementwise_affine:
normalized = normalized * weight + bias
return normalized
else:
if elementwise_affine:
res = flow._C.layer_norm_affine(
input,
weight,
bias,
begin_norm_axis=begin_norm_axis,
begin_params_axis=begin_params_axis,
epsilon=eps,
)
else:
res = flow._C.layer_norm(
input,
begin_norm_axis=begin_norm_axis,
begin_params_axis=begin_params_axis,
epsilon=eps,
)
return res


class LayerNorm(Module):
"""Applies Layer Normalization over a mini-batch of inputs as described in
the paper `Layer Normalization <https://arxiv.org/abs/1607.06450>`__
Expand Down Expand Up @@ -239,78 +302,14 @@ def __init__(
self.register_parameter("weight", None)
self.register_parameter("bias", None)
self.reset_parameters()
self.begin_norm_axis = 1
self.begin_params_axis = 1

def reset_parameters(self) -> None:
if self.elementwise_affine:
init.ones_(self.weight)
init.zeros_(self.bias)

def forward(self, x):
assert len(x.shape) > len(
self.normalized_shape
), "Input tensor dim must greater than normalized dim!"
self.begin_norm_axis = len(x.shape) - len(self.normalized_shape)
self.begin_params_axis = len(x.shape) - len(self.normalized_shape)

for i in range(0, len(self.normalized_shape)):
if x.shape[i + self.begin_params_axis] != self.normalized_shape[i]:
raise RuntimeError(
f"Given normalized_shape={self.normalized_shape}, expected input with shape [*, {str(self.normalized_shape)[1:-1]}], but got input of size {x.shape}"
)

if not x.is_cuda:
reduce_axis = []
for dim in range(len(x.shape)):
if dim >= self.begin_norm_axis:
reduce_axis.append(dim)
mean = x.mean(dim=reduce_axis, keepdim=True)
variance = x.var(dim=reduce_axis, unbiased=False, keepdim=True)
params_shape = x.shape[self.begin_params_axis :]
weight = self.weight
bias = self.bias
if len(mean.shape) == 1:
nd_params_shape = [1] * len(x.shape)
nd_params_shape[self.begin_norm_axis] = params_shape[0]
mean = flow.reshape(mean, shape=nd_params_shape)
variance = flow.reshape(variance, nd_params_shape)
if (
self.weight is not None
and params_shape[0] == self.weight.nelement()
):
weight = flow.reshape(self.weight, shape=nd_params_shape)
if self.bias is not None and params_shape[0] == self.bias.nelement():
bias = flow.reshape(self.bias, shape=nd_params_shape)
elif len(mean.shape) == len(x.shape):
pass
else:
raise ValueError(
"shape of mean and variance should be 1D or has number of axes and x's"
)
variance += self.eps
normalized = (x - mean) * variance.rsqrt()
if self.elementwise_affine:
normalized = normalized * weight + bias
return normalized
else:
if self.elementwise_affine:
res = flow._C.layer_norm_affine(
x,
self.weight,
self.bias,
begin_norm_axis=self.begin_norm_axis,
begin_params_axis=self.begin_params_axis,
epsilon=self.eps,
)
else:
res = flow._C.layer_norm(
x,
begin_norm_axis=self.begin_norm_axis,
begin_params_axis=self.begin_params_axis,
epsilon=self.eps,
)
return res
return layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)

def extra_repr(self) -> str:
return "{normalized_shape}, eps={eps}, elementwise_affine={elementwise_affine}".format(
Expand Down
9 changes: 0 additions & 9 deletions python/oneflow/test/expensive/pytorch_alexnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,12 @@
"""
import torch
import torch.nn as nn
from _internally_replaced_utils import load_state_dict_from_url
from typing import Any


__all__ = ["AlexNet", "alexnet"]


model_urls = {
"alexnet": "https://download.pytorch.org/models/alexnet-owt-7be5be79.pth",
}


class AlexNet(nn.Module):
def __init__(self, num_classes: int = 1000) -> None:
super(AlexNet, self).__init__()
Expand Down Expand Up @@ -73,7 +67,4 @@ def alexnet(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> A
progress (bool): If True, displays a progress bar of the download to stderr
"""
model = AlexNet(**kwargs)
if pretrained:
state_dict = load_state_dict_from_url(model_urls["alexnet"], progress=progress)
model.load_state_dict(state_dict)
return model
Loading

0 comments on commit 39cd242

Please sign in to comment.