# Batch Normalization

## definition
the method to control covariate shift at each layer 

\begin{equation}
{μ}_{Β} ←  \\
Var(W) = \sqrt{\frac{2} {n_{in} + n_{out}}} \\
\end{equation}

\begin{equation*}
    m = \frac{\sum{}^{} {}}{\sqrt{1-\frac{v^2}{c^2}‭}}
\end{equation*}

In [1]:
import torch
import torchvision.datasets as dsets
import torchvision.transforms as transforms
import matplotlib.pylab as plt

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# for reporductibility
torch.manual_seed(1)
if device == 'cuda':
    torch.cuda.manual_seed_all(1)

In [3]:
# parameters
lr = 0.01
epochs = 10
batch_size = 32

In [4]:
mnist_train = dsets.MNIST(root="../_datasets/", # 파일의 위치
                          train=True, # 학습 데이터 여부
                          transform=transforms.ToTensor(), # 데이터의 순서를 이미지에서 텐서에 맞도록 변환
                                                           # 일반 이미지 : 값 0~255, 순서 Height, Weight, Channel
                                                           # 변환된 이미지 : 값 0~1, 순서 Channel, Height, Weight
                          download=True)                   # 데이터가 없는 경우 다운을 받음
mnist_test = dsets.MNIST(root="../_datasets/", 
                          train=False, 
                          transform=transforms.ToTensor(),
                          download=True)

In [5]:
# dataset loader
train_loader = torch.utils.data.DataLoader(dataset=mnist_train,
                                           batch_size=batch_size,
                                           shuffle=True,
                                           drop_last=True)

test_loader = torch.utils.data.DataLoader(dataset=mnist_test,
                                          batch_size=batch_size,
                                          shuffle=False,
                                          drop_last=True)

In [6]:
# nn layers
linear1 = torch.nn.Linear(784, 32, bias = True)
linear2 = torch.nn.Linear(32, 32, bias = True)
linear3 = torch.nn.Linear(32, 10, bias = True)
relu = torch.nn.ReLU()
bn1 = torch.nn.BatchNorm1d(32)
bn2 = torch.nn.BatchNorm1d(32)

nn_linear1 = torch.nn.Linear(784, 32, bias = True)
nn_linear2 = torch.nn.Linear(32, 32, bias = True)
nn_linear3 = torch.nn.Linear(32, 10, bias = True)

In [7]:
# model
bn_model = torch.nn.Sequential(linear1, bn1, relu,
                               linear2, bn2, relu,
                               linear3).to(device)
nn_model = torch.nn.Sequential(nn_linear1, relu,
                               nn_linear2, relu,
                               nn_linear3).to(device)

In [8]:
# cost, loss, optimizer
criterion = torch.nn.CrossEntropyLoss().to(device)
bn_optimizer = torch.optim.Adam(bn_model.parameters(), lr=lr)
nn_optimizer = torch.optim.Adam(nn_model.parameters(), lr=lr)

In [14]:
# Save losses and accuracies every epoch
train_losses = []
train_accs = []

valid_losses = []
valid_accs = []

train_total_batch = len(train_loader)
test_total_batch = len(test_loader)

for epoch in range(epochs+1):
    bn_model.train()
    nn_model.train()
    
    for x, y in train_loader:
        x = x.view(-1, 28*28).to(device)
        y = y.to(device)
        
        bn_optimizer.zero_grad()
        bn_pred = bn_model(x)
        bn_loss = criterion(bn_pred, y) 
        bn_loss.backward()
        bn_optimizer.step()
        
        nn_optimizer.zero_grad()
        nn_pred = bn_model(x)
        nn_loss = criterion(nn_pred, y) 
        nn_loss.backward()
        nn_optimizer.step()
        
    with torch.no_grad():
        bn_model.eval()
        nn_model.eval()
        
        bn_lost, nn_loss, bn_acc, nn_acc = 0, 0, 0, 0
        for i, (x, y) in enumerate(train_loader):
            x = x.view(-1, 28*28).to(device)
            y = y.to(device)
            
            bn_pred = bn_model(x)
            bn_correct_pred = torch.argmax(bn_pred, 1) == y
            bn_loss += criterion(bn_pred, y)
            bn_acc += bn_correct_pred.float().mean()
        
            nn_pred = bn_model(x)
            nn_correct_pred = torch.argmax(bn_pred, 1) == y
            nn_loss += criterion(nn_pred, y)
            nn_acc += bn_correct_pred.float().mean()

        bn_loss = bn_loss / train_total_batch
        nn_loss = nn_loss / train_total_batch
        bn_acc  = bn_loss / train_total_batch
        nn_acc  = nn_loss / train_total_batch
        
        # Save valid losses/accuracies
        valid_losses.append([bn_loss, nn_loss])
        valid_accs.append([bn_acc, nn_acc])
        print('[Epoch %d-VALID] Batchnorm Loss(Acc): bn_loss:%.5f(bn_acc:%.2f) vs No Batchnorm Loss(Acc): nn_loss:%.5f(nn_acc:%.2f)' % (
                (epoch + 1), bn_loss.item(), bn_acc.item(), nn_loss.item(), nn_acc.item()))
        
print('Learning finished')

[Epoch 1-VALID] Batchnorm Loss(Acc): bn_loss:0.04811(bn_acc:0.00) vs No Batchnorm Loss(Acc): nn_loss:0.04808(nn_acc:0.00)
[Epoch 2-VALID] Batchnorm Loss(Acc): bn_loss:0.04370(bn_acc:0.00) vs No Batchnorm Loss(Acc): nn_loss:0.04369(nn_acc:0.00)
[Epoch 3-VALID] Batchnorm Loss(Acc): bn_loss:0.04125(bn_acc:0.00) vs No Batchnorm Loss(Acc): nn_loss:0.04125(nn_acc:0.00)
[Epoch 4-VALID] Batchnorm Loss(Acc): bn_loss:0.04038(bn_acc:0.00) vs No Batchnorm Loss(Acc): nn_loss:0.04035(nn_acc:0.00)
[Epoch 5-VALID] Batchnorm Loss(Acc): bn_loss:0.04072(bn_acc:0.00) vs No Batchnorm Loss(Acc): nn_loss:0.04070(nn_acc:0.00)
[Epoch 6-VALID] Batchnorm Loss(Acc): bn_loss:0.03541(bn_acc:0.00) vs No Batchnorm Loss(Acc): nn_loss:0.03537(nn_acc:0.00)
[Epoch 7-VALID] Batchnorm Loss(Acc): bn_loss:0.03800(bn_acc:0.00) vs No Batchnorm Loss(Acc): nn_loss:0.03796(nn_acc:0.00)
[Epoch 8-VALID] Batchnorm Loss(Acc): bn_loss:0.02959(bn_acc:0.00) vs No Batchnorm Loss(Acc): nn_loss:0.02945(nn_acc:0.00)
[Epoch 9-VALID] Batchnor

In [16]:
valid_accs

[[tensor(2.5657e-05, device='cuda:0'), tensor(2.5643e-05, device='cuda:0')],
 [tensor(2.3306e-05, device='cuda:0'), tensor(2.3303e-05, device='cuda:0')],
 [tensor(2.1999e-05, device='cuda:0'), tensor(2.1998e-05, device='cuda:0')],
 [tensor(2.1535e-05, device='cuda:0'), tensor(2.1519e-05, device='cuda:0')],
 [tensor(2.1716e-05, device='cuda:0'), tensor(2.1708e-05, device='cuda:0')],
 [tensor(1.8883e-05, device='cuda:0'), tensor(1.8866e-05, device='cuda:0')],
 [tensor(2.0266e-05, device='cuda:0'), tensor(2.0246e-05, device='cuda:0')],
 [tensor(1.5780e-05, device='cuda:0'), tensor(1.5709e-05, device='cuda:0')],
 [tensor(1.6397e-05, device='cuda:0'), tensor(1.6386e-05, device='cuda:0')],
 [tensor(1.6963e-05, device='cuda:0'), tensor(1.6962e-05, device='cuda:0')],
 [tensor(1.5599e-05, device='cuda:0'), tensor(1.5580e-05, device='cuda:0')]]

In [17]:
valid_losses

[[tensor(0.0481, device='cuda:0'), tensor(0.0481, device='cuda:0')],
 [tensor(0.0437, device='cuda:0'), tensor(0.0437, device='cuda:0')],
 [tensor(0.0412, device='cuda:0'), tensor(0.0412, device='cuda:0')],
 [tensor(0.0404, device='cuda:0'), tensor(0.0403, device='cuda:0')],
 [tensor(0.0407, device='cuda:0'), tensor(0.0407, device='cuda:0')],
 [tensor(0.0354, device='cuda:0'), tensor(0.0354, device='cuda:0')],
 [tensor(0.0380, device='cuda:0'), tensor(0.0380, device='cuda:0')],
 [tensor(0.0296, device='cuda:0'), tensor(0.0295, device='cuda:0')],
 [tensor(0.0307, device='cuda:0'), tensor(0.0307, device='cuda:0')],
 [tensor(0.0318, device='cuda:0'), tensor(0.0318, device='cuda:0')],
 [tensor(0.0292, device='cuda:0'), tensor(0.0292, device='cuda:0')]]