From 88029e21b6068f845d806cfc10d478a5d972ffa5 Mon Sep 17 00:00:00 2001 From: Elizabeth Thompson Date: Mon, 21 Mar 2022 09:43:51 -0700 Subject: [PATCH] fix dataset update table (#19269) --- superset/connectors/sqla/models.py | 272 +++++++++++++---------- tests/unit_tests/datasets/test_models.py | 81 ++++++- 2 files changed, 225 insertions(+), 128 deletions(-) diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 62ae8c9ebaf9..bbd1b5d84dad 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -1863,11 +1863,20 @@ def update_table( # pylint: disable=unused-argument session.execute(update(SqlaTable).where(SqlaTable.id == target.table.id)) - # update ``Column`` model as well dataset = ( - session.query(NewDataset).filter_by(sqlatable_id=target.table.id).one() + session.query(NewDataset) + .filter_by(sqlatable_id=target.table.id) + .one_or_none() ) + if not dataset: + # if dataset is not found create a new copy + # of the dataset instead of updating the existing + + SqlaTable.write_shadow_dataset(target.table, database, session) + return + + # update ``Column`` model as well if isinstance(target, TableColumn): columns = [ column @@ -1923,7 +1932,7 @@ def update_table( # pylint: disable=unused-argument column.extra_json = json.dumps(extra_json) if extra_json else None @staticmethod - def after_insert( # pylint: disable=too-many-locals + def after_insert( mapper: Mapper, connection: Connection, target: "SqlaTable", ) -> None: """ @@ -1938,135 +1947,18 @@ def after_insert( # pylint: disable=too-many-locals For more context: https://github.com/apache/superset/issues/14909 """ + session = inspect(target).session # set permissions security_manager.set_perm(mapper, connection, target) - session = inspect(target).session - # get DB-specific conditional quoter for expressions that point to columns or # table names database = ( target.database or session.query(Database).filter_by(id=target.database_id).one() ) - engine = database.get_sqla_engine(schema=target.schema) - conditional_quote = engine.dialect.identifier_preparer.quote - - # create columns - columns = [] - for column in target.columns: - # ``is_active`` might be ``None`` at this point, but it defaults to ``True``. - if column.is_active is False: - continue - - extra_json = json.loads(column.extra or "{}") - for attr in {"groupby", "filterable", "verbose_name", "python_date_format"}: - value = getattr(column, attr) - if value: - extra_json[attr] = value - - columns.append( - NewColumn( - name=column.column_name, - type=column.type or "Unknown", - expression=column.expression - or conditional_quote(column.column_name), - description=column.description, - is_temporal=column.is_dttm, - is_aggregation=False, - is_physical=column.expression is None, - is_spatial=False, - is_partition=False, - is_increase_desired=True, - extra_json=json.dumps(extra_json) if extra_json else None, - is_managed_externally=target.is_managed_externally, - external_url=target.external_url, - ), - ) - - # create metrics - for metric in target.metrics: - extra_json = json.loads(metric.extra or "{}") - for attr in {"verbose_name", "metric_type", "d3format"}: - value = getattr(metric, attr) - if value: - extra_json[attr] = value - - is_additive = ( - metric.metric_type - and metric.metric_type.lower() in ADDITIVE_METRIC_TYPES - ) - - columns.append( - NewColumn( - name=metric.metric_name, - type="Unknown", # figuring this out would require a type inferrer - expression=metric.expression, - warning_text=metric.warning_text, - description=metric.description, - is_aggregation=True, - is_additive=is_additive, - is_physical=False, - is_spatial=False, - is_partition=False, - is_increase_desired=True, - extra_json=json.dumps(extra_json) if extra_json else None, - is_managed_externally=target.is_managed_externally, - external_url=target.external_url, - ), - ) - - # physical dataset - tables = [] - if target.sql is None: - physical_columns = [column for column in columns if column.is_physical] - - # create table - table = NewTable( - name=target.table_name, - schema=target.schema, - catalog=None, # currently not supported - database_id=target.database_id, - columns=physical_columns, - is_managed_externally=target.is_managed_externally, - external_url=target.external_url, - ) - tables.append(table) - # virtual dataset - else: - # mark all columns as virtual (not physical) - for column in columns: - column.is_physical = False - - # find referenced tables - parsed = ParsedQuery(target.sql) - referenced_tables = parsed.tables - - # predicate for finding the referenced tables - predicate = or_( - *[ - and_( - NewTable.schema == (table.schema or target.schema), - NewTable.name == table.table, - ) - for table in referenced_tables - ] - ) - tables = session.query(NewTable).filter(predicate).all() - - # create the new dataset - dataset = NewDataset( - sqlatable_id=target.id, - name=target.table_name, - expression=target.sql or conditional_quote(target.table_name), - tables=tables, - columns=columns, - is_physical=target.sql is None, - is_managed_externally=target.is_managed_externally, - external_url=target.external_url, - ) - session.add(dataset) + SqlaTable.write_shadow_dataset(target, database, session) @staticmethod def after_delete( # pylint: disable=unused-argument @@ -2301,6 +2193,142 @@ def after_update( # pylint: disable=too-many-branches, too-many-locals, too-man dataset.expression = target.sql or conditional_quote(target.table_name) dataset.is_physical = target.sql is None + @staticmethod + def write_shadow_dataset( # pylint: disable=too-many-locals + dataset: "SqlaTable", database: Database, session: Session + ) -> None: + """ + Shadow write the dataset to new models. + + The ``SqlaTable`` model is currently being migrated to two new models, ``Table`` + and ``Dataset``. In the first phase of the migration the new models are populated + whenever ``SqlaTable`` is modified (created, updated, or deleted). + + In the second phase of the migration reads will be done from the new models. + Finally, in the third phase of the migration the old models will be removed. + + For more context: https://github.com/apache/superset/issues/14909 + """ + + engine = database.get_sqla_engine(schema=dataset.schema) + conditional_quote = engine.dialect.identifier_preparer.quote + + # create columns + columns = [] + for column in dataset.columns: + # ``is_active`` might be ``None`` at this point, but it defaults to ``True``. + if column.is_active is False: + continue + + extra_json = json.loads(column.extra or "{}") + for attr in {"groupby", "filterable", "verbose_name", "python_date_format"}: + value = getattr(column, attr) + if value: + extra_json[attr] = value + + columns.append( + NewColumn( + name=column.column_name, + type=column.type or "Unknown", + expression=column.expression + or conditional_quote(column.column_name), + description=column.description, + is_temporal=column.is_dttm, + is_aggregation=False, + is_physical=column.expression is None, + is_spatial=False, + is_partition=False, + is_increase_desired=True, + extra_json=json.dumps(extra_json) if extra_json else None, + is_managed_externally=dataset.is_managed_externally, + external_url=dataset.external_url, + ), + ) + + # create metrics + for metric in dataset.metrics: + extra_json = json.loads(metric.extra or "{}") + for attr in {"verbose_name", "metric_type", "d3format"}: + value = getattr(metric, attr) + if value: + extra_json[attr] = value + + is_additive = ( + metric.metric_type + and metric.metric_type.lower() in ADDITIVE_METRIC_TYPES + ) + + columns.append( + NewColumn( + name=metric.metric_name, + type="Unknown", # figuring this out would require a type inferrer + expression=metric.expression, + warning_text=metric.warning_text, + description=metric.description, + is_aggregation=True, + is_additive=is_additive, + is_physical=False, + is_spatial=False, + is_partition=False, + is_increase_desired=True, + extra_json=json.dumps(extra_json) if extra_json else None, + is_managed_externally=dataset.is_managed_externally, + external_url=dataset.external_url, + ), + ) + + # physical dataset + tables = [] + if dataset.sql is None: + physical_columns = [column for column in columns if column.is_physical] + + # create table + table = NewTable( + name=dataset.table_name, + schema=dataset.schema, + catalog=None, # currently not supported + database_id=dataset.database_id, + columns=physical_columns, + is_managed_externally=dataset.is_managed_externally, + external_url=dataset.external_url, + ) + tables.append(table) + + # virtual dataset + else: + # mark all columns as virtual (not physical) + for column in columns: + column.is_physical = False + + # find referenced tables + parsed = ParsedQuery(dataset.sql) + referenced_tables = parsed.tables + + # predicate for finding the referenced tables + predicate = or_( + *[ + and_( + NewTable.schema == (table.schema or dataset.schema), + NewTable.name == table.table, + ) + for table in referenced_tables + ] + ) + tables = session.query(NewTable).filter(predicate).all() + + # create the new dataset + new_dataset = NewDataset( + sqlatable_id=dataset.id, + name=dataset.table_name, + expression=dataset.sql or conditional_quote(dataset.table_name), + tables=tables, + columns=columns, + is_physical=dataset.sql is None, + is_managed_externally=dataset.is_managed_externally, + external_url=dataset.external_url, + ) + session.add(new_dataset) + sa.event.listen(SqlaTable, "before_update", SqlaTable.before_update) sa.event.listen(SqlaTable, "after_insert", SqlaTable.after_insert) diff --git a/tests/unit_tests/datasets/test_models.py b/tests/unit_tests/datasets/test_models.py index eab0a8aa2828..095b50276091 100644 --- a/tests/unit_tests/datasets/test_models.py +++ b/tests/unit_tests/datasets/test_models.py @@ -980,9 +980,9 @@ def test_update_sqlatable_schema( sqla_table.schema = "new_schema" session.flush() - dataset = session.query(Dataset).one() - assert dataset.tables[0].schema == "new_schema" - assert dataset.tables[0].id == 2 + new_dataset = session.query(Dataset).one() + assert new_dataset.tables[0].schema == "new_schema" + assert new_dataset.tables[0].id == 2 def test_update_sqlatable_metric( @@ -1098,9 +1098,9 @@ def test_update_virtual_sqlatable_references( session.flush() # check that new dataset has both tables - dataset = session.query(Dataset).one() - assert dataset.tables == [table1, table2] - assert dataset.expression == "SELECT a, b FROM table_a JOIN table_b" + new_dataset = session.query(Dataset).one() + assert new_dataset.tables == [table1, table2] + assert new_dataset.expression == "SELECT a, b FROM table_a JOIN table_b" def test_quote_expressions(app_context: None, session: Session) -> None: @@ -1242,3 +1242,72 @@ def test_update_physical_sqlatable( # check that dataset points to the original table assert dataset.tables[0].database_id == 1 + + +def test_update_physical_sqlatable_no_dataset( + mocker: MockFixture, app_context: None, session: Session +) -> None: + """ + Test updating the table on a physical dataset that it creates + a new dataset if one didn't already exist. + + When updating the table on a physical dataset by pointing it somewhere else (change + in database ID, schema, or table name) we should point the ``Dataset`` to an + existing ``Table`` if possible, and create a new one otherwise. + """ + # patch session + mocker.patch( + "superset.security.SupersetSecurityManager.get_session", return_value=session + ) + mocker.patch("superset.datasets.dao.db.session", session) + + from superset.columns.models import Column + from superset.connectors.sqla.models import SqlaTable, TableColumn + from superset.datasets.models import Dataset + from superset.models.core import Database + from superset.tables.models import Table + from superset.tables.schemas import TableSchema + + engine = session.get_bind() + Dataset.metadata.create_all(engine) # pylint: disable=no-member + + columns = [ + TableColumn(column_name="a", type="INTEGER"), + ] + + sqla_table = SqlaTable( + table_name="old_dataset", + columns=columns, + metrics=[], + database=Database(database_name="my_database", sqlalchemy_uri="sqlite://"), + ) + session.add(sqla_table) + session.flush() + + # check that the table was created + table = session.query(Table).one() + assert table.id == 1 + + dataset = session.query(Dataset).one() + assert dataset.tables == [table] + + # point ``SqlaTable`` to a different database + new_database = Database( + database_name="my_other_database", sqlalchemy_uri="sqlite://" + ) + session.add(new_database) + session.flush() + sqla_table.database = new_database + session.flush() + + new_dataset = session.query(Dataset).one() + + # check that dataset now points to the new table + assert new_dataset.tables[0].database_id == 2 + + # point ``SqlaTable`` back + sqla_table.database_id = 1 + session.flush() + + # check that dataset points to the original table + assert new_dataset.tables[0].database_id == 1