# Batch Normalization

## definition
the method to control covariate shift at each layer 

\begin{equation}
{μ}_{Β} ←  \\
Var(W) = \sqrt{\frac{2} {n_{in} + n_{out}}} \\
\end{equation}

\begin{equation*}
    m = \frac{\sum{}^{} {}}{\sqrt{1-\frac{v^2}{c^2}‭}}
\end{equation*}

In [None]:
import torch
import torch.nn as nn
import torchvision.datasets as dsets
import torchvision.transforms as transforms
import matplotlib.pylab as plt

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# for reporductibility
torch.manual_seed(1)
if device == 'cuda':
    torch.cuda.manual_seed_all(1)

In [None]:
# parameters
lr = 0.01
epochs = 10
batch_size = 32

## Prepare dataset
#### download

In [None]:
mnist_train = dsets.MNIST(root="../_datasets/", # 파일의 위치
                          train=True, # 학습 데이터 여부
                          transform=transforms.ToTensor(), # 데이터의 순서를 이미지에서 텐서에 맞도록 변환
                                                           # 일반 이미지 : 값 0~255, 순서 Height, Weight, Channel
                                                           # 변환된 이미지 : 값 0~1, 순서 Channel, Height, Weight
                          download=True)                   # 데이터가 없는 경우 다운을 받음
mnist_test = dsets.MNIST(root="../_datasets/", 
                          train=False, 
                          transform=transforms.ToTensor(),
                          download=True)

#### iterator

In [None]:
# dataset loader
train_loader = torch.utils.data.DataLoader(dataset=mnist_train,
                                           batch_size=batch_size,
                                           shuffle=True,
                                           drop_last=True)

test_loader = torch.utils.data.DataLoader(dataset=mnist_test,
                                          batch_size=batch_size,
                                          shuffle=False,
                                          drop_last=True)

## Build the model

In [None]:
class ClassifierWithBN(nn.Module):
    def __init__(self, use_bn=True):
        super().__init__()
        
        if use_bn == True:
            self.layer1 = nn.Sequential(nn.Linear(784, 32, bias=True),
                                        nn.BatchNorm1d(32),
                                        nn.ReLU()).to(device)
            self.layer2 = nn.Sequential(nn.Linear(32, 32, bias=True),
                                        nn.BatchNorm1d(32),
                                        nn.ReLU()).to(device)
        else:
            self.layer1 = nn.Sequential(nn.Linear(784, 32, bias=True),
                                        nn.ReLU()).to(device)
            self.layer2 = nn.Sequential(nn.Linear(32, 32, bias=True),
                                        nn.ReLU()).to(device)
            
        self.layer3 = nn.Linear(32, 10, bias=True).to(device)

    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        return x

In [None]:
# model
bn_model = ClassifierWithBN()
nn_model = ClassifierWithBN(False)

# cost, loss, optimizer
criterion = torch.nn.CrossEntropyLoss().to(device)
bn_optimizer = torch.optim.Adam(bn_model.parameters(), lr=lr)
nn_optimizer = torch.optim.Adam(nn_model.parameters(), lr=lr)

In [None]:
# Save losses and accuracies every epoch
train_losses, valid_losses = [], []
train_accs, valid_accs = [], []

train_total_batch = len(train_loader)
test_total_batch = len(test_loader)

In [None]:


for epoch in range(epochs+1):
    bn_model.train()
    nn_model.train()
    
    for x, y in train_loader:
        x = x.view(-1, 28*28).to(device)
        y = y.to(device)
        
        bn_optimizer.zero_grad()
        bn_pred = bn_model(x)
        bn_loss = criterion(bn_pred, y) 
        bn_loss.backward()
        bn_optimizer.step()
        
        nn_optimizer.zero_grad()
        nn_pred = bn_model(x)
        nn_loss = criterion(nn_pred, y) 
        nn_loss.backward()
        nn_optimizer.step()
        
    with torch.no_grad():
        bn_model.eval()
        nn_model.eval()
        
        bn_lost, nn_loss, bn_acc, nn_acc = 0, 0, 0, 0
        for i, (x, y) in enumerate(train_loader):
            x = x.view(-1, 28*28).to(device)
            y = y.to(device)
            
            bn_pred = bn_model(x)
            bn_correct_pred = torch.argmax(bn_pred, 1) == y
            bn_loss += criterion(bn_pred, y)
            bn_acc += bn_correct_pred.float().mean()
        
            nn_pred = bn_model(x)
            nn_correct_pred = torch.argmax(bn_pred, 1) == y
            nn_loss += criterion(nn_pred, y)
            nn_acc += bn_correct_pred.float().mean()

        bn_loss = bn_loss / train_total_batch
        nn_loss = nn_loss / train_total_batch
        bn_acc  = bn_loss / train_total_batch
        nn_acc  = nn_loss / train_total_batch
        
        # Save valid losses/accuracies
        valid_losses.append([bn_loss, nn_loss])
        valid_accs.append([bn_acc, nn_acc])
        print('[Epoch %d-VALID] Batchnorm Loss(Acc): bn_loss:%.5f(bn_acc:%.2f) vs No Batchnorm Loss(Acc): nn_loss:%.5f(nn_acc:%.2f)' % (
                (epoch + 1), bn_loss.item(), bn_acc.item(), nn_loss.item(), nn_acc.item()))
        
print('Learning finished')

In [None]:
valid_accs

In [None]:
valid_losses