<a href="https://colab.research.google.com/github/Bitdribble/dlwpt-code/blob/master/colab/PyTorchCh8_Batch_Normalization.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import datetime
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms

torch.set_printoptions(edgeitems=2, linewidth=75)
torch.manual_seed(123)

<torch._C.Generator at 0x7efb33cabb50>

In [2]:
# Data preparation
data_path = '.'
cifar10 = datasets.CIFAR10(data_path, train=True, download=True)
cifar10_val = datasets.CIFAR10(data_path, train=False, download=True)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting ./cifar-10-python.tar.gz to .
Files already downloaded and verified


In [3]:
# Normalize data
transformed_cifar10 = datasets.CIFAR10(
    data_path, train=True, download=False,
    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4915, 0.4823, 0.4468),
                             (0.2470, 0.2435, 0.2616))
    ]))
transformed_cifar10_val = datasets.CIFAR10(
    data_path, train=False, download=False,
    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4915, 0.4823, 0.4468),
                             (0.2470, 0.2435, 0.2616))
    ]))

In [4]:
# Restrict data to airplanes and birds
label_map = {0: 0, 2: 1}
class_names = ['airplane', 'bird']

cifar2 = [(img, label_map[label]) for img, label in transformed_cifar10 if label in [0, 2]]
cifar2_val = [(img, label_map[label]) for img, label in transformed_cifar10_val if label in [0, 2]]

In [5]:
device = (torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu'))
print(f"Training on device {device}.")

Training on device cuda.


In [6]:
def training_loop(n_epochs, device, optimizer, model, loss_fn, train_loader, log_epochs=0):
  for epoch in range(1, n_epochs + 1):
    loss_train = 0.0

    for imgs, labels in train_loader:
      imgs = imgs.to(device=device)
      labels = labels.to(device=device)

      outputs = model(imgs)
      loss = loss_fn(outputs, labels)
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()
      loss_train += loss.item()

    if log_epochs is not 0 and ((epoch+1) % log_epochs == 0 or (epoch+1) == n_epochs):
      print(f"{datetime.datetime.now()} Epoch {epoch+1}, "
            f"Training loss {loss_train / len(train_loader):.3f}")

def validate(model, device, train_loader, val_loader):
  for name, loader in [("train", train_loader), ("val", val_loader)]:
    correct = 0
    total = 0
    with torch.no_grad(): 
      for imgs, labels in loader:
        imgs = imgs.to(device=device)
        labels = labels.to(device=device)

        outputs = model(imgs)
        _, predicted = torch.max(outputs, dim=1) 

        total += labels.shape[0]
        correct += int((predicted == labels).sum())


    print(f"Accuracy {name}: {correct / total:.2f}")

In [7]:
# Batch Normalization allows us to increase the learning rate and make
# training less dependent on initialization and act as a regularizer, thus representing an
# alternative to dropout.
#
# Paper: https://arxiv.org/abs/1502.03167
#
# Batch normalization rescales the inputs to the activations
# of the network so that minibatches have a certain desirable distribution. Recalling
# the  mechanics  of  learning  and  the  role  of  nonlinear  activation  functions,  this
# helps avoid the inputs to activation functions being too far into the saturated portion
# of the function, thereby killing gradients and slowing training.
# 
# In  practical  terms,  batch  normalization  shifts  and  scales  an  intermediate  input
# using  the  mean  and  standard  deviation  collected  at  that  intermediate  location  over
# the samples of the minibatch. The regularization effect is a result of the fact that an
# individual  sample  and  its  downstream  activations  are  always  seen  by  the  model  as
# shifted  and  scaled,  depending  on  the  statistics  across  the  randomly  extracted  mini-
# batch.  This  is  in  itself  a  form  of  principled  augmentation.  The  authors  of  the  paper
# suggest  that  using  batch  normalization  eliminates  or  at  least  alleviates  the  need
# for dropout.
#
# Batch normalization in PyTorch is provided through the nn.BatchNorm1D,
# nn.BatchNorm2d, and nn.BatchNorm3d modules, depending on the dimensionality of
# the input. Since the aim for batch normalization is to rescale the inputs of the activa-
# tions, the natural location is after the linear transformation (convolution, in this case)
# and the activation, as shown here

class NetBatchNorm(nn.Module):
  def __init__(self, n_chans1=32):
    super().__init__()

    self.n_chans1 = n_chans1
    self.conv1 = nn.Conv2d(3, n_chans1, kernel_size=3, padding=1)
    self.conv1_batchnorm = nn.BatchNorm2d(num_features=n_chans1)
    self.act1 = nn.Tanh()
    self.pool1 = nn.MaxPool2d(2)

    self.conv2 = nn.Conv2d(n_chans1, n_chans1//2, kernel_size=3, padding=1)
    self.conv2_batchnorm = nn.BatchNorm2d(num_features=n_chans1//2)
    self.act2 = nn.Tanh()
    self.pool2 = nn.MaxPool2d(2)

    self.fc1 = nn.Linear(8*8*(n_chans1//2), 32)
    self.act3 = nn.Tanh()

    self.fc2 = nn.Linear(32, 2)

  def forward(self, x):
    out = self.pool1(self.act1(self.conv1_batchnorm(self.conv1(x))))
    out = self.pool2(self.act2(self.conv2_batchnorm(self.conv2(out))))
    out = out.view(-1, 8*8*(self.n_chans1//2)) # In place of nn.Flatten()
    out = self.act3(self.fc1(out))
    out = self.fc2(out)
    return out

In [8]:
# Dropout is normally active during training, while during the evaluation of a
# trained model in production, dropout is bypassed or, equivalently, assigned a proba-
# bility  equal  to  zero.  This  is  controlled  through  the  train  property  of  the  Dropout
# module. Recall that PyTorch lets us switch between the two modalities by calling
#
# model.train()
#
# or
#
# model.eval()
#
# on any nn.Model subclass. The call will be automatically replicated on the submodules
# so  that  if  Dropout  is  among  them,  it  will  behave  accordingly  in  subsequent  forward
# and backward passes.

In [9]:
model = NetBatchNorm(n_chans1=32).to(device=device)
optimizer = optim.SGD(model.parameters(), lr=1e-2)
train_loader = torch.utils.data.DataLoader(cifar2, batch_size=64,
                                           shuffle=True)
loss_fn = nn.CrossEntropyLoss()

model.train() # Set train mode

training_loop(
    n_epochs = 100,
    device=device,
    optimizer = optimizer,
    model = model,

    loss_fn = loss_fn,
    train_loader = train_loader,
    log_epochs = 10
)

2022-01-31 00:11:44.997435 Epoch 10, Training loss 0.281
2022-01-31 00:11:54.801778 Epoch 20, Training loss 0.223
2022-01-31 00:12:04.735258 Epoch 30, Training loss 0.177
2022-01-31 00:12:14.561481 Epoch 40, Training loss 0.136
2022-01-31 00:12:24.380778 Epoch 50, Training loss 0.105
2022-01-31 00:12:34.332574 Epoch 60, Training loss 0.082
2022-01-31 00:12:44.247117 Epoch 70, Training loss 0.059
2022-01-31 00:12:54.184916 Epoch 80, Training loss 0.039
2022-01-31 00:13:04.168679 Epoch 90, Training loss 0.031
2022-01-31 00:13:14.120098 Epoch 100, Training loss 0.023


In [10]:
# Just as for dropout, batch normalization needs to behave differently during training
# and inference. In fact, at inference time, we want to avoid having the output for a spe-
# cific input depend on the statistics of the other inputs we’re presenting to the model.
# As  such,  we  need  a  way  to  still  normalize,  but  this  time  fixing  the  normalization
# parameters once and for all.
#
# As  minibatches  are  processed,  in  addition  to  estimating  the  mean  and  standard
# deviation for the current minibatch, PyTorch also updates the running estimates for
# mean  and  standard  deviation  that  are  representative  of  the  whole  dataset,  as  an
# approximation. This way, when the user specifies
#
# model.eval()
# 
# and the model contains a batch normalization module, the running estimates are fro-
# zen  and  used  for  normalization. To unfreeze  running estimates  and  return  to using
# the minibatch statistics, we call model.train(), just as we did for dropout. 

In [None]:
train_loader = torch.utils.data.DataLoader(cifar2, batch_size=64, shuffle=False)
val_loader = torch.utils.data.DataLoader(cifar2_val, batch_size=64, shuffle=False)

model.eval() # Set eval mode

validate(model, device, train_loader, val_loader)

Accuracy train: 1.00
Accuracy val: 0.88
