## Coding Attention Mechanisms

### 1) Simple self-attention mechanism without trainable weights

In [1]:
import torch

inputs = torch.tensor(
    [
        [0.43, 0.12, 0.45],
        [0.23, 0.67, 0.10],
        [0.12, 0.45, 0.43],
        [0.67, 0.23, 0.10],
        [0.45, 0.10, 0.43],
        [0.10, 0.43, 0.67],
    ]
)


In [2]:
input_query = inputs[1]
input_query

tensor([0.2300, 0.6700, 0.1000])

In [3]:
input_1 = inputs[0]
input_1

tensor([0.4300, 0.1200, 0.4500])

In [4]:
torch.dot(input_query, input_1)

tensor(0.2243)

In [5]:
res = 0.
i = 1

for idx, element in enumerate(inputs[i]):
    res += inputs[i][idx] * input_query[idx]

res

tensor(0.5118)

In [6]:
i = 1
res = torch.dot(inputs[i], input_query) 
res

tensor(0.5118)

In [7]:
query = inputs[1]

attn_scores_2 = torch.empty(inputs.shape[0])
for i, x_i in enumerate(inputs):
    attn_scores_2[i] = torch.dot(x_i, input_query)

print(attn_scores_2)

tensor([0.2243, 0.5118, 0.3721, 0.3182, 0.2135, 0.3781])


In [8]:
attn_weights_2_tmp = attn_scores_2 / attn_scores_2.sum()
attn_weights_2_tmp

tensor([0.1111, 0.2536, 0.1844, 0.1577, 0.1058, 0.1874])

In [9]:
attn_weights_2_tmp.sum()

tensor(1.0000)

In [10]:
attn_weights_2 = torch.softmax(attn_scores_2, dim=0)

In [11]:
torch.zeros(query.shape)

tensor([0., 0., 0.])

In [12]:
query = inputs[1]

context_vec_2 = torch.zeros(query.shape)
for i, x_i in enumerate(inputs):
    print(f"{attn_weights_2[i]} ----> {x_i}")
    context_vec_2 += attn_weights_2[i] * x_i
print(context_vec_2)

0.14822857081890106 ----> tensor([0.4300, 0.1200, 0.4500])
0.19760210812091827 ----> tensor([0.2300, 0.6700, 0.1000])
0.17183856666088104 ----> tensor([0.1200, 0.4500, 0.4300])
0.16282165050506592 ----> tensor([0.6700, 0.2300, 0.1000])
0.14663630723953247 ----> tensor([0.4500, 0.1000, 0.4300])
0.17287269234657288 ----> tensor([0.1000, 0.4300, 0.6700])
tensor([0.3222, 0.3540, 0.3555])


### 2) Simple self-attention mechanism without trainable weights generalized

In [13]:
attn_scores = torch.empty(6, 6)


for i, x_i in enumerate(inputs):
    for j, x_j in enumerate(inputs):
        attn_scores[i][j] = torch.dot(x_i, x_j)
print(attn_scores)

tensor([[0.4018, 0.2243, 0.2991, 0.3607, 0.3990, 0.3961],
        [0.2243, 0.5118, 0.3721, 0.3182, 0.2135, 0.3781],
        [0.2991, 0.3721, 0.4018, 0.2269, 0.2839, 0.4936],
        [0.3607, 0.3182, 0.2269, 0.5118, 0.3675, 0.2329],
        [0.3990, 0.2135, 0.2839, 0.3675, 0.3974, 0.3761],
        [0.3961, 0.3781, 0.4936, 0.2329, 0.3761, 0.6438]])


In [14]:
attn_scores = inputs @ inputs.T
print(attn_scores)

tensor([[0.4018, 0.2243, 0.2991, 0.3607, 0.3990, 0.3961],
        [0.2243, 0.5118, 0.3721, 0.3182, 0.2135, 0.3781],
        [0.2991, 0.3721, 0.4018, 0.2269, 0.2839, 0.4936],
        [0.3607, 0.3182, 0.2269, 0.5118, 0.3675, 0.2329],
        [0.3990, 0.2135, 0.2839, 0.3675, 0.3974, 0.3761],
        [0.3961, 0.3781, 0.4936, 0.2329, 0.3761, 0.6438]])


In [15]:
attn_weights = torch.softmax(attn_scores, dim=1)
print(attn_weights)

tensor([[0.1757, 0.1471, 0.1586, 0.1686, 0.1752, 0.1747],
        [0.1482, 0.1976, 0.1718, 0.1628, 0.1466, 0.1729],
        [0.1584, 0.1704, 0.1755, 0.1473, 0.1560, 0.1924],
        [0.1700, 0.1629, 0.1487, 0.1977, 0.1711, 0.1496],
        [0.1765, 0.1466, 0.1573, 0.1710, 0.1762, 0.1725],
        [0.1614, 0.1585, 0.1779, 0.1371, 0.1582, 0.2068]])


In [16]:
all_context_vecs = attn_weights @ inputs
print(all_context_vecs)

tensor([[0.3377, 0.3225, 0.3712],
        [0.3222, 0.3540, 0.3555],
        [0.3165, 0.3444, 0.3745],
        [0.3528, 0.3234, 0.3503],
        [0.3396, 0.3213, 0.3701],
        [0.3110, 0.3419, 0.3853]])


### 3) Implementing self-attention with trainable weights

In [17]:
x_2 = inputs[1]
d_in = inputs.shape[1]
d_out = 2

In [18]:
torch.manual_seed(123)

W_query = torch.nn.Parameter(torch.randn(d_in, d_out))
W_key = torch.nn.Parameter(torch.randn(d_in, d_out))
W_value = torch.nn.Parameter(torch.randn(d_in, d_out))

In [19]:
query_2 = x_2 @ W_query

query_2

tensor([-0.3930, -0.1125], grad_fn=<SqueezeBackward4>)

In [20]:
x_2

tensor([0.2300, 0.6700, 0.1000])

In [21]:
W_query

Parameter containing:
tensor([[-0.1115,  0.1204],
        [-0.3696, -0.2404],
        [-1.1969,  0.2093]], requires_grad=True)

In [22]:
keys = inputs @ W_key
value = inputs @ W_value

keys.shape

torch.Size([6, 2])

In [23]:
keys

tensor([[-0.2846, -0.5136],
        [ 0.0144, -0.2855],
        [ 0.1195, -0.3075],
        [-0.5559, -0.5699],
        [-0.3147, -0.5187],
        [ 0.1830, -0.3840]], grad_fn=<MmBackward0>)

In [24]:
keys_2 = keys[1]
attn_score_22 = torch.dot(query_2, keys_2)
attn_score_22

tensor(0.0264, grad_fn=<DotBackward0>)

In [25]:
attn_scores_2 = query_2 @ keys.T
attn_scores_2

tensor([ 0.1696,  0.0264, -0.0124,  0.2826,  0.1820, -0.0287],
       grad_fn=<SqueezeBackward4>)

In [26]:
d_k = keys.shape[1]

attn_weights_2 = torch.softmax(attn_scores_2 / d_k**0.5, dim=-1)
attn_weights_2

tensor([0.1741, 0.1573, 0.1531, 0.1886, 0.1756, 0.1513],
       grad_fn=<SoftmaxBackward0>)

In [27]:
context_vec_2 = attn_weights_2 @ value
context_vec_2

tensor([0.1818, 0.2301], grad_fn=<SqueezeBackward4>)

### 4) Implementing a compact SelfAttention class

In [35]:
import torch.nn as nn

class SelfAttention_v1(nn.Module):

    def __init__(self, d_in, d_out):
        super().__init__()
        self.W_query = nn.Parameter(torch.randn(d_in, d_out))
        self.W_key = nn.Parameter(torch.randn(d_in, d_out))
        self.W_value = nn.Parameter(torch.randn(d_in, d_out))

    def forward(self, inputs):
        queries = inputs @ self.W_query
        keys = inputs @ self.W_key
        values = inputs @ self.W_value

        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(attn_scores / keys.shape[1]**0.5, dim=-1)
        context_vec = attn_weights @ values

        return context_vec
    
torch.manual_seed(123)
sa_v1 = SelfAttention_v1(d_in, d_out)
sa_v1(inputs)

tensor([[0.1823, 0.2348],
        [0.1818, 0.2301],
        [0.1821, 0.2355],
        [0.1820, 0.2287],
        [0.1822, 0.2343],
        [0.1823, 0.2405]], grad_fn=<MmBackward0>)

In [37]:
import torch.nn as nn

class SelfAttention_v2(nn.Module):

    def __init__(self, d_in, d_out ,qkv_bias=False):
        super().__init__()
        self.W_query = torch.nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = torch.nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = torch.nn.Linear(d_in, d_out, bias=qkv_bias)

    def forward(self, inputs):
        queries = self.W_query(inputs)
        keys = self.W_key(inputs)
        values = self.W_value(inputs)

        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(attn_scores / keys.shape[1]**0.5, dim=-1)
        context_vec = attn_weights @ values

        return context_vec
    
torch.manual_seed(123)
sa_v2 = SelfAttention_v2(d_in, d_out)
sa_v2(inputs)

tensor([[-0.3572, -0.0460],
        [-0.3570, -0.0488],
        [-0.3571, -0.0474],
        [-0.3572, -0.0469],
        [-0.3572, -0.0460],
        [-0.3572, -0.0468]], grad_fn=<MmBackward0>)

### 5) Hiding future words with casual attention

In [38]:
queries = sa_v2.W_query(inputs)
keys = sa_v2.W_key(inputs)
values = sa_v2.W_value(inputs)

attn_scores = queries @ keys.T
attn_weights = torch.softmax(attn_scores / keys.shape[1]**0.5, dim=-1)

In [39]:
attn_weights

tensor([[0.1683, 0.1630, 0.1646, 0.1680, 0.1683, 0.1678],
        [0.1640, 0.1712, 0.1684, 0.1651, 0.1637, 0.1676],
        [0.1661, 0.1672, 0.1666, 0.1665, 0.1659, 0.1677],
        [0.1669, 0.1656, 0.1658, 0.1671, 0.1668, 0.1678],
        [0.1684, 0.1629, 0.1646, 0.1680, 0.1683, 0.1678],
        [0.1670, 0.1653, 0.1656, 0.1672, 0.1669, 0.1679]],
       grad_fn=<SoftmaxBackward0>)

In [40]:
context_length = attn_scores.shape[0]
mask_simple = torch.tril(torch.ones(context_length, context_length))
mask_simple

tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])

In [41]:
masked_simple = attn_weights * mask_simple
masked_simple

tensor([[0.1683, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1640, 0.1712, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1661, 0.1672, 0.1666, 0.0000, 0.0000, 0.0000],
        [0.1669, 0.1656, 0.1658, 0.1671, 0.0000, 0.0000],
        [0.1684, 0.1629, 0.1646, 0.1680, 0.1683, 0.0000],
        [0.1670, 0.1653, 0.1656, 0.1672, 0.1669, 0.1679]],
       grad_fn=<MulBackward0>)

In [42]:
row_sums = masked_simple.sum(dim=-1, keepdim=True)
masked_simple_norm = masked_simple / row_sums
masked_simple_norm

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4892, 0.5108, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3323, 0.3344, 0.3333, 0.0000, 0.0000, 0.0000],
        [0.2508, 0.2489, 0.2492, 0.2511, 0.0000, 0.0000],
        [0.2024, 0.1957, 0.1977, 0.2019, 0.2023, 0.0000],
        [0.1670, 0.1653, 0.1656, 0.1672, 0.1669, 0.1679]],
       grad_fn=<DivBackward0>)

In [43]:
mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)
masked = attn_scores.masked_fill(mask.bool(), float('-inf'))
masked

tensor([[ 0.1325,    -inf,    -inf,    -inf,    -inf,    -inf],
        [-0.0157,  0.0454,    -inf,    -inf,    -inf,    -inf],
        [ 0.0547,  0.0638,  0.0588,    -inf,    -inf,    -inf],
        [ 0.0866,  0.0756,  0.0774,  0.0881,    -inf,    -inf],
        [ 0.1333,  0.0860,  0.1006,  0.1303,  0.1327,    -inf],
        [ 0.1002,  0.0856,  0.0885,  0.1016,  0.0990,  0.1078]],
       grad_fn=<MaskedFillBackward0>)

In [44]:
attn_weights = torch.softmax(masked / keys.shape[1]**0.5, dim=-1)
attn_weights

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4892, 0.5108, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3323, 0.3344, 0.3333, 0.0000, 0.0000, 0.0000],
        [0.2508, 0.2489, 0.2492, 0.2511, 0.0000, 0.0000],
        [0.2024, 0.1957, 0.1977, 0.2019, 0.2023, 0.0000],
        [0.1670, 0.1653, 0.1656, 0.1672, 0.1669, 0.1679]],
       grad_fn=<SoftmaxBackward0>)

### 6) Masking additional attention weights with dropout

In [45]:
torch.manual_seed(123)

layer = torch.nn.Dropout(0.5)

In [47]:
example = torch.ones(6,6)
layer(example)

tensor([[2., 2., 0., 0., 0., 0.],
        [0., 0., 2., 0., 0., 2.],
        [0., 0., 2., 2., 0., 2.],
        [0., 2., 0., 2., 2., 0.],
        [0., 2., 2., 2., 2., 2.],
        [2., 2., 0., 0., 2., 2.]])

### 7) Implementing a compact casual self-attention class

In [59]:
batch = torch.stack((inputs, inputs), dim=0)

In [60]:
import torch.nn as nn

class CasualAttention(nn.Module):

    def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):
        super().__init__()
        self.W_query = torch.nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = torch.nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = torch.nn.Linear(d_in, d_out, bias=qkv_bias)
        self.dropout = torch.nn.Dropout(dropout)
        self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))

    def forward(self, inputs):
        b, num_tokens, d_in = inputs.shape
        queries = self.W_query(inputs)
        keys = self.W_key(inputs)
        values = self.W_value(inputs)

        attn_scores = queries @ keys.transpose(1, 2)
        attn_scores.masked_fill_(self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)
        context_vec = attn_weights @ values
        return context_vec
    
torch.manual_seed(123)

context_length = batch.shape[1]
dropout = 0.0
ca = CasualAttention(d_in, d_out, context_length, dropout)
ca(batch)

tensor([[[-0.3481,  0.0685],
         [-0.3586, -0.1112],
         [-0.3415, -0.0884],
         [-0.3637, -0.1007],
         [-0.3604, -0.0659],
         [-0.3572, -0.0468]],

        [[-0.3481,  0.0685],
         [-0.3586, -0.1112],
         [-0.3415, -0.0884],
         [-0.3637, -0.1007],
         [-0.3604, -0.0659],
         [-0.3572, -0.0468]]], grad_fn=<UnsafeViewBackward0>)

### 8) Extending single-head attention to multi-head attention

In [62]:
class MultiHeadAttentionWrapper(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads=2, qkv_bias=False):
        super().__init__()
        self.heads = [CasualAttention(d_in, d_out, context_length, dropout, qkv_bias) for _ in range(num_heads)]

    def forward(self, inputs):
        return torch.cat([head(inputs) for head in self.heads], dim=-1)
    
torch.manual_seed(123)

context_length = batch.shape[1]
d_in, d_out = 3, 2

mha = MultiHeadAttentionWrapper(d_in, d_out, context_length, dropout=0.0, num_heads=2)
mha(batch)

tensor([[[-0.3481,  0.0685,  0.3582,  0.1255],
         [-0.3586, -0.1112,  0.3401,  0.2513],
         [-0.3415, -0.0884,  0.3239,  0.2406],
         [-0.3637, -0.1007,  0.3483,  0.2434],
         [-0.3604, -0.0659,  0.3501,  0.2216],
         [-0.3572, -0.0468,  0.3479,  0.2172]],

        [[-0.3481,  0.0685,  0.3582,  0.1255],
         [-0.3586, -0.1112,  0.3401,  0.2513],
         [-0.3415, -0.0884,  0.3239,  0.2406],
         [-0.3637, -0.1007,  0.3483,  0.2434],
         [-0.3604, -0.0659,  0.3501,  0.2216],
         [-0.3572, -0.0468,  0.3479,  0.2172]]], grad_fn=<CatBackward0>)