Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 22 additions & 10 deletions airflow/models/dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,11 +111,12 @@ def state(self):
return synonym('_state', descriptor=property(self.get_state, self.set_state))

@provide_session
def refresh_from_db(self, session=None):
def refresh_from_db(self, session: Session = None):
"""
Reloads the current dagrun from the database

:param session: database session
:type session: Session
"""
DR = DagRun

Expand Down Expand Up @@ -203,6 +204,7 @@ def find(

@staticmethod
def generate_run_id(run_type: DagRunType, execution_date: datetime) -> str:
"""Generate Run ID based on Run Type and Execution Date"""
return f"{run_type.value}__{execution_date.isoformat()}"

@provide_session
Expand Down Expand Up @@ -237,11 +239,14 @@ def get_task_instances(self, state=None, session=None):
return tis.all()

@provide_session
def get_task_instance(self, task_id, session=None):
def get_task_instance(self, task_id: str, session: Session = None):
"""
Returns the task instance specified by task_id for this dag run

:param task_id: the task id
:type task_id: str
:param session: Sqlalchemy ORM Session
:type session: Session
"""
ti = session.query(TI).filter(
TI.dag_id == self.dag_id,
Expand All @@ -258,8 +263,7 @@ def get_dag(self):
:return: DAG
"""
if not self.dag:
raise AirflowException("The DAG (.dag) for {} needs to be set"
.format(self))
raise AirflowException("The DAG (.dag) for {} needs to be set".format(self))

return self.dag

Expand All @@ -280,7 +284,7 @@ def get_previous_dagrun(self, state: Optional[str] = None, session: Session = No
).first()

@provide_session
def get_previous_scheduled_dagrun(self, session=None):
def get_previous_scheduled_dagrun(self, session: Session = None):
"""The previous, SCHEDULED DagRun, if there is one"""
dag = self.get_dag()

Expand All @@ -290,11 +294,13 @@ def get_previous_scheduled_dagrun(self, session=None):
).first()

@provide_session
def update_state(self, session=None) -> List[TI]:
def update_state(self, session: Session = None) -> List[TI]:
"""
Determines the overall state of the DagRun based on the state
of its TaskInstances.

:param session: Sqlalchemy ORM Session
:type session: Session
:return: ready_tis: the tis that can be scheduled in the current loop
:rtype ready_tis: list[airflow.models.TaskInstance]
"""
Expand Down Expand Up @@ -336,8 +342,7 @@ def update_state(self, session=None) -> List[TI]:
):
self.log.error('Marking run %s failed', self)
self.set_state(State.FAILED)
dag.handle_callback(self, success=False, reason='task_failure',
session=session)
dag.handle_callback(self, success=False, reason='task_failure', session=session)

# if all leafs succeeded and no unfinished tasks, the run succeeded
elif not unfinished_tasks and all(
Expand Down Expand Up @@ -430,10 +435,13 @@ def _emit_duration_stats_for_finished_state(self):
Stats.timing('dagrun.duration.failed.{}'.format(self.dag_id), duration)

@provide_session
def verify_integrity(self, session=None):
def verify_integrity(self, session: Session = None):
"""
Verifies the DagRun by checking for removed tasks or tasks that are not in the
database yet. It will set state to removed or add the task if required.

:param session: Sqlalchemy ORM Session
:type session: Session
"""
dag = self.get_dag()
tis = self.get_task_instances(session=session)
Expand Down Expand Up @@ -487,8 +495,12 @@ def verify_integrity(self, session=None):
session.rollback()

@staticmethod
def get_run(session, dag_id, execution_date):
def get_run(session: Session, dag_id: str, execution_date: datetime):
"""
Get a single DAG Run

:param session: Sqlalchemy ORM Session
:type session: Session
:param dag_id: DAG ID
:type dag_id: unicode
:param execution_date: execution date
Expand Down