In [1]:
import pickle
import pandas as pd
from gensim.models import KeyedVectors
from text_preprocess import TextProcessor
from gensim.models import KeyedVectors



## 导入数据df，有三列，sentence是分好词的文本序列，train_tag是训练集标志，label是极性标志

In [2]:
def pickle_load(path):
    with open(path, 'rb') as pickle_open:
        obj = pickle.load(pickle_open)
    print('Load pickle data from', path)
    return obj


def pickle_dump(obj, path):
    with open(path, 'wb') as pickle_open:
        pickle.dump(obj, pickle_open)
    print('Dump pickle data to', path)
    
    
data_pickle_path = '../data/imdb_data_3col.pkl'
df = pickle_load(data_pickle_path)
print('Load IMDB data, totally %d sentences.' % df.shape[0])

train_df = df[df['train_tag'] == 'train']
test_df = df[df['train_tag'] == 'test']
print('%d train data; %d test data.' % (train_df.shape[0], test_df.shape[0]))

print(df.head())

Load pickle data from ../data/imdb_data_3col.pkl
Load IMDB data, totally 50000 sentences.
25000 train data; 25000 test data.
  label                                           sentence train_tag
0   pos  Bromwell High is a cartoon comedy . It ran at ...     train
1   pos  Homelessness ( or Houselessness as George Carl...     train
2   pos  Brilliant over-acting by Lesley Ann Warren . B...     train
3   pos  This is easily the most underrated film inn th...     train
4   pos  This is not the typical Mel Brooks film . It w...     train


## 用TextProcessor类fit训练集sentence，构建word2id，词向量矩阵

In [3]:
train_sent_lst = train_df['sentence']

# 构建词表
vocab_size = 40000
data_processor = TextProcessor(train_sent_lst)
data_processor.build_word_freq_dct()
data_processor.build_word2id(vocab_size)

# 构建词向量矩阵
wv_path = '../data/2-W2V.50d.txt'
key_words = KeyedVectors.load_word2vec_format(wv_path)
data_processor.build_weights(key_words)

# 查看句子长度分布
data_processor.view_sent_length_freq()

Original 134957 words in vocabulary.
After truncated low frequent word:
words num: 40000/134957; words freq: 0.981
Words exit in w2v file: 39210/40004, rate: 98.015198%
Shape of weight matrix: (40006, 50)
length of sentence: length : freq
11 1
12 1
13 1
14 2
18 1
20 1
21 2
22 2
24 2
26 2
27 3
28 3
29 3
30 1
31 4
32 4
33 4
34 8
35 4
36 5
37 5
38 10
39 4
40 7
41 10
42 13
43 14
44 16
45 15
46 20
47 26
48 28
49 19
50 32
51 27
52 30
53 45
54 33
55 28
56 45
57 27
58 32
59 27
60 28
61 46
62 48
63 45
64 41
65 44
66 46
67 43
68 44
69 45
70 39
71 34
72 37
73 35
74 41
75 49
76 32
77 47
78 31
79 32
80 43
81 35
82 34
83 46
84 40
85 30
86 37
87 38
88 42
89 39
90 33
91 35
92 41
93 34
94 44
95 35
96 42
97 31
98 35
99 48
100 50
101 39
102 38
103 38
104 17
105 35
106 21
107 38
108 30
109 37
110 42
111 24
112 32
113 29
114 38
115 38
116 36
117 40
118 45
119 45
120 53
121 47
122 55
123 55
124 43
125 66
126 57
127 53
128 79
129 78
130 93
131 86
132 102
133 102
134 105
135 110
136 116
137 119
138 131
139 13

## 设置一个最大序列长度，得到训练集和测试集的id序列

In [4]:
max_seq_len = 500

train_seqs, train_lens = data_processor.get_truncate_id_list(train_df['sentence'], max_seq_len)
test_seqs, test_lens = data_processor.get_truncate_id_list(test_df['sentence'], max_seq_len)

train_data = {'data': train_seqs, 'data_len': train_lens, 'label': train_df['label']}
test_data = {'data': test_seqs, 'data_len': test_lens, 'label': test_df['label']}

print('Train data shape:', train_data['data'].shape, 'label length:', len(train_data['label']))
print('Test data shape:', test_data['data'].shape, 'label length:', len(test_data['label']))

Train data shape: (25000, 500) label length: 25000
Test data shape: (25000, 500) label length: 25000


## 上述是完整流程，放到了load_data.py的load_imdb_data函数中，可以用参数设置词表大小和序列截断长度。  
## 函数返回两个字典，对应训练集和测试集，以及一个词向量matrix

In [5]:
from load_data import *

train_data, test_data, weights = load_imdb_data()
print(train_data.keys())
print(test_data.keys())
print(weights.shape)

Load pickle data from ../data/imdb_data_3col.pkl
Original 134957 words in vocabulary.
After truncated low frequent word:
words num: 40000/134957; words freq: 0.981
Words exit in w2v file: 39210/40004, rate: 98.015198%
Shape of weight matrix: (40006, 50)
Train data shape: (25000, 500) label length: 25000
Test data shape: (25000, 500) label length: 25000
dict_keys(['data', 'data_len', 'label'])
dict_keys(['data', 'data_len', 'label'])
(40006, 50)
