In [1]:
import torch

In [2]:
inputs=torch.tensor([[0.43,0.15,0.89], #your
                             [0.55,0.89,0.66], #journey
                             [0.57,0.85,0.64],  #starts
                             [0.22,0.58,0.33],  #with
                             [0.77,0.25,0.10],   #one
                             [0.05,0.80,0.55]])  #step

In [4]:
d_in=inputs.shape[1]
d_out=2

In [9]:
torch.manual_seed(123)
W_query=torch.nn.Parameter(torch.rand(d_in,d_out))
W_key=torch.nn.Parameter(torch.rand(d_in,d_out))
W_value=torch.nn.Parameter(torch.rand(d_in,d_out))

In [10]:
W_query

Parameter containing:
tensor([[0.2961, 0.5166],
        [0.2517, 0.6886],
        [0.0740, 0.8665]], requires_grad=True)

In [12]:
query_matrix=inputs@W_query
key_matrix=inputs@W_key
value_matrix=inputs@W_value

In [13]:
query_matrix

tensor([[0.2309, 1.0966],
        [0.4357, 1.4688],
        [0.4300, 1.4343],
        [0.2355, 0.7990],
        [0.2983, 0.6565],
        [0.2568, 1.0533]], grad_fn=<MmBackward0>)

In [15]:
query_2=query_matrix[1]

In [16]:
attn_score_2=query_2 @ key_matrix.T

In [18]:
print(attn_score_2)

tensor([1.2829, 1.8933, 1.8287, 1.0900, 0.5632, 1.5589],
       grad_fn=<SqueezeBackward4>)


In [19]:
attn_scores=query_matrix@key_matrix.T

In [20]:
attn_scores

tensor([[0.9231, 1.3713, 1.3241, 0.7910, 0.4032, 1.1330],
        [1.2829, 1.8933, 1.8287, 1.0900, 0.5632, 1.5589],
        [1.2544, 1.8508, 1.7877, 1.0654, 0.5508, 1.5238],
        [0.6973, 1.0292, 0.9941, 0.5925, 0.3061, 0.8475],
        [0.6114, 0.8926, 0.8626, 0.5121, 0.2707, 0.7307],
        [0.8995, 1.3328, 1.2871, 0.7682, 0.3937, 1.0996]],
       grad_fn=<MmBackward0>)

In [22]:
d_key=key_matrix.shape[-1]
attn_weights=torch.softmax(attn_scores/d_key**0.5,dim=-1)

In [23]:
attn_weights

tensor([[0.1547, 0.2124, 0.2054, 0.1409, 0.1071, 0.1794],
        [0.1493, 0.2298, 0.2196, 0.1302, 0.0897, 0.1814],
        [0.1498, 0.2284, 0.2184, 0.1311, 0.0911, 0.1812],
        [0.1588, 0.2008, 0.1959, 0.1475, 0.1204, 0.1766],
        [0.1608, 0.1961, 0.1920, 0.1499, 0.1263, 0.1749],
        [0.1554, 0.2111, 0.2043, 0.1416, 0.1086, 0.1790]],
       grad_fn=<SoftmaxBackward0>)

In [24]:
context_vectors=attn_weights@value_matrix

In [25]:
context_vectors

tensor([[0.3012, 0.8075],
        [0.3081, 0.8241],
        [0.3075, 0.8228],
        [0.2962, 0.7959],
        [0.2941, 0.7910],
        [0.3006, 0.8062]], grad_fn=<MmBackward0>)

# organising into classes

In [31]:
import torch.nn as nn
class Self_Attention_v1(nn.Module):
    def __init__(self,d_in,d_out):
        super().__init__()
        self.W_query=nn.Parameter(torch.rand(d_in,d_out))
        self.W_key=nn.Parameter(torch.rand(d_in,d_out))
        self.W_value=nn.Parameter(torch.rand(d_in,d_out))
    
    def forward(self,x):
        query_matrix=x@ self.W_query
        key_matrix=x@ self.W_key
        value_matrix=x@ self.W_value
        
        attn_scores=query_matrix @ key_matrix.T
        attn_weights=torch.softmax(attn_scores/key_matrix.shape[-1]**0.5,dim=-1)
        
        context_vec=attn_weights @ value_matrix
        return context_vec
        
        

In [32]:
torch.manual_seed(123)
sa=Self_Attention_v1(d_in,d_out)
sa(inputs)

tensor([[0.3012, 0.8075],
        [0.3081, 0.8241],
        [0.3075, 0.8228],
        [0.2962, 0.7959],
        [0.2941, 0.7910],
        [0.3006, 0.8062]], grad_fn=<MmBackward0>)

In [40]:
import torch.nn as nn
class Self_Attention_v2(nn.Module):
    def __init__(self,d_in,d_out,q_bias=False):
        super().__init__()
        self.W_query=nn.Linear(d_in,d_out,bias=q_bias)
        self.W_key=nn.Linear(d_in,d_out,bias=q_bias)
        self.W_value=nn.Linear(d_in,d_out,bias=q_bias)
    
    def forward(self,x):
        query_matrix= self.W_query(x)
        key_matrix=self.W_key(x)
        value_matrix=self.W_value(x)
        
        attn_scores=query_matrix @ key_matrix.T
        attn_weights=torch.softmax(attn_scores/key_matrix.shape[-1]**0.5,dim=-1)
        
        context_vec=attn_weights @ value_matrix
        return context_vec
        

In [41]:
torch.manual_seed(789)
sa=Self_Attention_v2(d_in,d_out)
sa(inputs)

tensor([[-0.0743,  0.0710],
        [-0.0753,  0.0700],
        [-0.0753,  0.0700],
        [-0.0765,  0.0682],
        [-0.0768,  0.0677],
        [-0.0759,  0.0690]], grad_fn=<MmBackward0>)