# 手写self attension

## 1. 什么是self attension
formula:
$$ Attention(Q, K, V) = softmax(\frac{QK^T}{\sqrt{d_k}})V $$

In [6]:
# simple version
import math
import torch
import torch.nn as nn

class SelfAttentionV1(nn.Module):
    def __init__(self, hidden_dim: int = 728)->None:
       super().__init__()
       self.hidden_dim = hidden_dim 
       
       self.q = nn.Linear(hidden_dim, hidden_dim)
       self.k = nn.Linear(hidden_dim, hidden_dim)
       self.v = nn.Linear(hidden_dim, hidden_dim)
       
    def forward(self, x):
        # X shape is: (batch_size, seq_len, hidden_dim)
        Q = self.q(x)
        K = self.k(x)
        V = self.v(x)
        # Q K V shape (batch, seq, hidden_dim)
        
        # K^T
        attention_value = torch.matmul(Q, K.transpose(-1,-2))
        attention_weight = torch.softmax(attention_value / math.sqrt(self.hidden_dim), dim=-1)
        
        # (batch, seq, hidden_dim)
        output = torch.matmul(attention_weight, V)
        
        return output
    
# test
X = torch.rand(3, 2, 4)
print("X:", X)

self_att_net = SelfAttentionV1(4)
print(self_att_net(X))

X: tensor([[[0.4821, 0.5496, 0.2873, 0.6103],
         [0.7172, 0.1542, 0.7106, 0.2280]],

        [[0.4413, 0.1183, 0.5076, 0.6402],
         [0.5094, 0.3109, 0.7545, 0.1079]],

        [[0.6474, 0.8568, 0.4717, 0.9785],
         [0.0347, 0.8786, 0.8726, 0.8526]]])
tensor([[[-0.4421,  0.1213, -0.3098,  0.0249],
         [-0.4422,  0.1214, -0.3098,  0.0249]],

        [[-0.3942, -0.0236, -0.1933,  0.0258],
         [-0.3940, -0.0229, -0.1918,  0.0274]],

        [[-0.7986,  0.2971,  0.0608, -0.0584],
         [-0.7986,  0.2971,  0.0608, -0.0584]]], grad_fn=<UnsafeViewBackward0>)


这里为什么要除以$\sqrt{d_k}$，是因为这样可以使得Q和K的点积更稳定，更不容易出现梯度消失或者梯度爆炸的情况。

## 2. 如果我们要优化的话，我们可以怎么做？当网络比较小的时候，我们的怎么做？

In [7]:
# efficiency optimize
class SelfAttentionV2(nn.Module):
    def __init__(self, hidden_dim: int = 128):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.proj = nn.Linear(hidden_dim, hidden_dim*3)
        
    def forward(self, x):
        # X shape (batch, seq, dim)
        # QKV shape(batch, seq, hidden_dim*3)
        QKV = self.proj(x)
        Q, K, V = torch.split(QKV, self.hidden_dim, dim=-1)
        att_weight = torch.softmax(
            torch.matmul(Q, K.transpose(-1, -2)),
            dim = -1
        ) / math.sqrt(self.hidden_dim)
        
        output = att_weight @ V
        return output
    
# test
X = torch.rand(3, 2, 4)
print("X:", X)

self_att_net = SelfAttentionV1(4)
print(self_att_net(X))

X: tensor([[[0.1162, 0.8403, 0.3770, 0.5865],
         [0.0939, 0.2697, 0.4536, 0.6432]],

        [[0.8119, 0.2099, 0.7808, 0.2939],
         [0.0445, 0.3325, 0.1797, 0.2189]],

        [[0.1676, 0.0918, 0.4949, 0.8680],
         [0.7305, 0.5256, 0.3589, 0.2128]]])
tensor([[[-0.2721,  0.2719, -0.0715, -0.2825],
         [-0.2718,  0.2717, -0.0714, -0.2830]],

        [[-0.1489,  0.1622, -0.1168, -0.1334],
         [-0.1489,  0.1622, -0.1165, -0.1342]],

        [[-0.2043,  0.2489, -0.1393, -0.0328],
         [-0.2276,  0.2667, -0.1377, -0.0190]]], grad_fn=<UnsafeViewBackward0>)


## 3. 加入一些细节

In [14]:
## 1. dropout position
## 2. attention mask
## 3. output 

class SelfAttentionV3(nn.Module):
    def __init__(self, hidden_dim: int = 128, dropout_rate: float = 0.8):
        super().__init__()
        self.hidden_dim = hidden_dim
        
        self.proj = nn.Linear(hidden_dim, hidden_dim*3)
        self.attention_dropout = nn.Dropout(dropout_rate)
        
        self.output_proj = nn.Linear(hidden_dim, hidden_dim)
        
    def forward(self, x, attention_mask = None):
        QKV = self.proj(x)
        Q, K, V = torch.split(QKV, self.hidden_dim, dim=-1)
        
        attention_weight = Q @ K.transpose(-1, -2) / math.sqrt(self.hidden_dim)
        if attention_mask is not None:
            attention_weight = attention_weight.masked_fill(
                attention_mask == 0,
                float("-inf")
                )
            
        attention_weight = torch.softmax(
            attention_weight,
            dim = -1
        )
        
        attention_weight = self.attention_dropout(attention_weight)
        attention_result = attention_weight @ V

        output = self.output_proj(attention_result)
        return output
    
# test
X = torch.rand(3, 4, 2)
mask = torch.tensor(
    [
        [1, 1, 1, 0],
        [1, 1, 0, 0],
        [1, 0, 0, 0]
    ]
)
print("mask shape ", mask.shape)
mask = mask.unsqueeze(dim=1).repeat(1, 4, 1)
print("mask shape ", mask.shape)

print("X:", X)

self_att_net = SelfAttentionV3(2)
print(self_att_net(X, mask))

mask shape  torch.Size([3, 4])
mask shape  torch.Size([3, 4, 4])
X: tensor([[[0.8306, 0.1741],
         [0.2105, 0.7246],
         [0.8159, 0.0152],
         [0.5970, 0.1283]],

        [[0.4402, 0.6248],
         [0.8392, 0.3145],
         [0.0983, 0.4623],
         [0.8232, 0.2566]],

        [[0.9905, 0.2713],
         [0.9090, 0.8953],
         [0.7381, 0.6693],
         [0.4292, 0.7702]]])
tensor([[[ 0.5422, -0.0870],
         [-0.3529,  1.8475],
         [ 0.6292, -0.2785],
         [ 0.1152,  0.8339]],

        [[ 0.6292, -0.2785],
         [ 0.3301,  0.3729],
         [ 0.6292, -0.2785],
         [-0.0797,  1.2573]],

        [[-0.8814,  2.9921],
         [ 0.6292, -0.2785],
         [ 0.6292, -0.2785],
         [-0.8814,  2.9921]]], grad_fn=<ViewBackward0>)


## 4. Interview oriented

In [15]:
class SelfAttentionV4(nn.Module):
    def __init__(self, hidden_dim: int = 128, dropout_rate: float = 0.8):
        super().__init__()
        self.hidden_dim = hidden_dim
        
        self.q = nn.Linear(hidden_dim, hidden_dim)
        self.k = nn.Linear(hidden_dim, hidden_dim)
        self.v = nn.Linear(hidden_dim, hidden_dim)
        
        self.atten_dropout = nn.Dropout(dropout_rate)
        
    def forward(self, x, attention_mask = None):
        Q = self.q(x)
        K = self.k(x)
        V = self.v(x)
        
        attention_weight = Q @ K.transpose(-1, -2) / math.sqrt(self.hidden_dim)
        if attention_mask is not None:
            attention_weight = attention_weight.masked_fill(
                attention_mask == 0,
                float("-inf")
            )
            
        # (batch, seq, seq)
        attention_weight = torch.softmax(
            attention_weight,
            dim = -1
        )
        
        attention_weight = self.atten_dropout(attention_weight)
        output = attention_weight @ V

        # (batch, seq, dim)
        return output

# test
X = torch.rand(3, 4, 2)
mask = torch.tensor(
    [
        [1, 1, 1, 0],
        [1, 1, 0, 0],
        [1, 0, 0, 0]
    ]
)
print("mask shape ", mask.shape)
mask = mask.unsqueeze(dim=1).repeat(1, 4, 1)
print("mask shape ", mask.shape)

print("X:", X)

self_att_net = SelfAttentionV4(2)
print(self_att_net(X, mask))


mask shape  torch.Size([3, 4])
mask shape  torch.Size([3, 4, 4])
X: tensor([[[0.4089, 0.7366],
         [0.0619, 0.2798],
         [0.7809, 0.8408],
         [0.1825, 0.4636]],

        [[0.9935, 0.9957],
         [0.9282, 0.5119],
         [0.9985, 0.7510],
         [0.0298, 0.7328]],

        [[0.2781, 0.8414],
         [0.0555, 0.3954],
         [0.5185, 0.9357],
         [0.6754, 0.0191]]])
tensor([[[-0.4380, -0.3482],
         [ 0.0000,  0.0000],
         [ 0.0000,  0.0000],
         [ 0.0000,  0.0000]],

        [[ 0.0000,  0.0000],
         [ 0.0000,  0.0000],
         [ 0.0000,  0.0000],
         [ 1.3571, -0.6325]],

        [[ 0.0000,  0.0000],
         [ 0.0000,  0.0000],
         [ 0.0000,  0.0000],
         [ 0.0000,  0.0000]]], grad_fn=<UnsafeViewBackward0>)
