diff --git a/superset-frontend/src/components/FlashProvider.tsx b/superset-frontend/src/components/FlashProvider.tsx index d4b1ddbff399..5504288e87db 100644 --- a/superset-frontend/src/components/FlashProvider.tsx +++ b/superset-frontend/src/components/FlashProvider.tsx @@ -31,6 +31,7 @@ interface Props { const flashObj = { info: 'addInfoToast', + alert: 'addDangerToast', danger: 'addDangerToast', warning: 'addWarningToast', success: 'addSuccessToast', @@ -42,7 +43,9 @@ class FlashProvider extends React.PureComponent { flashMessages.forEach(message => { const [type, text] = message; const flash = flashObj[type]; - this.props[flash](text); + if (this.props[flash]) { + this.props[flash](text); + } }); } render() { diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 8677be6d6946..8cfae24222f5 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -1238,7 +1238,10 @@ def fetch_metadata(self, commit: bool = True) -> None: @classmethod def import_obj( - cls, i_datasource: "SqlaTable", import_time: Optional[int] = None + cls, + i_datasource: "SqlaTable", + database_id: Optional[int] = None, + import_time: Optional[int] = None, ) -> int: """Imports the datasource from the object to the database. @@ -1275,7 +1278,12 @@ def lookup_database(table_: SqlaTable) -> Database: ) return import_datasource.import_datasource( - db.session, i_datasource, lookup_database, lookup_sqlatable, import_time + db.session, + i_datasource, + lookup_database, + lookup_sqlatable, + import_time, + database_id, ) @classmethod diff --git a/superset/exceptions.py b/superset/exceptions.py index 51bd85f8460e..c849d678b3f2 100644 --- a/superset/exceptions.py +++ b/superset/exceptions.py @@ -77,3 +77,7 @@ class DatabaseNotFound(SupersetException): class QueryObjectValidationError(SupersetException): status = 400 + + +class DashboardImportException(SupersetException): + pass diff --git a/superset/models/dashboard.py b/superset/models/dashboard.py index b34906fc3f0f..02a33c7183c0 100644 --- a/superset/models/dashboard.py +++ b/superset/models/dashboard.py @@ -247,7 +247,7 @@ def position(self) -> Dict[str, Any]: @classmethod def import_obj( # pylint: disable=too-many-locals,too-many-branches,too-many-statements - cls, dashboard_to_import: "Dashboard", import_time: Optional[int] = None + cls, dashboard_to_import: "Dashboard", import_time: Optional[int] = None, ) -> int: """Imports the dashboard from the object to the database. @@ -311,6 +311,10 @@ def alter_positions( # copy slices object as Slice.import_slice will mutate the slice # and will remove the existing dashboard - slice association slices = copy(dashboard_to_import.slices) + + # Clearing the slug to avoid conflicts + dashboard_to_import.slug = None + old_json_metadata = json.loads(dashboard_to_import.json_metadata or "{}") old_to_new_slc_id_dict: Dict[int, int] = {} new_timed_refresh_immune_slices = [] @@ -332,8 +336,8 @@ def alter_positions( new_slc_id = Slice.import_obj(slc, remote_slc, import_time=import_time) old_to_new_slc_id_dict[slc.id] = new_slc_id # update json metadata that deals with slice ids - new_slc_id_str = "{}".format(new_slc_id) - old_slc_id_str = "{}".format(slc.id) + new_slc_id_str = str(new_slc_id) + old_slc_id_str = str(slc.id) if ( "timed_refresh_immune_slices" in i_params_dict and old_slc_id_str in i_params_dict["timed_refresh_immune_slices"] diff --git a/superset/templates/superset/import_dashboards.html b/superset/templates/superset/import_dashboards.html index 753953787c23..3f7fe4a6c0a2 100644 --- a/superset/templates/superset/import_dashboards.html +++ b/superset/templates/superset/import_dashboards.html @@ -24,29 +24,46 @@ {% include "superset/flash_wrapper.html" %}
-

Import dashboards

+
+

{{ _("Import Dashboard(s)") }}

+
-
- -

- - -
-
- -

-
+ type="hidden" + name="csrf_token" + id="csrf_token" + value="{{ csrf_token() if csrf_token else '' }}" /> + + + + + + + + + +
{{ _("File") }} + + +
{{ _("Database") }} + +
+ + +
+
{% endblock %} diff --git a/superset/utils/dashboard_import_export.py b/superset/utils/dashboard_import_export.py index 05889b7deb7a..6ae500b40f4c 100644 --- a/superset/utils/dashboard_import_export.py +++ b/superset/utils/dashboard_import_export.py @@ -21,9 +21,11 @@ from io import BytesIO from typing import Any, Dict, Optional +from flask_babel import lazy_gettext as _ from sqlalchemy.orm import Session from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn +from superset.exceptions import DashboardImportException from superset.models.dashboard import Dashboard from superset.models.slice import Slice @@ -69,14 +71,19 @@ def decode_dashboards( # pylint: disable=too-many-return-statements def import_dashboards( - session: Session, data_stream: BytesIO, import_time: Optional[int] = None + session: Session, + data_stream: BytesIO, + database_id: Optional[int] = None, + import_time: Optional[int] = None, ) -> None: """Imports dashboards from a stream to databases""" current_tt = int(time.time()) import_time = current_tt if import_time is None else import_time data = json.loads(data_stream.read(), object_hook=decode_dashboards) + if not data: + raise DashboardImportException(_("No data in file")) for table in data["datasources"]: - type(table).import_obj(table, import_time=import_time) + type(table).import_obj(table, database_id, import_time=import_time) session.commit() for dashboard in data["dashboards"]: Dashboard.import_obj(dashboard, import_time=import_time) diff --git a/superset/utils/import_datasource.py b/superset/utils/import_datasource.py index 50f375cb9d17..25da876b28b7 100644 --- a/superset/utils/import_datasource.py +++ b/superset/utils/import_datasource.py @@ -24,12 +24,13 @@ logger = logging.getLogger(__name__) -def import_datasource( +def import_datasource( # pylint: disable=too-many-arguments session: Session, i_datasource: Model, lookup_database: Callable[[Model], Model], lookup_datasource: Callable[[Model], Model], import_time: Optional[int] = None, + database_id: Optional[int] = None, ) -> int: """Imports the datasource from the object to the database. @@ -41,7 +42,9 @@ def import_datasource( logger.info("Started import of the datasource: %s", i_datasource.to_json()) i_datasource.id = None - i_datasource.database_id = lookup_database(i_datasource).id + i_datasource.database_id = ( + database_id if database_id else lookup_database(i_datasource).id + ) i_datasource.alter_params(import_time=import_time) # override the datasource diff --git a/superset/utils/log.py b/superset/utils/log.py index 68a0c95f52de..1b6e1b694504 100644 --- a/superset/utils/log.py +++ b/superset/utils/log.py @@ -24,6 +24,7 @@ from typing import Any, Callable, cast, Optional, Type from flask import current_app, g, request +from sqlalchemy.exc import SQLAlchemyError from superset.stats_logger import BaseStatsLogger @@ -169,6 +170,10 @@ def log( # pylint: disable=too-many-locals ) logs.append(log) - sesh = current_app.appbuilder.get_session - sesh.bulk_save_objects(logs) - sesh.commit() + try: + sesh = current_app.appbuilder.get_session + sesh.bulk_save_objects(logs) + sesh.commit() + except SQLAlchemyError as ex: + logging.error("DBEventLogger failed to log event(s)") + logging.exception(ex) diff --git a/superset/views/core.py b/superset/views/core.py index 3db10abfe24e..fea37b9ac549 100755 --- a/superset/views/core.py +++ b/superset/views/core.py @@ -75,6 +75,7 @@ SupersetTimeoutException, ) from superset.jinja_context import get_template_processor +from superset.models.core import Database from superset.models.dashboard import Dashboard from superset.models.datasource_access_request import DatasourceAccessRequest from superset.models.slice import Slice @@ -541,10 +542,15 @@ def explore_json( @expose("/import_dashboards", methods=["GET", "POST"]) def import_dashboards(self) -> FlaskResponse: """Overrides the dashboards using json instances from the file.""" - f = request.files.get("file") - if request.method == "POST" and f: + import_file = request.files.get("file") + if request.method == "POST" and import_file: + success = False + database_id = request.form.get("db_id") try: - dashboard_import_export.import_dashboards(db.session, f.stream) + dashboard_import_export.import_dashboards( + db.session, import_file.stream, database_id + ) + success = True except DatabaseNotFound as ex: logger.exception(ex) flash( @@ -565,8 +571,14 @@ def import_dashboards(self) -> FlaskResponse: ), "danger", ) - return redirect("/dashboard/list/") - return self.render_template("superset/import_dashboards.html") + if success: + flash("Dashboard(s) have been imported", "success") + return redirect("/dashboard/list/") + + databases = db.session.query(Database).all() + return self.render_template( + "superset/import_dashboards.html", databases=databases + ) @event_logger.log_this @has_access diff --git a/tests/import_export_tests.py b/tests/import_export_tests.py index 16443afa2317..e772d16b136b 100644 --- a/tests/import_export_tests.py +++ b/tests/import_export_tests.py @@ -34,6 +34,7 @@ from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn from superset.models.dashboard import Dashboard from superset.models.slice import Slice +from superset.utils.core import get_example_database from .base_tests import SupersetTestCase @@ -157,8 +158,11 @@ def get_datasource(self, datasource_id): def get_table_by_name(self, name): return db.session.query(SqlaTable).filter_by(table_name=name).first() - def assert_dash_equals(self, expected_dash, actual_dash, check_position=True): - self.assertEqual(expected_dash.slug, actual_dash.slug) + def assert_dash_equals( + self, expected_dash, actual_dash, check_position=True, check_slugs=True + ): + if check_slugs: + self.assertEqual(expected_dash.slug, actual_dash.slug) self.assertEqual(expected_dash.dashboard_title, actual_dash.dashboard_title) self.assertEqual(len(expected_dash.slices), len(actual_dash.slices)) expected_slices = sorted(expected_dash.slices, key=lambda s: s.slice_name or "") @@ -378,7 +382,9 @@ def test_import_dashboard_1_slice(self): expected_dash = self.create_dashboard("dash_with_1_slice", slcs=[slc], id=10002) make_transient(expected_dash) - self.assert_dash_equals(expected_dash, imported_dash, check_position=False) + self.assert_dash_equals( + expected_dash, imported_dash, check_position=False, check_slugs=False + ) self.assertEqual( {"remote_id": 10002, "import_time": 1990}, json.loads(imported_dash.json_metadata), @@ -420,7 +426,9 @@ def test_import_dashboard_2_slices(self): "dash_with_2_slices", slcs=[e_slc, b_slc], id=10003 ) make_transient(expected_dash) - self.assert_dash_equals(imported_dash, expected_dash, check_position=False) + self.assert_dash_equals( + imported_dash, expected_dash, check_position=False, check_slugs=False + ) i_e_slc = self.get_slice_by_name("e_slc") i_b_slc = self.get_slice_by_name("b_slc") expected_json_metadata = { @@ -466,7 +474,9 @@ def test_import_override_dashboard_2_slices(self): ) make_transient(expected_dash) imported_dash = self.get_dash(imported_dash_id_2) - self.assert_dash_equals(expected_dash, imported_dash, check_position=False) + self.assert_dash_equals( + expected_dash, imported_dash, check_position=False, check_slugs=False + ) self.assertEqual( {"remote_id": 10004, "import_time": 1992}, json.loads(imported_dash.json_metadata), @@ -556,8 +566,9 @@ def _create_dashboard_for_import(self, id_=10100): return dash_with_1_slice def test_import_table_no_metadata(self): + db_id = get_example_database().id table = self.create_table("pure_table", id=10001) - imported_id = SqlaTable.import_obj(table, import_time=1989) + imported_id = SqlaTable.import_obj(table, db_id, import_time=1989) imported = self.get_table_by_id(imported_id) self.assert_table_equals(table, imported) @@ -565,7 +576,8 @@ def test_import_table_1_col_1_met(self): table = self.create_table( "table_1_col_1_met", id=10002, cols_names=["col1"], metric_names=["metric1"] ) - imported_id = SqlaTable.import_obj(table, import_time=1990) + db_id = get_example_database().id + imported_id = SqlaTable.import_obj(table, db_id, import_time=1990) imported = self.get_table_by_id(imported_id) self.assert_table_equals(table, imported) self.assertEqual( @@ -580,7 +592,8 @@ def test_import_table_2_col_2_met(self): cols_names=["c1", "c2"], metric_names=["m1", "m2"], ) - imported_id = SqlaTable.import_obj(table, import_time=1991) + db_id = get_example_database().id + imported_id = SqlaTable.import_obj(table, db_id, import_time=1991) imported = self.get_table_by_id(imported_id) self.assert_table_equals(table, imported) @@ -589,7 +602,8 @@ def test_import_table_override(self): table = self.create_table( "table_override", id=10003, cols_names=["col1"], metric_names=["m1"] ) - imported_id = SqlaTable.import_obj(table, import_time=1991) + db_id = get_example_database().id + imported_id = SqlaTable.import_obj(table, db_id, import_time=1991) table_over = self.create_table( "table_override", @@ -597,7 +611,7 @@ def test_import_table_override(self): cols_names=["new_col1", "col2", "col3"], metric_names=["new_metric1"], ) - imported_over_id = SqlaTable.import_obj(table_over, import_time=1992) + imported_over_id = SqlaTable.import_obj(table_over, db_id, import_time=1992) imported_over = self.get_table_by_id(imported_over_id) self.assertEqual(imported_id, imported_over.id) @@ -616,7 +630,8 @@ def test_import_table_override_identical(self): cols_names=["new_col1", "col2", "col3"], metric_names=["new_metric1"], ) - imported_id = SqlaTable.import_obj(table, import_time=1993) + db_id = get_example_database().id + imported_id = SqlaTable.import_obj(table, db_id, import_time=1993) copy_table = self.create_table( "copy_cat", @@ -624,7 +639,7 @@ def test_import_table_override_identical(self): cols_names=["new_col1", "col2", "col3"], metric_names=["new_metric1"], ) - imported_id_copy = SqlaTable.import_obj(copy_table, import_time=1994) + imported_id_copy = SqlaTable.import_obj(copy_table, db_id, import_time=1994) self.assertEqual(imported_id, imported_id_copy) self.assert_table_equals(copy_table, self.get_table_by_id(imported_id))