In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
import math
import copy

In [None]:
##位置编码
class PositionalEncoding(nn.Module):
    def __init__(self, d_model,  max_len=5000):
        super(PositionalEncoding, self).__init__()
        ##创建一个长为max_len,维度为d_model的位置编码矩阵
        pe=torch.zeors(max_len,d_model)
        position=torch.arange(0,max_len,dtype=torch.float).unsqueeze(1)
        ##计算不同频率的正弦和余弦
        div_term=torch.exp(torch.arange(0,d_model,2).float()*-(math.log(10000.0)/d_model))
        pe[:,0::2]=torch.sin(position*div_term)#偶数维度用正弦
        pe[:,1::2]=torch.cos(position*div_term)#奇数维度用余弦 
        pe=pe.unsqueeze(0).transpose(0,1) #调整形状为(max_len,1,d_model) transpose指定两个维度进行交换
        self.register_buffer('pe',pe) #注册为buffer,不参与梯度更新
    def forward(self,x):   
        #x的形状:(seq_len,batch_size,d_model)
        x=x+self.pe[:x.size(0),:]
        return x
        

In [None]:
##多头自注意力机制
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        assert d_model%num_heads==0
        
        self.d_model=d_model
        self.num_heads=num_heads
        self.d_k=d_model//num_heads
        
        self.w_q=nn.Linear(d_model,d_model)
        self.w_k=nn.Linear(d_model,d_model)
        self.w_v=nn.Linear(d_model,d_model)
        
        self.fc_out=nn.Linear(d_model,d_model)
        
        self.dropout=nn.Dropout(p=0.1)
        self.scale=math.sqrt(self.d_k) ##根号下dk
    
    def forward(self,query,key,value,mask=None):
        batch_size=query.size(1)
        
        ##线性变换
        Q=self.w_q(query)
        K=self.w_k(key)
        V=self.w_v(value)
        
        ##分割成多个头
        Q=Q.view(batch_size,-1,self.num_heads,self.d_k).transpose(1,2) ##(batch,num_heads,seq_len, d_k# )
        K=K.view(batch_size,-1,self.num_heads,self.d_k).transpose(1,2)
        V=V.view(batch_size,-1,self.num_heads,self.d_v).transpose(1,2)
        
        ##计算注意力得分
        scores=torch.matmul(Q,K.transpose(-2,-1))/self.scale
        if mask is not None:
            scores=scores.masked_fill(mask==0,-1e9)
        attn=torch.softmax(scores,dim=1)
        attn=self.dropout(attn)
        
        ##加权求和
        x=torch.matmul(attn,V)
        ##合并多头
        x=x.transpose(1,2).contiguous().view(batch_size,-1,self.num_heads*self.d_v)
        
        x=self.fc_out(x)
        return x
        
            
        

In [None]:
##前馈网络(前馈网络由两个线性层和一个ReLu激活函数组成)
class PositionwiseFeedForward(nn.Module):
    def __init__(self, d_model, d_ff):
        super(PositionwiseFeedForward, self).__init__()
        self.fc1=nn.Linear(d_model,d_ff)
        self.fc2=nn.Linear(d_model,d_ff)
        self.dropout=nn.Dropout(p=0.1)
        self.activation=nn.ReLU()
    def forward(self,x):
        return self.fc2(self.activation(self.fc1(x)))
    

In [None]:
##子连接层(每个子层(自注意力和前馈网络)都包含一个残差连接和一个层归一化)
class SublayerConnection(nn.Module):
    def __init__(self, d_model,dropout=0.1):
        super(SublayerConnection, self).__init__()
        self.norm=nn.LayerNorm(d_model)
        self.dropout=nn.Dropout(p=0.1)
    def forward(self,x,sublayer):
        return x+self.dropout(sublayer(self.norm))
    

In [None]:
##编码器层(包含一个多头注意力和一个前馈神经网络)
class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super(EncoderLayer, self).__init__()
        self.self_attn=MultiHeadAttention(d_model, num_heads)
        self.feed_forward=PositionwiseFeedForward(d_model,d_ff,dropout)
        self.sublayer=clones(SublayerConnection(d_model,dropout),2)
        self.d_model=d_model
        
    def forward(self,x,mask):
        x=self.sublayer[0](x,self.self_attn(x,x,x,mask))
        ##匿名函数，接受输入x并返回多头子注意力的输出，(x,x,x)分别表示q,k,v,这是第一个残差连接和归一化
        return self.sublayer[1](x,self.feed_forward)
        ##接受前馈神经网络的输入和x的输入并进行残差连接和归一化

In [None]:
##克隆函数 用于克隆多个相同的子层
def clones(module,N):
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])


In [None]:
##解码器层
class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super(DecoderLayer, self).__init__()
        self.self_attn=MultiHeadAttention(d_model, num_heads)
        self.src_attn=MultiHeadAttention(d_model,num_heads)
        self.feed_forward=PositionwiseFeedForward(d_model,d_ff)
        self.sublayer=clones(SublayerConnection(d_model,dropout),3)
        self.d_model=d_model
    def forward(self,x,memory,src_mask,tgt_mask):
        m=memory
        x=self.sublayer[0](x,self.self_attn(x,x,x,tgt_mask))
        x=self.sublayer[1](x,self.src_attn(x,x,x,src_mask))
        return self.sublayer[2](x,self.feed_forward)
        

In [None]:
##编码器(编码器由多个编码器层组成)
class Encoder(nn.Module):
    def __init__(self, encoder_layer,num_layers):
        super(Encoder,self).__init__()
        self.layers=clones(encoder_layer,num_layers)
        self.num_layers=num_layers
        self.norm=nn.LayerNorm(encoder_layer.d_model)
    def forward(self,x,mask):
        for layer in self.layers:
            x=layer(x,mask)
        return self.norm(x)
        

In [None]:
##解码器
class Decoder(nn.Module):
    def __init__(self, decoder_layer,num_layers):
        super(Decoder,self).__init__()
        self.layers=clones(decoder_layer,num_layers)
        self.norm=nn.LayerNorm(decoder_layer.d_model)
    def forward(self,x,memory,src_mask,tgt_mask):
        for layer in self.layers:
            x=layer(x,memory,src_mask,tgt_mask) ##src_mask用于屏蔽填充位置，tgt_mask用于屏蔽未来位置,memory为编码器的输出
        return self.norm(x)

In [None]:
##Transformer模型
class Transformer(nn.Module):
    def __init__(self,encoder,decoder,src_embed,trg_embed,generator):
        super(Transformer,self).__init__()
        self.encoder=encoder
        self.decoder=decoder
        self.src_embed=src_embed
        self.trg_embed=trg_embed
        self.generator=generator
    def forward(self,src,tgt,memory,src_mask,tgt_mask):
        memory=self.encoder(self.src_embed(src),src_mask)
        output=self.decoder(self.trg_embed(tgt),memory,src_mask,tgt_mask)
        return self.generator(output)

In [None]:
##生成器
class Generator(nn.Module):
    def __init__(self,d_model,vocab_size):
        super(Generator,self).__init__()
        self.proj=nn.Linear(d_model,vocab_size)
    def forward(self,x):
        return torch.softmax(self.proj(x),dim=-1)

In [None]:
##嵌入层
class Embeddings(nn.Module):
    def __init__(self,d_model,vocab_size):
        super(Embeddings,self).__init__()
        self.lut=nn.Embedding(vocab_size,d_model)
        self.d_model=d_model
    def forward(self,x):
        return self.lut(x)*math.sqrt(self.d_model)

In [None]:
##掩码
def subsequent_mask(size):
    attn_shape=(1,size,size)
    subsequent_mask=torch.triu(torch.ones(attn_shape),diagonal=1).type(torch.uint8)
    return subsequent_mask==0

In [None]:
def make_model(src_vocab,tgt_vocab,N=6,d_model=512,d_ff=2048,num_heads=8,dropout=0.1):
    c=copy.deepcopy
    attn=MultiHeadAttention(d_model,num_heads)
    ff=PositionwiseFeedForward(d_model,d_ff,dropout)
    position=PositionalEncoding(d_model,dropout)
    
    model=Transformer(
        Encoder(EncoderLayer(d_model,num_heads,d_ff,dropout),N),
        Decoder(DecoderLayer(d_model,num_heads,d_ff,dropout),N),
        nn.Sequential(Embeddings(d_model,src_vocab),c(position)),
        nn.Sequential(Embeddings(d_model,tgt_vocab),c(position)),
        Generator(d_model,tgt_vocab)
    )
    
    ##初始化参数
    for p in model.parameters():
        if p.dim()>1:
            nn.init.xavier_uniform_(p)
    return model
    

In [None]:
##训练过程
#超参数
src_vocab=10000
tgt_vocab=10000
batch_size=64
N=6
d_model=512
d_ff=2048
num_heads=8
dropout=0.1

##创建模型
model=make_model(src_vocab,tgt_vocab,N,d_model,d_ff,num_heads,dropout)

##损失函数和优化器
criterion=nn.CrossEntropyLoss(ignore_index=0)
optimizer=torch.optim.Adam(model.parameters(),lr=0.001)

##训练一个简单的步骤
src=torch.randint(1,src_vocab,(batch_size,N))
tgt=torch.randint(1,tgt_vocab,(batch_size,N))
src_mask=torch.ones((batch_size,1,10))
tgt_mask=subsequent_mask(10).unsqueeze(0)

optimizer.zero_grad()
output=model(src,tgt,src_mask,tgt_mask)
loss=criterion(output.view(-1,tgt_vocab),tgt.view(-1))
loss.backward()
optimizer.step()