### 公式：
$$
Attention(Q, K, V) = Softmax(\frac{QK^T}{\sqrt{d_k}})V
$$

### 第一重境界：简化版本

In [6]:
import math
import torch
import torch.nn as nn

class SelfAttentionV1(nn.Module):
    def __init__(self, hidden_dim: int =728) -> None:
        super().__init__()
        self.hidden_dim = hidden_dim

        self.query_proj = nn.Linear(hidden_dim, hidden_dim)
        self.key_proj = nn.Linear(hidden_dim, hidden_dim)    
        self.value_proj = nn.Linear(hidden_dim, hidden_dim)
    
    def forward(self, x):
        # x: [batch_size, seq_len, hidden_dim]
        Q = self.query_proj(x)
        K = self.key_proj(x)
        V = self.value_proj(x)
        # Q K V: [batch_size, seq_len, hidden_dim]

        attention_value = torch.matmul(Q, K.transpose(1, 2)) / math.sqrt(self.hidden_dim)
        # attention_value: [batch_size, seq_len, seq_len]
        attention_weight = torch.softmax(attention_value, dim=-1)
        print(attention_weight)
        # attention_weight: [batch_size, seq_len, seq_len]
        output = torch.matmul(attention_weight, V)
        # output: [batch_size, seq_len, hidden_dim]

        return output

In [7]:
X = torch.randn(3, 2, 4)
print(X)

self_att_net = SelfAttentionV1(4)
self_att_net(X)

In [12]:
# 探究softmax的指定dim效果
X = torch.randn(3, 2, 4)
print(X)
Y = torch.softmax(X, dim=1)
print(Y)

### 第二重境界：效率优化
只做一次proj

In [14]:
import math
import torch
import torch.nn as nn

class SelfAttentionV2(nn.Module):
    def __init__(self, hidden_dim: int =728) -> None:
        super().__init__()
        self.hidden_dim = hidden_dim

        self.proj = nn.Linear(hidden_dim, hidden_dim * 3)
    
    def forward(self, x):
        # x: [batch_size, seq_len, hidden_dim]
        QKV = self.proj(x)
        # QKV: [batch_size, seq_len, hidden_dim * 3]
        Q, K, V = torch.split(QKV, self.hidden_dim, dim=-1)

        attention_value = torch.matmul(Q, K.transpose(1, 2)) / math.sqrt(self.hidden_dim)
        # attention_value: [batch_size, seq_len, seq_len]
        attention_weight = torch.softmax(attention_value, dim=-1)
        # attention_weight: [batch_size, seq_len, seq_len]
        output = attention_weight @ V  # @ 等于torch.matmul()
        # output: [batch_size, seq_len, hidden_dim]

        return output

In [15]:
X = torch.randn(3, 2, 4)
print(X)

self_att_net = SelfAttentionV2(4)
self_att_net(X)

### 第三重境界：加入一些细节
* dropout 的位置
* attention_mask：输入的seq_len不可能都一样
* MHA 中还有一个output_proj映射

In [1]:
import math
import torch
import torch.nn as nn

class SelfAttentionV3(nn.Module):
    def __init__(self, hidden_dim: int =728, dropout_rate=0.1) -> None:
        super().__init__()
        self.hidden_dim = hidden_dim

        self.query_proj = nn.Linear(hidden_dim, hidden_dim)
        self.key_proj = nn.Linear(hidden_dim, hidden_dim)    
        self.value_proj = nn.Linear(hidden_dim, hidden_dim)

        self.attention_dropout = nn.Dropout(dropout_rate)

        self.output_proj = nn.Linear(hidden_dim, hidden_dim)
    
    def forward(self, x, attention_mask=None):
        # x: [batch_size, seq_len, hidden_dim]
        Q = self.query_proj(x)
        K = self.key_proj(x)
        V = self.value_proj(x)
        # Q K V: [batch_size, seq_len, hidden_dim]

        attention_value = torch.matmul(Q, K.transpose(1, 2)) / math.sqrt(self.hidden_dim)
        # attention_value: [batch_size, seq_len, seq_len]

        if attention_mask is not None:
            # 根据给定的条件（attention_mask == 0）对attention_value进行修改，将满足条件的元素替换为指定的值
            attention_value = attention_value.masked_fill(attention_mask == 0, -1e9)

        attention_weight = torch.softmax(attention_value, dim=-1)
        print("attention_weight:", attention_weight)

        attention_weight = self.attention_dropout(attention_weight)  # dropout在这里！！！将有些词直接drop掉
        print("attention_weight:", attention_weight)
        # attention_weight: [batch_size, seq_len, seq_len]

        attention_result = torch.matmul(attention_weight, V)
        # attention_result: [batch_size, seq_len, hidden_dim]

        output = self.output_proj(attention_result)

        return output 

In [2]:
x = torch.randn(3, 4, 2)
# x : (batch_size, seq_len, hidden_dim)
mask = torch.tensor([[1, 1, 1, 0], [1, 1, 0, 0], [1, 0, 0, 0]])
# mask : (batch_size, seq_len)
# 要和attention_weight的shape保持一致：(batch_size, seq_len, seq_len)
# 扩维
mask = mask.unsqueeze(dim=1).repeat(1, x.size(1), 1)
# unsqueeze在指定的维度（dim=1）上增加一个新的维度，变成(batch_size, 1, seq_len)
# repeat沿着指定的维度重复张量，变成(batch_size, seq_len, seq_len)
print(mask)

tensor([[[1, 1, 1, 0],
         [1, 1, 1, 0],
         [1, 1, 1, 0],
         [1, 1, 1, 0]],

        [[1, 1, 0, 0],
         [1, 1, 0, 0],
         [1, 1, 0, 0],
         [1, 1, 0, 0]],

        [[1, 0, 0, 0],
         [1, 0, 0, 0],
         [1, 0, 0, 0],
         [1, 0, 0, 0]]])


In [3]:
# 深入理解unsqueeze和repeat
mask = torch.tensor([[1, 1, 1, 0], [1, 1, 0, 0], [1, 0, 0, 0]])
print(mask, mask.size())
mask = mask.unsqueeze(dim=1)
print(mask, mask.size())
mask = mask.repeat(1, 1, x.size(1))
print(mask, mask.size())
mask = mask.reshape(x.size(0), 1, x.size(1), x.size(1))
print(mask, mask.size())

tensor([[1, 1, 1, 0],
        [1, 1, 0, 0],
        [1, 0, 0, 0]]) torch.Size([3, 4])
tensor([[[1, 1, 1, 0]],

        [[1, 1, 0, 0]],

        [[1, 0, 0, 0]]]) torch.Size([3, 1, 4])
tensor([[[1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0]],

        [[1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0]],

        [[1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0]]]) torch.Size([3, 1, 16])
tensor([[[[1, 1, 1, 0],
          [1, 1, 1, 0],
          [1, 1, 1, 0],
          [1, 1, 1, 0]]],


        [[[1, 1, 0, 0],
          [1, 1, 0, 0],
          [1, 1, 0, 0],
          [1, 1, 0, 0]]],


        [[[1, 0, 0, 0],
          [1, 0, 0, 0],
          [1, 0, 0, 0],
          [1, 0, 0, 0]]]]) torch.Size([3, 1, 4, 4])


In [6]:
net = SelfAttentionV3(2)
net(x, mask)

attention_weight: tensor([[[[0.4364, 0.4057, 0.1580, 0.0000],
          [0.3586, 0.3635, 0.2778, 0.0000],
          [0.3707, 0.3408, 0.2885, 0.0000],
          [0.3331, 0.3058, 0.3611, 0.0000]],

         [[0.3782, 0.4063, 0.2155, 0.0000],
          [0.2360, 0.2727, 0.4914, 0.0000],
          [0.2831, 0.2639, 0.4530, 0.0000],
          [0.4154, 0.4344, 0.1501, 0.0000]],

         [[0.2569, 0.4696, 0.2735, 0.0000],
          [0.1911, 0.5683, 0.2407, 0.0000],
          [0.3203, 0.3579, 0.3218, 0.0000],
          [0.2732, 0.4304, 0.2963, 0.0000]]],


        [[[0.5182, 0.4818, 0.0000, 0.0000],
          [0.4966, 0.5034, 0.0000, 0.0000],
          [0.5211, 0.4789, 0.0000, 0.0000],
          [0.5214, 0.4786, 0.0000, 0.0000]],

         [[0.4821, 0.5179, 0.0000, 0.0000],
          [0.4639, 0.5361, 0.0000, 0.0000],
          [0.5176, 0.4824, 0.0000, 0.0000],
          [0.4888, 0.5112, 0.0000, 0.0000]],

         [[0.3536, 0.6464, 0.0000, 0.0000],
          [0.2516, 0.7484, 0.0000, 0.0000],
  

tensor([[[[-0.0549,  1.2250],
          [-0.1789,  1.1136],
          [-0.0378,  1.2962],
          [-0.2227,  0.9227]],

         [[-0.0359,  1.3616],
          [-0.3655,  0.7591],
          [-0.2266,  1.1342],
          [-0.0459,  1.3197]],

         [[-0.0496,  1.2516],
          [-0.2866,  0.7378],
          [-0.0310,  1.3288],
          [-0.0432,  1.2785]]],


        [[[-0.0756,  1.1383],
          [-0.0749,  1.1425],
          [-0.0757,  1.1377],
          [-0.0757,  1.1377]],

         [[-0.0685,  1.2278],
          [-0.0678,  1.2319],
          [-0.2751,  0.9257],
          [-0.0687,  1.2263]],

         [[-0.0759,  1.1355],
          [-0.0937,  1.0636],
          [-0.0553,  1.2192],
          [-0.0699,  1.1600]]],


        [[[-0.0916,  1.0449],
          [-0.0916,  1.0449],
          [-0.0916,  1.0449],
          [-0.0916,  1.0449]],

         [[-0.4829,  0.5425],
          [-0.0858,  1.1106],
          [-0.4829,  0.5425],
          [-0.0858,  1.1106]],

         [[ 0.0366, 

### 第四重境界：面试写法
没啥区别啊，up主又在瞎讲

In [38]:
import math
import torch
import torch.nn as nn

class SelfAttentionV4(nn.Module):
    def __init__(self, hidden_dim: int=728, dropout_rate: float=0.1) -> None:
        super().__init__()
        self.hidden_dim = hidden_dim

        self.query_proj = nn.Linear(hidden_dim, hidden_dim)
        self.key_proj = nn.Linear(hidden_dim, hidden_dim)    
        self.value_proj = nn.Linear(hidden_dim, hidden_dim)

        self.attention_dropout = nn.Dropout(dropout_rate)

    
    def forward(self, x, attention_mask=None):
        # x: [batch_size, seq_len, hidden_dim]
        Q = self.query_proj(x)
        K = self.key_proj(x)
        V = self.value_proj(x)
        # Q K V: [batch_size, seq_len, hidden_dim]

        attention_value = torch.matmul(Q, K.transpose(1, 2)) / math.sqrt(self.hidden_dim)
        # attention_value: [batch_size, seq_len, seq_len]

        if attention_mask is not None:
            # 根据给定的条件（attention_mask == 0）对attention_value进行修改，将满足条件的元素替换为指定的值
            attention_value = attention_value.masked_fill(attention_mask == 0, -1e9)

        attention_weight = torch.softmax(attention_value, dim=-1)
        attention_weight = self.attention_dropout(attention_weight)  # dropout在这里！！！将有些词直接drop掉
        # attention_weight: [batch_size, seq_len, seq_len]

        output = torch.matmul(attention_weight, V)

        return output 