Skip to content
Permalink
Browse files
Add support for queued state in DagRun update endpoint. (#23481)
  • Loading branch information
tirkarthi committed May 9, 2022
1 parent 1220c1a commit 4485393562ea4151a42f1be47bea11638b236001
Showing 4 changed files with 20 additions and 7 deletions.
@@ -23,7 +23,11 @@
from sqlalchemy import or_
from sqlalchemy.orm import Query, Session

from airflow.api.common.mark_tasks import set_dag_run_state_to_failed, set_dag_run_state_to_success
from airflow.api.common.mark_tasks import (
set_dag_run_state_to_failed,
set_dag_run_state_to_queued,
set_dag_run_state_to_success,
)
from airflow.api_connexion import security
from airflow.api_connexion.exceptions import AlreadyExists, BadRequest, NotFound
from airflow.api_connexion.parameters import apply_sorting, check_limit, format_datetime, format_parameters
@@ -308,6 +312,8 @@ def update_dag_run_state(*, dag_id: str, dag_run_id: str, session: Session = NEW
dag = current_app.dag_bag.get_dag(dag_id)
if state == DagRunState.SUCCESS:
set_dag_run_state_to_success(dag=dag, run_id=dag_run.run_id, commit=True)
elif state == DagRunState.QUEUED:
set_dag_run_state_to_queued(dag=dag, run_id=dag_run.run_id, commit=True)
else:
set_dag_run_state_to_failed(dag=dag, run_id=dag_run.run_id, commit=True)
dag_run = session.query(DagRun).get(dag_run.id)
@@ -2442,6 +2442,7 @@ components:
enum:
- success
- failed
- queued

DAGRunCollection:
type: object
@@ -112,7 +112,11 @@ def autofill(self, data, **kwargs):
class SetDagRunStateFormSchema(Schema):
"""Schema for handling the request of setting state of DAG run"""

state = DagStateField(validate=validate.OneOf([DagRunState.SUCCESS.value, DagRunState.FAILED.value]))
state = DagStateField(
validate=validate.OneOf(
[DagRunState.SUCCESS.value, DagRunState.FAILED.value, DagRunState.QUEUED.value]
)
)


class DAGRunCollection(NamedTuple):
@@ -1271,7 +1271,7 @@ def test_should_raises_403_unauthorized(self, username):


class TestPatchDagRunState(TestDagRunEndpoint):
@pytest.mark.parametrize("state", ["failed", "success"])
@pytest.mark.parametrize("state", ["failed", "success", "queued"])
@pytest.mark.parametrize("run_type", [state.value for state in DagRunType])
def test_should_respond_200(self, state, run_type, dag_maker, session):
dag_id = "TEST_DAG_ID"
@@ -1294,8 +1294,10 @@ def test_should_respond_200(self, state, run_type, dag_maker, session):
environ_overrides={"REMOTE_USER": "test"},
)

ti.refresh_from_db()
assert ti.state == state
if state != "queued":
ti.refresh_from_db()
assert ti.state == state

dr = session.query(DagRun).filter(DagRun.run_id == dr.run_id).first()
assert response.status_code == 200
assert response.json == {
@@ -1314,7 +1316,7 @@ def test_should_respond_200(self, state, run_type, dag_maker, session):
'run_type': run_type,
}

@pytest.mark.parametrize('invalid_state', ["running", "queued"])
@pytest.mark.parametrize('invalid_state', ["running"])
@freeze_time(TestDagRunEndpoint.default_time)
def test_should_response_400_for_non_existing_dag_run_state(self, invalid_state, dag_maker):
dag_id = "TEST_DAG_ID"
@@ -1332,7 +1334,7 @@ def test_should_response_400_for_non_existing_dag_run_state(self, invalid_state,
)
assert response.status_code == 400
assert response.json == {
'detail': f"'{invalid_state}' is not one of ['success', 'failed'] - 'state'",
'detail': f"'{invalid_state}' is not one of ['success', 'failed', 'queued'] - 'state'",
'status': 400,
'title': 'Bad Request',
'type': EXCEPTIONS_LINK_MAP[400],

0 comments on commit 4485393

Please sign in to comment.