Skip to content

Commit

Permalink
Introduce decorator to load providers configuration (#32765)
Browse files Browse the repository at this point in the history
A number of commands in Airflow relies on the fact that
providers configuration is loaded. This is a rather fast operation
as it does not involve any importing of provider classes, just
discovering entrypoints, running them and parsing yaml configuration,
so it is a very low sub-second time to do it.

We cannot do it once in settings/config because we actually need
settings/config to be pre-initialized without providers in order
to be able to bootstrap airflow, therefore we need to run
it individually in each command that can be run with the
"airflow" entrypoint. Decorator seems to be best suited to do
the job:

* easy to apply and not easy to forget when you create another
  command and look at other commands
* nicely wraps around local ProvidersManager import

There are exceptions for the "version" and "providers lazy-loaded"
commands because they are NOT supposed to initialize configuration
of providers.
  • Loading branch information
potiuk committed Jul 22, 2023
1 parent accdb0b commit 56c41d4
Show file tree
Hide file tree
Showing 33 changed files with 250 additions and 35 deletions.
11 changes: 6 additions & 5 deletions airflow/__main__.py
Expand Up @@ -46,12 +46,13 @@ def main():
argcomplete.autocomplete(parser)
args = parser.parse_args()

# Here we ensure that the default configuration is written if needed before running any command
# that might need it. This used to be done during configuration initialization but having it
# in main ensures that it is not done during tests and other ways airflow imports are used
from airflow.configuration import write_default_airflow_configuration_if_needed
if args.subcommand not in ["lazy_loaded", "version"]:
# Here we ensure that the default configuration is written if needed before running any command
# that might need it. This used to be done during configuration initialization but having it
# in main ensures that it is not done during tests and other ways airflow imports are used
from airflow.configuration import write_default_airflow_configuration_if_needed

write_default_airflow_configuration_if_needed()
write_default_airflow_configuration_if_needed()
args.func(args)


Expand Down
6 changes: 3 additions & 3 deletions airflow/cli/cli_config.py
Expand Up @@ -1872,9 +1872,9 @@ class GroupCommand(NamedTuple):
args=(ARG_OUTPUT, ARG_VERBOSE),
),
ActionCommand(
name="status",
help="Get information about provider initialization status",
func=lazy_load_command("airflow.cli.commands.provider_command.status"),
name="lazy-loaded",
help="Checks that provider configuration is lazy loaded",
func=lazy_load_command("airflow.cli.commands.provider_command.lazy_loaded"),
args=(ARG_VERBOSE,),
),
)
Expand Down
5 changes: 5 additions & 0 deletions airflow/cli/commands/celery_command.py
Expand Up @@ -36,12 +36,14 @@
from airflow.configuration import conf
from airflow.utils import cli as cli_utils
from airflow.utils.cli import setup_locations, setup_logging
from airflow.utils.providers_configuration_loader import providers_configuration_loaded
from airflow.utils.serve_logs import serve_logs

WORKER_PROCESS_NAME = "worker"


@cli_utils.action_cli
@providers_configuration_loaded
def flower(args):
"""Starts Flower, Celery monitoring tool."""
# This needs to be imported locally to not trigger Providers Manager initialization
Expand Down Expand Up @@ -103,6 +105,7 @@ def _serve_logs(skip_serve_logs: bool = False):


@after_setup_logger.connect()
@providers_configuration_loaded
def logger_setup_handler(logger, **kwargs):
"""
Reconfigure the logger.
Expand Down Expand Up @@ -132,6 +135,7 @@ def filter(self, record):


@cli_utils.action_cli
@providers_configuration_loaded
def worker(args):
"""Starts Airflow Celery worker."""
# This needs to be imported locally to not trigger Providers Manager initialization
Expand Down Expand Up @@ -239,6 +243,7 @@ def worker(args):


@cli_utils.action_cli
@providers_configuration_loaded
def stop_worker(args):
"""Sends SIGTERM to Celery worker."""
# Read PID from file
Expand Down
6 changes: 3 additions & 3 deletions airflow/cli/commands/config_command.py
Expand Up @@ -25,8 +25,10 @@
from airflow.configuration import conf
from airflow.utils.cli import should_use_colors
from airflow.utils.code_utils import get_terminal_formatter
from airflow.utils.providers_configuration_loader import providers_configuration_loaded


@providers_configuration_loaded
def show_config(args):
"""Show current application configuration."""
with io.StringIO() as output:
Expand All @@ -47,16 +49,14 @@ def show_config(args):
print(code)


@providers_configuration_loaded
def get_value(args):
"""Get one value from configuration."""
# while this will make get_value quite a bit slower we must initialize configuration
# for providers because we do not know what sections and options will be available after
# providers are initialized. Theoretically Providers might add new sections and options
# but also override defaults for existing options, so without loading all providers we
# cannot be sure what is the final value of the option.
from airflow.providers_manager import ProvidersManager

ProvidersManager().initialize_providers_configuration()
if not conf.has_option(args.section, args.option):
raise SystemExit(f"The option [{args.section}/{args.option}] is not found in config.")

Expand Down
8 changes: 8 additions & 0 deletions airflow/cli/commands/connection_command.py
Expand Up @@ -39,6 +39,7 @@
from airflow.secrets.local_filesystem import load_connections_dict
from airflow.utils import cli as cli_utils, helpers, yaml
from airflow.utils.cli import suppress_logs_and_warning
from airflow.utils.providers_configuration_loader import providers_configuration_loaded
from airflow.utils.session import create_session


Expand All @@ -61,6 +62,7 @@ def _connection_mapper(conn: Connection) -> dict[str, Any]:


@suppress_logs_and_warning
@providers_configuration_loaded
def connections_get(args):
"""Get a connection."""
try:
Expand All @@ -75,6 +77,7 @@ def connections_get(args):


@suppress_logs_and_warning
@providers_configuration_loaded
def connections_list(args):
"""Lists all connections at the command line."""
with create_session() as session:
Expand Down Expand Up @@ -150,6 +153,7 @@ def _get_connection_types() -> list[str]:
return _connection_types


@providers_configuration_loaded
def connections_export(args):
"""Exports all connections to a file."""
file_formats = [".yaml", ".json", ".env"]
Expand Down Expand Up @@ -200,6 +204,7 @@ def connections_export(args):


@cli_utils.action_cli
@providers_configuration_loaded
def connections_add(args):
"""Adds new connection."""
has_uri = bool(args.conn_uri)
Expand Down Expand Up @@ -291,6 +296,7 @@ def connections_add(args):


@cli_utils.action_cli
@providers_configuration_loaded
def connections_delete(args):
"""Deletes connection from DB."""
with create_session() as session:
Expand All @@ -306,6 +312,7 @@ def connections_delete(args):


@cli_utils.action_cli(check_db=False)
@providers_configuration_loaded
def connections_import(args):
"""Imports connections from a file."""
if os.path.exists(args.file):
Expand Down Expand Up @@ -343,6 +350,7 @@ def _import_helper(file_path: str, overwrite: bool) -> None:


@suppress_logs_and_warning
@providers_configuration_loaded
def connections_test(args) -> None:
"""Test an Airflow connection."""
console = AirflowConsole()
Expand Down
23 changes: 21 additions & 2 deletions airflow/cli/commands/dag_command.py
Expand Up @@ -45,6 +45,7 @@
from airflow.utils import cli as cli_utils, timezone
from airflow.utils.cli import get_dag, get_dags, process_subdir, sigint_handler, suppress_logs_and_warning
from airflow.utils.dot_renderer import render_dag, render_dag_dependencies
from airflow.utils.providers_configuration_loader import providers_configuration_loaded
from airflow.utils.session import NEW_SESSION, create_session, provide_session
from airflow.utils.state import DagRunState

Expand Down Expand Up @@ -120,6 +121,7 @@ def _run_dag_backfill(dags: list[DAG], args) -> None:


@cli_utils.action_cli
@providers_configuration_loaded
def dag_backfill(args, dag: list[DAG] | DAG | None = None) -> None:
"""Creates backfill job or dry run for a DAG or list of DAGs using regex."""
logging.basicConfig(level=settings.LOGGING_LEVEL, format=settings.SIMPLE_LOG_FORMAT)
Expand Down Expand Up @@ -150,6 +152,7 @@ def dag_backfill(args, dag: list[DAG] | DAG | None = None) -> None:


@cli_utils.action_cli
@providers_configuration_loaded
def dag_trigger(args) -> None:
"""Creates a dag run for the specified dag."""
api_client = get_current_api_client()
Expand All @@ -170,6 +173,7 @@ def dag_trigger(args) -> None:


@cli_utils.action_cli
@providers_configuration_loaded
def dag_delete(args) -> None:
"""Deletes all DB records related to the specified dag."""
api_client = get_current_api_client()
Expand All @@ -188,17 +192,20 @@ def dag_delete(args) -> None:


@cli_utils.action_cli
@providers_configuration_loaded
def dag_pause(args) -> None:
"""Pauses a DAG."""
set_is_paused(True, args)


@cli_utils.action_cli
@providers_configuration_loaded
def dag_unpause(args) -> None:
"""Unpauses a DAG."""
set_is_paused(False, args)


@providers_configuration_loaded
def set_is_paused(is_paused: bool, args) -> None:
"""Sets is_paused for DAG by a given dag_id."""
dag = DagModel.get_dagmodel(args.dag_id)
Expand All @@ -211,6 +218,7 @@ def set_is_paused(is_paused: bool, args) -> None:
print(f"Dag: {args.dag_id}, paused: {is_paused}")


@providers_configuration_loaded
def dag_dependencies_show(args) -> None:
"""Displays DAG dependencies, save to file or show as imgcat image."""
dot = render_dag_dependencies(SerializedDagModel.get_dag_dependencies())
Expand All @@ -230,6 +238,7 @@ def dag_dependencies_show(args) -> None:
print(dot.source)


@providers_configuration_loaded
def dag_show(args) -> None:
"""Displays DAG or saves it's graphic representation to the file."""
dag = get_dag(args.subdir, args.dag_id)
Expand Down Expand Up @@ -273,6 +282,7 @@ def _save_dot_to_file(dot: Dot, filename: str) -> None:


@cli_utils.action_cli
@providers_configuration_loaded
@provide_session
def dag_state(args, session: Session = NEW_SESSION) -> None:
"""
Expand All @@ -296,6 +306,7 @@ def dag_state(args, session: Session = NEW_SESSION) -> None:


@cli_utils.action_cli
@providers_configuration_loaded
def dag_next_execution(args) -> None:
"""
Returns the next execution datetime of a DAG at the command line.
Expand Down Expand Up @@ -335,6 +346,7 @@ def print_execution_interval(interval: DataInterval | None):

@cli_utils.action_cli
@suppress_logs_and_warning
@providers_configuration_loaded
def dag_list_dags(args) -> None:
"""Displays dags with or without stats at the command line."""
dagbag = DagBag(process_subdir(args.subdir))
Expand All @@ -360,6 +372,7 @@ def dag_list_dags(args) -> None:

@cli_utils.action_cli
@suppress_logs_and_warning
@providers_configuration_loaded
@provide_session
def dag_details(args, session=NEW_SESSION):
"""Get DAG details given a DAG id."""
Expand All @@ -381,6 +394,7 @@ def dag_details(args, session=NEW_SESSION):

@cli_utils.action_cli
@suppress_logs_and_warning
@providers_configuration_loaded
def dag_list_import_errors(args) -> None:
"""Displays dags with import errors on the command line."""
dagbag = DagBag(process_subdir(args.subdir))
Expand All @@ -395,6 +409,7 @@ def dag_list_import_errors(args) -> None:

@cli_utils.action_cli
@suppress_logs_and_warning
@providers_configuration_loaded
def dag_report(args) -> None:
"""Displays dagbag stats at the command line."""
dagbag = DagBag(process_subdir(args.subdir))
Expand All @@ -413,6 +428,7 @@ def dag_report(args) -> None:

@cli_utils.action_cli
@suppress_logs_and_warning
@providers_configuration_loaded
@provide_session
def dag_list_jobs(args, dag: DAG | None = None, session: Session = NEW_SESSION) -> None:
"""Lists latest n jobs."""
Expand Down Expand Up @@ -443,6 +459,7 @@ def dag_list_jobs(args, dag: DAG | None = None, session: Session = NEW_SESSION)

@cli_utils.action_cli
@suppress_logs_and_warning
@providers_configuration_loaded
@provide_session
def dag_list_dag_runs(args, dag: DAG | None = None, session: Session = NEW_SESSION) -> None:
"""Lists dag runs for a given DAG."""
Expand Down Expand Up @@ -479,8 +496,9 @@ def dag_list_dag_runs(args, dag: DAG | None = None, session: Session = NEW_SESSI
)


@provide_session
@cli_utils.action_cli
@providers_configuration_loaded
@provide_session
def dag_test(args, dag: DAG | None = None, session: Session = NEW_SESSION) -> None:
"""Execute one single DagRun for a given DAG and execution date."""
run_conf = None
Expand Down Expand Up @@ -513,8 +531,9 @@ def dag_test(args, dag: DAG | None = None, session: Session = NEW_SESSION) -> No
print(dot_graph.source)


@provide_session
@cli_utils.action_cli
@providers_configuration_loaded
@provide_session
def dag_reserialize(args, session: Session = NEW_SESSION) -> None:
"""Serialize a DAG instance."""
session.execute(delete(SerializedDagModel).execution_options(synchronize_session=False))
Expand Down
2 changes: 2 additions & 0 deletions airflow/cli/commands/dag_processor_command.py
Expand Up @@ -31,6 +31,7 @@
from airflow.jobs.job import Job, run_job
from airflow.utils import cli as cli_utils
from airflow.utils.cli import setup_locations, setup_logging
from airflow.utils.providers_configuration_loader import providers_configuration_loaded

log = logging.getLogger(__name__)

Expand All @@ -53,6 +54,7 @@ def _create_dag_processor_job_runner(args: Any) -> DagProcessorJobRunner:


@cli_utils.action_cli
@providers_configuration_loaded
def dag_processor(args):
"""Starts Airflow Dag Processor Job."""
if not conf.getboolean("scheduler", "standalone_dag_processor"):
Expand Down

0 comments on commit 56c41d4

Please sign in to comment.