Skip to content

Commit

Permalink
Temporary fix
Browse files Browse the repository at this point in the history
  • Loading branch information
jdddog committed Mar 18, 2024
1 parent 552f44c commit fc40423
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 1 deletion.
9 changes: 8 additions & 1 deletion observatory_platform/sandbox/sandbox_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,14 @@ def run_task(self, task_id: str, map_index: int = -1) -> TaskInstance:
task = dag.get_task(task_id=task_id)
ti = TaskInstance(task, run_id=run_id, map_index=map_index)
ti.refresh_from_db()
ti.run(ignore_ti_state=True, ignore_all_deps=True)

# TODO: remove this when this issue fixed / PR merged: https://github.com/apache/airflow/issues/34023#issuecomment-1705761692
# https://github.com/apache/airflow/pull/36462
ignore_task_deps = False
if map_index > -1:
ignore_task_deps = True

ti.run(ignore_task_deps=ignore_task_deps)

return ti

Expand Down
65 changes: 65 additions & 0 deletions observatory_platform/sandbox/tests/test_sandbox_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
import croniter
import pendulum
from airflow.decorators import dag, task
from airflow.decorators import task_group
from airflow.exceptions import AirflowSkipException
from airflow.models.connection import Connection
from airflow.models.dag import ScheduleArg
from airflow.models.variable import Variable
Expand Down Expand Up @@ -75,6 +77,52 @@ def task3():
return my_dag()


def create_dynamic_task_dag(
*,
dag_id: str,
start_date: pendulum.DateTime,
schedule: str = "@weekly",
catchup: bool = False,
):
@dag(
dag_id=dag_id,
schedule_interval=schedule,
start_date=start_date,
catchup=catchup,
tags=["example_tag"],
)
def example_workflow():
@task
def fetch_releases(**context):
releases = [0, 1]
if not releases:
raise AirflowSkipException("No new releases found, skipping")
return releases

@task_group(group_id="process_release")
def process_release(data, **context):
@task
def download(release: dict, **context):
print(f"Downloading {release}")

@task
def bq_load(release: dict, **context):
print(f"Loading to BigQuery {release}")

# Connects tasks
download(data) >> bq_load(data)

# Fetches releases
xcom_releases = fetch_releases()

# Using `.expand()` to dynamically create tasks for each release
process_release_task_group = process_release.expand(data=xcom_releases)

(xcom_releases >> process_release_task_group)

return example_workflow()


class TestSandboxEnvironment(unittest.TestCase):
"""Test the SandboxEnvironment"""

Expand Down Expand Up @@ -331,3 +379,20 @@ def test_create_dag_run_timedelta(self):
with env.create_dag_run(my_dag, execution_date):
self.assertIsNotNone(env.dag_run)
self.assertEqual(expected_dag_date, env.dag_run.start_date)

def test_map_index(self):
env = SandboxEnvironment(self.project_id, self.data_location)
logical_date = pendulum.datetime(2024, 1, 1)
my_dag = create_dynamic_task_dag(dag_id="dynamic_task_dag", start_date=logical_date)
with env.create():
with env.create_dag_run(my_dag, logical_date):
self.assertIsNotNone(env.dag_run)
ti = env.run_task("fetch_releases")
self.assertEqual(TaskInstanceState.SUCCESS, ti.state)

for map_index in range(2):
ti = env.run_task("process_release.download", map_index=map_index)
self.assertEqual(TaskInstanceState.SUCCESS, ti.state)

ti = env.run_task("process_release.bq_load", map_index=map_index)
self.assertEqual(TaskInstanceState.SUCCESS, ti.state)

0 comments on commit fc40423

Please sign in to comment.