Skip to content

Commit

Permalink
fix: Address regression introduced in #22853 (#24121)
Browse files Browse the repository at this point in the history
  • Loading branch information
john-bodley committed Jun 12, 2023
1 parent 6f25275 commit 2b36489
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 67 deletions.
98 changes: 44 additions & 54 deletions superset/connectors/sqla/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
backref,
Mapped,
Query,
reconstructor,
relationship,
RelationshipProperty,
Session,
Expand Down Expand Up @@ -218,6 +219,30 @@ class TableColumn(Model, BaseColumn, CertificationMixin):
update_from_object_fields = [s for s in export_fields if s not in ("table_id",)]
export_parent = "table"

def __init__(self, **kwargs: Any) -> None:
"""
Construct a TableColumn object.
Historically a TableColumn object (from an ORM perspective) was tighly bound to
a SqlaTable object, however with the introduction of the Query datasource this
is no longer true, i.e., the SqlaTable relationship is optional.
Now the TableColumn is either directly associated with the Database object (
which is unknown to the ORM) or indirectly via the SqlaTable object (courtesy of
the ORM) depending on the context.
"""

self._database: Database | None = kwargs.pop("database", None)
super().__init__(**kwargs)

@reconstructor
def init_on_load(self) -> None:
"""
Construct a TableColumn object when invoked via the SQLAlchemy ORM.
"""

self._database = None

@property
def is_boolean(self) -> bool:
"""
Expand Down Expand Up @@ -251,51 +276,33 @@ def is_temporal(self) -> bool:
return self.is_dttm
return self.type_generic == GenericDataType.TEMPORAL

@property
def database(self) -> Database:
return self.table.database if self.table else self._database

@property
def db_engine_spec(self) -> type[BaseEngineSpec]:
return self.table.db_engine_spec
return self.database.db_engine_spec

@property
def db_extra(self) -> dict[str, Any]:
return self.table.database.get_extra()
return self.database.get_extra()

@property
def type_generic(self) -> utils.GenericDataType | None:
if self.is_dttm:
return GenericDataType.TEMPORAL

bool_types = ("BOOL",)
num_types = (
"DOUBLE",
"FLOAT",
"INT",
"BIGINT",
"NUMBER",
"LONG",
"REAL",
"NUMERIC",
"DECIMAL",
"MONEY",
)
date_types = ("DATE", "TIME")
str_types = ("VARCHAR", "STRING", "CHAR")

if self.table is None:
# Query.TableColumns don't have a reference to a table.db_engine_spec
# reference so this logic will manage rendering types
if self.type and any(map(lambda t: t in self.type.upper(), str_types)):
return GenericDataType.STRING
if self.type and any(map(lambda t: t in self.type.upper(), bool_types)):
return GenericDataType.BOOLEAN
if self.type and any(map(lambda t: t in self.type.upper(), num_types)):
return GenericDataType.NUMERIC
if self.type and any(map(lambda t: t in self.type.upper(), date_types)):
return GenericDataType.TEMPORAL

column_spec = self.db_engine_spec.get_column_spec(
self.type, db_extra=self.db_extra
return (
column_spec.generic_type # pylint: disable=used-before-assignment
if (
column_spec := self.db_engine_spec.get_column_spec(
self.type,
db_extra=self.db_extra,
)
)
else None
)
return column_spec.generic_type if column_spec else None

def get_sqla_col(
self,
Expand All @@ -312,7 +319,7 @@ def get_sqla_col(
col = literal_column(expression, type_=type_)
else:
col = column(self.column_name, type_=type_)
col = self.table.make_sqla_column_compatible(col, label)
col = self.database.make_sqla_column_compatible(col, label)
return col

@property
Expand Down Expand Up @@ -343,15 +350,15 @@ def get_timestamp_expression(
type_ = column_spec.sqla_type if column_spec else DateTime
if not self.expression and not time_grain and not is_epoch:
sqla_col = column(self.column_name, type_=type_)
return self.table.make_sqla_column_compatible(sqla_col, label)
return self.database.make_sqla_column_compatible(sqla_col, label)
if expression := self.expression:
if template_processor:
expression = template_processor.process_template(expression)
col = literal_column(expression, type_=type_)
else:
col = column(self.column_name, type_=type_)
time_expr = self.db_engine_spec.get_timestamp_expr(col, pdf, time_grain)
return self.table.make_sqla_column_compatible(time_expr, label)
return self.database.make_sqla_column_compatible(time_expr, label)

@property
def data(self) -> dict[str, Any]:
Expand Down Expand Up @@ -423,7 +430,7 @@ def get_sqla_col(
expression = template_processor.process_template(expression)

sqla_col: ColumnClause = literal_column(expression)
return self.table.make_sqla_column_compatible(sqla_col, label)
return self.table.database.make_sqla_column_compatible(sqla_col, label)

@property
def perm(self) -> str | None:
Expand Down Expand Up @@ -1008,23 +1015,6 @@ def adhoc_column_to_sqla( # pylint: disable=too-many-locals
)
return self.make_sqla_column_compatible(sqla_column, label)

def make_sqla_column_compatible(
self, sqla_col: ColumnElement, label: str | None = None
) -> ColumnElement:
"""Takes a sqlalchemy column object and adds label info if supported by engine.
:param sqla_col: sqlalchemy column instance
:param label: alias/label that column is expected to have
:return: either a sql alchemy column or label instance if supported by engine
"""
label_expected = label or sqla_col.name
db_engine_spec = self.db_engine_spec
# add quotes to tables
if db_engine_spec.allows_alias_in_select:
label = db_engine_spec.make_label_compatible(label_expected)
sqla_col = sqla_col.label(label)
sqla_col.key = label_expected
return sqla_col

def make_orderby_compatible(
self, select_exprs: list[ColumnElement], orderby_exprs: list[ColumnElement]
) -> None:
Expand Down
20 changes: 18 additions & 2 deletions superset/models/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=line-too-long
# pylint: disable=line-too-long,too-many-lines
"""A collection of ORM sqlalchemy models for Superset"""
import builtins
import enum
Expand Down Expand Up @@ -53,7 +53,7 @@
from sqlalchemy.orm import relationship
from sqlalchemy.pool import NullPool
from sqlalchemy.schema import UniqueConstraint
from sqlalchemy.sql import expression, Select
from sqlalchemy.sql import ColumnElement, expression, Select

from superset import app, db_engine_specs
from superset.constants import LRU_CACHE_MAX_SIZE, PASSWORD_MASK
Expand Down Expand Up @@ -953,6 +953,22 @@ def get_dialect(self) -> Dialect:
sqla_url = make_url_safe(self.sqlalchemy_uri_decrypted)
return sqla_url.get_dialect()()

def make_sqla_column_compatible(
self, sqla_col: ColumnElement, label: Optional[str] = None
) -> ColumnElement:
"""Takes a sqlalchemy column object and adds label info if supported by engine.
:param sqla_col: sqlalchemy column instance
:param label: alias/label that column is expected to have
:return: either a sql alchemy column or label instance if supported by engine
"""
label_expected = label or sqla_col.name
# add quotes to tables
if self.db_engine_spec.allows_alias_in_select:
label = self.db_engine_spec.make_label_compatible(label_expected)
sqla_col = sqla_col.label(label)
sqla_col.key = label_expected
return sqla_col


sqla.event.listen(Database, "after_insert", security_manager.database_after_insert)
sqla.event.listen(Database, "after_update", security_manager.database_after_update)
Expand Down
21 changes: 10 additions & 11 deletions superset/models/sql_lab.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,18 +192,17 @@ def columns(self) -> list["TableColumn"]:
TableColumn,
)

columns = []
for col in self.extra.get("columns", []):
columns.append(
TableColumn(
column_name=col["name"],
type=col["type"],
is_dttm=col["is_dttm"],
groupby=True,
filterable=True,
)
return [
TableColumn(
column_name=col["name"],
database=self.database,
is_dttm=col["is_dttm"],
filterable=True,
groupby=True,
type=col["type"],
)
return columns
for col in self.extra.get("columns", [])
]

@property
def db_extra(self) -> Optional[dict[str, Any]]:
Expand Down
5 changes: 5 additions & 0 deletions tests/integration_tests/model_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,3 +671,8 @@ def test_data_for_slices_with_adhoc_column(self):

# clean up and auto commit
metadata_db.session.delete(slc)

@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_table_column_database(self) -> None:
tbl = self.get_table(name="birth_names")
assert tbl.get_column("ds").database is tbl.database # type: ignore
5 changes: 5 additions & 0 deletions tests/unit_tests/models/core_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,3 +207,8 @@ def test_dttm_sql_literal(
result: str,
) -> None:
assert SqlaTable(database=database).dttm_sql_literal(dttm, col) == result


def test_table_column_database() -> None:
database = Database(database_name="db")
assert TableColumn(database=database).database is database

0 comments on commit 2b36489

Please sign in to comment.