In [1]:
import sys
import subprocess

def install_dependency(dependency: str) -> bool:
    try:
        subprocess.check_call([sys.executable, "-m", "pipenv", "install", dependency])
        return True
    except e:
        return False            

## Scaled Dot Product Attention

As we read the paper, the authors focus on the scaled dot product attention first, this combined with multihead attention--which we'll come to next--forms the basis of this paper.

In [27]:
import torch
import torch.nn.functional as F

def scaled_dot_product_attention(
    queries: torch.Tensor, 
    keys: torch.Tensor, 
    values: torch.Tensor 
) -> torch.Tensor:
    """
    Computes scaled dot-product attention.

    Args:
        queries (torch.Tensor): Query matrix of shape (d_context, d_k).
        keys (torch.Tensor): Key matrix of shape (d_context, d_k).
        values (torch.Tensor): Value matrix of shape (d_context, d_k).

    Returns:
        torch.Tensor: Attention-weighted sum of values.
    """
    assert queries.shape[1] == keys.shape[1], "Queries and keys must have the same number of dimensions"
    
    # Compute attention scores
    compatibility = queries @ keys.T
    
    # Scale by sqrt(d_k)
    d_k = queries.shape[1]
    stabilized_compat = compatibility / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))
    
    scaled_compat = F.softmax(stabilized_compat, dim=-1)
    
    # Compute attention output
    return scaled_compat @ values

And we'll run a quick test to make sure everything works :) 

In [4]:
d_model = 512

d_k = 64
d_v = 64

queries = torch.randn((d_model, d_k), dtype=torch.float64, requires_grad=True)
keys = torch.randn((d_model, d_k), dtype=torch.float64, requires_grad=True)
values = torch.randn((d_model, d_v), dtype=torch.float64, requires_grad=True)

attention = scaled_dot_product_attention(queries, keys, values)

assert attention.shape[0] == d_model and attention.shape[1] == d_v, "Attention has incorrect shape, should be: (n_token, d_value)"

print (attention)

tensor([[ 0.0043, -0.0110, -0.0394,  ..., -0.0508,  0.0281, -0.0398],
        [ 0.0653,  0.0198,  0.1255,  ..., -0.0096,  0.0336, -0.0329],
        [-0.0497, -0.1065,  0.0174,  ..., -0.0755,  0.0563,  0.0345],
        ...,
        [ 0.1907, -0.0506,  0.1149,  ...,  0.0087, -0.0295, -0.0331],
        [-0.0345, -0.0536, -0.0193,  ..., -0.0849,  0.0768,  0.0350],
        [ 0.0262, -0.1209,  0.0337,  ..., -0.1057,  0.0646, -0.0657]],
       dtype=torch.float64, grad_fn=<MmBackward0>)


## Multihead Attention

With scaled dot product implemented, we can go ahead and implement multihead attention.

We'll focus on this part first, since there is a lot to unpack here:  
"_Instead of performing a single attention function with $d_{model}$-dimensional keys, values and queries,
we found it beneficial to linearly project the queries, keys and values h times with different, learned
linear projections to d_k, d_k and d_v dimensions, respectively_"

Let's get started

In [40]:
torch.arange(20).view(2, 2, 5).shape[-1]

5

In [41]:
from torch import nn

class MultiHeadAttention(nn.Module):
    def __init__(self, n_heads=8):
        super().__init__()
        self.n_heads = n_heads

    def forward(self, x):
        """ 
        Computes multi head attention
        """
        d_k = d_v = x.shape[-1] / self.n_heads
        query_projections = nn.Parameter(torch.randn(self.n_heads, d_model, d_k))
        key_projections = nn.Parameter(torch.randn(self.n_heads, d_model, d_k))
        value_projections = nn.Parameter(torch.randn(self.n_heads, d_model, d_v))
        output_projection = nn.Parameter(torch.randn(d_v * self.n_heads, d_model))
    
        head_outputs = []
        for i in range(num_heads):
            q_proj = x @ query_projections[i]
            k_proj = x @ key_projections[i]
            v_proj = x @ value_projections[i]
            head_output = scaled_dot_product_attention(q_proj, k_proj, v_proj)
            head_outputs.append(head_output)
    
        concat_output = torch.concat(head_outputs, dim=-1)  
        return concat_output @ output_projection    

In [42]:
d_context = 64

values = torch.randn(d_context, d_model)
values

tensor([[ 0.7603,  0.6376,  0.1728,  ..., -0.8567,  0.2092, -1.5860],
        [-1.3113, -0.1871, -0.4620,  ..., -0.1369,  0.7352, -1.1841],
        [-0.3600,  1.3046, -1.3649,  ...,  0.4419,  0.4811, -0.4008],
        ...,
        [ 0.5887,  1.1117,  0.9526,  ..., -0.2709,  0.8288, -0.4260],
        [ 0.8801,  0.9559, -0.5841,  ...,  0.7414,  0.8736,  1.6141],
        [-0.2932, -0.7811,  1.6160,  ...,  0.6907, -0.9375, -0.9236]])

In [43]:
attention = MultiHeadAttention()
attention(values)

TypeError: randn(): argument 'size' failed to unpack the object at pos 3 with error "type must be tuple of ints,but got float"