-
Notifications
You must be signed in to change notification settings - Fork 0
/
data.py
138 lines (114 loc) · 5.49 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
# -*- coding: utf-8 -*-
import pandas as pd
# import logging
from main.utils.tokenize import seg_words
from main.utils.vocab import Vocab
import numpy as np
class Dataset(object):
def __init__(self, opts):
# self.logger = logging.getLogger("rc")
# self.feature_extract = FeatureExtract()
self.max_sentence_size = opts.max_sentence_size
self.min_count = opts.min_count
self.embedding_size = opts.embedding_size
self.train_set, self.dev_set, self.test_set = [], [], []
self.un_tgt_field = ['id', 'content', 'seg_content']
self.train_set = self._load_dataset(opts.train_data_path, train=True)
self.dev_set = self._load_dataset(opts.validate_data_path)
self.test_set = self._load_dataset(opts.test_data_path)
self.tgt_info = self._get_tgt_info()
self.src_vocab = None
self.tgt_vocab = None
# self.logger.info('train_set size: {}'.format(len(self.train_set)))
# self.logger.info('dev_set size: {}'.format(len(self.dev_set)))
def _load_dataset(self, data_path, header=0, encoding="utf-8", train=False):
data = pd.read_csv(data_path, header=header, encoding=encoding).to_dict('records')
seg_words(data)
return data
def _get_tgt_info(self):
return list(filter(lambda x:True if x not in self.un_tgt_field else False,
list(self.train_set[0].keys())))
def build_vocab(self):
src_vocab = Vocab()
for word in self.word_iter('train'):
src_vocab.add(word)
# unfiltered_vocab_size = src_vocab.size()
src_vocab.filter_tokens_by_cnt(min_cnt=self.min_count)
# filtered_num = unfiltered_vocab_size - src_vocab.size()
# self.logger.info('After filter {} tokens, the final vocab size is {}'.format(filtered_num,
# src_vocab.size()))
# self.logger.info('Assigning embeddings...')
src_vocab.randomly_init_embeddings(self.embedding_size)
tgt_vocab_dict = {}
for tgt_field in self.tgt_info:
tgt_vocab = Vocab(initial_tokens=False, lower=False)
for sample in self.train_set:
tgt = sample[tgt_field]
tgt_vocab.add(tgt)
tgt_vocab_dict[tgt_field] = tgt_vocab
self.src_vocab = src_vocab
self.tgt_vocab = tgt_vocab_dict
self.convert_to_ids(self.src_vocab, self.tgt_vocab)
def _one_mini_batch(self, data, indices, tgt_field, set_name):
raw_data = [data[i] for i in indices]
batch = []
for sidx, sample in enumerate(raw_data):
batch_data = {}
batch_data['sentence_word_ids'] = sample['sentence_word_ids']
if set_name in ['train', 'dev']:
batch_data['tgt'] = sample[tgt_field]
batch.append(batch_data)
batch, pad_sentence_size = self._dynamic_padding(batch, 0)
return batch, pad_sentence_size
def _dynamic_padding(self, batch_data, pad_id):
# pad_sentence_size = min(self.max_sentence_size,
# max([len(t['sentence_word_ids']) for t in batch_data]))
pad_sentence_size = self.max_sentence_size
for sub_batch_data in batch_data:
ids = sub_batch_data['sentence_word_ids']
# print(len(ids), pad_sentence_size)
sub_batch_data['sentence_word_ids'] = ids + [pad_id] * (pad_sentence_size - len(ids))
sub_batch_data['sentence_word_ids'] = sub_batch_data['sentence_word_ids'][:pad_sentence_size]
# print(len(sub_batch_data['sentence_word_ids'] ))
return batch_data, pad_sentence_size
def word_iter(self, set_name=None):
if set_name is None:
data_set = self.train_set + self.dev_set + self.test_set
elif set_name == 'train':
data_set = self.train_set
elif set_name == 'dev':
data_set = self.dev_set
elif set_name == 'test':
data_set = self.test_set
else:
raise NotImplementedError('No data set named as {}'.format(set_name))
if data_set is not None:
for sample in data_set:
for token in sample['seg_content']:
yield token
def convert_to_ids(self, src_vocab, tgt_vocab):
for idx, data_set in enumerate([self.train_set, self.dev_set, self.test_set]):
if not len(data_set):
continue
for sample in data_set:
sample['sentence_word_ids'] = src_vocab.convert_to_ids(sample['seg_content'])
if idx <= 1:
for tgt_field in self.tgt_info:
if tgt_field in sample:
sample[tgt_field + '_id'] = tgt_vocab[tgt_field].get_id(sample[tgt_field])
def gen_mini_batches(self, set_name, batch_size, tgt_field, shuffle=True):
if set_name == 'train':
data = self.train_set
elif set_name == 'dev':
data = self.dev_set
elif set_name == 'test':
data = self.test_set
else:
raise NotImplementedError('No data set named as {}'.format(set_name))
data_size = len(data)
indices = np.arange(data_size)
if shuffle:
np.random.shuffle(indices)
for batch_start in np.arange(0, data_size, batch_size):
batch_indices = indices[batch_start: batch_start + batch_size]
yield self._one_mini_batch(data, batch_indices, tgt_field, set_name)