In [None]:
import numpy as np
import torch 
import torch.nn as nn
from torch.nn import functional as F

# https://www.youtube.com/watch?v=kCc8FmEb1nY - 1:19:12


In [None]:
# Hyperparameters
batch_size = 64
block_size = 256 # context window size for prediction 
max_iters = 5000
eval_interval = 500
learning_rate = 3e-4
device = torch.device("mps" if torch.has_mps else "cpu")
print("device", device)
n_embed = 384
n_head = 6 # 384 / 6 = 64 dim head 
n_layer = 6
dropout = 0.2 

# Get Input Data

In [None]:
# We always start with a dataset to train on. Let's download the tiny shakespeare dataset
#!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

In [None]:
with open('input.txt', 'r', encoding='utf-8') as f: 
    text = f.read()
print("characters", len(text))
print(text[:100])

# Tokenize Input

In [None]:
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(''.join(chars))
print(vocab_size)

In [None]:
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars )}
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join([itos[i] for i in l])

print("encoded", encode("hello world"))
print("decoded", decode(encode("hello world")))

In [None]:
data = torch.tensor(encode(text), dtype=torch.long)
print(data.shape, data.dtype)
print(data[:100])

In [None]:
n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]

# Data Loader

In [None]:
# Data loader
torch.manual_seed(42)

def get_batch(split): 
    data = train_data if split == "train" else val_data
    s = torch.randint(0, data.size(0) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in s])
    y = torch.stack([data[i+1:i+1+block_size] for i in s])
    x, y = x.to(device), y.to(device)
    return x, y

xb, yb = get_batch("train")
print(xb.shape, yb.shape)
print(xb)
print(yb)

# Bigram Model

In [None]:
class Head(nn.Module): # Added in later

    '''
    One head of self-attention 
    '''

    def __init__(self, head_size): 
        super().__init__()
        self.key = nn.Linear(n_embed, head_size, bias=False)
        self.query = nn.Linear(n_embed, head_size, bias=False)
        self.value = nn.Linear(n_embed, head_size, bias=False)

        # This is a buffer, not a parameters
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

        self.dropout = nn.Dropout(dropout)

    def forward(self, x): 
        B, T, C = x.shape
        k = self.key(x) # (B, T, head_size)
        q = self.query(x) # (B, T, head_size)
        v = self.value(x) # (B, T, head_size)

        # Compute self attention scores ("affinities")
        w = q @ k.transpose(-2, -1) * C**-0.5# (B, T, head_size) @ (B, head_size, T) = (B, T, T)
        w = w.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)
        w = F.softmax(w, dim=-1)
        w = self.dropout(w) # Drop out some of the communication between tokens 

        out = w @ v # (B, T, T) @ (B, T, head_size) = (B, T, head_size)
        return out
    

class MultiHeadAttention(nn.Module): 
    '''
    Multiple self-attention heads in parallel 
    '''

    def __init__(self, num_heads, head_size): 
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(n_embed, n_embed)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # Increase the number of channels that a token can use for information with multiple heads 
        out = torch.cat([h(x) for h in self.heads], dim=-1) # (B, T, num_heads * head_size)

        # Project back to the original number of channels - this allows us to create the residual pathway in the transformer block 
        out = self.proj(out) # (B, T, n_embed)

        # Dropout for regularization - prevent overfitting
        # During training, require the model to learn with a subset of weights (but a different subset with each pass)
        # During inference, it's like using an ensemble of sub-networks 
        out = self.dropout(out)

        return out
    

class FeedForward(nn.Module): 

    def __init__(self, n_embed): 
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embed, 4*n_embed),
            nn.ReLU(), # Introduct some non-linearity learnings 
            nn.Linear(4*n_embed, n_embed), # Projection back into residual pathway 
            nn.Dropout(dropout) # Dropout for regularization - prevent overfitting 
        )

    def forward(self, x): 
        return self.net(x)
    
class LayerNorm:
    '''
    Layer normalization is another technique to help with training deep networks. 
    It guarantees the output of the layer has zero mean and unit variance (following a normal distribution). 
    Actual layer norm is more complicated. 
    '''

    def __init__(self, dim, eps=1e-5): 
        self.eps = eps
        self.gamma = torch.ones(dim)
        self.beta = torch.zeros(dim)

    def __call__(self, x): 
        xmean = x.mean(1, keepdim=True) # Row mean per batch 
        xvar = x.var(1, keepdim=True) # Row variance per batch
        xhat = (x - xmean) / torch.sqrt(xvar + self.eps) # Normalization 
        self.out = self.gamma * xhat + self.beta # Scale and shift (learned parameters)
        return self.out
    
    def parameters(self): 
        return [self.gamma, self.beta]



class Block(nn.Module): 
    '''
    Transformer block: communication, then computation (self-attention, then feed-forward)
    '''

    def __init__(self, n_embed=8, n_head=4): 
        super().__init__()
        head_size = n_embed // n_head
        self.sa = MultiHeadAttention(n_head, head_size) # shape is 8 x 32 (4 parallel heads 8 dim w@v otuputs each)
        self.ff = FeedForward(n_embed) # Need some computation time to contemplate the self-attention output before getting to the logits 
        self.ln1 = nn.LayerNorm(n_embed) # Normalize the input to the self-attention layer - using pytorch LayerNorm 
        self.ln2 = nn.LayerNorm(n_embed) # Normalize the input to the feed-forward layer - using pytorch LayerNorm

    def forward(self, x):
        # The "x +=" is a skip connection, which is useful to propogate gradients through the network
        # This is necessary for deep neural nets 
        # At the start of training, the self.sa(x) and self.ff(x) will contribute little, so the skip connection will quickly propogate gradients back 
        x = x + self.sa(self.ln1(x))
        x = x + self.ff(self.ln2(x))
        return x


In [None]:
class BigramLanguageModel(nn.Module): 
    '''
    Description: BLM model that predicts the next token given the current token. 
    '''

    def __init__(self, vocab_size): 
        super().__init__()

        # Encode the identity of a categorical token deterministically to a dense vector space in n_embed dimensions 
        self.token_embedding_table = nn.Embedding(vocab_size, n_embed) # shape is 65 x 32

        # Encode the identity of a token position to n_embed dense vector space 
        self.position_embedding_table = nn.Embedding(block_size, n_embed) # shape is 8 x 32

        # Transformer block 
        # Having a lot of blocks makes the net very deep - deep nets are hard to optimize 
        self.blocks = nn.Sequential(*[Block(n_embed, n_head=n_head) for _ in range(n_layer)])

        self.ln_f = nn.LayerNorm(n_embed) # Normalize the input to the final layer 

        # Project the n_embed dimensional token embedding to the vocab size to get logits for output char
        self.lm_head = nn.Linear(n_embed, vocab_size) # shape is 32 x 65
    
    def forward(self, idx, targets=None): 
        B, T = idx.shape

        # For a given input token, get the corresponding row of the embedding matrix
        # This row is the logits (likelihood score) for the next token 
        # The parameters to pick the right logits gets trained during backprop
        token_emb = self.token_embedding_table(idx) # Output is 4 x 8 x 32
        pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # Output is 8 x 32
        x = token_emb + pos_emb # Output is 4 x 8 x 32
        x = self.blocks(x)
        logits = self.lm_head(x) # Output is 4 x 8 x 65
        
        if targets is None: 
            return logits, None
        
        # Use negative log likelihood loss to train the model
        # b = batch, t = block, c = vocab size
        B, T, C = logits.shape
        logits = logits.view(B*T, C) # Output is 32 x 65
        targets = targets.view(B*T) # Output is 32
        loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_tokens): 

        # idx is (B, T)
        for _ in range(max_new_tokens): 
            
            # trim idx to the last block size tokens 
            idx_cond = idx[:, -block_size:] # (B, T)

            logits, loss = self(idx_cond)

            # Only care about the last time step
            logits = logits[:, -1, :]

            # Get probabilities from logits 
            probs = F.softmax(logits, dim=-1)

            # Sample from distribution 
            idx_next = torch.multinomial(probs, num_samples=1)

            # Append sampled token to sequence 
            idx = torch.cat([idx, idx_next], dim=1)

        return idx #(B, T+1)



model = BigramLanguageModel(vocab_size)
m = model.to(device)
logits, loss = m(xb, yb)
print(logits.shape, loss)

In [None]:
eval_iters = 100

@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ["train", "val"]:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            xb, yb = get_batch(split)
            _, loss = model(xb, yb)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out



In [None]:
print("We expect the loss to be around", -np.log(1/65))

In [None]:
context = torch.zeros((1, 1), dtype=torch.long, device=device) # 0 is the \n character so good to start off
print(decode(m.generate(context, 100).cpu().numpy()[0]))

In [None]:
# Create optimizer 
optimizer = torch.optim.Adam(m.parameters(), lr=0.001)

In [None]:
# Training loop
batch_size = 32
for steps in range(10000): 

    if steps % 1000 == 0:
        losses = estimate_loss()
        print("Step", steps, "Train loss", losses["train"], "Val loss", losses["val"])

    xb, yb = get_batch('train')
    logits, loss = m(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

In [None]:
print(decode(m.generate(context, 100).cpu().numpy()[0]))

# Math trick in self attention

In [None]:
'''
In bigram, we are only looking at the previous token to predict the next token. It has no context. 
To introduce context, we can naively take the average of all the previous tokens to give some information about what came before. 
'''

In [None]:
torch.manual_seed(1337)
B, T, C = 4, 8, 2
x = torch.randn(B, T, C)
print(x.shape)

In [None]:
# We want xbow[b, t] = mean_{i<=t} x[b, i]
# X Bag of Words - term for averaging 
xbow = torch.zeros((B, T, C))
# For each batch
for b in range(B): 
    # For each token in the block 
    for t in range(T): 
        # Get the mean of all the previous tokens
        xprev = x[b, :t+1] # (t, C)
        xbow[b, t] = torch.mean(xprev, dim=0)
xbow[0], x[0]

In [None]:
# This process can be sped up w matrix multiplication
w = torch.tril(torch.ones((T, T))) # Lower triangular matrix
w = w / torch.sum(w, dim=1, keepdim=True) # Normalize

# With the lower tri matrix, we can do matrix multiplication to get the same result
# For each row i in the matmul, it takes the mean of the first i rows of x
# This matrix multiplication gives us context about the previous tokens by taking their weight sum (mean in this naive case)
print(w)
xbow2 = w @ x # (B, T, T) @ (B, T, C) = (B, T, C)
print(xbow2[0])

In [None]:
# Another way to do this 
tril = torch.tril(torch.ones((T, T)))
w = torch.zeros((T, T)) # These are our attenion weights (how much we care about each previous token) 
                        # -> allow for a weighted sum of past elements 

w = w.masked_fill(tril == 0, float('-inf')) # Replace 0 with -inf -> don't peek into the future (represented by upper tri)

w = F.softmax(w, dim=1) # Normalize -> gives us the weights for each token 
                        # (same matrix as above; weights are equal for each token since we're taking the mean)
                        # Normalizes weights to sum to 1
print(w)
xbow3 = w @ x # (B, T, T) @ (B, T, C) = (B, T, C)
print(xbow3[0])
print(torch.allclose(xbow2, xbow3))

In [None]:
# Self attention 

torch.manual_seed(1337)
B, T, C = 4, 8, 32 # increased channels to 32 -> each token has a 32 dim vector representation
x = torch.randn(B, T, C)

# example of single Head perform self-attention
head_size = 16

# Keys and querys are produced independently from the channels 
key = nn.Linear(C, head_size, bias=False) # 16x32
query = nn.Linear(C, head_size, bias=False) # 16x32
value = nn.Linear(C, head_size, bias=False) # 16x32

# Every token has a query and key, generated by its identity + positional encoding 
# query - what am i looking for? i.e. i'm looking for previous vowels 
# key - what do i contain? i.e. i'm a constant in the beginning of the block 
# value - the thing that gets weighted sum and moves on in the calculation 
# If query/key are aligned, then we will have a high dot product 
k = key(x) # (B, T, head_size)
q = query(x) # (B, T, head_size)
v = value(x) # (B, T, head_size) 

w = q @ k.transpose(-2, -1) # (B, T, head_size) @ (B, head_size, T) = (B, T, T)
w = w / np.sqrt(head_size) # Normalize -> want to make numbers smaller so softmax doesn't sharpen to 0 or 1 (becomes OHE)
tril = torch.tril(torch.ones((T, T)))
w = w.masked_fill(tril == 0, float('-inf')) # Replace 0 with -inf -> don't peek into the future (represented by upper tri)
                                            # In some cases, you might not want to mask the future (i.e. protein language modeling)
w = F.softmax(w, dim=-1) # Normalize -> gives us the weights for each token

# For each token, the output channels is a weighted sum of prev tokens "values"
# The weight applied to each token-value-vector is determined by the dot product of the query and key
# The "out" is then propogated to the next layer 
out = w @ v # (B, T, T) @ (B, T, head_size) = (B, T, head_size)

# We might have multiple heads, so for each x, you might have mulitple v's to communicate further in the network 
#   Each v is kind of like the feature vector of a token to communicate to the next layer
#   To know which v's are important, you need alignment between the query and key
print(out.shape)
print(out[0])
print(w[0])

In [None]:
# Remember the upper tri is masked 
# in the lower tri position (8, 7) was created by mulitplying the query of token 8 with the key of token 7
#   -> 8 broadcasts what it's looking for (in head_size dim vector repreesntation), and 7 broadcasts what it contains (in some embedding representation space of head_size)
#   -> high score means high dot product, means high alignment 
(q @ k.transpose(-2, -1))[0]

In [None]:
# The above is called self-attention because q,k, and v all come from the same x
# In cross attention, the query comes from x and k/v come from y 
#   Basically, x is "looking" for some characteristic in tokens of y. 
#   Y's "keys" project "what they are". If there is alignment, those y-values are propogated 