Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Save scheduler execution time by caching dags #30704

Merged
merged 8 commits into from
May 18, 2023
16 changes: 14 additions & 2 deletions airflow/jobs/scheduler_job_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
from airflow.timetables.simple import DatasetTriggeredTimetable
from airflow.utils import timezone
from airflow.utils.event_scheduler import EventScheduler
from airflow.utils.helpers import get_value_with_cache
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.retries import MAX_DB_RETRIES, retry_db_transaction, run_with_db_retries
from airflow.utils.session import NEW_SESSION, create_session, provide_session
Expand Down Expand Up @@ -1083,8 +1084,13 @@ def _do_scheduling(self, session: Session) -> int:
callback_tuples = self._schedule_all_dag_runs(guard, dag_runs, session)

# Send the callbacks after we commit to ensure the context is up to date when it gets run
# cache saves time during scheduling of many dag_runs for same dag
cached_dags: dict = {}
for dag_run, callback_to_run in callback_tuples:
dag = self.dagbag.get_dag(dag_run.dag_id, session=session)
dag = get_value_with_cache(
cached_dags, dag_run.dag_id, lambda: self.dagbag.get_dag(dag_run.dag_id, session=session)
)

if not dag:
self.log.error("DAG '%s' not found in serialized_dag table", dag_run.dag_id)
continue
Expand Down Expand Up @@ -1348,8 +1354,14 @@ def _update_state(dag: DAG, dag_run: DagRun):
tags={"dag_id": dag.dag_id},
)

# cache saves time during scheduling of many dag_runs for same dag
cached_dags: dict = {}

for dag_run in dag_runs:
dag = dag_run.dag = self.dagbag.get_dag(dag_run.dag_id, session=session)
dag = dag_run.dag = get_value_with_cache(
cached_dags, dag_run.dag_id, lambda: self.dagbag.get_dag(dag_run.dag_id, session=session)
)

if not dag:
self.log.error("DAG '%s' not found in serialized_dag table", dag_run.dag_id)
continue
Expand Down
9 changes: 9 additions & 0 deletions airflow/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,15 @@
S = TypeVar("S")


def get_value_with_cache(cache: dict[str, Any], key: str, do_get_value: Callable) -> Any:
"""Returns value from cache or function"""
return_value = cache.get(key)
if not return_value:
return_value = cache[key] = do_get_value()

return return_value


def validate_key(k: str, max_length: int = 250):
"""Validates value used as a key."""
if not isinstance(k, str):
Expand Down
18 changes: 18 additions & 0 deletions tests/utils/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
at_most_one,
build_airflow_url_with_query,
exactly_one,
get_value_with_cache,
merge_dicts,
prune_dict,
validate_group_key,
Expand Down Expand Up @@ -174,6 +175,23 @@ def test_build_airflow_url_with_query(self):
with cached_app(testing=True).test_request_context():
assert build_airflow_url_with_query(query) == expected_url

@pytest.mark.parametrize(
"cache, excepted_value",
[
({}, 1),
({"key2": 10}, 1),
({"key1": 10}, 10),
({"key1": ""}, 1),
],
ids=["cache_empty", "missing_cache_item", "cached", "empty_cached_value"],
)
def test_get_value_with_cache(self, cache, excepted_value):
def do_get_value():
return 1

assert excepted_value == get_value_with_cache(cache, "key1", lambda: do_get_value())
assert excepted_value == cache["key1"]

@pytest.mark.parametrize(
"key_id, message, exception",
[
Expand Down