In [62]:
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

print("number of characters: ", len(text))
print(text[:100])

number of characters:  1115393
First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You


In [63]:
vocab = sorted(list(set(text)))

In [64]:
class Tokenizer:
    def __init__(self, vocab):
        self.stoi = { ch:i for i,ch in enumerate(vocab) }
        self.itos = { i:ch for i,ch in enumerate(vocab) }

    def encode(self, text):
        ids = [self.stoi[c] for c in text]
        return ids

    def decode(self, ids):
        text = ''.join([self.itos[i] for i in ids])
        return text

tokenizer = Tokenizer(text)
text = """we are building an agi"""
print(tokenizer.encode(text))
print(tokenizer.decode(tokenizer.encode(text)))

[1115386, 1115374, 1115385, 1115387, 1115383, 1115374, 1115385, 1115309, 1115380, 1115389, 1115373, 1115349, 1115389, 1115390, 1115391, 1115385, 1115387, 1115390, 1115385, 1115387, 1115391, 1115389]
we are building an agi


In [65]:
import tiktoken
enc = tiktoken.get_encoding("o200k_base")

In [66]:
text = "we are building an agi"
ids = enc.encode(text)
print(ids)
print(enc.decode(ids))

[854, 553, 6282, 448, 1017, 72]
we are building an agi


In [67]:
import torch
from torch.utils.data import Dataset, DataLoader

class CustomDataset(Dataset):

    def __init__(self, text, tokenizer, max_length, stride):

        self.input_ids = []
        self.target_ids = []

        token_ids = tokenizer.encode(text)

        for i in range(0, len(token_ids) - max_length, stride):

            input_chunk = token_ids[i: i+max_length]
            target_chunk = token_ids[i+1: i+max_length+1]

            self.input_ids.append(torch.tensor(input_chunk))
            self.target_ids.append(torch.tensor(target_chunk))

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        return self.input_ids[idx], self.target_ids[idx]


In [68]:
data = CustomDataset(text, enc, max_length=4, stride=1)

In [69]:
len(data)

2

In [80]:

dataloader = DataLoader(dataset=data, 
                        batch_size=1,
                        num_workers=0,
                        drop_last=True, # drops last batch if its shorter than specified batch_size
                        shuffle=True)

In [81]:
data_iter = iter(dataloader)
x, y = next(data_iter)
x, y

(tensor([[ 854,  553, 6282,  448]]), tensor([[ 553, 6282,  448, 1017]]))

In [92]:
vocab_size = 50257
output_dim = 256
embedding_layer = torch.nn.Embedding(vocab_size, output_dim)
token_embd = embedding_layer(x)

max_length=4
pos_embedding_layer = torch.nn.Embedding(max_length, output_dim)
pos_embd = pos_embedding_layer(torch.arange(max_length)) # 0 1 ... max_length-1

input_embd = token_embd + pos_embd
input_embd.shape

torch.Size([1, 4, 256])

attention mechanism

In [187]:
import torch
inputs = torch.tensor(
  [[0.43, 0.15, 0.89],
   [0.55, 0.87, 0.66],
   [0.57, 0.85, 0.64],
   [0.22, 0.58, 0.33],
   [0.77, 0.25, 0.10],
   [0.05, 0.80, 0.55]]
)

# step 1
attn_scr = inputs @ inputs.T # efficient matmul
print(attn_scr)

# step 2
attn_w = torch.softmax(attn_scr, dim=-1)
print(attn_w)

# step 3
context_vec = attn_w @ inputs
print(context_vec)

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])
tensor([[0.2098, 0.2006, 0.1981, 0.1242, 0.1220, 0.1452],
        [0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581],
        [0.1390, 0.2369, 0.2326, 0.1242, 0.1108, 0.1565],
        [0.1435, 0.2074, 0.2046, 0.1462, 0.1263, 0.1720],
        [0.1526, 0.1958, 0.1975, 0.1367, 0.1879, 0.1295],
        [0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]])
tensor([[0.4421, 0.5931, 0.5790],
        [0.4419, 0.6515, 0.5683],
        [0.4431, 0.6496, 0.5671],
        [0.4304, 0.6298, 0.5510],
        [0.4671, 0.5910, 0.5266],
        [0.4177, 0.6503, 0.5645]])


attention mechanism with trainable matrices

In [188]:
import torch
import torch.nn as nn

inputs = torch.tensor(
  [[0.43, 0.15, 0.89],
   [0.55, 0.87, 0.66],
   [0.57, 0.85, 0.64],
   [0.22, 0.58, 0.33],
   [0.77, 0.25, 0.10],
   [0.05, 0.80, 0.55]]
)

class SelfAttention(nn.Module):

    def __init__(self, d_in, d_out):

        super().__init__()
        # instead of nn.Parameter(torch.rand(d_in, d_out)) as nn.Linear has 
        # optimized weight initialization scheme and it is more efficient
        # in matmul ops when bias=False
        self.w_q = nn.Linear(d_in, d_out, bias=False)
        self.w_k = nn.Linear(d_in, d_out, bias=False)
        self.w_v = nn.Linear(d_in, d_out, bias=False)

    def forward(self, x):

        keys = self.w_k(x)
        queries = self.w_q(x)
        values = self.w_v(x)

        attn_scores = queries @ keys.T

        d_k = keys.shape[-1]
        attn_weights = torch.softmax(attn_scores / d_k ** 0.5, dim=-1) # original paper d_k=64
        context_vectors = attn_weights @ values

        return context_vectors 


sa = SelfAttention(d_in, d_out)
sa(inputs)

tensor([[ 0.0310, -0.5954],
        [ 0.0372, -0.5953],
        [ 0.0371, -0.5954],
        [ 0.0355, -0.5953],
        [ 0.0336, -0.5962],
        [ 0.0367, -0.5948]], grad_fn=<MmBackward0>)

causal attention: prevent model from accessing future tokens

In [223]:
import torch
import torch.nn as nn

inputs = torch.tensor(
  [[0.43, 0.15, 0.89],
   [0.55, 0.87, 0.66],
   [0.57, 0.85, 0.64],
   [0.22, 0.58, 0.33],
   [0.77, 0.25, 0.10],
   [0.05, 0.80, 0.55]]
)

class CasualAttention(nn.Module):

    def __init__(self, d_in, d_out, context_length, bias=False):

        super().__init__()
        # instead of nn.Parameter(torch.rand(d_in, d_out)) as nn.Linear has 
        # optimized weight initialization scheme and it is more efficient
        # in matmul ops when bias=False
        self.w_q = nn.Linear(d_in, d_out, bias=False)
        self.w_k = nn.Linear(d_in, d_out, bias=False)
        self.w_v = nn.Linear(d_in, d_out, bias=False)

        self.d_out = d_out
        self.dropout = nn.Dropout(0.5)

        self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))

    def forward(self, x):
        
        b, tokens, d_in = x.shape
        keys = self.w_k(x) # (2, 6, 2)
        queries = self.w_q(x)
        values = self.w_v(x)

        attn_scores = queries @ keys.transpose(1, 2) # (2, 6, 3) @ (2, 3, 6) -> (2, 6, 6)

        masked = attn_scores.masked_fill(self.mask.bool()[:tokens, :tokens], -torch.inf) # masking with -inf

        d_k = keys.shape[-1]
        attn_weights = torch.softmax(masked / d_k ** 0.5, dim=-1) # original paper d_k=64

        # drop by 50% and remaining are rescale by a factor of 1/0.5 -> 2
        attn_weights = self.dropout(attn_weights) # reduce overfitting

        context_vectors = attn_weights @ values # (2, 6, 6) @ (2, 6, 2)

        return context_vectors # (2, 6, 2)

batch = torch.stack((inputs, inputs), dim=0)
context_length = batch.shape[1]
ca = CasualAttention(d_in, d_out, context_length)
ca(batch)

tensor([[[ 0.5284,  0.7834],
         [ 0.2639,  0.3914],
         [ 0.0757,  0.4419],
         [ 0.0026,  0.3854],
         [-0.0873,  0.3367],
         [-0.1483,  0.1921]],

        [[ 0.5284,  0.7834],
         [ 0.2639,  0.3914],
         [ 0.0000,  0.0000],
         [-0.0531,  0.0556],
         [-0.0637,  0.1064],
         [-0.0464,  0.2675]]], grad_fn=<UnsafeViewBackward0>)

multi-head attention: split attention into multiple heads where each head learns different aspect of the data and then combine the outputs

In [225]:
inputs = torch.tensor(
  [[0.43, 0.15, 0.89],
   [0.55, 0.87, 0.66],
   [0.57, 0.85, 0.64],
   [0.22, 0.58, 0.33],
   [0.77, 0.25, 0.10],
   [0.05, 0.80, 0.55]]
)

class MultiHeadAttentionStack(nn.Module):

    def __init__(self, d_in, d_out, context_length, n_head, bias=False):

        super().__init__()
        self.heads = nn.ModuleList([CasualAttention(d_in, d_out, context_length) for _ in range(n_head)])

    def forward(self, x):
        return torch.cat([head(x) for head in self.heads], dim=-1)


batch = torch.stack((inputs, inputs, inputs), dim=0)
context_length = batch.shape[1] # number of tokens (3, 6, 3)
mhas = MultiHeadAttentionStack(d_in, d_out, context_length, n_head=2)
mhas(batch) # (3, 6, 4) -> 2 batch size, 6 number of tokens, 4 context vector size * number of heads


tensor([[[-0.4583, -0.7410,  0.3520,  0.9139],
         [-0.1937, -0.3133,  0.0780, -0.0513],
         [-0.3056, -0.2787,  0.0538, -0.0353],
         [-0.1013, -0.1639,  0.0737, -0.0828],
         [-0.0936, -0.1312, -0.0721, -0.0675],
         [-0.1440, -0.1313,  0.1875,  0.1113]],

        [[-0.4583, -0.7410,  0.3520,  0.9139],
         [-0.6871, -0.7642,  0.0780, -0.0513],
         [-0.4326, -0.4813,  0.0538, -0.0353],
         [-0.6828, -0.6833,  0.0414, -0.0272],
         [-0.3691, -0.3743,  0.1317,  0.1418],
         [-0.1589, -0.1376,  0.0977,  0.0937]],

        [[-0.4583, -0.7410,  0.0000,  0.0000],
         [-0.1937, -0.3133,  0.2702,  0.4477],
         [-0.6161, -0.5624,  0.1724,  0.3077],
         [-0.1360, -0.1128,  0.1374,  0.1904],
         [ 0.0000,  0.0000,  0.0259, -0.0216],
         [-0.4578, -0.4484,  0.0282, -0.0378]]], grad_fn=<CatBackward0>)

lets combine causal attention and multi-head attention code to compute the attention in parallel. currently we're just stacking multiple causal attention blocks. 

In [228]:
import torch
import torch.nn as nn

inputs = torch.tensor(
  [[0.43, 0.15, 0.89],
   [0.55, 0.87, 0.66],
   [0.57, 0.85, 0.64],
   [0.22, 0.58, 0.33],
   [0.77, 0.25, 0.10],
   [0.05, 0.80, 0.55]]
)

class MultiHeadAttention(nn.Module):

    def __init__(self, d_in, d_out, context_length, n_heads, bias=False):

        super().__init__()
        
        # d_out must be divisible by n_heads to distribute it perfectly
        self.d_out = d_out
        self.n_heads = n_heads
        self.head_dim = d_out // n_heads

        self.w_q = nn.Linear(d_in, d_out, bias=False)
        self.w_k = nn.Linear(d_in, d_out, bias=False)
        self.w_v = nn.Linear(d_in, d_out, bias=False)

        self.out_proj = nn.Linear(d_out, d_out)
        
        self.dropout = nn.Dropout(0.5)

        self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))


    def forward(self, x):
        
        b, n_tokens, d_out = x.shape

        keys = self.w_k(x)
        queries = self.w_q(x)
        values = self.w_v(x)

        # split d_out into n_heads and head_dim
        keys = keys.view(b, n_tokens, self.n_heads, self.head_dim)
        queries = queries.view(b, n_tokens, self.n_heads, self.head_dim)
        values = values.view(b, n_tokens, self.n_heads, self.head_dim)

        # align the matrices
        keys = keys.transpose(1, 2)
        queries = queries.transpose(1, 2)
        values = values.transpose(1, 2)

        attn_scores = queries @ keys.transpose(2, 3)

        # create mask of shape (6, 6) with true and false
        mask = self.mask.bool()[:n_tokens, :n_tokens]
        
        # use mask to fill the attention scores matrix 
        attn_scores.masked_fill(mask, -torch.inf)
        
        d_k = keys.shape[-1]
        attn_weights = torch.softmax(attn_scores / d_k ** 0.5, dim=-1)

        attn_weights = self.dropout(attn_weights)

        context_vec = (attn_weights @ values).transpose(1, 2)

        context_vec = context_vec.contiguous().view(b, n_tokens, self.d_out)

        context_vec = self.out_proj(context_vec) # linear projection

        return context_vec

batch = torch.stack((inputs, inputs, inputs), dim=0)
context_length = batch.shape[1] # number of tokens (3, 6, 3)
mha = MultiHeadAttention(d_in, d_out, context_length, n_heads=2)
mha(batch)
        

tensor([[[-0.5277, -0.4210],
         [-0.5655, -0.2873],
         [-0.3715, -0.3506],
         [-0.3938, -0.2946],
         [-0.5428, -0.2676],
         [-0.5057, -0.4016]],

        [[-0.5692, -0.4380],
         [-0.5212, -0.2434],
         [-0.4669, -0.4453],
         [-0.4240, -0.3500],
         [-0.5441, -0.3412],
         [-0.5578, -0.3455]],

        [[-0.4985, -0.3627],
         [-0.5168, -0.3241],
         [-0.4536, -0.3536],
         [-0.5506, -0.2607],
         [-0.5175, -0.4166],
         [-0.5060, -0.3071]]], grad_fn=<ViewBackward0>)