Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bn track running stats #5393

Merged
merged 46 commits into from
Jul 6, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
eb694ab
refine and add test case
Flowingsun007 Jun 10, 2021
813f379
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow
Flowingsun007 Jun 10, 2021
0aae06b
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow
Flowingsun007 Jun 10, 2021
4a38ccd
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow
Flowingsun007 Jun 10, 2021
b58b849
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow
Flowingsun007 Jun 11, 2021
b5e151a
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow
Flowingsun007 Jun 11, 2021
0ff1651
support ellipsis type slice
Flowingsun007 Jun 11, 2021
1276e65
refine
Flowingsun007 Jun 11, 2021
b9066f9
refine
Flowingsun007 Jun 11, 2021
8f81967
support slice assign ellipsis type
Flowingsun007 Jun 11, 2021
8f8cee2
refine
Flowingsun007 Jun 11, 2021
a79dcdf
Merge branch 'master' into dev_fix_slice_bug
Flowingsun007 Jun 12, 2021
81a400e
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow
Flowingsun007 Jun 12, 2021
e9e2fa4
Merge branch 'master' into dev_fix_slice_bug
Flowingsun007 Jun 13, 2021
dbdcf18
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow
Flowingsun007 Jun 13, 2021
b11243f
Merge branch 'master' into dev_fix_slice_bug
Flowingsun007 Jun 13, 2021
9c7185b
Merge branch 'master' into dev_fix_slice_bug
oneflow-ci-bot Jun 13, 2021
c8c78fb
register fn to localtensor
Flowingsun007 Jun 13, 2021
f565929
Merge branch 'dev_fix_slice_bug' of https://github.com/Oneflow-Inc/on…
Flowingsun007 Jun 13, 2021
ebc25f0
Merge branch 'dev_fix_slice_bug'
Flowingsun007 Jun 13, 2021
475bbff
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow
Flowingsun007 Jun 13, 2021
b69f554
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow
Flowingsun007 Jun 15, 2021
e8cd9e3
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow
Flowingsun007 Jun 16, 2021
a500a6d
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow
Flowingsun007 Jun 17, 2021
5387b8f
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow
Flowingsun007 Jun 17, 2021
756b0ed
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow
Flowingsun007 Jun 17, 2021
34e9fd5
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow
Flowingsun007 Jun 17, 2021
a5d67ac
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow
Flowingsun007 Jun 21, 2021
e547b4b
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow
Flowingsun007 Jun 21, 2021
756e537
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow
Flowingsun007 Jun 24, 2021
a39271b
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow
Flowingsun007 Jun 25, 2021
d5ecb51
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow
Flowingsun007 Jun 27, 2021
db1b536
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow
Flowingsun007 Jun 28, 2021
75cc02b
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow
Flowingsun007 Jun 28, 2021
634b968
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow
Flowingsun007 Jun 29, 2021
6be7d0b
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow
Flowingsun007 Jun 29, 2021
d1eaabe
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow
Flowingsun007 Jun 30, 2021
2ffb4ed
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow
Flowingsun007 Jul 1, 2021
eca3dd6
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow
Flowingsun007 Jul 5, 2021
4117078
support bn track_running_stats=False
Flowingsun007 Jul 5, 2021
a82eeb8
remove useless import
Flowingsun007 Jul 5, 2021
ec5cf86
Merge branch 'master' into fix_bn_track_running_stats
Flowingsun007 Jul 5, 2021
5a3921c
Merge branch 'master' into fix_bn_track_running_stats
Flowingsun007 Jul 5, 2021
d05d306
Merge branch 'master' into fix_bn_track_running_stats
Flowingsun007 Jul 5, 2021
dc151b2
Merge branch 'master' into fix_bn_track_running_stats
Flowingsun007 Jul 5, 2021
2d6eaaf
Merge branch 'master' into fix_bn_track_running_stats
Flowingsun007 Jul 6, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 39 additions & 23 deletions oneflow/python/nn/modules/batchnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
See the License for the specific language governing permissions and
limitations under the License.
"""
import oneflow as flow
from typing import Union

import oneflow as flow
from oneflow.python.oneflow_export import oneflow_export, experimental_api
from oneflow.python.nn.module import Module
import oneflow._oneflow_internal as oneflow_api


class _NormBase(Module):
Expand All @@ -30,25 +30,32 @@ def __init__(
momentum: float = 0.1,
affine: bool = True,
track_running_stats: bool = True,
device: Union[str, flow.device] = None,
Ldpe2G marked this conversation as resolved.
Show resolved Hide resolved
dtype: flow.dtype = None,
) -> None:
super().__init__()
self.num_features = num_features
self.eps = eps
self.momentum = momentum
self.affine = affine
self.track_running_stats = track_running_stats
self.device = device
self.dtype = dtype

if self.affine:
self.weight = flow.nn.Parameter(flow.Tensor(num_features))
self.bias = flow.nn.Parameter(flow.Tensor(num_features))
self.weight = flow.nn.Parameter(
flow.Tensor(num_features, device=self.device)
Ldpe2G marked this conversation as resolved.
Show resolved Hide resolved
)
self.bias = flow.nn.Parameter(flow.Tensor(num_features, device=self.device))
else:
self.register_parameter("weight", None)
self.register_parameter("bias", None)
if self.track_running_stats:
self.register_buffer(
"running_mean", flow.Tensor(num_features),
"running_mean", flow.Tensor(num_features, device=self.device),
)
self.register_buffer(
"running_var", flow.Tensor(num_features),
"running_var", flow.Tensor(num_features, device=self.device),
)
else:
self.register_parameter("running_mean", None)
Expand Down Expand Up @@ -99,40 +106,48 @@ def __init__(
momentum=0.1,
affine=True,
track_running_stats=True,
device=None,
dtype=None,
):
super().__init__(num_features, eps, momentum, affine, track_running_stats)
super().__init__(
num_features, eps, momentum, affine, track_running_stats, device, dtype
)

def forward(self, x):
if self.dtype is None:
self.dtype = x.dtype
if self.device is None:
self.device = x.device

self._check_input_dim(x)
reduce_axis = []
for dim in range(len(x.shape)):
if dim != 1:
reduce_axis.append(dim)
mean = x.mean(dim=reduce_axis, keepdim=False)
variance = x.var(dim=reduce_axis, keepdim=False)

if x.device == flow.device("cpu"):
if self.training:
reduce_axis = []
for dim in range(len(x.shape)):
if dim != 1:
reduce_axis.append(dim)
mean = x.mean(dim=reduce_axis, keepdim=False)
variance = x.var(dim=reduce_axis, keepdim=False)

if self.training and self.track_running_stats:
running_mean = (
self.momentum * self.running_mean + (1 - self.momentum) * mean
)
running_var = (
self.momentum * self.running_var + (1 - self.momentum) * variance
)

# update training parameters/buffers
# update training buffers
self.__setattr__("running_mean", flow.Tensor(running_mean))
self.__setattr__("running_var", flow.Tensor(running_var))

else:
mean = self.running_mean
variance = self.running_var
mean = mean if self.running_mean is None else self.running_mean
variance = variance if self.running_var is None else self.running_var

axis = 1
params_shape = [x.shape[axis]]
weight = self.weight
bias = self.bias

if len(mean.shape) == 1:
nd_params_shape = [1] * len(x.shape)
nd_params_shape[axis] = params_shape[0]
Expand All @@ -158,20 +173,21 @@ def forward(self, x):
affined = affined * weight
if self.bias:
affined = affined + bias
return affined
return affined.to(dtype=self.dtype)

else:
return flow.F.normalization(
res = flow.F.normalization(
x,
self.running_mean,
self.running_var,
self.running_mean if self.track_running_stats else mean,
self.running_var if self.track_running_stats else variance,
self.weight,
self.bias,
axis=1,
epsilon=self.eps,
momentum=self.momentum,
is_training=self.training,
)
return res.to(dtype=self.dtype, device=self.device)


@oneflow_export("nn.BatchNorm1d")
Expand Down
111 changes: 96 additions & 15 deletions oneflow/python/test/modules/test_batchnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ def _test_batchnorm1d_2d_input(test_case, device):
dtype=np.float32,
)

m = flow.nn.BatchNorm1d(num_features=5, eps=1e-5, momentum=0.1).to(
device=flow.device(device)
m = flow.nn.BatchNorm1d(
num_features=5, eps=1e-5, momentum=0.1, device=flow.device(device)
)
x = flow.Tensor(input_arr, device=flow.device(device))
y = m(x)
Expand Down Expand Up @@ -85,8 +85,8 @@ def _test_batchnorm1d_3d_input(test_case, device):
dtype=np.float32,
)

m = flow.nn.BatchNorm1d(num_features=3, eps=1e-5, momentum=0.1).to(
device=flow.device(device)
m = flow.nn.BatchNorm1d(
num_features=3, eps=1e-5, momentum=0.1, device=flow.device(device)
)
x = flow.Tensor(input_arr, device=flow.device(device))
y = m(x)
Expand Down Expand Up @@ -154,8 +154,85 @@ def _test_batchnorm2d(test_case, device):
dtype=np.float32,
)

m = flow.nn.BatchNorm2d(num_features=2, eps=1e-5, momentum=0.1).to(
device=flow.device(device)
m = flow.nn.BatchNorm2d(
num_features=2,
eps=1e-5,
momentum=0.1,
device=flow.device(device),
dtype=flow.float64,
)
x = flow.Tensor(input_arr, device=flow.device(device), dtype=flow.float32)
y = m(x)
test_case.assertTrue(np.allclose(y.numpy(), output, 1e-04, 1e-04))


def _test_batchnorm2d_track_running_stats(test_case, device):
input_arr = np.array(
[
[
[
[-0.8791, 0.2553, 0.7403, -0.2859],
[0.8006, -1.7701, -0.9617, 0.1705],
[0.2842, 1.7825, 0.3365, -0.8525],
],
[
[0.7332, -0.0737, 0.7245, -0.6551],
[1.4461, -0.1827, 0.9737, -2.1571],
[0.4657, 0.7244, 0.3378, 0.1775],
],
],
[
[
[1.8896, 1.8686, 0.1896, 0.9817],
[-0.0671, 1.5569, 1.1449, 0.0086],
[-0.9468, -0.0124, 1.3227, -0.6567],
],
[
[-0.8472, 1.3012, -1.1065, 0.9348],
[1.0346, 1.5703, 0.2419, -0.7048],
[0.6957, -0.4523, -0.8819, 1.0164],
],
],
],
dtype=np.float32,
)

output = np.array(
[
[
[
[-1.1868, -0.0328, 0.4606, -0.5833],
[0.5220, -2.0933, -1.2709, -0.1190],
[-0.0034, 1.5209, 0.0498, -1.1598],
],
[
[0.5601, -0.3231, 0.5505, -0.9595],
[1.3404, -0.4424, 0.8233, -2.6035],
[0.2673, 0.5504, 0.1273, -0.0482],
],
],
[
[
[1.6299, 1.6085, -0.0996, 0.7062],
[-0.3608, 1.2914, 0.8723, -0.2837],
[-1.2557, -0.3051, 1.0531, -0.9606],
],
[
[-1.1698, 1.1818, -1.4536, 0.7807],
[0.8900, 1.4763, 0.0223, -1.0139],
[0.5190, -0.7375, -1.2078, 0.8700],
],
],
],
dtype=np.float32,
)

m = flow.nn.BatchNorm2d(
num_features=2,
eps=1e-5,
momentum=0.1,
track_running_stats=False,
device=flow.device(device),
)
x = flow.Tensor(input_arr, device=flow.device(device))
y = m(x)
Expand Down Expand Up @@ -223,8 +300,8 @@ def _test_batchnorm2d_4d_input(test_case, device):
dtype=np.float32,
)

m = flow.nn.BatchNorm2d(num_features=2, eps=1e-5, momentum=0.1).to(
device=flow.device(device)
m = flow.nn.BatchNorm2d(
num_features=2, eps=1e-5, momentum=0.1, device=flow.device(device)
)
x = flow.Tensor(input_arr, device=flow.device(device))
y = m(x)
Expand Down Expand Up @@ -292,8 +369,8 @@ def test_batchnorm2d_infer(test_case, device):
dtype=np.float32,
)

m = flow.nn.BatchNorm2d(num_features=2, eps=1e-5, momentum=0.1).to(
device=flow.device(device)
m = flow.nn.BatchNorm2d(
num_features=2, eps=1e-5, momentum=0.1, device=flow.device(device)
)
m.eval()
x = flow.Tensor(input_arr, device=flow.device(device))
Expand Down Expand Up @@ -362,8 +439,8 @@ def test_batchnorm2d_infer_4d_input(test_case, device):
dtype=np.float32,
)

m = flow.nn.BatchNorm2d(num_features=2, eps=1e-5, momentum=0.1).to(
device=flow.device(device)
m = flow.nn.BatchNorm2d(
num_features=2, eps=1e-5, momentum=0.1, device=flow.device(device)
)
m.eval()
x = flow.Tensor(input_arr, device=flow.device(device))
Expand Down Expand Up @@ -402,8 +479,8 @@ def _test_batchnorm2d_backward(test_case, device):
dtype=np.float32,
)

m = flow.nn.BatchNorm2d(num_features=2, eps=1e-5, momentum=0.1).to(
device=flow.device(device)
m = flow.nn.BatchNorm2d(
num_features=2, eps=1e-5, momentum=0.1, device=flow.device(device)
)
x = flow.Tensor(input_arr, device=flow.device(device), requires_grad=True)
y = m(x)
Expand All @@ -422,10 +499,11 @@ class TestBatchNorm(flow.unittest.TestCase):
def test_batchnorm(test_case):
arg_dict = OrderedDict()
arg_dict["test_fun"] = [
_test_batchnorm2d,
_test_batchnorm1d_2d_input,
_test_batchnorm1d_3d_input,
_test_batchnorm2d,
_test_batchnorm2d_4d_input,
_test_batchnorm2d_track_running_stats,
test_batchnorm2d_infer,
test_batchnorm2d_infer_4d_input,
_test_batchnorm2d_backward,
Expand All @@ -447,12 +525,15 @@ def test_with_random_data(test_case):
"momentum": float,
"affine": bool,
"track_running_stats": bool,
"dtype": str,
"device": flow.device,
},
extra_generators={
"input": random_tensor(ndim=4, dim1=8),
"num_features": constant(8),
"eps": random(1e-6, 1),
"momentum": random(0, 1),
"track_running_stats": constant(True),
},
device=device,
training=training,
Expand Down