diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 9e62b305416e..edde2232056c 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -58,6 +58,7 @@ backref, Mapped, Query, + reconstructor, relationship, RelationshipProperty, Session, @@ -218,6 +219,30 @@ class TableColumn(Model, BaseColumn, CertificationMixin): update_from_object_fields = [s for s in export_fields if s not in ("table_id",)] export_parent = "table" + def __init__(self, **kwargs: Any) -> None: + """ + Construct a TableColumn object. + + Historically a TableColumn object (from an ORM perspective) was tighly bound to + a SqlaTable object, however with the introduction of the Query datasource this + is no longer true, i.e., the SqlaTable relationship is optional. + + Now the TableColumn is either directly associated with the Database object ( + which is unknown to the ORM) or indirectly via the SqlaTable object (courtesy of + the ORM) depending on the context. + """ + + self._database: Database | None = kwargs.pop("database", None) + super().__init__(**kwargs) + + @reconstructor + def init_on_load(self) -> None: + """ + Construct a TableColumn object when invoked via the SQLAlchemy ORM. + """ + + self._database = None + @property def is_boolean(self) -> bool: """ @@ -251,51 +276,33 @@ def is_temporal(self) -> bool: return self.is_dttm return self.type_generic == GenericDataType.TEMPORAL + @property + def database(self) -> Database: + return self.table.database if self.table else self._database + @property def db_engine_spec(self) -> type[BaseEngineSpec]: - return self.table.db_engine_spec + return self.database.db_engine_spec @property def db_extra(self) -> dict[str, Any]: - return self.table.database.get_extra() + return self.database.get_extra() @property def type_generic(self) -> utils.GenericDataType | None: if self.is_dttm: return GenericDataType.TEMPORAL - bool_types = ("BOOL",) - num_types = ( - "DOUBLE", - "FLOAT", - "INT", - "BIGINT", - "NUMBER", - "LONG", - "REAL", - "NUMERIC", - "DECIMAL", - "MONEY", - ) - date_types = ("DATE", "TIME") - str_types = ("VARCHAR", "STRING", "CHAR") - - if self.table is None: - # Query.TableColumns don't have a reference to a table.db_engine_spec - # reference so this logic will manage rendering types - if self.type and any(map(lambda t: t in self.type.upper(), str_types)): - return GenericDataType.STRING - if self.type and any(map(lambda t: t in self.type.upper(), bool_types)): - return GenericDataType.BOOLEAN - if self.type and any(map(lambda t: t in self.type.upper(), num_types)): - return GenericDataType.NUMERIC - if self.type and any(map(lambda t: t in self.type.upper(), date_types)): - return GenericDataType.TEMPORAL - - column_spec = self.db_engine_spec.get_column_spec( - self.type, db_extra=self.db_extra + return ( + column_spec.generic_type # pylint: disable=used-before-assignment + if ( + column_spec := self.db_engine_spec.get_column_spec( + self.type, + db_extra=self.db_extra, + ) + ) + else None ) - return column_spec.generic_type if column_spec else None def get_sqla_col( self, @@ -312,7 +319,7 @@ def get_sqla_col( col = literal_column(expression, type_=type_) else: col = column(self.column_name, type_=type_) - col = self.table.make_sqla_column_compatible(col, label) + col = self.database.make_sqla_column_compatible(col, label) return col @property @@ -343,7 +350,7 @@ def get_timestamp_expression( type_ = column_spec.sqla_type if column_spec else DateTime if not self.expression and not time_grain and not is_epoch: sqla_col = column(self.column_name, type_=type_) - return self.table.make_sqla_column_compatible(sqla_col, label) + return self.database.make_sqla_column_compatible(sqla_col, label) if expression := self.expression: if template_processor: expression = template_processor.process_template(expression) @@ -351,7 +358,7 @@ def get_timestamp_expression( else: col = column(self.column_name, type_=type_) time_expr = self.db_engine_spec.get_timestamp_expr(col, pdf, time_grain) - return self.table.make_sqla_column_compatible(time_expr, label) + return self.database.make_sqla_column_compatible(time_expr, label) @property def data(self) -> dict[str, Any]: @@ -423,7 +430,7 @@ def get_sqla_col( expression = template_processor.process_template(expression) sqla_col: ColumnClause = literal_column(expression) - return self.table.make_sqla_column_compatible(sqla_col, label) + return self.table.database.make_sqla_column_compatible(sqla_col, label) @property def perm(self) -> str | None: @@ -1008,23 +1015,6 @@ def adhoc_column_to_sqla( # pylint: disable=too-many-locals ) return self.make_sqla_column_compatible(sqla_column, label) - def make_sqla_column_compatible( - self, sqla_col: ColumnElement, label: str | None = None - ) -> ColumnElement: - """Takes a sqlalchemy column object and adds label info if supported by engine. - :param sqla_col: sqlalchemy column instance - :param label: alias/label that column is expected to have - :return: either a sql alchemy column or label instance if supported by engine - """ - label_expected = label or sqla_col.name - db_engine_spec = self.db_engine_spec - # add quotes to tables - if db_engine_spec.allows_alias_in_select: - label = db_engine_spec.make_label_compatible(label_expected) - sqla_col = sqla_col.label(label) - sqla_col.key = label_expected - return sqla_col - def make_orderby_compatible( self, select_exprs: list[ColumnElement], orderby_exprs: list[ColumnElement] ) -> None: diff --git a/superset/models/core.py b/superset/models/core.py index 3c2b12d3782b..c18d12049e91 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=line-too-long +# pylint: disable=line-too-long,too-many-lines """A collection of ORM sqlalchemy models for Superset""" import builtins import enum @@ -53,7 +53,7 @@ from sqlalchemy.orm import relationship from sqlalchemy.pool import NullPool from sqlalchemy.schema import UniqueConstraint -from sqlalchemy.sql import expression, Select +from sqlalchemy.sql import ColumnElement, expression, Select from superset import app, db_engine_specs from superset.constants import LRU_CACHE_MAX_SIZE, PASSWORD_MASK @@ -953,6 +953,22 @@ def get_dialect(self) -> Dialect: sqla_url = make_url_safe(self.sqlalchemy_uri_decrypted) return sqla_url.get_dialect()() + def make_sqla_column_compatible( + self, sqla_col: ColumnElement, label: Optional[str] = None + ) -> ColumnElement: + """Takes a sqlalchemy column object and adds label info if supported by engine. + :param sqla_col: sqlalchemy column instance + :param label: alias/label that column is expected to have + :return: either a sql alchemy column or label instance if supported by engine + """ + label_expected = label or sqla_col.name + # add quotes to tables + if self.db_engine_spec.allows_alias_in_select: + label = self.db_engine_spec.make_label_compatible(label_expected) + sqla_col = sqla_col.label(label) + sqla_col.key = label_expected + return sqla_col + sqla.event.listen(Database, "after_insert", security_manager.database_after_insert) sqla.event.listen(Database, "after_update", security_manager.database_after_update) diff --git a/superset/models/sql_lab.py b/superset/models/sql_lab.py index b9ab153798f9..a566f75b43b2 100644 --- a/superset/models/sql_lab.py +++ b/superset/models/sql_lab.py @@ -192,18 +192,17 @@ def columns(self) -> list["TableColumn"]: TableColumn, ) - columns = [] - for col in self.extra.get("columns", []): - columns.append( - TableColumn( - column_name=col["name"], - type=col["type"], - is_dttm=col["is_dttm"], - groupby=True, - filterable=True, - ) + return [ + TableColumn( + column_name=col["name"], + database=self.database, + is_dttm=col["is_dttm"], + filterable=True, + groupby=True, + type=col["type"], ) - return columns + for col in self.extra.get("columns", []) + ] @property def db_extra(self) -> Optional[dict[str, Any]]: diff --git a/tests/integration_tests/model_tests.py b/tests/integration_tests/model_tests.py index c4bc7aa89bd9..3a5f7c0a77a1 100644 --- a/tests/integration_tests/model_tests.py +++ b/tests/integration_tests/model_tests.py @@ -671,3 +671,8 @@ def test_data_for_slices_with_adhoc_column(self): # clean up and auto commit metadata_db.session.delete(slc) + + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") + def test_table_column_database(self) -> None: + tbl = self.get_table(name="birth_names") + assert tbl.get_column("ds").database is tbl.database # type: ignore diff --git a/tests/unit_tests/models/core_test.py b/tests/unit_tests/models/core_test.py index d37296447ad6..267b7c024aae 100644 --- a/tests/unit_tests/models/core_test.py +++ b/tests/unit_tests/models/core_test.py @@ -207,3 +207,8 @@ def test_dttm_sql_literal( result: str, ) -> None: assert SqlaTable(database=database).dttm_sql_literal(dttm, col) == result + + +def test_table_column_database() -> None: + database = Database(database_name="db") + assert TableColumn(database=database).database is database