## 自注意力机制

[参考文档](https://zhuanlan.zhihu.com/p/631398525?utm_id=0)
大概这样的：
输入：
1. X：N x H
2. $ W^K 、W^Q、W^V $：H x H ->可训练的共享参数
3. X x $ W^K 、W^Q、W^V $ -> K 、 Q 、 V： N x H
4. $ \alpha = softmax( Q \text{x} K^T )-->dim = -1 $ : N x N
5. $ \alpha $ x V: N x H

![](./img/attention_self1.png)

In [None]:
import torch 
import torch.nn as nn

class SelfAttention(nn.Module):
    def __init__(self, embed_dim):
        super(SelfAttention, self).__init__()
        self.d_kq = embed_dim
        self.K = nn.Linear(embed_dim,embed_dim)
        self.Q = nn.Linear(embed_dim,embed_dim)
        self.V = nn.Linear(embed_dim,embed_dim)

    def forward(self, x):
        key = self.K(x)
        query = self.Q(x)
        value = self.V(x)
        attention_scores = torch.matmul(query,key.T)/torch.sqrt(self.d_kq)
        attention_scores = nn.functional.softmax(attention_scores,dim=-1)
        self_attention = torch.matmul(attention_scores,value)
        return self_attention

# 多头注意力机制

多头注意力机制和自注意力机制区别：
1. 它通过使用多个独立的注意力头，分别计算注意力权重，并将它们的结果进行拼接或加权求和，从而获得更丰富的表示。
  在自注意力机制中，每个单词或者字都仅仅只有一个 q、k、v与其对应，如下图所示：
  $$q^i = W^q a^i$$
  ![](./img/attention_self2.png)
2. 多头注意力机制在$a^i$乘q、k、v之后再分配多个q、k、v
在$$q^i = W^q a^i$$后继续分配多个head，例如分配俩head，则以q为例的两个头：
$$q_{i,1} = W_{q,1}q^i$$
$$q_{i,2} = W_{q,2}q^i$$
$$q_{i,1} \cdot k_{i,1}^T$$

![](./img/attention_multi1.png)

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

# 定义多头自注意力模块
class MultiHeadSelfAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super(MultiHeadSelfAttention, self).__init__()
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        self.query = nn.Linear(embed_dim, embed_dim)
        self.key = nn.Linear(embed_dim, embed_dim)
        self.value = nn.Linear(embed_dim, embed_dim)
        self.fc = nn.Linear(embed_dim, embed_dim)

    def forward(self, x):
        batch_size, seq_len, embed_dim = x.size()

        # 将输入向量拆分为多个头
        q = self.query(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        k = self.key(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        v = self.value(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)

        # 计算注意力权重
        attn_weights = torch.matmul(q, k.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float))
        attn_weights = torch.softmax(attn_weights, dim=-1)

        # 注意力加权求和
        attended_values = torch.matmul(attn_weights, v).transpose(1, 2).contiguous().view(batch_size, seq_len, embed_dim)

        # 经过线性变换和残差连接
        x = self.fc(attended_values) + x

        return x

# 定义多头自注意力分类器模型
class MultiHeadSelfAttentionClassifier(nn.Module):
    def __init__(self, embed_dim, num_heads, hidden_dim, num_classes):
        super(MultiHeadSelfAttentionClassifier, self).__init__()
        self.attention = MultiHeadSelfAttention(embed_dim, num_heads)
        self.fc1 = nn.Linear(embed_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, num_classes)

    def forward(self, x):
        x = self.attention(x)
        x = x.mean(dim=1)  # 对每个位置的向量求平均
        x = self.fc1(x)
        x = torch.relu(x)
        x = self.fc2(x)
        return x


In [None]:

class MultiheadAttention(nn.Module):
    def __init__(self, num_head, emb_dim):
        self.dim = emb_dim // num_head
        self.q = nn.Linear(self.dim,self.dim)
        self.k = nn.Linear(self.dim,self.dim)
        self.v = nn.Linear(self.dim,self.dim)
        self.embed_dim = emb_dim
        self.head_num = num_head
    def forward(self,x):
        batch_size, seq_len, embed_dim = x.size()
        if embed_dim != self.dim:
            assert "input dim 不等于初始化dim"
        q = self.q(x)

In [14]:
import torch
import torch.nn as nn
# x= torch.random()
x = torch.rand(3,10,20)
fc = torch.nn.Linear(20,30)
res = fc(x) # torch.Size([3, 10, 30])
res = res.view(3,10,2,15)
res.shape

torch.Size([3, 10, 2, 15])

In [19]:
x = torch.rand(3,10,20)
# x = torch.rand(3,10,20)
fc = torch.nn.Linear(10,30)
res = fc(x)

RuntimeError: mat1 and mat2 shapes cannot be multiplied (30x20 and 10x30)

In [12]:
a = torch.rand(4,10)
print(a)
a.view(2,2,10)

tensor([[0.3249, 0.5755, 0.1321, 0.3219, 0.5015, 0.4152, 0.6919, 0.4683, 0.3959,
         0.8583],
        [0.0333, 0.6990, 0.1000, 0.4782, 0.4039, 0.6993, 0.9004, 0.9163, 0.3077,
         0.1099],
        [0.6061, 0.3521, 0.1836, 0.2257, 0.9928, 0.6356, 0.8696, 0.2656, 0.3837,
         0.3558],
        [0.2074, 0.1046, 0.0048, 0.5473, 0.2107, 0.6902, 0.9560, 0.6407, 0.3129,
         0.6873]])


tensor([[[0.3249, 0.5755, 0.1321, 0.3219, 0.5015, 0.4152, 0.6919, 0.4683,
          0.3959, 0.8583],
         [0.0333, 0.6990, 0.1000, 0.4782, 0.4039, 0.6993, 0.9004, 0.9163,
          0.3077, 0.1099]],

        [[0.6061, 0.3521, 0.1836, 0.2257, 0.9928, 0.6356, 0.8696, 0.2656,
          0.3837, 0.3558],
         [0.2074, 0.1046, 0.0048, 0.5473, 0.2107, 0.6902, 0.9560, 0.6407,
          0.3129, 0.6873]]])