# Single Head Self-Attention in Transformers

In [3]:
import torch
from torch import nn

In [4]:
class selfAttention(nn.Module):
    def __init__(self, d_in, d_out):
        super(selfAttention, self).__init__()
        self.query = nn.Parameter(torch.rand(d_in, d_out), requires_grad=True)   # default value of requires_grad is True
        self.key = nn.Parameter(torch.rand(d_in, d_out))
        self.value = nn.Parameter(torch.rand(d_in, d_out))
    
    def forward(self, x):
        x_q = x @ self.query
        x_k = x @ self.key
        x_v = x @ self.value
        
        # Compute attention score
        att_score = x_q @ x_k.T
        
        # attention weight
        norm_factor = x_v.shape[-1] ** 0.5              # normalization factor    
        att_weights = torch.softmax(att_score/norm_factor, dim=-1)
        
        # context matrix
        context = att_weights @ x_v
        
        return context
        
        
        

In [8]:
sa = selfAttention(3,3)
sa.forward(torch.randn(6,3))

tensor([[ 0.5193,  0.4446,  0.0996],
        [-0.1750,  0.1582, -0.3943],
        [ 0.0792,  0.2550, -0.2068],
        [-0.2201,  0.1439, -0.4212],
        [-0.8927, -0.0247, -0.9990],
        [-0.7117,  0.0028, -0.8256]], grad_fn=<MmBackward0>)

## Improving using Linear layer instead nn.Parameter

In [19]:
# Matrix Multiplication is much faster with Linear layer

class selfAttentionImproved(nn.Module):
    def __init__(self, d_in, d_out, qkv_biased=False):
        super(selfAttentionImproved, self).__init__()
        self.liner_query = nn.Linear(d_in, d_out, bias=qkv_biased)   # default value of requires_grad is True
        self.liner_key = nn.Linear(d_in, d_out, bias=qkv_biased)
        self.liner_value = nn.Linear(d_in, d_out, bias=qkv_biased)
    
    def forward(self, x):
        x_q = self.liner_query(x)
        x_k = self.liner_key(x)
        x_v = self.liner_value(x)
        
        # Compute attention score
        att_score = x_q @ x_k.T
        
        # attention weight
        norm_factor = x_v.shape[-1] ** 0.5              # normalization factor    
        att_weights = torch.softmax(att_score/norm_factor, dim=-1)
        
        # context matrix
        context = att_weights @ x_v
        
        return context
        
        
        

In [20]:
sai = selfAttentionImproved(3,3)
sai.forward(torch.randn(6,3))

tensor([[-0.0044, -0.0584,  0.0811],
        [ 0.0716, -0.0795,  0.0767],
        [ 0.3407, -0.1604,  0.0786],
        [-0.1485, -0.0115,  0.0736],
        [ 0.1258, -0.0614,  0.0208],
        [-0.0783, -0.0678,  0.1265]], grad_fn=<MmBackward0>)

In [None]:
# NOTE:
    # all weights are randomly choose, not trained weights
    # results might not make sence (attention scores)