From 32bb1ce3ff3ae9b7e54708463f152326d35ceb98 Mon Sep 17 00:00:00 2001 From: Elizabeth Thompson Date: Thu, 2 Jun 2022 16:48:16 -0700 Subject: [PATCH] feat!: pass datasource_type and datasource_id to form_data (#19981) * pass datasource_type and datasource_id to form_data * add datasource_type to delete command * add datasource_type to delete command * fix old keys implementation * add more tests --- superset-frontend/src/constants.ts | 10 + .../components/gridComponents/Chart.jsx | 1 + .../ExploreViewContainer.test.tsx | 4 +- .../components/ExploreViewContainer/index.jsx | 37 +- .../controls/DatasourceControl/index.jsx | 4 +- .../src/explore/exploreUtils/formData.ts | 29 +- superset/cachekeys/schemas.py | 3 +- superset/charts/schemas.py | 7 +- superset/commands/exceptions.py | 18 +- superset/dao/datasource/dao.py | 6 +- superset/databases/dao.py | 4 +- superset/datasets/dao.py | 4 +- superset/examples/birth_names.py | 8 +- superset/examples/country_map.py | 3 +- superset/examples/deck.py | 15 +- superset/examples/energy.py | 7 +- superset/examples/long_lat.py | 3 +- superset/examples/multi_line.py | 3 +- superset/examples/multiformat_time_series.py | 3 +- superset/examples/random_time_series.py | 3 +- superset/examples/world_bank.py | 23 +- superset/explore/form_data/api.py | 12 +- superset/explore/form_data/commands/create.py | 12 +- superset/explore/form_data/commands/delete.py | 11 +- superset/explore/form_data/commands/get.py | 8 +- .../explore/form_data/commands/parameters.py | 5 +- superset/explore/form_data/commands/state.py | 3 +- superset/explore/form_data/commands/update.py | 10 +- superset/explore/form_data/commands/utils.py | 10 +- superset/explore/form_data/schemas.py | 24 +- superset/explore/permalink/commands/create.py | 14 +- superset/explore/permalink/commands/get.py | 8 +- superset/explore/permalink/types.py | 3 +- superset/explore/utils.py | 46 ++- superset/utils/cache_manager.py | 27 +- superset/utils/core.py | 5 +- tests/integration_tests/charts/api_tests.py | 20 +- tests/integration_tests/dashboard_utils.py | 4 +- .../explore/form_data/api_tests.py | 109 +++--- .../explore/form_data/commands_tests.py | 359 ++++++++++++++++++ .../explore/permalink/api_tests.py | 4 +- .../integration_tests/import_export_tests.py | 4 +- tests/integration_tests/model_tests.py | 3 +- tests/integration_tests/security_tests.py | 7 +- .../utils/cache_manager_tests.py | 49 +++ tests/unit_tests/dao/datasource_test.py | 6 +- tests/unit_tests/explore/utils_test.py | 177 ++++++++- 47 files changed, 959 insertions(+), 176 deletions(-) create mode 100644 tests/integration_tests/explore/form_data/commands_tests.py create mode 100644 tests/integration_tests/utils/cache_manager_tests.py diff --git a/superset-frontend/src/constants.ts b/superset-frontend/src/constants.ts index 3d0fd5dd2d59..60668ddcb865 100644 --- a/superset-frontend/src/constants.ts +++ b/superset-frontend/src/constants.ts @@ -67,10 +67,18 @@ export const URL_PARAMS = { name: 'slice_id', type: 'string', }, + datasourceId: { + name: 'datasource_id', + type: 'string', + }, datasetId: { name: 'dataset_id', type: 'string', }, + datasourceType: { + name: 'datasource_type', + type: 'string', + }, dashboardId: { name: 'dashboard_id', type: 'string', @@ -88,6 +96,8 @@ export const URL_PARAMS = { export const RESERVED_CHART_URL_PARAMS: string[] = [ URL_PARAMS.formDataKey.name, URL_PARAMS.sliceId.name, + URL_PARAMS.datasourceId.name, + URL_PARAMS.datasourceType.name, URL_PARAMS.datasetId.name, ]; export const RESERVED_DASHBOARD_URL_PARAMS: string[] = [ diff --git a/superset-frontend/src/dashboard/components/gridComponents/Chart.jsx b/superset-frontend/src/dashboard/components/gridComponents/Chart.jsx index 09472060176f..e5d19e931c58 100644 --- a/superset-frontend/src/dashboard/components/gridComponents/Chart.jsx +++ b/superset-frontend/src/dashboard/components/gridComponents/Chart.jsx @@ -272,6 +272,7 @@ export default class Chart extends React.Component { : undefined; const key = await postFormData( this.props.datasource.id, + this.props.datasource.type, this.props.formData, this.props.slice.slice_id, nextTabId, diff --git a/superset-frontend/src/explore/components/ExploreViewContainer/ExploreViewContainer.test.tsx b/superset-frontend/src/explore/components/ExploreViewContainer/ExploreViewContainer.test.tsx index 7743997a3552..2260346968dd 100644 --- a/superset-frontend/src/explore/components/ExploreViewContainer/ExploreViewContainer.test.tsx +++ b/superset-frontend/src/explore/components/ExploreViewContainer/ExploreViewContainer.test.tsx @@ -92,7 +92,7 @@ test('generates a new form_data param when none is available', async () => { expect(replaceState).toHaveBeenCalledWith( expect.anything(), undefined, - expect.stringMatching('dataset_id'), + expect.stringMatching('datasource_id'), ); replaceState.mockRestore(); }); @@ -109,7 +109,7 @@ test('generates a different form_data param when one is provided and is mounting expect(replaceState).toHaveBeenCalledWith( expect.anything(), undefined, - expect.stringMatching('dataset_id'), + expect.stringMatching('datasource_id'), ); replaceState.mockRestore(); }); diff --git a/superset-frontend/src/explore/components/ExploreViewContainer/index.jsx b/superset-frontend/src/explore/components/ExploreViewContainer/index.jsx index 3685023f39f0..e102f2dc970a 100644 --- a/superset-frontend/src/explore/components/ExploreViewContainer/index.jsx +++ b/superset-frontend/src/explore/components/ExploreViewContainer/index.jsx @@ -152,14 +152,24 @@ const ExplorePanelContainer = styled.div` `; const updateHistory = debounce( - async (formData, datasetId, isReplace, standalone, force, title, tabId) => { + async ( + formData, + datasourceId, + datasourceType, + isReplace, + standalone, + force, + title, + tabId, + ) => { const payload = { ...formData }; const chartId = formData.slice_id; const additionalParam = {}; if (chartId) { additionalParam[URL_PARAMS.sliceId.name] = chartId; } else { - additionalParam[URL_PARAMS.datasetId.name] = datasetId; + additionalParam[URL_PARAMS.datasourceId.name] = datasourceId; + additionalParam[URL_PARAMS.datasourceType.name] = datasourceType; } const urlParams = payload?.url_params || {}; @@ -173,11 +183,24 @@ const updateHistory = debounce( let key; let stateModifier; if (isReplace) { - key = await postFormData(datasetId, formData, chartId, tabId); + key = await postFormData( + datasourceId, + datasourceType, + formData, + chartId, + tabId, + ); stateModifier = 'replaceState'; } else { key = getUrlParam(URL_PARAMS.formDataKey); - await putFormData(datasetId, key, formData, chartId, tabId); + await putFormData( + datasourceId, + datasourceType, + key, + formData, + chartId, + tabId, + ); stateModifier = 'pushState'; } const url = mountExploreUrl( @@ -229,11 +252,12 @@ function ExploreViewContainer(props) { dashboardId: props.dashboardId, } : props.form_data; - const datasetId = props.datasource.id; + const { id: datasourceId, type: datasourceType } = props.datasource; updateHistory( formData, - datasetId, + datasourceId, + datasourceType, isReplace, props.standalone, props.force, @@ -245,6 +269,7 @@ function ExploreViewContainer(props) { props.dashboardId, props.form_data, props.datasource.id, + props.datasource.type, props.standalone, props.force, tabId, diff --git a/superset-frontend/src/explore/components/controls/DatasourceControl/index.jsx b/superset-frontend/src/explore/components/controls/DatasourceControl/index.jsx index 73aa5e4d913d..3d6ea2fdd266 100644 --- a/superset-frontend/src/explore/components/controls/DatasourceControl/index.jsx +++ b/superset-frontend/src/explore/components/controls/DatasourceControl/index.jsx @@ -189,9 +189,9 @@ class DatasourceControl extends React.PureComponent { const isMissingDatasource = datasource.id == null; let isMissingParams = false; if (isMissingDatasource) { - const datasetId = getUrlParam(URL_PARAMS.datasetId); + const datasourceId = getUrlParam(URL_PARAMS.datasourceId); const sliceId = getUrlParam(URL_PARAMS.sliceId); - if (!datasetId && !sliceId) { + if (!datasourceId && !sliceId) { isMissingParams = true; } } diff --git a/superset-frontend/src/explore/exploreUtils/formData.ts b/superset-frontend/src/explore/exploreUtils/formData.ts index 9987b5d8cfa7..36de6640a5c8 100644 --- a/superset-frontend/src/explore/exploreUtils/formData.ts +++ b/superset-frontend/src/explore/exploreUtils/formData.ts @@ -20,7 +20,8 @@ import { omit } from 'lodash'; import { SupersetClient, JsonObject } from '@superset-ui/core'; type Payload = { - dataset_id: number; + datasource_id: number; + datasource_type: string; form_data: string; chart_id?: number; }; @@ -42,12 +43,14 @@ const assembleEndpoint = (key?: string, tabId?: string) => { }; const assemblePayload = ( - datasetId: number, + datasourceId: number, + datasourceType: string, formData: JsonObject, chartId?: number, ) => { const payload: Payload = { - dataset_id: datasetId, + datasource_id: datasourceId, + datasource_type: datasourceType, form_data: JSON.stringify(sanitizeFormData(formData)), }; if (chartId) { @@ -57,18 +60,25 @@ const assemblePayload = ( }; export const postFormData = ( - datasetId: number, + datasourceId: number, + datasourceType: string, formData: JsonObject, chartId?: number, tabId?: string, ): Promise => SupersetClient.post({ endpoint: assembleEndpoint(undefined, tabId), - jsonPayload: assemblePayload(datasetId, formData, chartId), + jsonPayload: assemblePayload( + datasourceId, + datasourceType, + formData, + chartId, + ), }).then(r => r.json.key); export const putFormData = ( - datasetId: number, + datasourceId: number, + datasourceType: string, key: string, formData: JsonObject, chartId?: number, @@ -76,5 +86,10 @@ export const putFormData = ( ): Promise => SupersetClient.put({ endpoint: assembleEndpoint(key, tabId), - jsonPayload: assemblePayload(datasetId, formData, chartId), + jsonPayload: assemblePayload( + datasourceId, + datasourceType, + formData, + chartId, + ), }).then(r => r.json.message); diff --git a/superset/cachekeys/schemas.py b/superset/cachekeys/schemas.py index a44a7c545add..3d913e8b5f6e 100644 --- a/superset/cachekeys/schemas.py +++ b/superset/cachekeys/schemas.py @@ -22,6 +22,7 @@ datasource_type_description, datasource_uid_description, ) +from superset.utils.core import DatasourceType class Datasource(Schema): @@ -36,7 +37,7 @@ class Datasource(Schema): ) datasource_type = fields.String( description=datasource_type_description, - validate=validate.OneOf(choices=("druid", "table", "view")), + validate=validate.OneOf(choices=[ds.value for ds in DatasourceType]), required=True, ) diff --git a/superset/charts/schemas.py b/superset/charts/schemas.py index 6a05e4d9942b..8a82e364be47 100644 --- a/superset/charts/schemas.py +++ b/superset/charts/schemas.py @@ -31,6 +31,7 @@ from superset.utils import pandas_postprocessing, schema as utils from superset.utils.core import ( AnnotationType, + DatasourceType, FilterOperator, PostProcessingBoxplotWhiskerType, PostProcessingContributionOrientation, @@ -198,7 +199,7 @@ class ChartPostSchema(Schema): datasource_id = fields.Integer(description=datasource_id_description, required=True) datasource_type = fields.String( description=datasource_type_description, - validate=validate.OneOf(choices=("druid", "table", "view")), + validate=validate.OneOf(choices=[ds.value for ds in DatasourceType]), required=True, ) datasource_name = fields.String( @@ -244,7 +245,7 @@ class ChartPutSchema(Schema): ) datasource_type = fields.String( description=datasource_type_description, - validate=validate.OneOf(choices=("druid", "table", "view")), + validate=validate.OneOf(choices=[ds.value for ds in DatasourceType]), allow_none=True, ) dashboards = fields.List(fields.Integer(description=dashboards_description)) @@ -983,7 +984,7 @@ class ChartDataDatasourceSchema(Schema): ) type = fields.String( description="Datasource type", - validate=validate.OneOf(choices=("druid", "table")), + validate=validate.OneOf(choices=[ds.value for ds in DatasourceType]), ) diff --git a/superset/commands/exceptions.py b/superset/commands/exceptions.py index 2a60318b46e0..a661ef4d6047 100644 --- a/superset/commands/exceptions.py +++ b/superset/commands/exceptions.py @@ -115,8 +115,24 @@ def __init__(self) -> None: super().__init__([_("Some roles do not exist")], field_name="roles") +class DatasourceTypeInvalidError(ValidationError): + status = 422 + + def __init__(self) -> None: + super().__init__( + [_("Datasource type is invalid")], field_name="datasource_type" + ) + + class DatasourceNotFoundValidationError(ValidationError): status = 404 def __init__(self) -> None: - super().__init__([_("Dataset does not exist")], field_name="datasource_id") + super().__init__([_("Datasource does not exist")], field_name="datasource_id") + + +class QueryNotFoundValidationError(ValidationError): + status = 404 + + def __init__(self) -> None: + super().__init__([_("Query does not exist")], field_name="datasource_id") diff --git a/superset/dao/datasource/dao.py b/superset/dao/datasource/dao.py index 8b4845db3c51..caa45564aa25 100644 --- a/superset/dao/datasource/dao.py +++ b/superset/dao/datasource/dao.py @@ -39,11 +39,11 @@ class DatasourceDAO(BaseDAO): sources: Dict[DatasourceType, Type[Datasource]] = { - DatasourceType.SQLATABLE: SqlaTable, + DatasourceType.TABLE: SqlaTable, DatasourceType.QUERY: Query, DatasourceType.SAVEDQUERY: SavedQuery, DatasourceType.DATASET: Dataset, - DatasourceType.TABLE: Table, + DatasourceType.SLTABLE: Table, } @classmethod @@ -66,7 +66,7 @@ def get_datasource( @classmethod def get_all_sqlatables_datasources(cls, session: Session) -> List[Datasource]: - source_class = DatasourceDAO.sources[DatasourceType.SQLATABLE] + source_class = DatasourceDAO.sources[DatasourceType.TABLE] qry = session.query(source_class) qry = source_class.default_query(qry) return qry.all() diff --git a/superset/databases/dao.py b/superset/databases/dao.py index 5e47772cfc63..892ab86ed21d 100644 --- a/superset/databases/dao.py +++ b/superset/databases/dao.py @@ -24,6 +24,7 @@ from superset.models.dashboard import Dashboard from superset.models.slice import Slice from superset.models.sql_lab import TabState +from superset.utils.core import DatasourceType logger = logging.getLogger(__name__) @@ -75,7 +76,8 @@ def get_related_objects(cls, database_id: int) -> Dict[str, Any]: charts = ( db.session.query(Slice) .filter( - Slice.datasource_id.in_(dataset_ids), Slice.datasource_type == "table" + Slice.datasource_id.in_(dataset_ids), + Slice.datasource_type == DatasourceType.TABLE, ) .all() ) diff --git a/superset/datasets/dao.py b/superset/datasets/dao.py index 89460f3b43c6..44ab8efa0ce5 100644 --- a/superset/datasets/dao.py +++ b/superset/datasets/dao.py @@ -26,6 +26,7 @@ from superset.models.core import Database from superset.models.dashboard import Dashboard from superset.models.slice import Slice +from superset.utils.core import DatasourceType from superset.views.base import DatasourceFilter logger = logging.getLogger(__name__) @@ -56,7 +57,8 @@ def get_related_objects(database_id: int) -> Dict[str, Any]: charts = ( db.session.query(Slice) .filter( - Slice.datasource_id == database_id, Slice.datasource_type == "table" + Slice.datasource_id == database_id, + Slice.datasource_type == DatasourceType.TABLE, ) .all() ) diff --git a/superset/examples/birth_names.py b/superset/examples/birth_names.py index de0018ce7cd9..6b37fe9d08dc 100644 --- a/superset/examples/birth_names.py +++ b/superset/examples/birth_names.py @@ -29,6 +29,7 @@ from superset.models.core import Database from superset.models.dashboard import Dashboard from superset.models.slice import Slice +from superset.utils.core import DatasourceType from ..utils.database import get_example_database from .helpers import ( @@ -205,13 +206,16 @@ def create_slices(tbl: SqlaTable, admin_owner: bool) -> Tuple[List[Slice], List[ if admin_owner: slice_props = dict( datasource_id=tbl.id, - datasource_type="table", + datasource_type=DatasourceType.TABLE, owners=[admin], created_by=admin, ) else: slice_props = dict( - datasource_id=tbl.id, datasource_type="table", owners=[], created_by=admin + datasource_id=tbl.id, + datasource_type=DatasourceType.TABLE, + owners=[], + created_by=admin, ) print("Creating some slices") diff --git a/superset/examples/country_map.py b/superset/examples/country_map.py index 049de6650c44..c959a92085fc 100644 --- a/superset/examples/country_map.py +++ b/superset/examples/country_map.py @@ -24,6 +24,7 @@ from superset import db from superset.connectors.sqla.models import SqlMetric from superset.models.slice import Slice +from superset.utils.core import DatasourceType from .helpers import ( get_example_data, @@ -112,7 +113,7 @@ def load_country_map_data(only_metadata: bool = False, force: bool = False) -> N slc = Slice( slice_name="Birth in France by department in 2016", viz_type="country_map", - datasource_type="table", + datasource_type=DatasourceType.TABLE, datasource_id=tbl.id, params=get_slice_json(slice_data), ) diff --git a/superset/examples/deck.py b/superset/examples/deck.py index f6c7a8c6996c..418ed9d28ba1 100644 --- a/superset/examples/deck.py +++ b/superset/examples/deck.py @@ -19,6 +19,7 @@ from superset import db from superset.models.dashboard import Dashboard from superset.models.slice import Slice +from superset.utils.core import DatasourceType from .helpers import ( get_slice_json, @@ -213,7 +214,7 @@ def load_deck_dash() -> None: # pylint: disable=too-many-statements slc = Slice( slice_name="Deck.gl Scatterplot", viz_type="deck_scatter", - datasource_type="table", + datasource_type=DatasourceType.TABLE, datasource_id=tbl.id, params=get_slice_json(slice_data), ) @@ -248,7 +249,7 @@ def load_deck_dash() -> None: # pylint: disable=too-many-statements slc = Slice( slice_name="Deck.gl Screen grid", viz_type="deck_screengrid", - datasource_type="table", + datasource_type=DatasourceType.TABLE, datasource_id=tbl.id, params=get_slice_json(slice_data), ) @@ -284,7 +285,7 @@ def load_deck_dash() -> None: # pylint: disable=too-many-statements slc = Slice( slice_name="Deck.gl Hexagons", viz_type="deck_hex", - datasource_type="table", + datasource_type=DatasourceType.TABLE, datasource_id=tbl.id, params=get_slice_json(slice_data), ) @@ -321,7 +322,7 @@ def load_deck_dash() -> None: # pylint: disable=too-many-statements slc = Slice( slice_name="Deck.gl Grid", viz_type="deck_grid", - datasource_type="table", + datasource_type=DatasourceType.TABLE, datasource_id=tbl.id, params=get_slice_json(slice_data), ) @@ -410,7 +411,7 @@ def load_deck_dash() -> None: # pylint: disable=too-many-statements slc = Slice( slice_name="Deck.gl Polygons", viz_type="deck_polygon", - datasource_type="table", + datasource_type=DatasourceType.TABLE, datasource_id=polygon_tbl.id, params=get_slice_json(slice_data), ) @@ -460,7 +461,7 @@ def load_deck_dash() -> None: # pylint: disable=too-many-statements slc = Slice( slice_name="Deck.gl Arcs", viz_type="deck_arc", - datasource_type="table", + datasource_type=DatasourceType.TABLE, datasource_id=db.session.query(table) .filter_by(table_name="flights") .first() @@ -512,7 +513,7 @@ def load_deck_dash() -> None: # pylint: disable=too-many-statements slc = Slice( slice_name="Deck.gl Path", viz_type="deck_path", - datasource_type="table", + datasource_type=DatasourceType.TABLE, datasource_id=db.session.query(table) .filter_by(table_name="bart_lines") .first() diff --git a/superset/examples/energy.py b/superset/examples/energy.py index 137d7fe73501..d88d693651d4 100644 --- a/superset/examples/energy.py +++ b/superset/examples/energy.py @@ -25,6 +25,7 @@ from superset import db from superset.connectors.sqla.models import SqlMetric from superset.models.slice import Slice +from superset.utils.core import DatasourceType from .helpers import ( get_example_data, @@ -81,7 +82,7 @@ def load_energy( slc = Slice( slice_name="Energy Sankey", viz_type="sankey", - datasource_type="table", + datasource_type=DatasourceType.TABLE, datasource_id=tbl.id, params=textwrap.dedent( """\ @@ -105,7 +106,7 @@ def load_energy( slc = Slice( slice_name="Energy Force Layout", viz_type="graph_chart", - datasource_type="table", + datasource_type=DatasourceType.TABLE, datasource_id=tbl.id, params=textwrap.dedent( """\ @@ -129,7 +130,7 @@ def load_energy( slc = Slice( slice_name="Heatmap", viz_type="heatmap", - datasource_type="table", + datasource_type=DatasourceType.TABLE, datasource_id=tbl.id, params=textwrap.dedent( """\ diff --git a/superset/examples/long_lat.py b/superset/examples/long_lat.py index 4245be1057fe..ba9824bb43fe 100644 --- a/superset/examples/long_lat.py +++ b/superset/examples/long_lat.py @@ -24,6 +24,7 @@ import superset.utils.database as database_utils from superset import db from superset.models.slice import Slice +from superset.utils.core import DatasourceType from .helpers import ( get_example_data, @@ -113,7 +114,7 @@ def load_long_lat_data(only_metadata: bool = False, force: bool = False) -> None slc = Slice( slice_name="Mapbox Long/Lat", viz_type="mapbox", - datasource_type="table", + datasource_type=DatasourceType.TABLE, datasource_id=tbl.id, params=get_slice_json(slice_data), ) diff --git a/superset/examples/multi_line.py b/superset/examples/multi_line.py index 1887fd09069e..6ca023cdcf16 100644 --- a/superset/examples/multi_line.py +++ b/superset/examples/multi_line.py @@ -18,6 +18,7 @@ from superset import db from superset.models.slice import Slice +from superset.utils.core import DatasourceType from .birth_names import load_birth_names from .helpers import merge_slice, misc_dash_slices @@ -35,7 +36,7 @@ def load_multi_line(only_metadata: bool = False) -> None: ] slc = Slice( - datasource_type="table", # not true, but needed + datasource_type=DatasourceType.TABLE, # not true, but needed datasource_id=1, # cannot be empty slice_name="Multi Line", viz_type="line_multi", diff --git a/superset/examples/multiformat_time_series.py b/superset/examples/multiformat_time_series.py index 1e9ee497db6d..9b8bb22c98e8 100644 --- a/superset/examples/multiformat_time_series.py +++ b/superset/examples/multiformat_time_series.py @@ -21,6 +21,7 @@ from superset import app, db from superset.models.slice import Slice +from superset.utils.core import DatasourceType from ..utils.database import get_example_database from .helpers import ( @@ -120,7 +121,7 @@ def load_multiformat_time_series( # pylint: disable=too-many-locals slc = Slice( slice_name=f"Calendar Heatmap multiformat {i}", viz_type="cal_heatmap", - datasource_type="table", + datasource_type=DatasourceType.TABLE, datasource_id=tbl.id, params=get_slice_json(slice_data), ) diff --git a/superset/examples/random_time_series.py b/superset/examples/random_time_series.py index 0f39b95bd16c..152b63e1cc32 100644 --- a/superset/examples/random_time_series.py +++ b/superset/examples/random_time_series.py @@ -21,6 +21,7 @@ import superset.utils.database as database_utils from superset import app, db from superset.models.slice import Slice +from superset.utils.core import DatasourceType from .helpers import ( get_example_data, @@ -89,7 +90,7 @@ def load_random_time_series_data( slc = Slice( slice_name="Calendar Heatmap", viz_type="cal_heatmap", - datasource_type="table", + datasource_type=DatasourceType.TABLE, datasource_id=tbl.id, params=get_slice_json(slice_data), ) diff --git a/superset/examples/world_bank.py b/superset/examples/world_bank.py index 421818724f83..39b982aa5246 100644 --- a/superset/examples/world_bank.py +++ b/superset/examples/world_bank.py @@ -29,6 +29,7 @@ from superset.models.dashboard import Dashboard from superset.models.slice import Slice from superset.utils import core as utils +from superset.utils.core import DatasourceType from ..connectors.base.models import BaseDatasource from .helpers import ( @@ -172,7 +173,7 @@ def create_slices(tbl: BaseDatasource) -> List[Slice]: Slice( slice_name="Region Filter", viz_type="filter_box", - datasource_type="table", + datasource_type=DatasourceType.TABLE, datasource_id=tbl.id, params=get_slice_json( defaults, @@ -201,7 +202,7 @@ def create_slices(tbl: BaseDatasource) -> List[Slice]: Slice( slice_name="World's Population", viz_type="big_number", - datasource_type="table", + datasource_type=DatasourceType.TABLE, datasource_id=tbl.id, params=get_slice_json( defaults, @@ -215,7 +216,7 @@ def create_slices(tbl: BaseDatasource) -> List[Slice]: Slice( slice_name="Most Populated Countries", viz_type="table", - datasource_type="table", + datasource_type=DatasourceType.TABLE, datasource_id=tbl.id, params=get_slice_json( defaults, @@ -227,7 +228,7 @@ def create_slices(tbl: BaseDatasource) -> List[Slice]: Slice( slice_name="Growth Rate", viz_type="line", - datasource_type="table", + datasource_type=DatasourceType.TABLE, datasource_id=tbl.id, params=get_slice_json( defaults, @@ -241,7 +242,7 @@ def create_slices(tbl: BaseDatasource) -> List[Slice]: Slice( slice_name="% Rural", viz_type="world_map", - datasource_type="table", + datasource_type=DatasourceType.TABLE, datasource_id=tbl.id, params=get_slice_json( defaults, @@ -254,7 +255,7 @@ def create_slices(tbl: BaseDatasource) -> List[Slice]: Slice( slice_name="Life Expectancy VS Rural %", viz_type="bubble", - datasource_type="table", + datasource_type=DatasourceType.TABLE, datasource_id=tbl.id, params=get_slice_json( defaults, @@ -298,7 +299,7 @@ def create_slices(tbl: BaseDatasource) -> List[Slice]: Slice( slice_name="Rural Breakdown", viz_type="sunburst", - datasource_type="table", + datasource_type=DatasourceType.TABLE, datasource_id=tbl.id, params=get_slice_json( defaults, @@ -313,7 +314,7 @@ def create_slices(tbl: BaseDatasource) -> List[Slice]: Slice( slice_name="World's Pop Growth", viz_type="area", - datasource_type="table", + datasource_type=DatasourceType.TABLE, datasource_id=tbl.id, params=get_slice_json( defaults, @@ -327,7 +328,7 @@ def create_slices(tbl: BaseDatasource) -> List[Slice]: Slice( slice_name="Box plot", viz_type="box_plot", - datasource_type="table", + datasource_type=DatasourceType.TABLE, datasource_id=tbl.id, params=get_slice_json( defaults, @@ -343,7 +344,7 @@ def create_slices(tbl: BaseDatasource) -> List[Slice]: Slice( slice_name="Treemap", viz_type="treemap", - datasource_type="table", + datasource_type=DatasourceType.TABLE, datasource_id=tbl.id, params=get_slice_json( defaults, @@ -357,7 +358,7 @@ def create_slices(tbl: BaseDatasource) -> List[Slice]: Slice( slice_name="Parallel Coordinates", viz_type="para", - datasource_type="table", + datasource_type=DatasourceType.TABLE, datasource_id=tbl.id, params=get_slice_json( defaults, diff --git a/superset/explore/form_data/api.py b/superset/explore/form_data/api.py index ea2b38658bb6..00c8730ee411 100644 --- a/superset/explore/form_data/api.py +++ b/superset/explore/form_data/api.py @@ -104,7 +104,8 @@ def post(self) -> Response: tab_id = request.args.get("tab_id") args = CommandParameters( actor=g.user, - dataset_id=item["dataset_id"], + datasource_id=item["datasource_id"], + datasource_type=item["datasource_type"], chart_id=item.get("chart_id"), tab_id=tab_id, form_data=item["form_data"], @@ -123,7 +124,7 @@ def post(self) -> Response: @safe @event_logger.log_this_with_context( action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.put", - log_to_statsd=False, + log_to_statsd=True, ) @requires_json def put(self, key: str) -> Response: @@ -174,7 +175,8 @@ def put(self, key: str) -> Response: tab_id = request.args.get("tab_id") args = CommandParameters( actor=g.user, - dataset_id=item["dataset_id"], + datasource_id=item["datasource_id"], + datasource_type=item["datasource_type"], chart_id=item.get("chart_id"), tab_id=tab_id, key=key, @@ -196,7 +198,7 @@ def put(self, key: str) -> Response: @safe @event_logger.log_this_with_context( action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.get", - log_to_statsd=False, + log_to_statsd=True, ) def get(self, key: str) -> Response: """Retrives a form_data. @@ -247,7 +249,7 @@ def get(self, key: str) -> Response: @safe @event_logger.log_this_with_context( action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.delete", - log_to_statsd=False, + log_to_statsd=True, ) def delete(self, key: str) -> Response: """Deletes a form_data. diff --git a/superset/explore/form_data/commands/create.py b/superset/explore/form_data/commands/create.py index 84ad0f007850..7946980c8268 100644 --- a/superset/explore/form_data/commands/create.py +++ b/superset/explore/form_data/commands/create.py @@ -39,20 +39,24 @@ def __init__(self, cmd_params: CommandParameters): def run(self) -> str: self.validate() try: - dataset_id = self._cmd_params.dataset_id + datasource_id = self._cmd_params.datasource_id + datasource_type = self._cmd_params.datasource_type chart_id = self._cmd_params.chart_id tab_id = self._cmd_params.tab_id actor = self._cmd_params.actor form_data = self._cmd_params.form_data - check_access(dataset_id, chart_id, actor) - contextual_key = cache_key(session.get("_id"), tab_id, dataset_id, chart_id) + check_access(datasource_id, chart_id, actor, datasource_type) + contextual_key = cache_key( + session.get("_id"), tab_id, datasource_id, chart_id, datasource_type + ) key = cache_manager.explore_form_data_cache.get(contextual_key) if not key or not tab_id: key = random_key() if form_data: state: TemporaryExploreState = { "owner": get_owner(actor), - "dataset_id": dataset_id, + "datasource_id": datasource_id, + "datasource_type": datasource_type, "chart_id": chart_id, "form_data": form_data, } diff --git a/superset/explore/form_data/commands/delete.py b/superset/explore/form_data/commands/delete.py index 69d5a79b8f01..598ece3f080f 100644 --- a/superset/explore/form_data/commands/delete.py +++ b/superset/explore/form_data/commands/delete.py @@ -16,6 +16,7 @@ # under the License. import logging from abc import ABC +from typing import Optional from flask import session from sqlalchemy.exc import SQLAlchemyError @@ -31,6 +32,7 @@ TemporaryCacheDeleteFailedError, ) from superset.temporary_cache.utils import cache_key +from superset.utils.core import DatasourceType logger = logging.getLogger(__name__) @@ -47,14 +49,15 @@ def run(self) -> bool: key ) if state: - dataset_id = state["dataset_id"] - chart_id = state["chart_id"] - check_access(dataset_id, chart_id, actor) + datasource_id: int = state["datasource_id"] + chart_id: Optional[int] = state["chart_id"] + datasource_type = DatasourceType(state["datasource_type"]) + check_access(datasource_id, chart_id, actor, datasource_type) if state["owner"] != get_owner(actor): raise TemporaryCacheAccessDeniedError() tab_id = self._cmd_params.tab_id contextual_key = cache_key( - session.get("_id"), tab_id, dataset_id, chart_id + session.get("_id"), tab_id, datasource_id, chart_id, datasource_type ) cache_manager.explore_form_data_cache.delete(contextual_key) return cache_manager.explore_form_data_cache.delete(key) diff --git a/superset/explore/form_data/commands/get.py b/superset/explore/form_data/commands/get.py index 809672f32f98..982c8e3b4b7d 100644 --- a/superset/explore/form_data/commands/get.py +++ b/superset/explore/form_data/commands/get.py @@ -27,6 +27,7 @@ from superset.explore.form_data.commands.utils import check_access from superset.extensions import cache_manager from superset.temporary_cache.commands.exceptions import TemporaryCacheGetFailedError +from superset.utils.core import DatasourceType logger = logging.getLogger(__name__) @@ -45,7 +46,12 @@ def run(self) -> Optional[str]: key ) if state: - check_access(state["dataset_id"], state["chart_id"], actor) + check_access( + state["datasource_id"], + state["chart_id"], + actor, + DatasourceType(state["datasource_type"]), + ) if self._refresh_timeout: cache_manager.explore_form_data_cache.set(key, state) return state["form_data"] diff --git a/superset/explore/form_data/commands/parameters.py b/superset/explore/form_data/commands/parameters.py index 3e830810b500..fec06a581fb7 100644 --- a/superset/explore/form_data/commands/parameters.py +++ b/superset/explore/form_data/commands/parameters.py @@ -19,11 +19,14 @@ from flask_appbuilder.security.sqla.models import User +from superset.utils.core import DatasourceType + @dataclass class CommandParameters: actor: User - dataset_id: int = 0 + datasource_type: DatasourceType = DatasourceType.TABLE + datasource_id: int = 0 chart_id: int = 0 tab_id: Optional[int] = None key: Optional[str] = None diff --git a/superset/explore/form_data/commands/state.py b/superset/explore/form_data/commands/state.py index c8061e81f5a7..470f2e22f598 100644 --- a/superset/explore/form_data/commands/state.py +++ b/superset/explore/form_data/commands/state.py @@ -21,6 +21,7 @@ class TemporaryExploreState(TypedDict): owner: Optional[int] - dataset_id: int + datasource_id: int + datasource_type: str chart_id: Optional[int] form_data: str diff --git a/superset/explore/form_data/commands/update.py b/superset/explore/form_data/commands/update.py index 76dfee1dadef..fdc75093bef8 100644 --- a/superset/explore/form_data/commands/update.py +++ b/superset/explore/form_data/commands/update.py @@ -47,12 +47,13 @@ def __init__( def run(self) -> Optional[str]: self.validate() try: - dataset_id = self._cmd_params.dataset_id + datasource_id = self._cmd_params.datasource_id chart_id = self._cmd_params.chart_id + datasource_type = self._cmd_params.datasource_type actor = self._cmd_params.actor key = self._cmd_params.key form_data = self._cmd_params.form_data - check_access(dataset_id, chart_id, actor) + check_access(datasource_id, chart_id, actor, datasource_type) state: TemporaryExploreState = cache_manager.explore_form_data_cache.get( key ) @@ -64,7 +65,7 @@ def run(self) -> Optional[str]: # Generate a new key if tab_id changes or equals 0 tab_id = self._cmd_params.tab_id contextual_key = cache_key( - session.get("_id"), tab_id, dataset_id, chart_id + session.get("_id"), tab_id, datasource_id, chart_id, datasource_type ) key = cache_manager.explore_form_data_cache.get(contextual_key) if not key or not tab_id: @@ -73,7 +74,8 @@ def run(self) -> Optional[str]: new_state: TemporaryExploreState = { "owner": owner, - "dataset_id": dataset_id, + "datasource_id": datasource_id, + "datasource_type": datasource_type, "chart_id": chart_id, "form_data": form_data, } diff --git a/superset/explore/form_data/commands/utils.py b/superset/explore/form_data/commands/utils.py index 5d09657fbec4..7927457178c9 100644 --- a/superset/explore/form_data/commands/utils.py +++ b/superset/explore/form_data/commands/utils.py @@ -31,11 +31,17 @@ TemporaryCacheAccessDeniedError, TemporaryCacheResourceNotFoundError, ) +from superset.utils.core import DatasourceType -def check_access(dataset_id: int, chart_id: Optional[int], actor: User) -> None: +def check_access( + datasource_id: int, + chart_id: Optional[int], + actor: User, + datasource_type: DatasourceType, +) -> None: try: - explore_check_access(dataset_id, chart_id, actor) + explore_check_access(datasource_id, chart_id, actor, datasource_type) except (ChartNotFoundError, DatasetNotFoundError) as ex: raise TemporaryCacheResourceNotFoundError from ex except (ChartAccessDeniedError, DatasetAccessDeniedError) as ex: diff --git a/superset/explore/form_data/schemas.py b/superset/explore/form_data/schemas.py index 6d5509d777a3..192df089e818 100644 --- a/superset/explore/form_data/schemas.py +++ b/superset/explore/form_data/schemas.py @@ -14,12 +14,20 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from marshmallow import fields, Schema +from marshmallow import fields, Schema, validate + +from superset.utils.core import DatasourceType class FormDataPostSchema(Schema): - dataset_id = fields.Integer( - required=True, allow_none=False, description="The dataset ID" + datasource_id = fields.Integer( + required=True, allow_none=False, description="The datasource ID" + ) + datasource_type = fields.String( + required=True, + allow_none=False, + description="The datasource type", + validate=validate.OneOf(choices=[ds.value for ds in DatasourceType]), ) chart_id = fields.Integer(required=False, description="The chart ID") form_data = fields.String( @@ -28,8 +36,14 @@ class FormDataPostSchema(Schema): class FormDataPutSchema(Schema): - dataset_id = fields.Integer( - required=True, allow_none=False, description="The dataset ID" + datasource_id = fields.Integer( + required=True, allow_none=False, description="The datasource ID" + ) + datasource_type = fields.String( + required=True, + allow_none=False, + description="The datasource type", + validate=validate.OneOf(choices=[ds.value for ds in DatasourceType]), ) chart_id = fields.Integer(required=False, description="The chart ID") form_data = fields.String( diff --git a/superset/explore/permalink/commands/create.py b/superset/explore/permalink/commands/create.py index c09ca3b37212..7bd6365d814b 100644 --- a/superset/explore/permalink/commands/create.py +++ b/superset/explore/permalink/commands/create.py @@ -22,9 +22,10 @@ from superset.explore.permalink.commands.base import BaseExplorePermalinkCommand from superset.explore.permalink.exceptions import ExplorePermalinkCreateFailedError -from superset.explore.utils import check_access +from superset.explore.utils import check_access as check_chart_access from superset.key_value.commands.create import CreateKeyValueCommand from superset.key_value.utils import encode_permalink_key +from superset.utils.core import DatasourceType logger = logging.getLogger(__name__) @@ -39,11 +40,16 @@ def __init__(self, actor: User, state: Dict[str, Any]): def run(self) -> str: self.validate() try: - dataset_id = int(self.datasource.split("__")[0]) - check_access(dataset_id, self.chart_id, self.actor) + d_id, d_type = self.datasource.split("__") + datasource_id = int(d_id) + datasource_type = DatasourceType(d_type) + check_chart_access( + datasource_id, self.chart_id, self.actor, datasource_type + ) value = { "chartId": self.chart_id, - "datasetId": dataset_id, + "datasourceId": datasource_id, + "datasourceType": datasource_type, "datasource": self.datasource, "state": self.state, } diff --git a/superset/explore/permalink/commands/get.py b/superset/explore/permalink/commands/get.py index 1e3ea1fdc6f9..f75df69d7a63 100644 --- a/superset/explore/permalink/commands/get.py +++ b/superset/explore/permalink/commands/get.py @@ -24,10 +24,11 @@ from superset.explore.permalink.commands.base import BaseExplorePermalinkCommand from superset.explore.permalink.exceptions import ExplorePermalinkGetFailedError from superset.explore.permalink.types import ExplorePermalinkValue -from superset.explore.utils import check_access +from superset.explore.utils import check_access as check_chart_access from superset.key_value.commands.get import GetKeyValueCommand from superset.key_value.exceptions import KeyValueGetFailedError, KeyValueParseKeyError from superset.key_value.utils import decode_permalink_id +from superset.utils.core import DatasourceType logger = logging.getLogger(__name__) @@ -47,8 +48,9 @@ def run(self) -> Optional[ExplorePermalinkValue]: ).run() if value: chart_id: Optional[int] = value.get("chartId") - dataset_id = value["datasetId"] - check_access(dataset_id, chart_id, self.actor) + datasource_id: int = value["datasourceId"] + datasource_type = DatasourceType(value["datasourceType"]) + check_chart_access(datasource_id, chart_id, self.actor, datasource_type) return value return None except ( diff --git a/superset/explore/permalink/types.py b/superset/explore/permalink/types.py index b396e335104b..b90b4d760d4d 100644 --- a/superset/explore/permalink/types.py +++ b/superset/explore/permalink/types.py @@ -24,6 +24,7 @@ class ExplorePermalinkState(TypedDict, total=False): class ExplorePermalinkValue(TypedDict): chartId: Optional[int] - datasetId: int + datasourceId: int + datasourceType: str datasource: str state: ExplorePermalinkState diff --git a/superset/explore/utils.py b/superset/explore/utils.py index 7ab29de2f70e..f0bfd8f0aa40 100644 --- a/superset/explore/utils.py +++ b/superset/explore/utils.py @@ -24,11 +24,18 @@ ChartNotFoundError, ) from superset.charts.dao import ChartDAO +from superset.commands.exceptions import ( + DatasourceNotFoundValidationError, + DatasourceTypeInvalidError, + QueryNotFoundValidationError, +) from superset.datasets.commands.exceptions import ( DatasetAccessDeniedError, DatasetNotFoundError, ) from superset.datasets.dao import DatasetDAO +from superset.queries.dao import QueryDAO +from superset.utils.core import DatasourceType from superset.views.base import is_user_admin from superset.views.utils import is_owner @@ -44,10 +51,41 @@ def check_dataset_access(dataset_id: int) -> Optional[bool]: raise DatasetNotFoundError() -def check_access(dataset_id: int, chart_id: Optional[int], actor: User) -> None: - check_dataset_access(dataset_id) +def check_query_access(query_id: int) -> Optional[bool]: + if query_id: + query = QueryDAO.find_by_id(query_id) + if query: + security_manager.raise_for_access(query=query) + return True + raise QueryNotFoundValidationError() + + +ACCESS_FUNCTION_MAP = { + DatasourceType.TABLE: check_dataset_access, + DatasourceType.QUERY: check_query_access, +} + + +def check_datasource_access( + datasource_id: int, datasource_type: DatasourceType +) -> Optional[bool]: + if datasource_id: + try: + return ACCESS_FUNCTION_MAP[datasource_type](datasource_id) + except KeyError as ex: + raise DatasourceTypeInvalidError() from ex + raise DatasourceNotFoundValidationError() + + +def check_access( + datasource_id: int, + chart_id: Optional[int], + actor: User, + datasource_type: DatasourceType, +) -> Optional[bool]: + check_datasource_access(datasource_id, datasource_type) if not chart_id: - return + return True chart = ChartDAO.find_by_id(chart_id) if chart: can_access_chart = ( @@ -56,6 +94,6 @@ def check_access(dataset_id: int, chart_id: Optional[int], actor: User) -> None: or security_manager.can_access("can_read", "Chart") ) if can_access_chart: - return + return True raise ChartAccessDeniedError() raise ChartNotFoundError() diff --git a/superset/utils/cache_manager.py b/superset/utils/cache_manager.py index 3f071b15435b..d3b2dbdb00d5 100644 --- a/superset/utils/cache_manager.py +++ b/superset/utils/cache_manager.py @@ -15,15 +15,40 @@ # specific language governing permissions and limitations # under the License. import logging +from typing import Any, Optional, Union from flask import Flask from flask_caching import Cache +from markupsafe import Markup + +from superset.utils.core import DatasourceType logger = logging.getLogger(__name__) CACHE_IMPORT_PATH = "superset.extensions.metastore_cache.SupersetMetastoreCache" +class ExploreFormDataCache(Cache): + def get(self, *args: Any, **kwargs: Any) -> Optional[Union[str, Markup]]: + cache = self.cache.get(*args, **kwargs) + + if not cache: + return None + + # rename data keys for existing cache based on new TemporaryExploreState model + if isinstance(cache, dict): + cache = { + ("datasource_id" if key == "dataset_id" else key): value + for (key, value) in cache.items() + } + # add default datasource_type if it doesn't exist + # temporarily defaulting to table until sqlatables are deprecated + if "datasource_type" not in cache: + cache["datasource_type"] = DatasourceType.TABLE + + return cache + + class CacheManager: def __init__(self) -> None: super().__init__() @@ -32,7 +57,7 @@ def __init__(self) -> None: self._data_cache = Cache() self._thumbnail_cache = Cache() self._filter_state_cache = Cache() - self._explore_form_data_cache = Cache() + self._explore_form_data_cache = ExploreFormDataCache() @staticmethod def _init_cache( diff --git a/superset/utils/core.py b/superset/utils/core.py index 4a9992d2a29f..6c90837959ed 100644 --- a/superset/utils/core.py +++ b/superset/utils/core.py @@ -175,12 +175,13 @@ class GenericDataType(IntEnum): # ROW = 7 -class DatasourceType(Enum): - SQLATABLE = "sqlatable" +class DatasourceType(str, Enum): + SLTABLE = "sl_table" TABLE = "table" DATASET = "dataset" QUERY = "query" SAVEDQUERY = "saved_query" + VIEW = "view" class DatasourceDict(TypedDict): diff --git a/tests/integration_tests/charts/api_tests.py b/tests/integration_tests/charts/api_tests.py index 6b8d625d567e..a37acf6eafc3 100644 --- a/tests/integration_tests/charts/api_tests.py +++ b/tests/integration_tests/charts/api_tests.py @@ -520,7 +520,13 @@ def test_create_chart_validate_datasource(self): response = json.loads(rv.data.decode("utf-8")) self.assertEqual( response, - {"message": {"datasource_type": ["Must be one of: druid, table, view."]}}, + { + "message": { + "datasource_type": [ + "Must be one of: sl_table, table, dataset, query, saved_query, view." + ] + } + }, ) chart_data = { "slice_name": "title1", @@ -531,7 +537,7 @@ def test_create_chart_validate_datasource(self): self.assertEqual(rv.status_code, 422) response = json.loads(rv.data.decode("utf-8")) self.assertEqual( - response, {"message": {"datasource_id": ["Dataset does not exist"]}} + response, {"message": {"datasource_id": ["Datasource does not exist"]}} ) @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") @@ -686,7 +692,13 @@ def test_update_chart_validate_datasource(self): response = json.loads(rv.data.decode("utf-8")) self.assertEqual( response, - {"message": {"datasource_type": ["Must be one of: druid, table, view."]}}, + { + "message": { + "datasource_type": [ + "Must be one of: sl_table, table, dataset, query, saved_query, view." + ] + } + }, ) chart_data = {"datasource_id": 0, "datasource_type": "table"} @@ -694,7 +706,7 @@ def test_update_chart_validate_datasource(self): self.assertEqual(rv.status_code, 422) response = json.loads(rv.data.decode("utf-8")) self.assertEqual( - response, {"message": {"datasource_id": ["Dataset does not exist"]}} + response, {"message": {"datasource_id": ["Datasource does not exist"]}} ) db.session.delete(chart) diff --git a/tests/integration_tests/dashboard_utils.py b/tests/integration_tests/dashboard_utils.py index fa6efd60b4da..41a34fa36edf 100644 --- a/tests/integration_tests/dashboard_utils.py +++ b/tests/integration_tests/dashboard_utils.py @@ -26,7 +26,7 @@ from superset.models.core import Database from superset.models.dashboard import Dashboard from superset.models.slice import Slice -from superset.utils.core import get_example_default_schema +from superset.utils.core import DatasourceType, get_example_default_schema def get_table( @@ -72,7 +72,7 @@ def create_slice( return Slice( slice_name=title, viz_type=viz_type, - datasource_type="table", + datasource_type=DatasourceType.TABLE, datasource_id=table.id, params=json.dumps(slices_dict, indent=4, sort_keys=True), ) diff --git a/tests/integration_tests/explore/form_data/api_tests.py b/tests/integration_tests/explore/form_data/api_tests.py index c05be00e9618..8b375df56ae3 100644 --- a/tests/integration_tests/explore/form_data/api_tests.py +++ b/tests/integration_tests/explore/form_data/api_tests.py @@ -56,7 +56,7 @@ def admin_id() -> int: @pytest.fixture -def dataset_id() -> int: +def datasource() -> int: with app.app_context() as ctx: session: Session = ctx.app.appbuilder.get_session dataset = ( @@ -64,24 +64,26 @@ def dataset_id() -> int: .filter_by(table_name="wb_health_population") .first() ) - return dataset.id + return dataset @pytest.fixture(autouse=True) -def cache(chart_id, admin_id, dataset_id): +def cache(chart_id, admin_id, datasource): entry: TemporaryExploreState = { "owner": admin_id, - "dataset_id": dataset_id, + "datasource_id": datasource.id, + "datasource_type": datasource.type, "chart_id": chart_id, "form_data": INITIAL_FORM_DATA, } cache_manager.explore_form_data_cache.set(KEY, entry) -def test_post(client, chart_id: int, dataset_id: int): +def test_post(client, chart_id: int, datasource: SqlaTable): login(client, "admin") payload = { - "dataset_id": dataset_id, + "datasource_id": datasource.id, + "datasource_type": datasource.type, "chart_id": chart_id, "form_data": INITIAL_FORM_DATA, } @@ -89,10 +91,11 @@ def test_post(client, chart_id: int, dataset_id: int): assert resp.status_code == 201 -def test_post_bad_request_non_string(client, chart_id: int, dataset_id: int): +def test_post_bad_request_non_string(client, chart_id: int, datasource: SqlaTable): login(client, "admin") payload = { - "dataset_id": dataset_id, + "datasource_id": datasource.id, + "datasource_type": datasource.type, "chart_id": chart_id, "form_data": 1234, } @@ -100,10 +103,11 @@ def test_post_bad_request_non_string(client, chart_id: int, dataset_id: int): assert resp.status_code == 400 -def test_post_bad_request_non_json_string(client, chart_id: int, dataset_id: int): +def test_post_bad_request_non_json_string(client, chart_id: int, datasource: SqlaTable): login(client, "admin") payload = { - "dataset_id": dataset_id, + "datasource_id": datasource.id, + "datasource_type": datasource.type, "chart_id": chart_id, "form_data": "foo", } @@ -111,10 +115,11 @@ def test_post_bad_request_non_json_string(client, chart_id: int, dataset_id: int assert resp.status_code == 400 -def test_post_access_denied(client, chart_id: int, dataset_id: int): +def test_post_access_denied(client, chart_id: int, datasource: SqlaTable): login(client, "gamma") payload = { - "dataset_id": dataset_id, + "datasource_id": datasource.id, + "datasource_type": datasource.type, "chart_id": chart_id, "form_data": INITIAL_FORM_DATA, } @@ -122,10 +127,11 @@ def test_post_access_denied(client, chart_id: int, dataset_id: int): assert resp.status_code == 404 -def test_post_same_key_for_same_context(client, chart_id: int, dataset_id: int): +def test_post_same_key_for_same_context(client, chart_id: int, datasource: SqlaTable): login(client, "admin") payload = { - "dataset_id": dataset_id, + "datasource_id": datasource.id, + "datasource_type": datasource.type, "chart_id": chart_id, "form_data": UPDATED_FORM_DATA, } @@ -139,11 +145,12 @@ def test_post_same_key_for_same_context(client, chart_id: int, dataset_id: int): def test_post_different_key_for_different_context( - client, chart_id: int, dataset_id: int + client, chart_id: int, datasource: SqlaTable ): login(client, "admin") payload = { - "dataset_id": dataset_id, + "datasource_id": datasource.id, + "datasource_type": datasource.type, "chart_id": chart_id, "form_data": UPDATED_FORM_DATA, } @@ -151,7 +158,8 @@ def test_post_different_key_for_different_context( data = json.loads(resp.data.decode("utf-8")) first_key = data.get("key") payload = { - "dataset_id": dataset_id, + "datasource_id": datasource.id, + "datasource_type": datasource.type, "form_data": json.dumps({"test": "initial value"}), } resp = client.post("api/v1/explore/form_data?tab_id=1", json=payload) @@ -160,10 +168,11 @@ def test_post_different_key_for_different_context( assert first_key != second_key -def test_post_same_key_for_same_tab_id(client, chart_id: int, dataset_id: int): +def test_post_same_key_for_same_tab_id(client, chart_id: int, datasource: SqlaTable): login(client, "admin") payload = { - "dataset_id": dataset_id, + "datasource_id": datasource.id, + "datasource_type": datasource.type, "chart_id": chart_id, "form_data": json.dumps({"test": "initial value"}), } @@ -177,11 +186,12 @@ def test_post_same_key_for_same_tab_id(client, chart_id: int, dataset_id: int): def test_post_different_key_for_different_tab_id( - client, chart_id: int, dataset_id: int + client, chart_id: int, datasource: SqlaTable ): login(client, "admin") payload = { - "dataset_id": dataset_id, + "datasource_id": datasource.id, + "datasource_type": datasource.type, "chart_id": chart_id, "form_data": json.dumps({"test": "initial value"}), } @@ -194,10 +204,11 @@ def test_post_different_key_for_different_tab_id( assert first_key != second_key -def test_post_different_key_for_no_tab_id(client, chart_id: int, dataset_id: int): +def test_post_different_key_for_no_tab_id(client, chart_id: int, datasource: SqlaTable): login(client, "admin") payload = { - "dataset_id": dataset_id, + "datasource_id": datasource.id, + "datasource_type": datasource.type, "chart_id": chart_id, "form_data": INITIAL_FORM_DATA, } @@ -210,10 +221,11 @@ def test_post_different_key_for_no_tab_id(client, chart_id: int, dataset_id: int assert first_key != second_key -def test_put(client, chart_id: int, dataset_id: int): +def test_put(client, chart_id: int, datasource: SqlaTable): login(client, "admin") payload = { - "dataset_id": dataset_id, + "datasource_id": datasource.id, + "datasource_type": datasource.type, "chart_id": chart_id, "form_data": UPDATED_FORM_DATA, } @@ -221,10 +233,11 @@ def test_put(client, chart_id: int, dataset_id: int): assert resp.status_code == 200 -def test_put_same_key_for_same_tab_id(client, chart_id: int, dataset_id: int): +def test_put_same_key_for_same_tab_id(client, chart_id: int, datasource: SqlaTable): login(client, "admin") payload = { - "dataset_id": dataset_id, + "datasource_id": datasource.id, + "datasource_type": datasource.type, "chart_id": chart_id, "form_data": UPDATED_FORM_DATA, } @@ -237,10 +250,13 @@ def test_put_same_key_for_same_tab_id(client, chart_id: int, dataset_id: int): assert first_key == second_key -def test_put_different_key_for_different_tab_id(client, chart_id: int, dataset_id: int): +def test_put_different_key_for_different_tab_id( + client, chart_id: int, datasource: SqlaTable +): login(client, "admin") payload = { - "dataset_id": dataset_id, + "datasource_id": datasource.id, + "datasource_type": datasource.type, "chart_id": chart_id, "form_data": UPDATED_FORM_DATA, } @@ -253,10 +269,11 @@ def test_put_different_key_for_different_tab_id(client, chart_id: int, dataset_i assert first_key != second_key -def test_put_different_key_for_no_tab_id(client, chart_id: int, dataset_id: int): +def test_put_different_key_for_no_tab_id(client, chart_id: int, datasource: SqlaTable): login(client, "admin") payload = { - "dataset_id": dataset_id, + "datasource_id": datasource.id, + "datasource_type": datasource.type, "chart_id": chart_id, "form_data": UPDATED_FORM_DATA, } @@ -269,10 +286,11 @@ def test_put_different_key_for_no_tab_id(client, chart_id: int, dataset_id: int) assert first_key != second_key -def test_put_bad_request(client, chart_id: int, dataset_id: int): +def test_put_bad_request(client, chart_id: int, datasource: SqlaTable): login(client, "admin") payload = { - "dataset_id": dataset_id, + "datasource_id": datasource.id, + "datasource_type": datasource.type, "chart_id": chart_id, "form_data": 1234, } @@ -280,10 +298,11 @@ def test_put_bad_request(client, chart_id: int, dataset_id: int): assert resp.status_code == 400 -def test_put_bad_request_non_string(client, chart_id: int, dataset_id: int): +def test_put_bad_request_non_string(client, chart_id: int, datasource: SqlaTable): login(client, "admin") payload = { - "dataset_id": dataset_id, + "datasource_id": datasource.id, + "datasource_type": datasource.type, "chart_id": chart_id, "form_data": 1234, } @@ -291,10 +310,11 @@ def test_put_bad_request_non_string(client, chart_id: int, dataset_id: int): assert resp.status_code == 400 -def test_put_bad_request_non_json_string(client, chart_id: int, dataset_id: int): +def test_put_bad_request_non_json_string(client, chart_id: int, datasource: SqlaTable): login(client, "admin") payload = { - "dataset_id": dataset_id, + "datasource_id": datasource.id, + "datasource_type": datasource.type, "chart_id": chart_id, "form_data": "foo", } @@ -302,10 +322,11 @@ def test_put_bad_request_non_json_string(client, chart_id: int, dataset_id: int) assert resp.status_code == 400 -def test_put_access_denied(client, chart_id: int, dataset_id: int): +def test_put_access_denied(client, chart_id: int, datasource: SqlaTable): login(client, "gamma") payload = { - "dataset_id": dataset_id, + "datasource_id": datasource.id, + "datasource_type": datasource.type, "chart_id": chart_id, "form_data": UPDATED_FORM_DATA, } @@ -313,10 +334,11 @@ def test_put_access_denied(client, chart_id: int, dataset_id: int): assert resp.status_code == 404 -def test_put_not_owner(client, chart_id: int, dataset_id: int): +def test_put_not_owner(client, chart_id: int, datasource: SqlaTable): login(client, "gamma") payload = { - "dataset_id": dataset_id, + "datasource_id": datasource.id, + "datasource_type": datasource.type, "chart_id": chart_id, "form_data": UPDATED_FORM_DATA, } @@ -364,12 +386,13 @@ def test_delete_access_denied(client): assert resp.status_code == 404 -def test_delete_not_owner(client, chart_id: int, dataset_id: int, admin_id: int): +def test_delete_not_owner(client, chart_id: int, datasource: SqlaTable, admin_id: int): another_key = "another_key" another_owner = admin_id + 1 entry: TemporaryExploreState = { "owner": another_owner, - "dataset_id": dataset_id, + "datasource_id": datasource.id, + "datasource_type": datasource.type, "chart_id": chart_id, "form_data": INITIAL_FORM_DATA, } diff --git a/tests/integration_tests/explore/form_data/commands_tests.py b/tests/integration_tests/explore/form_data/commands_tests.py new file mode 100644 index 000000000000..4db48cfa7973 --- /dev/null +++ b/tests/integration_tests/explore/form_data/commands_tests.py @@ -0,0 +1,359 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import json +from unittest.mock import patch + +import pytest + +from superset import app, db, security, security_manager +from superset.commands.exceptions import DatasourceTypeInvalidError +from superset.connectors.sqla.models import SqlaTable +from superset.explore.form_data.commands.create import CreateFormDataCommand +from superset.explore.form_data.commands.delete import DeleteFormDataCommand +from superset.explore.form_data.commands.get import GetFormDataCommand +from superset.explore.form_data.commands.parameters import CommandParameters +from superset.explore.form_data.commands.update import UpdateFormDataCommand +from superset.models.slice import Slice +from superset.models.sql_lab import Query +from superset.utils.core import DatasourceType, get_example_default_schema +from superset.utils.database import get_example_database +from tests.integration_tests.base_tests import SupersetTestCase + + +class TestCreateFormDataCommand(SupersetTestCase): + @pytest.fixture() + def create_dataset(self): + with self.create_app().app_context(): + dataset = SqlaTable( + table_name="dummy_sql_table", + database=get_example_database(), + schema=get_example_default_schema(), + sql="select 123 as intcol, 'abc' as strcol", + ) + session = db.session + session.add(dataset) + session.commit() + + yield dataset + + # rollback + session.delete(dataset) + session.commit() + + @pytest.fixture() + def create_slice(self): + with self.create_app().app_context(): + session = db.session + dataset = ( + session.query(SqlaTable).filter_by(table_name="dummy_sql_table").first() + ) + slice = Slice( + datasource_id=dataset.id, + datasource_type=DatasourceType.TABLE, + datasource_name="tmp_perm_table", + slice_name="slice_name", + ) + + session.add(slice) + session.commit() + + yield slice + + # rollback + session.delete(slice) + session.commit() + + @pytest.fixture() + def create_query(self): + with self.create_app().app_context(): + session = db.session + + query = Query( + sql="select 1 as foo;", + client_id="sldkfjlk", + database=get_example_database(), + ) + + session.add(query) + session.commit() + + yield query + + # rollback + session.delete(query) + session.commit() + + @patch("superset.security.manager.g") + @pytest.mark.usefixtures("create_dataset", "create_slice") + def test_create_form_data_command(self, mock_g): + mock_g.user = security_manager.find_user("admin") + + dataset = ( + db.session.query(SqlaTable).filter_by(table_name="dummy_sql_table").first() + ) + slice = db.session.query(Slice).filter_by(slice_name="slice_name").first() + + datasource = f"{dataset.id}__{DatasourceType.TABLE}" + args = CommandParameters( + actor=mock_g.user, + datasource_id=dataset.id, + datasource_type=DatasourceType.TABLE, + chart_id=slice.id, + tab_id=1, + form_data=json.dumps({"datasource": datasource}), + ) + command = CreateFormDataCommand(args) + + assert isinstance(command.run(), str) + + @patch("superset.security.manager.g") + @pytest.mark.usefixtures("create_dataset", "create_slice", "create_query") + def test_create_form_data_command_invalid_type(self, mock_g): + mock_g.user = security_manager.find_user("admin") + app.config["EXPLORE_FORM_DATA_CACHE_CONFIG"] = { + "REFRESH_TIMEOUT_ON_RETRIEVAL": True + } + + dataset = ( + db.session.query(SqlaTable).filter_by(table_name="dummy_sql_table").first() + ) + slice = db.session.query(Slice).filter_by(slice_name="slice_name").first() + + datasource = f"{dataset.id}__{DatasourceType.TABLE}" + create_args = CommandParameters( + actor=mock_g.user, + datasource_id=dataset.id, + datasource_type="InvalidType", + chart_id=slice.id, + tab_id=1, + form_data=json.dumps({"datasource": datasource}), + ) + with pytest.raises(DatasourceTypeInvalidError) as exc: + CreateFormDataCommand(create_args).run() + + assert "Datasource type is invalid" in str(exc.value) + + @patch("superset.security.manager.g") + @pytest.mark.usefixtures("create_dataset", "create_slice", "create_query") + def test_create_form_data_command_type_as_string(self, mock_g): + mock_g.user = security_manager.find_user("admin") + app.config["EXPLORE_FORM_DATA_CACHE_CONFIG"] = { + "REFRESH_TIMEOUT_ON_RETRIEVAL": True + } + + dataset = ( + db.session.query(SqlaTable).filter_by(table_name="dummy_sql_table").first() + ) + slice = db.session.query(Slice).filter_by(slice_name="slice_name").first() + + datasource = f"{dataset.id}__{DatasourceType.TABLE}" + create_args = CommandParameters( + actor=mock_g.user, + datasource_id=dataset.id, + datasource_type="table", + chart_id=slice.id, + tab_id=1, + form_data=json.dumps({"datasource": datasource}), + ) + command = CreateFormDataCommand(create_args) + + assert isinstance(command.run(), str) + + @patch("superset.security.manager.g") + @pytest.mark.usefixtures("create_dataset", "create_slice") + def test_get_form_data_command(self, mock_g): + mock_g.user = security_manager.find_user("admin") + app.config["EXPLORE_FORM_DATA_CACHE_CONFIG"] = { + "REFRESH_TIMEOUT_ON_RETRIEVAL": True + } + + dataset = ( + db.session.query(SqlaTable).filter_by(table_name="dummy_sql_table").first() + ) + slice = db.session.query(Slice).filter_by(slice_name="slice_name").first() + + datasource = f"{dataset.id}__{DatasourceType.TABLE}" + create_args = CommandParameters( + actor=mock_g.user, + datasource_id=dataset.id, + datasource_type=DatasourceType.TABLE, + chart_id=slice.id, + tab_id=1, + form_data=json.dumps({"datasource": datasource}), + ) + key = CreateFormDataCommand(create_args).run() + + key_args = CommandParameters(actor=mock_g.user, key=key) + get_command = GetFormDataCommand(key_args) + cache_data = json.loads(get_command.run()) + + assert cache_data.get("datasource") == datasource + + @patch("superset.security.manager.g") + @pytest.mark.usefixtures("create_dataset", "create_slice", "create_query") + def test_update_form_data_command(self, mock_g): + mock_g.user = security_manager.find_user("admin") + app.config["EXPLORE_FORM_DATA_CACHE_CONFIG"] = { + "REFRESH_TIMEOUT_ON_RETRIEVAL": True + } + + dataset = ( + db.session.query(SqlaTable).filter_by(table_name="dummy_sql_table").first() + ) + slice = db.session.query(Slice).filter_by(slice_name="slice_name").first() + + query = db.session.query(Query).filter_by(sql="select 1 as foo;").first() + + datasource = f"{dataset.id}__{DatasourceType.TABLE}" + create_args = CommandParameters( + actor=mock_g.user, + datasource_id=dataset.id, + datasource_type=DatasourceType.TABLE, + chart_id=slice.id, + tab_id=1, + form_data=json.dumps({"datasource": datasource}), + ) + key = CreateFormDataCommand(create_args).run() + + query_datasource = f"{dataset.id}__{DatasourceType.TABLE}" + update_args = CommandParameters( + actor=mock_g.user, + datasource_id=query.id, + datasource_type=DatasourceType.QUERY, + chart_id=slice.id, + tab_id=1, + form_data=json.dumps({"datasource": query_datasource}), + key=key, + ) + + update_command = UpdateFormDataCommand(update_args) + new_key = update_command.run() + + # it should return a key + assert isinstance(new_key, str) + # the updated key returned should be different from the old one + assert new_key != key + + key_args = CommandParameters(actor=mock_g.user, key=key) + get_command = GetFormDataCommand(key_args) + + cache_data = json.loads(get_command.run()) + + assert cache_data.get("datasource") == query_datasource + + @patch("superset.security.manager.g") + @pytest.mark.usefixtures("create_dataset", "create_slice", "create_query") + def test_update_form_data_command_same_form_data(self, mock_g): + mock_g.user = security_manager.find_user("admin") + app.config["EXPLORE_FORM_DATA_CACHE_CONFIG"] = { + "REFRESH_TIMEOUT_ON_RETRIEVAL": True + } + + dataset = ( + db.session.query(SqlaTable).filter_by(table_name="dummy_sql_table").first() + ) + slice = db.session.query(Slice).filter_by(slice_name="slice_name").first() + + datasource = f"{dataset.id}__{DatasourceType.TABLE}" + create_args = CommandParameters( + actor=mock_g.user, + datasource_id=dataset.id, + datasource_type=DatasourceType.TABLE, + chart_id=slice.id, + tab_id=1, + form_data=json.dumps({"datasource": datasource}), + ) + key = CreateFormDataCommand(create_args).run() + + update_args = CommandParameters( + actor=mock_g.user, + datasource_id=dataset.id, + datasource_type=DatasourceType.TABLE, + chart_id=slice.id, + tab_id=1, + form_data=json.dumps({"datasource": datasource}), + key=key, + ) + + update_command = UpdateFormDataCommand(update_args) + new_key = update_command.run() + + # it should return a key + assert isinstance(new_key, str) + + # the updated key returned should be the same as the old one + assert new_key == key + + key_args = CommandParameters(actor=mock_g.user, key=key) + get_command = GetFormDataCommand(key_args) + + cache_data = json.loads(get_command.run()) + + assert cache_data.get("datasource") == datasource + + @patch("superset.security.manager.g") + @pytest.mark.usefixtures("create_dataset", "create_slice", "create_query") + def test_delete_form_data_command(self, mock_g): + mock_g.user = security_manager.find_user("admin") + app.config["EXPLORE_FORM_DATA_CACHE_CONFIG"] = { + "REFRESH_TIMEOUT_ON_RETRIEVAL": True + } + + dataset = ( + db.session.query(SqlaTable).filter_by(table_name="dummy_sql_table").first() + ) + slice = db.session.query(Slice).filter_by(slice_name="slice_name").first() + + datasource = f"{dataset.id}__{DatasourceType.TABLE}" + create_args = CommandParameters( + actor=mock_g.user, + datasource_id=dataset.id, + datasource_type=DatasourceType.TABLE, + chart_id=slice.id, + tab_id=1, + form_data=json.dumps({"datasource": datasource}), + ) + key = CreateFormDataCommand(create_args).run() + + delete_args = CommandParameters( + actor=mock_g.user, + key=key, + ) + + delete_command = DeleteFormDataCommand(delete_args) + response = delete_command.run() + + assert response == True + + @patch("superset.security.manager.g") + @pytest.mark.usefixtures("create_dataset", "create_slice", "create_query") + def test_delete_form_data_command_key_expired(self, mock_g): + mock_g.user = security_manager.find_user("admin") + app.config["EXPLORE_FORM_DATA_CACHE_CONFIG"] = { + "REFRESH_TIMEOUT_ON_RETRIEVAL": True + } + + delete_args = CommandParameters( + actor=mock_g.user, + key="some_expired_key", + ) + + delete_command = DeleteFormDataCommand(delete_args) + response = delete_command.run() + + assert response == False diff --git a/tests/integration_tests/explore/permalink/api_tests.py b/tests/integration_tests/explore/permalink/api_tests.py index a44bc70a7b49..b5228ab301b2 100644 --- a/tests/integration_tests/explore/permalink/api_tests.py +++ b/tests/integration_tests/explore/permalink/api_tests.py @@ -27,6 +27,7 @@ from superset.key_value.types import KeyValueResource from superset.key_value.utils import decode_permalink_id, encode_permalink_key from superset.models.slice import Slice +from superset.utils.core import DatasourceType from tests.integration_tests.base_tests import login from tests.integration_tests.fixtures.client import client from tests.integration_tests.fixtures.world_bank_dashboard import ( @@ -97,7 +98,8 @@ def test_get_missing_chart(client, chart, permalink_salt: str) -> None: value=pickle.dumps( { "chartId": chart_id, - "datasetId": chart.datasource.id, + "datasourceId": chart.datasource.id, + "datasourceType": DatasourceType.TABLE, "formData": { "slice_id": chart_id, "datasource": f"{chart.datasource.id}__{chart.datasource.type}", diff --git a/tests/integration_tests/import_export_tests.py b/tests/integration_tests/import_export_tests.py index 6d7d581ec6d4..81acda80185c 100644 --- a/tests/integration_tests/import_export_tests.py +++ b/tests/integration_tests/import_export_tests.py @@ -40,7 +40,7 @@ from superset.datasets.commands.importers.v0 import import_dataset from superset.models.dashboard import Dashboard from superset.models.slice import Slice -from superset.utils.core import get_example_default_schema +from superset.utils.core import DatasourceType, get_example_default_schema from superset.utils.database import get_example_database from tests.integration_tests.fixtures.world_bank_dashboard import ( @@ -103,7 +103,7 @@ def create_slice( return Slice( slice_name=name, - datasource_type="table", + datasource_type=DatasourceType.TABLE, viz_type="bubble", params=json.dumps(params), datasource_id=ds_id, diff --git a/tests/integration_tests/model_tests.py b/tests/integration_tests/model_tests.py index ace75da35a88..a1791db34bff 100644 --- a/tests/integration_tests/model_tests.py +++ b/tests/integration_tests/model_tests.py @@ -16,6 +16,7 @@ # under the License. # isort:skip_file import json +from superset.utils.core import DatasourceType import textwrap import unittest from unittest import mock @@ -604,7 +605,7 @@ def test_data_for_slices_with_adhoc_column(self): dashboard = self.get_dash_by_slug("births") slc = Slice( slice_name="slice with adhoc column", - datasource_type="table", + datasource_type=DatasourceType.TABLE, viz_type="table", params=json.dumps( { diff --git a/tests/integration_tests/security_tests.py b/tests/integration_tests/security_tests.py index c44335552b01..e66bf02e82cb 100644 --- a/tests/integration_tests/security_tests.py +++ b/tests/integration_tests/security_tests.py @@ -39,6 +39,7 @@ from superset.models.slice import Slice from superset.sql_parse import Table from superset.utils.core import ( + DatasourceType, backend, get_example_default_schema, ) @@ -120,7 +121,7 @@ def setUp(self): ds_slices = ( session.query(Slice) - .filter_by(datasource_type="table") + .filter_by(datasource_type=DatasourceType.TABLE) .filter_by(datasource_id=ds.id) .all() ) @@ -143,7 +144,7 @@ def tearDown(self): ds.schema_perm = None ds_slices = ( session.query(Slice) - .filter_by(datasource_type="table") + .filter_by(datasource_type=DatasourceType.TABLE) .filter_by(datasource_id=ds.id) .all() ) @@ -365,7 +366,7 @@ def test_set_perm_slice(self): # no schema permission slice = Slice( datasource_id=table.id, - datasource_type="table", + datasource_type=DatasourceType.TABLE, datasource_name="tmp_perm_table", slice_name="slice_name", ) diff --git a/tests/integration_tests/utils/cache_manager_tests.py b/tests/integration_tests/utils/cache_manager_tests.py new file mode 100644 index 000000000000..c5d4b390f9c9 --- /dev/null +++ b/tests/integration_tests/utils/cache_manager_tests.py @@ -0,0 +1,49 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest + +from superset.extensions import cache_manager +from superset.utils.core import backend, DatasourceType +from tests.integration_tests.base_tests import SupersetTestCase + + +class UtilsCacheManagerTests(SupersetTestCase): + def test_get_set_explore_form_data_cache(self): + key = "12345" + data = {"foo": "bar", "datasource_type": "query"} + cache_manager.explore_form_data_cache.set(key, data) + assert cache_manager.explore_form_data_cache.get(key) == data + + def test_get_same_context_twice(self): + key = "12345" + data = {"foo": "bar", "datasource_type": "query"} + cache_manager.explore_form_data_cache.set(key, data) + assert cache_manager.explore_form_data_cache.get(key) == data + assert cache_manager.explore_form_data_cache.get(key) == data + + def test_get_set_explore_form_data_cache_no_datasource_type(self): + key = "12345" + data = {"foo": "bar"} + cache_manager.explore_form_data_cache.set(key, data) + # datasource_type should be added because it is not present + assert cache_manager.explore_form_data_cache.get(key) == { + "datasource_type": DatasourceType.TABLE, + **data, + } + + def test_get_explore_form_data_cache_invalid_key(self): + assert cache_manager.explore_form_data_cache.get("foo") == None diff --git a/tests/unit_tests/dao/datasource_test.py b/tests/unit_tests/dao/datasource_test.py index dd0db265e7a0..a15684d71e69 100644 --- a/tests/unit_tests/dao/datasource_test.py +++ b/tests/unit_tests/dao/datasource_test.py @@ -106,7 +106,7 @@ def test_get_datasource_sqlatable( from superset.dao.datasource.dao import DatasourceDAO result = DatasourceDAO.get_datasource( - datasource_type=DatasourceType.SQLATABLE, + datasource_type=DatasourceType.TABLE, datasource_id=1, session=session_with_data, ) @@ -151,7 +151,9 @@ def test_get_datasource_sl_table(app_context: None, session_with_data: Session) # todo(hugh): This will break once we remove the dual write # update the datsource_id=1 and this will pass again result = DatasourceDAO.get_datasource( - datasource_type=DatasourceType.TABLE, datasource_id=2, session=session_with_data + datasource_type=DatasourceType.SLTABLE, + datasource_id=2, + session=session_with_data, ) assert result.id == 2 diff --git a/tests/unit_tests/explore/utils_test.py b/tests/unit_tests/explore/utils_test.py index 11e2906ed88d..9ef92872177e 100644 --- a/tests/unit_tests/explore/utils_test.py +++ b/tests/unit_tests/explore/utils_test.py @@ -23,12 +23,21 @@ ChartAccessDeniedError, ChartNotFoundError, ) +from superset.commands.exceptions import ( + DatasourceNotFoundValidationError, + DatasourceTypeInvalidError, + OwnersNotFoundValidationError, + QueryNotFoundValidationError, +) from superset.datasets.commands.exceptions import ( DatasetAccessDeniedError, DatasetNotFoundError, ) +from superset.exceptions import SupersetSecurityException +from superset.utils.core import DatasourceType dataset_find_by_id = "superset.datasets.dao.DatasetDAO.find_by_id" +query_find_by_id = "superset.queries.dao.QueryDAO.find_by_id" chart_find_by_id = "superset.charts.dao.ChartDAO.find_by_id" is_user_admin = "superset.explore.utils.is_user_admin" is_owner = "superset.explore.utils.is_owner" @@ -36,88 +45,142 @@ "superset.security.SupersetSecurityManager.can_access_datasource" ) can_access = "superset.security.SupersetSecurityManager.can_access" +raise_for_access = "superset.security.SupersetSecurityManager.raise_for_access" +query_datasources_by_name = ( + "superset.connectors.sqla.models.SqlaTable.query_datasources_by_name" +) def test_unsaved_chart_no_dataset_id(app_context: AppContext) -> None: - from superset.explore.utils import check_access + from superset.explore.utils import check_access as check_chart_access - with raises(DatasetNotFoundError): - check_access(dataset_id=0, chart_id=0, actor=User()) + with raises(DatasourceNotFoundValidationError): + check_chart_access( + datasource_id=0, + chart_id=0, + actor=User(), + datasource_type=DatasourceType.TABLE, + ) def test_unsaved_chart_unknown_dataset_id( mocker: MockFixture, app_context: AppContext ) -> None: - from superset.explore.utils import check_access + from superset.explore.utils import check_access as check_chart_access with raises(DatasetNotFoundError): mocker.patch(dataset_find_by_id, return_value=None) - check_access(dataset_id=1, chart_id=0, actor=User()) + check_chart_access( + datasource_id=1, + chart_id=0, + actor=User(), + datasource_type=DatasourceType.TABLE, + ) + + +def test_unsaved_chart_unknown_query_id( + mocker: MockFixture, app_context: AppContext +) -> None: + from superset.explore.utils import check_access as check_chart_access + + with raises(QueryNotFoundValidationError): + mocker.patch(query_find_by_id, return_value=None) + check_chart_access( + datasource_id=1, + chart_id=0, + actor=User(), + datasource_type=DatasourceType.QUERY, + ) def test_unsaved_chart_unauthorized_dataset( mocker: MockFixture, app_context: AppContext ) -> None: from superset.connectors.sqla.models import SqlaTable - from superset.explore import utils + from superset.explore.utils import check_access as check_chart_access with raises(DatasetAccessDeniedError): mocker.patch(dataset_find_by_id, return_value=SqlaTable()) mocker.patch(can_access_datasource, return_value=False) - utils.check_access(dataset_id=1, chart_id=0, actor=User()) + check_chart_access( + datasource_id=1, + chart_id=0, + actor=User(), + datasource_type=DatasourceType.TABLE, + ) def test_unsaved_chart_authorized_dataset( mocker: MockFixture, app_context: AppContext ) -> None: from superset.connectors.sqla.models import SqlaTable - from superset.explore.utils import check_access + from superset.explore.utils import check_access as check_chart_access mocker.patch(dataset_find_by_id, return_value=SqlaTable()) mocker.patch(can_access_datasource, return_value=True) - check_access(dataset_id=1, chart_id=0, actor=User()) + check_chart_access( + datasource_id=1, + chart_id=0, + actor=User(), + datasource_type=DatasourceType.TABLE, + ) def test_saved_chart_unknown_chart_id( mocker: MockFixture, app_context: AppContext ) -> None: from superset.connectors.sqla.models import SqlaTable - from superset.explore.utils import check_access + from superset.explore.utils import check_access as check_chart_access with raises(ChartNotFoundError): mocker.patch(dataset_find_by_id, return_value=SqlaTable()) mocker.patch(can_access_datasource, return_value=True) mocker.patch(chart_find_by_id, return_value=None) - check_access(dataset_id=1, chart_id=1, actor=User()) + check_chart_access( + datasource_id=1, + chart_id=1, + actor=User(), + datasource_type=DatasourceType.TABLE, + ) def test_saved_chart_unauthorized_dataset( mocker: MockFixture, app_context: AppContext ) -> None: from superset.connectors.sqla.models import SqlaTable - from superset.explore import utils + from superset.explore.utils import check_access as check_chart_access with raises(DatasetAccessDeniedError): mocker.patch(dataset_find_by_id, return_value=SqlaTable()) mocker.patch(can_access_datasource, return_value=False) - utils.check_access(dataset_id=1, chart_id=1, actor=User()) + check_chart_access( + datasource_id=1, + chart_id=1, + actor=User(), + datasource_type=DatasourceType.TABLE, + ) def test_saved_chart_is_admin(mocker: MockFixture, app_context: AppContext) -> None: from superset.connectors.sqla.models import SqlaTable - from superset.explore.utils import check_access + from superset.explore.utils import check_access as check_chart_access from superset.models.slice import Slice mocker.patch(dataset_find_by_id, return_value=SqlaTable()) mocker.patch(can_access_datasource, return_value=True) mocker.patch(is_user_admin, return_value=True) mocker.patch(chart_find_by_id, return_value=Slice()) - check_access(dataset_id=1, chart_id=1, actor=User()) + check_chart_access( + datasource_id=1, + chart_id=1, + actor=User(), + datasource_type=DatasourceType.TABLE, + ) def test_saved_chart_is_owner(mocker: MockFixture, app_context: AppContext) -> None: from superset.connectors.sqla.models import SqlaTable - from superset.explore.utils import check_access + from superset.explore.utils import check_access as check_chart_access from superset.models.slice import Slice mocker.patch(dataset_find_by_id, return_value=SqlaTable()) @@ -125,12 +188,17 @@ def test_saved_chart_is_owner(mocker: MockFixture, app_context: AppContext) -> N mocker.patch(is_user_admin, return_value=False) mocker.patch(is_owner, return_value=True) mocker.patch(chart_find_by_id, return_value=Slice()) - check_access(dataset_id=1, chart_id=1, actor=User()) + check_chart_access( + datasource_id=1, + chart_id=1, + actor=User(), + datasource_type=DatasourceType.TABLE, + ) def test_saved_chart_has_access(mocker: MockFixture, app_context: AppContext) -> None: from superset.connectors.sqla.models import SqlaTable - from superset.explore.utils import check_access + from superset.explore.utils import check_access as check_chart_access from superset.models.slice import Slice mocker.patch(dataset_find_by_id, return_value=SqlaTable()) @@ -139,12 +207,17 @@ def test_saved_chart_has_access(mocker: MockFixture, app_context: AppContext) -> mocker.patch(is_owner, return_value=False) mocker.patch(can_access, return_value=True) mocker.patch(chart_find_by_id, return_value=Slice()) - check_access(dataset_id=1, chart_id=1, actor=User()) + check_chart_access( + datasource_id=1, + chart_id=1, + actor=User(), + datasource_type=DatasourceType.TABLE, + ) def test_saved_chart_no_access(mocker: MockFixture, app_context: AppContext) -> None: from superset.connectors.sqla.models import SqlaTable - from superset.explore.utils import check_access + from superset.explore.utils import check_access as check_chart_access from superset.models.slice import Slice with raises(ChartAccessDeniedError): @@ -154,4 +227,66 @@ def test_saved_chart_no_access(mocker: MockFixture, app_context: AppContext) -> mocker.patch(is_owner, return_value=False) mocker.patch(can_access, return_value=False) mocker.patch(chart_find_by_id, return_value=Slice()) - check_access(dataset_id=1, chart_id=1, actor=User()) + check_chart_access( + datasource_id=1, + chart_id=1, + actor=User(), + datasource_type=DatasourceType.TABLE, + ) + + +def test_dataset_has_access(mocker: MockFixture, app_context: AppContext) -> None: + from superset.connectors.sqla.models import SqlaTable + from superset.explore.utils import check_datasource_access + + mocker.patch(dataset_find_by_id, return_value=SqlaTable()) + mocker.patch(can_access_datasource, return_value=True) + mocker.patch(is_user_admin, return_value=False) + mocker.patch(is_owner, return_value=False) + mocker.patch(can_access, return_value=True) + assert ( + check_datasource_access( + datasource_id=1, + datasource_type=DatasourceType.TABLE, + ) + == True + ) + + +def test_query_has_access(mocker: MockFixture, app_context: AppContext) -> None: + from superset.explore.utils import check_datasource_access + from superset.models.sql_lab import Query + + mocker.patch(query_find_by_id, return_value=Query()) + mocker.patch(raise_for_access, return_value=True) + mocker.patch(is_user_admin, return_value=False) + mocker.patch(is_owner, return_value=False) + mocker.patch(can_access, return_value=True) + assert ( + check_datasource_access( + datasource_id=1, + datasource_type=DatasourceType.QUERY, + ) + == True + ) + + +def test_query_no_access(mocker: MockFixture, app_context: AppContext) -> None: + from superset.connectors.sqla.models import SqlaTable + from superset.explore.utils import check_datasource_access + from superset.models.core import Database + from superset.models.sql_lab import Query + + with raises(SupersetSecurityException): + mocker.patch( + query_find_by_id, + return_value=Query(database=Database(), sql="select * from foo"), + ) + mocker.patch(query_datasources_by_name, return_value=[SqlaTable()]) + mocker.patch(is_user_admin, return_value=False) + mocker.patch(is_owner, return_value=False) + mocker.patch(can_access, return_value=False) + check_datasource_access( + datasource_id=1, + datasource_type=DatasourceType.QUERY, + )