In [1]:
import sys
sys.path.append("..")
import torch
import numpy as np
from tqdm import tqdm
import os
from Utils import VariantWordDataset, Config
from torch.utils.data import DataLoader
from torch import nn

## 1. 加载数据

In [2]:
# 实例化全局参数 Config 对象
config = Config()


# 构建数据集
train_set = VariantWordDataset("train", config.source_dic_path, config.target_dic_path)
valid_set = VariantWordDataset("test", config.source_dic_path, config.target_dic_path)
print(f"Train size: {len(train_set)}")

# 构建data_loader
n_cpu = os.cpu_count()
train_dataloader = DataLoader(train_set, batch_size=config.batch_size, shuffle=True, collate_fn=train_set.generate_batch, num_workers=n_cpu)
valid_dataloader = DataLoader(valid_set, batch_size=config.batch_size, shuffle=False, collate_fn=valid_set.generate_batch, num_workers=n_cpu)

loading dictionary from ../Data/source_vocal.pkl
loading dictionary from ../Data/target_vocal.pkl
loading dictionary from ../Data/source_vocal.pkl
loading dictionary from ../Data/target_vocal.pkl
Train size: 9086


## 2. 模型训练

In [4]:
from Model.BaselineModel import BaselineModel
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint
# import wandb

# 初始化 wandb
# wandb.init(project="Graduation_project")

# 模型初始化
model = BaselineModel(config)

# # wandb logger
# wandb_logger = WandbLogger(project = "Graduation_project",
#                            name = 'Transformer-CrossEL-lr-0.02',
#                            save_dir = '../Logs',
#                            log_model="all")

# # 模型参数保存
# checkpoint_callback = ModelCheckpoint(
#     monitor="valid_accuracy",
#     dirpath="../Weights",
#     filename="Baseline-Transformer-CrossEntropyLoss-{epoch:02d}-{valid_accuracy:.2f}",
#     save_top_k=3,
#     mode="max",
# )

# # 训练 Trainer 定义
# trainer = pl.Trainer(
#     max_epochs=5, 
#     gpus=0,
#     logger = wandb_logger,
#     callbacks=[checkpoint_callback]
#     )


# # 模型训练
# trainer.fit(
#     model, 
#     train_dataloaders=train_dataloader, 
#     val_dataloaders=valid_dataloader
# )

## 模型预测

In [26]:
import sys
sys.path.append("..")
from Model.TranslationModel import TranslationModel
from Utils.Variant_word import VariantWordDataset
import torch
from torchtext.data.metrics import bleu_score

def translate(model, src, data_loader, config):
    
    source_dic = data_loader.source_dic    
    target_dic = data_loader.target_dic

    model.eval()

    tokens = [source_dic.word2idx[i] for i in list(src)] # 构造一个样本
    num_tokens = len(tokens)
    src = (torch.LongTensor(tokens).reshape(num_tokens, 1))  # 将src_len 作为第一个维度
    tgt_tokens = greedy_decode(model, src, max_len=num_tokens + 5,
                                start_symbol=config.BOS_IDX, config=config).flatten()  # 解码的预测结果
    
    return " ".join([target_dic.idx2word[int(tok)] for tok in tgt_tokens]).replace("[BOS]", "").replace("[EOS]", "")


def greedy_decode(model, src, max_len, start_symbol, config):

    src = src.to(config.device)
    memory = model.encoder(src)  # 对输入的Token序列进行解码翻译
    ys = torch.ones(1, 1).fill_(start_symbol). \
       type(torch.long).to(config.device)  # 解码的第一个输入，起始符号

    for i in range(max_len - 1):
        memory = memory.to(config.device)
        tgt_mask = (model.my_transformer.generate_square_subsequent_mask(ys.size(0))
                   .type(torch.bool)).to(config.device)  # 根据tgt_len产生一个注意力mask矩阵（对称的）
        out = model.decoder(ys, memory, tgt_mask)  # [tgt_len,tgt_vocab_size]
        out = out.transpose(0, 1)  # [tgt_vocab_size, tgt_len]
        prob = model.classification(out[:, -1])  # 只对对预测的下一个词进行分类
        _, next_word = torch.max(prob, dim=1)  # 选择概率最大者
        next_word = next_word.item()
        ys = torch.cat([ys, torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=0)
        # 将当前时刻解码的预测输出结果，同之前所有的结果堆叠作为输入再去预测下一个词。
        if next_word == config.EOS_IDX:  # 如果当前时刻的预测输出为结束标志，则跳出循环结束预测。
            break
    return ys

def translate_to_right(src, config):
    data_loader = VariantWordDataset("train", config.source_dic_path, config.target_dic_path)
    translation_model = TranslationModel(src_vocab_size=config.source_vocab_size,
                                        tgt_vocab_size=config.target_vocab_size,
                                        d_model=config.d_model,
                                        nhead=config.num_head,
                                         num_encoder_layers=config.num_encoder_layers,
                                       num_decoder_layers=config.num_decoder_layers,
                                       dim_feedforward=config.dim_feedforward,
                                       dropout=config.dropout)
    translation_model = translation_model.to(config.device)
    torch.load("../Weights/model.pkl", map_location="cpu")
    r = translate(translation_model, src, data_loader, config)
    return r



srcs = ["9306你好,鉴于你良好的信誉,特聘请你~来我店帮忙工作（350/天）咨询,Q:707941883."]
tgts = ["9306你好,鉴于你良好的信誉,特聘请你来我店帮忙工作(350天)咨询,Q:707941883."]
config = Config()
for i, src in enumerate(srcs):
    r = translate_to_right(src, config)
    print(f"德语：{src}")
    print(f"翻译：{r}")
    print(f"英语：{tgts[i]}")
    # print(len([src]))
    # print(len([tgts[i]]))
    # print([tgts[i]])
    print([[i for i in src]])
    print([[[i for i in tgts[i]]]])
    print(bleu_score([[i for i in src]], [[[i for i in tgts[i]]]]))

loading dictionary from ../Data/source_vocal.pkl
loading dictionary from ../Data/target_vocal.pkl
德语：9306你好,鉴于你良好的信誉,特聘请你~来我店帮忙工作（350/天）咨询,Q:707941883.
翻译： 又 馥 崭 燕 元 元 雨 雨 雨 雨 雨 雨 猎 雨 雨 雨 雨 實 實 實 實 實 實 實 實 莱 實 實 實 實 實 實 實 實 實 實 實 筛 筛 筛 雨 雨 莱 莱 實 缇 實 實 實 實 實 雨 雨 猎
英语：9306你好,鉴于你良好的信誉,特聘请你来我店帮忙工作(350天)咨询,Q:707941883.
[['9', '3', '0', '6', '你', '好', ',', '鉴', '于', '你', '良', '好', '的', '信', '誉', ',', '特', '聘', '请', '你', '~', '来', '我', '店', '帮', '忙', '工', '作', '（', '3', '5', '0', '/', '天', '）', '咨', '询', ',', 'Q', ':', '7', '0', '7', '9', '4', '1', '8', '8', '3', '.']]
[[['9', '3', '0', '6', '你', '好', ',', '鉴', '于', '你', '良', '好', '的', '信', '誉', ',', '特', '聘', '请', '你', '来', '我', '店', '帮', '忙', '工', '作', '(', '3', '5', '0', '天', ')', '咨', '询', ',', 'Q', ':', '7', '0', '7', '9', '4', '1', '8', '8', '3', '.']]]
0.8034114837646484
