## Fold batchnorm beta & gamma into preceding layer's weights and biases
## In a three-layer perceptron

In [1]:
# Split up the dataset randomly into 80% train set, 10% dev set, 10% test set.
# Train the bigram and trigram models only on the training set. Evaluate them on dev and test splits. What can you see?
import random # shuffle the list of words to get an even distribution
import torch

words = open('../names.txt', 'r').read().splitlines()
random.seed(230)
random.shuffle(words)


In [3]:
# Map the words to indexes
chars = sorted(list(set(''.join(words)))) # get the unique characters through the set() method
stoi = {s:i +1 for i,s in enumerate(chars)} # string to index
stoi['.'] = 0 # end character

itos = {i:s for s,i in stoi.items()} # index to string
itos

{1: 'a',
 2: 'b',
 3: 'c',
 4: 'd',
 5: 'e',
 6: 'f',
 7: 'g',
 8: 'h',
 9: 'i',
 10: 'j',
 11: 'k',
 12: 'l',
 13: 'm',
 14: 'n',
 15: 'o',
 16: 'p',
 17: 'q',
 18: 'r',
 19: 's',
 20: 't',
 21: 'u',
 22: 'v',
 23: 'w',
 24: 'x',
 25: 'y',
 26: 'z',
 0: '.'}

In [4]:
# build the dataset
block_size = 3 # context length: how many characters do we take to predict the next one?

def build_dataset(words):  
  xs, ys = [], []
  
  for w in words:
    context = [0] * block_size
    for ch in w + '.':
      ix = stoi[ch]
      xs.append(context)
      ys.append(ix)
      context = context[1:] + [ix] # crop and append

  xs = torch.tensor(xs)
  ys = torch.tensor(ys)
  return xs, ys

import random
random.seed(42)
random.shuffle(words)
n1 = int(0.8*len(words))
n2 = int(0.9*len(words))

Xtr,  Ytr  = build_dataset(words[:n1])     # 80%
Xdev, Ydev = build_dataset(words[n1:n2])   # 10%
Xte,  Yte  = build_dataset(words[n2:])     # 10%

In [53]:
## SET UP MODEL PARAMETERS

import torch.nn.functional as F

gain = 5/3 # we need a gain because we are using the tanh activation function; squashes, add a gain to get back to normal std
dim_emb = 2 # dimensionality of the embedding
n_hidden = 100 # number of hidden units

# set up the model drivers
g = torch.Generator().manual_seed(2147483648)
emb_lookup = torch.randn((len(chars) + 1, dim_emb), generator=g) #also written as 'C'. Can scale up dimensionality to capture more nuanced patterns
W1 = torch.randn((dim_emb * block_size, n_hidden), generator=g)  * (gain / (dim_emb * block_size) **0.5) 
W2 = torch.randn((n_hidden, n_hidden), generator=g) * (gain / n_hidden**0.5)
W3 = torch.randn((n_hidden, n_hidden), generator=g) * (gain / n_hidden**0.5) 
W4 = torch.randn((n_hidden, len(chars) + 1), generator=g) * (gain / n_hidden**0.5) # ending back up with 27 possible outputs / chars
b4 = torch.randn((len(chars) + 1), generator=g) * 0

bngain = torch.ones((1, n_hidden)) # gamma
bnbias = torch.zeros((1, n_hidden)) # beta
bngain_2 = torch.ones((1, n_hidden)) # gamma
bnbias_2 = torch.zeros((1, n_hidden)) # beta
bngain_3 = torch.ones((1, n_hidden)) # gamma
bnbias_3 = torch.zeros((1, n_hidden)) # beta

bnmean_running = torch.zeros((1, n_hidden)) # running mean
bnstd_running = torch.ones((1, n_hidden)) # running std
bnmean_running_2 = torch.zeros((1, n_hidden)) # running mean
bnstd_running_2 = torch.ones((1, n_hidden)) # running std
bnmean_running_3 = torch.zeros((1, n_hidden)) # running mean
bnstd_running_3 = torch.ones((1, n_hidden)) # running std

# put all of the parameters in one array for neatness -- you can sum all these to get total param count
parameters = [emb_lookup, W1, W2, W3, W4, b4, bngain, bnbias, bngain_2, bnbias_2, bngain_3, bnbias_3]
for p in parameters:
    p.requires_grad = True

In [54]:
max_steps = 40000
loss_values = []
learning_rates = []
steps = []

In [58]:
## TRAINING THE MODEL

# Basically the same thing as running a validation / test run except with gradient updates
# max_steps = 2000
# loss_values = []
# learning_rates_loss = {}

# for index, learning_rate in enumerate(lrs): 

#     gain = 5/3 # we need a gain because we are using the tanh activation function; squashes, add a gain to get back to normal std
#     dim_emb = 2 # dimensionality of the embedding
#     n_hidden = 100 # number of hidden units

#     # set up the model drivers
#     g = torch.Generator().manual_seed(2147483648)
#     emb_lookup = torch.randn((len(chars) + 1, dim_emb), generator=g) #also written as 'C'. Can scale up dimensionality to capture more nuanced patterns
#     W1 = torch.randn((dim_emb * block_size, n_hidden), generator=g)  * (gain / (dim_emb * block_size) **0.5) 
#     W2 = torch.randn((n_hidden, n_hidden), generator=g) * (gain / n_hidden**0.5)
#     W3 = torch.randn((n_hidden, n_hidden), generator=g) * (gain / n_hidden**0.5) 
#     W4 = torch.randn((n_hidden, len(chars) + 1), generator=g) * (gain / n_hidden**0.5) # ending back up with 27 possible outputs / chars
#     b4 = torch.randn((len(chars) + 1), generator=g) * 0

#     bngain = torch.ones((1, n_hidden)) # gamma
#     bnbias = torch.zeros((1, n_hidden)) # beta
#     bngain_2 = torch.ones((1, n_hidden)) # gamma
#     bnbias_2 = torch.zeros((1, n_hidden)) # beta
#     bngain_3 = torch.ones((1, n_hidden)) # gamma
#     bnbias_3 = torch.zeros((1, n_hidden)) # beta

#     bnmean_running = torch.zeros((1, n_hidden)) # running mean
#     bnstd_running = torch.ones((1, n_hidden)) # running std
#     bnmean_running_2 = torch.zeros((1, n_hidden)) # running mean
#     bnstd_running_2 = torch.ones((1, n_hidden)) # running std
#     bnmean_running_3 = torch.zeros((1, n_hidden)) # running mean
#     bnstd_running_3 = torch.ones((1, n_hidden)) # running std

#     # put all of the parameters in one array for neatness -- you can sum all these to get total param count
#     parameters = [emb_lookup, W1, W2, W3, W4, b4, bngain, bnbias, bngain_2, bnbias_2, bngain_3, bnbias_3]
#     for p in parameters:
#         p.requires_grad = True

#     learning_rates_loss[index] = {"learning_rate": float(learning_rate), "loss": [], "steps": []}
#     loss_accumulator = 0

for k in range(max_steps): #running a loop as we're running a bunch of forward passes to get loss and optimize
    ## Make a mini batch
    # In this example, we will grab 32 rows of the lookup table
    # This is random, which can help with generalization
    ix = torch.randint(0, Xtr.shape[0], (32,))
    
    ## Forward pass
    xs_embeddings = emb_lookup[Xtr[ix]] # feed into the two subsequent layers
    
    ### LAYER ONE
    # Linear layer
    pre_activations = xs_embeddings.view(-1, 6) @  W1 

    # Batch norm layer
    # estimate batch statistics as we go so that we can pass one input through easily during val later
    bnmeani = pre_activations.mean(0, keepdim=True)
    bnstdi = pre_activations.std(0, keepdim=True)
    with torch.no_grad():
        bnmean_running = 0.999 * bnmean_running + 0.001 * bnmeani # nudging a bit toward current batch
        bnstd_running = 0.999 * bnstd_running + 0.001 * bnstdi # no_grad as we aren't edting activations wrt these

    pre_activations = bngain * (pre_activations - bnmeani) / bnstdi + bnbias
    
    # Non-linearity
    hidden_layer = torch.tanh(pre_activations)

    ### LAYER TWO
    # Linear layer
    pre_activations_2 = hidden_layer @  W2

    # Batch norm layer
    # estimate batch statistics as we go so that we can pass one input through easily during val later
    bnmeani_2 = pre_activations_2.mean(0, keepdim=True)
    bnstdi_2 = pre_activations_2.std(0, keepdim=True)
    with torch.no_grad():
        bnmean_running_2 = 0.999 * bnmean_running_2 + 0.001 * bnmeani_2 # nudging a bit toward current batch
        bnstd_running_2 = 0.999 * bnstd_running_2 + 0.001 * bnstdi_2 # no_grad as we aren't edting activations wrt these

    pre_activations_2 = bngain_2 * (pre_activations_2 - bnmeani_2) / bnstdi_2 + bnbias_2
    
    # Non-linearity
    hidden_layer_2 = torch.tanh(pre_activations_2)


    ### LAYER THREE
        # Linear layer
    pre_activations_3 = hidden_layer_2 @  W3 

    # Batch norm layer
    # estimate batch statistics as we go so that we can pass one input through easily during val later
    bnmeani_3 = pre_activations_3.mean(0, keepdim=True)
    bnstdi_3 = pre_activations_3.std(0, keepdim=True)
    with torch.no_grad():
        bnmean_running_3 = 0.999 * bnmean_running_3 + 0.001 * bnmeani_3 # nudging a bit toward current batch
        bnstd_running_3 = 0.999 * bnstd_running_3 + 0.001 * bnstdi_3 # no_grad as we aren't edting activations wrt these

    pre_activations_3 = bngain_3 * (pre_activations_3 - bnmeani_3) / bnstdi_3 + bnbias_3
    
    # Non-linearity
    logits = pre_activations_3 @ W4 + b4 #matrix multiplication, give us the log counts
    loss = F.cross_entropy(logits, Ytr[ix])
    print(loss.item())

    # TODO: need to understand the L2 regularization from part one. Also less clear about functions below
    # For example: 1) why are we going through all the parameters, and 2) why aren't we updating weights.data, but all

    ## Backward pass   
    for p in parameters:
        p.grad = None
    loss.backward()
    
    ## Update
    for p in parameters:
        p.data += -(0.005) * p.grad

    # Accumulate loss
    # loss_accumulator += loss.item()

    # if (k + 1) % 10 == 0:
    #     # Every 10 steps, compute the average loss and reset the accumulator
    #     average_loss = loss_accumulator / 10
    #     learning_rates_loss[index]["loss"].append(average_loss)
    #     learning_rates_loss[index]["steps"].append(k)
    #     loss_accumulator = 0  # Reset accumulator


2.445842742919922
2.7085328102111816
2.4167888164520264
2.1270394325256348
2.358886241912842
2.3624606132507324
2.278529644012451
2.724210500717163
2.1609127521514893
2.0531558990478516
2.442042350769043
2.368424892425537
2.4778778553009033
2.141958236694336
2.5988142490386963
2.339914083480835
2.065579652786255
2.4576237201690674
2.2829031944274902
2.5133659839630127
2.735032320022583
2.3047752380371094
2.001492500305176
2.525320529937744
2.397590398788452
2.625375747680664
2.512467384338379
2.524991989135742
2.3298544883728027
2.2765605449676514
2.5837786197662354
2.1976919174194336
2.3998863697052
2.825549840927124
2.5879619121551514
2.5838894844055176
2.1795709133148193
2.3907976150512695
2.7272517681121826
2.601530075073242
2.427306890487671
2.449981689453125
2.3919339179992676
2.365147829055786
2.461855888366699
2.2649760246276855
2.4105236530303955
2.068049430847168
2.585148572921753
2.2853362560272217
2.4759533405303955
2.4172308444976807
2.7180638313293457
2.307762384414673
2.

In [66]:
# CALCULATE ADJUSTED WEIGHTS TO FOLD IN
# IE. adjust model parameters by what we learned in initial training run with batchnorm
# We apply a common formula for adjusting weights after batchnorm

epsilon = 1e-5 # this is a smoothing constant, preventing division by zero or near-zero values with normalizing

W1_adjusted = W1 * (bngain / torch.sqrt(bnstd_running + epsilon))
b1_adjusted = bnbias - bnmean_running * bngain / torch.sqrt(bnstd_running + epsilon)

W2_adjusted = W2 * (bngain_2 / torch.sqrt(bnstd_running_2 + epsilon))
b2_adjusted = bnbias_2 - bnmean_running_2 * bngain_2 / torch.sqrt(bnstd_running_2 + epsilon)

W3_adjusted = W3 * (bngain_3 / torch.sqrt(bnstd_running_3 + epsilon))
b3_adjusted = bnbias_3 - bnmean_running_3 * bngain_3 / torch.sqrt(bnstd_running_3 + epsilon)

In [67]:
# FOLD IN UPDATED WEIGHTS
W1 = W1_adjusted
W2 = W2_adjusted
W3 = W3_adjusted

In [60]:
# CHECK TOTAL LOSS ON TRAIN_SET FORWARD PASS

# Do a forward pass (without updates)
@torch.no_grad()

def split_loss(split):
    x,y = {
    'train': (Xtr, Ytr),
    'val': (Xdev, Ydev),
    'test': (Xte, Yte),
  }[split]

    xs_embeddings = emb_lookup[x]
    pre_activations = xs_embeddings.view(xs_embeddings.shape[0], -1) @  W1 
    pre_activations = bngain * (pre_activations - bnmean_running) / bnstd_running + bnbias
    hidden_layer = torch.tanh(pre_activations)

    pre_activations_2 = hidden_layer @  W2
    pre_activations_2 = bngain_2 * (pre_activations_2 - bnmean_running_2) / bnstd_running_2 + bnbias_2
    hidden_layer_2 = torch.tanh(pre_activations_2)

    pre_activations_3 = hidden_layer_2 @  W3
    pre_activations_3 = bngain_3 * (pre_activations_3 - bnmean_running_3) / bnstd_running_3 + bnbias_3
    
    logits = pre_activations_3 @ W4 + b4 # matrix multiplication, give us the log counts
    loss = F.cross_entropy(logits, y)
    print(split, loss.item())

split_loss('train')
split_loss('val')


train 2.31453537940979
val 2.322866439819336


In [69]:
# VERIFY SIMILARITY OF LOSS AFTER FOLDING IN ADJUSTED WEIGHTS

# Do a forward pass (without updates)
@torch.no_grad()

def split_loss(split):
    x,y = {
    'train': (Xtr, Ytr),
    'val': (Xdev, Ydev),
    'test': (Xte, Yte),
  }[split]

    xs_embeddings = emb_lookup[x]
    pre_activations = xs_embeddings.view(xs_embeddings.shape[0], -1) @  W1 
    hidden_layer = torch.tanh(pre_activations)

    pre_activations_2 = hidden_layer @  W2
    hidden_layer_2 = torch.tanh(pre_activations_2)

    pre_activations_3 = hidden_layer_2 @  W3
    
    logits = pre_activations_3 @ W4 + b4 # matrix multiplication, give us the log counts
    loss = F.cross_entropy(logits, y)
    print(split, loss.item())

split_loss('train')
split_loss('val')


train 3.300934076309204
val 3.324777603149414


In [None]:
# SAMPLE FROM MODEL

g = torch.Generator().manual_seed(2147483647 + 10)
for _ in range(20):
    out = []
    context = [0] * block_size
    while True:
        # Forward pass
        embeddings = emb_lookup[torch.tensor([context])]
        hidden_layer = torch.tanh(embeddings.view(1, -1) @ W1 + b1)
        logits = hidden_layer @ W2 + b2
        probs = F.softmax(logits, dim=1)

        # Sample from distribution
        ix = torch.multinomial(probs, num_samples=1, generator=g).item()

        # Shift context window and track samples
        context = context[1:] + [ix]
        out.append(ix)

        # Break if end token detected
        if ix == 0:
            break
    print(''.join(itos[i] for i in out))

mria.
mayahlieel.
ndynyal.
rethrstend.
leg.
aderedieliileli.
jelle.
eisennanar.
kayziohlara.
noshubergahi.
jest.
jair.
jelilentenof.
uba.
ghde.
jyleli.
ehs.
kay.
myskeyah.
hil.
