Skip to content

Commit

Permalink
feat: create function for get_sqla_engine with context (#21790)
Browse files Browse the repository at this point in the history
  • Loading branch information
hughhhh committed Oct 25, 2022
1 parent 1388f21 commit 7600da8
Show file tree
Hide file tree
Showing 14 changed files with 182 additions and 143 deletions.
14 changes: 13 additions & 1 deletion superset/models/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import logging
import textwrap
from ast import literal_eval
from contextlib import closing
from contextlib import closing, contextmanager
from copy import deepcopy
from datetime import datetime
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type
Expand Down Expand Up @@ -362,6 +362,18 @@ def get_effective_user(self, object_url: URL) -> Optional[str]:
else None
)

@contextmanager
def get_sqla_engine_with_context(
self,
schema: Optional[str] = None,
nullpool: bool = True,
source: Optional[utils.QuerySource] = None,
) -> Engine:
try:
yield self.get_sqla_engine(schema=schema, nullpool=nullpool, source=source)
except Exception as ex:
raise self.db_engine_spec.get_dbapi_mapped_exception(ex)

def get_sqla_engine(
self,
schema: Optional[str] = None,
Expand Down
3 changes: 2 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@ def mock_provider() -> Mock:

@fixture(scope="session")
def example_db_engine(example_db_provider: Callable[[], Database]) -> Engine:
return example_db_provider().get_sqla_engine()
with example_db_provider().get_sqla_engine_with_context() as engine:
return engine


@fixture(scope="session")
Expand Down
8 changes: 4 additions & 4 deletions tests/integration_tests/access_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,8 @@ def test_override_role_permissions_is_admin_only(self):
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_override_role_permissions_1_table(self):
database = get_example_database()
engine = database.get_sqla_engine()
schema = inspect(engine).default_schema_name
with database.get_sqla_engine_with_context() as engine:
schema = inspect(engine).default_schema_name

perm_data = ROLE_TABLES_PERM_DATA.copy()
perm_data["database"][0]["schema"][0]["name"] = schema
Expand All @@ -186,8 +186,8 @@ def test_override_role_permissions_1_table(self):
)
def test_override_role_permissions_drops_absent_perms(self):
database = get_example_database()
engine = database.get_sqla_engine()
schema = inspect(engine).default_schema_name
with database.get_sqla_engine_with_context() as engine:
schema = inspect(engine).default_schema_name

override_me = security_manager.find_role("override_me")
override_me.permissions.append(
Expand Down
4 changes: 3 additions & 1 deletion tests/integration_tests/celery_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,9 @@ def run_sql(
def drop_table_if_exists(table_name: str, table_type: CtasMethod) -> None:
"""Drop table if it exists, works on any DB"""
sql = f"DROP {table_type} IF EXISTS {table_name}"
get_example_database().get_sqla_engine().execute(sql)
database = get_example_database()
with database.get_sqla_engine_with_context() as engine:
engine.execute(sql)


def quote_f(value: Optional[str]):
Expand Down
18 changes: 8 additions & 10 deletions tests/integration_tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,6 @@ def setup_sample_data() -> Any:
yield

with app.app_context():
engine = get_example_database().get_sqla_engine()

# drop sqlachemy tables

db.session.commit()
Expand Down Expand Up @@ -210,14 +208,14 @@ def setup_presto_if_needed():

if backend in {"presto", "hive"}:
database = get_example_database()
engine = database.get_sqla_engine()
drop_from_schema(engine, CTAS_SCHEMA_NAME)
engine.execute(f"DROP SCHEMA IF EXISTS {CTAS_SCHEMA_NAME}")
engine.execute(f"CREATE SCHEMA {CTAS_SCHEMA_NAME}")

drop_from_schema(engine, ADMIN_SCHEMA_NAME)
engine.execute(f"DROP SCHEMA IF EXISTS {ADMIN_SCHEMA_NAME}")
engine.execute(f"CREATE SCHEMA {ADMIN_SCHEMA_NAME}")
with database.get_sqla_engine_with_context() as engine:
drop_from_schema(engine, CTAS_SCHEMA_NAME)
engine.execute(f"DROP SCHEMA IF EXISTS {CTAS_SCHEMA_NAME}")
engine.execute(f"CREATE SCHEMA {CTAS_SCHEMA_NAME}")

drop_from_schema(engine, ADMIN_SCHEMA_NAME)
engine.execute(f"DROP SCHEMA IF EXISTS {ADMIN_SCHEMA_NAME}")
engine.execute(f"CREATE SCHEMA {ADMIN_SCHEMA_NAME}")


def with_feature_flags(**mock_feature_flags):
Expand Down
16 changes: 8 additions & 8 deletions tests/integration_tests/csv_upload_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,14 +71,14 @@ def setup_csv_upload(login_as_admin):
yield

upload_db = get_upload_db()
engine = upload_db.get_sqla_engine()
engine.execute(f"DROP TABLE IF EXISTS {EXCEL_UPLOAD_TABLE}")
engine.execute(f"DROP TABLE IF EXISTS {CSV_UPLOAD_TABLE}")
engine.execute(f"DROP TABLE IF EXISTS {PARQUET_UPLOAD_TABLE}")
engine.execute(f"DROP TABLE IF EXISTS {CSV_UPLOAD_TABLE_W_SCHEMA}")
engine.execute(f"DROP TABLE IF EXISTS {CSV_UPLOAD_TABLE_W_EXPLORE}")
db.session.delete(upload_db)
db.session.commit()
with upload_db.get_sqla_engine_with_context() as engine:
engine.execute(f"DROP TABLE IF EXISTS {EXCEL_UPLOAD_TABLE}")
engine.execute(f"DROP TABLE IF EXISTS {CSV_UPLOAD_TABLE}")
engine.execute(f"DROP TABLE IF EXISTS {PARQUET_UPLOAD_TABLE}")
engine.execute(f"DROP TABLE IF EXISTS {CSV_UPLOAD_TABLE_W_SCHEMA}")
engine.execute(f"DROP TABLE IF EXISTS {CSV_UPLOAD_TABLE_W_EXPLORE}")
db.session.delete(upload_db)
db.session.commit()


@pytest.fixture(scope="module")
Expand Down
27 changes: 14 additions & 13 deletions tests/integration_tests/datasets/api_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,9 +670,10 @@ def test_create_dataset_same_name_different_schema(self):
return

example_db = get_example_database()
example_db.get_sqla_engine().execute(
f"CREATE TABLE {CTAS_SCHEMA_NAME}.birth_names AS SELECT 2 as two"
)
with example_db.get_sqla_engine_with_context() as engine:
engine.execute(
f"CREATE TABLE {CTAS_SCHEMA_NAME}.birth_names AS SELECT 2 as two"
)

self.login(username="admin")
table_data = {
Expand All @@ -690,9 +691,8 @@ def test_create_dataset_same_name_different_schema(self):
uri = f'api/v1/dataset/{data.get("id")}'
rv = self.client.delete(uri)
assert rv.status_code == 200
example_db.get_sqla_engine().execute(
f"DROP TABLE {CTAS_SCHEMA_NAME}.birth_names"
)
with example_db.get_sqla_engine_with_context() as engine:
engine.execute(f"DROP TABLE {CTAS_SCHEMA_NAME}.birth_names")

def test_create_dataset_validate_database(self):
"""
Expand Down Expand Up @@ -758,13 +758,14 @@ def test_create_dataset_validate_view_exists(
mock_get_table.return_value = None

example_db = get_example_database()
engine = example_db.get_sqla_engine()
dialect = engine.dialect

with patch.object(
dialect, "get_view_names", wraps=dialect.get_view_names
) as patch_get_view_names:
patch_get_view_names.return_value = ["test_case_view"]
with example_db.get_sqla_engine_with_context() as engine:
engine = engine
dialect = engine.dialect

with patch.object(
dialect, "get_view_names", wraps=dialect.get_view_names
) as patch_get_view_names:
patch_get_view_names.return_value = ["test_case_view"]

self.login(username="admin")
table_data = {
Expand Down
19 changes: 9 additions & 10 deletions tests/integration_tests/datasource_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,18 +45,17 @@ def create_test_table_context(database: Database):
schema = get_example_default_schema()
full_table_name = f"{schema}.test_table" if schema else "test_table"

database.get_sqla_engine().execute(
f"CREATE TABLE IF NOT EXISTS {full_table_name} AS SELECT 1 as first, 2 as second"
)
database.get_sqla_engine().execute(
f"INSERT INTO {full_table_name} (first, second) VALUES (1, 2)"
)
database.get_sqla_engine().execute(
f"INSERT INTO {full_table_name} (first, second) VALUES (3, 4)"
)
with database.get_sqla_engine_with_context() as engine:
engine.execute(
f"CREATE TABLE IF NOT EXISTS {full_table_name} AS SELECT 1 as first, 2 as second"
)
engine.execute(f"INSERT INTO {full_table_name} (first, second) VALUES (1, 2)")
engine.execute(f"INSERT INTO {full_table_name} (first, second) VALUES (3, 4)")

yield db.session
database.get_sqla_engine().execute(f"DROP TABLE {full_table_name}")

with database.get_sqla_engine_with_context() as engine:
engine.execute(f"DROP TABLE {full_table_name}")


class TestDatasource(SupersetTestCase):
Expand Down
27 changes: 14 additions & 13 deletions tests/integration_tests/fixtures/energy_dashboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,21 +39,22 @@
def load_energy_table_data():
with app.app_context():
database = get_example_database()
df = _get_dataframe()
df.to_sql(
ENERGY_USAGE_TBL_NAME,
database.get_sqla_engine(),
if_exists="replace",
chunksize=500,
index=False,
dtype={"source": String(255), "target": String(255), "value": Float()},
method="multi",
schema=get_example_default_schema(),
)
with database.get_sqla_engine_with_context() as engine:
df = _get_dataframe()
df.to_sql(
ENERGY_USAGE_TBL_NAME,
engine,
if_exists="replace",
chunksize=500,
index=False,
dtype={"source": String(255), "target": String(255), "value": Float()},
method="multi",
schema=get_example_default_schema(),
)
yield
with app.app_context():
engine = get_example_database().get_sqla_engine()
engine.execute("DROP TABLE IF EXISTS energy_usage")
with get_example_database().get_sqla_engine_with_context() as engine:
engine.execute("DROP TABLE IF EXISTS energy_usage")


@pytest.fixture()
Expand Down
21 changes: 11 additions & 10 deletions tests/integration_tests/fixtures/unicode_dashboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,17 @@
@pytest.fixture(scope="session")
def load_unicode_data():
with app.app_context():
_get_dataframe().to_sql(
UNICODE_TBL_NAME,
get_example_database().get_sqla_engine(),
if_exists="replace",
chunksize=500,
dtype={"phrase": String(500)},
index=False,
method="multi",
schema=get_example_default_schema(),
)
with get_example_database().get_sqla_engine_with_context() as engine:
_get_dataframe().to_sql(
UNICODE_TBL_NAME,
engine,
if_exists="replace",
chunksize=500,
dtype={"phrase": String(500)},
index=False,
method="multi",
schema=get_example_default_schema(),
)

yield
with app.app_context():
Expand Down
21 changes: 11 additions & 10 deletions tests/integration_tests/fixtures/world_bank_dashboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,16 +50,17 @@ def load_world_bank_data():
"country_name": String(255),
"region": String(255),
}
_get_dataframe(database).to_sql(
WB_HEALTH_POPULATION,
get_example_database().get_sqla_engine(),
if_exists="replace",
chunksize=500,
dtype=dtype,
index=False,
method="multi",
schema=get_example_default_schema(),
)
with database.get_sqla_engine_with_context() as engine:
_get_dataframe(database).to_sql(
WB_HEALTH_POPULATION,
engine,
if_exists="replace",
chunksize=500,
dtype=dtype,
index=False,
method="multi",
schema=get_example_default_schema(),
)

yield
with app.app_context():
Expand Down
Loading

0 comments on commit 7600da8

Please sign in to comment.