diff --git a/metaflow/plugins/airflow/airflow_compiler.py b/metaflow/plugins/airflow/airflow_compiler.py index 3fa9d7c219..71a1eae055 100644 --- a/metaflow/plugins/airflow/airflow_compiler.py +++ b/metaflow/plugins/airflow/airflow_compiler.py @@ -43,6 +43,8 @@ class Airflow(object): + parameter_macro = "{{ params | json_dump }}" + task_id = AIRFLOW_TASK_ID_TEMPLATE_VALUE task_id_arg = "--task-id %s" % task_id @@ -96,13 +98,15 @@ def __init__( self.catchup = catchup self.schedule_interval = self._get_schedule() self._file_path = file_path + self.metaflow_parameters = None def _get_schedule(self): schedule = self.flow._flow_decorators.get("schedule") if schedule: return schedule.schedule - # Airflow requires a scheduling arguement so keeping this - return "*/2 * * * *" + # Schedule can be None. + # Especially if parameters are provided without defaults from toplevel. + return None def _k8s_job(self, node, input_paths, env): # since we are attaching k8s at cli, there will be one for a step. @@ -163,6 +167,16 @@ def _process_parameters(self): # Copied from metaflow.plugins.aws.step_functions.step_functions parameters = [] seen = set() + airflow_params = [] + allowed_types = [int, str, bool, float] + type_transform_dict = { + int.__name__: "integer", + str.__name__: "string", + bool.__name__: "string", + float.__name__: "number", + } + type_parser = {bool.__name__: lambda v: str(v)} + for var, param in self.flow._get_parameters(): # Throw an exception if the parameter is specified twice. norm = param.name.lower() @@ -181,12 +195,35 @@ def _process_parameters(self): if "default" not in param.kwargs and is_required: raise MetaflowException( "The parameter *%s* does not have a " - "default and is required. Scheduling " - "such parameters via AWS Event Bridge " - "is not currently supported." % param.name + "default while having 'required' set to 'True'. " + "A default is required for such parameters when deploying on Airflow." + ) + if "default" not in param.kwargs and self.schedule_interval: + raise MetaflowException( + "When @schedule is set with Airflow, Parameters require default values. " + "The parameter *%s* does not have a " + "'default' set" ) value = deploy_time_eval(param.kwargs.get("default")) parameters.append(dict(name=param.name, value=value)) + # Setting airflow related param args. + param_type = param.kwargs.get("type", None) + airflowparam = dict( + name=param.name, + ) + phelp = param.kwargs.get("help", None) + if value is not None: + airflowparam["default"] = value + if phelp: + airflowparam["description"] = phelp + if param_type in allowed_types: + airflowparam["type"] = type_transform_dict[param_type.__name__] + if param_type.__name__ in type_parser and value is not None: + airflowparam["default"] = type_parser[param_type.__name__](value) + + airflow_params.append(airflowparam) + self.metaflow_parameters = airflow_params + return parameters def _to_job(self, node: DAGNode): @@ -226,7 +263,7 @@ def _to_job(self, node: DAGNode): if node.name == "start": parameters = self._process_parameters() if parameters: - env["METAFLOW_PARAMETERS"] = "{{ params }}" + env["METAFLOW_PARAMETERS"] = self.parameter_macro default_parameters = {} for parameter in parameters: if parameter["value"] is not None: @@ -495,8 +532,9 @@ def _visit(node: DAGNode, workflow: Workflow, exit_node=None): tags=self.tags, file_path=self._file_path, ) - json_dag = _visit(self.graph["start"], workflow).to_dict() - return self._create_airflow_file(json_dag) + workflow = _visit(self.graph["start"], workflow) + workflow.set_parameters(self.metaflow_parameters) + return self._create_airflow_file(workflow.to_dict()) def _create_airflow_file(self, json_dag): util_file = None diff --git a/metaflow/plugins/airflow/airflow_utils.py b/metaflow/plugins/airflow/airflow_utils.py index cc34acdff9..ad49bf6784 100644 --- a/metaflow/plugins/airflow/airflow_utils.py +++ b/metaflow/plugins/airflow/airflow_utils.py @@ -53,6 +53,10 @@ def task_id_creator(lst): return hashlib.md5("/".join(lst).encode("utf-8")).hexdigest() +def json_dump(val): + return json.dumps(val) + + class AirflowDAGArgs(object): # _arg_types This object helps map types of # different keys that need to be parsed. None of the values in this @@ -95,6 +99,7 @@ class AirflowDAGArgs(object): "user_defined_filters": dict( hash=lambda my_value: hasher(my_value), task_id_creator=lambda v: task_id_creator(v), + json_dump=lambda val: json_dump(val), ), } @@ -343,6 +348,10 @@ def __init__(self, file_path=None, **kwargs): self._file_path = file_path tree = lambda: defaultdict(tree) self.states = tree() + self.metaflow_params = None + + def set_parameters(self, params): + self.metaflow_params = params def add_state(self, state): self.states[state.name] = state @@ -352,6 +361,7 @@ def to_dict(self): states={s: v.to_dict() for s, v in self.states.items()}, dag_instantiation_params=self._dag_instantiation_params.to_dict(), file_path=self._file_path, + metaflow_params=self.metaflow_params, ) def to_json(self): @@ -372,6 +382,7 @@ def from_dict(cls, data_dict): sd, flow_name=re_cls._dag_instantiation_params.arguements["dag_id"] ) ) + re_cls.set_parameters(data_dict["metaflow_params"]) return re_cls @classmethod @@ -379,10 +390,23 @@ def from_json(cls, json_string): data = json.loads(json_string) return cls.from_dict(data) + def _construct_params(self): + from airflow.models.param import Param + + if self.metaflow_params is None: + return {} + param_dict = {} + for p in self.metaflow_params: + name = p["name"] + del p["name"] + param_dict[name] = Param(**p) + return param_dict + def compile(self): from airflow import DAG - dag = DAG(**self._dag_instantiation_params.arguements) + params_dict = self._construct_params() + dag = DAG(params=params_dict, **self._dag_instantiation_params.arguements) curr_state = self.states["start"] curr_task = self.states["start"].to_task() prev_task = None