## **SELF ATTENTION WITH TRAINABLE WEIGHTS**

In [1]:
import torch
import torch.nn as nn

inputs = torch.tensor(
  [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]] # step     (x^6)
)

In [2]:
inputs.shape

torch.Size([6, 3])

In [3]:
x = inputs[1]
d_in = inputs.shape[1]
d_out = 2

In [4]:
torch.manual_seed(123)
W_query = nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_key = nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_value = nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)

In [5]:
print(f"W_query: \n{W_query}")
print(f"\nW_key: \n{W_key}")
print(f"\nW_value: \n{W_value}")

W_query: 
Parameter containing:
tensor([[0.2961, 0.5166],
        [0.2517, 0.6886],
        [0.0740, 0.8665]])

W_key: 
Parameter containing:
tensor([[0.1366, 0.1025],
        [0.1841, 0.7264],
        [0.3153, 0.6871]])

W_value: 
Parameter containing:
tensor([[0.0756, 0.1966],
        [0.3164, 0.4017],
        [0.1186, 0.8274]])


In [6]:
query = x @ W_query
key = x @ W_key
value = x @ W_value

In [7]:
# For input 2
print(f"query: \n{query}")
print(f"\nkey: \n{key}")
print(f"\nvalue: \n{value}")

query: 
tensor([0.4306, 1.4551])

key: 
tensor([0.4433, 1.1419])

value: 
tensor([0.3951, 1.0037])


## **QKV for whole Input**

In [8]:
query = inputs @ W_query
key = inputs @ W_key
value = inputs @ W_value

In [9]:
print(f"query: \n{query}")
print(f"\nkey: \n{key}")
print(f"\nvalue: \n{value}")

query: 
tensor([[0.2309, 1.0966],
        [0.4306, 1.4551],
        [0.4300, 1.4343],
        [0.2355, 0.7990],
        [0.2983, 0.6565],
        [0.2568, 1.0533]])

key: 
tensor([[0.3669, 0.7646],
        [0.4433, 1.1419],
        [0.4361, 1.1156],
        [0.2408, 0.6706],
        [0.1827, 0.3292],
        [0.3275, 0.9642]])

value: 
tensor([[0.1855, 0.8812],
        [0.3951, 1.0037],
        [0.3879, 0.9831],
        [0.2393, 0.5493],
        [0.1492, 0.3346],
        [0.3221, 0.7863]])


In [10]:
query_2 = query[1]
key_2 = key[1]

attn_scores_2 = query_2.dot(key_2)
attn_scores_2

tensor(1.8524)

In [11]:
attn_scores_2 = query_2 @ key.T
attn_scores_2

tensor([1.2705, 1.8524, 1.8111, 1.0795, 0.5577, 1.5440])

In [12]:
# Attention scores
attn_scores = query @ key.T
attn_scores

tensor([[0.9231, 1.3545, 1.3241, 0.7910, 0.4032, 1.1330],
        [1.2705, 1.8524, 1.8111, 1.0795, 0.5577, 1.5440],
        [1.2544, 1.8284, 1.7877, 1.0654, 0.5508, 1.5238],
        [0.6973, 1.0167, 0.9941, 0.5925, 0.3061, 0.8475],
        [0.6114, 0.8819, 0.8626, 0.5121, 0.2707, 0.7307],
        [0.8995, 1.3165, 1.2871, 0.7682, 0.3937, 1.0996]])

In [13]:
attn_scores.shape

torch.Size([6, 6])

## **Attention Weights**

In [15]:
d_k = key.shape[1]
attn_weights = torch.softmax(attn_scores / d_k ** 0.5, dim = -1)
print(f"Attention weights: \n{attn_weights}")

Attention weights: 
tensor([[0.1551, 0.2104, 0.2059, 0.1413, 0.1074, 0.1799],
        [0.1500, 0.2264, 0.2199, 0.1311, 0.0906, 0.1820],
        [0.1503, 0.2256, 0.2192, 0.1315, 0.0914, 0.1819],
        [0.1591, 0.1994, 0.1962, 0.1477, 0.1206, 0.1769],
        [0.1610, 0.1949, 0.1923, 0.1501, 0.1265, 0.1752],
        [0.1557, 0.2092, 0.2048, 0.1419, 0.1089, 0.1794]])


In [18]:
context_vector = attn_weights @ value
print(f"Context Vectors: \n{context_vector}")

Context Vectors: 
tensor([[0.2996, 0.8053],
        [0.3061, 0.8210],
        [0.3058, 0.8203],
        [0.2948, 0.7939],
        [0.2927, 0.7891],
        [0.2990, 0.8040]])


## **Self Attention Complete**

In [19]:
class SelfAttention(nn.Module):
  def __init__(self, d_in, d_out):
    super().__init__()
    self.W_Query = nn.Parameter(torch.rand(d_in, d_out))
    self.W_Value = nn.Parameter(torch.rand(d_in, d_out))
    self.W_Key = nn.Parameter(torch.rand(d_in, d_out))

  def forward(self, x):
    queries = x @ self.W_Query
    keys = x @ self.W_Key
    values = x @ self.W_Value

    attn_scores = queries @ keys.T
    attn_weights = torch.softmax(attn_scores / keys.shape[1] ** 0.5, dim = -1)

    context_vec = attn_weights @ values
    return context_vec

In [24]:
sa = SelfAttention(d_in = 3, d_out = 2)
print(sa(inputs))

tensor([[0.7236, 0.8354],
        [0.7308, 0.8426],
        [0.7306, 0.8425],
        [0.7145, 0.8256],
        [0.7170, 0.8303],
        [0.7179, 0.8283]], grad_fn=<MmBackward0>)


In [25]:
class SelfAttention_V2(nn.Module):
  def __init__(self, d_in, d_out, qkv_bias = False):
    super().__init__()
    self.W_Query = nn.Linear(d_in, d_out, bias = qkv_bias)
    self.W_Value = nn.Linear(d_in, d_out, bias = qkv_bias)
    self.W_Key = nn.Linear(d_in, d_out, bias = qkv_bias)

  def forward(self, x):
    keys = self.W_Key(x)
    queries = self.W_Query(x)
    values = self.W_Value(x)

    attn_scores = queries @ keys.T
    attn_weights = torch.softmax(attn_scores / keys.shape[1] ** 0.5, dim = -1)

    context_vec = attn_weights @ values
    return context_vec

In [26]:
sa = SelfAttention_V2(d_in = 3, d_out = 2)
print(sa(inputs))

tensor([[0.2371, 0.7407],
        [0.2305, 0.7362],
        [0.2307, 0.7364],
        [0.2286, 0.7345],
        [0.2338, 0.7396],
        [0.2269, 0.7328]], grad_fn=<MmBackward0>)


## **Hidding Future Words with Causal Attention**

In [45]:
inputs = torch.tensor(
  [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]] # step     (x^6)
  #  [0.4419, 0.6515, 0.5683]]
)

In [46]:
queries = sa.W_Query(inputs)
values = sa.W_Value(inputs)
keys = sa.W_Key(inputs)

attn_scores = queries @ keys.T
attn_weights = torch.softmax(attn_scores / keys.shape[-1] ** 0.5, dim = -1)
print(attn_weights)

tensor([[0.1533, 0.1751, 0.1748, 0.1654, 0.1627, 0.1687],
        [0.1660, 0.1679, 0.1679, 0.1659, 0.1660, 0.1664],
        [0.1657, 0.1682, 0.1681, 0.1658, 0.1659, 0.1664],
        [0.1693, 0.1653, 0.1654, 0.1666, 0.1673, 0.1661],
        [0.1604, 0.1727, 0.1726, 0.1642, 0.1636, 0.1665],
        [0.1722, 0.1628, 0.1629, 0.1675, 0.1685, 0.1660]],
       grad_fn=<SoftmaxBackward0>)


In [47]:
context_length = attn_weights.shape[-1]
torch.ones(context_length, context_length)

tensor([[1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.]])

In [48]:
context_length = attn_weights.shape[-1]
mask_simple = torch.tril(torch.ones(context_length, context_length))
print(mask_simple)

tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])


In [49]:
masked_simple = mask_simple * attn_weights
print(masked_simple)

tensor([[0.1533, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1660, 0.1679, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1657, 0.1682, 0.1681, 0.0000, 0.0000, 0.0000],
        [0.1693, 0.1653, 0.1654, 0.1666, 0.0000, 0.0000],
        [0.1604, 0.1727, 0.1726, 0.1642, 0.1636, 0.0000],
        [0.1722, 0.1628, 0.1629, 0.1675, 0.1685, 0.1660]],
       grad_fn=<MulBackward0>)


In [55]:
rows_sum = masked_simple.sum(dim = 1, keepdim = True)
masked_simple_norm = masked_simple / rows_sum
print(masked_simple_norm)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4971, 0.5029, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3300, 0.3350, 0.3349, 0.0000, 0.0000, 0.0000],
        [0.2540, 0.2480, 0.2481, 0.2500, 0.0000, 0.0000],
        [0.1924, 0.2073, 0.2070, 0.1970, 0.1963, 0.0000],
        [0.1722, 0.1628, 0.1629, 0.1675, 0.1685, 0.1660]],
       grad_fn=<DivBackward0>)


## **Using Infinity to mask**

In [63]:
mask = torch.triu(torch.ones(context_length, context_length), diagonal = 1)
masked = attn_scores.masked_fill(mask.bool(), float("-inf"))
print(masked)

tensor([[ 0.0327,    -inf,    -inf,    -inf,    -inf,    -inf],
        [ 0.0223,  0.0386,    -inf,    -inf,    -inf,    -inf],
        [ 0.0252,  0.0465,  0.0462,    -inf,    -inf,    -inf],
        [ 0.0006, -0.0333, -0.0329, -0.0221,    -inf,    -inf],
        [ 0.0701,  0.1752,  0.1737,  0.1030,  0.0986,    -inf],
        [-0.0260, -0.1052, -0.1042, -0.0648, -0.0571, -0.0778]],
       grad_fn=<MaskedFillBackward0>)


In [64]:
attn_weights = torch.softmax(masked / keys.shape[-1] ** 0.5, dim =1)
print(attn_weights)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4971, 0.5029, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3300, 0.3350, 0.3349, 0.0000, 0.0000, 0.0000],
        [0.2540, 0.2480, 0.2481, 0.2500, 0.0000, 0.0000],
        [0.1924, 0.2073, 0.2070, 0.1970, 0.1963, 0.0000],
        [0.1722, 0.1628, 0.1629, 0.1675, 0.1685, 0.1660]],
       grad_fn=<SoftmaxBackward0>)


## **Masking Additional Attention Weights with dropout**

In [71]:
torch.manual_seed(123)
dropout = nn.Dropout(0.5)
example = torch.ones(6, 6)
print("Before Dropout: \n", example)
print("\nAfter Dropout: \n", dropout(example))

Before Dropout: 
 tensor([[1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.]])

After Dropout: 
 tensor([[2., 2., 2., 2., 2., 2.],
        [0., 2., 0., 0., 0., 0.],
        [0., 0., 2., 0., 2., 0.],
        [2., 2., 0., 0., 0., 2.],
        [2., 0., 0., 0., 0., 2.],
        [0., 2., 0., 0., 0., 0.]])


In [72]:
print(dropout(attn_weights))

tensor([[2.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.6699, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.4960, 0.0000, 0.4999, 0.0000, 0.0000],
        [0.0000, 0.4145, 0.4141, 0.3939, 0.3927, 0.0000],
        [0.3444, 0.3257, 0.0000, 0.0000, 0.3369, 0.3320]],
       grad_fn=<MulBackward0>)


## **Causal Attention Complete**

In [73]:
batch = torch.stack((inputs, inputs), dim = 0)

In [75]:
print(batch.shape)

torch.Size([2, 6, 3])


In [78]:
class CausalAttention(nn.Module):
  def __init__(self, d_in, d_out, context_length, dropout, qkv_bias = False):
    super().__init__()
    self.d_out = d_out
    self.W_Query = nn.Linear(d_in, d_out, bias = qkv_bias)
    self.W_Value = nn.Linear(d_in, d_out, bias = qkv_bias)
    self.W_Key = nn.Linear(d_in, d_out, bias = qkv_bias)

    self.dropout = nn.Dropout(dropout)
    self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal = 1))

  def forward(self, x):
    b, num_tokens, d_in = x.shape
    keys = self.W_Key(x)
    queries = self.W_Query(x)
    values = self.W_Value(x)

    attn_scores = queries @ keys.transpose(1, 2)
    attn_scores.masked_fill(self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)
    attn_weights = torch.softmax(attn_scores / keys.shape[-1] ** 0.5, dim = -1)
    attn_weights = self.dropout(attn_weights)

    context_vec = attn_weights @ values
    return context_vec

In [87]:
from os import cpu_count
context_length = batch.shape[1]
ca = CausalAttention(d_in = 3, d_out = 2, context_length=context_length, dropout=0.0)
context_vec = ca(batch)

In [88]:
context_vec

tensor([[[-0.2390,  0.1580],
         [-0.2445,  0.1625],
         [-0.2442,  0.1622],
         [-0.2392,  0.1583],
         [-0.2351,  0.1549],
         [-0.2425,  0.1610]],

        [[-0.2390,  0.1580],
         [-0.2445,  0.1625],
         [-0.2442,  0.1622],
         [-0.2392,  0.1583],
         [-0.2351,  0.1549],
         [-0.2425,  0.1610]]], grad_fn=<UnsafeViewBackward0>)

In [89]:
context_vec.shape

torch.Size([2, 6, 2])

## **Multi-Head Attention**

In [90]:
class MultiHeadAttention(nn.Module):
  def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias = False):
    super().__init__()
    self.heads = [CausalAttention(d_in, d_out, context_length, dropout, qkv_bias) for _ in range(num_heads)]

  def forward(self, x):
    return torch.cat([head(x) for head in self.heads], dim = -1)

In [93]:
context_length = batch.shape[1]
d_in, d_out = 3, 2
mha = MultiHeadAttention(d_in, d_out, context_length, dropout = 0.0, num_heads = 2)
context_vecs = mha(batch)

In [94]:
context_vecs

tensor([[[-0.3298,  0.2368,  0.0692,  0.1881],
         [-0.3278,  0.2369,  0.0741,  0.1855],
         [-0.3280,  0.2368,  0.0737,  0.1857],
         [-0.3292,  0.2367,  0.0708,  0.1875],
         [-0.3324,  0.2352,  0.0636,  0.1906],
         [-0.3274,  0.2375,  0.0751,  0.1855]],

        [[-0.3298,  0.2368,  0.0692,  0.1881],
         [-0.3278,  0.2369,  0.0741,  0.1855],
         [-0.3280,  0.2368,  0.0737,  0.1857],
         [-0.3292,  0.2367,  0.0708,  0.1875],
         [-0.3324,  0.2352,  0.0636,  0.1906],
         [-0.3274,  0.2375,  0.0751,  0.1855]]], grad_fn=<CatBackward0>)

In [95]:
context_vecs.shape

torch.Size([2, 6, 4])

## **Multi-Head Attention with Weight Splits**

In [98]:
class MultiHeadAttention(nn.Module):
  def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias = False):
    super().__init__()
    assert (d_out % num_heads) == 0, "d_out must be divisible to num_heads"

    self.d_out = d_out
    self.num_heads = num_heads
    self.head_dim = d_out // num_heads

    self.W_Query = nn.Linear(d_in, d_out, bias = qkv_bias)
    self.W_Value = nn.Linear(d_in, d_out, bias = qkv_bias)
    self.W_Key = nn.Linear(d_in, d_out, bias = qkv_bias)

    self.out_proj = nn.Linear(d_out, d_out) # Linear Layer to combine outputs
    self.dropout = nn.Dropout(dropout)
    self.register_buffer(
        'mask',
        torch.triu(torch.ones(context_length, context_length), diagonal = 1)
    )

  def forward(self, x):
    b, num_tokens, d_in = x.shape

    keys = self.W_Key(x)
    values = self.W_Value(x)
    queries = self.W_Query(x)

    # We implicitly split the matrix by adding 'num_heads' dimension
    # Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
    keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
    values = values.view(b, num_tokens, self.num_heads, self.head_dim)
    queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)

    # Transpose (b, num_heads, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
    keys = keys.transpose(1, 2)
    values = values.transpose(1, 2)
    queries = queries.transpose(1, 2)

    # Scaled dot product attention (Self-Attention) with causal mask
    attn_scores = queries @ keys.transpose(2, 3) # Dot product of each head

    # Original mask truncated to the number of tokens and converted to boolean
    mask_bool = self.mask.bool()[:num_tokens, :num_tokens]

    # Use the mask to fill the attention scores
    attn_scores.masked_fill_(mask_bool, -torch.inf)

    attn_weights = torch.softmax(attn_scores / self.head_dim ** 0.5, dim = -1)
    attn_weights = self.dropout(attn_weights)

    # Shape: (b, num_tokens, num_heads, head_dim)
    context_vec = (attn_weights @ values).transpose(1, 2)

    # Combine heads, where self.d_out = self.num_heads * self.head_dim
    context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
    context_vec = self.out_proj(context_vec)

    return context_vec

In [101]:
torch.manual_seed(123)
inputs = torch.rand(3, 6)

In [102]:
inputs

tensor([[0.2961, 0.5166, 0.2517, 0.6886, 0.0740, 0.8665],
        [0.1366, 0.1025, 0.1841, 0.7264, 0.3153, 0.6871],
        [0.0756, 0.1966, 0.3164, 0.4017, 0.1186, 0.8274]])

In [103]:
batch = torch.stack((inputs, inputs), dim = 0)

In [104]:
batch

tensor([[[0.2961, 0.5166, 0.2517, 0.6886, 0.0740, 0.8665],
         [0.1366, 0.1025, 0.1841, 0.7264, 0.3153, 0.6871],
         [0.0756, 0.1966, 0.3164, 0.4017, 0.1186, 0.8274]],

        [[0.2961, 0.5166, 0.2517, 0.6886, 0.0740, 0.8665],
         [0.1366, 0.1025, 0.1841, 0.7264, 0.3153, 0.6871],
         [0.0756, 0.1966, 0.3164, 0.4017, 0.1186, 0.8274]]])

In [105]:
print(batch.shape)

torch.Size([2, 3, 6])


In [106]:
batch_size, context_length, d_in = batch.shape

In [107]:
print(f"batch_size: {batch_size}")
print(f"context_length: {context_length}")
print(f"d_in: {d_in}")

batch_size: 2
context_length: 3
d_in: 6


In [108]:
d_out = 6
mha = MultiHeadAttention(d_in, d_out, context_length, dropout = 0.0, num_heads = 2)
context_vecs = mha(batch)

In [110]:
print(context_vecs)
print(f"\nContext Vector shape: {context_vecs.shape}")

tensor([[[-0.5220,  0.5167,  0.4977,  0.1516,  0.0896,  0.3226],
         [-0.4437,  0.4934,  0.5212,  0.1573,  0.1259,  0.3273],
         [-0.4483,  0.4875,  0.5042,  0.1744,  0.1281,  0.3392]],

        [[-0.5220,  0.5167,  0.4977,  0.1516,  0.0896,  0.3226],
         [-0.4437,  0.4934,  0.5212,  0.1573,  0.1259,  0.3273],
         [-0.4483,  0.4875,  0.5042,  0.1744,  0.1281,  0.3392]]],
       grad_fn=<ViewBackward0>)

Context Vector shape: torch.Size([2, 3, 6])
