Skip to content

Commit

Permalink
Change logical_date back to execution_date, this can be done as anoth…
Browse files Browse the repository at this point in the history
…er PR later
  • Loading branch information
jdddog committed Apr 10, 2024
1 parent 112844f commit 286efe2
Show file tree
Hide file tree
Showing 6 changed files with 182 additions and 50 deletions.
18 changes: 9 additions & 9 deletions observatory-platform/observatory/platform/airflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,12 +96,12 @@ def check_connections(*connections):


def send_slack_msg(
*, ti: TaskInstance, logical_date: pendulum.DateTime, comments: str = "", slack_conn_id: str = AirflowConns.SLACK
*, ti: TaskInstance, execution_date: pendulum.DateTime, comments: str = "", slack_conn_id: str = AirflowConns.SLACK
):
"""
Send a Slack message using the token in the slack airflow connection.
:param ti: Task instance.
:param logical_date: DagRun execution date.
:param execution_date: DagRun execution date.
:param comments: Additional comments in slack message
:param slack_conn_id: the Airflow connection id for the Slack connection.
"""
Expand All @@ -118,7 +118,7 @@ def send_slack_msg(
).format(
task=ti.task_id,
dag=ti.dag_id,
exec_date=logical_date,
exec_date=execution_date,
log_url=ti.log_url,
comments=comments,
)
Expand Down Expand Up @@ -205,8 +205,8 @@ def on_failure_callback(context):

comments = f"Task failed, exception:\n{formatted_exception}"
ti = context["ti"]
logical_date = context["logical_date"]
send_slack_msg(ti=ti, logical_date=logical_date, comments=comments, slack_conn_id=AirflowConns.SLACK)
execution_date = context["execution_date"]
send_slack_msg(ti=ti, execution_date=execution_date, comments=comments, slack_conn_id=AirflowConns.SLACK)


def normalized_schedule_interval(schedule_interval: Optional[str]) -> Optional[ScheduleInterval]:
Expand Down Expand Up @@ -344,19 +344,19 @@ def fetch_dag_bag(path: str, include_examples: bool = False) -> DagBag:
def delete_old_xcoms(
session: Session = None,
dag_id: str = None,
logical_date: pendulum.DateTime = None,
execution_date: pendulum.DateTime = None,
retention_days: int = 31,
):
"""Delete XCom messages created by the DAG with the given ID that are as old or older than than
logical_date - retention_days. Defaults to 31 days of retention.
execution_date - retention_days. Defaults to 31 days of retention.
:param session: DB session.
:param dag_id: DAG ID.
:param logical_date: DAG execution date.
:param execution_date: DAG execution date.
:param retention_days: Days of messages to retain.
"""

cut_off_date = logical_date.subtract(days=retention_days)
cut_off_date = execution_date.subtract(days=retention_days)
results = session.query(XCom).filter(
and_(
XCom.dag_id == dag_id,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,7 @@ def run_task(self, task_id: str, map_index: int = -1) -> TaskInstance:
"""Run an Airflow task.
:param task_id: the Airflow task identifier.
:param map_index: the index of an expanded dynamic task.
:param map_index: the map index if the task is a daynamic task
:return: None.
"""

Expand All @@ -451,7 +451,27 @@ def run_task(self, task_id: str, map_index: int = -1) -> TaskInstance:
task = dag.get_task(task_id=task_id)
ti = TaskInstance(task, run_id=run_id, map_index=map_index)
ti.refresh_from_db()
ti.run(ignore_ti_state=True, ignore_all_deps=True)

# TODO: remove this when this issue fixed / PR merged: https://github.com/apache/airflow/issues/34023#issuecomment-1705761692
# https://github.com/apache/airflow/pull/36462
ignore_task_deps = False
if map_index > -1:
ignore_task_deps = True

ti.run(ignore_task_deps=ignore_task_deps)

return ti

def skip_task(self, task_id: str, map_index: int = -1) -> TaskInstance:

assert self.dag_run is not None, "with create_dag_run must be called before run_task"

dag = self.dag_run.dag
run_id = self.dag_run.run_id
task = dag.get_task(task_id=task_id)
ti = TaskInstance(task, run_id=run_id, map_index=map_index)
ti.refresh_from_db()
ti.set_state(State.SKIPPED)

return ti

Expand All @@ -474,30 +494,30 @@ def get_task_instance(self, task_id: str) -> TaskInstance:
def create_dag_run(
self,
dag: DAG,
logical_date: pendulum.DateTime,
execution_date: pendulum.DateTime,
run_type: DagRunType = DagRunType.SCHEDULED,
):
"""Create a DagRun that can be used when running tasks.
During cleanup the DAG run state is updated.
:param dag: the Airflow DAG instance.
:param logical_date: the execution date of the DAG.
:param execution_date: the execution date of the DAG.
:param run_type: what run_type to use when running the DAG run.
:return: None.
"""

# Get start date, which is one schedule interval after execution date
if isinstance(dag.normalized_schedule_interval, (timedelta, relativedelta)):
start_date = (
datetime.fromtimestamp(logical_date.timestamp(), pendulum.tz.UTC) + dag.normalized_schedule_interval
datetime.fromtimestamp(execution_date.timestamp(), pendulum.tz.UTC) + dag.normalized_schedule_interval
)
else:
start_date = croniter.croniter(dag.normalized_schedule_interval, logical_date).get_next(pendulum.DateTime)
start_date = croniter.croniter(dag.normalized_schedule_interval, execution_date).get_next(pendulum.DateTime)

try:
self.dag_run = dag.create_dagrun(
state=State.RUNNING,
execution_date=logical_date,
execution_date=execution_date,
start_date=start_date,
run_type=run_type,
)
Expand Down Expand Up @@ -794,9 +814,11 @@ def assert_dag_structure(self, expected: Dict, dag: DAG):

expected_keys = expected.keys()
actual_keys = dag.task_dict.keys()
diff = set(expected_keys) - set(actual_keys)
self.assertEqual(expected_keys, actual_keys)

for task_id, downstream_list in expected.items():
print(task_id)
self.assertTrue(dag.has_task(task_id))
task = dag.get_task(task_id)
expected = set(downstream_list)
Expand Down Expand Up @@ -1181,18 +1203,18 @@ def create(self):
self.server_thread.join()


def make_dummy_dag(dag_id: str, logical_date: pendulum.DateTime) -> DAG:
def make_dummy_dag(dag_id: str, execution_date: pendulum.DateTime) -> DAG:
"""A Dummy DAG for testing purposes.
:param dag_id: the DAG id.
:param logical_date: the DAGs execution date.
:param execution_date: the DAGs execution date.
:return: the DAG.
"""

with DAG(
dag_id=dag_id,
schedule="@weekly",
default_args={"owner": "airflow", "start_date": logical_date},
default_args={"owner": "airflow", "start_date": execution_date},
catchup=False,
) as dag:
task1 = EmptyOperator(task_id="dummy_task")
Expand Down
105 changes: 105 additions & 0 deletions observatory-platform/observatory/platform/refactor/sensors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# Copyright 2020, 2021 Curtin University
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Author: Tuan Chien, Keegan Smith, Jamie Diprose

from __future__ import annotations

from datetime import timedelta
from functools import partial
from typing import Callable, List, Optional

import pendulum
from airflow.models import DagRun
from airflow.sensors.external_task import ExternalTaskSensor
from airflow.utils.db import provide_session
from sqlalchemy.orm.scoping import scoped_session


class DagCompleteSensor(ExternalTaskSensor):
"""
A sensor that awaits the completion of an external dag by default. Wait functionality can be customised by
providing a different execution_date_fn.
The sensor checks for completion of a dag with "external_dag_id" on the logical date returned by the
execution_date_fn.
"""

def __init__(
self,
task_id: str,
external_dag_id: str,
mode: str = "reschedule",
poke_interval: int = 1200, # Check if dag run is ready every 20 minutes
timeout: int = int(timedelta(days=1).total_seconds()), # Sensor will fail after 1 day of waiting
check_existence: bool = True,
execution_date_fn: Optional[Callable] = None,
**kwargs,
):
"""
:param task_id: the id of the sensor task to create
:param external_dag_id: the id of the external dag to check
:param mode: The mode of the scheduler. Can be reschedule or poke.
:param poke_interval: how often to check if the external dag run is complete
:param timeout: how long to check before the sensor fails
:param check_existence: whether to check that the provided dag_id exists
:param execution_date_fn: a function that returns the logical date(s) of the external DAG runs to query for,
since you need a logical date and a DAG ID to find a particular DAG run to wait for.
"""

if execution_date_fn is None:
execution_date_fn = partial(get_logical_dates, external_dag_id)

super().__init__(
task_id=task_id,
external_dag_id=external_dag_id,
mode=mode,
poke_interval=poke_interval,
timeout=timeout,
check_existence=check_existence,
execution_date_fn=execution_date_fn,
**kwargs,
)


@provide_session
def get_logical_dates(
external_dag_id: str, logical_date: pendulum.DateTime, session: scoped_session = None, **context
) -> List[pendulum.DateTime]:
"""Get the logical dates for a given external dag that fall between and returns its data_interval_start (logical date)
:param external_dag_id: the DAG ID of the external DAG we are waiting for.
:param logical_date: the logic date of the waiting DAG.
:param session: the SQL Alchemy session.
:param context: the Airflow context.
:return: the last logical date of the external DAG that falls before the data interval end of the waiting DAG.
"""

data_interval_end = context["data_interval_end"]
dag_runs = (
session.query(DagRun)
.filter(
DagRun.dag_id == external_dag_id,
DagRun.data_interval_end <= data_interval_end,
)
.all()
)
dates = [d.logical_date for d in dag_runs]
dates.sort(reverse=True)

# If more than 1 date return first date
if len(dates) >= 2:
dates = [dates[0]]

return dates
22 changes: 11 additions & 11 deletions observatory-platform/observatory/platform/utils/dag_run_sensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class DagRunSensor(BaseSensorOperator):
A sensor for monitoring dag runs from other DAGs.
Behaviour:
* If the DAG ID does not exist, throw an exception.
* Monitors the time period [logical_date - duration, logical_date] for the DAG's dag runs.
* Monitors the time period [execution_date - duration, execution_date] for the DAG's dag runs.
- If there are no runs, then return a success state.
- If there are runs, look at the latest run in the time period:
* If the latest run is successful, return a success state.
Expand Down Expand Up @@ -79,17 +79,17 @@ def poke(self, context: Dict, session: scoped_session = None):

if self.check_exists:
self.check_dag_exists(session)
logical_date = context["logical_date"]
execution_date = context["execution_date"]

date = self.get_latest_execution_date(session=session, logical_date=logical_date)
date = self.get_latest_execution_date(session=session, execution_date=execution_date)

# If no logical_date could be found, sleep the grace period, and try again once.This will sleep the duration
# If no execution_date could be found, sleep the grace period, and try again once.This will sleep the duration
# of the grace period and try again. This is an alternative to scheduling the workflow at a slightly later time
# to allow Airflow to record the dagrun in the database.
# Note that this occupies a slot in execution queue.
if date is None:
sleep(self.grace_period.total_seconds())
date = self.get_latest_execution_date(session=session, logical_date=logical_date)
date = self.get_latest_execution_date(session=session, execution_date=execution_date)

if not date:
return success_state
Expand All @@ -102,28 +102,28 @@ def poke(self, context: Dict, session: scoped_session = None):
return retry_state

def get_latest_execution_date(
self, *, session: scoped_session, logical_date: datetime.datetime
self, *, session: scoped_session, execution_date: datetime.datetime
) -> Union[datetime.datetime, None]:
"""Get the most recent execution date for a task in a given DAG.
:param sesssion: Db session.
:param logical_date: Current execution date.
:param execution_date: Current execution date.
:return: Most recent execution date in the window.
"""

end_date = logical_date
end_date = execution_date
start_date = end_date - self.duration

response = (
session.query(DagRun)
.filter(
DagRun.dag_id == self.external_dag_id,
DagRun.logical_date.between(start_date, end_date),
DagRun.execution_date.between(start_date, end_date),
)
.all()
)

dates = [r.logical_date for r in response]
dates = [r.execution_date for r in response]
dates.sort(reverse=True)

if len(dates) > 0:
Expand Down Expand Up @@ -152,7 +152,7 @@ def is_dagrun_done(self, *, date: datetime.datetime, session: scoped_session) ->
session.query(DagRun)
.filter(
DagRun.dag_id == self.external_dag_id,
DagRun.logical_date.in_([date]),
DagRun.execution_date.in_([date]),
DagRun.state.in_(success_states),
)
.first()
Expand Down
Loading

0 comments on commit 286efe2

Please sign in to comment.