Skip to content

Commit

Permalink
feat: refactor all get_sqla_engine to use contextmanager in codebase (
Browse files Browse the repository at this point in the history
  • Loading branch information
hughhhh committed Nov 15, 2022
1 parent 06f87e1 commit e23efef
Show file tree
Hide file tree
Showing 41 changed files with 638 additions and 598 deletions.
12 changes: 6 additions & 6 deletions superset/connectors/sqla/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -804,13 +804,13 @@ def values_for_column(self, column_name: str, limit: int = 10000) -> List[Any]:
if self.fetch_values_predicate:
qry = qry.where(self.get_fetch_values_predicate())

engine = self.database.get_sqla_engine()
sql = qry.compile(engine, compile_kwargs={"literal_binds": True})
sql = self._apply_cte(sql, cte)
sql = self.mutate_query_from_config(sql)
with self.database.get_sqla_engine_with_context() as engine:
sql = qry.compile(engine, compile_kwargs={"literal_binds": True})
sql = self._apply_cte(sql, cte)
sql = self.mutate_query_from_config(sql)

df = pd.read_sql_query(sql=sql, con=engine)
return df[column_name].to_list()
df = pd.read_sql_query(sql=sql, con=engine)
return df[column_name].to_list()

def mutate_query_from_config(self, sql: str) -> str:
"""Apply config's SQL_QUERY_MUTATOR
Expand Down
39 changes: 23 additions & 16 deletions superset/connectors/sqla/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,6 @@ def get_virtual_table_metadata(dataset: SqlaTable) -> List[ResultSetColumnType]:
)

db_engine_spec = dataset.database.db_engine_spec
engine = dataset.database.get_sqla_engine(schema=dataset.schema)
sql = dataset.get_template_processor().process_template(
dataset.sql, **dataset.template_params_dict
)
Expand All @@ -137,13 +136,18 @@ def get_virtual_table_metadata(dataset: SqlaTable) -> List[ResultSetColumnType]:
# TODO(villebro): refactor to use same code that's used by
# sql_lab.py:execute_sql_statements
try:
with closing(engine.raw_connection()) as conn:
cursor = conn.cursor()
query = dataset.database.apply_limit_to_sql(statements[0], limit=1)
db_engine_spec.execute(cursor, query)
result = db_engine_spec.fetch_data(cursor, limit=1)
result_set = SupersetResultSet(result, cursor.description, db_engine_spec)
cols = result_set.columns
with dataset.database.get_sqla_engine_with_context(
schema=dataset.schema
) as engine:
with closing(engine.raw_connection()) as conn:
cursor = conn.cursor()
query = dataset.database.apply_limit_to_sql(statements[0], limit=1)
db_engine_spec.execute(cursor, query)
result = db_engine_spec.fetch_data(cursor, limit=1)
result_set = SupersetResultSet(
result, cursor.description, db_engine_spec
)
cols = result_set.columns
except Exception as ex:
raise SupersetGenericDBErrorException(message=str(ex)) from ex
return cols
Expand All @@ -155,14 +159,17 @@ def get_columns_description(
) -> List[ResultSetColumnType]:
db_engine_spec = database.db_engine_spec
try:
with closing(database.get_sqla_engine().raw_connection()) as conn:
cursor = conn.cursor()
query = database.apply_limit_to_sql(query, limit=1)
cursor.execute(query)
db_engine_spec.execute(cursor, query)
result = db_engine_spec.fetch_data(cursor, limit=1)
result_set = SupersetResultSet(result, cursor.description, db_engine_spec)
return result_set.columns
with database.get_sqla_engine_with_context() as engine:
with closing(engine.raw_connection()) as conn:
cursor = conn.cursor()
query = database.apply_limit_to_sql(query, limit=1)
cursor.execute(query)
db_engine_spec.execute(cursor, query)
result = db_engine_spec.fetch_data(cursor, limit=1)
result_set = SupersetResultSet(
result, cursor.description, db_engine_spec
)
return result_set.columns
except Exception as ex:
raise SupersetGenericDBErrorException(message=str(ex)) from ex

Expand Down
52 changes: 26 additions & 26 deletions superset/databases/commands/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,6 @@ def run(self) -> None: # pylint: disable=too-many-statements
database.set_sqlalchemy_uri(uri)
database.db_engine_spec.mutate_db_for_connection_test(database)

engine = database.get_sqla_engine()
event_logger.log_with_context(
action="test_connection_attempt",
engine=database.db_engine_spec.__name__,
Expand All @@ -100,31 +99,32 @@ def ping(engine: Engine) -> bool:
with closing(engine.raw_connection()) as conn:
return engine.dialect.do_ping(conn)

try:
alive = func_timeout(
int(app.config["TEST_DATABASE_CONNECTION_TIMEOUT"].total_seconds()),
ping,
args=(engine,),
)
except (sqlite3.ProgrammingError, RuntimeError):
# SQLite can't run on a separate thread, so ``func_timeout`` fails
# RuntimeError catches the equivalent error from duckdb.
alive = engine.dialect.do_ping(engine)
except FunctionTimedOut as ex:
raise SupersetTimeoutException(
error_type=SupersetErrorType.CONNECTION_DATABASE_TIMEOUT,
message=(
"Please check your connection details and database settings, "
"and ensure that your database is accepting connections, "
"then try connecting again."
),
level=ErrorLevel.ERROR,
extra={"sqlalchemy_uri": database.sqlalchemy_uri},
) from ex
except Exception as ex: # pylint: disable=broad-except
alive = False
# So we stop losing the original message if any
ex_str = str(ex)
with database.get_sqla_engine_with_context() as engine:
try:
alive = func_timeout(
app.config["TEST_DATABASE_CONNECTION_TIMEOUT"].total_seconds(),
ping,
args=(engine,),
)
except (sqlite3.ProgrammingError, RuntimeError):
# SQLite can't run on a separate thread, so ``func_timeout`` fails
# RuntimeError catches the equivalent error from duckdb.
alive = engine.dialect.do_ping(engine)
except FunctionTimedOut as ex:
raise SupersetTimeoutException(
error_type=SupersetErrorType.CONNECTION_DATABASE_TIMEOUT,
message=(
"Please check your connection details and database settings, "
"and ensure that your database is accepting connections, "
"then try connecting again."
),
level=ErrorLevel.ERROR,
extra={"sqlalchemy_uri": database.sqlalchemy_uri},
) from ex
except Exception as ex: # pylint: disable=broad-except
alive = False
# So we stop losing the original message if any
ex_str = str(ex)

if not alive:
raise DBAPIError(ex_str or None, None, None)
Expand Down
31 changes: 16 additions & 15 deletions superset/databases/commands/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,21 +101,22 @@ def run(self) -> None:
database.set_sqlalchemy_uri(sqlalchemy_uri)
database.db_engine_spec.mutate_db_for_connection_test(database)

engine = database.get_sqla_engine()
try:
with closing(engine.raw_connection()) as conn:
alive = engine.dialect.do_ping(conn)
except Exception as ex:
url = make_url_safe(sqlalchemy_uri)
context = {
"hostname": url.host,
"password": url.password,
"port": url.port,
"username": url.username,
"database": url.database,
}
errors = database.db_engine_spec.extract_errors(ex, context)
raise DatabaseTestConnectionFailedError(errors) from ex
alive = False
with database.get_sqla_engine_with_context() as engine:
try:
with closing(engine.raw_connection()) as conn:
alive = engine.dialect.do_ping(conn)
except Exception as ex:
url = make_url_safe(sqlalchemy_uri)
context = {
"hostname": url.host,
"password": url.password,
"port": url.port,
"username": url.username,
"database": url.database,
}
errors = database.db_engine_spec.extract_errors(ex, context)
raise DatabaseTestConnectionFailedError(errors) from ex

if not alive:
raise DatabaseOfflineError(
Expand Down
33 changes: 21 additions & 12 deletions superset/datasets/commands/importers/v1/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,17 +166,26 @@ def load_data(
if database.sqlalchemy_uri == current_app.config.get("SQLALCHEMY_DATABASE_URI"):
logger.info("Loading data inside the import transaction")
connection = session.connection()
df.to_sql(
dataset.table_name,
con=connection,
schema=dataset.schema,
if_exists="replace",
chunksize=CHUNKSIZE,
dtype=dtype,
index=False,
method="multi",
)
else:
logger.warning("Loading data outside the import transaction")
connection = database.get_sqla_engine()

df.to_sql(
dataset.table_name,
con=connection,
schema=dataset.schema,
if_exists="replace",
chunksize=CHUNKSIZE,
dtype=dtype,
index=False,
method="multi",
)
with database.get_sqla_engine_with_context() as engine:
df.to_sql(
dataset.table_name,
con=engine,
schema=dataset.schema,
if_exists="replace",
chunksize=CHUNKSIZE,
dtype=dtype,
index=False,
method="multi",
)
35 changes: 23 additions & 12 deletions superset/db_engine_specs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from typing import (
Any,
Callable,
ContextManager,
Dict,
List,
Match,
Expand Down Expand Up @@ -480,8 +481,16 @@ def get_engine(
database: "Database",
schema: Optional[str] = None,
source: Optional[utils.QuerySource] = None,
) -> Engine:
return database.get_sqla_engine(schema=schema, source=source)
) -> ContextManager[Engine]:
"""
Return an engine context manager.
>>> with DBEngineSpec.get_engine(database, schema, source) as engine:
... connection = engine.connect()
... connection.execute(sql)
"""
return database.get_sqla_engine_with_context(schema=schema, source=source)

@classmethod
def get_timestamp_expr(
Expand Down Expand Up @@ -903,17 +912,17 @@ def df_to_sql(
:param to_sql_kwargs: The kwargs to be passed to pandas.DataFrame.to_sql` method
"""

engine = cls.get_engine(database)
to_sql_kwargs["name"] = table.table

if table.schema:
# Only add schema when it is preset and non empty.
to_sql_kwargs["schema"] = table.schema

if engine.dialect.supports_multivalues_insert:
to_sql_kwargs["method"] = "multi"
with cls.get_engine(database) as engine:
if engine.dialect.supports_multivalues_insert:
to_sql_kwargs["method"] = "multi"

df.to_sql(con=engine, **to_sql_kwargs)
df.to_sql(con=engine, **to_sql_kwargs)

@classmethod
def convert_dttm( # pylint: disable=unused-argument
Expand Down Expand Up @@ -1286,13 +1295,15 @@ def estimate_query_cost(
parsed_query = sql_parse.ParsedQuery(sql)
statements = parsed_query.get_statements()

engine = cls.get_engine(database, schema=schema, source=source)
costs = []
with closing(engine.raw_connection()) as conn:
cursor = conn.cursor()
for statement in statements:
processed_statement = cls.process_statement(statement, database)
costs.append(cls.estimate_statement_cost(processed_statement, cursor))
with cls.get_engine(database, schema=schema, source=source) as engine:
with closing(engine.raw_connection()) as conn:
cursor = conn.cursor()
for statement in statements:
processed_statement = cls.process_statement(statement, database)
costs.append(
cls.estimate_statement_cost(processed_statement, cursor)
)
return costs

@classmethod
Expand Down
8 changes: 6 additions & 2 deletions superset/db_engine_specs/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,8 +340,12 @@ def df_to_sql(
if not table.schema:
raise Exception("The table schema must be defined")

engine = cls.get_engine(database)
to_gbq_kwargs = {"destination_table": str(table), "project_id": engine.url.host}
to_gbq_kwargs = {}
with cls.get_engine(database) as engine:
to_gbq_kwargs = {
"destination_table": str(table),
"project_id": engine.url.host,
}

# Add credentials if they are set on the SQLAlchemy dialect.
creds = engine.dialect.credentials_info
Expand Down
10 changes: 5 additions & 5 deletions superset/db_engine_specs/gsheets.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,11 +109,11 @@ def extra_table_metadata(
table_name: str,
schema_name: Optional[str],
) -> Dict[str, Any]:
engine = cls.get_engine(database, schema=schema_name)
with closing(engine.raw_connection()) as conn:
cursor = conn.cursor()
cursor.execute(f'SELECT GET_METADATA("{table_name}")')
results = cursor.fetchone()[0]
with cls.get_engine(database, schema=schema_name) as engine:
with closing(engine.raw_connection()) as conn:
cursor = conn.cursor()
cursor.execute(f'SELECT GET_METADATA("{table_name}")')
results = cursor.fetchone()[0]

try:
metadata = json.loads(results)
Expand Down
5 changes: 2 additions & 3 deletions superset/db_engine_specs/hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,8 +185,6 @@ def df_to_sql(
:param to_sql_kwargs: The kwargs to be passed to pandas.DataFrame.to_sql` method
"""

engine = cls.get_engine(database)

if to_sql_kwargs["if_exists"] == "append":
raise SupersetException("Append operation not currently supported")

Expand All @@ -205,7 +203,8 @@ def df_to_sql(
if table_exists:
raise SupersetException("Table already exists")
elif to_sql_kwargs["if_exists"] == "replace":
engine.execute(f"DROP TABLE IF EXISTS {str(table)}")
with cls.get_engine(database) as engine:
engine.execute(f"DROP TABLE IF EXISTS {str(table)}")

def _get_hive_type(dtype: np.dtype) -> str:
hive_type_by_dtype = {
Expand Down
33 changes: 16 additions & 17 deletions superset/db_engine_specs/presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,12 +462,11 @@ def get_view_names(
).strip()
params = {}

engine = cls.get_engine(database, schema=schema)

with closing(engine.raw_connection()) as conn:
cursor = conn.cursor()
cursor.execute(sql, params)
results = cursor.fetchall()
with cls.get_engine(database, schema=schema) as engine:
with closing(engine.raw_connection()) as conn:
cursor = conn.cursor()
cursor.execute(sql, params)
results = cursor.fetchall()

return sorted([row[0] for row in results])

Expand Down Expand Up @@ -989,17 +988,17 @@ def get_create_view(
# pylint: disable=import-outside-toplevel
from pyhive.exc import DatabaseError

engine = cls.get_engine(database, schema)
with closing(engine.raw_connection()) as conn:
cursor = conn.cursor()
sql = f"SHOW CREATE VIEW {schema}.{table}"
try:
cls.execute(cursor, sql)

except DatabaseError: # not a VIEW
return None
rows = cls.fetch_data(cursor, 1)
return rows[0][0]
with cls.get_engine(database, schema=schema) as engine:
with closing(engine.raw_connection()) as conn:
cursor = conn.cursor()
sql = f"SHOW CREATE VIEW {schema}.{table}"
try:
cls.execute(cursor, sql)

except DatabaseError: # not a VIEW
return None
rows = cls.fetch_data(cursor, 1)
return rows[0][0]

@classmethod
def get_tracking_url(cls, cursor: "Cursor") -> Optional[str]:
Expand Down
Loading

0 comments on commit e23efef

Please sign in to comment.