# Batch normalization 

The raw (before activation) outputs of the hidden layer - `hpreact` should ideally be not too small or not too big to avoid push to the extremes by `tanh`. This was the whole motivation for emphasis on correct initialization for network params: `W1, b1, W2, b2` etc

But what if: <br>
We just normalize `hpreact` directly? instead of trying to get the initialization right? Something akin to $X \rightarrow \frac{X-\mu}{\sigma}$ -- introduced by [Ioffe and Szegedy from Google brain in their paper on Batch Normalization](https://arxiv.org/pdf/1502.03167). 

In [8]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt # for making figures
%matplotlib inline

In [9]:
words = open('names.txt', 'r').read().splitlines()

In [22]:
# character mapping 
stoi = {}
allletters = sorted(set("".join(words)))

stoi = {s:i+1 for i,s in enumerate(allletters)}
stoi['.'] = 0

itos = {i:s for s,i in stoi.items()}

vocab_size = len(itos)

In [23]:
# build the dataset
block_size = 3 # context length: how many characters do we take to predict the next one?

def build_dataset(words):  
  X, Y = [], []
  
  for w in words:
    context = [0] * block_size
    for ch in w + '.':
      ix = stoi[ch]
      X.append(context)
      Y.append(ix)
      context = context[1:] + [ix] # crop and append

  X = torch.tensor(X)
  Y = torch.tensor(Y)
  print(X.shape, Y.shape)
  return X, Y

import random
random.seed(42)
random.shuffle(words)
n1 = int(0.8*len(words))
n2 = int(0.9*len(words))

Xtr,  Ytr  = build_dataset(words[:n1])     # 80%
Xdev, Ydev = build_dataset(words[n1:n2])   # 10%
Xte,  Yte  = build_dataset(words[n2:])     # 10%

torch.Size([182437, 3]) torch.Size([182437])
torch.Size([22781, 3]) torch.Size([22781])
torch.Size([22928, 3]) torch.Size([22928])


The paper for batch normalization can be found [here](https://arxiv.org/pdf/1502.03167). Section $3.1$ is most relevant to us. 

In [None]:
n_embed = 10
n_hidden = 200

g = torch.Generator().manual_seed(2147483647)
# lookup matrix
C = torch.randn((vocab_size,n_embed), generator=g)
# hidden layer - 100 neurons
W1 = torch.randn((block_size*n_embed,n_hidden), generator=g) * (5/3)/((n_embed*block_size)**0.5)   # use kaiming factor
b1 = torch.randn((n_hidden,), generator=g) * 0.01
# Output layer
W2 = torch.randn((n_hidden,vocab_size), generator=g ) * 0.1
b2 = torch.randn((vocab_size,), generator=g) * 0

# batch normalization params
bngain = torch.ones((1,n_hidden))
bnbias = torch.zeros((1,n_hidden))

# running mean and std
# W1 init using kaiming factor such that mean = 0, std = 1 so running must be init as:
bnmean_running = torch.zeros((1, n_hidden))
bnstd_running = torch.ones((1, n_hidden))

parameters = [C, W1, b1, W2, b2, bngain, bnbias]

for p in parameters:
    p.requires_grad = True

print(sum(pm.nelement() for pm in parameters))

12297


In [41]:
(5/3)/((n_embed*block_size)**0.5)

0.3042903097250923

<img src="../papers/batch_normalization.png" style="width:70%;">


Steps 1,2,3 are standard steps to mean center and normalize by std dev. Step 4 advocates for introducing a batch_gain and batch_bias as trainable parameters for the NN to learn from data. 
```
bngain = torch.ones((1,n_hidden))
bnbias = torch.zeros((1,n_hidden))
```
being init this way allows for __an exactly normal initialization__, and later as the NN trains, the normal constraint is relaxed. 

In [47]:
max_steps = 200000
batch_size = 32
lossi = []

for i in range(max_steps):

    # construct minibatch 
    ix = torch.randint(0, Xtr.shape[0], (batch_size,), generator=g)
    Xb, Yb = Xtr[ix], Ytr[ix]

    # forward pass
    emb = C[Xb]
    embcat = emb.view(emb.shape[0], -1)
    hpreact = embcat @ W1 + b1

    mu = hpreact.mean(dim=0, keepdim=True)
    std = hpreact.std(dim=0, keepdim=True)

    hpreact = bngain * ((hpreact-mu)/std) + bnbias

    h = torch.tanh(hpreact)
    logits = h @ W2 + b2
    loss = F.cross_entropy(input = logits, target=Yb)

    #backpass
    for p in parameters:
        p.grad = None
    loss.backward()

    # update
    lr = 0.1 if i < 150000 else 0.01 # step learning rate decay
    for p in parameters:
        p.data -= lr*p.grad
    
    # track loss
    if i % 10000 ==0:
        print(f'{i:7d}/{max_steps:7d}: loss = {loss.item():.4f}')
    lossi.append(loss.log10().item())

      0/ 200000: loss = 3.6993
  10000/ 200000: loss = 1.8060
  20000/ 200000: loss = 1.8391
  30000/ 200000: loss = 2.4287
  40000/ 200000: loss = 2.4655
  50000/ 200000: loss = 1.8893
  60000/ 200000: loss = 2.5348
  70000/ 200000: loss = 2.2075
  80000/ 200000: loss = 1.8094
  90000/ 200000: loss = 2.3167
 100000/ 200000: loss = 2.3735
 110000/ 200000: loss = 2.1251
 120000/ 200000: loss = 2.1736
 130000/ 200000: loss = 1.7423
 140000/ 200000: loss = 2.2378
 150000/ 200000: loss = 2.3517
 160000/ 200000: loss = 1.9762
 170000/ 200000: loss = 2.1643
 180000/ 200000: loss = 2.5964
 190000/ 200000: loss = 1.8833


Despite introducing BN and kaiming init factor, the loss on validation set has not reduced dramatically because of the simplicity of our network. It would seem that we have already exploited our basic architecture to the maximum. 

How would BN work in case of multiple hidden layers? 

- It is common to sprinkle BN after every linear (Y @ W1 + b1 type) layer 

__Interesting fact:__<br>
Before BN, mini batch allowed parallelization to enhance efficiency but still each data point progressed _independently_ (hpreact -> logits -> loss). But now, when we do BN, each datapoint in a become linked to all the other points in _that_ batch!<br>
This effect is may __sometimes be good__ as it introduces a regularization effect and prevents overfitting to any single datapoint.

However, this random coupling of examples in a batch is hard to predict mathematically and can lead to bizzare results. Instead: 
- Instance normalization
- Layer normalization 
- group normalization , may be preferred


### Fixing mean and std on training data for test time

By current design, the NN expects input batches (to calculate $\mu$ and $\sigma$) when in irl there may be individual examples and their forward passes. __This is similar to:__ finding training statistics and using them as it is on dev and test set for `scikit-learn`.

In [None]:
# explicit way

with torch.no_grad():
    # pass through training set to get statistics
    emb = C[Xb]
    embcat = emb.view(emb.shape[0], -1)
    hpreact = embcat @ W1 + b1
    bnmean = hpreact.mean(0, keepdim=True)
    bnstd = hpreact.std(0, keepdim=True)


In [53]:
@torch.no_grad()
def split_loss(split):
  x,y = {
    'train': (Xtr, Ytr),
    'val': (Xdev, Ydev),
    'test': (Xte, Yte),
  }[split]
  emb = C[x] 
  embcat = emb.view(emb.shape[0], -1) 
  
  hpreact = embcat @ W1 + b1
  hpreact = bngain * (hpreact-bnmean)/bnstd + bnbias   # normalize into training statistics 
  
  h = torch.tanh(hpreact)
  logits = h @ W2+ b2
  loss = F.cross_entropy(logits, y)
  # print(split, loss.item())
  return loss

print(f"{split_loss('train'):.4f}")
print(f"{split_loss('val'):.4f}")

2.0905
2.1481


This step of explicit calculation can be avoided by iteratively arriving at the batch mean _while_ it is training by using something as such in the body of the loop: 

```with torch.no_grad():
    bnmean_running = 0.999 * bnmean_running + 0.001 * bnmeani
    bnstd_running = 0.999 * bnstd_running + 0.001 * bnstdi
```
where `bnmeani` and `bnstdi` are stats for that batch. 

We can verify that at the end of training, `bnmean_running` ~ `bnmean` and `bnstd_running` ~ `bnstd`!

### Rationale:

The above formula for `bnmean_running` and `bnstd_running` comes from Exponential moving averages (EMA). <br>
running  = $(1-\alpha)\times$ running + $\alpha \times$ new_value | at $\alpha = 0.001$

"Update the current estimate with 0.1% of the new value, and retain 99.9% of the current running estimate."

Mathematically unrolling the recursion gives: 

$running_t = \alpha x_t + \alpha(1-\alpha)x_{t-1} + \alpha(1-\alpha)^2x_{t-2} + ...$

In practice:

PyTorch's default for momentum in BatchNorm is 0.1, meaning $\alpha = 0.1$. Your version with __0.001 is more conservative and stable__, but slower to adapt.

In [None]:
# PSEUDO - CODE

for i in range(max_steps):
  
  # minibatch construct
  #....
  
  # forward pass
  emb = C[Xb] # embed the characters into vectors
  embcat = emb.view(emb.shape[0], -1) # concatenate the vectors
  # Linear layer
  hpreact = embcat @ W1 #+ b1 # hidden layer pre-activation
  # BatchNorm layer
  # -------------------------------------------------------------
  bnmeani = hpreact.mean(0, keepdim=True) # i-th batch mean
  bnstdi = hpreact.std(0, keepdim=True) # i-th batch std
  hpreact = bngain * (hpreact - bnmeani) / bnstdi + bnbias
  with torch.no_grad():
    bnmean_running = 0.999 * bnmean_running + 0.001 * bnmeani
    bnstd_running = 0.999 * bnstd_running + 0.001 * bnstdi
  # -------------------------------------------------------------
#.......

## Redundancy of bias b1

Consider these lines of code: 
```
hpreact = embcat @ W1 + b1 # hidden layer pre-activation
  # BatchNorm layer
  # -------------------------------------------------------------
  bnmeani = hpreact.mean(0, keepdim=True) # i-th batch mean
  bnstdi = hpreact.std(0, keepdim=True) # i-th batch std
  hpreact = bngain * (hpreact - bnmeani) / bnstdi + bnbias
```
bias `b1` is added to compute hpreact but subtracted in form of `bnmeani`, so might as well remove b1 altogether! This is applicable for all linear layers before a BN step. 
