Skip to content

Commit

Permalink
fix dataset update table (#19269)
Browse files Browse the repository at this point in the history
  • Loading branch information
eschutho committed Mar 21, 2022
1 parent c07a707 commit 88029e2
Show file tree
Hide file tree
Showing 2 changed files with 225 additions and 128 deletions.
272 changes: 150 additions & 122 deletions superset/connectors/sqla/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1863,11 +1863,20 @@ def update_table( # pylint: disable=unused-argument

session.execute(update(SqlaTable).where(SqlaTable.id == target.table.id))

# update ``Column`` model as well
dataset = (
session.query(NewDataset).filter_by(sqlatable_id=target.table.id).one()
session.query(NewDataset)
.filter_by(sqlatable_id=target.table.id)
.one_or_none()
)

if not dataset:
# if dataset is not found create a new copy
# of the dataset instead of updating the existing

SqlaTable.write_shadow_dataset(target.table, database, session)
return

# update ``Column`` model as well
if isinstance(target, TableColumn):
columns = [
column
Expand Down Expand Up @@ -1923,7 +1932,7 @@ def update_table( # pylint: disable=unused-argument
column.extra_json = json.dumps(extra_json) if extra_json else None

@staticmethod
def after_insert( # pylint: disable=too-many-locals
def after_insert(
mapper: Mapper, connection: Connection, target: "SqlaTable",
) -> None:
"""
Expand All @@ -1938,135 +1947,18 @@ def after_insert( # pylint: disable=too-many-locals
For more context: https://github.com/apache/superset/issues/14909
"""
session = inspect(target).session
# set permissions
security_manager.set_perm(mapper, connection, target)

session = inspect(target).session

# get DB-specific conditional quoter for expressions that point to columns or
# table names
database = (
target.database
or session.query(Database).filter_by(id=target.database_id).one()
)
engine = database.get_sqla_engine(schema=target.schema)
conditional_quote = engine.dialect.identifier_preparer.quote

# create columns
columns = []
for column in target.columns:
# ``is_active`` might be ``None`` at this point, but it defaults to ``True``.
if column.is_active is False:
continue

extra_json = json.loads(column.extra or "{}")
for attr in {"groupby", "filterable", "verbose_name", "python_date_format"}:
value = getattr(column, attr)
if value:
extra_json[attr] = value

columns.append(
NewColumn(
name=column.column_name,
type=column.type or "Unknown",
expression=column.expression
or conditional_quote(column.column_name),
description=column.description,
is_temporal=column.is_dttm,
is_aggregation=False,
is_physical=column.expression is None,
is_spatial=False,
is_partition=False,
is_increase_desired=True,
extra_json=json.dumps(extra_json) if extra_json else None,
is_managed_externally=target.is_managed_externally,
external_url=target.external_url,
),
)

# create metrics
for metric in target.metrics:
extra_json = json.loads(metric.extra or "{}")
for attr in {"verbose_name", "metric_type", "d3format"}:
value = getattr(metric, attr)
if value:
extra_json[attr] = value

is_additive = (
metric.metric_type
and metric.metric_type.lower() in ADDITIVE_METRIC_TYPES
)

columns.append(
NewColumn(
name=metric.metric_name,
type="Unknown", # figuring this out would require a type inferrer
expression=metric.expression,
warning_text=metric.warning_text,
description=metric.description,
is_aggregation=True,
is_additive=is_additive,
is_physical=False,
is_spatial=False,
is_partition=False,
is_increase_desired=True,
extra_json=json.dumps(extra_json) if extra_json else None,
is_managed_externally=target.is_managed_externally,
external_url=target.external_url,
),
)

# physical dataset
tables = []
if target.sql is None:
physical_columns = [column for column in columns if column.is_physical]

# create table
table = NewTable(
name=target.table_name,
schema=target.schema,
catalog=None, # currently not supported
database_id=target.database_id,
columns=physical_columns,
is_managed_externally=target.is_managed_externally,
external_url=target.external_url,
)
tables.append(table)

# virtual dataset
else:
# mark all columns as virtual (not physical)
for column in columns:
column.is_physical = False

# find referenced tables
parsed = ParsedQuery(target.sql)
referenced_tables = parsed.tables

# predicate for finding the referenced tables
predicate = or_(
*[
and_(
NewTable.schema == (table.schema or target.schema),
NewTable.name == table.table,
)
for table in referenced_tables
]
)
tables = session.query(NewTable).filter(predicate).all()

# create the new dataset
dataset = NewDataset(
sqlatable_id=target.id,
name=target.table_name,
expression=target.sql or conditional_quote(target.table_name),
tables=tables,
columns=columns,
is_physical=target.sql is None,
is_managed_externally=target.is_managed_externally,
external_url=target.external_url,
)
session.add(dataset)
SqlaTable.write_shadow_dataset(target, database, session)

@staticmethod
def after_delete( # pylint: disable=unused-argument
Expand Down Expand Up @@ -2301,6 +2193,142 @@ def after_update( # pylint: disable=too-many-branches, too-many-locals, too-man
dataset.expression = target.sql or conditional_quote(target.table_name)
dataset.is_physical = target.sql is None

@staticmethod
def write_shadow_dataset( # pylint: disable=too-many-locals
dataset: "SqlaTable", database: Database, session: Session
) -> None:
"""
Shadow write the dataset to new models.
The ``SqlaTable`` model is currently being migrated to two new models, ``Table``
and ``Dataset``. In the first phase of the migration the new models are populated
whenever ``SqlaTable`` is modified (created, updated, or deleted).
In the second phase of the migration reads will be done from the new models.
Finally, in the third phase of the migration the old models will be removed.
For more context: https://github.com/apache/superset/issues/14909
"""

engine = database.get_sqla_engine(schema=dataset.schema)
conditional_quote = engine.dialect.identifier_preparer.quote

# create columns
columns = []
for column in dataset.columns:
# ``is_active`` might be ``None`` at this point, but it defaults to ``True``.
if column.is_active is False:
continue

extra_json = json.loads(column.extra or "{}")
for attr in {"groupby", "filterable", "verbose_name", "python_date_format"}:
value = getattr(column, attr)
if value:
extra_json[attr] = value

columns.append(
NewColumn(
name=column.column_name,
type=column.type or "Unknown",
expression=column.expression
or conditional_quote(column.column_name),
description=column.description,
is_temporal=column.is_dttm,
is_aggregation=False,
is_physical=column.expression is None,
is_spatial=False,
is_partition=False,
is_increase_desired=True,
extra_json=json.dumps(extra_json) if extra_json else None,
is_managed_externally=dataset.is_managed_externally,
external_url=dataset.external_url,
),
)

# create metrics
for metric in dataset.metrics:
extra_json = json.loads(metric.extra or "{}")
for attr in {"verbose_name", "metric_type", "d3format"}:
value = getattr(metric, attr)
if value:
extra_json[attr] = value

is_additive = (
metric.metric_type
and metric.metric_type.lower() in ADDITIVE_METRIC_TYPES
)

columns.append(
NewColumn(
name=metric.metric_name,
type="Unknown", # figuring this out would require a type inferrer
expression=metric.expression,
warning_text=metric.warning_text,
description=metric.description,
is_aggregation=True,
is_additive=is_additive,
is_physical=False,
is_spatial=False,
is_partition=False,
is_increase_desired=True,
extra_json=json.dumps(extra_json) if extra_json else None,
is_managed_externally=dataset.is_managed_externally,
external_url=dataset.external_url,
),
)

# physical dataset
tables = []
if dataset.sql is None:
physical_columns = [column for column in columns if column.is_physical]

# create table
table = NewTable(
name=dataset.table_name,
schema=dataset.schema,
catalog=None, # currently not supported
database_id=dataset.database_id,
columns=physical_columns,
is_managed_externally=dataset.is_managed_externally,
external_url=dataset.external_url,
)
tables.append(table)

# virtual dataset
else:
# mark all columns as virtual (not physical)
for column in columns:
column.is_physical = False

# find referenced tables
parsed = ParsedQuery(dataset.sql)
referenced_tables = parsed.tables

# predicate for finding the referenced tables
predicate = or_(
*[
and_(
NewTable.schema == (table.schema or dataset.schema),
NewTable.name == table.table,
)
for table in referenced_tables
]
)
tables = session.query(NewTable).filter(predicate).all()

# create the new dataset
new_dataset = NewDataset(
sqlatable_id=dataset.id,
name=dataset.table_name,
expression=dataset.sql or conditional_quote(dataset.table_name),
tables=tables,
columns=columns,
is_physical=dataset.sql is None,
is_managed_externally=dataset.is_managed_externally,
external_url=dataset.external_url,
)
session.add(new_dataset)


sa.event.listen(SqlaTable, "before_update", SqlaTable.before_update)
sa.event.listen(SqlaTable, "after_insert", SqlaTable.after_insert)
Expand Down

0 comments on commit 88029e2

Please sign in to comment.