In [21]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

In [22]:
class TokenEmbedding(nn.Module):
    def __init__(self, vocab_size, d_model, device):
        super(TokenEmbedding, self).__init__()
        self.embedding = nn.Embedding(vocab_size, d_model, device=device) # (vocab_size, d_model)

    def forward(self, x):
        x = self.embedding(x)
        return x

class PositionalEmbedding(nn.Module):
    def __init__(self, d_model, max_len, device):
        super(PositionalEmbedding, self).__init__()
        self.encoding = torch.zeros(max_len, d_model, device=device)
        self.encoding.require_grad = False # 位置编码不需要梯度
        pos = torch.arange(0, max_len, dtype=torch.float, device=device)
        pos = pos.unsqueeze(1)
        _2i = torch.arange(0, d_model, 2, dtype=torch.float, device=device)
        self.encoding[:, 0::2] = torch.sin(pos / (10000 ** (_2i / d_model)))
        self.encoding[:, 1::2] = torch.cos(pos / (10000 ** (_2i / d_model)))

    def forward(self, x):
        batch_size, seq_len = x.size()
        return self.encoding[:seq_len, :]
    
class TransformerEmbedding(nn.Module):
    def __init__(self, vocab_size, d_model, max_len, dropout, device):
        super(TransformerEmbedding, self).__init__()
        self.token_embedding = TokenEmbedding(vocab_size, d_model, device)
        self.positional_embedding = PositionalEmbedding(d_model, max_len, device)
        self.dropout = nn.Dropout(p = dropout) 

    def forward(self, x):
        token_embedding = self.token_embedding(x)
        positional_embedding = self.positional_embedding(x)
        return self.dropout(token_embedding + positional_embedding) # (batch_size, tokens, d_model)

In [23]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_head, device):
        super(MultiHeadAttention, self).__init__()
        self.d_model = d_model
        self.n_head = n_head
        self.W_q = nn.Linear(d_model, d_model, bias=False, device=device)
        self.W_k = nn.Linear(d_model, d_model, bias=False, device=device)
        self.W_v = nn.Linear(d_model, d_model, bias=False, device=device)
        self.concat = nn.Linear(d_model, d_model, bias=False, device=device)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, encodings_for_q, encodings_for_k, encodings_for_v, mask):
        # encodings_for_q: (batch, seq_len, d_model)
        # encodings_for_k: (batch, seq_len, d_model)
        # encodings_for_v: (batch, seq_len, d_model)
        batch, seq_len, d_model = encodings_for_q.size()
        n_d = self.d_model // self.n_head
        Q = self.W_q(encodings_for_q)
        K = self.W_k(encodings_for_k)
        V = self.W_v(encodings_for_v)
        Q = Q.view(batch, seq_len, self.n_head, n_d).permute(0, 2, 1, 3) # (batch, n_head, seq_len, d_model)
        K = K.view(batch, seq_len, self.n_head, n_d).permute(0, 2, 1, 3) # (batch, n_head, seq_len, d_model)
        V = V.view(batch, seq_len, self.n_head, n_d).permute(0, 2, 1, 3) # (batch, n_head, seq_len, d_model)
        scaled_sims = Q@K.transpose(2, 3) / torch.sqrt(torch.tensor(n_d))
        if mask is not None:
            scaled_sims = scaled_sims.masked_fill(mask == 0, -1e9)
        attention_percent = self.softmax(scaled_sims)
        attention_scores = attention_percent@V # (batch_size, n_head, seq_len, d_model)
        attention_scores = attention_scores.permute(0, 2, 1, 3).contiguous().view(batch, seq_len, d_model) # (batch, seq_len, d_model)

        return self.concat(attention_scores) # (batch_size, seq_len, d_model)

In [24]:
class PositionwiseFeedForward(nn.Module):
    def __init__(self, d_model, hidden, dropout, device):
        super(PositionwiseFeedForward, self).__init__()
        self.fc1 = nn.Linear(d_model, hidden, device=device)
        self.fc2 = nn.Linear(hidden, d_model, device=device)
        self.dropout = nn.Dropout(dropout)
    def forward(self, x):
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.fc2(x)
        return x # (batch_size, seq_len, d_model)

In [36]:
class DecoderLayer(nn.Module):
    def __init__(self, d_model, ffn_hidden, n_head, dropout, device):
        super(DecoderLayer, self).__init__()
        self.attention1 = MultiHeadAttention(d_model, n_head, device)
        self.norm1 = nn.LayerNorm(d_model, device=device)
        self.dropout1 = nn.Dropout(dropout)
        self.cross_attention = MultiHeadAttention(d_model, n_head, device)
        self.norm2 = nn.LayerNorm(d_model, device=device)
        self.dropout2 = nn.Dropout(dropout)
        self.ffn = PositionwiseFeedForward(d_model, ffn_hidden, dropout, device)
        self.norm3 = nn.LayerNorm(d_model, device=device)
        self.dropout3 = nn.Dropout(dropout)
        
    def forward(self, encoder_output, decoder_input, t_mask, s_mask):
        _x = decoder_input
        x = self.attention1(decoder_input, decoder_input, decoder_input, t_mask)
        x = self.dropout1(x)
        x = self.norm1(x + _x)
        _x = x
        x = self.cross_attention(encoder_output, encoder_output, x, s_mask)
        x = self.dropout2(x)
        x = self.norm2(x + _x)
        x = self.ffn(x)
        x = self.dropout3(x)
        x = self.norm3(x + _x)
        return x 

In [31]:
class Decoder(nn.Module):
    def __init__(self, decoder_vocab_size, max_len, d_model, ffn_hidden, n_head, n_layer, dropout, device):
        super(Decoder, self).__init__()
        self.embedding = TransformerEmbedding(decoder_vocab_size, d_model, max_len, dropout, device)
        self.decoder_layers = nn.ModuleList(
            [
                DecoderLayer(d_model, ffn_hidden, n_head, dropout, device) for _ in range(n_layer)  
            ]
        )
        self.fc = nn.Linear(d_model, decoder_vocab_size, device=device)
        
    def forward(self, encoder_output, decoder_output, t_mask, s_mask):
        decoder_output = self.embedding(decoder_output)
        for decoder_layer in self.decoder_layers:
            decoder_output = decoder_layer(decoder_output, encoder_output, t_mask, s_mask)
        decoder_output = self.fc(decoder_output)
        return decoder_output
        

In [39]:
decoder_vocab_size = 100
max_len = 20
d_model = 64
n_head = 8
ffn_hidden = 64
n_layer = 2
device = torch.device("cuda")
dropout = 0.1
t_mask = None
s_mask = None

decoder = Decoder(decoder_vocab_size, max_len, d_model, ffn_hidden, n_head, n_layer, dropout, device)

encoder_output = torch.randn(2, 12, d_model, device=device) 
decoder_input = torch.randint(0, decoder_vocab_size, (2, 12), device=device)

output = decoder(encoder_output, decoder_input, t_mask, s_mask)
print('输入的数据类型:', output.dtype)
print('输出形状:', output.shape)

输入的数据类型: torch.float32
输出形状: torch.Size([2, 12, 100])
