Skip to content

Commit

Permalink
fix: use nullpool even for user lookup in the celery (#10938)
Browse files Browse the repository at this point in the history
* Use nullpool even for user lookup in the celery

* Address feedback

Co-authored-by: bogdan kyryliuk <bogdankyryliuk@dropbox.com>
  • Loading branch information
bkyryliuk and bogdan-dbx committed Sep 21, 2020
1 parent 801fb40 commit 56d0018
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 12 deletions.
34 changes: 23 additions & 11 deletions superset/tasks/schedules.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from selenium.common.exceptions import WebDriverException
from selenium.webdriver import chrome, firefox
from selenium.webdriver.remote.webdriver import WebDriver
from sqlalchemy import func
from sqlalchemy.exc import NoSuchColumnError, ResourceClosedError
from sqlalchemy.orm import Session

Expand Down Expand Up @@ -200,12 +201,21 @@ def _get_url_path(view: str, user_friendly: bool = False, **kwargs: Any) -> str:
return urllib.parse.urljoin(str(base_url), url_for(view, **kwargs))


def create_webdriver() -> WebDriver:
return WebDriverProxy(driver_type=config["WEBDRIVER_TYPE"]).auth(get_reports_user())
def create_webdriver(session: Session) -> WebDriver:
return WebDriverProxy(driver_type=config["WEBDRIVER_TYPE"]).auth(
get_reports_user(session)
)


def get_reports_user() -> "User":
return security_manager.find_user(config["EMAIL_REPORTS_USER"])
def get_reports_user(session: Session) -> "User":
return (
session.query(security_manager.user_model)
.filter(
func.lower(security_manager.user_model.username)
== func.lower(config["EMAIL_REPORTS_USER"])
)
.one()
)


def destroy_webdriver(
Expand Down Expand Up @@ -249,7 +259,7 @@ def deliver_dashboard( # pylint: disable=too-many-locals
)

# Create a driver, fetch the page, wait for the page to render
driver = create_webdriver()
driver = create_webdriver(session)
window = config["WEBDRIVER_WINDOW"]["dashboard"]
driver.set_window_size(*window)
driver.get(dashboard_url)
Expand Down Expand Up @@ -303,7 +313,9 @@ def deliver_dashboard( # pylint: disable=too-many-locals
)


def _get_slice_data(slc: Slice, delivery_type: EmailDeliveryType) -> ReportContent:
def _get_slice_data(
slc: Slice, delivery_type: EmailDeliveryType, session: Session
) -> ReportContent:
slice_url = _get_url_path(
"Superset.explore_json", csv="true", form_data=json.dumps({"slice_id": slc.id})
)
Expand All @@ -315,7 +327,7 @@ def _get_slice_data(slc: Slice, delivery_type: EmailDeliveryType) -> ReportConte

# Login on behalf of the "reports" user in order to get cookies to deal with auth
auth_cookies = machine_auth_provider_factory.instance.get_auth_cookies(
get_reports_user()
get_reports_user(session)
)
# Build something like "session=cool_sess.val;other-cookie=awesome_other_cookie"
cookie_str = ";".join([f"{key}={val}" for key, val in auth_cookies.items()])
Expand Down Expand Up @@ -384,10 +396,10 @@ def _get_slice_screenshot(slice_id: int, session: Session) -> ScreenshotData:


def _get_slice_visualization(
slc: Slice, delivery_type: EmailDeliveryType
slc: Slice, delivery_type: EmailDeliveryType, session: Session
) -> ReportContent:
# Create a driver, fetch the page, wait for the page to render
driver = create_webdriver()
driver = create_webdriver(session)
window = config["WEBDRIVER_WINDOW"]["slice"]
driver.set_window_size(*window)

Expand Down Expand Up @@ -438,9 +450,9 @@ def deliver_slice( # pylint: disable=too-many-arguments
slc = session.query(Slice).filter_by(id=slice_id).one()

if email_format == SliceEmailReportFormat.data:
report_content = _get_slice_data(slc, delivery_type)
report_content = _get_slice_data(slc, delivery_type, session)
elif email_format == SliceEmailReportFormat.visualization:
report_content = _get_slice_visualization(slc, delivery_type)
report_content = _get_slice_visualization(slc, delivery_type, session)
else:
raise RuntimeError("Unknown email report format")

Expand Down
2 changes: 1 addition & 1 deletion tests/schedules_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def test_create_driver(self, mock_driver_class):
mock_driver_class.return_value = mock_driver
mock_driver.find_elements_by_id.side_effect = [True, False]

create_webdriver()
create_webdriver(db.session)
mock_driver.add_cookie.assert_called_once()

@patch("superset.tasks.schedules.firefox.webdriver.WebDriver")
Expand Down

0 comments on commit 56d0018

Please sign in to comment.