# 1. 导入所需要的模块,设置语料的路径和训练时的batch_size

In [1]:
import torch
import torch.nn as nn
import torch.utils.data as Data
import jieba
import matplotlib.pyplot as plt

batch_size = 2
en_name = './corpus/train/EN.txt'
cn_name = './corpus/train/CN.txt'

# 2. 从文件中读取数据并且使用jieba分词，随后生成词表

In [2]:
def file2tokens(file_name: str):
    r"""
    return lists of words and max len
    """
    with open(file_name, 'r') as file:
        data = file.read()
        data = data.split('\n')
        data = [list(jieba.cut(t)) for t in data]
        data = [t for t in data if len(t) > 5]
        return data


def tokens2vocab(data, target=None):
    r"""
    根据分词后的文本生成对应的词表(vocab)和相关数据
    
    :param data:generated by file2tokens
    :param target: should be 'tgt' or 'src'
    :return: 
    """
    if target == 'src':
        vocab = {'P': 0}
        for sentence in data:
            for token in sentence:
                if not token in vocab:
                    vocab[token] = len(vocab)
        idx2token = {i: t for i, t in enumerate(vocab)}
        vocab_size = len(vocab)
    elif target == 'tgt':
        vocab = {'P': 0, 'S': 1, 'E': 2}
        for sentence in data:
            for token in sentence:
                if not token in vocab:
                    vocab[token] = len(vocab)
        idx2token = {i: t for i, t in enumerate(vocab)}
        vocab_size = len(vocab)
    else:
        raise ValueError("invalid param about target!")
    return vocab, idx2token, vocab_size

# 3. 生成Transformer使用的LongTensor序列

In [3]:
def make_data(src_tokens: list, tgt_tokens: list, src_vocab: dict, tgt_vocab: dict):
    r"""
    把分词后的文本转化成下标序列。本函数同时也要实现了padding的功能。
    :param token_lists:输入由file2tokens生成的src_tokens和tgt_tokens
    :param size: 分别需要作padding的长度
    :return data_list: [LongTensor, LongTensor, LongTensor]
    """
    enc_inputs, dec_inputs, dec_outputs = [], [], []

    def pad(x, max_len):
        # 这里的padding的方法可能会有性能上的问题，但就先这样吧。
        x = x + [0] * (max_len - len(x))
        return x

    src_len = max([len(sentence) for sentence in src_tokens])
    tgt_len = max([len(sentence) for sentence in tgt_tokens])
    for src in src_tokens:
        src_input = pad([src_vocab[token] for token in src], src_len)
        enc_inputs.append(src_input)

    for tgt in tgt_tokens:
        tgt_input = pad([tgt_vocab[token] for token in tgt], tgt_len)
        dec_inputs.append([tgt_vocab['S']] + tgt_input)
        dec_outputs.append(tgt_input + [tgt_vocab['E']])

    data_list = [torch.LongTensor(enc_inputs), torch.LongTensor(dec_inputs), torch.LongTensor(dec_outputs)]
    return data_list

In [4]:
src_tokens = file2tokens(en_name)
src2idx, idx2src, src_vocab_size = tokens2vocab(src_tokens, 'src')

tgt_tokens = file2tokens(cn_name)
tgt2idx, idx2tgt, tgt_vocab_size = tokens2vocab(tgt_tokens, 'tgt')

data_list = make_data(src_tokens, tgt_tokens, src2idx, tgt2idx)

Building prefix dict from the default dictionary ...
Dumping model to file cache /tmp/jieba.cache
Loading model cost 0.343 seconds.
Prefix dict has been built successfully.


# 4. 实现自己的DataSet类和DataLoader类

In [5]:
class MyDataSet(Data.Dataset):
    def __init__(self, data_list):
        super(MyDataSet, self).__init__()
        self.enc_inputs, self.dec_inputs, self.dec_outputs = data_list

    def __len__(self):
        return self.enc_inputs.shape[0]

    def __getitem__(self, idx):
        return self.enc_inputs[idx], self.dec_inputs[idx], self.dec_outputs[idx]


loader = Data.DataLoader(MyDataSet(data_list), batch_size, shuffle=True)
for a, b, c in loader:
    print(repr(b))
    break

tensor([[  1, 194,  15, 195, 150,   5, 196, 180, 197, 198,  21,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0],
        [  1,  50,  49,   8,  71,  15, 253, 254, 255, 256, 144,  12, 145, 147,
           5, 257, 258, 259,   5, 260, 261, 147, 183,   5, 211, 262, 208, 262,
          15, 263,  38, 246,  38, 264, 265, 266,  15, 263,  38, 245,  38, 264,
         267, 266,  21,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0]])


In [6]:
print(tgt_vocab_size)

294
