## **SELF ATTENTION WITH TRAINABLE WEIGHTS**

In [1]:
import torch
import torch.nn as nn

inputs = torch.tensor(
  [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]] # step     (x^6)
)

In [2]:
inputs.shape

torch.Size([6, 3])

In [3]:
x = inputs[1]
d_in = inputs.shape[1]
d_out = 2

In [4]:
torch.manual_seed(123)
W_query = nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_key = nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_value = nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)

In [5]:
print(f"W_query: \n{W_query}")
print(f"\nW_key: \n{W_key}")
print(f"\nW_value: \n{W_value}")

W_query: 
Parameter containing:
tensor([[0.2961, 0.5166],
        [0.2517, 0.6886],
        [0.0740, 0.8665]])

W_key: 
Parameter containing:
tensor([[0.1366, 0.1025],
        [0.1841, 0.7264],
        [0.3153, 0.6871]])

W_value: 
Parameter containing:
tensor([[0.0756, 0.1966],
        [0.3164, 0.4017],
        [0.1186, 0.8274]])


In [6]:
query = x @ W_query
key = x @ W_key
value = x @ W_value

In [7]:
# For input 2
print(f"query: \n{query}")
print(f"\nkey: \n{key}")
print(f"\nvalue: \n{value}")

query: 
tensor([0.4306, 1.4551])

key: 
tensor([0.4433, 1.1419])

value: 
tensor([0.3951, 1.0037])


## **QKV for whole Input**

In [8]:
query = inputs @ W_query
key = inputs @ W_key
value = inputs @ W_value

In [9]:
print(f"query: \n{query}")
print(f"\nkey: \n{key}")
print(f"\nvalue: \n{value}")

query: 
tensor([[0.2309, 1.0966],
        [0.4306, 1.4551],
        [0.4300, 1.4343],
        [0.2355, 0.7990],
        [0.2983, 0.6565],
        [0.2568, 1.0533]])

key: 
tensor([[0.3669, 0.7646],
        [0.4433, 1.1419],
        [0.4361, 1.1156],
        [0.2408, 0.6706],
        [0.1827, 0.3292],
        [0.3275, 0.9642]])

value: 
tensor([[0.1855, 0.8812],
        [0.3951, 1.0037],
        [0.3879, 0.9831],
        [0.2393, 0.5493],
        [0.1492, 0.3346],
        [0.3221, 0.7863]])


In [10]:
query_2 = query[1]
key_2 = key[1]

attn_scores_2 = query_2.dot(key_2)
attn_scores_2

tensor(1.8524)

In [11]:
attn_scores_2 = query_2 @ key.T
attn_scores_2

tensor([1.2705, 1.8524, 1.8111, 1.0795, 0.5577, 1.5440])

In [12]:
# Attention scores
attn_scores = query @ key.T
attn_scores

tensor([[0.9231, 1.3545, 1.3241, 0.7910, 0.4032, 1.1330],
        [1.2705, 1.8524, 1.8111, 1.0795, 0.5577, 1.5440],
        [1.2544, 1.8284, 1.7877, 1.0654, 0.5508, 1.5238],
        [0.6973, 1.0167, 0.9941, 0.5925, 0.3061, 0.8475],
        [0.6114, 0.8819, 0.8626, 0.5121, 0.2707, 0.7307],
        [0.8995, 1.3165, 1.2871, 0.7682, 0.3937, 1.0996]])

In [13]:
attn_scores.shape

torch.Size([6, 6])

## **Attention Weights**

In [15]:
d_k = key.shape[1]
attn_weights = torch.softmax(attn_scores / d_k ** 0.5, dim = -1)
print(f"Attention weights: \n{attn_weights}")

Attention weights: 
tensor([[0.1551, 0.2104, 0.2059, 0.1413, 0.1074, 0.1799],
        [0.1500, 0.2264, 0.2199, 0.1311, 0.0906, 0.1820],
        [0.1503, 0.2256, 0.2192, 0.1315, 0.0914, 0.1819],
        [0.1591, 0.1994, 0.1962, 0.1477, 0.1206, 0.1769],
        [0.1610, 0.1949, 0.1923, 0.1501, 0.1265, 0.1752],
        [0.1557, 0.2092, 0.2048, 0.1419, 0.1089, 0.1794]])


In [18]:
context_vector = attn_weights @ value
print(f"Context Vectors: \n{context_vector}")

Context Vectors: 
tensor([[0.2996, 0.8053],
        [0.3061, 0.8210],
        [0.3058, 0.8203],
        [0.2948, 0.7939],
        [0.2927, 0.7891],
        [0.2990, 0.8040]])


## **Self Attention Complete**

In [19]:
class SelfAttention(nn.Module):
  def __init__(self, d_in, d_out):
    super().__init__()
    self.W_Query = nn.Parameter(torch.rand(d_in, d_out))
    self.W_Value = nn.Parameter(torch.rand(d_in, d_out))
    self.W_Key = nn.Parameter(torch.rand(d_in, d_out))

  def forward(self, x):
    queries = x @ self.W_Query
    keys = x @ self.W_Key
    values = x @ self.W_Value

    attn_scores = queries @ keys.T
    attn_weights = torch.softmax(attn_scores / keys.shape[1] ** 0.5, dim = -1)

    context_vec = attn_weights @ values
    return context_vec

In [24]:
sa = SelfAttention(d_in = 3, d_out = 2)
print(sa(inputs))

tensor([[0.7236, 0.8354],
        [0.7308, 0.8426],
        [0.7306, 0.8425],
        [0.7145, 0.8256],
        [0.7170, 0.8303],
        [0.7179, 0.8283]], grad_fn=<MmBackward0>)


In [25]:
class SelfAttention_V2(nn.Module):
  def __init__(self, d_in, d_out, qkv_bias = False):
    super().__init__()
    self.W_Query = nn.Linear(d_in, d_out, bias = qkv_bias)
    self.W_Value = nn.Linear(d_in, d_out, bias = qkv_bias)
    self.W_Key = nn.Linear(d_in, d_out, bias = qkv_bias)

  def forward(self, x):
    keys = self.W_Key(x)
    queries = self.W_Query(x)
    values = self.W_Value(x)

    attn_scores = queries @ keys.T
    attn_weights = torch.softmax(attn_scores / keys.shape[1] ** 0.5, dim = -1)

    context_vec = attn_weights @ values
    return context_vec

In [26]:
sa = SelfAttention_V2(d_in = 3, d_out = 2)
print(sa(inputs))

tensor([[0.2371, 0.7407],
        [0.2305, 0.7362],
        [0.2307, 0.7364],
        [0.2286, 0.7345],
        [0.2338, 0.7396],
        [0.2269, 0.7328]], grad_fn=<MmBackward0>)
