Skip to content

Commit

Permalink
Added support for Parameters.
Browse files Browse the repository at this point in the history
- Supporting int, str, bool, float, JSONType
  • Loading branch information
valayDave committed Mar 20, 2022
1 parent c9378e9 commit 5b23eb7
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 9 deletions.
54 changes: 46 additions & 8 deletions metaflow/plugins/airflow/airflow_compiler.py
Expand Up @@ -43,6 +43,8 @@

class Airflow(object):

parameter_macro = "{{ params | json_dump }}"

task_id = AIRFLOW_TASK_ID_TEMPLATE_VALUE
task_id_arg = "--task-id %s" % task_id

Expand Down Expand Up @@ -96,13 +98,15 @@ def __init__(
self.catchup = catchup
self.schedule_interval = self._get_schedule()
self._file_path = file_path
self.metaflow_parameters = None

def _get_schedule(self):
schedule = self.flow._flow_decorators.get("schedule")
if schedule:
return schedule.schedule
# Airflow requires a scheduling arguement so keeping this
return "*/2 * * * *"
# Schedule can be None.
# Especially if parameters are provided without defaults from toplevel.
return None

def _k8s_job(self, node, input_paths, env):
# since we are attaching k8s at cli, there will be one for a step.
Expand Down Expand Up @@ -163,6 +167,16 @@ def _process_parameters(self):
# Copied from metaflow.plugins.aws.step_functions.step_functions
parameters = []
seen = set()
airflow_params = []
allowed_types = [int, str, bool, float]
type_transform_dict = {
int.__name__: "integer",
str.__name__: "string",
bool.__name__: "string",
float.__name__: "number",
}
type_parser = {bool.__name__: lambda v: str(v)}

for var, param in self.flow._get_parameters():
# Throw an exception if the parameter is specified twice.
norm = param.name.lower()
Expand All @@ -181,12 +195,35 @@ def _process_parameters(self):
if "default" not in param.kwargs and is_required:
raise MetaflowException(
"The parameter *%s* does not have a "
"default and is required. Scheduling "
"such parameters via AWS Event Bridge "
"is not currently supported." % param.name
"default while having 'required' set to 'True'. "
"A default is required for such parameters when deploying on Airflow."
)
if "default" not in param.kwargs and self.schedule_interval:
raise MetaflowException(
"When @schedule is set with Airflow, Parameters require default values. "
"The parameter *%s* does not have a "
"'default' set"
)
value = deploy_time_eval(param.kwargs.get("default"))
parameters.append(dict(name=param.name, value=value))
# Setting airflow related param args.
param_type = param.kwargs.get("type", None)
airflowparam = dict(
name=param.name,
)
phelp = param.kwargs.get("help", None)
if value is not None:
airflowparam["default"] = value
if phelp:
airflowparam["description"] = phelp
if param_type in allowed_types:
airflowparam["type"] = type_transform_dict[param_type.__name__]
if param_type.__name__ in type_parser and value is not None:
airflowparam["default"] = type_parser[param_type.__name__](value)

airflow_params.append(airflowparam)
self.metaflow_parameters = airflow_params

return parameters

def _to_job(self, node: DAGNode):
Expand Down Expand Up @@ -226,7 +263,7 @@ def _to_job(self, node: DAGNode):
if node.name == "start":
parameters = self._process_parameters()
if parameters:
env["METAFLOW_PARAMETERS"] = "{{ params }}"
env["METAFLOW_PARAMETERS"] = self.parameter_macro
default_parameters = {}
for parameter in parameters:
if parameter["value"] is not None:
Expand Down Expand Up @@ -495,8 +532,9 @@ def _visit(node: DAGNode, workflow: Workflow, exit_node=None):
tags=self.tags,
file_path=self._file_path,
)
json_dag = _visit(self.graph["start"], workflow).to_dict()
return self._create_airflow_file(json_dag)
workflow = _visit(self.graph["start"], workflow)
workflow.set_parameters(self.metaflow_parameters)
return self._create_airflow_file(workflow.to_dict())

def _create_airflow_file(self, json_dag):
util_file = None
Expand Down
26 changes: 25 additions & 1 deletion metaflow/plugins/airflow/airflow_utils.py
Expand Up @@ -53,6 +53,10 @@ def task_id_creator(lst):
return hashlib.md5("/".join(lst).encode("utf-8")).hexdigest()


def json_dump(val):
return json.dumps(val)


class AirflowDAGArgs(object):
# _arg_types This object helps map types of
# different keys that need to be parsed. None of the values in this
Expand Down Expand Up @@ -95,6 +99,7 @@ class AirflowDAGArgs(object):
"user_defined_filters": dict(
hash=lambda my_value: hasher(my_value),
task_id_creator=lambda v: task_id_creator(v),
json_dump=lambda val: json_dump(val),
),
}

Expand Down Expand Up @@ -343,6 +348,10 @@ def __init__(self, file_path=None, **kwargs):
self._file_path = file_path
tree = lambda: defaultdict(tree)
self.states = tree()
self.metaflow_params = None

def set_parameters(self, params):
self.metaflow_params = params

def add_state(self, state):
self.states[state.name] = state
Expand All @@ -352,6 +361,7 @@ def to_dict(self):
states={s: v.to_dict() for s, v in self.states.items()},
dag_instantiation_params=self._dag_instantiation_params.to_dict(),
file_path=self._file_path,
metaflow_params=self.metaflow_params,
)

def to_json(self):
Expand All @@ -372,17 +382,31 @@ def from_dict(cls, data_dict):
sd, flow_name=re_cls._dag_instantiation_params.arguements["dag_id"]
)
)
re_cls.set_parameters(data_dict["metaflow_params"])
return re_cls

@classmethod
def from_json(cls, json_string):
data = json.loads(json_string)
return cls.from_dict(data)

def _construct_params(self):
from airflow.models.param import Param

if self.metaflow_params is None:
return {}
param_dict = {}
for p in self.metaflow_params:
name = p["name"]
del p["name"]
param_dict[name] = Param(**p)
return param_dict

def compile(self):
from airflow import DAG

dag = DAG(**self._dag_instantiation_params.arguements)
params_dict = self._construct_params()
dag = DAG(params=params_dict, **self._dag_instantiation_params.arguements)
curr_state = self.states["start"]
curr_task = self.states["start"].to_task()
prev_task = None
Expand Down

0 comments on commit 5b23eb7

Please sign in to comment.