forked from ethancaballero/first_stmn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
mctest_dataset_parser.py
188 lines (149 loc) · 6.99 KB
/
mctest_dataset_parser.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
import re
from theano_util import *
def only_words(line):
ps = re.sub(r'[^a-zA-Z0-9]', r' ', line)
ws = re.sub(r'(\W)', r' \1 ', ps) # Put spaces around punctuations
ns = re.sub(r'(\d+)', r' <number> ', ws) # Put spaces around numbers
hs = re.sub(r'-', r' ', ns) # Replace hyphens with space
rs = re.sub(r' +', r' ', hs) # Reduce multiple spaces into 1
rs = rs.lower().strip()
return rs
def clean_sentence(line):
ps = re.sub(r'[^a-zA-Z0-9\.\?\!]', ' ', line) # Split on punctuations and hex characters
ws = re.sub(r'(\W)', r' \1 ', ps) # Put spaces around punctuations
ns = re.sub(r'(\d+)', r' <number> ', ws) # Put spaces around numbers
hs = re.sub(r'-', r' ', ns) # Replace hyphens with space
rs = re.sub(r' +', r' ', hs) # Reduce multiple spaces into 1
rs = rs.lower().strip()
return rs
def get_sentences(line):
ps = re.sub(r'[^a-zA-Z0-9\.\?\!]', ' ', line) # Split on punctuations and hex characters
s = re.sub(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?|\!)\s', '\t', ps) # Split on sentences
ws = re.sub(r'(\W)', r' \1 ', s) # Put spaces around punctuations
ns = re.sub(r'(\d+)', r' <number> ', ws) # Put spaces around numbers
hs = re.sub(r'-', r' ', ns) # Replace hyphens with space
rs = re.sub(r' +', r' ', hs) # Reduce multiple spaces into 1
rs = rs.lower().strip()
return rs.split('\t')
def get_answer_index(a):
answer_to_index = {
'A': 0,
'B': 1,
'C': 2,
'D': 3,
}
return answer_to_index[a]
def parse_mc_test_dataset(questions_file, answers_file, word_id=0, word_to_id={}, update_word_ids=True, max_stmts=20, max_words=20, pad=True):
dataset = []
questions = []
null_word = '<NULL>'
if null_word not in word_to_id:
if update_word_ids == True:
word_to_id[null_word] = word_id
word_id += 1
else:
print "Null word not found!! AAAAA"
sys.exit(1)
null_word_id = word_to_id[null_word]
article_files = set()
print("Parsing questions %s %s" % (questions_file, answers_file))
q_file = open(questions_file, 'r')
a_file = open(answers_file, 'r')
questions_data = q_file.readlines()
answers_data = a_file.readlines()
assert(len(questions_data) == len(answers_data))
more_than_1_word_answers = 0
answer_word_unknown = 0
for i in xrange(len(questions_data)):
question_line = questions_data[i]
answer_line = answers_data[i]
question_pieces = question_line.strip().split('\t')
assert(len(question_pieces) == 23)
answer_pieces = answer_line.strip().split('\t')
assert(len(answer_pieces) == 4)
text = question_pieces[2]
text = text.replace('\\newline', ' ')
sentences = get_sentences(text)
statements = []
for s in sentences:
tokens = s.strip().split()
if update_word_ids:
for token in tokens:
if token not in word_to_id:
word_to_id[token] = word_id
word_id += 1
else:
tokens = filter(lambda x: x in word_to_id, tokens)
if pad:
tokens = pad_statement(tokens, null_word, max_words)
statements.append(tokens)
dataset.append(tokens)
if pad:
statements = pad_memories(statements, null_word, max_stmts, max_words)
# 4 questions
for j in range(4):
q_index = (j * 5) + 3
q_words = question_pieces[q_index]
q_words = clean_sentence(q_words).split()
options = [
only_words(question_pieces[q_index + 1]),
only_words(question_pieces[q_index + 2]),
only_words(question_pieces[q_index + 3]),
only_words(question_pieces[q_index + 4]),
]
correct = get_answer_index(answer_pieces[j])
answer = options[correct]
if update_word_ids:
for token in (q_words + options):
if token not in word_to_id:
word_to_id[token] = word_id
word_id += 1
else:
q_words = filter(lambda x: x in word_to_id, q_words)
if pad:
q_words = pad_statement(q_words, null_word, max_words)
# Ignore more than 1 word answers
if len(answer.split(' ')) > 1:
more_than_1_word_answers += 1
continue
elif len(filter(lambda x: x not in word_to_id, options)) > 0:
answer_word_unknown += 1
continue
option_word_ids = map(lambda x: word_to_id[x], options)
article_no = len(questions)
questions.append([article_no, -1, statements, q_words, answer, option_word_ids])
print "There are %d questions" % len(questions)
print "There are %d statements" % len(dataset)
print "There are %d words" % len(word_to_id)
print "Ignored %d questions which had more than 1 word answers" % more_than_1_word_answers
print "Ignored %d questions which had an unknown answer word" % answer_word_unknown
print("Final processing...")
questions_seq = map(lambda x: transform_ques_weak(x, word_to_id, word_id), questions)
return dataset, questions_seq, word_to_id, word_id, null_word_id
import cPickle
if __name__ == "__main__":
ADD_PADDING = True
train_file = 'mc500.train.tsv'
train_answers = train_file.replace('tsv', 'ans')
test_file = train_file.replace('train', 'test')
test_answers = test_file.replace('tsv', 'ans')
data_dir = sys.argv[1]
train_dataset, train_questions, word_to_id, num_words, null_word_id = parse_mc_test_dataset(data_dir + '/' + train_file, data_dir + '/' + train_answers, pad=ADD_PADDING)
test_dataset, test_questions, word_to_id, num_words, null_word_id = parse_mc_test_dataset(data_dir + '/' + test_file, data_dir + '/' + test_answers, word_id=num_words, word_to_id=word_to_id, update_word_ids=False, pad=ADD_PADDING)
# Add dev to test
test2_file = train_file.replace('train', 'dev')
test2_answers = test2_file.replace('tsv', 'ans')
test2_dataset, test2_questions, word_to_id, num_words, null_word_id = parse_mc_test_dataset(data_dir + '/' + test2_file, data_dir + '/' + test2_answers, word_id=num_words, word_to_id=word_to_id, update_word_ids=False, pad=ADD_PADDING)
test_dataset += test2_dataset
test_questions += test2_questions
# Pickle!!!!
print("Pickling train...")
train_pickle = train_file.replace('tsv', 'pickle')
f = file(data_dir + '/' + train_pickle, 'wb')
cPickle.dump((train_dataset, train_questions, word_to_id, num_words, null_word_id), f, protocol=cPickle.HIGHEST_PROTOCOL)
f.close()
print("Pickling test...")
test_pickle = test_file.replace('tsv', 'pickle')
f = file(data_dir + '/' + test_pickle, 'wb')
cPickle.dump((test_dataset, test_questions, word_to_id, num_words, null_word_id), f, protocol=cPickle.HIGHEST_PROTOCOL)
f.close()