Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Gluon BatchNorm beta=False not working properly #11774

Open
roywei opened this issue Jul 16, 2018 · 3 comments · Fixed by #12625
Open

Gluon BatchNorm beta=False not working properly #11774

roywei opened this issue Jul 16, 2018 · 3 comments · Fixed by #12625

Comments

@roywei
Copy link
Member

roywei commented Jul 16, 2018

Description
I came from #10401
In that issue, there is no reproducible code, so not sure whats the use case there.
When I was developing mxnet backend for keras, I have similar issue.

So I used this unit test and was able to reproduce:

Config
machine: mac
mxnet: latest master (pip install mxnet --pre)
keras: 2.2.0 with TensorFlow backend

The purpose of the test:
Create a random normalized data with size size=(1000, 3, 4, 4), centered on 5.0, variance 10.0.
Use batchnorm layer to normalize it to have mean close to 0, and std close to 1.

Reproducible code:
First part of the code shows Keras implementation with TensorFlow backend, you can see mean is close to 0 and std is close to 1.
Second Part of the code shows Gluon implementation, When scale and center is False, it throws error

import mxnet as mx
import numpy as np
from keras.layers import normalization
from keras.models import Sequential
from mxnet import autograd
from mxnet import gluon
from mxnet import nd
from numpy.testing import assert_allclose


"""
Keras test
"""
model = Sequential()
norm = normalization.BatchNormalization(center=False, scale=False,
                                        input_shape=(3, 4, 4))
model.add(norm)
model.compile(loss='mse', optimizer='sgd')

# centered on 5.0, variance 10.0
x = np.random.normal(loc=5.0, scale=10.0, size=(1000, 3, 4, 4))
model.fit(x, x, epochs=4, verbose=0, batch_size=32)
out = model.predict(x)
print("Results from keras:")
print( "Mean: %s" % np.mean(out, axis=(0, 2, 3)))
print( "Std: %s" % np.std(out, axis=(0, 2, 3)))
assert_allclose(np.mean(out, axis=(0, 2, 3)), 0.0, atol=1e-1)
assert_allclose(np.std(out, axis=(0, 2, 3)), 1.0, atol=1e-1)



"""
Gluon test
"""
print(mx.__version__)
x = np.random.normal(loc=5.0, scale=10.0, size=(1000, 3, 4, 4))
x_nd = nd.array(x)
net = gluon.nn.Sequential()
ctx = mx.cpu()

with net.name_scope():
    net.add(gluon.nn.BatchNorm(scale=False, center=False))
net.collect_params().initialize(mx.init.Normal(sigma=.1), ctx=ctx)
mse = gluon.loss.L2Loss()
trainer = gluon.Trainer(net.collect_params(), 'sgd')

epochs = 4

for e in range(epochs):
    cumulative_loss = 0
    for i in range(1000):
        data = x_nd[i, :]
        label = x_nd[i, :]
        with autograd.record():
            output = net(data)
            loss = mse(output, label)
        loss.backward(retain_graph=True)
        trainer.step(32)
        cumulative_loss += nd.sum(loss).asscalar()

out = net(x_nd).asnumpy()
print("Results from gluon:")
print( "Mean: %s" % np.mean(out, axis=(0, 2, 3)))
print( "Std: %s" % np.std(out, axis=(0, 2, 3)))
assert_allclose(np.mean(out, axis=(0, 2, 3)), 0.0, atol=1e-1)
assert_allclose(np.std(out, axis=(0, 2, 3)), 1.0, atol=1e-1)

Error Message:

Traceback (most recent call last):
  File "/Users/lawei/Documents/Notebooks/keras/gluon_batchnorm.py", line 56, in <module>
    loss.backward(retain_graph=True)
  File "/Users/lawei/anaconda3/lib/python3.6/site-packages/mxnet/ndarray/ndarray.py", line 2130, in backward
    ctypes.c_void_p(0)))
  File "/Users/lawei/anaconda3/lib/python3.6/site-packages/mxnet/base.py", line 210, in check_call
    raise MXNetError(py_str(_LIB.MXGetLastError()))
mxnet.base.MXNetError: [16:31:24] src/imperative/imperative.cc:285: Check failed: !AGInfo::IsNone(*i) Cannot differentiate node because it is not in a computational graph. You need to set is_recording to true or use autograd.record() to save computational graphs for backward. If you want to differentiate the same graph twice, you need to pass retain_graph=True to backward.

Stack trace returned 6 entries:
[bt] (0) 0   libmxnet.so                         0x0000000111dfdeb4 libmxnet.so + 20148
[bt] (1) 1   libmxnet.so                         0x0000000111dfdc6f libmxnet.so + 19567
[bt] (2) 2   libmxnet.so                         0x0000000112f69389 MXNDListFree + 548393
[bt] (3) 3   libmxnet.so                         0x0000000112ebe1fd MXAutogradBackwardEx + 893
[bt] (4) 4   libffi.6.dylib                      0x0000000110670884 ffi_call_unix64 + 76
[bt] (5) 5   ???                                 0x00007fff4fcabf40 0x0 + 140734532075328

My Questions:

  1. gluon.nn.BatchNorm is not working with scale and beta both to False, need to fix
  2. When changing scale=False, center=True, I was able to get std close to 1, but mean is not close to 0, see following output. What's the workaround to normalize this data to mean is 0, and std is 1. (I assume the current BatchNorm implementation is correct, some logic not handled when beta is None/False)
  3. How to do this in Symbolic API?
Results from keras:
Mean: [-0.00930524  0.00754305  0.0025016 ]
Std: [ 0.99344313  0.99861413  1.00193286]
1.3.0
Results from gluon:
Mean: [ 3.16370535  2.95543957  3.03811169]
Std: [ 1.04609287  0.93967396  0.98776978]
Traceback (most recent call last):
  File "/Users/lawei/Documents/Notebooks/keras/gluon_batchnorm.py", line 65, in <module>
    assert_allclose(np.mean(out, axis=(0, 2, 3)), 0.0, atol=1e-1)
  File "/Users/lawei/anaconda3/lib/python3.6/site-packages/numpy/testing/utils.py", line 1395, in assert_allclose
    verbose=verbose, header=header, equal_nan=equal_nan)
  File "/Users/lawei/anaconda3/lib/python3.6/site-packages/numpy/testing/utils.py", line 778, in assert_array_compare
    raise AssertionError(msg)
AssertionError: 
Not equal to tolerance rtol=1e-07, atol=0.1

(mismatch 100.0%)
 x: array([ 3.163705,  2.95544 ,  3.038112], dtype=float32)
 y: array(0.0)
@roywei
Copy link
Member Author

roywei commented Jul 16, 2018

@sandeep-krishnamurthy could you help label Gluon, Bug? Thanks

@sandeep-krishnamurthy
Copy link
Contributor

Resolving via #12625

@roywei
Copy link
Member Author

roywei commented Oct 2, 2019

I'm re-opening this issue as the fix PR got reverted, and there is a user reported this does not work in keras-mxnet. #12789

I've verified the above reproducible code is still not working for both keras-mxnet and gluon

@roywei roywei reopened this Oct 2, 2019
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants