# Batch Normalization

In [1]:
import torch
from torch import nn
import numpy as np
import matplotlib.pyplot as plt

In [462]:
class Batch_Norm(nn.Module):
    def __init__(self,in_channels,eps=1e-05):
        super().__init__()
        self.r=nn.Parameter(torch.ones(1,in_channels,1,1)) 
        self.b = nn.Parameter(torch.zeros(1,in_channels,1,1))
        self.eps =eps
    def forward(self,x):
        mean = x.mean(dim=(0,2,3), keepdim=True)
        std = x.var(dim=(0,2,3),keepdim=True)
        
        return (x-mean)/torch.pow(std+self.eps,0.5)*self.r+self.b

In [463]:
def checkBatch():
    #张量越小,估计值越差，所以需要足够大才行
    x = torch.randn(2,2,1000,1000)
    
    bn =nn.BatchNorm2d(2,affine=True)
    
    out = bn(x)
    
    mean = torch.mean(x,dim=(0,2,3),keepdim=True)
    var = torch.var(x,dim=(0,2,3),keepdim=True)
    
    mybn = Batch_Norm(2)
    myout = mybn(x)
    #计算公式
    calc =((x-mean)/torch.pow(var+1e-5,0.5))
    if np.allclose(calc.detach().numpy(),out.detach().numpy()):
        print('right')
    else:
        print('false')
    if np.allclose(myout.detach().numpy(),out.detach().numpy()):
        print('you are right')
    else:
        print('you are false')

In [464]:
checkBatch()

right
you are right


# Group Norm

In [9]:
class Group_Norm(nn.Module):
    def __init__(self,in_channels,
                 G=2,
                 eps=1e-05):
        super().__init__()
        self.r=nn.Parameter(torch.ones(1,in_channels,1,1)) 
        self.b = nn.Parameter(torch.zeros(1,in_channels,1,1))
        self.eps =eps
        self.G=G
    def forward(self,x):
        N,C,H,W=x.shape
        x =x.view(N,self.G,C//self.G,H,W)
        
        mean = x.mean(dim=(2,3,4), keepdim=True)
        std = x.var(dim=(2,3,4),keepdim=True)
        
        x = (x-mean)/torch.pow(std+self.eps,0.5)
        x=x.view(N,C,H,W)
        return self.r*x+self.b

In [18]:
def checkGN():
    x = torch.randn(3,64,200,200)
    GN=nn.GroupNorm(2,64)
    out = GN(x)
    
    mygn = Group_Norm(64)
    myout = mygn(x)
    
    if np.allclose(out.detach().numpy(),myout.detach().numpy()):
        print('right')
    else:
        print('wrong')

In [22]:
checkGN()

right
