Skip to content

Commit

Permalink
catch some potential errors on dual write
Browse files Browse the repository at this point in the history
  • Loading branch information
eschutho committed Jun 11, 2022
1 parent eab0009 commit baa5cd8
Show file tree
Hide file tree
Showing 5 changed files with 161 additions and 17 deletions.
42 changes: 27 additions & 15 deletions superset/connectors/sqla/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
)
from sqlalchemy.engine.base import Connection
from sqlalchemy.orm import backref, Query, relationship, RelationshipProperty, Session
from sqlalchemy.orm.exc import NoResultFound
from sqlalchemy.orm.mapper import Mapper
from sqlalchemy.schema import UniqueConstraint
from sqlalchemy.sql import column, ColumnElement, literal_column, table
Expand Down Expand Up @@ -933,7 +934,8 @@ def mutate_query_from_config(self, sql: str) -> str:
if sql_query_mutator:
sql = sql_query_mutator(
sql,
user_name=get_username(), # TODO(john-bodley): Deprecate in 3.0.
# TODO(john-bodley): Deprecate in 3.0.
user_name=get_username(),
security_manager=security_manager,
database=self.database,
)
Expand Down Expand Up @@ -2115,7 +2117,7 @@ def get_sl_columns(self) -> List[NewColumn]:
]

@staticmethod
def update_table( # pylint: disable=unused-argument
def update_column( # pylint: disable=unused-argument
mapper: Mapper, connection: Connection, target: Union[SqlMetric, TableColumn]
) -> None:
"""
Expand All @@ -2130,7 +2132,7 @@ def update_table( # pylint: disable=unused-argument
# table is updated. This busts the cache key for all charts that use the table.
session.execute(update(SqlaTable).where(SqlaTable.id == target.table.id))

# if table itself has changed, shadow-writing will happen in `after_udpate` anyway
# if table itself has changed, shadow-writing will happen in `after_update` anyway
if target.table not in session.dirty:
dataset: NewDataset = (
session.query(NewDataset)
Expand All @@ -2146,17 +2148,27 @@ def update_table( # pylint: disable=unused-argument

# update changed_on timestamp
session.execute(update(NewDataset).where(NewDataset.id == dataset.id))

# update `Column` model as well
session.add(
target.to_sl_column(
{
target.uuid: session.query(NewColumn)
.filter_by(uuid=target.uuid)
.one_or_none()
}
try:
column = session.query(NewColumn).filter_by(uuid=target.uuid).one()
# update `Column` model as well
session.merge(target.to_sl_column({target.uuid: column}))
except NoResultFound:
logger.warning("No column was found for %s", target)
# see if the column is in cache
column = next(
find_cached_objects_in_session(
session, NewColumn, uuids=[target.uuid]
),
None,
)
)

if not column:
# to be safe, use a different uuid and create a new column
uuid = uuid4()
target.uuid = uuid
column = NewColumn(uuid=uuid)

session.add(target.to_sl_column({column.uuid: column}))

@staticmethod
def after_insert(
Expand Down Expand Up @@ -2441,9 +2453,9 @@ def write_shadow_dataset(
sa.event.listen(SqlaTable, "after_insert", SqlaTable.after_insert)
sa.event.listen(SqlaTable, "after_delete", SqlaTable.after_delete)
sa.event.listen(SqlaTable, "after_update", SqlaTable.after_update)
sa.event.listen(SqlMetric, "after_update", SqlaTable.update_table)
sa.event.listen(SqlMetric, "after_update", SqlaTable.update_column)
sa.event.listen(SqlMetric, "after_delete", SqlMetric.after_delete)
sa.event.listen(TableColumn, "after_update", SqlaTable.update_table)
sa.event.listen(TableColumn, "after_update", SqlaTable.update_column)
sa.event.listen(TableColumn, "after_delete", TableColumn.after_delete)

RLSFilterRoles = Table(
Expand Down
11 changes: 10 additions & 1 deletion superset/connectors/sqla/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import logging
from contextlib import closing
from typing import (
Any,
Expand All @@ -35,6 +36,7 @@
from sqlalchemy.exc import NoSuchTableError
from sqlalchemy.ext.declarative import DeclarativeMeta
from sqlalchemy.orm import Session
from sqlalchemy.orm.exc import ObjectDeletedError
from sqlalchemy.sql.type_api import TypeEngine

from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
Expand Down Expand Up @@ -191,6 +193,7 @@ def get_identifier_quoter(drivername: str) -> Dict[str, Callable[[str], str]]:


DeclarativeModel = TypeVar("DeclarativeModel", bound=DeclarativeMeta)
logger = logging.getLogger(__name__)


def find_cached_objects_in_session(
Expand All @@ -209,9 +212,15 @@ def find_cached_objects_in_session(
if not ids and not uuids:
return iter([])
uuids = uuids or []
try:
items = set(session)
except ObjectDeletedError:
logger.warning("ObjectDeletedError", exc_info=True)
return iter(())

return (
item
# `session` is an iterator of all known items
for item in set(session)
for item in items
if isinstance(item, cls) and (item.id in ids if ids else item.uuid in uuids)
)
6 changes: 6 additions & 0 deletions tests/integration_tests/datasets/api_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
DAODeleteFailedError,
DAOUpdateFailedError,
)
from superset.datasets.models import Dataset
from superset.extensions import db, security_manager
from superset.models.core import Database
from superset.utils.core import backend, get_example_default_schema
Expand Down Expand Up @@ -1635,16 +1636,21 @@ def test_import_dataset(self):
database = (
db.session.query(Database).filter_by(uuid=database_config["uuid"]).one()
)
shadow_dataset = (
db.session.query(Dataset).filter_by(uuid=dataset_config["uuid"]).one()
)
assert database.database_name == "imported_database"

assert len(database.tables) == 1
dataset = database.tables[0]
assert dataset.table_name == "imported_dataset"
assert str(dataset.uuid) == dataset_config["uuid"]
assert str(shadow_dataset.uuid) == dataset_config["uuid"]

dataset.owners = []
database.owners = []
db.session.delete(dataset)
db.session.delete(shadow_dataset)
db.session.delete(database)
db.session.commit()

Expand Down
69 changes: 69 additions & 0 deletions tests/integration_tests/datasets/model_tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from unittest import mock

import pytest
from sqlalchemy.orm.exc import NoResultFound

from superset.connectors.sqla.models import SqlaTable, TableColumn
from superset.extensions import db
from tests.integration_tests.base_tests import SupersetTestCase
from tests.integration_tests.fixtures.datasource import load_dataset_with_columns


class SqlaTableModelTest(SupersetTestCase):
@pytest.mark.usefixtures("load_dataset_with_columns")
def test_dual_update_column(self) -> None:
"""
Test that when updating a sqla ``TableColumn``
That the shadow ``Column`` is also updated
"""
dataset = db.session.query(SqlaTable).filter_by(table_name="students").first()
column = dataset.columns[0]
column_name = column.column_name
column.column_name = "new_column_name"
SqlaTable.update_column(None, None, target=column)

# refetch
dataset = db.session.query(SqlaTable).filter_by(id=dataset.id).one()
assert dataset.columns[0].column_name == "new_column_name"

# reset
column.column_name = column_name
SqlaTable.update_column(None, None, target=column)

@pytest.mark.usefixtures("load_dataset_with_columns")
@mock.patch("superset.columns.models.Column")
def test_dual_update_column_not_found(self, column_mock) -> None:
"""
Test that when updating a sqla ``TableColumn``
That the shadow ``Column`` is also updated
"""
dataset = db.session.query(SqlaTable).filter_by(table_name="students").first()
column = dataset.columns[0]
column_uuid = column.uuid
with mock.patch("sqlalchemy.orm.query.Query.one", side_effect=NoResultFound):
SqlaTable.update_column(None, None, target=column)

# refetch
dataset = db.session.query(SqlaTable).filter_by(id=dataset.id).one()
# it should create a new uuid
assert dataset.columns[0].uuid != column_uuid

# reset
column.uuid = column_uuid
SqlaTable.update_column(None, None, target=column)
50 changes: 49 additions & 1 deletion tests/integration_tests/fixtures/datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,20 @@
# specific language governing permissions and limitations
# under the License.
"""Fixtures for test_datasource.py"""
from typing import Any, Dict
from typing import Any, Dict, Generator

import pytest
from sqlalchemy import Column, create_engine, Date, Integer, MetaData, String, Table
from sqlalchemy.ext.declarative.api import declarative_base

from superset.columns.models import Column as Sl_Column
from superset.connectors.sqla.models import SqlaTable, TableColumn
from superset.extensions import db
from superset.models.core import Database
from superset.tables.models import Table as Sl_Table
from superset.utils.core import get_example_default_schema
from superset.utils.database import get_example_database
from tests.integration_tests.test_app import app


def get_datasource_post() -> Dict[str, Any]:
Expand Down Expand Up @@ -159,3 +169,41 @@ def get_datasource_post() -> Dict[str, Any]:
},
],
}


@pytest.fixture()
def load_dataset_with_columns() -> Generator[SqlaTable, None, None]:
with app.app_context():
engine = create_engine(app.config["SQLALCHEMY_DATABASE_URI"], echo=True)
meta = MetaData()
session = db.session

students = Table(
"students",
meta,
Column("id", Integer, primary_key=True),
Column("name", String),
Column("lastname", String),
Column("ds", Date),
)
meta.create_all(engine)

students.insert().values(name="George", ds="2021-01-01")

dataset = SqlaTable(
database_id=db.session.query(Database).first().id, table_name="students"
)
column = TableColumn(table_id=dataset.id, column_name="name")
dataset.columns = [column]
session.add(dataset)
session.commit()
yield dataset

# cleanup
students_table = meta.tables.get("students")
if students_table is not None:
base = declarative_base()
base.metadata.drop_all(engine, [students_table], checkfirst=True)
session.delete(dataset)
session.delete(column)
session.commit()

0 comments on commit baa5cd8

Please sign in to comment.