Skip to content
Permalink
 
 
Cannot retrieve contributors at this time
# -*- coding: utf-8 -*-
"""
redshift_psql.py is a collection of functions used to load data from s3 into
redshift
Example Instantiation:
rspg = RedshiftPostrges("config.yaml", "stream_name", "pg_auth_file",
run_local=True)
"""
import socket
import time
from datetime import datetime
import boto
import staticconf
from dateutil.parser import parse as parsedate
from staticconf import read_string
from staticconf import YamlConfiguration
import psycopg2
from psycopg2.extensions import ISOLATION_LEVEL_AUTOCOMMIT
from psycopg2.extensions import QueryCanceledError
from sherlock.common.aws import get_aws_creds
ADD_SCHEMA_PATH = "SET search_path TO '$user', public, %(schema_path)s"
DEFAULT_NAMESPACE = "public"
# Copied from http://initd.org/psycopg/articles/2014/07/20/cancelling-postgresql-statements-python/
from select import select
from psycopg2.extensions import POLL_OK, POLL_READ, POLL_WRITE
def wait_select_inter(conn):
while True:
try:
state = conn.poll()
if state == POLL_OK:
break
elif state == POLL_READ:
select([conn.fileno()], [], [])
elif state == POLL_WRITE:
select([], [conn.fileno()], [])
else:
raise conn.OperationalError(
"bad state from poll: %s" % state)
except KeyboardInterrupt:
conn.cancel()
# the loop will be broken by a server error
continue
def get_namespaced_tablename(tablename, schemaname=None):
if schemaname is None:
rs_schema = get_redshift_schema()
else:
# note we do lower for backward compatability
rs_schema = schemaname.lower()
if rs_schema == DEFAULT_NAMESPACE:
return tablename
return rs_schema + "." + tablename
def get_redshift_schema():
# note we do lower for backward compatability
return read_string('redshift_schema', DEFAULT_NAMESPACE).lower()
class RedshiftPostgres(object):
"""
This class simplifies running queries on redshift. The current purpose is
for creating tables, and copying data into them from S3. However, it can
be used for general SQL commands.
Constructor Args:
logdir -- the directory where the logs go
logstrm -- a PipelineStreamLogger to record starts, completes and
failed sql commands
psql_auth_file -- the file from which we get a username and password for a
redshift account
run_local -- whether to run locally or not
"""
# this should give 1 hour for a sql command to complete
SECONDS_BEFORE_SENDING_PROBE = 1
SECONDS_BETWEEN_SENDING_PROBE = 60
RETRIES_BEFORE_QUIT = 60
def __init__(self, logstrm, psql_auth_file, run_local=False):
self.run_local = run_local
self.host = staticconf.read_string('redshift_host')
self.port = staticconf.read_int('redshift_port')
private_dict = YamlConfiguration(psql_auth_file)
self.user = private_dict['redshift_user']
self.password = private_dict['redshift_password']
self.log_stream = logstrm
self._aws_key = ''
self._aws_secret = ''
self._aws_token = ''
self._aws_token_expiry = datetime.utcnow()
self._whitelist = ['select', 'create', 'insert', 'update']
self._set_aws_auth()
psycopg2.extensions.set_wait_callback(wait_select_inter)
def _set_aws_auth(self):
"""
_set_aws_auth gets key, secret, token and expiration either from a
file or from a temporary instance and sets them
"""
cred_tuple = get_aws_creds(self.run_local)
self._aws_key = cred_tuple.access_key_id
self._aws_secret = cred_tuple.secret_access_key
self._aws_token = cred_tuple.token
self._aws_token_expiry = parsedate(cred_tuple.expiration)
def get_boto_config(self):
boto_dict = {}
for section in boto.config.sections():
boto_dict[section] = {}
for option in boto.config.options(section):
if option != 'aws_secret_access_key':
boto_dict[section][option] = boto.config.get(section, option)
else:
boto_dict[section][option] = "xxxxxxxxxxxxxxxx"
return boto_dict
def get_connection(self, database):
"""
gets a connection to the a psql database
Args:
self.password -- the password to the database
Returns:
a connection object
"""
# additional logging to help with connection issues
boto_config_dict = self.get_boto_config()
self.log_stream.write_msg('boto', extra_msg=boto_config_dict)
log_template = "getting connection with host {0} port {1} user {2} db {3}"
log_msg = log_template.format(self.host, self.port, self.user, database)
self.log_stream.write_msg('starting', extra_msg=log_msg)
conn = psycopg2.connect(
host=self.host,
port=self.port,
user=self.user,
password=self.password,
database=database,
sslmode='require')
self.log_stream.write_msg('finished', extra_msg=log_msg)
fd = conn.fileno()
sock = socket.fromfd(fd, socket.AF_INET, socket.SOCK_STREAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE,
self.SECONDS_BEFORE_SENDING_PROBE)
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL,
self.SECONDS_BETWEEN_SENDING_PROBE)
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT,
self.RETRIES_BEFORE_QUIT)
return conn
def cleanse_sql(self, command):
"""
cleanses a psql command of any auth information
Args:
command -- the psql command
Returns:
the cleansed command
"""
cmd_list = command.split()
first_word = cmd_list[0]
if first_word.lower() in self._whitelist:
return " ".join(cmd_list)
return cmd_list[0] + " cleansed "
def run_sql_ex(self, sql, database, log_msg, s3_needed=False, params=None,
output=False, time_est_secs=10, need_commit=True,
schema=DEFAULT_NAMESPACE):
"""
run_sql takes a command and executes using the connection found
in get_conection.
Args:
sql -- the postgres command to run
database -- the database on which the command is to be run
log_msg -- a shortened message for what command we're running
s3_needed -- if the sql command requires s3 input this = True,
otherwise it's false. For example, a simple query of a table
would have s3_needed=False, while COPY from S3 would have
s3_needed=True.
params -- if there are any parameters for the command
output -- if the command Returns rows (e.g., a SELECT command))
time_est_secs -- how long the user estimates the command to run
in seconds. This is used to decide whether to get new certs
or not
need_commit -- False if command does not need
to be committed (ex: vacuum)
schema -- the schema in the database on which the command is run
anything other than the default namespace must have the
schemaname added in the search_path. This ephemeral so
must be done on a per-session basis, and since we close
the cursor and connecion after each query we'll check
every time.
Returns:
if there's a return value, it is the results of the query
"""
start_time = time.time()
self.log_stream.write_msg('starting', extra_msg=log_msg)
exception = None
if s3_needed:
try:
if self._aws_token is None:
sql = sql.replace(';token=%s', '')
sql = sql % (self._aws_key, self._aws_secret)
else:
sql = sql % (self._aws_key, self._aws_secret, self._aws_token)
except TypeError as type_error:
self.log_stream.write_msg(
'error', error_msg=repr(type_error), extra_msg=sql
)
raise
try:
result = dict()
with self.get_connection(database) as conn:
if need_commit is False:
conn.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT)
cur = conn.cursor()
if schema != DEFAULT_NAMESPACE:
schema_params = {'schema_path': schema}
cur.execute(ADD_SCHEMA_PATH, schema_params)
if params:
cur.execute(sql, params)
else:
cur.execute(sql)
if output:
rows = cur.fetchall()
result['status'] = cur.statusmessage
cur.close()
self.log_stream.write_msg(
'finished', job_start_secs=start_time, extra_msg=log_msg
)
if output:
result['output'] = rows
return result
except QueryCanceledError as cmd_exception:
exception = cmd_exception
raise KeyboardInterrupt
except Exception as cmd_exception:
exception = cmd_exception
raise
finally:
self.log_stream.write_msg(
'error',
job_start_secs=start_time,
error_msg=repr(exception), extra_msg=log_msg
)
def run_sql(self, sql, database, log_msg, s3_needed=False, params=None,
output=False, time_est_secs=10, need_commit=True,
schema=DEFAULT_NAMESPACE):
result = self.run_sql_ex(
sql, database, log_msg, s3_needed, params,
output, time_est_secs, need_commit, schema=schema)
if result is False:
return False
return result['output'] if output is True else True