In [4]:
import math
import torch
import torch.nn as nn

In [6]:
class SelfAttentionV1(nn.Module):
    def __init__(self, hidden_dim: int = 768):
        super(SelfAttentionV1, self).__init__()
        self.hidden_dim = hidden_dim

        self.query_proj = nn.Linear(hidden_dim, hidden_dim)
        self.key_proj = nn.Linear(hidden_dim, hidden_dim)
        self.value_proj = nn.Linear(hidden_dim, hidden_dim)

    def forward(self, X):
        # X shape is :(batch_size,seq_len,hidden_dim)
        Q = self.query_proj(X)
        K = self.key_proj(X)
        V = self.value_proj(X)
        # Q K V shape (batch,seq,hidden_dim)
        # attention_value (batch,seq,seq)
        # K需要变成(batch,hidden_dim,seq)
        attention_value = torch.matmul(Q, K.transpose(-1, -2))
        # attention_weight (batch,seq,seq)
        attention_weight = torch.softmax(attention_value / math.sqrt(self.hidden_dim), dim=-1)
        print(attention_weight)
        # (batch,seq,hidden)
        output = torch.matmul(attention_weight, V)
        return output


X = torch.randn((3, 2, 4))
print(X)
print(X.transpose(-1, -2))
self_att_net = SelfAttentionV1(4)
output = self_att_net(X)
print(output)

tensor([[[-0.2930, -0.5030,  0.3918,  0.8008],
         [-0.0612,  0.2693,  0.3765,  1.2838]],

        [[-0.7289, -0.4290,  1.4241,  1.8681],
         [-0.8519,  0.6102,  0.5591, -0.0497]],

        [[-1.5693, -1.0519,  0.0561, -1.3462],
         [-1.0723,  0.5500, -0.7984,  0.3470]]])
tensor([[[-0.2930, -0.0612],
         [-0.5030,  0.2693],
         [ 0.3918,  0.3765],
         [ 0.8008,  1.2838]],

        [[-0.7289, -0.8519],
         [-0.4290,  0.6102],
         [ 1.4241,  0.5591],
         [ 1.8681, -0.0497]],

        [[-1.5693, -1.0723],
         [-1.0519,  0.5500],
         [ 0.0561, -0.7984],
         [-1.3462,  0.3470]]])
tensor([[[0.4715, 0.5285],
         [0.4655, 0.5345]],

        [[0.5644, 0.4356],
         [0.5156, 0.4844]],

        [[0.3903, 0.6097],
         [0.3447, 0.6553]]], grad_fn=<SoftmaxBackward0>)
tensor([[[-0.3969, -0.0183,  0.2707, -0.9467],
         [-0.3958, -0.0194,  0.2713, -0.9474]],

        [[-0.3971,  0.0331,  0.7189, -1.0445],
         [-0.3501, 

In [None]:
class SelfAttentionV2(nn.Module):
    def __init__(self, dim: int = 768):
        super(SelfAttentionV2, self).__init__()
        self.dim = dim
        self.proj = nn.Linear(dim, dim * 3)

    def forward(self, X):
        #X shape(batch,seq,dim)
        #QKV shape (batch,seq,dim*3)
        QKV = self.proj(X)
        Q, K, V = torch.split(QKV, self.dim, dim=-1)
        att_weight = torch.softmax(
            torch.matmul(Q, K.transpose(-1, -2)) / math.sqrt(self.dim), dim=-1
        )
        print(att_weight)
        output = torch.matmul(att_weight, V)
        return output


X = torch.randn(3, 2, 4)
net = SelfAttentionV2(4)
print(net(X))

In [None]:
#1.droupout位置
#2.attention_mask
#3.output矩阵映射

class SelfAttentionV3(nn.Module):
    def __init__(self, dim, dropout_rate=0.1, *args, **kwargs):
        super(SelfAttentionV3, self).__init__()
        self.dim = dim
        self.proj = nn.Linear(dim, dim * 3)
        self.attention_dropout = nn.Dropout(dropout_rate)
        #可选
        self.output_proj = nn.Linear(dim, dim)

    def forward(self, X, attention_mask=None):
        # X(batch,seq,dim)
        QKV = self.proj(X)
        Q, K, V = torch.split(QKV, self.dim, dim=-1)
        #(batch,seq,seq)
        attention_weight = Q @ K.transpose(-1, -2) / math.sqrt(self.dim)
        if attention_mask is not None:
            attention_weight = attention_weight.masked_fill(attention_mask == 0, float("-1e20"))
        print(attention_weight)
        attention_weight = torch.softmax(attention_weight, dim=-1)
        print(attention_weight)
        attention_weight = self.attention_dropout(attention_weight)
        attention_result = attention_weight @ V
        output = self.output_proj(attention_result)
        return output


X = torch.randn(3, 4, 2)
#(batch,seq,seq)
mark = torch.tensor([
    [1, 1, 1, 0],
    [1, 1, 0, 0],
    [1, 0, 0, 0]
])
print(mark.shape)
mark = mark.unsqueeze(dim=1).repeat(1, 4, 1)
print(mark.shape)
net = SelfAttentionV3(4)
net = SelfAttentionV3(2)
print(net(X, mark))

In [None]:
class SelfAttentionV4(nn.Module):
    def __init__(self, dim: int, dropout_rate: float = 0.1) -> None:
        super().__init__()
        self.dim = dim
        self.query = nn.Linear(dim, dim)
        self.key = nn.Linear(dim, dim)
        self.value = nn.Linear(dim, dim)
        self.attention_dropout = nn.Dropout(dropout_rate)

    def forward(self, X, attention_mask=None):
        #X shape is (batch,seq,seq)
        Q = self.query(X)
        K = self.key(X)
        V = self.value(X)

        attention_weight = Q @ K.transpose(-1, -2) / math.sqrt(self.dim)
        if attention_mask is not None:
            attention_weight = attention_weight.masked_fill(attention_mask == 0, float("-inf"))
        attention_weight = torch.softmax(attention_weight, dim=-1)
        attention_weight = self.attention_dropout(attention_weight)
        #(batch,seq,seq)
        output = attention_weight @ V
        return output


X = torch.randn(3, 4, 2)
mark = torch.tensor([
    [1, 1, 1, 0],
    [1, 1, 0, 0],
    [1, 0, 0, 0]
])
# batch seq,seq
mark = mark.unsqueeze(dim=1).repeat(1, 4, 1)
net = SelfAttentionV4(2)
print(net(X, mark))