In [1]:
import torch

In [8]:
import torch.nn as nn

class SelfAttention(nn.Module):

    def __init__(self, d_in, d_out):
        super().__init__()
        self.w_query = nn.Parameter(torch.rand(d_in, d_out))
        self.w_key = nn.Parameter(torch.rand(d_in, d_out))
        self.w_value = nn.Parameter(torch.rand(d_in, d_out))

    def forward(self, x):
        keys = x @ self.w_key
        queries = x @ self.w_query
        values = x @ self.w_value

        d_k = keys.shape[-1]
        attn_scores = queries @ keys.T
        attn_weights = torch.nn.functional.softmax(attn_scores / torch.sqrt(torch.tensor(d_k)), dim=-1)
        context_vector = attn_weights @ values
        return context_vector, attn_weights, attn_scores


In [4]:
inputs = torch.tensor(
    [[0.72, 0.45, 0.310], #Dream
    [0.75, 0.20,0.55], #big
    [0.30,0.80,0.40], #and
    [0.85,0.35,0.60], #work
    [0.55,0.15,0.75], #for
    [0.25,0.20,0.85] #it
    ]
)

x_2 = inputs[1]
d_in = inputs.shape[1]
d_out =2

#corresponding wordss
words = ["Dream", "big", "and", "work", "for", "it"]

In [10]:
atten_obj = SelfAttention(d_in, d_out)

In [11]:
context_vector, attention_weights, attn_scores = atten_obj.forward(inputs)

In [7]:
attention_weights

tensor([[0.1581, 0.1610, 0.1710, 0.1981, 0.1609, 0.1510],
        [0.1577, 0.1609, 0.1707, 0.1987, 0.1610, 0.1510],
        [0.1566, 0.1613, 0.1691, 0.1982, 0.1622, 0.1526],
        [0.1557, 0.1597, 0.1713, 0.2055, 0.1598, 0.1480],
        [0.1572, 0.1613, 0.1695, 0.1974, 0.1620, 0.1525],
        [0.1573, 0.1621, 0.1680, 0.1939, 0.1634, 0.1552]],
       grad_fn=<SoftmaxBackward0>)

#### Lower triangular matrix(mask)

In [13]:
context_length = attn_scores.shape[0] # or you can simply do inputs.shape[0]
mask_sample = torch.tril(torch.ones((context_length, context_length)))
mask_sample

tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])

#### Attention weights after applying mask

In [14]:
masked_sample = attention_weights * mask_sample
masked_sample

tensor([[0.1467, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1485, 0.1725, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1412, 0.1741, 0.1270, 0.0000, 0.0000, 0.0000],
        [0.1438, 0.1738, 0.1303, 0.1962, 0.0000, 0.0000],
        [0.1484, 0.1724, 0.1376, 0.1894, 0.1807, 0.0000],
        [0.1483, 0.1721, 0.1379, 0.1889, 0.1808, 0.1720]],
       grad_fn=<MulBackward0>)

#### Attention weights normalized

In [18]:
rows_sum = masked_sample.sum(dim=1, keepdim=True)
masked_simple_normalized = masked_sample / rows_sum
masked_simple_normalized


tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4626, 0.5374, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3192, 0.3937, 0.2871, 0.0000, 0.0000, 0.0000],
        [0.2233, 0.2698, 0.2023, 0.3046, 0.0000, 0.0000],
        [0.1791, 0.2080, 0.1661, 0.2286, 0.2181, 0.0000],
        [0.1483, 0.1721, 0.1379, 0.1889, 0.1808, 0.1720]],
       grad_fn=<DivBackward0>)

In [17]:
rows_sum

tensor([[0.1467],
        [0.3211],
        [0.4423],
        [0.6440],
        [0.8286],
        [1.0000]], grad_fn=<SumBackward1>)

#### Here what actually we do for mask self attention
![image](./images/masked_self_attention.png)

In [19]:
mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)
mask

tensor([[0., 1., 1., 1., 1., 1.],
        [0., 0., 1., 1., 1., 1.],
        [0., 0., 0., 1., 1., 1.],
        [0., 0., 0., 0., 1., 1.],
        [0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0.]])

In [21]:
attention_scores = attn_scores / torch.sqrt(torch.tensor(d_out))
masked = attention_scores.masked_fill(mask == 1, float('-inf'))
masked

tensor([[0.6464,   -inf,   -inf,   -inf,   -inf,   -inf],
        [0.5821, 0.7318,   -inf,   -inf,   -inf,   -inf],
        [0.8037, 1.0133, 0.6976,   -inf,   -inf,   -inf],
        [0.7351, 0.9243, 0.6364, 1.0457,   -inf,   -inf],
        [0.5728, 0.7222, 0.4972, 0.8168, 0.7696,   -inf],
        [0.5621, 0.7111, 0.4895, 0.8038, 0.7600, 0.7100]],
       grad_fn=<MaskedFillBackward0>)

In [22]:
atten_weights = torch.nn.functional.softmax(masked, dim=-1)
atten_weights

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4626, 0.5374, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3192, 0.3937, 0.2871, 0.0000, 0.0000, 0.0000],
        [0.2233, 0.2698, 0.2023, 0.3046, 0.0000, 0.0000],
        [0.1791, 0.2080, 0.1661, 0.2286, 0.2181, 0.0000],
        [0.1483, 0.1721, 0.1379, 0.1889, 0.1808, 0.1720]],
       grad_fn=<SoftmaxBackward0>)

#### Here is how the dropout is implemented in self attention
![img](./images/drop_out.png)
![img2](./images/dropout_3.png)
![img3](./images/dropout_2.png)

In [24]:
# once matrix
example = torch.ones(context_length,context_length)
example

tensor([[1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.]])

In [26]:
# Random dropout with 50% probability
dropout = torch.nn.Dropout(0.5)
dropout(example)

tensor([[2., 2., 0., 0., 0., 2.],
        [0., 0., 2., 2., 0., 0.],
        [0., 0., 2., 2., 2., 2.],
        [0., 2., 2., 2., 2., 2.],
        [2., 0., 0., 2., 0., 0.],
        [2., 0., 0., 2., 2., 0.]])

In [27]:
# Attention weights after dropout mask
torch.manual_seed(0)
dropout(atten_weights)

tensor([[2.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 1.0747, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.7873, 0.5742, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3583, 0.4160, 0.3322, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.3443, 0.0000, 0.3778, 0.3615, 0.0000]],
       grad_fn=<MulBackward0>)