Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

batchnorm from scratch with autograd gives very different gradient from mx.nd.BatchNorm #12369

Closed
RogerChern opened this issue Aug 27, 2018 · 9 comments

Comments

Projects
None yet
7 participants
@RogerChern
Copy link

commented Aug 27, 2018

Description

batchnorm from scratch with autograd gives very different gradient from mx.nd.BatchNorm. Forward results are OK.

Environment info (Required)

macOS 10.13 with mxnet 1.2.1 cpu version from pip.

Package used (Python/R/Scala/Julia):
Python

Error Message:

(Paste the complete error message, including stack trace.)

Minimum reproducible example

import mxnet as mx


def batch_norm_nd(x, gamma, beta, eps=1e-5):
    mean = mx.nd.mean(x, axis=(0, 2, 3), keepdims=True)
    var = mx.nd.mean((x - mean) ** 2, axis=(0, 2, 3), keepdims=True)
    x_hat = (x - mean) / mx.nd.sqrt(var + eps)

    return x_hat * gamma + beta

if __name__ == "__main__":
    x = mx.nd.random_uniform(low=1, high=2, shape=(2, 16, 4, 4))
    gamma = mx.nd.ones(shape=(1, 16, 1, 1))
    beta = mx.nd.zeros(shape=(1, 16, 1, 1))
    mmean = mx.nd.zeros(shape=(1, 16, 1, 1))
    mvar = mx.nd.zeros(shape=(1, 16, 1, 1))
    x.attach_grad()
    gamma.attach_grad()
    beta.attach_grad()

    with mx.autograd.record(train_mode=True):
        y = mx.nd.BatchNorm(x, gamma, beta, mmean, mvar, fix_gamma=False, use_global_stats=False)
    y.backward(mx.nd.ones_like(y))
    y2 = y.copy()
    x2_grad = x.grad.copy()

    with mx.autograd.record(train_mode=True):
        y = batch_norm_nd(x, gamma, beta)
    y.backward(mx.nd.ones_like(y))
    y1 = y.copy()
    x1_grad = x.grad.copy()

    print((y2 / y1)[0, 1])
    print((x2_grad / x1_grad)[0, 1])

results:

[[0.99354386 0.9935453  0.993546   0.9935485 ]
 [0.99354345 0.9935435  0.993581   0.9935487 ]
 [0.9935372  0.99354607 0.9935438  0.9935436 ]
 [0.9935449  0.9935456  0.993545   0.9935423 ]]
<NDArray 4x4 @cpu(0)>

[[-3.6692393 -3.6692448 -3.669247  -3.669256 ]
 [-3.6692376 -3.6692383 -3.6693766 -3.6692567]
 [-3.6692145 -3.6692476 -3.669239  -3.6692383]
 [-3.669243  -3.6692457 -3.6692433 -3.6692333]]
<NDArray 4x4 @cpu(0)>

Steps to reproduce

(Paste the commands you ran that produced the error.)

What have you tried to solve it?

@RogerChern

This comment has been minimized.

Copy link
Author

commented Aug 27, 2018

@ankkhedia

This comment has been minimized.

Copy link
Contributor

commented Aug 27, 2018

@mxnet-label-bot [Bug, NDArray]

@anirudhacharya

This comment has been minimized.

Copy link
Contributor

commented Jun 11, 2019

might be related - #14710

@piyushghai

This comment has been minimized.

Copy link
Contributor

commented Jun 11, 2019

I don't reckon this is related to the tagged issue Anirudh.
In this case, the MRE provided is running imperative NDArray operations in both the cases.

@piyushghai

This comment has been minimized.

Copy link
Contributor

commented Jun 11, 2019

I posted this question on the MXNet Discuss Forum as well to get a wider audience. https://discuss.mxnet.io/t/grads-from-batchnorm-implemented-from-scratch-different-from-mx-nd-batchnorm/4167

@thomelane

This comment has been minimized.

Copy link
Contributor

commented Jun 12, 2019

Gradients in this example are tiny (smaller than float epsilon) so I think the variance in the ratio is to be expected here.

@NRauschmayr

This comment has been minimized.

Copy link
Contributor

commented Jun 12, 2019

Also, in your example you set eps = 1e-5, but mx.nd.BatchNorm uses a default value of 1e-3:

DMLC_DECLARE_FIELD(eps).set_default(1e-3f)

@piyushghai

This comment has been minimized.

Copy link
Contributor

commented Jun 13, 2019

Going by the above comments, I don't think it's a bug in the implementation of BatchNorm or the autograd's differentiation module. So there's nothing to be fixed here.

@mxnet-label-bot Update [Question, NDArray, Autograd]

@RogerChern Can this issue be closed ?
Please feel free to re-open the issue if this is closed in error.

Thanks!

@marcoabreu marcoabreu added Autograd Question and removed Bug labels Jun 13, 2019

@RogerChern

This comment has been minimized.

Copy link
Author

commented Jun 13, 2019

Cool, I now get the correct result with the following snippet.

import mxnet as mx


def batch_norm_nd(x, gamma, beta, eps=1e-5):
    mean = mx.nd.mean(x, axis=(0, 2, 3), keepdims=True)
    var = mx.nd.mean((x - mean) ** 2, axis=(0, 2, 3), keepdims=True)
    x_hat = (x - mean) / mx.nd.sqrt(var + eps)

    return x_hat * gamma + beta

if __name__ == "__main__":
    x1 = mx.nd.random_normal(0.3, 2, shape=(2, 16, 32, 32))
    x2 = x1.copy()
    gamma = mx.nd.ones(shape=(1, 16, 1, 1))
    beta = mx.nd.zeros(shape=(1, 16, 1, 1))
    mmean = mx.nd.zeros(shape=(1, 16, 1, 1))
    mvar = mx.nd.ones(shape=(1, 16, 1, 1))
    x1.attach_grad()
    x2.attach_grad()
    gamma.attach_grad()
    beta.attach_grad()

    grad = mx.nd.random_normal(0, 1, shape=(2, 16, 32, 32))
    with mx.autograd.record(train_mode=True):
        y1 = batch_norm_nd(x1, gamma, beta)
    y1.backward(grad)

    with mx.autograd.record(train_mode=True):
        y2 = mx.nd.BatchNorm(x2, gamma, beta, mmean, mvar, fix_gamma=False, use_global_stats=False, eps=1e-5)
    y2.backward(grad)

    print("--------------------autograd grad scale----------------------")
    print(x1.grad[0, 1])
    print("\n\n")

    print("--------------------forward native/autograd----------------------")
    print((y2 / y1)[0, 1])
    print("\n\n")

    print("--------------------backward native/autograd----------------------")
    print((x2.grad / x1.grad)[0, 1])

@RogerChern RogerChern closed this Jun 13, 2019

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.