Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Ensure table uniqueness on update #15909

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 3 additions & 0 deletions UPDATING.md
Expand Up @@ -23,6 +23,9 @@ This file documents any backwards-incompatible changes in Superset and
assists people when migrating to a new version.

## Next
- [15909](https://github.com/apache/incubator-superset/pull/15909): a change which
drops a uniqueness criterion (which may or may not have existed) to the tables table. This constraint was obsolete as it is handled by the ORM due to differences in how MySQL, PostgreSQL, etc. handle uniqueness for NULL values.

- [15927](https://github.com/apache/superset/pull/15927): Upgrades Celery to 5.x. Per the [upgrading](https://docs.celeryproject.org/en/stable/history/whatsnew-5.0.html#upgrading-from-celery-4-x) instructions Celery 5.0 introduces a new CLI implementation which is not completely backwards compatible. Please ensure global options are positioned before the sub-command.

- [13772](https://github.com/apache/superset/pull/13772): Row level security (RLS) is now enabled by default. To activate the feature, please run `superset init` to expose the RLS menus to Admin users.
Expand Down
52 changes: 51 additions & 1 deletion superset/connectors/sqla/models.py
Expand Up @@ -483,7 +483,15 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at
owner_class = security_manager.user_model

__tablename__ = "tables"
__table_args__ = (UniqueConstraint("database_id", "table_name"),)

# Note this uniqueness constraint is not part of the physical schema, i.e., it does
# not exist in the migrations, but is required by `import_from_dict` to ensure the
# correct filters are applied in order to identify uniqueness.
#
# The reason it does not physically exist is MySQL, PostgreSQL, etc. have a
# different interpretation of uniqueness when it comes to NULL which is problematic
# given the schema is optional.
__table_args__ = (UniqueConstraint("database_id", "schema", "table_name"),)

table_name = Column(String(250), nullable=False)
main_dttm_col = Column(String(250))
Expand Down Expand Up @@ -1604,6 +1612,47 @@ class and any keys added via `ExtraCache`.
extra_cache_keys += sqla_query.extra_cache_keys
return extra_cache_keys

@staticmethod
def before_update(
mapper: Mapper, # pylint: disable=unused-argument
connection: Connection, # pylint: disable=unused-argument
target: "SqlaTable",
) -> None:
"""
Check whether before update if the target table already exists.

Note this listener is called when any fields are being updated and thus it is
necessary to first check whether the reference table is being updated.

Note this logic is temporary, given uniqueness is handled via the dataset DAO,
but is necessary until both the legacy datasource editor and datasource/save
endpoints are deprecated.

:param mapper: The table mapper
:param connection: The DB-API connection
:param target: The mapped instance being persisted
:raises Exception: If the target table is not unique
"""

from superset.datasets.commands.exceptions import get_dataset_exist_error_msg
from superset.datasets.dao import DatasetDAO

# Check whether the relevant attributes have changed.
state = db.inspect(target) # pylint: disable=no-member

for attr in ["database_id", "schema", "table_name"]:
history = state.get_history(attr, True)

if history.has_changes():
break
else:
return None

if not DatasetDAO.validate_uniqueness(
target.database_id, target.schema, target.table_name
):
raise Exception(get_dataset_exist_error_msg(target.full_name))


def update_table(
_mapper: Mapper, _connection: Connection, obj: Union[SqlMetric, TableColumn]
Expand All @@ -1621,6 +1670,7 @@ def update_table(

sa.event.listen(SqlaTable, "after_insert", security_manager.set_perm)
sa.event.listen(SqlaTable, "after_update", security_manager.set_perm)
sa.event.listen(SqlaTable, "before_update", SqlaTable.before_update)
sa.event.listen(SqlMetric, "after_update", update_table)
sa.event.listen(TableColumn, "after_update", update_table)

Expand Down
@@ -0,0 +1,54 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""drop tables constraint

Revision ID: 31b2a1039d4a
Revises: ae1ed299413b
Create Date: 2021-07-27 08:25:20.755453

"""

from alembic import op
from sqlalchemy import engine
from sqlalchemy.exc import OperationalError, ProgrammingError

from superset.utils.core import generic_find_uq_constraint_name

# revision identifiers, used by Alembic.
revision = "31b2a1039d4a"
down_revision = "ae1ed299413b"

conv = {"uq": "uq_%(table_name)s_%(column_0_name)s"}


def upgrade():
bind = op.get_bind()
insp = engine.reflection.Inspector.from_engine(bind)

# Drop the uniqueness constraint if it exists.
constraint = generic_find_uq_constraint_name("tables", {"table_name"}, insp)

if constraint:
with op.batch_alter_table("tables", naming_convention=conv) as batch_op:
batch_op.drop_constraint(constraint, type_="unique")


def downgrade():

# One cannot simply re-add the uniqueness constraint as it may not have previously
# existed.
pass
5 changes: 2 additions & 3 deletions superset/models/slice.py
Expand Up @@ -53,10 +53,9 @@
logger = logging.getLogger(__name__)


class Slice(
class Slice( # pylint: disable=too-many-instance-attributes,too-many-public-methods
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure why pylint was complaining on this file as it was unchanged.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I also faced this issue.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here. :-P

Model, AuditMixinNullable, ImportExportMixin
): # pylint: disable=too-many-public-methods, too-many-instance-attributes

):
"""A slice is essentially a report or a view on data"""

__tablename__ = "slices"
Expand Down
8 changes: 4 additions & 4 deletions tests/integration_tests/access_tests.py
Expand Up @@ -161,7 +161,7 @@ def test_override_role_permissions_1_table(self):

updated_override_me = security_manager.find_role("override_me")
self.assertEqual(1, len(updated_override_me.permissions))
birth_names = self.get_table_by_name("birth_names")
birth_names = self.get_table(name="birth_names")
self.assertEqual(
birth_names.perm, updated_override_me.permissions[0].view_menu.name
)
Expand Down Expand Up @@ -190,7 +190,7 @@ def test_override_role_permissions_druid_and_table(self):
"datasource_access", updated_role.permissions[1].permission.name
)

birth_names = self.get_table_by_name("birth_names")
birth_names = self.get_table(name="birth_names")
self.assertEqual(birth_names.perm, perms[2].view_menu.name)
self.assertEqual(
"datasource_access", updated_role.permissions[2].permission.name
Expand All @@ -204,7 +204,7 @@ def test_override_role_permissions_drops_absent_perms(self):
override_me = security_manager.find_role("override_me")
override_me.permissions.append(
security_manager.find_permission_view_menu(
view_menu_name=self.get_table_by_name("energy_usage").perm,
view_menu_name=self.get_table(name="energy_usage").perm,
permission_name="datasource_access",
)
)
Expand All @@ -218,7 +218,7 @@ def test_override_role_permissions_drops_absent_perms(self):
self.assertEqual(201, response.status_code)
updated_override_me = security_manager.find_role("override_me")
self.assertEqual(1, len(updated_override_me.permissions))
birth_names = self.get_table_by_name("birth_names")
birth_names = self.get_table(name="birth_names")
self.assertEqual(
birth_names.perm, updated_override_me.permissions[0].view_menu.name
)
Expand Down
43 changes: 23 additions & 20 deletions tests/integration_tests/base_tests.py
Expand Up @@ -99,10 +99,6 @@ def post_assert_metric(
return rv


def get_table_by_name(name: str) -> SqlaTable:
return db.session.query(SqlaTable).filter_by(table_name=name).one()


@pytest.fixture
def logged_in_admin():
"""Fixture with app context and logged in admin user."""
Expand Down Expand Up @@ -132,12 +128,7 @@ def get_nonexistent_numeric_id(model):

@staticmethod
def get_birth_names_dataset() -> SqlaTable:
example_db = get_example_database()
return (
db.session.query(SqlaTable)
.filter_by(database=example_db, table_name="birth_names")
.one()
)
return SupersetTestCase.get_table(name="birth_names")

@staticmethod
def create_user_with_roles(
Expand Down Expand Up @@ -254,13 +245,31 @@ def get_slice(
return slc

@staticmethod
def get_table_by_name(name: str) -> SqlaTable:
return get_table_by_name(name)
def get_table(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cleanup to ensure consistency.

name: str, database_id: Optional[int] = None, schema: Optional[str] = None
) -> SqlaTable:
return (
db.session.query(SqlaTable)
.filter_by(
database_id=database_id
or SupersetTestCase.get_database_by_name("examples").id,
schema=schema,
table_name=name,
)
.one()
)

@staticmethod
def get_database_by_id(db_id: int) -> Database:
return db.session.query(Database).filter_by(id=db_id).one()

@staticmethod
def get_database_by_name(database_name: str = "main") -> Database:
if database_name == "examples":
return get_example_database()
else:
raise ValueError("Database doesn't exist")

@staticmethod
def get_druid_ds_by_name(name: str) -> DruidDatasource:
return db.session.query(DruidDatasource).filter_by(datasource_name=name).first()
Expand Down Expand Up @@ -340,12 +349,6 @@ def revoke_role_access_to_table(self, role_name, table):
):
security_manager.del_permission_role(public_role, perm)

def _get_database_by_name(self, database_name="main"):
if database_name == "examples":
return get_example_database()
else:
raise ValueError("Database doesn't exist")

def run_sql(
self,
sql,
Expand All @@ -364,7 +367,7 @@ def run_sql(
if user_name:
self.logout()
self.login(username=(user_name or "admin"))
dbid = self._get_database_by_name(database_name).id
dbid = SupersetTestCase.get_database_by_name(database_name).id
json_payload = {
"database_id": dbid,
"sql": sql,
Expand Down Expand Up @@ -448,7 +451,7 @@ def validate_sql(
if user_name:
self.logout()
self.login(username=(user_name if user_name else "admin"))
dbid = self._get_database_by_name(database_name).id
dbid = SupersetTestCase.get_database_by_name(database_name).id
resp = self.get_json_resp(
"/superset/validate_sql_json/",
raise_on_error=False,
Expand Down
2 changes: 1 addition & 1 deletion tests/integration_tests/charts/api_tests.py
Expand Up @@ -545,7 +545,7 @@ def test_update_chart(self):
"""
admin = self.get_user("admin")
gamma = self.get_user("gamma")
birth_names_table_id = SupersetTestCase.get_table_by_name("birth_names").id
birth_names_table_id = SupersetTestCase.get_table(name="birth_names").id
chart_id = self.insert_chart(
"title", [admin.id], birth_names_table_id, admin
).id
Expand Down
4 changes: 2 additions & 2 deletions tests/integration_tests/csv_upload_tests.py
Expand Up @@ -221,7 +221,7 @@ def test_import_csv_explore_database(setup_csv_upload, create_csv_files):
f'CSV file "{CSV_FILENAME1}" uploaded to table "{CSV_UPLOAD_TABLE_W_EXPLORE}"'
in resp
)
table = SupersetTestCase.get_table_by_name(CSV_UPLOAD_TABLE_W_EXPLORE)
table = SupersetTestCase.get_table(name=CSV_UPLOAD_TABLE_W_EXPLORE)
assert table.database_id == utils.get_example_database().id


Expand Down Expand Up @@ -267,7 +267,7 @@ def test_import_csv(setup_csv_upload, create_csv_files):
)
assert success_msg_f2 in resp

table = SupersetTestCase.get_table_by_name(CSV_UPLOAD_TABLE)
table = SupersetTestCase.get_table(name=CSV_UPLOAD_TABLE)
# make sure the new column name is reflected in the table metadata
assert "d" in table.column_names

Expand Down
8 changes: 6 additions & 2 deletions tests/integration_tests/dashboard_utils.py
Expand Up @@ -35,6 +35,7 @@ def create_table_for_dashboard(
dtype: Dict[str, Any],
table_description: str = "",
fetch_values_predicate: Optional[str] = None,
schema: Optional[str] = None,
) -> SqlaTable:
df.to_sql(
table_name,
Expand All @@ -44,14 +45,17 @@ def create_table_for_dashboard(
dtype=dtype,
index=False,
method="multi",
schema=schema,
)

table_source = ConnectorRegistry.sources["table"]
table = (
db.session.query(table_source).filter_by(table_name=table_name).one_or_none()
db.session.query(table_source)
.filter_by(database_id=database.id, schema=schema, table_name=table_name)
.one_or_none()
)
if not table:
table = table_source(table_name=table_name)
table = table_source(schema=schema, table_name=table_name)
if fetch_values_predicate:
table.fetch_values_predicate = fetch_values_predicate
table.database = database
Expand Down