diff --git a/tests/integration_tests/access_tests.py b/tests/integration_tests/access_tests.py index 66067ae2d359a..6bf6cac25f538 100644 --- a/tests/integration_tests/access_tests.py +++ b/tests/integration_tests/access_tests.py @@ -170,7 +170,7 @@ def test_override_role_permissions_1_table(self): updated_override_me = security_manager.find_role("override_me") self.assertEqual(1, len(updated_override_me.permissions)) - birth_names = self.get_table(name="birth_names", schema=schema) + birth_names = self.get_table(name="birth_names") self.assertEqual( birth_names.perm, updated_override_me.permissions[0].view_menu.name ) @@ -205,7 +205,7 @@ def test_override_role_permissions_druid_and_table(self): "datasource_access", updated_role.permissions[1].permission.name ) - birth_names = self.get_table(name="birth_names", schema=schema) + birth_names = self.get_table(name="birth_names") self.assertEqual(birth_names.perm, perms[2].view_menu.name) self.assertEqual( "datasource_access", updated_role.permissions[2].permission.name @@ -223,7 +223,7 @@ def test_override_role_permissions_drops_absent_perms(self): override_me = security_manager.find_role("override_me") override_me.permissions.append( security_manager.find_permission_view_menu( - view_menu_name=self.get_table(name="energy_usage", schema=schema).perm, + view_menu_name=self.get_table(name="energy_usage").perm, permission_name="datasource_access", ) ) @@ -240,7 +240,7 @@ def test_override_role_permissions_drops_absent_perms(self): self.assertEqual(201, response.status_code) updated_override_me = security_manager.find_role("override_me") self.assertEqual(1, len(updated_override_me.permissions)) - birth_names = self.get_table(name="birth_names", schema=schema) + birth_names = self.get_table(name="birth_names") self.assertEqual( birth_names.perm, updated_override_me.permissions[0].view_menu.name ) diff --git a/tests/integration_tests/base_tests.py b/tests/integration_tests/base_tests.py index 003de23e87b47..fdafe20b846c3 100644 --- a/tests/integration_tests/base_tests.py +++ b/tests/integration_tests/base_tests.py @@ -28,6 +28,7 @@ from flask import Response from flask_appbuilder.security.sqla import models as ab_models from flask_testing import TestCase +from sqlalchemy import inspect from sqlalchemy.engine.interfaces import Dialect from sqlalchemy.ext.declarative.api import DeclarativeMeta from sqlalchemy.orm import Session @@ -250,6 +251,10 @@ def get_slice( def get_table( name: str, database_id: Optional[int] = None, schema: Optional[str] = None ) -> SqlaTable: + database = get_example_database() + engine = database.get_sqla_engine() + schema = schema or inspect(engine).default_schema_name + return ( db.session.query(SqlaTable) .filter_by( diff --git a/tests/integration_tests/datasource_tests.py b/tests/integration_tests/datasource_tests.py index f684304162dd7..64821b8235b57 100644 --- a/tests/integration_tests/datasource_tests.py +++ b/tests/integration_tests/datasource_tests.py @@ -279,12 +279,8 @@ def save_datasource_from_dict(self, datasource_post): @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_change_database(self): - database = get_example_database() - engine = database.get_sqla_engine() - schema = inspect(engine).default_schema_name - self.login(username="admin") - tbl = self.get_table(name="birth_names", schema=schema) + tbl = self.get_table(name="birth_names") tbl_id = tbl.id db_id = tbl.database_id datasource_post = get_datasource_post() diff --git a/tests/integration_tests/fixtures/datasource.py b/tests/integration_tests/fixtures/datasource.py index e6cd7e8229cc5..763d58c8a8145 100644 --- a/tests/integration_tests/fixtures/datasource.py +++ b/tests/integration_tests/fixtures/datasource.py @@ -17,8 +17,16 @@ """Fixtures for test_datasource.py""" from typing import Any, Dict +from sqlalchemy import inspect + +from superset.utils.core import get_example_database + def get_datasource_post() -> Dict[str, Any]: + database = get_example_database() + engine = database.get_sqla_engine() + schema = inspect(engine).default_schema_name + return { "id": None, "column_formats": {"ratio": ".2%"}, @@ -30,7 +38,7 @@ def get_datasource_post() -> Dict[str, Any]: "table_name": "birth_names", "datasource_name": "birth_names", "type": "table", - "schema": None, + "schema": schema, "offset": 66, "cache_timeout": 55, "sql": "",