# 描述
基于pytorch实现attention    

参考文献：    
- [手写 Self-Attention 的四重境界，从 self-attention 到 multi-head self-attention](https://bruceyuan.com/hands-on-code/from-self-attention-to-multi-head-self-attention.html#%E7%AC%AC%E5%9B%9B%E9%87%8D-multi-head-self-attention)

In [None]:
import math
import torch
import torch.nn as nn

class SelfAttentionV1(nn.Module):
    def __init__(self, hidden_dim):
        super(SelfAttentionV1, self).__init__()
        self.hidden_dim = hidden_dim
        
        self.query_proj = nn.Linear(hidden_dim, hidden_dim)
        self.key_proj = nn.Linear(hidden_dim, hidden_dim)
        self.value_proj = nn.Linear(hidden_dim, hidden_dim)

    def forward(self, x):
        # x shape: (batch_size, seq_len, embedding_dim)
        Q = self.query_proj(x) # (batch_size, seq_len, hidden_dim)
        K = self.key_proj(x) # (batch_size, seq_len, hidden_dim)
        V = self.value_proj(x) # (batch_size, seq_len, hidden_dim)

        # Attention weights
        # K 经过转置之后 shape: (batch_size, hidden_dim, seq_len)
        attention_value = torch.matmul(Q, K.transpose(-1, -2))
        # attention_weight shape: (batch_size, seq_len, seq_len)
        attention_weight = torch.softmax(attention_value / math.sqrt(self.hidden_dim), dim=-1)
        # output shape: (batch_size, seq_len, hidden_dim)
        output = torch.matmul(attention_weight, V)
        return output
    
X = torch.rand(3, 2, 4)
print(X)
net = SelfAttentionV1(4)
ret = net(X)
print(ret)



In [None]:
class SelfAttentionV2(nn.Module):
    # 将Q、K、V合并到一个大的矩阵中，然后一次性完成计算
    def __init__(self, dim):
        super(SelfAttentionV2, self).__init__()
        self.dim = dim
        self.proj = nn.Linear(dim, 3 * dim)
        self.output_proj = nn.Linear(dim, dim)

    def forward(self, x):
        # x shape: (batch_size, seq_len, embedding_dim)
        QKV = self.proj(x) # (batch_size, seq_len, 3 * dim)
        # 分解Q、K、V
        Q, K, V = torch.split(QKV, self.dim, dim=-1)
        # Q, K, V shape: (batch_size, seq_len, dim)
        att_weight = torch.softmax(torch.matmul(
            Q, K.transpose(-1, -2)) / math.sqrt(self.dim), dim=-1
        )
        output = self.output_proj(torch.matmul(att_weight, V))
        return output
    
X = torch.rand(3, 2, 4)
print(X)
net = SelfAttentionV2(4)
ret = net(X)
print(ret)


In [None]:
class SelfAttentionV3(nn.Module):
    # 加入dropout 和 mask
    def __init__(self, dim, dropout=0.1):
        super(SelfAttentionV3, self).__init__()
        self.dim = dim
        self.proj = nn.Linear(dim, 3 * dim)
        self.att_drop = nn.Dropout(dropout)
        self.output_proj = nn.Linear(dim, dim)

    def forward(self, x, att_mask=None):
        QKV = self.proj(x)
        Q, K, V = torch.split(QKV, self.dim, dim=-1)
        att_weight = torch.matmul(Q, K.transpose(-1, -2)) / math.sqrt(self.dim)
        if att_mask != None:
            # 给att_weigh添加一个极小值
            att_weight = att_weight.masked_fill(att_mask == 0, float("-1e20"))
        att_weight = torch.softmax(att_weight, dim=-1)
        att_weight = self.att_drop(att_weight)
        output = self.output_proj(torch.matmul(att_weight, V))
        return output
    
X = torch.rand(3, 4, 2)
print(X)
mask = torch.tensor(
    [
        [1, 1, 1, 0],
        [1, 1, 0, 0],
        [1, 0, 0, 0]
    ]
)
print(mask.shape)
x_mask = mask.unsqueeze(dim=1).repeat(1, 4, 1)
print(x_mask)
print(x_mask.shape)
net = SelfAttentionV3(2)
ret = net(X, x_mask)
print(ret)


In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, hidden_dim, nums_head, dropout=0.1):
        super(MultiHeadAttention, self).__init__()
        self.nums_head = nums_head
        self.hidden_dim = hidden_dim
        self.head_dim = hidden_dim // nums_head

        self.q_proj = nn.Linear(hidden_dim, hidden_dim)
        self.k_proj = nn.Linear(hidden_dim, hidden_dim)
        self.v_proj = nn.Linear(hidden_dim, hidden_dim)

        self.att_dropout = nn.Dropout(dropout)
        self.out_proj = nn.Linear(hidden_dim, hidden_dim)

    def forward(self, x, att_mask=None):
        batch_size, seq_len, _ = x.shape
        
        Q = self.q_proj(x)
        K = self.k_proj(x)
        V = self.v_proj(x) # (batch_size, seq_len, embedding_dim)

        # (batch_size, seq_len, embedding_dim) => (batch_size, head_num, seq_len, head_dim)
        q_state = Q.view(batch_size, seq_len, self.nums_head, self.head_dim).transpose(1, 2)
        k_state = K.view(batch_size, seq_len, self.nums_head, self.head_dim).transpose(1, 2)
        v_state = V.view(batch_size, seq_len, self.nums_head, self.head_dim).transpose(1, 2)

        # shape: (batch_size, head_num, seq_len, seq_len)
        att_weight = q_state @ k_state.transpose(-1, -2) / math.sqrt(self.head_dim)
        print(att_weight.shape)
        if att_mask != None:
            att_weight = att_weight.masked_fill(
                att_weight == 0, float("-1e20")
            )
        # 在最后一维做softmax
        att_weight = torch.softmax(att_weight, dim=3)
        att_weight = self.att_dropout(att_weight)
        output_mid = att_weight @ v_state # (batch_size, head_num, seq_len, head_dim)
        
        # (batch_size, head_num, seq_len, head_dim) => (batch_size, seq_len, head_num, head_dim)
        # 这里的 contiguous() 是相当于返回一个连续内存的 tensor，因为后面用的是view，view只能在连续的内存上操作
        output_mid = output_mid.transpose(1, 2).contiguous()
        output = output_mid.view(batch_size, seq_len, -1)
        output = self.out_proj(output)
        return output
        


X = torch.rand(3, 4, 128)
mask = torch.tensor(
    [
        [1, 1, 1, 0],
        [1, 1, 0, 0],
        [1, 0, 0, 0]
    ]
)
print(mask.shape)
x_mask = mask.unsqueeze(dim=1).unsqueeze(dim=2).repeat(1, 8, 4, 1)
print(x_mask.shape)
net = MultiHeadAttention(128, 8)
ret = net(X, x_mask)
print(ret)
