Skip to content

Commit

Permalink
Try to fix var bug (#5973)
Browse files Browse the repository at this point in the history
* fix var impl bug

* add more test_case

* refine

* add tensor method test

* fix comment

* auto format by CI

* fix comments

* fix ci error

* fix run ci error

Co-authored-by: oneflow-ci-bot <ci-bot@oneflow.org>
Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
  • Loading branch information
3 people committed Aug 21, 2021
1 parent b24fd1b commit fc60e22
Show file tree
Hide file tree
Showing 8 changed files with 103 additions and 67 deletions.
2 changes: 1 addition & 1 deletion python/oneflow/nn/modules/batchnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def forward(self, x):
if dim != 1:
reduce_axis.append(dim)
mean = x.mean(dim=reduce_axis, keepdim=False)
variance = x.var(dim=reduce_axis, keepdim=False)
variance = x.var(dim=reduce_axis, unbiased=False, keepdim=False)
if self.training and self.track_running_stats:
running_mean = (
self.momentum * self.running_mean + (1 - self.momentum) * mean
Expand Down
2 changes: 1 addition & 1 deletion python/oneflow/nn/modules/instancenorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def _forward(self, x):
nd_params_shape = [1] * len(x.shape)
nd_params_shape[axis] = params_shape[0]
mean = x.mean(2, keepdim=True)
variance = x.var(2, keepdim=True)
variance = x.var(2, unbiased=False, keepdim=True)
normalized = (x - mean) / flow.sqrt(variance + self.eps)
if self.weight is not None and params_shape[0] == self.weight.nelement():
weight = flow.reshape(self.weight, shape=nd_params_shape)
Expand Down
25 changes: 15 additions & 10 deletions python/oneflow/nn/modules/math_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def _mul(input, other):


@register_tensor_op("var")
def variance_op(input, dim=None, keepdim=False):
def variance_op(input, dim=None, unbiased=True, keepdim=False):
"""Returns the variance of each row of the `input` tensor in the given dimension `dim`.
If `keepdim` is `True`, the output tensor is of the same size as `input` except in the dimension(s) `dim`
Expand All @@ -78,6 +78,7 @@ def variance_op(input, dim=None, keepdim=False):
Args:
input (Tensor): the input tensor.
dim (int or tuple of python:ints): the dimension or dimensions to reduce. Defaults to None.
unbiased (bool, optional): whether to use Bessel’s correction (:math:`\delta N = 1`). Defaults to True.
keepdim (bool, optional): whether the output tensor has dim retained or not. Defaults to False.
Returns:
Expand All @@ -95,15 +96,19 @@ def variance_op(input, dim=None, keepdim=False):
>>> output = flow.var(input, 1, True)
"""

axis = _check_axis(dim, input.shape)
if isinstance(axis, list) and len(axis) == 0:
return flow.zeros(input.shape)
else:
return flow.sub(
flow.mean(flow.square(input), axis, keepdim),
flow.square(flow.mean(input, axis, keepdim)),
)
input_shape = input.shape
axis = _check_axis(dim, input_shape)
input_shape_dim = 1
for x in axis:
input_shape_dim *= input_shape[x]
if unbiased:
input_shape_dim -= 1
res = flow.sum(
flow.square(input - flow.mean(input, dim=axis, keepdim=True)),
dim=axis,
keepdim=keepdim,
)
return res / input_shape_dim


@register_tensor_op("sub")
Expand Down
12 changes: 6 additions & 6 deletions python/oneflow/nn/modules/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def forward(self, input: Tensor) -> Tensor:
input, shape=[origin_shape[0], self.num_groups, -1]
)
mean = flow.mean(reshape_to_1d, dim=2, keepdim=True)
variance = flow.var(reshape_to_1d, dim=2, keepdim=True)
variance = flow.var(reshape_to_1d, dim=2, unbiased=False, keepdim=True)
normalized = (reshape_to_1d - mean) / flow.sqrt(variance + self.eps)
normalized = flow.reshape(
normalized, shape=[origin_shape[0], self.num_channels, -1]
Expand Down Expand Up @@ -203,15 +203,15 @@ class LayerNorm(Module):
array([[[[ 0.99997395, -0.99997395],
[-0.999947 , 0.999947 ]],
<BLANKLINE>
[[-0.9999596 , 0.9999594 ],
[[-0.99995965, 0.9999595 ],
[ 0.999988 , -0.999988 ]]],
<BLANKLINE>
<BLANKLINE>
[[[-0.9998343 , 0.9998341 ],
[[[-0.9998348 , 0.99983466],
[ 0.9999914 , -0.9999914 ]],
<BLANKLINE>
[[ 0.99997866, -0.99997866],
[ 0.9999646 , -0.9999646 ]]]], dtype=float32)
[[ 0.9999785 , -0.9999785 ],
[ 0.9999645 , -0.9999645 ]]]], dtype=float32)
"""

Expand Down Expand Up @@ -259,7 +259,7 @@ def forward(self, x):
if dim >= self.begin_norm_axis:
reduce_axis.append(dim)
mean = x.mean(dim=reduce_axis, keepdim=True)
variance = x.var(dim=reduce_axis, keepdim=True)
variance = x.var(dim=reduce_axis, unbiased=False, keepdim=True)
axis = self.begin_norm_axis
params_shape = x.shape[self.begin_params_axis :]
weight = self.weight
Expand Down
48 changes: 0 additions & 48 deletions python/oneflow/test/modules/test_math_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,54 +25,6 @@
import oneflow.unittest


def _test_variance_keepdim(test_case, shape, device):
np_arr = np.random.randn(*shape)
of_out = flow.Tensor(np_arr, device=flow.device(device)).var(0, True)
np_out = np.var(np_arr, 0, keepdims=True)
test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05))


def _test_variance(test_case, shape, device):
np_arr = np.random.randn(*shape)
of_out = flow.var(flow.Tensor(np_arr, device=flow.device(device)), 1, False)
np_out = np.var(np_arr, 1, keepdims=False)
test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05))


def _test_variance_backward(test_case, shape, device):
np_arr = np.array(
[
[
[-0.436214, -1.11672411, 0.78394664, 2.0621712],
[0.7716703, -1.35367316, -0.40694879, -1.72392356],
[-1.08482436, -0.20731248, 1.39633697, 0.32614333],
],
[
[-1.42467297, -1.78418015, 0.17861511, 0.12065858],
[2.03621124, -0.93674042, 0.1943963, 1.98559192],
[-0.00436223, 0.37788105, 0.47820872, 0.15467583],
],
]
)
x = flow.Tensor(np_arr, requires_grad=True, device=flow.device(device))
y = flow.var(x, False)
z = y.sum()
z.backward()
np_grad = 2 * (np_arr - np_arr.mean()) / np_arr.size
test_case.assertTrue(np.allclose(x.grad.numpy(), np_grad, 1e-05, 1e-05))


@flow.unittest.skip_unless_1n1d()
class TestVariance(flow.unittest.TestCase):
def test_variance(test_case):
arg_dict = OrderedDict()
arg_dict["test_fun"] = [_test_variance, _test_variance_keepdim]
arg_dict["shape"] = [(2, 3), (2, 3, 4), (2, 3, 4, 5)]
arg_dict["device"] = ["cpu", "cuda"]
for arg in GenArgList(arg_dict):
arg[0](test_case, *arg[1:])


@flow.unittest.skip_unless_1n1d()
class TestSinh(flow.unittest.TestCase):
@autotest()
Expand Down
59 changes: 59 additions & 0 deletions python/oneflow/test/modules/test_var.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest

import oneflow as flow
from oneflow.test_utils.automated_test_util.generators import random
import oneflow.unittest
from automated_test_util import *


class TestVar(flow.unittest.TestCase):
@autotest()
def test_flow_var_all_dim_with_random_data(test_case):
device = random_device()
x = random_pytorch_tensor().to(device)
y = torch.var(x)
return y

@autotest()
def test_flow_var_one_dim_with_random_data(test_case):
device = random_device()
x = random_pytorch_tensor(ndim=4).to(device)
y = torch.var(
x,
dim=random(low=0, high=4).to(int),
unbiased=random().to(bool),
keepdim=random().to(bool),
)
return y

@unittest.skip("var not support 0-shape tensor currently")
@autotest()
def test_flow_var_0d_tensor_with_random_data(test_case):
device = random_device()
x = random_pytorch_tensor(4, 2, 3, 0, 4).to(device)
y = torch.var(
x,
dim=random(low=0, high=4).to(int),
unbiased=random().to(bool),
keepdim=random().to(bool),
)
return y


if __name__ == "__main__":
unittest.main()
18 changes: 18 additions & 0 deletions python/oneflow/test/tensor/test_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -881,6 +881,24 @@ def test_floor_tensor_with_random_data(test_case):
y = x.floor()
return y

@autotest()
def test_tesnor_var_all_dim_with_random_data(test_case):
device = random_device()
x = random_pytorch_tensor().to(device)
y = x.var()
return y

@autotest()
def test_tesnor_var_one_dim_with_random_data(test_case):
device = random_device()
x = random_pytorch_tensor(ndim=4).to(device)
y = x.var(
dim=random(low=0, high=4).to(int),
unbiased=random().to(bool),
keepdim=random().to(bool),
)
return y

@flow.unittest.skip_unless_1n1d()
def test_norm_tensor_function(test_case):
input = flow.Tensor(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,9 @@ def check_tensor_equality(torch_tensor, flow_tensor, rtol=0.0001, atol=1e-05):
), f"OneFlow tensor doesn't have grad while PyTorch tensor has one, PyTorch tensor is\n {torch_tensor}\n, OneFlow tensor is\n{flow_tensor} "
torch_grad = torch_tensor.grad.detach().cpu().numpy()
flow_grad = flow_tensor.grad.numpy()
if not np.allclose(torch_grad, flow_grad, rtol=rtol, atol=atol):
if not np.allclose(
torch_grad, flow_grad, rtol=rtol, atol=atol, equal_nan=True,
):
print(
f"Grads are not equal. PyTorch grad: \n{torch_grad}\n, OneFlow grad: \n{flow_grad}"
)
Expand Down

0 comments on commit fc60e22

Please sign in to comment.