In [2]:
import sys
import subprocess

def install_dependency(dependency: str) -> bool:
    try:
        subprocess.check_call([sys.executable, "-m", "pipenv", "install", dependency])
        return True
    except e:
        return False            

In [15]:
import torch
from torch import nn
import torch.nn.functional as F

## Scaled Dot Product Attention

As we read the paper, the authors focus on the scaled dot product attention first, this combined with multihead attention--which we'll come to next--forms the basis of this paper.

In [16]:
def scaled_dot_product_attention(
    queries: torch.Tensor, 
    keys: torch.Tensor, 
    values: torch.Tensor 
) -> torch.Tensor:
    """
    Computes scaled dot-product attention.

    Args:
        queries (torch.Tensor): Query matrix of shape (d_context, d_k).
        keys (torch.Tensor): Key matrix of shape (d_context, d_k).
        values (torch.Tensor): Value matrix of shape (d_context, d_k).

    Returns:
        torch.Tensor: Attention-weighted sum of values.
    """
    assert queries.shape[1] == keys.shape[1], "Queries and keys must have the same number of dimensions"
    
    # Compute attention scores
    compatibility = queries @ keys.T
    
    # Scale by sqrt(d_k)
    d_k = queries.shape[1]
    stabilized_compat = compatibility / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))
    
    scaled_compat = F.softmax(stabilized_compat, dim=-1)
    
    # Compute attention output
    return scaled_compat @ values

And we'll run a quick test to make sure everything works :) 

In [4]:
d_model = 512

d_k = 64
d_v = 64

queries = torch.randn((d_model, d_k), dtype=torch.float64, requires_grad=True)
keys = torch.randn((d_model, d_k), dtype=torch.float64, requires_grad=True)
values = torch.randn((d_model, d_v), dtype=torch.float64, requires_grad=True)

attention = scaled_dot_product_attention(queries, keys, values)

assert attention.shape[0] == d_model and attention.shape[1] == d_v, "Attention has incorrect shape, should be: (n_token, d_value)"

print (attention)

tensor([[ 0.0006, -0.1030,  0.0279,  ..., -0.0876,  0.0201,  0.0039],
        [ 0.1138, -0.0085, -0.0288,  ..., -0.0100, -0.0136, -0.0037],
        [ 0.0294, -0.0451, -0.0455,  ...,  0.0154, -0.0138,  0.0649],
        ...,
        [ 0.0653, -0.1240, -0.0054,  ..., -0.0147,  0.0097,  0.0454],
        [ 0.1139, -0.0221, -0.0298,  ..., -0.0383,  0.0234, -0.0043],
        [ 0.0297,  0.0306, -0.0823,  ..., -0.0151,  0.0243,  0.1242]],
       dtype=torch.float64, grad_fn=<MmBackward0>)


## Multihead Attention

With scaled dot product implemented, we can go ahead and implement multihead attention.

We'll focus on this part first, since there is a lot to unpack here:  
"_Instead of performing a single attention function with $d_{model}$-dimensional keys, values and queries,
we found it beneficial to linearly project the queries, keys and values h times with different, learned
linear projections to d_k, d_k and d_v dimensions, respectively_"

Let's get started

In [5]:
torch.arange(20).view(2, 2, 5).shape[-1]

5

In [12]:
from torch import nn

class MultiHeadAttention(nn.Module):
    def __init__(self, n_heads=8):
        super().__init__()
        self.n_heads = n_heads

    def forward(self, x):
        """ 
        Computes multi head attention
        """
        d_k = d_v = x.shape[-1] // self.n_heads
        query_projections = nn.Parameter(torch.randn(self.n_heads, d_model, d_k))
        key_projections = nn.Parameter(torch.randn(self.n_heads, d_model, d_k))
        value_projections = nn.Parameter(torch.randn(self.n_heads, d_model, d_v))
        output_projection = nn.Parameter(torch.randn(d_v * self.n_heads, d_model))
    
        head_outputs = []
        for i in range(self.n_heads):
            q_proj = x @ query_projections[i]
            k_proj = x @ key_projections[i]
            v_proj = x @ value_projections[i]
            head_output = scaled_dot_product_attention(q_proj, k_proj, v_proj)
            head_outputs.append(head_output)
    
        concat_output = torch.concat(head_outputs, dim=-1)  
        return concat_output @ output_projection    

In [13]:
d_context = 64

values = torch.randn(d_context, d_model)
values

tensor([[ 1.0940,  0.2034, -0.5026,  ..., -1.4236, -0.7154,  0.0523],
        [ 1.2520, -0.3363,  0.1674,  ...,  0.4714, -1.7212,  1.3816],
        [-0.7531, -0.0715, -0.2196,  ..., -1.9276,  0.9531,  1.9429],
        ...,
        [ 2.0618, -0.2677, -1.2191,  ..., -1.6306,  1.1393,  0.6736],
        [ 0.1889,  0.6635, -1.3172,  ...,  1.4840,  1.0351,  0.6650],
        [ 0.3105,  0.5503,  0.3433,  ..., -0.4514, -1.1513, -0.6309]])

In [14]:
attention = MultiHeadAttention()
attention(values)

tensor([[ 8.5218e+02, -3.1108e+02, -3.0129e+02,  ...,  4.0344e-01,
         -4.8299e+02, -1.6218e+02],
        [ 3.3179e+02,  1.8065e+02, -4.0647e+02,  ..., -7.4219e+02,
          3.7257e+02,  2.6203e+02],
        [ 2.2212e+02,  6.6790e+02, -2.0758e+02,  ..., -3.6973e+02,
         -2.8809e+02, -3.5029e+02],
        ...,
        [ 5.5337e+02,  5.3414e+01,  2.4222e+02,  ...,  6.6526e+02,
         -3.2101e+02,  1.3975e+02],
        [ 8.2053e+02, -2.7026e+02, -7.3547e+02,  ...,  5.8289e+02,
          1.0888e+02,  6.2960e+01],
        [ 9.4588e+02, -6.3049e+02, -5.1539e+02,  ..., -4.1804e+02,
          4.9018e+02, -5.8086e+02]], grad_fn=<MmBackward0>)

## Tokenization & Embedding

In [21]:
class TokenEmbedder(nn.Module):
    def __init__(self, vocab_size: int, embed_size: int) -> None:
        """ """
        self.embedding_matrix = torch.nn.Parameter(torch.randn((vocab_size, embed_size)))

    def forward(self, tokens: torch.Tensor) -> torch.Tensor:
        """Produces an embedding of the specified token sequence

        Args:
            tokens: a sequence of tokens of shape (max_seq_len, vocab_size)
        """
        return tokens @ self.embedding_matrix
        

## Positional Encoding

In [36]:
bias = torch.full((d_model,), 10_000, dtype=torch.float)
indices = ((torch.arange(d_model) // 2) * 2) / d_model
divisor = bias.pow(indices)
divisor

tensor([1.0000e+00, 1.0000e+00, 1.0366e+00, 1.0366e+00, 1.0746e+00, 1.0746e+00,
        1.1140e+00, 1.1140e+00, 1.1548e+00, 1.1548e+00, 1.1971e+00, 1.1971e+00,
        1.2409e+00, 1.2409e+00, 1.2864e+00, 1.2864e+00, 1.3335e+00, 1.3335e+00,
        1.3824e+00, 1.3824e+00, 1.4330e+00, 1.4330e+00, 1.4855e+00, 1.4855e+00,
        1.5399e+00, 1.5399e+00, 1.5963e+00, 1.5963e+00, 1.6548e+00, 1.6548e+00,
        1.7154e+00, 1.7154e+00, 1.7783e+00, 1.7783e+00, 1.8434e+00, 1.8434e+00,
        1.9110e+00, 1.9110e+00, 1.9810e+00, 1.9810e+00, 2.0535e+00, 2.0535e+00,
        2.1288e+00, 2.1288e+00, 2.2067e+00, 2.2067e+00, 2.2876e+00, 2.2876e+00,
        2.3714e+00, 2.3714e+00, 2.4582e+00, 2.4582e+00, 2.5483e+00, 2.5483e+00,
        2.6416e+00, 2.6416e+00, 2.7384e+00, 2.7384e+00, 2.8387e+00, 2.8387e+00,
        2.9427e+00, 2.9427e+00, 3.0505e+00, 3.0505e+00, 3.1623e+00, 3.1623e+00,
        3.2781e+00, 3.2781e+00, 3.3982e+00, 3.3982e+00, 3.5227e+00, 3.5227e+00,
        3.6517e+00, 3.6517e+00, 3.7855e+

In [41]:
max_seq_len = 1024
pos = torch.arange(max_seq_len, dtype=torch.float).view((-1, 1))
freqs = pos / divisor
freqs

tensor([[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 1.0000e+00, 9.6466e-01,  ..., 1.0746e-04, 1.0366e-04,
         1.0366e-04],
        [2.0000e+00, 2.0000e+00, 1.9293e+00,  ..., 2.1492e-04, 2.0733e-04,
         2.0733e-04],
        ...,
        [1.0210e+03, 1.0210e+03, 9.8492e+02,  ..., 1.0972e-01, 1.0584e-01,
         1.0584e-01],
        [1.0220e+03, 1.0220e+03, 9.8588e+02,  ..., 1.0982e-01, 1.0594e-01,
         1.0594e-01],
        [1.0230e+03, 1.0230e+03, 9.8685e+02,  ..., 1.0993e-01, 1.0605e-01,
         1.0605e-01]])

In [43]:
PE = torch.zeros((max_seq_len, d_model), dtype=torch.float)
PE[:, 0::2] = torch.sin(freqs[:, 0::2])
PE[:, 1::2] = torch.cos(freqs[:, 1::2])
PE

tensor([[ 0.0000e+00,  1.0000e+00,  0.0000e+00,  ...,  1.0000e+00,
          0.0000e+00,  1.0000e+00],
        [ 8.4147e-01,  5.4030e-01,  8.2186e-01,  ...,  1.0000e+00,
          1.0366e-04,  1.0000e+00],
        [ 9.0930e-01, -4.1615e-01,  9.3641e-01,  ...,  1.0000e+00,
          2.0733e-04,  1.0000e+00],
        ...,
        [ 1.7612e-02, -9.9984e-01, -9.9954e-01,  ...,  9.9399e-01,
          1.0564e-01,  9.9440e-01],
        [-8.3182e-01, -5.5504e-01, -5.4457e-01,  ...,  9.9398e-01,
          1.0575e-01,  9.9439e-01],
        [-9.1649e-01,  4.0007e-01,  3.7906e-01,  ...,  9.9396e-01,
          1.0585e-01,  9.9438e-01]])

Now let's pull it all together

In [47]:
class PositionalEncode(nn.Module):
    def __init__(self, max_seq_length: int = 1024, d_model: int = 512) -> None:
        """
        Generates sinusoidal positional encodings.
    
        Parameters:
            max_seq_len (int): Maximum sequence length.
            d_model (int): Dimensionality of the model embeddings.
    
        Returns:
            torch.Tensor: A tensor of shape (max_seq_len, d_model) containing 
                          the positional encodings.
        """
        # Create position indices: pos = [0, 1, ..., max_seq_len-1]
        pos_indices = torch.arange(max_seq_len, dtype=torch.float32)
        
        # Create dimension indices: dim = [0, 1, ..., d_model-1]
        dim_indices = torch.arange(d_model, dtype=torch.float32)
        
        # Compute the scaling exponent: 2 * floor(dim/2) / d_model
        exponent = ((dim_indices // 2) * 2) / d_model
        
        # Compute the denominator term: 10000^(exponent)
        div_term = torch.pow(10000, exponent)
        
        # Compute the angle rates: pos / div_term
        angle_rates = pos_indices.unsqueeze(1) / div_term
        
        # Initialize the positional encoding matrix and apply sine to even 
        # indices and cosine to odd indices.
        pos_encoding = torch.zeros_like(angle_rates)
        pos_encoding[:, 0::2] = torch.sin(angle_rates[:, 0::2])
        pos_encoding[:, 1::2] = torch.cos(angle_rates[:, 1::2])
        
        return pos_encoding

    def forward(self):
        return self.PE

In [48]:
PositionalEncode()()

tensor([[ 0.0000e+00,  1.0000e+00,  0.0000e+00,  ...,  1.0000e+00,
          0.0000e+00,  1.0000e+00],
        [ 8.4147e-01,  5.4030e-01,  8.2186e-01,  ...,  1.0000e+00,
          1.0366e-04,  1.0000e+00],
        [ 9.0930e-01, -4.1615e-01,  9.3641e-01,  ...,  1.0000e+00,
          2.0733e-04,  1.0000e+00],
        ...,
        [ 1.7612e-02, -9.9984e-01, -9.9954e-01,  ...,  9.9399e-01,
          1.0564e-01,  9.9440e-01],
        [-8.3182e-01, -5.5504e-01, -5.4457e-01,  ...,  9.9398e-01,
          1.0575e-01,  9.9439e-01],
        [-9.1649e-01,  4.0007e-01,  3.7906e-01,  ...,  9.9396e-01,
          1.0585e-01,  9.9438e-01]])

In [51]:
def generate_positional_encodings(max_seq_length, d_model):
        dimensions = torch.arange(d_model)
        positions = torch.arange(max_seq_len)

        exponent = ((dimensions // 2) * 2) / d_model
        freq_divisor = torch.full_like(dimensions, 10_000).pow(exponent)
        freqs = positions.view((max_seq_len, 1)) / freq_divisor
        position_encoding = torch.zeros_like(freqs)
        position_encoding[:, 0::2] = torch.sin(freqs[:, 0::2])
        position_encoding[:, 1::2] = torch.cos(freqs[:, 1::2])
        return position_encoding

In [52]:
generate_positional_encodings(1024, 512)

tensor([[ 0.0000e+00,  1.0000e+00,  0.0000e+00,  ...,  1.0000e+00,
          0.0000e+00,  1.0000e+00],
        [ 8.4147e-01,  5.4030e-01,  8.2186e-01,  ...,  1.0000e+00,
          1.0366e-04,  1.0000e+00],
        [ 9.0930e-01, -4.1615e-01,  9.3641e-01,  ...,  1.0000e+00,
          2.0733e-04,  1.0000e+00],
        ...,
        [ 1.7612e-02, -9.9984e-01, -9.9954e-01,  ...,  9.9399e-01,
          1.0564e-01,  9.9440e-01],
        [-8.3182e-01, -5.5504e-01, -5.4457e-01,  ...,  9.9398e-01,
          1.0575e-01,  9.9439e-01],
        [-9.1649e-01,  4.0007e-01,  3.7906e-01,  ...,  9.9396e-01,
          1.0585e-01,  9.9438e-01]])