Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: Ensure Flask framework leverages the Flask-SQLAlchemy session (Phase I) #26200

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 7 additions & 10 deletions scripts/benchmark_migration.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,6 @@ def main(
filepath: str, limit: int = 1000, force: bool = False, no_auto_cleanup: bool = False
) -> None:
auto_cleanup = not no_auto_cleanup
session = db.session()

print(f"Importing migration script: {filepath}")
module = import_migration_script(Path(filepath))

Expand Down Expand Up @@ -174,10 +172,9 @@ def main(
models = find_models(module)
model_rows: dict[type[Model], int] = {}
for model in models:
rows = session.query(model).count()
rows = db.session.query(model).count()
print(f"- {model.__name__} ({rows} rows in table {model.__tablename__})")
model_rows[model] = rows
session.close()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no need to explicitly close the session as this is handled by Flask-SQLAlchemy when the session is torn down—be that at the end of request or when a script/shell terminates.


print("Benchmarking migration")
results: dict[str, float] = {}
Expand All @@ -199,16 +196,16 @@ def main(
print(f"- Adding {missing} entities to the {model.__name__} model")
bar = ChargingBar("Processing", max=missing)
try:
for entity in add_sample_rows(session, model, missing):
for entity in add_sample_rows(model, missing):
entities.append(entity)
bar.next()
except Exception:
session.rollback()
db.session.rollback()
raise
bar.finish()
model_rows[model] = min_entities
session.add_all(entities)
session.commit()
db.session.add_all(entities)
db.session.commit()

if auto_cleanup:
new_models[model].extend(entities)
Expand All @@ -227,10 +224,10 @@ def main(
print("Cleaning up DB")
# delete in reverse order of creation to handle relationships
for model, entities in list(new_models.items())[::-1]:
session.query(model).filter(
db.session.query(model).filter(
model.id.in_(entity.id for entity in entities)
).delete(synchronize_session=False)
session.commit()
db.session.commit()

if current_revision != revision and not force:
click.confirm(f"\nRevert DB to {revision}?", abort=True)
Expand Down
1 change: 0 additions & 1 deletion superset/cachekeys/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,6 @@ def invalidate(self) -> Response:
datasource_uids = set(datasources.get("datasource_uids", []))
for ds in datasources.get("datasources", []):
ds_obj = SqlaTable.get_datasource_by_name(
session=db.session,
datasource_name=ds.get("datasource_name"),
schema=ds.get("schema"),
database_name=ds.get("database_name"),
Expand Down
32 changes: 15 additions & 17 deletions superset/commands/dashboard/importers/v0.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from typing import Any, Optional

from flask_babel import lazy_gettext as _
from sqlalchemy.orm import make_transient, Session
from sqlalchemy.orm import make_transient

from superset import db
from superset.commands.base import BaseCommand
Expand Down Expand Up @@ -55,7 +55,6 @@ def import_chart(
:returns: The resulting id for the imported slice
:rtype: int
"""
session = db.session
make_transient(slc_to_import)
slc_to_import.dashboards = []
slc_to_import.alter_params(remote_id=slc_to_import.id, import_time=import_time)
Expand All @@ -64,19 +63,18 @@ def import_chart(
slc_to_import.reset_ownership()
params = slc_to_import.params_dict
datasource = SqlaTable.get_datasource_by_name(
session=session,
datasource_name=params["datasource_name"],
database_name=params["database_name"],
schema=params["schema"],
)
slc_to_import.datasource_id = datasource.id # type: ignore
if slc_to_override:
slc_to_override.override(slc_to_import)
session.flush()
db.session.flush()
return slc_to_override.id
session.add(slc_to_import)
db.session.add(slc_to_import)
logger.info("Final slice: %s", str(slc_to_import.to_json()))
session.flush()
db.session.flush()
return slc_to_import.id


Expand Down Expand Up @@ -156,7 +154,6 @@ def alter_native_filters(dashboard: Dashboard) -> None:
dashboard.json_metadata = json.dumps(json_metadata)

logger.info("Started import of the dashboard: %s", dashboard_to_import.to_json())
session = db.session
logger.info("Dashboard has %d slices", len(dashboard_to_import.slices))
# copy slices object as Slice.import_slice will mutate the slice
# and will remove the existing dashboard - slice association
Expand All @@ -173,7 +170,7 @@ def alter_native_filters(dashboard: Dashboard) -> None:
i_params_dict = dashboard_to_import.params_dict
remote_id_slice_map = {
slc.params_dict["remote_id"]: slc
for slc in session.query(Slice).all()
for slc in db.session.query(Slice).all()
if "remote_id" in slc.params_dict
}
for slc in slices:
Expand Down Expand Up @@ -224,7 +221,7 @@ def alter_native_filters(dashboard: Dashboard) -> None:

# override the dashboard
existing_dashboard = None
for dash in session.query(Dashboard).all():
for dash in db.session.query(Dashboard).all():
if (
"remote_id" in dash.params_dict
and dash.params_dict["remote_id"] == dashboard_to_import.id
Expand Down Expand Up @@ -253,18 +250,20 @@ def alter_native_filters(dashboard: Dashboard) -> None:
alter_native_filters(dashboard_to_import)

new_slices = (
session.query(Slice).filter(Slice.id.in_(old_to_new_slc_id_dict.values())).all()
db.session.query(Slice)
.filter(Slice.id.in_(old_to_new_slc_id_dict.values()))
.all()
)

if existing_dashboard:
existing_dashboard.override(dashboard_to_import)
existing_dashboard.slices = new_slices
session.flush()
db.session.flush()
return existing_dashboard.id

dashboard_to_import.slices = new_slices
session.add(dashboard_to_import)
session.flush()
db.session.add(dashboard_to_import)
db.session.flush()
return dashboard_to_import.id # type: ignore


Expand All @@ -291,7 +290,6 @@ def decode_dashboards(o: dict[str, Any]) -> Any:


def import_dashboards(
session: Session,
content: str,
database_id: Optional[int] = None,
import_time: Optional[int] = None,
Expand All @@ -308,10 +306,10 @@ def import_dashboards(
params = json.loads(table.params)
dataset_id_mapping[params["remote_id"]] = new_dataset_id

session.commit()
db.session.commit()
for dashboard in data["dashboards"]:
import_dashboard(dashboard, dataset_id_mapping, import_time=import_time)
session.commit()
db.session.commit()


class ImportDashboardsCommand(BaseCommand):
Expand All @@ -334,7 +332,7 @@ def run(self) -> None:

for file_name, content in self.contents.items():
logger.info("Importing dashboard from file %s", file_name)
import_dashboards(db.session, content, self.database_id)
import_dashboards(content, self.database_id)

def validate(self) -> None:
# ensure all files are JSON
Expand Down
3 changes: 1 addition & 2 deletions superset/commands/explore/get.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from flask_babel import lazy_gettext as _
from sqlalchemy.exc import SQLAlchemyError

from superset import db
from superset.commands.base import BaseCommand
from superset.commands.explore.form_data.get import GetFormDataCommand
from superset.commands.explore.form_data.parameters import (
Expand Down Expand Up @@ -114,7 +113,7 @@ def run(self) -> Optional[dict[str, Any]]:
if self._datasource_id is not None:
with contextlib.suppress(DatasourceNotFound):
datasource = DatasourceDAO.get_datasource(
db.session, cast(str, self._datasource_type), self._datasource_id
cast(str, self._datasource_type), self._datasource_id
)
datasource_name = datasource.name if datasource else _("[Missing Dataset]")
viz_type = form_data.get("viz_type")
Expand Down
3 changes: 1 addition & 2 deletions superset/commands/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
)
from superset.daos.datasource import DatasourceDAO
from superset.daos.exceptions import DatasourceNotFound
from superset.extensions import db
from superset.utils.core import DatasourceType, get_user_id

if TYPE_CHECKING:
Expand Down Expand Up @@ -80,7 +79,7 @@ def populate_roles(role_ids: list[int] | None = None) -> list[Role]:
def get_datasource_by_id(datasource_id: int, datasource_type: str) -> BaseDatasource:
try:
return DatasourceDAO.get_datasource(
db.session, DatasourceType(datasource_type), datasource_id
DatasourceType(datasource_type), datasource_id
)
except DatasourceNotFound as ex:
raise DatasourceNotFoundValidationError() from ex
5 changes: 2 additions & 3 deletions superset/common/query_context_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from typing import Any, TYPE_CHECKING

from superset import app, db
from superset import app
from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType
from superset.common.query_context import QueryContext
from superset.common.query_object import QueryObject
Expand All @@ -35,7 +35,7 @@


def create_query_object_factory() -> QueryObjectFactory:
return QueryObjectFactory(config, DatasourceDAO(), db.session)
return QueryObjectFactory(config, DatasourceDAO())


class QueryContextFactory: # pylint: disable=too-few-public-methods
Expand Down Expand Up @@ -95,7 +95,6 @@ def create(

def _convert_to_model(self, datasource: DatasourceDict) -> BaseDatasource:
return DatasourceDAO.get_datasource(
session=db.session,
datasource_type=DatasourceType(datasource["type"]),
datasource_id=int(datasource["id"]),
)
Expand Down
6 changes: 0 additions & 6 deletions superset/common/query_object_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,26 +33,21 @@
)

if TYPE_CHECKING:
from sqlalchemy.orm import sessionmaker

from superset.connectors.sqla.models import BaseDatasource
from superset.daos.datasource import DatasourceDAO


class QueryObjectFactory: # pylint: disable=too-few-public-methods
_config: dict[str, Any]
_datasource_dao: DatasourceDAO
_session_maker: sessionmaker

def __init__(
self,
app_configurations: dict[str, Any],
_datasource_dao: DatasourceDAO,
session_maker: sessionmaker,
):
self._config = app_configurations
self._datasource_dao = _datasource_dao
self._session_maker = session_maker

def create( # pylint: disable=too-many-arguments
self,
Expand Down Expand Up @@ -91,7 +86,6 @@ def _convert_to_model(self, datasource: DatasourceDict) -> BaseDatasource:
return self._datasource_dao.get_datasource(
datasource_type=DatasourceType(datasource["type"]),
datasource_id=int(datasource["id"]),
session=self._session_maker(),
)

def _process_extras(
Expand Down
14 changes: 5 additions & 9 deletions superset/connectors/sqla/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,7 +699,7 @@ def raise_for_access(self) -> None:

@classmethod
def get_datasource_by_name(
cls, session: Session, datasource_name: str, schema: str, database_name: str
cls, datasource_name: str, schema: str, database_name: str
) -> BaseDatasource | None:
raise NotImplementedError()

Expand Down Expand Up @@ -1238,14 +1238,13 @@ def database_name(self) -> str:
@classmethod
def get_datasource_by_name(
cls,
session: Session,
datasource_name: str,
schema: str | None,
database_name: str,
) -> SqlaTable | None:
schema = schema or None
query = (
session.query(cls)
db.session.query(cls)
.join(Database)
.filter(cls.table_name == datasource_name)
.filter(Database.database_name == database_name)
Expand Down Expand Up @@ -1939,12 +1938,10 @@ def query_datasources_by_permissions( # pylint: disable=invalid-name
)

@classmethod
def get_eager_sqlatable_datasource(
cls, session: Session, datasource_id: int
) -> SqlaTable:
def get_eager_sqlatable_datasource(cls, datasource_id: int) -> SqlaTable:
"""Returns SqlaTable with columns and metrics."""
return (
session.query(cls)
db.session.query(cls)
.options(
sa.orm.subqueryload(cls.columns),
sa.orm.subqueryload(cls.metrics),
Expand Down Expand Up @@ -2037,8 +2034,7 @@ def update_column( # pylint: disable=unused-argument
:param connection: Unused.
:param target: The metric or column that was updated.
"""
inspector = inspect(target)
session = inspector.session
session = inspect(target).session

# Forces an update to the table's changed_on value when a metric or column on the
# table is updated. This busts the cache key for all charts that use the table.
Expand Down
7 changes: 4 additions & 3 deletions superset/daos/dashboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def validate_update_slug_uniqueness(dashboard_id: int, slug: str | None) -> bool
return True

@staticmethod
def set_dash_metadata( # pylint: disable=too-many-locals
def set_dash_metadata(
dashboard: Dashboard,
data: dict[Any, Any],
old_to_new_slice_ids: dict[int, int] | None = None,
Expand All @@ -187,8 +187,9 @@ def set_dash_metadata( # pylint: disable=too-many-locals
if isinstance(value, dict)
]

session = db.session()
current_slices = session.query(Slice).filter(Slice.id.in_(slice_ids)).all()
current_slices = (
db.session.query(Slice).filter(Slice.id.in_(slice_ids)).all()
)

dashboard.slices = current_slices

Expand Down
6 changes: 2 additions & 4 deletions superset/daos/datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@
import logging
from typing import Union

from sqlalchemy.orm import Session

from superset import db
from superset.connectors.sqla.models import SqlaTable
from superset.daos.base import BaseDAO
from superset.daos.exceptions import DatasourceNotFound, DatasourceTypeNotSupportedError
Expand All @@ -45,15 +44,14 @@ class DatasourceDAO(BaseDAO[Datasource]):
@classmethod
def get_datasource(
cls,
session: Session,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is not apparent to me why the session was included here. There are unit tests which use a session other than db.session but, per here, the global db.session is mocked to use said session.

datasource_type: Union[DatasourceType, str],
datasource_id: int,
) -> Datasource:
if datasource_type not in cls.sources:
raise DatasourceTypeNotSupportedError()

datasource = (
session.query(cls.sources[datasource_type])
db.session.query(cls.sources[datasource_type])
.filter_by(id=datasource_id)
.one_or_none()
)
Expand Down
4 changes: 2 additions & 2 deletions superset/datasource/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from flask_appbuilder.api import expose, protect, safe

from superset import app, db, event_logger
from superset import app, event_logger
from superset.daos.datasource import DatasourceDAO
from superset.daos.exceptions import DatasourceNotFound, DatasourceTypeNotSupportedError
from superset.exceptions import SupersetSecurityException
Expand Down Expand Up @@ -100,7 +100,7 @@ def get_column_values(
"""
try:
datasource = DatasourceDAO.get_datasource(
db.session, DatasourceType(datasource_type), datasource_id
DatasourceType(datasource_type), datasource_id
)
datasource.raise_for_access()
except ValueError:
Expand Down
Loading
Loading