# HW5: Transformer

# Sequence-to-Sequence 介紹
- 大多數常見的 seq2seq model 為 encoder-decoder model，主要由兩個部分組成，分別是 encoder 和 decoder，而這兩個部可以使用 recurrent neural network (RNN)或 transformer 來實作，主要是用來解決輸入和輸出的長度不一樣的情況
- **Encoder** 是將一連串的輸入，如文字、影片、聲音訊號等，編碼為單個向量，這單個向量可以想像為是整個輸入的抽象表示，包含了整個輸入的資訊
- **Decoder** 是將 encoder 輸出的單個向量逐步解碼，一次輸出一個結果，直到將最後目標輸出被產生出來為止，每次輸出會影響下一次的輸出，一般會在開頭加入 "< BOS >" 來表示開始解碼，會在結尾輸出 "< EOS >" 來表示輸出結束

# 作業介紹
- 英文翻譯中文
  - 輸入： 一句英文 （e.g.		tom is a student .） 
  - 輸出： 中文翻譯 （e.g. 		湯姆 是 個 學生 。）

- TODO
  - 訓練一個 RNN 模型達到 Seq2seq 翻譯
  - 訓練一個 Transformer 大幅提升效能
  - 實作 Back-translation 大幅提升效能

In [None]:
import os
import re
import random
import numpy as np
import pandas as pd
from pathlib import Path
from tqdm.auto import tqdm
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import GradScaler, autocast

from transformers import (
    MarianMTModel,
    MarianTokenizer,
    MarianConfig,
    get_linear_schedule_with_warmup,
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
)

import sentencepiece as spm
from datasets import load_metric
import sacrebleu

## 1. 设置随机种子

In [None]:
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 2. 配置参数

In [None]:
class Config:
    # 路径
    data_dir = './res/hw5/ted2020'
    save_dir = './output/hw5checkpoints_hf'
    
    # 数据
    src_lang = 'en'
    tgt_lang = 'zh'
    max_length = 128
    vocab_size = 8000
    
    # 模型
    model_type = 'custom'  # 'custom' or 'pretrained'
    d_model = 512
    num_encoder_layers = 6
    num_decoder_layers = 6
    num_attention_heads = 8
    ffn_dim = 2048
    dropout = 0.1
    
    # 训练
    batch_size = 32
    gradient_accumulation_steps = 2
    num_epochs = 30
    learning_rate = 5e-4
    warmup_steps = 4000
    max_grad_norm = 1.0
    
    # 推理
    beam_size = 5
    length_penalty = 1.0
    
    # 其他
    num_workers = 4
    save_steps = 1000
    eval_steps = 1000
    logging_steps = 100

config = Config()
os.makedirs(config.save_dir, exist_ok=True)

## 3. 数据预处理工具

In [None]:
def strQ2B(ustring):
    """全形转半形"""
    ss = []
    for s in ustring:
        rstring = ""
        for uchar in s:
            inside_code = ord(uchar)
            if inside_code == 12288:
                inside_code = 32
            elif (inside_code >= 65281 and inside_code <= 65374):
                inside_code -= 65248
            rstring += chr(inside_code)
        ss.append(rstring)
    return ''.join(ss)

def clean_text(text, lang):
    """清理文本"""
    if lang == 'en':
        text = re.sub(r"\([^()]*\)", "", text)
        text = text.replace('-', '')
        text = re.sub('([.,;!?()\"])', r' \1 ', text)
    elif lang == 'zh':
        text = strQ2B(text)
        text = re.sub(r"\([^()]*\)", "", text)
        text = text.replace(' ', '')
        text = text.replace('—', '')
        text = text.replace('"', '"')
        text = text.replace('"', '"')
        text = text.replace('_', '')
        text = re.sub('([。,;!?()\"~「」])', r' \1 ', text)
    
    text = ' '.join(text.strip().split())
    return text

def get_text_length(text, lang):
    """计算文本长度"""
    if lang == 'zh':
        return len(text.replace(' ', ''))
    return len(text.split())

def filter_parallel_corpus(src_file, tgt_file, output_prefix, 
                          ratio=9, max_len=1000, min_len=1):
    """过滤平行语料"""
    src_out = f"{output_prefix}.{config.src_lang}"
    tgt_out = f"{output_prefix}.{config.tgt_lang}"
    
    if os.path.exists(src_out) and os.path.exists(tgt_out):
        print(f"Filtered files already exist, skipping...")
        return
    
    with open(src_file, 'r', encoding='utf-8') as f_src, \
         open(tgt_file, 'r', encoding='utf-8') as f_tgt, \
         open(src_out, 'w', encoding='utf-8') as f_src_out, \
         open(tgt_out, 'w', encoding='utf-8') as f_tgt_out:
        
        kept = 0
        total = 0
        
        for src_line, tgt_line in zip(f_src, f_tgt):
            total += 1
            src_line = clean_text(src_line.strip(), config.src_lang)
            tgt_line = clean_text(tgt_line.strip(), config.tgt_lang)
            
            src_len = get_text_length(src_line, config.src_lang)
            tgt_len = get_text_length(tgt_line, config.tgt_lang)
            
            # 过滤条件
            if min_len > 0 and (src_len < min_len or tgt_len < min_len):
                continue
            if max_len > 0 and (src_len > max_len or tgt_len > max_len):
                continue
            if ratio > 0 and (src_len/tgt_len > ratio or tgt_len/src_len > ratio):
                continue
            
            f_src_out.write(src_line + '\n')
            f_tgt_out.write(tgt_line + '\n')
            kept += 1
        
        print(f"Kept {kept}/{total} sentence pairs ({kept/total*100:.2f}%)")

## 4. 自定义 Tokenizer (使用 SentencePiece)

In [None]:
class SPMTokenizer:
    def __init__(self, model_path):
        self.sp = spm.SentencePieceProcessor(model_file=model_path)
        self.vocab_size = self.sp.vocab_size()
        
        # 特殊 token
        self.pad_token = '<pad>'
        self.unk_token = '<unk>'
        self.bos_token = '<s>'
        self.eos_token = '</s>'
        
        self.pad_token_id = self.sp.pad_id()
        self.unk_token_id = self.sp.unk_id()
        self.bos_token_id = self.sp.bos_id()
        self.eos_token_id = self.sp.eos_id()
    
    def encode(self, text, add_special_tokens=True, max_length=None, 
               padding=False, truncation=False, return_tensors=None):
        """编码文本"""
        ids = self.sp.encode(text, out_type=int)
        
        if add_special_tokens:
            ids = [self.bos_token_id] + ids + [self.eos_token_id]
        
        if truncation and max_length:
            ids = ids[:max_length]
        
        if padding and max_length:
            if len(ids) < max_length:
                ids = ids + [self.pad_token_id] * (max_length - len(ids))
        
        if return_tensors == 'pt':
            ids = torch.tensor([ids])
        
        return {'input_ids': ids if return_tensors is None else ids}
    
    def decode(self, ids, skip_special_tokens=True):
        """解码 token ids"""
        if isinstance(ids, torch.Tensor):
            ids = ids.tolist()
        
        if skip_special_tokens:
            ids = [id for id in ids if id not in 
                   [self.pad_token_id, self.bos_token_id, self.eos_token_id]]
        
        return self.sp.decode(ids)
    
    def batch_decode(self, ids_list, skip_special_tokens=True):
        """批量解码"""
        return [self.decode(ids, skip_special_tokens) for ids in ids_list]
    
    def __len__(self):
        return self.vocab_size

## 5. 训练 SentencePiece 模型

In [None]:
def train_sentencepiece(input_files, model_prefix, vocab_size=8000):
    """训练 SentencePiece 模型"""
    model_file = f"{model_prefix}.model"
    
    if os.path.exists(model_file):
        print(f"SentencePiece model already exists: {model_file}")
        return
    
    print(f"Training SentencePiece model...")
    spm.SentencePieceTrainer.train(
        input=','.join(input_files),
        model_prefix=model_prefix,
        vocab_size=vocab_size,
        character_coverage=1.0,
        model_type='unigram',
        pad_id=0,
        unk_id=1,
        bos_id=2,
        eos_id=3,
        pad_piece='<pad>',
        unk_piece='<unk>',
        bos_piece='<s>',
        eos_piece='</s>',
        user_defined_symbols=[],
    )
    print(f"SentencePiece model saved to {model_file}")

## 6. Dataset 类

In [None]:
class TranslationDataset(Dataset):
    def __init__(self, src_file, tgt_file, tokenizer, max_length=128):
        self.tokenizer = tokenizer
        self.max_length = max_length
        
        # 读取数据
        with open(src_file, 'r', encoding='utf-8') as f:
            self.src_texts = [line.strip() for line in f]
        
        with open(tgt_file, 'r', encoding='utf-8') as f:
            self.tgt_texts = [line.strip() for line in f]
        
        assert len(self.src_texts) == len(self.tgt_texts)
        print(f"Loaded {len(self.src_texts)} examples")
    
    def __len__(self):
        return len(self.src_texts)
    
    def __getitem__(self, idx):
        src_text = self.src_texts[idx]
        tgt_text = self.tgt_texts[idx]
        
        # Tokenize
        src_ids = self.tokenizer.sp.encode(src_text, out_type=int)
        tgt_ids = self.tokenizer.sp.encode(tgt_text, out_type=int)
        
        # 添加特殊 token 和截断
        src_ids = [self.tokenizer.bos_token_id] + src_ids + [self.tokenizer.eos_token_id]
        tgt_ids = [self.tokenizer.bos_token_id] + tgt_ids + [self.tokenizer.eos_token_id]
        
        if len(src_ids) > self.max_length:
            src_ids = src_ids[:self.max_length-1] + [self.tokenizer.eos_token_id]
        
        if len(tgt_ids) > self.max_length:
            tgt_ids = tgt_ids[:self.max_length-1] + [self.tokenizer.eos_token_id]
        
        return {
            'input_ids': src_ids,
            'labels': tgt_ids,
            'src_text': src_text,
            'tgt_text': tgt_text,
        }

def collate_fn(batch, pad_token_id=0):
    """自定义 collate function"""
    input_ids = [item['input_ids'] for item in batch]
    labels = [item['labels'] for item in batch]
    
    # Padding
    max_src_len = max(len(ids) for ids in input_ids)
    max_tgt_len = max(len(ids) for ids in labels)
    
    input_ids_padded = []
    attention_mask = []
    labels_padded = []
    
    for ids in input_ids:
        padding_length = max_src_len - len(ids)
        input_ids_padded.append(ids + [pad_token_id] * padding_length)
        attention_mask.append([1] * len(ids) + [0] * padding_length)
    
    for ids in labels:
        padding_length = max_tgt_len - len(ids)
        # Label padding 用 -100 (PyTorch 会忽略)
        labels_padded.append(ids + [-100] * padding_length)
    
    return {
        'input_ids': torch.tensor(input_ids_padded, dtype=torch.long),
        'attention_mask': torch.tensor(attention_mask, dtype=torch.long),
        'labels': torch.tensor(labels_padded, dtype=torch.long),
    }

## 7. 构建模型

In [None]:
def build_transformer_model(vocab_size, config):
    """构建 Transformer 模型"""
    model_config = MarianConfig(
        vocab_size=vocab_size,
        d_model=config.d_model,
        encoder_layers=config.num_encoder_layers,
        decoder_layers=config.num_decoder_layers,
        encoder_attention_heads=config.num_attention_heads,
        decoder_attention_heads=config.num_attention_heads,
        encoder_ffn_dim=config.ffn_dim,
        decoder_ffn_dim=config.ffn_dim,
        dropout=config.dropout,
        attention_dropout=config.dropout,
        activation_dropout=config.dropout,
        max_position_embeddings=1024,
        pad_token_id=0,
        bos_token_id=2,
        eos_token_id=3,
        decoder_start_token_id=2,
        forced_eos_token_id=3,
    )
    
    model = MarianMTModel(model_config)
    
    # 初始化参数
    def init_weights(module):
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=0.02)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=0.02)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()
    
    model.apply(init_weights)
    
    return model

## 8. 训练函数

In [None]:
def train_epoch(model, dataloader, optimizer, scheduler, scaler, device, config):
    """训练一个 epoch"""
    model.train()
    total_loss = 0
    progress_bar = tqdm(dataloader, desc="Training")
    
    for step, batch in enumerate(progress_bar):
        # 移动到设备
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        
        # 混合精度训练
        with autocast():
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels,
            )
            loss = outputs.loss
            loss = loss / config.gradient_accumulation_steps
        
        # 反向传播
        scaler.scale(loss).backward()
        
        # 梯度累积
        if (step + 1) % config.gradient_accumulation_steps == 0:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm)
            scaler.step(optimizer)
            scaler.update()
            scheduler.step()
            optimizer.zero_grad()
        
        total_loss += loss.item() * config.gradient_accumulation_steps
        progress_bar.set_postfix({'loss': loss.item() * config.gradient_accumulation_steps})
    
    return total_loss / len(dataloader)

## 9. 评估函数

In [None]:
def evaluate(model, dataloader, tokenizer, device, config):
    """评估模型"""
    model.eval()
    total_loss = 0
    predictions = []
    references = []
    
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Evaluating"):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            
            # 计算 loss
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels,
            )
            total_loss += outputs.loss.item()
            
            # 生成翻译
            generated = model.generate(
                input_ids=input_ids,
                attention_mask=attention_mask,
                max_length=config.max_length,
                num_beams=config.beam_size,
                length_penalty=config.length_penalty,
                early_stopping=True,
            )
            
            # 解码
            pred_texts = tokenizer.batch_decode(generated, skip_special_tokens=True)
            ref_texts = tokenizer.batch_decode(labels, skip_special_tokens=True)
            
            predictions.extend(pred_texts)
            references.extend([[ref] for ref in ref_texts])
    
    avg_loss = total_loss / len(dataloader)
    
    # 计算 BLEU
    bleu = sacrebleu.corpus_bleu(predictions, [[ref[0] for ref in references]])
    
    return {
        'loss': avg_loss,
        'bleu': bleu.score,
        'predictions': predictions[:5],
        'references': [ref[0] for ref in references[:5]],
    }

## 10. 主训练流程

In [None]:
def main():
    print("="*70)
    print("Machine Translation Training with Hugging Face Transformers")
    print("="*70)
    
    # 1. 数据预处理
    print("\n[1/6] Data Preprocessing...")
    prefix = Path(config.data_dir)
    
    # 过滤数据
    filter_parallel_corpus(
        f"{prefix}/train_dev.raw.en",
        f"{prefix}/train_dev.raw.zh",
        f"{prefix}/train_dev.clean"
    )
    
    # 切分训练集和验证集
    train_src = f"{prefix}/train.{config.src_lang}"
    train_tgt = f"{prefix}/train.{config.tgt_lang}"
    valid_src = f"{prefix}/valid.{config.src_lang}"
    valid_tgt = f"{prefix}/valid.{config.tgt_lang}"
    
    if not os.path.exists(train_src):
        print("Splitting train/valid...")
        with open(f"{prefix}/train_dev.clean.{config.src_lang}", 'r') as f_src, \
             open(f"{prefix}/train_dev.clean.{config.tgt_lang}", 'r') as f_tgt:
            
            src_lines = f_src.readlines()
            tgt_lines = f_tgt.readlines()
            
            split_idx = int(len(src_lines) * 0.99)
            
            with open(train_src, 'w') as f:
                f.writelines(src_lines[:split_idx])
            with open(train_tgt, 'w') as f:
                f.writelines(tgt_lines[:split_idx])
            with open(valid_src, 'w') as f:
                f.writelines(src_lines[split_idx:])
            with open(valid_tgt, 'w') as f:
                f.writelines(tgt_lines[split_idx:])
    
    # 2. 训练 SentencePiece
    print("\n[2/6] Training SentencePiece...")
    spm_prefix = f"{prefix}/spm{config.vocab_size}"
    train_sentencepiece(
        [train_src, train_tgt, valid_src, valid_tgt],
        spm_prefix,
        config.vocab_size
    )
    
    # 3. 加载 Tokenizer
    print("\n[3/6] Loading Tokenizer...")
    tokenizer = SPMTokenizer(f"{spm_prefix}.model")
    print(f"Vocab size: {len(tokenizer)}")
    
    # 4. 创建 Dataset 和 DataLoader
    print("\n[4/6] Creating Datasets...")
    train_dataset = TranslationDataset(train_src, train_tgt, tokenizer, config.max_length)
    valid_dataset = TranslationDataset(valid_src, valid_tgt, tokenizer, config.max_length)
    
    train_loader = DataLoader(
        train_dataset,
        batch_size=config.batch_size,
        shuffle=True,
        num_workers=config.num_workers,
        collate_fn=lambda x: collate_fn(x, tokenizer.pad_token_id),
    )
    
    valid_loader = DataLoader(
        valid_dataset,
        batch_size=config.batch_size,
        shuffle=False,
        num_workers=config.num_workers,
        collate_fn=lambda x: collate_fn(x, tokenizer.pad_token_id),
    )
    
    # 5. 构建模型
    print("\n[5/6] Building Model...")
    model = build_transformer_model(len(tokenizer), config)
    model = model.to(device)
    
    num_params = sum(p.numel() for p in model.parameters())
    num_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Total parameters: {num_params:,}")
    print(f"Trainable parameters: {num_trainable:,}")
    
    # 6. 优化器和调度器
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=config.learning_rate,
        betas=(0.9, 0.98),
        eps=1e-9,
        weight_decay=0.01
    )
    
    num_training_steps = len(train_loader) * config.num_epochs // config.gradient_accumulation_steps
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=config.warmup_steps,
        num_training_steps=num_training_steps
    )
    
    scaler = GradScaler()
    
    # 7. 训练循环
    print("\n[6/6] Training...")
    best_bleu = 0
    
    for epoch in range(config.num_epochs):
        print(f"\n{'='*70}")
        print(f"Epoch {epoch + 1}/{config.num_epochs}")
        print(f"{'='*70}")
        
        # 训练
        train_loss = train_epoch(model, train_loader, optimizer, scheduler, scaler, device, config)
        print(f"Train Loss: {train_loss:.4f}")
        
        # 评估
        eval_results = evaluate(model, valid_loader, tokenizer, device, config)
        print(f"Valid Loss: {eval_results['loss']:.4f}")
        print(f"Valid BLEU: {eval_results['bleu']:.2f}")
        
        # 显示示例
        print("\nExample translations:")
        for i, (pred, ref) in enumerate(zip(eval_results['predictions'], eval_results['references'])):
            print(f"\nExample {i+1}:")
            print(f"  Prediction: {pred}")
            print(f"  Reference:  {ref}")
        
        # 保存最佳模型
        if eval_results['bleu'] > best_bleu:
            best_bleu = eval_results['bleu']
            model.save_pretrained(f"{config.save_dir}/best_model")
            tokenizer.sp.save(f"{config.save_dir}/tokenizer.model")
            print(f"✓ Saved best model (BLEU: {best_bleu:.2f})")
        
        # 保存检查点
        if (epoch + 1) % 5 == 0:
            model.save_pretrained(f"{config.save_dir}/checkpoint-epoch-{epoch+1}")
    
    print("\n" + "="*70)
    print(f"Training completed! Best BLEU: {best_bleu:.2f}")
    print("="*70)

## 11. 推理函数

In [None]:
def generate_predictions(model_path, tokenizer_path, test_src, output_file, config):
    """生成测试集预测"""
    print("Loading model for inference...")
    
    # 加载模型和 tokenizer
    model = MarianMTModel.from_pretrained(model_path)
    model = model.to(device)
    model.eval()
    
    tokenizer = SPMTokenizer(tokenizer_path)
    
    # 读取测试数据
    with open(test_src, 'r', encoding='utf-8') as f:
        test_texts = [line.strip() for line in f]
    
    predictions = []
    
    with torch.no_grad():
        for i in tqdm(range(0, len(test_texts), config.batch_size), desc="Generating"):
            batch_texts = test_texts[i:i + config.batch_size]
            
            # Tokenize
            input_ids_list = []
            for text in batch_texts:
                ids = tokenizer.sp.encode(text, out_type=int)
                ids = [tokenizer.bos_token_id] + ids + [tokenizer.eos_token_id]
                input_ids_list.append(ids)
            
            # Padding
            max_len = max(len(ids) for ids in input_ids_list)
            input_ids_padded = []
            attention_mask = []
            
            for ids in input_ids_list:
                padding_length = max_len - len(ids)
                input_ids_padded.append(ids + [tokenizer.pad_token_id] * padding_length)
                attention_mask.append([1] * len(ids) + [0] * padding_length)
            
            input_ids = torch.tensor(input_ids_padded, dtype=torch.long).to(device)
            attention_mask = torch.tensor(attention_mask, dtype=torch.long).to(device)
            
            # Generate
            generated = model.generate(
                input_ids=input_ids,
                attention_mask=attention_mask,
                max_length=config.max_length,
                num_beams=config.beam_size,
                length_penalty=config.length_penalty,
                early_stopping=True,
            )
            
            # Decode
            batch_predictions = tokenizer.batch_decode(generated, skip_special_tokens=True)
            predictions.extend(batch_predictions)
    
    # 保存预测
    with open(output_file, 'w', encoding='utf-8') as f:
        for pred in predictions:
            f.write(pred + '\n')
    
    print(f"Predictions saved to {output_file}")

## 12. 运行

In [None]:
 # 训练模型
main()

In [None]:
# 生成测试集预测
generate_predictions(
    model_path=f"{config.save_dir}/best_model",
    tokenizer_path=f"{config.save_dir}/tokenizer.model",
    test_src=f"{config.data_dir}/test.raw.{config.src_lang}",
    output_file="./prediction.txt",
    config=config
)