diff --git a/docs/docs/installation/alerts-reports.mdx b/docs/docs/installation/alerts-reports.mdx index d8f04817e872..3538ca1479e4 100644 --- a/docs/docs/installation/alerts-reports.mdx +++ b/docs/docs/installation/alerts-reports.mdx @@ -371,10 +371,36 @@ to specify on behalf of which username to render the dashboards. In general dash are not accessible to unauthorized requests, that is why the worker needs to take over credentials of an existing user to take a snapshot. +By default, Alerts and Reports are executed as the user that the `THUMBNAIL_SELENIUM_USER` config +parameter is set to. To change this user, just change the config as follows: + ```python THUMBNAIL_SELENIUM_USER = 'username_with_permission_to_access_dashboards' ``` +In addition, it's also possible to execute the reports as the report owners/creators. This is typically +needed if there isn't a central service account that has access to all objects or databases (e.g. +when using user impersonation on database connections). For this there's the config flag +`ALERTS_REPORTS_EXECUTE_AS` which makes it possible to customize how alerts and reports are executed. +To first try to execute as the creator in the owners list (if present), then fall +back to the creator, then the last modifier in the owners list (if present), then the +last modifier, then an owner (giving priority to the last modifier and then the +creator if either is contained within the list of owners, otherwise the first owner +will be used) and finally `THUMBNAIL_SELENIUM_USER`, set as follows: + +```python +from superset.reports.types import ReportScheduleExecutor + +ALERT_REPORTS_EXECUTE_AS = [ + ReportScheduleExecutor.CREATOR_OWNER, + ReportScheduleExecutor.CREATOR, + ReportScheduleExecutor.MODIFIER_OWNER, + ReportScheduleExecutor.MODIFIER, + ReportScheduleExecutor.OWNER, + ReportScheduleExecutor.SELENIUM, +] +``` + **Important notes** - Be mindful of the concurrency setting for celery (using `-c 4`). Selenium/webdriver instances can @@ -382,7 +408,7 @@ THUMBNAIL_SELENIUM_USER = 'username_with_permission_to_access_dashboards' - In some cases, if you notice a lot of leaked geckodriver processes, try running your celery processes with `celery worker --pool=prefork --max-tasks-per-child=128 ...` - It is recommended to run separate workers for the `sql_lab` and `email_reports` tasks. This can be - done using the `queue` field in `CELERY_ANNOTATIONS`. + done using the `queue` field in `task_annotations`. - Adjust `WEBDRIVER_BASEURL` in your configuration file if celery workers can’t access Superset via its default value of `http://0.0.0.0:8080/`. diff --git a/superset/config.py b/superset/config.py index 30f8bbc89341..341217822398 100644 --- a/superset/config.py +++ b/superset/config.py @@ -57,6 +57,7 @@ from superset.advanced_data_type.types import AdvancedDataType from superset.constants import CHANGE_ME_SECRET_KEY from superset.jinja_context import BaseTemplateProcessor +from superset.reports.types import ReportScheduleExecutor from superset.stats_logger import DummyStatsLogger from superset.superset_typing import CacheConfig from superset.utils.core import is_test, NO_TIME_RANGE, parse_boolean_string @@ -1143,6 +1144,24 @@ def EMAIL_HEADER_MUTATOR( # pylint: disable=invalid-name,unused-argument # sliding cron window size, should be synced with the celery beat config minus 1 second ALERT_REPORTS_CRON_WINDOW_SIZE = 59 ALERT_REPORTS_WORKING_TIME_OUT_KILL = True +# Which user to attempt to execute Alerts/Reports as. By default, +# use the user defined in the `THUMBNAIL_SELENIUM_USER` config parameter. +# To first try to execute as the creator in the owners list (if present), then fall +# back to the creator, then the last modifier in the owners list (if present), then the +# last modifier, then an owner (giving priority to the last modifier and then the +# creator if either is contained within the list of owners, otherwise the first owner +# will be used) and finally `THUMBNAIL_SELENIUM_USER`, set as follows: +# ALERT_REPORTS_EXECUTE_AS = [ +# ReportScheduleExecutor.CREATOR_OWNER, +# ReportScheduleExecutor.CREATOR, +# ReportScheduleExecutor.MODIFIER_OWNER, +# ReportScheduleExecutor.MODIFIER, +# ReportScheduleExecutor.OWNER, +# ReportScheduleExecutor.SELENIUM, +# ] +ALERT_REPORTS_EXECUTE_AS: List[ReportScheduleExecutor] = [ + ReportScheduleExecutor.SELENIUM +] # if ALERT_REPORTS_WORKING_TIME_OUT_KILL is True, set a celery hard timeout # Equal to working timeout + ALERT_REPORTS_WORKING_TIME_OUT_LAG ALERT_REPORTS_WORKING_TIME_OUT_LAG = int(timedelta(seconds=10).total_seconds()) diff --git a/superset/reports/commands/exceptions.py b/superset/reports/commands/exceptions.py index c087cf95e0d1..89f2c82fb9ff 100644 --- a/superset/reports/commands/exceptions.py +++ b/superset/reports/commands/exceptions.py @@ -250,8 +250,8 @@ class ReportScheduleNotificationError(CommandException): message = _("Alert on grace period") -class ReportScheduleSelleniumUserNotFoundError(CommandException): - message = _("Report Schedule sellenium user not found") +class ReportScheduleUserNotFoundError(CommandException): + message = _("Report Schedule user not found") class ReportScheduleStateNotFoundError(CommandException): diff --git a/superset/reports/commands/execute.py b/superset/reports/commands/execute.py index a8776f0d7d8a..0ab6c6252440 100644 --- a/superset/reports/commands/execute.py +++ b/superset/reports/commands/execute.py @@ -22,10 +22,9 @@ import pandas as pd from celery.exceptions import SoftTimeLimitExceeded -from flask_appbuilder.security.sqla.models import User from sqlalchemy.orm import Session -from superset import app, security_manager +from superset import app from superset.commands.base import BaseCommand from superset.commands.exceptions import CommandException from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType @@ -45,7 +44,6 @@ ReportSchedulePreviousWorkingError, ReportScheduleScreenshotFailedError, ReportScheduleScreenshotTimeout, - ReportScheduleSelleniumUserNotFoundError, ReportScheduleStateNotFoundError, ReportScheduleUnexpectedError, ReportScheduleWorkingTimeoutError, @@ -67,6 +65,7 @@ from superset.reports.notifications import create_notification from superset.reports.notifications.base import NotificationContent from superset.reports.notifications.exceptions import NotificationError +from superset.reports.utils import get_executor from superset.utils.celery import session_scope from superset.utils.core import HeaderDataType, override_user from superset.utils.csv import get_chart_csv_data, get_chart_dataframe @@ -77,13 +76,6 @@ logger = logging.getLogger(__name__) -def _get_user() -> User: - user = security_manager.find_user(username=app.config["THUMBNAIL_SELENIUM_USER"]) - if not user: - raise ReportScheduleSelleniumUserNotFoundError() - return user - - class BaseReportState: current_states: List[ReportState] = [] initial: bool = False @@ -182,11 +174,11 @@ def _get_url( **kwargs, ) - # If we need to render dashboard in a specific sate, use stateful permalink + # If we need to render dashboard in a specific state, use stateful permalink dashboard_state = self._report_schedule.extra.get("dashboard") if dashboard_state: permalink_key = CreateDashboardPermalinkCommand( - dashboard_id=self._report_schedule.dashboard_id, + dashboard_id=str(self._report_schedule.dashboard_id), state=dashboard_state, ).run() return get_url_path("Superset.dashboard_permalink", key=permalink_key) @@ -206,7 +198,7 @@ def _get_screenshots(self) -> List[bytes]: :raises: ReportScheduleScreenshotFailedError """ url = self._get_url() - user = _get_user() + user = get_executor(self._report_schedule) if self._report_schedule.chart: screenshot: Union[ChartScreenshot, DashboardScreenshot] = ChartScreenshot( url, @@ -236,16 +228,15 @@ def _get_screenshots(self) -> List[bytes]: def _get_csv_data(self) -> bytes: url = self._get_url(result_format=ChartDataResultFormat.CSV) - auth_cookies = machine_auth_provider_factory.instance.get_auth_cookies( - _get_user() - ) + user = get_executor(self._report_schedule) + auth_cookies = machine_auth_provider_factory.instance.get_auth_cookies(user) if self._report_schedule.chart.query_context is None: logger.warning("No query context found, taking a screenshot to generate it") self._update_query_context() try: - logger.info("Getting chart from %s", url) + logger.info("Getting chart from %s as user %s", url, user.username) csv_data = get_chart_csv_data(url, auth_cookies) except SoftTimeLimitExceeded as ex: raise ReportScheduleCsvTimeout() from ex @@ -262,16 +253,15 @@ def _get_embedded_data(self) -> pd.DataFrame: Return data as a Pandas dataframe, to embed in notifications as a table. """ url = self._get_url(result_format=ChartDataResultFormat.JSON) - auth_cookies = machine_auth_provider_factory.instance.get_auth_cookies( - _get_user() - ) + user = get_executor(self._report_schedule) + auth_cookies = machine_auth_provider_factory.instance.get_auth_cookies(user) if self._report_schedule.chart.query_context is None: logger.warning("No query context found, taking a screenshot to generate it") self._update_query_context() try: - logger.info("Getting chart from %s", url) + logger.info("Getting chart from %s as user %s", url, user.username) dataframe = get_chart_dataframe(url, auth_cookies) except SoftTimeLimitExceeded as ex: raise ReportScheduleDataFrameTimeout() from ex @@ -674,10 +664,16 @@ def __init__(self, task_id: str, model_id: int, scheduled_dttm: datetime): def run(self) -> None: with session_scope(nullpool=True) as session: try: - with override_user(_get_user()): - self.validate(session=session) - if not self._model: - raise ReportScheduleExecuteUnexpectedError() + self.validate(session=session) + if not self._model: + raise ReportScheduleExecuteUnexpectedError() + user = get_executor(self._model) + with override_user(user): + logger.info( + "Running report schedule %s as user %s", + self._execution_id, + user.username, + ) ReportScheduleStateMachine( session, self._execution_id, self._model, self._scheduled_dttm ).run() @@ -695,6 +691,8 @@ def validate( # pylint: disable=arguments-differ self._model_id, self._execution_id, ) - self._model = ReportScheduleDAO.find_by_id(self._model_id, session=session) + self._model = ( + session.query(ReportSchedule).filter_by(id=self._model_id).one_or_none() + ) if not self._model: raise ReportScheduleNotFoundError() diff --git a/superset/reports/types.py b/superset/reports/types.py index d487e3ad2376..7977a2defa9a 100644 --- a/superset/reports/types.py +++ b/superset/reports/types.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from enum import Enum from typing import TypedDict from superset.dashboards.permalink.types import DashboardPermalinkState @@ -21,3 +22,12 @@ class ReportScheduleExtra(TypedDict): dashboard: DashboardPermalinkState + + +class ReportScheduleExecutor(str, Enum): + SELENIUM = "selenium" + CREATOR = "creator" + CREATOR_OWNER = "creator_owner" + MODIFIER = "modifier" + MODIFIER_OWNER = "modifier_owner" + OWNER = "owner" diff --git a/superset/reports/utils.py b/superset/reports/utils.py new file mode 100644 index 000000000000..215fca99887a --- /dev/null +++ b/superset/reports/utils.py @@ -0,0 +1,71 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from flask_appbuilder.security.sqla.models import User + +from superset import app, security_manager +from superset.reports.commands.exceptions import ReportScheduleUserNotFoundError +from superset.reports.models import ReportSchedule +from superset.reports.types import ReportScheduleExecutor + + +# pylint: disable=too-many-branches +def get_executor(report_schedule: ReportSchedule) -> User: + """ + Extract the user that should be used to execute a report schedule as. + + :param report_schedule: The report to execute + :return: User to execute the report as + """ + user_types = app.config["ALERT_REPORTS_EXECUTE_AS"] + owners = report_schedule.owners + owner_dict = {owner.id: owner for owner in owners} + for user_type in user_types: + if user_type == ReportScheduleExecutor.SELENIUM: + username = app.config["THUMBNAIL_SELENIUM_USER"] + if username and (user := security_manager.find_user(username=username)): + return user + if user_type == ReportScheduleExecutor.CREATOR_OWNER: + if (user := report_schedule.created_by) and ( + owner := owner_dict.get(user.id) + ): + return owner + if user_type == ReportScheduleExecutor.CREATOR: + if user := report_schedule.created_by: + return user + if user_type == ReportScheduleExecutor.MODIFIER_OWNER: + if (user := report_schedule.changed_by) and ( + owner := owner_dict.get(user.id) + ): + return owner + if user_type == ReportScheduleExecutor.MODIFIER: + if user := report_schedule.changed_by: + return user + if user_type == ReportScheduleExecutor.OWNER: + owners = report_schedule.owners + if len(owners) == 1: + return owners[0] + if len(owners) > 1: + if modifier := report_schedule.changed_by: + if modifier and (user := owner_dict.get(modifier.id)): + return user + if creator := report_schedule.created_by: + if creator and (user := owner_dict.get(creator.id)): + return user + return owners[0] + + raise ReportScheduleUserNotFoundError() diff --git a/superset/utils/machine_auth.py b/superset/utils/machine_auth.py index 01347f90f1c2..d58f739f7713 100644 --- a/superset/utils/machine_auth.py +++ b/superset/utils/machine_auth.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. +from __future__ import annotations + import importlib import logging from typing import Callable, Dict, TYPE_CHECKING @@ -34,7 +36,7 @@ class MachineAuthProvider: def __init__( - self, auth_webdriver_func_override: Callable[[WebDriver, "User"], WebDriver] + self, auth_webdriver_func_override: Callable[[WebDriver, User], WebDriver] ): # This is here in order to allow for the authenticate_webdriver func to be # overridden via config, as opposed to the entire provider implementation @@ -43,7 +45,7 @@ def __init__( def authenticate_webdriver( self, driver: WebDriver, - user: "User", + user: User, ) -> WebDriver: """ Default AuthDriverFuncType type that sets a session cookie flask-login style @@ -69,7 +71,7 @@ def authenticate_webdriver( return driver @staticmethod - def get_auth_cookies(user: "User") -> Dict[str, str]: + def get_auth_cookies(user: User) -> Dict[str, str]: # Login with the user specified to get the reports with current_app.test_request_context("/login"): login_user(user) diff --git a/superset/utils/screenshots.py b/superset/utils/screenshots.py index 7c4d372414fc..c81ede6dc195 100644 --- a/superset/utils/screenshots.py +++ b/superset/utils/screenshots.py @@ -14,6 +14,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations + import logging from io import BytesIO from typing import Optional, TYPE_CHECKING, Union @@ -68,7 +70,7 @@ def cache_key( return md5_sha_from_dict(args) def get_screenshot( - self, user: "User", window_size: Optional[WindowSize] = None + self, user: User, window_size: Optional[WindowSize] = None ) -> Optional[bytes]: driver = self.driver(window_size) self.screenshot = driver.get_screenshot(self.url, self.element, user) @@ -76,8 +78,8 @@ def get_screenshot( def get( self, - user: "User" = None, - cache: "Cache" = None, + user: User = None, + cache: Cache = None, thumb_size: Optional[WindowSize] = None, ) -> Optional[BytesIO]: """ @@ -103,7 +105,7 @@ def get( def get_from_cache( self, - cache: "Cache", + cache: Cache, window_size: Optional[WindowSize] = None, thumb_size: Optional[WindowSize] = None, ) -> Optional[BytesIO]: @@ -111,7 +113,7 @@ def get_from_cache( return self.get_from_cache_key(cache, cache_key) @staticmethod - def get_from_cache_key(cache: "Cache", cache_key: str) -> Optional[BytesIO]: + def get_from_cache_key(cache: Cache, cache_key: str) -> Optional[BytesIO]: logger.info("Attempting to get from cache: %s", cache_key) payload = cache.get(cache_key) if payload: @@ -121,10 +123,10 @@ def get_from_cache_key(cache: "Cache", cache_key: str) -> Optional[BytesIO]: def compute_and_cache( # pylint: disable=too-many-arguments self, - user: "User" = None, + user: User = None, window_size: Optional[WindowSize] = None, thumb_size: Optional[WindowSize] = None, - cache: "Cache" = None, + cache: Cache = None, force: bool = True, ) -> Optional[bytes]: """ diff --git a/superset/utils/webdriver.py b/superset/utils/webdriver.py index 5fdc0a213116..93e8957f2176 100644 --- a/superset/utils/webdriver.py +++ b/superset/utils/webdriver.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. +from __future__ import annotations + import logging from enum import Enum from time import sleep @@ -83,7 +85,7 @@ def create(self) -> WebDriver: return driver_class(**kwargs) - def auth(self, user: "User") -> WebDriver: + def auth(self, user: User) -> WebDriver: driver = self.create() return machine_auth_provider_factory.instance.authenticate_webdriver( driver, user @@ -104,7 +106,7 @@ def destroy(driver: WebDriver, tries: int = 2) -> None: pass def get_screenshot( - self, url: str, element_name: str, user: "User" + self, url: str, element_name: str, user: User ) -> Optional[bytes]: driver = self.auth(user) driver.set_window_size(*self._window) @@ -134,7 +136,11 @@ def get_screenshot( ] logger.debug("Wait %i seconds for chart animation", selenium_animation_wait) sleep(selenium_animation_wait) - logger.info("Taking a PNG screenshot of url %s", url) + logger.info( + "Taking a PNG screenshot of url %s as user %s", + url, + user.username, + ) img = element.screenshot_as_png except TimeoutException: logger.warning("Selenium timed out requesting url %s", url, exc_info=True) diff --git a/tests/integration_tests/reports/commands_tests.py b/tests/integration_tests/reports/commands_tests.py index b3ef86c5e32c..2dd1c461cafc 100644 --- a/tests/integration_tests/reports/commands_tests.py +++ b/tests/integration_tests/reports/commands_tests.py @@ -23,6 +23,7 @@ import pytest from flask import current_app +from flask_appbuilder.security.sqla.models import User from flask_sqlalchemy import BaseQuery from freezegun import freeze_time from sqlalchemy.sql import func @@ -55,6 +56,7 @@ ReportScheduleValidatorType, ReportState, ) +from superset.reports.types import ReportScheduleExecutor from superset.utils.database import get_example_database from tests.integration_tests.fixtures.birth_names_dashboard import ( load_birth_names_dashboard_with_slices, @@ -68,7 +70,7 @@ cleanup_report_schedule, create_report_notification, CSV_FILE, - OWNER_EMAIL, + DEFAULT_OWNER_EMAIL, SCREENSHOT_FILE, TEST_ID, ) @@ -152,6 +154,19 @@ def create_report_email_chart(): cleanup_report_schedule(report_schedule) +@pytest.fixture() +def create_report_email_chart_alpha_owner(get_user): + with app.app_context(): + owners = [get_user("alpha")] + chart = db.session.query(Slice).first() + report_schedule = create_report_notification( + email_target="target@email.com", chart=chart, owners=owners + ) + yield report_schedule + + cleanup_report_schedule(report_schedule) + + @pytest.fixture() def create_report_email_chart_force_screenshot(): with app.app_context(): @@ -645,6 +660,65 @@ def test_email_chart_report_schedule( assert_log(ReportState.SUCCESS) +@pytest.mark.usefixtures( + "load_birth_names_dashboard_with_slices", "create_report_email_chart_alpha_owner" +) +@patch("superset.reports.notifications.email.send_email_smtp") +@patch("superset.utils.screenshots.ChartScreenshot.get_screenshot") +def test_email_chart_report_schedule_alpha_owner( + screenshot_mock, + email_mock, + create_report_email_chart_alpha_owner, +): + """ + ExecuteReport Command: Test chart email report schedule with screenshot + executed as the chart owner + """ + config_key = "ALERT_REPORTS_EXECUTE_AS" + original_config_value = app.config[config_key] + app.config[config_key] = [ReportScheduleExecutor.OWNER] + + # setup screenshot mock + username = "" + + def _screenshot_side_effect(user: User) -> Optional[bytes]: + nonlocal username + username = user.username + + return SCREENSHOT_FILE + + screenshot_mock.side_effect = _screenshot_side_effect + + with freeze_time("2020-01-01T00:00:00Z"): + AsyncExecuteReportScheduleCommand( + TEST_ID, create_report_email_chart_alpha_owner.id, datetime.utcnow() + ).run() + + notification_targets = get_target_from_report_schedule( + create_report_email_chart_alpha_owner + ) + # assert that the screenshot is executed as the chart owner + assert username == "alpha" + + # assert that the link sent is correct + assert ( + 'Explore in Superset' + in email_mock.call_args[0][2] + ) + # Assert the email smtp address + assert email_mock.call_args[0][0] == notification_targets[0] + # Assert the email inline screenshot + smtp_images = email_mock.call_args[1]["images"] + assert smtp_images[list(smtp_images.keys())[0]] == SCREENSHOT_FILE + # Assert logs are correct + assert_log(ReportState.SUCCESS) + + app.config[config_key] = original_config_value + + @pytest.mark.usefixtures( "load_birth_names_dashboard_with_slices", "create_report_email_chart_force_screenshot", @@ -1465,7 +1539,7 @@ def test_soft_timeout_alert(email_mock, create_alert_email_chart): notification_targets = get_target_from_report_schedule(create_alert_email_chart) # Assert the email smtp address, asserts a notification was sent with the error - assert email_mock.call_args[0][0] == OWNER_EMAIL + assert email_mock.call_args[0][0] == DEFAULT_OWNER_EMAIL assert_log( ReportState.ERROR, error_message="A timeout occurred while executing the query." @@ -1494,7 +1568,7 @@ def test_soft_timeout_screenshot(screenshot_mock, email_mock, create_alert_email ).run() # Assert the email smtp address, asserts a notification was sent with the error - assert email_mock.call_args[0][0] == OWNER_EMAIL + assert email_mock.call_args[0][0] == DEFAULT_OWNER_EMAIL assert_log( ReportState.ERROR, error_message="A timeout occurred while taking a screenshot." @@ -1534,7 +1608,7 @@ def test_soft_timeout_csv( create_report_email_chart_with_csv ) # Assert the email smtp address, asserts a notification was sent with the error - assert email_mock.call_args[0][0] == OWNER_EMAIL + assert email_mock.call_args[0][0] == DEFAULT_OWNER_EMAIL assert_log( ReportState.ERROR, @@ -1574,7 +1648,7 @@ def test_generate_no_csv( create_report_email_chart_with_csv ) # Assert the email smtp address, asserts a notification was sent with the error - assert email_mock.call_args[0][0] == OWNER_EMAIL + assert email_mock.call_args[0][0] == DEFAULT_OWNER_EMAIL assert_log( ReportState.ERROR, @@ -1603,7 +1677,7 @@ def test_fail_screenshot(screenshot_mock, email_mock, create_report_email_chart) notification_targets = get_target_from_report_schedule(create_report_email_chart) # Assert the email smtp address, asserts a notification was sent with the error - assert email_mock.call_args[0][0] == OWNER_EMAIL + assert email_mock.call_args[0][0] == DEFAULT_OWNER_EMAIL assert_log( ReportState.ERROR, error_message="Failed taking a screenshot Unexpected error" @@ -1636,7 +1710,7 @@ def test_fail_csv( get_target_from_report_schedule(create_report_email_chart_with_csv) # Assert the email smtp address, asserts a notification was sent with the error - assert email_mock.call_args[0][0] == OWNER_EMAIL + assert email_mock.call_args[0][0] == DEFAULT_OWNER_EMAIL assert_log( ReportState.ERROR, error_message="Failed generating csv " @@ -1685,7 +1759,7 @@ def test_invalid_sql_alert(email_mock, create_invalid_sql_alert_email_chart): create_invalid_sql_alert_email_chart ) # Assert the email smtp address, asserts a notification was sent with the error - assert email_mock.call_args[0][0] == OWNER_EMAIL + assert email_mock.call_args[0][0] == DEFAULT_OWNER_EMAIL @pytest.mark.usefixtures("create_invalid_sql_alert_email_chart") @@ -1706,7 +1780,7 @@ def test_grace_period_error(email_mock, create_invalid_sql_alert_email_chart): create_invalid_sql_alert_email_chart ) # Assert the email smtp address, asserts a notification was sent with the error - assert email_mock.call_args[0][0] == OWNER_EMAIL + assert email_mock.call_args[0][0] == DEFAULT_OWNER_EMAIL assert ( get_notification_error_sent_count(create_invalid_sql_alert_email_chart) == 1 ) diff --git a/tests/integration_tests/reports/scheduler_tests.py b/tests/integration_tests/reports/scheduler_tests.py index 9f3d0d55d886..76e4c4006e13 100644 --- a/tests/integration_tests/reports/scheduler_tests.py +++ b/tests/integration_tests/reports/scheduler_tests.py @@ -14,9 +14,13 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. + from random import randint +from typing import List from unittest.mock import patch +import pytest +from flask_appbuilder.security.sqla.models import User from freezegun import freeze_time from freezegun.api import FakeDatetime # type: ignore @@ -27,8 +31,14 @@ from tests.integration_tests.test_app import app +@pytest.fixture +def owners(get_user) -> List[User]: + return [get_user("admin")] + + +@pytest.mark.usefixtures("owners") @patch("superset.tasks.scheduler.execute.apply_async") -def test_scheduler_celery_timeout_ny(execute_mock): +def test_scheduler_celery_timeout_ny(execute_mock, owners): """ Reports scheduler: Test scheduler setting celery soft and hard timeout """ @@ -39,6 +49,7 @@ def test_scheduler_celery_timeout_ny(execute_mock): name="report", crontab="0 4 * * *", timezone="America/New_York", + owners=owners, ) with freeze_time("2020-01-01T09:00:00Z"): @@ -49,8 +60,9 @@ def test_scheduler_celery_timeout_ny(execute_mock): db.session.commit() +@pytest.mark.usefixtures("owners") @patch("superset.tasks.scheduler.execute.apply_async") -def test_scheduler_celery_no_timeout_ny(execute_mock): +def test_scheduler_celery_no_timeout_ny(execute_mock, owners): """ Reports scheduler: Test scheduler setting celery soft and hard timeout """ @@ -61,6 +73,7 @@ def test_scheduler_celery_no_timeout_ny(execute_mock): name="report", crontab="0 4 * * *", timezone="America/New_York", + owners=owners, ) with freeze_time("2020-01-01T09:00:00Z"): @@ -71,8 +84,9 @@ def test_scheduler_celery_no_timeout_ny(execute_mock): app.config["ALERT_REPORTS_WORKING_TIME_OUT_KILL"] = True +@pytest.mark.usefixtures("owners") @patch("superset.tasks.scheduler.execute.apply_async") -def test_scheduler_celery_timeout_utc(execute_mock): +def test_scheduler_celery_timeout_utc(execute_mock, owners): """ Reports scheduler: Test scheduler setting celery soft and hard timeout """ @@ -83,6 +97,7 @@ def test_scheduler_celery_timeout_utc(execute_mock): name="report", crontab="0 9 * * *", timezone="UTC", + owners=owners, ) with freeze_time("2020-01-01T09:00:00Z"): @@ -93,8 +108,9 @@ def test_scheduler_celery_timeout_utc(execute_mock): db.session.commit() +@pytest.mark.usefixtures("owners") @patch("superset.tasks.scheduler.execute.apply_async") -def test_scheduler_celery_no_timeout_utc(execute_mock): +def test_scheduler_celery_no_timeout_utc(execute_mock, owners): """ Reports scheduler: Test scheduler setting celery soft and hard timeout """ @@ -105,6 +121,7 @@ def test_scheduler_celery_no_timeout_utc(execute_mock): name="report", crontab="0 9 * * *", timezone="UTC", + owners=owners, ) with freeze_time("2020-01-01T09:00:00Z"): @@ -115,9 +132,10 @@ def test_scheduler_celery_no_timeout_utc(execute_mock): app.config["ALERT_REPORTS_WORKING_TIME_OUT_KILL"] = True +@pytest.mark.usefixtures("owners") @patch("superset.tasks.scheduler.is_feature_enabled") @patch("superset.tasks.scheduler.execute.apply_async") -def test_scheduler_feature_flag_off(execute_mock, is_feature_enabled): +def test_scheduler_feature_flag_off(execute_mock, is_feature_enabled, owners): """ Reports scheduler: Test scheduler with feature flag off """ @@ -128,6 +146,7 @@ def test_scheduler_feature_flag_off(execute_mock, is_feature_enabled): name="report", crontab="0 9 * * *", timezone="UTC", + owners=owners, ) with freeze_time("2020-01-01T09:00:00Z"): @@ -137,10 +156,11 @@ def test_scheduler_feature_flag_off(execute_mock, is_feature_enabled): db.session.commit() +@pytest.mark.usefixtures("owners") @patch("superset.reports.commands.execute.AsyncExecuteReportScheduleCommand.__init__") @patch("superset.reports.commands.execute.AsyncExecuteReportScheduleCommand.run") @patch("superset.tasks.scheduler.execute.update_state") -def test_execute_task(update_state_mock, command_mock, init_mock): +def test_execute_task(update_state_mock, command_mock, init_mock, owners): from superset.reports.commands.exceptions import ReportScheduleUnexpectedError with app.app_context(): @@ -149,6 +169,7 @@ def test_execute_task(update_state_mock, command_mock, init_mock): name=f"report-{randint(0,1000)}", crontab="0 4 * * *", timezone="America/New_York", + owners=owners, ) init_mock.return_value = None command_mock.side_effect = ReportScheduleUnexpectedError("Unexpected error") diff --git a/tests/integration_tests/reports/utils.py b/tests/integration_tests/reports/utils.py index 4b204f0d9004..3801beb1a328 100644 --- a/tests/integration_tests/reports/utils.py +++ b/tests/integration_tests/reports/utils.py @@ -35,26 +35,27 @@ ReportScheduleType, ReportState, ) +from superset.utils.core import override_user from tests.integration_tests.test_app import app from tests.integration_tests.utils import read_fixture TEST_ID = str(uuid4()) CSV_FILE = read_fixture("trends.csv") SCREENSHOT_FILE = read_fixture("sample.png") -OWNER_EMAIL = "admin@fab.org" +DEFAULT_OWNER_EMAIL = "admin@fab.org" def insert_report_schedule( type: str, name: str, crontab: str, + owners: List[User], timezone: Optional[str] = None, sql: Optional[str] = None, description: Optional[str] = None, chart: Optional[Slice] = None, dashboard: Optional[Dashboard] = None, database: Optional[Database] = None, - owners: Optional[List[User]] = None, validator_type: Optional[str] = None, validator_config_json: Optional[str] = None, log_retention: Optional[int] = None, @@ -70,28 +71,30 @@ def insert_report_schedule( recipients = recipients or [] logs = logs or [] last_state = last_state or ReportState.NOOP - report_schedule = ReportSchedule( - type=type, - name=name, - crontab=crontab, - timezone=timezone, - sql=sql, - description=description, - chart=chart, - dashboard=dashboard, - database=database, - owners=owners, - validator_type=validator_type, - validator_config_json=validator_config_json, - log_retention=log_retention, - grace_period=grace_period, - recipients=recipients, - logs=logs, - last_state=last_state, - report_format=report_format, - extra=extra, - force_screenshot=force_screenshot, - ) + + with override_user(owners[0]): + report_schedule = ReportSchedule( + type=type, + name=name, + crontab=crontab, + timezone=timezone, + sql=sql, + description=description, + chart=chart, + dashboard=dashboard, + database=database, + owners=owners, + validator_type=validator_type, + validator_config_json=validator_config_json, + log_retention=log_retention, + grace_period=grace_period, + recipients=recipients, + logs=logs, + last_state=last_state, + report_format=report_format, + extra=extra, + force_screenshot=force_screenshot, + ) db.session.add(report_schedule) db.session.commit() return report_schedule @@ -112,12 +115,16 @@ def create_report_notification( name: Optional[str] = None, extra: Optional[Dict[str, Any]] = None, force_screenshot: bool = False, + owners: Optional[List[User]] = None, ) -> ReportSchedule: - owner = ( - db.session.query(security_manager.user_model) - .filter_by(email=OWNER_EMAIL) - .one_or_none() - ) + if not owners: + owners = [ + ( + db.session.query(security_manager.user_model) + .filter_by(email=DEFAULT_OWNER_EMAIL) + .one_or_none() + ) + ] if slack_channel: recipient = ReportRecipients( @@ -147,7 +154,7 @@ def create_report_notification( dashboard=dashboard, database=database, recipients=[recipient], - owners=[owner], + owners=owners, validator_type=validator_type, validator_config_json=validator_config_json, grace_period=grace_period, diff --git a/tests/unit_tests/reports/__init__.py b/tests/unit_tests/reports/__init__.py new file mode 100644 index 000000000000..13a83393a912 --- /dev/null +++ b/tests/unit_tests/reports/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/unit_tests/reports/test_utils.py b/tests/unit_tests/reports/test_utils.py new file mode 100644 index 000000000000..8b4bf93e718a --- /dev/null +++ b/tests/unit_tests/reports/test_utils.py @@ -0,0 +1,178 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from dataclasses import dataclass +from typing import List, Optional, Union +from unittest.mock import patch + +import pytest +from flask_appbuilder.security.sqla.models import User + +from superset.reports.types import ReportScheduleExecutor + +SELENIUM_USER_ID = 1234 + + +def _get_users( + params: Optional[Union[int, List[int]]] +) -> Optional[Union[User, List[User]]]: + if params is None: + return None + if isinstance(params, int): + return User(id=params) + return [User(id=user) for user in params] + + +@dataclass +class ReportConfig: + owners: List[int] + creator: Optional[int] = None + modifier: Optional[int] = None + + +@pytest.mark.parametrize( + "config,report_config,expected_user", + [ + ( + [ReportScheduleExecutor.SELENIUM], + ReportConfig( + owners=[1, 2], + creator=3, + modifier=4, + ), + SELENIUM_USER_ID, + ), + ( + [ + ReportScheduleExecutor.CREATOR, + ReportScheduleExecutor.CREATOR_OWNER, + ReportScheduleExecutor.OWNER, + ReportScheduleExecutor.MODIFIER, + ReportScheduleExecutor.MODIFIER_OWNER, + ReportScheduleExecutor.SELENIUM, + ], + ReportConfig(owners=[]), + SELENIUM_USER_ID, + ), + ( + [ + ReportScheduleExecutor.CREATOR, + ReportScheduleExecutor.CREATOR_OWNER, + ReportScheduleExecutor.OWNER, + ReportScheduleExecutor.MODIFIER, + ReportScheduleExecutor.MODIFIER_OWNER, + ReportScheduleExecutor.SELENIUM, + ], + ReportConfig(owners=[], modifier=1), + 1, + ), + ( + [ + ReportScheduleExecutor.CREATOR, + ReportScheduleExecutor.CREATOR_OWNER, + ReportScheduleExecutor.OWNER, + ReportScheduleExecutor.MODIFIER, + ReportScheduleExecutor.MODIFIER_OWNER, + ReportScheduleExecutor.SELENIUM, + ], + ReportConfig(owners=[2], modifier=1), + 2, + ), + ( + [ + ReportScheduleExecutor.CREATOR, + ReportScheduleExecutor.CREATOR_OWNER, + ReportScheduleExecutor.OWNER, + ReportScheduleExecutor.MODIFIER, + ReportScheduleExecutor.MODIFIER_OWNER, + ReportScheduleExecutor.SELENIUM, + ], + ReportConfig(owners=[2], creator=3, modifier=1), + 3, + ), + ( + [ + ReportScheduleExecutor.OWNER, + ], + ReportConfig(owners=[1, 2, 3, 4, 5, 6, 7], creator=3, modifier=4), + 4, + ), + ( + [ + ReportScheduleExecutor.OWNER, + ], + ReportConfig(owners=[1, 2, 3, 4, 5, 6, 7], creator=3, modifier=8), + 3, + ), + ( + [ + ReportScheduleExecutor.MODIFIER_OWNER, + ], + ReportConfig(owners=[1, 2, 3, 4, 5, 6, 7], creator=8, modifier=9), + None, + ), + ( + [ + ReportScheduleExecutor.MODIFIER_OWNER, + ], + ReportConfig(owners=[1, 2, 3, 4, 5, 6, 7], creator=8, modifier=4), + 4, + ), + ( + [ + ReportScheduleExecutor.CREATOR_OWNER, + ], + ReportConfig(owners=[1, 2, 3, 4, 5, 6, 7], creator=8, modifier=9), + None, + ), + ( + [ + ReportScheduleExecutor.CREATOR_OWNER, + ], + ReportConfig(owners=[1, 2, 3, 4, 5, 6, 7], creator=4, modifier=8), + 4, + ), + ], +) +def test_get_executor( + config: List[ReportScheduleExecutor], + report_config: ReportConfig, + expected_user: Optional[int], +) -> None: + from superset import app, security_manager + from superset.reports.commands.exceptions import ReportScheduleUserNotFoundError + from superset.reports.models import ReportSchedule + from superset.reports.utils import get_executor + + selenium_user = User(id=SELENIUM_USER_ID) + + with patch.dict(app.config, {"ALERT_REPORTS_EXECUTE_AS": config}), patch.object( + security_manager, "find_user", return_value=selenium_user + ): + report_schedule = ReportSchedule( + id=1, + type="report", + name="test_report", + owners=_get_users(report_config.owners), + created_by=_get_users(report_config.creator), + changed_by=_get_users(report_config.modifier), + ) + if expected_user is None: + with pytest.raises(ReportScheduleUserNotFoundError): + get_executor(report_schedule) + else: + assert get_executor(report_schedule).id == expected_user