From b8f73768ec13f8d4cc1605cca3fa93be6caac473 Mon Sep 17 00:00:00 2001 From: Shahar Epstein <60007259+shahar1@users.noreply.github.com> Date: Tue, 9 May 2023 09:05:24 +0300 Subject: [PATCH] Add `as_dict` param to `BigQueryGetDataOperator` (#30887) * Add "as_dict" param to BigQueryGetDataOperator --- .../providers/google/cloud/hooks/bigquery.py | 12 +++++-- .../google/cloud/operators/bigquery.py | 35 +++++++++++++------ .../google/cloud/triggers/bigquery.py | 13 +++++-- .../operators/cloud/bigquery.rst | 7 ++-- .../google/cloud/hooks/test_bigquery.py | 26 ++++++++++++++ .../google/cloud/operators/test_bigquery.py | 16 ++++++--- 6 files changed, 87 insertions(+), 22 deletions(-) diff --git a/airflow/providers/google/cloud/hooks/bigquery.py b/airflow/providers/google/cloud/hooks/bigquery.py index f24562da7b846..a091dd73fe200 100644 --- a/airflow/providers/google/cloud/hooks/bigquery.py +++ b/airflow/providers/google/cloud/hooks/bigquery.py @@ -3125,20 +3125,26 @@ async def create_job_for_partition_get( job_query_resp = await job_client.query(query_request, cast(Session, session)) return job_query_resp["jobReference"]["jobId"] - def get_records(self, query_results: dict[str, Any]) -> list[Any]: + def get_records(self, query_results: dict[str, Any], as_dict: bool = False) -> list[Any]: """ Given the output query response from gcloud-aio bigquery, convert the response to records. :param query_results: the results from a SQL query + :param as_dict: if True returns the result as a list of dictionaries, otherwise as list of lists. """ - buffer = [] + buffer: list[Any] = [] if "rows" in query_results and query_results["rows"]: rows = query_results["rows"] fields = query_results["schema"]["fields"] col_types = [field["type"] for field in fields] for dict_row in rows: typed_row = [bq_cast(vs["v"], col_types[idx]) for idx, vs in enumerate(dict_row["f"])] - buffer.append(typed_row) + if not as_dict: + buffer.append(typed_row) + else: + fields_names = [field["name"] for field in fields] + typed_row_dict = {k: v for k, v in zip(fields_names, typed_row)} + buffer.append(typed_row_dict) return buffer def value_check( diff --git a/airflow/providers/google/cloud/operators/bigquery.py b/airflow/providers/google/cloud/operators/bigquery.py index 22150a6221653..3d3d9719cc046 100644 --- a/airflow/providers/google/cloud/operators/bigquery.py +++ b/airflow/providers/google/cloud/operators/bigquery.py @@ -758,12 +758,19 @@ def execute(self, context=None): class BigQueryGetDataOperator(GoogleCloudBaseOperator): """ - Fetches the data from a BigQuery table (alternatively fetch data for selected columns) - and returns data in a python list. The number of elements in the returned list will - be equal to the number of rows fetched. Each element in the list will again be a list - where element would represent the columns values for that row. + Fetches the data from a BigQuery table (alternatively fetch data for selected columns) and returns data + in either of the following two formats, based on "as_dict" value: + 1. False (Default) - A Python list of lists, with the number of nested lists equal to the number of rows + fetched. Each nested list represents a row, where the elements within it correspond to the column values + for that particular row. - **Example Result**: ``[['Tony', '10'], ['Mike', '20'], ['Steve', '15']]`` + **Example Result**: ``[['Tony', 10], ['Mike', 20]`` + + + 2. True - A Python list of dictionaries, where each dictionary represents a row. In each dictionary, + the keys are the column names and the values are the corresponding values for those columns. + + **Example Result**: ``[{'name': 'Tony', 'age': 10}, {'name': 'Mike', 'age': 20}]`` .. seealso:: For more information on how to use this operator, take a look at the guide: @@ -810,6 +817,8 @@ class BigQueryGetDataOperator(GoogleCloudBaseOperator): :param deferrable: Run operator in the deferrable mode :param poll_interval: (Deferrable mode only) polling period in seconds to check for the status of job. Defaults to 4 seconds. + :param as_dict: if True returns the result as a list of dictionaries, otherwise as list of lists + (default: False). """ template_fields: Sequence[str] = ( @@ -835,6 +844,7 @@ def __init__( impersonation_chain: str | Sequence[str] | None = None, deferrable: bool = False, poll_interval: float = 4.0, + as_dict: bool = False, **kwargs, ) -> None: super().__init__(**kwargs) @@ -849,6 +859,7 @@ def __init__( self.project_id = project_id self.deferrable = deferrable self.poll_interval = poll_interval + self.as_dict = as_dict def _submit_job( self, @@ -884,7 +895,6 @@ def execute(self, context: Context): gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, ) - self.hook = hook if not self.deferrable: self.log.info( @@ -910,21 +920,26 @@ def execute(self, context: Context): self.log.info("Total extracted rows: %s", len(rows)) - table_data = [row.values() for row in rows] + if self.as_dict: + table_data = [{k: v for k, v in row.items()} for row in rows] + else: + table_data = [row.values() for row in rows] + return table_data job = self._submit_job(hook, job_id="") - self.job_id = job.job_id - context["ti"].xcom_push(key="job_id", value=self.job_id) + + context["ti"].xcom_push(key="job_id", value=job.job_id) self.defer( timeout=self.execution_timeout, trigger=BigQueryGetDataTrigger( conn_id=self.gcp_conn_id, - job_id=self.job_id, + job_id=job.job_id, dataset_id=self.dataset_id, table_id=self.table_id, project_id=hook.project_id, poll_interval=self.poll_interval, + as_dict=self.as_dict, ), method_name="execute_complete", ) diff --git a/airflow/providers/google/cloud/triggers/bigquery.py b/airflow/providers/google/cloud/triggers/bigquery.py index ba4ce8c19be42..1da7f87f90259 100644 --- a/airflow/providers/google/cloud/triggers/bigquery.py +++ b/airflow/providers/google/cloud/triggers/bigquery.py @@ -165,7 +165,16 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override] class BigQueryGetDataTrigger(BigQueryInsertJobTrigger): - """BigQueryGetDataTrigger run on the trigger worker, inherits from BigQueryInsertJobTrigger class""" + """ + BigQueryGetDataTrigger run on the trigger worker, inherits from BigQueryInsertJobTrigger class + + :param as_dict: if True returns the result as a list of dictionaries, otherwise as list of lists + (default: False). + """ + + def __init__(self, as_dict: bool = False, **kwargs): + super().__init__(**kwargs) + self.as_dict = as_dict def serialize(self) -> tuple[str, dict[str, Any]]: """Serializes BigQueryInsertJobTrigger arguments and classpath.""" @@ -190,7 +199,7 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override] response_from_hook = await hook.get_job_status(job_id=self.job_id, project_id=self.project_id) if response_from_hook == "success": query_results = await hook.get_job_output(job_id=self.job_id, project_id=self.project_id) - records = hook.get_records(query_results) + records = hook.get_records(query_results=query_results, as_dict=self.as_dict) self.log.debug("Response from hook: %s", response_from_hook) yield TriggerEvent( { diff --git a/docs/apache-airflow-providers-google/operators/cloud/bigquery.rst b/docs/apache-airflow-providers-google/operators/cloud/bigquery.rst index 61f4439fe1897..6529ee4522f0b 100644 --- a/docs/apache-airflow-providers-google/operators/cloud/bigquery.rst +++ b/docs/apache-airflow-providers-google/operators/cloud/bigquery.rst @@ -208,10 +208,11 @@ To fetch data from a BigQuery table you can use Alternatively you can fetch data for selected columns if you pass fields to ``selected_fields``. -This operator returns data in a Python list where the number of elements in the -returned list will be equal to the number of rows fetched. Each element in the -list will again be a list where elements would represent the column values for +The result of this operator can be retrieved in two different formats based on the value of the ``as_dict`` parameter: +``False`` (default) - A Python list of lists, where the number of elements in the nesting list will be equal to the number of rows fetched. Each element in the +nesting will a nested list where elements would represent the column values for that row. +``True`` - A Python list of dictionaries, where each dictionary represents a row. In each dictionary, the keys are the column names and the values are the corresponding values for those columns. .. exampleinclude:: /../../tests/system/providers/google/cloud/bigquery/example_bigquery_queries.py :language: python diff --git a/tests/providers/google/cloud/hooks/test_bigquery.py b/tests/providers/google/cloud/hooks/test_bigquery.py index 0e09080fa7d83..a508f1424904a 100644 --- a/tests/providers/google/cloud/hooks/test_bigquery.py +++ b/tests/providers/google/cloud/hooks/test_bigquery.py @@ -2348,3 +2348,29 @@ def test_get_records_return_type(self): assert isinstance(result[0][0], int) assert isinstance(result[0][1], float) assert isinstance(result[0][2], str) + + def test_get_records_as_dict(self): + query_result = { + "kind": "bigquery#getQueryResultsResponse", + "etag": "test_etag", + "schema": { + "fields": [ + {"name": "f0_", "type": "INTEGER", "mode": "NULLABLE"}, + {"name": "f1_", "type": "FLOAT", "mode": "NULLABLE"}, + {"name": "f2_", "type": "STRING", "mode": "NULLABLE"}, + ] + }, + "jobReference": { + "projectId": "test_airflow-providers", + "jobId": "test_jobid", + "location": "US", + }, + "totalRows": "1", + "rows": [{"f": [{"v": "22"}, {"v": "3.14"}, {"v": "PI"}]}], + "totalBytesProcessed": "0", + "jobComplete": True, + "cacheHit": False, + } + hook = BigQueryAsyncHook() + result = hook.get_records(query_result, as_dict=True) + assert result == [{"f0_": 22, "f1_": 3.14, "f2_": "PI"}] diff --git a/tests/providers/google/cloud/operators/test_bigquery.py b/tests/providers/google/cloud/operators/test_bigquery.py index d5fafb6995dec..c5051d92e6b0e 100644 --- a/tests/providers/google/cloud/operators/test_bigquery.py +++ b/tests/providers/google/cloud/operators/test_bigquery.py @@ -785,8 +785,9 @@ def test_bigquery_operator_extra_link_when_multiple_query( class TestBigQueryGetDataOperator: + @pytest.mark.parametrize("as_dict", [True, False]) @mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook") - def test_execute(self, mock_hook): + def test_execute(self, mock_hook, as_dict): max_results = 100 selected_fields = "DATE" operator = BigQueryGetDataOperator( @@ -797,6 +798,7 @@ def test_execute(self, mock_hook): max_results=max_results, selected_fields=selected_fields, location=TEST_DATASET_LOCATION, + as_dict=as_dict, ) operator.execute(None) mock_hook.return_value.list_rows.assert_called_once_with( @@ -840,9 +842,10 @@ def test_bigquery_get_data_operator_async_with_selected_fields( exc.value.trigger, BigQueryGetDataTrigger ), "Trigger is not a BigQueryGetDataTrigger" + @pytest.mark.parametrize("as_dict", [True, False]) @mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook") def test_bigquery_get_data_operator_async_without_selected_fields( - self, mock_hook, create_task_instance_of_operator + self, mock_hook, create_task_instance_of_operator, as_dict ): """ Asserts that a task is deferred and a BigQueryGetDataTrigger will be fired @@ -862,6 +865,7 @@ def test_bigquery_get_data_operator_async_without_selected_fields( table_id=TEST_TABLE_ID, max_results=100, deferrable=True, + as_dict=as_dict, ) with pytest.raises(TaskDeferred) as exc: @@ -871,7 +875,8 @@ def test_bigquery_get_data_operator_async_without_selected_fields( exc.value.trigger, BigQueryGetDataTrigger ), "Trigger is not a BigQueryGetDataTrigger" - def test_bigquery_get_data_operator_execute_failure(self): + @pytest.mark.parametrize("as_dict", [True, False]) + def test_bigquery_get_data_operator_execute_failure(self, as_dict): """Tests that an AirflowException is raised in case of error event""" operator = BigQueryGetDataOperator( @@ -880,6 +885,7 @@ def test_bigquery_get_data_operator_execute_failure(self): table_id="any", max_results=100, deferrable=True, + as_dict=as_dict, ) with pytest.raises(AirflowException): @@ -887,7 +893,8 @@ def test_bigquery_get_data_operator_execute_failure(self): context=None, event={"status": "error", "message": "test failure message"} ) - def test_bigquery_get_data_op_execute_complete_with_records(self): + @pytest.mark.parametrize("as_dict", [True, False]) + def test_bigquery_get_data_op_execute_complete_with_records(self, as_dict): """Asserts that exception is raised with correct expected exception message""" operator = BigQueryGetDataOperator( @@ -896,6 +903,7 @@ def test_bigquery_get_data_op_execute_complete_with_records(self): table_id="any", max_results=100, deferrable=True, + as_dict=as_dict, ) with mock.patch.object(operator.log, "info") as mock_log_info: