diff --git a/superset/charts/commands/update.py b/superset/charts/commands/update.py index e613222b36f7..042c85a930f9 100644 --- a/superset/charts/commands/update.py +++ b/superset/charts/commands/update.py @@ -105,7 +105,10 @@ def validate(self) -> None: # Validate/Populate dashboards only if it's a list if dashboard_ids is not None: - dashboards = DashboardDAO.find_by_ids(dashboard_ids) + dashboards = DashboardDAO.find_by_ids( + dashboard_ids, + skip_base_filter=True, + ) if len(dashboards) != len(dashboard_ids): exceptions.append(DashboardsNotFoundValidationError()) self._properties["dashboards"] = dashboards diff --git a/superset/dao/base.py b/superset/dao/base.py index c6890e53a5ce..126238f66132 100644 --- a/superset/dao/base.py +++ b/superset/dao/base.py @@ -73,16 +73,22 @@ def find_by_id( return None @classmethod - def find_by_ids(cls, model_ids: Union[List[str], List[int]]) -> List[Model]: + def find_by_ids( + cls, + model_ids: Union[List[str], List[int]], + session: Session = None, + skip_base_filter: bool = False, + ) -> List[Model]: """ Find a List of models by a list of ids, if defined applies `base_filter` """ id_col = getattr(cls.model_cls, cls.id_column_name, None) if id_col is None: return [] - query = db.session.query(cls.model_cls).filter(id_col.in_(model_ids)) - if cls.base_filter: - data_model = SQLAInterface(cls.model_cls, db.session) + session = session or db.session + query = session.query(cls.model_cls).filter(id_col.in_(model_ids)) + if cls.base_filter and not skip_base_filter: + data_model = SQLAInterface(cls.model_cls, session) query = cls.base_filter( # pylint: disable=not-callable cls.id_column_name, data_model ).apply(query, None) diff --git a/tests/integration_tests/charts/api_tests.py b/tests/integration_tests/charts/api_tests.py index 3d8a4695f4eb..965a9c137ba8 100644 --- a/tests/integration_tests/charts/api_tests.py +++ b/tests/integration_tests/charts/api_tests.py @@ -692,6 +692,61 @@ def test_update_chart_not_owned(self): db.session.delete(user_alpha2) db.session.commit() + def test_update_chart_linked_with_not_owned_dashboard(self): + """ + Chart API: Test update chart which is linked to not owned dashboard + """ + user_alpha1 = self.create_user( + "alpha1", "password", "Alpha", email="alpha1@superset.org" + ) + user_alpha2 = self.create_user( + "alpha2", "password", "Alpha", email="alpha2@superset.org" + ) + chart = self.insert_chart("title", [user_alpha1.id], 1) + + original_dashboard = Dashboard() + original_dashboard.dashboard_title = "Original Dashboard" + original_dashboard.slug = "slug" + original_dashboard.owners = [user_alpha1] + original_dashboard.slices = [chart] + original_dashboard.published = False + db.session.add(original_dashboard) + + new_dashboard = Dashboard() + new_dashboard.dashboard_title = "Cloned Dashboard" + new_dashboard.slug = "new_slug" + new_dashboard.owners = [user_alpha2] + new_dashboard.slices = [chart] + new_dashboard.published = False + db.session.add(new_dashboard) + + self.login(username="alpha1", password="password") + chart_data_with_invalid_dashboard = { + "slice_name": "title1_changed", + "dashboards": [original_dashboard.id, 0], + } + chart_data = { + "slice_name": "title1_changed", + "dashboards": [original_dashboard.id, new_dashboard.id], + } + uri = f"api/v1/chart/{chart.id}" + + rv = self.put_assert_metric(uri, chart_data_with_invalid_dashboard, "put") + self.assertEqual(rv.status_code, 422) + response = json.loads(rv.data.decode("utf-8")) + expected_response = {"message": {"dashboards": ["Dashboards do not exist"]}} + self.assertEqual(response, expected_response) + + rv = self.put_assert_metric(uri, chart_data, "put") + self.assertEqual(rv.status_code, 200) + + db.session.delete(chart) + db.session.delete(original_dashboard) + db.session.delete(new_dashboard) + db.session.delete(user_alpha1) + db.session.delete(user_alpha2) + db.session.commit() + def test_update_chart_validate_datasource(self): """ Chart API: Test update validate datasource diff --git a/tests/unit_tests/datasets/dao/dao_tests.py b/tests/unit_tests/datasets/dao/dao_tests.py index 31aa9f27d085..350425d08e89 100644 --- a/tests/unit_tests/datasets/dao/dao_tests.py +++ b/tests/unit_tests/datasets/dao/dao_tests.py @@ -71,3 +71,33 @@ def test_datasource_find_by_id_skip_base_filter_not_found( skip_base_filter=True, ) assert result is None + + +def test_datasource_find_by_ids_skip_base_filter(session_with_data: Session) -> None: + from superset.connectors.sqla.models import SqlaTable + from superset.datasets.dao import DatasetDAO + + result = DatasetDAO.find_by_ids( + [1, 125326326], + session=session_with_data, + skip_base_filter=True, + ) + + assert result + assert [1] == list(map(lambda x: x.id, result)) + assert ["my_sqla_table"] == list(map(lambda x: x.table_name, result)) + assert isinstance(result[0], SqlaTable) + + +def test_datasource_find_by_ids_skip_base_filter_not_found( + session_with_data: Session, +) -> None: + from superset.datasets.dao import DatasetDAO + + result = DatasetDAO.find_by_ids( + [125326326, 125326326125326326], + session=session_with_data, + skip_base_filter=True, + ) + + assert len(result) == 0