In [None]:
#手撕multi-attention机制
import torch
import torch.nn as nn
import torch.nn.functional as F 
import math

#初始化输入
x  = torch.randn(128,64,512)
#定义模型参数
d_model = 512
n_heads = 8
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads):
        super(MultiHeadAttention, self).__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.wq = nn.Linear(d_model, d_model)
        self.wk = nn.Linear(d_model, d_model)
        self.wv = nn.Linear(d_model, d_model)
        self.w_combine = nn.Linear(d_model, d_model)
        self.softmax = nn.Softmax(dim=-1)
    def forward(self, q, k, v,mask = None):
        batch, time, dimension = q.shape
        n_dim = self.d_model // self.n_heads
        q,k,v = self.wq(q), self.wk(k), self.wv(v)
        q = q.view(batch,time,self.n_heads,n_dim).permute(0,2,1,3)
        k = k.view(batch,time,self.n_heads,n_dim).permute(0,2,1,3)
        v = v.view(batch,time,self.n_heads,n_dim).permute(0,2,1,3)

        attn_score = q@k.transpose(2,3) / math.sqrt(n_dim)
        if mask is not None:
            # mask = torch.tril(torch.ones(time,time,d_tpye =bool))
            attn_score = attn_score.masked_fill(mask == 0, -1e9)
        attn_score = self.softmax(attn_score) @ v
        #contigeous()是为了让attn_score在内存中也是连续的,这样才可以使用view()
        attn_score = attn_score.permute(0,2,1,3).contiguous().view(batch,time,dimension)
        output = self.w_combine(attn_score)
        return output



#encoder层
#首先是token embedding
class TokenEmbedding(nn.Module):
    def __init__(self, d_model, vocab_size):
        super(TokenEmbedding, self).__init__()
        self.embedding = nn.Embedding(vocab_size, d_model) 

#然后是位置编
# 提供位置信息：在自然语言处理（NLP）等序列处理任务中，单词的顺序通常是语义理解的关键。位置编码确保模型能够捕捉到这种顺序关系，即使自注意力层本身并不直接处理位置信息。
# 保持模型的并行性：在Transformer模型中，自注意力层和前馈网络层可以并行处理序列中的所有元素。位置编码允许模型在保持这种并行性的同时，理解元素之间的位置关系。
# 增强模型的表达能力：通过将位置编码添加到输入表示中，模型可以学习到位置相关的特征，这有助于提高模型对序列数据的理解深度。
# 允许模型处理可变长度的序列：位置编码使得模型能够处理不同长度的输入序列。通过为每个可能的位置提供唯一的编码，模型可以适应不同长度的输入。
# 训练效率：位置编码是固定的或可学习的参数，这意味着它们不需要在每个训练步骤中重新计算。这有助于提高训练效率，尤其是在处理长序列时。
class positionalEmbedding(nn.Module):
    def __init__(self, d_model, max_len,device):
        super(positionalEmbedding, self).__init__()
        #初始化位置编码,他是一个全0矩阵，大小为max_len*d_model
        self.encoding = torch.zeros(max_len, d_model, device=device)
        self.encoding.require_grad(False)
        #创建一个从0到max_len-1的序列，形状为(max_len, 1)。
        pos = torch.arange(0,max_len,device=device).float().unsqueeze(1)
        #创建一个序列，包含从0开始到d_model-1的步长为2的整数，用于计算正弦和余弦函数的频率。
        _2i = torch.arange(0,d_model,step=2,device=device).float()
        #使用sin和cos函数生成位置编码。对于d_model中的每个维度，如果索引是偶数，则使用sin函数；如果索引是奇数，则使用cos函数。这样做是为了让模型能够从位置编码中学习到位置信息，同时保持模型的对称性。
        self.encoding[:,0::2] = torch.sin(pos / (10000 ** (_2i / d_model)))
        self.encoding[:,1::2] = torch.cos(pos / (10000 ** (_2i / d_model)))
    def forward(self, x):
        seq_len = x.shape[1]
        #返回位置编码矩阵的前seq_len行，以匹配输入序列的长度。
        return self.encoding[:seq_len,:]
    

#然后是layer norm
#减少对显存的需求
class LayerNorm(nn.Module):
    def __init__(self, d_model, eps=1e-6):
        class layer_norm(nn.Module):
            def __init__(self, d_model, eps=1e-6):
                super(layer_norm, self).__init__()
                self.gamma = nn.Parameter(torch.ones(d_model))
                self.beta = nn.Parameter(torch.zeros(d_model))
                self.eps = eps
            def forward(self, x):
                mean = x.mean(-1, keepdim=True)
                var = x.var(-1.unbiased=False, keepdim=True)
                out = (x - mean)/torch.sqrt(var + self.eps)
                out = self.gamma * out + self.beta
                return out
            
#然后是全连接层
#position-wise feed forward network
class PositionwiseFeedForward(nn.Module):
    def __init__(self, d_model,hidden,dropout=0.1):
        self.fc1 = nn.Linear(d_model, hidden)
        self.fc2 = nn.Linear(hidden, d_model)
        self.dropout = nn.Dropout(dropout)
    def forward(self, x):
        out = self.fc1(x)
        out = F.relu(out)
        out = self.dropout(out)
        out = self.fc2(out)
        return out
#拼在一起形成transformer的embedding:
#输出是attention层的输入
class transformerEmbedding(nn.Module):
    def __init__(self, d_model, vocab_size, max_len, device,drop_prop):
        super(transformerEmbedding, self).__init__():
        self.tokemb = TokenEmbedding(d_model, vocab_size)
        self.posemb = positionalEmbedding(d_model, max_len, device)
        self.drop = nn.Dropout(drop_prop)
    def forward(self, x):
        tokemb = self.tokemb(x)
        posemb = self.posemb(x)
        return self.drop_out(tokemb + posemb)
        

#最后就是encoder层
#self-attention > add & norm > position-wise feed forward > add & norm
class encoderLayer(nn.Module):
    def __init__(self, d_model, n_heads, hidden, dropout=0.1):
        super(encoderLayer, self).__init__()
        self.self_attn = MultiHeadAttention(d_model, n_heads)
        self.layer_norm1 = LayerNorm(d_model)
        self.drop1 = nn.Dropout(dropout)
        self.ffn = PositionwiseFeedForward(d_model, hidden, dropout)
        self.layer_norm2 = LayerNorm(d_model)
        self.drop2 = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        _x = x
        x = self.self_attn(x, x, x, mask)
        x = self.drop1(x)
        x = self.layer_norm1(x + _x)
        _x = x
        x = self.ffn(x)
        self.drop2(x)
        s  = self.layer_norm2(x + _x)
        return s


#decoder层
#masked multi-head attention > add & norm > cross attention > add & norm > position-wise feed forward > add & norm
class decoderLayer(nn.Module):
    def __init__(self, d_model, n_heads, hidden, dropout=0.1,mask=True):
        super(decoderLayer, self).__init__()
        self.attn1 = MultiHeadAttention(d_model, n_heads)
        self.layer_norm1 = LayerNorm(d_model)
        self.drop1 = nn.Dropout(dropout)
        self.cross_attn = MultiHeadAttention(d_model, n_heads,mask)
        self.layer_norm2 = LayerNorm(d_model)
        self.drop2 = nn.Dropout(dropout)
        self.ffn = PositionwiseFeedForward(d_model, hidden, dropout)
        self.layer_norm3 = LayerNorm(d_model)
        self.drop3 = nn.Dropout(dropout)
    #加入两个特殊mask,一个是decoder的mask,一个是encoder的mask
    #t_mask 又叫做下三角掩码,用来mask未来序列的信息
    #s_mask用来忽视掉 padding的信息
    def forward(self,dec,enc,t__mask,s_mask):    
        _x = dec
        x = self.attn1(dec, dec, dec, t__mask)
        x = self.drop1(x)
        x = self.layer_norm1(x + _x)
        _x = x
        if enc is not None:
            _x = x
            x = self.cross_attn(x, enc, enc, s_mask)
            x = self.drop2(x)
            x = self.layer_norm2(x + _x)
        _x = x
        x = self.ffn(x)
        x = self.drop3(x)
        x = self.layer_norm3(x + _x)
        return x


#一整个encoders有多个encoder layer组成
class transformerEncoder(nn.Module):
    def __init__(self,enc_voc_size,max_len,d_model,ffn_hidden,n_heads,n_layers,dropout=0.1,device):
        super(transformerEncoder, self).__init__()
        self.embedding = transformerEmbedding(d_model, enc_voc_size, max_len, device, dropout)
        self.layers = nn.ModuleList([encoderLayer(d_model, n_heads, ffn_hidden, dropout) for _ in range(n_layers)])
    def forward(self, x, s_mask):
        x = self.embedding(x)
        for layer in self.layers:
            x = layer(x, s_mask)
        return x
    

#decoder层
class transformerDecoder(nn.Module):
    def __init__(self,dec_voc_size,max_len,d_model,ffn_hidden,n_heads,n_layers,dropout=0.1,device):
        super(transformerDecoder, self).__init__():
        self.embedding = transformerEmbedding(d_model, dec_voc_size, max_len, device, dropout)
        self.layers = nn.ModuleList([decoderLayer(d_model, n_heads, ffn_hidden, dropout) for _ in range(n_layers)])
        self.fc = nn.Linear(d_model, dec_voc_size)
    def forward(self, dec, enc, t_mask, s_mask):
        dec = self.embedding(dec)
        for layer in self.layers:    
            x = layer(dec, enc, t_mask, s_mask)
        dec = self.fc(dec)
        return x


#最后就是transformer
class transformer(nn.Module):
    def __init__(self,src_pad_idx,trg_pad_idx,enc_voc_size,dec_voc_size,max_len,d_model,ffn_hidden,n_heads,n_layers,dropout=0.1,device):
        super(transformer, self).__init__()
        self.encoder = transformerEncoder(enc_voc_size,max_len,d_model,ffn_hidden,n_heads,n_layers,dropout,device)
        self.decoder = transformerDecoder(dec_voc_size,max_len,d_model,ffn_hidden,n_heads,n_layers,dropout,device)
        self.src_pad_idx = src_pad_idx
        self.trg_pad_idx = trg_pad_idx
        self.device = device
    #建立t-mask方法
    def make_trg_mask(self, q,k):
        len_q,len_k = q.size(1),k.size(1)
        mask = torch.tril(torch.ones(len_q,len_k,device=self.device)).bool()
        return mask
    #建立s-mask方法
    def make_src_mask(self, q,k,pad_idx_q,pad_idx_k):
        len_q,len_k = q.size(1),k.size(1)
        #(batch,time,len_q,len_k)
        q = q.ne(pad_idx_q).unsqueeze(1).unsqueeze(3)
        q = q.repeat(1,1,1,len_k)

        k = k.ne(pad_idx_k).unsqueeze(1).unsqueeze(2)
        k = k.repeat(1,1,len_q,1)
        #与运算,得到s_mask
        mask = q & k
        return mask
    def forward(self, src, trg):
        src_mask = self.make_src_mask(src,src,self.src_pad_idx,self.src_pad_idx)
        trg_mask = self.make_trg_mask(trg,trg,self.trg_pad_idx,self.trg_pad_idx) * self.make_trg_mask(trg,trg,self.src_pad_idx,self.src_pad_idx)
        src_trg_mask = self.make_src_mask(trg,src,self.trg_pad_idx,self.src_pad_idx)
        enc = self.encoder(src, src_mask)
        dec = self.decoder(trg, enc, trg_mask, src_trg_mask)
        return dec


