From 913395d8db78f7c3ec836e11f3027996b395ce69 Mon Sep 17 00:00:00 2001 From: Jarek Potiuk Date: Wed, 24 Jul 2024 13:31:09 +0200 Subject: [PATCH] mAke standalone dag file processor works in DB isolation mode (#40916) There were a few missing DB operations in DAGFileProcessor that prevented it to run in DB isolation mode. Those have been refactored and exposed as internal API calls. A bug was fixed in scheduler_job_runner that caused using of next_event before it has been declared (which occured when standalone dag processor is used and db isolation mode. The DB retry will now correctly use logger when it is used as decorator on class method. The "main" code that removes DB connection from configuration (mostly in case of Breeze) when untrusted components are used has been improved to handle the case where DAGFile Processor forks parsing subprocesses. Tmux configuration got improved so that both non-isolation and isolation mode distribute panels better. Simplified InternalApiConfig - "main" directly sets db/internal use in db_isolation mode depending on the component. --- .github/workflows/basic-tests.yml | 6 + .github/workflows/ci.yml | 1 + airflow/__main__.py | 28 +- .../endpoints/rpc_api_endpoint.py | 16 +- airflow/api_internal/internal_api_call.py | 63 ++-- airflow/cli/commands/dag_processor_command.py | 4 + airflow/cli/commands/internal_api_command.py | 7 +- airflow/dag_processing/manager.py | 68 ++-- airflow/dag_processing/processor.py | 290 ++++++++++-------- airflow/models/dagcode.py | 2 + airflow/models/serialized_dag.py | 1 + airflow/settings.py | 12 +- .../task/task_runner/standard_task_runner.py | 6 +- airflow/utils/retries.py | 8 +- airflow/www/app.py | 3 +- scripts/in_container/bin/run_tmux | 12 +- tests/api_internal/test_internal_api_call.py | 35 ++- .../cli/commands/test_internal_api_command.py | 8 +- tests/conftest.py | 3 +- tests/core/test_settings.py | 29 +- tests/core/test_sqlalchemy_config.py | 5 + tests/dag_processing/test_processor.py | 10 +- .../test_dag_import_error_listener.py | 2 +- tests/www/views/test_views_home.py | 84 +++-- 24 files changed, 427 insertions(+), 276 deletions(-) diff --git a/.github/workflows/basic-tests.yml b/.github/workflows/basic-tests.yml index db84bae38e2e2..9828a14993581 100644 --- a/.github/workflows/basic-tests.yml +++ b/.github/workflows/basic-tests.yml @@ -52,6 +52,12 @@ on: # yamllint disable-line rule:truthy description: "Whether to run only latest version checks (true/false)" required: true type: string + enable-aip-44: + description: "Whether to enable AIP-44 (true/false)" + required: true + type: string +env: + AIRFLOW_ENABLE_AIP_44: "${{ inputs.enable-aip-44 }}" jobs: run-breeze-tests: timeout-minutes: 10 diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index e423c388f9ec8..babb1da265121 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -173,6 +173,7 @@ jobs: skip-pre-commits: ${{needs.build-info.outputs.skip-pre-commits}} canary-run: ${{needs.build-info.outputs.canary-run}} latest-versions-only: ${{needs.build-info.outputs.latest-versions-only}} + enable-aip-44: "false" build-ci-images: name: > diff --git a/airflow/__main__.py b/airflow/__main__.py index 82a866c42a478..8fbcd7e777640 100644 --- a/airflow/__main__.py +++ b/airflow/__main__.py @@ -22,6 +22,7 @@ from __future__ import annotations import os +from argparse import Namespace import argcomplete @@ -35,7 +36,8 @@ # any possible import cycles with settings downstream. from airflow import configuration from airflow.cli import cli_parser -from airflow.configuration import write_webserver_configuration_if_needed +from airflow.configuration import AirflowConfigParser, write_webserver_configuration_if_needed +from airflow.exceptions import AirflowException def main(): @@ -55,23 +57,33 @@ def main(): conf = write_default_airflow_configuration_if_needed() if args.subcommand in ["webserver", "internal-api", "worker"]: write_webserver_configuration_if_needed(conf) + configure_internal_api(args, conf) + + args.func(args) + + +def configure_internal_api(args: Namespace, conf: AirflowConfigParser): if conf.getboolean("core", "database_access_isolation", fallback=False): if args.subcommand in ["worker", "dag-processor", "triggerer", "run"]: # Untrusted components if "AIRFLOW__DATABASE__SQL_ALCHEMY_CONN" in os.environ: # make sure that the DB is not available for the components that should not access it - del os.environ["AIRFLOW__DATABASE__SQL_ALCHEMY_CONN"] + os.environ["AIRFLOW__DATABASE__SQL_ALCHEMY_CONN"] = "none://" conf.set("database", "sql_alchemy_conn", "none://") - from airflow.settings import force_traceback_session_for_untrusted_components + from airflow.api_internal.internal_api_call import InternalApiConfig - force_traceback_session_for_untrusted_components() + InternalApiConfig.set_use_internal_api(args.subcommand) else: - # Trusted components + # Trusted components (this setting is mostly for Breeze where db_isolation and DB are both set + db_connection_url = conf.get("database", "sql_alchemy_conn") + if not db_connection_url or db_connection_url == "none://": + raise AirflowException( + f"Running trusted components {args.subcommand} in db isolation mode " + f"requires connection to be configured via database/sql_alchemy_conn." + ) from airflow.api_internal.internal_api_call import InternalApiConfig - InternalApiConfig.force_database_direct_access("Running " + args.subcommand + " command") - - args.func(args) + InternalApiConfig.set_use_database_access(args.subcommand) if __name__ == "__main__": diff --git a/airflow/api_internal/endpoints/rpc_api_endpoint.py b/airflow/api_internal/endpoints/rpc_api_endpoint.py index fcac1925d90fc..7e655e5b4ecfe 100644 --- a/airflow/api_internal/endpoints/rpc_api_endpoint.py +++ b/airflow/api_internal/endpoints/rpc_api_endpoint.py @@ -36,6 +36,7 @@ from airflow.api_connexion.exceptions import PermissionDenied from airflow.configuration import conf from airflow.jobs.job import Job, most_recent_job +from airflow.models.dagcode import DagCode from airflow.models.taskinstance import _record_task_map_for_downstreams from airflow.models.xcom_arg import _get_task_map_length from airflow.sensors.base import _orig_start_date @@ -89,13 +90,21 @@ def initialize_method_map() -> dict[str, Callable]: _add_log, _xcom_pull, _record_task_map_for_downstreams, - DagFileProcessor.update_import_errors, - DagFileProcessor.manage_slas, - DagFileProcessorManager.deactivate_stale_dags, + DagCode.remove_deleted_code, DagModel.deactivate_deleted_dags, DagModel.get_paused_dag_ids, DagModel.get_current, + DagFileProcessor._execute_task_callbacks, + DagFileProcessor.execute_callbacks, + DagFileProcessor.execute_callbacks_without_dag, + DagFileProcessor.manage_slas, + DagFileProcessor.save_dag_to_db, + DagFileProcessor.update_import_errors, + DagFileProcessor._validate_task_pools_and_update_dag_warnings, + DagFileProcessorManager._fetch_callbacks, + DagFileProcessorManager._get_priority_filelocs, DagFileProcessorManager.clear_nonexistent_import_errors, + DagFileProcessorManager.deactivate_stale_dags, DagWarning.purge_inactive_dag_warnings, DatasetManager.register_dataset_change, FileTaskHandler._render_filename_db_access, @@ -124,6 +133,7 @@ def initialize_method_map() -> dict[str, Callable]: DagRun._get_log_template, RenderedTaskInstanceFields._update_runtime_evaluated_template_fields, SerializedDagModel.get_serialized_dag, + SerializedDagModel.remove_deleted_dags, SkipMixin._skip, SkipMixin._skip_all_except, TaskInstance._check_and_change_state_before_execution, diff --git a/airflow/api_internal/internal_api_call.py b/airflow/api_internal/internal_api_call.py index 07bd0ec5fedd9..7962b5b590eed 100644 --- a/airflow/api_internal/internal_api_call.py +++ b/airflow/api_internal/internal_api_call.py @@ -30,7 +30,7 @@ from airflow.configuration import conf from airflow.exceptions import AirflowConfigException, AirflowException -from airflow.settings import _ENABLE_AIP_44 +from airflow.settings import _ENABLE_AIP_44, force_traceback_session_for_untrusted_components from airflow.typing_compat import ParamSpec from airflow.utils.jwt_signer import JWTSigner @@ -43,67 +43,52 @@ class InternalApiConfig: """Stores and caches configuration for Internal API.""" - _initialized = False _use_internal_api = False _internal_api_endpoint = "" @staticmethod - def force_database_direct_access(message: str): + def set_use_database_access(component: str): """ Block current component from using Internal API. All methods decorated with internal_api_call will always be executed locally.` This mode is needed for "trusted" components like Scheduler, Webserver, Internal Api server """ - InternalApiConfig._initialized = True InternalApiConfig._use_internal_api = False - if _ENABLE_AIP_44: - logger.info("Forcing database direct access. %s", message) + if not _ENABLE_AIP_44: + raise RuntimeError("The AIP_44 is not enabled so you cannot use it. ") + logger.info( + "DB isolation mode. But this is a trusted component and DB connection is set. " + "Using database direct access when running %s.", + component, + ) @staticmethod - def force_api_access(api_endpoint: str): - """ - Force using Internal API with provided endpoint. - - All methods decorated with internal_api_call will always be executed remote/via API. - This mode is needed for remote setups/remote executor. - """ - InternalApiConfig._initialized = True + def set_use_internal_api(component: str): + if not _ENABLE_AIP_44: + raise RuntimeError("The AIP_44 is not enabled so you cannot use it. ") + internal_api_url = conf.get("core", "internal_api_url") + url_conf = urlparse(internal_api_url) + api_path = url_conf.path + if api_path in ["", "/"]: + # Add the default path if not given in the configuration + api_path = "/internal_api/v1/rpcapi" + if url_conf.scheme not in ["http", "https"]: + raise AirflowConfigException("[core]internal_api_url must start with http:// or https://") + internal_api_endpoint = f"{url_conf.scheme}://{url_conf.netloc}{api_path}" InternalApiConfig._use_internal_api = True - InternalApiConfig._internal_api_endpoint = api_endpoint + InternalApiConfig._internal_api_endpoint = internal_api_endpoint + logger.info("DB isolation mode. Using internal_api when running %s.", component) + force_traceback_session_for_untrusted_components() @staticmethod def get_use_internal_api(): - if not InternalApiConfig._initialized: - InternalApiConfig._init_values() return InternalApiConfig._use_internal_api @staticmethod def get_internal_api_endpoint(): - if not InternalApiConfig._initialized: - InternalApiConfig._init_values() return InternalApiConfig._internal_api_endpoint - @staticmethod - def _init_values(): - use_internal_api = conf.getboolean("core", "database_access_isolation", fallback=False) - if use_internal_api and not _ENABLE_AIP_44: - raise RuntimeError("The AIP_44 is not enabled so you cannot use it.") - internal_api_endpoint = "" - if use_internal_api: - url_conf = urlparse(conf.get("core", "internal_api_url")) - api_path = url_conf.path - if api_path in ["", "/"]: - # Add the default path if not given in the configuration - api_path = "/internal_api/v1/rpcapi" - if url_conf.scheme not in ["http", "https"]: - raise AirflowConfigException("[core]internal_api_url must start with http:// or https://") - internal_api_endpoint = f"{url_conf.scheme}://{url_conf.netloc}{api_path}" - - InternalApiConfig._initialized = True - InternalApiConfig._use_internal_api = use_internal_api - InternalApiConfig._internal_api_endpoint = internal_api_endpoint - def internal_api_call(func: Callable[PS, RT]) -> Callable[PS, RT]: """ diff --git a/airflow/cli/commands/dag_processor_command.py b/airflow/cli/commands/dag_processor_command.py index 14deee6770284..8ec173ba5202e 100644 --- a/airflow/cli/commands/dag_processor_command.py +++ b/airflow/cli/commands/dag_processor_command.py @@ -22,6 +22,7 @@ from datetime import timedelta from typing import Any +from airflow.api_internal.internal_api_call import InternalApiConfig from airflow.cli.commands.daemon_utils import run_command_with_daemon_option from airflow.configuration import conf from airflow.dag_processing.manager import DagFileProcessorManager, reload_configuration_for_dag_processing @@ -37,6 +38,9 @@ def _create_dag_processor_job_runner(args: Any) -> DagProcessorJobRunner: """Create DagFileProcessorProcess instance.""" processor_timeout_seconds: int = conf.getint("core", "dag_file_processor_timeout") processor_timeout = timedelta(seconds=processor_timeout_seconds) + if InternalApiConfig.get_use_internal_api(): + from airflow.models.renderedtifields import RenderedTaskInstanceFields # noqa: F401 + from airflow.models.trigger import Trigger # noqa: F401 return DagProcessorJobRunner( job=Job(), processor=DagFileProcessorManager( diff --git a/airflow/cli/commands/internal_api_command.py b/airflow/cli/commands/internal_api_command.py index 4d377d56e0972..d1ab8eea86787 100644 --- a/airflow/cli/commands/internal_api_command.py +++ b/airflow/cli/commands/internal_api_command.py @@ -222,7 +222,12 @@ def create_app(config=None, testing=False): if "SQLALCHEMY_ENGINE_OPTIONS" not in flask_app.config: flask_app.config["SQLALCHEMY_ENGINE_OPTIONS"] = settings.prepare_engine_args() - InternalApiConfig.force_database_direct_access("Gunicorn worker initialization") + if conf.getboolean("core", "database_access_isolation", fallback=False): + InternalApiConfig.set_use_database_access("Gunicorn worker initialization") + else: + raise AirflowConfigException( + "The internal-api component should only be run when database_access_isolation is enabled." + ) csrf = CSRFProtect() csrf.init_app(flask_app) diff --git a/airflow/dag_processing/manager.py b/airflow/dag_processing/manager.py index 57858fc4270ac..c03bc074d0abd 100644 --- a/airflow/dag_processing/manager.py +++ b/airflow/dag_processing/manager.py @@ -618,7 +618,10 @@ def _run_parsing_loop(self): self._processors.pop(processor.file_path) if self.standalone_dag_processor: - self._fetch_callbacks(max_callbacks_per_loop) + for callback in DagFileProcessorManager._fetch_callbacks( + max_callbacks_per_loop, self.standalone_dag_processor, self.get_dag_directory() + ): + self._add_callback_to_queue(callback) self._scan_stale_dags() DagWarning.purge_inactive_dag_warnings() refreshed_dag_dir = self._refresh_dag_dir() @@ -707,30 +710,46 @@ def _run_parsing_loop(self): else: poll_time = 0.0 + @classmethod + @internal_api_call @provide_session - def _fetch_callbacks(self, max_callbacks: int, session: Session = NEW_SESSION): - self._fetch_callbacks_with_retries(max_callbacks, session) + def _fetch_callbacks( + cls, + max_callbacks: int, + standalone_dag_processor: bool, + dag_directory: str, + session: Session = NEW_SESSION, + ) -> list[CallbackRequest]: + return cls._fetch_callbacks_with_retries( + max_callbacks, standalone_dag_processor, dag_directory, session + ) + @classmethod @retry_db_transaction - def _fetch_callbacks_with_retries(self, max_callbacks: int, session: Session): + def _fetch_callbacks_with_retries( + cls, max_callbacks: int, standalone_dag_processor: bool, dag_directory: str, session: Session + ) -> list[CallbackRequest]: """Fetch callbacks from database and add them to the internal queue for execution.""" - self.log.debug("Fetching callbacks from the database.") + cls.logger().debug("Fetching callbacks from the database.") + + callback_queue: list[CallbackRequest] = [] with prohibit_commit(session) as guard: query = select(DbCallbackRequest) - if self.standalone_dag_processor: + if standalone_dag_processor: query = query.where( - DbCallbackRequest.processor_subdir == self.get_dag_directory(), + DbCallbackRequest.processor_subdir == dag_directory, ) query = query.order_by(DbCallbackRequest.priority_weight.asc()).limit(max_callbacks) query = with_row_locks(query, of=DbCallbackRequest, session=session, skip_locked=True) callbacks = session.scalars(query) for callback in callbacks: try: - self._add_callback_to_queue(callback.get_callback_request()) + callback_queue.append(callback.get_callback_request()) session.delete(callback) except Exception as e: - self.log.warning("Error adding callback for execution: %s, %s", callback, e) + cls.logger().warning("Error adding callback for execution: %s, %s", callback, e) guard.commit() + return callback_queue def _add_callback_to_queue(self, request: CallbackRequest): # requests are sent by dag processors. SLAs exist per-dag, but can be generated once per SLA-enabled @@ -768,23 +787,30 @@ def _add_callback_to_queue(self, request: CallbackRequest): self._add_paths_to_queue([request.full_filepath], True) Stats.incr("dag_processing.other_callback_count") - @provide_session - def _refresh_requested_filelocs(self, session=NEW_SESSION) -> None: + def _refresh_requested_filelocs(self) -> None: """Refresh filepaths from dag dir as requested by users via APIs.""" # Get values from DB table + filelocs = DagFileProcessorManager._get_priority_filelocs() + for fileloc in filelocs: + # Try removing the fileloc if already present + try: + self._file_path_queue.remove(fileloc) + except ValueError: + pass + # enqueue fileloc to the start of the queue. + self._file_path_queue.appendleft(fileloc) + + @classmethod + @internal_api_call + @provide_session + def _get_priority_filelocs(cls, session: Session = NEW_SESSION): + """Get filelocs from DB table.""" + filelocs: list[str] = [] requests = session.scalars(select(DagPriorityParsingRequest)) for request in requests: - # Check if fileloc is in valid file paths. Parsing any - # filepaths can be a security issue. - if request.fileloc in self._file_paths: - # Try removing the fileloc if already present - try: - self._file_path_queue.remove(request.fileloc) - except ValueError: - pass - # enqueue fileloc to the start of the queue. - self._file_path_queue.appendleft(request.fileloc) + filelocs.append(request.fileloc) session.delete(request) + return filelocs def _refresh_dag_dir(self) -> bool: """Refresh file paths from dag dir if we haven't done it for too long.""" diff --git a/airflow/dag_processing/processor.py b/airflow/dag_processing/processor.py index ceb0476b8a382..84049de4e2675 100644 --- a/airflow/dag_processing/processor.py +++ b/airflow/dag_processing/processor.py @@ -23,9 +23,9 @@ import threading import time import zipfile -from contextlib import redirect_stderr, redirect_stdout, suppress +from contextlib import contextmanager, redirect_stderr, redirect_stdout, suppress from datetime import timedelta -from typing import TYPE_CHECKING, Iterable, Iterator +from typing import TYPE_CHECKING, Generator, Iterable, Iterator from setproctitle import setproctitle from sqlalchemy import delete, event, func, or_, select @@ -68,6 +68,20 @@ from airflow.models.operator import Operator +@contextmanager +def count_queries(session: Session) -> Generator[list[int], None, None]: + # using list allows to read the updated counter from what context manager returns + query_count: list[int] = [0] + + @event.listens_for(session, "do_orm_execute") + def _count_db_queries(orm_execute_state): + nonlocal query_count + query_count[0] += 1 + + yield query_count + event.remove(session, "do_orm_execute", _count_db_queries) + + class DagFileProcessorProcess(LoggingMixin, MultiprocessingStartMethodMixin): """ Runs DAG processing in a separate process using DagFileProcessor. @@ -599,109 +613,124 @@ def update_import_errors( import_errors: dict[str, str], processor_subdir: str | None, session: Session = NEW_SESSION, - ) -> None: + ) -> int: """ Update any import errors to be displayed in the UI. For the DAGs in the given DagBag, record any associated import errors and clears errors for files that no longer have them. These are usually displayed through the Airflow UI so that users know that there are issues parsing DAGs. - - :param dagbag: DagBag containing DAGs with import errors + :param file_last_changed: Dictionary containing the last changed time of the files + :param import_errors: Dictionary containing the import errors :param session: session for ORM operations """ files_without_error = file_last_changed - import_errors.keys() - # Clear the errors of the processed files - # that no longer have errors - for dagbag_file in files_without_error: - session.execute( - delete(ParseImportError) - .where(ParseImportError.filename.startswith(dagbag_file)) - .execution_options(synchronize_session="fetch") - ) + with count_queries(session) as query_count: + # Clear the errors of the processed files + # that no longer have errors + for dagbag_file in files_without_error: + session.execute( + delete(ParseImportError) + .where(ParseImportError.filename.startswith(dagbag_file)) + .execution_options(synchronize_session="fetch") + ) - # files that still have errors - existing_import_error_files = [x.filename for x in session.query(ParseImportError.filename).all()] + # files that still have errors + existing_import_error_files = [x.filename for x in session.query(ParseImportError.filename).all()] - # Add the errors of the processed files - for filename, stacktrace in import_errors.items(): - if filename in existing_import_error_files: - session.query(ParseImportError).filter(ParseImportError.filename == filename).update( - {"filename": filename, "timestamp": timezone.utcnow(), "stacktrace": stacktrace}, - synchronize_session="fetch", - ) - # sending notification when an existing dag import error occurs - get_listener_manager().hook.on_existing_dag_import_error( - filename=filename, stacktrace=stacktrace - ) - else: - session.add( - ParseImportError( - filename=filename, - timestamp=timezone.utcnow(), - stacktrace=stacktrace, - processor_subdir=processor_subdir, + # Add the errors of the processed files + for filename, stacktrace in import_errors.items(): + if filename in existing_import_error_files: + session.query(ParseImportError).filter(ParseImportError.filename == filename).update( + {"filename": filename, "timestamp": timezone.utcnow(), "stacktrace": stacktrace}, + synchronize_session="fetch", + ) + # sending notification when an existing dag import error occurs + get_listener_manager().hook.on_existing_dag_import_error( + filename=filename, stacktrace=stacktrace ) + else: + session.add( + ParseImportError( + filename=filename, + timestamp=timezone.utcnow(), + stacktrace=stacktrace, + processor_subdir=processor_subdir, + ) + ) + # sending notification when a new dag import error occurs + get_listener_manager().hook.on_new_dag_import_error( + filename=filename, stacktrace=stacktrace + ) + ( + session.query(DagModel) + .filter(DagModel.fileloc == filename) + .update({"has_import_errors": True}, synchronize_session="fetch") ) - # sending notification when a new dag import error occurs - get_listener_manager().hook.on_new_dag_import_error(filename=filename, stacktrace=stacktrace) - ( - session.query(DagModel) - .filter(DagModel.fileloc == filename) - .update({"has_import_errors": True}, synchronize_session="fetch") - ) - session.commit() + session.commit() + session.flush() + return query_count[0] - @provide_session - def _validate_task_pools(self, *, dagbag: DagBag, session: Session = NEW_SESSION): + @classmethod + def update_dag_warnings(cla, *, dagbag: DagBag) -> int: """Validate and raise exception if any task in a dag is using a non-existent pool.""" - from airflow.models.pool import Pool - def check_pools(dag): - task_pools = {task.pool for task in dag.tasks} - nonexistent_pools = task_pools - pools - if nonexistent_pools: - return f"Dag '{dag.dag_id}' references non-existent pools: {sorted(nonexistent_pools)!r}" + def get_pools(dag) -> dict[str, set[str]]: + return {dag.dag_id: {task.pool for task in dag.tasks}} - pools = {p.pool for p in Pool.get_pools(session)} + pool_dict: dict[str, set[str]] = {} for dag in dagbag.dags.values(): - message = check_pools(dag) - if message: - self.dag_warnings.add(DagWarning(dag.dag_id, DagWarningType.NONEXISTENT_POOL, message)) + pool_dict.update(get_pools(dag)) for subdag in dag.subdags: - message = check_pools(subdag) - if message: - self.dag_warnings.add(DagWarning(subdag.dag_id, DagWarningType.NONEXISTENT_POOL, message)) - - def update_dag_warnings(self, *, session: Session, dagbag: DagBag) -> None: - """ - Update any import warnings to be displayed in the UI. - - For the DAGs in the given DagBag, record any associated configuration warnings and clear - warnings for files that no longer have them. These are usually displayed through the - Airflow UI so that users know that there are issues parsing DAGs. - - :param session: session for ORM operations - :param dagbag: DagBag containing DAGs with configuration warnings - """ - self._validate_task_pools(dagbag=dagbag) + pool_dict.update(get_pools(subdag)) + dag_ids = {dag.dag_id for dag in dagbag.dags.values()} + return DagFileProcessor._validate_task_pools_and_update_dag_warnings(pool_dict, dag_ids) - stored_warnings = set(session.query(DagWarning).filter(DagWarning.dag_id.in_(dagbag.dags)).all()) + @classmethod + @internal_api_call + @provide_session + def _validate_task_pools_and_update_dag_warnings( + cls, pool_dict: dict[str, set[str]], dag_ids: set[str], session: Session = NEW_SESSION + ) -> int: + with count_queries(session) as query_count: + from airflow.models.pool import Pool + + all_pools = {p.pool for p in Pool.get_pools(session)} + warnings: set[DagWarning] = set() + for dag_id, dag_pools in pool_dict.items(): + nonexistent_pools = dag_pools - all_pools + if nonexistent_pools: + warnings.add( + DagWarning( + dag_id, + DagWarningType.NONEXISTENT_POOL, + f"Dag '{dag_id}' references non-existent pools: {sorted(nonexistent_pools)!r}", + ) + ) - for warning_to_delete in stored_warnings - self.dag_warnings: - session.delete(warning_to_delete) + stored_warnings = set(session.query(DagWarning).filter(DagWarning.dag_id.in_(dag_ids)).all()) - for warning_to_add in self.dag_warnings: - session.merge(warning_to_add) + for warning_to_delete in stored_warnings - warnings: + session.delete(warning_to_delete) - session.commit() + for warning_to_add in warnings: + session.merge(warning_to_add) + session.flush() + session.commit() + return query_count[0] + @classmethod + @internal_api_call @provide_session def execute_callbacks( - self, dagbag: DagBag, callback_requests: list[CallbackRequest], session: Session = NEW_SESSION - ) -> None: + cls, + dagbag: DagBag, + callback_requests: list[CallbackRequest], + unit_test_mode: bool, + session: Session = NEW_SESSION, + ) -> int: """ Execute on failure callbacks. @@ -710,47 +739,58 @@ def execute_callbacks( :param dagbag: Dag Bag of dags :param callback_requests: failure callbacks to execute :param session: DB session. - """ - for request in callback_requests: - self.log.debug("Processing Callback Request: %s", request) - try: - if isinstance(request, TaskCallbackRequest): - self._execute_task_callbacks(dagbag, request, session=session) - elif isinstance(request, SlaCallbackRequest): - DagFileProcessor.manage_slas(dagbag.dag_folder, request.dag_id, session=session) - elif isinstance(request, DagCallbackRequest): - self._execute_dag_callbacks(dagbag, request, session) - except Exception: - self.log.exception( - "Error executing %s callback for file: %s", - request.__class__.__name__, - request.full_filepath, - ) - session.flush() + :return: number of queries executed + """ + with count_queries(session) as query_count: + for request in callback_requests: + cls.logger().debug("Processing Callback Request: %s", request) + try: + if isinstance(request, TaskCallbackRequest): + cls._execute_task_callbacks(dagbag, request, unit_test_mode, session=session) + elif isinstance(request, SlaCallbackRequest): + DagFileProcessor.manage_slas(dagbag.dag_folder, request.dag_id, session=session) + elif isinstance(request, DagCallbackRequest): + cls._execute_dag_callbacks(dagbag, request, session=session) + except Exception: + cls.logger().exception( + "Error executing %s callback for file: %s", + request.__class__.__name__, + request.full_filepath, + ) + session.flush() + session.commit() + return query_count[0] + @classmethod + @internal_api_call + @provide_session def execute_callbacks_without_dag( - self, callback_requests: list[CallbackRequest], session: Session - ) -> None: + cls, callback_requests: list[CallbackRequest], unit_test_mode: bool, session: Session = NEW_SESSION + ) -> int: """ Execute what callbacks we can as "best effort" when the dag cannot be found/had parse errors. This is so important so that tasks that failed when there is a parse error don't get stuck in queued state. """ - for request in callback_requests: - self.log.debug("Processing Callback Request: %s", request) - if isinstance(request, TaskCallbackRequest): - self._execute_task_callbacks(None, request, session) - else: - self.log.info( - "Not executing %s callback for file %s as there was a dag parse error", - request.__class__.__name__, - request.full_filepath, - ) + with count_queries(session) as query_count: + for request in callback_requests: + cls.logger().debug("Processing Callback Request: %s", request) + if isinstance(request, TaskCallbackRequest): + cls._execute_task_callbacks(None, request, unit_test_mode, session) + else: + cls.logger().info( + "Not executing %s callback for file %s as there was a dag parse error", + request.__class__.__name__, + request.full_filepath, + ) + session.flush() + session.commit() + return query_count[0] - @provide_session - def _execute_dag_callbacks(self, dagbag: DagBag, request: DagCallbackRequest, session: Session): + @classmethod + def _execute_dag_callbacks(cls, dagbag: DagBag, request: DagCallbackRequest, session: Session): dag = dagbag.dags[request.dag_id] callbacks, context = DAG.fetch_callback( dag=dag, @@ -763,7 +803,12 @@ def _execute_dag_callbacks(self, dagbag: DagBag, request: DagCallbackRequest, se if callbacks and context: DAG.execute_callback(callbacks, context, dag.dag_id) - def _execute_task_callbacks(self, dagbag: DagBag | None, request: TaskCallbackRequest, session: Session): + @classmethod + @internal_api_call + @provide_session + def _execute_task_callbacks( + cls, dagbag: DagBag | None, request: TaskCallbackRequest, unit_test_mode: bool, session: Session + ): if not request.is_failure_callback: return @@ -796,8 +841,8 @@ def _execute_task_callbacks(self, dagbag: DagBag | None, request: TaskCallbackRe if task: ti.refresh_from_task(task) - ti.handle_failure(error=request.msg, test_mode=self.UNIT_TEST_MODE, session=session) - self.log.info("Executed failure callback for %s in state %s", ti, ti.state) + ti.handle_failure(error=request.msg, test_mode=unit_test_mode, session=session) + cls.logger().info("Executed failure callback for %s in state %s", ti, ti.state) session.flush() @classmethod @@ -809,13 +854,11 @@ def _get_dagbag(cls, file_path: str): Stats.incr("dag_file_refresh_error", tags={"file_path": file_path}) raise - @provide_session def process_file( self, file_path: str, callback_requests: list[CallbackRequest], pickle_dags: bool = False, - session: Session = NEW_SESSION, ) -> tuple[int, int, int]: """ Process a Python file containing Airflow DAGs. @@ -833,15 +876,9 @@ def process_file( :param callback_requests: failure callback to execute :param pickle_dags: whether serialize the DAGs found in the file and save them to the db - :param session: Sqlalchemy ORM Session :return: number of dags found, count of import errors, last number of db queries """ self.log.info("Processing file %s for tasks to queue", file_path) - - @event.listens_for(session, "do_orm_execute") - def _count_db_queries(orm_execute_state): - self._last_num_of_db_queries += 1 - try: dagbag = DagFileProcessor._get_dagbag(file_path) except Exception: @@ -853,21 +890,21 @@ def _count_db_queries(orm_execute_state): self.log.info("DAG(s) %s retrieved from %s", ", ".join(map(repr, dagbag.dags)), file_path) else: self.log.warning("No viable dags retrieved from %s", file_path) - DagFileProcessor.update_import_errors( + self._last_num_of_db_queries += DagFileProcessor.update_import_errors( file_last_changed=dagbag.file_last_changed, import_errors=dagbag.import_errors, processor_subdir=self._dag_directory, - session=session, ) if callback_requests: # If there were callback requests for this file but there was a # parse error we still need to progress the state of TIs, # otherwise they might be stuck in queued/running for ever! - self.execute_callbacks_without_dag(callback_requests, session) + self._last_num_of_db_queries += DagFileProcessor.execute_callbacks_without_dag( + callback_requests, self.UNIT_TEST_MODE + ) return 0, len(dagbag.import_errors), self._last_num_of_db_queries - self.execute_callbacks(dagbag, callback_requests, session) - session.commit() + self._last_num_of_db_queries += self.execute_callbacks(dagbag, callback_requests, self.UNIT_TEST_MODE) serialize_errors = DagFileProcessor.save_dag_to_db( dags=dagbag.dags, @@ -879,18 +916,17 @@ def _count_db_queries(orm_execute_state): # Record import errors into the ORM try: - DagFileProcessor.update_import_errors( + self._last_num_of_db_queries += DagFileProcessor.update_import_errors( file_last_changed=dagbag.file_last_changed, import_errors=dagbag.import_errors, processor_subdir=self._dag_directory, - session=session, ) except Exception: self.log.exception("Error logging import errors!") # Record DAG warnings in the metadatabase. try: - self.update_dag_warnings(session=session, dagbag=dagbag) + self._last_num_of_db_queries += self.update_dag_warnings(dagbag=dagbag) except Exception: self.log.exception("Error logging DAG warnings.") diff --git a/airflow/models/dagcode.py b/airflow/models/dagcode.py index a8c45fe492065..321f819999bf6 100644 --- a/airflow/models/dagcode.py +++ b/airflow/models/dagcode.py @@ -26,6 +26,7 @@ from sqlalchemy.dialects.mysql import MEDIUMTEXT from sqlalchemy.sql.expression import literal +from airflow.api_internal.internal_api_call import internal_api_call from airflow.exceptions import AirflowException, DagCodeNotFound from airflow.models.base import Base from airflow.utils import timezone @@ -129,6 +130,7 @@ def bulk_sync_to_db(cls, filelocs: Iterable[str], session: Session = NEW_SESSION session.merge(orm_dag_code) @classmethod + @internal_api_call @provide_session def remove_deleted_code( cls, diff --git a/airflow/models/serialized_dag.py b/airflow/models/serialized_dag.py index 4655057880a42..646e1c61f1cd6 100644 --- a/airflow/models/serialized_dag.py +++ b/airflow/models/serialized_dag.py @@ -245,6 +245,7 @@ def remove_dag(cls, dag_id: str, session: Session = NEW_SESSION) -> None: session.execute(cls.__table__.delete().where(cls.dag_id == dag_id)) @classmethod + @internal_api_call @provide_session def remove_deleted_dags( cls, diff --git a/airflow/settings.py b/airflow/settings.py index 5e797b268a47e..81eebc9d01f8a 100644 --- a/airflow/settings.py +++ b/airflow/settings.py @@ -319,6 +319,11 @@ def configure_orm(disable_connection_pool=False, pool_class=None): Session = SkipDBTestsSession engine = None return + if conf.get("database", "sql_alchemy_conn") == "none://": + from airflow.api_internal.internal_api_call import InternalApiConfig + + InternalApiConfig.set_use_internal_api("ORM reconfigured in forked process.") + return log.debug("Setting up DB connection pool (PID %s)", os.getpid()) engine_args = prepare_engine_args(disable_connection_pool, pool_class) @@ -344,9 +349,14 @@ def configure_orm(disable_connection_pool=False, pool_class=None): def force_traceback_session_for_untrusted_components(): + log.info("Forcing TracebackSession for untrusted components.") global Session global engine - dispose_orm() + try: + dispose_orm() + except NameError: + # This exception might be thrown in case the ORM has not been initialized yet. + pass Session = TracebackSession engine = None diff --git a/airflow/task/task_runner/standard_task_runner.py b/airflow/task/task_runner/standard_task_runner.py index 5f0606fe5e245..6a4351e17a5b5 100644 --- a/airflow/task/task_runner/standard_task_runner.py +++ b/airflow/task/task_runner/standard_task_runner.py @@ -29,7 +29,6 @@ import psutil from setproctitle import setproctitle -from airflow.api_internal.internal_api_call import InternalApiConfig from airflow.models.taskinstance import TaskReturnCode from airflow.settings import CAN_FORK from airflow.stats import Stats @@ -73,6 +72,11 @@ def _start_by_fork(self): self.log.info("Started process %d to run task", pid) return psutil.Process(pid) else: + from airflow.api_internal.internal_api_call import InternalApiConfig + from airflow.configuration import conf + + if conf.getboolean("core", "database_access_isolation", fallback=False): + InternalApiConfig.set_use_internal_api("Forked task runner") # Start a new process group set_new_process_group() diff --git a/airflow/utils/retries.py b/airflow/utils/retries.py index e62e885307bfe..809d176ef6c8e 100644 --- a/airflow/utils/retries.py +++ b/airflow/utils/retries.py @@ -76,8 +76,12 @@ def retry_decorator(func: Callable) -> Callable: @functools.wraps(func) def wrapped_function(*args, **kwargs): - logger = args[0].log if args and hasattr(args[0], "log") else logging.getLogger(func.__module__) - + if args and hasattr(args[0], "logger"): + logger = args[0].logger() + elif args and hasattr(args[0], "log"): + logger = args[0].log + else: + logger = logging.getLogger(func.__module__) # Get session from args or kwargs if "session" in kwargs: session = kwargs["session"] diff --git a/airflow/www/app.py b/airflow/www/app.py index c496e314dc458..e093e66cfd881 100644 --- a/airflow/www/app.py +++ b/airflow/www/app.py @@ -137,7 +137,8 @@ def create_app(config=None, testing=False): flask_app.json_provider_class = AirflowJsonProvider flask_app.json = AirflowJsonProvider(flask_app) - InternalApiConfig.force_database_direct_access("Gunicorn worker initialization") + if conf.getboolean("core", "database_access_isolation", fallback=False): + InternalApiConfig.set_use_database_access("Gunicorn worker initialization") csrf.init_app(flask_app) diff --git a/scripts/in_container/bin/run_tmux b/scripts/in_container/bin/run_tmux index 3ca2ddc3eed8d..40fc695a643b0 100755 --- a/scripts/in_container/bin/run_tmux +++ b/scripts/in_container/bin/run_tmux @@ -73,20 +73,20 @@ if [[ ${INTEGRATION_CELERY} == "true" ]]; then tmux split-window -h tmux send-keys 'airflow celery worker' C-m fi -if [[ ${DATABASE_ISOLATION=} == "true" ]]; then - tmux select-pane -t 0 +if [[ ${INTEGRATION_CELERY} == "true" && ${CELERY_FLOWER} == "true" ]]; then + tmux select-pane -t 3 tmux split-window -h - tmux send-keys 'airflow internal-api' C-m + tmux send-keys 'airflow celery flower' C-m fi if [[ ${STANDALONE_DAG_PROCESSOR} == "true" ]]; then tmux select-pane -t 3 tmux split-window -h tmux send-keys 'airflow dag-processor' C-m fi -if [[ ${INTEGRATION_CELERY} == "true" && ${CELERY_FLOWER} == "true" ]]; then - tmux select-pane -t 3 +if [[ ${DATABASE_ISOLATION=} == "true" ]]; then + tmux select-pane -t 0 tmux split-window -h - tmux send-keys 'airflow celery flower' C-m + tmux send-keys 'airflow internal-api' C-m fi # Attach Session, on the Main window diff --git a/tests/api_internal/test_internal_api_call.py b/tests/api_internal/test_internal_api_call.py index 896e88d77e824..d779b504ea479 100644 --- a/tests/api_internal/test_internal_api_call.py +++ b/tests/api_internal/test_internal_api_call.py @@ -19,13 +19,16 @@ from __future__ import annotations import json +from argparse import Namespace from typing import TYPE_CHECKING from unittest import mock import pytest import requests +from airflow.__main__ import configure_internal_api from airflow.api_internal.internal_api_call import InternalApiConfig, internal_api_call +from airflow.configuration import conf from airflow.models.taskinstance import TaskInstance from airflow.operators.empty import EmptyOperator from airflow.serialization.serialized_objects import BaseSerialization @@ -41,7 +44,21 @@ @pytest.fixture(autouse=True) def reset_init_api_config(): - InternalApiConfig._initialized = False + InternalApiConfig._use_internal_api = False + InternalApiConfig._internal_api_endpoint = "" + from airflow import settings + + old_engine = settings.engine + old_session = settings.Session + old_conn = settings.SQL_ALCHEMY_CONN + try: + yield + finally: + InternalApiConfig._use_internal_api = False + InternalApiConfig._internal_api_endpoint = "" + settings.engine = old_engine + settings.Session = old_session + settings.SQL_ALCHEMY_CONN = old_conn @pytest.mark.skipif(not _ENABLE_AIP_44, reason="AIP-44 is disabled") @@ -50,18 +67,22 @@ class TestInternalApiConfig: { ("core", "database_access_isolation"): "false", ("core", "internal_api_url"): "http://localhost:8888", + ("database", "sql_alchemy_conn"): "none://", } ) def test_get_use_internal_api_disabled(self): + configure_internal_api(Namespace(subcommand="webserver"), conf) assert InternalApiConfig.get_use_internal_api() is False @conf_vars( { ("core", "database_access_isolation"): "true", ("core", "internal_api_url"): "http://localhost:8888", + ("database", "sql_alchemy_conn"): "none://", } ) def test_get_use_internal_api_enabled(self): + configure_internal_api(Namespace(subcommand="dag-processor"), conf) assert InternalApiConfig.get_use_internal_api() is True assert InternalApiConfig.get_internal_api_endpoint() == "http://localhost:8888/internal_api/v1/rpcapi" @@ -72,7 +93,7 @@ def test_get_use_internal_api_enabled(self): } ) def test_force_database_direct_access(self): - InternalApiConfig.force_database_direct_access("message") + InternalApiConfig.set_use_database_access("message") assert InternalApiConfig.get_use_internal_api() is False @@ -118,10 +139,12 @@ def test_local_call(self, mock_requests): { ("core", "database_access_isolation"): "true", ("core", "internal_api_url"): "http://localhost:8888", + ("database", "sql_alchemy_conn"): "none://", } ) @mock.patch("airflow.api_internal.internal_api_call.requests") def test_remote_call(self, mock_requests): + configure_internal_api(Namespace(subcommand="dag-processor"), conf) response = requests.Response() response.status_code = 200 @@ -149,10 +172,12 @@ def test_remote_call(self, mock_requests): { ("core", "database_access_isolation"): "true", ("core", "internal_api_url"): "http://localhost:8888", + ("database", "sql_alchemy_conn"): "none://", } ) @mock.patch("airflow.api_internal.internal_api_call.requests") def test_remote_call_with_none_result(self, mock_requests): + configure_internal_api(Namespace(subcommand="dag-processor"), conf) response = requests.Response() response.status_code = 200 response._content = b"" @@ -166,10 +191,12 @@ def test_remote_call_with_none_result(self, mock_requests): { ("core", "database_access_isolation"): "true", ("core", "internal_api_url"): "http://localhost:8888", + ("database", "sql_alchemy_conn"): "none://", } ) @mock.patch("airflow.api_internal.internal_api_call.requests") def test_remote_call_with_params(self, mock_requests): + configure_internal_api(Namespace(subcommand="dag-processor"), conf) response = requests.Response() response.status_code = 200 @@ -204,10 +231,12 @@ def test_remote_call_with_params(self, mock_requests): { ("core", "database_access_isolation"): "true", ("core", "internal_api_url"): "http://localhost:8888", + ("database", "sql_alchemy_conn"): "none://", } ) @mock.patch("airflow.api_internal.internal_api_call.requests") def test_remote_classmethod_call_with_params(self, mock_requests): + configure_internal_api(Namespace(subcommand="dag-processor"), conf) response = requests.Response() response.status_code = 200 @@ -241,10 +270,12 @@ def test_remote_classmethod_call_with_params(self, mock_requests): { ("core", "database_access_isolation"): "true", ("core", "internal_api_url"): "http://localhost:8888", + ("database", "sql_alchemy_conn"): "none://", } ) @mock.patch("airflow.api_internal.internal_api_call.requests") def test_remote_call_with_serialized_model(self, mock_requests): + configure_internal_api(Namespace(subcommand="dag-processor"), conf) response = requests.Response() response.status_code = 200 diff --git a/tests/cli/commands/test_internal_api_command.py b/tests/cli/commands/test_internal_api_command.py index 9de857588a3fc..99992e6266861 100644 --- a/tests/cli/commands/test_internal_api_command.py +++ b/tests/cli/commands/test_internal_api_command.py @@ -32,6 +32,7 @@ from airflow.cli.commands.internal_api_command import GunicornMonitor from airflow.settings import _ENABLE_AIP_44 from tests.cli.commands._common_cli_classes import _ComonCLIGunicornTestClass +from tests.test_utils.config import conf_vars console = Console(width=400, color_system="standard") @@ -99,6 +100,8 @@ def test_cli_internal_api_background(self, tmp_path): try: # Run internal-api as daemon in background. Note that the wait method is not called. console.print("[magenta]Starting airflow internal-api --daemon") + env = os.environ.copy() + env["AIRFLOW__CORE__DATABASE_ACCESS_ISOLATION"] = "true" proc = subprocess.Popen( [ "airflow", @@ -112,7 +115,8 @@ def test_cli_internal_api_background(self, tmp_path): os.fspath(stderr), "--log-file", os.fspath(logfile), - ] + ], + env=env, ) assert proc.poll() is None @@ -150,6 +154,7 @@ def test_cli_internal_api_background(self, tmp_path): console.print(file.read_text()) raise + @conf_vars({("core", "database_access_isolation"): "true"}) def test_cli_internal_api_debug(self, app): with mock.patch( "airflow.cli.commands.internal_api_command.create_app", return_value=app @@ -169,6 +174,7 @@ def test_cli_internal_api_debug(self, app): host="0.0.0.0", ) + @conf_vars({("core", "database_access_isolation"): "true"}) def test_cli_internal_api_args(self): with mock.patch("subprocess.Popen") as Popen, mock.patch.object( internal_api_command, "GunicornMonitor" diff --git a/tests/conftest.py b/tests/conftest.py index 6cb74446dce8b..284341e5fd826 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1200,9 +1200,9 @@ def suppress_info_logs_for_dag_and_fab(): @pytest.fixture(scope="module", autouse=True) def _clear_db(request): + """Clear DB before each test module run.""" from tests.test_utils.db import clear_all - """Clear DB before each test module run.""" if not request.config.option.db_cleanup: return if skip_db_tests: @@ -1220,7 +1220,6 @@ def _clear_db(request): if dist_option != "no" or hasattr(request.config, "workerinput"): # Skip if pytest-xdist detected (controller or worker) return - try: clear_all() except Exception as ex: diff --git a/tests/core/test_settings.py b/tests/core/test_settings.py index dd264b61569e0..d05344bfa91d8 100644 --- a/tests/core/test_settings.py +++ b/tests/core/test_settings.py @@ -21,12 +21,15 @@ import os import sys import tempfile +from argparse import Namespace from unittest import mock from unittest.mock import MagicMock, call, patch import pytest +from airflow.__main__ import configure_internal_api from airflow.api_internal.internal_api_call import InternalApiConfig +from airflow.configuration import conf from airflow.exceptions import AirflowClusterPolicyViolation, AirflowConfigException from airflow.settings import _ENABLE_AIP_44, TracebackSession, is_usage_data_collection_enabled from airflow.utils.session import create_session @@ -67,12 +70,21 @@ def task_must_have_owners(task: BaseOperator): @pytest.fixture def clear_internal_api(): + InternalApiConfig._use_internal_api = False + InternalApiConfig._internal_api_endpoint = "" + from airflow import settings + + old_engine = settings.engine + old_session = settings.Session + old_conn = settings.SQL_ALCHEMY_CONN try: yield finally: - InternalApiConfig._initialized = False - InternalApiConfig._use_internal_api = None - InternalApiConfig._internal_api_endpoint = None + InternalApiConfig._use_internal_api = False + InternalApiConfig._internal_api_endpoint = "" + settings.engine = old_engine + settings.Session = old_session + settings.SQL_ALCHEMY_CONN = old_conn class SettingsContext: @@ -294,11 +306,11 @@ def test_encoding_absent_in_v2(is_v1, mock_conf): { ("core", "database_access_isolation"): "true", ("core", "internal_api_url"): "http://localhost:8888", + ("database", "sql_alchemy_conn"): "none://", } ) def test_get_traceback_session_if_aip_44_enabled(clear_internal_api): - # ensure we take the database_access_isolation config - InternalApiConfig._init_values() + configure_internal_api(Namespace(subcommand="worker"), conf) assert InternalApiConfig.get_use_internal_api() is True with create_session() as session: @@ -319,14 +331,15 @@ def test_get_traceback_session_if_aip_44_enabled(clear_internal_api): { ("core", "database_access_isolation"): "true", ("core", "internal_api_url"): "http://localhost:8888", + ("database", "sql_alchemy_conn"): "none://", } ) @patch("airflow.utils.session.TracebackSession.__new__") def test_create_session_ctx_mgr_no_call_methods(mock_new, clear_internal_api): + configure_internal_api(Namespace(subcommand="worker"), conf) m = MagicMock() mock_new.return_value = m - # ensure we take the database_access_isolation config - InternalApiConfig._init_values() + assert InternalApiConfig.get_use_internal_api() is True with create_session() as session: @@ -348,7 +361,7 @@ def test_create_session_ctx_mgr_no_call_methods(mock_new, clear_internal_api): (None, "False", False), # Default env, conf disables ], ) -def test_usage_data_collection_disabled(env_var, conf_setting, is_enabled): +def test_usage_data_collection_disabled(env_var, conf_setting, is_enabled, clear_internal_api): conf_patch = conf_vars({("usage_data_collection", "enabled"): conf_setting}) if env_var is not None: diff --git a/tests/core/test_sqlalchemy_config.py b/tests/core/test_sqlalchemy_config.py index d0fd77654107c..068f803a6f015 100644 --- a/tests/core/test_sqlalchemy_config.py +++ b/tests/core/test_sqlalchemy_config.py @@ -23,6 +23,7 @@ from sqlalchemy.pool import NullPool from airflow import settings +from airflow.api_internal.internal_api_call import InternalApiConfig from airflow.exceptions import AirflowConfigException from tests.test_utils.config import conf_vars @@ -36,12 +37,16 @@ def setup_method(self): self.old_engine = settings.engine self.old_session = settings.Session self.old_conn = settings.SQL_ALCHEMY_CONN + InternalApiConfig._use_internal_api = False + InternalApiConfig._internal_api_endpoint = "" settings.SQL_ALCHEMY_CONN = "mysql+foobar://user:pass@host/dbname?inline=param&another=param" def teardown_method(self): settings.engine = self.old_engine settings.Session = self.old_session settings.SQL_ALCHEMY_CONN = self.old_conn + InternalApiConfig._use_internal_api = False + InternalApiConfig._internal_api_endpoint = "" @patch("airflow.settings.setup_event_handlers") @patch("airflow.settings.scoped_session") diff --git a/tests/dag_processing/test_processor.py b/tests/dag_processing/test_processor.py index 124a13ff11573..2cc0067cc67ac 100644 --- a/tests/dag_processing/test_processor.py +++ b/tests/dag_processing/test_processor.py @@ -111,7 +111,7 @@ def _process_file(self, file_path, dag_directory, session): dag_ids=[], dag_directory=str(dag_directory), log=mock.MagicMock() ) - dag_file_processor.process_file(file_path, [], False, session) + dag_file_processor.process_file(file_path, [], False) @mock.patch("airflow.dag_processing.processor.DagFileProcessor._get_dagbag") def test_dag_file_processor_sla_miss_callback(self, mock_get_dagbag, create_dummy_dag, get_test_dag): @@ -508,7 +508,7 @@ def test_execute_on_failure_callbacks(self, mock_ti_handle_failure): full_filepath="A", simple_task_instance=SimpleTaskInstance.from_ti(ti), msg="Message" ) ] - dag_file_processor.execute_callbacks(dagbag, requests, session) + dag_file_processor.execute_callbacks(dagbag, requests, dag_file_processor.UNIT_TEST_MODE, session) mock_ti_handle_failure.assert_called_once_with( error="Message", test_mode=conf.getboolean("core", "unit_test_mode"), session=session ) @@ -546,7 +546,7 @@ def test_execute_on_failure_callbacks_without_dag(self, mock_ti_handle_failure, full_filepath="A", simple_task_instance=SimpleTaskInstance.from_ti(ti), msg="Message" ) ] - dag_file_processor.execute_callbacks_without_dag(requests, session) + dag_file_processor.execute_callbacks_without_dag(requests, True, session) mock_ti_handle_failure.assert_called_once_with( error="Message", test_mode=conf.getboolean("core", "unit_test_mode"), session=session ) @@ -577,7 +577,7 @@ def test_failure_callbacks_should_not_drop_hostname(self): full_filepath="A", simple_task_instance=SimpleTaskInstance.from_ti(ti), msg="Message" ) ] - dag_file_processor.execute_callbacks(dagbag, requests) + dag_file_processor.execute_callbacks(dagbag, requests, False) with create_session() as session: tis = session.query(TaskInstance) @@ -611,7 +611,7 @@ def test_process_file_should_failure_callback(self, monkeypatch, tmp_path, get_t msg="Message", ) ] - dag_file_processor.process_file(dag.fileloc, requests, session=session) + dag_file_processor.process_file(dag.fileloc, requests) ti.refresh_from_db() msg = " ".join([str(k) for k in ti.key.primary]) + " fired callback" diff --git a/tests/listeners/test_dag_import_error_listener.py b/tests/listeners/test_dag_import_error_listener.py index 0417ae24f56c2..ce6aae01518be 100644 --- a/tests/listeners/test_dag_import_error_listener.py +++ b/tests/listeners/test_dag_import_error_listener.py @@ -100,7 +100,7 @@ def _process_file(self, file_path, dag_directory, session): dag_ids=[], dag_directory=str(dag_directory), log=mock.MagicMock() ) - dag_file_processor.process_file(file_path, [], False, session) + dag_file_processor.process_file(file_path, [], False) def test_newly_added_import_error(self, tmp_path, session): dag_import_error_listener.clear() diff --git a/tests/www/views/test_views_home.py b/tests/www/views/test_views_home.py index cffe0844b012e..275c585170c3f 100644 --- a/tests/www/views/test_views_home.py +++ b/tests/www/views/test_views_home.py @@ -24,7 +24,6 @@ from airflow.dag_processing.processor import DagFileProcessor from airflow.security import permissions -from airflow.utils.session import create_session from airflow.utils.state import State from airflow.www.utils import UIAlert from airflow.www.views import FILTER_LASTRUN_COOKIE, FILTER_STATUS_COOKIE, FILTER_TAGS_COOKIE @@ -192,20 +191,18 @@ def client_single_dag_edit(app, user_single_dag_edit): TEST_TAGS = ["example", "test", "team", "group"] -def _process_file(file_path, session): +def _process_file(file_path): dag_file_processor = DagFileProcessor(dag_ids=[], dag_directory="/tmp", log=mock.MagicMock()) - dag_file_processor.process_file(file_path, [], False, session) + dag_file_processor.process_file(file_path, [], False) @pytest.fixture def working_dags(tmp_path): dag_contents_template = "from airflow import DAG\ndag = DAG('{}', tags=['{}'])" - - with create_session() as session: - for dag_id, tag in zip(TEST_FILTER_DAG_IDS, TEST_TAGS): - path = tmp_path / f"{dag_id}.py" - path.write_text(dag_contents_template.format(dag_id, tag)) - _process_file(path, session) + for dag_id, tag in zip(TEST_FILTER_DAG_IDS, TEST_TAGS): + path = tmp_path / f"{dag_id}.py" + path.write_text(dag_contents_template.format(dag_id, tag)) + _process_file(path) @pytest.fixture @@ -215,14 +212,13 @@ def working_dags_with_read_perm(tmp_path): "from airflow import DAG\ndag = DAG('{}', tags=['{}'], " "access_control={{'role_single_dag':{{'can_read'}}}}) " ) - with create_session() as session: - for dag_id, tag in zip(TEST_FILTER_DAG_IDS, TEST_TAGS): - path = tmp_path / f"{dag_id}.py" - if dag_id == "filter_test_1": - path.write_text(dag_contents_template_with_read_perm.format(dag_id, tag)) - else: - path.write_text(dag_contents_template.format(dag_id, tag)) - _process_file(path, session) + for dag_id, tag in zip(TEST_FILTER_DAG_IDS, TEST_TAGS): + path = tmp_path / f"{dag_id}.py" + if dag_id == "filter_test_1": + path.write_text(dag_contents_template_with_read_perm.format(dag_id, tag)) + else: + path.write_text(dag_contents_template.format(dag_id, tag)) + _process_file(path) @pytest.fixture @@ -232,50 +228,44 @@ def working_dags_with_edit_perm(tmp_path): "from airflow import DAG\ndag = DAG('{}', tags=['{}'], " "access_control={{'role_single_dag':{{'can_edit'}}}}) " ) - with create_session() as session: - for dag_id, tag in zip(TEST_FILTER_DAG_IDS, TEST_TAGS): - path = tmp_path / f"{dag_id}.py" - if dag_id == "filter_test_1": - path.write_text(dag_contents_template_with_read_perm.format(dag_id, tag)) - else: - path.write_text(dag_contents_template.format(dag_id, tag)) - _process_file(path, session) + for dag_id, tag in zip(TEST_FILTER_DAG_IDS, TEST_TAGS): + path = tmp_path / f"{dag_id}.py" + if dag_id == "filter_test_1": + path.write_text(dag_contents_template_with_read_perm.format(dag_id, tag)) + else: + path.write_text(dag_contents_template.format(dag_id, tag)) + _process_file(path) @pytest.fixture def broken_dags(tmp_path, working_dags): - with create_session() as session: - for dag_id in TEST_FILTER_DAG_IDS: - path = tmp_path / f"{dag_id}.py" - path.write_text("airflow DAG") - _process_file(path, session) + for dag_id in TEST_FILTER_DAG_IDS: + path = tmp_path / f"{dag_id}.py" + path.write_text("airflow DAG") + _process_file(path) @pytest.fixture def broken_dags_with_read_perm(tmp_path, working_dags_with_read_perm): - with create_session() as session: - for dag_id in TEST_FILTER_DAG_IDS: - path = tmp_path / f"{dag_id}.py" - path.write_text("airflow DAG") - _process_file(path, session) + for dag_id in TEST_FILTER_DAG_IDS: + path = tmp_path / f"{dag_id}.py" + path.write_text("airflow DAG") + _process_file(path) @pytest.fixture def broken_dags_after_working(tmp_path): # First create and process a DAG file that works path = tmp_path / "all_in_one.py" - with create_session() as session: - contents = "from airflow import DAG\n" - for i, dag_id in enumerate(TEST_FILTER_DAG_IDS): - contents += f"dag{i} = DAG('{dag_id}')\n" - path.write_text(contents) - _process_file(path, session) - - # Then break it! - with create_session() as session: - contents += "foobar()" - path.write_text(contents) - _process_file(path, session) + contents = "from airflow import DAG\n" + for i, dag_id in enumerate(TEST_FILTER_DAG_IDS): + contents += f"dag{i} = DAG('{dag_id}')\n" + path.write_text(contents) + _process_file(path) + + contents += "foobar()" + path.write_text(contents) + _process_file(path) def test_home_filter_tags(working_dags, admin_client):