Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: new column UUID conflicts in dual write #20460

Merged
merged 1 commit into from
Jun 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 44 additions & 10 deletions superset/connectors/sqla/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,7 @@ def to_sl_column(
self, known_columns: Optional[Dict[str, NewColumn]] = None
) -> NewColumn:
"""Convert a TableColumn to NewColumn"""
session: Session = inspect(self).session
column = known_columns.get(self.uuid) if known_columns else None
if not column:
column = NewColumn()
Expand All @@ -452,6 +453,21 @@ def to_sl_column(
if value:
extra_json[attr] = value

if not column.id:
with session.no_autoflush:
saved_column = (
session.query(NewColumn).filter_by(uuid=self.uuid).one_or_none()
)
if saved_column:
logger.warning(
"sl_column already exists. Assigning existing id %s", self
)

# uuid isn't a primary key, so add the id of the existing column to
# ensure that the column is modified instead of created
# in order to avoid a uuid collision
column.id = saved_column.id

column.uuid = self.uuid
column.created_on = self.created_on
column.changed_on = self.changed_on
Expand Down Expand Up @@ -555,6 +571,7 @@ def to_sl_column(
) -> NewColumn:
"""Convert a SqlMetric to NewColumn. Find and update existing or
create a new one."""
session: Session = inspect(self).session
column = known_columns.get(self.uuid) if known_columns else None
if not column:
column = NewColumn()
Expand All @@ -568,6 +585,20 @@ def to_sl_column(
self.metric_type and self.metric_type.lower() in ADDITIVE_METRIC_TYPES_LOWER
)

if not column.id:
with session.no_autoflush:
saved_column = (
session.query(NewColumn).filter_by(uuid=self.uuid).one_or_none()
)
if saved_column:
logger.warning(
"sl_column already exists. Assigning existing id %s", self
)
# uuid isn't a primary key, so add the id of the existing column to
# ensure that the column is modified instead of created
# in order to avoid a uuid collision
column.id = saved_column.id

column.uuid = self.uuid
column.name = self.metric_name
column.created_on = self.created_on
Expand Down Expand Up @@ -2149,10 +2180,11 @@ def get_sl_columns(self) -> List[NewColumn]:
uuids.remove(column.uuid)

if uuids:
# load those not found from db
existing_columns |= set(
session.query(NewColumn).filter(NewColumn.uuid.in_(uuids))
)
with session.no_autoflush:
# load those not found from db
existing_columns |= set(
session.query(NewColumn).filter(NewColumn.uuid.in_(uuids))
)

known_columns = {column.uuid: column for column in existing_columns}
return [
Expand Down Expand Up @@ -2192,9 +2224,10 @@ def update_column( # pylint: disable=unused-argument
# update changed_on timestamp
session.execute(update(NewDataset).where(NewDataset.id == dataset.id))
try:
column = session.query(NewColumn).filter_by(uuid=target.uuid).one()
# update `Column` model as well
session.merge(target.to_sl_column({target.uuid: column}))
with session.no_autoflush:
column = session.query(NewColumn).filter_by(uuid=target.uuid).one()
# update `Column` model as well
session.merge(target.to_sl_column({target.uuid: column}))
except NoResultFound:
logger.warning("No column was found for %s", target)
# see if the column is in cache
Expand All @@ -2204,14 +2237,15 @@ def update_column( # pylint: disable=unused-argument
),
None,
)
if column:
logger.warning("New column was found in cache: %s", column)

if not column:
else:
# to be safe, use a different uuid and create a new column
uuid = uuid4()
target.uuid = uuid
column = NewColumn(uuid=uuid)

session.add(target.to_sl_column({column.uuid: column}))
session.add(target.to_sl_column())

@staticmethod
def after_insert(
Expand Down
18 changes: 18 additions & 0 deletions tests/integration_tests/datasets/model_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
from unittest import mock

import pytest
from sqlalchemy import inspect
from sqlalchemy.orm.exc import NoResultFound

from superset.columns.models import Column
from superset.connectors.sqla.models import SqlaTable, TableColumn
from superset.extensions import db
from tests.integration_tests.base_tests import SupersetTestCase
Expand Down Expand Up @@ -59,6 +61,10 @@ def test_dual_update_column_not_found(self, column_mock) -> None:
with mock.patch("sqlalchemy.orm.query.Query.one", side_effect=NoResultFound):
SqlaTable.update_column(None, None, target=column)

session = inspect(column).session

session.flush()

# refetch
dataset = db.session.query(SqlaTable).filter_by(id=dataset.id).one()
# it should create a new uuid
Expand All @@ -67,3 +73,15 @@ def test_dual_update_column_not_found(self, column_mock) -> None:
# reset
column.uuid = column_uuid
SqlaTable.update_column(None, None, target=column)

@pytest.mark.usefixtures("load_dataset_with_columns")
def test_to_sl_column_no_known_columns(self) -> None:
"""
Test that the function returns a new column
"""
dataset = db.session.query(SqlaTable).filter_by(table_name="students").first()
column = dataset.columns[0]
new_column = column.to_sl_column()

# it should use the same uuid
assert column.uuid == new_column.uuid
8 changes: 6 additions & 2 deletions tests/unit_tests/datasets/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -718,6 +718,7 @@ def test_update_physical_sqlatable_columns(
metrics=[],
database=Database(database_name="my_database", sqlalchemy_uri="sqlite://"),
)

session.add(sqla_table)
session.flush()

Expand All @@ -735,8 +736,11 @@ def test_update_physical_sqlatable_columns(
assert session.query(Column).count() == 3
dataset = session.query(Dataset).one()
assert len(dataset.columns) == 2
for table_column, dataset_column in zip(sqla_table.columns, dataset.columns):
assert table_column.uuid == dataset_column.uuid

# check that both lists have the same uuids
assert [col.uuid for col in sqla_table.columns].sort() == [
col.uuid for col in dataset.columns
].sort()

# delete the column in the original instance
sqla_table.columns = sqla_table.columns[1:]
Expand Down