Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: use nullpool even for user lookup in the celery #10938

Merged
merged 2 commits into from
Sep 21, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 23 additions & 11 deletions superset/tasks/schedules.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from selenium.common.exceptions import WebDriverException
from selenium.webdriver import chrome, firefox
from selenium.webdriver.remote.webdriver import WebDriver
from sqlalchemy import func
from sqlalchemy.exc import NoSuchColumnError, ResourceClosedError
from sqlalchemy.orm import Session

Expand Down Expand Up @@ -200,12 +201,21 @@ def _get_url_path(view: str, user_friendly: bool = False, **kwargs: Any) -> str:
return urllib.parse.urljoin(str(base_url), url_for(view, **kwargs))


def create_webdriver() -> WebDriver:
return WebDriverProxy(driver_type=config["WEBDRIVER_TYPE"]).auth(get_reports_user())
def create_webdriver(session: Session) -> WebDriver:
return WebDriverProxy(driver_type=config["WEBDRIVER_TYPE"]).auth(
get_reports_user(session)
)


def get_reports_user() -> "User":
return security_manager.find_user(config["EMAIL_REPORTS_USER"])
def get_reports_user(session: Session) -> "User":
return (
session.query(security_manager.user_model)
.filter(
func.lower(security_manager.user_model.username)
== func.lower(config["EMAIL_REPORTS_USER"])
Comment on lines +214 to +215
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious about this, what are the cases where it's needed?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, this is from security manager implementation, wanted them to be aligned, email is case insensitive

)
.one()
)


def destroy_webdriver(
Expand Down Expand Up @@ -249,7 +259,7 @@ def deliver_dashboard( # pylint: disable=too-many-locals
)

# Create a driver, fetch the page, wait for the page to render
driver = create_webdriver()
driver = create_webdriver(session)
window = config["WEBDRIVER_WINDOW"]["dashboard"]
driver.set_window_size(*window)
driver.get(dashboard_url)
Expand Down Expand Up @@ -303,7 +313,9 @@ def deliver_dashboard( # pylint: disable=too-many-locals
)


def _get_slice_data(slc: Slice, delivery_type: EmailDeliveryType) -> ReportContent:
def _get_slice_data(
slc: Slice, delivery_type: EmailDeliveryType, session: Session
) -> ReportContent:
slice_url = _get_url_path(
"Superset.explore_json", csv="true", form_data=json.dumps({"slice_id": slc.id})
)
Expand All @@ -315,7 +327,7 @@ def _get_slice_data(slc: Slice, delivery_type: EmailDeliveryType) -> ReportConte

# Login on behalf of the "reports" user in order to get cookies to deal with auth
auth_cookies = machine_auth_provider_factory.instance.get_auth_cookies(
get_reports_user()
get_reports_user(session)
)
# Build something like "session=cool_sess.val;other-cookie=awesome_other_cookie"
cookie_str = ";".join([f"{key}={val}" for key, val in auth_cookies.items()])
Expand Down Expand Up @@ -384,10 +396,10 @@ def _get_slice_screenshot(slice_id: int, session: Session) -> ScreenshotData:


def _get_slice_visualization(
slc: Slice, delivery_type: EmailDeliveryType
slc: Slice, delivery_type: EmailDeliveryType, session: Session
) -> ReportContent:
# Create a driver, fetch the page, wait for the page to render
driver = create_webdriver()
driver = create_webdriver(session)
window = config["WEBDRIVER_WINDOW"]["slice"]
driver.set_window_size(*window)

Expand Down Expand Up @@ -438,9 +450,9 @@ def deliver_slice( # pylint: disable=too-many-arguments
slc = session.query(Slice).filter_by(id=slice_id).one()

if email_format == SliceEmailReportFormat.data:
report_content = _get_slice_data(slc, delivery_type)
report_content = _get_slice_data(slc, delivery_type, session)
elif email_format == SliceEmailReportFormat.visualization:
report_content = _get_slice_visualization(slc, delivery_type)
report_content = _get_slice_visualization(slc, delivery_type, session)
else:
raise RuntimeError("Unknown email report format")

Expand Down
2 changes: 1 addition & 1 deletion tests/schedules_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def test_create_driver(self, mock_driver_class):
mock_driver_class.return_value = mock_driver
mock_driver.find_elements_by_id.side_effect = [True, False]

create_webdriver()
create_webdriver(db.session)
mock_driver.add_cookie.assert_called_once()

@patch("superset.tasks.schedules.firefox.webdriver.WebDriver")
Expand Down