Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions airflow/providers/databricks/operators/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,12 @@ def execute(self, context: Context) -> int:
if "name" not in self.json:
raise AirflowException("Missing required parameter: name")
job_id = self._hook.find_job_id_by_name(self.json["name"])
if not self.json.get("parameters") and self.params:
job_params = self.params.items() if self.params.items() else {}
param_list = []
for k, v in job_params:
param_list.append({"name": k, "default": v})
self.json["parameters"] = param_list
Comment on lines +317 to +321
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
job_params = self.params.items() if self.params.items() else {}
param_list = []
for k, v in job_params:
param_list.append({"name": k, "default": v})
self.json["parameters"] = param_list
if self.params.items() is not None:
self.json["parameters"] = [{"name": k, "default": v} for for k, v in self.params.items()]
else:
self.json["parameters"] = {}

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is this self.params from?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are airflow configs

if job_id is None:
return self._hook.create_job(self.json)
self._hook.reset_job(str(job_id), self.json)
Expand Down
55 changes: 55 additions & 0 deletions tests/providers/databricks/operators/test_databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@
"permission_level": "CAN_MANAGE",
}
]
JOB_PARAMS = [{"name": "param1", "default": "value1"}]


def mock_dict(d: dict):
Expand Down Expand Up @@ -441,6 +442,7 @@ def test_exec_create(self, db_mock_class):
"max_concurrent_runs": MAX_CONCURRENT_RUNS,
"git_source": GIT_SOURCE,
"access_control_list": ACCESS_CONTROL_LIST,
"parameters": [],
}
)
db_mock_class.assert_called_once_with(
Expand Down Expand Up @@ -491,6 +493,7 @@ def test_exec_reset(self, db_mock_class):
"max_concurrent_runs": MAX_CONCURRENT_RUNS,
"git_source": GIT_SOURCE,
"access_control_list": ACCESS_CONTROL_LIST,
"parameters": [],
}
)
db_mock_class.assert_called_once_with(
Expand Down Expand Up @@ -573,6 +576,58 @@ def test_exec_update_job_permission_with_empty_acl(self, db_mock_class):

db_mock.update_job_permission.assert_not_called()

@mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook")
def test_create_job_with_job_parameter(self, db_mock_class):
"""
Test create job with job parameters.
"""
json = {
"name": JOB_NAME,
"tags": TAGS,
"tasks": TASKS,
"job_clusters": JOB_CLUSTERS,
"email_notifications": EMAIL_NOTIFICATIONS,
"webhook_notifications": WEBHOOK_NOTIFICATIONS,
"timeout_seconds": TIMEOUT_SECONDS,
"schedule": SCHEDULE,
"max_concurrent_runs": MAX_CONCURRENT_RUNS,
"git_source": GIT_SOURCE,
"access_control_list": ACCESS_CONTROL_LIST,
"parameters": JOB_PARAMS,
}
op = DatabricksCreateJobsOperator(task_id=TASK_ID, json=json)
db_mock = db_mock_class.return_value
db_mock.find_job_id_by_name.return_value = None

op.execute({})

expected = utils.normalise_json_content(
{
"name": JOB_NAME,
"tags": TAGS,
"tasks": TASKS,
"job_clusters": JOB_CLUSTERS,
"email_notifications": EMAIL_NOTIFICATIONS,
"webhook_notifications": WEBHOOK_NOTIFICATIONS,
"timeout_seconds": TIMEOUT_SECONDS,
"schedule": SCHEDULE,
"max_concurrent_runs": MAX_CONCURRENT_RUNS,
"git_source": GIT_SOURCE,
"access_control_list": ACCESS_CONTROL_LIST,
"parameters": JOB_PARAMS,
}
)

db_mock_class.assert_called_once_with(
DEFAULT_CONN_ID,
retry_limit=op.databricks_retry_limit,
retry_delay=op.databricks_retry_delay,
retry_args=None,
caller="DatabricksCreateJobsOperator",
)

db_mock.create_job.assert_called_once_with(expected)


class TestDatabricksSubmitRunOperator:
def test_init_with_notebook_task_named_parameters(self):
Expand Down