diff --git a/airflow/migrations/versions/0105_48925b2719cb_add_map_index_to_taskfail.py b/airflow/migrations/versions/0105_48925b2719cb_add_map_index_to_taskfail.py new file mode 100644 index 0000000000000..303e1eda58241 --- /dev/null +++ b/airflow/migrations/versions/0105_48925b2719cb_add_map_index_to_taskfail.py @@ -0,0 +1,158 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Add map_index to TaskFail + +Drop index idx_task_fail_dag_task_date +Add run_id and map_index +Drop execution_date +Add FK `task_fail_ti_fkey`: TF -> TI ([dag_id, task_id, run_id, map_index]) + + +Revision ID: 48925b2719cb +Revises: 4eaab2fe6582 +Create Date: 2022-03-14 10:31:11.220720 +""" + +from typing import List + +import sqlalchemy as sa +from alembic import op +from sqlalchemy.sql import ColumnElement, Update, and_, select + +from airflow.migrations.db_types import TIMESTAMP, StringID + +# revision identifiers, used by Alembic. +revision = '48925b2719cb' +down_revision = '4eaab2fe6582' +branch_labels = None +depends_on = None +airflow_version = '2.3.0' + +ID_LEN = 250 + + +def tables(): + global task_instance, task_fail, dag_run + metadata = sa.MetaData() + task_instance = sa.Table( + 'task_instance', + metadata, + sa.Column('task_id', StringID()), + sa.Column('dag_id', StringID()), + sa.Column('run_id', StringID()), + sa.Column('map_index', sa.Integer(), server_default='-1'), + sa.Column('execution_date', TIMESTAMP), + ) + task_fail = sa.Table( + 'task_fail', + metadata, + sa.Column('dag_id', StringID()), + sa.Column('task_id', StringID()), + sa.Column('run_id', StringID()), + sa.Column('map_index', StringID()), + sa.Column('execution_date', TIMESTAMP), + ) + dag_run = sa.Table( + 'dag_run', + metadata, + sa.Column('dag_id', StringID()), + sa.Column('run_id', StringID()), + sa.Column('execution_date', TIMESTAMP), + ) + + +def _update_value_from_dag_run( + dialect_name: str, + target_table: sa.Table, + target_column: ColumnElement, + join_columns: List[str], +) -> Update: + """ + Grabs a value from the source table ``dag_run`` and updates target with this value. + :param dialect_name: dialect in use + :param target_table: the table to update + :param target_column: the column to update + """ + # for run_id: dag_id, execution_date + # otherwise: dag_id, run_id + condition_list = [getattr(dag_run.c, x) == getattr(target_table.c, x) for x in join_columns] + condition = and_(*condition_list) + + if dialect_name == "sqlite": + # Most SQLite versions don't support multi table update (and SQLA doesn't know about it anyway), so we + # need to do a Correlated subquery update + sub_q = select([dag_run.c[target_column.name]]).where(condition) + + return target_table.update().values({target_column: sub_q}) + else: + return target_table.update().where(condition).values({target_column: dag_run.c[target_column.name]}) + + +def upgrade(): + tables() + dialect_name = op.get_bind().dialect.name + + op.drop_index('idx_task_fail_dag_task_date', table_name='task_fail') + + with op.batch_alter_table('task_fail') as batch_op: + batch_op.add_column(sa.Column('map_index', sa.Integer(), server_default='-1', nullable=False)) + batch_op.add_column(sa.Column('run_id', type_=StringID(), nullable=True)) + + update_query = _update_value_from_dag_run( + dialect_name=dialect_name, + target_table=task_fail, + target_column=task_fail.c.run_id, + join_columns=['dag_id', 'execution_date'], + ) + op.execute(update_query) + with op.batch_alter_table('task_fail') as batch_op: + batch_op.alter_column('run_id', existing_type=StringID(), existing_nullable=True, nullable=False) + batch_op.drop_column('execution_date') + batch_op.create_foreign_key( + 'task_fail_ti_fkey', + 'task_instance', + ['dag_id', 'task_id', 'run_id', 'map_index'], + ['dag_id', 'task_id', 'run_id', 'map_index'], + ondelete='CASCADE', + ) + + +def downgrade(): + tables() + dialect_name = op.get_bind().dialect.name + op.add_column('task_fail', sa.Column('execution_date', TIMESTAMP, nullable=True)) + update_query = _update_value_from_dag_run( + dialect_name=dialect_name, + target_table=task_fail, + target_column=task_fail.c.execution_date, + join_columns=['dag_id', 'run_id'], + ) + op.execute(update_query) + with op.batch_alter_table('task_fail', copy_from=task_fail) as batch_op: + batch_op.alter_column('execution_date', existing_type=TIMESTAMP, nullable=False) + if dialect_name != 'sqlite': + batch_op.drop_constraint('task_fail_ti_fkey', type_='foreignkey') + batch_op.drop_column('map_index', mssql_drop_default=True) + batch_op.drop_column('run_id') + op.create_index( + index_name='idx_task_fail_dag_task_date', + table_name='task_fail', + columns=['dag_id', 'task_id', 'execution_date'], + unique=False, + ) diff --git a/airflow/models/taskfail.py b/airflow/models/taskfail.py index 23d3086e2d013..4266179a3496e 100644 --- a/airflow/models/taskfail.py +++ b/airflow/models/taskfail.py @@ -16,9 +16,10 @@ # specific language governing permissions and limitations # under the License. """Taskfail tracks the failed run durations of each task instance""" -from sqlalchemy import Column, Index, Integer, String -from airflow.models.base import COLLATION_ARGS, ID_LEN, Base +from sqlalchemy import Column, ForeignKeyConstraint, Integer + +from airflow.models.base import Base, StringID from airflow.utils.sqlalchemy import UtcDateTime @@ -28,22 +29,42 @@ class TaskFail(Base): __tablename__ = "task_fail" id = Column(Integer, primary_key=True) - task_id = Column(String(ID_LEN, **COLLATION_ARGS), nullable=False) - dag_id = Column(String(ID_LEN, **COLLATION_ARGS), nullable=False) - execution_date = Column(UtcDateTime, nullable=False) + task_id = Column(StringID(), nullable=False) + dag_id = Column(StringID(), nullable=False) + run_id = Column(StringID(), nullable=False) + map_index = Column(Integer, nullable=False) start_date = Column(UtcDateTime) end_date = Column(UtcDateTime) duration = Column(Integer) - __table_args__ = (Index('idx_task_fail_dag_task_date', dag_id, task_id, execution_date, unique=False),) + __table_args__ = ( + ForeignKeyConstraint( + [dag_id, task_id, run_id, map_index], + [ + "task_instance.dag_id", + "task_instance.task_id", + "task_instance.run_id", + "task_instance.map_index", + ], + name='task_fail_ti_fkey', + ondelete="CASCADE", + ), + ) - def __init__(self, task, execution_date, start_date, end_date): + def __init__(self, task, run_id, start_date, end_date, map_index): self.dag_id = task.dag_id self.task_id = task.task_id - self.execution_date = execution_date + self.run_id = run_id + self.map_index = map_index self.start_date = start_date self.end_date = end_date if self.end_date and self.start_date: self.duration = int((self.end_date - self.start_date).total_seconds()) else: self.duration = None + + def __repr__(self): + prefix = f"<{self.__class__.__name__}: {self.dag_id}.{self.task_id} {self.run_id}" + if self.map_index != -1: + prefix += f" map_index={self.map_index}" + return prefix + '>' diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index f0fd779f5f49c..0f9b08a5cc60e 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -1799,8 +1799,15 @@ def handle_failure( session.add(Log(State.FAILED, self)) # Log failure duration - dag_run = self.get_dagrun(session=session) # self.dag_run not populated by refresh_from_db - session.add(TaskFail(task, dag_run.execution_date, self.start_date, self.end_date)) + session.add( + TaskFail( + task=task, + run_id=self.run_id, + start_date=self.start_date, + end_date=self.end_date, + map_index=self.map_index, + ) + ) self.clear_next_method_args() diff --git a/airflow/utils/db.py b/airflow/utils/db.py index acaf953bd7efe..f9c7906f44893 100644 --- a/airflow/utils/db.py +++ b/airflow/utils/db.py @@ -25,7 +25,7 @@ from typing import TYPE_CHECKING, Any, Callable, Iterable, List, Optional, Tuple from bcrypt import warnings -from sqlalchemy import Table, exc, func, inspect, or_, text +from sqlalchemy import Table, column, exc, func, inspect, literal, or_, table, text from sqlalchemy.orm.session import Session from airflow import settings @@ -827,6 +827,72 @@ def check_conn_id_duplicates(session: Session) -> Iterable[str]: ) +def reflect_tables(models, session): + """ + When running checks prior to upgrades, we use reflection to determine current state of the + database. + This function gets the current state of each table in the set of models provided and returns + a SqlAlchemy metadata object containing them. + """ + + import sqlalchemy.schema + + metadata = sqlalchemy.schema.MetaData(session.bind) + + for model in models: + try: + metadata.reflect(only=[model.__tablename__], extend_existing=True, resolve_fks=False) + except exc.InvalidRequestError: + continue + return metadata + + +def check_task_fail_for_duplicates(session): + """Check that there are no duplicates in the task_fail table before creating FK""" + metadata = reflect_tables([TaskFail], session) + task_fail = metadata.tables.get(TaskFail.__tablename__) # type: ignore + if task_fail is None: # table not there + return + if "run_id" in task_fail.columns: # upgrade already applied + return + yield from check_table_for_duplicates( + table_name=task_fail.name, + uniqueness=['dag_id', 'task_id', 'execution_date'], + session=session, + ) + + +def check_table_for_duplicates(table_name: str, uniqueness: List[str], session: Session) -> Iterable[str]: + """ + Check table for duplicates, given a list of columns which define the uniqueness of the table. + + Call from ``run_duplicates_checks``. + + :param table_name: table name to check + :param uniqueness: uniqueness constraint to evaluate against + :param session: session of the sqlalchemy + :rtype: str + """ + table_obj = table(table_name, *[column(x) for x in uniqueness]) + dupe_count = 0 + try: + subquery = ( + session.query(table_obj, func.count().label('dupe_count')) + .group_by(*[text(x) for x in uniqueness]) + .having(func.count() > literal(1)) + .subquery() + ) + dupe_count = session.query(func.sum(subquery.c.dupe_count)).scalar() + except (exc.OperationalError, exc.ProgrammingError): + # fallback if tables hasn't been created yet + session.rollback() + if dupe_count: + yield ( + f"Found {dupe_count} duplicate records in table {table_name}. You must de-dupe these " + f"records before upgrading. The uniqueness constraint for this table is {uniqueness!r}" + ) + + def check_conn_type_null(session: Session) -> Iterable[str]: """ Check nullable conn_type column in Connection table @@ -963,13 +1029,20 @@ def _from_name(from_) -> str: def check_task_tables_without_matching_dagruns(session: Session) -> Iterable[str]: + """ + Starting in Airflow 2.2, we began a process of replacing `execution_date` with `run_id` + in many tables. + Here we go through each table and look for records that can't be mapped to a dag run. + When we find such "dangling" rows we back them up in a special table and delete them + from the main table. + """ import sqlalchemy.schema from sqlalchemy import and_, outerjoin from airflow.models.renderedtifields import RenderedTaskInstanceFields metadata = sqlalchemy.schema.MetaData(session.bind) - models_to_dagrun: List[Any] = [TaskInstance, TaskReschedule, XCom, RenderedTaskInstanceFields] + models_to_dagrun: List[Any] = [RenderedTaskInstanceFields, TaskInstance, TaskFail, TaskReschedule, XCom] for model in models_to_dagrun + [DagRun]: try: metadata.reflect( @@ -1002,6 +1075,7 @@ def check_task_tables_without_matching_dagruns(session: Session) -> Iterable[str if "run_id" in source_table.columns: continue + # find rows in source table which don't have a matching dag run source_to_dag_run_join_cond = and_( source_table.c.dag_id == dagrun_table.c.dag_id, source_table.c.execution_date == dagrun_table.c.execution_date, @@ -1045,13 +1119,14 @@ def _check_migration_errors(session: Session = NEW_SESSION) -> Iterable[str]: :rtype: list[str] """ check_functions: Tuple[Callable[..., Iterable[str]], ...] = ( + check_task_fail_for_duplicates, check_conn_id_duplicates, check_conn_type_null, check_run_id_null, check_task_tables_without_matching_dagruns, ) for check_fn in check_functions: - yield from check_fn(session) + yield from check_fn(session=session) # Ensure there is no "active" transaction. Seems odd, but without this MSSQL can hang session.commit() diff --git a/airflow/www/views.py b/airflow/www/views.py index 1d74996bf48a6..96f2df5c3133f 100644 --- a/airflow/www/views.py +++ b/airflow/www/views.py @@ -2847,8 +2847,8 @@ def duration(self, dag_id, session=None): session.query(TaskFail) .filter( TaskFail.dag_id == dag.dag_id, - TaskFail.execution_date >= min_date, - TaskFail.execution_date <= base_date, + DagRun.execution_date >= min_date, + DagRun.execution_date <= base_date, TaskFail.task_id.in_([t.task_id for t in dag.tasks]), ) .all() @@ -3185,11 +3185,7 @@ def gantt(self, dag_id, session=None): .order_by(TaskInstance.start_date) ) - ti_fails = ( - session.query(TaskFail) - .join(DagRun, DagRun.execution_date == TaskFail.execution_date) - .filter(DagRun.execution_date == dttm, TaskFail.dag_id == dag_id) - ) + ti_fails = session.query(TaskFail).filter(DagRun.execution_date == dttm, TaskFail.dag_id == dag_id) tasks = [] for ti in tis: diff --git a/docs/apache-airflow/migrations-ref.rst b/docs/apache-airflow/migrations-ref.rst index 9dfd7dd838b34..de8221620c6a8 100644 --- a/docs/apache-airflow/migrations-ref.rst +++ b/docs/apache-airflow/migrations-ref.rst @@ -25,7 +25,9 @@ Here's the list of all the Database Migrations that are executed via when you ru .. Beginning of auto-generated table +---------------------------------+-------------------+-------------+--------------------------------------------------------------+ -| ``4eaab2fe6582`` (head) | ``c97c2ab6aa23`` | ``2.3.0`` | Migrate RTIF to use run_id and map_index | +| ``48925b2719cb`` (head) | ``4eaab2fe6582`` | ``2.3.0`` | Add map_index to TaskFail | ++---------------------------------+-------------------+-------------+--------------------------------------------------------------+ +| ``4eaab2fe6582`` | ``c97c2ab6aa23`` | ``2.3.0`` | Migrate RTIF to use run_id and map_index | +---------------------------------+-------------------+-------------+--------------------------------------------------------------+ | ``c97c2ab6aa23`` | ``c306b5b5ae4a`` | ``2.3.0`` | add callback request table | +---------------------------------+-------------------+-------------+--------------------------------------------------------------+ diff --git a/tests/api/common/test_delete_dag.py b/tests/api/common/test_delete_dag.py index d9dc0b0a01c7f..5765cbd04565c 100644 --- a/tests/api/common/test_delete_dag.py +++ b/tests/api/common/test_delete_dag.py @@ -96,7 +96,15 @@ def setup_dag_models(self, for_sub_dag=False): event="varimport", ) ) - session.add(TF(task=task, execution_date=test_date, start_date=test_date, end_date=test_date)) + session.add( + TF( + task=task, + run_id=ti.run_id, + start_date=test_date, + end_date=test_date, + map_index=ti.map_index, + ) + ) session.add( TR( task=ti.task, diff --git a/tests/core/test_core.py b/tests/core/test_core.py index 91d4e134d6e7a..3d6aabb9c7f7a 100644 --- a/tests/core/test_core.py +++ b/tests/core/test_core.py @@ -15,7 +15,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - +from contextlib import suppress from datetime import timedelta from time import sleep @@ -23,7 +23,7 @@ from airflow import settings from airflow.exceptions import AirflowTaskTimeout -from airflow.models import TaskFail, TaskInstance +from airflow.models import DagRun, TaskFail, TaskInstance from airflow.models.baseoperator import BaseOperator from airflow.operators.bash import BashOperator from airflow.operators.dummy import DummyOperator @@ -89,22 +89,25 @@ def test_task_fail_duration(self, dag_maker): ) dag_maker.create_dagrun() session = settings.Session() - try: - op1.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) - except Exception: - pass - try: + op1.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) + with suppress(AirflowTaskTimeout): op2.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) - except Exception: - pass op1_fails = ( session.query(TaskFail) - .filter_by(task_id='pass_sleepy', dag_id=dag.dag_id, execution_date=DEFAULT_DATE) + .filter( + TaskFail.task_id == 'pass_sleepy', + TaskFail.dag_id == dag.dag_id, + DagRun.execution_date == DEFAULT_DATE, + ) .all() ) op2_fails = ( session.query(TaskFail) - .filter_by(task_id='fail_sleepy', dag_id=dag.dag_id, execution_date=DEFAULT_DATE) + .filter( + TaskFail.task_id == 'fail_sleepy', + TaskFail.dag_id == dag.dag_id, + DagRun.execution_date == DEFAULT_DATE, + ) .all() )