<a href="https://colab.research.google.com/github/Yuan-Yu-Han/PTAS/blob/main/machine_translation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import math
import time
import os
import urllib.request
import zipfile
import shutil

In [2]:
def download_and_read_data(max_lines=3000):
    url = "http://www.manythings.org/anki/cmn-eng.zip"
    zip_path = "cmn-eng.zip"
    txt_path = "cmn.txt"

    # 1. 如果文件不存在，下载
    if not os.path.exists(txt_path):
        if not os.path.exists(zip_path):
            print(f"正在下载数据集: {url} ...")

            # --- 修改开始：添加 User-Agent 伪装成浏览器 ---
            req = urllib.request.Request(
                url,
                headers={'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'}
            )

            with urllib.request.urlopen(req) as response, open(zip_path, 'wb') as out_file:
                shutil.copyfileobj(response, out_file)
            # --- 修改结束 ---

        print("解压数据...")
        with zipfile.ZipFile(zip_path, 'r') as zip_ref:
            zip_ref.extractall(".")

    # 2. 读取前 max_lines 行 (保持不变)
    print(f"读取前 {max_lines} 行数据...")
    pairs = []
    with open(txt_path, 'r', encoding='utf-8') as f:
        for i, line in enumerate(f):
            if i >= max_lines: break
            parts = line.strip().split('\t')
            if len(parts) >= 2:
                en_sent = parts[0].lower().strip('.!?')
                cn_sent = parts[1].strip()
                pairs.append((en_sent, cn_sent))
    return pairs

In [3]:
class SimpleTokenizer:
    def __init__(self, sentences, is_char_level=False):
        self.stoi = {'<pad>': 0, '<sos>': 1, '<eos>': 2, '<unk>': 3}
        self.itos = {0: '<pad>', 1: '<sos>', 2: '<eos>', 3: '<unk>'}
        self.is_char_level = is_char_level

        for sent in sentences:
            tokens = self._tokenize(sent)
            for token in tokens:
                if token not in self.stoi:
                    idx = len(self.stoi)
                    self.stoi[token] = idx
                    self.itos[idx] = token

    def _tokenize(self, sent):
        if self.is_char_level:
            return list(sent)
        else:
            return sent.split()

    def encode(self, sent, max_len=20):
        tokens = self._tokenize(sent)
        indices = [self.stoi.get(t, self.stoi['<unk>']) for t in tokens]
        indices = [self.stoi['<sos>']] + indices + [self.stoi['<eos>']]
        if len(indices) < max_len:
            indices += [self.stoi['<pad>']] * (max_len - len(indices))
        else:
            indices = indices[:max_len]
        return torch.tensor(indices, dtype=torch.long)

    # --- 这里是修改后的 decode 方法 ---
    def decode(self, indices):
        tokens = []
        for idx in indices:
            # 修正点：增加判断，如果 idx 已经是 int，就不用调 .item()
            if isinstance(idx, torch.Tensor):
                idx = idx.item()

            if idx == self.stoi['<eos>']: break
            if idx == self.stoi['<sos>']: continue
            if idx == self.stoi['<pad>']: continue
            tokens.append(self.itos.get(idx, '?'))

        return "".join(tokens) if self.is_char_level else " ".join(tokens)

In [4]:
# ==========================================
# 3. Transformer 模型 (复用之前的标准结构)
# ==========================================
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe.unsqueeze(0))

    def forward(self, x):
        return x + self.pe[:, :x.size(1), :]

class TransformerMT(nn.Module):
    def __init__(self, src_vocab, tgt_vocab, d_model, nhead, num_layers):
        super(TransformerMT, self).__init__()
        self.src_emb = nn.Embedding(src_vocab, d_model)
        self.tgt_emb = nn.Embedding(tgt_vocab, d_model)
        self.pos_enc = PositionalEncoding(d_model)
        self.transformer = nn.Transformer(
            d_model=d_model, nhead=nhead,
            num_encoder_layers=num_layers, num_decoder_layers=num_layers,
            batch_first=True
        )
        self.fc = nn.Linear(d_model, tgt_vocab)

    def forward(self, src, tgt, tgt_mask=None, src_pad_mask=None, tgt_pad_mask=None):
        src = self.pos_enc(self.src_emb(src) * math.sqrt(src.shape[-1]))
        tgt = self.pos_enc(self.tgt_emb(tgt) * math.sqrt(tgt.shape[-1]))

        # --- 修改点在这里 ---
        # 错误写法: key_padding_mask=src_pad_mask
        # 正确写法: src_key_padding_mask=src_pad_mask
        out = self.transformer(src, tgt,
                               tgt_mask=tgt_mask,
                               src_key_padding_mask=src_pad_mask,  # <--- 修正了这里
                               tgt_key_padding_mask=tgt_pad_mask)
        return self.fc(out)

def get_tgt_mask(size):
    mask = torch.tril(torch.ones(size, size) == 1).float()
    mask = mask.masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, 0.0)
    return mask

In [5]:
# ==========================================
# 4. 训练主流程
# ==========================================
# --- 配置 ---
MAX_LINES = 2000   # 只用前2000句简单句，保证速度
BATCH_SIZE = 32
EPOCHS = 30        # 真实数据需要多跑几轮，但因为数据少，依然很快
LR = 0.0005
D_MODEL = 128

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 1. 准备数据
raw_data = download_and_read_data(MAX_LINES) # [(en, cn), ...]
en_sents = [p[0] for p in raw_data]
cn_sents = [p[1] for p in raw_data]

# 2. 构建分词器
en_tokenizer = SimpleTokenizer(en_sents, is_char_level=False)
cn_tokenizer = SimpleTokenizer(cn_sents, is_char_level=True) # 中文按字

print(f"英文词表: {len(en_tokenizer.stoi)}, 中文词表: {len(cn_tokenizer.stoi)}")

# 3. Dataset & DataLoader
class TransDataset(Dataset):
  def __init__(self, data): self.data = data
  def __len__(self): return len(self.data)
  def __getitem__(self, idx):
      src = en_tokenizer.encode(self.data[idx][0])
      tgt = cn_tokenizer.encode(self.data[idx][1])
      return src, tgt

dataloader = DataLoader(TransDataset(raw_data), batch_size=BATCH_SIZE, shuffle=True)

正在下载数据集: http://www.manythings.org/anki/cmn-eng.zip ...
解压数据...
读取前 2000 行数据...
英文词表: 1064, 中文词表: 1182


In [6]:
# 4. 模型初始化
model = TransformerMT(
  len(en_tokenizer.stoi), len(cn_tokenizer.stoi),
  d_model=D_MODEL, nhead=4, num_layers=2
).to(device)

criterion = nn.CrossEntropyLoss(ignore_index=0) # 0 is pad
optimizer = optim.Adam(model.parameters(), lr=LR)

# 5. 训练循环
print("开始训练 (预计 1-2 分钟)...")
model.train()
start_time = time.time()

for epoch in range(15):
  total_loss = 0
  for src, tgt in dataloader:
      src, tgt = src.to(device), tgt.to(device)

      tgt_input = tgt[:, :-1]
      tgt_output = tgt[:, 1:]

      tgt_mask = get_tgt_mask(tgt_input.size(1)).to(device)
      src_pad_mask = (src == 0)
      tgt_pad_mask = (tgt_input == 0)

      optimizer.zero_grad()
      output = model(src, tgt_input, tgt_mask=tgt_mask,
                      src_pad_mask=src_pad_mask, tgt_pad_mask=tgt_pad_mask)

      loss = criterion(output.reshape(-1, len(cn_tokenizer.stoi)), tgt_output.reshape(-1))
      loss.backward()
      optimizer.step()
      total_loss += loss.item()

  if (epoch+1) % 5 == 0:
      print(f"Epoch {epoch+1}, Loss: {total_loss/len(dataloader):.4f}")

print(f"训练结束，耗时: {time.time()-start_time:.1f}s")

开始训练 (预计 1-2 分钟)...




Epoch 5, Loss: 2.5733
Epoch 10, Loss: 1.3932
Epoch 15, Loss: 0.6980
训练结束，耗时: 317.2s


In [7]:
# ==========================================
# 5. 推理测试
# ==========================================
def translate(sent):
  model.eval()
  src = en_tokenizer.encode(sent).unsqueeze(0).to(device)
  tgt_indices = [1] # <sos>

  for i in range(20):
      tgt_tensor = torch.tensor(tgt_indices).unsqueeze(0).to(device)
      with torch.no_grad():
          output = model(src, tgt_tensor)
      next_token = output[0, -1, :].argmax().item()
      if next_token == 2: break # <eos>
      tgt_indices.append(next_token)

  return cn_tokenizer.decode(tgt_indices[1:])

print("\n=== 真实翻译测试 (过拟合前几千句的效果) ===")
test_sentences = ["I am fine.", "He runs fast.", "Wait for me.", "It is a book."]
# 注意：如果输入的句子不在训练集的前2000句里，效果会很差，因为这是极小数据集

for s in test_sentences:
  print(f"En: {s:15} -> Zh: {translate(s.lower())}")


=== 真实翻译测试 (过拟合前几千句的效果) ===
En: I am fine.      -> Zh: 我同意。
En: He runs fast.   -> Zh: 他不要。
En: Wait for me.    -> Zh: 静静静静的。
En: It is a book.   -> Zh: 把它关掉。
