Skip to content

Commit

Permalink
refactor: Ensure Flask framework leverages the Flask-SQLAlchemy sessi…
Browse files Browse the repository at this point in the history
…on (Phase II) (#26909)
  • Loading branch information
john-bodley committed Feb 13, 2024
1 parent 827864b commit 847ed3f
Show file tree
Hide file tree
Showing 96 changed files with 656 additions and 730 deletions.
2 changes: 1 addition & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / stateme
good-names=_,df,ex,f,i,id,j,k,l,o,pk,Run,ts,v,x,y

# Bad variable names which should always be refused, separated by a comma
bad-names=fd,foo,bar,baz,toto,tutu,tata
bad-names=bar,baz,db,fd,foo,sesh,session,tata,toto,tutu

# Colon-delimited sets of names that determine each other's naming style when
# the name regexes allow several styles.
Expand Down
3 changes: 1 addition & 2 deletions superset/cli/importexport.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def legacy_export_dashboards(
# pylint: disable=import-outside-toplevel
from superset.utils import dashboard_import_export

data = dashboard_import_export.export_dashboards(db.session)
data = dashboard_import_export.export_dashboards()
if print_stdout or not dashboard_file:
print(data)
if dashboard_file:
Expand Down Expand Up @@ -263,7 +263,6 @@ def legacy_export_datasources(
from superset.utils import dict_import_export

data = dict_import_export.export_to_dict(
session=db.session,
recursive=True,
back_references=back_references,
include_defaults=include_defaults,
Expand Down
10 changes: 4 additions & 6 deletions superset/commands/chart/importers/v1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,7 @@ class ImportChartsCommand(ImportModelsCommand):
import_error = ChartImportError

@staticmethod
def _import(
session: Session, configs: dict[str, Any], overwrite: bool = False
) -> None:
def _import(configs: dict[str, Any], overwrite: bool = False) -> None:
# discover datasets associated with charts
dataset_uuids: set[str] = set()
for file_name, config in configs.items():
Expand All @@ -66,7 +64,7 @@ def _import(
database_ids: dict[str, int] = {}
for file_name, config in configs.items():
if file_name.startswith("databases/") and config["uuid"] in database_uuids:
database = import_database(session, config, overwrite=False)
database = import_database(config, overwrite=False)
database_ids[str(database.uuid)] = database.id

# import datasets with the correct parent ref
Expand All @@ -77,7 +75,7 @@ def _import(
and config["database_uuid"] in database_ids
):
config["database_id"] = database_ids[config["database_uuid"]]
dataset = import_dataset(session, config, overwrite=False)
dataset = import_dataset(config, overwrite=False)
datasets[str(dataset.uuid)] = dataset

# import charts with the correct parent ref
Expand All @@ -101,4 +99,4 @@ def _import(
if "query_context" in config:
config["query_context"] = None

import_chart(session, config, overwrite=overwrite)
import_chart(config, overwrite=overwrite)
13 changes: 4 additions & 9 deletions superset/commands/chart/importers/v1/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,7 @@
from inspect import isclass
from typing import Any

from sqlalchemy.orm import Session

from superset import security_manager
from superset import db, security_manager
from superset.commands.exceptions import ImportFailedError
from superset.migrations.shared.migrate_viz import processors
from superset.migrations.shared.migrate_viz.base import MigrateViz
Expand All @@ -46,13 +44,12 @@ def filter_chart_annotations(chart_config: dict[str, Any]) -> None:


def import_chart(
session: Session,
config: dict[str, Any],
overwrite: bool = False,
ignore_permissions: bool = False,
) -> Slice:
can_write = ignore_permissions or security_manager.can_access("can_write", "Chart")
existing = session.query(Slice).filter_by(uuid=config["uuid"]).first()
existing = db.session.query(Slice).filter_by(uuid=config["uuid"]).first()
if existing:
if overwrite and can_write and get_user():
if not security_manager.can_access_chart(existing):
Expand All @@ -76,11 +73,9 @@ def import_chart(
# migrate old viz types to new ones
config = migrate_chart(config)

chart = Slice.import_from_dict(
session, config, recursive=False, allow_reparenting=True
)
chart = Slice.import_from_dict(config, recursive=False, allow_reparenting=True)
if chart.id is None:
session.flush()
db.session.flush()

if user := get_user():
chart.owners.append(user)
Expand Down
19 changes: 9 additions & 10 deletions superset/commands/dashboard/importers/v1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from sqlalchemy.orm import Session
from sqlalchemy.sql import select

from superset import db
from superset.charts.schemas import ImportV1ChartSchema
from superset.commands.chart.importers.v1.utils import import_chart
from superset.commands.dashboard.exceptions import DashboardImportError
Expand Down Expand Up @@ -59,9 +60,7 @@ class ImportDashboardsCommand(ImportModelsCommand):
# TODO (betodealmeida): refactor to use code from other commands
# pylint: disable=too-many-branches, too-many-locals
@staticmethod
def _import(
session: Session, configs: dict[str, Any], overwrite: bool = False
) -> None:
def _import(configs: dict[str, Any], overwrite: bool = False) -> None:
# discover charts and datasets associated with dashboards
chart_uuids: set[str] = set()
dataset_uuids: set[str] = set()
Expand All @@ -87,7 +86,7 @@ def _import(
database_ids: dict[str, int] = {}
for file_name, config in configs.items():
if file_name.startswith("databases/") and config["uuid"] in database_uuids:
database = import_database(session, config, overwrite=False)
database = import_database(config, overwrite=False)
database_ids[str(database.uuid)] = database.id

# import datasets with the correct parent ref
Expand All @@ -98,7 +97,7 @@ def _import(
and config["database_uuid"] in database_ids
):
config["database_id"] = database_ids[config["database_uuid"]]
dataset = import_dataset(session, config, overwrite=False)
dataset = import_dataset(config, overwrite=False)
dataset_info[str(dataset.uuid)] = {
"datasource_id": dataset.id,
"datasource_type": dataset.datasource_type,
Expand All @@ -122,12 +121,12 @@ def _import(
if "query_context" in config:
config["query_context"] = None

chart = import_chart(session, config, overwrite=False)
chart = import_chart(config, overwrite=False)
charts.append(chart)
chart_ids[str(chart.uuid)] = chart.id

# store the existing relationship between dashboards and charts
existing_relationships = session.execute(
existing_relationships = db.session.execute(
select([dashboard_slices.c.dashboard_id, dashboard_slices.c.slice_id])
).fetchall()

Expand All @@ -137,7 +136,7 @@ def _import(
for file_name, config in configs.items():
if file_name.startswith("dashboards/"):
config = update_id_refs(config, chart_ids, dataset_info)
dashboard = import_dashboard(session, config, overwrite=overwrite)
dashboard = import_dashboard(config, overwrite=overwrite)
dashboards.append(dashboard)
for uuid in find_chart_uuids(config["position"]):
if uuid not in chart_ids:
Expand All @@ -151,7 +150,7 @@ def _import(
{"dashboard_id": dashboard_id, "slice_id": chart_id}
for (dashboard_id, chart_id) in dashboard_chart_ids
]
session.execute(dashboard_slices.insert(), values)
db.session.execute(dashboard_slices.insert(), values)

# Migrate any filter-box charts to native dashboard filters.
for dashboard in dashboards:
Expand All @@ -160,4 +159,4 @@ def _import(
# Remove all obsolete filter-box charts.
for chart in charts:
if chart.viz_type == "filter_box":
session.delete(chart)
db.session.delete(chart)
11 changes: 4 additions & 7 deletions superset/commands/dashboard/importers/v1/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,7 @@
import logging
from typing import Any

from sqlalchemy.orm import Session

from superset import security_manager
from superset import db, security_manager
from superset.commands.exceptions import ImportFailedError
from superset.models.dashboard import Dashboard
from superset.utils.core import get_user
Expand Down Expand Up @@ -146,7 +144,6 @@ def update_id_refs( # pylint: disable=too-many-locals


def import_dashboard(
session: Session,
config: dict[str, Any],
overwrite: bool = False,
ignore_permissions: bool = False,
Expand All @@ -155,7 +152,7 @@ def import_dashboard(
"can_write",
"Dashboard",
)
existing = session.query(Dashboard).filter_by(uuid=config["uuid"]).first()
existing = db.session.query(Dashboard).filter_by(uuid=config["uuid"]).first()
if existing:
if overwrite and can_write and get_user():
if not security_manager.can_access_dashboard(existing):
Expand Down Expand Up @@ -187,9 +184,9 @@ def import_dashboard(
except TypeError:
logger.info("Unable to encode `%s` field: %s", key, value)

dashboard = Dashboard.import_from_dict(session, config, recursive=False)
dashboard = Dashboard.import_from_dict(config, recursive=False)
if dashboard.id is None:
session.flush()
db.session.flush()

if user := get_user():
dashboard.owners.append(user)
Expand Down
8 changes: 3 additions & 5 deletions superset/commands/database/importers/v1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,12 @@ class ImportDatabasesCommand(ImportModelsCommand):
import_error = DatabaseImportError

@staticmethod
def _import(
session: Session, configs: dict[str, Any], overwrite: bool = False
) -> None:
def _import(configs: dict[str, Any], overwrite: bool = False) -> None:
# first import databases
database_ids: dict[str, int] = {}
for file_name, config in configs.items():
if file_name.startswith("databases/"):
database = import_database(session, config, overwrite=overwrite)
database = import_database(config, overwrite=overwrite)
database_ids[str(database.uuid)] = database.id

# import related datasets
Expand All @@ -61,4 +59,4 @@ def _import(
):
config["database_id"] = database_ids[config["database_uuid"]]
# overwrite=False prevents deleting any non-imported columns/metrics
import_dataset(session, config, overwrite=False)
import_dataset(config, overwrite=False)
13 changes: 5 additions & 8 deletions superset/commands/database/importers/v1/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,7 @@
import json
from typing import Any

from sqlalchemy.orm import Session

from superset import app, security_manager
from superset import app, db, security_manager
from superset.commands.exceptions import ImportFailedError
from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.databases.utils import make_url_safe
Expand All @@ -30,7 +28,6 @@


def import_database(
session: Session,
config: dict[str, Any],
overwrite: bool = False,
ignore_permissions: bool = False,
Expand All @@ -39,7 +36,7 @@ def import_database(
"can_write",
"Database",
)
existing = session.query(Database).filter_by(uuid=config["uuid"]).first()
existing = db.session.query(Database).filter_by(uuid=config["uuid"]).first()
if existing:
if not overwrite or not can_write:
return existing
Expand Down Expand Up @@ -67,12 +64,12 @@ def import_database(
# Before it gets removed in import_from_dict
ssh_tunnel = config.pop("ssh_tunnel", None)

database = Database.import_from_dict(session, config, recursive=False)
database = Database.import_from_dict(config, recursive=False)
if database.id is None:
session.flush()
db.session.flush()

if ssh_tunnel:
ssh_tunnel["database_id"] = database.id
SSHTunnel.import_from_dict(session, ssh_tunnel, recursive=False)
SSHTunnel.import_from_dict(ssh_tunnel, recursive=False)

return database
Loading

0 comments on commit 847ed3f

Please sign in to comment.