Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""Add map_index to TaskFail

Drop index idx_task_fail_dag_task_date
Add run_id and map_index
Drop execution_date
Add FK `task_fail_ti_fkey`: TF -> TI ([dag_id, task_id, run_id, map_index])


Revision ID: 48925b2719cb
Revises: 4eaab2fe6582
Create Date: 2022-03-14 10:31:11.220720
"""

from typing import List

import sqlalchemy as sa
from alembic import op
from sqlalchemy.sql import ColumnElement, Update, and_, select

from airflow.migrations.db_types import TIMESTAMP, StringID

# revision identifiers, used by Alembic.
revision = '48925b2719cb'
down_revision = '4eaab2fe6582'
branch_labels = None
depends_on = None
airflow_version = '2.3.0'

ID_LEN = 250


def tables():
global task_instance, task_fail, dag_run
metadata = sa.MetaData()
task_instance = sa.Table(
'task_instance',
metadata,
sa.Column('task_id', StringID()),
sa.Column('dag_id', StringID()),
sa.Column('run_id', StringID()),
sa.Column('map_index', sa.Integer(), server_default='-1'),
sa.Column('execution_date', TIMESTAMP),
)
task_fail = sa.Table(
'task_fail',
metadata,
sa.Column('dag_id', StringID()),
sa.Column('task_id', StringID()),
sa.Column('run_id', StringID()),
sa.Column('map_index', StringID()),
sa.Column('execution_date', TIMESTAMP),
)
dag_run = sa.Table(
'dag_run',
metadata,
sa.Column('dag_id', StringID()),
sa.Column('run_id', StringID()),
sa.Column('execution_date', TIMESTAMP),
)


def _update_value_from_dag_run(
dialect_name: str,
target_table: sa.Table,
target_column: ColumnElement,
join_columns: List[str],
) -> Update:
"""
Grabs a value from the source table ``dag_run`` and updates target with this value.
:param dialect_name: dialect in use
:param target_table: the table to update
:param target_column: the column to update
"""
# for run_id: dag_id, execution_date
# otherwise: dag_id, run_id
condition_list = [getattr(dag_run.c, x) == getattr(target_table.c, x) for x in join_columns]
condition = and_(*condition_list)

if dialect_name == "sqlite":
# Most SQLite versions don't support multi table update (and SQLA doesn't know about it anyway), so we
# need to do a Correlated subquery update
sub_q = select([dag_run.c[target_column.name]]).where(condition)

return target_table.update().values({target_column: sub_q})
else:
return target_table.update().where(condition).values({target_column: dag_run.c[target_column.name]})


def upgrade():
tables()
dialect_name = op.get_bind().dialect.name

op.drop_index('idx_task_fail_dag_task_date', table_name='task_fail')

with op.batch_alter_table('task_fail') as batch_op:
batch_op.add_column(sa.Column('map_index', sa.Integer(), server_default='-1', nullable=False))
batch_op.add_column(sa.Column('run_id', type_=StringID(), nullable=True))

update_query = _update_value_from_dag_run(
dialect_name=dialect_name,
target_table=task_fail,
target_column=task_fail.c.run_id,
join_columns=['dag_id', 'execution_date'],
)
op.execute(update_query)
with op.batch_alter_table('task_fail') as batch_op:
batch_op.alter_column('run_id', existing_type=StringID(), existing_nullable=True, nullable=False)
batch_op.drop_column('execution_date')
batch_op.create_foreign_key(
'task_fail_ti_fkey',
'task_instance',
['dag_id', 'task_id', 'run_id', 'map_index'],
['dag_id', 'task_id', 'run_id', 'map_index'],
ondelete='CASCADE',
)


def downgrade():
tables()
dialect_name = op.get_bind().dialect.name
op.add_column('task_fail', sa.Column('execution_date', TIMESTAMP, nullable=True))
update_query = _update_value_from_dag_run(
dialect_name=dialect_name,
target_table=task_fail,
target_column=task_fail.c.execution_date,
join_columns=['dag_id', 'run_id'],
)
op.execute(update_query)
with op.batch_alter_table('task_fail', copy_from=task_fail) as batch_op:
batch_op.alter_column('execution_date', existing_type=TIMESTAMP, nullable=False)
if dialect_name != 'sqlite':
batch_op.drop_constraint('task_fail_ti_fkey', type_='foreignkey')
batch_op.drop_column('map_index', mssql_drop_default=True)
batch_op.drop_column('run_id')
op.create_index(
index_name='idx_task_fail_dag_task_date',
table_name='task_fail',
columns=['dag_id', 'task_id', 'execution_date'],
unique=False,
)
37 changes: 29 additions & 8 deletions airflow/models/taskfail.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@
# specific language governing permissions and limitations
# under the License.
"""Taskfail tracks the failed run durations of each task instance"""
from sqlalchemy import Column, Index, Integer, String

from airflow.models.base import COLLATION_ARGS, ID_LEN, Base
from sqlalchemy import Column, ForeignKeyConstraint, Integer

from airflow.models.base import Base, StringID
from airflow.utils.sqlalchemy import UtcDateTime


Expand All @@ -28,22 +29,42 @@ class TaskFail(Base):
__tablename__ = "task_fail"

id = Column(Integer, primary_key=True)
task_id = Column(String(ID_LEN, **COLLATION_ARGS), nullable=False)
dag_id = Column(String(ID_LEN, **COLLATION_ARGS), nullable=False)
execution_date = Column(UtcDateTime, nullable=False)
task_id = Column(StringID(), nullable=False)
dag_id = Column(StringID(), nullable=False)
run_id = Column(StringID(), nullable=False)
map_index = Column(Integer, nullable=False)
start_date = Column(UtcDateTime)
end_date = Column(UtcDateTime)
duration = Column(Integer)

__table_args__ = (Index('idx_task_fail_dag_task_date', dag_id, task_id, execution_date, unique=False),)
__table_args__ = (
ForeignKeyConstraint(
[dag_id, task_id, run_id, map_index],
[
"task_instance.dag_id",
"task_instance.task_id",
"task_instance.run_id",
"task_instance.map_index",
],
name='task_fail_ti_fkey',
ondelete="CASCADE",
),
)

def __init__(self, task, execution_date, start_date, end_date):
def __init__(self, task, run_id, start_date, end_date, map_index):
self.dag_id = task.dag_id
self.task_id = task.task_id
self.execution_date = execution_date
self.run_id = run_id
self.map_index = map_index
self.start_date = start_date
self.end_date = end_date
if self.end_date and self.start_date:
self.duration = int((self.end_date - self.start_date).total_seconds())
else:
self.duration = None

def __repr__(self):
prefix = f"<{self.__class__.__name__}: {self.dag_id}.{self.task_id} {self.run_id}"
if self.map_index != -1:
prefix += f" map_index={self.map_index}"
return prefix + '>'
11 changes: 9 additions & 2 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -1799,8 +1799,15 @@ def handle_failure(
session.add(Log(State.FAILED, self))

# Log failure duration
dag_run = self.get_dagrun(session=session) # self.dag_run not populated by refresh_from_db
session.add(TaskFail(task, dag_run.execution_date, self.start_date, self.end_date))
session.add(
TaskFail(
task=task,
run_id=self.run_id,
start_date=self.start_date,
end_date=self.end_date,
map_index=self.map_index,
)
)

self.clear_next_method_args()

Expand Down
81 changes: 78 additions & 3 deletions airflow/utils/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from typing import TYPE_CHECKING, Any, Callable, Iterable, List, Optional, Tuple

from bcrypt import warnings
from sqlalchemy import Table, exc, func, inspect, or_, text
from sqlalchemy import Table, column, exc, func, inspect, literal, or_, table, text
from sqlalchemy.orm.session import Session

from airflow import settings
Expand Down Expand Up @@ -827,6 +827,72 @@ def check_conn_id_duplicates(session: Session) -> Iterable[str]:
)


def reflect_tables(models, session):
"""
When running checks prior to upgrades, we use reflection to determine current state of the
database.
This function gets the current state of each table in the set of models provided and returns
a SqlAlchemy metadata object containing them.
"""

import sqlalchemy.schema

metadata = sqlalchemy.schema.MetaData(session.bind)

for model in models:
try:
metadata.reflect(only=[model.__tablename__], extend_existing=True, resolve_fks=False)
except exc.InvalidRequestError:
continue
return metadata


def check_task_fail_for_duplicates(session):
"""Check that there are no duplicates in the task_fail table before creating FK"""
metadata = reflect_tables([TaskFail], session)
task_fail = metadata.tables.get(TaskFail.__tablename__) # type: ignore
if task_fail is None: # table not there
return
if "run_id" in task_fail.columns: # upgrade already applied
return
yield from check_table_for_duplicates(
table_name=task_fail.name,
uniqueness=['dag_id', 'task_id', 'execution_date'],
session=session,
)


def check_table_for_duplicates(table_name: str, uniqueness: List[str], session: Session) -> Iterable[str]:
"""
Check table for duplicates, given a list of columns which define the uniqueness of the table.

Call from ``run_duplicates_checks``.

:param table_name: table name to check
:param uniqueness: uniqueness constraint to evaluate against
:param session: session of the sqlalchemy
:rtype: str
"""
table_obj = table(table_name, *[column(x) for x in uniqueness])
dupe_count = 0
try:
subquery = (
session.query(table_obj, func.count().label('dupe_count'))
.group_by(*[text(x) for x in uniqueness])
.having(func.count() > literal(1))
.subquery()
)
dupe_count = session.query(func.sum(subquery.c.dupe_count)).scalar()
except (exc.OperationalError, exc.ProgrammingError):
# fallback if tables hasn't been created yet
session.rollback()
if dupe_count:
yield (
f"Found {dupe_count} duplicate records in table {table_name}. You must de-dupe these "
f"records before upgrading. The uniqueness constraint for this table is {uniqueness!r}"
)


def check_conn_type_null(session: Session) -> Iterable[str]:
"""
Check nullable conn_type column in Connection table
Expand Down Expand Up @@ -963,13 +1029,20 @@ def _from_name(from_) -> str:


def check_task_tables_without_matching_dagruns(session: Session) -> Iterable[str]:
"""
Starting in Airflow 2.2, we began a process of replacing `execution_date` with `run_id`
in many tables.
Here we go through each table and look for records that can't be mapped to a dag run.
When we find such "dangling" rows we back them up in a special table and delete them
from the main table.
"""
import sqlalchemy.schema
from sqlalchemy import and_, outerjoin

from airflow.models.renderedtifields import RenderedTaskInstanceFields

metadata = sqlalchemy.schema.MetaData(session.bind)
models_to_dagrun: List[Any] = [TaskInstance, TaskReschedule, XCom, RenderedTaskInstanceFields]
models_to_dagrun: List[Any] = [RenderedTaskInstanceFields, TaskInstance, TaskFail, TaskReschedule, XCom]
for model in models_to_dagrun + [DagRun]:
try:
metadata.reflect(
Expand Down Expand Up @@ -1002,6 +1075,7 @@ def check_task_tables_without_matching_dagruns(session: Session) -> Iterable[str
if "run_id" in source_table.columns:
continue

# find rows in source table which don't have a matching dag run
source_to_dag_run_join_cond = and_(
source_table.c.dag_id == dagrun_table.c.dag_id,
source_table.c.execution_date == dagrun_table.c.execution_date,
Expand Down Expand Up @@ -1045,13 +1119,14 @@ def _check_migration_errors(session: Session = NEW_SESSION) -> Iterable[str]:
:rtype: list[str]
"""
check_functions: Tuple[Callable[..., Iterable[str]], ...] = (
check_task_fail_for_duplicates,
check_conn_id_duplicates,
check_conn_type_null,
check_run_id_null,
check_task_tables_without_matching_dagruns,
)
for check_fn in check_functions:
yield from check_fn(session)
yield from check_fn(session=session)
# Ensure there is no "active" transaction. Seems odd, but without this MSSQL can hang
session.commit()

Expand Down
10 changes: 3 additions & 7 deletions airflow/www/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -2847,8 +2847,8 @@ def duration(self, dag_id, session=None):
session.query(TaskFail)
.filter(
TaskFail.dag_id == dag.dag_id,
TaskFail.execution_date >= min_date,
TaskFail.execution_date <= base_date,
DagRun.execution_date >= min_date,
DagRun.execution_date <= base_date,
TaskFail.task_id.in_([t.task_id for t in dag.tasks]),
)
.all()
Expand Down Expand Up @@ -3185,11 +3185,7 @@ def gantt(self, dag_id, session=None):
.order_by(TaskInstance.start_date)
)

ti_fails = (
session.query(TaskFail)
.join(DagRun, DagRun.execution_date == TaskFail.execution_date)
.filter(DagRun.execution_date == dttm, TaskFail.dag_id == dag_id)
)
ti_fails = session.query(TaskFail).filter(DagRun.execution_date == dttm, TaskFail.dag_id == dag_id)

tasks = []
for ti in tis:
Expand Down
Loading