-
Notifications
You must be signed in to change notification settings - Fork 0
/
json2csv.py
242 lines (218 loc) · 7.55 KB
/
json2csv.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
import json
import os
import argparse
import random
alphas = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l,', 'm']
DEVM_DBIDS = ['pets_1', 'car_1', 'flight_2', 'employee_hire_evaluation', 'cre_Doc_Template_Mgt', 'museum_visit',
'poker_player', 'orchestra', 'network_1', 'dog_kennels']
TESTM_DBIDS = ['wta_1', 'real_estate_properties', 'singer', 'tvshow', 'battle_death', 'student_transcripts_tracking',
'concert_singer', 'world_1', 'voter_1'] # 'course_teach' has already been annotated, thus is excluded here
DUSQL_TESTM_DBIDS = ['运动员比赛记录', '洗衣机', '中国高校', '企业融资', '综艺节目', '友好城市', '欧洲杯球队', '打车软件',
'枪击事件', '城市财政收入']
def parse_colors_english(item):
q = []
item = item.split()
for w in item:
tocsv_w = None
if w[0] == '@':
tocsv_w = '<font color=red>' + w[1:]
if w.count('@') > 2:
print(item)
print(tocsv_w)
raise AssertionError
tocsv_w = tocsv_w.replace('@', '</font>')
elif w[0] == '$':
tocsv_w = '<font color=blue>' + w[1:]
if w.count('$') > 2:
print(item)
print(tocsv_w)
raise AssertionError
tocsv_w = tocsv_w.replace('$', '</font>')
else:
tocsv_w = w
if tocsv_w.count('@') > 1 or tocsv_w.count('$') > 1:
print(item)
print(tocsv_w)
raise AssertionError
tocsv_w = tocsv_w.replace('@', '</font>')
tocsv_w = tocsv_w.replace('$', '</font>')
q.append(tocsv_w)
q = ' '.join(q)
return q
def parse_colors_chinese(item):
q = ''
t_start = None # start index of ongoing table name
c_start = None # start index of ongoing column name
for c_idx, c in enumerate(item):
if c == '@':
if c_idx != 0 and c_idx != len(item)-1 and item[c_idx-1].encode('UTF-8').isalnum() and item[c_idx+1].encode('UTF-8').isalnum():
print(item[c_idx-1])
print(item[c_idx])
q += c
continue
if t_start is None:
t_start = c_idx
q += '<font color=red>'
else:
t_start = None
q += '</font>'
elif c == '$':
if c_start is None: # if is the start of a column
c_start = c_idx
q += '<font color=blue>'
else:
c_start = None
q += '</font>'
else:
q += c
return q
parser = argparse.ArgumentParser()
parser.add_argument('-d', '--dark', type=bool, default=False)
parser.add_argument('-p', '--pilot', type=bool, default=False, help="whether this is for a pilot experiment; if so, "
"50 entries will be sampled out of all entries")
parser.add_argument('-r', '--run', type=bool, default=False)
parser.add_argument('-i', '--input', type=str, default='')
parser.add_argument('-o', '--output', type=str, default='')
parser.add_argument('-l', '--lang', type=str, default='eng')
parser.add_argument('-t', '--test', type=bool, default=False)
parser.add_argument('--dusql_test', type=bool, default=False)
parser.add_argument('--amendment', type=bool, default=False)
args = parser.parse_args()
if args.dark:
PATHS = ['concert_singer', 'pets_1', 'car_1']
PATHS = [os.path.join('saved_results', path, 'qrys_saved.json') for path in PATHS]
OUT_PATHS = [os.path.join('saved_results', path, 'input.csv') for path in PATHS]
elif args.run:
PATHS = [os.path.join('saved_results', 'qrys_saved.json')]
OUT_PATHS = [os.path.join('saved_results', 'input.csv')]
elif len(args.input) > 0:
assert len(args.output) > 0
PATHS = [args.input]
OUT_PATHS = [args.output]
else:
raise AssertionError
update_json_flag = False
for path, out_path in zip(PATHS, OUT_PATHS):
with open(path, 'r', encoding='utf-8') as fp:
file = json.load(fp)
if args.pilot:
sample_file = []
for item in file:
if item['db_id'] == '企业融资':
sample_file.append(item)
file = sample_file
#file = random.sample(file, k=50)
elif args.test:
sample_file = []
for item in file:
if item['db_id'] in TESTM_DBIDS:
sample_file.append(item)
file = sample_file
elif args.dusql_test:
sample_file = []
for item in file:
if item['db_id'] in DUSQL_TESTM_DBIDS:
sample_file.append(item)
file = sample_file
elif args.amendment:
sample_file = []
for item in file:
if item['db_id'] == 'course_teach':
sample_file.append(item)
file = sample_file
print(len(file))
with open(out_path, 'w', encoding='utf-8') as fp:
fp.write('topic,sequence,ref_sequence,ref_gold,answer,ref_answer,qryidx\n')
#fp.write('topic,sequence,gold,ref_sequence,ref_gold,answer,ref_answer,qryidx\n')
sqls = []
new_entrys = []
for ent_idx, entry in enumerate(file):
if 'global_idx' not in entry:
entry['global_idx'] = ent_idx
new_entrys.append(entry)
update_json_flag = True
sql = None
if 'query' in entry:
sql = entry['query']
elif 'sql_query' in entry:
sql = entry['sql_query']
else:
raise AssertionError
if sql in sqls and args.lang == 'eng':
continue
else:
sqls.append(sql)
dbname = entry['db_id'].split('_')
n_dbname = []
for w in dbname:
if w not in ['1', '2', '3', '4', '5', '6', '7', '8', '9', '0']:
n_dbname.append(w)
topic = ' '.join(n_dbname)
topic = '\"'+topic+'\"'
ref_sequence = []
for item in entry['ref_question_sequence']:
if args.lang == 'eng':
q = parse_colors_english(item)
elif args.lang == 'chi':
q = parse_colors_chinese(item)
else:
raise AssertionError
ref_sequence.append(q)
ref_sequence = ' <br> '.join(ref_sequence)
ref_sequence = ref_sequence.replace('\"', '\'')
ref_sequence = '\"'+ref_sequence+'\"'
ref_gold = entry['ref_gold']
ref_gold = ref_gold.replace('"', "'")
ref_gold = '\"' + ref_gold + '\"'
entry['ref_response'][0] = entry['ref_response'][0].replace('*', 'Everything')
ref_response = '<table border=2> <tr> '
for item in entry['ref_response'][0].split(','):
ref_response += ' <th> '+item.strip().strip('"')+' </th> '
ref_response += '</tr> <tr> '
for line in entry['ref_response'][1:]:
for item in line.split(','):
ref_response += ' <th> ' + item.strip().strip('"')+' </th> '
ref_response += ' </tr> </table>'
ref_response = '\"' + ref_response + '\"'
sequence = []
if args.lang == 'eng':
question_sequence = entry['question_sequence']
elif args.lang == 'chi':
question_sequence = entry['question_sequence_chinese']
else:
raise AssertionError
for item in question_sequence:
if args.lang == 'eng':
q = parse_colors_english(item)
elif args.lang == 'chi':
q = parse_colors_chinese(item)
else:
raise AssertionError
sequence.append(q)
sequence = ' <br> '.join(sequence)
sequence = sequence.replace('\"', '')
sequence = '\"'+sequence+'\"'
if 'question_gold' in entry:
gold = entry['question_gold']
gold = gold.replace('"', "'")
gold = '\"' + gold + '\"'
else:
gold = ''
entry['answer_sample'][0] = entry['answer_sample'][0].replace('*', 'Everything')
answer_sample = '<table border=2> <tr> '
for item in entry['answer_sample'][0].split(','):
answer_sample += ' <th> ' + item.strip().strip('"') + ' </th> '
answer_sample += '</tr> <tr> '
for line in entry['answer_sample'][1:]:
for item in line.split(','):
answer_sample += ' <th> ' + item.strip().strip('"') + ' </th> '
answer_sample += ' </tr> </table>'
answer_sample = '\"' + answer_sample + '\"'
qry_idx = entry['global_idx']
qry_idx = '\"'+str(qry_idx)+'\"'
res = [topic, sequence, ref_sequence, ref_gold, answer_sample, ref_response, qry_idx]
#res = [topic, sequence, gold, ref_sequence, ref_gold, answer_sample, ref_response, qry_idx]
fp.write(','.join(res)+'\n')
if update_json_flag:
with open(path, 'w', encoding='utf-8') as fp:
json.dump(new_entrys, fp, indent=4, ensure_ascii=False)