# Attention Mechanism
attention from its name, is a mechanism that allows the model to focus on specific parts of the input sequence.
Let's build a simple attention mechanism from scratch. Starting with the self-attention mechanism.


In [3]:
import torch
inputs = torch.tensor(
  [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]] # step     (x^6)
)

# Let's compute the attention scores for each input.
# We'll use the dot product of the query, key, and value.
# The query, key, and value are the same in this case.
# So we'll use the same linear layer to compute the query, key, and value.
# The query, key, and value are the same in this case.
# So we'll use the same linear layer to compute the query, key, and value.

In [4]:
x_2 = inputs[1] # second input element
d_in = inputs.shape[1] # the input embedding size, d=3
d_out = 2 # the output embedding size, d=2

## Self-Attention


In [5]:
import torch.nn as nn

class SelfAttention(nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()
        self.query = nn.Linear(d_in, d_out, bias=False)
        self.key = nn.Linear(d_in, d_out, bias=False)
        self.value = nn.Linear(d_in, d_out, bias=False)

    def forward(self, x):
        queries = self.query(x)
        keys = self.key(x)
        values = self.value(x)

        # Compute the attention scores
        attention_scores = torch.matmul(queries, keys.T)
        attention_weights= torch.softmax(attention_scores / keys.shape[-1]**0.5, dim=1)
        output = torch.matmul(attention_weights, values)
        return output

torch.manual_seed(42)
attention_output = SelfAttention(d_in, d_out)(inputs)
print(attention_output)

tensor([[0.3755, 0.2777],
        [0.3761, 0.2831],
        [0.3761, 0.2833],
        [0.3768, 0.2763],
        [0.3754, 0.2836],
        [0.3772, 0.2746]], grad_fn=<MmBackward0>)


The problem with the above is that the attention mechanism can look at all the tokens in the input sequence. In generation tasks, we want the model to look at the previous tokens.
To achieve this, we can use a causal mask. Let's modify the SelfAttention class to include a causal mask.

## Causal Self-Attention (Masked Attention)

In [8]:
class CausalSelfAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout=0.1):
        super().__init__()
        self.query = nn.Linear(d_in, d_out, bias=False)
        self.key = nn.Linear(d_in, d_out, bias=False)
        self.value = nn.Linear(d_in, d_out, bias=False)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))

    def forward(self, x):
        _, num_tokens, _ = x.shape
        keys = self.key(x)
        queries = self.query(x)
        values = self.value(x)

        attn_scores = torch.matmul(queries, keys.transpose(1, 2))
        attn_scores.masked_fill_(  # New, _ ops are in-place
            self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)  # `:num_tokens` to account for cases where the number of tokens in the batch is smaller than the supported context_size
        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1]**0.5, dim=-1
        )
        attn_weights = self.dropout(attn_weights)
        output = torch.matmul(attn_weights, values)
        return output

torch.manual_seed(42)
batch = torch.stack((inputs, inputs), dim=0)

context_length = batch.shape[1]

attention_output = CausalSelfAttention(d_in, d_out, context_length=context_length)(batch)
print(attention_output)

tensor([[[0.4921, 0.1196],
         [0.5174, 0.2886],
         [0.5257, 0.3366],
         [0.4595, 0.3246],
         [0.4531, 0.2852],
         [0.2807, 0.1521]],

        [[0.4921, 0.1196],
         [0.5174, 0.2886],
         [0.5257, 0.3366],
         [0.4595, 0.3246],
         [0.3358, 0.1922],
         [0.4191, 0.3051]]], grad_fn=<UnsafeViewBackward0>)


Now we have build what we call a single-head attention mechanism. We can do better by using multiple attention heads, which helps the model to focus on different parts of the input sequence. 

## Multi-Head Attention

In [13]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, num_heads, dropout=0.1):
        super().__init__()
        assert d_out % num_heads == 0, "d_out must be divisible by num_heads"
        self.num_heads = num_heads
        self.d_out = d_out
        self.head_dim = d_out // num_heads
        self.query = nn.Linear(d_in, d_out, bias=False)
        self.key = nn.Linear(d_in, d_out, bias=False)
        self.value = nn.Linear(d_in, d_out, bias=False)
        self.output = nn.Linear(d_out, d_out, bias=False)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))

    def forward(self, x):
        b, num_tokens, _ = x.shape

        keys = self.key(x) # Shape: (b, num_tokens, d_out)
        queries = self.query(x)
        values = self.value(x)

        # We implicitly split the matrix by adding a `num_heads` dimension
        # Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim) 
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)

        # Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
        keys = keys.transpose(1, 2)
        queries = queries.transpose(1, 2)
        values = values.transpose(1, 2)

        # Compute scaled dot-product attention (aka self-attention) with a causal mask
        attn_scores = queries @ keys.transpose(2, 3)  # Dot product for each head

        # Original mask truncated to the number of tokens and converted to boolean
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]

        # Use the mask to fill attention scores
        attn_scores.masked_fill_(mask_bool, -torch.inf)
        
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)

        # Shape: (b, num_tokens, num_heads, head_dim)
        output = (attn_weights @ values).transpose(1, 2) 
        
        # Combine heads, where self.d_out = self.num_heads * self.head_dim
        output = output.contiguous().view(b, num_tokens, self.d_out)
        output = self.output(output) # optional projection

        return output

torch.manual_seed(42)

batch_size, context_length, d_in = batch.shape
d_out = 2
mha = MultiHeadAttention(d_in, d_out, context_length, 2, 0.0)

context_vecs = mha(batch)

print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)

tensor([[[-0.1268, -0.1532],
         [-0.1070, -0.1730],
         [-0.1016, -0.1791],
         [-0.0842, -0.1591],
         [-0.0892, -0.1540],
         [-0.0760, -0.1452]],

        [[-0.1268, -0.1532],
         [-0.1070, -0.1730],
         [-0.1016, -0.1791],
         [-0.0842, -0.1591],
         [-0.0892, -0.1540],
         [-0.0760, -0.1452]]], grad_fn=<UnsafeViewBackward0>)
context_vecs.shape: torch.Size([2, 6, 2])
