# 3. Coding Attention Mechanisms

In [1]:
import torch
import torch.nn as nn

## 3.3 Attending to different parts of the input with self-attention
### 3.3.1 A simple self-attention mechanism without trainable weights

In [2]:
input_embeddings = torch.tensor(
  [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]] # step     (x^6)
)

query = input_embeddings[1]

# Compute the intermediate attention scores 
# between the query token and each input token.
attention_scores_2 = torch.empty(input_embeddings.shape[0])
for i, embedding in enumerate(input_embeddings):
    attention_scores_2[i] = torch.dot(embedding, query)

attention_weights_2 = torch.softmax(attention_scores_2, dim=0)
print("Attention weights:", attention_weights_2)
print("Sum:", attention_weights_2.sum())

# Compute the context vector 
# as the attention-weighted sum of the input embeddings.
context_vector_2 = torch.zeros(query.shape)
for i, embedding in enumerate(input_embeddings):
    context_vector_2 += attention_weights_2[i] * embedding
print("Context vector:", context_vector_2)

Attention weights: tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
Sum: tensor(1.)
Context vector: tensor([0.4419, 0.6515, 0.5683])


### 3.3.2 Computing attention weights for all input tokens

In [3]:
attention_scores = input_embeddings @ input_embeddings.T
attention_weights = torch.softmax(attention_scores, dim=-1)
print(attention_weights[1])

context_vectors = attention_weights @ input_embeddings
print(context_vectors[1])

tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
tensor([0.4419, 0.6515, 0.5683])


## 3.4 Implementing self-attention with trainable weights

### 3.4.1 Implementing a compact self-attention Python class 

In [4]:
context_length, embedding_dim = input_embeddings.shape
context_dim = 2

In [5]:
torch.manual_seed(123)

W_query = torch.nn.Parameter(torch.randn(embedding_dim, context_dim), requires_grad=False)
W_key = torch.nn.Parameter(torch.randn(embedding_dim, context_dim), requires_grad=False)
W_value = torch.nn.Parameter(torch.randn(embedding_dim, context_dim), requires_grad=False)

keys = input_embeddings @ W_key
values = input_embeddings @ W_value

print("Keys shape:", keys.shape)
print("Values shape:", values.shape)

Keys shape: torch.Size([6, 2])
Values shape: torch.Size([6, 2])


In [6]:
# compute the attention score for the second token
input_embeddings_2 = input_embeddings[1]

query_2 = input_embeddings_2 @ W_query
key_2 = input_embeddings_2 @ W_key
value_2 = input_embeddings_2 @ W_value
print("Query:", query_2)

# compute attention score between query_2 and key_2
attention_score_22 = query_2.dot(key_2)
print(attention_score_22)

# compute attention scores between query_2 and all keys
attention_scores_2 = query_2 @ keys.T
print(attention_scores_2)

# standardize the attention scores
key_dim = keys.shape[-1]  # same as output_dim
attention_weights_2 = torch.softmax(attention_scores_2 / key_dim**0.5, dim=-1)
print(attention_weights_2)

# compute the context vector
context_vector_2 = attention_weights_2 @ values
print(context_vector_2)

Query: tensor([-1.1729, -0.0048])
tensor(0.1376)
tensor([ 0.2172,  0.1376,  0.1730, -0.0491,  0.7616, -0.3809])
tensor([0.1704, 0.1611, 0.1652, 0.1412, 0.2505, 0.1117])
tensor([0.2854, 0.4081])


### 3.4.2 Implementing a compact self-attention Python class

In [7]:
class SelfAttention_v2(nn.Module):

    def __init__(self, embedding_dim, context_dim, qkv_bias=False):
        super().__init__()
        self.W_query = nn.Linear(embedding_dim, context_dim, bias=qkv_bias)
        self.W_key = nn.Linear(embedding_dim, context_dim, bias=qkv_bias)
        self.W_value = nn.Linear(embedding_dim, context_dim, bias=qkv_bias)
    
    def forward(self, emdeddings):
        keys = self.W_key(emdeddings)
        queries = self.W_query(emdeddings)
        attention_scores = queries @ keys.T
        attention_weights = torch.softmax(
            attention_scores / keys.shape[-1]**0.5, 
            dim=-1
        )
        values = self.W_value(emdeddings)
        context_vectors = attention_weights @ values
        return context_vectors

torch.manual_seed(123)
self_attention = SelfAttention_v2(embedding_dim, context_dim)
# context vectors has different values due to different initial weights
print(self_attention(input_embeddings))

tensor([[-0.5337, -0.1051],
        [-0.5323, -0.1080],
        [-0.5323, -0.1079],
        [-0.5297, -0.1076],
        [-0.5311, -0.1066],
        [-0.5299, -0.1081]], grad_fn=<MmBackward0>)


## 3.5 Hiding future words with causal attention
### 3.5.1 Applying a causal attention mask

In [8]:
queries = self_attention.W_query(input_embeddings)
keys = self_attention.W_key(input_embeddings)
attention_scores = queries @ keys.T
attention_weights = torch.softmax(
    attention_scores / keys.shape[-1]**0.5, 
    dim=-1
)
print("Attention weights:\n", attention_weights)

Attention weights:
 tensor([[0.1717, 0.1762, 0.1761, 0.1555, 0.1627, 0.1579],
        [0.1636, 0.1749, 0.1746, 0.1612, 0.1605, 0.1652],
        [0.1637, 0.1749, 0.1746, 0.1611, 0.1606, 0.1651],
        [0.1636, 0.1704, 0.1702, 0.1652, 0.1632, 0.1674],
        [0.1667, 0.1722, 0.1721, 0.1618, 0.1633, 0.1639],
        [0.1624, 0.1709, 0.1706, 0.1654, 0.1625, 0.1682]],
       grad_fn=<SoftmaxBackward0>)


In [9]:
context_length = attention_scores.shape[0]
mask_0 = torch.tril(torch.ones((context_length, context_length)))
print("Causal mask:\n", mask_0)

Causal mask:
 tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])


In [10]:
masked_attention_weights = attention_weights * mask_0
print("Masked attention weights (unnormalized):\n", masked_attention_weights)

Masked attention weights (unnormalized):
 tensor([[0.1717, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1636, 0.1749, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1637, 0.1749, 0.1746, 0.0000, 0.0000, 0.0000],
        [0.1636, 0.1704, 0.1702, 0.1652, 0.0000, 0.0000],
        [0.1667, 0.1722, 0.1721, 0.1618, 0.1633, 0.0000],
        [0.1624, 0.1709, 0.1706, 0.1654, 0.1625, 0.1682]],
       grad_fn=<MulBackward0>)


In [11]:
attention_weight_sums = masked_attention_weights.sum(dim=-1, keepdim=True)
normalized_masked_attention_weights = masked_attention_weights / attention_weight_sums
print("Masked attention weights (normalized):\n", normalized_masked_attention_weights)

Masked attention weights (normalized):
 tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4833, 0.5167, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3190, 0.3408, 0.3402, 0.0000, 0.0000, 0.0000],
        [0.2445, 0.2545, 0.2542, 0.2468, 0.0000, 0.0000],
        [0.1994, 0.2060, 0.2058, 0.1935, 0.1953, 0.0000],
        [0.1624, 0.1709, 0.1706, 0.1654, 0.1625, 0.1682]],
       grad_fn=<DivBackward0>)


In [12]:
mask_inf = torch.triu(torch.ones(context_length, context_length), diagonal=1) 
masked_attention_scores = attention_scores.masked_fill(mask_inf.bool(), -torch.inf) 
print("Masked attention scores:\n", masked_attention_scores)

masked_attention_weights = torch.softmax(masked_attention_scores / keys.shape[-1]**0.5, dim=1) 
print("Attention weights:\n", masked_attention_weights)

Masked attention scores:
 tensor([[0.3111,   -inf,   -inf,   -inf,   -inf,   -inf],
        [0.1655, 0.2602,   -inf,   -inf,   -inf,   -inf],
        [0.1667, 0.2602, 0.2577,   -inf,   -inf,   -inf],
        [0.0510, 0.1080, 0.1064, 0.0643,   -inf,   -inf],
        [0.1415, 0.1875, 0.1863, 0.0987, 0.1121,   -inf],
        [0.0476, 0.1192, 0.1171, 0.0731, 0.0477, 0.0966]],
       grad_fn=<MaskedFillBackward0>)
Attention weights:
 tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4833, 0.5167, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3190, 0.3408, 0.3402, 0.0000, 0.0000, 0.0000],
        [0.2445, 0.2545, 0.2542, 0.2468, 0.0000, 0.0000],
        [0.1994, 0.2060, 0.2058, 0.1935, 0.1953, 0.0000],
        [0.1624, 0.1709, 0.1706, 0.1654, 0.1625, 0.1682]],
       grad_fn=<SoftmaxBackward0>)


### 3.5.2 Masking additional attention weights with dropout

In [None]:
torch.manual_seed(123)
dropout = torch.nn.Dropout(0.5)
example = torch.ones(6, 6)
print(dropout(example))

tensor([[2., 2., 0., 2., 2., 0.],
        [0., 0., 0., 2., 0., 2.],
        [2., 2., 2., 2., 0., 2.],
        [0., 2., 2., 0., 0., 2.],
        [0., 2., 0., 2., 0., 2.],
        [0., 2., 2., 2., 2., 0.]])


In [None]:
torch.manual_seed(123)
print(dropout(masked_attention_weights))

tensor([[2.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.6380, 0.6816, 0.6804, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.5090, 0.5085, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.4120, 0.0000, 0.3869, 0.0000, 0.0000],
        [0.0000, 0.3418, 0.3413, 0.3308, 0.3249, 0.0000]],
       grad_fn=<MulBackward0>)


### 3.5.3 Implementing a compact causal attention class

In [15]:
class CausalAttention(nn.Module):

    def __init__(
        self, 
        embedding_dim, 
        context_dim, 
        context_length, 
        dropout_rate, 
        qkv_bias=False
    ):
        super().__init__()
        self.context_dim = context_dim
        self.W_query = nn.Linear(embedding_dim, context_dim, bias=qkv_bias)
        self.W_key = nn.Linear(embedding_dim, context_dim, bias=qkv_bias)
        self.W_value = nn.Linear(embedding_dim, context_dim, bias=qkv_bias)
        self.register_buffer(
            'mask',
            torch.triu(
                torch.ones(context_length, context_length), 
                diagonal=1
            )
        )
        self.dropout = nn.Dropout(dropout_rate)
    
    def forward(self, embeddings):
        n_batches, n_tokens, embedding_dim = embeddings.shape

        keys = self.W_key(embeddings)
        queries = self.W_query(embeddings)
        attention_scores = queries @ keys.transpose(1, 2)
        attention_scores.masked_fill_(
            self.mask.bool()[
                :n_tokens, 
                :n_tokens
            ],
            -torch.inf
        )
        attention_weights = torch.softmax(
            attention_scores / keys.shape[-1]**0.5, 
            dim=-1 
        )
        attention_weights = self.dropout(attention_weights)

        context_vectors = attention_weights @ values
        return context_vectors


batch = torch.stack((input_embeddings, input_embeddings), dim=0)

torch.manual_seed(123)
causal_attention = CausalAttention(embedding_dim, context_dim, context_length, dropout_rate=0.0)

context_vectors = causal_attention(batch)
print("context_vectors.shape:", context_vectors.shape)

context_vectors.shape: torch.Size([2, 6, 2])


## 3.6 Extending single-head attention to multi-head attention
### 3.6.1 Stacking multiple single-head attention layers

In [16]:
class MultiHeadAttentionWrapper(nn.Module):

    def __init__(
        self, 
        n_heads,
        embedding_dim, 
        context_dim, 
        context_length, 
        dropout_rate, 
        qkv_bias=False
    ):
        super().__init__()
        self.heads = nn.ModuleList(
            [
                CausalAttention(
                    embedding_dim, 
                    context_dim, 
                    context_length, 
                    dropout_rate, 
                    qkv_bias
                )
                for _ in range(n_heads)
            ]
        )
    
    def forward(self, embeddings):
        return torch.cat([head(embeddings) for head in self.heads], dim=-1)


torch.manual_seed(123)

multihead_attention_wrapper = MultiHeadAttentionWrapper(
    2, embedding_dim, context_dim, context_length, 0.0
)
context_vectors = multihead_attention_wrapper(batch)

print(context_vectors)
print("context_vectors.shape:", context_vectors.shape)

tensor([[[ 0.1196, -0.3566,  0.1196, -0.3566],
         [ 0.2700,  0.1519,  0.2826,  0.1944],
         [ 0.3173,  0.3174,  0.3278,  0.3531],
         [ 0.2979,  0.3382,  0.3014,  0.3538],
         [ 0.2916,  0.4008,  0.2918,  0.4088],
         [ 0.2887,  0.3903,  0.2920,  0.4108]],

        [[ 0.1196, -0.3566,  0.1196, -0.3566],
         [ 0.2700,  0.1519,  0.2826,  0.1944],
         [ 0.3173,  0.3174,  0.3278,  0.3531],
         [ 0.2979,  0.3382,  0.3014,  0.3538],
         [ 0.2916,  0.4008,  0.2918,  0.4088],
         [ 0.2887,  0.3903,  0.2920,  0.4108]]], grad_fn=<CatBackward0>)
context_vectors.shape: torch.Size([2, 6, 4])


### 3.6.2 Implementing multi-head attention with weight splits

In [17]:
class MultiHeadAttention(nn.Module):

    def __init__(
        self, 
        n_heads, 
        embedding_dim, 
        context_dim, 
        context_length, 
        dropout_rate, 
        qkv_bias=False
    ):
        super().__init__()
        assert (
            context_dim % n_heads == 0
        ), "context dimension must be divisible by number of heads"


        self.context_dim = context_dim
        self.n_heads = n_heads
        self.head_dim = context_dim // n_heads

        self.W_query = nn.Linear(embedding_dim, context_dim, bias=qkv_bias)
        self.W_key = nn.Linear(embedding_dim, context_dim, bias=qkv_bias)
        self.W_value = nn.Linear(embedding_dim, context_dim, bias=qkv_bias)
        self.register_buffer(
            'mask',
            torch.triu(
                torch.ones(context_length, context_length), 
                diagonal=1
            )
        )
        self.dropout = nn.Dropout(dropout_rate)
        self.out_projection = nn.Linear(context_dim, context_dim)
    
    def forward(self, embeddings):
        n_batches, context_length, embedding_dim = embeddings.shape

        keys = self.W_key(embeddings)
        queries = self.W_query(embeddings)
        values = self.W_value(embeddings)

        keys = keys.view(n_batches, context_length, self.n_heads, self.head_dim)
        queries = queries.view(n_batches, context_length, self.n_heads, self.head_dim)
        values = values.view(n_batches, context_length, self.n_heads, self.head_dim)

        # transpose to (n_batches, n_heads, n_tokens, head_dim) 
        keys = keys.transpose(1, 2)   
        queries = queries.transpose(1, 2)
        values = values.transpose(1, 2)

        attention_scores = queries @ keys.transpose(2, 3)
        mask_bool = self.mask.bool()[:context_length, :context_length]
        attention_scores.masked_fill_(mask_bool, -torch.inf)

        attention_weights = torch.softmax(attention_scores / (self.head_dim**0.5), dim=-1)
        attention_weights = self.dropout(attention_weights)

        context_vectors = attention_weights @ values
        context_vectors = context_vectors.transpose(1, 2).contiguous()
        context_vectors = context_vectors.view(n_batches, context_length, self.context_dim)
        
        context_vectors = self.out_projection(context_vectors)
        return context_vectors
    

torch.manual_seed(123)

multihead_attention = MultiHeadAttention(2, embedding_dim, context_dim, context_length, 0.0)
context_vectors = multihead_attention(batch)
print(context_vectors)
print("context_vectors.shape:", context_vectors.shape)

tensor([[[0.3190, 0.4858],
         [0.2943, 0.3897],
         [0.2856, 0.3593],
         [0.2693, 0.3873],
         [0.2639, 0.3928],
         [0.2575, 0.4028]],

        [[0.3190, 0.4858],
         [0.2943, 0.3897],
         [0.2856, 0.3593],
         [0.2693, 0.3873],
         [0.2639, 0.3928],
         [0.2575, 0.4028]]], grad_fn=<ViewBackward0>)
context_vectors.shape: torch.Size([2, 6, 2])
