Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 47 additions & 17 deletions airflow/cli/commands/task_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import os
import textwrap
from contextlib import contextmanager, redirect_stderr, redirect_stdout
from typing import List, Optional
from typing import List, Optional, Tuple, Union

from pendulum.parsing.exceptions import ParserError
from sqlalchemy.orm.exc import NoResultFound
Expand All @@ -39,9 +39,9 @@
from airflow.models.baseoperator import BaseOperator
from airflow.models.dag import DAG
from airflow.models.dagrun import DagRun
from airflow.models.xcom import IN_MEMORY_RUN_ID
from airflow.ti_deps.dep_context import DepContext
from airflow.ti_deps.dependencies_deps import SCHEDULER_QUEUED_DEPS
from airflow.typing_compat import Literal
from airflow.utils import cli as cli_utils
from airflow.utils.cli import (
get_dag,
Expand All @@ -54,15 +54,27 @@
from airflow.utils.log.logging_mixin import StreamLogWriter
from airflow.utils.net import get_hostname
from airflow.utils.session import NEW_SESSION, create_session, provide_session
from airflow.utils.state import DagRunState

CreateIfNecessary = Union[Literal[False], Literal["db"], Literal["memory"]]


def _generate_temporary_run_id() -> str:
"""Generate a ``run_id`` for a DAG run that will be created temporarily.

This is used mostly by ``airflow task test`` to create a DAG run that will
be deleted after the task is run.
"""
return f"__airflow_temporary_run_{timezone.utcnow().isoformat()}__"


def _get_dag_run(
*,
dag: DAG,
exec_date_or_run_id: str,
create_if_necessary: bool,
create_if_necessary: CreateIfNecessary,
session: Session,
) -> DagRun:
) -> Tuple[DagRun, bool]:
"""Try to retrieve a DAG run from a string representing either a run ID or logical date.

This checks DAG runs like this:
Expand All @@ -78,15 +90,15 @@ def _get_dag_run(
"""
dag_run = dag.get_dagrun(run_id=exec_date_or_run_id, session=session)
if dag_run:
return dag_run
return dag_run, False

try:
execution_date: Optional[datetime.datetime] = timezone.parse(exec_date_or_run_id)
except (ParserError, TypeError):
execution_date = None

try:
return (
dag_run = (
session.query(DagRun)
.filter(DagRun.dag_id == dag.dag_id, DagRun.execution_date == execution_date)
.one()
Expand All @@ -96,10 +108,25 @@ def _get_dag_run(
raise DagRunNotFound(
f"DagRun for {dag.dag_id} with run_id or execution_date of {exec_date_or_run_id!r} not found"
) from None
else:
return dag_run, False

if execution_date is not None:
return DagRun(dag.dag_id, run_id=IN_MEMORY_RUN_ID, execution_date=execution_date)
return DagRun(dag.dag_id, run_id=exec_date_or_run_id, execution_date=timezone.utcnow())
dag_run_execution_date = execution_date
else:
dag_run_execution_date = timezone.utcnow()
if create_if_necessary == "memory":
dag_run = DagRun(dag.dag_id, run_id=exec_date_or_run_id, execution_date=dag_run_execution_date)
return dag_run, True
elif create_if_necessary == "db":
dag_run = dag.create_dagrun(
state=DagRunState.QUEUED,
execution_date=dag_run_execution_date,
run_id=_generate_temporary_run_id(),
session=session,
)
return dag_run, True
raise ValueError(f"unknown create_if_necessary value: {create_if_necessary!r}")


@provide_session
Expand All @@ -108,16 +135,16 @@ def _get_ti(
exec_date_or_run_id: str,
map_index: int,
*,
create_if_necessary: bool = False,
create_if_necessary: CreateIfNecessary = False,
session: Session = NEW_SESSION,
) -> TaskInstance:
) -> Tuple[TaskInstance, bool]:
"""Get the task instance through DagRun.run_id, if that fails, get the TI the old way"""
if task.is_mapped:
if map_index < 0:
raise RuntimeError("No map_index passed to mapped task")
elif map_index >= 0:
raise RuntimeError("map_index passed to non-mapped task")
dag_run = _get_dag_run(
dag_run, dr_created = _get_dag_run(
dag=task.dag,
exec_date_or_run_id=exec_date_or_run_id,
create_if_necessary=create_if_necessary,
Expand All @@ -137,7 +164,7 @@ def _get_ti(
else:
ti = ti_or_none
ti.refresh_from_task(task)
return ti
return ti, dr_created


def _run_task_by_selected_method(args, dag: DAG, ti: TaskInstance) -> None:
Expand Down Expand Up @@ -332,7 +359,7 @@ def task_run(args, dag=None):
# Use DAG from parameter
pass
task = dag.get_task(task_id=args.task_id)
ti = _get_ti(task, args.execution_date_or_run_id, args.map_index)
ti, _ = _get_ti(task, args.execution_date_or_run_id, args.map_index)
ti.init_run_context(raw=args.raw)

hostname = get_hostname()
Expand Down Expand Up @@ -360,7 +387,7 @@ def task_failed_deps(args):
"""
dag = get_dag(args.subdir, args.dag_id)
task = dag.get_task(task_id=args.task_id)
ti = _get_ti(task, args.execution_date_or_run_id, args.map_index)
ti, _ = _get_ti(task, args.execution_date_or_run_id, args.map_index)

dep_context = DepContext(deps=SCHEDULER_QUEUED_DEPS)
failed_deps = list(ti.get_failed_dep_statuses(dep_context=dep_context))
Expand All @@ -383,7 +410,7 @@ def task_state(args):
"""
dag = get_dag(args.subdir, args.dag_id)
task = dag.get_task(task_id=args.task_id)
ti = _get_ti(task, args.execution_date_or_run_id, args.map_index)
ti, _ = _get_ti(task, args.execution_date_or_run_id, args.map_index)
print(ti.current_state())


Expand Down Expand Up @@ -502,7 +529,7 @@ def task_test(args, dag=None):
if task.params:
task.params.validate()

ti = _get_ti(task, args.execution_date_or_run_id, args.map_index, create_if_necessary=True)
ti, dr_created = _get_ti(task, args.execution_date_or_run_id, args.map_index, create_if_necessary="db")

try:
if args.dry_run:
Expand All @@ -520,6 +547,9 @@ def task_test(args, dag=None):
# Make sure to reset back to normal. When run for CLI this doesn't
# matter, but it does for test suite
logging.getLogger('airflow.task').propagate = False
if dr_created:
with create_session() as session:
session.delete(ti.dag_run)


@cli_utils.action_cli(check_db=False)
Expand All @@ -528,7 +558,7 @@ def task_render(args):
"""Renders and displays templated fields for a given task"""
dag = get_dag(args.subdir, args.dag_id)
task = dag.get_task(task_id=args.task_id)
ti = _get_ti(task, args.execution_date_or_run_id, args.map_index, create_if_necessary=True)
ti, _ = _get_ti(task, args.execution_date_or_run_id, args.map_index, create_if_necessary="memory")
ti.render_templates()
for attr in task.__class__.template_fields:
print(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,13 @@ def upgrade():
with op.batch_alter_table("xcom") as batch_op:
batch_op.create_primary_key("xcom_pkey", ["dag_run_id", "task_id", "map_index", "key"])
batch_op.create_index("idx_xcom_key", ["key"])
batch_op.create_index("idx_xcom_ti_id", ["dag_id", "run_id", "task_id", "map_index"])
batch_op.create_foreign_key(
"xcom_task_instance_fkey",
"task_instance",
["dag_id", "task_id", "run_id", "map_index"],
["dag_id", "task_id", "run_id", "map_index"],
ondelete="CASCADE",
)


def downgrade():
Expand Down
1 change: 1 addition & 0 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -2244,6 +2244,7 @@ def xcom_push(
task_id=self.task_id,
dag_id=self.dag_id,
run_id=self.run_id,
map_index=self.map_index,
session=session,
)

Expand Down
43 changes: 22 additions & 21 deletions airflow/models/xcom.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Type, Union, cast, overload

import pendulum
from sqlalchemy import Column, Index, Integer, LargeBinary, String
from sqlalchemy import Column, ForeignKeyConstraint, Index, Integer, LargeBinary, String
from sqlalchemy.ext.associationproxy import association_proxy
from sqlalchemy.orm import Query, Session, reconstructor, relationship
from sqlalchemy.orm.exc import NoResultFound
Expand All @@ -46,10 +46,6 @@
MAX_XCOM_SIZE = 49344
XCOM_RETURN_KEY = 'return_value'

# Stand-in value for 'airflow task test' generating a temporary in-memory DAG
# run without storing it in the database.
IN_MEMORY_RUN_ID = "__airflow_in_memory_dagrun__"

if TYPE_CHECKING:
from airflow.models.taskinstance import TaskInstanceKey

Expand All @@ -71,26 +67,33 @@ class BaseXCom(Base, LoggingMixin):
value = Column(LargeBinary)
timestamp = Column(UtcDateTime, default=timezone.utcnow, nullable=False)

__table_args__ = (
# Ideally we should create a unique index over (key, dag_id, task_id, run_id),
# but it goes over MySQL's index length limit. So we instead index 'key'
# separately, and enforce uniqueness with DagRun.id instead.
Index("idx_xcom_key", key),
ForeignKeyConstraint(
[dag_id, task_id, run_id, map_index],
[
"task_instance.dag_id",
"task_instance.task_id",
"task_instance.run_id",
"task_instance.map_index",
],
name="xcom_task_instance_fkey",
ondelete="CASCADE",
),
)

dag_run = relationship(
"DagRun",
primaryjoin="""and_(
BaseXCom.dag_id == foreign(DagRun.dag_id),
BaseXCom.run_id == foreign(DagRun.run_id),
)""",
primaryjoin="BaseXCom.dag_run_id == foreign(DagRun.id)",
uselist=False,
lazy="joined",
passive_deletes="all",
)
execution_date = association_proxy("dag_run", "execution_date")

__table_args__ = (
# Ideally we should create a unique index over (key, dag_id, task_id, run_id),
# but it goes over MySQL's index length limit. So we instead create indexes
# separately, and enforce uniqueness with DagRun.id instead.
Index("idx_xcom_key", key),
Index("idx_xcom_ti_id", dag_id, task_id, run_id, map_index),
)

@reconstructor
def init_on_load(self):
"""
Expand Down Expand Up @@ -175,8 +178,6 @@ def set(
)
except NoResultFound:
raise ValueError(f"DAG run not found on DAG {dag_id!r} at {execution_date}") from None
elif run_id == IN_MEMORY_RUN_ID:
dag_run_id = -1
else:
dag_run_id = session.query(DagRun.id).filter_by(dag_id=dag_id, run_id=run_id).scalar()
if dag_run_id is None:
Expand All @@ -197,6 +198,7 @@ def set(
cls.run_id == run_id,
cls.task_id == task_id,
cls.dag_id == dag_id,
cls.map_index == map_index,
).delete()
new = cast(Any, cls)( # Work around Mypy complaining model not defining '__init__'.
dag_run_id=dag_run_id,
Expand All @@ -205,6 +207,7 @@ def set(
run_id=run_id,
task_id=task_id,
dag_id=dag_id,
map_index=map_index,
)
session.add(new)
session.flush()
Expand Down Expand Up @@ -452,8 +455,6 @@ def get_many(
if execution_date is not None:
query = query.filter(DagRun.execution_date <= execution_date)
else:
# This returns an empty query result for IN_MEMORY_RUN_ID,
# but that is impossible to implement. Sorry?
dr = session.query(DagRun.execution_date).filter(DagRun.run_id == run_id).subquery()
query = query.filter(cls.execution_date <= dr.c.execution_date)
elif execution_date is not None:
Expand Down
Loading