From 1e8259a4108a1e5747570345a6df57239df07bfd Mon Sep 17 00:00:00 2001 From: Bogdan Date: Tue, 9 Aug 2022 09:59:31 -0700 Subject: [PATCH] perf: Implement model specific lookups by id to improve performance (#20974) * Implement model specific lookups by id to improve performance * Address comments e.g. better variable names and test cleanup * commit after cleanup * even better name and test cleanup via rollback Co-authored-by: Bogdan Kyryliuk --- superset/common/query_context_processor.py | 2 + superset/dao/base.py | 7 +- superset/explore/utils.py | 6 +- tests/integration_tests/datasets/api_tests.py | 1 + .../explore/form_data/api_tests.py | 10 +-- .../explore/permalink/api_tests.py | 2 +- tests/unit_tests/charts/dao/__init__.py | 16 ++++ tests/unit_tests/charts/dao/dao_tests.py | 67 +++++++++++++++++ tests/unit_tests/datasets/dao/__init__.py | 16 ++++ tests/unit_tests/datasets/dao/dao_tests.py | 73 +++++++++++++++++++ 10 files changed, 190 insertions(+), 10 deletions(-) create mode 100644 tests/unit_tests/charts/dao/__init__.py create mode 100644 tests/unit_tests/charts/dao/dao_tests.py create mode 100644 tests/unit_tests/datasets/dao/__init__.py create mode 100644 tests/unit_tests/datasets/dao/dao_tests.py diff --git a/superset/common/query_context_processor.py b/superset/common/query_context_processor.py index c87e878fddef..d528aa32930c 100644 --- a/superset/common/query_context_processor.py +++ b/superset/common/query_context_processor.py @@ -473,6 +473,8 @@ def get_viz_annotation_data( chart = ChartDAO.find_by_id(annotation_layer["value"]) if not chart: raise QueryObjectValidationError(_("The chart does not exist")) + if not chart.datasource: + raise QueryObjectValidationError(_("The chart datasource does not exist")) form_data = chart.form_data.copy() try: viz_obj = get_viz( diff --git a/superset/dao/base.py b/superset/dao/base.py index 607967e3041e..981243d0dbdf 100644 --- a/superset/dao/base.py +++ b/superset/dao/base.py @@ -50,14 +50,17 @@ class BaseDAO: @classmethod def find_by_id( - cls, model_id: Union[str, int], session: Session = None + cls, + model_id: Union[str, int], + session: Session = None, + skip_base_filter: bool = False, ) -> Optional[Model]: """ Find a model by id, if defined applies `base_filter` """ session = session or db.session query = session.query(cls.model_cls) - if cls.base_filter: + if cls.base_filter and not skip_base_filter: data_model = SQLAInterface(cls.model_cls, session) query = cls.base_filter( # pylint: disable=not-callable cls.id_column_name, data_model diff --git a/superset/explore/utils.py b/superset/explore/utils.py index 7ab29de2f70e..989294619f71 100644 --- a/superset/explore/utils.py +++ b/superset/explore/utils.py @@ -35,7 +35,8 @@ def check_dataset_access(dataset_id: int) -> Optional[bool]: if dataset_id: - dataset = DatasetDAO.find_by_id(dataset_id) + # Access checks below, no need to validate them twice as they can be expensive. + dataset = DatasetDAO.find_by_id(dataset_id, skip_base_filter=True) if dataset: can_access_datasource = security_manager.can_access_datasource(dataset) if can_access_datasource: @@ -48,7 +49,8 @@ def check_access(dataset_id: int, chart_id: Optional[int], actor: User) -> None: check_dataset_access(dataset_id) if not chart_id: return - chart = ChartDAO.find_by_id(chart_id) + # Access checks below, no need to validate them twice as they can be expensive. + chart = ChartDAO.find_by_id(chart_id, skip_base_filter=True) if chart: can_access_chart = ( is_user_admin() diff --git a/tests/integration_tests/datasets/api_tests.py b/tests/integration_tests/datasets/api_tests.py index 2031d645570f..e6562010300b 100644 --- a/tests/integration_tests/datasets/api_tests.py +++ b/tests/integration_tests/datasets/api_tests.py @@ -1579,6 +1579,7 @@ def test_get_dataset_related_objects_not_found(self): rv = self.client.get(uri) assert rv.status_code == 404 self.logout() + self.login(username="gamma") table = self.get_birth_names_dataset() uri = f"api/v1/dataset/{table.id}/related_objects" diff --git a/tests/integration_tests/explore/form_data/api_tests.py b/tests/integration_tests/explore/form_data/api_tests.py index c05be00e9618..af7df78f9040 100644 --- a/tests/integration_tests/explore/form_data/api_tests.py +++ b/tests/integration_tests/explore/form_data/api_tests.py @@ -119,7 +119,7 @@ def test_post_access_denied(client, chart_id: int, dataset_id: int): "form_data": INITIAL_FORM_DATA, } resp = client.post("api/v1/explore/form_data", json=payload) - assert resp.status_code == 404 + assert resp.status_code == 403 def test_post_same_key_for_same_context(client, chart_id: int, dataset_id: int): @@ -310,7 +310,7 @@ def test_put_access_denied(client, chart_id: int, dataset_id: int): "form_data": UPDATED_FORM_DATA, } resp = client.put(f"api/v1/explore/form_data/{KEY}", json=payload) - assert resp.status_code == 404 + assert resp.status_code == 403 def test_put_not_owner(client, chart_id: int, dataset_id: int): @@ -321,7 +321,7 @@ def test_put_not_owner(client, chart_id: int, dataset_id: int): "form_data": UPDATED_FORM_DATA, } resp = client.put(f"api/v1/explore/form_data/{KEY}", json=payload) - assert resp.status_code == 404 + assert resp.status_code == 403 def test_get_key_not_found(client): @@ -341,7 +341,7 @@ def test_get(client): def test_get_access_denied(client): login(client, "gamma") resp = client.get(f"api/v1/explore/form_data/{KEY}") - assert resp.status_code == 404 + assert resp.status_code == 403 @patch("superset.security.SupersetSecurityManager.can_access_datasource") @@ -361,7 +361,7 @@ def test_delete(client): def test_delete_access_denied(client): login(client, "gamma") resp = client.delete(f"api/v1/explore/form_data/{KEY}") - assert resp.status_code == 404 + assert resp.status_code == 403 def test_delete_not_owner(client, chart_id: int, dataset_id: int, admin_id: int): diff --git a/tests/integration_tests/explore/permalink/api_tests.py b/tests/integration_tests/explore/permalink/api_tests.py index a44bc70a7b49..d25f3f11491e 100644 --- a/tests/integration_tests/explore/permalink/api_tests.py +++ b/tests/integration_tests/explore/permalink/api_tests.py @@ -85,7 +85,7 @@ def test_post(client, form_data: Dict[str, Any], permalink_salt: str): def test_post_access_denied(client, form_data): login(client, "gamma") resp = client.post(f"api/v1/explore/permalink", json={"formData": form_data}) - assert resp.status_code == 404 + assert resp.status_code == 403 def test_get_missing_chart(client, chart, permalink_salt: str) -> None: diff --git a/tests/unit_tests/charts/dao/__init__.py b/tests/unit_tests/charts/dao/__init__.py new file mode 100644 index 000000000000..13a83393a912 --- /dev/null +++ b/tests/unit_tests/charts/dao/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/unit_tests/charts/dao/dao_tests.py b/tests/unit_tests/charts/dao/dao_tests.py new file mode 100644 index 000000000000..15310712a5f8 --- /dev/null +++ b/tests/unit_tests/charts/dao/dao_tests.py @@ -0,0 +1,67 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Iterator + +import pytest +from sqlalchemy.orm.session import Session + +from superset.utils.core import DatasourceType + + +@pytest.fixture +def session_with_data(session: Session) -> Iterator[Session]: + from superset.models.slice import Slice + + engine = session.get_bind() + Slice.metadata.create_all(engine) # pylint: disable=no-member + + slice_obj = Slice( + id=1, + datasource_id=1, + datasource_type=DatasourceType.TABLE, + datasource_name="tmp_perm_table", + slice_name="slice_name", + ) + + session.add(slice_obj) + session.commit() + yield session + session.rollback() + + +def test_slice_find_by_id_skip_base_filter(session_with_data: Session) -> None: + from superset.charts.dao import ChartDAO + from superset.models.slice import Slice + + result = ChartDAO.find_by_id(1, session=session_with_data, skip_base_filter=True) + + assert result + assert 1 == result.id + assert "slice_name" == result.slice_name + assert isinstance(result, Slice) + + +def test_datasource_find_by_id_skip_base_filter_not_found( + session_with_data: Session, +) -> None: + from superset.charts.dao import ChartDAO + + result = ChartDAO.find_by_id( + 125326326, session=session_with_data, skip_base_filter=True + ) + assert result is None diff --git a/tests/unit_tests/datasets/dao/__init__.py b/tests/unit_tests/datasets/dao/__init__.py new file mode 100644 index 000000000000..13a83393a912 --- /dev/null +++ b/tests/unit_tests/datasets/dao/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/unit_tests/datasets/dao/dao_tests.py b/tests/unit_tests/datasets/dao/dao_tests.py new file mode 100644 index 000000000000..31aa9f27d085 --- /dev/null +++ b/tests/unit_tests/datasets/dao/dao_tests.py @@ -0,0 +1,73 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Iterator + +import pytest +from sqlalchemy.orm.session import Session + + +@pytest.fixture +def session_with_data(session: Session) -> Iterator[Session]: + from superset.connectors.sqla.models import SqlaTable + from superset.models.core import Database + + engine = session.get_bind() + SqlaTable.metadata.create_all(engine) # pylint: disable=no-member + + db = Database(database_name="my_database", sqlalchemy_uri="sqlite://") + sqla_table = SqlaTable( + table_name="my_sqla_table", + columns=[], + metrics=[], + database=db, + ) + + session.add(db) + session.add(sqla_table) + session.flush() + yield session + session.rollback() + + +def test_datasource_find_by_id_skip_base_filter(session_with_data: Session) -> None: + from superset.connectors.sqla.models import SqlaTable + from superset.datasets.dao import DatasetDAO + + result = DatasetDAO.find_by_id( + 1, + session=session_with_data, + skip_base_filter=True, + ) + + assert result + assert 1 == result.id + assert "my_sqla_table" == result.table_name + assert isinstance(result, SqlaTable) + + +def test_datasource_find_by_id_skip_base_filter_not_found( + session_with_data: Session, +) -> None: + from superset.datasets.dao import DatasetDAO + + result = DatasetDAO.find_by_id( + 125326326, + session=session_with_data, + skip_base_filter=True, + ) + assert result is None