Skip to content

Commit

Permalink
D401 Support - Providers: DaskExecutor to Github (Inclusive) (#34935)
Browse files Browse the repository at this point in the history
  • Loading branch information
ferruzzi committed Oct 16, 2023
1 parent f23170c commit 7a93b19
Show file tree
Hide file tree
Showing 24 changed files with 110 additions and 108 deletions.
46 changes: 23 additions & 23 deletions airflow/providers/databricks/hooks/databricks.py
Expand Up @@ -196,7 +196,7 @@ def __init__(

def run_now(self, json: dict) -> int:
"""
Utility function to call the ``api/2.1/jobs/run-now`` endpoint.
Call the ``api/2.1/jobs/run-now`` endpoint.
:param json: The data used in the body of the request to the ``run-now`` endpoint.
:return: the run_id as an int
Expand All @@ -206,7 +206,7 @@ def run_now(self, json: dict) -> int:

def submit_run(self, json: dict) -> int:
"""
Utility function to call the ``api/2.1/jobs/runs/submit`` endpoint.
Call the ``api/2.1/jobs/runs/submit`` endpoint.
:param json: The data used in the body of the request to the ``submit`` endpoint.
:return: the run_id as an int
Expand All @@ -223,7 +223,7 @@ def list_jobs(
page_token: str | None = None,
) -> list[dict[str, Any]]:
"""
Lists the jobs in the Databricks Job Service.
List the jobs in the Databricks Job Service.
:param limit: The limit/batch size used to retrieve jobs.
:param offset: The offset of the first job to return, relative to the most recently created job.
Expand Down Expand Up @@ -274,7 +274,7 @@ def list_jobs(

def find_job_id_by_name(self, job_name: str) -> int | None:
"""
Finds job id by its name. If there are multiple jobs with the same name, raises AirflowException.
Find job id by its name; if there are multiple jobs with the same name, raise AirflowException.
:param job_name: The name of the job to look up.
:return: The job_id as an int or None if no job was found.
Expand All @@ -295,7 +295,7 @@ def list_pipelines(
self, batch_size: int = 25, pipeline_name: str | None = None, notebook_path: str | None = None
) -> list[dict[str, Any]]:
"""
Lists the pipelines in Databricks Delta Live Tables.
List the pipelines in Databricks Delta Live Tables.
:param batch_size: The limit/batch size used to retrieve pipelines.
:param pipeline_name: Optional name of a pipeline to search. Cannot be combined with path.
Expand Down Expand Up @@ -334,7 +334,7 @@ def list_pipelines(

def find_pipeline_id_by_name(self, pipeline_name: str) -> str | None:
"""
Finds pipeline id by its name. If multiple pipelines with the same name, raises AirflowException.
Find pipeline id by its name; if multiple pipelines with the same name, raise AirflowException.
:param pipeline_name: The name of the pipeline to look up.
:return: The pipeline_id as a GUID string or None if no pipeline was found.
Expand All @@ -354,7 +354,7 @@ def find_pipeline_id_by_name(self, pipeline_name: str) -> str | None:

def get_run_page_url(self, run_id: int) -> str:
"""
Retrieves run_page_url.
Retrieve run_page_url.
:param run_id: id of the run
:return: URL of the run page
Expand All @@ -376,7 +376,7 @@ async def a_get_run_page_url(self, run_id: int) -> str:

def get_job_id(self, run_id: int) -> int:
"""
Retrieves job_id from run_id.
Retrieve job_id from run_id.
:param run_id: id of the run
:return: Job id for given Databricks run
Expand All @@ -387,7 +387,7 @@ def get_job_id(self, run_id: int) -> int:

def get_run_state(self, run_id: int) -> RunState:
"""
Retrieves run state of the run.
Retrieve run state of the run.
Please note that any Airflow tasks that call the ``get_run_state`` method will result in
failure unless you have enabled xcom pickling. This can be done using the following
Expand Down Expand Up @@ -454,7 +454,7 @@ def get_run_state_str(self, run_id: int) -> str:

def get_run_state_lifecycle(self, run_id: int) -> str:
"""
Returns the lifecycle state of the run.
Return the lifecycle state of the run.
:param run_id: id of the run
:return: string with lifecycle state
Expand All @@ -463,7 +463,7 @@ def get_run_state_lifecycle(self, run_id: int) -> str:

def get_run_state_result(self, run_id: int) -> str:
"""
Returns the resulting state of the run.
Return the resulting state of the run.
:param run_id: id of the run
:return: string with resulting state
Expand All @@ -472,7 +472,7 @@ def get_run_state_result(self, run_id: int) -> str:

def get_run_state_message(self, run_id: int) -> str:
"""
Returns the state message for the run.
Return the state message for the run.
:param run_id: id of the run
:return: string with state message
Expand All @@ -481,7 +481,7 @@ def get_run_state_message(self, run_id: int) -> str:

def get_run_output(self, run_id: int) -> dict:
"""
Retrieves run output of the run.
Retrieve run output of the run.
:param run_id: id of the run
:return: output of the run
Expand All @@ -492,7 +492,7 @@ def get_run_output(self, run_id: int) -> dict:

def cancel_run(self, run_id: int) -> None:
"""
Cancels the run.
Cancel the run.
:param run_id: id of the run
"""
Expand All @@ -501,7 +501,7 @@ def cancel_run(self, run_id: int) -> None:

def cancel_all_runs(self, job_id: int) -> None:
"""
Cancels all active runs of a job. The runs are canceled asynchronously.
Cancel all active runs of a job asynchronously.
:param job_id: The canonical identifier of the job to cancel all runs of
"""
Expand All @@ -510,7 +510,7 @@ def cancel_all_runs(self, job_id: int) -> None:

def delete_run(self, run_id: int) -> None:
"""
Deletes a non-active run.
Delete a non-active run.
:param run_id: id of the run
"""
Expand All @@ -527,7 +527,7 @@ def repair_run(self, json: dict) -> None:

def get_cluster_state(self, cluster_id: str) -> ClusterState:
"""
Retrieves run state of the cluster.
Retrieve run state of the cluster.
:param cluster_id: id of the cluster
:return: state of the cluster
Expand Down Expand Up @@ -561,15 +561,15 @@ def restart_cluster(self, json: dict) -> None:

def start_cluster(self, json: dict) -> None:
"""
Starts the cluster.
Start the cluster.
:param json: json dictionary containing cluster specification.
"""
self._do_api_call(START_CLUSTER_ENDPOINT, json)

def terminate_cluster(self, json: dict) -> None:
"""
Terminates the cluster.
Terminate the cluster.
:param json: json dictionary containing cluster specification.
"""
Expand Down Expand Up @@ -597,7 +597,7 @@ def uninstall(self, json: dict) -> None:

def update_repo(self, repo_id: str, json: dict[str, Any]) -> dict:
"""
Updates given Databricks Repos.
Update given Databricks Repos.
:param repo_id: ID of Databricks Repos
:param json: payload
Expand All @@ -608,7 +608,7 @@ def update_repo(self, repo_id: str, json: dict[str, Any]) -> dict:

def delete_repo(self, repo_id: str):
"""
Deletes given Databricks Repos.
Delete given Databricks Repos.
:param repo_id: ID of Databricks Repos
:return:
Expand All @@ -618,7 +618,7 @@ def delete_repo(self, repo_id: str):

def create_repo(self, json: dict[str, Any]) -> dict:
"""
Creates a Databricks Repos.
Create a Databricks Repos.
:param json: payload
:return:
Expand All @@ -628,7 +628,7 @@ def create_repo(self, json: dict[str, Any]) -> dict:

def get_repo_by_path(self, path: str) -> str | None:
"""
Obtains Repos ID by path.
Obtain Repos ID by path.
:param path: path to a repository
:return: Repos ID if it exists, None if doesn't.
Expand Down
10 changes: 5 additions & 5 deletions airflow/providers/databricks/hooks/databricks_base.py
Expand Up @@ -175,7 +175,7 @@ async def __aexit__(self, *err):
@staticmethod
def _parse_host(host: str) -> str:
"""
This function is resistant to incorrect connection settings provided by users, in the host field.
Parse host field data; this function is resistant to incorrect connection settings provided by users.
For example -- when users supply ``https://xx.cloud.databricks.com`` as the
host, we must strip out the protocol to get the host.::
Expand Down Expand Up @@ -215,7 +215,7 @@ def _a_get_retry_object(self) -> AsyncRetrying:
return AsyncRetrying(**self.retry_args)

def _get_sp_token(self, resource: str) -> str:
"""Function to get Service Principal token."""
"""Get Service Principal token."""
sp_token = self.oauth_tokens.get(resource)
if sp_token and self._is_oauth_token_valid(sp_token):
return sp_token["access_token"]
Expand Down Expand Up @@ -287,7 +287,7 @@ async def _a_get_sp_token(self, resource: str) -> str:

def _get_aad_token(self, resource: str) -> str:
"""
Function to get AAD token for given resource.
Get AAD token for given resource.
Supports managed identity or service principal auth.
:param resource: resource to issue token to
Expand Down Expand Up @@ -441,7 +441,7 @@ async def _a_get_aad_headers(self) -> dict:
@staticmethod
def _is_oauth_token_valid(token: dict, time_key="expires_on") -> bool:
"""
Utility function to check if an OAuth token is valid and hasn't expired yet.
Check if an OAuth token is valid and hasn't expired yet.
:param sp_token: dict with properties of OAuth token
:param time_key: name of the key that holds the time of expiration
Expand Down Expand Up @@ -556,7 +556,7 @@ def _do_api_call(
wrap_http_errors: bool = True,
):
"""
Utility function to perform an API call with retries.
Perform an API call with retries.
:param endpoint_info: Tuple of method and endpoint
:param json: Parameters for this API call.
Expand Down
5 changes: 3 additions & 2 deletions airflow/providers/databricks/hooks/databricks_sql.py
Expand Up @@ -101,7 +101,7 @@ def _get_sql_endpoint_by_name(self, endpoint_name) -> dict[str, Any]:
return endpoint

def get_conn(self) -> Connection:
"""Returns a Databricks SQL connection object."""
"""Return a Databricks SQL connection object."""
if not self._http_path:
if self._sql_endpoint_name:
endpoint = self._get_sql_endpoint_by_name(self._sql_endpoint_name)
Expand Down Expand Up @@ -178,7 +178,8 @@ def run(
split_statements: bool = True,
return_last: bool = True,
) -> T | list[T] | None:
"""Runs a command or a list of commands.
"""
Run a command or a list of commands.
Pass a list of SQL statements to the SQL parameter to get them to
execute sequentially.
Expand Down
8 changes: 4 additions & 4 deletions airflow/providers/databricks/operators/databricks.py
Expand Up @@ -44,7 +44,7 @@

def _handle_databricks_operator_execution(operator, hook, log, context) -> None:
"""
Handles the Airflow + Databricks lifecycle logic for a Databricks operator.
Handle the Airflow + Databricks lifecycle logic for a Databricks operator.
:param operator: Databricks operator being handled
:param context: Airflow context
Expand Down Expand Up @@ -102,7 +102,7 @@ def _handle_databricks_operator_execution(operator, hook, log, context) -> None:

def _handle_deferrable_databricks_operator_execution(operator, hook, log, context) -> None:
"""
Handles the Airflow + Databricks lifecycle logic for deferrable Databricks operators.
Handle the Airflow + Databricks lifecycle logic for deferrable Databricks operators.
:param operator: Databricks async operator being handled
:param context: Airflow context
Expand Down Expand Up @@ -320,7 +320,7 @@ def __init__(
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
**kwargs,
) -> None:
"""Creates a new ``DatabricksSubmitRunOperator``."""
"""Create a new ``DatabricksSubmitRunOperator``."""
super().__init__(**kwargs)
self.json = json or {}
self.databricks_conn_id = databricks_conn_id
Expand Down Expand Up @@ -621,7 +621,7 @@ def __init__(
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
**kwargs,
) -> None:
"""Creates a new ``DatabricksRunNowOperator``."""
"""Create a new ``DatabricksRunNowOperator``."""
super().__init__(**kwargs)
self.json = json or {}
self.databricks_conn_id = databricks_conn_id
Expand Down
8 changes: 4 additions & 4 deletions airflow/providers/databricks/operators/databricks_repos.py
Expand Up @@ -80,7 +80,7 @@ def __init__(
databricks_retry_delay: int = 1,
**kwargs,
) -> None:
"""Creates a new ``DatabricksReposCreateOperator``."""
"""Create a new ``DatabricksReposCreateOperator``."""
super().__init__(**kwargs)
self.databricks_conn_id = databricks_conn_id
self.databricks_retry_limit = databricks_retry_limit
Expand Down Expand Up @@ -125,7 +125,7 @@ def _hook(self) -> DatabricksHook:

def execute(self, context: Context):
"""
Creates a Databricks Repo.
Create a Databricks Repo.
:param context: context
:return: Repo ID
Expand Down Expand Up @@ -194,7 +194,7 @@ def __init__(
databricks_retry_delay: int = 1,
**kwargs,
) -> None:
"""Creates a new ``DatabricksReposUpdateOperator``."""
"""Create a new ``DatabricksReposUpdateOperator``."""
super().__init__(**kwargs)
self.databricks_conn_id = databricks_conn_id
self.databricks_retry_limit = databricks_retry_limit
Expand Down Expand Up @@ -266,7 +266,7 @@ def __init__(
databricks_retry_delay: int = 1,
**kwargs,
) -> None:
"""Creates a new ``DatabricksReposDeleteOperator``."""
"""Create a new ``DatabricksReposDeleteOperator``."""
super().__init__(**kwargs)
self.databricks_conn_id = databricks_conn_id
self.databricks_retry_limit = databricks_retry_limit
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/databricks/operators/databricks_sql.py
Expand Up @@ -244,7 +244,7 @@ def __init__(
validate: bool | int | None = None,
**kwargs,
) -> None:
"""Creates a new ``DatabricksSqlOperator``."""
"""Create a new ``DatabricksSqlOperator``."""
super().__init__(**kwargs)
if files is not None and pattern is not None:
raise AirflowException("Only one of 'pattern' or 'files' should be specified")
Expand Down
8 changes: 4 additions & 4 deletions airflow/providers/databricks/sensors/databricks_partition.py
Expand Up @@ -110,7 +110,7 @@ def __init__(
super().__init__(**kwargs)

def _sql_sensor(self, sql):
"""Executes the supplied SQL statement using the hook object."""
"""Execute the supplied SQL statement using the hook object."""
hook = self._get_hook
sql_result = hook.run(
sql,
Expand All @@ -121,7 +121,7 @@ def _sql_sensor(self, sql):

@cached_property
def _get_hook(self) -> DatabricksSqlHook:
"""Creates and returns a DatabricksSqlHook object."""
"""Create and return a DatabricksSqlHook object."""
return DatabricksSqlHook(
self.databricks_conn_id,
self._http_path,
Expand Down Expand Up @@ -166,7 +166,7 @@ def _generate_partition_query(
escape_key: bool = False,
) -> str:
"""
Queries the table for available partitions.
Query the table for available partitions.
Generates the SQL query based on the partition data types.
* For a list, it prepares the SQL in the format:
Expand Down Expand Up @@ -225,7 +225,7 @@ def _generate_partition_query(
return formatted_opts.strip()

def poke(self, context: Context) -> bool:
"""Checks the table partitions and returns the results."""
"""Check the table partitions and return the results."""
partition_result = self._check_table_partitions()
self.log.debug("Partition sensor result: %s", partition_result)
if partition_result:
Expand Down

0 comments on commit 7a93b19

Please sign in to comment.