Skip to content

Commit

Permalink
Make Dataprep system test self-sufficient (#34880)
Browse files Browse the repository at this point in the history
  • Loading branch information
moiseenkov committed Oct 26, 2023
1 parent fe360cb commit acff4c7
Show file tree
Hide file tree
Showing 5 changed files with 652 additions and 57 deletions.
102 changes: 94 additions & 8 deletions airflow/providers/google/cloud/hooks/dataprep.py
Expand Up @@ -72,9 +72,10 @@ class GoogleDataprepHook(BaseHook):
conn_type = "dataprep"
hook_name = "Google Dataprep"

def __init__(self, dataprep_conn_id: str = default_conn_name) -> None:
def __init__(self, dataprep_conn_id: str = default_conn_name, api_version: str = "v4") -> None:
super().__init__()
self.dataprep_conn_id = dataprep_conn_id
self.api_version = api_version
conn = self.get_connection(self.dataprep_conn_id)
extras = conn.extra_dejson
self._token = _get_field(extras, "token")
Expand All @@ -95,7 +96,7 @@ def get_jobs_for_job_group(self, job_id: int) -> dict[str, Any]:
:param job_id: The ID of the job that will be fetched
"""
endpoint_path = f"v4/jobGroups/{job_id}/jobs"
endpoint_path = f"{self.api_version}/jobGroups/{job_id}/jobs"
url: str = urljoin(self._base_url, endpoint_path)
response = requests.get(url, headers=self._headers)
self._raise_for_status(response)
Expand All @@ -113,7 +114,7 @@ def get_job_group(self, job_group_id: int, embed: str, include_deleted: bool) ->
:param include_deleted: if set to "true", will include deleted objects
"""
params: dict[str, Any] = {"embed": embed, "includeDeleted": include_deleted}
endpoint_path = f"v4/jobGroups/{job_group_id}"
endpoint_path = f"{self.api_version}/jobGroups/{job_group_id}"
url: str = urljoin(self._base_url, endpoint_path)
response = requests.get(url, headers=self._headers, params=params)
self._raise_for_status(response)
Expand All @@ -131,12 +132,26 @@ def run_job_group(self, body_request: dict) -> dict[str, Any]:
:param body_request: The identifier for the recipe you would like to run.
"""
endpoint_path = "v4/jobGroups"
endpoint_path = f"{self.api_version}/jobGroups"
url: str = urljoin(self._base_url, endpoint_path)
response = requests.post(url, headers=self._headers, data=json.dumps(body_request))
self._raise_for_status(response)
return response.json()

@retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, max=10))
def create_flow(self, *, body_request: dict) -> dict:
"""
Creates flow.
:param body_request: Body of the POST request to be sent.
For more details check https://clouddataprep.com/documentation/api#operation/createFlow
"""
endpoint = f"/{self.api_version}/flows"
url: str = urljoin(self._base_url, endpoint)
response = requests.post(url, headers=self._headers, data=json.dumps(body_request))
self._raise_for_status(response)
return response.json()

@retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, max=10))
def copy_flow(
self, *, flow_id: int, name: str = "", description: str = "", copy_datasources: bool = False
Expand All @@ -149,7 +164,7 @@ def copy_flow(
:param description: Description of the copy of the flow
:param copy_datasources: Bool value to define should copies of data inputs be made or not.
"""
endpoint_path = f"v4/flows/{flow_id}/copy"
endpoint_path = f"{self.api_version}/flows/{flow_id}/copy"
url: str = urljoin(self._base_url, endpoint_path)
body_request = {
"name": name,
Expand All @@ -167,7 +182,7 @@ def delete_flow(self, *, flow_id: int) -> None:
:param flow_id: ID of the flow to be copied
"""
endpoint_path = f"v4/flows/{flow_id}"
endpoint_path = f"{self.api_version}/flows/{flow_id}"
url: str = urljoin(self._base_url, endpoint_path)
response = requests.delete(url, headers=self._headers)
self._raise_for_status(response)
Expand All @@ -180,7 +195,7 @@ def run_flow(self, *, flow_id: int, body_request: dict) -> dict:
:param flow_id: ID of the flow to be copied
:param body_request: Body of the POST request to be sent.
"""
endpoint = f"v4/flows/{flow_id}/run"
endpoint = f"{self.api_version}/flows/{flow_id}/run"
url: str = urljoin(self._base_url, endpoint)
response = requests.post(url, headers=self._headers, data=json.dumps(body_request))
self._raise_for_status(response)
Expand All @@ -193,7 +208,7 @@ def get_job_group_status(self, *, job_group_id: int) -> JobGroupStatuses:
:param job_group_id: ID of the job group to check
"""
endpoint = f"/v4/jobGroups/{job_group_id}/status"
endpoint = f"/{self.api_version}/jobGroups/{job_group_id}/status"
url: str = urljoin(self._base_url, endpoint)
response = requests.get(url, headers=self._headers)
self._raise_for_status(response)
Expand All @@ -205,3 +220,74 @@ def _raise_for_status(self, response: requests.models.Response) -> None:
except HTTPError:
self.log.error(response.json().get("exception"))
raise

@retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, max=10))
def create_imported_dataset(self, *, body_request: dict) -> dict:
"""
Creates imported dataset.
:param body_request: Body of the POST request to be sent.
For more details check https://clouddataprep.com/documentation/api#operation/createImportedDataset
"""
endpoint = f"/{self.api_version}/importedDatasets"
url: str = urljoin(self._base_url, endpoint)
response = requests.post(url, headers=self._headers, data=json.dumps(body_request))
self._raise_for_status(response)
return response.json()

@retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, max=10))
def create_wrangled_dataset(self, *, body_request: dict) -> dict:
"""
Creates wrangled dataset.
:param body_request: Body of the POST request to be sent.
For more details check
https://clouddataprep.com/documentation/api#operation/createWrangledDataset
"""
endpoint = f"/{self.api_version}/wrangledDatasets"
url: str = urljoin(self._base_url, endpoint)
response = requests.post(url, headers=self._headers, data=json.dumps(body_request))
self._raise_for_status(response)
return response.json()

@retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, max=10))
def create_output_object(self, *, body_request: dict) -> dict:
"""
Creates output.
:param body_request: Body of the POST request to be sent.
For more details check
https://clouddataprep.com/documentation/api#operation/createOutputObject
"""
endpoint = f"/{self.api_version}/outputObjects"
url: str = urljoin(self._base_url, endpoint)
response = requests.post(url, headers=self._headers, data=json.dumps(body_request))
self._raise_for_status(response)
return response.json()

@retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, max=10))
def create_write_settings(self, *, body_request: dict) -> dict:
"""
Creates write settings.
:param body_request: Body of the POST request to be sent.
For more details check
https://clouddataprep.com/documentation/api#tag/createWriteSetting
"""
endpoint = f"/{self.api_version}/writeSettings"
url: str = urljoin(self._base_url, endpoint)
response = requests.post(url, headers=self._headers, data=json.dumps(body_request))
self._raise_for_status(response)
return response.json()

@retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, max=10))
def delete_imported_dataset(self, *, dataset_id: int) -> None:
"""
Deletes imported dataset.
:param dataset_id: ID of the imported dataset for removal.
"""
endpoint = f"/{self.api_version}/importedDatasets/{dataset_id}"
url: str = urljoin(self._base_url, endpoint)
response = requests.delete(url, headers=self._headers)
self._raise_for_status(response)
4 changes: 2 additions & 2 deletions airflow/providers/google/cloud/operators/dataprep.py
Expand Up @@ -51,13 +51,13 @@ def __init__(
**kwargs,
) -> None:
super().__init__(**kwargs)
self.dataprep_conn_id = (dataprep_conn_id,)
self.dataprep_conn_id = dataprep_conn_id
self.job_group_id = job_group_id

def execute(self, context: Context) -> dict:
self.log.info("Fetching data for job with id: %d ...", self.job_group_id)
hook = GoogleDataprepHook(
dataprep_conn_id="dataprep_default",
dataprep_conn_id=self.dataprep_conn_id,
)
response = hook.get_jobs_for_job_group(job_id=int(self.job_group_id))
return response
Expand Down

0 comments on commit acff4c7

Please sign in to comment.