In [24]:
import pickle
import torch
import json
from torch.utils.data import DataLoader
from EncoderDecoderAttentionModel import Seq2Seq
import torch.optim as optim
import torch.nn as nn
from tqdm import tqdm
from torch.nn.utils.rnn import pad_sequence


In [19]:
def read_data():
    with open('./train_data/train/in.txt', 'r') as f:
        lines = f.readlines()
        input_lines = [line.strip() for line in lines if line.strip()]
    with open('./train_data/train/out.txt', 'r') as f:
        lines = f.readlines()
        output_lines = [line.strip() for line in lines if line.strip()]


    assert len(input_lines) == len(output_lines), "Input and output files must have the same number of lines."

    enc_data, dec_data = [], []

    for line in input_lines:
        enc_data.append(line.split())

    for line in output_lines:
        dec_data.append(['BOS'] + line.split() + ['EOS'])

    with open('./train_data/vocabs', 'r') as f:
        lines = f.readlines()
        tokens = ['PAD', 'UNK', 'BOS', 'EOS'] + [line.strip() for line in lines if line.strip()]
        vocab = { tk:i for i, tk in enumerate(tokens)}
    
    return enc_data, dec_data, vocab

In [26]:
enc_data, dec_data, vocab = read_data()

In [28]:
def get_proc(vocab):
    # 嵌套函数定义
    # 外部函数变量生命周期会延续到内部函数调用结束 （闭包）

    def batch_proc(data):
        """
        批次数据处理并返回
        """
        enc_ids, dec_ids, labels = [],[],[]
        for enc,dec in data:
            # token -> token index
            enc_idx = [vocab[tk] for tk in enc]
            dec_idx = [vocab[tk] for tk in dec]

            # encoder_input
            enc_ids.append(torch.tensor(enc_idx))
            # decoder_input
            dec_ids.append(torch.tensor(dec_idx[:-1]))
            # label
            labels.append(torch.tensor(dec_idx[1:]))

        
        # 数据转换张量 [batch, max_token_len]
        # 用批次中最长token序列构建张量
        enc_input = pad_sequence(enc_ids, batch_first=True)
        dec_input = pad_sequence(dec_ids, batch_first=True)
        targets = pad_sequence(labels, batch_first=True)

        # 返回数据都是模型训练和推理的需要
        return enc_input, dec_input, targets

    # 返回回调函数
    return batch_proc    


In [29]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

dl = DataLoader(
    list(zip(enc_data, dec_data)),
    batch_size=256,
    shuffle=True,
    collate_fn=get_proc(vocab)
)


In [33]:
from torch.utils.tensorboard import SummaryWriter


# 构建训练模型
# 模型构建
model = Seq2Seq(
    enc_emb_size=len(vocab),
    dec_emb_size=len(vocab),
    emb_dim=100,
    hidden_size=120,
    dropout=0.5,
)
model.to(device)

writer = SummaryWriter()

# 优化器、损失
optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

# 训练
i = 0
for epoch in range(5):
    model.train()
    tpbar = tqdm(dl)
    for enc_input, dec_input, targets in tpbar:
        enc_input = enc_input.to(device)
        dec_input = dec_input.to(device)
        targets = targets.to(device)

        # 前向传播 
        logits, _ = model(enc_input, dec_input)

        # 计算损失
        # CrossEntropyLoss需要将logits和targets展平
        # logits: [batch_size, seq_len, vocab_size]
        # targets: [batch_size, seq_len]
        # 展平为 [batch_size * seq_len, vocab_size] 和 [batch_size * seq_len]
        loss = criterion(logits.view(-1, logits.size(-1)), targets.view(-1))

        # 反向传播
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        tpbar.set_description(f'Epoch {epoch+1}, Loss: {loss.item():.4f}')
        if i % 10 == 0:
            writer.add_scalar(f'training loss', loss.item(), i)
        i += 1


writer.close()
torch.save(model.state_dict(), 'seq2seq_state.bin')

Epoch 1, Loss: 1.6889: 100%|██████████| 3010/3010 [02:19<00:00, 21.57it/s]
Epoch 2, Loss: 1.3905: 100%|██████████| 3010/3010 [02:21<00:00, 21.32it/s]
Epoch 3, Loss: 1.3318: 100%|██████████| 3010/3010 [02:23<00:00, 21.01it/s]
Epoch 4, Loss: 1.4152: 100%|██████████| 3010/3010 [02:23<00:00, 20.96it/s]
Epoch 5, Loss: 1.3719: 100%|██████████| 3010/3010 [02:24<00:00, 20.86it/s]
