In [2]:
import pandas as pd
import torch

In [5]:
inputs = torch.tensor(
    [[0.72, 0.45, 0.310], #Dream
    [0.75, 0.20,0.55], #big
    [0.30,0.80,0.40], #and
    [0.85,0.35,0.60], #work
    [0.55,0.15,0.75], #for
    [0.25,0.20,0.85] #it
    ]
)

#corresponding wordss
words = ["Dream", "big", "and", "work", "for", "it"]

In [6]:
x_2 = inputs[1]
d_in = inputs.shape[1]
d_out =2

##### Randomly initializing Wq, Wk and Wv matrices

In [10]:
torch.manual_seed(0)
w_query = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
w_key = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
w_value = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)

In [11]:
w_query

Parameter containing:
tensor([[0.4963, 0.7682],
        [0.0885, 0.1320],
        [0.3074, 0.6341]])

In [12]:
query_2 = x_2 @ w_query
key = x_2 @ w_key
value = x_2 @ w_value

In [13]:
query_2

tensor([0.5590, 0.9513])

#### Calculating the Q, K and V using X, Wq, Wk, Wv
![My image](./images/transforming_input_embd_Q_K_V_space.png)

In [15]:
keys = inputs @ w_key
values = inputs @ w_value
query = inputs @ w_value

print("Keys:\n", keys)
print("Values:\n", values)
print("Query:\n", query)

Keys:
 tensor([[0.6661, 1.0545],
        [0.6506, 1.0197],
        [0.6511, 0.9355],
        [0.7854, 1.2243],
        [0.5996, 0.8892],
        [0.5102, 0.6920]])
Values:
 tensor([[0.3646, 0.6029],
        [0.4592, 0.6704],
        [0.5209, 0.7855],
        [0.5404, 0.8050],
        [0.5796, 0.7707],
        [0.6574, 0.8259]])
Query:
 tensor([[0.3646, 0.6029],
        [0.4592, 0.6704],
        [0.5209, 0.7855],
        [0.5404, 0.8050],
        [0.5796, 0.7707],
        [0.6574, 0.8259]])


In [20]:
# For word 2 ("big")
keys_2 = keys[1]
atten_score_22 = query_2.dot(keys_2)
# attention score for word 2 with word 2
atten_score_22

tensor(1.3338)

### Attention score for entire

In [22]:
attn_scores = query @ keys.T
attn_scores

tensor([[0.8786, 0.8520, 0.8014, 1.0245, 0.7547, 0.6033],
        [1.0128, 0.9824, 0.9261, 1.1814, 0.8714, 0.6982],
        [1.1752, 1.1399, 1.0739, 1.3708, 1.0107, 0.8093],
        [1.2089, 1.1725, 1.1049, 1.4101, 1.0398, 0.8328],
        [1.1987, 1.1630, 1.0983, 1.3987, 1.0328, 0.8290],
        [1.3088, 1.2699, 1.2006, 1.5275, 1.1285, 0.9070]])

#### Scales by 1/sqrt(d) and then take softmax
![My image](./images/dot_product_to_attention_weights.png)

In [27]:
d_k = keys.shape[-1]
attn_weights = torch.nn.functional.softmax(attn_scores / torch.sqrt(torch.tensor(d_k)), dim=-1)
attn_weights

tensor([[0.1731, 0.1699, 0.1639, 0.1919, 0.1586, 0.1425],
        [0.1739, 0.1702, 0.1635, 0.1959, 0.1573, 0.1392],
        [0.1749, 0.1706, 0.1628, 0.2009, 0.1557, 0.1350],
        [0.1751, 0.1707, 0.1627, 0.2019, 0.1554, 0.1342],
        [0.1749, 0.1705, 0.1629, 0.2015, 0.1555, 0.1347],
        [0.1755, 0.1707, 0.1625, 0.2048, 0.1545, 0.1321]])

In [28]:
context_vector = attn_weights @ values
context_vector

tensor([[0.5159, 0.7415],
        [0.5153, 0.7413],
        [0.5145, 0.7410],
        [0.5144, 0.7409],
        [0.5145, 0.7410],
        [0.5140, 0.7409]])

#### Python class for doing this whole operation

In [29]:
import torch.nn as nn

class SelfAttention(nn.Module):

    def __init__(self, d_in, d_out):
        super().__init__()
        self.w_query = nn.Parameter(torch.rand(d_in, d_out))
        self.w_key = nn.Parameter(torch.rand(d_in, d_out))
        self.w_value = nn.Parameter(torch.rand(d_in, d_out))

    def forward(self, x):
        keys = x @ self.w_key
        queries = x @ self.w_query
        values = x @ self.w_value

        d_k = keys.shape[-1]
        attn_scores = queries @ keys.T
        attn_weights = torch.nn.functional.softmax(attn_scores / torch.sqrt(torch.tensor(d_k)), dim=-1)
        context_vector = attn_weights @ values
        return context_vector


In [33]:
torch.manual_seed(0)
self_atten = SelfAttention(d_in,d_out)

In [None]:
self_atten.forward(inputs)

tensor([[0.5145, 0.7409],
        [0.5136, 0.7406],
        [0.5160, 0.7415],
        [0.5128, 0.7403],
        [0.5138, 0.7407],
        [0.5149, 0.7411]], grad_fn=<MmBackward0>)