# Self attention embedding table

Adding route to custom libraries

In [1]:
import sys
import os


dirname = os.path.abspath(os.path.join(os.getcwd(), "..", "..", "scripts/lib"))
sys.path.append(dirname)

## Importing libraries

In [2]:
import torch
import torch.nn as nn
from torch.nn import functional as F

from utils.compile import compileFolder
from utils.tokenizer import CharTokenizer, END_CHAR
from utils.datasets import TextChunksDataset, split_dataset, get_batch

In [3]:
# This module helps to quickly save the weights and load them
from transformers import Module

## Setting Hyperparameters

In [4]:
# The max block size (also known as max context) [in tokens]
block_size = 8

# How much does the test/validation set represent of the total data
test_train_split_ratio = 0.1

# Number of embedding
n_embed = 32

## Setting up the data and other

In [5]:
# Importing the data
raw_data = compileFolder('tate')

# Creating the tokenizer
tokenizer = CharTokenizer(raw_data)

# Tokenizing and creating the dataset object
data = TextChunksDataset(raw_data, block_size, tokenizer)

In [6]:
train_data, test_data = split_dataset(data, 0.1)

## The mathematical trick to self attention

In [7]:
# Consider the following toy example

torch.manual_seed(198)
B, T, C = 4, 8, 2 # batch, time, channels
x = torch.randn(B, T, C)
x.shape

torch.Size([4, 8, 2])

In [8]:
# We want x[b, t] = the mean of x[b, i] with i<=t
xbow = torch.zeros((B, T, C))
for b in range(B):
    for t in range(T):
        xprev = x[b,:t+1] # (t,C)
        xbow[b,t] = torch.mean(xprev, 0)
xbow[0]

tensor([[ 0.3597,  0.1501],
        [ 0.3383,  0.7864],
        [ 0.3464,  0.5391],
        [ 0.0438,  0.4631],
        [ 0.0989,  0.1779],
        [ 0.2658,  0.2987],
        [-0.0105,  0.2283],
        [ 0.0798,  0.2066]])

In [9]:
# Other version (using Softmax)
tril = torch.tril(torch.ones(T, T))
wei = torch.zeros((T, T))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)
wei

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])

In [10]:
xbow2 = wei @ x
xbow2[0]

tensor([[ 0.3597,  0.1501],
        [ 0.3383,  0.7864],
        [ 0.3464,  0.5391],
        [ 0.0438,  0.4631],
        [ 0.0989,  0.1779],
        [ 0.2658,  0.2987],
        [-0.0105,  0.2283],
        [ 0.0798,  0.2066]])

In [11]:
# We get the same tensor
torch.allclose(xbow, xbow2)

True

## Implementation of the self attention block

We take almost the same structure as the base-embedding structure

> Note: starting from now, we're going to use `cuda` when available

In [12]:

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class BigramLanguageModel(Module):
    def __init__(self, vocab_size: int | CharTokenizer | TextChunksDataset, n_embed, device=device):
        super().__init__()
        self.device = device
        if type(vocab_size)==TextChunksDataset:
            vocab_size=len(vocab_size.tokenizer)
        elif type(vocab_size)==CharTokenizer:
            vocab_size=len(vocab_size)
        # each token has a probability distribution of appearing depending on the last token
        self.token_embedding_table = nn.Embedding(vocab_size, n_embed, self.device)

    def forward(self, idx, targets=None):
        # idx and targets are both (B,T) tensor of integers
        logits = self.token_embedding_table(idx) # (B,T,C)

        if targets is None:
            loss=None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss
    
    def generate(self, idx, max_new_tokens: int):
        for _ in range(max_new_tokens):
            # get the predictions
            logits, loss = self(idx)
            # focus only on the last time step
            logits = logits[:,-1,:]
            # apply softmax to get the probabilities
            probs = F.softmax(logits, dim=1)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled text to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx
        