In [2]:
import torch
from torch import nn
import math
from torch import functional as F

In [None]:
class selfAttention(nn.Module):
    def __init__(self, hidden_dim:int = 728)->None:
        super().__init__()
        self.hidden_dim = hidden_dim
    #初始化qkv
        self.query_proj = nn.Linear(hidden_dim,hidden_dim)
        self.key_proj = nn.Linear(hidden_dim,hidden_dim)
        self.value_proj = nn.Linear(hidden_dim,hidden_dim)
    
    def forward(self,X):
        # Q : (batch_size,seq_len,dim)
        # K*T:(batch_size,dim,seq_len)
        Q = self.query_proj(X)
        K = self.key_proj(X)
        V = self.value_proj(X)

        attention_value = torch.matmul(
            #给k转置
            Q,K.transpose(-1,-2)
        )

        attention_softmax = torch.softmax(
            attention_value / math.sqrt(self.hidden_dim),dim=-1
        )

        result = torch.matmul(
            attention_softmax,V
        )
        return result

X = torch.rand(2,3,4)


self_att_test = selfAttention(4)
self_att_test(X)


tensor([[[-0.1647, -0.5922,  0.2470, -0.0232],
         [-0.1650, -0.5883,  0.2439, -0.0200],
         [-0.1641, -0.5926,  0.2476, -0.0244]],

        [[-0.2515, -0.7480,  0.2592, -0.1423],
         [-0.2503, -0.7461,  0.2571, -0.1425],
         [-0.2504, -0.7468,  0.2562, -0.1411]]], grad_fn=<UnsafeViewBackward0>)

In [40]:
class selfAttentionV2(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
        self.qkv = nn.Linear(dim,dim*3)

    def forward(self,X):
        # 
        QKV = self.qkv(X)
        Q,K,V = torch.split(QKV,self.dim,dim=-1)

        result = torch.softmax(torch.matmul(
            Q,K.transpose(-1,-2)
        ) / math.sqrt(self.dim),dim=-1
        )
        print(result)
        output = result @ V
        return output
X = torch.rand(3,2,4)
model2 = selfAttentionV2(4)
model2(X)


tensor([[[0.4894, 0.5106],
         [0.5109, 0.4891]],

        [[0.5137, 0.4863],
         [0.4978, 0.5022]],

        [[0.5048, 0.4952],
         [0.5059, 0.4941]]], grad_fn=<SoftmaxBackward0>)


tensor([[[-0.4449,  0.2280, -0.5718,  0.0779],
         [-0.4421,  0.2280, -0.5792,  0.0705]],

        [[-0.5378,  0.1127, -0.5518,  0.1663],
         [-0.5330,  0.1195, -0.5498,  0.1640]],

        [[-0.5221,  0.0467, -0.6415,  0.1114],
         [-0.5225,  0.0464, -0.6415,  0.1117]]], grad_fn=<UnsafeViewBackward0>)

In [43]:
class MutiheadTransformer(nn.Module):
    def __init__(self,head_num,hidden_dim,drop_out = 0.1):
        super().__init__()
        self.head_num = head_num
        self.hidden_dim = hidden_dim
        self.head_dim = hidden_dim // head_num
        self.q_proj = nn.Linear(hidden_dim,hidden_dim)
        self.k_proj = nn.Linear(hidden_dim,hidden_dim)
        self.v_proj = nn.Linear(hidden_dim,hidden_dim)

        self.drop_out = nn.Dropout(drop_out)
        
        self.out_proj = nn.Linear(hidden_dim,hidden_dim)

    def forward(self,X,hidden_mask = None):
        #   qkv (batch,s,h)  ->   (batch,head_num,s,head_dim)
        batch , seq , _ =  X.size()
        Q = self.q_proj(X)
        K = self.k_proj(X)
        V = self.v_proj(X)

        #   transform it's shape (batch,s,head_num,head_dim)  (batch,head_num,s,head_dim)
        q = Q.view(batch,seq,self.head_num,self.head_dim).transpose(1,2)   
        k = K.view(batch,seq,self.head_num,self.head_dim).transpose(1,2)   
        v = V.view(batch,seq,self.head_num,self.head_dim).transpose(1,2)  

        result_mid = (q @ k.transpose(-1,-2)) / math.sqrt(self.head_dim)

        if hidden_mask is not None:
            result_mid = result_mid.masked_fill(
                hidden_mask == 0 , float('-inf') #   无限小的值
            )
        print(result_mid)

        #   对行向量进行softmax
        result_mid = torch.softmax(result_mid,dim=-1)
        result_drop = self.drop_out(result_mid)

        output_mid = (result_drop @ v).transpose(1,2).contiguous()
        output_mid = output_mid.view(batch,seq,-1)#h

        output = self.out_proj(output_mid)
        return output

mask = torch.tensor([
    [1,0],
    [0,0],
    [1,1]
]).unsqueeze(1).unsqueeze(2).expand(3,8,2,2)
print(mask.shape)
X = torch.rand(3,2,128)
model = MutiheadTransformer(8,128)# 8 head
model(X,mask)






torch.Size([3, 8, 2, 2])
tensor([[[[ 0.1323,    -inf],
          [ 0.0970,    -inf]],

         [[ 0.0960,    -inf],
          [-0.0008,    -inf]],

         [[-0.1000,    -inf],
          [-0.0680,    -inf]],

         [[ 0.1082,    -inf],
          [ 0.1479,    -inf]],

         [[-0.1186,    -inf],
          [-0.1090,    -inf]],

         [[ 0.0476,    -inf],
          [ 0.0272,    -inf]],

         [[-0.0938,    -inf],
          [-0.0341,    -inf]],

         [[ 0.1974,    -inf],
          [ 0.0012,    -inf]]],


        [[[   -inf,    -inf],
          [   -inf,    -inf]],

         [[   -inf,    -inf],
          [   -inf,    -inf]],

         [[   -inf,    -inf],
          [   -inf,    -inf]],

         [[   -inf,    -inf],
          [   -inf,    -inf]],

         [[   -inf,    -inf],
          [   -inf,    -inf]],

         [[   -inf,    -inf],
          [   -inf,    -inf]],

         [[   -inf,    -inf],
          [   -inf,    -inf]],

         [[   -inf,    -inf],
          [  

tensor([[[-2.0130e-01,  3.0907e-01, -3.5146e-01, -2.6524e-01,  2.9541e-01,
           1.6806e-01, -1.6961e-01,  2.8238e-01,  7.8997e-02, -9.3792e-02,
          -1.4074e-01, -2.9890e-01,  7.3648e-02,  5.9242e-02, -2.5970e-02,
          -1.6604e-01,  1.4730e-01,  3.2386e-01,  3.8600e-02, -1.1986e-01,
          -4.4942e-02,  3.5500e-02,  2.6568e-02, -7.6863e-02, -4.7323e-03,
          -9.5115e-02,  1.1173e-02,  8.6500e-02, -1.3608e-01,  1.3307e-01,
           4.8173e-02,  2.8138e-01, -3.1594e-02, -2.8889e-01, -2.1654e-01,
          -1.8872e-01,  1.0044e-01,  1.1752e-01, -1.6978e-01,  7.4238e-03,
           7.9879e-02, -1.8393e-01, -3.9519e-02,  1.2685e-01,  3.2053e-01,
           1.0059e-01, -1.4045e-01,  5.3156e-02, -5.0888e-01,  7.2849e-03,
          -9.4807e-03, -2.1026e-02, -4.0783e-01, -4.0632e-02, -4.2254e-01,
          -2.3455e-01,  4.1237e-02, -1.0502e-02,  3.2463e-02, -2.3092e-02,
           2.4118e-01,  1.2116e-01, -1.2872e-01, -4.2637e-01,  9.7296e-02,
          -2.3868e-01,  4