Skip to content

Commit

Permalink
DatabricksSubmitRunOperator to support taskflow (#29840)
Browse files Browse the repository at this point in the history
* move validation to execute + change type hints

* revert type changes

* update tests

* move json normalisation in SubmitRunDeferrable too

---------

Co-authored-by: Hernan Resnizky <hernanresnizky@NEX-0003.local>
Co-authored-by: Tzu-ping Chung <uranusjr@gmail.com>
  • Loading branch information
3 people committed Mar 7, 2023
1 parent c95184e commit c405ecb
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 16 deletions.
7 changes: 4 additions & 3 deletions airflow/providers/databricks/operators/databricks.py
Expand Up @@ -353,7 +353,6 @@ def __init__(
if "dbt_task" in self.json and "git_source" not in self.json:
raise AirflowException("git_source is required for dbt_task")

self.json = normalise_json_content(self.json)
# This variable will be used in case our task gets killed.
self.run_id: int | None = None
self.do_xcom_push = do_xcom_push
Expand All @@ -372,7 +371,8 @@ def _get_hook(self, caller: str) -> DatabricksHook:
)

def execute(self, context: Context):
self.run_id = self._hook.submit_run(self.json)
json_normalised = normalise_json_content(self.json)
self.run_id = self._hook.submit_run(json_normalised)
_handle_databricks_operator_execution(self, self._hook, self.log, context)

def on_kill(self):
Expand All @@ -390,7 +390,8 @@ class DatabricksSubmitRunDeferrableOperator(DatabricksSubmitRunOperator):

def execute(self, context):
hook = self._get_hook(caller="DatabricksSubmitRunDeferrableOperator")
self.run_id = hook.submit_run(self.json)
json_normalised = normalise_json_content(self.json)
self.run_id = hook.submit_run(json_normalised)
_handle_deferrable_databricks_operator_execution(self, hook, self.log, context)

def execute_complete(self, context: dict | None, event: dict):
Expand Down
27 changes: 14 additions & 13 deletions tests/providers/databricks/operators/test_databricks.py
Expand Up @@ -105,7 +105,7 @@ def test_init_with_notebook_task_named_parameters(self):
{"new_cluster": NEW_CLUSTER, "notebook_task": NOTEBOOK_TASK, "run_name": TASK_ID}
)

assert expected == op.json
assert expected == utils.normalise_json_content(op.json)

def test_init_with_spark_python_task_named_parameters(self):
"""
Expand All @@ -118,7 +118,7 @@ def test_init_with_spark_python_task_named_parameters(self):
{"new_cluster": NEW_CLUSTER, "spark_python_task": SPARK_PYTHON_TASK, "run_name": TASK_ID}
)

assert expected == op.json
assert expected == utils.normalise_json_content(op.json)

def test_init_with_spark_submit_task_named_parameters(self):
"""
Expand All @@ -131,7 +131,7 @@ def test_init_with_spark_submit_task_named_parameters(self):
{"new_cluster": NEW_CLUSTER, "spark_submit_task": SPARK_SUBMIT_TASK, "run_name": TASK_ID}
)

assert expected == op.json
assert expected == utils.normalise_json_content(op.json)

def test_init_with_dbt_task_named_parameters(self):
"""
Expand All @@ -149,7 +149,7 @@ def test_init_with_dbt_task_named_parameters(self):
{"new_cluster": NEW_CLUSTER, "dbt_task": DBT_TASK, "git_source": git_source, "run_name": TASK_ID}
)

assert expected == op.json
assert expected == utils.normalise_json_content(op.json)

def test_init_with_dbt_task_mixed_parameters(self):
"""
Expand All @@ -168,7 +168,7 @@ def test_init_with_dbt_task_mixed_parameters(self):
{"new_cluster": NEW_CLUSTER, "dbt_task": DBT_TASK, "git_source": git_source, "run_name": TASK_ID}
)

assert expected == op.json
assert expected == utils.normalise_json_content(op.json)

def test_init_with_dbt_task_without_git_source_raises_error(self):
"""
Expand Down Expand Up @@ -197,13 +197,13 @@ def test_init_with_json(self):
expected = utils.normalise_json_content(
{"new_cluster": NEW_CLUSTER, "notebook_task": NOTEBOOK_TASK, "run_name": TASK_ID}
)
assert expected == op.json
assert expected == utils.normalise_json_content(op.json)

def test_init_with_tasks(self):
tasks = [{"task_key": 1, "new_cluster": NEW_CLUSTER, "notebook_task": NOTEBOOK_TASK}]
op = DatabricksSubmitRunOperator(task_id=TASK_ID, tasks=tasks)
expected = utils.normalise_json_content({"run_name": TASK_ID, "tasks": tasks})
assert expected == op.json
assert expected == utils.normalise_json_content(op.json)

def test_init_with_specified_run_name(self):
"""
Expand All @@ -214,7 +214,7 @@ def test_init_with_specified_run_name(self):
expected = utils.normalise_json_content(
{"new_cluster": NEW_CLUSTER, "notebook_task": NOTEBOOK_TASK, "run_name": RUN_NAME}
)
assert expected == op.json
assert expected == utils.normalise_json_content(op.json)

def test_pipeline_task(self):
"""
Expand All @@ -226,7 +226,7 @@ def test_pipeline_task(self):
expected = utils.normalise_json_content(
{"new_cluster": NEW_CLUSTER, "pipeline_task": pipeline_task, "run_name": RUN_NAME}
)
assert expected == op.json
assert expected == utils.normalise_json_content(op.json)

def test_init_with_merging(self):
"""
Expand All @@ -247,7 +247,7 @@ def test_init_with_merging(self):
"run_name": TASK_ID,
}
)
assert expected == op.json
assert expected == utils.normalise_json_content(op.json)

def test_init_with_templating(self):
json = {
Expand All @@ -264,7 +264,7 @@ def test_init_with_templating(self):
"run_name": TASK_ID,
}
)
assert expected == op.json
assert expected == utils.normalise_json_content(op.json)

def test_init_with_git_source(self):
json = {"new_cluster": NEW_CLUSTER, "notebook_task": NOTEBOOK_TASK, "run_name": RUN_NAME}
Expand All @@ -282,17 +282,18 @@ def test_init_with_git_source(self):
"git_source": git_source,
}
)
assert expected == op.json
assert expected == utils.normalise_json_content(op.json)

def test_init_with_bad_type(self):
json = {"test": datetime.now()}
op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=json)
# Looks a bit weird since we have to escape regex reserved symbols.
exception_message = (
r"Type \<(type|class) \'datetime.datetime\'\> used "
r"for parameter json\[test\] is not a number or a string"
)
with pytest.raises(AirflowException, match=exception_message):
DatabricksSubmitRunOperator(task_id=TASK_ID, json=json)
utils.normalise_json_content(op.json)

@mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook")
def test_exec_success(self, db_mock_class):
Expand Down

0 comments on commit c405ecb

Please sign in to comment.