# 数据载入模块
目标：给定数据集所在路径，返回整个文件经过处理后而成的id索引及标签

In [1]:
import numpy as np
import pandas as pd

In [2]:
vocab_file = './data/token_vec_300.bin'

### 读取vocab_file文件，生成嵌入矩阵以及建立字符到索引，索引到字符之间的双向关系

In [6]:
def get_embed(file):
    word2idx = {} # 词 -> id
    row = 1
    word2embed = {} # 词 -> 嵌入
    
    word2idx['PAD&UNK'] = 0
    word2embed['PAD&UNK'] = [float(0)]*300
    
    with open(file, encoding='utf-8') as f:
        lines = f.readlines()
        for line in lines:
            line_list = line.split()
            word = line_list[0]
            embed = line_list[1:]
            embed = [float(num) for num in embed]
            word2embed[word] = embed
            word2idx[word] = row
            row+=1
            
    idx2word = {idx: w for w, idx in word2idx.items()}
    id2embed = {}
    for ix in range(len(word2idx)):
        id2embed[ix] = word2embed[idx2word[ix]]

    embed = np.array([id2embed[ix] for ix in range(len(word2idx))])
    return embed, word2idx, idx2word

In [9]:
embed, char2idx, idx2char = get_embed(vocab_file)
text = '我爱学习'
text_ids = []
id_list = [14, 34, 55, 12, 7, 2, 44, 58, 59]
print("将'我爱学习'文本转化为id序列",end="")
for char in text.lower():
    print(char2idx[char],end=' ')
    text_ids.append(char2idx[char])
print("\nid_list转化为嵌入向量为")
for idx in id_list:
    print(embed[idx],end="")

将'我爱学习'文本转化为id序列579 545 35 799 
id_list转化为嵌入向量为
[ 1.79629400e+00  2.51598120e+00  1.82228640e+00 -3.53821160e+00
  1.99003430e+00 -2.30876790e-01  1.03420600e+00 -1.01830220e+00
 -3.21812240e-01 -1.52446540e+00 -9.26450250e-01  4.37394680e-01
  1.91165830e-01  2.55095330e-01 -2.22546630e+00  7.01172100e-01
  3.40463230e+00  2.32615610e+00 -1.45944270e+00 -2.23996950e+00
  3.06487830e-02  1.17222360e+00 -4.57386060e+00  1.06058770e-01
 -1.43589580e+00 -1.63493960e+00 -3.04952030e-01  1.70673070e-01
 -2.68366960e+00 -2.61135980e+00  2.69963200e+00  2.73121140e+00
 -1.21422040e+00  1.75877790e+00  1.26100090e+00 -4.85585100e-01
  9.02912200e-01 -7.35035240e-01 -1.96040900e-01 -2.51912360e+00
 -2.91899230e-01 -4.74387600e-01 -3.05935550e+00  4.04685400e+00
  3.44127630e+00  1.25115690e+00 -1.58445160e+00 -2.80243870e+00
 -2.95927670e-01 -9.60050900e-01  5.19874040e-01 -3.38323000e+00
 -5.32603450e+00 -2.45540450e+00 -4.37591500e+00  3.62992440e-01
 -2.19670990e+00 -4.98429000e-02  3.195730

In [13]:
def padding(text, maxlen=20):
    pad_text = []
    for sentence in text:
        pad_sentence = np.zeros(maxlen).astype('int64')
        cnt=0
        for index in sentence:
            pad_sentence[cnt]=index
            cnt+=1
            if cnt == maxlen:
                break
        pad_text.append(pad_sentence.tolist())
    return pad_text

In [14]:
text=[[5,4,15,12,7,7],
     list(range(130))]
pad_text = padding(text)
print(pad_text)

[[5, 4, 15, 12, 7, 7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]]


# 整合上两函数，给定批量的文本对，返回他们的索引id
考虑未出现过的单词与截断

In [17]:
def char_index(text_a, text_b, file):
    embed, char2idx, idx2char = get_embed(vocab_file)
    a_list, b_list = [], []
    
    # 对文件中的每一行
    for a_sentence, b_sentence in zip(text_a, text_b):
        a, b = [], []
        
        # 对每一行中的每一个字
        for char in str(a_sentence).lower():
            if char in char2idx.keys():
                a.append(char2idx[char])
            else:
                a.append(0)
                
                
        for char in str(b_sentence).lower():
            if char in char2idx.keys():
                b.append(char2idx[char])
            else:
                b.append(0)
                
        a_list.append(a)
        b_list.append(b)
        
    a_list = padding(a_list)
    b_list = padding(b_list)
    
    return a_list, b_list

### 函数应用举例

In [18]:
ta = ['我爱你','岇鎩']
tb=["再来一遍",'溌郶']
a_list,b_list=char_index(ta,tb,vocab_file)

print(a_list)
print(b_list)

[[579, 545, 1164, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
[[489, 93, 9, 1523, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]


In [19]:
def load_char_data(filename, file):
    df = pd.read_csv(filename,encoding='utf-8', sep='\t')
    text_a = df['text_a'].values
    text_b = df['text_b'].values
    label = df['label'].values
    a_index, b_index = char_index(text_a, text_b, file)
    return np.array(a_index), np.array(b_index), np.array(label)

In [20]:
a_index, b_index, label = load_char_data('./data/lcqmc_train.tsv',vocab_file)

In [21]:
len(a_index)

238766

In [22]:
print(a_index[:17])

[[1019 1165  516 1598  184    2  626   63 1019 1165  992 1205  575    2
   281   63    0    0    0    0]
 [ 579  318  151 2542   36    1  579  712  912   40  318  151    0    0
     0    0    0    0    0    0]
 [  20   75 1134  143  467  556  762 2683    0    0    0    0    0    0
     0    0    0    0    0    0]
 [ 606 1249  323   55  360 1191  510  198  351    0    0    0    0    0
     0    0    0    0    0    0]
 [ 971   47 2289 1134  383  436 1158  151 1245  344  297   21  992 1205
  1051  342 2683 1542    0    0]
 [  35   31  265 1069  451  318  151   47    2    0    0    0    0    0
     0    0    0    0    0    0]
 [ 516  619  151   26  138 1144 2401  575  461  298    1  194  257  700
   212  560    0    0    0    0]
 [1895 1776  659  165 1073 1275  214   90 2401  575  285  165    0    0
     0    0    0    0    0    0]
 [ 992 1205  525    9    6  255  521  214  108    0    0    0    0    0
     0    0    0    0    0    0]
 [ 762  380 2406    9  138  276   58    0    0    0    

In [23]:
b_index[:17]

array([[ 545,  516, 1598,  184,    2,  626,   63, 1019, 1165,  992, 1205,
         575,    2,  281,   63,    0,    0,    0,    0,    0],
       [ 579,  712, 1286,   40,  120,  318,  151,    1,  606,  570, 1989,
           0,    0,    0,    0,    0,    0,    0,    0,    0],
       [  20,   75, 1134,  143, 1552,  626,  556,  762, 2683, 1542,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0],
       [ 606, 1249,  323,   55,  360,  198,  351, 1191,  510,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0],
       [2200, 1441,   95,   27, 1300, 1158,  151, 1245,  344,  297, 2683,
        2709,    0,    0,    0,    0,    0,    0,    0,    0],
       [ 318,  151,   35,   31,  265,    2, 1069,  451,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0],
       [ 257,  700,  726,  383,  300,  209,    2,  138, 1144,  461,  298,
          72,  516,  619,  151,   47,    0,    0,    0,    0],
       [1895, 1776,  659,  165, 1073, 127

In [24]:
label[:17]

array([1, 1, 0, 1, 0, 1, 0, 1, 1, 1, 0, 0, 1, 0, 0, 1, 1], dtype=int64)