-
Notifications
You must be signed in to change notification settings - Fork 12.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Refactore the sql calling functions into the QueryRunner class.
- Loading branch information
Bogdan Kyryliuk
committed
Aug 17, 2016
1 parent
e1e3382
commit 939c588
Showing
9 changed files
with
471 additions
and
375 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,154 @@ | ||
import celery | ||
from caravel import models, app, utils, sql_lab_utils | ||
from datetime import datetime | ||
|
||
celery_app = celery.Celery(config_source=app.config.get('CELERY_CONFIG')) | ||
|
||
|
||
@celery_app.task | ||
def get_sql_results(query_id): | ||
"""Executes the sql query returns the results.""" | ||
sql_manager = QueryRunner(query_id) | ||
sql_manager.run_sql() | ||
# Return the result for the sync call. | ||
# if self.request.called_directly: | ||
if sql_manager.query().status == models.QueryStatus.FINISHED: | ||
return { | ||
'query_id': sql_manager.query().id, | ||
'status': sql_manager.query().status, | ||
'data': sql_manager.data(), | ||
'columns': sql_manager.columns(), | ||
} | ||
else: | ||
return { | ||
'query_id': sql_manager.query().id, | ||
'status': sql_manager.query().status, | ||
'error': sql_manager.query().error_message, | ||
} | ||
|
||
|
||
class QueryRunner: | ||
def __init__(self, query_id): | ||
self._query_id = query_id | ||
# Creates a separate session, reusing the db.session leads to the | ||
# concurrency issues. | ||
self._session = sql_lab_utils.create_scoped_session() | ||
self._query = self._session.query(models.Query).filter_by( | ||
id=query_id).first() | ||
self._db_to_query = self._session.query(models.Database).filter_by( | ||
id=self._query.database_id).first() | ||
# Query result. | ||
self._data = None | ||
self._columns = None | ||
|
||
def _sanity_check(self): | ||
if not self._query: | ||
self._query.error_message = "Query with id {0} not found.".format( | ||
self._query_id) | ||
if not self._db_to_query: | ||
self._query.error_message = ( | ||
"Database with id {0} is missing.".format( | ||
self._query.database_id) | ||
) | ||
|
||
if self._query.error_message: | ||
self._query.status = models.QueryStatus.FAILED | ||
self._session.flush() | ||
return False | ||
return True | ||
|
||
def query(self): | ||
return self._query | ||
|
||
def data(self): | ||
return self._data | ||
|
||
def columns(self): | ||
return self._columns | ||
|
||
def run_sql(self): | ||
if not self._sanity_check(): | ||
print("FAILED QUERY#" + str(self._query.id)) | ||
print(self._query.executed_sql) | ||
print(self._query.error_message) | ||
print(self._query.status) | ||
print(self._data) | ||
return self._query.status | ||
|
||
# TODO(bkyryliuk): dump results somewhere for the webserver. | ||
engine = self._db_to_query.get_sqla_engine(schema=self._query.schema) | ||
self._query.executed_sql = self._query.sql.strip().strip(';') | ||
|
||
# Limit enforced only for retrieving the data, not for the CTA queries. | ||
self._query.select_as_cta_used = False | ||
self._query.limit_used = False | ||
if sql_lab_utils.is_query_select(self._query.sql): | ||
if self._query.select_as_cta: | ||
if not self._query.tmp_table_name: | ||
self._query.tmp_table_name = 'tmp_{}_table_{}'.format( | ||
self._query.user_id, | ||
self._query.start_time.strftime('%Y_%m_%d_%H_%M_%S')) | ||
self._query.executed_sql = sql_lab_utils.create_table_as( | ||
self._query.executed_sql, self._query.tmp_table_name) | ||
self._query.select_as_cta_used = True | ||
elif self._query.limit: | ||
self._query.executed_sql = sql_lab_utils.add_limit_to_the_sql( | ||
self._query.executed_sql, self._query.limit, engine) | ||
self._query.limit_used = True | ||
|
||
# TODO(bkyryliuk): ensure that tmp table was created. | ||
# Do not set tmp table name if table wasn't created. | ||
if not self._query.select_as_cta_used: | ||
self._query.tmp_table_name = None | ||
self._get_sql_results(engine) | ||
|
||
self._query.end_time = datetime.now() | ||
self._session.flush() | ||
return self._query.status | ||
|
||
def _get_sql_results(self, engine): | ||
try: | ||
result_proxy = engine.execute( | ||
self._query.executed_sql, schema=self._query.schema) | ||
except Exception as e: | ||
self._query.error_message = utils.error_msg_from_exception(e) | ||
self._query.status = models.QueryStatus.FAILED | ||
return | ||
|
||
cursor = result_proxy.cursor | ||
if hasattr(cursor, "poll"): | ||
query_stats = cursor.poll() | ||
self._query.status = models.QueryStatus.IN_PROGRESS | ||
self._session.flush() | ||
# poll returns dict -- JSON status information or ``None`` | ||
# if the query is done | ||
# https://github.com/dropbox/PyHive/blob/ | ||
# b34bdbf51378b3979eaf5eca9e956f06ddc36ca0/pyhive/presto.py#L178 | ||
while query_stats: | ||
# Update the object and wait for the kill signal. | ||
self._session.refresh(self._query) | ||
completed_splits = int(query_stats['stats']['completedSplits']) | ||
total_splits = int(query_stats['stats']['totalSplits']) | ||
progress = 100 * completed_splits / total_splits | ||
if progress > self._query.progress: | ||
self._query.progress = progress | ||
|
||
self._session.flush() | ||
query_stats = cursor.poll() | ||
# TODO(b.kyryliuk): check for the kill signal. | ||
|
||
sql_results = sql_lab_utils.fetch_response_from_cursor( | ||
result_proxy) | ||
self._columns = sql_results['columns'] | ||
self._data = sql_results['data'] | ||
self._query.rows = result_proxy.rowcount | ||
self._query.status = models.QueryStatus.FINISHED | ||
|
||
# CTAs queries result in 1 cell having the # of the added rows. | ||
if self._query.select_as_cta_used: | ||
self._query.select_sql = sql_lab_utils.select_star( | ||
engine, self._query.tmp_table_name, self._query.limit) | ||
else: | ||
self._query.tmp_table = None | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
# SQL Lab Utils | ||
import pandas as pd | ||
|
||
import sqlparse | ||
from caravel import models, app | ||
from sqlalchemy import create_engine | ||
from sqlalchemy.orm import scoped_session, sessionmaker | ||
from sqlalchemy import select, text | ||
from sqlalchemy.sql.expression import TextAsFrom | ||
|
||
|
||
def create_scoped_session(): | ||
"""Creates new SQLAlchemy scoped_session.""" | ||
engine = create_engine( | ||
app.config.get('SQLALCHEMY_DATABASE_URI'), convert_unicode=True) | ||
return scoped_session(sessionmaker( | ||
autocommit=True, autoflush=False, bind=engine)) | ||
|
||
|
||
def fetch_response_from_cursor(result_proxy): | ||
columns = None | ||
data = None | ||
if result_proxy.cursor: | ||
cols = [col[0] for col in result_proxy.cursor.description] | ||
data = result_proxy.fetchall() | ||
df = pd.DataFrame(data, columns=cols) | ||
df = df.fillna(0) | ||
columns = [c for c in df.columns] | ||
data = df.to_dict(orient='records') | ||
return { | ||
'columns': columns, | ||
'data': data, | ||
} | ||
|
||
|
||
def is_query_select(sql): | ||
try: | ||
return sqlparse.parse(sql)[0].get_type() == 'SELECT' | ||
# Capture sqlparse exceptions, worker shouldn't fail here. | ||
except Exception: | ||
# TODO(bkyryliuk): add logging here. | ||
return False | ||
|
||
|
||
# if sqlparse provides the stream of tokens but don't provide the API | ||
# to access the table names, more on it: | ||
# https://groups.google.com/forum/#!topic/sqlparse/sL2aAi6dSJU | ||
# https://github.com/andialbrecht/sqlparse/blob/master/examples/ | ||
# extract_table_names.py | ||
# | ||
# Another approach would be to run the EXPLAIN on the sql statement: | ||
# https://prestodb.io/docs/current/sql/explain.html | ||
# https://cwiki.apache.org/confluence/display/Hive/LanguageManual+Explain | ||
def get_tables(): | ||
"""Retrieves the query names from the query.""" | ||
# TODO(bkyryliuk): implement parsing the sql statement. | ||
pass | ||
|
||
|
||
def select_star(engine, table_name, limit): | ||
if limit: | ||
select_star_sql = select('*').select_from(table_name).limit(limit) | ||
else: | ||
select_star_sql = select('*').select_from(table_name) | ||
|
||
# SQL code to preview the results | ||
return '{}'.format(select_star_sql.compile( | ||
engine, compile_kwargs={"literal_binds": True})) | ||
|
||
|
||
def add_limit_to_the_sql(sql, limit, eng): | ||
# Treat as single sql statement in case of failure. | ||
try: | ||
sql_statements = [s for s in sqlparse.split(sql) if s] | ||
except Exception as e: | ||
app.logger.info( | ||
"Statement " + sql + "failed to be transformed to have the limit " | ||
|
||
|
||
"with the exception" + e.message) | ||
return sql | ||
if len(sql_statements) == 1 and is_query_select(sql): | ||
qry = select('*').select_from( | ||
TextAsFrom(text(sql_statements[0]), ['*']).alias( | ||
'inner_qry')).limit(limit) | ||
sql_statement = str(qry.compile( | ||
eng, compile_kwargs={"literal_binds": True})) | ||
return sql_statement | ||
return sql | ||
|
||
|
||
# create table works only for the single statement. | ||
# TODO(bkyryliuk): enforce that all the columns have names. Presto requires it | ||
# for the CTA operation. | ||
def create_table_as(sql, table_name, override=False): | ||
"""Reformats the query into the create table as query. | ||
Works only for the single select SQL statements, in all other cases | ||
the sql query is not modified. | ||
:param sql: string, sql query that will be executed | ||
:param table_name: string, will contain the results of the query execution | ||
:param override, boolean, table table_name will be dropped if true | ||
:return: string, create table as query | ||
""" | ||
# TODO(bkyryliuk): drop table if allowed, check the namespace and | ||
# the permissions. | ||
# Treat as single sql statement in case of failure. | ||
try: | ||
# Filter out empty statements. | ||
sql_statements = [s for s in sqlparse.split(sql) if s] | ||
except Exception as e: | ||
app.logger.info( | ||
"Statement " + sql + "failed to be transformed as create table as " | ||
"with the exception" + e.message) | ||
return sql | ||
if len(sql_statements) == 1 and is_query_select(sql): | ||
updated_sql = '' | ||
# TODO(bkyryliuk): use sqlalchemy statements for the | ||
# the drop and create operations. | ||
if override: | ||
updated_sql = 'DROP TABLE IF EXISTS {};\n'.format(table_name) | ||
updated_sql += "CREATE TABLE %s AS %s" % ( | ||
table_name, sql_statements[0]) | ||
return updated_sql | ||
return sql |
Oops, something went wrong.