In [9]:
from datasets import load_dataset

dataset = load_dataset("bentrevett/multi30k")

In [10]:
# 只取前200条训练数据
train_data = dataset["train"].select(range(200))

# 提取源语言和目标语言文本
src_texts = [x["en"] for x in train_data]
tgt_texts = [x["de"] for x in train_data]

print(len(src_texts))  # 200
print(src_texts[:3])
print(tgt_texts[:3])

200
['Two young, White males are outside near many bushes.', 'Several men in hard hats are operating a giant pulley system.', 'A little girl climbing into a wooden playhouse.']
['Zwei junge weiße Männer sind im Freien in der Nähe vieler Büsche.', 'Mehrere Männer mit Schutzhelmen bedienen ein Antriebsradsystem.', 'Ein kleines Mädchen klettert in ein Spielhaus aus Holz.']


In [11]:
from collections import Counter

def build_vocab(sentences, min_freq=1):
    counter = Counter()
    for sent in sentences:
        counter.update(sent.lower().split())
    vocab = {'<pad>':0, '<sos>':1, '<eos>':2, '<unk>':3}
    idx = 4
    for word, freq in counter.items():
        if freq >= min_freq:
            vocab[word] = idx
            idx += 1
    return vocab

src_vocab = build_vocab(src_texts)
tgt_vocab = build_vocab(tgt_texts)

In [12]:
import torch
from torch.utils.data import Dataset, DataLoader

class TranslationDataset(Dataset):
    def __init__(self, src_texts, tgt_texts, src_vocab, tgt_vocab, max_len=32):
        self.src_texts = src_texts
        self.tgt_texts = tgt_texts
        self.src_vocab = src_vocab
        self.tgt_vocab = tgt_vocab
        self.max_len = max_len

    def __len__(self):
        return len(self.src_texts)

    def encode(self, text, vocab):
        tokens = text.lower().split()
        tokens = ['<sos>'] + tokens + ['<eos>']
        ids = [vocab.get(tok, vocab['<unk>']) for tok in tokens]
        if len(ids) < self.max_len:
            ids += [vocab['<pad>']] * (self.max_len - len(ids))
        else:
            ids = ids[:self.max_len]
        return torch.tensor(ids)

    def __getitem__(self, idx):
        src_ids = self.encode(self.src_texts[idx], self.src_vocab)
        tgt_ids = self.encode(self.tgt_texts[idx], self.tgt_vocab)
        return src_ids, tgt_ids

train_dataset = TranslationDataset(src_texts, tgt_texts, src_vocab, tgt_vocab)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)

In [13]:
from transformer import make_model
device = "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu"
model = make_model(
    src_vocab=len(src_vocab), 
    tgt_vocab=len(tgt_vocab), 
    d_model=768, h=8, d_ff=2048, N=6
).to(device)

In [14]:
def make_src_mask(src):
    # src: [batch, seq_len]
    return (src != src_vocab['<pad>']).unsqueeze(1).unsqueeze(2)  # [B,1,1,seq_len]

def make_tgt_mask(tgt):
    # tgt: [batch, seq_len]
    tgt_pad_mask = (tgt != tgt_vocab['<pad>']).unsqueeze(1).unsqueeze(2)  # HL: [B,1,1,seq_len]
    seq_len = tgt.size(1)
    tgt_sub_mask = torch.tril(torch.ones((seq_len, seq_len), device=tgt.device)).bool() # HL: 下三角矩阵
    return tgt_pad_mask & tgt_sub_mask  # HL: [B,1,seq_len,seq_len]

# TEST: mask
# for src, tgt in train_loader:
#     src, tgt = src.to(device), tgt.to(device)
#     src_mask = make_src_mask(src)
#     tgt_mask = make_tgt_mask(tgt[:, :-1])
#     print(src.shape)
#     print(tgt.shape)
#     print(src_mask.shape)
#     print(tgt_mask.shape)
#     out = model(src, tgt[:, :-1], src_mask, tgt_mask)
#     print(out.shape)  # [batch, seq_len-1, tgt_vocab_size]
#     break

In [16]:
import torch.optim as optim
optimizer = optim.Adam(model.parameters(), lr=1e-4)
criterion = torch.nn.CrossEntropyLoss(ignore_index=tgt_vocab['<pad>'])

In [29]:
# # TEST runnable
# for src, tgt in train_loader:
#     model.eval()
#     src = src.to(device)
#     tgt = tgt.to(device)
    
#     tgt_input = tgt[:, :-1]
#     tgt_output = tgt[:, 1:]
    
#     src_mask = make_src_mask(src)
#     tgt_mask = make_tgt_mask(tgt_input)
    
#     optimizer.zero_grad()
#     output = model(src, tgt_input, src_mask, tgt_mask)  # [B, seq_len, vocab_size]
    
#     loss = criterion(output.reshape(-1, output.size(-1)), tgt_output.reshape(-1))
#     loss.backward()
#     optimizer.step()
#     print(loss)
    
#     break

In [59]:
from tqdm import tqdm
import os

num_epochs = 10
save_dir = './checkpoints'

for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    
    # 在 DataLoader 外包一层 tqdm
    loop = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", leave=False)
    
    for src, tgt in loop:
        src = src.to(device)
        tgt = tgt.to(device)
        
        tgt_input = tgt[:, :-1]
        tgt_output = tgt[:, 1:]
        
        src_mask = make_src_mask(src)
        tgt_mask = make_tgt_mask(tgt_input)
        
        optimizer.zero_grad()
        output = model(src, tgt_input, src_mask, tgt_mask)  # [B, seq_len, vocab_size]
        
        loss = criterion(output.reshape(-1, output.size(-1)), tgt_output.reshape(-1))
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        
        # 更新进度条显示当前 loss
        loop.set_postfix(loss=loss.item())
    
    avg_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch+1} finished, Average Loss: {avg_loss:.4f}")
    # 保存模型
    save_path = os.path.join(save_dir, f"transformer_epoch{epoch+1}.pt")
    torch.save(model.state_dict(), save_path)
    print(f"Saved model checkpoint to {save_path}")


                                                                      

Epoch 1 finished, Average Loss: 7.1451
Saved model checkpoint to ./checkpoints/transformer_epoch1.pt


                                                                      

Epoch 2 finished, Average Loss: 7.1522
Saved model checkpoint to ./checkpoints/transformer_epoch2.pt


                                                                      

Epoch 3 finished, Average Loss: 7.1300
Saved model checkpoint to ./checkpoints/transformer_epoch3.pt


                                                                      

Epoch 4 finished, Average Loss: 7.1615
Saved model checkpoint to ./checkpoints/transformer_epoch4.pt


                                                                      

Epoch 5 finished, Average Loss: 7.1338
Saved model checkpoint to ./checkpoints/transformer_epoch5.pt


                                                                      

Epoch 6 finished, Average Loss: 7.1347
Saved model checkpoint to ./checkpoints/transformer_epoch6.pt


                                                                      

Epoch 7 finished, Average Loss: 7.1382
Saved model checkpoint to ./checkpoints/transformer_epoch7.pt


                                                                      

Epoch 8 finished, Average Loss: 7.1440
Saved model checkpoint to ./checkpoints/transformer_epoch8.pt


                                                                      

Epoch 9 finished, Average Loss: 7.1396
Saved model checkpoint to ./checkpoints/transformer_epoch9.pt


                                                                       

Epoch 10 finished, Average Loss: 7.1407
Saved model checkpoint to ./checkpoints/transformer_epoch10.pt


In [65]:
import torch

def translate(model, sentence, src_vocab, tgt_vocab, max_len=32, device='cuda'):
    model.eval()
    
    # 将输入句子转为 token id
    src_ids = [src_vocab.get(tok.lower(), src_vocab['<unk>']) for tok in sentence.split()]
    src_ids = [src_vocab['<sos>']] + src_ids + [src_vocab['<eos>']]
    src_tensor = torch.tensor(src_ids).unsqueeze(0).to(device)  # [1, seq_len]
    
    src_mask = make_src_mask(src_tensor)
    
    # decoder 初始输入
    ys = torch.tensor([[tgt_vocab['<sos>']]], device=device)
    
    for i in range(max_len):
        tgt_mask = make_tgt_mask(ys)
        out = model(src_tensor, ys, src_mask, tgt_mask)  # [1, seq_len, vocab_size]
        next_word = out[:, -1, :].argmax(-1).unsqueeze(0)  # 贪心取最大概率
        
        ys = torch.cat([ys, next_word], dim=1)
        
        if next_word.item() == tgt_vocab['<eos>']:
            break
    
    # 将 token id 转回文字
    id_to_word = {idx: word for word, idx in tgt_vocab.items()}
    translated = [id_to_word[i.item()] for i in ys[0,1:]]  # 去掉<sos>
    
    return ' '.join(translated).replace('<eos>', '')


In [66]:
model.load_state_dict(torch.load('./checkpoints/transformer_epoch10.pt'))

<All keys matched successfully>

In [67]:
model.to(device)

src_sentence = "A man in a black shirt is playing a guitar."
translation = translate(model, src_sentence, src_vocab, tgt_vocab, device=device)
print("Source:", src_sentence)
print("Translation:", translation)

Source: A man in a black shirt is playing a guitar.
Translation: mahlzeit mahlzeit mahlzeit mahlzeit mahlzeit mahlzeit mahlzeit mahlzeit mahlzeit mahlzeit mahlzeit mahlzeit mahlzeit mahlzeit mahlzeit mahlzeit mahlzeit mahlzeit mahlzeit mahlzeit mahlzeit mahlzeit mahlzeit mahlzeit mahlzeit mahlzeit mahlzeit mahlzeit mahlzeit mahlzeit mahlzeit mahlzeit
