Skip to content

Commit

Permalink
Fixed setting task_id :
Browse files Browse the repository at this point in the history
- switch task-id from airflow job is to hash to "runid/stepname"
- refactor xcom setting variables
- added comments
  • Loading branch information
valayDave committed Mar 19, 2022
1 parent e2a1e50 commit 563a200
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 31 deletions.
38 changes: 27 additions & 11 deletions metaflow/plugins/airflow/airflow_compiler.py
Expand Up @@ -19,7 +19,13 @@
from metaflow.plugins.cards.card_modules import chevron
from metaflow.plugins.aws.aws_utils import compute_resource_attributes
from .exceptions import AirflowNotPresent, AirflowException
from .airflow_utils import Workflow, AirflowTask, AirflowDAGArgs
from .airflow_utils import (
TASK_ID_XCOM_KEY,
Workflow,
AirflowTask,
AirflowDAGArgs,
AIRFLOW_TASK_ID_TEMPLATE_VALUE,
)
from . import airflow_utils as af_utils
from .compute.k8s import create_k8s_args
import metaflow.util as util
Expand All @@ -37,11 +43,9 @@

class Airflow(object):

# {{ ti.job_id }} is doesn't provide the gaurentees we need for `Task` ids.
# {{ ti.job_id }} changes when retry is done.

task_id = "arf-{{ ti.job_id }}"
task_id = AIRFLOW_TASK_ID_TEMPLATE_VALUE
task_id_arg = "--task-id %s" % task_id

# Airflow run_ids are of the form : "manual__2022-03-15T01:26:41.186781+00:00"
# Such run-ids break the `metaflow.util.decompress_list`; this is why we hash the runid
run_id = "%s-$(echo -n {{ run_id }} | md5sum | awk '{print $1}')" % AIRFLOW_PREFIX
Expand Down Expand Up @@ -110,8 +114,9 @@ def _k8s_job(self, node, input_paths, env):
k8s_deco = [deco for deco in node.decorators if deco.name == "kubernetes"][0]
user_code_retries, total_retries = self._get_retries(node)
retry_delay = self._get_retry_delay(node)
# This sets timeouts for @timeout decorators.
# The timeout is set as "execution_timeout" for an airflow task.
runtime_limit = get_run_time_limit_for_task(node.decorators)

return create_k8s_args(
self.flow_datastore,
self.metadata,
Expand Down Expand Up @@ -155,8 +160,8 @@ def _get_retries(self, node):
def _get_retry_delay(self, node):
retry_decos = [deco for deco in node.decorators if deco.name == "retry"]
if len(retry_decos) > 0:
retry_mins = retry_decos[0]["attributes"]["minutes_between_retries"]
return timedelta(minutes=retry_mins)
retry_mins = retry_decos[0].attributes["minutes_between_retries"]
return timedelta(minutes=int(retry_mins))
return None

def _process_parameters(self):
Expand Down Expand Up @@ -251,11 +256,22 @@ def _to_job(self, node: DAGNode):
else:
if len(node.in_funcs) == 1:
# set input paths where this is only one parent node
# The parent-task-id is passed via the xcom;
# The parent-task-id is passed via the xcom; There is no other way to get that.
# One key thing about xcoms is that they are immutable and only accepted if the task
# doesn't fail.
# From airflow docs :
# "Note: If the first task run is not succeeded then on every retry task XComs will be cleared to make the task run idempotent."
input_paths = (
# This is set using the `airflow_internal` decorator.
"%s/%s/{{ task_instance.xcom_pull('%s')['metaflow_task_id'] }}"
% (self.run_id, node.in_funcs[0], node.in_funcs[0])
# This will pull the `return_value` xcom which holds a dictionary.
# This helps pass state.
"%s/%s/{{ task_instance.xcom_pull(task_ids='%s')['%s'] }}"
% (
self.run_id,
node.in_funcs[0],
node.in_funcs[0],
TASK_ID_XCOM_KEY,
)
)
else:
# this is a split scenario where there can be more than one input paths.
Expand Down
19 changes: 10 additions & 9 deletions metaflow/plugins/airflow/airflow_decorator.py
Expand Up @@ -5,6 +5,7 @@
from metaflow.decorators import StepDecorator
from metaflow.metadata import MetaDatum
from .plumbing.airflow_xcom_push import push_xcom_values
from .airflow_utils import TASK_ID_XCOM_KEY


class AirflowInternalDecorator(StepDecorator):
Expand All @@ -29,6 +30,8 @@ def task_pre_step(
# handle xcom push / pull differently
meta = {}
meta["airflow-execution"] = os.environ["METAFLOW_RUN_ID"]
meta["airflow-dag-run-id"] = os.environ["METAFLOW_AIRFLOW_DAG_RUN_ID"]
meta["airflow-job-id"] = os.environ["METAFLOW_AIRFLOW_JOB_ID"]
entries = [
MetaDatum(
field=k, value=v, type=k, tags=["attempt_id:{0}".format(retry_count)]
Expand All @@ -37,17 +40,15 @@ def task_pre_step(
]
# Register book-keeping metadata for debugging.
metadata.register_metadata(run_id, step_name, task_id, entries)
if retry_count == 0:
push_xcom_values(
{
TASK_ID_XCOM_KEY: os.environ["METAFLOW_AIRFLOW_TASK_ID"],
}
)

def task_finished(
self, step_name, flow, graph, is_task_ok, retry_count, max_user_code_retries
):
if not is_task_ok:
# The task finished with an exception - execution won't
# continue so no need to do anything here.
return
pass
# todo : Figure ways to find out foreach cardinality over here,
push_xcom_values(
{
"metaflow_task_id": os.environ["METAFLOW_AIRFLOW_TASK_ID"],
}
)
42 changes: 31 additions & 11 deletions metaflow/plugins/airflow/airflow_utils.py
Expand Up @@ -7,10 +7,21 @@
import hashlib
import re
from datetime import timedelta, datetime
import hashlib

LABEL_VALUE_REGEX = re.compile(r"^[a-zA-Z0-9]([a-zA-Z0-9\-\_\.]{0,61}[a-zA-Z0-9])?$")

TASK_ID_XCOM_KEY = "metaflow_task_id"

# AIRFLOW_TASK_ID_TEMPLATE_VALUE will work for linear/branched workflows.
# ti.task_id is the stepname in metaflow code.
# AIRFLOW_TASK_ID_TEMPLATE_VALUE uses a jinja filter called `task_id_creator` which helps
# concatenate the string using a `/`. Since run-id will keep changing and stepname will be
# the same task id will change. Since airflow doesn't encourage dynamic rewriting of dags
# we can rename steps in a foreach with indexes (eg. `stepname-$index`) to create those steps.
# Hence : Foreachs will require some special form of plumbing.
# https://stackoverflow.com/questions/62962386/can-an-airflow-task-dynamically-generate-a-dag-at-runtime
AIRFLOW_TASK_ID_TEMPLATE_VALUE = "arf-{{ [run_id, ti.task_id ] | task_id_creator }}"


def sanitize_label_value(val):
# Label sanitization: if the value can be used as is, return it as is.
Expand All @@ -35,6 +46,13 @@ def hasher(my_value):
return hashlib.md5(my_value.encode("utf-8")).hexdigest()


def task_id_creator(lst):
# This is a filter which creates a hash of the run_id/step_name string.
# Since run_ids in airflow are constants, they don't create an issue with the
#
return hashlib.md5("/".join(lst).encode("utf-8")).hexdigest()


class AirflowDAGArgs(object):
# _arg_types This object helps map types of
# different keys that need to be parsed. None of the values in this
Expand Down Expand Up @@ -74,7 +92,10 @@ class AirflowDAGArgs(object):

metaflow_centric_args = {
# Reference for user_defined_filters : https://stackoverflow.com/a/70175317
"user_defined_filters": dict(hash=lambda my_value: hasher(my_value)),
"user_defined_filters": dict(
hash=lambda my_value: hasher(my_value),
task_id_creator=lambda v: task_id_creator(v),
),
}

def __init__(self, **kwargs):
Expand Down Expand Up @@ -151,9 +172,7 @@ def generate_rfc1123_name(flow_name, step_name):
def set_k8s_operator_args(flow_name, step_name, operator_args):
from kubernetes import client

task_id = (
"arf-{{ ti.job_id }}" # Todo : find a way to switch this with something else.
)
task_id = AIRFLOW_TASK_ID_TEMPLATE_VALUE
run_id = "arf-{{ run_id | hash }}" # hash is added via the `user_defined_filters`
attempt = "{{ task_instance.try_number - 1 }}"
# Set dynamic env variables like run-id, task-id etc from here.
Expand All @@ -165,6 +184,8 @@ def set_k8s_operator_args(flow_name, step_name, operator_args):
for k, v in dict(
METAFLOW_RUN_ID=run_id,
METAFLOW_AIRFLOW_TASK_ID=task_id,
METAFLOW_AIRFLOW_DAG_RUN_ID="{{run_id}}",
METAFLOW_AIRFLOW_JOB_ID="{{ti.job_id}}",
METAFLOW_ATTEMPT_NUMBER=attempt,
).items()
]
Expand All @@ -179,6 +200,7 @@ def set_k8s_operator_args(flow_name, step_name, operator_args):
]
volumes = [client.V1Volume(**v) for v in operator_args.get("volumes", [])]
args = {
# "on_retry_callback": retry_callback,
"namespace": operator_args.get("namespace", "airflow"),
"image": operator_args.get("image", "python"),
"name": generate_rfc1123_name(flow_name, step_name),
Expand Down Expand Up @@ -241,12 +263,10 @@ def set_k8s_operator_args(flow_name, step_name, operator_args):
}
args["labels"].update(labels)
if operator_args.get("execution_timeout", None):
args["execution_timeout"] = (
timedelta(
**operator_args.get(
"execution_timeout",
)
),
args["execution_timeout"] = timedelta(
**operator_args.get(
"execution_timeout",
)
)
if operator_args.get("retry_delay", None):
args["retry_delay"] = timedelta(**operator_args.get("retry_delay"))
Expand Down

0 comments on commit 563a200

Please sign in to comment.