-
Notifications
You must be signed in to change notification settings - Fork 69
/
mysql.py
467 lines (369 loc) · 17.5 KB
/
mysql.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
"""drivers/sqlite.py
MySQL-backed raster driver. Metadata is stored in a MySQL database, raster data is assumed
to be present on disk.
"""
from typing import (Tuple, Dict, Iterator, Sequence, Union,
Mapping, Any, Optional, cast, TypeVar, NamedTuple)
from collections import OrderedDict
import contextlib
from contextlib import AbstractContextManager
import re
import json
import urllib.parse as urlparse
from urllib.parse import ParseResult
import numpy as np
import pymysql
from pymysql.connections import Connection
from pymysql.cursors import DictCursor
from terracotta import get_settings, __version__
from terracotta.drivers.raster_base import RasterDriver
from terracotta.drivers.base import requires_connection
from terracotta import exceptions
from terracotta.profile import trace
T = TypeVar('T')
_ERROR_ON_CONNECT = (
'Could not connect to database. Make sure that the given path points '
'to a valid Terracotta database, and that you ran driver.create().'
)
@contextlib.contextmanager
def convert_exceptions(msg: str) -> Iterator:
"""Convert internal mysql exceptions to our InvalidDatabaseError"""
from pymysql import OperationalError, InternalError, ProgrammingError
try:
yield
except (OperationalError, InternalError, ProgrammingError) as exc:
raise exceptions.InvalidDatabaseError(msg) from exc
class MySQLCredentials(NamedTuple):
host: str
port: int
db: str
user: Optional[str] = None
password: Optional[str] = None
class MySQLDriver(RasterDriver):
"""A MySQL-backed raster driver.
Assumes raster data to be present in separate GDAL-readable files on disk or remotely.
Stores metadata and paths to raster files in MySQL.
Requires a running MySQL server.
The MySQL database consists of 4 different tables:
- ``terracotta``: Metadata about the database itself.
- ``key_names``: Contains two columns holding all available keys and their description.
- ``datasets``: Maps key values to physical raster path.
- ``metadata``: Contains actual metadata as separate columns. Indexed via key values.
This driver caches raster data and key names, but not metadata.
"""
_MAX_PRIMARY_KEY_LENGTH = 767 // 4 # Max key length for MySQL is at least 767B
_METADATA_COLUMNS: Tuple[Tuple[str, ...], ...] = (
('bounds_north', 'REAL'),
('bounds_east', 'REAL'),
('bounds_south', 'REAL'),
('bounds_west', 'REAL'),
('convex_hull', 'LONGTEXT'),
('valid_percentage', 'REAL'),
('min', 'REAL'),
('max', 'REAL'),
('mean', 'REAL'),
('stdev', 'REAL'),
('percentiles', 'BLOB'),
('metadata', 'LONGTEXT')
)
_CHARSET: str = 'utf8mb4'
def __init__(self, mysql_path: str) -> None:
"""Initialize the MySQLDriver.
This should not be called directly, use :func:`~terracotta.get_driver` instead.
Arguments:
mysql_path: URL to running MySQL server, in the form
``mysql://username:password@hostname/database``
"""
settings = get_settings()
self.DB_CONNECTION_TIMEOUT: int = settings.DB_CONNECTION_TIMEOUT
con_params = urlparse.urlparse(mysql_path)
if not con_params.hostname:
con_params = urlparse.urlparse(f'mysql://{mysql_path}')
assert con_params.hostname is not None
if con_params.scheme != 'mysql':
raise ValueError(f'unsupported URL scheme "{con_params.scheme}"')
self._db_args = MySQLCredentials(
host=con_params.hostname,
user=con_params.username,
password=con_params.password,
port=con_params.port or 3306,
db=self._parse_db_name(con_params)
)
self._connection: Connection
self._cursor: DictCursor
self._connected = False
self._version_checked: bool = False
self._db_keys: Optional[OrderedDict] = None
qualified_path = self._build_qualified_path(self._db_args)
super().__init__(qualified_path)
@staticmethod
def _build_qualified_path(db_args: MySQLCredentials) -> str:
"""Convert path to mysql://({USER}(:[REDACTED])@){HOST}:{PORT}/{DB}"""
qualified_path = ['mysql://']
if db_args.user:
qualified_path.append(f'{db_args.user}')
if db_args.password:
qualified_path.append(':[REDACTED]')
qualified_path.append('@')
qualified_path.append(f'{db_args.host}:{db_args.port}/{db_args.db}')
return ''.join(qualified_path)
@staticmethod
def _parse_db_name(con_params: ParseResult) -> str:
if not con_params.path:
raise ValueError('database must be specified in MySQL path')
path = con_params.path.strip('/')
if len(path.split('/')) != 1:
raise ValueError('invalid database path')
return path
@requires_connection
@convert_exceptions(_ERROR_ON_CONNECT)
def _get_db_version(self) -> str:
"""Terracotta version used to create the database"""
cursor = self._cursor
cursor.execute('SELECT version from terracotta')
db_row = cast(Dict[str, str], cursor.fetchone())
return db_row['version']
db_version = cast(str, property(_get_db_version))
def _connection_callback(self) -> None:
if not self._version_checked:
# check for version compatibility
def versiontuple(version_string: str) -> Sequence[str]:
return version_string.split('.')
db_version = self.db_version
current_version = __version__
if versiontuple(db_version)[:2] != versiontuple(current_version)[:2]:
raise exceptions.InvalidDatabaseError(
f'Version conflict: database was created in v{db_version}, '
f'but this is v{current_version}'
)
self._version_checked = True
def _get_key_names(self) -> Tuple[str, ...]:
"""Names of all keys defined by the database"""
return tuple(self.get_keys().keys())
key_names = cast(Tuple[str], property(_get_key_names))
def connect(self) -> AbstractContextManager:
return self._connect(check=True)
@contextlib.contextmanager
def _connect(self, check: bool = True) -> Iterator:
close = False
try:
if not self._connected:
with convert_exceptions(_ERROR_ON_CONNECT):
self._connection = pymysql.connect(
host=self._db_args.host, user=self._db_args.user, db=self._db_args.db,
password=self._db_args.password, port=self._db_args.port,
read_timeout=self.DB_CONNECTION_TIMEOUT,
write_timeout=self.DB_CONNECTION_TIMEOUT,
binary_prefix=True, charset='utf8mb4'
)
self._cursor = cast(DictCursor, self._connection.cursor(DictCursor))
self._connected = close = True
if check:
self._connection_callback()
try:
yield
except Exception:
self._connection.rollback()
raise
finally:
if close:
self._cursor.close()
self._connection.commit()
self._connection.close()
self._connected = False
@convert_exceptions('Could not create database')
def create(self, keys: Sequence[str], key_descriptions: Mapping[str, str] = None) -> None:
"""Create and initialize database with empty tables.
This must be called before opening the first connection. The MySQL database must not
exist already.
Arguments:
keys: Key names to use throughout the Terracotta database.
key_descriptions: Optional (but recommended) full-text description for some keys,
in the form of ``{key_name: description}``.
"""
if key_descriptions is None:
key_descriptions = {}
else:
key_descriptions = dict(key_descriptions)
if not all(k in keys for k in key_descriptions.keys()):
raise exceptions.InvalidKeyError('key description dict contains unknown keys')
if not all(re.match(r'^\w+$', key) for key in keys):
raise exceptions.InvalidKeyError('key names must be alphanumeric')
if any(key in self._RESERVED_KEYS for key in keys):
raise exceptions.InvalidKeyError(f'key names cannot be one of {self._RESERVED_KEYS!s}')
for key in keys:
if key not in key_descriptions:
key_descriptions[key] = ''
# total primary key length has an upper limit in MySQL
key_size = self._MAX_PRIMARY_KEY_LENGTH // len(keys)
key_type = f'VARCHAR({key_size})'
with pymysql.connect(host=self._db_args.host, user=self._db_args.user,
password=self._db_args.password, port=self._db_args.port,
read_timeout=self.DB_CONNECTION_TIMEOUT,
write_timeout=self.DB_CONNECTION_TIMEOUT,
binary_prefix=True, charset='utf8mb4') as con:
con.execute(f'CREATE DATABASE {self._db_args.db}')
with self._connect(check=False):
cursor = self._cursor
cursor.execute(f'CREATE TABLE terracotta (version VARCHAR(255)) '
f'CHARACTER SET {self._CHARSET}')
cursor.execute('INSERT INTO terracotta VALUES (%s)', [str(__version__)])
cursor.execute(f'CREATE TABLE key_names (key_name {key_type}, '
f'description VARCHAR(8000)) CHARACTER SET {self._CHARSET}')
key_rows = [(key, key_descriptions[key]) for key in keys]
cursor.executemany('INSERT INTO key_names VALUES (%s, %s)', key_rows)
key_string = ', '.join([f'{key} {key_type}' for key in keys])
cursor.execute(f'CREATE TABLE datasets ({key_string}, filepath VARCHAR(8000), '
f'PRIMARY KEY({", ".join(keys)})) CHARACTER SET {self._CHARSET}')
column_string = ', '.join(f'{col} {col_type}' for col, col_type
in self._METADATA_COLUMNS)
cursor.execute(f'CREATE TABLE metadata ({key_string}, {column_string}, '
f'PRIMARY KEY ({", ".join(keys)})) CHARACTER SET {self._CHARSET}')
# invalidate key cache
self._db_keys = None
def get_keys(self) -> OrderedDict:
if self._db_keys is None:
self._db_keys = self._get_keys()
return self._db_keys
@requires_connection
@convert_exceptions('Could not retrieve keys from database')
def _get_keys(self) -> OrderedDict:
out: OrderedDict = OrderedDict()
cursor = self._cursor
cursor.execute('SELECT * FROM key_names')
key_rows = cursor.fetchall() or ()
for row in key_rows:
out[row['key_name']] = row['description']
return out
@trace('get_datasets')
@requires_connection
@convert_exceptions('Could not retrieve datasets')
def get_datasets(self, where: Mapping[str, str] = None,
page: int = 0, limit: int = None) -> Dict[Tuple[str, ...], str]:
cursor = self._cursor
if limit is not None:
# explicitly cast to int to prevent SQL injection
page_fragment = f'LIMIT {int(limit)} OFFSET {int(page) * int(limit)}'
else:
page_fragment = ''
# sort by keys to ensure deterministic results
order_fragment = f'ORDER BY {", ".join(self.key_names)}'
if where is None:
cursor.execute(f'SELECT * FROM datasets {order_fragment} {page_fragment}')
else:
if not all(key in self.key_names for key in where.keys()):
raise exceptions.InvalidKeyError('Encountered unrecognized keys in '
'where clause')
where_fragment = ' AND '.join([f'{key}=%s' for key in where.keys()])
cursor.execute(
f'SELECT * FROM datasets WHERE {where_fragment} {order_fragment} {page_fragment}',
list(where.values())
)
def keytuple(row: Dict[str, Any]) -> Tuple[str, ...]:
return tuple(row[key] for key in self.key_names)
datasets = {}
for row in cursor:
row = cast(Dict[str, Any], row)
datasets[keytuple(row)] = row['filepath']
return datasets
@staticmethod
def _encode_data(decoded: Mapping[str, Any]) -> Dict[str, Any]:
"""Transform from internal format to database representation"""
encoded = {
'bounds_north': decoded['bounds'][0],
'bounds_east': decoded['bounds'][1],
'bounds_south': decoded['bounds'][2],
'bounds_west': decoded['bounds'][3],
'convex_hull': json.dumps(decoded['convex_hull']),
'valid_percentage': decoded['valid_percentage'],
'min': decoded['range'][0],
'max': decoded['range'][1],
'mean': decoded['mean'],
'stdev': decoded['stdev'],
'percentiles': np.array(decoded['percentiles'], dtype='float32').tobytes(),
'metadata': json.dumps(decoded['metadata'])
}
return encoded
@staticmethod
def _decode_data(encoded: Mapping[str, Any]) -> Dict[str, Any]:
"""Transform from database format to internal representation"""
decoded = {
'bounds': tuple([encoded[f'bounds_{d}'] for d in ('north', 'east', 'south', 'west')]),
'convex_hull': json.loads(encoded['convex_hull']),
'valid_percentage': encoded['valid_percentage'],
'range': (encoded['min'], encoded['max']),
'mean': encoded['mean'],
'stdev': encoded['stdev'],
'percentiles': np.frombuffer(encoded['percentiles'], dtype='float32').tolist(),
'metadata': json.loads(encoded['metadata'])
}
return decoded
@trace('get_metadata')
@requires_connection
@convert_exceptions('Could not retrieve metadata')
def get_metadata(self, keys: Union[Sequence[str], Mapping[str, str]]) -> Dict[str, Any]:
keys = tuple(self._key_dict_to_sequence(keys))
if len(keys) != len(self.key_names):
raise exceptions.InvalidKeyError('Got wrong number of keys')
cursor = self._cursor
where_string = ' AND '.join([f'{key}=%s' for key in self.key_names])
cursor.execute(f'SELECT * FROM metadata WHERE {where_string}', keys)
row = cursor.fetchone()
if not row: # support lazy loading
filepath = self.get_datasets(dict(zip(self.key_names, keys)))
if not filepath:
raise exceptions.DatasetNotFoundError(f'No dataset found for given keys {keys}')
assert len(filepath) == 1
# compute metadata and try again
self.insert(keys, filepath[keys], skip_metadata=False)
cursor.execute(f'SELECT * FROM metadata WHERE {where_string}', keys)
row = cursor.fetchone()
assert row
data_columns, _ = zip(*self._METADATA_COLUMNS)
encoded_data = {col: row[col] for col in self.key_names + data_columns}
return self._decode_data(encoded_data)
@trace('insert')
@requires_connection
@convert_exceptions('Could not write to database')
def insert(self,
keys: Union[Sequence[str], Mapping[str, str]],
filepath: str, *,
metadata: Mapping[str, Any] = None,
skip_metadata: bool = False,
override_path: str = None) -> None:
cursor = self._cursor
if len(keys) != len(self.key_names):
raise exceptions.InvalidKeyError(
f'Got wrong number of keys (available keys: {self.key_names})'
)
if override_path is None:
override_path = filepath
keys = self._key_dict_to_sequence(keys)
template_string = ', '.join(['%s'] * (len(keys) + 1))
cursor.execute(f'REPLACE INTO datasets VALUES ({template_string})',
[*keys, override_path])
if metadata is None and not skip_metadata:
metadata = self.compute_metadata(filepath)
if metadata is not None:
encoded_data = self._encode_data(metadata)
row_keys, row_values = zip(*encoded_data.items())
template_string = ', '.join(['%s'] * (len(keys) + len(row_values)))
cursor.execute(f'REPLACE INTO metadata ({", ".join(self.key_names)}, '
f'{", ".join(row_keys)}) VALUES ({template_string})',
[*keys, *row_values])
@trace('delete')
@requires_connection
@convert_exceptions('Could not write to database')
def delete(self, keys: Union[Sequence[str], Mapping[str, str]]) -> None:
cursor = self._cursor
if len(keys) != len(self.key_names):
raise exceptions.InvalidKeyError(
f'Got wrong number of keys (available keys: {self.key_names})'
)
keys = self._key_dict_to_sequence(keys)
key_dict = dict(zip(self.key_names, keys))
if not self.get_datasets(key_dict):
raise exceptions.DatasetNotFoundError(f'No dataset found with keys {keys}')
where_string = ' AND '.join([f'{key}=%s' for key in self.key_names])
cursor.execute(f'DELETE FROM datasets WHERE {where_string}', keys)
cursor.execute(f'DELETE FROM metadata WHERE {where_string}', keys)