In [43]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

In [44]:
class PositionEncoder(nn.Module):
    def __init__(self, d, max_len):
        super().__init__()
        self.d = d
        pos = torch.zeros(max_len, d)
        for seq in range(max_len):
            for i in range(0, d, 2):
                pos[seq, i] = math.sin(seq / 10000 ** ((2 * i) / self.d))
                pos[seq, i+1] = math.cos(seq / 10000 **((2 * (i + 1) / self.d)))
        pos = pos.unsqueeze(0)
        self.register_buffer('pos', pos)

    def forward(self, X):
        length = X.size(1)
        X = X * math.sqrt(self.d) + self.pos[:, :length]
        return X

In [45]:
class MultiheadAttention(nn.Module):
    def __init__(self, d, dropout, heads):
        super().__init__()
        self.heads = heads
        self.d = d
        self.part = d // heads
        self.linQ = nn.Linear(d, d)
        self.linK = nn.Linear(d, d)
        self.linV = nn.Linear(d, d)
        self.dropout = nn.Dropout(dropout)
        self.out = nn.Linear(d, d)
    
    def attention(self, d, Q, K, V, mask):
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d)
        
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        scores = F.softmax(scores, dim = -1)
        
        if self.dropout:
            scores = self.dropout(scores)
        output = torch.matmul(scores, V)
        return output

    
    def forward(self, Q, K, V, mask):
        bs = Q.size(0)
        part = self.d // self.heads
        Q = self.linQ(Q).view(bs, -1, self.heads, part).transpose(1, 2)
        K = self.linK(K).view(bs, -1, self.heads, part).transpose(1, 2)
        V = self.linV(V).view(bs, -1, self.heads, part).transpose(1, 2)
        scores = self.attention(self.d, Q, K, V, mask)
        scores = scores.transpose(1, 2).contiguous()
        scores = scores.view(bs, -1, self.d)
        output = self.out(scores)
        return output

In [46]:
class FeedForward(nn.Module):
    def __init__(self, dropout, d, d_ff):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        self.lin1 = nn.Linear(d, d_ff)
        self.lin2 = nn.Linear(d_ff, d)
    
    def forward(self, X):
        X = self.dropout(F.relu(self.lin1(X)))
        X = self.lin2(X)
        return X

In [47]:
class NormLayer(nn.Module):
    def __init__(self, d, eps):
        super().__init__()
        self.alpha = nn.Parameter(torch.ones(d))
        self.beta = nn.Parameter(torch.zeros(d))
        self.eps = eps
    
    def forward(self, X):
        X = self.alpha * (X - X.mean(dim = -1, keepdim = True)) / (X.std(dim = -1, keepdim = True) + self.eps) + self.beta
        return X

In [48]:
class EncoderLayer(nn.Module):
    def __init__(self, d, d_ff, heads, eps, dropout):
        super().__init__()
        self.d = d
        self.dropout = dropout
        self.norm1 = NormLayer(d, eps)
        self.norm2 = NormLayer(d, eps)
        self.attention = MultiheadAttention(d, dropout, heads)
        self.ffn = FeedForward(dropout, d, d_ff)
        self.drop1 = nn.Dropout(dropout)
        self.drop2 = nn.Dropout(dropout)

    def forward(self, X, src_mask):
        X = self.norm1(X)
        X = X + self.drop1(self.attention(X, X, X, src_mask))
        X = self.norm2(X)
        X = X + self.drop2(self.ffn(X))
        return X

In [49]:
class Encoder(nn.Module):
    def __init__(self, vocab_size, d, d_ff, heads, max_len = 500, eps = 1e-7, N = 8, dropout = 0.1):
        super().__init__()
        self.pe = PositionEncoder(d, max_len)
        self.norm = NormLayer(d, eps)
        self.embed = nn.Embedding(vocab_size, d)
        self.layers = nn.ModuleList([EncoderLayer(d, d_ff, heads, eps, dropout) for _ in range(N)])

    def forward(self, X, src_mask):
        X = self.embed(X)
        X = self.pe(X)
        for layer in self.layers:
            X = layer(X, src_mask)
        return self.norm(X)

In [50]:
class DecoderLayer(nn.Module):
    def __init__(self, d, heads, dropout, d_ff, eps):
        super().__init__()
        self.norm1 = NormLayer(d, eps)
        self.att1 = MultiheadAttention(d, dropout, heads)
        self.drop1 = nn.Dropout(dropout)
        self.norm2 = NormLayer(d, eps)
        self.att2 = MultiheadAttention(d, dropout, heads)
        self.drop2 = nn.Dropout(dropout)
        self.norm3 = NormLayer(d, eps)
        self.ffn = FeedForward(dropout, d, d_ff)
        self.drop3 = nn.Dropout(dropout)

    def forward(self, Y, Z, src_mask, trg_mask):
        Y = self.norm1(Y)
        Y = Y + self.drop1(self.att1(Y, Y, Y, trg_mask))
        Y = self.norm2(Y)
        Y = Y + self.drop2(self.att2(Y, Z, Z, src_mask))
        Y = self.norm3(Y)
        Y = Y + self.drop3(self.ffn(Y))
        return Y

In [51]:
class Decoder(nn.Module):
    def __init__(self, vocab_size, d, heads, dropout, d_ff, max_len, N, eps):
        super().__init__()
        self.pos = PositionEncoder(d, max_len = 500)
        self.embed = nn.Embedding(vocab_size, d)
        self.layers = nn.ModuleList([DecoderLayer(d, heads, dropout, d_ff, eps) for _ in range(N)])
        self.norm = NormLayer(d, eps)

    def forward(self, Y, Z, src_mask, trg_mask):
        Y = self.embed(Y)
        Y = self.pos(Y)
        for layer in self.layers:
            Y = layer(Y, Z, src_mask, trg_mask)
        return self.norm(Y)

In [52]:
class Transformer(nn.Module):
    def __init__(self, trg_vocab, vocab_size, d, d_ff, heads, max_len , eps, N, dropout):
        super().__init__()
        self.Encoder = Encoder(vocab_size, d, d_ff, heads, max_len, eps, N, dropout)
        self.Decoder = Decoder(vocab_size, d, heads, dropout, d_ff, max_len, N, eps)
        self.out = nn.Linear(d, trg_vocab)
    
    def forward(self, X, Y, src_mask, trg_mask):
        Z = self.Encoder(X, src_mask)
        out = self.Decoder(Y, Z, src_mask, trg_mask)
        out = self.out(out)
        return out

In [53]:
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

vocab_size = 1000  
trg_vocab = 2000   
d = 512     
d_ff = 2048
heads = 8
n = 6
max_len = 50
dropout = 0.1
eps = 1e-6
batch_size = 32
src_len = 40
trg_len = 30
learning_rate = 1e-4
pad_idx = 0          
epochs = 20
steps_per_epoch = 50  

In [None]:
def data_loader(vocab_size, batch_size, src_len, trg_len):

    src_data = torch.randint(2, vocab_size, (batch_size, src_len))
    trg_data = torch.randint(2, vocab_size, (batch_size, trg_len))
    trg_data[:, 0] = 1
    trg_data_target = torch.cat([trg_data[:, 1:], torch.full((batch_size, 1), 2)], dim = -1)

    src_mask = torch.ones(batch_size, 1, 1, src_len).bool() 
    trg_mask = torch.tril(torch.ones(trg_len, trg_len)).bool()
    trg_mask = trg_mask.unsqueeze(0).unsqueeze(0)

    return src_data, trg_data, src_mask, trg_mask

In [None]:
def train_epoch(model, optimizer, criterion, device, steps, batch_size, src_len, trg_len, vocab_size):
    model.to(device)
    model.train()
    total_loss = 0
    
    for step in range(steps):
        optimizer.zero_grad()
        src_data, trg_data, src_mask, trg_mask = data_loader(vocab_size, batch_size, src_len, trg_len)
        src_data = src_data.to(device)
        trg_data = trg_data.to(device)
        src_mask = src_mask.to(device)
        trg_mask = trg_mask.to(device)
        output = model(src_data, trg_data, src_mask, trg_mask)
        dim = output.size(-1)
        output = output.reshape(-1, dim)
        trg_data = trg_data.reshape(-1)
        loss = criterion(output, trg_data)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        total_loss += loss.item()

    avg_loss = total_loss / steps
    print(f"平均损失: {avg_loss:.4f}")

    return avg_loss

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Transformer(
    trg_vocab = trg_vocab, vocab_size = vocab_size, d = d, d_ff = d_ff, 
    heads = heads, max_len = max_len, N = n, eps = eps, dropout = dropout
).to(device)

criterion = nn.CrossEntropyLoss(ignore_index = pad_idx)
optimizer = optim.Adam(model.parameters(), lr = learning_rate)

print(f"模型运行在 {device} 上。")
print("Start")

for epoch in range(1, epochs + 1):
    print(f"\nEpoch {epoch}/{epochs}")
    train_epoch(
        model, optimizer, criterion, device, 
        steps_per_epoch, batch_size, src_len, trg_len, vocab_size
    )

print("\nFinish")

模型运行在 cpu 上。
Start

Epoch 1/20


In [None]:
def greedy_decode(model, src, src_mask, max_len, sos_idx, eos_idx, device):
    model.eval()
    with torch.no_grad():
        Z = model.Encoder(src, src_mask) 
    
    trg_tokens = torch.full((1, 1), sos_idx, dtype=torch.long, device=device)
    

    for i in range(max_len - 1): 
        trg_len = trg_tokens.size(1)
        trg_causal_mask_2d = torch.tril(torch.ones(trg_len, trg_len)).bool().to(device)
        trg_mask = trg_causal_mask_2d.unsqueeze(0).unsqueeze(0) 
        with torch.no_grad():
            output = model.Decoder(trg_tokens, Z, src_mask, trg_mask)
        pred_logits = model.out(output[:, -1, :]) 
        next_token = torch.argmax(pred_logits, dim=-1, keepdim=True)
        if next_token.item() == eos_idx:
            break
        trg_tokens = torch.cat([trg_tokens, next_token], dim=1)
    return trg_tokens

src_len_input = 15
src_data = torch.randint(3, 1000, (1, src_len_input)).to(device) # [1, L_src]
src_mask = torch.ones(1, 1, 1, src_len_input).bool().to(device)

predicted_sequence = greedy_decode(
    model=model, 
    src=src_data, 
    src_mask=src_mask, 
    max_len=max_len, 
    sos_idx=1, 
    eos_idx=2, 
    device=device
)

print(f"输入序列形状 (Encoder Input): {src_data.shape}")
print(f"预测序列形状 (Token IDs): {predicted_sequence.shape}")
print(f"预测的 Token ID 序列: {predicted_sequence.squeeze(0).tolist()}")

输入序列形状 (Encoder Input): torch.Size([1, 15])
预测序列形状 (Token IDs): torch.Size([1, 1])
预测的 Token ID 序列: [1]
