-
Notifications
You must be signed in to change notification settings - Fork 51
/
sqlize_csv.py
508 lines (456 loc) · 20 KB
/
sqlize_csv.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
################################################################################
# sqlize_csv.py, originally written by Luitien Pan
# Last updated 18 October 2017 by Melinda Morang
################################################################################
'''Copyright 2017 Esri
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.'''
################################################################################
# Imports the CSV-formatted GTFS information into a SQLite database file.
# Handles data conversion, table creation, and indexing transparently.
#
# If you specify multiple GTFS datasets, this merges them. In order to avoid
# collisions between identifiers that are supposed to be dataset-unique, I
# prepend an agency label to each *_id field value. This label comes from the
# last component of the corresponding GTFS_DIR* path. So for example, if I
# keep my CTA data in
# /home/luitien/gtfs/cta/*.txt
# then stop_id 1518 gets stored and manipulated as "cta:1518".
#
# In line with this, *don't* put your CSV files like so:
# [...]/cta/data/*.txt
# [...]/metra/data/*.txt
# and so on, because then they'll all be labelled as ``data''.
import csv
from cStringIO import StringIO
import datetime
import itertools
import os
import re
import sqlite3
import sys
import hms
class CustomError(Exception):
pass
Errors_To_Return = []
csv_fnames = ["stops.txt", "calendar.txt", "calendar_dates.txt",
"trips.txt", "stop_times.txt", "routes.txt", "frequencies.txt"]
sql_types = {
str : "TEXT" ,
float : "REAL" ,
int : "INTEGER" ,
}
# Each subdictionary specifies the columns for the named sql table.
# The format is:
# tbl_name : { col_name : (datatype, is_required) }
# where is_required is True for columns required by the GTFS
# and names the default value otherwise.
sql_schema = {
"stops" : {
"stop_id" : (str, True) ,
"stop_code" : (str, "NULL") ,
"stop_name" : (str, True) ,
"stop_desc" : (str, "NULL") ,
"stop_lat" : (float, True) ,
"stop_lon" : (float, True) ,
"zone_id" : (str, "NULL") ,
"stop_url" : (str, "NULL") ,
"location_type" : (int, "NULL") ,
"parent_station" : (str, "NULL") ,
"wheelchair_boarding" : (int, "0"),
} ,
"calendar" : {
"service_id" : (str, True) ,
"monday" : (int, True) ,
"tuesday" : (int, True) ,
"wednesday" : (int, True) ,
"thursday" : (int, True) ,
"friday" : (int, True) ,
"saturday" : (int, True) ,
"sunday" : (int, True) ,
"start_date" : (str, True) ,
"end_date" : (str, True) ,
} ,
"calendar_dates" : {
"service_id" : (str, True) ,
"date" : (str, True) ,
"exception_type" : (int, True) ,
} ,
"stop_times" : {
"trip_id" : (str, True) ,
"arrival_time" : (float, True) ,
"departure_time" : (float, True) ,
"stop_id" : (str, True) ,
"stop_sequence" : (int, True) ,
} ,
"trips" : {
"route_id" : (str, True) ,
"service_id" : (str, True) ,
"trip_id" : (str, True) ,
"wheelchair_accessible" : (int, "0"),
"bikes_allowed": (int, "0"),
} ,
"routes" : {
"route_id" : (str, True),
"agency_id" : (str, "NULL"),
"route_short_name": (str, "NULL"),
"route_long_name": (str, "NULL"),
"route_desc": (str, "NULL"),
"route_type": (int, True),
"route_url": (str, "NULL"),
"route_color": (str, "NULL"),
"route_text_color": (str, "NULL"),
},
"frequencies" : {
"trip_id" : (str, True),
"start_time" : (float, True),
"end_time" : (float, True),
"headway_secs" : (int, True)
},
"linefeatures" : { # Non-GTFS table for relating network line features to stops and eids
"SourceOID" : (int, True),
"from_stop" : (str, True),
"to_stop" : (str, True),
"route_type" : (int, True),
"eid" : (int, True)
},
"schedules" : { # Non-GTFS table for each instance of a transit trip crossing a line
"SourceOIDKey" : (str, True),
"SourceOID" : (int, True),
"trip_id" : (str, True),
"start_time" : (float, True),
"end_time" : (float, True)
}
}
db = None
def connect(dbname):
global db
if db == None:
db = sqlite3.connect(dbname)
# Turn off journaling and synchronous mode to make things run faster.
# We don't care about data corruption and backups because if sqlite crashes,
# the user will have to re-run this tool anyway.
c = db.cursor()
c.execute("PRAGMA journal_mode = OFF;")
c.execute("PRAGMA synchronous = OFF;")
db.commit()
c.close()
def check_time_str(s):
'''Check that the string s is a valid clock time of the form HH:MM:SS.'''
if not re.match('^-?\d?\d:\d\d:\d\d$', s):
return False
return True
def make_add_agency_labels(service, columns):
'''Make a function that adds ${service}_* labels to the *_id columns
of a row of data.'''
service = re.sub("[^A-Za-z0-9]", "", service)
# Figure out which columns need labelling:
s = set()
for idx,field in enumerate(columns):
if (field.endswith("_id") and field != "direction_id") or field == "parent_station":
s.add(idx)
# ... and here's the function:
def add_labels(row):
ret = list(row)
for idx in s:
if not row[idx]:
ret[idx] = row[idx]
else:
ret[idx] = "%s:%s" % (service, row[idx].strip())
return tuple(ret)
return add_labels
def make_remove_extra_fields(tablename, columns):
'''Make a function that removes extraneous columns from the CSV rows.'''
orig_num_fields = len(columns)
# Identify the extraneous columns:
cols = [ ]
tbl = sql_schema[tablename]
for idx,field in enumerate(columns):
if field not in tbl:
cols.append(idx)
cols.reverse()
# ... and here's the function:
def drop_fields(in_row):
out_row = list(in_row)
# Check that row was the correct length in the first place.
if len(out_row) != orig_num_fields:
msg = u"GTFS table %s contains at least one row with the wrong number of fields. Fields: %s; Row: %s" % (tablename, columns, str(in_row))
Errors_To_Return.append(msg)
raise CustomError
# Remove the row entries for the extraneous columns
for idx in cols:
out_row.pop(idx)
return tuple(out_row)
return drop_fields
def check_for_required_fields(tablename, columns, dataset):
'''Check that the GTFS file has the required fields'''
for col in sql_schema[tablename]:
if sql_schema[tablename][col][1] == True:
if not col in columns:
msg = u"GTFS file " + tablename + u".txt in dataset " + dataset + u" is missing required field '" + col + u"'. Failed to SQLize GTFS data."
Errors_To_Return.append(msg)
raise CustomError
def smarter_convert_times(rows, col_names, fname, GTFSdir, time_columns=('arrival_time', 'departure_time')):
'''Parses time fields according to the column name. Accepts HMS or numeric
times, converting to seconds-since-midnight.'''
time_column_idxs = [col_names.index(x) for x in time_columns]
def convert_time_columns(row):
out_row = row[:] # copy
for idx in time_column_idxs:
field = row[idx].strip()
if check_time_str(field):
out_row[idx] = hms.str2sec(field)
elif field == '':
msg = u"GTFS dataset " + GTFSdir + u" contains empty \
values for arrival_time or departure_time in stop_times.txt. Although the \
GTFS spec allows empty values for these fields, this toolbox \
requires exact time values for all stops. You will not be able to use this \
dataset for your analysis."
Errors_To_Return.append(msg)
raise CustomError
else:
try:
out_row[idx] = float (field)
except ValueError:
msg = u'Column "' + col_names[idx] + u'" in file ' + os.path.join(GTFSdir, fname) + u' has an invalid value:' + field + u'.'
Errors_To_Return.append(msg)
raise CustomError
return out_row
return itertools.imap (convert_time_columns, rows)
def check_date_fields(rows, col_names, tablename, fname):
'''Ensure date fields are the in the correct YYYYMMDD format before adding them to the SQL table'''
def check_date_cols(row):
if tablename == "calendar":
date_cols = ["start_date", "end_date"]
elif tablename == "calendar_dates":
date_cols = ["date"]
date_column_idxs = [col_names.index(x) for x in date_cols]
for idx in date_column_idxs:
date = row[idx]
try:
datetime.datetime.strptime(date, '%Y%m%d')
except ValueError:
msg = u'Column "' + col_names[idx] + u'" in file ' + fname + u' has an invalid value: ' + date + u'. \
Date fields must be in YYYYMMDD format. Please check the date field formatting in calendar.txt and calendar_dates.txt.'
Errors_To_Return.append(msg)
raise CustomError
return row
return itertools.imap(check_date_cols, rows)
def check_latlon_fields(rows, col_names, fname):
'''Ensure lat/lon fields are valid'''
def check_latlon_cols(row):
stop_id = row[col_names.index("stop_id")]
stop_lat = row[col_names.index("stop_lat")]
stop_lon = row[col_names.index("stop_lon")]
try:
stop_lat_float = float(stop_lat)
except ValueError:
msg = u'stop_id "%s" in %s contains an invalid non-numerical value \
for the stop_lat field: "%s". Please double-check all lat/lon values in your \
stops.txt file.' % (stop_id, fname, stop_lat)
Errors_To_Return.append(msg)
raise CustomError
try:
stop_lon_float = float(stop_lon)
except ValueError:
msg = u'stop_id "%s" in %s contains an invalid non-numerical value \
for the stop_lon field: "%s". Please double-check all lat/lon values in your \
stops.txt file.' % (stop_id, fname, stop_lon)
Errors_To_Return.append(msg)
raise CustomError
if not (-90.0 <= stop_lat_float <= 90.0):
msg = u'stop_id "%s" in %s contains an invalid value outside the \
range (-90, 90) the stop_lat field: "%s". stop_lat values must be in valid WGS 84 \
coordinates. Please double-check all lat/lon values in your stops.txt file.\
' % (stop_id, fname, stop_lat)
Errors_To_Return.append(msg)
raise CustomError
if not (-180.0 <= stop_lon_float <= 180.0):
msg = u'stop_id "%s" in %s contains an invalid value outside the \
range (-180, 180) the stop_lon field: "%s". stop_lon values must be in valid WGS 84 \
coordinates. Please double-check all lat/lon values in your stops.txt file.\
' % (stop_id, fname, stop_lon)
Errors_To_Return.append(msg)
raise CustomError
return row
return itertools.imap(check_latlon_cols, rows)
def column_specs(tablename):
'''Turns the sql_schema python datastructure above into the appropriate
column specs for a CREATE TABLE statement. Used in create_table().'''
tblspec = sql_schema[tablename]
lines = [ "id INTEGER PRIMARY KEY" ]
for col_name in tblspec:
col_type,required = tblspec[col_name]
data_type = sql_types[col_type]
if required is True:
defaults_str = ""
else:
defaults_str = " DEFAULT %s" % required
lines.append ("%s\t%s%s" % (col_name, data_type, defaults_str))
return " ,\n".join (lines)
def create_table(tablename):
cur = db.cursor()
cur.execute("DROP TABLE IF EXISTS %s;" % tablename)
create_stmt = "CREATE TABLE %s (%s);" % (tablename, column_specs (tablename))
cur.execute(create_stmt)
db.commit()
cur.close()
def handle_file(fname, service_label):
'''Creates and populates a table for the given CSV file.'''
if fname.endswith(".txt"):
tablename = fname[:-4]
else:
tablename = fname
tablename = os.path.basename(tablename)
#-- Read in everything from the CSV table
f = open(fname)
reader = csv.reader(f)
# Put everything in utf-8 to handle BOMs and weird characters.
# Eliminate blank rows (extra newlines) while we're at it.
reader = ([x.decode('utf-8-sig').strip() for x in r] for r in reader if len(r) > 0)
# First row is column names:
columns = [name.strip() for name in reader.next()]
#-- Do some data validity checking and reformatting
# Check that all required fields are present
check_for_required_fields(tablename, columns, service_label)
# This is the only file with HH:MM:SS time strings. Convert to seconds since midnight.
if tablename == "stop_times":
rows = smarter_convert_times(reader, columns, fname, service_label)
elif tablename == "frequencies":
rows = smarter_convert_times(reader, columns, fname, service_label, ('start_time', 'end_time'))
# Make sure date fields are in YYYYMMDD format
elif tablename in ["calendar", "calendar_dates"]:
rows = check_date_fields(reader, columns, tablename, fname)
# Make sure lat/lon values are valid
elif tablename == "stops":
rows = check_latlon_fields(reader, columns, fname)
# Otherwise just leave them as they are
else:
rows = reader
# Prepare functions for adding agency labels and filtering out unrequired columns
labeller = make_add_agency_labels(service_label, columns)
columns_filter = make_remove_extra_fields(tablename, columns)
# Remove unnecessary columns
columns = columns_filter(columns)
# Add agency labels for merged datasets
rows = itertools.imap(labeller, rows)
# Remove data from columns that aren't in the spec
rows = itertools.imap(columns_filter, rows)
# Add to the SQL table
cur = db.cursor()
values_placeholders = ["?"] * len(columns)
cur.executemany("INSERT INTO %s (%s) VALUES (%s);" %
(tablename,
",".join(columns),
",".join(values_placeholders))
, rows)
db.commit()
cur.close()
f.close()
def handle_agency(gtfs_dir):
'''Parses the relevant parts of an agency's GTFS CSV files into
the sqlite database. Returns a list of error messages from some basic
GTFS dataset validation'''
try:
csvs_withPaths = []
# Create a dataset label
label = os.path.basename(os.path.normpath(gtfs_dir))
# Verify that the required files are present
missing_files = []
has_a_calendar = 0
for fname in csv_fnames:
fname2 = os.path.join(gtfs_dir, fname)
if os.path.exists(fname2):
csvs_withPaths.append(fname2)
# We must have at least one of calendar or calendar_dates
if fname in ["calendar_dates.txt", "calendar.txt"]:
has_a_calendar = 1
else:
# These files aren't required
if fname not in ["calendar.txt", "calendar_dates.txt", "frequencies.txt"]:
missing_files.append(fname)
if not has_a_calendar:
missing_files.append("calendar.txt or calendar_dates.txt")
if missing_files:
Errors_To_Return.append(u"GTFS dataset %s is missing files required for \
this tool: %s" % (label, str(missing_files)))
return Errors_To_Return
# Sqlize each GTFS file
for fname2 in csvs_withPaths:
handle_file(fname2, label)
# Return any errors we collected, or an empty list if there were none.
return Errors_To_Return
except UnicodeDecodeError:
Errors_To_Return.append(u"Unicode decoding of GTFS dataset %s failed. Please \
ensure that your GTFS files have the proper utf-8 encoding required by the GTFS \
specification." % label)
return Errors_To_Return
except CustomError:
return Errors_To_Return
except:
raise
def create_indices():
cur = db.cursor()
cur.execute("CREATE INDEX trips_index_serviceIDs ON trips (service_id);")
cur.execute("CREATE INDEX trips_index_tripIDs ON trips (trip_id);")
cur.execute("CREATE INDEX stops_index_locationtype ON stops (location_type, parent_station);")
cur.execute("CREATE INDEX stopTimes_index_tripIdsSeq ON stop_times (trip_id, stop_sequence);")
db.commit()
cur.close()
def check_nonoverlapping_dateranges():
'''Check for non-overlapping date ranges in calendar.txt to prevent
double-counting in analyses that use generic weekdays.'''
# Function by Melinda Morang, Esri
# Only do this if we have a calendar table from calendar.txt.
c = db.cursor()
GetTblNamesStmt = "SELECT name FROM sqlite_master WHERE type='table' AND name='calendar';"
c.execute(GetTblNamesStmt)
tblnames = c.fetchall()
if tblnames:
# Check for non-overlapping date ranges to prevent double-counting.
serviceidlist = []
startdatedict = {}
enddatedict = {}
overlapwarning = ""
nonoverlappingsids = []
# Find all the service_ids.
serviceidfetch = '''
SELECT service_id, start_date, end_date FROM calendar
;'''
c.execute(serviceidfetch)
ids = c.fetchall()
for id in ids:
# Add to the list of service_ids
serviceidlist.append(id[0])
startdatedict[id[0]] = id[1]
enddatedict[id[0]] = id[2]
# Check for non-overlapping date ranges.
for sid in serviceidlist:
for eid in serviceidlist:
if startdatedict[sid] > enddatedict[eid]:
nonoverlappingsids.append([sid, eid])
if len(nonoverlappingsids) >= 10:
break
if len(nonoverlappingsids) >= 10:
break
if nonoverlappingsids:
overlapwarning = u"Warning! Your calendar.txt file(s) contain(s) \
non-overlapping date ranges. As a result, your analysis might double \
count the number of trips available if you are analyzing a generic weekday \
instead of a specific date. This is especially likely if the \
non-overlapping pairs are in the same GTFS dataset. Please check the date \
ranges in your calendar.txt file(s). See the User's Guide for further \
assistance. Date ranges do not overlap in the following pairs of service_ids: "
if len(nonoverlappingsids) == 10:
overlapwarning += "(Showing the first 10 non-overlaps) "
overlapwarning += str(nonoverlappingsids)
# Close up the SQL file.
c.close()
return overlapwarning