In [1]:
from mxnet import nd
def pure_batch_norm(X, gamma, beta, eps = 1e-5):
    assert len(X.shape) in (2, 4)
    # if dense layer (batch_size x feature)
    if len(X.shape) == 2:
        # EX and DX in every input dimension
        mean = X.mean(axis = 0)
        variance = ((X - mean) ** 2).mean(axis = 0)
    # if conv layer (batch_size x channel x height x width)
    else:
        # EX and DX in every channels
        mean = X.mean(axis = (0, 2, 3), keepdims = True)
        variance = ((X - mean) ** 2).mean(axis = (0, 2, 3), keepdims = True)
        
    print('mean: ', mean, '\nvariance: ', variance, '\nmean shape: ', mean.shape, '\nvariance shape', variance.shape)
    # normalize
    X_hat = (X - mean) / nd.sqrt(variance + eps)
    # 
    return gamma.reshape(mean.shape) * X_hat + beta.reshape(mean.shape)

  import OpenSSL.SSL


In [2]:
A = nd.arange(6).reshape((3, 2))
A


[[ 0.  1.]
 [ 2.  3.]
 [ 4.  5.]]
<NDArray 3x2 @cpu(0)>

In [3]:
pure_batch_norm(A, gamma = nd.array([1, 1]), beta = nd.array([0, 0]))

mean:  
[ 2.  3.]
<NDArray 2 @cpu(0)> 
variance:  
[ 2.66666675  2.66666675]
<NDArray 2 @cpu(0)> 
mean shape:  (2,) 
variance shape (2,)



[[-1.22474265 -1.22474265]
 [ 0.          0.        ]
 [ 1.22474265  1.22474265]]
<NDArray 3x2 @cpu(0)>

In [4]:
restore = pure_batch_norm(A, gamma = nd.sqrt(nd.array([2.66666675, 2.66666675])), beta = nd.array([2, 3]))
restore

mean:  
[ 2.  3.]
<NDArray 2 @cpu(0)> 
variance:  
[ 2.66666675  2.66666675]
<NDArray 2 @cpu(0)> 
mean shape:  (2,) 
variance shape (2,)



[[  3.57627869e-06   1.00000358e+00]
 [  2.00000000e+00   3.00000000e+00]
 [  3.99999642e+00   4.99999619e+00]]
<NDArray 3x2 @cpu(0)>

In [5]:
nd.array([1, 2]) * nd.array([[1, 2],[3, 4], [5, 6]]) + nd.array([2, 3])


[[  3.   7.]
 [  5.  11.]
 [  7.  15.]]
<NDArray 3x2 @cpu(0)>

In [6]:
B = nd.arange(54).reshape((2, 3, 3, 3))
B


[[[[  0.   1.   2.]
   [  3.   4.   5.]
   [  6.   7.   8.]]

  [[  9.  10.  11.]
   [ 12.  13.  14.]
   [ 15.  16.  17.]]

  [[ 18.  19.  20.]
   [ 21.  22.  23.]
   [ 24.  25.  26.]]]


 [[[ 27.  28.  29.]
   [ 30.  31.  32.]
   [ 33.  34.  35.]]

  [[ 36.  37.  38.]
   [ 39.  40.  41.]
   [ 42.  43.  44.]]

  [[ 45.  46.  47.]
   [ 48.  49.  50.]
   [ 51.  52.  53.]]]]
<NDArray 2x3x3x3 @cpu(0)>

In [7]:
pure_batch_norm(B, gamma = nd.array([1, 1, 1]), beta = nd.array([0, 0, 0]))

mean:  
[[[[ 17.5]]

  [[ 26.5]]

  [[ 35.5]]]]
<NDArray 1x3x1x1 @cpu(0)> 
variance:  
[[[[ 188.91667175]]

  [[ 188.91667175]]

  [[ 188.91667175]]]]
<NDArray 1x3x1x1 @cpu(0)> 
mean shape:  (1, 3, 1, 1) 
variance shape (1, 3, 1, 1)



[[[[-1.27321839 -1.20046306 -1.12770772]
   [-1.05495238 -0.98219699 -0.90944171]
   [-0.83668637 -0.76393104 -0.6911757 ]]

  [[-1.27321839 -1.20046306 -1.12770772]
   [-1.05495238 -0.98219699 -0.90944171]
   [-0.83668637 -0.76393104 -0.6911757 ]]

  [[-1.27321839 -1.20046306 -1.12770772]
   [-1.05495238 -0.98219699 -0.90944171]
   [-0.83668637 -0.76393104 -0.6911757 ]]]


 [[[ 0.6911757   0.76393104  0.83668637]
   [ 0.90944171  0.98219699  1.05495238]
   [ 1.12770772  1.20046306  1.27321839]]

  [[ 0.6911757   0.76393104  0.83668637]
   [ 0.90944171  0.98219699  1.05495238]
   [ 1.12770772  1.20046306  1.27321839]]

  [[ 0.6911757   0.76393104  0.83668637]
   [ 0.90944171  0.98219699  1.05495238]
   [ 1.12770772  1.20046306  1.27321839]]]]
<NDArray 2x3x3x3 @cpu(0)>

In [8]:
def batch_norm(X, gamma, beta, is_training, moving_mean, moving_variance, 
               eps = 1e-5, moving_momentum = 0.9):
    assert len(X.shape) in (2, 4)
    if len(X.shape) == 2:
        mean = X.mean(axis = 0)
        variance = ((X - mean) ** 2).mean(axis = 0)
    else:
        mean = X.mean(axis = (0, 2, 3), keepdims = True)
        variance = ((X - mean) ** 2).mean(axis = (0, 2, 3),keepdims = True)
    # transform in order to broadcast correct
    moving_mean = moving_mean.reshape(mean.shape)
    moving_variance = moving_variance.reshape(mean.shape)
        
    if is_training:
        X_hat = (X - mean) / nd.sqrt(variance + eps)
        # update global EX and DX
        moving_mean[:] = moving_momentum * moving_mean + (1.0 - moving_momentum) * mean
        moving_variance[:] = moving_momentum * moving_variance + (1.0 - moving_momentum) * variance
    else:
        X_hat = (X - moving_mean) / nd.sqrt(moving_variance + eps)
        
    return gamma.reshape(mean.shape) * X_hat + beta.reshape(mean.shape)

In [9]:
import sys
sys.path.append('..')
import utils
ctx = utils.try_gpu()
ctx

gpu(0)

In [10]:
weight_scale = .01

# output channels = 20, kernel = (5, 5)
c1 = 20
w1 = nd.random.normal(shape = (c1, 1, 5, 5), scale = weight_scale, ctx = ctx)
b1 = nd.zeros(c1, ctx = ctx)

# batch_norm 1
gamma1 = nd.random.normal(shape = c1, scale = weight_scale, ctx = ctx)
beta1 = nd.random.normal(shape = c1, scale = weight_scale, ctx = ctx)
moving_mean1 = nd.zeros(c1, ctx = ctx)
moving_variance1 = nd.zeros(c1, ctx = ctx)

# output channels = 50, kernel = (3, 3)
c2 = 50
w2 = nd.random.normal(shape = (c2, c1, 3, 3), scale = weight_scale, ctx = ctx)
b2 = nd.zeros(shape = c2, ctx = ctx)

# batch_norm 2
gamma2 = nd.random.normal(shape = c2, scale = weight_scale, ctx = ctx)
beta2 = nd.random.normal(shape = c2, scale = weight_scale, ctx = ctx)
moving_mean2 = nd.zeros(c2, ctx =ctx)
moving_variance2 = nd.zeros(c2, ctx = ctx)

# output dims = 128
o3 = 128
w3 = nd.random.normal(shape = (1250, o3), scale = weight_scale, ctx = ctx)
b3 = nd.zeros(o3, ctx = ctx)

# output dims = 10
o4 = 10
w4 = nd.random.normal(shape = (o3, o4), scale = weight_scale, ctx = ctx)
b4 = nd.zeros(o4, ctx = ctx)

# note: the moving_* need not to update
params = [w1, b1, gamma1, beta1, 
         w2, b2, gamma2, beta2, 
         w3, b3, w4, b4]

for param in params:
    param.attach_grad()

In [11]:
def net(X, is_training = False, verbose = False):
    X = X.as_in_context(w1.context)
    # 1st conv
    h1_conv = nd.Convolution(
        data = X, weight = w1, bias = b1, kernel = w1.shape[2:], num_filter = c1)
    # add batch_norm
    h1_bn = batch_norm(h1_conv, gamma1, beta1, is_training, 
                       moving_mean1, moving_variance1)
    h1_activation = nd.relu(h1_bn)
    h1 = nd.Pooling(data = h1_activation, pool_type = 'max', 
                    kernel = (2, 2), stride = (2, 2))
    # 2nd conv
    h2_conv = nd.Convolution(
        data = h1, weight = w2, bias = b2, kernel = w2.shape[2:], num_filter = c2)
    h2_bn = batch_norm(h2_conv, gamma2, beta2, is_training, 
                       moving_mean2, moving_variance2)
    h2_activation = nd.relu(h2_bn)
    h2 = nd.Pooling(data = h2_activation, pool_type = 'max', 
                    kernel = (2, 2), stride = (2, 2))
    h2_flat = nd.flatten(h2)
    # 1st dense
    h3_linear = nd.dot(h2_flat, w3) + b3
    h3 = nd.relu(h3_linear)
    # 2nd dense
    h4_linear = nd.dot(h3, w4) + b4
    if verbose:
        print('1st conv block:', h1.shape)
        print('2nd conv block:', h2.shape)
        print('1st dense:', h3.shape)
        print('2nd dense:', h4_linear.shape)
        print('output:', h4_linear)
    return h4_linear

In [12]:
from mxnet import autograd as ag
from mxnet import gluon

batch_size = 256
train_data, test_data = utils.load_data_fashion_mnist(batch_size)

softmax_cross_entropy = gluon.loss.SoftmaxCrossEntropyLoss()

learning_rate = 0.2

for epoch in range(5):
    train_loss = 0.
    train_acc = 0.
    for data, label in train_data:
        label = label.as_in_context(ctx)
        with ag.record():
            output = net(data, is_training = True, verbose = False)
            loss = softmax_cross_entropy(output, label)
        loss.backward()
        utils.SGD(params, learning_rate / batch_size)
        
        train_loss += nd.mean(loss).asscalar()
        train_acc += utils.accuracy(output, label)
    
    test_acc = utils.evaluate_accuracy(test_data, net, ctx)
    print('Epoch %d, Loss: %f, Train acc %f, Test acc %f' % (
            epoch, train_loss / len(train_data), train_acc / len(train_data), test_acc))

Epoch 0, Loss: 2.070894, Train acc 0.222990, Test acc 0.625501
Epoch 1, Loss: 0.586744, Train acc 0.773371, Test acc 0.814303
Epoch 2, Loss: 0.419263, Train acc 0.842231, Test acc 0.865184
Epoch 3, Loss: 0.357987, Train acc 0.866536, Test acc 0.869992
Epoch 4, Loss: 0.325181, Train acc 0.878506, Test acc 0.887220


In [13]:
print(moving_mean1)


[  1.38687454e-02  -5.99591732e-01  -2.79458463e-01  -2.35398367e-01
   4.75273356e-02   6.03083931e-02  -1.18171982e-02   1.72615852e-02
   3.60988233e-05   1.10024847e-02   8.05155188e-02   1.93556115e-01
  -4.80096880e-03  -1.68642491e-01   1.85707323e-02   9.81035650e-01
   2.23618627e+00   5.15063852e-03  -6.44020140e-02  -2.01391190e-01]
<NDArray 20 @gpu(0)>
