In [1]:
%pip install "torch>=2" numpy --quiet

Note: you may need to restart the kernel to use updated packages.


# Download TinyShakespeare Text

In [2]:
#!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

# Inspect Text

In [3]:
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

In [4]:
print(f"length of dataset in characters: {len(text)}")

length of dataset in characters: 1115394


In [5]:
print(text[:1000])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor citizens, the patricians good.
What authority surfeits on would relieve us: if they
would yield us but the superfluity, while it were
wholesome, we might guess they relieved us humanely;
but they think we are too dear: the leanness that
afflicts us, the object of our misery, is as an
inventory to particularise their abundance; our
sufferance is a gain to them Let us revenge this with
our pikes, ere we become rakes: for the gods know I
speak this in hunger for bread, not in thirst for revenge.



# Tokenization

In [6]:
# all the unique chars that occur in this text
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(''.join(chars))
print(vocab_size)


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
65


In [7]:
# Implement a simple chaacter level tokenization schema. More sophisticated tokenizers include SentencePiece / tiktoken
stoi = { ch:i for i,ch in enumerate(chars) } # str to int mapping
itos = { i:ch for i,ch in enumerate(chars) } # int to str mapping
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join([itos[i] for i in l])

print(encode("transformers"))
print(decode(encode("transformers")))

[58, 56, 39, 52, 57, 44, 53, 56, 51, 43, 56, 57]
transformers


In [8]:
import torch
data = torch.tensor(encode(text), dtype=torch.long)
print(data.shape, data.dtype)
print(data[:1000])

torch.Size([1115394]) torch.int64
tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 14, 43, 44,
        53, 56, 43,  1, 61, 43,  1, 54, 56, 53, 41, 43, 43, 42,  1, 39, 52, 63,
         1, 44, 59, 56, 58, 46, 43, 56,  6,  1, 46, 43, 39, 56,  1, 51, 43,  1,
        57, 54, 43, 39, 49,  8,  0,  0, 13, 50, 50, 10,  0, 31, 54, 43, 39, 49,
         6,  1, 57, 54, 43, 39, 49,  8,  0,  0, 18, 47, 56, 57, 58,  1, 15, 47,
        58, 47, 64, 43, 52, 10,  0, 37, 53, 59,  1, 39, 56, 43,  1, 39, 50, 50,
         1, 56, 43, 57, 53, 50, 60, 43, 42,  1, 56, 39, 58, 46, 43, 56,  1, 58,
        53,  1, 42, 47, 43,  1, 58, 46, 39, 52,  1, 58, 53,  1, 44, 39, 51, 47,
        57, 46, 12,  0,  0, 13, 50, 50, 10,  0, 30, 43, 57, 53, 50, 60, 43, 42,
         8,  1, 56, 43, 57, 53, 50, 60, 43, 42,  8,  0,  0, 18, 47, 56, 57, 58,
         1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 18, 47, 56, 57, 58,  6,  1, 63,
        53, 59,  1, 49, 52, 53, 61,  1, 15, 39, 47, 59, 57,  1, 25, 39, 56, 41,
      

# Train-Test Split

In [9]:
n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]
print(len(train_data), len(val_data))

1003854 111540


In [10]:
block_size = 8
train_data[:block_size+1]

tensor([18, 47, 56, 57, 58,  1, 15, 47, 58])

In [11]:
x = train_data[:block_size]
y = train_data[1:block_size+1]
for t in range(block_size):
    context = x[:t+1]
    target = y[t]
    print(f"When input content is {context}, the targer is {target}")

When input content is tensor([18]), the targer is 47
When input content is tensor([18, 47]), the targer is 56
When input content is tensor([18, 47, 56]), the targer is 57
When input content is tensor([18, 47, 56, 57]), the targer is 58
When input content is tensor([18, 47, 56, 57, 58]), the targer is 1
When input content is tensor([18, 47, 56, 57, 58,  1]), the targer is 15
When input content is tensor([18, 47, 56, 57, 58,  1, 15]), the targer is 47
When input content is tensor([18, 47, 56, 57, 58,  1, 15, 47]), the targer is 58


In [12]:
torch.manual_seed(1337)
batch_size = 4
block_size = 8

def get_batch(split, batch_size, block_size):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    return x, y

xb, yb = get_batch('train', batch_size, block_size)
print('inputs:')
print(xb.shape)
print(xb)
print('targets:')
print(yb.shape)
print(yb)

print('----')

for b in range(batch_size):
    for t in range(batch_size):
        context = xb[b, :t+1]
        target =yb[b,t]
        print(f"when input is {context}, target is {target}")

inputs:
torch.Size([4, 8])
tensor([[24, 43, 58,  5, 57,  1, 46, 43],
        [44, 53, 56,  1, 58, 46, 39, 58],
        [52, 58,  1, 58, 46, 39, 58,  1],
        [25, 17, 27, 10,  0, 21,  1, 54]])
targets:
torch.Size([4, 8])
tensor([[43, 58,  5, 57,  1, 46, 43, 39],
        [53, 56,  1, 58, 46, 39, 58,  1],
        [58,  1, 58, 46, 39, 58,  1, 46],
        [17, 27, 10,  0, 21,  1, 54, 39]])
----
when input is tensor([24]), target is 43
when input is tensor([24, 43]), target is 58
when input is tensor([24, 43, 58]), target is 5
when input is tensor([24, 43, 58,  5]), target is 57
when input is tensor([44]), target is 53
when input is tensor([44, 53]), target is 56
when input is tensor([44, 53, 56]), target is 1
when input is tensor([44, 53, 56,  1]), target is 58
when input is tensor([52]), target is 58
when input is tensor([52, 58]), target is 1
when input is tensor([52, 58,  1]), target is 58
when input is tensor([52, 58,  1, 58]), target is 46
when input is tensor([25]), target is 17


# Bigram LM

In [13]:
import torch
import torch.nn as nn
from torch.nn import functional as F
torch.manual_seed(1337)


class BigramLanguageModel(nn.Module):

    def __init__(self, vocab_size):
        super().__init__()
        # the whole model is simply a square lookup table.
        # For each char (token) a in the whole char set, we maintain the probability of char b appearing after a.
        # So the size is vocab_size x vocab_size

        # each token directly reads off the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)

    def forward(self, idx, targets=None):
        # idx and targets are both (B, T) tensor of integers
        logits = self.token_embedding_table(idx)  # (B, T, C) -> Batch=4(batch size), Time=8(block_size), Channel=65(vocab_size)
        
        if targets is None:
            loss = None
        else:
            # reshape logits & targets into 2D to cater for F.corss_entropy
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)
        
        return logits, loss
    
    def generate(self, idx, max_new_tokens): # idx is (B, T) array of indices in the current context
        
        # generate max_new_tokens tokens iteratively by looking at only the last token each time
        for _ in range(max_new_tokens):
            
            # get the predictions
            logits, loss = self(idx)
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx
    
m = BigramLanguageModel(vocab_size)
logits, loss = m(xb, yb)
print(logits.shape)
print(loss)

torch.Size([32, 65])
tensor(4.8786, grad_fn=<NllLossBackward0>)


In [14]:
idx = torch.zeros((1,1), dtype=torch.long)
print(decode(m.generate(idx, max_new_tokens=100)[0].tolist()))


SKIcLT;AcELMoTbvZv C?nq-QE33:CJqkOKH-q;:la!oiywkHjgChzbQ?u!3bLIgwevmyFJGUGp
wnYWmnxKWWev-tDqXErVKLgJ


# Train Bigram LM

In [15]:
def train(model, steps, batch_size, block_size, lr=1e-3):
    optimizer = torch.optim.AdamW(model.parameters(), lr)
    for steps in range(steps):
        xb, yb = get_batch('train', batch_size, block_size)

        logits, loss = model(xb, yb)
        optimizer.zero_grad(set_to_none=True) # clear grads from the previous step
        loss.backward() # calculate grads for all params
        optimizer.step() # update params

    print(loss.item())

In [16]:
batch_size = 32
train(m, 10000, batch_size, block_size)

2.382369041442871


In [17]:
idx = torch.zeros((1,1), dtype=torch.long)
print(decode(m.generate(idx, max_new_tokens=100)[0].tolist()))


lso br. ave aviasurf my, yxMPZI ivee iuedrd whar ksth y h bora s be hese, woweee; the! KI 'de, ulsee


# Deriving self-attention

In [18]:
# consider this toy example batch:

torch.manual_seed(1337)
B, T, C = 4, 8, 2 # batch, time, channels
x = torch.randn(B, T, C)
x.shape

torch.Size([4, 8, 2])

Currently, the bigram model is not communicating / paying attention at the (n-1, n-2, ... , 1)th tokens when predicting the (n+1)th token from the nth. Majority of the context info is lost.

We need to derive a mechanism for the model to attend to previous tokens when predicting the future token.


## Naive aggregation: averaging past tokens (weakest from of "communication")

### Version 1: by naive for loop

In [19]:
# We want x[b,t] = mean_{i<=t} x[b,i]
xbow = torch.zeros((B, T, C)) # x bag of words
for b in range(B):
    for t in range(T):
        xprev = x[b,:t+1]
        xbow[b, t] = torch.mean(xprev, 0)
xbow

tensor([[[ 0.1808, -0.0700],
         [-0.0894, -0.4926],
         [ 0.1490, -0.3199],
         [ 0.3504, -0.2238],
         [ 0.3525,  0.0545],
         [ 0.0688, -0.0396],
         [ 0.0927, -0.0682],
         [-0.0341,  0.1332]],

        [[ 1.3488, -0.1396],
         [ 0.8173,  0.4127],
         [-0.1342,  0.4395],
         [ 0.2711,  0.4774],
         [ 0.2421,  0.0694],
         [ 0.0084,  0.0020],
         [ 0.0712, -0.1128],
         [ 0.2527,  0.2149]],

        [[-0.6631, -0.2513],
         [ 0.1735, -0.0649],
         [ 0.1685,  0.3348],
         [-0.1621,  0.1765],
         [-0.2312, -0.0436],
         [-0.1015, -0.2855],
         [-0.2593, -0.1630],
         [-0.3015, -0.2293]],

        [[ 1.6455, -0.8030],
         [ 1.4985, -0.5395],
         [ 0.4954,  0.3420],
         [ 1.0623, -0.1802],
         [ 1.1401, -0.4462],
         [ 1.0870, -0.4071],
         [ 1.0430, -0.1299],
         [ 1.1138, -0.1641]]])

In [20]:
torch.manual_seed(42)
a = torch.tril(torch.ones(3, 3))
a = a/ torch.sum(a, 1, keepdim=True)
b = torch.randint(0,10,(3,2)).float()
c = a @ b
print(f'a={a}')
print(f'b={b}')
print(f'c={c}')

a=tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])
b=tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
c=tensor([[2.0000, 7.0000],
        [4.0000, 5.5000],
        [4.6667, 5.3333]])


### Version 2: by matrix multiplication

The operation of x[b,t] = mean_{i<=t} x[b,i] can be simplified & optimized by a row-normalized (each row sums to 1) lower triangular matric @ x

### Version 3: get weights by dividing by row sum

In [21]:
torch.manual_seed(42)
wei = torch.tril(torch.ones(T, T)) # weight - the row-normalized lower triangular matrix
wei = wei / wei.sum(1, keepdim=True)

xbow2 = wei @ x # (B (auto broadcasted by torch), T, T) @ (B, T, C) --> (B, T, C)
print(f"wei: {wei}")
print(f"xbow2: {xbow2}")
torch.allclose(xbow, xbow2)

wei: tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])
xbow2: tensor([[[ 0.1808, -0.0700],
         [-0.0894, -0.4926],
         [ 0.1490, -0.3199],
         [ 0.3504, -0.2238],
         [ 0.3525,  0.0545],
         [ 0.0688, -0.0396],
         [ 0.0927, -0.0682],
         [-0.0341,  0.1332]],

        [[ 1.3488, -0.1396],
         [ 0.8173,  0.4127],
         [-0.1342,  0.4395],
         [ 0.2711,  0.4774],
         [ 0.2421,  0.0694],
         [ 0.008

True

### Version 4: get weights by softmax

by turning all zeros to -inf in a lower tri mat, then applying softmax to row, we can get the same weights

In [22]:
wei = torch.tril(torch.ones(T, T)) # T by T lower tri mat
wei = wei.masked_fill(wei==0, float('-inf'))
wei

tensor([[1., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [1., 1., -inf, -inf, -inf, -inf, -inf, -inf],
        [1., 1., 1., -inf, -inf, -inf, -inf, -inf],
        [1., 1., 1., 1., -inf, -inf, -inf, -inf],
        [1., 1., 1., 1., 1., -inf, -inf, -inf],
        [1., 1., 1., 1., 1., 1., -inf, -inf],
        [1., 1., 1., 1., 1., 1., 1., -inf],
        [1., 1., 1., 1., 1., 1., 1., 1.]])

In [23]:
wei = F.softmax(wei, dim=1)
wei

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])

### Implementing the averaging head module 

In [24]:
class AveragingHead(nn.Module):
    """ one head of naive aggregation """
    def __init__(self):
        super().__init__()

    def forward(self, x):
        B, T, C = x.shape
        wei = torch.tril(torch.ones(T, T))
        wei = F.softmax(wei.masked_fill(wei==0, float('-inf')), dim=1) # T by T lower tri mat
        agg_x = wei @ x # (T, T) @ (B, T, C) --> (B, T, C)
        return agg_x

In [25]:
torch.manual_seed(500)
x = torch.rand(1, 3, 3)

ah = AveragingHead()
agg_x = ah(x)
print(f'x: {x}')
print(f'agg_x: {agg_x}')

x: tensor([[[0.5820, 0.1338, 0.7995],
         [0.3071, 0.6526, 0.6105],
         [0.1575, 0.6983, 0.7883]]])
agg_x: tensor([[[0.5820, 0.1338, 0.7995],
         [0.4446, 0.3932, 0.7050],
         [0.3489, 0.4949, 0.7328]]])


### Adding aggregation head to Bigram LM

#### A few changes on BigramLanguageModelV2 from BigramLanguageModel
1. add position_embedding_table (along T / block_size axis) to capture positional info
2. parameterize n_embed in embedding tables to configure # dimensions of embedding vectors
3. add AggregationHead to establish the weakest form of communication between upper context of text
4. Now that we have implemented positional embedding, we cannot feed idx longer than block_size, else we will get index out of range when accessing the positional embedding table. idx is cropped to the last block_size block during generate()

In [26]:
class BigramLanguageModelWithAveragingHead(nn.Module):

    def __init__(self, vocab_size, n_embed):
        super().__init__()
        
        # each token directly reads off the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, n_embed)
        self.position_embedding_table = nn.Embedding(block_size, n_embed)
        self.head = AveragingHead()
        self.lm_head = nn.Linear(n_embed, vocab_size)

    def forward(self, idx, targets=None):
        B, T = idx.shape
        # idx and targets are both (B, T) tensor of integers
        tok_emb = self.token_embedding_table(idx) # (B, T, n_embed)
        pos_emb = self.position_embedding_table(torch.arange(T)) # (T, n_embed)
        x = tok_emb + pos_emb # (B, T, C)
        agg_x = self.head(x) # (B, C, C)
        logits = self.lm_head(agg_x)  # (B, T, vocab_size)
        
        if targets is None:
            loss = None
        else:
            # reshape logits & targets into 2D to cater for F.corss_entropy
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)
        
        return logits, loss
    
    def generate(self, idx, max_new_tokens): # idx is (B, T) array of indices in the current context
        
        # generate max_new_tokens tokens iteratively by looking at only the last token each time
        for _ in range(max_new_tokens):
            
            # crop idx to the last block_size tokens
            idx_cropped = idx[:, -block_size:]
            # get the predictions
            logits, loss = self(idx_cropped)
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx

In [27]:
m1 = torch.compile(BigramLanguageModelWithAveragingHead(vocab_size, n_embed=32))
m1

OptimizedModule(
  (_orig_mod): BigramLanguageModelWithAveragingHead(
    (token_embedding_table): Embedding(65, 32)
    (position_embedding_table): Embedding(8, 32)
    (head): AveragingHead()
    (lm_head): Linear(in_features=32, out_features=65, bias=True)
  )
)

In [28]:
torch.manual_seed(1337)
train(m1, 20000, batch_size, block_size, lr=1e-3)

2.3330883979797363


In [29]:
idx = torch.zeros((1,1), dtype=torch.long)
print(decode(m1.generate(idx, max_new_tokens=100)[0].tolist()))


Jo win, pull jutourd fay t as, h ith w y BULousithatrt thiouthe?
ARachissthiver y size, meatane nd h
