Skip to content

Commit

Permalink
fix: cache warmup for non-legacy charts
Browse files Browse the repository at this point in the history
  • Loading branch information
john-bodley committed Jul 18, 2023
1 parent aa01b51 commit 74f25f0
Show file tree
Hide file tree
Showing 6 changed files with 221 additions and 155 deletions.
2 changes: 1 addition & 1 deletion superset/charts/commands/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@


# keys present in the standard export that are not needed
REMOVE_KEYS = ["datasource_type", "datasource_name", "query_context", "url_params"]
REMOVE_KEYS = ["datasource_type", "datasource_name", "url_params"]


class ExportChartsCommand(ExportModelsCommand):
Expand Down
73 changes: 49 additions & 24 deletions superset/charts/commands/warm_up_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,17 @@
import simplejson as json
from flask import g

from superset.charts.commands.exceptions import WarmUpCacheChartNotFoundError
from superset.charts.commands.exceptions import (
ChartInvalidError,
WarmUpCacheChartNotFoundError,
)
from superset.charts.data.commands.get_data_command import ChartDataCommand
from superset.commands.base import BaseCommand
from superset.extensions import db
from superset.models.slice import Slice
from superset.utils.core import error_msg_from_exception
from superset.views.utils import get_dashboard_extra_filters, get_form_data, get_viz
from superset.viz import viz_types


class ChartWarmUpCacheCommand(BaseCommand):
Expand All @@ -43,31 +48,51 @@ def __init__(
def run(self) -> dict[str, Any]:
self.validate()
chart: Slice = self._chart_or_id # type: ignore

try:
form_data = get_form_data(chart.id, use_slice_data=True)[0]
if self._dashboard_id:
form_data["extra_filters"] = (
json.loads(self._extra_filters)
if self._extra_filters
else get_dashboard_extra_filters(chart.id, self._dashboard_id)
)

if not chart.datasource:
raise Exception("Chart's datasource does not exist")

obj = get_viz(
datasource_type=chart.datasource.type,
datasource_id=chart.datasource.id,
form_data=form_data,
force=True,
)

# pylint: disable=assigning-non-slot
g.form_data = form_data
payload = obj.get_payload()
delattr(g, "form_data")
error = payload["errors"] or None
status = payload["status"]

if form_data.get("viz_type") in viz_types:
# Legacy visualizations.
if not chart.datasource:
raise ChartInvalidError("Chart's datasource does not exist")

if self._dashboard_id:
form_data["extra_filters"] = (
json.loads(self._extra_filters)
if self._extra_filters
else get_dashboard_extra_filters(chart.id, self._dashboard_id)
)

g.form_data = form_data # pylint: disable=assigning-non-slot
payload = get_viz(
datasource_type=chart.datasource.type,
datasource_id=chart.datasource.id,
form_data=form_data,
force=True,
).get_payload()
delattr(g, "form_data")
error = payload["errors"] or None
status = payload["status"]
else:
# Non-legacy visualizations.
query_context = chart.get_query_context()

if not query_context:
raise ChartInvalidError("Chart's query context does not exist")

query_context.force = True
command = ChartDataCommand(query_context)
command.validate()
payload = command.run()

# Report the first error.
for query in payload["queries"]:
error = query["error"]
status = query["status"]

if error is not None:
break
except Exception as ex: # pylint: disable=broad-except
error = error_msg_from_exception(ex)
status = None
Expand Down
59 changes: 20 additions & 39 deletions superset/views/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
security_manager,
)
from superset.charts.commands.exceptions import ChartNotFoundError
from superset.charts.commands.warm_up_cache import ChartWarmUpCacheCommand
from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType
from superset.connectors.base.models import BaseDatasource
from superset.connectors.sqla.models import SqlaTable
Expand Down Expand Up @@ -72,6 +73,7 @@
from superset.utils.async_query_manager import AsyncQueryTokenException
from superset.utils.cache import etag_cache
from superset.utils.core import (
base_json_conv,
DatasourceType,
get_user_id,
get_username,
Expand All @@ -95,7 +97,6 @@
check_datasource_perms,
check_explore_cache_perms,
check_resource_permissions,
get_dashboard_extra_filters,
get_datasource_info,
get_form_data,
get_viz,
Expand Down Expand Up @@ -769,7 +770,8 @@ def save_or_overwrite_slice(
@api
@has_access_api
@expose("/warm_up_cache/", methods=("GET",))
def warm_up_cache( # pylint: disable=too-many-locals,no-self-use
@deprecated(new_target="api/v1/chart/warm_up_cache/")
def warm_up_cache( # pylint: disable=no-self-use
self,
) -> FlaskResponse:
"""Warms up the cache for the slice or table.
Expand Down Expand Up @@ -825,43 +827,22 @@ def warm_up_cache( # pylint: disable=too-many-locals,no-self-use
.all()
)

result = []

for slc in slices:
try:
form_data = get_form_data(slc.id, use_slice_data=True)[0]
if dashboard_id:
form_data["extra_filters"] = (
json.loads(extra_filters)
if extra_filters
else get_dashboard_extra_filters(slc.id, dashboard_id)
)

if not slc.datasource:
raise Exception("Slice's datasource does not exist")

obj = get_viz(
datasource_type=slc.datasource.type,
datasource_id=slc.datasource.id,
form_data=form_data,
force=True,
)

# pylint: disable=assigning-non-slot
g.form_data = form_data
payload = obj.get_payload()
delattr(g, "form_data")
error = payload["errors"] or None
status = payload["status"]
except Exception as ex: # pylint: disable=broad-except
error = utils.error_msg_from_exception(ex)
status = None

result.append(
{"slice_id": slc.id, "viz_error": error, "viz_status": status}
)

return json_success(json.dumps(result))
return json_success(
json.dumps(
[
{
"slice_id" if key == "chart_id" else key: value
for key, value in ChartWarmUpCacheCommand(
slc, dashboard_id, extra_filters
)
.run()
.items()
}
for slc in slices
],
default=base_json_conv,
),
)

@has_access
@expose("/dashboard/<dashboard_id_or_slug>/")
Expand Down
105 changes: 95 additions & 10 deletions tests/integration_tests/charts/api_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,37 +14,41 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# isort:skip_file
"""Unit tests for Superset"""
import json
from io import BytesIO
from unittest import mock
from zipfile import is_zipfile, ZipFile

import prison
import pytest
import yaml
from flask_babel import lazy_gettext as _
from parameterized import parameterized
from sqlalchemy import and_
from sqlalchemy.sql import func

from superset.charts.commands.exceptions import ChartDataQueryFailedError
from superset.charts.data.commands.get_data_command import ChartDataCommand
from superset.connectors.sqla.models import SqlaTable
from superset.extensions import cache_manager, db, security_manager
from superset.models.core import Database, FavStar, FavStarClassName
from superset.models.dashboard import Dashboard
from superset.reports.models import ReportSchedule, ReportScheduleType
from superset.models.slice import Slice
from superset.reports.models import ReportSchedule, ReportScheduleType
from superset.utils.core import get_example_default_schema
from superset.utils.database import get_example_database

from tests.integration_tests.conftest import with_feature_flags
from superset.viz import viz_types
from tests.integration_tests.base_api_tests import ApiOwnersTestCaseMixin
from tests.integration_tests.base_tests import SupersetTestCase
from tests.integration_tests.conftest import with_feature_flags
from tests.integration_tests.fixtures.birth_names_dashboard import (
load_birth_names_dashboard_with_slices,
load_birth_names_data,
)
from tests.integration_tests.fixtures.energy_dashboard import (
load_energy_table_with_slice,
load_energy_table_data,
load_energy_table_with_slice,
)
from tests.integration_tests.fixtures.importexport import (
chart_config,
Expand Down Expand Up @@ -1710,12 +1714,16 @@ def test_gets_owned_created_favorited_by_me_filter(self):
assert data["result"][0]["slice_name"] == "name0"
assert data["result"][0]["datasource_id"] == 1

@pytest.mark.usefixtures(
"load_energy_table_with_slice", "load_birth_names_dashboard_with_slices"
@parameterized.expand(
[
"Top 10 Girl Name Share", # Legacy chart
"Pivot Table v2", # Non-legacy chart
],
)
def test_warm_up_cache(self):
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_warm_up_cache(self, slice_name):
self.login()
slc = self.get_slice("Top 10 Girl Name Share", db.session)
slc = self.get_slice(slice_name, db.session)
rv = self.client.put("/api/v1/chart/warm_up_cache", json={"chart_id": slc.id})
self.assertEqual(rv.status_code, 200)
data = json.loads(rv.data.decode("utf-8"))
Expand Down Expand Up @@ -1780,7 +1788,6 @@ def test_warm_up_cache_payload_validation(self):
)
self.assertEqual(rv.status_code, 400)
data = json.loads(rv.data.decode("utf-8"))
print(data)
self.assertEqual(
data,
{
Expand All @@ -1791,3 +1798,81 @@ def test_warm_up_cache_payload_validation(self):
}
},
)

@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_warm_up_cache_error(self) -> None:
self.login()
slc = self.get_slice("Pivot Table v2", db.session)

with mock.patch.object(ChartDataCommand, "run") as mock_run:
mock_run.side_effect = ChartDataQueryFailedError(
_(
"Error: %(error)s",
error=_("Empty query?"),
)
)

assert json.loads(
self.client.put(
"/api/v1/chart/warm_up_cache",
json={"chart_id": slc.id},
).data
) == {
"result": [
{
"chart_id": slc.id,
"viz_error": "Error: Empty query?",
"viz_status": None,
},
],
}

@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_warm_up_cache_no_query_context(self) -> None:
self.login()
slc = self.get_slice("Pivot Table v2", db.session)

with mock.patch.object(Slice, "get_query_context") as mock_get_query_context:
mock_get_query_context.return_value = None

assert json.loads(
self.client.put(
f"/api/v1/chart/warm_up_cache",
json={"chart_id": slc.id},
).data
) == {
"result": [
{
"chart_id": slc.id,
"viz_error": "Chart's query context does not exist",
"viz_status": None,
},
],
}

@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_warm_up_cache_no_datasource(self) -> None:
self.login()
slc = self.get_slice("Top 10 Girl Name Share", db.session)

with mock.patch.object(
Slice,
"datasource",
new_callable=mock.PropertyMock,
) as mock_datasource:
mock_datasource.return_value = None

assert json.loads(
self.client.put(
f"/api/v1/chart/warm_up_cache",
json={"chart_id": slc.id},
).data
) == {
"result": [
{
"chart_id": slc.id,
"viz_error": "Chart's datasource does not exist",
"viz_status": None,
},
],
}
Loading

0 comments on commit 74f25f0

Please sign in to comment.