From 56d001835e918261324ea07008a27dec1e29ecaa Mon Sep 17 00:00:00 2001 From: Bogdan Date: Mon, 21 Sep 2020 10:34:03 -0700 Subject: [PATCH] fix: use nullpool even for user lookup in the celery (#10938) * Use nullpool even for user lookup in the celery * Address feedback Co-authored-by: bogdan kyryliuk --- superset/tasks/schedules.py | 34 +++++++++++++++++++++++----------- tests/schedules_test.py | 2 +- 2 files changed, 24 insertions(+), 12 deletions(-) diff --git a/superset/tasks/schedules.py b/superset/tasks/schedules.py index 09a42145e8cd..2cc1280f0108 100644 --- a/superset/tasks/schedules.py +++ b/superset/tasks/schedules.py @@ -46,6 +46,7 @@ from selenium.common.exceptions import WebDriverException from selenium.webdriver import chrome, firefox from selenium.webdriver.remote.webdriver import WebDriver +from sqlalchemy import func from sqlalchemy.exc import NoSuchColumnError, ResourceClosedError from sqlalchemy.orm import Session @@ -200,12 +201,21 @@ def _get_url_path(view: str, user_friendly: bool = False, **kwargs: Any) -> str: return urllib.parse.urljoin(str(base_url), url_for(view, **kwargs)) -def create_webdriver() -> WebDriver: - return WebDriverProxy(driver_type=config["WEBDRIVER_TYPE"]).auth(get_reports_user()) +def create_webdriver(session: Session) -> WebDriver: + return WebDriverProxy(driver_type=config["WEBDRIVER_TYPE"]).auth( + get_reports_user(session) + ) -def get_reports_user() -> "User": - return security_manager.find_user(config["EMAIL_REPORTS_USER"]) +def get_reports_user(session: Session) -> "User": + return ( + session.query(security_manager.user_model) + .filter( + func.lower(security_manager.user_model.username) + == func.lower(config["EMAIL_REPORTS_USER"]) + ) + .one() + ) def destroy_webdriver( @@ -249,7 +259,7 @@ def deliver_dashboard( # pylint: disable=too-many-locals ) # Create a driver, fetch the page, wait for the page to render - driver = create_webdriver() + driver = create_webdriver(session) window = config["WEBDRIVER_WINDOW"]["dashboard"] driver.set_window_size(*window) driver.get(dashboard_url) @@ -303,7 +313,9 @@ def deliver_dashboard( # pylint: disable=too-many-locals ) -def _get_slice_data(slc: Slice, delivery_type: EmailDeliveryType) -> ReportContent: +def _get_slice_data( + slc: Slice, delivery_type: EmailDeliveryType, session: Session +) -> ReportContent: slice_url = _get_url_path( "Superset.explore_json", csv="true", form_data=json.dumps({"slice_id": slc.id}) ) @@ -315,7 +327,7 @@ def _get_slice_data(slc: Slice, delivery_type: EmailDeliveryType) -> ReportConte # Login on behalf of the "reports" user in order to get cookies to deal with auth auth_cookies = machine_auth_provider_factory.instance.get_auth_cookies( - get_reports_user() + get_reports_user(session) ) # Build something like "session=cool_sess.val;other-cookie=awesome_other_cookie" cookie_str = ";".join([f"{key}={val}" for key, val in auth_cookies.items()]) @@ -384,10 +396,10 @@ def _get_slice_screenshot(slice_id: int, session: Session) -> ScreenshotData: def _get_slice_visualization( - slc: Slice, delivery_type: EmailDeliveryType + slc: Slice, delivery_type: EmailDeliveryType, session: Session ) -> ReportContent: # Create a driver, fetch the page, wait for the page to render - driver = create_webdriver() + driver = create_webdriver(session) window = config["WEBDRIVER_WINDOW"]["slice"] driver.set_window_size(*window) @@ -438,9 +450,9 @@ def deliver_slice( # pylint: disable=too-many-arguments slc = session.query(Slice).filter_by(id=slice_id).one() if email_format == SliceEmailReportFormat.data: - report_content = _get_slice_data(slc, delivery_type) + report_content = _get_slice_data(slc, delivery_type, session) elif email_format == SliceEmailReportFormat.visualization: - report_content = _get_slice_visualization(slc, delivery_type) + report_content = _get_slice_visualization(slc, delivery_type, session) else: raise RuntimeError("Unknown email report format") diff --git a/tests/schedules_test.py b/tests/schedules_test.py index 88b6d1f924d1..b18007e2f77a 100644 --- a/tests/schedules_test.py +++ b/tests/schedules_test.py @@ -171,7 +171,7 @@ def test_create_driver(self, mock_driver_class): mock_driver_class.return_value = mock_driver mock_driver.find_elements_by_id.side_effect = [True, False] - create_webdriver() + create_webdriver(db.session) mock_driver.add_cookie.assert_called_once() @patch("superset.tasks.schedules.firefox.webdriver.WebDriver")