In [1]:
import torch
from torch import nn
import math
from torch import functional as F

In [None]:
class selfAttention(nn.Module):
    def __init__(self, hidden_dim:int = 728)->None:
        super().__init__()
        self.hidden_dim = hidden_dim
    #初始化qkv
        self.query_proj = nn.Linear(hidden_dim,hidden_dim)
        self.key_proj = nn.Linear(hidden_dim,hidden_dim)
        self.value_proj = nn.Linear(hidden_dim,hidden_dim)
    
    def forward(self,X):
        # Q : (batch_size,seq_len,dim)
        # K*T:(batch_size,dim,seq_len)
        Q = self.query_proj(X)  # X (batch_size,seq,dim) -> (batch_size,seq,hidden_dim)
        K = self.key_proj(X)
        V = self.value_proj(X)

        attention_value = torch.matmul(
            #给k转置
            Q,K.transpose(-1,-2)
        )

        attention_softmax = torch.softmax(
            attention_value / math.sqrt(self.hidden_dim),dim=-1
        )

        result = torch.matmul(
            attention_softmax,V
        )
        return result

X = torch.rand(2,3,4)


self_att_test = selfAttention(4)
self_att_test(X)


tensor([[[-0.1647, -0.5922,  0.2470, -0.0232],
         [-0.1650, -0.5883,  0.2439, -0.0200],
         [-0.1641, -0.5926,  0.2476, -0.0244]],

        [[-0.2515, -0.7480,  0.2592, -0.1423],
         [-0.2503, -0.7461,  0.2571, -0.1425],
         [-0.2504, -0.7468,  0.2562, -0.1411]]], grad_fn=<UnsafeViewBackward0>)

In [40]:
class selfAttentionV2(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
        self.qkv = nn.Linear(dim,dim*3)

    def forward(self,X):
        # 
        QKV = self.qkv(X)
        Q,K,V = torch.split(QKV,self.dim,dim=-1)

        result = torch.softmax(torch.matmul(
            Q,K.transpose(-1,-2)
        ) / math.sqrt(self.dim),dim=-1
        )
        print(result)
        output = result @ V
        return output
X = torch.rand(3,2,4)
model2 = selfAttentionV2(4)
model2(X)


tensor([[[0.4894, 0.5106],
         [0.5109, 0.4891]],

        [[0.5137, 0.4863],
         [0.4978, 0.5022]],

        [[0.5048, 0.4952],
         [0.5059, 0.4941]]], grad_fn=<SoftmaxBackward0>)


tensor([[[-0.4449,  0.2280, -0.5718,  0.0779],
         [-0.4421,  0.2280, -0.5792,  0.0705]],

        [[-0.5378,  0.1127, -0.5518,  0.1663],
         [-0.5330,  0.1195, -0.5498,  0.1640]],

        [[-0.5221,  0.0467, -0.6415,  0.1114],
         [-0.5225,  0.0464, -0.6415,  0.1117]]], grad_fn=<UnsafeViewBackward0>)

In [None]:
class MutiheadTransformer(nn.Module):
    def __init__(self,head_num,hidden_dim,drop_out = 0.1):
        super().__init__()
        self.head_num = head_num
        self.hidden_dim = hidden_dim
        self.head_dim = hidden_dim // head_num
        self.q_proj = nn.Linear(hidden_dim,hidden_dim)
        self.k_proj = nn.Linear(hidden_dim,hidden_dim)
        self.v_proj = nn.Linear(hidden_dim,hidden_dim)

        self.drop_out = nn.Dropout(drop_out)
        
        self.out_proj = nn.Linear(hidden_dim,hidden_dim)

    def forward(self,X,hidden_mask = None):
        #   qkv (batch,s,h)  ->   (batch,head_num,s,head_dim)
        batch , seq , _ =  X.size()
        Q = self.q_proj(X)
        K = self.k_proj(X)
        V = self.v_proj(X)

        #   transform it's shape (batch,s,head_num,head_dim)  (batch,head_num,s,head_dim)
        q = Q.view(batch,seq,self.head_num,self.head_dim).transpose(1,2)   
        k = K.view(batch,seq,self.head_num,self.head_dim).transpose(1,2)   
        v = V.view(batch,seq,self.head_num,self.head_dim).transpose(1,2)  

        result_mid = (q @ k.transpose(-1,-2)) / math.sqrt(self.head_dim)

        if hidden_mask is not None:
            result_mid = result_mid.masked_fill(
                hidden_mask == 0 , float('-inf') #   无限小的值
            )
        print(result_mid)

        #   对行向量进行softmax
        result_mid = torch.softmax(result_mid,dim=-1)
        result_drop = self.drop_out(result_mid)

        output_mid = (result_drop @ v).transpose(1,2).contiguous()
        output_mid = output_mid.view(batch,seq,-1)#h

        output = self.out_proj(output_mid)
        return output

mask = torch.tensor([
    [1,0],
    [0,0],
    [1,1]
]).unsqueeze(1).unsqueeze(2).expand(3,8,2,2)
print(mask.shape)
X = torch.rand(3,2,128)
model = MutiheadTransformer(8,128)# 8 head
model(X,mask)

MutiHead 中QKV的不同的头，即 将其拆分为不同的特征向量。他们经过不同的初始化之后有着不同的权重值，再将其**语义的特征向量**进行拆分，并分别与不同的KV矩阵相乘

In [None]:
class SimpleDecoder(nn.Module):
    def __init__(self,hidden_dim,head_num,dropout_rate=0.1):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.head_num = head_num
        self.head_dim = hidden_dim // head_num
        self.dropout_rate = dropout_rate

        self.q_proj = nn.Linear(hidden_dim,hidden_dim)
        self.k_proj = nn.Linear(hidden_dim,hidden_dim)
        self.v_proj = nn.Linear(hidden_dim,hidden_dim)

        # upgrade dimension 
        self.up = nn.Linear(hidden_dim,hidden_dim*4)
        self.down = nn.Linear(hidden_dim*4 , hidden_dim)
        self.act_proj = nn.ReLU()
        # dropout
        self.dropout_proj = nn.Dropout(dropout_rate)
        self.layerNorm_proj = nn.LayerNorm(hidden_dim,eps=0.000001)


    def attention_layer(self,q,k,v,mask,):
        # (batch,seq,head_num,head_dim) -> (batch,head_num,seq,head_dim)
        attention_mid = (q @ k.transpose(-2,-1)) / math.sqrt(self.head_dim)
        # Itself tril and original mask
        if mask is not None:
            mask = mask.tril()
            attention_mid = attention_mid.masked_fill(
                mask == 0 , float('-inf')
            )
        else:
            # make a ones trangles matrix which dimensions same like original 
            mask = attention_mid.ones_like().tril()
            attention_mid = attention_mid.masked_fill(
                mask == 0 , float('-inf')
            )
        # Specify a dimension row
        attention_mid = torch.softmax(attention_mid,dim=-1)
        # dropout layer
        attention_dropout = self.dropout_proj(attention_mid)
        #(batch,head_num,seq,head_dim)
        output = (attention_dropout @ v)

        return output


    def ffn(self,X):

        up = self.up(X)
        up = self.act_proj(up)
        down = self.down(up)
        down = self.dropout_proj(down)
        return self.layerNorm_proj(down + X)

    def mha(self,X,mask):
        batch,seq,_ = X.size()
        Q = self.q_proj(X)
        K = self.k_proj(X)
        V = self.v_proj(X)
        # (batch,seq,hidden_dim) -> (batch,seq,head_num,head_dim) -> (batch,head_num,seq,head_dim)
        q = Q.view(batch,seq,self.head_num,-1).transpose(1,2)
        k = K.view(batch,seq,self.head_num,-1).transpose(1,2)
        v = V.view(batch,seq,self.head_num,-1).transpose(1,2)

        attention = self.attention_layer(q,k,v,mask).transpose(1,2).contiguous()
        output = attention.view(batch,seq,-1)
        return output

    def forward(self,X,attention_mask=None):
        X = self.mha(X,attention_mask)
        X = self.ffn(X)
        
        return X

class Decoder(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.layer_list = nn.ModuleList(
            [
                SimpleDecoder(64,8) # hidden=64,head_num=8
            ]
        )
        self.emb = nn.Embedding(12,64)
        self.out = nn.Linear(64,12)

    def forward(self,X,mask=None):
        X = self.emb(X)
        for i , l in enumerate(self.layer_list):
            X = l(X,mask)
        print(X.shape)
        output = self.out(X)
        return torch.softmax(output,dim=-1)
    
x  = torch.randint(low=0,high=12,size=(3,4))
#x = torch.rand(3,4,64)
net = Decoder()
mask = (
    torch.tensor([[1,1,1,1],[1,1,0,0],[1,0,0,0]])
    .unsqueeze(1)
    .unsqueeze(2)
    .repeat(1,8,4,1)
)
#print(mask.shape)
net(x,mask)


torch.Size([3, 4, 64])


tensor([[[0.0299, 0.0615, 0.0786, 0.0514, 0.1575, 0.0586, 0.0886, 0.0536,
          0.1314, 0.1477, 0.0744, 0.0667],
         [0.0370, 0.0445, 0.0565, 0.0635, 0.1195, 0.0328, 0.1554, 0.0355,
          0.0491, 0.2019, 0.0530, 0.1513],
         [0.0189, 0.0400, 0.0905, 0.1386, 0.1475, 0.0352, 0.0664, 0.0400,
          0.0685, 0.2035, 0.0357, 0.1150],
         [0.0210, 0.0402, 0.0771, 0.1360, 0.1415, 0.0348, 0.0512, 0.0334,
          0.0797, 0.2284, 0.0266, 0.1299]],

        [[0.0567, 0.0522, 0.0610, 0.1105, 0.0988, 0.0443, 0.1244, 0.0548,
          0.0397, 0.1198, 0.0902, 0.1478],
         [0.0344, 0.0637, 0.1130, 0.1915, 0.1115, 0.0312, 0.0915, 0.0554,
          0.0373, 0.0636, 0.0412, 0.1659],
         [0.0400, 0.0625, 0.0957, 0.1342, 0.1082, 0.0427, 0.1139, 0.0451,
          0.0280, 0.1210, 0.0431, 0.1656],
         [0.0524, 0.0748, 0.1068, 0.1486, 0.0640, 0.0352, 0.1018, 0.0440,
          0.0347, 0.1116, 0.0392, 0.1869]],

        [[0.0500, 0.0579, 0.0581, 0.1085, 0.0823, 0.0383, 0.