From 1a01826e174159725997b3fc13d3234ed79ae9ba Mon Sep 17 00:00:00 2001 From: Beto Dealmeida Date: Mon, 13 Dec 2021 15:47:57 -0800 Subject: [PATCH] fix: column extra in import/export --- superset/datasets/commands/export.py | 15 +- .../datasets/commands/importers/v1/utils.py | 32 +-- superset/datasets/schemas.py | 3 +- tests/unit_tests/conftest.py | 41 +++- tests/unit_tests/datasets/__init__.py | 16 ++ .../unit_tests/datasets/commands/__init__.py | 16 ++ .../datasets/commands/export_test.py | 192 ++++++++++++++++++ .../datasets/commands/importers/__init__.py | 16 ++ .../commands/importers/v1/__init__.py | 16 ++ .../commands/importers/v1/import_test.py | 190 +++++++++++++++++ 10 files changed, 512 insertions(+), 25 deletions(-) create mode 100644 tests/unit_tests/datasets/__init__.py create mode 100644 tests/unit_tests/datasets/commands/__init__.py create mode 100644 tests/unit_tests/datasets/commands/export_test.py create mode 100644 tests/unit_tests/datasets/commands/importers/__init__.py create mode 100644 tests/unit_tests/datasets/commands/importers/v1/__init__.py create mode 100644 tests/unit_tests/datasets/commands/importers/v1/import_test.py diff --git a/superset/datasets/commands/export.py b/superset/datasets/commands/export.py index 84ae6c486c08d..45460f36e3455 100644 --- a/superset/datasets/commands/export.py +++ b/superset/datasets/commands/export.py @@ -59,12 +59,15 @@ def _export(model: SqlaTable) -> Iterator[Tuple[str, str]]: payload[key] = json.loads(payload[key]) except json.decoder.JSONDecodeError: logger.info("Unable to decode `%s` field: %s", key, payload[key]) - for metric in payload.get("metrics", []): - if metric.get("extra"): - try: - metric["extra"] = json.loads(metric["extra"]) - except json.decoder.JSONDecodeError: - logger.info("Unable to decode `extra` field: %s", metric["extra"]) + for key in ("metrics", "columns"): + for attributes in payload.get(key, []): + if attributes.get("extra"): + try: + attributes["extra"] = json.loads(attributes["extra"]) + except json.decoder.JSONDecodeError: + logger.info( + "Unable to decode `extra` field: %s", attributes["extra"] + ) payload["version"] = EXPORT_VERSION payload["database_uuid"] = str(model.database.uuid) diff --git a/superset/datasets/commands/importers/v1/utils.py b/superset/datasets/commands/importers/v1/utils.py index 37522da28c2d2..92687324aab8c 100644 --- a/superset/datasets/commands/importers/v1/utils.py +++ b/superset/datasets/commands/importers/v1/utils.py @@ -30,7 +30,6 @@ from superset.connectors.sqla.models import SqlaTable from superset.models.core import Database -from superset.utils.core import get_example_database logger = logging.getLogger(__name__) @@ -96,13 +95,17 @@ def import_dataset( config[key] = json.dumps(config[key]) except TypeError: logger.info("Unable to encode `%s` field: %s", key, config[key]) - for metric in config.get("metrics", []): - if metric.get("extra") is not None: - try: - metric["extra"] = json.dumps(metric["extra"]) - except TypeError: - logger.info("Unable to encode `extra` field: %s", metric["extra"]) - metric["extra"] = None + for key in ("metrics", "columns"): + for attributes in config.get(key, []): + # should be a dictionary, but in initial exports this was a string + if isinstance(attributes.get("extra"), dict): + try: + attributes["extra"] = json.dumps(attributes["extra"]) + except TypeError: + logger.info( + "Unable to encode `extra` field: %s", attributes["extra"] + ) + attributes["extra"] = None # should we delete columns and metrics not present in the current import? sync = ["columns", "metrics"] if overwrite else [] @@ -127,9 +130,8 @@ def import_dataset( if dataset.id is None: session.flush() - example_database = get_example_database() try: - table_exists = example_database.has_table_by_name(dataset.table_name) + table_exists = dataset.database.has_table_by_name(dataset.table_name) except Exception: # pylint: disable=broad-except # MySQL doesn't play nice with GSheets table names logger.warning( @@ -139,7 +141,7 @@ def import_dataset( if data_uri and (not table_exists or force_data): logger.info("Downloading data from %s", data_uri) - load_data(data_uri, dataset, example_database, session) + load_data(data_uri, dataset, dataset.database, session) if hasattr(g, "user") and g.user: dataset.owners.append(g.user) @@ -148,7 +150,7 @@ def import_dataset( def load_data( - data_uri: str, dataset: SqlaTable, example_database: Database, session: Session + data_uri: str, dataset: SqlaTable, database: Database, session: Session ) -> None: data = request.urlopen(data_uri) # pylint: disable=consider-using-with if data_uri.endswith(".gz"): @@ -162,14 +164,12 @@ def load_data( df[column_name] = pd.to_datetime(df[column_name]) # reuse session when loading data if possible, to make import atomic - if example_database.sqlalchemy_uri == current_app.config.get( - "SQLALCHEMY_DATABASE_URI" - ) or not current_app.config.get("SQLALCHEMY_EXAMPLES_URI"): + if database.sqlalchemy_uri == current_app.config.get("SQLALCHEMY_DATABASE_URI"): logger.info("Loading data inside the import transaction") connection = session.connection() else: logger.warning("Loading data outside the import transaction") - connection = example_database.get_sqla_engine() + connection = database.get_sqla_engine() df.to_sql( dataset.table_name, diff --git a/superset/datasets/schemas.py b/superset/datasets/schemas.py index 58258b1dda158..3a206c7539300 100644 --- a/superset/datasets/schemas.py +++ b/superset/datasets/schemas.py @@ -131,7 +131,8 @@ class DatasetRelatedObjectsResponse(Schema): class ImportV1ColumnSchema(Schema): column_name = fields.String(required=True) - extra = fields.Dict(allow_none=True) + # extra was initially exported incorrectly as a string + extra = fields.Raw(allow_none=True) verbose_name = fields.String(allow_none=True) is_dttm = fields.Boolean(default=False, allow_none=True) is_active = fields.Boolean(default=True, allow_none=True) diff --git a/tests/unit_tests/conftest.py b/tests/unit_tests/conftest.py index 4700cd19f8d7d..8522877c28865 100644 --- a/tests/unit_tests/conftest.py +++ b/tests/unit_tests/conftest.py @@ -14,25 +14,62 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +# pylint: disable=redefined-outer-name + +from typing import Iterator import pytest +from pytest_mock import MockFixture +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker +from sqlalchemy.orm.session import Session from superset.app import SupersetApp from superset.initialization import SupersetAppInitializer +@pytest.fixture() +def session() -> Iterator[Session]: + """ + Create an in-memory SQLite session to test models. + """ + engine = create_engine("sqlite://") + Session_ = sessionmaker(bind=engine) # pylint: disable=invalid-name + in_memory_session = Session_() + + # flask calls session.remove() + in_memory_session.remove = lambda: None + + yield in_memory_session + + @pytest.fixture -def app_context(): +def app(mocker: MockFixture, session: Session) -> Iterator[SupersetApp]: """ - A fixture for running the test inside an app context. + A fixture that generates a Superset app. """ app = SupersetApp(__name__) app.config.from_object("superset.config") app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite://" + app.config["FAB_ADD_SECURITY_VIEWS"] = False app_initializer = app.config.get("APP_INITIALIZER", SupersetAppInitializer)(app) app_initializer.init_app() + # patch session + mocker.patch( + "superset.security.SupersetSecurityManager.get_session", return_value=session, + ) + mocker.patch("superset.db.session", session) + + yield app + + +@pytest.fixture +def app_context(app: SupersetApp) -> Iterator[None]: + """ + A fixture that yields and application context. + """ with app.app_context(): yield diff --git a/tests/unit_tests/datasets/__init__.py b/tests/unit_tests/datasets/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/tests/unit_tests/datasets/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/unit_tests/datasets/commands/__init__.py b/tests/unit_tests/datasets/commands/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/tests/unit_tests/datasets/commands/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/unit_tests/datasets/commands/export_test.py b/tests/unit_tests/datasets/commands/export_test.py new file mode 100644 index 0000000000000..5c67ab2e145ad --- /dev/null +++ b/tests/unit_tests/datasets/commands/export_test.py @@ -0,0 +1,192 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=import-outside-toplevel, unused-argument, unused-import + +import json + +from sqlalchemy.orm.session import Session + + +def test_export(app_context: None, session: Session) -> None: + """ + Test exporting a dataset. + """ + from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn + from superset.datasets.commands.export import ExportDatasetsCommand + from superset.models.core import Database + + engine = session.get_bind() + SqlaTable.metadata.create_all(engine) # pylint: disable=no-member + + database = Database(database_name="my_database", sqlalchemy_uri="sqlite://") + session.add(database) + session.flush() + + columns = [ + TableColumn(column_name="ds", is_dttm=1, type="TIMESTAMP"), + TableColumn(column_name="user_id", type="INTEGER"), + TableColumn(column_name="revenue", type="INTEGER"), + TableColumn(column_name="expenses", type="INTEGER"), + TableColumn( + column_name="profit", + type="INTEGER", + expression="revenue-expenses", + extra=json.dumps({"certified_by": "User"}), + ), + ] + metrics = [ + SqlMetric(metric_name="cnt", expression="COUNT(*)"), + ] + + sqla_table = SqlaTable( + table_name="my_table", + columns=columns, + metrics=metrics, + main_dttm_col="ds", + database=database, + offset=-8, + description="This is the description", + is_featured=1, + cache_timeout=3600, + schema="my_schema", + sql=None, + params=json.dumps( + {"remote_id": 64, "database_name": "examples", "import_time": 1606677834,} + ), + perm=None, + filter_select_enabled=1, + fetch_values_predicate="foo IN (1, 2)", + is_sqllab_view=0, # no longer used? + template_params=json.dumps({"answer": "42"}), + schema_perm=None, + extra=json.dumps({"warning_markdown": "*WARNING*"}), + ) + + export = list( + ExportDatasetsCommand._export(sqla_table) # pylint: disable=protected-access + ) + assert export == [ + ( + "datasets/my_database/my_table.yaml", + f"""table_name: my_table +main_dttm_col: ds +description: This is the description +default_endpoint: null +offset: -8 +cache_timeout: 3600 +schema: my_schema +sql: null +params: + remote_id: 64 + database_name: examples + import_time: 1606677834 +template_params: + answer: '42' +filter_select_enabled: 1 +fetch_values_predicate: foo IN (1, 2) +extra: '{{\"warning_markdown\": \"*WARNING*\"}}' +uuid: null +metrics: +- metric_name: cnt + verbose_name: null + metric_type: null + expression: COUNT(*) + description: null + d3format: null + extra: null + warning_text: null +columns: +- column_name: profit + verbose_name: null + is_dttm: null + is_active: null + type: INTEGER + groupby: null + filterable: null + expression: revenue-expenses + description: null + python_date_format: null + extra: + certified_by: User +- column_name: ds + verbose_name: null + is_dttm: 1 + is_active: null + type: TIMESTAMP + groupby: null + filterable: null + expression: null + description: null + python_date_format: null + extra: null +- column_name: user_id + verbose_name: null + is_dttm: null + is_active: null + type: INTEGER + groupby: null + filterable: null + expression: null + description: null + python_date_format: null + extra: null +- column_name: expenses + verbose_name: null + is_dttm: null + is_active: null + type: INTEGER + groupby: null + filterable: null + expression: null + description: null + python_date_format: null + extra: null +- column_name: revenue + verbose_name: null + is_dttm: null + is_active: null + type: INTEGER + groupby: null + filterable: null + expression: null + description: null + python_date_format: null + extra: null +version: 1.0.0 +database_uuid: {database.uuid} +""", + ), + ( + "databases/my_database.yaml", + f"""database_name: my_database +sqlalchemy_uri: sqlite:// +cache_timeout: null +expose_in_sqllab: true +allow_run_async: false +allow_ctas: false +allow_cvas: false +allow_file_upload: false +extra: + metadata_params: {{}} + engine_params: {{}} + metadata_cache_timeout: {{}} + schemas_allowed_for_file_upload: [] +uuid: {database.uuid} +version: 1.0.0 +""", + ), + ] diff --git a/tests/unit_tests/datasets/commands/importers/__init__.py b/tests/unit_tests/datasets/commands/importers/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/tests/unit_tests/datasets/commands/importers/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/unit_tests/datasets/commands/importers/v1/__init__.py b/tests/unit_tests/datasets/commands/importers/v1/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/tests/unit_tests/datasets/commands/importers/v1/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/unit_tests/datasets/commands/importers/v1/import_test.py b/tests/unit_tests/datasets/commands/importers/v1/import_test.py new file mode 100644 index 0000000000000..99fa42d072f3d --- /dev/null +++ b/tests/unit_tests/datasets/commands/importers/v1/import_test.py @@ -0,0 +1,190 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=import-outside-toplevel, unused-argument, unused-import, invalid-name + +import json +import uuid +from typing import Any, Dict + +from sqlalchemy.orm.session import Session + + +def test_import_(app_context: None, session: Session) -> None: + """ + Test importing a dataset. + """ + from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn + from superset.datasets.commands.importers.v1.utils import import_dataset + from superset.models.core import Database + + engine = session.get_bind() + SqlaTable.metadata.create_all(engine) # pylint: disable=no-member + + database = Database(database_name="my_database", sqlalchemy_uri="sqlite://") + session.add(database) + session.flush() + + dataset_uuid = uuid.uuid4() + config = { + "table_name": "my_table", + "main_dttm_col": "ds", + "description": "This is the description", + "default_endpoint": None, + "offset": -8, + "cache_timeout": 3600, + "schema": "my_schema", + "sql": None, + "params": { + "remote_id": 64, + "database_name": "examples", + "import_time": 1606677834, + }, + "template_params": {"answer": "42",}, + "filter_select_enabled": True, + "fetch_values_predicate": "foo IN (1, 2)", + "extra": '{"warning_markdown": "*WARNING*"}', + "uuid": dataset_uuid, + "metrics": [ + { + "metric_name": "cnt", + "verbose_name": None, + "metric_type": None, + "expression": "COUNT(*)", + "description": None, + "d3format": None, + "extra": None, + "warning_text": None, + } + ], + "columns": [ + { + "column_name": "profit", + "verbose_name": None, + "is_dttm": None, + "is_active": None, + "type": "INTEGER", + "groupby": None, + "filterable": None, + "expression": "revenue-expenses", + "description": None, + "python_date_format": None, + "extra": {"certified_by": "User",}, + } + ], + "database_uuid": database.uuid, + "database_id": database.id, + } + + sqla_table = import_dataset(session, config) + assert sqla_table.table_name == "my_table" + assert sqla_table.main_dttm_col == "ds" + assert sqla_table.description == "This is the description" + assert sqla_table.default_endpoint is None + assert sqla_table.offset == -8 + assert sqla_table.cache_timeout == 3600 + assert sqla_table.schema == "my_schema" + assert sqla_table.sql is None + assert sqla_table.params == json.dumps( + {"remote_id": 64, "database_name": "examples", "import_time": 1606677834} + ) + assert sqla_table.template_params == json.dumps({"answer": "42"}) + assert sqla_table.filter_select_enabled is True + assert sqla_table.fetch_values_predicate == "foo IN (1, 2)" + assert sqla_table.extra == '{"warning_markdown": "*WARNING*"}' + assert sqla_table.uuid == dataset_uuid + assert len(sqla_table.metrics) == 1 + assert sqla_table.metrics[0].metric_name == "cnt" + assert sqla_table.metrics[0].verbose_name is None + assert sqla_table.metrics[0].metric_type is None + assert sqla_table.metrics[0].expression == "COUNT(*)" + assert sqla_table.metrics[0].description is None + assert sqla_table.metrics[0].d3format is None + assert sqla_table.metrics[0].extra is None + assert sqla_table.metrics[0].warning_text is None + assert len(sqla_table.columns) == 1 + assert sqla_table.columns[0].column_name == "profit" + assert sqla_table.columns[0].verbose_name is None + assert sqla_table.columns[0].is_dttm is False + assert sqla_table.columns[0].is_active is True + assert sqla_table.columns[0].type == "INTEGER" + assert sqla_table.columns[0].groupby is True + assert sqla_table.columns[0].filterable is True + assert sqla_table.columns[0].expression == "revenue-expenses" + assert sqla_table.columns[0].description is None + assert sqla_table.columns[0].python_date_format is None + assert sqla_table.columns[0].extra == '{"certified_by": "User"}' + assert sqla_table.database.uuid == database.uuid + assert sqla_table.database.id == database.id + + +def test_import_column_extra_is_string(app_context: None, session: Session) -> None: + """ + Test importing a dataset when the column extra is a string. + """ + from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn + from superset.datasets.commands.importers.v1.utils import import_dataset + from superset.models.core import Database + + engine = session.get_bind() + SqlaTable.metadata.create_all(engine) # pylint: disable=no-member + + database = Database(database_name="my_database", sqlalchemy_uri="sqlite://") + session.add(database) + session.flush() + + dataset_uuid = uuid.uuid4() + config: Dict[str, Any] = { + "table_name": "my_table", + "main_dttm_col": "ds", + "description": "This is the description", + "default_endpoint": None, + "offset": -8, + "cache_timeout": 3600, + "schema": "my_schema", + "sql": None, + "params": { + "remote_id": 64, + "database_name": "examples", + "import_time": 1606677834, + }, + "template_params": {"answer": "42",}, + "filter_select_enabled": True, + "fetch_values_predicate": "foo IN (1, 2)", + "extra": '{"warning_markdown": "*WARNING*"}', + "uuid": dataset_uuid, + "metrics": [], + "columns": [ + { + "column_name": "profit", + "verbose_name": None, + "is_dttm": None, + "is_active": None, + "type": "INTEGER", + "groupby": None, + "filterable": None, + "expression": "revenue-expenses", + "description": None, + "python_date_format": None, + "extra": '{"certified_by": "User"}', + } + ], + "database_uuid": database.uuid, + "database_id": database.id, + } + + sqla_table = import_dataset(session, config) + assert sqla_table.columns[0].extra == '{"certified_by": "User"}'