In [1]:
import pandas as pd
import re
import numpy as np
import torch 
import torch.nn as nn
import pickle
import pandas as pd
import matplotlib.pyplot as plt
from collections.abc import Iterable, Iterator
from torch.utils.data import DataLoader,Dataset
from torch.nn.utils import clip_grad_norm_
from tqdm import tqdm
import random
from torch.utils.tensorboard import SummaryWriter
from torch.nn.utils.rnn import pad_sequence

### 数据处理

– 编码数据、解码数据字典构建
– Dataload数据后处理方法

### 模型结构

– Encoder: embedding + rnn
– Decoder: embedding + rnn + Linear
– Encoder和Decoder衔接
– Encoder采样方向：双向RNN

### 模型训练

### 生成推理

In [2]:
# 读取并拆分数据
def read_data(data_file):
    with open(data_file, "r") as f:
        lines = f.readlines()
        enc_data, dec_data = [], []
        for line in lines:
            if line == "":
                continue
            enc, dec = line.split("\t")
            enc_tokens = re.findall(r"[\w']+", enc)
            dec_tokens = ["<BOS>"] + re.findall(r"[\u4e00-\u9fff]", dec) + ["<EOS>"]
            enc_data.append(enc_tokens)
            dec_data.append(dec_tokens)

    assert len(enc_data) == len(dec_data), "编码数据和解码数据长度不一致。"
    return enc_data, dec_data

In [3]:
enc_data, dec_data = read_data("../../data/cmn.txt")
print(len(enc_data))
print(len(dec_data))

20133
20133


In [4]:
dec_data

[['<BOS>', '嗨', '<EOS>'],
 ['<BOS>', '你', '好', '<EOS>'],
 ['<BOS>', '你', '用', '跑', '的', '<EOS>'],
 ['<BOS>', '等', '等', '<EOS>'],
 ['<BOS>', '你', '好', '<EOS>'],
 ['<BOS>', '让', '我', '来', '<EOS>'],
 ['<BOS>', '我', '赢', '了', '<EOS>'],
 ['<BOS>', '不', '会', '吧', '<EOS>'],
 ['<BOS>', '乾', '杯', '<EOS>'],
 ['<BOS>', '他', '跑', '了', '<EOS>'],
 ['<BOS>', '跳', '进', '来', '<EOS>'],
 ['<BOS>', '我', '迷', '失', '了', '<EOS>'],
 ['<BOS>', '我', '退', '出', '<EOS>'],
 ['<BOS>', '我', '沒', '事', '<EOS>'],
 ['<BOS>', '听', '着', '<EOS>'],
 ['<BOS>', '不', '可', '能', '<EOS>'],
 ['<BOS>', '没', '门', '<EOS>'],
 ['<BOS>', '你', '确', '定', '<EOS>'],
 ['<BOS>', '试', '试', '吧', '<EOS>'],
 ['<BOS>', '我', '们', '来', '试', '试', '<EOS>'],
 ['<BOS>', '为', '什', '么', '是', '我', '<EOS>'],
 ['<BOS>', '去', '问', '汤', '姆', '<EOS>'],
 ['<BOS>', '冷', '静', '点', '<EOS>'],
 ['<BOS>', '公', '平', '点', '<EOS>'],
 ['<BOS>', '友', '善', '点', '<EOS>'],
 ['<BOS>', '和', '气', '点', '<EOS>'],
 ['<BOS>', '联', '系', '我', '<EOS>'],
 ['<BOS>', '联', '系', '我', '们', '<

In [114]:
class Vocabulary:
    def __init__(self, vocab):
        self.vocab = vocab

    @classmethod
    def from_documents(cls, documents):
        tokens = set() 
        for cmt in documents:
            tokens.update(list(cmt))
        tokens = ["<PAD>", "<UNK>", "<BOS>", "<EOS>"] + sorted(list(tokens)) # set是无序的，可以在list之后做排序,保证每次构建词典顺序一致
        vocab = {token:i for i, token in enumerate(tokens)} 
        return cls(vocab)

In [133]:
enc_vocab = Vocabulary.from_documents(enc_data)
dec_vocab = Vocabulary.from_documents(dec_data)
print(len(enc_vocab.vocab))
print(len(dec_vocab.vocab))

6952
3330


In [275]:
# 编码-解码 样本

# 能getitem，能len，不就是dataset吗

# def batch_fn(data):
#     "自己编写的"
#     enc_list, dec_lsit = [], []
#     for enc, dec in data:
#         enc_index = [enc_vocab.vocab.get(tk, 1) for tk in enc]
#         dec_index = [dec_vocab.vocab.get(tk, 1) for tk in dec]
#         enc_tensor = torch.zeros(size=(1000, ), dtype=torch.long)
#         dec_tensor = torch.zeros(size=(1000, ), dtype=torch.long)

#         for i in range(len(enc_index)):
#             enc_tensor[i] = enc_index[i]
#         enc_list.append(enc_tensor)

#         for i in range(len(dec_index)):
#             dec_tensor[i] = dec_index[i]
#         dec_lsit.append(dec_tensor)

#     enc_input = torch.stack(enc_list, dim=0)
#     dec_input = torch.stack(dec_lsit, dim=0)
#     return enc_input,dec_input

#     pass

def batch_fn(data):
    enc_list, dec_lsit = [], []
    for enc, dec in data:
        enc_index = [enc_vocab.vocab.get(tk, 1) for tk in enc]
        dec_index = [dec_vocab.vocab.get(tk, 1) for tk in dec]
        enc_list.append(torch.tensor(enc_index, dtype=torch.long))
        dec_lsit.append(torch.tensor(dec_index, dtype=torch.long))

    # 用批次中最长的序列长度作为最大长度 ： (batch_size ,max_token_len)
    # 把元素为张量的List转换为张量矩阵，自动填充
    enc_input = pad_sequence(enc_list, batch_first=True)
    dec_input = pad_sequence(dec_lsit, batch_first=True)
    return enc_input, dec_input
    pass



dataset = list(zip(enc_data, dec_data))
dataloader = DataLoader(dataset, batch_size=2,shuffle=True, collate_fn=batch_fn)
for enc_input, dec_input in dataloader:
    print(enc_input.shape)
    print(enc_input)
    print(dec_input.shape)
    print(dec_input)
    break

torch.Size([2, 6])
tensor([[ 586, 2698, 5231, 3698, 1641, 3256],
        [1124, 2792, 4278, 5600,    0,    0]])
torch.Size([2, 8])
tensor([[2788,  298, 1112, 1589, 2619,  285,   91,  148],
        [2871, 1592, 1392,  104,  137, 1070,   40,    0]])
