<a href="https://colab.research.google.com/github/Hongyongmin/Edwith-Pytorch/blob/main/9_4_Batch_Normalization.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torchvision.datasets as dsets
import torchvision.transforms as transforms
import matplotlib.pylab as plt

In [None]:
device  ='cuda' if torch.cuda.is_available() else 'cpu'

torch.manual_seed(1)
if device == 'cuda':
  torch.cuda.manual_seed_all(1)

In [None]:
learning_rate = 0.01
training_epochs = 10
batch_size = 32

In [None]:
mnist_train = dsets.MNIST(root='MNIST_data/',
                          train = True,
                          transform = transforms.ToTensor(),
                          download=True)

mnist_test = dsets.MNIST(root='MNIST_data/',
                         train=False,
                         transform = transforms.ToTensor(),
                         download=True)

In [None]:
train_loader = torch.utils.data.DataLoader(dataset = mnist_train,
                                           batch_size = batch_size,
                                           shuffle = True,
                                           drop_last = True)
test_loader = torch.utils.data.DataLoader(dataset = mnist_test,
                                          batch_size = batch_size,
                                          shuffle = False,
                                          drop_last = True)

In [None]:
linear1 = torch.nn.Linear(784, 32, bias = True)
linear2 = torch.nn.Linear(32, 32, bias = True)
linear3 = torch.nn.Linear(32, 10, bias = True)
relu = torch.nn.ReLU()
bn1 = torch.nn.BatchNorm1d(32)
bn2 = torch.nn.BatchNorm1d(32)
 
nn_linear1 = torch.nn.Linear(784, 32, bias=True)
nn_linear2 = torch.nn.Linear(32, 32, bias = True)
nn_linear3 = torch.nn.Linear(32, 10, bias=True)

In [None]:
bn_model = torch.nn.Sequential(linear1, bn1, relu,
                               linear2, bn2, relu,
                               linear3).to(device)

nn_model = torch.nn.Sequential(nn_linear1, relu,
                               nn_linear2, relu,
                               nn_linear3).to(device)

In [None]:
criterion = torch.nn.CrossEntropyLoss().to(device)
bn_optimizer = torch.optim.Adam(bn_model.parameters(), lr = learning_rate)
nn_optimizer = torch.optim.Adam(nn_model.parameters(), lr = learning_rate)

In [None]:
train_losses = []
train_accs = []

valid_losses = []
valid_accs = []

train_total_batch = len(train_loader)
test_total_batch = len(test_loader)
for epoch in range(training_epochs):
  bn_model.train()

  for X, Y in train_loader:

    X = X.view(-1, 28*28).to(device)
    Y = Y.to(device)

    bn_optimizer.zero_grad()
    bn_prediction = bn_model(X)
    bn_loss = criterion(bn_prediction, Y)
    bn_loss.backward()
    bn_optimizer.step()

    nn_optimizer.zero_grad()
    nn_prediction = nn_model(X)
    nn_loss = criterion(nn_prediction, Y)
    nn_loss.backward()
    nn_optimizer.step()

  with torch.no_grad():
    bn_model.eval()

    bn_loss, nn_loss, bn_acc, nn_acc = 0, 0, 0, 0
    for i, (X, Y) in enumerate(train_loader):
      X = X.view(-1, 28*28).to(device)
      Y = Y.to(device)

      bn_prediction = bn_model(X)
      bn_correct_prediction = torch.argmax(bn_prediction, 1) ==Y
      bn_loss += criterion(bn_prediction, Y)
      bn_acc += bn_correct_prediction.float().mean()

      nn_prediction = nn_model(X)
      nn_correct_prediction = torch.argmax(nn_prediction, 1) ==Y
      nn_loss += criterion(nn_prediction, Y)
      nn_acc += nn_correct_prediction.float().mean()

    bn_loss, nn_loss, bn_acc, nn_acc = bn_loss / train_total_batch, nn_loss / train_total_batch, bn_acc / train_total_batch, nn_acc / train_total_batch

    train_losses.append([bn_loss, nn_loss])
    train_accs.append([bn_acc, nn_acc])
    print(
        '[Epoch %d-TRAIN] Batchnorm Loss(Acc): bn_loss:%.5f(bn_acc:%.2f) vs No Batchnorm Loss(Acc): nn_loss:%.5f(nn_acc:%.2f)'
         %((epoch +1), bn_loss.item(), bn_acc.item(), nn_loss.item(), nn_acc.item()))
    
    for i, (X, Y) in enumerate(test_loader):
            X = X.view(-1, 28 * 28).to(device)
            Y = Y.to(device)

            bn_prediction = bn_model(X)
            bn_correct_prediction = torch.argmax(bn_prediction, 1) == Y
            bn_loss += criterion(bn_prediction, Y)
            bn_acc += bn_correct_prediction.float().mean()

            nn_prediction = nn_model(X)
            nn_correct_prediction = torch.argmax(nn_prediction, 1) == Y
            nn_loss += criterion(nn_prediction, Y)
            nn_acc += nn_correct_prediction.float().mean()

    bn_loss, nn_loss, bn_acc, nn_acc = bn_loss / train_total_batch, nn_loss / train_total_batch, bn_acc / train_total_batch, nn_acc / train_total_batch


        # Save valid losses/acc
    valid_losses.append([bn_loss, nn_loss])
    valid_accs.append([bn_acc, nn_acc])
    print(
            '[Epoch %d-VALID] Batchnorm Loss(Acc): bn_loss:%.5f(bn_acc:%.2f) vs No Batchnorm Loss(Acc): nn_loss:%.5f(nn_acc:%.2f)' % (
                (epoch + 1), bn_loss.item(), bn_acc.item(), nn_loss.item(), nn_acc.item()))
    print()

print('Learning finished')

[Epoch 1-TRAIN] Batchnorm Loss(Acc): bn_loss:0.10442(bn_acc:0.97) vs No Batchnorm Loss(Acc): nn_loss:0.19405(nn_acc:0.94)
[Epoch 1-VALID] Batchnorm Loss(Acc): bn_loss:0.02028(bn_acc:0.16) vs No Batchnorm Loss(Acc): nn_loss:0.03679(nn_acc:0.16)

[Epoch 2-TRAIN] Batchnorm Loss(Acc): bn_loss:0.08397(bn_acc:0.97) vs No Batchnorm Loss(Acc): nn_loss:0.18143(nn_acc:0.95)
[Epoch 2-VALID] Batchnorm Loss(Acc): bn_loss:0.01756(bn_acc:0.16) vs No Batchnorm Loss(Acc): nn_loss:0.03740(nn_acc:0.16)

[Epoch 3-TRAIN] Batchnorm Loss(Acc): bn_loss:0.07141(bn_acc:0.98) vs No Batchnorm Loss(Acc): nn_loss:0.13216(nn_acc:0.96)
[Epoch 3-VALID] Batchnorm Loss(Acc): bn_loss:0.01660(bn_acc:0.16) vs No Batchnorm Loss(Acc): nn_loss:0.03098(nn_acc:0.16)

[Epoch 4-TRAIN] Batchnorm Loss(Acc): bn_loss:0.06997(bn_acc:0.98) vs No Batchnorm Loss(Acc): nn_loss:0.14292(nn_acc:0.96)
[Epoch 4-VALID] Batchnorm Loss(Acc): bn_loss:0.01678(bn_acc:0.16) vs No Batchnorm Loss(Acc): nn_loss:0.03436(nn_acc:0.16)

[Epoch 5-TRAIN] Batc