In [20]:
import mxnet as mx
from mxnet import autograd,gluon,init,nd
from mxnet.gluon import loss as gloss,nn,data as gdata
import time

# 1. BN层定义
---
- 注意共有情况：
    - gamma & beta：各通道共用一组！！
    - moving_mean & moving_var：这个Batch中所有数据共用！！

In [21]:
def batch_norm(X,gamma,beta,moving_mean,moving_var,eps,momentum):
    if not autograd.is_training(): # 如果不是训练，就直接用。
        X_hat = (X-moving_mean) / nd.sqrt(moving_var + eps)
    else: # 如果是训练，分情况讨论。
        assert len(X.shape) in (2,4)
        
        if len(X.shape) == 2:# 如果要处理的并不是图片
            mean = X.mean(axis=0)
            var = (X-mean**2).mean(axis=0)
        else: # 如果要处理的是图片
            mean = X.mean(axis=(0,2,3),keepdims=True)
            var = ((X-mean)**2).mean(axis=(0,2,3),keepdims=True)
        X_hat = (X-mean)/nd.sqrt(var + eps) # 就在这里，对输入X进行标准化
        
        moving_mean = momentum * moving_mean + (1-momentum) * moving_mean
        moving_var = momentum * moving_var + (1-momentum) + moving_var
    
    Y = gamma * X_hat + beta # 然后考虑放大缩小和偏移，计算出经过BN层后的Y
    
    return Y,moving_mean,moving_var

In [22]:
class BatchNorm(nn.Block):
    def __init__(self,num_features,num_dims,**kwargs):
        super(BatchNorm,self).__init__(**kwargs)
        
        # 首先要得到channel的shape,这里就是代表着输入了有多少张图。
        if num_dims == 2:
            shape = (1,num_features)
        else:
            shape = (1,num_features,1,1)
        
        # 然后带着这个shape，去生成 gamme & beta,moving_mean & moving_var
        self.gamma = self.params.get('gamma',shape = shape,init=init.One())
        self.beta = self.params.get('beta',shape = shape,init=init.Zero())
        
        self.moving_mean = nd.zeros(shape)
        self.moving_var = nd.zeros(shape)
            
        # 然后进行前向计算得到moving_mean & moving_var的值
    def forward(self,X): # 这里应该是重写！！原函数的__init__中应该有调用这个函数的句子。
        if self.moving_mean.context != X.context:
            self.moving_mean = self.moving_mean.copyto(X.context)
            self.moving_var = self.moving_var.copyto(X.context)
        
        Y,self.moving_mean,self.moving_var = batch_norm(X,self.gamma.data(),self.beta.data(),
                                                       self.moving_mean,self.moving_var,
                                                       eps=1e-5,momentum=0.9)
        return Y 
            
        

# 2. 读取数据 + 生成iter 
---

In [23]:
batch_size = 256
mnist_train = gdata.vision.FashionMNIST(train=True)
mnist_test = gdata.vision.FashionMNIST(train=False)

transformer = gdata.vision.transforms.ToTensor()
train_iter = gdata.DataLoader(mnist_train.transform_first(transformer),batch_size=batch_size,shuffle=True)
test_iter = gdata.DataLoader(mnist_test.transform_first(transformer),batch_size=batch_size,shuffle=False)

# 3. 定义网络
---

In [24]:
net = nn.Sequential()
net.add(nn.Conv2D(6,kernel_size=5), # 注意，输入的不是一位线性数据，而是一张图像
        BatchNorm(6,num_dims=4),# 所以这里的dims是4，（批次，channel,长宽）。注意前面的6和conv的输出通道数要对应。
       nn.Activation('sigmoid'),
       nn.MaxPool2D(pool_size=2,strides=2),
       
       nn.Conv2D(16,kernel_size=5),
       BatchNorm(16,num_dims=4),
       nn.Activation('sigmoid'),
       nn.MaxPool2D(pool_size=2,strides=2),
       
       nn.Dense(120),
       BatchNorm(120,num_dims=2), # 注意这里的num_features也是120，但是在全连层，dims=2！！只有行和列。行是批次。
       nn.Activation('sigmoid'),
       
       nn.Dense(84),
       BatchNorm(84,num_dims=2),
       nn.Activation('sigmoid'),
       
       nn.Dense(10))


In [25]:
# 看一下输出的样子！
X = nd.random.normal(shape=(1,1,28,28))
net.initialize()
for layer in net:
    X = layer(X) # 可以看到，BN层后，数据的形状是无变化的！只是规范化了数据的值。
    print(layer.name,'output shape:\t',X.shape)

conv4 output shape:	 (1, 6, 24, 24)
batchnorm8 output shape:	 (1, 6, 24, 24)
sigmoid8 output shape:	 (1, 6, 24, 24)
pool4 output shape:	 (1, 6, 12, 12)
conv5 output shape:	 (1, 16, 8, 8)
batchnorm9 output shape:	 (1, 16, 8, 8)
sigmoid9 output shape:	 (1, 16, 8, 8)
pool5 output shape:	 (1, 16, 4, 4)
dense6 output shape:	 (1, 120)
batchnorm10 output shape:	 (1, 120)
sigmoid10 output shape:	 (1, 120)
dense7 output shape:	 (1, 84)
batchnorm11 output shape:	 (1, 84)
sigmoid11 output shape:	 (1, 84)
dense8 output shape:	 (1, 10)


# 4. 准确率
---

In [26]:
def evaluate_accuracy(data_iter,net):
    acc_sum,n = nd.array([0]),0
    for X,y in data_iter:
        y_hat = net(X)
        acc_sum += (y_hat.argmax(axis=1) == y.astype('float32')).sum()
        n+=y.size
    return acc_sum.asscalar() / n

# 5. 训练
---

In [27]:
def train(net,train_iter,test_iter,
         batch_size,trainer,num_epochs):
    loss = gloss.SoftmaxCrossEntropyLoss()
    
    for epoch in range(num_epochs):
        # 别忘了这一行~要计算的是每一轮epoch所用的时间。
        train_l_sum,train_acc_sum,n,start = 0.0,0.0,0,time.time()
        for X,y in train_iter:
            with autograd.record():
                y_hat = net(X)
                l = loss(y_hat,y).sum()
            l.backward()
            trainer.step(batch_size)
            
            # 每一轮训练（单位是batch）都要统计这个。
            train_l_sum = l.asscalar()
            train_acc_sum = (y_hat.argmax(axis=1)==y.astype('float32')).sum().asscalar()
            n+=y.size
            
        # 但是test_acc是走完一轮epoch之后才测试的！
        # 也就是，每将数据集全部训练完后才测试net的精准度。
        test_acc = evaluate_accuracy(test_iter,net)
        print('train_l_sum:',train_l_sum / n)
        print('train_acc_sum:',train_acc_sum / n)
        print('test_acc:',test_acc)
        print('time consume:',time.time()-start)

In [28]:
lr,num_epochs,batch_size = 1.0,5,256
net.initialize(force_reinit = True,init = init.Xavier())
trainer = gluon.Trainer(net.collect_params(),'sgd',{'learning_rate':lr})
train(net,train_iter,test_iter,
     batch_size,trainer,num_epochs)

train_l_sum: nan
train_acc_sum: 0.00011666666666666667
test_acc: 0.1
time consume: 96.77768325805664
train_l_sum: nan
train_acc_sum: 0.00018333333333333334
test_acc: 0.1
time consume: 96.4963047504425
train_l_sum: nan
train_acc_sum: 0.00025
test_acc: 0.1
time consume: 97.957270860672
train_l_sum: nan
train_acc_sum: 0.00025
test_acc: 0.1
time consume: 96.63986539840698
train_l_sum: nan
train_acc_sum: 0.00015
test_acc: 0.1
time consume: 96.45490765571594
