# Training GPT-1 on Tiny Shakespeare

In [22]:
import torch
import torch.nn as nn
from torch.nn import functional as F

In [23]:
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

print(len(text))

1115394


In [24]:
# get unique chars
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(''.join(chars))
print(vocab_size)


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
65


# Preprocessing

Tokenizing

In our implementation, we are simply converting individual characters into integers, since we are building a character-level model. <br>

In [25]:
stoi = {ch:i for i,ch in enumerate(chars)}
itos = {i:ch for i, ch in enumerate(chars)}
encode = lambda s: [stoi[c] for c in s] # convert a string into a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # convert a list of integers into a string
print(encode('hi there'))
print(decode(encode('hi there')))

[46, 47, 1, 58, 46, 43, 56, 43]
hi there


Tokenizing the entire set

In [26]:
data = torch.tensor(encode(text), dtype=torch.long)
print(data.shape, data.dtype)
print(data[:1000])

torch.Size([1115394]) torch.int64
tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 14, 43, 44,
        53, 56, 43,  1, 61, 43,  1, 54, 56, 53, 41, 43, 43, 42,  1, 39, 52, 63,
         1, 44, 59, 56, 58, 46, 43, 56,  6,  1, 46, 43, 39, 56,  1, 51, 43,  1,
        57, 54, 43, 39, 49,  8,  0,  0, 13, 50, 50, 10,  0, 31, 54, 43, 39, 49,
         6,  1, 57, 54, 43, 39, 49,  8,  0,  0, 18, 47, 56, 57, 58,  1, 15, 47,
        58, 47, 64, 43, 52, 10,  0, 37, 53, 59,  1, 39, 56, 43,  1, 39, 50, 50,
         1, 56, 43, 57, 53, 50, 60, 43, 42,  1, 56, 39, 58, 46, 43, 56,  1, 58,
        53,  1, 42, 47, 43,  1, 58, 46, 39, 52,  1, 58, 53,  1, 44, 39, 51, 47,
        57, 46, 12,  0,  0, 13, 50, 50, 10,  0, 30, 43, 57, 53, 50, 60, 43, 42,
         8,  1, 56, 43, 57, 53, 50, 60, 43, 42,  8,  0,  0, 18, 47, 56, 57, 58,
         1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 18, 47, 56, 57, 58,  6,  1, 63,
        53, 59,  1, 49, 52, 53, 61,  1, 15, 39, 47, 59, 57,  1, 25, 39, 56, 41,
      

Train Test Split

In [27]:
n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]

# 'Chunking' the data <br>
Feeding the entire data into the model at once will be computationally infeasible. Hence we sample random chunks and feed into the model

In [28]:
block_size = 8
train_data[:block_size+1]

tensor([18, 47, 56, 57, 58,  1, 15, 47, 58])

A chunk, in essence, is multiple training examples because the model is trained to read each incremental sequence of characters and predict the next in sequence <br>
e.g. in 18, the prediction is 47. in 18,47 the prediction is 56, etc

In [29]:
x = train_data[:block_size] # the first block size chars
y = train_data[1:block_size+1] # the next block size char offset by 1
for t in range(block_size):
    context = x[:t+1] # chars up to and including t
    target = y[t]
    print(f'when input is {context}, the target is {target}')


when input is tensor([18]), the target is 47
when input is tensor([18, 47]), the target is 56
when input is tensor([18, 47, 56]), the target is 57
when input is tensor([18, 47, 56, 57]), the target is 58
when input is tensor([18, 47, 56, 57, 58]), the target is 1
when input is tensor([18, 47, 56, 57, 58,  1]), the target is 15
when input is tensor([18, 47, 56, 57, 58,  1, 15]), the target is 47
when input is tensor([18, 47, 56, 57, 58,  1, 15, 47]), the target is 58


Training in this way serves two purposes <br>
1. Computational efficiency <br>
2. Familiarize the transformer with context from as little as 1, all the way to its blocksize


Batching our data for parallelization

In [30]:
torch.manual_seed(1337)
batch_size = 4
block_size = 8

def get_batch(split):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,)) # generate batch size numbers of random offsets
    x = torch.stack([data[i:i+block_size] for i in ix]) # the first block size chars, stacked from 1D tensors into rows
    y = torch.stack([data[i+1:i+block_size+1] for i in ix]) # the offset 
    return x, y

xb, yb = get_batch('train')
print('inputs:')
print(xb.shape)
print('targets:')
print(yb.shape)
print(yb)

print('------')

for b in range(batch_size):
    for t in range(block_size):
        context = xb[b, :t+1]
        target = yb[b, t]
        print(f'when input is {context.tolist()} the target is {target}')

inputs:
torch.Size([4, 8])
targets:
torch.Size([4, 8])
tensor([[43, 58,  5, 57,  1, 46, 43, 39],
        [53, 56,  1, 58, 46, 39, 58,  1],
        [58,  1, 58, 46, 39, 58,  1, 46],
        [17, 27, 10,  0, 21,  1, 54, 39]])
------
when input is [24] the target is 43
when input is [24, 43] the target is 58
when input is [24, 43, 58] the target is 5
when input is [24, 43, 58, 5] the target is 57
when input is [24, 43, 58, 5, 57] the target is 1
when input is [24, 43, 58, 5, 57, 1] the target is 46
when input is [24, 43, 58, 5, 57, 1, 46] the target is 43
when input is [24, 43, 58, 5, 57, 1, 46, 43] the target is 39
when input is [44] the target is 53
when input is [44, 53] the target is 56
when input is [44, 53, 56] the target is 1
when input is [44, 53, 56, 1] the target is 58
when input is [44, 53, 56, 1, 58] the target is 46
when input is [44, 53, 56, 1, 58, 46] the target is 39
when input is [44, 53, 56, 1, 58, 46, 39] the target is 58
when input is [44, 53, 56, 1, 58, 46, 39, 58] th

# ----------------------------------------

# Neural Networks

Bi-Gram Language Model

In the code, the constructor creates an embedding table. <br> 
When we pass idx in the forward pass, every int in the input refers to the embedding table and gets a row corresponding to the index. <br>
This is returns in a batch, time, channel (B, T, C) tensor ie (4,8,65)

In [31]:
torch.manual_seed(1337)
class BigramLanguageModel(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)

    def forward(self, idx, targets=None):
        logits = self.token_embedding_table(idx) 
        if targets  is None:
            loss = None
        else:
            B, T, C = logits.shape # batch, time, channels
            logits = logits.view(B*T, C) # stretching the array into 2D
            targets = targets.view(B*T) # stretching to 1D
            loss = F.cross_entropy(logits, targets)
        return logits, loss
    
    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            logits, loss = self(idx)
            # get the last time step
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim=1)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1)
            # append to the running sequence
            idx = torch.cat((idx, idx_next), dim=1)
        return idx
    
m = BigramLanguageModel(vocab_size)
logits, loss = m(xb, yb)
print(logits.shape)
print(loss)

idx = torch.zeros((1, 1), dtype=torch.long) # we are feeding a batch of zeros (newline chars)

print(decode(m.generate(idx, max_new_tokens=100)[0].tolist())) # [0] get the content without the batchsize

torch.Size([32, 65])
tensor(4.8786, grad_fn=<NllLossBackward0>)

SKIcLT;AcELMoTbvZv C?nq-QE33:CJqkOKH-q;:la!oiywkHjgChzbQ?u!3bLIgwevmyFJGUGp
wnYWmnxKWWev-tDqXErVKLgJ


Training the Bigram

In [32]:
optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)
batch_size = 32
for steps in range(10000):
    xb, yb = get_batch('train')
    logits, loss = m(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
print(loss.item())

2.382369041442871


In [33]:
print(decode(m.generate(idx, max_new_tokens=300)[0].tolist()))


lso br. ave aviasurf my, yxMPZI ivee iuedrd whar ksth y h bora s be hese, woweee; the! KI 'de, ulseecherd d o blllando;LUCEO, oraingofof win!
RIfans picspeserer hee tha,
TOFonk? me ain ckntoty ded. bo'llll st ta d:
ELIS me hurf lal y, ma dus pe athouo
BEY:! Indy; by s afreanoo adicererupa anse tecor


# Standard Training Loop

In [54]:
batch_size = 32 # num of independent sequences processed in parallel
block_size = 8 # max context length for prediction
max_iters = 3000
eval_interval = 300
learning_rate = 1e-2
device = torch.device('mps')
eval_iters = 200
n_embd = 32

torch.manual_seed(1337)

with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

# get all unique chars in the text
chars = sorted(list(set(text)))
vocab_size = len(chars)
stoi = {ch:i for i,ch in enumerate(chars)}
itos = {i:ch for i,ch in enumerate(chars)}
encode = lambda s: [stoi[c] for c in s] # encoder
decode = lambda l: ''.join(itos[i] for i in l) # decoder

# train test split
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]

# data loading
def get_batch(split):
    data = train_data if split=='train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y

@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

class BigramLanguageModel(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        # token reads off the logits for the next token from look-up table
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        # we also embed the position of the tokens
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        self.lm_head = nn.Linear(n_embd, vocab_size)

    def forward(self, idx, targets=None):
        B, T = idx.shape
        tok_emb = self.token_embedding_table(idx) # (B, T, C)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T, C)
        # then we pass the token embedding and positional embedding
        x = tok_emb + pos_emb
        logits = self.lm_head(x) # (B, T, vocab_size)
        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)
        return logits, loss

    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            logits, loss = self(idx)
            # focus on last time step 
            logits = logits[:, -1, :] # becomes (B, C)
            # get probs
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1)
        return idx
    
model = BigramLanguageModel(vocab_size)
m = model.to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

for iter in range(max_iters):
    # eval loss on train and val sets every once a while
    if iter % eval_interval == 0:
        losses = estimate_loss()
        print(f'Step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}')
    
    # sample batch of data
    xb, yb = get_batch('train')
    
    # evaluate loss
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

# generate from the model
context = torch.zeros((1, 1), dtype=torch.long, device=device)
print(decode(m.generate(context, max_new_tokens=500)[0].tolist()))



Step 0: train loss 4.4801, val loss 4.4801
Step 300: train loss 2.5404, val loss 2.5566
Step 600: train loss 2.5160, val loss 2.5335
Step 900: train loss 2.4967, val loss 2.5149
Step 1200: train loss 2.5106, val loss 2.5254
Step 1500: train loss 2.4853, val loss 2.5109
Step 1800: train loss 2.4966, val loss 2.5198
Step 2100: train loss 2.4949, val loss 2.5100
Step 2400: train loss 2.4937, val loss 2.5102
Step 2700: train loss 2.5040, val loss 2.5114

Foasth prsexpizequppathel
GOMUKEE&CKIOMINCHUKEENORineg aggellprdrrvetecowhrthy;
The?
TyONGrsothy,
D HPayomind ppry Pad avend
Wh T:
TRDUMEYOf bykncaknd-htcthy
BORYOFRIOMBOf gwisexprenlbuststlant'GSCKENGBETHURK:
MONTRKIIDUKIEUSLUKI:
Why!
TICllppp'ly BEThapy HEED:
PORKINGLnk&CHBRI's,
B&CllDWlerd,
uqughold crrayf

PESThe
TJUSENCTINIXEnd!
3 bald-YOLYowmy!
HENRKVOfakthindy;
MBERK:
QULANGEGHLOfourknghm H:
YTIppmplkedextherex'digavitfuthrd.'GENUS:
My,
WhindKINGABedigerdjurd.
MERMjulKENGickequr.F INIS:
Y


# Self Attention

A Mathematical Trick

In [36]:
torch.manual_seed(1337)
B, T, C = 4, 8, 2 
x = torch.randn(B, T, C)
x.shape

torch.Size([4, 8, 2])

We have 8 tokens in a batch that initially don't 'talk' to each other <br>
We'd like to couple them in a way that the token in e.g. 5th location should not communicate with the token in the 6th, 7th, and 8th location<br>because those are in future locations and we are trying to predict the future token<br>
The token in 5th position should only talk to those in the 4th, 3rd, 2nd, and 1st position i.e. info only flows from previous timesteps <br>
to the current timestep <br>

The easiest way to communicate is to do an average of all previous elements <br>
e.g. in the 5th position, we want to take the average of channels in the 4th, 3rd, 2nd, and 1st positions <br>
This becomes sort of a feature vector that summarizes the 5th in the context of its history <br>
This 'average' is, of course, a very simplified reduction of information

Now, for every t-th token in a sequence, we want to calculate the average of all vectors in previous tokens and at the current token <br>
i.e. we want x[b,t] = mean_{i<=t} x[b, i]

In [37]:
xbow = torch.zeros((B,T,C)) # x bag of words
for b in range(B):
    for t in range(T):
        xprev = x[b,:t+1] # everything up to and including the t-th token
        xbow[b,t] = torch.mean(xprev, 0)

Making it efficient with MatMul

In [38]:
torch.manual_seed(42)
a = torch.ones(3, 3)
b = torch.randint(0,10,(3,2)).float()
c = a @ b
print(f'{a=}')
print(f'{b=}')
print(f'{c=}')

a=tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]])
b=tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
c=tensor([[14., 16.],
        [14., 16.],
        [14., 16.]])


torch has a function for lower triangular matrix

In [39]:
torch.manual_seed(42)
a = torch.tril(torch.ones(3, 3))
b = torch.randint(0,10,(3,2)).float()
c = a @ b
print(f'{a=}')
print(f'{b=}')
print(f'{c=}')

a=tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])
b=tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
c=tensor([[ 2.,  7.],
        [ 8., 11.],
        [14., 16.]])


By design of matmul, the first row of a dot prod first col of b <br>
But since row 0 of a = 1, 0, 0, the 0, 0 essentially means that 6, 6, of col 1 in b are ignored and we get 2 + 0 + 0<br>
Now, with row 1 of a = 1, 1, 0, the last 0 means the last 6 of col 1 gets ignore and we get 2 + 6 = 8 <br>
And, with row 2 of a = 1, 1, 1, we sum up everything and get 2 + 6 + 6 = 14. <br>
Same logic applies for col 2

Depending on the number of zeros and ones, we are therefore doing a sum of the variables in b and depositing in c <br>
We can similarly do average of the rows of b in an incremental fashion by normalizing the rows of a so they sum to 1. <br>
Then we get an average

In [40]:
torch.manual_seed(42)
a = torch.tril(torch.ones(3, 3))
a = a / torch.sum(a, 1, keepdim=True)
b = torch.randint(0,10,(3,2)).float()
c = a @ b
print(f'{a=}')
print(f'{b=}')
print(f'{c=}')

a=tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])
b=tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
c=tensor([[2.0000, 7.0000],
        [4.0000, 5.5000],
        [4.6667, 5.3333]])


We are now getting the average of the rows in the next row (excluding row 0) i.e.<br>
4.0 = (2+6) / 2 and 5.5 = (7+4) / 2 <br>
4.6 = (2+6+6) / 3 and 5.3 = (7+4+5) / 3

Vectorizing our bag of words

In [41]:
wei = torch.tril(torch.ones(T, T)) # weights
wei = wei / wei.sum(1, keepdim=True)
wei

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])

In [42]:
xbow2 = wei @ x # (T, T) @ (B, T, C) ---> (B, T, C)
torch.allclose(xbow, xbow2)

True

Another way of writing it: Using Softmax to normalize. <br>
In max_fill, we say for all elements where tril=0, make them -inf <br>
Softmax then normalizes the rows to 1

In [None]:
tril = torch.tril(torch.ones(T, T))
wei = torch.zeros((T, T))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)
xbow3 = wei @ x
torch.allclose(xbow, xbow3)

RuntimeError: The size of tensor a (2) must match the size of tensor b (32) at non-singleton dimension 2

In self attention terms

wei = torch.zeros() tells us that initially, the current tokens have zero affinity for past tokens. These token affinities will become data dependent and update each other during training <br>
wei = wei.masked_fill(tril == 0, float('-inf')) we say we won't aggregate anything from future tokens, i.e. the future can't communicate with the past <br>
wei = F.softmax(wei, dim=1) we then aggregate their values depending on how interesting they find each other <br>

# Query and Vector

We don't want wei to be uniform because diffferent tokens will attend to diff tokens in a data dependent manner e.g. a vowel needs to know what are the consonants in its past<br>
Self-attention solves this through two vectors - key and query - emitted by an input vector <br>

A query vector has the information we are looking for and a key vector has the information contained by the input

Affinity between the sequences is a dot product between the keys and the queries <br>
A query dot products with all the keys of all other tokens and the dot product becomes wei <br>
If a key and query align, it interacts at a very high amount and learns more about that specific token as opposed to other tokens in the sequence 

Implementing a Single Attention Head

In [51]:
torch.manual_seed(1337)
B, T, C = 4, 8, 32 
x = torch.randn(B, T, C)

# single head performing self attention
head_size = 16
key = nn.Linear(C, head_size, bias=False) # direct MatMul without bias
query = nn.Linear(C, head_size, bias=False)
k = key(x) # (B, T, 16)
q = query(x) # (B, T, 16)

# at this stage, no communication has happened between k and q
# now, we make them communicate
wei = q @ k.transpose(-2, -1) # (B, T, 16) @ (B, 16, T) ---> (B, T, T)

tril = torch.tril(torch.ones(T, T))
# wei = torch.zeros((T, T))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)
out = wei @ x
out.shape

torch.Size([4, 8, 32])

Now, wei has a dynamic shape within batches that are based on token sizes 

In [52]:
wei[0]

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1574, 0.8426, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2088, 0.1646, 0.6266, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5792, 0.1187, 0.1889, 0.1131, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0294, 0.1052, 0.0469, 0.0276, 0.7909, 0.0000, 0.0000, 0.0000],
        [0.0176, 0.2689, 0.0215, 0.0089, 0.6812, 0.0019, 0.0000, 0.0000],
        [0.1691, 0.4066, 0.0438, 0.0416, 0.1048, 0.2012, 0.0329, 0.0000],
        [0.0210, 0.0843, 0.0555, 0.2297, 0.0573, 0.0709, 0.2423, 0.2391]],
       grad_fn=<SelectBackward0>)

Let's consider the 8th token on the last row <br>
It knows what content it has and the position it's in. <br>
Based on this, it creates a query "I'm a vowel, and I'm looking for any consonants at position up to X" <br>
All nodes will emit keys and maybe one of the nodes will emits keys that respond to the query "I am a consonant and I am in a position up to X" <br>
That key will have a high number in that specific channel, thus creating a high affinity e.g. on the last row, token 0.2297 and 0.2423 are of interest to the 8th token

When we do the aggregation, we don't actually aggregate on the tokens x. We aggregate on the 'value'

In [53]:
torch.manual_seed(1337)
B, T, C = 4, 8, 32 
x = torch.randn(B, T, C)

# single head performing self attention
head_size = 16
key = nn.Linear(C, head_size, bias=False) # direct MatMul without bias
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)
k = key(x) # (B, T, 16)
q = query(x) # (B, T, 16)
v = value(x)

# at this stage, no communication has happened between k and q
# now, we make them communicate
wei = q @ k.transpose(-2, -1) # (B, T, 16) @ (B, 16, T) ---> (B, T, T)

tril = torch.tril(torch.ones(T, T))
# wei = torch.zeros((T, T))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)

out = wei @ v
out.shape

torch.Size([4, 8, 16])

# Notes on Attention

We can think of x as private to the token e.g. I'm the 5th token with some identity, my information is kept in vector x <br>
For the purpose of a single head, here's what I'm interested in (q), here's what I have (k), and if you find me interesting, here's what I will communicate with you (v)

Attention can be seen as a communication mechanism between a number of nodes in a directed graph. <br>
Every node has a vector of information and it gets to aggregate the information via a weighted sum of all the nodes that point to it (in a data dependent manner) <br>
Our graph has 8 nodes (block size=8). The first node is pointed to by itself, the second node pointed to the first node and itself and so on.

Attention also has no spatial information, so we must encode the positional information to the nodes <br>
This is different from convolution, which has a sort of layout of information in the space

Elements across batch dimensions never talk to each other and are processed independently. 

In some cases, we want nodes to talk to each other fully i.e. not limited to past nodes. e.g. sentiment analysis <br>
This full-ended communication is done by eliminating the masking wei = wei.masked_fill(tril == 0, float('-inf')) 

Self attention vs Cross Attention <br>
In [our] self attention, the k, q, v all come from the same source (x), so the nodes are self-attending. <br>
In cross attention, attention can have q from x and k, v come from some separate source nodes. We produce queries and read information from other sources.

In the reference paper, we divide the softmax by sqrt(d) - referred to as scaled attention. <br>
This normalization prevents extremeties which result in softmax producing skewed values. 

In [56]:
batch_size = 32 # num of independent sequences processed in parallel
block_size = 8 # max context length for prediction
max_iters = 5000
eval_interval = 500
learning_rate = 1e-3
device = torch.device('mps')
eval_iters = 200
n_embd = 32

torch.manual_seed(1337)

with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

# get all unique chars in the text
chars = sorted(list(set(text)))
vocab_size = len(chars)
stoi = {ch:i for i,ch in enumerate(chars)}
itos = {i:ch for i,ch in enumerate(chars)}
encode = lambda s: [stoi[c] for c in s] # encoder
decode = lambda l: ''.join(itos[i] for i in l) # decoder

# train test split
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]

# data loading
def get_batch(split):
    data = train_data if split=='train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y

@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

class Head(nn.Module):
    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size))) # create a buffer which we call tril 

    def forward(self, x):
        B, T, C = x.shape
        k = self.key(x)
        q = self.query(x)
        # attention scores
        wei = q @ k.transpose(-2, -1) * C**-0.5
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        wei = F.softmax(wei, dim=-1)
        # weighted aggregation of values
        v = self.value(x)
        out = wei @ v 
        return out
    
class BigramLanguageModel(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        # token reads off the logits for the next token from look-up table
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        # we also embed the position of the tokens
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        self.sa_head = Head(n_embd)
        self.lm_head = nn.Linear(n_embd, vocab_size)

    def forward(self, idx, targets=None):
        B, T = idx.shape
        tok_emb = self.token_embedding_table(idx) # (B, T, C)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T, C)
        # then we pass the token embedding and positional embedding
        x = tok_emb + pos_emb
        x = self.sa_head(x)
        logits = self.lm_head(x) # (B, T, vocab_size)
        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)
        return logits, loss

    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            # crop idx to fit into block size tokens
            idx_cond = idx[:, -block_size:]
            logits, loss = self(idx_cond)
            # focus on last time step 
            logits = logits[:, -1, :] # becomes (B, C)
            # get probs
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1)
        return idx
    
model = BigramLanguageModel(vocab_size)
m = model.to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

for iter in range(max_iters):
    # eval loss on train and val sets every once a while
    if iter % eval_interval == 0:
        losses = estimate_loss()
        print(f'Step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}')
    
    # sample batch of data
    xb, yb = get_batch('train')
    
    # evaluate loss
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

# generate from the model
context = torch.zeros((1, 1), dtype=torch.long, device=device)
print(decode(m.generate(context, max_new_tokens=500)[0].tolist()))


Step 0: train loss 4.2000, val loss 4.2047
Step 500: train loss 2.6911, val loss 2.7087
Step 1000: train loss 2.5196, val loss 2.5303
Step 1500: train loss 2.4775, val loss 2.4829
Step 2000: train loss 2.4408, val loss 2.4523
Step 2500: train loss 2.4272, val loss 2.4435
Step 3000: train loss 2.4130, val loss 2.4327
Step 3500: train loss 2.3956, val loss 2.4212
Step 4000: train loss 2.4041, val loss 2.3992
Step 4500: train loss 2.3980, val loss 2.4084

UNTGULOK:
MI foenderst el
O urfrnievil:
Alesk, COI yeg agnell thre Mtecoror shad ge?


ONGreothakechous omou mpery waly, the oube, er sickes bokecard ihiceny
Bing?

Al fe of ise fre lbustsel withous; to. Com artl at;
I me ffaruk monden itheland'l oer oghithet f, bad gien yof thougre yucouler asureis, nt rt hingesty ckield, wins, in mamy thavenyongmeroe, dojooutthendy sak shil brves
GHaraster him to, oupomp rede ds it hor avit gin LUSan thoms lathind my ouerouer aby any sot,
K thicerare.

I IS:
Y


# Multi Head Attention

Multiple self attention running in parallel

In [59]:
batch_size = 32 # num of independent sequences processed in parallel
block_size = 8 # max context length for prediction
max_iters = 5000
eval_interval = 500
learning_rate = 1e-3
device = torch.device('mps')
eval_iters = 200
n_embd = 32

torch.manual_seed(1337)

with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

# get all unique chars in the text
chars = sorted(list(set(text)))
vocab_size = len(chars)
stoi = {ch:i for i,ch in enumerate(chars)}
itos = {i:ch for i,ch in enumerate(chars)}
encode = lambda s: [stoi[c] for c in s] # encoder
decode = lambda l: ''.join(itos[i] for i in l) # decoder

# train test split
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]

# data loading
def get_batch(split):
    data = train_data if split=='train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y

@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

class Head(nn.Module):
    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size))) # create a buffer which we call tril 

    def forward(self, x):
        B, T, C = x.shape
        k = self.key(x)
        q = self.query(x)
        # attention scores
        wei = q @ k.transpose(-2, -1) * C**-0.5
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        wei = F.softmax(wei, dim=-1)
        # weighted aggregation of values
        v = self.value(x)
        out = wei @ v 
        return out

class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])

    def forward(self, x):
        return torch.cat([h(x) for h in self.heads], dim=-1)

class BigramLanguageModel(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        # token reads off the logits for the next token from look-up table
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        # we also embed the position of the tokens
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        self.sa_heads = MultiHeadAttention(4, n_embd//4) # we have 4 comm channels, we want 8-dim self-attention so it concats to n_embed=32
        self.lm_head = nn.Linear(n_embd, vocab_size)

    def forward(self, idx, targets=None):
        B, T = idx.shape
        tok_emb = self.token_embedding_table(idx) # (B, T, C)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T, C)
        # then we pass the token embedding and positional embedding
        x = tok_emb + pos_emb
        x = self.sa_heads(x)
        logits = self.lm_head(x) # (B, T, vocab_size)
        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)
        return logits, loss

    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            # crop idx to fit into block size tokens
            idx_cond = idx[:, -block_size:]
            logits, loss = self(idx_cond)
            # focus on last time step 
            logits = logits[:, -1, :] # becomes (B, C)
            # get probs
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1)
        return idx
    
model = BigramLanguageModel(vocab_size)
m = model.to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

for iter in range(max_iters):
    # eval loss on train and val sets every once a while
    if iter % eval_interval == 0:
        losses = estimate_loss()
        print(f'Step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}')
    
    # sample batch of data
    xb, yb = get_batch('train')
    
    # evaluate loss
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

# generate from the model
context = torch.zeros((1, 1), dtype=torch.long, device=device)
print(decode(m.generate(context, max_new_tokens=500)[0].tolist()))


Step 0: train loss 4.2227, val loss 4.2226
Step 500: train loss 2.6592, val loss 2.6733
Step 1000: train loss 2.4980, val loss 2.5064
Step 1500: train loss 2.4291, val loss 2.4349
Step 2000: train loss 2.3716, val loss 2.3844
Step 2500: train loss 2.3417, val loss 2.3561
Step 3000: train loss 2.3149, val loss 2.3347
Step 3500: train loss 2.2918, val loss 2.3171
Step 4000: train loss 2.2895, val loss 2.2868
Step 4500: train loss 2.2748, val loss 2.2858

UNTGUST:
Ye pich heist el
O dof Wie by to yok, CO, tea agethe torr gaecoror cund ge?
Wen, reath LoD of youriompery wallav I cou hater sickes
Tokt ard dhiceny
Bo, troel fef gaise fre le stselant'dcus;
I mey
Thavely ourind houjeris, anrntit.

FAn's of I gimy.

Q:
That gientyou thoughe yurouler'dsureis lat riky nok thackield, wits, in mamy thout yougmeroe, do of that,
Nown
Hefil breche warlster him to, oupomp rete dat thim gaikt gin Theandersts lay ind my woed ourse I toy son,
KGo,-O'd arir notesed


# Adding the Feedforward Component

In the multi-head, we had the attention that performed communication, but we were too fast in logit calculation. <br>
The feed forward single layer performs a feed forward operation on the logits <br>
The self attention is the communication, and once that is done for all tokens, the nodes need to think on the data individually - which is what Feed Forward does

In [60]:
batch_size = 32 # num of independent sequences processed in parallel
block_size = 8 # max context length for prediction
max_iters = 5000
eval_interval = 500
learning_rate = 1e-3
device = torch.device('mps')
eval_iters = 200
n_embd = 32

torch.manual_seed(1337)

with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

# get all unique chars in the text
chars = sorted(list(set(text)))
vocab_size = len(chars)
stoi = {ch:i for i,ch in enumerate(chars)}
itos = {i:ch for i,ch in enumerate(chars)}
encode = lambda s: [stoi[c] for c in s] # encoder
decode = lambda l: ''.join(itos[i] for i in l) # decoder

# train test split
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]

# data loading
def get_batch(split):
    data = train_data if split=='train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y

@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

class Head(nn.Module):
    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size))) # create a buffer which we call tril 

    def forward(self, x):
        B, T, C = x.shape
        k = self.key(x)
        q = self.query(x)
        # attention scores
        wei = q @ k.transpose(-2, -1) * C**-0.5
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        wei = F.softmax(wei, dim=-1)
        # weighted aggregation of values
        v = self.value(x)
        out = wei @ v 
        return out

class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])

    def forward(self, x):
        return torch.cat([h(x) for h in self.heads], dim=-1)

class FeedForward(nn.Module):
    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(nn.Linear(n_embd, n_embd), nn.ReLU())
    def forward(self, x):
        return self.net(x)

class BigramLanguageModel(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        # token reads off the logits for the next token from look-up table
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        # we also embed the position of the tokens
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        self.sa_heads = MultiHeadAttention(4, n_embd//4) # we have 4 comm channels, we want 8-dim self-attention so it concats to n_embed=32
        self.ffwd = FeedForward(n_embd)
        self.lm_head = nn.Linear(n_embd, vocab_size)

    def forward(self, idx, targets=None):
        B, T = idx.shape
        tok_emb = self.token_embedding_table(idx) # (B, T, C)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T, C)
        # then we pass the token embedding and positional embedding
        x = tok_emb + pos_emb
        x = self.sa_heads(x)
        x = self.ffwd(x)
        logits = self.lm_head(x) # (B, T, vocab_size)
        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)
        return logits, loss

    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            # crop idx to fit into block size tokens
            idx_cond = idx[:, -block_size:]
            logits, loss = self(idx_cond)
            # focus on last time step 
            logits = logits[:, -1, :] # becomes (B, C)
            # get probs
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1)
        return idx
    
model = BigramLanguageModel(vocab_size)
m = model.to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

for iter in range(max_iters):
    # eval loss on train and val sets every once a while
    if iter % eval_interval == 0:
        losses = estimate_loss()
        print(f'Step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}')
    
    # sample batch of data
    xb, yb = get_batch('train')
    
    # evaluate loss
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

# generate from the model
context = torch.zeros((1, 1), dtype=torch.long, device=device)
print(decode(m.generate(context, max_new_tokens=500)[0].tolist()))


Step 0: train loss 4.1996, val loss 4.1995
Step 500: train loss 2.5993, val loss 2.6077
Step 1000: train loss 2.4628, val loss 2.4651
Step 1500: train loss 2.3975, val loss 2.3952
Step 2000: train loss 2.3298, val loss 2.3472
Step 2500: train loss 2.3011, val loss 2.3215
Step 3000: train loss 2.2830, val loss 2.2928
Step 3500: train loss 2.2494, val loss 2.2716
Step 4000: train loss 2.2432, val loss 2.2457
Step 4500: train loss 2.2291, val loss 2.2411

Fo saw prue to chatist eis yu friie hy tolesk, COICHed agetle torrigkect or cund to of O, rioth Loch,
Wel on mpend wallav he ou here. Pickes boktheall-hice: is ow?

Af sef awith hee letst tlowt' cull to loce artly ould mouds fris, and fithell wel of soghit.

QUS:
Thougen.
Why, to,
Brouck lead,
uqiis, notrrayf

DUCOLIUCONCENCHINIBSI mamy thave yougmy, eord Vofett;
Afrest
Hefquind, have allles him to, oupped rote dat thim gack fathrin'd
tursts lathise my doed our chalvoy son,
Kathit dire.
In ther 


# Transformer: Communication + Computation

<img src='transformer.png'>

We now intersperse communication with computation <br>
This is done through a transformer block

In [None]:
batch_size = 32 # num of independent sequences processed in parallel
block_size = 8 # max context length for prediction
max_iters = 5000
eval_interval = 500
learning_rate = 1e-3
device = torch.device('mps')
eval_iters = 200
n_embd = 32

torch.manual_seed(1337)

with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

# get all unique chars in the text
chars = sorted(list(set(text)))
vocab_size = len(chars)
stoi = {ch:i for i,ch in enumerate(chars)}
itos = {i:ch for i,ch in enumerate(chars)}
encode = lambda s: [stoi[c] for c in s] # encoder
decode = lambda l: ''.join(itos[i] for i in l) # decoder

# train test split
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]

# data loading
def get_batch(split):
    data = train_data if split=='train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y

@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

class Head(nn.Module):
    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size))) # create a buffer which we call tril 

    def forward(self, x):
        B, T, C = x.shape
        k = self.key(x)
        q = self.query(x)
        # attention scores
        wei = q @ k.transpose(-2, -1) * C**-0.5
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        wei = F.softmax(wei, dim=-1)
        # weighted aggregation of values
        v = self.value(x)
        out = wei @ v 
        return out

class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])

    def forward(self, x):
        return torch.cat([h(x) for h in self.heads], dim=-1)

class FeedForward(nn.Module):
    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(nn.Linear(n_embd, n_embd), nn.ReLU())
    def forward(self, x):
        return self.net(x)
    
class Block(nn.Module):
    # Transformer block: communication followed by computation
    def __init__(self, n_embd, n_head):
        super().__init__()
        head_size = n_embd // n_head
        self.sa = MultiHeadAttention(n_head, head_size)
        self.ffwd = FeedForward(n_embd)

    def forward(self, x):
        x = self.sa(x)
        x = self.ffwd(x)
        return x

class BigramLanguageModel(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        # token reads off the logits for the next token from look-up table
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        # we also embed the position of the tokens
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        self.blocks = nn.Sequential(
            Block(n_embd, n_head=4),
            Block(n_embd, n_head=4),
            Block(n_embd, n_head=4)
        )
        self.lm_head = nn.Linear(n_embd, vocab_size)

    def forward(self, idx, targets=None):
        B, T = idx.shape
        tok_emb = self.token_embedding_table(idx) # (B, T, C)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T, C)
        # then we pass the token embedding and positional embedding
        x = tok_emb + pos_emb
        x = self.blocks(x)
        logits = self.lm_head(x) # (B, T, vocab_size)
        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)
        return logits, loss

    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            # crop idx to fit into block size tokens
            idx_cond = idx[:, -block_size:]
            logits, loss = self(idx_cond)
            # focus on last time step 
            logits = logits[:, -1, :] # becomes (B, C)
            # get probs
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1)
        return idx
    
model = BigramLanguageModel(vocab_size)
m = model.to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

for iter in range(max_iters):
    # eval loss on train and val sets every once a while
    if iter % eval_interval == 0:
        losses = estimate_loss()
        print(f'Step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}')
    
    # sample batch of data
    xb, yb = get_batch('train')
    
    # evaluate loss
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

# generate from the model
context = torch.zeros((1, 1), dtype=torch.long, device=device)
print(decode(m.generate(context, max_new_tokens=500)[0].tolist()))


Our network becomes a deep neural network and suffers from optimization problems <br>
The paper suggests two otimisations to help with the depth of the network: 1) residual connections; 2) Layer norm

# Residual Connections 

Transform the data, and add a skip connection from the previous features. <br>
There is a residual pathway that can be forked from to perform additional computation and the resulting computation is added to the original pathway. <br>
This is useful because addition distributes gradients equally to both of its branches during backprop <br>
The gradients from the loss hops unimpeded from the loss to the input through the residual connection <br>

In [63]:
batch_size = 32 # num of independent sequences processed in parallel
block_size = 8 # max context length for prediction
max_iters = 5000
eval_interval = 500
learning_rate = 1e-3
device = torch.device('mps')
eval_iters = 200
n_embd = 32

torch.manual_seed(1337)

with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

# get all unique chars in the text
chars = sorted(list(set(text)))
vocab_size = len(chars)
stoi = {ch:i for i,ch in enumerate(chars)}
itos = {i:ch for i,ch in enumerate(chars)}
encode = lambda s: [stoi[c] for c in s] # encoder
decode = lambda l: ''.join(itos[i] for i in l) # decoder

# train test split
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]

# data loading
def get_batch(split):
    data = train_data if split=='train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y

@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

class Head(nn.Module):
    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size))) # create a buffer which we call tril 

    def forward(self, x):
        B, T, C = x.shape
        k = self.key(x)
        q = self.query(x)
        # attention scores
        wei = q @ k.transpose(-2, -1) * C**-0.5
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        wei = F.softmax(wei, dim=-1)
        # weighted aggregation of values
        v = self.value(x)
        out = wei @ v 
        return out

class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(num_heads*head_size, n_embd) # projection into the residual pathway

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.proj(out)
        return out

class FeedForward(nn.Module):
    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4*n_embd), # the dimensionality is multiplied by 4 as in the paper
            nn.ReLU(),
            nn.Linear(4*n_embd, n_embd) # projection layer
            )
    def forward(self, x):
        return self.net(x)
    
class Block(nn.Module):
    # Transformer block: communication followed by computation
    def __init__(self, n_embd, n_head):
        super().__init__()
        head_size = n_embd // n_head
        self.sa = MultiHeadAttention(n_head, head_size)
        self.ffwd = FeedForward(n_embd)

    def forward(self, x):
        # Residual connection
        x = x + self.sa(x)
        x = x + self.ffwd(x)
        return x

class BigramLanguageModel(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        # token reads off the logits for the next token from look-up table
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        # we also embed the position of the tokens
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        self.blocks = nn.Sequential(
            Block(n_embd, n_head=4),
            Block(n_embd, n_head=4),
            Block(n_embd, n_head=4)
        )
        self.lm_head = nn.Linear(n_embd, vocab_size)

    def forward(self, idx, targets=None):
        B, T = idx.shape
        tok_emb = self.token_embedding_table(idx) # (B, T, C)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T, C)
        # then we pass the token embedding and positional embedding
        x = tok_emb + pos_emb
        x = self.blocks(x)
        logits = self.lm_head(x) # (B, T, vocab_size)
        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)
        return logits, loss

    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            # crop idx to fit into block size tokens
            idx_cond = idx[:, -block_size:]
            logits, loss = self(idx_cond)
            # focus on last time step 
            logits = logits[:, -1, :] # becomes (B, C)
            # get probs
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1)
        return idx
    
model = BigramLanguageModel(vocab_size)
m = model.to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

for iter in range(max_iters):
    # eval loss on train and val sets every once a while
    if iter % eval_interval == 0:
        losses = estimate_loss()
        print(f'Step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}')
    
    # sample batch of data
    xb, yb = get_batch('train')
    
    # evaluate loss
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

# generate from the model
context = torch.zeros((1, 1), dtype=torch.long, device=device)
print(decode(m.generate(context, max_new_tokens=500)[0].tolist()))


Step 0: train loss 4.6255, val loss 4.6233
Step 500: train loss 2.3885, val loss 2.3850
Step 1000: train loss 2.2705, val loss 2.2691
Step 1500: train loss 2.1873, val loss 2.2097
Step 2000: train loss 2.1481, val loss 2.1844
Step 2500: train loss 2.1071, val loss 2.1539
Step 3000: train loss 2.0710, val loss 2.1441
Step 3500: train loss 2.0611, val loss 2.1178
Step 4000: train loss 2.0285, val loss 2.1121
Step 4500: train loss 2.0038, val loss 2.1022

Upastar duke is beed their your tie.

DUKE VINENO:
Heg agle with shease, make ad gequne am,
To you wousel in mades way as he bubserer sidels beked, dliht thy
But twoell.

VINCENTE:
Ole strelaw!' cust to lace and thould moust frised drefither haply enep is with madies itn you the good deatle
Theumeing not thy presty cried mayses, and bey that my thmeroed do of that, your
Hefor bracimay all wase, and to pomplectes a me his as this readanters,
O, this Kertured head by to camp,
Kat, the and the be f


# Layer Norm

This is similar to batch-norm but normalized row- <br>
Because the computation does not span across examples, we can delete all the buffers. <br>
Similarly, not distinction between training and runtime

In [64]:
class LayerNorm:  
    def __init__(self, dim, eps=1e-5, momentum=0.1):
        self.eps = eps # epsilon
        self.gamma = torch.ones(dim)
        self.beta = torch.zeros(dim)


    def __call__(self, x):
        # forward pass
        xmean = x.mean(1, keepdim=True) # batch mean
        xvar = x.var(1, keepdim=True) # batch variance
        xhat = (x - xmean) / torch.sqrt(xvar + self.eps) # normalize to unit variance
        self.out = self.gamma * xhat + self.beta
        return self.out
        
    def parameters(self):
        return [self.gamma, self.beta]

# Side note

Very few changes have been made on the transformer since its inception <br>
One thing that departs from papers is the position of the Layer Norm - this has recently been implement before the transformation (pre-norm shuffling)

In [66]:
batch_size = 32 # num of independent sequences processed in parallel
block_size = 8 # max context length for prediction
max_iters = 5000
eval_interval = 500
learning_rate = 1e-3
device = torch.device('mps')
eval_iters = 200
n_embd = 32

torch.manual_seed(1337)

with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

# get all unique chars in the text
chars = sorted(list(set(text)))
vocab_size = len(chars)
stoi = {ch:i for i,ch in enumerate(chars)}
itos = {i:ch for i,ch in enumerate(chars)}
encode = lambda s: [stoi[c] for c in s] # encoder
decode = lambda l: ''.join(itos[i] for i in l) # decoder

# train test split
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]

# data loading
def get_batch(split):
    data = train_data if split=='train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y

@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

class Head(nn.Module):
    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size))) # create a buffer which we call tril 

    def forward(self, x):
        B, T, C = x.shape
        k = self.key(x)
        q = self.query(x)
        # attention scores
        wei = q @ k.transpose(-2, -1) * C**-0.5
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        wei = F.softmax(wei, dim=-1)
        # weighted aggregation of values
        v = self.value(x)
        out = wei @ v 
        return out

class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(num_heads*head_size, n_embd) # projection into the residual pathway

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.proj(out)
        return out

class FeedForward(nn.Module):
    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4*n_embd), # the dimensionality is multiplied by 4 as in the paper
            nn.ReLU(),
            nn.Linear(4*n_embd, n_embd) # projection layer
            )
    def forward(self, x):
        return self.net(x)
    
class Block(nn.Module):
    # Transformer block: communication followed by computation
    def __init__(self, n_embd, n_head):
        super().__init__()
        head_size = n_embd // n_head
        self.sa = MultiHeadAttention(n_head, head_size)
        self.ffwd = FeedForward(n_embd)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    def forward(self, x):
        # Residual connection and layer norm
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x

class BigramLanguageModel(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        # token reads off the logits for the next token from look-up table
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        # we also embed the position of the tokens
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        self.blocks = nn.Sequential(
            Block(n_embd, n_head=4),
            Block(n_embd, n_head=4),
            Block(n_embd, n_head=4),
            nn.LayerNorm(n_embd)
        )
        self.lm_head = nn.Linear(n_embd, vocab_size)

    def forward(self, idx, targets=None):
        B, T = idx.shape
        tok_emb = self.token_embedding_table(idx) # (B, T, C)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T, C)
        # then we pass the token embedding and positional embedding
        x = tok_emb + pos_emb
        x = self.blocks(x)
        logits = self.lm_head(x) # (B, T, vocab_size)
        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)
        return logits, loss

    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            # crop idx to fit into block size tokens
            idx_cond = idx[:, -block_size:]
            logits, loss = self(idx_cond)
            # focus on last time step 
            logits = logits[:, -1, :] # becomes (B, C)
            # get probs
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1)
        return idx
    
model = BigramLanguageModel(vocab_size)
m = model.to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

for iter in range(max_iters):
    # eval loss on train and val sets every once a while
    if iter % eval_interval == 0:
        losses = estimate_loss()
        print(f'Step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}')
    
    # sample batch of data
    xb, yb = get_batch('train')
    
    # evaluate loss
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

# generate from the model
context = torch.zeros((1, 1), dtype=torch.long, device=device)
print(decode(m.generate(context, max_new_tokens=500)[0].tolist()))


Step 0: train loss 4.3103, val loss 4.3097
Step 500: train loss 2.3999, val loss 2.4008
Step 1000: train loss 2.2647, val loss 2.2663
Step 1500: train loss 2.1659, val loss 2.1890
Step 2000: train loss 2.1309, val loss 2.1676
Step 2500: train loss 2.0808, val loss 2.1295
Step 3000: train loss 2.0515, val loss 2.1242
Step 3500: train loss 2.0438, val loss 2.1036
Step 4000: train loss 2.0123, val loss 2.0929
Step 4500: train loss 1.9912, val loss 2.0951

Forst.

MENENCESSTPRIARD OLANUS:
Pealy to yoker Offeegeage, Tis me goect you had genone sir
They scalian, in make that an he oubsere. Pick shalk, and dhich you back ellought is of bulb strelant's sir to. To goves as and should king, here selp welt ene?
Wither for head tnotor,
I'll may could
Theureis on array nok?
Thereel me sead in mamany him your myor, do of that, your
Hefquiter make all was im to, and my reat dapies'd gack.

WESTUS:
Werses laguies my doed head had oy son,
Kameiver, eye not sef


# Scaling Up the Transformer

A few cosmetic changes have been added to the code

In [None]:
batch_size = 32 # num of independent sequences processed in parallel
block_size = 8 # max context length for prediction
max_iters = 5000
eval_interval = 500
learning_rate = 1e-3
device = torch.device('mps')
eval_iters = 200
n_embd = 32
n_head = 6
n_layer = 6
dropout = 0.2

torch.manual_seed(1337)

with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

# get all unique chars in the text
chars = sorted(list(set(text)))
vocab_size = len(chars)
stoi = {ch:i for i,ch in enumerate(chars)}
itos = {i:ch for i,ch in enumerate(chars)}
encode = lambda s: [stoi[c] for c in s] # encoder
decode = lambda l: ''.join(itos[i] for i in l) # decoder

# train test split
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]

# data loading
def get_batch(split):
    data = train_data if split=='train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y

@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

class Head(nn.Module):
    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size))) # create a buffer which we call tril 
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B, T, C = x.shape
        k = self.key(x)
        q = self.query(x)
        # attention scores
        wei = q @ k.transpose(-2, -1) * C**-0.5
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        wei = F.softmax(wei, dim=-1)
        # weighted aggregation of values
        v = self.value(x)
        out = wei @ v 
        return out

class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(num_heads*head_size, n_embd) # projection into the residual pathway
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.proj(out)
        return out

class FeedForward(nn.Module):
    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4*n_embd), # the dimensionality is multiplied by 4 as in the paper
            nn.ReLU(),
            nn.Linear(4*n_embd, n_embd), # projection layer
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)
    
class Block(nn.Module):
    # Transformer block: communication followed by computation
    def __init__(self, n_embd, n_head):
        super().__init__()
        head_size = n_embd // n_head
        self.sa = MultiHeadAttention(n_head, head_size)
        self.ffwd = FeedForward(n_embd)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    def forward(self, x):
        # Residual connection and layer norm
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x

class BigramLanguageModel(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        # token reads off the logits for the next token from look-up table
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        # we also embed the position of the tokens
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embd)
        self.lm_head = nn.Linear(n_embd, vocab_size)

    def forward(self, idx, targets=None):
        B, T = idx.shape
        tok_emb = self.token_embedding_table(idx) # (B, T, C)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T, C)
        # then we pass the token embedding and positional embedding
        x = tok_emb + pos_emb
        x = self.blocks(x)
        logits = self.lm_head(x) # (B, T, vocab_size)
        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)
        return logits, loss

    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            # crop idx to fit into block size tokens
            idx_cond = idx[:, -block_size:]
            logits, loss = self(idx_cond)
            # focus on last time step 
            logits = logits[:, -1, :] # becomes (B, C)
            # get probs
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1)
        return idx
    
model = BigramLanguageModel(vocab_size)
m = model.to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

for iter in range(max_iters):
    # eval loss on train and val sets every once a while
    if iter % eval_interval == 0:
        losses = estimate_loss()
        print(f'Step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}')
    
    # sample batch of data
    xb, yb = get_batch('train')
    
    # evaluate loss
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()




In [69]:
# generate from the model
context = torch.zeros((1, 1), dtype=torch.long, device=device)
print(decode(m.generate(context, max_new_tokens=5000)[0].tolist()))


But what thou, madedine;
And his to pentieto bear moster, and from buke'
Puuch;
GoZ of our I law yor all har were thou hose, houlds it? arm, wo preade. To Dugas you priightne.

LOUCEOSS I'lby apidy winlish our when to tour? mace can trillous,
And chears'st hear of a lirunce me to that not ajave, that all all ore? I my repowers!
That so of more to goive mine the coldor. Frohtses as that will of am-dear.
WAROVINCUS:
Hesere hath, breli,
Whith dabibitcllied of thigh dister fa: And to bit llan Theit thy poor fe arquresing the the arewof.
But me
To
The foldwer's in hirh crayes?

LUTHERN MINTIOLEA:
Whis fars. Thes; hild with not this beMace him, as mict the the keet
Thou shall of hour you here counse:
Pearst larde; whines
Your quich madsire tear Sood this and welish, and in that I shall sike my you, thy, sashorticesing I ware hear pan mur?
The muse my nead mare!

Twor hould mothte; and like hart beartiond!
Shald lans.--
ClommessIy,
Thy, ope, the proothes like: you for it thel do, wed ad man:

# Encoder vs Decoder

<img src='transformer.png'>

The code implementation is a decoder-only transformer. This focuses on just generating text (autoregressively) <br>
We just use the triangular mask from the attention <br>
The original paper has an encoder decoder because of its usecase for translation. It expects an encoded French corpus which it is uses to start of the generation using a start and end token. <br>
The generation in that case is conditioned on the French corpus. The encoder encodes the French tokens into vector encodings. <br>
This encoded information is fed into the cross-attention layer i.e. the queries come from x, but the keys and values come from the encodings <br>
This conditions the decoding not just on the past, but also on the fully encoded French text (because full context is necessary for translation)