In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F

# 检查是否有 GPU 可用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 1. 定义词汇表加载和更新函数
def load_vocab(file_path):
    vocab = {}
    with open(file_path, 'r', encoding='utf-8') as file:
        for index, word in enumerate(file):
            vocab[word.strip()] = index
    return vocab

def update_vocab(vocab):
    special_tokens = ['<unk>', '<s>', '</s>', '<pad>']
    for token in special_tokens:
        if token not in vocab:
            vocab[token] = len(vocab)
    return vocab

# 2. 数据预处理功能
def sentence_to_index(sentence, vocab, add_sos=False, add_eos=True):
    sos = vocab['<s>'] if add_sos else None
    eos = vocab['</s>']
    unk = vocab['<unk>']
    indices = [sos] if add_sos else []
    indices += [vocab.get(word, unk) for word in sentence.split()]
    if add_eos:
        indices.append(eos)
    return indices

# 3. 定义数据集类
class TranslationDataset(Dataset):
    def __init__(self, source_file, target_file, source_vocab, target_vocab):
        self.source_sentences = open(source_file, encoding='utf-8').read().split('\n')
        self.target_sentences = open(target_file, encoding='utf-8').read().split('\n')
        self.source_vocab = source_vocab
        self.target_vocab = target_vocab

    def __len__(self):
        return len(self.source_sentences)

    def __getitem__(self, idx):
        source_sentence = self.source_sentences[idx]
        target_sentence = self.target_sentences[idx]
        return (
            torch.tensor(sentence_to_index(source_sentence, self.source_vocab, add_sos=True, add_eos=True), dtype=torch.long),
            torch.tensor(sentence_to_index(target_sentence, self.target_vocab, add_sos=True, add_eos=True), dtype=torch.long)
        )

# 4. 定义 Transformer 模型
class SimpleTransformerModel(nn.Module):
    def __init__(self, num_tokens, dim_model, num_heads, num_encoder_layers, num_decoder_layers, dropout):
        super(SimpleTransformerModel, self).__init__()
        self.embed_src = nn.Embedding(num_tokens, dim_model)
        self.embed_tgt = nn.Embedding(num_tokens, dim_model)
        self.positional_encoding = nn.Parameter(torch.randn(1, dim_model))
        self.transformer = nn.Transformer(d_model=dim_model, nhead=num_heads,
                                          num_encoder_layers=num_encoder_layers,
                                          num_decoder_layers=num_decoder_layers,
                                          dropout=dropout)
        self.fc_out = nn.Linear(dim_model, num_tokens)

    def forward(self, src, tgt):
        src = self.embed_src(src) + self.positional_encoding[:, :src.size(1)]
        tgt = self.embed_tgt(tgt) + self.positional_encoding[:, :tgt.size(1)]
        output = self.transformer(src, tgt)
        output = self.fc_out(output)
        return output

# 5. 加载词汇表
vocab_en = update_vocab(load_vocab('./data/vocab.en'))
vocab_zh = update_vocab(load_vocab('./data/vocab.zh'))

# 6. 实例化数据集和 DataLoader
dataset = TranslationDataset('./data/train.zh', './data/train.en', vocab_zh, vocab_en)
train_loader = DataLoader(dataset, batch_size=32, shuffle=True, drop_last=True)

# 7. 初始化模型和训练配置
model = SimpleTransformerModel(num_tokens=len(vocab_en), dim_model=512, num_heads=8,
                               num_encoder_layers=3, num_decoder_layers=3, dropout=0.1).to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss(ignore_index=vocab_en['<pad>'])

# 8. 训练模型
for epoch in range(10):
    model.train()
    total_loss = 0
    for src, tgt in train_loader:
        src = src.to(device)
        tgt_input = tgt[:, :-1].to(device)  # Remove <eos> for input
        tgt_output = tgt[:, 1:].to(device)  # Shift for target output

        optimizer.zero_grad()
        preds = model(src, tgt_input)
        loss = criterion(preds.view(-1, len(vocab_en)), tgt_output.reshape(-1))
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f'Epoch {epoch+1}, Loss: {total_loss / len(train_loader)}')



In [None]:
def translate(model, src_sentence, src_vocab, tgt_vocab, device):
    model.eval()
    src_indices = torch.tensor(sentence_to_index(src_sentence, src_vocab, add_sos=True, add_eos=True), dtype=torch.long).unsqueeze(0).to(device)
    src_mask = model.transformer.generate_square_subsequent_mask(src_indices.size(1)).to(device)
    memory = model.transformer.encoder(model.embed_src(src_indices))

    tgt_indices = [tgt_vocab['<s>']]
    for i in range(200):  # 假定最大长度为 200
        tgt_tensor = torch.tensor(tgt_indices, dtype=torch.long).unsqueeze(0).to(device)
        tgt_mask = model.transformer.generate_square_subsequent_mask(tgt_tensor.size(1)).to(device)
        output = model.transformer.decoder(model.embed_tgt(tgt_tensor), memory, tgt_mask=tgt_mask, memory_mask=src_mask)
        output = output.transpose(0, 1)
        output = model.fc_out(output[:, -1])
        next_word = output.argmax(1).item()
        tgt_indices.append(next_word)
        if next_word == tgt_vocab['</s>']:
            break
    translated_sentence = ' '.join([list(tgt_vocab.keys())[list(tgt_vocab.values()).index(idx)] for idx in tgt_indices[1:-1]])
    return translated_sentence

# 测试例子
src_sentence = "你好，中国科学院大学"
translated_sentence = translate(model, src_sentence, vocab_zh, vocab_en, device)
print("Translated:", translated_sentence)


In [None]:
# 保存模型
torch.save(model.state_dict(), 'transformer_translation_model.pth')

# 加载模型
model.load_state_dict(torch.load('transformer_translation_model.pth'))
model.to(device)


In [None]:
from nltk.translate.bleu_score import corpus_bleu

def calculate_bleu(data_loader, model, src_vocab, tgt_vocab, device):
    model.eval()
    references = []
    candidates = []
    for src, tgt in data_loader:
        # 假设源句子在第一列，目标句子在第二列
        src = src.to(device)
        tgt = tgt.to(device)
        
        # 翻译源句子
        for i in range(src.size(0)):
            src_sentence = ' '.join([list(src_vocab.keys())[list(src_vocab.values()).index(idx)] for idx in src[i] if idx in src_vocab.values()])
            translated_sentence = translate(model, src_sentence, src_vocab, tgt_vocab, device)
            tgt_sentence = [list(tgt_vocab.keys())[list(tgt_vocab.values()).index(idx)] for idx in tgt[i] if idx in tgt_vocab.values() and idx not in [tgt_vocab['<s>'], tgt_vocab['</s>'], tgt_vocab['<pad>']]]
            
            # 保存候选翻译和参考翻译
            candidates.append(translated_sentence.split())
            references.append([tgt_sentence])

    # 计算 BLEU-4 分数
    score = corpus_bleu(references, candidates, weights=(0.25, 0.25, 0.25, 0.25))
    return score

bleu_score = calculate_bleu(test_loader, model, vocab_zh, vocab_en, device)
print(f"BLEU-4 Score: {bleu_score}")
