<a href="https://colab.research.google.com/github/Redcoder815/Deep_Learning_PyTorch/blob/main/21BatchNormalization.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
from torch import nn
import torchvision.transforms as transforms
from torch.utils import data
from torchvision import datasets
import torch.optim as optim

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

In [3]:
def batch_norm(X, gamma, beta, moving_mean, moving_var, eps, momentum):
    # Use is_grad_enabled to determine whether we are in training mode
    if not torch.is_grad_enabled():
        # In prediction mode, use mean and variance obtained by moving average
        X_hat = (X - moving_mean) / torch.sqrt(moving_var + eps)
    else:
        assert len(X.shape) in (2, 4)
        if len(X.shape) == 2:
            # When using a fully connected layer, calculate the mean and
            # variance on the feature dimension
            mean = X.mean(dim=0)
            var = ((X - mean) ** 2).mean(dim=0)
        else:
            # When using a two-dimensional convolutional layer, calculate the
            # mean and variance on the channel dimension (axis=1). Here we
            # need to maintain the shape of X, so that the broadcasting
            # operation can be carried out later
            mean = X.mean(dim=(0, 2, 3), keepdim=True)
            var = ((X - mean) ** 2).mean(dim=(0, 2, 3), keepdim=True)
        # In training mode, the current mean and variance are used
        X_hat = (X - mean) / torch.sqrt(var + eps)
        # Update the mean and variance using moving average
        moving_mean = (1.0 - momentum) * moving_mean + momentum * mean
        moving_var = (1.0 - momentum) * moving_var + momentum * var
    Y = gamma * X_hat + beta  # Scale and shift
    return Y, moving_mean.data, moving_var.data

In [4]:
class BatchNorm(nn.Module):
    # num_features: the number of outputs for a fully connected layer or the
    # number of output channels for a convolutional layer. num_dims: 2 for a
    # fully connected layer and 4 for a convolutional layer
    def __init__(self, num_features, num_dims):
        super().__init__()
        if num_dims == 2:
            shape = (1, num_features)
        else:
            shape = (1, num_features, 1, 1)
        # The scale parameter and the shift parameter (model parameters) are
        # initialized to 1 and 0, respectively
        self.gamma = nn.Parameter(torch.ones(shape))
        self.beta = nn.Parameter(torch.zeros(shape))
        # The variables that are not model parameters are initialized to 0 and
        # 1
        self.moving_mean = torch.zeros(shape)
        self.moving_var = torch.ones(shape)

    def forward(self, X):
        # If X is not on the main memory, copy moving_mean and moving_var to
        # the device where X is located
        if self.moving_mean.device != X.device:
            self.moving_mean = self.moving_mean.to(X.device)
            self.moving_var = self.moving_var.to(X.device)
        # Save the updated moving_mean and moving_var
        Y, self.moving_mean, self.moving_var = batch_norm(
            X, self.gamma, self.beta, self.moving_mean,
            self.moving_var, eps=1e-5, momentum=0.1)
        return Y

In [8]:
class BNLeNetScratch(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.net = nn.Sequential(
            nn.LazyConv2d(6, kernel_size=5), BatchNorm(6, num_dims=4),
            nn.Sigmoid(), nn.AvgPool2d(kernel_size=2, stride=2),
            nn.LazyConv2d(16, kernel_size=5), BatchNorm(16, num_dims=4),
            nn.Sigmoid(), nn.AvgPool2d(kernel_size=2, stride=2),
            nn.Flatten(), nn.LazyLinear(120),
            BatchNorm(120, num_dims=2), nn.Sigmoid(), nn.LazyLinear(84),
            BatchNorm(84, num_dims=2), nn.Sigmoid(),
            nn.LazyLinear(num_classes))
    def forward(self, X):
        return self.net(X)

In [9]:
model = BNLeNetScratch()
model.to(device)

BNLeNetScratch(
  (net): Sequential(
    (0): LazyConv2d(0, 6, kernel_size=(5, 5), stride=(1, 1))
    (1): BatchNorm()
    (2): Sigmoid()
    (3): AvgPool2d(kernel_size=2, stride=2, padding=0)
    (4): LazyConv2d(0, 16, kernel_size=(5, 5), stride=(1, 1))
    (5): BatchNorm()
    (6): Sigmoid()
    (7): AvgPool2d(kernel_size=2, stride=2, padding=0)
    (8): Flatten(start_dim=1, end_dim=-1)
    (9): LazyLinear(in_features=0, out_features=120, bias=True)
    (10): BatchNorm()
    (11): Sigmoid()
    (12): LazyLinear(in_features=0, out_features=84, bias=True)
    (13): BatchNorm()
    (14): Sigmoid()
    (15): LazyLinear(in_features=0, out_features=10, bias=True)
  )
)

In [10]:
batch_size = 256

In [11]:
Transform = transforms.Compose([
    transforms.Resize((227, 227)),
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

In [12]:
mnist_train = datasets.FashionMNIST(root="../data", train=True, transform=Transform, download=True)
mnist_val = datasets.FashionMNIST(root="../data", train=False, transform=Transform, download=True)

train_iter = data.DataLoader(mnist_train, batch_size, shuffle=True, num_workers=2)
val_iter = data.DataLoader(mnist_val, batch_size, shuffle=False, num_workers=2)

100%|██████████| 26.4M/26.4M [00:02<00:00, 11.1MB/s]
100%|██████████| 29.5k/29.5k [00:00<00:00, 191kB/s]
100%|██████████| 4.42M/4.42M [00:01<00:00, 3.54MB/s]
100%|██████████| 5.15k/5.15k [00:00<00:00, 29.6MB/s]


In [13]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [14]:
max_epochs = 3

In [15]:
for epoch in range(max_epochs):
  model.train()
  train_loss_sum, train_accuracy_sum, n = 0.0, 0.0, 0
  for images, labels in train_iter:
    images, labels = images.to(device), labels.to(device)
    y_pred = model(images)
    l = criterion(y_pred, labels)
    optimizer.zero_grad()
    l.backward()
    optimizer.step()
    train_loss_sum += l
    predicted_labels = torch.argmax(y_pred, dim=1)
    train_accuracy_sum += (predicted_labels == labels).float().sum()
    n += labels.numel()

  model.eval()
  test_accuracy_sum, test_n = 0.0, 0
  with torch.no_grad():
    for images, labels in val_iter:
      images, labels = images.to(device), labels.to(device)
      y_pred = model(images)
      predicted_labels = torch.argmax(y_pred, dim=1)
      test_accuracy_sum += (predicted_labels == labels).float().sum()
      test_n += labels.numel()
  test_accuracy = test_accuracy_sum / test_n
  print(f'Epoch {epoch + 1}, Loss: {train_loss_sum / n:.4f}, Train Accuracy: {train_accuracy_sum / n:.4f}, Validation Accuracy: {test_accuracy:.4f}')

Epoch 1, Loss: 0.0038, Train Accuracy: 0.8124, Validation Accuracy: 0.7939
Epoch 2, Loss: 0.0018, Train Accuracy: 0.8618, Validation Accuracy: 0.8576
Epoch 3, Loss: 0.0014, Train Accuracy: 0.8833, Validation Accuracy: 0.8696


In [16]:
class BNLeNet(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.net = nn.Sequential(
            nn.LazyConv2d(6, kernel_size=5), nn.LazyBatchNorm2d(),
            nn.Sigmoid(), nn.AvgPool2d(kernel_size=2, stride=2),
            nn.LazyConv2d(16, kernel_size=5), nn.LazyBatchNorm2d(),
            nn.Sigmoid(), nn.AvgPool2d(kernel_size=2, stride=2),
            nn.Flatten(), nn.LazyLinear(120), nn.LazyBatchNorm1d(),
            nn.Sigmoid(), nn.LazyLinear(84), nn.LazyBatchNorm1d(),
            nn.Sigmoid(), nn.LazyLinear(num_classes))
    def forward(self, X):
      self.net(X)