## E02. Folding batch norm into Linear layer Weights & Biases

In [23]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt # for making figures
%matplotlib inline

In [24]:
# read in all the words
words = open('names.txt', 'r').read().splitlines()
words[:8]

['emma', 'olivia', 'ava', 'isabella', 'sophia', 'charlotte', 'mia', 'amelia']

In [25]:
# build the vocabulary of characters and mappings to/from integers
chars = sorted(list(set(''.join(words))))
stoi = {s:i+1 for i,s in enumerate(chars)}
stoi['.'] = 0
itos = {i:s for s,i in stoi.items()}
vocab_size = len(itos)
print(itos)
print(vocab_size)

{1: 'a', 2: 'b', 3: 'c', 4: 'd', 5: 'e', 6: 'f', 7: 'g', 8: 'h', 9: 'i', 10: 'j', 11: 'k', 12: 'l', 13: 'm', 14: 'n', 15: 'o', 16: 'p', 17: 'q', 18: 'r', 19: 's', 20: 't', 21: 'u', 22: 'v', 23: 'w', 24: 'x', 25: 'y', 26: 'z', 0: '.'}
27


In [26]:
# build the dataset
block_size = 3 # context length: how many characters do we take to predict the next one?

def build_dataset(words):  
  X, Y = [], []
  
  for w in words:
    context = [0] * block_size
    for ch in w + '.':
      ix = stoi[ch]
      X.append(context)
      Y.append(ix)
      context = context[1:] + [ix] # crop and append

  X = torch.tensor(X)
  Y = torch.tensor(Y)
  print(X.shape, Y.shape)
  return X, Y

import random
random.seed(42)
random.shuffle(words)
n1 = int(0.8*len(words))
n2 = int(0.9*len(words))

Xtr,  Ytr  = build_dataset(words[:n1])     # 80%
Xdev, Ydev = build_dataset(words[n1:n2])   # 10%
Xte,  Yte  = build_dataset(words[n2:])     # 10%

torch.Size([182625, 3]) torch.Size([182625])
torch.Size([22655, 3]) torch.Size([22655])
torch.Size([22866, 3]) torch.Size([22866])


In [162]:
# Let's train a deeper network
# The classes we create here are the same API as nn.Module in PyTorch

class Linear:
  
  def __init__(self, fan_in, fan_out, bias=True):
    self.in_features = fan_in
    self.out_features = fan_out
    self.weight = torch.randn((self.in_features, self.out_features), generator=g) #/ fan_in**0.5
    # self.weight = torch.zeros(fan_in, fan_out) 
    self.bias = torch.zeros(self.out_features) if bias else None
  
  def __call__(self, x):
    self.out = x @ self.weight
    if self.bias is not None:
      self.out += self.bias
    return self.out
  
  def parameters(self):
    return [self.weight] + ([] if self.bias is None else [self.bias])


class BatchNorm1d:
  
  def __init__(self, dim, eps=1e-5, momentum=0.1):
    self.eps = eps
    self.momentum = momentum
    self.training = True
    # parameters (trained with backprop)
    # self.gamma = torch.zeros(dim)
    self.gamma = torch.ones(dim)
    self.beta = torch.zeros(dim)
    # buffers (trained with a running 'momentum update')
    self.running_mean = torch.zeros(dim)
    self.running_var = torch.ones(dim)
  
  def __call__(self, x):
    # calculate the forward pass
    if self.training:
      self.xmean = x.mean(0, keepdim=True) # batch mean
      self.xvar = x.var(0, keepdim=True) # batch variance
    else:
      self.xmean = self.running_mean
      self.xvar = self.running_var
    xhat = (x - self.xmean) / torch.sqrt(self.xvar + self.eps) # normalize to unit variance
    self.out = self.gamma * xhat + self.beta
    # update the buffers
    if self.training:
      with torch.no_grad():
        self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * self.xmean
        self.running_var = (1 - self.momentum) * self.running_var + self.momentum * self.xvar
    return self.out
  
  def parameters(self):
    return [self.gamma, self.beta]

class Tanh:
  def __call__(self, x):
    self.out = torch.tanh(x)
    return self.out
  def parameters(self):
    return []


n_embd = 10 # the dimensionality of the character embedding vectors
n_hidden = 100 # the number of neurons in the hidden layer of the MLP
g = torch.Generator().manual_seed(2147483647) # for reproducibility

C  = torch.randn((vocab_size, n_embd),            generator=g)

layers = [
  Linear(n_embd * block_size, n_hidden, bias=False), BatchNorm1d(n_hidden), Tanh(),
  Linear(           n_hidden, n_hidden, bias=False), BatchNorm1d(n_hidden), Tanh(),
  Linear(           n_hidden, vocab_size, bias=False), BatchNorm1d(vocab_size), 
]

with torch.no_grad():
  # last layer: make less confident
  # layers[-1].weight *= 0.1
  layers[-1].gamma *= 0.1
  # all other layers: apply gain
  for layer in layers[:-1]:
    if isinstance(layer, Linear):
      layer.weight *= 1.0 #5/3

parameters = [C] + [p for layer in layers for p in layer.parameters()]
print(sum(p.nelement() for p in parameters)) # number of parameters in total
for p in parameters:
  p.requires_grad = True

16424


In [163]:
# same optimization as last time
max_steps = 200000
batch_size = 32
lossi = []
# Update-to-data ratio
ud = []

for i in range(max_steps):
    # minibatch construct
    ix = torch.randint(0, Xtr.shape[0], (batch_size,), generator=g)
    Xb, Yb = Xtr[ix], Ytr[ix] # batch X,Y

    # forward pass
    emb = C[Xb] # embed the characters into vectors
    x = emb.view(emb.shape[0], -1) # concatenate the vectors
    for layer in layers:
        x = layer(x)
    loss = F.cross_entropy(x, Yb) # loss function

    # backward pass
    # for layer in layers:
        # layer.out.retain_grad() # AFTER_DEBUG: would take out retain_graph
    for p in parameters:
        p.grad = None
    loss.backward()

    # update
    lr = 1.0
    # lr = 0.1 if i < 100000 else 0.01 # step learning rate decay
    for p in parameters:
        p.data += -lr * p.grad

    # track stats
    if i % 10000 == 0: # print every once in a while
        print(f'{i:7d}/{max_steps:7d}: {loss.item():.4f}')
    lossi.append(loss.log10().item())
    with torch.no_grad():
        ud.append([((lr*p.grad).std() / p.data.std()).log10().item() for p in parameters])
    if i >= 1000:
        break # AFTER_DEBUG: would take out obviously to run full optimization

      0/ 200000: 3.3138


### Checking the forward pass

In [164]:
for layer in layers:
    if isinstance(layer, BatchNorm1d):
        layer.training = False

In [182]:
with torch.no_grad():
    # minibatch construct
    ix = torch.tensor([5])
    print(ix)
    Xb, Yb = Xtr[ix], Ytr[ix] # batch X,Y
    print(Xb)
    print(Yb)

    # forward pass
    emb = C[Xb] # embed the characters into vectors
    print(emb)
    x = emb.view(emb.shape[0], -1) # concatenate the vectors
    print(x)
    for layer in layers:
        x = layer(x)
        print(x, type(layer))
        if isinstance(layer, BatchNorm1d):
            print(layer.training)
            print(layer.xmean)
            print(layer.xvar)
    print(x)
    loss = F.cross_entropy(x, Yb) # loss function
    print(loss)

tensor([5])
tensor([[ 8,  5, 14]])
tensor([7])
tensor([[[-1.4627,  0.9584, -0.8036, -1.4208,  1.1574, -0.7327,  0.9706,
          -1.0113, -0.7093,  1.3238],
         [ 0.1966,  0.4489,  1.1327, -0.9615,  0.8308, -0.0270,  1.2264,
           0.5786, -1.7541, -1.0490],
         [-0.0664, -0.3270, -1.9017, -0.5343,  0.1581,  0.7689,  0.8360,
          -1.1002, -0.8201,  0.0174]]])
tensor([[-1.4627,  0.9584, -0.8036, -1.4208,  1.1574, -0.7327,  0.9706, -1.0113,
         -0.7093,  1.3238,  0.1966,  0.4489,  1.1327, -0.9615,  0.8308, -0.0270,
          1.2264,  0.5786, -1.7541, -1.0490, -0.0664, -0.3270, -1.9017, -0.5343,
          0.1581,  0.7689,  0.8360, -1.1002, -0.8201,  0.0174]])
tensor([[-15.0269,   5.2077,  -2.7655,  -7.2338,   0.4678,  -7.5516, -10.3462,
          -4.6439,  -0.2803,   0.0743,   6.7710,  13.8418,   3.3735,   7.7241,
          -3.3188,  -2.4831,  -6.1476,  -3.8310,  -0.2850,  -8.0273,  -6.3942,
          -0.1582,  -0.2489,   1.6168,  -6.0590,  -0.1325,  -0.4848,  -5.

## Let's remove the BatchNorm

In [183]:
# new_layers = [
#   Linear(n_embd * block_size, n_hidden, bias=False), Tanh(),
#   Linear(           n_hidden, n_hidden, bias=False), Tanh(),
#   Linear(           n_hidden, vocab_size, bias=False), 
# ]


In [184]:
new_layers = []
with torch.no_grad():
    for i in range(len(layers)):
        if isinstance(layers[i], Linear):
            # Create a new Linear layer with the same parameters as the original layer
            new_layer = Linear(layers[i].in_features, layers[i].out_features, bias = True)
            
            # Modify the weights and biases of the new layer
            gamma = torch.tensor(layers[i+1].gamma)
            beta = torch.tensor(layers[i+1].beta)
            mean = torch.tensor(layers[i+1].xmean)
            weight = torch.tensor(layers[i].weight)
            std = torch.sqrt(torch.tensor(layers[i+1].xvar) + torch.tensor(layers[i+1].eps))
            new_layer.weight = weight
            new_layer.weight = (new_layer.weight / std) * gamma
            new_layer.bias = beta - (gamma * (mean / std))
            
            new_layers.append(new_layer)
        
        if isinstance(layers[i], Tanh):
            new_layer = Tanh()
            new_layers.append(new_layer)

  gamma = torch.tensor(layers[i+1].gamma)
  beta = torch.tensor(layers[i+1].beta)
  mean = torch.tensor(layers[i+1].xmean)
  weight = torch.tensor(layers[i].weight)
  std = torch.sqrt(torch.tensor(layers[i+1].xvar) + torch.tensor(layers[i+1].eps))


In [185]:
with torch.no_grad():
    # minibatch construct
    ix = torch.tensor([5])
    print(ix)
    Xb, Yb = Xtr[ix], Ytr[ix] # batch X,Y
    print(Xb)
    print(Yb)

    # forward pass
    emb = C[Xb] # embed the characters into vectors
    print(emb)
    x = emb.view(emb.shape[0], -1) # concatenate the vectors
    print(x)
    for layer in new_layers:
        x = layer(x)
        print(x, type(layer))
    loss = F.cross_entropy(x, Yb) # loss function
    print(loss)

tensor([5])
tensor([[ 8,  5, 14]])
tensor([7])
tensor([[[-1.4627,  0.9584, -0.8036, -1.4208,  1.1574, -0.7327,  0.9706,
          -1.0113, -0.7093,  1.3238],
         [ 0.1966,  0.4489,  1.1327, -0.9615,  0.8308, -0.0270,  1.2264,
           0.5786, -1.7541, -1.0490],
         [-0.0664, -0.3270, -1.9017, -0.5343,  0.1581,  0.7689,  0.8360,
          -1.1002, -0.8201,  0.0174]]])
tensor([[-1.4627,  0.9584, -0.8036, -1.4208,  1.1574, -0.7327,  0.9706, -1.0113,
         -0.7093,  1.3238,  0.1966,  0.4489,  1.1327, -0.9615,  0.8308, -0.0270,
          1.2264,  0.5786, -1.7541, -1.0490, -0.0664, -0.3270, -1.9017, -0.5343,
          0.1581,  0.7689,  0.8360, -1.1002, -0.8201,  0.0174]])
tensor([[-1.2995,  0.3919,  0.0431, -1.7861,  0.2265, -0.9747, -1.6103, -0.3442,
          0.9947,  0.0622,  1.7110,  1.4984,  0.1674,  1.7259, -0.4804, -0.7882,
         -1.2727,  0.3393,  0.2982, -1.1001, -0.2605,  0.1389,  0.0096, -0.9884,
         -1.1252,  0.3007, -0.3896, -1.1593,  1.3266, -0.7473,  0.4