In [1]:
# let's pick up where we left off in ngram_mlp_weight_initialization.
import torch
import torch.nn.functional as F
torch.set_default_device("mps")

In [2]:
# we need to include some stuff from before.
g = torch.Generator(device='mps').manual_seed(2147483647) # for reproducibility
# read in all the words
words = open('res/names.txt', 'r').read().splitlines()
# build the vocabulary of characters and mappings to/from integers
chars = sorted(list(set(''.join(words))))
stoi = {s:i+1 for i,s in enumerate(chars)}
stoi['.'] = 0
itos = {i:s for s,i in stoi.items()}
vocab_size = len(itos)

In [3]:
# build the dataset
block_size = 3 # context length: how many characters do we take to predict the next one?

def build_dataset(words):
  X, Y = [], []

  for w in words:
    context = [0] * block_size
    for ch in w + '.':
      ix = stoi[ch]
      X.append(context)
      Y.append(ix)
      context = context[1:] + [ix] # crop and append

  X = torch.tensor(X)
  Y = torch.tensor(Y)
  return X, Y

import random
random.seed(42)
random.shuffle(words)
n1 = int(0.8*len(words))
n2 = int(0.9*len(words))

Xtr,  Ytr  = build_dataset(words[:n1])     # 80%
Xdev, Ydev = build_dataset(words[n1:n2])   # 10%
Xte,  Yte  = build_dataset(words[n2:])     # 10%

In [4]:
# build the network from last time
n_embd = 10 # the dimensionality of the character embedding vectors
n_hidden = 200 # the number of neurons in the hidden layer of the MLP

tanh_gain = 5 / 3
C  = torch.randn((vocab_size, n_embd),            generator=g)
W1 = torch.randn((n_embd * block_size, n_hidden), generator=g) * tanh_gain / (n_embd * block_size)**0.5  # kaiming_normal
b1 = torch.rand(n_hidden,                         generator=g) / n_hidden**0.5  # pytorch uniform init
W2 = torch.randn((n_hidden, vocab_size),          generator=g) * (2 / (n_hidden + vocab_size))**0.5  # xavier_normal
b2 = torch.rand(vocab_size,                       generator=g) / vocab_size**0.5  # pytorch uniform init

parameters = [C, W1, b1, W2, b2]
print(sum(p.nelement() for p in parameters)) # number of parameters in total
for p in parameters:
  p.requires_grad = True

11897


In [5]:
# it turns out that weight initialization isn't as important as it used to be (I know... bummer).
#   - for deep neural networks, weight initialization is too fragile and impractical to either ensure good
#     starting gradients or numerical stability while training the network.
#   - enter batch normalization to the rescue.

# batch normalization is a technique in machine learning that makes training neural networks faster and more stable.
# it works by standardizing the inputs to each layer so that they are unit gaussian — adjusting them to have a mean
# of zero and a variance of one — within a mini-batch of data. then, it scales and shifts the standardized values
# using learnable parameters. this reduces issues like vanishing gradients and helps the network learn more efficiently.
# think of it as keeping the data "well-behaved" and numerically stable as it flows through the layers!

# let's revisit the training process from last time.
max_steps = 200000
batch_size = 32
lossi = []

def step(epoch=0):
  global hpreact
  global h

  # minibatch construct
  ix = torch.randint(0, Xtr.shape[0], (batch_size,), generator=g)
  Xb, Yb = Xtr[ix], Ytr[ix] # batch X,Y

  # forward pass
  emb = C[Xb] # embed the characters into vectors
  embcat = emb.view(emb.shape[0], -1) # concatenate the vectors
  hpreact = embcat @ W1 + b1 # hidden layer pre-activation
  h = torch.tanh(hpreact) # hidden layer
  logits = h @ W2 + b2 # output layer
  loss = F.cross_entropy(logits, Yb) # loss function

  # backward pass
  for p in parameters:
    p.grad = None
  loss.backward()

  # update
  lr = 0.1 if epoch < 100000 else 0.01 # step learning rate decay
  for p in parameters:
    p.data += -lr * p.grad

step()  # perform 1 step of gradient descent

print(hpreact.shape)
print("Mean:", hpreact.mean().item())
print("Standard Deviation:", hpreact.std().item())

torch.Size([32, 200])
Mean: -0.052187662571668625
Standard Deviation: 1.8892154693603516


In [6]:
# How do we standardize `hpreact` to be unit gaussian?
print(hpreact.mean(0, keepdim=True).shape)  # mean
print(hpreact.std(0, keepdim=True).shape)  # standard deviation
hpreact_unit = (hpreact - hpreact.mean(0, keepdim=True)) / hpreact.std(0, keepdim=True)
print("Mean:", hpreact_unit.mean().item())  # mean ~= 0
print("Standard Deviation:", hpreact_unit.std().item())  # std ~= 1

torch.Size([1, 200])
torch.Size([1, 200])
Mean: 0.0
Standard Deviation: 0.9843279123306274


In [7]:
# let's update the model and training process to use batch normalization.

# -----------------------------------------------------------------------------------------------------------
# add batch normalization - update model parameters
# -----------------------------------------------------------------------------------------------------------
n_embd = 10 # the dimensionality of the character embedding vectors
n_hidden = 200 # the number of neurons in the hidden layer of the MLP

tanh_gain = 5 / 3
C  = torch.randn((vocab_size, n_embd),            generator=g)
W1 = torch.randn((n_embd * block_size, n_hidden), generator=g) * tanh_gain / (n_embd * block_size)**0.5  # kaiming_normal
b1 = torch.rand(n_hidden,                         generator=g) / n_hidden**0.5  # pytorch uniform init
W2 = torch.randn((n_hidden, vocab_size),          generator=g) * (2 / (n_hidden + vocab_size))**0.5  # xavier_normal
b2 = torch.rand(vocab_size,                       generator=g) / vocab_size**0.5  # pytorch uniform init

bngain = torch.ones((1, n_hidden))  # batch norm parameter (scale)
bnbias = torch.zeros((1, n_hidden))  # batch norm parameter (shift)

parameters = [C, W1, b1, W2, b2, bngain, bnbias]
print(sum(p.nelement() for p in parameters)) # number of parameters in total
for p in parameters:
  p.requires_grad = True

# -----------------------------------------------------------------------------------------------------------
# add batch normalization - update training process
# -----------------------------------------------------------------------------------------------------------
def step(epoch=0):

  # minibatch construct
  ix = torch.randint(0, Xtr.shape[0], (batch_size,), generator=g)
  Xb, Yb = Xtr[ix], Ytr[ix] # batch X,Y

  # forward pass
  emb = C[Xb] # embed the characters into vectors
  embcat = emb.view(emb.shape[0], -1) # concatenate the vectors
  hpreact = embcat @ W1 + b1 # hidden layer pre-activation
  # --- batch normalization start ---
  hpreact_mean = hpreact.mean(0, keepdim=True)  # find the mean
  hpreact_std = hpreact.std(0, keepdim=True)  # find the standard deviation
  hpreact_gaus = (hpreact - hpreact_mean) / hpreact_std  # build unit gaussian
  hpreact_bn = bngain * hpreact_gaus + bnbias  # scale and shift the unit gaussian with learnable parameters
  # --- batch normalization end ---
  h = torch.tanh(hpreact_bn) # hidden layer
  logits = h @ W2 + b2 # output layer
  loss = F.cross_entropy(logits, Yb) # loss function

  # backward pass
  for p in parameters:
    p.grad = None
  loss.backward()

  # update
  lr = 0.1 if epoch < 100000 else 0.01 # step learning rate decay
  for p in parameters:
    p.data += -lr * p.grad

  # track stats
  if epoch % 10000 == 0: # print every once in a while
    print(f'{epoch:7d}/{max_steps:7d}: {loss.item():.4f}')
  lossi.append(loss.log10().item())

12297


In [8]:
# Now lets train the model with batch normalization
for epoch in range(max_steps):
    step(epoch)

      0/ 200000: 3.7462
  10000/ 200000: 2.2720
  20000/ 200000: 2.0809
  30000/ 200000: 2.3447
  40000/ 200000: 2.3300
  50000/ 200000: 2.3601
  60000/ 200000: 2.3727
  70000/ 200000: 1.9113
  80000/ 200000: 2.0702
  90000/ 200000: 2.2670
 100000/ 200000: 2.0914
 110000/ 200000: 2.1400
 120000/ 200000: 1.9095
 130000/ 200000: 1.9673
 140000/ 200000: 2.0318
 150000/ 200000: 2.1418
 160000/ 200000: 2.0933
 170000/ 200000: 1.9627
 180000/ 200000: 1.8675
 190000/ 200000: 2.2156


In [9]:
# evaluate train loss and validation loss
@torch.no_grad() # this decorator disables gradient tracking
def split_loss(split):
  x, y = {
    'train': (Xtr, Ytr),
    'val': (Xdev, Ydev),
    'test': (Xte, Yte),
  }[split]
  emb = C[x] # (N, block_size, n_embd)
  embcat = emb.view(emb.shape[0], -1) # concat into (N, block_size * n_embd)
  hpreact = embcat @ W1 + b1
  # --- batch normalization start ---
  hpreact_mean = hpreact.mean(0, keepdim=True)  # find the mean
  hpreact_std = hpreact.std(0, keepdim=True)  # find the standard deviation
  hpreact_gaus = (hpreact - hpreact_mean) / hpreact_std  # build unit gaussian
  hpreact_bn = bngain * hpreact_gaus + bnbias  # scale and shift the unit gaussian with learnable parameters
  # --- batch normalization end ---
  h = torch.tanh(hpreact_bn) # (N, n_hidden)
  logits = h @ W2 + b2 # (N, vocab_size)
  loss = F.cross_entropy(logits, y)
  print(split, loss.item())

split_loss('train')
split_loss('val')

# cool! our validation loss originally was 2.2065, and now it's 2.1139 - roughly the same as what we got with
# weight initialization (2.1053). So it's working, but it's not really doing much "heavy lifting" because
# weight initialization is already normalizing the output distribution with respect to the input distribution.

train 2.07417631149292
val 2.1139535903930664


In [10]:
# btw... batch normalization creates a mathematical "coupling" (i.e. dependency) between all the samples of a
# normalized batch, which leads to strange bugs. It's apparently really easy to shoot yourself in the foot
# with this layer, as Karpathi confesses (he says to avoid it).

# This isn't desired and researchers have come up with alternatives:
#   - Group Normalization
#   - Layer Normalization
#   - Instance Normalization

# we won't be exploring these here, but they might be worth looking into.

In [11]:
# another challenge that batch normalization introduces is that we can no longer easily pass single inputs through
# our model, since the model now needs the mean and standard deviation of a batch in the forward calculation.
#   - to get around this, we can estimate the mean and standard deviation of the entire train dataset and use these
#     values in the forward pass with a single sample.
#   - we can estimate this mean and standard deviation while training the model.

# another problem is that the bias term of the layer we are applying batch normalization to is useless. we can understand
# this by looking at the equation of the pre-activation, `h = x @ W + b`, and the equation of batch normalization applied
# to `h`, `h_bn = (h - h_mean) / h_std`.
#   - whatever bias `b` we choose, it will be counteracted by subtracting `h_mean` from `h` in batch normaalization.
#   - so the bias serves no purpose! we can let batch normalization's tunable bias serve the equivalent purpose.

# Let's now fix these issues.

# -----------------------------------------------------------------------------------------------------------
# Use same model as before
# -----------------------------------------------------------------------------------------------------------
n_embd = 10  # the dimensionality of the character embedding vectors
n_hidden = 200  # the number of neurons in the hidden layer of the MLP

tanh_gain = 5 / 3
C  = torch.randn((vocab_size, n_embd),            generator=g)
W1 = torch.randn((n_embd * block_size, n_hidden), generator=g) * tanh_gain / (n_embd * block_size)**0.5  # kaiming_normal
#b1 = torch.rand(n_hidden,                         generator=g) / n_hidden**0.5  # pytorch uniform init
W2 = torch.randn((n_hidden, vocab_size),          generator=g) * (2 / (n_hidden + vocab_size))**0.5  # xavier_normal
b2 = torch.rand(vocab_size,                       generator=g) / vocab_size**0.5  # pytorch uniform init

bngain = torch.ones((1, n_hidden))  # batch norm parameter (scale)
bnbias = torch.zeros((1, n_hidden))  # batch norm parameter (shift)

parameters = [C, W1, W2, b2, bngain, bnbias]
print(sum(p.nelement() for p in parameters)) # number of parameters in total
for p in parameters:
  p.requires_grad = True

# -----------------------------------------------------------------------------------------------------------
# calculate running mean and std for batch normalization
# -----------------------------------------------------------------------------------------------------------
bnmean_running = torch.zeros((1, n_hidden))
bnstd_running = torch.ones((1, n_hidden))
def step(epoch=0):
  global bnmean_running
  global bnstd_running

  # minibatch construct
  ix = torch.randint(0, Xtr.shape[0], (batch_size,), generator=g)
  Xb, Yb = Xtr[ix], Ytr[ix] # batch X,Y

  # forward pass
  emb = C[Xb] # embed the characters into vectors
  embcat = emb.view(emb.shape[0], -1) # concatenate the vectors
  # Linear layer
  hpreact = embcat @ W1 #+ b1 # hidden layer pre-activation
  # BatchNorm layer
  # ------------------------------------------------------------------------------------------------------
  hpreact_mean = hpreact.mean(0, keepdim=True)  # find the mean
  hpreact_std = hpreact.std(0, keepdim=True)  # find the standard deviation
  hpreact_gaus = (hpreact - hpreact_mean) / (hpreact_std + 1e-5)  # build unit gaussian (avoid division by 0)
  hpreact_bn = bngain * hpreact_gaus + bnbias  # scale and shift the unit gaussian with learnable parameters
  with torch.no_grad(): # estimate batch normalization mean and standard deviation over the entire train set
    bnmean_running = 0.999 * bnmean_running + 0.001 * hpreact_mean
    bnstd_running = 0.999 * bnstd_running + 0.001 * hpreact_std
  # ------------------------------------------------------------------------------------------------------
  # Non-linearity
  h = torch.tanh(hpreact_bn) # hidden layer
  logits = h @ W2 + b2 # output layer
  loss = F.cross_entropy(logits, Yb) # loss function

  # backward pass
  for p in parameters:
    p.grad = None
  loss.backward()

  # update
  lr = 0.1 if epoch < 100000 else 0.01 # step learning rate decay
  for p in parameters:
    p.data += -lr * p.grad

  # track stats
  if epoch % 10000 == 0: # print every once in a while
    print(f'{epoch:7d}/{max_steps:7d}: {loss.item():.4f}')
  lossi.append(loss.log10().item())

12097


In [13]:
for epoch in range(max_steps):
    step(epoch)

      0/ 200000: 2.8258
  10000/ 200000: 2.0964
  20000/ 200000: 2.2610
  30000/ 200000: 2.4791
  40000/ 200000: 2.3386
  50000/ 200000: 2.1275
  60000/ 200000: 2.0692
  70000/ 200000: 1.8274
  80000/ 200000: 1.7437
  90000/ 200000: 2.2305
 100000/ 200000: 2.4006
 110000/ 200000: 2.2560
 120000/ 200000: 2.1189
 130000/ 200000: 1.8969
 140000/ 200000: 2.3872
 150000/ 200000: 2.0362
 160000/ 200000: 2.6009
 170000/ 200000: 2.3469
 180000/ 200000: 1.9456
 190000/ 200000: 2.1879


In [14]:
# evaluate train loss and validation loss
@torch.no_grad() # this decorator disables gradient tracking
def split_loss(split):
  x, y = {
    'train': (Xtr, Ytr),
    'val': (Xdev, Ydev),
    'test': (Xte, Yte),
  }[split]
  emb = C[x] # (N, block_size, n_embd)
  embcat = emb.view(emb.shape[0], -1) # concat into (N, block_size * n_embd)
  hpreact = embcat @ W1 #+ b1
  # --- batch normalization start ---
  hpreact_gaus = (hpreact - bnmean_running) / bnstd_running  # build unit gaussian (use running mean and std instead)
  hpreact_bn = bngain * hpreact_gaus + bnbias  # scale and shift the unit gaussian with learnable parameters
  # --- batch normalization end ---
  h = torch.tanh(hpreact_bn) # (N, n_hidden)
  logits = h @ W2 + b2 # (N, vocab_size)
  loss = F.cross_entropy(logits, y)
  print(split, loss.item())

split_loss('train')
split_loss('val')

# awesome! we get basically the same validation loss, but we're using the mean and std over the whole train set instead
# of per batch, allowing us to feed the model single inputs, and we're removed one bias operation.

train 2.059697389602661
val 2.1048686504364014


In [34]:
# let's now sample from this model (since we spent all this time implementing a "fix" to allow us to do so...)
for _ in range(10):
    out = []
    context = [0] * block_size # initialize with all ...
    while True:
      # forward pass the neural net
      emb = C[torch.tensor([context])] # (N, block_size, n_embd)
      embcat = emb.view(emb.shape[0], -1) # concat into (N, block_size * n_embd)
      hpreact = embcat @ W1 #+ b1
      # --- batch normalization start ---
      hpreact_gaus = (hpreact - bnmean_running) / bnstd_running  # build unit gaussian (use running mean and std instead)
      hpreact_bn = bngain * hpreact_gaus + bnbias  # scale and shift the unit gaussian with learnable parameters
      # --- batch normalization end ---
      h = torch.tanh(hpreact_bn) # (N, n_hidden)
      logits = h @ W2 + b2 # (N, vocab_size)
      probs = F.softmax(logits, dim=1)
      # sample from the distribution
      ix = torch.multinomial(probs, num_samples=1, generator=g).item()
      # shift the context window and track the samples
      context = context[1:] + [ix]
      out.append(ix)
      # if we sample the special '.' token, break
      if ix == 0:
        break

    print(''.join(itos[i] for i in out)) # decode and print the generated word

blin.
amyraine.
dhneus.
jolee.
damarken.
aym.
alaylameerediel.
ari.
kael.
mitan.
