From b297b3674025d3836e3d39d6cabe9af7c5ac0737 Mon Sep 17 00:00:00 2001 From: subham611 Date: Sun, 14 Apr 2024 18:42:38 +0530 Subject: [PATCH 1/4] Pass airflow config as job parameters in databrickCreateJobOperator --- airflow/providers/databricks/operators/databricks.py | 9 +++++++++ tests/providers/databricks/operators/test_databricks.py | 2 ++ 2 files changed, 11 insertions(+) diff --git a/airflow/providers/databricks/operators/databricks.py b/airflow/providers/databricks/operators/databricks.py index 1d0d920ecca1b..799f22681a8ae 100644 --- a/airflow/providers/databricks/operators/databricks.py +++ b/airflow/providers/databricks/operators/databricks.py @@ -313,6 +313,15 @@ def execute(self, context: Context) -> int: if "name" not in self.json: raise AirflowException("Missing required parameter: name") job_id = self._hook.find_job_id_by_name(self.json["name"]) + + if self.json.get("parameters") is None and self.params is not None: + job_params = self.params.items() if self.params.items() is not None else {} + param_list = [] + for k, v in job_params: + param_list.append({"name": k, "default": v}) + self.log.info("Param list") + self.log.info(param_list) + self.json["parameters"] = param_list if job_id is None: return self._hook.create_job(self.json) self._hook.reset_job(str(job_id), self.json) diff --git a/tests/providers/databricks/operators/test_databricks.py b/tests/providers/databricks/operators/test_databricks.py index 46e14a917ab4e..76174343b43d3 100644 --- a/tests/providers/databricks/operators/test_databricks.py +++ b/tests/providers/databricks/operators/test_databricks.py @@ -441,6 +441,7 @@ def test_exec_create(self, db_mock_class): "max_concurrent_runs": MAX_CONCURRENT_RUNS, "git_source": GIT_SOURCE, "access_control_list": ACCESS_CONTROL_LIST, + "parameters": [] } ) db_mock_class.assert_called_once_with( @@ -491,6 +492,7 @@ def test_exec_reset(self, db_mock_class): "max_concurrent_runs": MAX_CONCURRENT_RUNS, "git_source": GIT_SOURCE, "access_control_list": ACCESS_CONTROL_LIST, + "parameters": [] } ) db_mock_class.assert_called_once_with( From 49cd58af238d7c7a73a14733e1bb172968935d5c Mon Sep 17 00:00:00 2001 From: subham611 Date: Sun, 14 Apr 2024 19:19:20 +0530 Subject: [PATCH 2/4] Adds UT --- .../databricks/operators/databricks.py | 3 -- .../databricks/operators/test_databricks.py | 53 +++++++++++++++++++ 2 files changed, 53 insertions(+), 3 deletions(-) diff --git a/airflow/providers/databricks/operators/databricks.py b/airflow/providers/databricks/operators/databricks.py index 799f22681a8ae..7461d294c7743 100644 --- a/airflow/providers/databricks/operators/databricks.py +++ b/airflow/providers/databricks/operators/databricks.py @@ -313,14 +313,11 @@ def execute(self, context: Context) -> int: if "name" not in self.json: raise AirflowException("Missing required parameter: name") job_id = self._hook.find_job_id_by_name(self.json["name"]) - if self.json.get("parameters") is None and self.params is not None: job_params = self.params.items() if self.params.items() is not None else {} param_list = [] for k, v in job_params: param_list.append({"name": k, "default": v}) - self.log.info("Param list") - self.log.info(param_list) self.json["parameters"] = param_list if job_id is None: return self._hook.create_job(self.json) diff --git a/tests/providers/databricks/operators/test_databricks.py b/tests/providers/databricks/operators/test_databricks.py index 76174343b43d3..431e74a851e9a 100644 --- a/tests/providers/databricks/operators/test_databricks.py +++ b/tests/providers/databricks/operators/test_databricks.py @@ -220,6 +220,7 @@ "permission_level": "CAN_MANAGE", } ] +JOB_PARAMS = [{"name": "param1", "default": "value1"}] def mock_dict(d: dict): @@ -575,6 +576,58 @@ def test_exec_update_job_permission_with_empty_acl(self, db_mock_class): db_mock.update_job_permission.assert_not_called() + @mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook") + def test_create_job_with_job_parameter(self, db_mock_class): + """ + Test create job with job parameters. + """ + json = { + "name": JOB_NAME, + "tags": TAGS, + "tasks": TASKS, + "job_clusters": JOB_CLUSTERS, + "email_notifications": EMAIL_NOTIFICATIONS, + "webhook_notifications": WEBHOOK_NOTIFICATIONS, + "timeout_seconds": TIMEOUT_SECONDS, + "schedule": SCHEDULE, + "max_concurrent_runs": MAX_CONCURRENT_RUNS, + "git_source": GIT_SOURCE, + "access_control_list": ACCESS_CONTROL_LIST, + "parameters": JOB_PARAMS + } + op = DatabricksCreateJobsOperator(task_id=TASK_ID, json=json) + db_mock = db_mock_class.return_value + db_mock.find_job_id_by_name.return_value = None + + op.execute({}) + + expected = utils.normalise_json_content( + { + "name": JOB_NAME, + "tags": TAGS, + "tasks": TASKS, + "job_clusters": JOB_CLUSTERS, + "email_notifications": EMAIL_NOTIFICATIONS, + "webhook_notifications": WEBHOOK_NOTIFICATIONS, + "timeout_seconds": TIMEOUT_SECONDS, + "schedule": SCHEDULE, + "max_concurrent_runs": MAX_CONCURRENT_RUNS, + "git_source": GIT_SOURCE, + "access_control_list": ACCESS_CONTROL_LIST, + "parameters": JOB_PARAMS + } + ) + + db_mock_class.assert_called_once_with( + DEFAULT_CONN_ID, + retry_limit=op.databricks_retry_limit, + retry_delay=op.databricks_retry_delay, + retry_args=None, + caller="DatabricksCreateJobsOperator", + ) + + db_mock.create_job.assert_called_once_with(expected) + class TestDatabricksSubmitRunOperator: def test_init_with_notebook_task_named_parameters(self): From 23c4505ecc56a271fd212a5b37dedfca7ba44423 Mon Sep 17 00:00:00 2001 From: subham611 Date: Sun, 14 Apr 2024 19:43:21 +0530 Subject: [PATCH 3/4] Fix static check --- tests/providers/databricks/operators/test_databricks.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/providers/databricks/operators/test_databricks.py b/tests/providers/databricks/operators/test_databricks.py index 431e74a851e9a..a42e3a1dfb655 100644 --- a/tests/providers/databricks/operators/test_databricks.py +++ b/tests/providers/databricks/operators/test_databricks.py @@ -442,7 +442,7 @@ def test_exec_create(self, db_mock_class): "max_concurrent_runs": MAX_CONCURRENT_RUNS, "git_source": GIT_SOURCE, "access_control_list": ACCESS_CONTROL_LIST, - "parameters": [] + "parameters": [], } ) db_mock_class.assert_called_once_with( @@ -493,7 +493,7 @@ def test_exec_reset(self, db_mock_class): "max_concurrent_runs": MAX_CONCURRENT_RUNS, "git_source": GIT_SOURCE, "access_control_list": ACCESS_CONTROL_LIST, - "parameters": [] + "parameters": [], } ) db_mock_class.assert_called_once_with( @@ -593,7 +593,7 @@ def test_create_job_with_job_parameter(self, db_mock_class): "max_concurrent_runs": MAX_CONCURRENT_RUNS, "git_source": GIT_SOURCE, "access_control_list": ACCESS_CONTROL_LIST, - "parameters": JOB_PARAMS + "parameters": JOB_PARAMS, } op = DatabricksCreateJobsOperator(task_id=TASK_ID, json=json) db_mock = db_mock_class.return_value @@ -614,7 +614,7 @@ def test_create_job_with_job_parameter(self, db_mock_class): "max_concurrent_runs": MAX_CONCURRENT_RUNS, "git_source": GIT_SOURCE, "access_control_list": ACCESS_CONTROL_LIST, - "parameters": JOB_PARAMS + "parameters": JOB_PARAMS, } ) From 01e3c76d8c2d5f8ea482ba4b0ee5af2456cffacf Mon Sep 17 00:00:00 2001 From: subham611 Date: Mon, 15 Apr 2024 11:17:53 +0530 Subject: [PATCH 4/4] code refactor --- airflow/providers/databricks/operators/databricks.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/airflow/providers/databricks/operators/databricks.py b/airflow/providers/databricks/operators/databricks.py index 7461d294c7743..d579daf3d516b 100644 --- a/airflow/providers/databricks/operators/databricks.py +++ b/airflow/providers/databricks/operators/databricks.py @@ -313,8 +313,8 @@ def execute(self, context: Context) -> int: if "name" not in self.json: raise AirflowException("Missing required parameter: name") job_id = self._hook.find_job_id_by_name(self.json["name"]) - if self.json.get("parameters") is None and self.params is not None: - job_params = self.params.items() if self.params.items() is not None else {} + if not self.json.get("parameters") and self.params: + job_params = self.params.items() if self.params.items() else {} param_list = [] for k, v in job_params: param_list.append({"name": k, "default": v})