Skip to content

Commit

Permalink
Added support for Branching with Airflow
Browse files Browse the repository at this point in the history
- remove `next` function in `AirflowTask`
- `AirflowTask`s has no knowledge of next tasks.
- removed todos + added some todos
- Graph construction on airflow side using graph_structure datastructure.
- graph_structure comes from`FlowGraph.output_steps()[1]`
  • Loading branch information
valayDave committed Mar 20, 2022
1 parent 8e9f649 commit 874b94a
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 61 deletions.
70 changes: 45 additions & 25 deletions metaflow/plugins/airflow/airflow_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,11 @@ def __init__(
self.schedule_interval = self._get_schedule()
self._file_path = file_path
self.metaflow_parameters = None
_, self.graph_structure = self.graph.output_steps()

def _get_schedule(self):
schedule = self.flow._flow_decorators.get("schedule")
# todo : Fix bug here in schedule. The regex pattern doesn't work as intended.
if schedule:
return schedule.schedule
# Schedule can be None.
Expand Down Expand Up @@ -226,6 +228,29 @@ def _process_parameters(self):

return parameters

def _make_parent_input_path_compressed(
self,
step_names,
):
return "%s:" % (self.run_id) + ",".join(
self._make_parent_input_path(s, only_task_id=True) for s in step_names
)

def _make_parent_input_path(self, step_name, only_task_id=False):
# This is set using the `airflow_internal` decorator.
# This will pull the `return_value` xcom which holds a dictionary.
# This helps pass state.
task_id_string = "/%s/{{ task_instance.xcom_pull(task_ids='%s')['%s'] }}" % (
step_name,
step_name,
TASK_ID_XCOM_KEY,
)

if only_task_id:
return task_id_string

return "%s%s" % (self.run_id, task_id_string)

def _to_job(self, node: DAGNode):
# supported compute : k8s (v1), local(v2), batch(v3)
attrs = {
Expand Down Expand Up @@ -261,6 +286,9 @@ def _to_job(self, node: DAGNode):
# The Below If/Else Block handle "Input Paths".
# Input Paths help manage dataflow across the graph.
if node.name == "start":
# Initialize parameters for the flow in the `start` step.
# `start` step has no upstream input dependencies aside from
# parameters.
parameters = self._process_parameters()
if parameters:
env["METAFLOW_PARAMETERS"] = self.parameter_macro
Expand All @@ -270,10 +298,6 @@ def _to_job(self, node: DAGNode):
default_parameters[parameter["name"]] = parameter["value"]
# Dump the default values specified in the flow.
env["METAFLOW_DEFAULT_PARAMETERS"] = json.dumps(default_parameters)
# Initialize parameters for the flow in the `start` step.
# todo : Handle parameters
# `start` step has no upstream input dependencies aside from
# parameters.
input_paths = None
else:
if node.parallel_foreach:
Expand All @@ -293,22 +317,11 @@ def _to_job(self, node: DAGNode):
# doesn't fail.
# From airflow docs :
# "Note: If the first task run is not succeeded then on every retry task XComs will be cleared to make the task run idempotent."
input_paths = (
# This is set using the `airflow_internal` decorator.
# This will pull the `return_value` xcom which holds a dictionary.
# This helps pass state.
"%s/%s/{{ task_instance.xcom_pull(task_ids='%s')['%s'] }}"
% (
self.run_id,
node.in_funcs[0],
node.in_funcs[0],
TASK_ID_XCOM_KEY,
)
)
input_paths = self._make_parent_input_path(node.in_funcs[0])
else:
# this is a split scenario where there can be more than one input paths.
# todo : set input paths for a split join step
pass
input_paths = self._make_parent_input_path_compressed(node.in_funcs)

env["METAFLOW_INPUT_PATHS"] = input_paths

if node.is_inside_foreach:
Expand Down Expand Up @@ -401,7 +414,6 @@ def _step_cli(self, node, paths, code_package_url, user_code_retries):
if node.name == "start":
# We need a separate unique ID for the special _parameters task
task_id_params = "%s-params" % self.task_id
# TODO : Currently I am putting this boiler plate because we need to check if parameters are set or not.
# Export user-defined parameters into runtime environment
param_file = "".join(
random.choice(string.ascii_lowercase) for _ in range(10)
Expand Down Expand Up @@ -495,23 +507,30 @@ def _visit(node: DAGNode, workflow: Workflow, exit_node=None):

state = AirflowTask(node.name).set_operator_args(**self._to_job(node))

if node.type == "end" or exit_node in node.out_funcs:
if node.type == "end":
workflow.add_state(state)

# Continue linear assignment within the (sub)workflow if the node
# doesn't branch or fork.
elif node.type in ("start", "linear", "join"):
workflow.add_state(state.next(node.out_funcs[0]))
_visit(self.graph[node.out_funcs[0]], workflow, exit_node)
workflow.add_state(state)
_visit(
self.graph[node.out_funcs[0]],
workflow,
)

elif node.type == "split":
# Todo : handle Taskgroup in this step cardinality in some way
pass
workflow.add_state(state)
for func in node.out_funcs:
_visit(
self.graph[func],
workflow,
)

elif node.type == "foreach":
# Todo : handle foreach cardinality in some way
# Continue the traversal from the matching_join.
_visit(self.graph[node.matching_join], workflow, exit_node)
_visit(self.graph[node.matching_join], workflow)
# We shouldn't ideally ever get here.
else:
raise AirflowException(
Expand All @@ -530,6 +549,7 @@ def _visit(node: DAGNode, workflow: Workflow, exit_node=None):
catchup=self.catchup,
tags=self.tags,
file_path=self._file_path,
graph_structure=self.graph_structure,
)
workflow = _visit(self.graph["start"], workflow)
workflow.set_parameters(self.metaflow_parameters)
Expand Down
80 changes: 44 additions & 36 deletions metaflow/plugins/airflow/airflow_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,43 +291,29 @@ def __init__(self, name, operator_type="kubernetes", flow_name=None):
self.name = name
self._operator_args = None
self._operator_type = operator_type
self._next = None
self._flow_name = flow_name

def set_operator_args(self, **kwargs):
self._operator_args = kwargs
return self

@property
def next_state(self):
return self._next

def next(self, out_func_name):
self._next = out_func_name
return self

def to_dict(self):
return {
"name": self.name,
"next": self._next,
"operator_type": self._operator_type,
"operator_args": self._operator_args,
}

@classmethod
def from_dict(cls, jsd, flow_name=None):
op_args = {} if not "operator_args" in jsd else jsd["operator_args"]
return (
cls(
jsd["name"],
operator_type=jsd["operator_type"]
if "operator_type" in jsd
else "kubernetes",
flow_name=flow_name,
)
.next(jsd["next"])
.set_operator_args(**op_args)
)
return cls(
jsd["name"],
operator_type=jsd["operator_type"]
if "operator_type" in jsd
else "kubernetes",
flow_name=flow_name,
).set_operator_args(**op_args)

def _kubenetes_task(self):
KubernetesPodOperator = get_k8s_operator()
Expand All @@ -337,18 +323,18 @@ def _kubenetes_task(self):
return KubernetesPodOperator(**k8s_args)

def to_task(self):
# todo fix
if self._operator_type == "kubernetes":
return self._kubenetes_task()


class Workflow(object):
def __init__(self, file_path=None, **kwargs):
def __init__(self, file_path=None, graph_structure=None, **kwargs):
self._dag_instantiation_params = AirflowDAGArgs(**kwargs)
self._file_path = file_path
tree = lambda: defaultdict(tree)
self.states = tree()
self.metaflow_params = None
self.graph_structure = graph_structure

def set_parameters(self, params):
self.metaflow_params = params
Expand All @@ -358,6 +344,7 @@ def add_state(self, state):

def to_dict(self):
return dict(
graph_structure=self.graph_structure,
states={s: v.to_dict() for s, v in self.states.items()},
dag_instantiation_params=self._dag_instantiation_params.to_dict(),
file_path=self._file_path,
Expand All @@ -371,6 +358,7 @@ def to_json(self):
def from_dict(cls, data_dict):
re_cls = cls(
file_path=data_dict["file_path"],
graph_structure=data_dict["graph_structure"],
)
re_cls._dag_instantiation_params = AirflowDAGArgs.from_dict(
data_dict["dag_instantiation_params"]
Expand Down Expand Up @@ -407,20 +395,40 @@ def compile(self):

params_dict = self._construct_params()
dag = DAG(params=params_dict, **self._dag_instantiation_params.arguements)
curr_state = self.states["start"]
curr_task = self.states["start"].to_task()
prev_task = None
# Todo : Assert that fileloc is required because the DAG export has that information.
dag.fileloc = self._file_path if self._file_path is not None else dag.fileloc
with dag:
while curr_state is not None:
curr_task = curr_state.to_task()
if prev_task is not None:
prev_task >> curr_task

if curr_state.next_state is None:
curr_state = None
def add_node(node, parents, dag):
"""
A recursive function to traverse the specialized
graph_structure datastructure.
"""
if type(node) == str:
task = self.states[node].to_task()
if parents:
for parent in parents:
parent >> task
return [task] # Return Parent

# this means a split from parent
if type(node) == list:
# this means branching since everything within the list is a list
if all(isinstance(n, list) for n in node):
curr_parents = parents
parent_list = []
for node_list in node:
last_parent = add_node(node_list, curr_parents, dag)
parent_list.extend(last_parent)
return parent_list
else:
curr_state = self.states[curr_state.next_state]
prev_task = curr_task
# this means no branching and everything within the list is not a list and can be actual nodes.
curr_parents = parents
for node_x in node:
curr_parents = add_node(node_x, curr_parents, dag)
return curr_parents

with dag:
parent = None
for node in self.graph_structure:
parent = add_node(node, parent, dag)

return dag

0 comments on commit 874b94a

Please sign in to comment.