In [None]:
import torch
import torch.nn.functional as F
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
words = open('data/names.txt', 'r').read().splitlines()
words[:8]

In [None]:
all_chars = ['.'] + sorted(list(set("".join(words))))
itos = {idx: v for idx, v in enumerate(all_chars)}
stoi = {v: k for k, v in itos.items()}
vocab_size = len(itos)
print(itos)
print(vocab_size)

In [None]:
# build the dataset
block_size = 3

def build_dataset(words):
    X, Y = [], []

    for w in words:
        context = [0] * block_size
        for ch in w + '.':
            ix = stoi[ch]
            X.append(context)
            Y.append(ix)
            context = context[1:] + [ix]

    X = torch.tensor(X, device=device)
    Y = torch.tensor(Y, device=device)
    return X, Y

import random
random.seed(314)
random.shuffle(words)
n1 = int(0.8 * len(words))
n2 = int(0.9 * len(words))
Xtr, Ytr = build_dataset(words[:n1])
Xval, Yval = build_dataset(words[n1:n2])
Xte, Yte = build_dataset(words[n2:])

print(Xtr.shape, Ytr.shape)
print(Xval.shape, Yval.shape)
print(Xte.shape, Yte.shape)

In [None]:
Xtr.device

In [None]:
n_embd = 10
n_hidden = 200

g = torch.Generator(device=device).manual_seed(31416)
C = torch.randn((vocab_size, n_embd), generator=g, device=device)
W1 = torch.randn((block_size*n_embd, n_hidden), generator=g, device=device)
b1 = torch.randn(n_hidden, generator=g, device=device)
W2 = torch.randn([n_hidden, len(all_chars)], generator=g, device=device)
b2 = torch.rand(len(all_chars), generator=g, device=device)
parameters = [C, W1, b1, W2, b2]

print("# Params: ", sum(p.nelement() for p in parameters))
 
for p in parameters:
    p.requires_grad = True

In [None]:
max_steps = 100000
batch_size = 32
lossi = []

for epoch in range(max_steps):

    # build minibatch
    ix = torch.randint(0, Xtr.shape[0], (batch_size,), generator=g, device=device)
    Xb, Yb = Xtr[ix], Ytr[ix]

    # forward pass
    emb = C[Xb] 
    embcat = emb.view(emb.shape[0], -1)
    hpreact = embcat @ W1 + b1
    h = torch.tanh(hpreact)
    logits = h @ W2 + b2
    loss = F.cross_entropy(logits, Yb)
    

    # backward pass
    for p in parameters:
        p.grad = None
    loss.backward() 

    if epoch % 10000 == 0:
        print(f"{epoch:7d}/{max_steps:7d}: {loss.item():.4f}")
    lossi.append(loss.log10().item())

    learning_rate = 0.1 if epoch < max_steps / 2 else 0.01
    # update
    for p in parameters:
        p.data -= learning_rate * p.grad

In [None]:
plt.plot(lossi)

In [None]:
plt.plot(torch.tensor(lossi).view((-1, 1000)).mean(dim=1))

In [None]:
splits = {
    "train": (Xtr, Ytr),
    "val": (Xval, Yval),
    "test": (Xte, Yte)
}

@torch.no_grad()
def split_loss(split):
    x, y = splits[split]
    emb = C[x] 
    embcat = emb.view(emb.shape[0], -1)
    hpreact = embcat @ W1 + b1
    h = torch.tanh(hpreact)
    logits = h @ W2 + b2
    loss = F.cross_entropy(logits, y)
    print(split, loss.item())

split_loss('train')
split_loss('test')

In [None]:
# sample of the model
for _ in range(20):
    xs = [0, 0, 0]
    letters = []
    while True:
        emb = C[torch.tensor(xs)]
        h = torch.tanh((emb.view(-1) @ W1 + b1))
        logits = h @ W2 + b2
        probs = F.softmax(logits, dim=0)
        ix = torch.multinomial(probs,  num_samples=1, replacement=True).item()
        letters.append(itos[ix])
        xs = xs[1:] + [ix]
        if ix == 0:        
            break
    print(''.join(letters))

## Improving initializations

The large drop in the quality from the first iterations to the remaining is due to weights initialization.

By the way, due to no prior knowledge exists about character probabilities, we could expect they are equiprobable. 

In [None]:
-torch.tensor(1 / vocab_size).log().item()

This value is quite lower than those obtained in the first network iterations, so the network weights contains 'crazy' distributions.

In [None]:
# 4D example of the issue
logits = torch.tensor([0.0, 0.0, 0.0, 0.0])
probs = torch.softmax(logits, dim=0)
loss = -probs[2].log()  # any index ....
probs, loss

In [None]:
# A high value, selected
logits = torch.tensor([0.0, 0.0, 5.0, 0.0])
probs = torch.softmax(logits, dim=0)
loss = -probs[2].log()  
probs, loss

In [None]:
# A high value, unselected
logits = torch.tensor([0.0, 0.0, 5.0, 0.0])
probs = torch.softmax(logits, dim=0)
loss = -probs[1].log()  
probs, loss

In [None]:
# A random one, the loss is constrained
logits = torch.randn(4)
probs = torch.softmax(logits, dim=0)
loss = -probs[1].log()  
logits, probs, loss

In [None]:
# A random one, big values, loss could explode, and is unstable ...
logits = torch.randn(4) * 10
probs = torch.softmax(logits, dim=0)
loss = -probs[1].log()  
logits, probs, loss

In [None]:
# Even bigger ...
logits = torch.randn(4) * 100
probs = torch.softmax(logits, dim=0)
loss = -probs[1].log()  
logits, probs, loss

- In initialization, we might want uniformly distributed values
- Smaller values in the matrix allows small losses in untrained matrixes

In [None]:
n_embd = 10
n_hidden = 200

g = torch.Generator(device=device).manual_seed(31416)
C = torch.randn((vocab_size, n_embd), generator=g, device=device)
W1 = torch.randn((block_size*n_embd, n_hidden), generator=g, device=device)
b1 = torch.randn(n_hidden, generator=g, device=device)
W2 = torch.randn([n_hidden, len(all_chars)], generator=g, device=device)
b2 = torch.rand(len(all_chars), generator=g, device=device)
parameters = [C, W1, b1, W2, b2]

print("# Params: ", sum(p.nelement() for p in parameters))
 
for p in parameters:
    p.requires_grad = True

max_steps = 200000
batch_size = 32
lossi = []

for epoch in range(max_steps):

    # build minibatch
    ix = torch.randint(0, Xtr.shape[0], (batch_size,), generator=g, device=device)
    Xb, Yb = Xtr[ix], Ytr[ix]

    # forward pass
    emb = C[Xb] 
    embcat = emb.view(emb.shape[0], -1)
    hpreact = embcat @ W1 + b1
    h = torch.tanh(hpreact)
    logits = h @ W2 + b2
    loss = F.cross_entropy(logits, Yb)
    

    # backward pass
    for p in parameters:
        p.grad = None
    loss.backward() 

    if epoch % 10000 == 0:
        print(f"{epoch:7d}/{max_steps:7d}: {loss.item():.4f}")
    lossi.append(loss.log10().item())

    learning_rate = 0.1 if epoch < max_steps / 2 else 0.01
    # update
    for p in parameters:
        p.data -= learning_rate * p.grad

    break

print(logits[:2])

Logits contains really large values, what creates the very large loss. To solve this we will:
- Set all biases to zero
- Set all weights to small values 

In [None]:
n_embd = 10
n_hidden = 200

g = torch.Generator(device=device).manual_seed(31416)
C = torch.randn((vocab_size, n_embd), generator=g, device=device)
W1 = torch.randn((block_size*n_embd, n_hidden), generator=g, device=device)
b1 = torch.randn(n_hidden, generator=g, device=device)
W2 = torch.randn([n_hidden, len(all_chars)], generator=g, device=device)       * 0.01
b2 = torch.rand(len(all_chars), generator=g, device=device)                    * 0.0
parameters = [C, W1, b1, W2, b2]

print("# Params: ", sum(p.nelement() for p in parameters))
 
for p in parameters:
    p.requires_grad = True

max_steps = 200000
batch_size = 32
lossi = []

for epoch in range(max_steps):

    # build minibatch
    ix = torch.randint(0, Xtr.shape[0], (batch_size,), generator=g, device=device)
    Xb, Yb = Xtr[ix], Ytr[ix]

    # forward pass
    emb = C[Xb] 
    embcat = emb.view(emb.shape[0], -1)
    hpreact = embcat @ W1 + b1
    h = torch.tanh(hpreact)
    logits = h @ W2 + b2
    loss = F.cross_entropy(logits, Yb)
    

    # backward pass
    for p in parameters:
        p.grad = None
    loss.backward() 

    if epoch % 10000 == 0:
        print(f"{epoch:7d}/{max_steps:7d}: {loss.item():.4f}")
    lossi.append(loss.log10().item())

    learning_rate = 0.1 if epoch < max_steps / 2 else 0.01
    # update
    for p in parameters:
        p.data -= learning_rate * p.grad

    break

print(logits[:2])

As desired, the loss is now closer to the theoretical expected value: 3.

Lets run some iterations with the new initialization.

In [None]:
n_embd = 10
n_hidden = 200

g = torch.Generator(device=device).manual_seed(31416)
C = torch.randn((vocab_size, n_embd), generator=g, device=device)
W1 = torch.randn((block_size*n_embd, n_hidden), generator=g, device=device)
b1 = torch.randn(n_hidden, generator=g, device=device)
W2 = torch.randn([n_hidden, len(all_chars)], generator=g, device=device)       * 0.01
b2 = torch.rand(len(all_chars), generator=g, device=device)                    * 0.0
parameters = [C, W1, b1, W2, b2]

print("# Params: ", sum(p.nelement() for p in parameters))
 
for p in parameters:
    p.requires_grad = True

max_steps = 1000
batch_size = 32
lossi = []

for epoch in range(max_steps):

    # build minibatch
    ix = torch.randint(0, Xtr.shape[0], (batch_size,), generator=g, device=device)
    Xb, Yb = Xtr[ix], Ytr[ix]

    # forward pass
    emb = C[Xb] 
    embcat = emb.view(emb.shape[0], -1)
    hpreact = embcat @ W1 + b1
    h = torch.tanh(hpreact)
    logits = h @ W2 + b2
    loss = F.cross_entropy(logits, Yb)
    

    # backward pass
    for p in parameters:
        p.grad = None
    loss.backward() 

    if epoch % (max_steps // 10) == 0:
        print(f"{epoch:7d}/{max_steps:7d}: {loss.item():.4f}")
    lossi.append(loss.log10().item())

    learning_rate = 0.1 if epoch < max_steps / 2 else 0.01
    # update
    for p in parameters:
        p.data -= learning_rate * p.grad


In [None]:
split_loss('train')
split_loss('test')

As you can see, now we have an smoother loss function even in the beginning. 

- Initialization problems can make to waste the initial epochs, instead of using the effort to later loss improvements.

- Now, both loss values are smaller than with the original initialization (2.22, 2.25)

## Second problem, arguments to activation functions

In [None]:
n_embd = 10
n_hidden = 200

g = torch.Generator(device=device).manual_seed(31416)
C = torch.randn((vocab_size, n_embd), generator=g, device=device)
W1 = torch.randn((block_size*n_embd, n_hidden), generator=g, device=device)
b1 = torch.randn(n_hidden, generator=g, device=device)
W2 = torch.randn([n_hidden, len(all_chars)], generator=g, device=device)       * 0.01
b2 = torch.rand(len(all_chars), generator=g, device=device)                    * 0.0
parameters = [C, W1, b1, W2, b2]

print("# Params: ", sum(p.nelement() for p in parameters))
 
for p in parameters:
    p.requires_grad = True

max_steps = 200000
batch_size = 32
lossi = []

for epoch in range(max_steps):

    # build minibatch
    ix = torch.randint(0, Xtr.shape[0], (batch_size,), generator=g, device=device)
    Xb, Yb = Xtr[ix], Ytr[ix]

    # forward pass
    emb = C[Xb] 
    embcat = emb.view(emb.shape[0], -1)
    hpreact = embcat @ W1 + b1
    h = torch.tanh(hpreact)
    logits = h @ W2 + b2
    loss = F.cross_entropy(logits, Yb)
    

    # backward pass
    for p in parameters:
        p.grad = None
    loss.backward() 

    if epoch % 10000 == 0:
        print(f"{epoch:7d}/{max_steps:7d}: {loss.item():.4f}")
    lossi.append(loss.log10().item())

    learning_rate = 0.1 if epoch < max_steps / 2 else 0.01
    # update
    for p in parameters:
        p.data -= learning_rate * p.grad

    break

In [None]:
# take a look at the values returned by the activation function
h

You can see how many values are very close to 1 and -1. In all this cases, the neuron is in a zone that cannot learn (remember, gradient is very close to zero)

In [None]:
def show_tanh():
    x = torch.arange(-5, 5, 0.1)
    y = x.tanh()
    der_y = 1 - torch.tanh(x) ** 2
    plt.plot(x, y)
    plt.plot(x, der_y)
    plt.show()

show_tanh()

In [None]:
# lets calculate the histogram of h values
plt.hist(h.view(-1).tolist(), bins=50)
plt.show()

In [None]:
# you can see that most values are located in zones where the tanh function cannot learn.
# lets see the pre-activation values
plt.hist(hpreact.view(-1).tolist(), bins=50)
plt.show()

In [None]:
# You can see it is very broad, and that explains the tanh results
# lets see the distribution of h values per training instance
plt.figure(figsize=(20, 10))
plt.imshow(h.cpu().abs()>0.99, cmap='gray', interpolation='nearest')

As you can see the image is full of white dots, spread across all the neurons and instances.
- there could be cases where a whole columns is white, so the neuron is completelly dead
- the same behavior can be found in other activation functions like sigmoids and ReLU

The solution to this problem is transforming the values of hpreact close to zero, by modifying the weight and bias of the first layer.

In [None]:
n_embd = 10
n_hidden = 200

g = torch.Generator(device=device).manual_seed(31416)
C = torch.randn((vocab_size, n_embd), generator=g, device=device)
W1 = torch.randn((block_size*n_embd, n_hidden), generator=g, device=device)    * 0.1
b1 = torch.randn(n_hidden, generator=g, device=device)                         * 0.0
W2 = torch.randn([n_hidden, len(all_chars)], generator=g, device=device)       * 0.01
b2 = torch.rand(len(all_chars), generator=g, device=device)                    * 0.0
parameters = [C, W1, b1, W2, b2]

print("# Params: ", sum(p.nelement() for p in parameters))
 
for p in parameters:
    p.requires_grad = True

max_steps = 200000
batch_size = 32
lossi = []

for epoch in range(max_steps):

    # build minibatch
    ix = torch.randint(0, Xtr.shape[0], (batch_size,), generator=g, device=device)
    Xb, Yb = Xtr[ix], Ytr[ix]

    # forward pass
    emb = C[Xb] 
    embcat = emb.view(emb.shape[0], -1)
    hpreact = embcat @ W1 + b1
    h = torch.tanh(hpreact)
    logits = h @ W2 + b2
    loss = F.cross_entropy(logits, Yb)
    

    # backward pass
    for p in parameters:
        p.grad = None
    loss.backward() 

    if epoch % 10000 == 0:
        print(f"{epoch:7d}/{max_steps:7d}: {loss.item():.4f}")
    lossi.append(loss.log10().item())

    learning_rate = 0.1 if epoch < max_steps / 2 else 0.01
    # update
    for p in parameters:
        p.data -= learning_rate * p.grad

    break

In [None]:
plt.hist(h.view(-1).tolist(), bins=50)
plt.show()
plt.figure(figsize=(20, 10))
plt.imshow(h.abs().cpu()>0.99, cmap='gray', interpolation='nearest')

Using other values for initialization generates other distributions of h

In [None]:
n_embd = 10
n_hidden = 200

g = torch.Generator(device=device).manual_seed(31416)
C = torch.randn((vocab_size, n_embd), generator=g, device=device)
W1 = torch.randn((block_size*n_embd, n_hidden), generator=g, device=device)    * 0.2
b1 = torch.randn(n_hidden, generator=g, device=device)                         * 0.0
W2 = torch.randn([n_hidden, len(all_chars)], generator=g, device=device)       * 0.01
b2 = torch.rand(len(all_chars), generator=g, device=device)                    * 0.0
parameters = [C, W1, b1, W2, b2]

print("# Params: ", sum(p.nelement() for p in parameters))
 
for p in parameters:
    p.requires_grad = True

max_steps = 200000
batch_size = 32
lossi = []

for epoch in range(max_steps):

    # build minibatch
    ix = torch.randint(0, Xtr.shape[0], (batch_size,), generator=g, device=device)
    Xb, Yb = Xtr[ix], Ytr[ix]

    # forward pass
    emb = C[Xb] 
    embcat = emb.view(emb.shape[0], -1)
    hpreact = embcat @ W1 + b1
    h = torch.tanh(hpreact)
    logits = h @ W2 + b2
    loss = F.cross_entropy(logits, Yb)
    

    # backward pass
    for p in parameters:
        p.grad = None
    loss.backward() 

    if epoch % 10000 == 0:
        print(f"{epoch:7d}/{max_steps:7d}: {loss.item():.4f}")
    lossi.append(loss.log10().item())

    learning_rate = 0.1 if epoch < max_steps / 2 else 0.01
    # update
    for p in parameters:
        p.data -= learning_rate * p.grad

    break

In [None]:
plt.hist(h.view(-1).tolist(), bins=50)
plt.show()
plt.figure(figsize=(20, 10))
plt.imshow(h.abs().cpu()>0.99, cmap='gray', interpolation='nearest')

In [None]:
# Run the whole training
n_embd = 10
n_hidden = 200

g = torch.Generator(device=device).manual_seed(31416)
C = torch.randn((vocab_size, n_embd), generator=g, device=device)
W1 = torch.randn((block_size*n_embd, n_hidden), generator=g, device=device)    * 0.2
b1 = torch.randn(n_hidden, generator=g, device=device)                         * 0.0
W2 = torch.randn([n_hidden, len(all_chars)], generator=g, device=device)       * 0.01
b2 = torch.rand(len(all_chars), generator=g, device=device)                    * 0.0
parameters = [C, W1, b1, W2, b2]

print("# Params: ", sum(p.nelement() for p in parameters))
 
for p in parameters:
    p.requires_grad = True

max_steps = 20000
batch_size = 32
lossi = []

for epoch in range(max_steps):

    # build minibatch
    ix = torch.randint(0, Xtr.shape[0], (batch_size,), generator=g, device=device)
    Xb, Yb = Xtr[ix], Ytr[ix]

    # forward pass
    emb = C[Xb] 
    embcat = emb.view(emb.shape[0], -1)
    hpreact = embcat @ W1 + b1
    h = torch.tanh(hpreact)
    logits = h @ W2 + b2
    loss = F.cross_entropy(logits, Yb)
    

    # backward pass
    for p in parameters:
        p.grad = None
    loss.backward() 

    if epoch % (max_steps // 10) == 0:
        print(f"{epoch:7d}/{max_steps:7d}: {loss.item():.4f}")
    lossi.append(loss.log10().item())

    learning_rate = 0.1 if epoch < max_steps / 2 else 0.01
    # update
    for p in parameters:
        p.data -= learning_rate * p.grad

In [None]:
split_loss('train')
split_loss('test')

We can see a new improvement to the loss function (2.0985, 2.1490). This is because all the neurons are working most of the time in the training process.

**Loss log**:
- original: 

train 2.1261403560638428

test 2.184201955795288

- fix softmax confidently wrong:

train 2.098554849624634

test 2.151556968688965

- fix tanh layer too saturated at init

train 2.064349412918091

test 2.120842933654785



Since the network is shallow, even with wrong initializations the results are quite good. In deeper networks errors are stacked and can turn the network unusable, or very hard to train.

In general, it is quite hard to assign the reducing factors by hand, and some rules can be applied.

In [None]:
x = torch.randn(1000, 10)
w = torch.randn(10, 200)
y = x @ w
print(x.mean(), x.std())
print(y.mean(), y.std())
plt.figure(figsize=(20, 5))
plt.subplot(121)
plt.hist(x.view(-1).tolist(), 50, density=True)
plt.subplot(122)
plt.hist(y.view(-1).tolist(), 50, density=True)
plt.show()

We can see that both are Gaussian with mean zero, but the standar deviation is larger in the product, so values are more spread around the curve.
- This is why we saw in the network that the activation values are so frequently larger

**Question**: How to scale the weights for achieving a better behavior

Intuition: reduce the weight values

In [None]:
x = torch.randn(1000, 10)
w = torch.randn(10, 200) * 0.2
y = x @ w
print(x.mean(), x.std())
print(y.mean(), y.std()) 
plt.figure(figsize=(20, 5))
plt.subplot(121)
plt.hist(x.view(-1).tolist(), 50, density=True)
plt.subplot(122)
plt.hist(y.view(-1).tolist(), 50, density=True)
plt.show()

Then, what value to use to obtaing a standard deviation of 1?
- From statistics: divide the weights for the square root of the number of rows (fan in)

In [None]:
x = torch.randn(1000, 10)
w = torch.randn(10, 200) / 10**0.5
y = x @ w
print(x.mean(), x.std())
print(y.mean(), y.std()) 
plt.figure(figsize=(20, 5))
plt.subplot(121)
plt.hist(x.view(-1).tolist(), 50, density=True)
plt.subplot(122)
plt.hist(y.view(-1).tolist(), 50, density=True)
plt.show()

In very deep networks we need that this activations do not grow or shrink too much.

When we have activation functions, we need to also take care of them when initializing. Lets see the following code:

In [None]:
x = torch.randn(1000, 10)
w = torch.randn(10, 200) / 10**0.5
y = (x @ w).tanh()
print(x.mean(), x.std())
print(y.mean(), y.std()) 
plt.figure(figsize=(20, 5))
plt.subplot(121)
plt.hist(x.view(-1).tolist(), 50, density=True)
plt.subplot(122)
plt.hist(y.view(-1).tolist(), 50, density=True)
plt.show()

In this [paper](papers/He_Delving_Deep_into_ICCV_2015_paper.pdf), the problem is explored and solutions are proposed. They found that, to compensate for squashing functions, we need to use a factor. 

Torch has a function for performing that correction, named *torch.nn.init.kaiming_normal*

In the past, tunning all this parameters were crucial, and the resulting network was frequetly fragile. 
- Fortunatelly, we have now other tools that makes the final quality less sensitive to weight initialization

One of these tools is batch normalization.

## Batch Normalization

Batch normalization is introduced in 2015 for a team at Google as a way to allow training very deep networks [PDF](papers/ioffe15.pdf)

The proposed idea is, instead of trying to guarantee the good properties of the pre-activation values by initializing the weights, lets directly modify the matrixes before being used on each epoch.

- It can be done because normalizing the matrix is a differentiable operation, so it can be included in the backpropagation mechanism

Lets add this to our current implementation

In [None]:
# Run the whole training
n_embd = 10
n_hidden = 200

g = torch.Generator(device=device).manual_seed(31416+1)
C = torch.randn((vocab_size, n_embd), generator=g, device=device)
W1 = torch.randn((block_size*n_embd, n_hidden), generator=g, device=device)    * 0.2
b1 = torch.randn(n_hidden, generator=g, device=device)                         * 0.0
W2 = torch.randn([n_hidden, len(all_chars)], generator=g, device=device)       * 0.01
b2 = torch.rand(len(all_chars), generator=g, device=device)                    * 0.0
parameters = [C, W1, b1, W2, b2]

print("# Params: ", sum(p.nelement() for p in parameters))
 
for p in parameters:
    p.requires_grad = True


max_steps = 100000
batch_size = 32
lossi = []

for epoch in range(max_steps):

    # build minibatch
    ix = torch.randint(0, Xtr.shape[0], (batch_size,), generator=g, device=device)
    Xb, Yb = Xtr[ix], Ytr[ix]

    # forward pass
    emb = C[Xb] 
    embcat = emb.view(emb.shape[0], -1)
    hpreact = embcat @ W1 + b1

    # doing this, every single neuron has mean 0 and stdev 1 for all its activations in the batch
    hpreact = (hpreact - hpreact.mean(0, keepdim=True)) / hpreact.std(0, keepdim=True)

    h = torch.tanh(hpreact)
    logits = h @ W2 + b2
    loss = F.cross_entropy(logits, Yb)
    

    # backward pass
    for p in parameters:
        p.grad = None
    loss.backward() 

    if epoch % 10000 == 0:
        print(f"{epoch:7d}/{max_steps:7d}: {loss.item():.4f}")
    lossi.append(loss.log10().item())

    learning_rate = 0.1 if epoch < max_steps / 2 else 0.01
    # update
    for p in parameters:
        p.data -= learning_rate * p.grad

    break

A problem with this normalization is that it completelly removes the bias and any multiplicative factor of the weight matrix, so we need to add them manually after normalization.º

In [None]:
# Run the whole training
n_embd = 10
n_hidden = 200

g = torch.Generator(device=device).manual_seed(31416+1)
C = torch.randn((vocab_size, n_embd), generator=g, device=device)
W1 = torch.randn((block_size*n_embd, n_hidden), generator=g, device=device)    * 0.2
b1 = torch.randn(n_hidden, generator=g, device=device)                         * 0.0
W2 = torch.randn([n_hidden, len(all_chars)], generator=g, device=device)       * 0.01
b2 = torch.rand(len(all_chars), generator=g, device=device)                    * 0.0

bngain = torch.ones((1, n_hidden), device=device)
bnbias = torch.zeros((1, n_hidden), device=device)
parameters = [C, W1, b1, W2, b2, bngain, bnbias]

print("# Params: ", sum(p.nelement() for p in parameters))
 
for p in parameters:
    p.requires_grad = True


max_steps = 1000
batch_size = 32
lossi = []

for epoch in range(max_steps):

    # build minibatch
    ix = torch.randint(0, Xtr.shape[0], (batch_size,), generator=g, device=device)
    Xb, Yb = Xtr[ix], Ytr[ix]

    # forward pass
    emb = C[Xb] 
    embcat = emb.view(emb.shape[0], -1)
    hpreact = embcat @ W1 + b1

    # doing this, every single neuron has mean 0 and stdev 1 for all its activations in the batch
    hpreact = bngain * (hpreact - hpreact.mean(0, keepdim=True)) / hpreact.std(0, keepdim=True) + bnbias

    h = torch.tanh(hpreact)
    logits = h @ W2 + b2
    loss = F.cross_entropy(logits, Yb)
    

    # backward pass
    for p in parameters:
        p.grad = None
    loss.backward() 

    if epoch % (max_steps // 10) == 0:
        print(f"{epoch:7d}/{max_steps:7d}: {loss.item():.4f}")
    lossi.append(loss.log10().item())

    learning_rate = 0.1 if epoch < max_steps / 2 else 0.01
    # update
    for p in parameters:
        p.data -= learning_rate * p.grad

        


In [None]:
@torch.no_grad()
def split_loss2(split):
    x, y = splits[split]
    emb = C[x] 
    embcat = emb.view(emb.shape[0], -1)
    hpreact = embcat @ W1 + b1
    hpreact = bngain * (hpreact - hpreact.mean(0, keepdim=True)) / hpreact.std(0, keepdim=True) + bnbias
    h = torch.tanh(hpreact)
    logits = h @ W2 + b2
    loss = F.cross_entropy(logits, y)
    print(split, loss.item())

split_loss2('train')
split_loss2('test')

In the evaluation code we are estimating the mean and deviation using the whole validation data but, what happend if we need to process a single instance?
- This shows an important problem of batch normalization: we are tighting together many examples randomly taken for guiding the learning process. There are other normalization methods that solves this issue.
- Other problem is what values of mean and deviation to use in the exploitation of the network.

The solution proposed in the paper for this last problem is the following:

In [None]:
# Run the whole training
n_embd = 10
n_hidden = 200

g = torch.Generator(device=device).manual_seed(31416+1)
C = torch.randn((vocab_size, n_embd), generator=g, device=device)
W1 = torch.randn((block_size*n_embd, n_hidden), generator=g, device=device)    * 0.2
b1 = torch.randn(n_hidden, generator=g, device=device)                         * 0.0
W2 = torch.randn([n_hidden, len(all_chars)], generator=g, device=device)       * 0.01
b2 = torch.rand(len(all_chars), generator=g, device=device)                    * 0.0

bngain = torch.ones((1, n_hidden), device=device)
bnbias = torch.zeros((1, n_hidden), device=device)
parameters = [C, W1, b1, W2, b2, bngain, bnbias]

print("# Params: ", sum(p.nelement() for p in parameters))
 
for p in parameters:
    p.requires_grad = True

bnmean_running = torch.zeros((1, n_hidden), device=device)
bnstd_running = torch.ones((1, n_hidden), device=device)

max_steps = 100000
batch_size = 32
lossi = []

for epoch in range(max_steps):

    # build minibatch
    ix = torch.randint(0, Xtr.shape[0], (batch_size,), generator=g, device=device)
    Xb, Yb = Xtr[ix], Ytr[ix]

    # forward pass
    emb = C[Xb] 
    embcat = emb.view(emb.shape[0], -1)
    hpreact = embcat @ W1 + b1

    bnmean = hpreact.mean(0, keepdim=True)
    bnstd = hpreact.std(0, keepdim=True)
    hpreact = bngain * (hpreact - bnmean) / bnstd + bnbias

    with torch.no_grad():
        bnmean_running = 0.99 * bnmean_running + 0.01 * bnmean
        bnstd_running = 0.99 * bnstd_running + 0.01 * bnstd

    h = torch.tanh(hpreact)
    logits = h @ W2 + b2
    loss = F.cross_entropy(logits, Yb)
    

    # backward pass
    for p in parameters:
        p.grad = None
    loss.backward() 

    if epoch % 10000 == 0:
        print(f"{epoch:7d}/{max_steps:7d}: {loss.item():.4f}")
    lossi.append(loss.log10().item())

    learning_rate = 0.1 if epoch < max_steps / 2 else 0.01
    # update
    for p in parameters:
        p.data -= learning_rate * p.grad   

    break   


In [None]:
bnmean_running.shape, bnstd_running.shape

In [None]:
@torch.no_grad()
def split_loss2(split):
    x, y = splits[split]
    emb = C[x] 
    embcat = emb.view(emb.shape[0], -1)
    hpreact = embcat @ W1 + b1
    hpreact = bngain * (hpreact - bnmean_running) / bnstd_running + bnbias
    h = torch.tanh(hpreact)
    logits = h @ W2 + b2
    loss = F.cross_entropy(logits, y)
    print(split, loss.item())

split_loss2('train')
split_loss2('test')

Although its drawbacks, using batch statistics in batch normalization has a regularization effect that additionally improves the results:
- The slight modification of all the activation matrixes works as a type of data augmentation 

## Batch normalization using torch.nn

[nn.BatchNorm1d](https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm1d.html)

In [None]:
import torch.nn as nn

torch.manual_seed(31416)

# Define the custom model class
class MyModel(nn.Module):
    def __init__(self, n_embd, n_hidden):
        super(MyModel, self).__init__()
        self.embedding_layer = nn.Embedding(vocab_size, n_embd, device=device)
        self.flat = nn.Flatten(start_dim=1)
        self.h_layer = nn.Linear(block_size*n_embd, n_hidden, bias=False, device=device)
        self.bn_layer = nn.BatchNorm1d(num_features=n_hidden, device=device)
        self.out_layer = nn.Linear(n_hidden, len(all_chars), device=device)

    def forward(self, xs):
        x = self.embedding_layer(xs).flatten(1)
        x = self.h_layer(x)
        x = self.bn_layer(x).tanh()
        x = self.out_layer(x)
        return x

# Instantiate the model
model = MyModel(n_embd = 10, n_hidden = 200)

# Define the optimizer
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

max_steps = 10000
batch_size = 32
lossi = []

model.train()
for epoch in range(max_steps):
    # Forward pass
    ix = torch.randint(0, Xtr.shape[0], (batch_size,), generator=g, device=device)
    Xb, Yb = Xtr[ix], Ytr[ix]

    outputs = model(Xb)
    loss = F.cross_entropy(outputs, Yb)
    
    if epoch % 1000 == 0:
        print(f"Epoch {epoch}, loss={loss.item()}")
    
    # Backward pass
    optimizer.zero_grad()
    loss.backward()

    lossi.append(loss.item())

    # Update parameters
    optimizer.step()
        
print(loss.item())

In [None]:
splits = {
    "train": (Xtr, Ytr),
    "val": (Xval, Yval),
    "test": (Xte, Yte)
}

@torch.no_grad()
def split_loss2(split):
    x, y = splits[split]
    outputs = model(x)
    loss = F.cross_entropy(outputs, y)
    print(split, loss.item())

model.eval()  # Turn to eval mode, for batch normalization to work ok
split_loss2('train')
split_loss2('test')

In [None]:
avg_loss = torch.tensor(lossi).view(-1, 1000).mean(dim=1)
plt.plot(avg_loss)