Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Address regression introduced in #22853 #24121

Merged
merged 3 commits into from
Jun 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 44 additions & 54 deletions superset/connectors/sqla/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
backref,
Mapped,
Query,
reconstructor,
relationship,
RelationshipProperty,
Session,
Expand Down Expand Up @@ -218,6 +219,30 @@ class TableColumn(Model, BaseColumn, CertificationMixin):
update_from_object_fields = [s for s in export_fields if s not in ("table_id",)]
export_parent = "table"

def __init__(self, **kwargs: Any) -> None:
"""
Construct a TableColumn object.

Historically a TableColumn object (from an ORM perspective) was tighly bound to
a SqlaTable object, however with the introduction of the Query datasource this
is no longer true, i.e., the SqlaTable relationship is optional.

Now the TableColumn is either directly associated with the Database object (
which is unknown to the ORM) or indirectly via the SqlaTable object (courtesy of
the ORM) depending on the context.
"""

self._database: Database | None = kwargs.pop("database", None)
super().__init__(**kwargs)

@reconstructor
def init_on_load(self) -> None:
"""
Construct a TableColumn object when invoked via the SQLAlchemy ORM.
"""

self._database = None

@property
def is_boolean(self) -> bool:
"""
Expand Down Expand Up @@ -251,51 +276,33 @@ def is_temporal(self) -> bool:
return self.is_dttm
return self.type_generic == GenericDataType.TEMPORAL

@property
def database(self) -> Database:
return self.table.database if self.table else self._database
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By design either self.table or self._database (but not both) will be defined and thus the database property will always be non-null.


@property
def db_engine_spec(self) -> type[BaseEngineSpec]:
return self.table.db_engine_spec
return self.database.db_engine_spec

@property
def db_extra(self) -> dict[str, Any]:
return self.table.database.get_extra()
return self.database.get_extra()

@property
def type_generic(self) -> utils.GenericDataType | None:
if self.is_dttm:
return GenericDataType.TEMPORAL

bool_types = ("BOOL",)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A nice byproduct of this change is this previous logic can be removed given that the db_engine_spec is accessible if associated with a Query or SqlaTable.

num_types = (
"DOUBLE",
"FLOAT",
"INT",
"BIGINT",
"NUMBER",
"LONG",
"REAL",
"NUMERIC",
"DECIMAL",
"MONEY",
)
date_types = ("DATE", "TIME")
str_types = ("VARCHAR", "STRING", "CHAR")

if self.table is None:
# Query.TableColumns don't have a reference to a table.db_engine_spec
# reference so this logic will manage rendering types
if self.type and any(map(lambda t: t in self.type.upper(), str_types)):
return GenericDataType.STRING
if self.type and any(map(lambda t: t in self.type.upper(), bool_types)):
return GenericDataType.BOOLEAN
if self.type and any(map(lambda t: t in self.type.upper(), num_types)):
return GenericDataType.NUMERIC
if self.type and any(map(lambda t: t in self.type.upper(), date_types)):
return GenericDataType.TEMPORAL

column_spec = self.db_engine_spec.get_column_spec(
self.type, db_extra=self.db_extra
return (
column_spec.generic_type # pylint: disable=used-before-assignment
if (
column_spec := self.db_engine_spec.get_column_spec(
self.type,
db_extra=self.db_extra,
)
)
else None
)
return column_spec.generic_type if column_spec else None

def get_sqla_col(
self,
Expand All @@ -312,7 +319,7 @@ def get_sqla_col(
col = literal_column(expression, type_=type_)
else:
col = column(self.column_name, type_=type_)
col = self.table.make_sqla_column_compatible(col, label)
col = self.database.make_sqla_column_compatible(col, label)
return col

@property
Expand Down Expand Up @@ -343,15 +350,15 @@ def get_timestamp_expression(
type_ = column_spec.sqla_type if column_spec else DateTime
if not self.expression and not time_grain and not is_epoch:
sqla_col = column(self.column_name, type_=type_)
return self.table.make_sqla_column_compatible(sqla_col, label)
return self.database.make_sqla_column_compatible(sqla_col, label)
if expression := self.expression:
if template_processor:
expression = template_processor.process_template(expression)
col = literal_column(expression, type_=type_)
else:
col = column(self.column_name, type_=type_)
time_expr = self.db_engine_spec.get_timestamp_expr(col, pdf, time_grain)
return self.table.make_sqla_column_compatible(time_expr, label)
return self.database.make_sqla_column_compatible(time_expr, label)

@property
def data(self) -> dict[str, Any]:
Expand Down Expand Up @@ -423,7 +430,7 @@ def get_sqla_col(
expression = template_processor.process_template(expression)

sqla_col: ColumnClause = literal_column(expression)
return self.table.make_sqla_column_compatible(sqla_col, label)
return self.table.database.make_sqla_column_compatible(sqla_col, label)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The SqlMetric class is still tightly bound to the SqlaTable.


@property
def perm(self) -> str | None:
Expand Down Expand Up @@ -1008,23 +1015,6 @@ def adhoc_column_to_sqla( # pylint: disable=too-many-locals
)
return self.make_sqla_column_compatible(sqla_column, label)

def make_sqla_column_compatible(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The SqlaTable.make_sqla_column_compatible method was moved to Database.make_sqla_column_compatible given that now the TableColumn.table variable can be None.

self, sqla_col: ColumnElement, label: str | None = None
) -> ColumnElement:
"""Takes a sqlalchemy column object and adds label info if supported by engine.
:param sqla_col: sqlalchemy column instance
:param label: alias/label that column is expected to have
:return: either a sql alchemy column or label instance if supported by engine
"""
label_expected = label or sqla_col.name
db_engine_spec = self.db_engine_spec
# add quotes to tables
if db_engine_spec.allows_alias_in_select:
label = db_engine_spec.make_label_compatible(label_expected)
sqla_col = sqla_col.label(label)
sqla_col.key = label_expected
return sqla_col

def make_orderby_compatible(
self, select_exprs: list[ColumnElement], orderby_exprs: list[ColumnElement]
) -> None:
Expand Down
20 changes: 18 additions & 2 deletions superset/models/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=line-too-long
# pylint: disable=line-too-long,too-many-lines
"""A collection of ORM sqlalchemy models for Superset"""
import builtins
import enum
Expand Down Expand Up @@ -53,7 +53,7 @@
from sqlalchemy.orm import relationship
from sqlalchemy.pool import NullPool
from sqlalchemy.schema import UniqueConstraint
from sqlalchemy.sql import expression, Select
from sqlalchemy.sql import ColumnElement, expression, Select

from superset import app, db_engine_specs
from superset.constants import LRU_CACHE_MAX_SIZE, PASSWORD_MASK
Expand Down Expand Up @@ -953,6 +953,22 @@ def get_dialect(self) -> Dialect:
sqla_url = make_url_safe(self.sqlalchemy_uri_decrypted)
return sqla_url.get_dialect()()

def make_sqla_column_compatible(
self, sqla_col: ColumnElement, label: Optional[str] = None
) -> ColumnElement:
"""Takes a sqlalchemy column object and adds label info if supported by engine.
:param sqla_col: sqlalchemy column instance
:param label: alias/label that column is expected to have
:return: either a sql alchemy column or label instance if supported by engine
"""
label_expected = label or sqla_col.name
# add quotes to tables
if self.db_engine_spec.allows_alias_in_select:
label = self.db_engine_spec.make_label_compatible(label_expected)
sqla_col = sqla_col.label(label)
sqla_col.key = label_expected
return sqla_col


sqla.event.listen(Database, "after_insert", security_manager.database_after_insert)
sqla.event.listen(Database, "after_update", security_manager.database_after_update)
Expand Down
21 changes: 10 additions & 11 deletions superset/models/sql_lab.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,18 +192,17 @@ def columns(self) -> list["TableColumn"]:
TableColumn,
)

columns = []
for col in self.extra.get("columns", []):
columns.append(
TableColumn(
column_name=col["name"],
type=col["type"],
is_dttm=col["is_dttm"],
groupby=True,
filterable=True,
)
return [
TableColumn(
column_name=col["name"],
database=self.database,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a rewrite of the previous logic using a list comprehension. The only addition is the database argument which is required given a SQL Lab Query object has no associated SqlaTable which is previously were/how the database was obtained.

is_dttm=col["is_dttm"],
filterable=True,
groupby=True,
type=col["type"],
)
return columns
for col in self.extra.get("columns", [])
]

@property
def db_extra(self) -> Optional[dict[str, Any]]:
Expand Down
5 changes: 5 additions & 0 deletions tests/integration_tests/model_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,3 +671,8 @@ def test_data_for_slices_with_adhoc_column(self):

# clean up and auto commit
metadata_db.session.delete(slc)

@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_table_column_database(self) -> None:
tbl = self.get_table(name="birth_names")
assert tbl.get_column("ds").database is tbl.database # type: ignore
5 changes: 5 additions & 0 deletions tests/unit_tests/models/core_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,3 +207,8 @@ def test_dttm_sql_literal(
result: str,
) -> None:
assert SqlaTable(database=database).dttm_sql_literal(dttm, col) == result


def test_table_column_database() -> None:
database = Database(database_name="db")
assert TableColumn(database=database).database is database