E02: BatchNorm, unlike other normalization layers like LayerNorm/GroupNorm etc. has the big advantage that after training, the batchnorm gamma/beta can be "folded into" the weights of the preceeding Linear layers, effectively erasing the need to forward it at test time. Set up a small 3-layer MLP with batchnorms, train the network, then "fold" the batchnorm gamma/beta into the preceeding Linear layer's W,b by creating a new W2, b2 and erasing the batch norm. Verify that this gives the same forward pass during inference. i.e. we see that the batchnorm is there just for stabilizing the training, and can be thrown out after training is done! pretty cool.

In [1]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
names = open("../names.txt", "r").read().splitlines()
names[:8]

['emma', 'olivia', 'ava', 'isabella', 'sophia', 'charlotte', 'mia', 'amelia']

In [3]:
# count the number of letters and prepare the index
letters = set()
for name in names:
    for c in name:
        letters.add(c)

letter_list = list(letters)
letter_list.append(".")
letter_list.sort()

stoi = {s:i for i, s in enumerate(letter_list)}
itos = {i:s for s, i in stoi.items()}

In [4]:
def prep_data(names, block_size):
    x, y = [], []
    for name in names:
    # add the necessary padding to the name
        modified_string = block_size * "." + name + "."
        for ch1, ch2, ch3, ch4 in zip(modified_string, modified_string[1:], modified_string[2:], modified_string[3:]):
            x.append([stoi[ch1], stoi[ch2], stoi[ch3]])
            y.append(stoi[ch4])

    X = torch.tensor(x)
    Y = torch.tensor(y)

    return X, Y

# The data will be split as such - 80% training, 10% dev, 10% test
BLOCK_SIZE = 3

import random
random.seed(42)
random.shuffle(names)

sample_size = len(names)
train_size = int(0.8 * sample_size)
dev_size = int(0.9 * sample_size)

print(train_size, dev_size)

Xtr, Ytr = prep_data(names[:train_size], BLOCK_SIZE)
Xdev, Ydev = prep_data(names[train_size:dev_size], BLOCK_SIZE)
Xtest, Ytest = prep_data(names[dev_size:], BLOCK_SIZE)

print(Xtr.shape, Xdev.shape, Xtest.shape)

25626 28829
torch.Size([182625, 3]) torch.Size([22655, 3]) torch.Size([22866, 3])


In [12]:
# torchify the code

class Linear:
    def __init__(self, fan_in, fan_out, gain=1, bias=True):
        self.weight = torch.randn((fan_in, fan_out)) / (fan_in**0.5)
        # seems like the biases should be torch.zeros?
        self.bias = torch.zeros(fan_out) if bias else None

    def __call__(self, input):
        self.out = input @ self.weight
        if self.bias != None:
            self.out += self.bias
        return self.out

    def parameters(self):
        return [self.weight] + ([self.bias] if self.bias != None else [])

    def shape(self):
        return self.weight.shape, self.bias.shape if self.bias != None else None

class Tanh:    
    def __call__(self, x):
        self.out = torch.tanh(x)
        return self.out

    def parameters(self):
        return []

class BatchNorm1d:
    def __init__(self, input_size, epsilon=1e-5, momentum=0.1):
        self.epsilon = epsilon
        self.momentum = momentum
        # initialise the running mean and running var
        self.training = True
       
        self.gamma = torch.ones(input_size)
        self.beta = torch.zeros(input_size)

        # for a normal distribution, the mean is 0 and the variance is 1
        self.running_mean = torch.zeros(input_size)
        self.running_var = torch.ones(input_size)

    # normalize the values
    def __call__(self, batch):
        # # when dim=0, the output will be (100, 100) -> (100) -> this means that there is one average for each column=
        # you only have to calculate the batch mean if you wanna update it
        if self.training:
            mean = batch.mean(0)
            var = batch.var(0)

        else:
            mean = self.running_mean
            var = self.running_var
            
        normalized = (batch - mean)/torch.sqrt(var + self.epsilon)
        self.out = normalized * self.gamma + self.beta
        
        # update the running values if training
        if self.training:
            with torch.no_grad():
                self.running_mean = self.running_mean * (1 - self.momentum) + mean * self.momentum
                self.running_var = self.running_var * (1 - self.momentum) + var * self.momentum

        return self.out

    def parameters(self):
        return [self.gamma] + [self.beta]


In [13]:
# hyperperameters are gonna be set here
EMBEDDING_DIMENSION = 10
HIDDEN_LAYER_SIZE = 100
VOCAB_SIZE = 27

g = torch.Generator().manual_seed(2147483647) # for reproducibility
enc = torch.randn((VOCAB_SIZE, EMBEDDING_DIMENSION), generator=g)


layers = [
    Linear(EMBEDDING_DIMENSION * BLOCK_SIZE, HIDDEN_LAYER_SIZE, bias=False),
    BatchNorm1d(HIDDEN_LAYER_SIZE, momentum=0.1),
    Tanh(),
    Linear(HIDDEN_LAYER_SIZE, HIDDEN_LAYER_SIZE, bias=False),
    BatchNorm1d(HIDDEN_LAYER_SIZE, momentum=0.1),
    Tanh(),
    Linear(HIDDEN_LAYER_SIZE, VOCAB_SIZE, bias=False),
    BatchNorm1d(VOCAB_SIZE, momentum=0.1),
]

# set the last layer to be less certain, so that the initial outputs are more uniform
layers[-1].gamma *= 0.1

# apply the right gain for batchnorm
for layer in layers[:-1]:
    if isinstance(layer, Linear):
        layer.weight *= 0.7

# count the number of layers and prepare for backprop
parameters = [enc]

for layer in layers:
    parameters += layer.parameters()

sum = enc.nelement()
for p in parameters:
    sum += p.nelement()
    p.requires_grad = True

sum

16694

In [14]:
# track the updates and the data
updates_to_data = []
# perform training
# perform a test run to check if the plotted graphs are eventually correct
for i in range(1000):
    random_sample = torch.randint(0, Xtr.shape[0], (32,), generator=g)
    # create the encoding first
    output = enc[Xtr[random_sample]].view(-1, BLOCK_SIZE * EMBEDDING_DIMENSION)

    for layer in layers:
        output = layer(output)

    loss = F.cross_entropy(output, Ytr[random_sample])

    if i == 0:
        print(f'The initial loss is {loss}')

    # for plotting
    for layer in layers:
        layer.out.retain_grad()

    for p in parameters:
        p.grad = None

    loss.backward()

    lr = 0.1 if i > 100000 else 0.01
    for p in parameters:
        p.data += -lr * p.grad

    # what even are they tracking here man
    with torch.no_grad():
        updates_to_data.append([((lr*p.grad).std() / p.data.std()).log10().item() for p in parameters])
loss

The initial loss is 3.316511392593384


tensor(2.5477, grad_fn=<NllLossBackward0>)

In [15]:
def calculate_loss(X, Y):
    with torch.no_grad():
        for layer in layers:
            if isinstance(layer, BatchNorm1d):
                layer.training = False
            
        xout = enc[X].view(-1, EMBEDDING_DIMENSION*BLOCK_SIZE)
        for layer in layers:
            xout = layer(xout)

        loss = F.cross_entropy(xout, Y)

    return loss.item()
calculate_loss(Xdev, Ydev)

2.805344343185425

In [16]:
# layers = [
#     Linear(EMBEDDING_DIMENSION * BLOCK_SIZE, HIDDEN_LAYER_SIZE, bias=False),
#     BatchNorm1d(HIDDEN_LAYER_SIZE, momentum=0.1),
#     Tanh(),
#     Linear(HIDDEN_LAYER_SIZE, HIDDEN_LAYER_SIZE, bias=False),
#     BatchNorm1d(HIDDEN_LAYER_SIZE, momentum=0.1),
#     Tanh(),
#     Linear(HIDDEN_LAYER_SIZE, VOCAB_SIZE, bias=False),
#     BatchNorm1d(VOCAB_SIZE, momentum=0.1),
# ]

In [22]:
# my assumption is that 'folding the weights' means multiplying the weights by that of the previous batch
w2 = layers[0].weight * layers[1].gamma
b2 = layers[1].beta

w2, b2


(tensor([[-0.0004,  0.0451,  0.0592,  ..., -0.0636, -0.0807, -0.1227],
         [ 0.0448, -0.0233,  0.0886,  ..., -0.1760, -0.1986, -0.1342],
         [ 0.1465,  0.1298, -0.0687,  ..., -0.1589,  0.0764,  0.0351],
         ...,
         [-0.0287, -0.0839,  0.0246,  ...,  0.0157,  0.2501,  0.0295],
         [ 0.1064, -0.2199,  0.1551,  ..., -0.0513, -0.0155,  0.0283],
         [ 0.2185,  0.2005, -0.0183,  ...,  0.1906, -0.1868,  0.0493]],
        grad_fn=<MulBackward0>),
 tensor([ 9.6443e-04, -2.0612e-03, -3.2228e-03, -2.9549e-03,  8.0622e-04,
         -5.8683e-03,  1.2179e-03,  7.2280e-04,  1.0903e-03, -5.2481e-03,
         -9.2608e-04,  3.8336e-04, -3.6237e-03, -3.6296e-03,  2.6891e-03,
          2.0360e-03,  2.2279e-04, -1.4937e-03, -4.2076e-03, -1.3299e-04,
          2.3582e-03,  1.7673e-03, -3.5823e-03,  1.1244e-03,  2.0174e-03,
         -1.7253e-03,  3.8006e-04, -9.1260e-03,  1.3912e-03,  2.2438e-03,
          1.9858e-03, -2.0389e-03, -2.5857e-04,  5.5556e-03,  8.0743e-03,
        

In [26]:
class Layer:
    def __init__(self, weight, bias):
        self.weight = weight
        self.bias = bias

    def __call__(self, x):
        self.out = x @ self.weight + self.bias
        return self.out

new_layers = []
for i in range(0, len(layers), 3):
    w2 = layers[i].weight * layers[i + 1].gamma
    # my Linear layers in the layers list does not have any bias 
    b2 = layers[i + 1].beta

    new_layer = Layer(w2, b2)
    new_layers.append(new_layer)
    new_layers.append(Tanh())

new_layers

[<__main__.Layer at 0x13f45d480>,
 <__main__.Tanh at 0x14a371600>,
 <__main__.Layer at 0x14a24db70>,
 <__main__.Tanh at 0x13f45f190>,
 <__main__.Layer at 0x13f402e00>,
 <__main__.Tanh at 0x14a2013f0>]

In [29]:
def calculate_loss_folded(X, Y):
    with torch.no_grad():
        xout = enc[X].view(-1, EMBEDDING_DIMENSION*BLOCK_SIZE)
        for layer in layers:
            xout = layer(xout)

        loss = F.cross_entropy(xout, Y)

    return loss.item()

# verify that it was done correctly
calculate_loss_folded(Xdev, Ydev), calculate_loss(Xdev, Ydev), calculate_loss_folded(Xdev, Ydev) == calculate_loss(Xdev, Ydev)

(2.805344343185425, 2.805344343185425, True)