Skip to content

Commit

Permalink
remove autoflush for queries during dual write (#20460)
Browse files Browse the repository at this point in the history
  • Loading branch information
eschutho committed Jun 23, 2022
1 parent 661ab35 commit 44f0b51
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 12 deletions.
54 changes: 44 additions & 10 deletions superset/connectors/sqla/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,7 @@ def to_sl_column(
self, known_columns: Optional[Dict[str, NewColumn]] = None
) -> NewColumn:
"""Convert a TableColumn to NewColumn"""
session: Session = inspect(self).session
column = known_columns.get(self.uuid) if known_columns else None
if not column:
column = NewColumn()
Expand All @@ -452,6 +453,21 @@ def to_sl_column(
if value:
extra_json[attr] = value

if not column.id:
with session.no_autoflush:
saved_column = (
session.query(NewColumn).filter_by(uuid=self.uuid).one_or_none()
)
if saved_column:
logger.warning(
"sl_column already exists. Assigning existing id %s", self
)

# uuid isn't a primary key, so add the id of the existing column to
# ensure that the column is modified instead of created
# in order to avoid a uuid collision
column.id = saved_column.id

column.uuid = self.uuid
column.created_on = self.created_on
column.changed_on = self.changed_on
Expand Down Expand Up @@ -555,6 +571,7 @@ def to_sl_column(
) -> NewColumn:
"""Convert a SqlMetric to NewColumn. Find and update existing or
create a new one."""
session: Session = inspect(self).session
column = known_columns.get(self.uuid) if known_columns else None
if not column:
column = NewColumn()
Expand All @@ -568,6 +585,20 @@ def to_sl_column(
self.metric_type and self.metric_type.lower() in ADDITIVE_METRIC_TYPES_LOWER
)

if not column.id:
with session.no_autoflush:
saved_column = (
session.query(NewColumn).filter_by(uuid=self.uuid).one_or_none()
)
if saved_column:
logger.warning(
"sl_column already exists. Assigning existing id %s", self
)
# uuid isn't a primary key, so add the id of the existing column to
# ensure that the column is modified instead of created
# in order to avoid a uuid collision
column.id = saved_column.id

column.uuid = self.uuid
column.name = self.metric_name
column.created_on = self.created_on
Expand Down Expand Up @@ -2149,10 +2180,11 @@ def get_sl_columns(self) -> List[NewColumn]:
uuids.remove(column.uuid)

if uuids:
# load those not found from db
existing_columns |= set(
session.query(NewColumn).filter(NewColumn.uuid.in_(uuids))
)
with session.no_autoflush:
# load those not found from db
existing_columns |= set(
session.query(NewColumn).filter(NewColumn.uuid.in_(uuids))
)

known_columns = {column.uuid: column for column in existing_columns}
return [
Expand Down Expand Up @@ -2192,9 +2224,10 @@ def update_column( # pylint: disable=unused-argument
# update changed_on timestamp
session.execute(update(NewDataset).where(NewDataset.id == dataset.id))
try:
column = session.query(NewColumn).filter_by(uuid=target.uuid).one()
# update `Column` model as well
session.merge(target.to_sl_column({target.uuid: column}))
with session.no_autoflush:
column = session.query(NewColumn).filter_by(uuid=target.uuid).one()
# update `Column` model as well
session.merge(target.to_sl_column({target.uuid: column}))
except NoResultFound:
logger.warning("No column was found for %s", target)
# see if the column is in cache
Expand All @@ -2204,14 +2237,15 @@ def update_column( # pylint: disable=unused-argument
),
None,
)
if column:
logger.warning("New column was found in cache: %s", column)

if not column:
else:
# to be safe, use a different uuid and create a new column
uuid = uuid4()
target.uuid = uuid
column = NewColumn(uuid=uuid)

session.add(target.to_sl_column({column.uuid: column}))
session.add(target.to_sl_column())

@staticmethod
def after_insert(
Expand Down
18 changes: 18 additions & 0 deletions tests/integration_tests/datasets/model_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
from unittest import mock

import pytest
from sqlalchemy import inspect
from sqlalchemy.orm.exc import NoResultFound

from superset.columns.models import Column
from superset.connectors.sqla.models import SqlaTable, TableColumn
from superset.extensions import db
from tests.integration_tests.base_tests import SupersetTestCase
Expand Down Expand Up @@ -59,6 +61,10 @@ def test_dual_update_column_not_found(self, column_mock) -> None:
with mock.patch("sqlalchemy.orm.query.Query.one", side_effect=NoResultFound):
SqlaTable.update_column(None, None, target=column)

session = inspect(column).session

session.flush()

# refetch
dataset = db.session.query(SqlaTable).filter_by(id=dataset.id).one()
# it should create a new uuid
Expand All @@ -67,3 +73,15 @@ def test_dual_update_column_not_found(self, column_mock) -> None:
# reset
column.uuid = column_uuid
SqlaTable.update_column(None, None, target=column)

@pytest.mark.usefixtures("load_dataset_with_columns")
def test_to_sl_column_no_known_columns(self) -> None:
"""
Test that the function returns a new column
"""
dataset = db.session.query(SqlaTable).filter_by(table_name="students").first()
column = dataset.columns[0]
new_column = column.to_sl_column()

# it should use the same uuid
assert column.uuid == new_column.uuid
8 changes: 6 additions & 2 deletions tests/unit_tests/datasets/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -718,6 +718,7 @@ def test_update_physical_sqlatable_columns(
metrics=[],
database=Database(database_name="my_database", sqlalchemy_uri="sqlite://"),
)

session.add(sqla_table)
session.flush()

Expand All @@ -735,8 +736,11 @@ def test_update_physical_sqlatable_columns(
assert session.query(Column).count() == 3
dataset = session.query(Dataset).one()
assert len(dataset.columns) == 2
for table_column, dataset_column in zip(sqla_table.columns, dataset.columns):
assert table_column.uuid == dataset_column.uuid

# check that both lists have the same uuids
assert [col.uuid for col in sqla_table.columns].sort() == [
col.uuid for col in dataset.columns
].sort()

# delete the column in the original instance
sqla_table.columns = sqla_table.columns[1:]
Expand Down

0 comments on commit 44f0b51

Please sign in to comment.