# Self Attention Mechanism

### 1. Simplified Self attention mechanism

---



In [1]:
# dummy embeddings
import torch
inputs = torch.tensor([
[0.43, 0.15, 0.89], # Your (x^1)
[0.55, 0.87, 0.66],  # journey (x^2)
[0.57, 0.85, 0.64],  # starts (x^3)
[0.22, 0.58, 0.33],  # with (x^4)
[0.77, 0.25, 0.10],  # one (x^5)
[0.05, 0.80, 0.55]   # step (x^6)
])
print(inputs)

tensor([[0.4300, 0.1500, 0.8900],
        [0.5500, 0.8700, 0.6600],
        [0.5700, 0.8500, 0.6400],
        [0.2200, 0.5800, 0.3300],
        [0.7700, 0.2500, 0.1000],
        [0.0500, 0.8000, 0.5500]])


In [2]:
# calculating attention scores via dot product
query = inputs[1] # journey
attn_score_2 = torch.empty(inputs.shape[0]) # 6 random 0. values
for i, x_i in enumerate(inputs):
    attn_score_2[i] = torch.dot(x_i, query)

print("Attention scores:")
print(attn_score_2)

Attention scores:
tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])


In [3]:
# normalize the attention scores
attn_weight_2 = attn_score_2 / attn_score_2.sum()
print("Attention scores:", attn_score_2)
print("Attention weights:", attn_weight_2)
print("Sum:", attn_weight_2.sum())

Attention scores: tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])
Attention weights: tensor([0.1455, 0.2278, 0.2249, 0.1285, 0.1077, 0.1656])
Sum: tensor(1.0000)


In [4]:
# normalization using softmax activation function
attn_weight_2 = torch.softmax(attn_score_2, dim = 0)
print("Attention weights:", attn_weight_2)
print("Sum:", attn_weight_2.sum())

Attention weights: tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
Sum: tensor(1.)


In [5]:
# calculating context vector --> z^2 i.e journey
query = inputs[1]
context_vec_2 = torch.zeros(query.shape)
for i, x_i in enumerate(inputs):
    context_vec_2 += attn_weight_2[i] * x_i

print("Context vectors of input 2:")
print(context_vec_2)

Context vectors of input 2:
tensor([0.4419, 0.6515, 0.5683])


### 2. Multiple attention weights calculation

---



In [6]:
# calculate attention scores
attn_scores = torch.empty(6, 6)
for i, x_i in enumerate(inputs):
    for j, x_j in enumerate(inputs):
        attn_scores[i, j] = torch.dot(x_i, x_j)

print(attn_scores)

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])


In [7]:
# calculating using matrix multiplication
attn_scores = inputs @ inputs.T
print(attn_scores)

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])


In [8]:
# calculate attention weights
attn_weights = torch.softmax(attn_scores, dim = -1)
print(attn_weights)

tensor([[0.2098, 0.2006, 0.1981, 0.1242, 0.1220, 0.1452],
        [0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581],
        [0.1390, 0.2369, 0.2326, 0.1242, 0.1108, 0.1565],
        [0.1435, 0.2074, 0.2046, 0.1462, 0.1263, 0.1720],
        [0.1526, 0.1958, 0.1975, 0.1367, 0.1879, 0.1295],
        [0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]])


In [9]:
# calculate context vectors
all_context_vec = attn_weights @ inputs
print(all_context_vec)

tensor([[0.4421, 0.5931, 0.5790],
        [0.4419, 0.6515, 0.5683],
        [0.4431, 0.6496, 0.5671],
        [0.4304, 0.6298, 0.5510],
        [0.4671, 0.5910, 0.5266],
        [0.4177, 0.6503, 0.5645]])


### 2. Self attention mechanism with trainable weights

---



In [10]:
x_2 = inputs[1]
d_in = inputs.shape[1]
d_out = 2

In [11]:
# initialize three weights matrices
w_query = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad = False)
w_key = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad = False)
w_value = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad = False)

In [12]:
# calculate query, key and value vector
query_2 = x_2 @ w_query
key_2 = x_2 @ w_key
value_2 = x_2 @ w_value
print(query_2)

tensor([1.0774, 1.5918])


In [13]:
# all key and values
keys = inputs @ w_key
values = inputs @ w_value
print("Key shape:", keys.shape)
print("Value shape:", values.shape)

Key shape: torch.Size([6, 2])
Value shape: torch.Size([6, 2])


In [14]:
# calculate attention score of W22
keys_2 = keys[1]
attn_score_22 = query_2.dot(keys_2)
print("Attention score:", attn_score_22)

Attention score: tensor(2.2706)


In [15]:
# calculate attention score of all queries
attn_score_2 = query_2 @ keys.T
print("All Attention score:", attn_score_2)

All Attention score: tensor([1.3497, 2.2706, 2.2296, 1.3289, 0.8635, 1.8044])


In [16]:
# normalized attention weights
d_k = keys.shape[1]
attn_weights_2 = torch.softmax(attn_score_2 / d_k ** 0.5, dim = -1)
print(attn_weights_2)

tensor([0.1273, 0.2442, 0.2372, 0.1255, 0.0903, 0.1756])


In [17]:
# calculate the context vectors
context_vec_2 = attn_weights_2 @ values
print("Context vector:", context_vec_2)

Context vector: tensor([0.4599, 0.3309])


### 3. Python class

---



In [18]:
import torch
import torch.nn as nn

In [19]:
class SelfAttentionV1(nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()
        self.d_out = d_out
        self.W_query = torch.nn.Parameter(torch.rand(d_in, d_out))
        self.W_key = torch.nn.Parameter(torch.rand(d_in, d_out))
        self.W_value = torch.nn.Parameter(torch.rand(d_in, d_out))

    def forward(self, x):
        queries = x @ self.W_query
        keys = x @ self.W_key
        values = x @ self.W_value
        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(attn_scores / d_in ** 0.5, dim = -1)
        all_context_vec = attn_weights @ values
        return all_context_vec

In [20]:
sa_V1 = SelfAttentionV1(d_in, d_out)
print("Context vectors ->")
print(sa_V1(inputs))

Context vectors ->
tensor([[0.9396, 0.7081],
        [0.9674, 0.7226],
        [0.9661, 0.7220],
        [0.9107, 0.6933],
        [0.9084, 0.6942],
        [0.9286, 0.7015]], grad_fn=<MmBackward0>)


In [23]:
# using linear layer -> better for weights initialization
class SelfAttentionV2(nn.Module):
    def __init__(self, d_in, d_out, qkv_bias = False):
        super().__init__()
        self.d_out = d_out
        self.W_query = nn.Linear(d_in, d_out, bias = qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias = qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias = qkv_bias)

    def forward(self, x):
        queries = self.W_query(x)
        keys = self.W_key(x)
        values = self.W_value(x)
        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(attn_scores / keys.shape[-1] ** 0.5, dim = -1)
        all_context_vec = attn_weights @ values
        return all_context_vec

In [24]:
sa_V2 = SelfAttentionV2(d_in, d_out)
print(sa_V1(inputs))

tensor([[ 0.1362, -0.4624],
        [ 0.1403, -0.4660],
        [ 0.1401, -0.4658],
        [ 0.1403, -0.4674],
        [ 0.1353, -0.4619],
        [ 0.1426, -0.4696]], grad_fn=<MmBackward0>)


### 4. Masked Self attention

---



In [26]:
# calculate query, key and value
queries = sa_V2.W_query(inputs)
key = sa_V2.W_key(inputs)
attn_score = queries @ key.T
attn_weights = torch.softmax(attn_score / keys.shape[-1] ** 0.5, dim = -1)
print(attn_weights)

tensor([[0.1578, 0.1763, 0.1760, 0.1620, 0.1627, 0.1652],
        [0.1656, 0.1692, 0.1691, 0.1650, 0.1649, 0.1662],
        [0.1657, 0.1691, 0.1690, 0.1651, 0.1649, 0.1662],
        [0.1671, 0.1669, 0.1669, 0.1664, 0.1662, 0.1666],
        [0.1681, 0.1675, 0.1674, 0.1656, 0.1651, 0.1664],
        [0.1658, 0.1676, 0.1675, 0.1662, 0.1663, 0.1665]],
       grad_fn=<SoftmaxBackward0>)


In [27]:
# create masking
context_length = attn_score.shape[0]
mask_simple = torch.tril(torch.ones(context_length, context_length))
print(mask_simple)

tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])


In [29]:
# multiply masks with attn_scores
masked_simple = mask_simple * attn_weights
print(masked_simple)

tensor([[0.1578, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1656, 0.1692, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1657, 0.1691, 0.1690, 0.0000, 0.0000, 0.0000],
        [0.1671, 0.1669, 0.1669, 0.1664, 0.0000, 0.0000],
        [0.1681, 0.1675, 0.1674, 0.1656, 0.1651, 0.0000],
        [0.1658, 0.1676, 0.1675, 0.1662, 0.1663, 0.1665]],
       grad_fn=<MulBackward0>)


In [32]:
# normalize weights
rows_sums = masked_simple.sum(dim = 1, keepdim = True)
masked_norm = masked_simple / rows_sums
print(masked_norm)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4946, 0.5054, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3288, 0.3357, 0.3355, 0.0000, 0.0000, 0.0000],
        [0.2504, 0.2501, 0.2501, 0.2493, 0.0000, 0.0000],
        [0.2016, 0.2009, 0.2008, 0.1986, 0.1980, 0.0000],
        [0.1658, 0.1676, 0.1675, 0.1662, 0.1663, 0.1665]],
       grad_fn=<DivBackward0>)


In [37]:
# masking
mask = torch.triu(torch.ones(context_length, context_length), diagonal = 1)
masked = attn_score.masked_fill(mask.bool(), -torch.inf)
print(masked)

tensor([[0.1242,   -inf,   -inf,   -inf,   -inf,   -inf],
        [0.0485, 0.0790,   -inf,   -inf,   -inf,   -inf],
        [0.0486, 0.0777, 0.0769,   -inf,   -inf,   -inf],
        [0.0104, 0.0087, 0.0085, 0.0043,   -inf,   -inf],
        [0.0368, 0.0317, 0.0311, 0.0158, 0.0114,   -inf],
        [0.0121, 0.0266, 0.0265, 0.0153, 0.0158, 0.0180]],
       grad_fn=<MaskedFillBackward0>)


In [39]:
# apply softmax
attn_weights = torch.softmax(masked / keys.shape[0], dim = -1)
print(attn_weights)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4987, 0.5013, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3323, 0.3339, 0.3338, 0.0000, 0.0000, 0.0000],
        [0.2501, 0.2500, 0.2500, 0.2498, 0.0000, 0.0000],
        [0.2004, 0.2002, 0.2002, 0.1997, 0.1995, 0.0000],
        [0.1665, 0.1669, 0.1669, 0.1666, 0.1666, 0.1666]],
       grad_fn=<SoftmaxBackward0>)


In [43]:
# using dropouts for preventing overfitting
dropout = torch.nn.Dropout(p = 0.5)
print(dropout(attn_weights))

tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.6678, 0.6677, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.5001, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.4004, 0.0000, 0.0000, 0.3991, 0.0000],
        [0.3329, 0.0000, 0.3337, 0.0000, 0.0000, 0.3333]],
       grad_fn=<MulBackward0>)


In [45]:
# create dummy inputs
batch = torch.stack((inputs, inputs), dim = 0)
print(batch)

tensor([[[0.4300, 0.1500, 0.8900],
         [0.5500, 0.8700, 0.6600],
         [0.5700, 0.8500, 0.6400],
         [0.2200, 0.5800, 0.3300],
         [0.7700, 0.2500, 0.1000],
         [0.0500, 0.8000, 0.5500]],

        [[0.4300, 0.1500, 0.8900],
         [0.5500, 0.8700, 0.6600],
         [0.5700, 0.8500, 0.6400],
         [0.2200, 0.5800, 0.3300],
         [0.7700, 0.2500, 0.1000],
         [0.0500, 0.8000, 0.5500]]])


In [51]:
class CausalAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, qkv_bias = False):
        super().__init__()
        self.d_out = d_out
        self.W_query = torch.nn.Linear(d_in, d_out, bias = qkv_bias)
        self.W_key = torch.nn.Linear(d_in, d_out, bias = qkv_bias)
        self.W_value = torch.nn.Linear(d_in, d_out, bias = qkv_bias)
        self.dropout = torch.nn.Dropout(p = 0.5)
        self.register_buffer("mask", torch.triu(torch.ones(context_length, context_length), diagonal = 1))

    def forward(self, x):
        b, num_tokens, d_in = x.shape
        queries = self.W_query(x)
        keys = self.W_key(x)
        values = self.W_value(x)
        attn_scores = queries @ keys.transpose(1, 2)
        attn_weights = attn_scores.masked_fill_(self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)
        attn_weights = torch.softmax(attn_weights / keys.shape[-1] ** 0.5, dim = -1)
        attn_weights = self.dropout(attn_weights)

        context_vec = attn_weights @ values
        return context_vec

In [53]:
# context vectors
context_length = batch.shape[1]
ca = CausalAttention(d_in, d_out,context_length, 0.0)
context_vec = ca(batch)
print(context_vec)

tensor([[[-0.1421, -0.5438],
         [-0.0745, -0.2851],
         [-0.1348, -0.1828],
         [-0.0507, -0.0732],
         [ 0.0000,  0.0000],
         [-0.0763, -0.1756]],

        [[-0.1421, -0.5438],
         [-0.1742, -0.4291],
         [-0.1852, -0.3756],
         [-0.1307, -0.1932],
         [-0.1581, -0.1561],
         [-0.1342, -0.1897]]], grad_fn=<UnsafeViewBackward0>)


In [54]:
print(context_vec.shape)

torch.Size([2, 6, 2])


### 5. Multi head attention

---



In [57]:
class MultiHeadAttentionWrapper(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias = False):
        super().__init__()
        self.heads = nn.ModuleList(
            [CausalAttention(d_in, d_out, context_length, dropout, qkv_bias) for _ in range(num_heads)]
        )

    def forward(self, x):
        return torch.cat([head(x)for head in self.heads], dim = -1)

In [58]:
context_length = batch.shape[1]
d_in, d_out = 3, 2
ma = MultiHeadAttentionWrapper(d_in, d_out, context_length, 0.0, num_heads = 2)
context_vec = ma(batch)
print(context_vec)

tensor([[[ 0.0000,  0.0000,  0.0324,  0.4033],
         [-0.0821,  0.2978,  0.0000,  0.0000],
         [-0.1015,  0.4133, -0.0387, -0.0218],
         [-0.0857,  0.2748, -0.0254,  0.0818],
         [-0.0345,  0.0956, -0.0269, -0.0163],
         [-0.1333,  0.3603, -0.0047, -0.0086]],

        [[-0.1468,  0.6157,  0.0324,  0.4033],
         [-0.0793,  0.3326,  0.0165,  0.2060],
         [-0.0542,  0.2273, -0.0831, -0.0487],
         [-0.0424,  0.1178, -0.0409, -0.0357],
         [-0.0294,  0.1157, -0.0957,  0.0054],
         [-0.0928,  0.3955, -0.0245, -0.0212]]], grad_fn=<CatBackward0>)


In [61]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias = False):
        super().__init__()
        assert d_out % num_heads == 0, "d_out must be divisible by num_heads"

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads
        self.W_query = torch.nn.Linear(d_in, d_out, qkv_bias)
        self.W_key = torch.nn.Linear(d_in, d_out, qkv_bias)
        self.W_value = torch.nn.Linear(d_in, d_out, qkv_bias)
        self.out_proj = torch.nn.Linear(d_out, d_out)
        self.dropout = torch.nn.Dropout(p = 0.5)
        self.register_buffer("mask", torch.triu(torch.ones(context_length, context_length), diagonal = 1))

    def forward(self, x):
        b, num_tokens, d_in = x.shape
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)

        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)

        keys = keys.transpose(1, 2)
        queries = queries.transpose(1, 2)
        values = values.transpose(1, 2)

        attn_scores = queries @ keys.transpose(2, 3)
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
        attn_scores.masked_fill_(mask_bool, -torch.inf)
        attn_weights = torch.softmax(attn_scores / keys.shape[-1] ** 0.5, dim = -1)
        attn_weights = self.dropout(attn_weights)

        context_vec = (attn_weights @ values).transpose(1, 2)
        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
        context_vec = self.out_proj(context_vec)
        return context_vec

In [62]:
batch_size, context_length, d_in = batch.shape
d_out = 2
ma = MultiHeadAttention(d_in, d_out, context_length, 0.0, num_heads = 2)
context_vec = ma(batch)
print(context_vec)
print(context_vec.shape)

tensor([[[-0.2121, -0.3697],
         [-0.4154, -0.4455],
         [-0.5426, -0.4872],
         [-0.5657, -0.4942],
         [-0.5532, -0.4912],
         [-0.5066, -0.4682]],

        [[-0.5577, -0.4983],
         [-0.3307, -0.4057],
         [-0.5048, -0.4674],
         [-0.4258, -0.4375],
         [-0.5907, -0.5109],
         [-0.5273, -0.4789]]], grad_fn=<ViewBackward0>)
torch.Size([2, 6, 2])
