Navigation Menu

Skip to content

Commit

Permalink
Foreach polish (valayDave#62)
Browse files Browse the repository at this point in the history
* Removing unused imports
* Added validation logic for airflow version numbers with foreaches
* Removed `airflow_schedule_interval` decorator.

* Added production/deployment token related changes
- Uses s3 as a backend to store the production token
- Token used for avoiding nameclashes
- token stored via `FlowDatastore`

* Graph type validation for airflow foreachs
- Airflow foreachs only support single node fanout.
- validation invalidates graphs with nested foreachs

* Added configuration about startup_timeout.

* Added final todo on `resources` argument of k8sOp
- added a commented code block
- it needs to be uncommented when airflow releasese the patch for the op
- Code seems feature complete keeping aside airflow patch
  • Loading branch information
valayDave committed Jul 28, 2022
1 parent 4b2dd12 commit 0673db7
Show file tree
Hide file tree
Showing 6 changed files with 387 additions and 99 deletions.
7 changes: 7 additions & 0 deletions metaflow/metaflow_config.py
Expand Up @@ -242,6 +242,13 @@ def from_conf(name, default=None):
)
#

##
# Airflow Configuration
##
AIRFLOW_KUBERNETES_STARTUP_TIMEOUT = from_conf(
"METAFLOW_AIRFLOW_KUBERNETES_STARTUP_TIMEOUT_SECONDS", 60 * 60
)


###
# Conda configuration
Expand Down
4 changes: 0 additions & 4 deletions metaflow/plugins/__init__.py
Expand Up @@ -165,9 +165,6 @@ def get_plugin_cli():
from .aws.step_functions.schedule_decorator import ScheduleDecorator
from .project_decorator import ProjectDecorator

from .airflow.airflow_decorator import (
AirflowScheduleIntervalDecorator,
)

from .airflow.sensors import (
S3KeySensorDecorator,
Expand All @@ -178,7 +175,6 @@ def get_plugin_cli():
FLOW_DECORATORS = [CondaFlowDecorator, ScheduleDecorator, ProjectDecorator] + [
S3KeySensorDecorator,
ExternalTaskSensorDecorator,
AirflowScheduleIntervalDecorator,
SQLSensorDecorator,
]
_merge_lists(FLOW_DECORATORS, _ext_plugins["FLOW_DECORATORS"], "name")
Expand Down
69 changes: 55 additions & 14 deletions metaflow/plugins/airflow/airflow.py
@@ -1,4 +1,4 @@
import base64
from io import BytesIO
import json
import os
import random
Expand All @@ -17,6 +17,7 @@
DATATOOLS_S3ROOT,
KUBERNETES_SERVICE_ACCOUNT,
KUBERNETES_SECRETS,
AIRFLOW_KUBERNETES_STARTUP_TIMEOUT,
)
from metaflow.parameters import deploy_time_eval
from metaflow.plugins.kubernetes.kubernetes import Kubernetes
Expand All @@ -31,8 +32,6 @@
from .exception import AirflowException
from .sensors import SUPPORTED_SENSORS
from .airflow_utils import (
RUN_HASH_ID_LEN,
RUN_ID_PREFIX,
TASK_ID_XCOM_KEY,
AirflowTask,
Workflow,
Expand All @@ -44,6 +43,9 @@


class Airflow(object):

TOKEN_STORAGE_ROOT = "mf.airflow"

def __init__(
self,
name,
Expand All @@ -56,6 +58,7 @@ def __init__(
environment,
event_logger,
monitor,
production_token,
tags=None,
namespace=None,
username=None,
Expand Down Expand Up @@ -89,6 +92,40 @@ def __init__(
self.workflow_timeout = workflow_timeout
self.schedule = self._scheduling_interval()
self.parameters = self._process_parameters()
self.production_token = production_token

@classmethod
def get_existing_deployment(cls, name, flow_datastore):
_backend = flow_datastore._storage_impl
token_paths = _backend.list_content([cls.get_token_path(name)])
if len(token_paths) == 0:
return None

with _backend.load_bytes([token_paths[0]]) as get_results:
for _, path, _ in get_results:
if path is not None:
with open(path, "r") as f:
data = json.loads(f.read())
return (data["owner"], data["token"])

@classmethod
def get_token_path(cls, name):
return os.path.join(cls.TOKEN_STORAGE_ROOT, name)

@classmethod
def save_deployment_token(cls, owner, token, flow_datastore):
_backend = flow_datastore._storage_impl
_backend.save_bytes(
[
(
cls.get_token_path(token),
BytesIO(
bytes(json.dumps({"token": token, "owner": owner}), "utf-8")
),
)
],
overwrite=False,
)

def _scheduling_interval(self):
"""
Expand Down Expand Up @@ -320,13 +357,9 @@ def _to_job(self, node):

metaflow_version = self.environment.get_environment_info()
metaflow_version["flow_name"] = self.graph.name
metaflow_version["production_token"] = self.production_token
env["METAFLOW_VERSION"] = json.dumps(metaflow_version)

# Todo : Find ways to pass state using for the below usecases:
# 1. To set the cardinality of foreaches
# 2. To set the input paths from the parent steps of a foreach join.
# 3. To read the input paths in a foreach join.

# Extract the k8s decorators for constructing the arguments of the K8s Pod Operator on Airflow.
k8s_deco = [deco for deco in node.decorators if deco.name == "kubernetes"][0]
user_code_retries, _ = self._get_retries(node)
Expand Down Expand Up @@ -357,14 +390,16 @@ def _to_job(self, node):
"METAFLOW_DATATOOLS_S3ROOT": DATATOOLS_S3ROOT,
"METAFLOW_DEFAULT_DATASTORE": "s3",
"METAFLOW_DEFAULT_METADATA": "service",
# Question for (savin) : what does `METAFLOW_KUBERNETES_WORKLOAD` do ?
"METAFLOW_KUBERNETES_WORKLOAD": str(1),
"METAFLOW_KUBERNETES_WORKLOAD": str(
1
), # This is used by kubernetes decorator.
"METAFLOW_RUNTIME_ENVIRONMENT": "kubernetes",
"METAFLOW_CARD_S3ROOT": DATASTORE_CARD_S3ROOT,
"METAFLOW_RUN_ID": AIRFLOW_MACROS.RUN_ID,
"METAFLOW_AIRFLOW_TASK_ID": AIRFLOW_MACROS.TASK_ID,
"METAFLOW_AIRFLOW_DAG_RUN_ID": AIRFLOW_MACROS.AIRFLOW_RUN_ID,
"METAFLOW_AIRFLOW_JOB_ID": AIRFLOW_MACROS.AIRFLOW_JOB_ID,
"METAFLOW_PRODUCTION_TOKEN": self.production_token,
"METAFLOW_ATTEMPT_NUMBER": AIRFLOW_MACROS.ATTEMPT,
}
env.update(additional_mf_variables)
Expand Down Expand Up @@ -399,6 +434,7 @@ def _to_job(self, node):
)

annotations = {
"metaflow/production_token": self.production_token,
"metaflow/owner": self.username,
"metaflow/user": self.username,
"metaflow/flow_name": self.flow.name,
Expand Down Expand Up @@ -437,7 +473,7 @@ def _to_job(self, node):
env_vars=[dict(name=k, value=v) for k, v in env.items()],
labels=labels,
task_id=node.name,
startup_timeout_seconds=60 * 60,
startup_timeout_seconds=AIRFLOW_KUBERNETES_STARTUP_TIMEOUT,
in_cluster=True,
get_logs=True,
do_xcom_push=True,
Expand Down Expand Up @@ -575,11 +611,15 @@ def _collect_flow_sensors(self):
self._depends_on_upstream_sensors = True
return af_tasks

def compile(self):
from metaflow.graph import DAGNode
def _contains_foreach(self):
for node in self.graph:
if node.type == "foreach":
return True
return False

def compile(self):
# Visit every node of the flow and recursively build the state machine.
def _visit(node: DAGNode, workflow, exit_node=None):
def _visit(node, workflow, exit_node=None):
if node.parallel_foreach:
raise AirflowException(
"Deploying flows with @parallel decorator(s) "
Expand Down Expand Up @@ -643,6 +683,7 @@ def _visit(node: DAGNode, workflow, exit_node=None):
tags=self.tags,
file_path=self._file_path,
graph_structure=self.graph_structure,
metadata=dict(contains_foreach=self._contains_foreach()),
**airflow_dag_args
)
workflow = _visit(self.graph["start"], workflow)
Expand Down

0 comments on commit 0673db7

Please sign in to comment.