From 8ff7dfbd9e76aa40b04adeb231df3820606f5ba3 Mon Sep 17 00:00:00 2001 From: Ephraim Anierobi Date: Sun, 2 Jul 2023 12:31:13 +0100 Subject: [PATCH] Sanitize `DagRun.run_id` and allow flexibility (#32293) This commit sanitizes the DagRun.run_id parameter by introducing a configurable option. Users now have the ability to select a specific run_id pattern for their runs, ensuring stricter control over the values used. This update does not impact the default run_id generation performed by the scheduler for scheduled DAG runs or for Dag runs triggered without modifying the run_id parameter in the run configuration page. The configuration flexibility empowers users to align the run_id pattern with their specific requirements. (cherry picked from commit 05bd90f563649f2e9c8f0c85cf5838315a665a02) --- airflow/config_templates/config.yml | 9 ++++ airflow/config_templates/default_airflow.cfg | 5 ++ airflow/models/dag.py | 55 +++++++++++--------- airflow/models/dagrun.py | 16 +++++- airflow/www/views.py | 32 ++++++++---- tests/models/test_dagrun.py | 32 ++++++++++++ tests/www/views/test_views_trigger_dag.py | 30 +++++++++++ 7 files changed, 143 insertions(+), 36 deletions(-) diff --git a/airflow/config_templates/config.yml b/airflow/config_templates/config.yml index 68c2001e0036b..4653a82416fd8 100644 --- a/airflow/config_templates/config.yml +++ b/airflow/config_templates/config.yml @@ -2440,6 +2440,15 @@ scheduler: type: float example: ~ default: "120.0" + allowed_run_id_pattern: + description: | + The run_id pattern used to verify the validity of user input to the run_id parameter when + triggering a DAG. This pattern cannot change the pattern used by scheduler to generate run_id + for scheduled DAG runs or DAG runs triggered without changing the run_id parameter. + version_added: 2.6.3 + type: string + example: ~ + default: "^[A-Za-z0-9_.~:+-]+$" triggerer: description: ~ options: diff --git a/airflow/config_templates/default_airflow.cfg b/airflow/config_templates/default_airflow.cfg index 04971050c4cc9..c765f17eada0e 100644 --- a/airflow/config_templates/default_airflow.cfg +++ b/airflow/config_templates/default_airflow.cfg @@ -1245,6 +1245,11 @@ task_queued_timeout = 600.0 # longer than `[scheduler] task_queued_timeout`. task_queued_timeout_check_interval = 120.0 +# The run_id pattern used to verify the validity of user input to the run_id parameter when +# triggering a DAG. This pattern cannot change the pattern used by scheduler to generate run_id +# for scheduled DAG runs or DAG runs triggered without changing the run_id parameter. +allowed_run_id_pattern = ^[A-Za-z0-9_.~:+-]+$ + [triggerer] # How many triggers a single Triggerer will run at once, by default. default_capacity = 1000 diff --git a/airflow/models/dag.py b/airflow/models/dag.py index 479c70524bd9b..97b35e5a788c4 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -66,7 +66,7 @@ from airflow import settings, utils from airflow.api_internal.internal_api_call import internal_api_call from airflow.compat.functools import cached_property -from airflow.configuration import conf, secrets_backend_list +from airflow.configuration import conf as airflow_conf, secrets_backend_list from airflow.exceptions import ( AirflowDagInconsistent, AirflowException, @@ -80,7 +80,7 @@ from airflow.models.base import Base, StringID from airflow.models.dagcode import DagCode from airflow.models.dagpickle import DagPickle -from airflow.models.dagrun import DagRun +from airflow.models.dagrun import RUN_ID_REGEX, DagRun from airflow.models.operator import Operator from airflow.models.param import DagParam, ParamsDict from airflow.models.taskinstance import Context, TaskInstance, TaskInstanceKey, clear_task_instances @@ -402,13 +402,13 @@ def __init__( user_defined_filters: dict | None = None, default_args: dict | None = None, concurrency: int | None = None, - max_active_tasks: int = conf.getint("core", "max_active_tasks_per_dag"), - max_active_runs: int = conf.getint("core", "max_active_runs_per_dag"), + max_active_tasks: int = airflow_conf.getint("core", "max_active_tasks_per_dag"), + max_active_runs: int = airflow_conf.getint("core", "max_active_runs_per_dag"), dagrun_timeout: timedelta | None = None, sla_miss_callback: None | SLAMissCallback | list[SLAMissCallback] = None, - default_view: str = conf.get_mandatory_value("webserver", "dag_default_view").lower(), - orientation: str = conf.get_mandatory_value("webserver", "dag_orientation"), - catchup: bool = conf.getboolean("scheduler", "catchup_by_default"), + default_view: str = airflow_conf.get_mandatory_value("webserver", "dag_default_view").lower(), + orientation: str = airflow_conf.get_mandatory_value("webserver", "dag_orientation"), + catchup: bool = airflow_conf.getboolean("scheduler", "catchup_by_default"), on_success_callback: None | DagStateChangeCallback | list[DagStateChangeCallback] = None, on_failure_callback: None | DagStateChangeCallback | list[DagStateChangeCallback] = None, doc_md: str | None = None, @@ -2429,7 +2429,7 @@ def run( mark_success=False, local=False, executor=None, - donot_pickle=conf.getboolean("core", "donot_pickle"), + donot_pickle=airflow_conf.getboolean("core", "donot_pickle"), ignore_task_deps=False, ignore_first_depends_on_past=True, pool=None, @@ -2666,13 +2666,14 @@ def create_dagrun( "Creating DagRun needs either `run_id` or both `run_type` and `execution_date`" ) - if run_id and "/" in run_id: - warnings.warn( - "Using forward slash ('/') in a DAG run ID is deprecated. Note that this character " - "also makes the run impossible to retrieve via Airflow's REST API.", - RemovedInAirflow3Warning, - stacklevel=3, - ) + regex = airflow_conf.get("scheduler", "allowed_run_id_pattern") + + if run_id and not re.match(RUN_ID_REGEX, run_id): + if not regex.strip() or not re.match(regex.strip(), run_id): + raise AirflowException( + f"The provided run ID '{run_id}' is invalid. It does not match either " + f"the configured pattern: '{regex}' or the built-in pattern: '{RUN_ID_REGEX}'" + ) # create a copy of params before validating copied_params = copy.deepcopy(self.params) @@ -2960,7 +2961,7 @@ def sync_to_db(self, processor_subdir: str | None = None, session=NEW_SESSION): def get_default_view(self): """This is only there for backward compatible jinja2 templates""" if self.default_view is None: - return conf.get("webserver", "dag_default_view").lower() + return airflow_conf.get("webserver", "dag_default_view").lower() else: return self.default_view @@ -3177,7 +3178,7 @@ class DagModel(Base): root_dag_id = Column(StringID()) # A DAG can be paused from the UI / DB # Set this default value of is_paused based on a configuration value! - is_paused_at_creation = conf.getboolean("core", "dags_are_paused_at_creation") + is_paused_at_creation = airflow_conf.getboolean("core", "dags_are_paused_at_creation") is_paused = Column(Boolean, default=is_paused_at_creation) # Whether the DAG is a subdag is_subdag = Column(Boolean, default=False) @@ -3251,7 +3252,9 @@ class DagModel(Base): "TaskOutletDatasetReference", cascade="all, delete, delete-orphan", ) - NUM_DAGS_PER_DAGRUN_QUERY = conf.getint("scheduler", "max_dagruns_to_create_per_loop", fallback=10) + NUM_DAGS_PER_DAGRUN_QUERY = airflow_conf.getint( + "scheduler", "max_dagruns_to_create_per_loop", fallback=10 + ) def __init__(self, concurrency=None, **kwargs): super().__init__(**kwargs) @@ -3264,10 +3267,10 @@ def __init__(self, concurrency=None, **kwargs): ) self.max_active_tasks = concurrency else: - self.max_active_tasks = conf.getint("core", "max_active_tasks_per_dag") + self.max_active_tasks = airflow_conf.getint("core", "max_active_tasks_per_dag") if self.max_active_runs is None: - self.max_active_runs = conf.getint("core", "max_active_runs_per_dag") + self.max_active_runs = airflow_conf.getint("core", "max_active_runs_per_dag") if self.has_task_concurrency_limits is None: # Be safe -- this will be updated later once the DAG is parsed @@ -3346,7 +3349,7 @@ def get_default_view(self) -> str: have a value """ # This is for backwards-compatibility with old dags that don't have None as default_view - return self.default_view or conf.get_mandatory_value("webserver", "dag_default_view").lower() + return self.default_view or airflow_conf.get_mandatory_value("webserver", "dag_default_view").lower() @property def safe_dag_id(self): @@ -3529,13 +3532,13 @@ def dag( user_defined_filters: dict | None = None, default_args: dict | None = None, concurrency: int | None = None, - max_active_tasks: int = conf.getint("core", "max_active_tasks_per_dag"), - max_active_runs: int = conf.getint("core", "max_active_runs_per_dag"), + max_active_tasks: int = airflow_conf.getint("core", "max_active_tasks_per_dag"), + max_active_runs: int = airflow_conf.getint("core", "max_active_runs_per_dag"), dagrun_timeout: timedelta | None = None, sla_miss_callback: None | SLAMissCallback | list[SLAMissCallback] = None, - default_view: str = conf.get_mandatory_value("webserver", "dag_default_view").lower(), - orientation: str = conf.get_mandatory_value("webserver", "dag_orientation"), - catchup: bool = conf.getboolean("scheduler", "catchup_by_default"), + default_view: str = airflow_conf.get_mandatory_value("webserver", "dag_default_view").lower(), + orientation: str = airflow_conf.get_mandatory_value("webserver", "dag_orientation"), + catchup: bool = airflow_conf.getboolean("scheduler", "catchup_by_default"), on_success_callback: None | DagStateChangeCallback | list[DagStateChangeCallback] = None, on_failure_callback: None | DagStateChangeCallback | list[DagStateChangeCallback] = None, doc_md: str | None = None, diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py index ba0fd9fda331f..028a81bb75751 100644 --- a/airflow/models/dagrun.py +++ b/airflow/models/dagrun.py @@ -24,6 +24,7 @@ from datetime import datetime from typing import TYPE_CHECKING, Any, Callable, Iterable, Iterator, NamedTuple, Sequence, TypeVar, overload +import re2 as re from sqlalchemy import ( Boolean, Column, @@ -43,7 +44,7 @@ ) from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.associationproxy import association_proxy -from sqlalchemy.orm import Session, declared_attr, joinedload, relationship, synonym +from sqlalchemy.orm import Session, declared_attr, joinedload, relationship, synonym, validates from sqlalchemy.sql.expression import false, select, true from airflow import settings @@ -75,6 +76,8 @@ CreatedTasks = TypeVar("CreatedTasks", Iterator["dict[str, Any]"], Iterator[TI]) TaskCreator = Callable[[Operator, Iterable[int]], CreatedTasks] +RUN_ID_REGEX = r"^(?:manual|scheduled|dataset_triggered)__(?:\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\+00:00)$" + class TISchedulingDecision(NamedTuple): """Type of return for DagRun.task_instance_scheduling_decisions""" @@ -238,6 +241,17 @@ def __repr__(self): external_trigger=self.external_trigger, ) + @validates("run_id") + def validate_run_id(self, key: str, run_id: str) -> str | None: + if not run_id: + return None + regex = airflow_conf.get("scheduler", "allowed_run_id_pattern") + if not re.match(regex, run_id) and not re.match(RUN_ID_REGEX, run_id): + raise ValueError( + f"The run_id provided '{run_id}' does not match the pattern '{regex}' or '{RUN_ID_REGEX}'" + ) + return run_id + @property def logical_date(self) -> datetime: return self.execution_date diff --git a/airflow/www/views.py b/airflow/www/views.py index c29d952583568..18188d3de5912 100644 --- a/airflow/www/views.py +++ b/airflow/www/views.py @@ -96,7 +96,7 @@ from airflow.models.abstractoperator import AbstractOperator from airflow.models.dag import DAG, get_dataset_triggered_next_run_info from airflow.models.dagcode import DagCode -from airflow.models.dagrun import DagRun, DagRunType +from airflow.models.dagrun import RUN_ID_REGEX, DagRun, DagRunType from airflow.models.dataset import DagScheduleDatasetReference, DatasetDagRunQueue, DatasetEvent, DatasetModel from airflow.models.mappedoperator import MappedOperator from airflow.models.operator import Operator @@ -1896,7 +1896,7 @@ def delete(self): def trigger(self, session: Session = NEW_SESSION): """Triggers DAG Run.""" dag_id = request.values["dag_id"] - run_id = request.values.get("run_id", "") + run_id = request.values.get("run_id", "").replace(" ", "+") origin = get_safe_url(request.values.get("origin")) unpause = request.values.get("unpause") request_conf = request.values.get("conf") @@ -2016,13 +2016,27 @@ def trigger(self, session: Session = NEW_SESSION): flash(message, "error") return redirect(origin) - # Flash a warning when slash is used, but still allow it to continue on. - if run_id and "/" in run_id: - flash( - "Using forward slash ('/') in a DAG run ID is deprecated. Note that this character " - "also makes the run impossible to retrieve via Airflow's REST API.", - "warning", - ) + regex = conf.get("scheduler", "allowed_run_id_pattern") + if run_id and not re.match(RUN_ID_REGEX, run_id): + if not regex.strip() or not re.match(regex.strip(), run_id): + flash( + f"The provided run ID '{run_id}' is invalid. It does not match either " + f"the configured pattern: '{regex}' or the built-in pattern: '{RUN_ID_REGEX}'", + "error", + ) + + form = DateTimeForm(data={"execution_date": execution_date}) + return self.render_template( + "airflow/trigger.html", + form_fields=form_fields, + dag=dag, + dag_id=dag_id, + origin=origin, + conf=request_conf, + form=form, + is_dag_run_conf_overrides_params=is_dag_run_conf_overrides_params, + recent_confs=recent_confs, + ) run_conf = {} if request_conf: diff --git a/tests/models/test_dagrun.py b/tests/models/test_dagrun.py index 38f882c687de6..5cb2c61e7b227 100644 --- a/tests/models/test_dagrun.py +++ b/tests/models/test_dagrun.py @@ -29,6 +29,7 @@ from airflow import settings from airflow.callbacks.callback_requests import DagCallbackRequest from airflow.decorators import task, task_group +from airflow.exceptions import AirflowException from airflow.models import ( DAG, DagBag, @@ -53,6 +54,7 @@ from airflow.utils.types import DagRunType from tests.models import DEFAULT_DATE as _DEFAULT_DATE from tests.test_utils import db +from tests.test_utils.config import conf_vars from tests.test_utils.mock_operators import MockOperator DEFAULT_DATE = pendulum.instance(_DEFAULT_DATE) @@ -2328,3 +2330,33 @@ def the_task(): assert session.query(DagRun).filter(DagRun.id == dr.id).one_or_none() is None assert session.query(DagRunNote).filter(DagRunNote.dag_run_id == dr.id).one_or_none() is None + + +@pytest.mark.parametrize( + "pattern, run_id, result", + [ + ["^[A-Z]", "ABC", True], + ["^[A-Z]", "abc", False], + ["^[0-9]", "123", True], + # The below params tests that user configuration does not affect internally generated + # run_ids + ["", "scheduled__2023-01-01T00:00:00+00:00", True], + ["", "manual__2023-01-01T00:00:00+00:00", True], + ["", "dataset_triggered__2023-01-01T00:00:00+00:00", True], + ["", "scheduled_2023-01-01T00", False], + ["", "manual_2023-01-01T00", False], + ["", "dataset_triggered_2023-01-01T00", False], + ["^[0-9]", "scheduled__2023-01-01T00:00:00+00:00", True], + ["^[0-9]", "manual__2023-01-01T00:00:00+00:00", True], + ["^[a-z]", "dataset_triggered__2023-01-01T00:00:00+00:00", True], + ], +) +def test_dag_run_id_config(session, dag_maker, pattern, run_id, result): + with conf_vars({("scheduler", "allowed_run_id_pattern"): pattern}): + with dag_maker(): + ... + if result: + dag_maker.create_dagrun(run_id=run_id) + else: + with pytest.raises(AirflowException): + dag_maker.create_dagrun(run_id=run_id) diff --git a/tests/www/views/test_views_trigger_dag.py b/tests/www/views/test_views_trigger_dag.py index 1c7cc5827ee49..b6fcd20d3effc 100644 --- a/tests/www/views/test_views_trigger_dag.py +++ b/tests/www/views/test_views_trigger_dag.py @@ -30,6 +30,7 @@ from airflow.utils.session import create_session from airflow.utils.types import DagRunType from tests.test_utils.api_connexion_utils import create_test_client +from tests.test_utils.config import conf_vars from tests.test_utils.www import check_content_in_response @@ -286,3 +287,32 @@ def test_trigger_dag_params_array_value_none_render(admin_client, dag_maker, ses f'', resp, ) + + +@pytest.mark.parametrize( + "pattern, run_id, result", + [ + ["^[A-Z]", "ABC", True], + ["^[A-Z]", "abc", False], + ["^[0-9]", "123", True], + # The below params tests that user configuration does not affect internally generated + # run_ids. We use manual__ as a prefix for manually triggered DAGs due to a restriction + # in manually triggered DAGs that the run_id must not start with scheduled__. + ["", "manual__2023-01-01T00:00:00+00:00", True], + ["", "scheduled_2023-01-01T00", False], + ["", "manual_2023-01-01T00", False], + ["", "dataset_triggered_2023-01-01T00", False], + ["^[0-9]", "manual__2023-01-01T00:00:00+00:00", True], + ["^[a-z]", "manual__2023-01-01T00:00:00+00:00", True], + ], +) +def test_dag_run_id_pattern(session, admin_client, pattern, run_id, result): + with conf_vars({("scheduler", "allowed_run_id_pattern"): pattern}): + test_dag_id = "example_bash_operator" + admin_client.post(f"dags/{test_dag_id}/trigger?&run_id={run_id}") + run = session.query(DagRun).filter(DagRun.dag_id == test_dag_id).first() + if result: + assert run is not None + assert run.run_type == DagRunType.MANUAL + else: + assert run is None