**Chapter 3: Coding Attention Mechansim**

3.3 Attending to different parts of the input with self-attention

3.3.1 A simple self-attention mechansim without trainable weights

In [None]:
import torch

inputs = torch.tensor(
  [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]] # step     (x^6)
)

In [None]:
attn_scores = torch.zeros(inputs.shape[0], inputs.shape[0])
print(attn_scores)

tensor([[0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.]])


In [None]:
for i in range(inputs.shape[0]):
  for j in range(inputs.shape[0]):
    attn_scores[i][j] = torch.dot(inputs[i], inputs[j])

print(attn_scores)

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])


In [None]:
attn_scores = inputs @ inputs.T
print(attn_scores)

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])


In [None]:
attn_weights = torch.softmax(attn_scores, dim = -1)
print(attn_weights)

tensor([[0.2098, 0.2006, 0.1981, 0.1242, 0.1220, 0.1452],
        [0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581],
        [0.1390, 0.2369, 0.2326, 0.1242, 0.1108, 0.1565],
        [0.1435, 0.2074, 0.2046, 0.1462, 0.1263, 0.1720],
        [0.1526, 0.1958, 0.1975, 0.1367, 0.1879, 0.1295],
        [0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]])


In [None]:
print(attn_weights[0].sum())

tensor(1.0000)


In [None]:
context_vector = attn_weights @ inputs
print(context_vector)

tensor([[0.4421, 0.5931, 0.5790],
        [0.4419, 0.6515, 0.5683],
        [0.4431, 0.6496, 0.5671],
        [0.4304, 0.6298, 0.5510],
        [0.4671, 0.5910, 0.5266],
        [0.4177, 0.6503, 0.5645]])


**3.4 Implementing Self attention with trainable weights**

In [None]:
d_in = inputs.shape[1]
d_out = 2

torch.manual_seed(123)

W_key = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad = True)
W_query = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad = True)
W_value = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad = True)



In [None]:
keys = inputs @ W_key
queries = inputs @ W_query
values = inputs @ W_value


In [None]:
attn_scores = queries @ keys.T

d_k = keys.shape[1]

attn_weights = torch.softmax(attn_scores / d_k**0.5, dim = -1)



In [None]:
context_vector = attn_weights @ values
print(context_vector)

tensor([[0.2947, 0.7956],
        [0.3015, 0.8132],
        [0.3010, 0.8120],
        [0.2925, 0.7902],
        [0.2863, 0.7737],
        [0.2979, 0.8043]], grad_fn=<MmBackward0>)


**3.4.2 Implementing a compact SelfAttention Class**

In [None]:
import torch.nn as nn

class SelfAttention_V1(nn.Module):

  def __init__(self, d_in, d_out):
    super().__init__()
    self.W_key = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad = True)
    self.W_query = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad = True)
    self.W_value = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad = True)


  def forward(self, x):
    keys = inputs @ W_key
    queries = inputs @ W_query
    values = inputs @ W_value

    attn_scores = queries @ keys.T
    attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim = -1)
    context_vector = attn_weights @ values

    return context_vector

torch.manual_seed(123)
sa_v1 = SelfAttention_V1(d_in, d_out)
sa_v1(inputs)

tensor([[0.2947, 0.7956],
        [0.3015, 0.8132],
        [0.3010, 0.8120],
        [0.2925, 0.7902],
        [0.2863, 0.7737],
        [0.2979, 0.8043]], grad_fn=<MmBackward0>)

In [None]:
import torch.nn as nn

class SelfAttention_V2(nn.Module):

  def __init__(self, d_in, d_out, qkv_bias = False):
    super().__init__()
    self.W_key = torch.nn.Linear(d_in, d_out, qkv_bias)
    self.W_query = torch.nn.Linear(d_in, d_out, qkv_bias)
    self.W_value = torch.nn.Linear(d_in, d_out, qkv_bias)

  def forward(self, x):
    keys = self.W_key(x)
    queries = self.W_query(x)
    values = self.W_value(x)

    attn_scores = queries @ keys.T
    attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim = -1)
    context_vector = attn_weights @ values

    return context_vector

torch.manual_seed(123)
sa_v2 = SelfAttention_V2(d_in, d_out)
sa_v2(inputs)

tensor([[-0.5300, -0.0988],
        [-0.5317, -0.1005],
        [-0.5317, -0.1005],
        [-0.5301, -0.1040],
        [-0.5298, -0.1011],
        [-0.5307, -0.1042]], grad_fn=<MmBackward0>)

**3.5 Hiding future words with causal attention**

In [None]:
keys = inputs @ W_key
queries = inputs @ W_query
values = inputs @ W_value



In [None]:
attn_scores = queries @ keys.T
print(attn_scores)

tensor([[0.9231, 1.2705, 1.2544, 0.6973, 0.6114, 0.8995],
        [1.3545, 1.8524, 1.8284, 1.0167, 0.8819, 1.3165],
        [1.3241, 1.8111, 1.7877, 0.9941, 0.8626, 1.2871],
        [0.7910, 1.0795, 1.0654, 0.5925, 0.5121, 0.7682],
        [0.4032, 0.5577, 0.5508, 0.3061, 0.2707, 0.3937],
        [1.1330, 1.5440, 1.5238, 0.8475, 0.7307, 1.0996]],
       grad_fn=<MmBackward0>)


In [None]:
context_length = inputs.shape[0]
mask = torch.triu(torch.ones(context_length, context_length), diagonal = 1)
masked = attn_scores.masked_fill(mask.bool(), -torch.inf)
print(masked)

tensor([[0.9231,   -inf,   -inf,   -inf,   -inf,   -inf],
        [1.3545, 1.8524,   -inf,   -inf,   -inf,   -inf],
        [1.3241, 1.8111, 1.7877,   -inf,   -inf,   -inf],
        [0.7910, 1.0795, 1.0654, 0.5925,   -inf,   -inf],
        [0.4032, 0.5577, 0.5508, 0.3061, 0.2707,   -inf],
        [1.1330, 1.5440, 1.5238, 0.8475, 0.7307, 1.0996]],
       grad_fn=<MaskedFillBackward0>)


In [None]:
attn_weights = torch.softmax(masked / keys.shape[-1]**0.5, dim = -1)
print(attn_weights)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4129, 0.5871, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2632, 0.3714, 0.3653, 0.0000, 0.0000, 0.0000],
        [0.2320, 0.2846, 0.2817, 0.2017, 0.0000, 0.0000],
        [0.1973, 0.2200, 0.2189, 0.1842, 0.1796, 0.0000],
        [0.1612, 0.2156, 0.2126, 0.1318, 0.1213, 0.1575]],
       grad_fn=<SoftmaxBackward0>)
tensor(1., grad_fn=<SumBackward0>)


In [None]:
torch.manual_seed(123)

dropout = torch.nn.Dropout(0.5)

print(dropout(attn_weights))

tensor([[2.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 1.1742, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.7307, 0.0000, 0.0000, 0.0000],
        [0.4641, 0.5691, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3945, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.4312, 0.0000, 0.0000, 0.0000, 0.0000]],
       grad_fn=<MulBackward0>)


**3.5.3: Implementing a compact causal self-attention class**

In [None]:
batch = torch.stack((inputs, inputs), dim = 0)
print(batch)
print(batch.shape)

tensor([[[0.4300, 0.1500, 0.8900],
         [0.5500, 0.8700, 0.6600],
         [0.5700, 0.8500, 0.6400],
         [0.2200, 0.5800, 0.3300],
         [0.7700, 0.2500, 0.1000],
         [0.0500, 0.8000, 0.5500]],

        [[0.4300, 0.1500, 0.8900],
         [0.5500, 0.8700, 0.6600],
         [0.5700, 0.8500, 0.6400],
         [0.2200, 0.5800, 0.3300],
         [0.7700, 0.2500, 0.1000],
         [0.0500, 0.8000, 0.5500]]])
torch.Size([2, 6, 3])


In [None]:
class CausalAttention(nn.Module):

    def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):
        super().__init__()
        self.d_out = d_out
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.dropout = nn.Dropout(dropout) # New
        self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1)) # New

    def forward(self, x):
        b, num_tokens, d_in = x.shape # New batch dimension b
        # For inputs where `num_tokens` exceeds `context_length`, this will result in errors
        # in the mask creation further below.
        # In practice, this is not a problem since the LLM (chapters 4-7) ensures that inputs
        # do not exceed `context_length` before reaching this forward method.
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)

        attn_scores = queries @ keys.transpose(1, 2) # Changed transpose
        attn_scores.masked_fill_(  # New, _ ops are in-place
            self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)  # `:num_tokens` to account for cases where the number of tokens in the batch is smaller than the supported context_size
        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1]**0.5, dim=-1
        )
        attn_weights = self.dropout(attn_weights) # New

        context_vec = attn_weights @ values
        return context_vec

torch.manual_seed(123)

context_length = batch.shape[1]
ca = CausalAttention(d_in, d_out, context_length, 0.0)

context_vecs = ca(batch)

print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)

tensor([[[-0.4519,  0.2216],
         [-0.5874,  0.0058],
         [-0.6300, -0.0632],
         [-0.5675, -0.0843],
         [-0.5526, -0.0981],
         [-0.5299, -0.1081]],

        [[-0.4519,  0.2216],
         [-0.5874,  0.0058],
         [-0.6300, -0.0632],
         [-0.5675, -0.0843],
         [-0.5526, -0.0981],
         [-0.5299, -0.1081]]], grad_fn=<UnsafeViewBackward0>)
context_vecs.shape: torch.Size([2, 6, 2])


3.6 Extending single-head attention to multi-head attention

**3.6.2 Implementing multi-head attention with weight splits**

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        assert (d_out % num_heads == 0), \
            "d_out must be divisible by num_heads"

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dim

        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out)  # Linear layer to combine head outputs
        self.dropout = nn.Dropout(dropout)
        self.register_buffer(
            "mask",
            torch.triu(torch.ones(context_length, context_length),
                       diagonal=1)
        )

    def forward(self, x):
      b, num_tokens, d_in = x.shape
      # As in `CausalAttention`, for inputs where `num_tokens` exceeds `context_length`,
      # this will result in errors in the mask creation further below.
      # In practice, this is not a problem since the LLM (chapters 4-7) ensures that inputs
      # do not exceed `context_length` before reaching this forwar

      keys = self.W_key(x) # Shape: (b, num_tokens, d_out)
      queries = self.W_query(x)
      values = self.W_value(x)

      # We implicitly split the matrix by adding a `num_heads` dimension
      # Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
      keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
      values = values.view(b, num_tokens, self.num_heads, self.head_dim)
      queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)

      # Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
      keys = keys.transpose(1, 2)
      queries = queries.transpose(1, 2)
      values = values.transpose(1, 2)

      # Compute scaled dot-product attention (aka self-attention) with a causal mask
      attn_scores = queries @ keys.transpose(2, 3)  # Dot product for each head

      # Original mask truncated to the number of tokens and converted to boolean
      mask_bool = self.mask.bool()[:num_tokens, :num_tokens]

      # Use the mask to fill attention scores
      attn_scores.masked_fill_(mask_bool, -torch.inf)

      attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
      attn_weights = self.dropout(attn_weights)

      # Shape: (b, num_tokens, num_heads, head_dim)
      context_vec = (attn_weights @ values).transpose(1, 2)

      # Combine heads, where self.d_out = self.num_heads * self.head_dim
      context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
      context_vec = self.out_proj(context_vec) # optional projection

      return context_vec

torch.manual_seed(123)

batch_size, context_length, d_in = batch.shape
d_out = 2
mha = MultiHeadAttention(d_in, d_out, context_length, 0.0, num_heads=2)

context_vecs = mha(batch)

print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)


tensor([[[0.3190, 0.4858],
         [0.2943, 0.3897],
         [0.2856, 0.3593],
         [0.2693, 0.3873],
         [0.2639, 0.3928],
         [0.2575, 0.4028]],

        [[0.3190, 0.4858],
         [0.2943, 0.3897],
         [0.2856, 0.3593],
         [0.2693, 0.3873],
         [0.2639, 0.3928],
         [0.2575, 0.4028]]], grad_fn=<ViewBackward0>)
context_vecs.shape: torch.Size([2, 6, 2])
