In [1]:
import torch
import numpy as np

### Self-Attention

In [2]:
vocab_size = 50000



sentence = "Life is short eat dessert first"

dct = {word:i for i, word in enumerate(sorted(sentence.split()))}
sentence_int = torch.tensor([dct[s] for s in sentence.split()])

torch.manual_seed(123)
embed = torch.nn.Embedding(vocab_size, 3)

embedded_sentence = embed(sentence_int).detach()

print(f"Embedded vector: {embedded_sentence}")
print(f'Shape of embeddeding {embedded_sentence.shape}')

Embedded vector: tensor([[ 0.3374, -0.1778, -0.3035],
        [ 0.1794,  1.8951,  0.4954],
        [ 0.2692, -0.0770, -1.0205],
        [-0.2196, -0.3792,  0.7671],
        [-0.5880,  0.3486,  0.6603],
        [-1.1925,  0.6984, -1.4097]])
Shape of embeddeding torch.Size([6, 3])


In [3]:
torch.manual_seed(123)
d = embedded_sentence.shape[1]

d_q, d_k, d_v = 2,2,4

W_q = torch.nn.Parameter(torch.rand(d, d_q))
W_k = torch.nn.Parameter(torch.rand(d, d_k))
W_v = torch.nn.Parameter(torch.rand(d, d_v))


print(W_q.shape)
print(W_k.shape)
print(W_v.shape)

torch.Size([3, 2])
torch.Size([3, 2])
torch.Size([3, 4])


In [4]:
x_2 = embedded_sentence[1]
q_2 = x_2 @ W_q
k_2 = x_2 @ W_k
v_2 = x_2 @ W_v

print(f"Second token query vector: {q_2} . Shape: {q_2.shape}")
print(f"Second token key vector: {k_2} . Shape: {k_2.shape}")
print(f"Second token value vector: {v_2} . Shape: {v_2.shape}")

Second token query vector: tensor([0.5667, 1.8269], grad_fn=<SqueezeBackward4>) . Shape: torch.Size([2])
Second token key vector: tensor([0.5295, 1.7355], grad_fn=<SqueezeBackward4>) . Shape: torch.Size([2])
Second token value vector: tensor([0.6612, 1.8972, 1.0963, 1.8106], grad_fn=<SqueezeBackward4>) . Shape: torch.Size([4])


In [12]:
keys = embedded_sentence @ W_k
values = embedded_sentence @ W_v

print(f"Keys : {keys} Shape of Keys : {keys.shape}")
print(f"Values : {values} Shape of values : {values.shape}")

Keys : tensor([[-0.0823, -0.3031],
        [ 0.5295,  1.7355],
        [-0.2991, -0.7295],
        [ 0.1420,  0.2291],
        [ 0.1920,  0.6467],
        [-0.4788, -0.5835]], grad_fn=<MmBackward0>) Shape of Keys : torch.Size([6, 2])
Values : tensor([[-0.2546, -0.2608, -0.1544, -0.2801],
        [ 0.6612,  1.8972,  1.0963,  1.8106],
        [-0.8598, -0.6161, -0.5940, -0.9455],
        [ 0.5932,  0.0981,  0.2741,  0.4151],
        [ 0.5605,  0.5645,  0.3676,  0.6429],
        [-1.2107, -0.4929, -1.0081, -1.4031]], grad_fn=<MmBackward0>) Shape of values : torch.Size([6, 4])


In [13]:
omega_24 = q_2.dot(keys[4])
print(omega_24)

omega_2 = q_2 @ keys.T
print(f"Unormalized attention weights of each word to the second input token {omega_2}")


tensor(1.2903, grad_fn=<DotBackward0>)
Unormalized attention weights of each word to the second input token tensor([-0.6004,  3.4707, -1.5023,  0.4991,  1.2903, -1.3374],
       grad_fn=<SqueezeBackward4>)


In [16]:
import torch.nn.functional as F
alpha_2 = F.softmax(omega_2/d_k**0.5, dim=-1)

print(f"Normalized attention weights with respect to second token {alpha_2}")


Normalized attention weights with respect to second token tensor([0.0386, 0.6870, 0.0204, 0.0840, 0.1470, 0.0229],
       grad_fn=<SoftmaxBackward0>)


In [17]:
context_vector_2 = alpha_2 @ values


print(context_vector_2)
print(context_vector_2.shape)

tensor([0.5313, 1.3607, 0.7891, 1.3110], grad_fn=<SqueezeBackward4>)
torch.Size([4])


In [20]:
import torch.nn as nn
class SelfAttention(nn.Module):
    def __init__(self, d_in, d_out_kq, d_out_v):
        super().__init__()
        self.d_out_kq = d_out_kq
        self.W_k = nn.Parameter(torch.rand(d_in, d_out_kq))
        self.W_q = nn.Parameter(torch.rand(d_in, d_out_kq))
        self.W_v = nn.Parameter(torch.rand(d_in, d_out_v))

    def forward(self, x):
        keys = x @ self.W_k
        values = x @ self.W_v
        queries = x @ self.W_q

        attn_scores = queries @ keys.T

        attn_weights = torch.softmax(attn_scores / self.d_out_kq**0.5, dim=-1)
        context = attn_weights @ values
        return context

In [21]:
torch.manual_seed(123)
sa = SelfAttention(3, 2, 4)
print(sa(embedded_sentence))

tensor([[-0.2358,  0.0274, -0.1529, -0.1919],
        [ 0.5449,  1.4054,  0.8220,  1.3609],
        [-0.4417, -0.1620, -0.3432, -0.4809],
        [ 0.0351,  0.3622,  0.1272,  0.2443],
        [ 0.2236,  0.6611,  0.3469,  0.5929],
        [-0.3897, -0.1226, -0.2999, -0.4153]], grad_fn=<MmBackward0>)


### Multi-Head Attention

In [28]:
class MultiHeadAttentionWrapper(nn.Module):
    def __init__(self, d_in, d_out_kq, d_out_v, num_heads):
        super().__init__()
        self.heads = nn.ModuleList([SelfAttention(d_in, d_out_kq, d_out_v) for _ in range(num_heads)])
    
    def forward(self, x):
        return torch.cat([head(x) for head in self.heads], dim=-1)

In [29]:
torch.manual_seed(123)

d_in, d_out_kq, d_out_v = 3,2,1

sa = SelfAttention(d_in, d_out_kq, d_out_v)

sa(embedded_sentence)

tensor([[-0.0529],
        [ 0.4134],
        [-0.1403],
        [ 0.0794],
        [ 0.1848],
        [-0.1201]], grad_fn=<MmBackward0>)

In [30]:
torch.manual_seed(123)

mha = MultiHeadAttentionWrapper(d_in, d_out_kq, d_out_v, num_heads=4)

mha(embedded_sentence)

tensor([[-0.0529,  0.1157,  0.1206,  0.0557],
        [ 0.4134,  1.6489,  1.4377,  0.9843],
        [-0.1403, -0.0389, -0.0751,  0.0052],
        [ 0.0794,  0.1602,  0.3335, -0.1488],
        [ 0.1848,  0.3579,  0.5004, -0.0543],
        [-0.1201, -0.1457, -0.1849, -0.2312]], grad_fn=<CatBackward0>)

### Cross Attention

In [40]:
class CrossAttention(nn.Module):
    def __init__(self, d_in, d_out_kq, d_out_v):
        super().__init__()
        self.W_q = nn.Parameter(torch.rand(d_in, d_out_kq))
        self.W_k = nn.Parameter(torch.rand(d_in, d_out_kq))
        self.W_v = nn.Parameter(torch.rand(d_in, d_out_v))
        self.d_out_kq = d_out_kq
    
    def forward(self, x_1, x_2):
        queries_1 = x_1 @ self.W_q
        keys_2 = x_2 @ self.W_k
        values_2 = x_2 @ self.W_v

        attn_scores = queries_1 @ keys_2.T

        attn_weights = torch.softmax(attn_scores/self.d_out_kq**0.5, dim=-1)

        context_vectors= attn_weights @ values_2

        return context_vectors

In [42]:
torch.manual_seed(123)

d_in, d_out_kq, d_out_v = 3,2,4

crossattn = CrossAttention(d_in, d_out_kq, d_out_v)

first_input = embedded_sentence
second_input = torch.rand(8, d_in)

print("First input shape", first_input.shape)
print("Second input shape", second_input.shape)

context_vectors = crossattn(first_input, second_input)
print(context_vectors)
print("Output shape:", context_vectors.shape)

First input shape torch.Size([6, 3])
Second input shape torch.Size([8, 3])
tensor([[0.4231, 0.8665, 0.6503, 1.0042],
        [0.4874, 0.9718, 0.7359, 1.1353],
        [0.4054, 0.8359, 0.6258, 0.9667],
        [0.4357, 0.8886, 0.6678, 1.0311],
        [0.4429, 0.9006, 0.6775, 1.0460],
        [0.3860, 0.8021, 0.5985, 0.9250]], grad_fn=<MmBackward0>)
Output shape: torch.Size([6, 4])


In [83]:
class MaskedSelfAttention(nn.Module):
    def __init__(self, d_in, d_out_kq, d_out_v, mask=None):
        super().__init__()
        self.W_q = nn.Parameter(torch.rand(d_in, d_out_kq))
        self.W_k = nn.Parameter(torch.rand(d_in, d_out_kq))
        self.W_v = nn.Parameter(torch.rand(d_in, d_out_v))
        self.d_out_kq = d_out_kq
        self.d_in = d_in
        self.mask = mask

    def forward(self, x):
        queries = x @ self.W_q
        keys = x @ self.W_k
        values = x @ self.W_v
        print(queries.shape)

        attn_scores = queries @ keys.T
        block_size = attn_scores.shape[0]
        if self.mask:
            block_size = attn_scores.shape[0]
            mask = torch.triu(torch.ones(block_size, block_size), diagonal=1)
            attn_scores = attn_scores.masked_fill(mask.bool(), -torch.inf)

        attn_weights = torch.softmax(attn_scores/self.d_out_kq**0.5, dim=-1)
        print(attn_weights)

        context_vector = attn_weights @ values

        return context_vector
    


In [84]:
torch.manual_seed(123)
msa = MaskedSelfAttention(3,2,1)

msa(embedded_sentence)

torch.Size([6, 2])
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0532, 0.9468, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3862, 0.1214, 0.4924, 0.0000, 0.0000, 0.0000],
        [0.2232, 0.3242, 0.2078, 0.2449, 0.0000, 0.0000],
        [0.1536, 0.3145, 0.1325, 0.1849, 0.2145, 0.0000],
        [0.1973, 0.0247, 0.3102, 0.1132, 0.0751, 0.2794]],
       grad_fn=<SoftmaxBackward0>)


tensor([[-0.1055],
        [ 0.5085],
        [-0.1312],
        [ 0.1236],
        [ 0.1905],
        [-0.1827]], grad_fn=<MmBackward0>)

In [82]:
class MaskedMultiHeadAttentionWrapper(nn.Module):
    def __init__(self, d_in, d_out_kq, d_out_v, num_heads):
        super().__init__()
        self.heads = nn.ModuleList([MaskedSelfAttention(d_in, d_out_kq, d_out_v) for _ in range(num_heads)])
    
    def forward(self, x):
        return torch.cat([head(x) for head in self.heads], dim=-1)

In [81]:
torch.manual_seed(123)

mmha = MaskedMultiHeadAttentionWrapper(3, 2, 1, 1)

print(mmha(embedded_sentence))

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0532, 0.9468, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3862, 0.1214, 0.4924, 0.0000, 0.0000, 0.0000],
        [0.2232, 0.3242, 0.2078, 0.2449, 0.0000, 0.0000],
        [0.1536, 0.3145, 0.1325, 0.1849, 0.2145, 0.0000],
        [0.1973, 0.0247, 0.3102, 0.1132, 0.0751, 0.2794]],
       grad_fn=<SoftmaxBackward0>)
tensor([[-0.1055],
        [ 0.5085],
        [-0.1312],
        [ 0.1236],
        [ 0.1905],
        [-0.1827]], grad_fn=<CatBackward0>)


### Building Encoder Block

### Using a KV-Cache

From here on we will assume we are working with a decoder only language model. Therefore, we won't have cross-attention and we will use a mask. 

In [None]:
class MultiHeadAttention(nn.Module):

    def __init__(self, max_batch_size, max_seq_len, d_in, d_out_kq, d_out_v, num_heads, kv_cache=False):
        self.cache_k = torch.zeros((max_batch_size, max_seq_len, num_heads, d_out_kq))
        self.cache_v = torch.zeros((max_batch_size, max_seq_len, num_heads, d_out_v))

        
    

### Multi-Query Attention