In [1]:
import torch
import torch.nn as nn


class SelfAttention(nn.Module):
    def __init__(self, hidden_size):
        super(SelfAttention, self).__init__()
        self.q_layer = nn.Sequential(
            nn.Linear(in_features=hidden_size, out_features=hidden_size)
        )
        self.k_layer = nn.Sequential(
            nn.Linear(in_features=hidden_size, out_features=hidden_size)
        )
        self.v_layer = nn.Sequential(
            nn.Linear(in_features=hidden_size, out_features=hidden_size)
        )

    def forward(self, x):
        """
        前向过程
        :param x: [n,t,e] n个文本，t个时刻，每个时刻e维的向量
        :return: [n,t,e]
        """
        # 1. 获取q、k、v
        q = self.q_layer(x)  # [n,t,e]
        k = self.k_layer(x)  # [n,t,e]
        v = self.v_layer(x)  # [n,t,e]

        # 2. 计算q和k之间的相关性->F函数
        scores = torch.matmul(q, torch.permute(k, dims=(0, 2, 1)))  # [n,t,t] 每个时刻和每个时刻的相关性

        # 3. 转换为权重
        alpha = torch.softmax(scores, dim=2)  # [n,t,t]

        # 4. 值的合并
        v = torch.matmul(alpha, v)  # [n,t,e]
        return v


@torch.no_grad()
def attention():
    token_id = torch.tensor([
        [1, 3, 5],  # 表示一个样本，三个时刻
        [1, 6, 3]  # 表示一个样本，三个时刻
    ])

    # 静态特征向量提取 Word2Vec EmbeddingLayer
    emb_layer = nn.Embedding(num_embeddings=10, embedding_dim=4)
    x1 = emb_layer(token_id)  # [2,3,4]
    print(x1[0][0])  # 第一个样本的第一个token对应的向量
    print(x1[1][0])  # 第二个样本的第一个token对应的向量
    print("=" * 100)

    att = SelfAttention(hidden_size=4)
    x3 = att(x1)
    print(x3[0][0])  # 第一个样本的第一个token对应的向量
    print(x3[1][0])  # 第二个样本的第一个token对应的向量
    print(x3[1][1])  # 第二个样本的第二个token对应的向量

    att2 = SelfAttention(hidden_size=4)
    x4 = att2(x1)

    x5 = torch.concat([x3, x4], dim=2)
    print(x5.shape)
    print(x5[0][0])  # 第一个样本的第一个token对应的向量
    print(x5[1][0])  # 第二个样本的第一个token对应的向量


if __name__ == '__main__':
    attention()


tensor([-1.2545, -0.5140,  1.0517, -0.6680])
tensor([-1.2545, -0.5140,  1.0517, -0.6680])
tensor([-0.1423, -0.3471,  0.5710,  0.2368])
tensor([ 0.0423, -0.6835,  0.4214, -0.1109])
tensor([ 0.0367, -0.6609,  0.3978, -0.1035])
torch.Size([2, 3, 8])
tensor([-0.1423, -0.3471,  0.5710,  0.2368, -0.3454,  0.0538, -0.0294,  0.2424])
tensor([ 0.0423, -0.6835,  0.4214, -0.1109, -0.2446,  0.0827, -0.0693,  0.2088])
