# Self attention embedding table

Adding route to custom libraries

In [1]:
import sys
import os


dirname = os.path.abspath(os.path.join(os.getcwd(), "..", "..", "scripts/lib"))
sys.path.append(dirname)

## Importing libraries

In [2]:
import torch
import torch.nn as nn
from torch.nn import functional as F

from utils.compile import compileFolder
from utils.tokenizer import CharTokenizer, END_CHAR
from utils.datasets import TextChunksDataset, split_dataset, get_batch

In [3]:
# This module helps to quickly save the weights and load them
from transformers import Module

## Setting Hyperparameters

In [4]:
# The max block size (also known as max context) [in tokens]
block_size = 8

# How much does the test/validation set represent of the total data
test_train_split_ratio = 0.1

# Number of embedding
n_embd = 32

# Device (gpu or cpu)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Setting up the data and other

In [5]:
# Importing the data
raw_data = compileFolder('tate')

# Creating the tokenizer
tokenizer = CharTokenizer(raw_data)

# Tokenizing and creating the dataset object
data = TextChunksDataset(raw_data, block_size, tokenizer)

In [6]:
train_data, test_data = split_dataset(data, 0.1)

## Implementation of the self attention block

We take almost the same structure as the base-embedding structure

> Note: starting from now, we're going to use `cuda` when available

In [7]:
class BigramLanguageModel(Module):
    def __init__(self, vocab_size: int | CharTokenizer | TextChunksDataset, n_embd, device=device, context_size=None):
        """
        If vocab_size is a Dataset with context_length, then no need to specify context_size
        """
        super().__init__()
        if context_size==None:
            if type(vocab_size)==TextChunksDataset:
                context_size = vocab_size.context_length
            else:
                raise Exception("You need to specify the context length")
        self.block_size = context_size
        self.device = device
        if type(vocab_size)==TextChunksDataset:
            vocab_size=len(vocab_size.tokenizer)
        elif type(vocab_size)==CharTokenizer:
            vocab_size=len(vocab_size)
        # each token has a probability distribution of appearing depending on the last token
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd, device=self.device)
        self.position_embedding_table = nn.Embedding(self.block_size, n_embd, device=self.device)
        self.lm_head = nn.Linear(n_embd, vocab_size, device=self.device)

    def forward(self, idx, targets=None):
        B, T = idx.shape

        # idx and targets are both (B,T) tensor of integers
        tok_embd = self.token_embedding_table(idx) # (B,T,C)
        pos_embd = self.position_embedding_table(torch.arange(T, device=self.device)) # (T,C)
        x = tok_embd + pos_embd # (B,T,C)
        logits = self.lm_head(x) # (B,T,vocab_size)

        if targets is None:
            loss=None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss
    
    def generate(self, idx, max_new_tokens: int):
        for _ in range(max_new_tokens):
            # get the predictions
            logits, loss = self(idx[:,-self.block_size:])
            # focus only on the last time step
            logits = logits[:,-1,:]
            # apply softmax to get the probabilities
            probs = F.softmax(logits, dim=1)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled text to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx
        

In [8]:
torch.manual_seed(89)
m = BigramLanguageModel(train_data, n_embd)
xb, yb = train_data[:10]
out = m(xb, yb)
print(tokenizer.decodeText(m.generate(idx= torch.zeros((1,1), dtype=torch.long), max_new_tokens=100)[0]))


LeL:havOnUnmEVLM:bRwFt dM0c?cDtLv-?oJ=?VaRJ7"2M20AM
fn8bFyZ:?m7N0na0wm'qKeL
8T?B?9!ib:6Zbp2A:%tsaM8


The result is random characters

## Training the model

In [9]:
# create a PyTorch optimizer
optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)

In [10]:
batch_size = 32
num_epochs = 100

# verbose
show_loss_each_epoch = 10

def train(optimizer, num_epochs=num_epochs):
    for steps in range(num_epochs):

        # sample a batch of data
        xb, yb = get_batch(train_data, batch_size)

        # evaluate the loss
        logits, loss = m(xb, yb)
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()
        if (steps+1)%show_loss_each_epoch==0:
            print('loss :',loss.item())
    print('done!')

In [11]:
train(optimizer, 1500)

loss : 4.564981937408447
loss : 4.342898845672607
loss : 4.289647102355957
loss : 4.166916370391846
loss : 3.926363468170166
loss : 3.88124942779541
loss : 3.672184705734253
loss : 3.659520387649536
loss : 3.497307300567627
loss : 3.4887804985046387
loss : 3.4159581661224365
loss : 3.3326728343963623
loss : 3.1581273078918457
loss : 3.283029794692993
loss : 3.0645952224731445
loss : 3.0826728343963623
loss : 3.0945539474487305
loss : 2.938164710998535
loss : 3.0291380882263184
loss : 2.8865647315979004
loss : 3.0385732650756836
loss : 2.7839837074279785
loss : 2.8570516109466553
loss : 2.7800652980804443
loss : 3.0109987258911133
loss : 2.7849817276000977
loss : 2.7034881114959717
loss : 2.6601319313049316
loss : 2.9246671199798584
loss : 2.7791905403137207
loss : 2.710498571395874
loss : 2.7385547161102295
loss : 2.6902031898498535
loss : 2.615878105163574
loss : 2.616091251373291
loss : 2.9033002853393555
loss : 2.735739231109619
loss : 2.7156715393066406
loss : 2.7572851181030273
lo

In [12]:
print(tokenizer.decodeText(m.generate(idx= torch.zeros((1,1), dtype=torch.long), max_new_tokens=100)[0]))




Oheve  y angowhay lerly I anoura g te ubulitinst wea youtpe ththin igrst yo t joine war, meat



s


Still the same result as the base model


## The mathematical trick to self attention

In [13]:
# Consider the following toy example

torch.manual_seed(198)
B, T, C = 4, 8, 2 # batch, time, channels
x = torch.randn(B, T, C)
x.shape

torch.Size([4, 8, 2])

In [14]:
# We want x[b, t] = the mean of x[b, i] with i<=t
xbow = torch.zeros((B, T, C))
for b in range(B):
    for t in range(T):
        xprev = x[b,:t+1] # (t,C)
        xbow[b,t] = torch.mean(xprev, 0)
xbow[0]

tensor([[ 0.3597,  0.1501],
        [ 0.3383,  0.7864],
        [ 0.3464,  0.5391],
        [ 0.0438,  0.4631],
        [ 0.0989,  0.1779],
        [ 0.2658,  0.2987],
        [-0.0105,  0.2283],
        [ 0.0798,  0.2066]])

In [15]:
# Second version (using Softmax)
tril = torch.tril(torch.ones(T, T))
wei = torch.zeros((T, T))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)
wei

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])

In [16]:
xbow2 = wei @ x
xbow2[0]

tensor([[ 0.3597,  0.1501],
        [ 0.3383,  0.7864],
        [ 0.3464,  0.5391],
        [ 0.0438,  0.4631],
        [ 0.0989,  0.1779],
        [ 0.2658,  0.2987],
        [-0.0105,  0.2283],
        [ 0.0798,  0.2066]])

In [17]:
# We get the same tensor
torch.allclose(xbow, xbow2)

True

It is the same result

### Implementing an attention head
Now we're going to have a third version: the attention head.

We introduce first this new parameter: the head size

In [18]:
# Head size
head_size = 16

In [19]:
# Third version
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)


In [20]:
k = key(x) # (B, T, head_size)
q = query(x) # (B, T, head_size)
wei = q @ k.transpose(-2, -1) # (B, T, head_size) @ (B, head_size, T) --> (B, T, T)

tril = torch.tril(torch.ones(T, T))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)

In [21]:
wei[0]

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4099, 0.5901, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3780, 0.2279, 0.3942, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1302, 0.5305, 0.1159, 0.2234, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2031, 0.0725, 0.2210, 0.0464, 0.4569, 0.0000, 0.0000, 0.0000],
        [0.1195, 0.0402, 0.1308, 0.3458, 0.3329, 0.0308, 0.0000, 0.0000],
        [0.0555, 0.5743, 0.0458, 0.0614, 0.0072, 0.2289, 0.0269, 0.0000],
        [0.1148, 0.0417, 0.1248, 0.1055, 0.2782, 0.0637, 0.1468, 0.1245]],
       grad_fn=<SelectBackward0>)

We can see that now the weights are not uniform in a row, rather have different values.

So we've implemented the keys, the queries, now we'll implement the values

In [22]:
value = nn.Linear(C, head_size, bias=False) # Same linear structure as the key and query linear models
v = value(x)
out = wei @ v

out.shape

torch.Size([4, 8, 16])

In [23]:
out[0,0] # The value of the first token for the first example

tensor([ 0.2252,  0.0816, -0.0861, -0.0980, -0.1490, -0.0824, -0.2695,  0.0722,
         0.1332, -0.1667, -0.0388, -0.2477, -0.0050,  0.1838, -0.0065,  0.1206],
       grad_fn=<SelectBackward0>)

For variance stability purposes, `wei` tensor needs to be divided by $\sqrt{\text{head\_size}}$

In [24]:
k = key(x) # (B, T, head_size)
q = query(x) # (B, T, head_size)
wei = q @ k.transpose(-2, -1) * (head_size**-0.5) # (B, T, head_size) @ (B, head_size, T) --> (B, T, T)

tril = torch.tril(torch.ones(T, T))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)

v = value(x)
out = wei @ v

Let's create a `Head` class that implements a single head.

In [25]:
class Head(nn.Module):
    """One head of self attention"""

    def __init__(self, n_embd: int, head_size: int, block_size: int) -> None:
        super().__init__()
        self.head_size = head_size
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

    def forward(self, x):
        B, T, C = x.shape
        k = self.key(x)   # (B, T, C)
        q = self.query(x) # (B, T, C)
        # Compute attention score ('affinities')
        wei = q @ k.transpose(-2, -1) * (self.head_size**-0.5) # (B, T, head_size) @ (B, head_size, T) --> (B, T, T)
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)
        wei = F.softmax(wei, dim=-1)
        # Perform the weighted aggregation of the values
        v = self.value(x) # (B, T, C)
        out = wei @ v     # (B, T, T) @ (B, T, C) --> (B, T, C)
        return out        # (B, T, C)

Now let's redefine `BigramLanguageModel`

In [26]:
class BigramLanguageModel(Module):
    def __init__(self, vocab_size: int | CharTokenizer | TextChunksDataset, n_embd, device=device, context_size=None, head_size=16):
        """
        If vocab_size is a Dataset with context_length, then no need to specify context_size
        """
        super().__init__()
        if context_size==None:
            if type(vocab_size)==TextChunksDataset:
                context_size = vocab_size.context_length
            else:
                raise Exception("You need to specify the context length")
        self.block_size = context_size
        self.device = device
        if type(vocab_size)==TextChunksDataset:
            vocab_size=len(vocab_size.tokenizer)
        elif type(vocab_size)==CharTokenizer:
            vocab_size=len(vocab_size)
        # each token has a probability distribution of appearing depending on the last token
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd, device=self.device)
        self.position_embedding_table = nn.Embedding(self.block_size, n_embd, device=self.device)
        self.sa_head = Head(n_embd, n_embd, vocab_size)
        self.lm_head = nn.Linear(n_embd, vocab_size, device=self.device)

    def forward(self, idx, targets=None):
        B, T = idx.shape

        # idx and targets are both (B,T) tensor of integers
        tok_embd = self.token_embedding_table(idx) # (B,T,C)
        pos_embd = self.position_embedding_table(torch.arange(T, device=self.device)) # (T,C)
        x = tok_embd + pos_embd # (B,T,C)
        x = self.sa_head(x)
        logits = self.lm_head(x) # (B,T,vocab_size)

        if targets is None:
            loss=None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss
    
    def generate(self, idx, max_new_tokens: int):
        for _ in range(max_new_tokens):
            # get the predictions
            logits, loss = self(idx[:,-self.block_size:])
            # focus only on the last time step
            logits = logits[:,-1,:]
            # apply softmax to get the probabilities
            probs = F.softmax(logits, dim=1)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled text to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx
        

In [27]:

m = BigramLanguageModel(train_data, n_embd)
xb, yb = train_data[:10]
out = m(xb, yb)
print(tokenizer.decodeText(m.generate(idx= torch.zeros((1,1), dtype=torch.long), max_new_tokens=100)[0]))


gth2C-riwP"wSCi0vOy%C…XsT5/"ftn1
8!!*aG*T8eP$4
mUxCcYTC1H/Wiy
Bl5wX
-6
*Q'*4W/5HC9XSq4CKC*…96>oC2Z…A


### The training

In [28]:
# create a PyTorch optimizer
optimizer = torch.optim.AdamW(m.parameters(), lr=1e-4)

In [29]:
train(optimizer, 5000)

loss : 4.447386741638184
loss : 4.421143531799316
loss : 4.351117134094238
loss : 4.393448352813721
loss : 4.375442028045654
loss : 4.336314678192139
loss : 4.288149356842041
loss : 4.264897346496582
loss : 4.263428211212158
loss : 4.239197731018066
loss : 4.180443286895752
loss : 4.160712242126465
loss : 4.150298595428467
loss : 4.129261016845703
loss : 4.094697952270508
loss : 4.087104797363281
loss : 4.062117099761963
loss : 3.9753551483154297
loss : 3.968000888824463
loss : 3.92964506149292
loss : 3.8719117641448975
loss : 3.9430956840515137
loss : 3.8390052318573
loss : 3.8218882083892822
loss : 3.7549421787261963
loss : 3.7504026889801025
loss : 3.759014368057251
loss : 3.6938560009002686
loss : 3.599816083908081
loss : 3.6268932819366455
loss : 3.5400853157043457
loss : 3.522675037384033
loss : 3.5024819374084473
loss : 3.520388126373291
loss : 3.523282289505005
loss : 3.50000262260437
loss : 3.451000690460205
loss : 3.400134325027466
loss : 3.378018379211426
loss : 3.4012668132

In [30]:
print(tokenizer.decodeText(m.generate(idx= torch.zeros((1,1), dtype=torch.long), max_new_tokens=100)[0]))


I', ed zikd pde bele d oTd ngmycsassWtharo n alep then' wan ind'renst ok
Eiforyal wetel-inig Cro lth


Not yet very good!