Skip to content

Commit

Permalink
Refactore the sql calling functions into the QueryRunner class.
Browse files Browse the repository at this point in the history
  • Loading branch information
Bogdan Kyryliuk committed Aug 17, 2016
1 parent e1e3382 commit 939c588
Show file tree
Hide file tree
Showing 9 changed files with 471 additions and 375 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def upgrade():
sa.Column('select_as_cta', sa.Boolean(), nullable=True),
sa.Column('select_as_cta_used', sa.Boolean(), nullable=True),
sa.Column('progress', sa.Integer(), nullable=True),
sa.Column('rows', sa.Integer(), nullable=True),
sa.Column('error_message', sa.Text(), nullable=True),
sa.Column('start_time', sa.DateTime(), nullable=True),
sa.Column('end_time', sa.DateTime(), nullable=True),
Expand Down
2 changes: 2 additions & 0 deletions caravel/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1761,6 +1761,8 @@ class Query(Model):

# 1..100
progress = Column(Integer)
# # of rows in the result set or rows modified.
rows = Column(Integer)
error_message = Column(Text)
start_time = Column(DateTime)
end_time = Column(DateTime)
154 changes: 154 additions & 0 deletions caravel/sql_lab.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
import celery
from caravel import models, app, utils, sql_lab_utils
from datetime import datetime

celery_app = celery.Celery(config_source=app.config.get('CELERY_CONFIG'))


@celery_app.task
def get_sql_results(query_id):
"""Executes the sql query returns the results."""
sql_manager = QueryRunner(query_id)
sql_manager.run_sql()
# Return the result for the sync call.
# if self.request.called_directly:
if sql_manager.query().status == models.QueryStatus.FINISHED:
return {
'query_id': sql_manager.query().id,
'status': sql_manager.query().status,
'data': sql_manager.data(),
'columns': sql_manager.columns(),
}
else:
return {
'query_id': sql_manager.query().id,
'status': sql_manager.query().status,
'error': sql_manager.query().error_message,
}


class QueryRunner:
def __init__(self, query_id):
self._query_id = query_id
# Creates a separate session, reusing the db.session leads to the
# concurrency issues.
self._session = sql_lab_utils.create_scoped_session()
self._query = self._session.query(models.Query).filter_by(
id=query_id).first()
self._db_to_query = self._session.query(models.Database).filter_by(
id=self._query.database_id).first()
# Query result.
self._data = None
self._columns = None

def _sanity_check(self):
if not self._query:
self._query.error_message = "Query with id {0} not found.".format(
self._query_id)
if not self._db_to_query:
self._query.error_message = (
"Database with id {0} is missing.".format(
self._query.database_id)
)

if self._query.error_message:
self._query.status = models.QueryStatus.FAILED
self._session.flush()
return False
return True

def query(self):
return self._query

def data(self):
return self._data

def columns(self):
return self._columns

def run_sql(self):
if not self._sanity_check():
print("FAILED QUERY#" + str(self._query.id))
print(self._query.executed_sql)
print(self._query.error_message)
print(self._query.status)
print(self._data)
return self._query.status

# TODO(bkyryliuk): dump results somewhere for the webserver.
engine = self._db_to_query.get_sqla_engine(schema=self._query.schema)
self._query.executed_sql = self._query.sql.strip().strip(';')

# Limit enforced only for retrieving the data, not for the CTA queries.
self._query.select_as_cta_used = False
self._query.limit_used = False
if sql_lab_utils.is_query_select(self._query.sql):
if self._query.select_as_cta:
if not self._query.tmp_table_name:
self._query.tmp_table_name = 'tmp_{}_table_{}'.format(
self._query.user_id,
self._query.start_time.strftime('%Y_%m_%d_%H_%M_%S'))
self._query.executed_sql = sql_lab_utils.create_table_as(
self._query.executed_sql, self._query.tmp_table_name)
self._query.select_as_cta_used = True
elif self._query.limit:
self._query.executed_sql = sql_lab_utils.add_limit_to_the_sql(
self._query.executed_sql, self._query.limit, engine)
self._query.limit_used = True

# TODO(bkyryliuk): ensure that tmp table was created.
# Do not set tmp table name if table wasn't created.
if not self._query.select_as_cta_used:
self._query.tmp_table_name = None
self._get_sql_results(engine)

self._query.end_time = datetime.now()
self._session.flush()
return self._query.status

def _get_sql_results(self, engine):
try:
result_proxy = engine.execute(
self._query.executed_sql, schema=self._query.schema)
except Exception as e:
self._query.error_message = utils.error_msg_from_exception(e)
self._query.status = models.QueryStatus.FAILED
return

cursor = result_proxy.cursor
if hasattr(cursor, "poll"):
query_stats = cursor.poll()
self._query.status = models.QueryStatus.IN_PROGRESS
self._session.flush()
# poll returns dict -- JSON status information or ``None``
# if the query is done
# https://github.com/dropbox/PyHive/blob/
# b34bdbf51378b3979eaf5eca9e956f06ddc36ca0/pyhive/presto.py#L178
while query_stats:
# Update the object and wait for the kill signal.
self._session.refresh(self._query)
completed_splits = int(query_stats['stats']['completedSplits'])
total_splits = int(query_stats['stats']['totalSplits'])
progress = 100 * completed_splits / total_splits
if progress > self._query.progress:
self._query.progress = progress

self._session.flush()
query_stats = cursor.poll()
# TODO(b.kyryliuk): check for the kill signal.

sql_results = sql_lab_utils.fetch_response_from_cursor(
result_proxy)
self._columns = sql_results['columns']
self._data = sql_results['data']
self._query.rows = result_proxy.rowcount
self._query.status = models.QueryStatus.FINISHED

# CTAs queries result in 1 cell having the # of the added rows.
if self._query.select_as_cta_used:
self._query.select_sql = sql_lab_utils.select_star(
engine, self._query.tmp_table_name, self._query.limit)
else:
self._query.tmp_table = None


125 changes: 125 additions & 0 deletions caravel/sql_lab_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# SQL Lab Utils
import pandas as pd

import sqlparse
from caravel import models, app
from sqlalchemy import create_engine
from sqlalchemy.orm import scoped_session, sessionmaker
from sqlalchemy import select, text
from sqlalchemy.sql.expression import TextAsFrom


def create_scoped_session():
"""Creates new SQLAlchemy scoped_session."""
engine = create_engine(
app.config.get('SQLALCHEMY_DATABASE_URI'), convert_unicode=True)
return scoped_session(sessionmaker(
autocommit=True, autoflush=False, bind=engine))


def fetch_response_from_cursor(result_proxy):
columns = None
data = None
if result_proxy.cursor:
cols = [col[0] for col in result_proxy.cursor.description]
data = result_proxy.fetchall()
df = pd.DataFrame(data, columns=cols)
df = df.fillna(0)
columns = [c for c in df.columns]
data = df.to_dict(orient='records')
return {
'columns': columns,
'data': data,
}


def is_query_select(sql):
try:
return sqlparse.parse(sql)[0].get_type() == 'SELECT'
# Capture sqlparse exceptions, worker shouldn't fail here.
except Exception:
# TODO(bkyryliuk): add logging here.
return False


# if sqlparse provides the stream of tokens but don't provide the API
# to access the table names, more on it:
# https://groups.google.com/forum/#!topic/sqlparse/sL2aAi6dSJU
# https://github.com/andialbrecht/sqlparse/blob/master/examples/
# extract_table_names.py
#
# Another approach would be to run the EXPLAIN on the sql statement:
# https://prestodb.io/docs/current/sql/explain.html
# https://cwiki.apache.org/confluence/display/Hive/LanguageManual+Explain
def get_tables():
"""Retrieves the query names from the query."""
# TODO(bkyryliuk): implement parsing the sql statement.
pass


def select_star(engine, table_name, limit):
if limit:
select_star_sql = select('*').select_from(table_name).limit(limit)
else:
select_star_sql = select('*').select_from(table_name)

# SQL code to preview the results
return '{}'.format(select_star_sql.compile(
engine, compile_kwargs={"literal_binds": True}))


def add_limit_to_the_sql(sql, limit, eng):
# Treat as single sql statement in case of failure.
try:
sql_statements = [s for s in sqlparse.split(sql) if s]
except Exception as e:
app.logger.info(
"Statement " + sql + "failed to be transformed to have the limit "


"with the exception" + e.message)
return sql
if len(sql_statements) == 1 and is_query_select(sql):
qry = select('*').select_from(
TextAsFrom(text(sql_statements[0]), ['*']).alias(
'inner_qry')).limit(limit)
sql_statement = str(qry.compile(
eng, compile_kwargs={"literal_binds": True}))
return sql_statement
return sql


# create table works only for the single statement.
# TODO(bkyryliuk): enforce that all the columns have names. Presto requires it
# for the CTA operation.
def create_table_as(sql, table_name, override=False):
"""Reformats the query into the create table as query.
Works only for the single select SQL statements, in all other cases
the sql query is not modified.
:param sql: string, sql query that will be executed
:param table_name: string, will contain the results of the query execution
:param override, boolean, table table_name will be dropped if true
:return: string, create table as query
"""
# TODO(bkyryliuk): drop table if allowed, check the namespace and
# the permissions.
# Treat as single sql statement in case of failure.
try:
# Filter out empty statements.
sql_statements = [s for s in sqlparse.split(sql) if s]
except Exception as e:
app.logger.info(
"Statement " + sql + "failed to be transformed as create table as "
"with the exception" + e.message)
return sql
if len(sql_statements) == 1 and is_query_select(sql):
updated_sql = ''
# TODO(bkyryliuk): use sqlalchemy statements for the
# the drop and create operations.
if override:
updated_sql = 'DROP TABLE IF EXISTS {};\n'.format(table_name)
updated_sql += "CREATE TABLE %s AS %s" % (
table_name, sql_statements[0])
return updated_sql
return sql

0 comments on commit 939c588

Please sign in to comment.