## Self attention

The goal of self attention is to compute the `context vector` of each token in a sequence. The context vector is an enriched embedding representation of a token. It is packed with information about the token itself and its relationship/relevance to other tokens in a sequence.

#### Implement self-attention with untrainable weights

In [3]:
import torch

inputs = torch.tensor(
  [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]] # step     (x^6)
)

  cpu = _conversion_method_template(device=torch.device("cpu"))


In [4]:
# Compute attention scores, attention weights and context vector
attention_scores = inputs @ inputs.T
attention_weights = torch.softmax(attention_scores, dim=-1)
context_vector = attention_weights @ inputs

In [5]:
context_vector.shape

torch.Size([6, 3])

In [6]:
context_vector

tensor([[0.4421, 0.5931, 0.5790],
        [0.4419, 0.6515, 0.5683],
        [0.4431, 0.6496, 0.5671],
        [0.4304, 0.6298, 0.5510],
        [0.4671, 0.5910, 0.5266],
        [0.4177, 0.6503, 0.5645]])

#### Implement self-attention with trainable weights

In [22]:
from torch import nn

class SelfAttention_V1(nn.Module):
    def __init__(self, din, dout):
        super().__init__()
        # din is dimension of input token embeddings
        # dout is output dimension of context vectors
        self.W_query = nn.Parameter(torch.randn(din, dout))
        self.W_key = nn.Parameter(torch.randn(din, dout))
        self.W_value = nn.Parameter(torch.randn(din, dout))

    def forward(self, x):
        queries = x @ self.W_query
        keys = x @ self.W_key
        values = x @ self.W_value

        attention_scores = queries @ keys.T
        attention_weights = torch.softmax(attention_scores / keys.shape[-1]**0.5, dim=-1)
        
        context_vectors = attention_weights @ values
        return context_vectors

In [23]:
torch.manual_seed(123)
din = inputs.shape[1]
dout = 2
sa_v1 = SelfAttention_V1(din, dout)

In [18]:
din, dout

(3, 2)

In [24]:
sa_v1(inputs)

tensor([[0.2845, 0.4071],
        [0.2854, 0.4081],
        [0.2854, 0.4075],
        [0.2864, 0.3974],
        [0.2863, 0.3910],
        [0.2860, 0.4039]], grad_fn=<MmBackward0>)

In [34]:
class SelfAttention_V2(nn.Module):
    def __init__(self, din, dout, qkv_bias=False):
        super().__init__()
        # din is dimension of input token embeddings
        # dout is output dimension of context vectors
        self.W_query = nn.Linear(din, dout, bias=qkv_bias)
        self.W_key = nn.Linear(din, dout, bias=qkv_bias)
        self.W_value = nn.Linear(din, dout, bias=qkv_bias)

    def forward(self, x):
        queries = self.W_query(x)
        keys = self.W_key(x)
        values = self.W_value(x)

        attention_scores = queries @ keys.T
        attention_weights = torch.softmax(attention_scores / keys.shape[-1]**0.5, dim=-1)
        
        context_vectors = attention_weights @ values
        return context_vectors

In [35]:
torch.manual_seed(123)
din = inputs.shape[1]
dout = 2
sa_v2 = SelfAttention_V2(din, dout)
print(sa_v2(inputs))

tensor([[-0.5337, -0.1051],
        [-0.5323, -0.1080],
        [-0.5323, -0.1079],
        [-0.5297, -0.1076],
        [-0.5311, -0.1066],
        [-0.5299, -0.1081]], grad_fn=<MmBackward0>)


In [36]:
# Reuse the query and key weight matrices of the
# SelfAttention_v2 object from the previous section for convenience
queries = sa_v2.W_query(inputs)
keys = sa_v2.W_key(inputs) 
attn_scores = queries @ keys.T

attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
print(attn_weights)

tensor([[0.1717, 0.1762, 0.1761, 0.1555, 0.1627, 0.1579],
        [0.1636, 0.1749, 0.1746, 0.1612, 0.1605, 0.1652],
        [0.1637, 0.1749, 0.1746, 0.1611, 0.1606, 0.1651],
        [0.1636, 0.1704, 0.1702, 0.1652, 0.1632, 0.1674],
        [0.1667, 0.1722, 0.1721, 0.1618, 0.1633, 0.1639],
        [0.1624, 0.1709, 0.1706, 0.1654, 0.1625, 0.1682]],
       grad_fn=<SoftmaxBackward0>)


In [39]:
context_length = attn_scores.shape[0]
mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)
masked = attention_scores.masked_fill(mask.bool(), -torch.inf)
masked

tensor([[0.9995,   -inf,   -inf,   -inf,   -inf,   -inf],
        [0.9544, 1.4950,   -inf,   -inf,   -inf,   -inf],
        [0.9422, 1.4754, 1.4570,   -inf,   -inf,   -inf],
        [0.4753, 0.8434, 0.8296, 0.4937,   -inf,   -inf],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654,   -inf],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])

#### Implementing a compact causal self-attention class

In [None]:
class CausalAttention(nn.Module):
    def __init__(self, din, dout, context_length, dropout, qkv_bias=False):
        super().__init__()
        # din is dimension of input token embeddings
        # dout is output dimension of context vectors
        self.W_query = nn.Linear(din, dout, bias=qkv_bias)
        self.W_key = nn.Linear(din, dout, bias=qkv_bias)
        self.W_value = nn.Linear(din, dout, bias=qkv_bias)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))

    def forward(self, x):
        # x is a batched input
        B, num_tokens, din = x.shape
        
        queries = self.W_query(x)
        keys = self.W_key(x)
        values = self.W_value(x)

        attention_scores = queries @ keys.transpose(1, 2)
        attention_scores.masked_fill_(self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)
        attention_weights = torch.softmax(attention_scores / keys.shape[-1]**0.5, dim=-1)
        attention_weights = self.dropout(attention_weights)
        
        context_vectors = attention_weights @ values
        return context_vectors

In [48]:
batch = torch.stack([inputs, inputs], dim=0)
batch.shape

torch.Size([2, 6, 3])

In [49]:
torch.manual_seed(123)
_, context_length, din = batch.shape
dout = 2
causal_attention = CausalAttention(din, dout, context_length, 0.0)
print(causal_attention(batch))

tensor([[[-0.4519,  0.2216],
         [-0.5874,  0.0058],
         [-0.6300, -0.0632],
         [-0.5675, -0.0843],
         [-0.5526, -0.0981],
         [-0.5299, -0.1081]],

        [[-0.4519,  0.2216],
         [-0.5874,  0.0058],
         [-0.6300, -0.0632],
         [-0.5675, -0.0843],
         [-0.5526, -0.0981],
         [-0.5299, -0.1081]]], grad_fn=<UnsafeViewBackward0>)


## Extending single-head attention to multi-head attention

#### A less performant way of stacking multiple single-head attention

In [51]:
class MultiHeadAttentionWrapper(nn.Module):

    def __init__(self, din, dout, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        self.heads = nn.ModuleList(
            [CausalAttention(din, dout, context_length, dropout, qkv_bias) for _ in range(num_heads)]
        )

    def forward(self, x):
        return torch.cat([head(x) for head in self.heads], dim=-1)

In [52]:
torch.manual_seed(123)

context_length = batch.shape[1] # This is the number of tokens
d_in, d_out = 3, 2
mha = MultiHeadAttentionWrapper(
    d_in, d_out, context_length, 0.0, num_heads=2
)

context_vecs = mha(batch)
context_vecs

tensor([[[-0.4519,  0.2216,  0.4772,  0.1063],
         [-0.5874,  0.0058,  0.5891,  0.3257],
         [-0.6300, -0.0632,  0.6202,  0.3860],
         [-0.5675, -0.0843,  0.5478,  0.3589],
         [-0.5526, -0.0981,  0.5321,  0.3428],
         [-0.5299, -0.1081,  0.5077,  0.3493]],

        [[-0.4519,  0.2216,  0.4772,  0.1063],
         [-0.5874,  0.0058,  0.5891,  0.3257],
         [-0.6300, -0.0632,  0.6202,  0.3860],
         [-0.5675, -0.0843,  0.5478,  0.3589],
         [-0.5526, -0.0981,  0.5321,  0.3428],
         [-0.5299, -0.1081,  0.5077,  0.3493]]], grad_fn=<CatBackward0>)

#### A more performant way of stacking multiple single-head attention

In [61]:
class MultiHeadAttention(nn.Module):

    def __init__(self, din, dout, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        assert (dout % num_heads == 0), "dout must be divisible by num_heads"

        self.dout = dout
        self.num_heads = num_heads
        self.head_dim = dout // num_heads

        self.W_query = nn.Linear(din, dout, bias=qkv_bias)
        self.W_key = nn.Linear(din, dout, bias=qkv_bias)
        self.W_value = nn.Linear(din, dout, bias=qkv_bias)
        self.dropout = nn.Dropout(dropout)
        self.out_proj = nn.Linear(dout, dout)
        # Register a buffer for the causal mask
        self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))

    def forward(self, x):
        B, num_tokens, din = x.shape

        queries = self.W_query(x) # (B, num_tokens, din) @ (din, dout) -> (B, num_tokens, dout)
        keys = self.W_key(x)
        values = self.W_value(x)

        # lets transform all queries, keys and values to (B, num_tokens, num_heads, head_dim)
        queries = queries.view(B, num_tokens, self.num_heads, self.head_dim)
        keys = keys.view(B, num_tokens, self.num_heads, self.head_dim)
        values = values.view(B, num_tokens, self.num_heads, self.head_dim)

        # Lets swap the dimension 1 and 2 for queries, keys and values, so we can do compute attention scores over the head dimension
        # Hence it transforms into (B, num_heads, num_tokens, head_dim)
        queries = queries.transpose(1, 2)
        keys = keys.transpose(1, 2)
        values = values.transpose(1, 2)

        # (B, num_heads, num_tokens, head_dim) @ (B, num_heads, head_dim, num_tokens) -> (B, num_heads, num_tokens, num_tokens)
        attention_scores = queries @ keys.transpose(2, 3)

        # Lets prevent the model from seeing future tokens
        # We then update attention scores inplace using masked_fill_
        # We also scale the softmax along the head dimension
        attention_scores.masked_fill_(self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)
        attention_weights = torch.softmax(attention_scores / self.head_dim**0.5, dim=-1)

        # Let do some drop out on attention weights to avoid over-fitting
        attention_weights = self.dropout(attention_weights)

        # Let's now compute context vectors
        context_vectors = attention_weights @ values # (B, H, T, head_dim)

        # Reshape to have dout dimension
        # (B, H, T, head_dim) -> (B, T, H, head_dim) and make it contiguous in memory 
        context_vectors = context_vectors.transpose(1, 2).contiguous().view(B, num_tokens, self.dout)

        # Output -> (B, num_tokens, dout)
        context_vectors = self.out_proj(context_vectors)
        return context_vectors
        

In [62]:
torch.manual_seed(123)

context_length = batch.shape[1] # This is the number of tokens
d_in, d_out = 3, 2
mha = MultiHeadAttention(
    d_in, d_out, context_length, 0.0, num_heads=2
)

context_vecs = mha(batch)
context_vecs

tensor([[[0.3190, 0.4858],
         [0.2943, 0.3897],
         [0.2856, 0.3593],
         [0.2693, 0.3873],
         [0.2639, 0.3928],
         [0.2575, 0.4028]],

        [[0.3190, 0.4858],
         [0.2943, 0.3897],
         [0.2856, 0.3593],
         [0.2693, 0.3873],
         [0.2639, 0.3928],
         [0.2575, 0.4028]]], grad_fn=<ViewBackward0>)