In [33]:
with open('data/input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

print("number of characters: ", len(text))
print(text[:100])

number of characters:  2579948
Three Rings for the Elven-kings under the sky,
               Seven for the Dwarf-lords in their hal


In [28]:
vocab = sorted(list(set(text)))
len(vocab)

82

tokenizer

In [29]:
class Tokenizer:
    def __init__(self, vocab):
        self.stoi = { ch:i for i,ch in enumerate(vocab) }
        self.itos = { i:ch for i,ch in enumerate(vocab) }

    def encode(self, text):
        ids = [self.stoi[c] for c in text]
        return ids

    def decode(self, ids):
        text = ''.join([self.itos[i] for i in ids])
        return text

tokenizer = Tokenizer(text)
text = """we are building an agi"""
print(tokenizer.encode(text))
print(tokenizer.decode(tokenizer.encode(text)))

[2579849, 2579885, 2579886, 2579888, 2579859, 2579885, 2579886, 2579877, 2579824, 2579889, 2579870, 2579890, 2579889, 2579827, 2579715, 2579886, 2579888, 2579827, 2579886, 2579888, 2579715, 2579889]
we are building an agi


just use tiktoken

In [5]:
import tiktoken
enc = tiktoken.get_encoding("o200k_base")

In [6]:
text = "we are building an agi"
ids = enc.encode(text)
print(ids)
print(enc.decode(ids))

[854, 553, 6282, 448, 1017, 72]
we are building an agi


build a dataset class

In [34]:
import torch
from torch.utils.data import Dataset, DataLoader

class CustomDataset(Dataset):

    def __init__(self, text, tokenizer, max_length, stride):

        self.input_ids = []
        self.target_ids = []

        token_ids = tokenizer.encode(text)

        for i in range(0, len(token_ids) - max_length, stride):

            input_chunk = token_ids[i: i+max_length]
            target_chunk = token_ids[i+1: i+max_length+1]

            self.input_ids.append(torch.tensor(input_chunk))
            self.target_ids.append(torch.tensor(target_chunk))

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        return self.input_ids[idx], self.target_ids[idx]


In [35]:
data = CustomDataset(text, enc, max_length=4, stride=1)

In [39]:
len(data)

623287

In [10]:

dataloader = DataLoader(dataset=data, 
                        batch_size=1,
                        num_workers=0,
                        drop_last=True, # drops last batch if its shorter than specified batch_size
                        shuffle=True)

In [11]:
data_iter = iter(dataloader)
x, y = next(data_iter)
x, y

(tensor([[ 553, 6282,  448, 1017]]), tensor([[6282,  448, 1017,   72]]))

In [12]:
vocab_size = 50257
output_dim = 256
embedding_layer = torch.nn.Embedding(vocab_size, output_dim)
token_embd = embedding_layer(x)

max_length=4
pos_embedding_layer = torch.nn.Embedding(max_length, output_dim)
pos_embd = pos_embedding_layer(torch.arange(max_length)) # 0 1 ... max_length-1

input_embd = token_embd + pos_embd
input_embd.shape

torch.Size([1, 4, 256])

attention mechanism

In [13]:
import torch
inputs = torch.tensor(
  [[0.43, 0.15, 0.89],
   [0.55, 0.87, 0.66],
   [0.57, 0.85, 0.64],
   [0.22, 0.58, 0.33],
   [0.77, 0.25, 0.10],
   [0.05, 0.80, 0.55]]
)

# step 1
attn_scr = inputs @ inputs.T # efficient matmul
print(attn_scr)

# step 2
attn_w = torch.softmax(attn_scr, dim=-1)
print(attn_w)

# step 3
context_vec = attn_w @ inputs
print(context_vec)

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])
tensor([[0.2098, 0.2006, 0.1981, 0.1242, 0.1220, 0.1452],
        [0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581],
        [0.1390, 0.2369, 0.2326, 0.1242, 0.1108, 0.1565],
        [0.1435, 0.2074, 0.2046, 0.1462, 0.1263, 0.1720],
        [0.1526, 0.1958, 0.1975, 0.1367, 0.1879, 0.1295],
        [0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]])
tensor([[0.4421, 0.5931, 0.5790],
        [0.4419, 0.6515, 0.5683],
        [0.4431, 0.6496, 0.5671],
        [0.4304, 0.6298, 0.5510],
        [0.4671, 0.5910, 0.5266],
        [0.4177, 0.6503, 0.5645]])


attention mechanism with trainable matrices

In [15]:
import torch
import torch.nn as nn

inputs = torch.tensor(
  [[0.43, 0.15, 0.89],
   [0.55, 0.87, 0.66],
   [0.57, 0.85, 0.64],
   [0.22, 0.58, 0.33],
   [0.77, 0.25, 0.10],
   [0.05, 0.80, 0.55]]
)

class SelfAttention(nn.Module):

    def __init__(self, d_in, d_out):

        super().__init__()
        # instead of nn.Parameter(torch.rand(d_in, d_out)) as nn.Linear has 
        # optimized weight initialization scheme and it is more efficient
        # in matmul ops when bias=False
        self.w_q = nn.Linear(d_in, d_out, bias=False)
        self.w_k = nn.Linear(d_in, d_out, bias=False)
        self.w_v = nn.Linear(d_in, d_out, bias=False)

    def forward(self, x):

        keys = self.w_k(x)
        queries = self.w_q(x)
        values = self.w_v(x)

        attn_scores = queries @ keys.T

        d_k = keys.shape[-1]
        attn_weights = torch.softmax(attn_scores / d_k ** 0.5, dim=-1) # original paper d_k=64
        
        context_vectors = attn_weights @ values

        return context_vectors 

d_in = 3
d_out = 2
sa = SelfAttention(d_in, d_out)
sa(inputs)

tensor([[ 0.0033, -0.0760],
        [ 0.0025, -0.0773],
        [ 0.0024, -0.0775],
        [ 0.0010, -0.0804],
        [ 0.0001, -0.0821],
        [ 0.0019, -0.0786]], grad_fn=<MmBackward0>)

causal attention: prevent model from accessing future tokens

In [37]:
import torch
import torch.nn as nn

inputs = torch.tensor(
  [[0.43, 0.15, 0.89],
   [0.55, 0.87, 0.66],
   [0.57, 0.85, 0.64],
   [0.22, 0.58, 0.33],
   [0.77, 0.25, 0.10],
   [0.05, 0.80, 0.55]]
)

class CasualAttention(nn.Module):

    def __init__(self, d_in, d_out, context_length, bias=False):

        super().__init__()
        # instead of nn.Parameter(torch.rand(d_in, d_out)) as nn.Linear has 
        # optimized weight initialization scheme and it is more efficient
        # in matmul ops when bias=False
        self.w_q = nn.Linear(d_in, d_out, bias=False)
        self.w_k = nn.Linear(d_in, d_out, bias=False)
        self.w_v = nn.Linear(d_in, d_out, bias=False)

        self.d_out = d_out
        self.dropout = nn.Dropout(0.5) # prevent overfitting

        self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1)) # mask to prevent model from seeing future tokens

    def forward(self, x):
        
        b, tokens, d_in = x.shape
        
        keys = self.w_k(x)
        queries = self.w_q(x)
        values = self.w_v(x)

        attn_scores = queries @ keys.transpose(1, 2)

        masked = attn_scores.masked_fill(self.mask.bool()[:tokens, :tokens], -torch.inf) # masking with -inf

        d_k = keys.shape[-1]
        attn_weights = torch.softmax(masked / d_k ** 0.5, dim=-1) # original paper d_k=64

        # drop by 50% and remaining are rescale by a factor of 1/0.5 -> 2
        attn_weights = self.dropout(attn_weights)

        context_vectors = attn_weights @ values

        return context_vectors

d_out = 2
batch = torch.stack((inputs, inputs), dim=0)
batch_size, context_length, d_in = batch.shape
ca = CasualAttention(d_in, d_out, context_length)
ca(batch)

tensor([[[ 0.0000,  0.0000],
         [-0.2049, -0.1062],
         [-0.1351, -0.0700],
         [ 0.0655, -0.0176],
         [-0.2211, -0.0101],
         [ 0.0513, -0.0541]],

        [[ 0.5709,  0.1466],
         [ 0.2761,  0.0709],
         [ 0.0471, -0.0232],
         [-0.0386, -0.0638],
         [-0.1129, -0.0938],
         [-0.0184, -0.0903]]], grad_fn=<UnsafeViewBackward0>)

multi-head attention: split attention into multiple heads where each head learns different aspect of the data and then combine the outputs

In [38]:
inputs = torch.tensor(
  [[0.43, 0.15, 0.89],
   [0.55, 0.87, 0.66],
   [0.57, 0.85, 0.64],
   [0.22, 0.58, 0.33],
   [0.77, 0.25, 0.10],
   [0.05, 0.80, 0.55]]
)

class MultiHeadAttentionStack(nn.Module):

    def __init__(self, d_in, d_out, context_length, n_head, bias=False):

        super().__init__()
        self.heads = nn.ModuleList([CasualAttention(d_in, d_out, context_length) for _ in range(n_head)])

    def forward(self, x):
        return torch.cat([head(x) for head in self.heads], dim=-1)

d_in = 3
d_out = 2
batch = torch.stack((inputs, inputs, inputs), dim=0)
batch_size, context_length, d_in = batch.shape
mhas = MultiHeadAttentionStack(d_in, d_out, context_length, n_head=2)
mhas(batch)


tensor([[[ 0.0000,  0.0000,  0.0000,  0.0000],
         [-0.9600, -0.6369, -0.2028,  0.6374],
         [-1.7226, -0.9330,  0.0000,  0.0000],
         [ 0.0000,  0.0000, -0.2691,  0.9411],
         [-0.9658, -0.6532, -0.2174,  0.7612],
         [-0.3829, -0.1564, -0.1982,  0.5920]],

        [[ 0.0000,  0.0000,  0.0000,  0.0000],
         [-0.9600, -0.6369, -0.3205,  1.2088],
         [-1.7226, -0.9330, -0.2739,  0.8554],
         [-0.7407, -0.5015, -0.1639,  0.6102],
         [-0.2062, -0.2165, -0.0476,  0.2311],
         [-0.6981, -0.5547, -0.2338,  0.7444]],

        [[ 0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000, -0.3205,  1.2088],
         [-0.6381, -0.4233, -0.1374,  0.4262],
         [-0.4829, -0.3204, -0.1597,  0.4945],
         [-1.1720, -0.8698, -0.2396,  0.7670],
         [-0.3193, -0.2118, -0.1086,  0.4078]]], grad_fn=<CatBackward0>)

lets combine causal attention and multi-head attention code to compute the attention in parallel. currently we're just stacking multiple causal attention blocks. 

In [22]:
import torch
import torch.nn as nn

inputs = torch.tensor(
  [[0.43, 0.15, 0.89],
   [0.55, 0.87, 0.66],
   [0.57, 0.85, 0.64],
   [0.22, 0.58, 0.33],
   [0.77, 0.25, 0.10],
   [0.05, 0.80, 0.55]]
)

class MultiHeadAttention(nn.Module):

    def __init__(self, d_in, d_out, context_length, n_heads, bias=False):

        super().__init__()
        
        # d_out must be divisible by n_heads to distribute it perfectly across heads
        self.d_out = d_out
        self.n_heads = n_heads
        self.head_dim = d_out // n_heads

        self.w_q = nn.Linear(d_in, d_out, bias=False)
        self.w_k = nn.Linear(d_in, d_out, bias=False)
        self.w_v = nn.Linear(d_in, d_out, bias=False)

        self.out_proj = nn.Linear(d_out, d_out)
        
        self.dropout = nn.Dropout(0.5)

        self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))


    def forward(self, x):
        
        b, n_tokens, d_in = x.shape

        keys = self.w_k(x)
        queries = self.w_q(x)
        values = self.w_v(x)

        # split d_out into n_heads and head_dim
        keys = keys.view(b, n_tokens, self.n_heads, self.head_dim)
        queries = queries.view(b, n_tokens, self.n_heads, self.head_dim)
        values = values.view(b, n_tokens, self.n_heads, self.head_dim)

        # align all the three matrices
        keys = keys.transpose(1, 2)
        queries = queries.transpose(1, 2)
        values = values.transpose(1, 2)

        attn_scores = queries @ keys.transpose(2, 3)

        # create mask of shape (6, 6) with true and false
        mask = self.mask.bool()[:n_tokens, :n_tokens]
        
        # use mask to fill the attention scores matrix 
        attn_scores.masked_fill(mask, -torch.inf)
        
        d_k = keys.shape[-1]
        attn_weights = torch.softmax(attn_scores / d_k ** 0.5, dim=-1)

        attn_weights = self.dropout(attn_weights)

        context_vec = (attn_weights @ values).transpose(1, 2) # transpose back to the shape (b, n_tokens, n_heads, head_dim)

        context_vec = context_vec.contiguous().view(b, n_tokens, self.d_out)

        context_vec = self.out_proj(context_vec) # linear projection

        return context_vec

d_out = 2
batch = torch.stack((inputs, inputs, inputs), dim=0)
batch_size, context_length, d_in = batch.shape
mha = MultiHeadAttention(d_in, d_out, context_length, n_heads=2)
mha(batch)


tensor([[[ 0.6093, -0.2380],
         [ 0.6680, -0.1577],
         [ 0.6637, -0.3462],
         [ 0.4764, -0.4829],
         [ 0.4008, -0.6265],
         [ 0.4819, -0.5453]],

        [[ 0.6528, -0.4079],
         [ 0.5989, -0.3880],
         [ 0.6899, -0.4989],
         [ 0.6011, -0.6095],
         [ 0.5705, -0.3075],
         [ 0.5086, -0.5835]],

        [[ 0.5433, -0.2040],
         [ 0.5737, -0.3353],
         [ 0.5032, -0.5368],
         [ 0.6032, -0.2582],
         [ 0.5619, -0.4527],
         [ 0.6547, -0.3854]]], grad_fn=<ViewBackward0>)