# Coding Attention Mechanims

## Self Attention: attending to different parts of the input

### Simple self-attention mechanism without trainable weights

In [11]:
import torch

inputs = torch.tensor(
    [[0.43, 0.15, 0.89],
     [0.55, 0.87, 0.66],
     [0.57, 0.85, 0.64],
     [0.22, 0.58, 0.33],
     [0.77, 0.25, 0.10],
     [0.05, 0.80, 0.55]]
)

In [None]:
query = inputs[1]

attention_score_2 = torch.empty(inputs.shape[0])
for i, x_i in enumerate(inputs):
    attention_score_2[i] = torch.dot(x_i, query)

attention_score_2

tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])

In [18]:
# Attention Weights
attention_weights_2 = torch.softmax(attention_score_2, dim=0)
attention_weights_2

tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])

In [20]:
query = inputs[1]

context_vec_2 = torch.zeros(query.shape)
for i, x_i in enumerate(inputs):
    context_vec_2+=attention_weights_2[i]*inputs[i]


print(context_vec_2)


tensor([0.4419, 0.6515, 0.5683])


In [21]:
inputs.shape

torch.Size([6, 3])

### Full Attention Score Matric

In [22]:
atten_scores = torch.empty(inputs.shape[0], inputs.shape[0])

for i , x_i in enumerate(inputs):
    for j, x_j in enumerate(inputs):
        atten_scores[i,j] = torch.dot(x_i,x_j)

atten_scores

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])

### Use Matrix Multiplication instead of double for loop

In [23]:
atten_scores = inputs @ inputs.T
atten_scores

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])

In [26]:
attention_weights = torch.softmax(atten_scores, dim=1)
attention_weights

tensor([[0.2098, 0.2006, 0.1981, 0.1242, 0.1220, 0.1452],
        [0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581],
        [0.1390, 0.2369, 0.2326, 0.1242, 0.1108, 0.1565],
        [0.1435, 0.2074, 0.2046, 0.1462, 0.1263, 0.1720],
        [0.1526, 0.1958, 0.1975, 0.1367, 0.1879, 0.1295],
        [0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]])

In [32]:
all_context = attention_weights @ inputs
all_context

tensor([[0.4421, 0.5931, 0.5790],
        [0.4419, 0.6515, 0.5683],
        [0.4431, 0.6496, 0.5671],
        [0.4304, 0.6298, 0.5510],
        [0.4671, 0.5910, 0.5266],
        [0.4177, 0.6503, 0.5645]])

### Implementing Selt-attention with trainable weights

#### focusing only on the second input

In [33]:
x_2 = inputs[1]
d_in = inputs.shape[1]
# output embedding size
d_out = 2


In [35]:
torch.manual_seed(123)

W_query = torch.nn.Parameter(torch.rand(d_in, d_out))
W_key = torch.nn.Parameter(torch.rand(d_in, d_out))
W_value = torch.nn.Parameter(torch.rand(d_in, d_out))



In [37]:
query_2 = x_2 @ W_query
query_2

tensor([0.4306, 1.4551], grad_fn=<SqueezeBackward4>)

In [38]:
keys  =  inputs @ W_key
values  =  inputs @ W_value


In [41]:
keys

tensor([[0.3669, 0.7646],
        [0.4433, 1.1419],
        [0.4361, 1.1156],
        [0.2408, 0.6706],
        [0.1827, 0.3292],
        [0.3275, 0.9642]], grad_fn=<MmBackward0>)

In [42]:
atten_scores_2 = query_2 @ keys.T
atten_scores_2

tensor([1.2705, 1.8524, 1.8111, 1.0795, 0.5577, 1.5440],
       grad_fn=<SqueezeBackward4>)

In [43]:
d_k  = keys.shape[1]

atten_weights_2 = torch.softmax(atten_scores_2 / d_k**0.5 , dim=-1)
atten_weights_2

tensor([0.1500, 0.2264, 0.2199, 0.1311, 0.0906, 0.1820],
       grad_fn=<SoftmaxBackward0>)

In [44]:
context_vec_2 = atten_weights_2 @ values
context_vec_2

tensor([0.3061, 0.8210], grad_fn=<SqueezeBackward4>)

#### Creating a compact class for self attention

In [None]:
import torch.nn as nn

class SelfAttention_V1(nn.Module):
    
    def __init__(self, d_in, d_out):
        super().__init__()
        self.W_query = torch.nn.Parameter(torch.rand(d_in, d_out))
        self.W_key = torch.nn.Parameter(torch.rand(d_in, d_out))
        self.W_value = torch.nn.Parameter(torch.rand(d_in, d_out))
        
    
    def forward(self, x):
        queries = x @ W_query
        keys = x @ W_key
        values = x @ W_value
        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(attn_scores / keys.shape[1] **0.5, dim=-1)
        context_vec = attn_weights @ values
        return context_vec
        

torch.manual_seed(123)
sa_v1 = SelfAttention_V1(d_in=d_in, d_out=d_out)       
sa_v1(inputs)

tensor([[0.2996, 0.8053],
        [0.3061, 0.8210],
        [0.3058, 0.8203],
        [0.2948, 0.7939],
        [0.2927, 0.7891],
        [0.2990, 0.8040]], grad_fn=<MmBackward0>)

In [46]:
class SelfAttention_V2(nn.Module):
    def __init__(self, d_in, d_out, qkv_bias=False):
        super().__init__()
        self.W_query = torch.nn.Linear(d_in, d_out, qkv_bias)
        self.W_key = torch.nn.Linear(d_in, d_out, qkv_bias)
        self.W_value = torch.nn.Linear(d_in, d_out,qkv_bias)
        
    
    def forward(self, x):
        queries = self.W_query(x)
        keys = self.W_key(x)
        values = self.W_value(x)
        
        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(attn_scores / keys.shape[1] **0.5, dim=-1)
        context_vec = attn_weights @ values
        return context_vec

torch.manual_seed(123)
sa_v2 = SelfAttention_V2(d_in=d_in, d_out=d_out)       
sa_v2(inputs)

tensor([[-0.5337, -0.1051],
        [-0.5323, -0.1080],
        [-0.5323, -0.1079],
        [-0.5297, -0.1076],
        [-0.5311, -0.1066],
        [-0.5299, -0.1081]], grad_fn=<MmBackward0>)

#### Mask Future Tokens

In [47]:
queries = sa_v2.W_query(inputs)
keys = sa_v2.W_key(inputs)
values = sa_v2.W_value(inputs)

attn_scores = queries @ keys.T
attn_weights = torch.softmax(attn_scores / keys.shape[1] **0.5, dim=-1)

In [50]:
context_length = attn_scores.shape[0]
context_length

6

In [52]:
simple_mask = torch.tril(torch.ones(context_length,context_length))
simple_mask

tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])

In [53]:
masked_simple = attn_scores * simple_mask
masked_simple

tensor([[0.3111, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1655, 0.2602, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.2602, 0.2577, 0.0000, 0.0000, 0.0000],
        [0.0510, 0.1080, 0.1064, 0.0643, 0.0000, 0.0000],
        [0.1415, 0.1875, 0.1863, 0.0987, 0.1121, 0.0000],
        [0.0476, 0.1192, 0.1171, 0.0731, 0.0477, 0.0966]],
       grad_fn=<MulBackward0>)