In [3]:
import os
import pickle
import requests
import numpy as np
import glob

project_dir = os.path.abspath('.')
WORK_DIR = os.path.abspath(os.path.join(project_dir, '../'))
os.chdir(WORK_DIR)
print("Current Working Directory:", os.getcwd())

Current Working Directory: /Users/david.amat/Documents/david/llm-learning/llm/gpt-from-scratch


### [START] Download Shakespeare

In [4]:
# download the tiny shakespeare dataset
input_file_path = os.path.join(WORK_DIR, 'input.txt')
if not os.path.exists(input_file_path):
    data_url = 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt'
    with open(input_file_path, 'w') as f:
        f.write(requests.get(data_url).text)

In [5]:
with open(input_file_path, 'r') as f:
    text = f.read()
print(f"length of dataset in characters: {len(text):,}")

# get all the unique characters that occur in this text
chars = sorted(list(set(text)))
vocab_size = len(chars)
print("all the unique characters:", ''.join(chars))
print(f"vocab size: {vocab_size:,}")

length of dataset in characters: 1,115,394
all the unique characters: 
 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
vocab size: 65


In [6]:
# create a mapping from characters to integers
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
def encode(s):
    return [stoi[c] for c in s] # encoder: take a string, output a list of integers
def decode(l):
    return ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string


In [7]:
print(encode("hii there"))
print(decode(encode("hii there")))

[46, 47, 47, 1, 58, 46, 43, 56, 43]
hii there


### [END] Download Shakespeare

### [START] Tik Token

In [8]:
# trying with tiktoken to see diferences
import tiktoken
enc = tiktoken.get_encoding('gpt2')

In [9]:
enc.encode("hii there")

[71, 4178, 612]

In [10]:
enc.n_vocab

50257

### [END] Tik Token

### [START] Tokenize Shakespeare

In [11]:
import torch
# Set print options to show 4 decimal places without scientific notation
torch.set_printoptions(sci_mode=False, precision=4, linewidth=120)

In [12]:
data =torch.tensor(encode(text), dtype=torch.long)

In [13]:
print(data.shape, data.dtype)

torch.Size([1115394]) torch.int64


In [14]:
data[:10]

tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47])

In [15]:
text[:10]

'First Citi'

### [END] Tokenize Shakespeare

### [START] Train Split

In [16]:
# create the train and test splits
n = len(data)
train_data = data[:int(n*0.9)]
val_data = data[int(n*0.9):]

In [17]:
block_size = 8
train_data[:block_size+1]

tensor([18, 47, 56, 57, 58,  1, 15, 47, 58])

In [18]:
x = train_data[:block_size]
y = train_data[1:block_size+1]
for t in range(block_size):
    context = x[:t+1]
    target = y[t]
    print(f"Input: {context.tolist()} -> {target}")

Input: [18] -> 47
Input: [18, 47] -> 56
Input: [18, 47, 56] -> 57
Input: [18, 47, 56, 57] -> 58
Input: [18, 47, 56, 57, 58] -> 1
Input: [18, 47, 56, 57, 58, 1] -> 15
Input: [18, 47, 56, 57, 58, 1, 15] -> 47
Input: [18, 47, 56, 57, 58, 1, 15, 47] -> 58


### [END] Train Split

### [START] Batch

In [19]:
torch.manual_seed(1337)
batch_size = 4
block_size = 8

def get_batch(split):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    return x,y
    

In [20]:
xb, yb = get_batch('train')

In [21]:
xb

tensor([[24, 43, 58,  5, 57,  1, 46, 43],
        [44, 53, 56,  1, 58, 46, 39, 58],
        [52, 58,  1, 58, 46, 39, 58,  1],
        [25, 17, 27, 10,  0, 21,  1, 54]])

In [22]:
yb

tensor([[43, 58,  5, 57,  1, 46, 43, 39],
        [53, 56,  1, 58, 46, 39, 58,  1],
        [58,  1, 58, 46, 39, 58,  1, 46],
        [17, 27, 10,  0, 21,  1, 54, 39]])

### [END] Batch

### [START] Pytorch Bigram Model

In [23]:
import torch 
from copy import copy as cp
import torch.nn as nn
from torch.nn import functional as F
torch.manual_seed(1337)

<torch._C.Generator at 0x135541d90>

In [24]:
class BiGramLanguageModel(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        
        # Simulates a embedding prediction coming from the model
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)
        
    def forward(self, idx, targets):
        # Simulate and embedding prediction coming from the model
        logits = self.token_embedding_table(idx) # (Batch, Time, Channel)
        
        # If no targets provided, only return logits
        if targets is None:
            loss = None
        else:
            # Flatten the Time and Channel dimensions
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
        
            # Flatten the targets
            targets = targets.view(B*T)
        
            # Loss expects (Batch, Channel, Time)
            loss = F.cross_entropy(logits, targets)
        
        # Explanation
        # for a single sample:
        # Loss = - log ( (exp(logit[y_i])) / (sum(exp(logit_i))) ) where y_i is the index of the target next character
        # Shapes:
        # logits: (Batch * Time, Channel) => [32, 65]
        # targets: (Batch * Time) => [32]
        # Example: target = 43 // logits[0, 43] = -0.1728 0 => exp(XX) = 0.8413
        #  for the denominator: torch.sum(torch.exp(logits[0, :])) = 96.6660
        # => Loss = - log(0.8413 / 96.6660) = 4.7441
        
        # For all the samples
        # F.cross_entropy(logits, targets)
        
        return logits, loss

Cross Entropy Loss Formula for Batch

$$\text{CrossEntropyLoss} = \frac{1}{B \times T} \sum_{i=1}^{B} \sum_{t=1}^{T} -\log\left(\frac{\exp(\text{logits}_{i,t,y_{i,t}})}{\sum_{j=1}^{C} \exp(\text{logits}_{i,t,j})}\right)
$$

By default, the losses are averaged over each loss element in the batch. Otherwise provide `size_average` parameter

In [25]:
m = BiGramLanguageModel(vocab_size)

In [26]:
m.token_embedding_table.weight.shape

torch.Size([65, 65])

In [27]:
logits = m.token_embedding_table(xb)
logits.shape

torch.Size([4, 8, 65])

In [28]:
B, T, C = logits.shape
logits = logits.view(B*T, C)

In [29]:
logits.shape

torch.Size([32, 65])

In [30]:
targets = cp(yb)
targets = targets.view(B*T)
targets.shape

torch.Size([32])

In [31]:
F.cross_entropy(logits, targets)

tensor(4.8786, grad_fn=<NllLossBackward0>)

In [32]:
flogits, floss = m(xb, yb)

In [33]:
floss

tensor(4.8786, grad_fn=<NllLossBackward0>)

### [END] Pytorch Bigram Model

### [START] Generate sequence from Bigram Model

In [34]:
class BiGramLanguageModel(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        
        # Simulates a embedding prediction of the next character after the given token (idx)
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)
        
    def forward(self, idx, targets=None):
        # Simulate and embedding prediction coming from the model
        # idx => (Batch, Time) values are character index
        logits = self.token_embedding_table(idx) # => (Batch, Time, Channel) # values are logits for each character
        
        # If no targets provided, only return logits
        if targets is None:
            loss = None
        else:
            # Flatten the Time and Channel dimensions
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
        
            # Flatten the targets
            targets = targets.view(B*T)
        
            # Loss expects (Batch, Channel, Time)
            loss = F.cross_entropy(logits, targets)
        return logits, loss
        
    def generate(self, idx, max_new_tokens):
        # idx => [B, T]
        for _ in range(max_new_tokens):
            # Calling without targets => only logits with B,T,C shape
            logits, _ = self(idx)
            
            # Get the last time step (T=last) => [B, C]
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim=-1) # => [B, C]
            
            # Sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # => [B, 1]
            
            # Add to idx the generated token
            # # idx => [B, T] + [B, 1] => [B, T+1]
            idx = torch.cat((idx, idx_next), dim=1)
        return idx

In [35]:
a = torch.tensor([[0.1, 1, 6], [0.1, 5, 6]], dtype=torch.float32)

# Returns the indices of a sampling considering the values of a as logits 
# (the higher the value, the higher the probability of being sampled)
torch.multinomial(a, num_samples=15, replacement=True)

tensor([[2, 1, 2, 2, 1, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2],
        [2, 1, 2, 1, 1, 2, 2, 1, 2, 2, 1, 2, 2, 1, 2]])

In [36]:
m = BiGramLanguageModel(vocab_size)
flogits, floss = m(xb)

# Get the last timestep
lflogits = flogits[:, -1, :] # B, C

fprobs = F.softmax(lflogits, dim=-1)

In [37]:
fprobs.shape

torch.Size([4, 65])

In [38]:
idx_next = torch.multinomial(fprobs, num_samples=1)

In [39]:
idx_next.shape

torch.Size([4, 1])

In [40]:
idx_next

tensor([[30],
        [54],
        [30],
        [23]])

In [41]:
yb[:, -1]

tensor([39,  1, 46, 39])

In [42]:
# Final Concatenation of the generated idx (idx_next) to the original input (xb)
torch.cat((xb, idx_next), dim=1)

tensor([[24, 43, 58,  5, 57,  1, 46, 43, 30],
        [44, 53, 56,  1, 58, 46, 39, 58, 54],
        [52, 58,  1, 58, 46, 39, 58,  1, 30],
        [25, 17, 27, 10,  0, 21,  1, 54, 23]])

### [END] Generate sequence from Bigram Model

### [START] Try the generator of sequence

In [43]:
m = BiGramLanguageModel(vocab_size)

In [44]:
t_init = torch.zeros((1,1), dtype=torch.long)
gen = m.generate(idx=t_init, max_new_tokens=100)

In [45]:
decode(gen[0].tolist())

"\nnZnZt.vr,E'':CpBxFNYWFgzv?;M!cPNQUw.Kg$ &gMkioP.nrDQ\nfdpepBDOTwHtkwumsS3-u.x-&zu?$Lglwqq$KjRrQiYb B,"

### [END] Try the generator of sequence

### [START] Train Bigram Model

In [46]:
optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)

In [47]:
batch_size = 32
for steps in range(20000):
    xb, yb = get_batch('train')
    
    # Evaluate the loss
    logits, loss = m(xb, yb)
    
    # Training step
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
    
    # print every 1000 steps
    if steps % 1000 == 0:
        print(f"Step: {steps} Loss: {loss.item()}")        
    
print(f"Step: {steps} Loss: {loss.item()}")

Step: 0 Loss: 4.67943000793457
Step: 1000 Loss: 3.673828125
Step: 2000 Loss: 3.024789571762085
Step: 3000 Loss: 2.6680943965911865
Step: 4000 Loss: 2.7315430641174316
Step: 5000 Loss: 2.521266222000122
Step: 6000 Loss: 2.639226198196411
Step: 7000 Loss: 2.5043129920959473
Step: 8000 Loss: 2.4675230979919434
Step: 9000 Loss: 2.4043755531311035
Step: 10000 Loss: 2.428785562515259
Step: 11000 Loss: 2.329158067703247
Step: 12000 Loss: 2.6495845317840576
Step: 13000 Loss: 2.573981285095215
Step: 14000 Loss: 2.5653419494628906
Step: 15000 Loss: 2.452608108520508
Step: 16000 Loss: 2.42024564743042
Step: 17000 Loss: 2.4490718841552734
Step: 18000 Loss: 2.319751501083374
Step: 19000 Loss: 2.5759243965148926
Step: 19999 Loss: 2.5398643016815186


In [48]:
# Generate
t_init = torch.zeros((1,1), dtype=torch.long)
gen = m.generate(idx=t_init, max_new_tokens=100)
decode(gen[0].tolist())

'\nAla ancade ted ELondiserdoundd?\nT:\nMitho yonteam s? ouiearsithe ghormm l beritold KI:\n\ne IChieeng ng'

### [END] Train Bigram Model

### [START] Trick on self-attention

In [49]:
torch.manual_seed(1337)
B,T,C = 4,8,2
x = torch.randn(B,T,C)
x.shape

torch.Size([4, 8, 2])

Let's simulate a simple attention like averaging for time (t) all the previous embeddings up until t of that sequence in that batch (no look at the afterwards t+1, t+2, etc...) (masked self-attention simplified)

In [50]:
xbow = torch.zeros((B,T,C))
for b in range(B):
    for t in range(T):
        xprev = x[b, :t+1]  # [t, C]
        xbow[b, t] = torch.mean(xprev, 0)

In [51]:
xprev.shape # [t=8, C=2]

torch.Size([8, 2])

In [52]:
x[0][:]

tensor([[ 0.1808, -0.0700],
        [-0.3596, -0.9152],
        [ 0.6258,  0.0255],
        [ 0.9545,  0.0643],
        [ 0.3612,  1.1679],
        [-1.3499, -0.5102],
        [ 0.2360, -0.2398],
        [-0.9211,  1.5433]])

In [53]:
xbow[0][:]

tensor([[ 0.1808, -0.0700],
        [-0.0894, -0.4926],
        [ 0.1490, -0.3199],
        [ 0.3504, -0.2238],
        [ 0.3525,  0.0545],
        [ 0.0688, -0.0396],
        [ 0.0927, -0.0682],
        [-0.0341,  0.1332]])

In [54]:
np.mean([0.1808, -0.3596]) # which matches -0.0894 of the xbow second row first column

np.float64(-0.0894)

In [55]:
np.mean([0.1808, -0.3596, 0.6258]) # which matches -0.0894 of the xbow second row first column

np.float64(0.14900000000000002)

### [END] Trick on self-attention

### [START] Trick on self-attention introducing TRIL

In [56]:
x = torch.tril(torch.ones(3,3))
x

tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])

In [57]:
torch.manual_seed(42)
a = torch.tril(torch.ones(3,3))

# Make rows normalized to sum 1
a_norm = a / torch.sum(a, dim=1, keepdim=True)

# 
b = torch.randint(0,10,(3,2)).float()
c = a_norm @ b
print(f"{a=}\n{a_norm=}\n{b=}\n{c=}")

a=tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])
a_norm=tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])
b=tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
c=tensor([[2.0000, 7.0000],
        [4.0000, 5.5000],
        [4.6667, 5.3333]])


### [END] Trick on self-attention introducing TRIL

### [START] Applying TRIL

In [58]:
torch.manual_seed(1337)
B,T,C = 4,8,2
x = torch.randn(B,T,C)

wei = torch.tril(torch.ones(T, T))
wei = wei / wei.sum(dim=1, keepdim=True)
xbow2 = wei @ x  # [T, T] @ [B, T, C] => (broadcasting automatic) [B, T, T] @ [B, T, C] => [B, T, C]

In [59]:
print(f"{wei.shape=}\n{x.shape=}\n{xbow2.shape=}")

wei.shape=torch.Size([8, 8])
x.shape=torch.Size([4, 8, 2])
xbow2.shape=torch.Size([4, 8, 2])


### [END] Applying TRIL

### [START] Applying TRIL with Softmax

In [60]:
torch.manual_seed(1337)
B,T,C = 4,8,2
tril = torch.tril(torch.ones(T,T))
wei = torch.zeros((T,T))

In [61]:
wei = wei.masked_fill(tril == 0, float('-inf'))
wei

tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., -inf],
        [0., 0., 0., 0., 0., 0., 0., 0.]])

In [62]:
wei = F.softmax(wei, dim=-1)

In [63]:
wei

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])

In [64]:
xbow3 = wei @ x

In [65]:
torch.allclose(xbow2, xbow3)

True

### [END] Applying TRIL with Softmax

### [START] Linear layer to project embeddings

In [66]:
n_embd = 32
token_embedding_table = nn.Embedding(vocab_size, n_embd)
lm_head = nn.Linear(n_embd, vocab_size)

In [67]:
# Get only 3 samples from batch
xb1 = xb[:3,...]  # [3, 8]

In [68]:
# idx and targets are both (B,T) tensor of integers
tok_emb = token_embedding_table(xb1) # (B,T,C)
tok_emb.shape

torch.Size([3, 8, 32])

In [69]:
logits = lm_head(tok_emb) # (B,T,C)
logits.shape

torch.Size([3, 8, 65])

### [END] Linear layer to project embeddings

### [START] Position embedding

In [70]:
position_embedding_table = nn.Embedding(block_size, n_embd)

In [71]:
# Get only 3 samples from batch
xb1 = xb[:3,...]  # [3, 8]

# Get the token embeddings
tok_emb = token_embedding_table(xb1) # (B,T,C)

# Get the position embeddings
Txb = xb1.shape[1]
pos_emb = position_embedding_table(torch.arange(Txb)) # (T,C)

# Add the position embeddings to the token embeddings
emb_final = tok_emb + pos_emb  # Broadcasting is done automatically (B,T,C) + (T,C) => (B,T,C) + (B,T,C) => (B,T,C)

In [72]:
logits_emb_final = lm_head(emb_final) # (B,T,C)
logits_emb_final.shape

torch.Size([3, 8, 65])

### [END] Position embedding

### [START] Head Attention

Every single token in the sequence will emit a QUERY and a KEY vectors

In [73]:
torch.manual_seed(1337)
B,T,C = 4,8,2
x = torch.randn(B,T,C)

# Head
head_size = 16
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)

k = key(x) # [B, T, H]
q = query(x) # [B, T, H]
v = value(x)

In [74]:
q.shape

torch.Size([4, 8, 16])

In [75]:
ex = torch.randn(2,3,4)
print(ex.shape)

# Transpose the last two dimensions
ex.transpose(-2, -1).shape

torch.Size([2, 3, 4])


torch.Size([2, 4, 3])

In [76]:
# Affinities between keys and vectors
# careful, the transpose to do the dot product must NOT happen at the batch dimension level
wei = q @ k.transpose(-2, -1)  # (B, T, H ) @ (B, H, T) => (B, T, T)

In [77]:
tril = torch.tril(torch.ones(T,T))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)
out = wei @ v # (B, T, T) @ (B, T, C) => (B, T, C)

In [78]:
out.shape

torch.Size([4, 8, 16])

In [79]:
wei[0:2,...]

tensor([[[    1.0000,     0.0000,     0.0000,     0.0000,     0.0000,     0.0000,     0.0000,     0.0000],
         [    0.5599,     0.4401,     0.0000,     0.0000,     0.0000,     0.0000,     0.0000,     0.0000],
         [    0.3220,     0.2016,     0.4764,     0.0000,     0.0000,     0.0000,     0.0000,     0.0000],
         [    0.1640,     0.0815,     0.2961,     0.4585,     0.0000,     0.0000,     0.0000,     0.0000],
         [    0.2051,     0.3007,     0.1894,     0.1808,     0.1241,     0.0000,     0.0000,     0.0000],
         [    0.0600,     0.1273,     0.0291,     0.0169,     0.0552,     0.7114,     0.0000,     0.0000],
         [    0.1408,     0.1025,     0.1744,     0.2038,     0.1690,     0.0669,     0.1426,     0.0000],
         [    0.0223,     0.1086,     0.0082,     0.0040,     0.0080,     0.7257,     0.0216,     0.1016]],

        [[    1.0000,     0.0000,     0.0000,     0.0000,     0.0000,     0.0000,     0.0000,     0.0000],
         [    0.5634,     0.4366,  

###  -----  [START] Scaled Attention

In [86]:
k = torch.randn(B, T, head_size)
q = torch.randn(B, T, head_size)
wei = q @ k.transpose(-2, -1)
wei_norm = q @ k.transpose(-2, -1) * (head_size ** -0.5)

In [85]:
print(f"{k.var()=:.4f} {q.var()=:.4f} {wei.var()=:.4f}")

k.var()=1.0028 q.var()=1.0326 wei.var()=17.6690


In [87]:
print(f"{k.var()=:.4f} {q.var()=:.4f} {wei_norm.var()=:.4f}")

k.var()=0.9256 q.var()=0.9796 wei_norm.var()=0.8931


###  -----  [END] Scaled Attention

### [END] Head Attention

###  [START] Batch Normalization

###  [START] Batch Normalization

In [88]:
class BatchNorm1d:
  
  def __init__(self, dim, eps=1e-5, momentum=0.1):
    self.eps = eps
    self.momentum = momentum
    self.training = True
    # parameters (trained with backprop)
    self.gamma = torch.ones(dim)
    self.beta = torch.zeros(dim)
    # buffers (trained with a running 'momentum update')
    self.running_mean = torch.zeros(dim)
    self.running_var = torch.ones(dim)
  
  def __call__(self, x):
      # x: [B, H]
    # calculate the forward pass
    if self.training:
      if x.ndim == 2:
        dim = 0
      elif x.ndim == 3:
        dim = (0,1)
      xmean = x.mean(dim, keepdim=True) # xmean: [1, H]
      xvar = x.var(dim, keepdim=True) # xvar: [1, H]
    else:
      xmean = self.running_mean
      xvar = self.running_var
    xhat = (x - xmean) / torch.sqrt(xvar + self.eps) # xhat: [B, H]
    self.out = self.gamma * xhat + self.beta  # out: [B, H]
    # update the buffers
    if self.training:
      with torch.no_grad():
        self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * xmean
        self.running_var = (1 - self.momentum) * self.running_var + self.momentum * xvar
    return self.out
  
  def parameters(self):
    return [self.gamma, self.beta]

In [90]:
torch.manual_seed(1337)
module = BatchNorm1d(100)
x = torch.randn(32, 100)
x = module(x)
x.shape

torch.Size([32, 100])

In [91]:
# we are normalizing columns (rows are batch samples, so we normalize each feature across the batch)
xm = x[:,0].mean()
xstd = x[:,0].std()
print(f"{xm=:.4f} {xstd=:.4f}")

xm=0.0000 xstd=1.0000


In [95]:
class LayerNorm1d:
  
  def __init__(self, dim, eps=1e-5, momentum=0.1):
    self.eps = eps
    self.gamma = torch.ones(dim)
    self.beta = torch.zeros(dim)
  
  def __call__(self, x):
      # x: [B, H]
    # calculate the forward pass
    if x.ndim == 2:
        dim = 1
    elif x.ndim == 3:
        dim = (1,2)
    xmean = x.mean(dim, keepdim=True) # xmean: [1, H]
    xvar = x.var(dim, keepdim=True) # xvar: [1, H]

    xhat = (x - xmean) / torch.sqrt(xvar + self.eps) # xhat: [B, H]
    self.out = self.gamma * xhat + self.beta  # out: [B, H]
    return self.out
  
  def parameters(self):
    return [self.gamma, self.beta]

In [96]:
torch.manual_seed(1337)
module = LayerNorm1d(100)
x = torch.randn(32, 100)
x = module(x)
x.shape

torch.Size([32, 100])

In [97]:
# we are normalizing rows
xm = x[0, :].mean()
xstd = x[0, :].std()
print(f"{xm=:.4f} {xstd=:.4f}")

xm=-0.0000 xstd=1.0000
