# 手写self attension

## 1. 什么是self attension
formula:
$$ Attention(Q, K, V) = softmax(\frac{QK^T}{\sqrt{d_k}})V $$

In [2]:
# simple version
import math
import torch
import torch.nn as nn

class SelfAttentionV1(nn.Module):
    def __init__(self, hidden_dim: int = 728)->None:
       super().__init__()
       self.hidden_dim = hidden_dim 
       
       self.q = nn.Linear(hidden_dim, hidden_dim)
       self.k = nn.Linear(hidden_dim, hidden_dim)
       self.v = nn.Linear(hidden_dim, hidden_dim)
       
    def forward(self, x):
        # X shape is: (batch_size, seq_len, hidden_dim)
        Q = self.q(x)
        K = self.k(x)
        V = self.v(x)
        # Q K V shape (batch, seq, hidden_dim)
        
        # K^T
        attention_value = torch.matmul(Q, K.transpose(-1,-2))
        attention_weight = torch.softmax(attention_value / math.sqrt(self.hidden_dim), dim=-1)
        
        # (batch, seq, hidden_dim)
        output = torch.matmul(attention_weight, V)
        
        return output
    
# test
X = torch.rand(3, 2, 4)
print("X:", X)

self_att_net = SelfAttentionV1(4)
print(self_att_net(X))

X: tensor([[[0.7795, 0.2373, 0.7946, 0.9313],
         [0.1784, 0.8063, 0.1943, 0.1121]],

        [[0.7090, 0.2948, 0.8234, 0.1098],
         [0.6749, 0.5183, 0.0602, 0.8583]],

        [[0.6374, 0.1271, 0.4514, 0.6903],
         [0.2616, 0.2978, 0.7061, 0.7400]]])
tensor([[[ 0.2579, -0.4916,  0.3464,  0.2177],
         [ 0.2610, -0.4959,  0.3418,  0.2271]],

        [[ 0.3481, -0.6142,  0.2398,  0.3398],
         [ 0.3484, -0.6145,  0.2405,  0.3396]],

        [[ 0.2389, -0.5029,  0.2769,  0.3219],
         [ 0.2386, -0.5025,  0.2773,  0.3216]]], grad_fn=<UnsafeViewBackward0>)


这里为什么要除以$\sqrt{d_k}$，是因为这样可以使得Q和K的点积更稳定，更不容易出现梯度消失或者梯度爆炸的情况。

## 2. 如果我们要优化的话，我们可以怎么做？当网络比较小的时候，我们的怎么做？

In [3]:
# efficiency optimize
class SelfAttentionV2(nn.Module):
    def __init__(self, hidden_dim: int = 128):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.proj = nn.Linear(hidden_dim, hidden_dim*3)
        
    def forward(self, x):
        # X shape (batch, seq, dim)
        # QKV shape(batch, seq, hidden_dim*3)
        QKV = self.proj(x)
        Q, K, V = torch.split(QKV, self.hidden_dim, dim=-1)
        att_weight = torch.softmax(
            torch.matmul(Q, K.transpose(-1, -2)),
            dim = -1
        ) / math.sqrt(self.hidden_dim)
        
        output = att_weight @ V
        return output
    
# test
X = torch.rand(3, 2, 4)
print("X:", X)

self_att_net = SelfAttentionV1(4)
print(self_att_net(X))

X: tensor([[[0.2239, 0.4982, 0.3198, 0.2065],
         [0.6760, 0.5683, 0.2843, 0.7094]],

        [[0.4296, 0.0365, 0.9676, 0.8618],
         [0.9661, 0.8206, 0.7621, 0.0858]],

        [[0.1024, 0.1856, 0.5597, 0.9592],
         [0.6754, 0.9313, 0.9466, 0.3101]]])
tensor([[[ 0.4516,  0.3538, -0.3691,  0.7849],
         [ 0.4505,  0.3566, -0.3699,  0.7883]],

        [[ 0.6903,  0.1282, -0.0618,  1.0307],
         [ 0.6926,  0.1296, -0.0609,  1.0310]],

        [[ 0.5350,  0.1434, -0.2532,  0.9843],
         [ 0.5435,  0.1446, -0.2481,  0.9878]]], grad_fn=<UnsafeViewBackward0>)


## 3. 加入一些细节

In [4]:
## 1. dropout position
## 2. attention mask
## 3. output 

class SelfAttentionV3(nn.Module):
    def __init__(self, hidden_dim: int = 128, dropout_rate: float = 0.8):
        super().__init__()
        self.hidden_dim = hidden_dim
        
        self.proj = nn.Linear(hidden_dim, hidden_dim*3)
        self.attention_dropout = nn.Dropout(dropout_rate)
        
        self.output_proj = nn.Linear(hidden_dim, hidden_dim)
        
    def forward(self, x, attention_mask = None):
        QKV = self.proj(x)
        Q, K, V = torch.split(QKV, self.hidden_dim, dim=-1)
        
        attention_weight = Q @ K.transpose(-1, -2) / math.sqrt(self.hidden_dim)
        if attention_mask is not None:
            attention_weight = attention_weight.masked_fill(
                attention_mask == 0,
                float("-inf")
                )
            
        attention_weight = torch.softmax(
            attention_weight,
            dim = -1
        )
        
        attention_weight = self.attention_dropout(attention_weight)
        attention_result = attention_weight @ V

        output = self.output_proj(attention_result)
        return output
    
# test
X = torch.rand(3, 4, 2)
mask = torch.tensor(
    [
        [1, 1, 1, 0],
        [1, 1, 0, 0],
        [1, 0, 0, 0]
    ]
)
# print("mask shape ", mask.shape)
print(mask)
mask = mask.unsqueeze(dim=1).repeat(1, 4, 1)
print(mask)
# print("mask shape ", mask.shape)

print("X:", X)

self_att_net = SelfAttentionV3(2)
# print(self_att_net(X, mask))

tensor([[1, 1, 1, 0],
        [1, 1, 0, 0],
        [1, 0, 0, 0]])
tensor([[[1, 1, 1, 0],
         [1, 1, 1, 0],
         [1, 1, 1, 0],
         [1, 1, 1, 0]],

        [[1, 1, 0, 0],
         [1, 1, 0, 0],
         [1, 1, 0, 0],
         [1, 1, 0, 0]],

        [[1, 0, 0, 0],
         [1, 0, 0, 0],
         [1, 0, 0, 0],
         [1, 0, 0, 0]]])
X: tensor([[[0.5080, 0.5815],
         [0.3296, 0.9273],
         [0.5979, 0.6149],
         [0.5742, 0.7032]],

        [[0.4230, 0.2200],
         [0.3816, 0.2268],
         [0.0430, 0.7821],
         [0.6195, 0.4677]],

        [[0.6937, 0.2079],
         [0.1230, 0.0851],
         [0.1146, 0.8231],
         [0.2967, 0.6506]]])


## 4. Interview oriented

In [28]:
class SelfAttentionV4(nn.Module):
    def __init__(self, hidden_dim: int = 128, dropout_rate: float = 0.8):
        super().__init__()
        self.hidden_dim = hidden_dim
        
        self.q = nn.Linear(hidden_dim, hidden_dim)
        self.k = nn.Linear(hidden_dim, hidden_dim)
        self.v = nn.Linear(hidden_dim, hidden_dim)
        
        self.atten_dropout = nn.Dropout(dropout_rate)
        
    def forward(self, x, attention_mask = None):
        Q = self.q(x)
        K = self.k(x)
        V = self.v(x)
        
        print("Q shape:", Q.shape)
        
        attention_weight = Q @ K.transpose(-1, -2) / math.sqrt(self.hidden_dim)
        print("attention weight shape:", attention_weight.shape)
        # print(attention_mask)
        print(attention_weight)
        if attention_mask is not None:
            attention_weight = attention_weight.masked_fill(
                attention_mask == 0,
                float("-inf")
            )
        print("\r\n", attention_weight)
            
        # (batch, seq, seq)
        attention_weight = torch.softmax(
            attention_weight,
            dim = -1
        )
        
        attention_weight = self.atten_dropout(attention_weight)
        output = attention_weight @ V

        # (batch, seq, dim)
        return output

# test
X = torch.rand(3, 4, 2)
print("X size shape", X.shape)
mask = torch.tensor(
    [
        [1, 1, 1, 0],
        [1, 1, 0, 0],
        [1, 0, 0, 0]
    ]
)
print("mask shape", mask.shape)
mask = mask.unsqueeze(dim=1).repeat(1, 4, 1)
print("mask shape", mask.shape)

self_att_net = SelfAttentionV4(2)
result = self_att_net(X, mask)


X size shape torch.Size([3, 4, 2])
mask shape torch.Size([3, 4])
mask shape torch.Size([3, 4, 4])
Q shape: torch.Size([3, 4, 2])
attention weight shape: torch.Size([3, 4, 4])
tensor([[[-0.5785, -0.1090, -0.3257, -0.2762],
         [-0.4806, -0.0688, -0.3072, -0.2650],
         [-0.4174, -0.0869, -0.2212, -0.1858],
         [-0.4043, -0.0826, -0.2168, -0.1825]],

        [[-0.0487, -0.1576,  0.0632, -0.3716],
         [-0.0673, -0.2039,  0.0792, -0.4687],
         [-0.1082, -0.2105,  0.0579, -0.3725],
         [-0.1546, -0.3601,  0.1178, -0.7248]],

        [[-0.4661,  0.0917, -0.3320, -0.0616],
         [-0.2817,  0.0463, -0.1862, -0.0323],
         [-0.4904,  0.0822, -0.3266, -0.0571],
         [-0.3798,  0.0598, -0.2467, -0.0421]]], grad_fn=<DivBackward0>)

 tensor([[[-0.5785, -0.1090, -0.3257,    -inf],
         [-0.4806, -0.0688, -0.3072,    -inf],
         [-0.4174, -0.0869, -0.2212,    -inf],
         [-0.4043, -0.0826, -0.2168,    -inf]],

        [[-0.0487, -0.1576,    -inf,   