forked from spotify/luigi
-
Notifications
You must be signed in to change notification settings - Fork 0
/
postgres.py
344 lines (291 loc) · 11.8 KB
/
postgres.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
# -*- coding: utf-8 -*-
#
# Copyright 2012-2015 Spotify AB
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Implements a subclass of :py:class:`~luigi.target.Target` that writes data to Postgres.
Also provides a helper task to copy data into a Postgres table.
"""
import datetime
import logging
import re
import tempfile
from luigi import six
import luigi
from luigi.contrib import rdbms
logger = logging.getLogger('luigi-interface')
try:
import psycopg2
import psycopg2.errorcodes
import psycopg2.extensions
except ImportError:
logger.warning("Loading postgres module without psycopg2 installed. Will crash at runtime if postgres functionality is used.")
class MultiReplacer(object):
"""
Object for one-pass replace of multiple words
Substituted parts will not be matched against other replace patterns, as opposed to when using multipass replace.
The order of the items in the replace_pairs input will dictate replacement precedence.
Constructor arguments:
replace_pairs -- list of 2-tuples which hold strings to be replaced and replace string
Usage:
.. code-block:: python
>>> replace_pairs = [("a", "b"), ("b", "c")]
>>> MultiReplacer(replace_pairs)("abcd")
'bccd'
>>> replace_pairs = [("ab", "x"), ("a", "x")]
>>> MultiReplacer(replace_pairs)("ab")
'x'
>>> replace_pairs.reverse()
>>> MultiReplacer(replace_pairs)("ab")
'xb'
"""
# TODO: move to misc/util module
def __init__(self, replace_pairs):
"""
Initializes a MultiReplacer instance.
:param replace_pairs: list of 2-tuples which hold strings to be replaced and replace string.
:type replace_pairs: tuple
"""
replace_list = list(replace_pairs) # make a copy in case input is iterable
self._replace_dict = dict(replace_list)
pattern = '|'.join(re.escape(x) for x, y in replace_list)
self._search_re = re.compile(pattern)
def _replacer(self, match_object):
# this method is used as the replace function in the re.sub below
return self._replace_dict[match_object.group()]
def __call__(self, search_string):
# using function replacing for a per-result replace
return self._search_re.sub(self._replacer, search_string)
# these are the escape sequences recognized by postgres COPY
# according to http://www.postgresql.org/docs/8.1/static/sql-copy.html
default_escape = MultiReplacer([('\\', '\\\\'),
('\t', '\\t'),
('\n', '\\n'),
('\r', '\\r'),
('\v', '\\v'),
('\b', '\\b'),
('\f', '\\f')
])
class PostgresTarget(luigi.Target):
"""
Target for a resource in Postgres.
This will rarely have to be directly instantiated by the user.
"""
marker_table = luigi.configuration.get_config().get('postgres', 'marker-table', 'table_updates')
# Use DB side timestamps or client side timestamps in the marker_table
use_db_timestamps = True
def __init__(
self, host, database, user, password, table, update_id, port=None
):
"""
Args:
host (str): Postgres server address. Possibly a host:port string.
database (str): Database name
user (str): Database user
password (str): Password for specified user
update_id (str): An identifier for this data set
port (int): Postgres server port.
"""
if ':' in host:
self.host, self.port = host.split(':')
else:
self.host = host
self.port = port
self.database = database
self.user = user
self.password = password
self.table = table
self.update_id = update_id
def touch(self, connection=None):
"""
Mark this update as complete.
Important: If the marker table doesn't exist, the connection transaction will be aborted
and the connection reset.
Then the marker table will be created.
"""
self.create_marker_table()
if connection is None:
# TODO: test this
connection = self.connect()
connection.autocommit = True # if connection created here, we commit it here
if self.use_db_timestamps:
connection.cursor().execute(
"""INSERT INTO {marker_table} (update_id, target_table)
VALUES (%s, %s)
""".format(marker_table=self.marker_table),
(self.update_id, self.table))
else:
connection.cursor().execute(
"""INSERT INTO {marker_table} (update_id, target_table, inserted)
VALUES (%s, %s, %s);
""".format(marker_table=self.marker_table),
(self.update_id, self.table,
datetime.datetime.now()))
# make sure update is properly marked
assert self.exists(connection)
def exists(self, connection=None):
if connection is None:
connection = self.connect()
connection.autocommit = True
cursor = connection.cursor()
try:
cursor.execute("""SELECT 1 FROM {marker_table}
WHERE update_id = %s
LIMIT 1""".format(marker_table=self.marker_table),
(self.update_id,)
)
row = cursor.fetchone()
except psycopg2.ProgrammingError as e:
if e.pgcode == psycopg2.errorcodes.UNDEFINED_TABLE:
row = None
else:
raise
return row is not None
def connect(self):
"""
Get a psycopg2 connection object to the database where the table is.
"""
connection = psycopg2.connect(
host=self.host,
port=self.port,
database=self.database,
user=self.user,
password=self.password)
connection.set_client_encoding('utf-8')
return connection
def create_marker_table(self):
"""
Create marker table if it doesn't exist.
Using a separate connection since the transaction might have to be reset.
"""
connection = self.connect()
connection.autocommit = True
cursor = connection.cursor()
if self.use_db_timestamps:
sql = """ CREATE TABLE {marker_table} (
update_id TEXT PRIMARY KEY,
target_table TEXT,
inserted TIMESTAMP DEFAULT NOW())
""".format(marker_table=self.marker_table)
else:
sql = """ CREATE TABLE {marker_table} (
update_id TEXT PRIMARY KEY,
target_table TEXT,
inserted TIMESTAMP);
""".format(marker_table=self.marker_table)
try:
cursor.execute(sql)
except psycopg2.ProgrammingError as e:
if e.pgcode == psycopg2.errorcodes.DUPLICATE_TABLE:
pass
else:
raise
connection.close()
def open(self, mode):
raise NotImplementedError("Cannot open() PostgresTarget")
class CopyToTable(rdbms.CopyToTable):
"""
Template task for inserting a data set into Postgres
Usage:
Subclass and override the required `host`, `database`, `user`,
`password`, `table` and `columns` attributes.
To customize how to access data from an input task, override the `rows` method
with a generator that yields each row as a tuple with fields ordered according to `columns`.
"""
def rows(self):
"""
Return/yield tuples or lists corresponding to each row to be inserted.
"""
with self.input().open('r') as fobj:
for line in fobj:
yield line.strip('\n').split('\t')
def map_column(self, value):
"""
Applied to each column of every row returned by `rows`.
Default behaviour is to escape special characters and identify any self.null_values.
"""
if value in self.null_values:
return r'\N'
else:
return default_escape(six.text_type(value))
# everything below will rarely have to be overridden
def output(self):
"""
Returns a PostgresTarget representing the inserted dataset.
Normally you don't override this.
"""
return PostgresTarget(
host=self.host,
database=self.database,
user=self.user,
password=self.password,
table=self.table,
update_id=self.update_id()
)
def copy(self, cursor, file):
if isinstance(self.columns[0], six.string_types):
column_names = self.columns
elif len(self.columns[0]) == 2:
column_names = [c[0] for c in self.columns]
else:
raise Exception('columns must consist of column strings or (column string, type string) tuples (was %r ...)' % (self.columns[0],))
cursor.copy_from(file, self.table, null=r'\N', sep=self.column_separator, columns=column_names)
def run(self):
"""
Inserts data generated by rows() into target table.
If the target table doesn't exist, self.create_table will be called to attempt to create the table.
Normally you don't want to override this.
"""
if not (self.table and self.columns):
raise Exception("table and columns need to be specified")
connection = self.output().connect()
# transform all data generated by rows() using map_column and write data
# to a temporary file for import using postgres COPY
tmp_dir = luigi.configuration.get_config().get('postgres', 'local-tmp-dir', None)
tmp_file = tempfile.TemporaryFile(dir=tmp_dir)
n = 0
for row in self.rows():
n += 1
if n % 100000 == 0:
logger.info("Wrote %d lines", n)
rowstr = self.column_separator.join(self.map_column(val) for val in row)
rowstr += "\n"
tmp_file.write(rowstr.encode('utf-8'))
logger.info("Done writing, importing at %s", datetime.datetime.now())
tmp_file.seek(0)
# attempt to copy the data into postgres
# if it fails because the target table doesn't exist
# try to create it by running self.create_table
for attempt in range(2):
try:
cursor = connection.cursor()
self.init_copy(connection)
self.copy(cursor, tmp_file)
except psycopg2.ProgrammingError as e:
if e.pgcode == psycopg2.errorcodes.UNDEFINED_TABLE and attempt == 0:
# if first attempt fails with "relation not found", try creating table
logger.info("Creating table %s", self.table)
connection.reset()
self.create_table(connection)
else:
raise
else:
break
# mark as complete in same transaction
self.output().touch(connection)
# commit and clean up
connection.commit()
connection.close()
tmp_file.close()