Skip to content

Commit

Permalink
[AIP-34] TaskGroup: A UI task grouping concept as an alternative to S…
Browse files Browse the repository at this point in the history
…ubDagOperator #10153

(cherry picked from commit 49c193f)
  • Loading branch information
yuqian90 committed Sep 24, 2020
1 parent 515c7dc commit d4ef251
Show file tree
Hide file tree
Showing 16 changed files with 1,849 additions and 151 deletions.
58 changes: 58 additions & 0 deletions airflow/example_dags/example_task_group.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# -*- coding: utf-8 -*-
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""Example DAG demonstrating the usage of the TaskGroup."""

from airflow.models.dag import DAG
from airflow.operators.dummy_operator import DummyOperator
from airflow.utils.dates import days_ago
from airflow.utils.task_group import TaskGroup

# [START howto_task_group]
with DAG(dag_id="example_task_group", start_date=days_ago(2)) as dag:
start = DummyOperator(task_id="start")

# [START howto_task_group_section_1]
with TaskGroup("section_1", tooltip="Tasks for section_1") as section_1:
task_1 = DummyOperator(task_id="task_1")
task_2 = DummyOperator(task_id="task_2")
task_3 = DummyOperator(task_id="task_3")

task_1 >> [task_2, task_3]
# [END howto_task_group_section_1]

# [START howto_task_group_section_2]
with TaskGroup("section_2", tooltip="Tasks for section_2") as section_2:
task_1 = DummyOperator(task_id="task_1")

# [START howto_task_group_inner_section_2]
with TaskGroup("inner_section_2", tooltip="Tasks for inner_section2") as inner_section_2:
task_2 = DummyOperator(task_id="task_2")
task_3 = DummyOperator(task_id="task_3")
task_4 = DummyOperator(task_id="task_4")

[task_2, task_3] >> task_4
# [END howto_task_group_inner_section_2]

# [END howto_task_group_section_2]

end = DummyOperator(task_id='end')

start >> section_1 >> section_2 >> end
# [END howto_task_group]
61 changes: 54 additions & 7 deletions airflow/models/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,9 +326,11 @@ def __init__(
do_xcom_push=True, # type: bool
inlets=None, # type: Optional[Dict]
outlets=None, # type: Optional[Dict]
task_group=None,
*args,
**kwargs
):
from airflow.utils.task_group import TaskGroupContext

if args or kwargs:
# TODO remove *args and **kwargs in Airflow 2.0
Expand All @@ -343,6 +345,11 @@ def __init__(
)
validate_key(task_id)
self.task_id = task_id
self.label = task_id
task_group = task_group or TaskGroupContext.get_current_task_group(dag)
if task_group:
self.task_id = task_group.child_id(task_id)
task_group.add(self)
self.owner = owner
self.email = email
self.email_on_retry = email_on_retry
Expand Down Expand Up @@ -474,6 +481,42 @@ def __hash__(self):
hash_components.append(repr(val))
return hash(tuple(hash_components))

def __rshift__(self, other):
"""
Implements Self >> Other == self.set_downstream(other)
If "Other" is a DAG, the DAG is assigned to the Operator.
NOTE: This method is supposed to have moved to TaskMixin. But this override is needed
here because of this special treatment for DAG. It can be removed in Airflow 2.0.
"""
if isinstance(other, DAG):
# if this dag is already assigned, do nothing
# otherwise, do normal dag assignment
if not (self.has_dag() and self.dag is other):
self.dag = other
else:
self.set_downstream(other)
return other

def __lshift__(self, other):
"""
Implements Self << Other == self.set_upstream(other)
If "Other" is a DAG, the DAG is assigned to the Operator.
NOTE: This method is supposed to have moved to TaskMixin. But this override is needed
here because of this special treatment for DAG. It can be removed in Airflow 2.0.
"""
if isinstance(other, DAG):
# if this dag is already assigned, do nothing
# otherwise, do normal dag assignment
if not (self.has_dag() and self.dag is other):
self.dag = other
else:
self.set_upstream(other)
return other

@property
def dag(self):
"""
Expand Down Expand Up @@ -946,21 +989,25 @@ def roots(self):
"""Required by TaskMixin"""
return [self]

@property
def leaves(self):
"""Required by TaskMixin"""
return [self]

def _set_relatives(
self,
task_or_task_list, # type: Union[TaskMixin, Sequence[TaskMixin]]
upstream=False,
):
"""Sets relatives for the task or task list."""

if isinstance(task_or_task_list, Sequence):
task_like_object_list = task_or_task_list
else:
task_like_object_list = [task_or_task_list]
if not isinstance(task_or_task_list, Sequence):
task_or_task_list = [task_or_task_list]

task_list = [] # type: List["BaseOperator"]
for task_object in task_like_object_list:
task_list.extend(task_object.roots)
for task_object in task_or_task_list:
task_object.update_relative(self, not upstream)
relatives = task_object.leaves if upstream else task_object.roots
task_list.extend(relatives)

for task in task_list:
if not isinstance(task, BaseOperator):
Expand Down
50 changes: 44 additions & 6 deletions airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
import warnings
from collections import OrderedDict, defaultdict
from datetime import timedelta, datetime
from typing import TYPE_CHECKING, Callable, Dict, FrozenSet, Iterable, List, Optional, Type, Union
from typing import Callable, Dict, FrozenSet, Iterable, List, Optional, Type, Union

import jinja2
import pendulum
Expand Down Expand Up @@ -64,9 +64,6 @@
from airflow.utils.sqlalchemy import UtcDateTime, Interval
from airflow.utils.state import State

if TYPE_CHECKING:
from airflow.models.baseoperator import BaseOperator # Avoid circular dependency

install_aliases()

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -245,6 +242,9 @@ def __init__(
jinja_environment_kwargs=None, # type: Optional[Dict]
tags=None, # type: Optional[List[str]]
):
from airflow.utils.task_group import TaskGroup
from airflow.models.baseoperator import BaseOperator

self.user_defined_macros = user_defined_macros
self.user_defined_filters = user_defined_filters
self.default_args = copy.deepcopy(default_args or {})
Expand Down Expand Up @@ -329,6 +329,7 @@ def __init__(

self.jinja_environment_kwargs = jinja_environment_kwargs
self.tags = tags
self._task_group = TaskGroup.create_root(self)

def __repr__(self):
return "<DAG: {self.dag_id}>".format(self=self)
Expand Down Expand Up @@ -591,6 +592,10 @@ def filepath(self):
fn = fn.replace(os.path.dirname(__file__) + '/', '')
return fn

@property
def task_group(self):
return self._task_group

@property
def folder(self):
"""Folder location of where the DAG object is instantiated."""
Expand Down Expand Up @@ -1221,6 +1226,7 @@ def sub_dag(self, task_regex, include_downstream=False,
based on a regex that should match one or many tasks, and includes
upstream and downstream neighbours based on the flag passed.
"""
from airflow.models.baseoperator import BaseOperator

# deep-copying self.task_dict takes a long time, and we don't want all
# the tasks anyway, so we copy the tasks manually later
Expand All @@ -1242,9 +1248,38 @@ def sub_dag(self, task_regex, include_downstream=False,
# Make sure to not recursively deepcopy the dag while copying the task
dag.task_dict = {t.task_id: copy.deepcopy(t, {id(t.dag): dag})
for t in regex_match + also_include}

# Remove tasks not included in the subdag from task_group
def remove_excluded(group):
for child in list(group.children.values()):
if isinstance(child, BaseOperator):
if child.task_id not in dag.task_dict:
group.children.pop(child.task_id)
else:
# The tasks in the subdag are a copy of tasks in the original dag
# so update the reference in the TaskGroups too.
group.children[child.task_id] = dag.task_dict[child.task_id]
else:
remove_excluded(child)

# Remove this TaskGroup if it doesn't contain any tasks in this subdag
if not child.children:
group.children.pop(child.group_id)

remove_excluded(dag.task_group)

# Removing upstream/downstream references to tasks and TaskGroups that did not make
# the cut.
subdag_task_groups = dag.task_group.get_task_group_dict()
for group in subdag_task_groups.values():
group.upstream_group_ids = group.upstream_group_ids.intersection(subdag_task_groups.keys())
group.downstream_group_ids = group.downstream_group_ids.intersection(subdag_task_groups.keys())
group.upstream_task_ids = group.upstream_task_ids.intersection(dag.task_dict.keys())
group.downstream_task_ids = group.downstream_task_ids.intersection(dag.task_dict.keys())

for t in dag.tasks:
# Removing upstream/downstream references to tasks that did not
# made the cut
# make the cut
t._upstream_task_ids = t._upstream_task_ids.intersection(dag.task_dict.keys())
t._downstream_task_ids = t._downstream_task_ids.intersection(
dag.task_dict.keys())
Expand Down Expand Up @@ -1332,7 +1367,8 @@ def add_task(self, task):
elif task.end_date and self.end_date:
task.end_date = min(task.end_date, self.end_date)

if task.task_id in self.task_dict and self.task_dict[task.task_id] is not task:
if ((task.task_id in self.task_dict and self.task_dict[task.task_id] is not task)
or task.task_id in self._task_group.used_group_ids):
# TODO: raise an error in Airflow 2.0
warnings.warn(
'The requested task could not be added to the DAG because a '
Expand All @@ -1343,6 +1379,8 @@ def add_task(self, task):
else:
self.task_dict[task.task_id] = task
task.dag = self
# Add task_id to used_group_ids to prevent group_id and task_id collisions.
self._task_group.used_group_ids.add(task.task_id)

self.task_count = len(self.task_dict)

Expand Down
16 changes: 16 additions & 0 deletions airflow/models/taskmixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@ def roots(self):
"""Should return list of root operator List[BaseOperator]"""
raise NotImplementedError()

@property
def leaves(self):
"""Should return list of leaf operator List[BaseOperator]"""
raise NotImplementedError()

@abstractmethod
def set_upstream(
self,
Expand All @@ -53,6 +58,17 @@ def set_downstream(
"""
raise NotImplementedError()

def update_relative(
self,
other, # type: "TaskMixin"
upstream=True
):
"""
Update relationship information about another TaskMixin. Default is no-op.
Override if necessary.
"""
pass

def __lshift__(
self,
other, # type: Union["TaskMixin", Sequence["TaskMixin"]]
Expand Down
1 change: 1 addition & 0 deletions airflow/serialization/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,4 @@ class DagAttributeTypes(str, Enum):
DICT = 'dict'
SET = 'set'
TUPLE = 'tuple'
TASK_GROUP = 'taskgroup'
48 changes: 47 additions & 1 deletion airflow/serialization/schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,11 @@
"_default_view": { "type" : "string"},
"_access_control": {"$ref": "#/definitions/dict" },
"is_paused_upon_creation": { "type": "boolean" },
"tags": { "type": "array" }
"tags": { "type": "array" },
"_task_group": {"anyOf": [
{ "type": "null" },
{ "$ref": "#/definitions/task_group" }
]}
},
"required": [
"_dag_id",
Expand Down Expand Up @@ -125,6 +129,7 @@
"_task_module": { "type": "string" },
"_operator_extra_links": { "$ref": "#/definitions/extra_links" },
"task_id": { "type": "string" },
"label": { "type": "string" },
"owner": { "type": "string" },
"start_date": { "$ref": "#/definitions/datetime" },
"end_date": { "$ref": "#/definitions/datetime" },
Expand Down Expand Up @@ -156,6 +161,47 @@
}
},
"additionalProperties": true
},
"task_group": {
"$comment": "A TaskGroup containing tasks",
"type": "object",
"required": [
"_group_id",
"prefix_group_id",
"children",
"tooltip",
"ui_color",
"ui_fgcolor",
"upstream_group_ids",
"downstream_group_ids",
"upstream_task_ids",
"downstream_task_ids"
],
"properties": {
"_group_id": {"anyOf": [{"type": "null"}, { "type": "string" }]},
"prefix_group_id": { "type": "boolean" },
"children": { "$ref": "#/definitions/dict" },
"tooltip": { "type": "string" },
"ui_color": { "type": "string" },
"ui_fgcolor": { "type": "string" },
"upstream_group_ids": {
"type": "array",
"items": { "type": "string" }
},
"downstream_group_ids": {
"type": "array",
"items": { "type": "string" }
},
"upstream_task_ids": {
"type": "array",
"items": { "type": "string" }
},
"downstream_task_ids": {
"type": "array",
"items": { "type": "string" }
}
},
"additionalProperties": false
}
},

Expand Down
Loading

0 comments on commit d4ef251

Please sign in to comment.