Skip to content

Commit

Permalink
fix failed tests and ci issues
Browse files Browse the repository at this point in the history
  • Loading branch information
ofekisr committed Sep 5, 2021
1 parent 581bbc8 commit e9b794b
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 30 deletions.
50 changes: 24 additions & 26 deletions superset/utils/sqllab_execution_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,25 +118,24 @@ def is_run_asynchronous(self) -> bool:
def select_as_cta(self) -> bool:
return self.create_table_as_select is not None

def set_database(self, db: Database) -> None:
self._validate_db(db)
self.database = db
def set_database(self, database: Database) -> None:
self._validate_db(database)
self.database = database
if self.select_as_cta:
self.create_table_as_select.target_schema_name = self._get_ctas_target_schema_name( # type: ignore
db
)
schema_name = self._get_ctas_target_schema_name(database)
self.create_table_as_select.target_schema_name = schema_name # type: ignore

def _get_ctas_target_schema_name(self, db: Database) -> Optional[str]:
if db.force_ctas_schema:
return db.force_ctas_schema
else:
return get_cta_schema_name(db, g.user, self.schema, self.sql)
def _get_ctas_target_schema_name(self, database: Database) -> Optional[str]:
if database.force_ctas_schema:
return database.force_ctas_schema
return get_cta_schema_name(database, g.user, self.schema, self.sql)

def _validate_db(self, db: Database) -> None:
def _validate_db(self, database: Database) -> None:
# TODO validate db.id is equal to self.database_id
pass

def create_query(self) -> Query:
# pylint: disable=C0301
start_time = now_as_float()
if self.select_as_cta:
return Query(
Expand All @@ -154,22 +153,21 @@ def create_query(self) -> Query:
user_id=self.user_id,
client_id=self.client_id_or_short_id,
)
else:
return Query(
database_id=self.database_id,
sql=self.sql,
schema=self.schema,
select_as_cta=False,
start_time=start_time,
tab_name=self.tab_name,
status=self.status,
sql_editor_id=self.sql_editor_id,
user_id=self.user_id,
client_id=self.client_id_or_short_id,
)
return Query(
database_id=self.database_id,
sql=self.sql,
schema=self.schema,
select_as_cta=False,
start_time=start_time,
tab_name=self.tab_name,
status=self.status,
sql_editor_id=self.sql_editor_id,
user_id=self.user_id,
client_id=self.client_id_or_short_id,
)


class CreateTableAsSelect:
class CreateTableAsSelect: # pylint: disable=R0903
ctas_method: CtasMethod
target_schema_name: Optional[str]
target_table_name: str
Expand Down
2 changes: 1 addition & 1 deletion superset/views/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2566,7 +2566,7 @@ def is_query_handled(cls, query: Optional[Query]) -> bool:
QueryStatus.TIMED_OUT,
]

def sql_json_exec( # pylint: disable=too-many-statements,too-many-locals
def sql_json_exec( # pylint: disable=too-many-statements
self,
execution_context: SqlJsonExecutionContext,
query_params: Dict[str, Any],
Expand Down
7 changes: 5 additions & 2 deletions tests/integration_tests/celery_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def get_query_by_id(id: int):

@pytest.fixture(autouse=True, scope="module")
def setup_sqllab():

with app.app_context():
yield

Expand Down Expand Up @@ -216,7 +217,8 @@ def test_run_sync_query_cta_no_data(setup_sqllab):
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
@pytest.mark.parametrize("ctas_method", [CtasMethod.TABLE, CtasMethod.VIEW])
@mock.patch(
"superset.utils.sqllab_execution_context.get_cta_schema_name", lambda d, u, s, sql: CTAS_SCHEMA_NAME
"superset.utils.sqllab_execution_context.get_cta_schema_name",
lambda d, u, s, sql: CTAS_SCHEMA_NAME,
)
def test_run_sync_query_cta_config(setup_sqllab, ctas_method):
if backend() == "sqlite":
Expand All @@ -243,7 +245,8 @@ def test_run_sync_query_cta_config(setup_sqllab, ctas_method):
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
@pytest.mark.parametrize("ctas_method", [CtasMethod.TABLE, CtasMethod.VIEW])
@mock.patch(
"superset.utils.sqllab_execution_context.get_cta_schema_name", lambda d, u, s, sql: CTAS_SCHEMA_NAME
"superset.utils.sqllab_execution_context.get_cta_schema_name",
lambda d, u, s, sql: CTAS_SCHEMA_NAME,
)
def test_run_async_query_cta_config(setup_sqllab, ctas_method):
if backend() == "sqlite":
Expand Down
3 changes: 2 additions & 1 deletion tests/integration_tests/sqllab_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
QUERY_3 = "SELECT * FROM birth_names LIMIT 10"


@pytest.mark.sqllab
class TestSqlLab(SupersetTestCase):
"""Testings for Sql Lab"""

Expand Down Expand Up @@ -188,7 +189,7 @@ def test_sql_json_cta_dynamic_db(self, ctas_method):
return

with mock.patch(
"superset.views.core.get_cta_schema_name",
"superset.utils.sqllab_execution_context.get_cta_schema_name",
lambda d, u, s, sql: f"{u.username}_database",
):
old_allow_ctas = examples_db.allow_ctas
Expand Down

0 comments on commit e9b794b

Please sign in to comment.