From b7f607ecc7c20ee465c485f49ef20b19a7c82b06 Mon Sep 17 00:00:00 2001 From: Daniel Standish <15932138+dstandish@users.noreply.github.com> Date: Mon, 14 Mar 2022 11:53:02 -0700 Subject: [PATCH 01/17] Add map_index to TaskFail --- ..._48925b2719cb_add_map_index_to_taskfail.py | 76 +++++++++++++++++++ airflow/models/taskfail.py | 47 ++++++++++-- airflow/models/taskinstance.py | 10 ++- airflow/www/views.py | 10 +-- docs/apache-airflow/migrations-ref.rst | 4 +- tests/core/test_core.py | 12 ++- 6 files changed, 141 insertions(+), 18 deletions(-) create mode 100644 airflow/migrations/versions/0105_48925b2719cb_add_map_index_to_taskfail.py diff --git a/airflow/migrations/versions/0105_48925b2719cb_add_map_index_to_taskfail.py b/airflow/migrations/versions/0105_48925b2719cb_add_map_index_to_taskfail.py new file mode 100644 index 0000000000000..ea8f4668b5d28 --- /dev/null +++ b/airflow/migrations/versions/0105_48925b2719cb_add_map_index_to_taskfail.py @@ -0,0 +1,76 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Add map_index to TaskFail + +Drop index idx_task_fail_dag_task_date +Add FK `task_fail_ti_fkey`: TF -> TI ([dag_id, task_id, run_id, map_index]) + + + +Revision ID: 48925b2719cb +Revises: 4eaab2fe6582 +Create Date: 2022-03-14 10:31:11.220720 + +""" + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = '48925b2719cb' +down_revision = '4eaab2fe6582' +branch_labels = None +depends_on = None +airflow_version = '2.3.0' + + +def upgrade(): + """Apply Add map_index to TaskFail""" + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('task_fail', schema=None) as batch_op: + batch_op.add_column(sa.Column('run_id', sa.String(length=250), nullable=False)) + batch_op.add_column(sa.Column('map_index', sa.Integer(), server_default='-1', nullable=False)) + batch_op.drop_index('idx_task_fail_dag_task_date') + batch_op.create_index('idx_task_fail_dag_task_date', ['dag_id', 'task_id', 'run_id'], unique=False) + batch_op.create_foreign_key( + 'sla_miss_ti_fkey', + 'task_instance', + ['dag_id', 'task_id', 'run_id', 'map_index'], + ['dag_id', 'task_id', 'run_id', 'map_index'], + ondelete='CASCADE', + ) + batch_op.drop_column('execution_date') + + # ### end Alembic commands ### + + +def downgrade(): + """Unapply Add map_index to TaskFail""" + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('task_fail', schema=None) as batch_op: + batch_op.add_column(sa.Column('execution_date', sa.DATETIME(), nullable=False)) + batch_op.drop_constraint('sla_miss_ti_fkey', type_='foreignkey') + batch_op.drop_index('idx_task_fail_dag_task_date') + batch_op.create_index( + 'idx_task_fail_dag_task_date', ['dag_id', 'task_id', 'execution_date'], unique=False + ) + batch_op.drop_column('map_index') + batch_op.drop_column('run_id') + + # ### end Alembic commands ### diff --git a/airflow/models/taskfail.py b/airflow/models/taskfail.py index 23d3086e2d013..34263c76291b9 100644 --- a/airflow/models/taskfail.py +++ b/airflow/models/taskfail.py @@ -16,11 +16,17 @@ # specific language governing permissions and limitations # under the License. """Taskfail tracks the failed run durations of each task instance""" -from sqlalchemy import Column, Index, Integer, String +from typing import TYPE_CHECKING -from airflow.models.base import COLLATION_ARGS, ID_LEN, Base +from sqlalchemy import Column, ForeignKeyConstraint, Integer +from sqlalchemy.orm import relationship + +from airflow.models.base import Base, StringID from airflow.utils.sqlalchemy import UtcDateTime +if TYPE_CHECKING: + from airflow.models.taskinstance import TaskInstance + class TaskFail(Base): """TaskFail tracks the failed run durations of each task instance.""" @@ -28,22 +34,47 @@ class TaskFail(Base): __tablename__ = "task_fail" id = Column(Integer, primary_key=True) - task_id = Column(String(ID_LEN, **COLLATION_ARGS), nullable=False) - dag_id = Column(String(ID_LEN, **COLLATION_ARGS), nullable=False) - execution_date = Column(UtcDateTime, nullable=False) + task_id = Column(StringID(), primary_key=True) + dag_id = Column(StringID(), primary_key=True) + run_id = Column(StringID(), primary_key=True) + map_index = Column(Integer, primary_key=True, server_default='-1') start_date = Column(UtcDateTime) end_date = Column(UtcDateTime) duration = Column(Integer) - __table_args__ = (Index('idx_task_fail_dag_task_date', dag_id, task_id, execution_date, unique=False),) + __table_args__ = ( + ForeignKeyConstraint( + [dag_id, task_id, run_id, map_index], + [ + "task_instance.dag_id", + "task_instance.task_id", + "task_instance.run_id", + "task_instance.map_index", + ], + name='task_fail_ti_fkey', + ondelete="CASCADE", + ), + ) - def __init__(self, task, execution_date, start_date, end_date): + task_instance: "TaskInstance" = relationship( + "TaskInstance", + lazy='joined', + ) + + def __init__(self, task, run_id, start_date, end_date, map_index=None): self.dag_id = task.dag_id self.task_id = task.task_id - self.execution_date = execution_date + self.run_id = run_id + self.map_index = map_index self.start_date = start_date self.end_date = end_date if self.end_date and self.start_date: self.duration = int((self.end_date - self.start_date).total_seconds()) else: self.duration = None + + def __repr__(self): + prefix = f"<{self.__class__.__name__}: {self.dag_id}.{self.task_id} {self.run_id}" + if self.map_index != -1: + prefix += f" map_index={self.map_index}" + return prefix + '>' diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index f0fd779f5f49c..a2a7b3e12eaf9 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -1800,7 +1800,15 @@ def handle_failure( # Log failure duration dag_run = self.get_dagrun(session=session) # self.dag_run not populated by refresh_from_db - session.add(TaskFail(task, dag_run.execution_date, self.start_date, self.end_date)) + session.add( + TaskFail( + task=task, + run_id=dag_run.run_id, + start_date=self.start_date, + end_date=self.end_date, + map_index=self.map_index, + ) + ) self.clear_next_method_args() diff --git a/airflow/www/views.py b/airflow/www/views.py index 1d74996bf48a6..8c21efa22b51d 100644 --- a/airflow/www/views.py +++ b/airflow/www/views.py @@ -2847,8 +2847,8 @@ def duration(self, dag_id, session=None): session.query(TaskFail) .filter( TaskFail.dag_id == dag.dag_id, - TaskFail.execution_date >= min_date, - TaskFail.execution_date <= base_date, + TaskFail.task_instance.execution_date >= min_date, + TaskFail.task_instance.execution_date <= base_date, TaskFail.task_id.in_([t.task_id for t in dag.tasks]), ) .all() @@ -3185,10 +3185,8 @@ def gantt(self, dag_id, session=None): .order_by(TaskInstance.start_date) ) - ti_fails = ( - session.query(TaskFail) - .join(DagRun, DagRun.execution_date == TaskFail.execution_date) - .filter(DagRun.execution_date == dttm, TaskFail.dag_id == dag_id) + ti_fails = session.query(TaskFail).filter( + TaskFail.task_instance.execution_date == dttm, TaskFail.dag_id == dag_id ) tasks = [] diff --git a/docs/apache-airflow/migrations-ref.rst b/docs/apache-airflow/migrations-ref.rst index 9dfd7dd838b34..de8221620c6a8 100644 --- a/docs/apache-airflow/migrations-ref.rst +++ b/docs/apache-airflow/migrations-ref.rst @@ -25,7 +25,9 @@ Here's the list of all the Database Migrations that are executed via when you ru .. Beginning of auto-generated table +---------------------------------+-------------------+-------------+--------------------------------------------------------------+ -| ``4eaab2fe6582`` (head) | ``c97c2ab6aa23`` | ``2.3.0`` | Migrate RTIF to use run_id and map_index | +| ``48925b2719cb`` (head) | ``4eaab2fe6582`` | ``2.3.0`` | Add map_index to TaskFail | ++---------------------------------+-------------------+-------------+--------------------------------------------------------------+ +| ``4eaab2fe6582`` | ``c97c2ab6aa23`` | ``2.3.0`` | Migrate RTIF to use run_id and map_index | +---------------------------------+-------------------+-------------+--------------------------------------------------------------+ | ``c97c2ab6aa23`` | ``c306b5b5ae4a`` | ``2.3.0`` | add callback request table | +---------------------------------+-------------------+-------------+--------------------------------------------------------------+ diff --git a/tests/core/test_core.py b/tests/core/test_core.py index 91d4e134d6e7a..510a5ea87352b 100644 --- a/tests/core/test_core.py +++ b/tests/core/test_core.py @@ -99,12 +99,20 @@ def test_task_fail_duration(self, dag_maker): pass op1_fails = ( session.query(TaskFail) - .filter_by(task_id='pass_sleepy', dag_id=dag.dag_id, execution_date=DEFAULT_DATE) + .filter( + TaskFail.task_id == 'pass_sleepy', + TaskFail.dag_id == dag.dag_id, + TaskFail.task_instance.execution_date == DEFAULT_DATE, + ) .all() ) op2_fails = ( session.query(TaskFail) - .filter_by(task_id='fail_sleepy', dag_id=dag.dag_id, execution_date=DEFAULT_DATE) + .filter( + TaskFail.task_id == 'pass_sleepy', + TaskFail.dag_id == dag.dag_id, + TaskFail.task_instance.execution_date == DEFAULT_DATE, + ) .all() ) From 320dc97556c64a3268085299ef64000fcb889084 Mon Sep 17 00:00:00 2001 From: Daniel Standish <15932138+dstandish@users.noreply.github.com> Date: Mon, 14 Mar 2022 13:28:59 -0700 Subject: [PATCH 02/17] update migration --- ..._48925b2719cb_add_map_index_to_taskfail.py | 137 +++++++++++++++--- 1 file changed, 114 insertions(+), 23 deletions(-) diff --git a/airflow/migrations/versions/0105_48925b2719cb_add_map_index_to_taskfail.py b/airflow/migrations/versions/0105_48925b2719cb_add_map_index_to_taskfail.py index ea8f4668b5d28..279f31b364ba6 100644 --- a/airflow/migrations/versions/0105_48925b2719cb_add_map_index_to_taskfail.py +++ b/airflow/migrations/versions/0105_48925b2719cb_add_map_index_to_taskfail.py @@ -26,11 +26,16 @@ Revision ID: 48925b2719cb Revises: 4eaab2fe6582 Create Date: 2022-03-14 10:31:11.220720 - """ +from typing import List + import sqlalchemy as sa from alembic import op +from sqlalchemy.sql import ColumnElement, Update, and_, select + +from airflow.migrations.db_types import TIMESTAMP, StringID +from airflow.migrations.utils import get_mssql_table_constraints # revision identifiers, used by Alembic. revision = '48925b2719cb' @@ -39,38 +44,124 @@ depends_on = None airflow_version = '2.3.0' +ID_LEN = 250 + + +def tables(): + global task_instance, task_fail, dag_run + metadata = sa.MetaData() + task_instance = sa.Table( + 'task_instance', + metadata, + sa.Column('task_id', StringID()), + sa.Column('dag_id', StringID()), + sa.Column('run_id', StringID()), + sa.Column('map_index', sa.Integer(), server_default='-1'), + sa.Column('execution_date', TIMESTAMP), + ) + task_fail = sa.Table( + 'task_fail', + metadata, + sa.Column('dag_id', StringID()), + sa.Column('task_id', StringID()), + sa.Column('run_id', StringID()), + sa.Column('map_index', StringID()), + sa.Column('execution_date', TIMESTAMP), + ) + dag_run = sa.Table( + 'dag_run', + metadata, + sa.Column('dag_id', StringID()), + sa.Column('run_id', StringID()), + sa.Column('execution_date', TIMESTAMP), + ) + + +def _update_value_from_dag_run( + dialect_name: str, + target_table: sa.Table, + target_column: ColumnElement, + join_columns: List[str], +) -> Update: + """ + Grabs a value from the source table ``dag_run`` and updates target with this value. + :param dialect_name: dialect in use + :param target_table: the table to update + :param target_column: the column to update + """ + # for run_id: dag_id, execution_date + # otherwise: dag_id, run_id + condition_list = [getattr(dag_run.c, x) == getattr(target_table.c, x) for x in join_columns] + condition = and_(*condition_list) + + if dialect_name == "sqlite": + # Most SQLite versions don't support multi table update (and SQLA doesn't know about it anyway), so we + # need to do a Correlated subquery update + sub_q = select([dag_run.c[target_column.name]]).where(condition) + + return target_table.update().values({target_column: sub_q}) + else: + return target_table.update().where(condition).values({target_column: dag_run.c[target_column.name]}) + def upgrade(): - """Apply Add map_index to TaskFail""" - # ### commands auto generated by Alembic - please adjust! ### - with op.batch_alter_table('task_fail', schema=None) as batch_op: - batch_op.add_column(sa.Column('run_id', sa.String(length=250), nullable=False)) + tables() + dialect_name = op.get_bind().dialect.name + + op.drop_index('idx_task_fail_dag_task_date', table_name='task_fail') + + with op.batch_alter_table('task_fail') as batch_op: batch_op.add_column(sa.Column('map_index', sa.Integer(), server_default='-1', nullable=False)) - batch_op.drop_index('idx_task_fail_dag_task_date') - batch_op.create_index('idx_task_fail_dag_task_date', ['dag_id', 'task_id', 'run_id'], unique=False) + batch_op.add_column(sa.Column('run_id', type_=StringID(), nullable=True)) + + update_query = _update_value_from_dag_run( + dialect_name=dialect_name, + target_table=task_fail, + target_column=task_fail.c.run_id, + join_columns=['dag_id', 'execution_date'], + ) + op.execute(update_query) + with op.batch_alter_table('task_fail') as batch_op: + if dialect_name == 'mssql': + constraints = get_mssql_table_constraints(op.get_bind(), 'task_fail') + pk, _ = constraints['PRIMARY KEY'].popitem() + batch_op.drop_constraint(pk, type_='primary') + elif dialect_name != 'sqlite': # sqlite PK is managed by SQLA + batch_op.drop_constraint('task_fail_pkey', type_='primary') + batch_op.alter_column('run_id', existing_type=StringID(), existing_nullable=True, nullable=False) + batch_op.drop_column('execution_date') + batch_op.create_primary_key('task_fail_pkey', ['dag_id', 'task_id', 'run_id', 'map_index']) batch_op.create_foreign_key( - 'sla_miss_ti_fkey', + 'task_fail_ti_fkey', 'task_instance', ['dag_id', 'task_id', 'run_id', 'map_index'], ['dag_id', 'task_id', 'run_id', 'map_index'], ondelete='CASCADE', ) - batch_op.drop_column('execution_date') - - # ### end Alembic commands ### def downgrade(): - """Unapply Add map_index to TaskFail""" - # ### commands auto generated by Alembic - please adjust! ### - with op.batch_alter_table('task_fail', schema=None) as batch_op: - batch_op.add_column(sa.Column('execution_date', sa.DATETIME(), nullable=False)) - batch_op.drop_constraint('sla_miss_ti_fkey', type_='foreignkey') - batch_op.drop_index('idx_task_fail_dag_task_date') - batch_op.create_index( - 'idx_task_fail_dag_task_date', ['dag_id', 'task_id', 'execution_date'], unique=False - ) - batch_op.drop_column('map_index') + tables() + dialect_name = op.get_bind().dialect.name + op.add_column('task_fail', sa.Column('execution_date', TIMESTAMP, nullable=True)) + update_query = _update_value_from_dag_run( + dialect_name=dialect_name, + target_table=task_fail, + target_column=task_fail.c.execution_date, + join_columns=['dag_id', 'run_id'], + ) + op.execute(update_query) + with op.batch_alter_table('task_fail', copy_from=task_fail) as batch_op: + batch_op.alter_column('execution_date', existing_type=TIMESTAMP, nullable=False) + if dialect_name != 'sqlite': + batch_op.drop_constraint('task_fail_ti_fkey', type_='foreignkey') + batch_op.drop_constraint('task_fail_pkey', type_='primary') + batch_op.create_primary_key('task_fail_pkey', ['dag_id', 'task_id', 'execution_date']) + batch_op.drop_column('map_index', mssql_drop_default=True) batch_op.drop_column('run_id') - - # ### end Alembic commands ### + op.create_index( + index_name='idx_task_fail_dag_task_date', + table_name='task_fail', + columns=['dag_id', 'task_id', 'execution_date'], + unique=False, + ) From 07f8cb508f3b9029a0acad6a3dd7ccd4ee8b4b39 Mon Sep 17 00:00:00 2001 From: Daniel Standish <15932138+dstandish@users.noreply.github.com> Date: Mon, 14 Mar 2022 14:34:48 -0700 Subject: [PATCH 03/17] Update airflow/models/taskinstance.py Co-authored-by: Ash Berlin-Taylor --- airflow/models/taskinstance.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index a2a7b3e12eaf9..11051e8d8559c 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -1803,7 +1803,7 @@ def handle_failure( session.add( TaskFail( task=task, - run_id=dag_run.run_id, + run_id=self.run_id, start_date=self.start_date, end_date=self.end_date, map_index=self.map_index, From 62bcb3f9235766c18ebe341293b1b7eefb119681 Mon Sep 17 00:00:00 2001 From: Daniel Standish <15932138+dstandish@users.noreply.github.com> Date: Mon, 14 Mar 2022 22:18:49 -0700 Subject: [PATCH 04/17] fix mypy --- airflow/models/taskinstance.py | 1 - 1 file changed, 1 deletion(-) diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 11051e8d8559c..0f9b08a5cc60e 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -1799,7 +1799,6 @@ def handle_failure( session.add(Log(State.FAILED, self)) # Log failure duration - dag_run = self.get_dagrun(session=session) # self.dag_run not populated by refresh_from_db session.add( TaskFail( task=task, From 2ac142123b85fae0e65b223fc31d5b456e0f2726 Mon Sep 17 00:00:00 2001 From: Daniel Standish <15932138+dstandish@users.noreply.github.com> Date: Mon, 14 Mar 2022 22:45:27 -0700 Subject: [PATCH 05/17] fix pk --- .../0105_48925b2719cb_add_map_index_to_taskfail.py | 2 +- airflow/models/taskfail.py | 8 ++++---- tests/api/common/test_delete_dag.py | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/airflow/migrations/versions/0105_48925b2719cb_add_map_index_to_taskfail.py b/airflow/migrations/versions/0105_48925b2719cb_add_map_index_to_taskfail.py index 279f31b364ba6..db91b6afd9bc2 100644 --- a/airflow/migrations/versions/0105_48925b2719cb_add_map_index_to_taskfail.py +++ b/airflow/migrations/versions/0105_48925b2719cb_add_map_index_to_taskfail.py @@ -130,7 +130,7 @@ def upgrade(): batch_op.drop_constraint('task_fail_pkey', type_='primary') batch_op.alter_column('run_id', existing_type=StringID(), existing_nullable=True, nullable=False) batch_op.drop_column('execution_date') - batch_op.create_primary_key('task_fail_pkey', ['dag_id', 'task_id', 'run_id', 'map_index']) + batch_op.create_primary_key('task_fail_pkey', ['id']) batch_op.create_foreign_key( 'task_fail_ti_fkey', 'task_instance', diff --git a/airflow/models/taskfail.py b/airflow/models/taskfail.py index 34263c76291b9..c6968c22102a1 100644 --- a/airflow/models/taskfail.py +++ b/airflow/models/taskfail.py @@ -34,10 +34,10 @@ class TaskFail(Base): __tablename__ = "task_fail" id = Column(Integer, primary_key=True) - task_id = Column(StringID(), primary_key=True) - dag_id = Column(StringID(), primary_key=True) - run_id = Column(StringID(), primary_key=True) - map_index = Column(Integer, primary_key=True, server_default='-1') + task_id = Column(StringID()) + dag_id = Column(StringID()) + run_id = Column(StringID()) + map_index = Column(Integer, server_default='-1') start_date = Column(UtcDateTime) end_date = Column(UtcDateTime) duration = Column(Integer) diff --git a/tests/api/common/test_delete_dag.py b/tests/api/common/test_delete_dag.py index d9dc0b0a01c7f..9b0cf94dd2c6d 100644 --- a/tests/api/common/test_delete_dag.py +++ b/tests/api/common/test_delete_dag.py @@ -96,7 +96,7 @@ def setup_dag_models(self, for_sub_dag=False): event="varimport", ) ) - session.add(TF(task=task, execution_date=test_date, start_date=test_date, end_date=test_date)) + session.add(TF(task=task, run_id=ti.run_id, start_date=test_date, end_date=test_date)) session.add( TR( task=ti.task, From 66137eba357eb55d6287cfd14550d111ae89b04a Mon Sep 17 00:00:00 2001 From: Daniel Standish <15932138+dstandish@users.noreply.github.com> Date: Mon, 14 Mar 2022 23:54:31 -0700 Subject: [PATCH 06/17] fix some tests --- airflow/models/taskfail.py | 8 +------- airflow/www/views.py | 6 +++--- tests/core/test_core.py | 17 ++++++----------- 3 files changed, 10 insertions(+), 21 deletions(-) diff --git a/airflow/models/taskfail.py b/airflow/models/taskfail.py index c6968c22102a1..dfdf000a9f9b8 100644 --- a/airflow/models/taskfail.py +++ b/airflow/models/taskfail.py @@ -19,13 +19,12 @@ from typing import TYPE_CHECKING from sqlalchemy import Column, ForeignKeyConstraint, Integer -from sqlalchemy.orm import relationship from airflow.models.base import Base, StringID from airflow.utils.sqlalchemy import UtcDateTime if TYPE_CHECKING: - from airflow.models.taskinstance import TaskInstance + pass class TaskFail(Base): @@ -56,11 +55,6 @@ class TaskFail(Base): ), ) - task_instance: "TaskInstance" = relationship( - "TaskInstance", - lazy='joined', - ) - def __init__(self, task, run_id, start_date, end_date, map_index=None): self.dag_id = task.dag_id self.task_id = task.task_id diff --git a/airflow/www/views.py b/airflow/www/views.py index 8c21efa22b51d..bba8155a56c07 100644 --- a/airflow/www/views.py +++ b/airflow/www/views.py @@ -2847,8 +2847,8 @@ def duration(self, dag_id, session=None): session.query(TaskFail) .filter( TaskFail.dag_id == dag.dag_id, - TaskFail.task_instance.execution_date >= min_date, - TaskFail.task_instance.execution_date <= base_date, + TaskInstance.execution_date >= min_date, + TaskInstance.execution_date <= base_date, TaskFail.task_id.in_([t.task_id for t in dag.tasks]), ) .all() @@ -3186,7 +3186,7 @@ def gantt(self, dag_id, session=None): ) ti_fails = session.query(TaskFail).filter( - TaskFail.task_instance.execution_date == dttm, TaskFail.dag_id == dag_id + TaskInstance.execution_date == dttm, TaskFail.dag_id == dag_id ) tasks = [] diff --git a/tests/core/test_core.py b/tests/core/test_core.py index 510a5ea87352b..5ed786e9cd819 100644 --- a/tests/core/test_core.py +++ b/tests/core/test_core.py @@ -15,7 +15,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - +from contextlib import suppress from datetime import timedelta from time import sleep @@ -89,29 +89,24 @@ def test_task_fail_duration(self, dag_maker): ) dag_maker.create_dagrun() session = settings.Session() - try: - op1.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) - except Exception: - pass - try: + op1.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) + with suppress(AirflowTaskTimeout): op2.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) - except Exception: - pass op1_fails = ( session.query(TaskFail) .filter( TaskFail.task_id == 'pass_sleepy', TaskFail.dag_id == dag.dag_id, - TaskFail.task_instance.execution_date == DEFAULT_DATE, + TaskInstance.execution_date == DEFAULT_DATE, ) .all() ) op2_fails = ( session.query(TaskFail) .filter( - TaskFail.task_id == 'pass_sleepy', + TaskFail.task_id == 'fail_sleepy', TaskFail.dag_id == dag.dag_id, - TaskFail.task_instance.execution_date == DEFAULT_DATE, + TaskInstance.execution_date == DEFAULT_DATE, ) .all() ) From 93f8e12cf0dbc0e77732c2e42c64b71117d2126c Mon Sep 17 00:00:00 2001 From: Daniel Standish <15932138+dstandish@users.noreply.github.com> Date: Tue, 15 Mar 2022 14:52:25 -0700 Subject: [PATCH 07/17] remove unused --- airflow/models/taskfail.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/airflow/models/taskfail.py b/airflow/models/taskfail.py index dfdf000a9f9b8..4009416593df0 100644 --- a/airflow/models/taskfail.py +++ b/airflow/models/taskfail.py @@ -16,16 +16,12 @@ # specific language governing permissions and limitations # under the License. """Taskfail tracks the failed run durations of each task instance""" -from typing import TYPE_CHECKING from sqlalchemy import Column, ForeignKeyConstraint, Integer from airflow.models.base import Base, StringID from airflow.utils.sqlalchemy import UtcDateTime -if TYPE_CHECKING: - pass - class TaskFail(Base): """TaskFail tracks the failed run durations of each task instance.""" From 78c5e20ed9b7c9e4cd43ae3c95286cab07babb39 Mon Sep 17 00:00:00 2001 From: Daniel Standish <15932138+dstandish@users.noreply.github.com> Date: Tue, 15 Mar 2022 15:08:09 -0700 Subject: [PATCH 08/17] use DagRun --- airflow/www/views.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/airflow/www/views.py b/airflow/www/views.py index bba8155a56c07..96f2df5c3133f 100644 --- a/airflow/www/views.py +++ b/airflow/www/views.py @@ -2847,8 +2847,8 @@ def duration(self, dag_id, session=None): session.query(TaskFail) .filter( TaskFail.dag_id == dag.dag_id, - TaskInstance.execution_date >= min_date, - TaskInstance.execution_date <= base_date, + DagRun.execution_date >= min_date, + DagRun.execution_date <= base_date, TaskFail.task_id.in_([t.task_id for t in dag.tasks]), ) .all() @@ -3185,9 +3185,7 @@ def gantt(self, dag_id, session=None): .order_by(TaskInstance.start_date) ) - ti_fails = session.query(TaskFail).filter( - TaskInstance.execution_date == dttm, TaskFail.dag_id == dag_id - ) + ti_fails = session.query(TaskFail).filter(DagRun.execution_date == dttm, TaskFail.dag_id == dag_id) tasks = [] for ti in tis: From fdf718006a19c550d52b8c71d27c71bed14ee8f2 Mon Sep 17 00:00:00 2001 From: Daniel Standish <15932138+dstandish@users.noreply.github.com> Date: Tue, 15 Mar 2022 15:08:47 -0700 Subject: [PATCH 09/17] fix nullable --- airflow/models/taskfail.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/airflow/models/taskfail.py b/airflow/models/taskfail.py index 4009416593df0..3eb1c1148d8dd 100644 --- a/airflow/models/taskfail.py +++ b/airflow/models/taskfail.py @@ -29,10 +29,10 @@ class TaskFail(Base): __tablename__ = "task_fail" id = Column(Integer, primary_key=True) - task_id = Column(StringID()) - dag_id = Column(StringID()) - run_id = Column(StringID()) - map_index = Column(Integer, server_default='-1') + task_id = Column(StringID(), nullable=False) + dag_id = Column(StringID(), nullable=False) + run_id = Column(StringID(), nullable=False) + map_index = Column(Integer, server_default='-1', nullable=False) start_date = Column(UtcDateTime) end_date = Column(UtcDateTime) duration = Column(Integer) From 141164e6e7a8fd1a96c4f35e8e793258261ad875 Mon Sep 17 00:00:00 2001 From: Daniel Standish <15932138+dstandish@users.noreply.github.com> Date: Fri, 18 Mar 2022 10:24:08 -0700 Subject: [PATCH 10/17] fix mysql migration --- .../0105_48925b2719cb_add_map_index_to_taskfail.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/airflow/migrations/versions/0105_48925b2719cb_add_map_index_to_taskfail.py b/airflow/migrations/versions/0105_48925b2719cb_add_map_index_to_taskfail.py index db91b6afd9bc2..3550bd4575120 100644 --- a/airflow/migrations/versions/0105_48925b2719cb_add_map_index_to_taskfail.py +++ b/airflow/migrations/versions/0105_48925b2719cb_add_map_index_to_taskfail.py @@ -19,8 +19,11 @@ """Add map_index to TaskFail Drop index idx_task_fail_dag_task_date +Add run_id and map_index +Drop execution_date Add FK `task_fail_ti_fkey`: TF -> TI ([dag_id, task_id, run_id, map_index]) - +Change primary key from [id, dag_id, task_id, execution_date] to [id] + * since we are changing the PK have to handle mysql autoincrement column `id` separately Revision ID: 48925b2719cb @@ -32,6 +35,7 @@ import sqlalchemy as sa from alembic import op +from sqlalchemy import Integer from sqlalchemy.sql import ColumnElement, Update, and_, select from airflow.migrations.db_types import TIMESTAMP, StringID @@ -126,11 +130,17 @@ def upgrade(): constraints = get_mssql_table_constraints(op.get_bind(), 'task_fail') pk, _ = constraints['PRIMARY KEY'].popitem() batch_op.drop_constraint(pk, type_='primary') + elif dialect_name == 'mysql': # have to handle mysql autoincrement column separately + batch_op.alter_column('id', type_=Integer, autoincrement=False, nullable=False) + batch_op.drop_constraint('task_fail_pkey', type_='primary') elif dialect_name != 'sqlite': # sqlite PK is managed by SQLA batch_op.drop_constraint('task_fail_pkey', type_='primary') + batch_op.alter_column('id', existing_type=StringID(), existing_nullable=True, nullable=False) batch_op.alter_column('run_id', existing_type=StringID(), existing_nullable=True, nullable=False) batch_op.drop_column('execution_date') batch_op.create_primary_key('task_fail_pkey', ['id']) + if dialect_name == 'mysql': # have to handle mysql autoincrement column separately + batch_op.alter_column('id', type_=Integer, autoincrement=True, nullable=False) batch_op.create_foreign_key( 'task_fail_ti_fkey', 'task_instance', From 09e4c073661b340d235df0a496801144745cb297 Mon Sep 17 00:00:00 2001 From: Daniel Standish <15932138+dstandish@users.noreply.github.com> Date: Fri, 18 Mar 2022 13:27:22 -0700 Subject: [PATCH 11/17] don't mess with id col --- ..._48925b2719cb_add_map_index_to_taskfail.py | 19 ------------------- 1 file changed, 19 deletions(-) diff --git a/airflow/migrations/versions/0105_48925b2719cb_add_map_index_to_taskfail.py b/airflow/migrations/versions/0105_48925b2719cb_add_map_index_to_taskfail.py index 3550bd4575120..303e1eda58241 100644 --- a/airflow/migrations/versions/0105_48925b2719cb_add_map_index_to_taskfail.py +++ b/airflow/migrations/versions/0105_48925b2719cb_add_map_index_to_taskfail.py @@ -22,8 +22,6 @@ Add run_id and map_index Drop execution_date Add FK `task_fail_ti_fkey`: TF -> TI ([dag_id, task_id, run_id, map_index]) -Change primary key from [id, dag_id, task_id, execution_date] to [id] - * since we are changing the PK have to handle mysql autoincrement column `id` separately Revision ID: 48925b2719cb @@ -35,11 +33,9 @@ import sqlalchemy as sa from alembic import op -from sqlalchemy import Integer from sqlalchemy.sql import ColumnElement, Update, and_, select from airflow.migrations.db_types import TIMESTAMP, StringID -from airflow.migrations.utils import get_mssql_table_constraints # revision identifiers, used by Alembic. revision = '48925b2719cb' @@ -126,21 +122,8 @@ def upgrade(): ) op.execute(update_query) with op.batch_alter_table('task_fail') as batch_op: - if dialect_name == 'mssql': - constraints = get_mssql_table_constraints(op.get_bind(), 'task_fail') - pk, _ = constraints['PRIMARY KEY'].popitem() - batch_op.drop_constraint(pk, type_='primary') - elif dialect_name == 'mysql': # have to handle mysql autoincrement column separately - batch_op.alter_column('id', type_=Integer, autoincrement=False, nullable=False) - batch_op.drop_constraint('task_fail_pkey', type_='primary') - elif dialect_name != 'sqlite': # sqlite PK is managed by SQLA - batch_op.drop_constraint('task_fail_pkey', type_='primary') - batch_op.alter_column('id', existing_type=StringID(), existing_nullable=True, nullable=False) batch_op.alter_column('run_id', existing_type=StringID(), existing_nullable=True, nullable=False) batch_op.drop_column('execution_date') - batch_op.create_primary_key('task_fail_pkey', ['id']) - if dialect_name == 'mysql': # have to handle mysql autoincrement column separately - batch_op.alter_column('id', type_=Integer, autoincrement=True, nullable=False) batch_op.create_foreign_key( 'task_fail_ti_fkey', 'task_instance', @@ -165,8 +148,6 @@ def downgrade(): batch_op.alter_column('execution_date', existing_type=TIMESTAMP, nullable=False) if dialect_name != 'sqlite': batch_op.drop_constraint('task_fail_ti_fkey', type_='foreignkey') - batch_op.drop_constraint('task_fail_pkey', type_='primary') - batch_op.create_primary_key('task_fail_pkey', ['dag_id', 'task_id', 'execution_date']) batch_op.drop_column('map_index', mssql_drop_default=True) batch_op.drop_column('run_id') op.create_index( From a64901f54dc512072062befc89f6a9671a5e1a77 Mon Sep 17 00:00:00 2001 From: Daniel Standish <15932138+dstandish@users.noreply.github.com> Date: Fri, 18 Mar 2022 17:54:59 -0700 Subject: [PATCH 12/17] add TaskFail to run_id check and some comments --- airflow/utils/db.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/airflow/utils/db.py b/airflow/utils/db.py index acaf953bd7efe..67ef672eaa03e 100644 --- a/airflow/utils/db.py +++ b/airflow/utils/db.py @@ -963,13 +963,20 @@ def _from_name(from_) -> str: def check_task_tables_without_matching_dagruns(session: Session) -> Iterable[str]: + """ + Starting in Airflow 2.2, we began a process of replacing `execution_date` with `run_id` + in many tables. + Here we go through each table and look for records that can't be mapped to a dag run. + When we find such "dangling" rows we back them up in a special table and delete them + from the main table. + """ import sqlalchemy.schema from sqlalchemy import and_, outerjoin from airflow.models.renderedtifields import RenderedTaskInstanceFields metadata = sqlalchemy.schema.MetaData(session.bind) - models_to_dagrun: List[Any] = [TaskInstance, TaskReschedule, XCom, RenderedTaskInstanceFields] + models_to_dagrun: List[Any] = [RenderedTaskInstanceFields, TaskInstance, TaskFail, TaskReschedule, XCom] for model in models_to_dagrun + [DagRun]: try: metadata.reflect( @@ -1002,6 +1009,7 @@ def check_task_tables_without_matching_dagruns(session: Session) -> Iterable[str if "run_id" in source_table.columns: continue + # find rows in source table which don't have a matching dag run source_to_dag_run_join_cond = and_( source_table.c.dag_id == dagrun_table.c.dag_id, source_table.c.execution_date == dagrun_table.c.execution_date, From 95a896384d767865fc8d1c5bed3558ac52ed28f7 Mon Sep 17 00:00:00 2001 From: Daniel Standish <15932138+dstandish@users.noreply.github.com> Date: Fri, 18 Mar 2022 21:17:22 -0700 Subject: [PATCH 13/17] add logic to check for dupes before creating fk --- airflow/utils/db.py | 64 +++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 62 insertions(+), 2 deletions(-) diff --git a/airflow/utils/db.py b/airflow/utils/db.py index 67ef672eaa03e..9cfaa49900c2b 100644 --- a/airflow/utils/db.py +++ b/airflow/utils/db.py @@ -25,7 +25,7 @@ from typing import TYPE_CHECKING, Any, Callable, Iterable, List, Optional, Tuple from bcrypt import warnings -from sqlalchemy import Table, exc, func, inspect, or_, text +from sqlalchemy import Table, column, exc, func, inspect, literal, or_, table, text from sqlalchemy.orm.session import Session from airflow import settings @@ -827,6 +827,65 @@ def check_conn_id_duplicates(session: Session) -> Iterable[str]: ) +def reflect_tables(models, session): + import sqlalchemy.schema + + metadata = sqlalchemy.schema.MetaData(session.bind) + + for model in models: + try: + metadata.reflect(only=[model.__tablename__], extend_existing=True, resolve_fks=False) + except exc.InvalidRequestError: + continue + return metadata + + +def check_task_fail_for_duplicates(session): + """Check that there are no duplicates in the task_fail table before creating FK""" + metadata = reflect_tables([TaskFail], session) + task_fail = metadata.tables.get(TaskFail.__tablename__) # type: ignore + if task_fail is None: # table not there + return + if "run_id" in task_fail.columns: # upgrade already applied + return + yield from check_table_for_duplicates( + table_name=task_fail.name, + uniqueness=['dag_id', 'task_id', 'execution_date'], + session=session, + ) + + +def check_table_for_duplicates(table_name: str, uniqueness: List[str], session: Session) -> Iterable[str]: + """ + Check table for duplicates, given a list of columns which define the uniqueness of the table. + + Call from ``run_duplicates_checks``. + + :param table_name: table name to check + :param uniqueness: uniqueness constraint to evaluate against + :param session: session of the sqlalchemy + :rtype: str + """ + table_obj = table(table_name, *[column(x) for x in uniqueness]) + dupe_count = 0 + try: + subquery = ( + session.query(table_obj, func.count().label('dupe_count')) + .group_by(*[text(x) for x in uniqueness]) + .having(func.count() > literal(1)) + .subquery() + ) + dupe_count = session.query(func.sum(subquery.c.dupe_count)).scalar() + except (exc.OperationalError, exc.ProgrammingError): + # fallback if tables hasn't been created yet + session.rollback() + if dupe_count: + yield ( + f"Found {dupe_count} duplicate records in table {table_name}. You must de-dupe these " + f"records before upgrading. The uniqueness constraint for this table is {uniqueness!r}" + ) + + def check_conn_type_null(session: Session) -> Iterable[str]: """ Check nullable conn_type column in Connection table @@ -1053,13 +1112,14 @@ def _check_migration_errors(session: Session = NEW_SESSION) -> Iterable[str]: :rtype: list[str] """ check_functions: Tuple[Callable[..., Iterable[str]], ...] = ( + check_task_fail_for_duplicates, check_conn_id_duplicates, check_conn_type_null, check_run_id_null, check_task_tables_without_matching_dagruns, ) for check_fn in check_functions: - yield from check_fn(session) + yield from check_fn(session=session) # Ensure there is no "active" transaction. Seems odd, but without this MSSQL can hang session.commit() From d9ef5580c58d631673cfcc0fe77c0d1769db6179 Mon Sep 17 00:00:00 2001 From: Daniel Standish <15932138+dstandish@users.noreply.github.com> Date: Fri, 18 Mar 2022 23:08:30 -0700 Subject: [PATCH 14/17] add docstring --- airflow/utils/db.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/airflow/utils/db.py b/airflow/utils/db.py index 9cfaa49900c2b..f9c7906f44893 100644 --- a/airflow/utils/db.py +++ b/airflow/utils/db.py @@ -828,6 +828,13 @@ def check_conn_id_duplicates(session: Session) -> Iterable[str]: def reflect_tables(models, session): + """ + When running checks prior to upgrades, we use reflection to determine current state of the + database. + This function gets the current state of each table in the set of models provided and returns + a SqlAlchemy metadata object containing them. + """ + import sqlalchemy.schema metadata = sqlalchemy.schema.MetaData(session.bind) From 3541e2a73a64cf6bc77fe16f6b6e2208fef9f9cc Mon Sep 17 00:00:00 2001 From: Daniel Standish <15932138+dstandish@users.noreply.github.com> Date: Sat, 19 Mar 2022 08:37:29 -0700 Subject: [PATCH 15/17] DagRun.exec-dat --- tests/core/test_core.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/core/test_core.py b/tests/core/test_core.py index 5ed786e9cd819..3d6aabb9c7f7a 100644 --- a/tests/core/test_core.py +++ b/tests/core/test_core.py @@ -23,7 +23,7 @@ from airflow import settings from airflow.exceptions import AirflowTaskTimeout -from airflow.models import TaskFail, TaskInstance +from airflow.models import DagRun, TaskFail, TaskInstance from airflow.models.baseoperator import BaseOperator from airflow.operators.bash import BashOperator from airflow.operators.dummy import DummyOperator @@ -97,7 +97,7 @@ def test_task_fail_duration(self, dag_maker): .filter( TaskFail.task_id == 'pass_sleepy', TaskFail.dag_id == dag.dag_id, - TaskInstance.execution_date == DEFAULT_DATE, + DagRun.execution_date == DEFAULT_DATE, ) .all() ) @@ -106,7 +106,7 @@ def test_task_fail_duration(self, dag_maker): .filter( TaskFail.task_id == 'fail_sleepy', TaskFail.dag_id == dag.dag_id, - TaskInstance.execution_date == DEFAULT_DATE, + DagRun.execution_date == DEFAULT_DATE, ) .all() ) From ebb48e2a6cae21b5e4be432f49e1c4b953190189 Mon Sep 17 00:00:00 2001 From: Daniel Standish <15932138+dstandish@users.noreply.github.com> Date: Sat, 19 Mar 2022 08:37:51 -0700 Subject: [PATCH 16/17] remove default for map_index --- airflow/models/taskfail.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/airflow/models/taskfail.py b/airflow/models/taskfail.py index 3eb1c1148d8dd..4266179a3496e 100644 --- a/airflow/models/taskfail.py +++ b/airflow/models/taskfail.py @@ -32,7 +32,7 @@ class TaskFail(Base): task_id = Column(StringID(), nullable=False) dag_id = Column(StringID(), nullable=False) run_id = Column(StringID(), nullable=False) - map_index = Column(Integer, server_default='-1', nullable=False) + map_index = Column(Integer, nullable=False) start_date = Column(UtcDateTime) end_date = Column(UtcDateTime) duration = Column(Integer) @@ -51,7 +51,7 @@ class TaskFail(Base): ), ) - def __init__(self, task, run_id, start_date, end_date, map_index=None): + def __init__(self, task, run_id, start_date, end_date, map_index): self.dag_id = task.dag_id self.task_id = task.task_id self.run_id = run_id From 0cd6301c59885e98705ef179ad429034d1f8ad2a Mon Sep 17 00:00:00 2001 From: Daniel Standish <15932138+dstandish@users.noreply.github.com> Date: Sun, 20 Mar 2022 20:57:51 -0700 Subject: [PATCH 17/17] fix missing map index --- tests/api/common/test_delete_dag.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/tests/api/common/test_delete_dag.py b/tests/api/common/test_delete_dag.py index 9b0cf94dd2c6d..5765cbd04565c 100644 --- a/tests/api/common/test_delete_dag.py +++ b/tests/api/common/test_delete_dag.py @@ -96,7 +96,15 @@ def setup_dag_models(self, for_sub_dag=False): event="varimport", ) ) - session.add(TF(task=task, run_id=ti.run_id, start_date=test_date, end_date=test_date)) + session.add( + TF( + task=task, + run_id=ti.run_id, + start_date=test_date, + end_date=test_date, + map_index=ti.map_index, + ) + ) session.add( TR( task=ti.task,