Skip to content

Commit

Permalink
perf: Implement model specific lookups by id to improve performance (#…
Browse files Browse the repository at this point in the history
…20974)

* Implement model specific lookups by id to improve performance

* Address comments e.g. better variable names and test cleanup

* commit after cleanup

* even better name and test cleanup via rollback

Co-authored-by: Bogdan Kyryliuk <bogdankyryliuk@dropbox.com>
  • Loading branch information
2 people authored and michael-s-molina committed Aug 29, 2022
1 parent 094b17e commit 1e8259a
Show file tree
Hide file tree
Showing 10 changed files with 190 additions and 10 deletions.
2 changes: 2 additions & 0 deletions superset/common/query_context_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,8 @@ def get_viz_annotation_data(
chart = ChartDAO.find_by_id(annotation_layer["value"])
if not chart:
raise QueryObjectValidationError(_("The chart does not exist"))
if not chart.datasource:
raise QueryObjectValidationError(_("The chart datasource does not exist"))
form_data = chart.form_data.copy()
try:
viz_obj = get_viz(
Expand Down
7 changes: 5 additions & 2 deletions superset/dao/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,17 @@ class BaseDAO:

@classmethod
def find_by_id(
cls, model_id: Union[str, int], session: Session = None
cls,
model_id: Union[str, int],
session: Session = None,
skip_base_filter: bool = False,
) -> Optional[Model]:
"""
Find a model by id, if defined applies `base_filter`
"""
session = session or db.session
query = session.query(cls.model_cls)
if cls.base_filter:
if cls.base_filter and not skip_base_filter:
data_model = SQLAInterface(cls.model_cls, session)
query = cls.base_filter( # pylint: disable=not-callable
cls.id_column_name, data_model
Expand Down
6 changes: 4 additions & 2 deletions superset/explore/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@

def check_dataset_access(dataset_id: int) -> Optional[bool]:
if dataset_id:
dataset = DatasetDAO.find_by_id(dataset_id)
# Access checks below, no need to validate them twice as they can be expensive.
dataset = DatasetDAO.find_by_id(dataset_id, skip_base_filter=True)
if dataset:
can_access_datasource = security_manager.can_access_datasource(dataset)
if can_access_datasource:
Expand All @@ -48,7 +49,8 @@ def check_access(dataset_id: int, chart_id: Optional[int], actor: User) -> None:
check_dataset_access(dataset_id)
if not chart_id:
return
chart = ChartDAO.find_by_id(chart_id)
# Access checks below, no need to validate them twice as they can be expensive.
chart = ChartDAO.find_by_id(chart_id, skip_base_filter=True)
if chart:
can_access_chart = (
is_user_admin()
Expand Down
1 change: 1 addition & 0 deletions tests/integration_tests/datasets/api_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1579,6 +1579,7 @@ def test_get_dataset_related_objects_not_found(self):
rv = self.client.get(uri)
assert rv.status_code == 404
self.logout()

self.login(username="gamma")
table = self.get_birth_names_dataset()
uri = f"api/v1/dataset/{table.id}/related_objects"
Expand Down
10 changes: 5 additions & 5 deletions tests/integration_tests/explore/form_data/api_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def test_post_access_denied(client, chart_id: int, dataset_id: int):
"form_data": INITIAL_FORM_DATA,
}
resp = client.post("api/v1/explore/form_data", json=payload)
assert resp.status_code == 404
assert resp.status_code == 403


def test_post_same_key_for_same_context(client, chart_id: int, dataset_id: int):
Expand Down Expand Up @@ -310,7 +310,7 @@ def test_put_access_denied(client, chart_id: int, dataset_id: int):
"form_data": UPDATED_FORM_DATA,
}
resp = client.put(f"api/v1/explore/form_data/{KEY}", json=payload)
assert resp.status_code == 404
assert resp.status_code == 403


def test_put_not_owner(client, chart_id: int, dataset_id: int):
Expand All @@ -321,7 +321,7 @@ def test_put_not_owner(client, chart_id: int, dataset_id: int):
"form_data": UPDATED_FORM_DATA,
}
resp = client.put(f"api/v1/explore/form_data/{KEY}", json=payload)
assert resp.status_code == 404
assert resp.status_code == 403


def test_get_key_not_found(client):
Expand All @@ -341,7 +341,7 @@ def test_get(client):
def test_get_access_denied(client):
login(client, "gamma")
resp = client.get(f"api/v1/explore/form_data/{KEY}")
assert resp.status_code == 404
assert resp.status_code == 403


@patch("superset.security.SupersetSecurityManager.can_access_datasource")
Expand All @@ -361,7 +361,7 @@ def test_delete(client):
def test_delete_access_denied(client):
login(client, "gamma")
resp = client.delete(f"api/v1/explore/form_data/{KEY}")
assert resp.status_code == 404
assert resp.status_code == 403


def test_delete_not_owner(client, chart_id: int, dataset_id: int, admin_id: int):
Expand Down
2 changes: 1 addition & 1 deletion tests/integration_tests/explore/permalink/api_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def test_post(client, form_data: Dict[str, Any], permalink_salt: str):
def test_post_access_denied(client, form_data):
login(client, "gamma")
resp = client.post(f"api/v1/explore/permalink", json={"formData": form_data})
assert resp.status_code == 404
assert resp.status_code == 403


def test_get_missing_chart(client, chart, permalink_salt: str) -> None:
Expand Down
16 changes: 16 additions & 0 deletions tests/unit_tests/charts/dao/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
67 changes: 67 additions & 0 deletions tests/unit_tests/charts/dao/dao_tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from typing import Iterator

import pytest
from sqlalchemy.orm.session import Session

from superset.utils.core import DatasourceType


@pytest.fixture
def session_with_data(session: Session) -> Iterator[Session]:
from superset.models.slice import Slice

engine = session.get_bind()
Slice.metadata.create_all(engine) # pylint: disable=no-member

slice_obj = Slice(
id=1,
datasource_id=1,
datasource_type=DatasourceType.TABLE,
datasource_name="tmp_perm_table",
slice_name="slice_name",
)

session.add(slice_obj)
session.commit()
yield session
session.rollback()


def test_slice_find_by_id_skip_base_filter(session_with_data: Session) -> None:
from superset.charts.dao import ChartDAO
from superset.models.slice import Slice

result = ChartDAO.find_by_id(1, session=session_with_data, skip_base_filter=True)

assert result
assert 1 == result.id
assert "slice_name" == result.slice_name
assert isinstance(result, Slice)


def test_datasource_find_by_id_skip_base_filter_not_found(
session_with_data: Session,
) -> None:
from superset.charts.dao import ChartDAO

result = ChartDAO.find_by_id(
125326326, session=session_with_data, skip_base_filter=True
)
assert result is None
16 changes: 16 additions & 0 deletions tests/unit_tests/datasets/dao/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
73 changes: 73 additions & 0 deletions tests/unit_tests/datasets/dao/dao_tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from typing import Iterator

import pytest
from sqlalchemy.orm.session import Session


@pytest.fixture
def session_with_data(session: Session) -> Iterator[Session]:
from superset.connectors.sqla.models import SqlaTable
from superset.models.core import Database

engine = session.get_bind()
SqlaTable.metadata.create_all(engine) # pylint: disable=no-member

db = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
sqla_table = SqlaTable(
table_name="my_sqla_table",
columns=[],
metrics=[],
database=db,
)

session.add(db)
session.add(sqla_table)
session.flush()
yield session
session.rollback()


def test_datasource_find_by_id_skip_base_filter(session_with_data: Session) -> None:
from superset.connectors.sqla.models import SqlaTable
from superset.datasets.dao import DatasetDAO

result = DatasetDAO.find_by_id(
1,
session=session_with_data,
skip_base_filter=True,
)

assert result
assert 1 == result.id
assert "my_sqla_table" == result.table_name
assert isinstance(result, SqlaTable)


def test_datasource_find_by_id_skip_base_filter_not_found(
session_with_data: Session,
) -> None:
from superset.datasets.dao import DatasetDAO

result = DatasetDAO.find_by_id(
125326326,
session=session_with_data,
skip_base_filter=True,
)
assert result is None

0 comments on commit 1e8259a

Please sign in to comment.