diff --git a/airflow/providers/google/cloud/transfers/bigquery_to_bigquery.py b/airflow/providers/google/cloud/transfers/bigquery_to_bigquery.py index a54ced136f3ba..671ad33488e27 100644 --- a/airflow/providers/google/cloud/transfers/bigquery_to_bigquery.py +++ b/airflow/providers/google/cloud/transfers/bigquery_to_bigquery.py @@ -18,10 +18,8 @@ """This module contains Google BigQuery to BigQuery operator.""" from __future__ import annotations -import warnings from typing import TYPE_CHECKING, Sequence -from airflow.exceptions import AirflowProviderDeprecationWarning from airflow.models import BaseOperator from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook from airflow.providers.google.cloud.links.bigquery import BigQueryTableLink @@ -110,6 +108,58 @@ def __init__( self.encryption_configuration = encryption_configuration self.location = location self.impersonation_chain = impersonation_chain + self.hook: BigQueryHook | None = None + + def _prepare_job_configuration(self): + self.source_project_dataset_tables = ( + [self.source_project_dataset_tables] + if not isinstance(self.source_project_dataset_tables, list) + else self.source_project_dataset_tables + ) + + source_project_dataset_tables_fixup = [] + for source_project_dataset_table in self.source_project_dataset_tables: + source_project, source_dataset, source_table = self.hook.split_tablename( + table_input=source_project_dataset_table, + default_project_id=self.hook.project_id, + var_name="source_project_dataset_table", + ) + source_project_dataset_tables_fixup.append( + {"projectId": source_project, "datasetId": source_dataset, "tableId": source_table} + ) + + destination_project, destination_dataset, destination_table = self.hook.split_tablename( + table_input=self.destination_project_dataset_table, + default_project_id=self.hook.project_id, + ) + configuration = { + "copy": { + "createDisposition": self.create_disposition, + "writeDisposition": self.write_disposition, + "sourceTables": source_project_dataset_tables_fixup, + "destinationTable": { + "projectId": destination_project, + "datasetId": destination_dataset, + "tableId": destination_table, + }, + } + } + + if self.labels: + configuration["labels"] = self.labels + + if self.encryption_configuration: + configuration["copy"]["destinationEncryptionConfiguration"] = self.encryption_configuration + + return configuration + + def _submit_job( + self, + hook: BigQueryHook, + configuration: dict, + ) -> str: + job = hook.insert_job(configuration=configuration, project_id=hook.project_id) + return job.job_id def execute(self, context: Context) -> None: self.log.info( @@ -122,24 +172,20 @@ def execute(self, context: Context) -> None: location=self.location, impersonation_chain=self.impersonation_chain, ) + self.hook = hook - with warnings.catch_warnings(): - warnings.simplefilter("ignore", AirflowProviderDeprecationWarning) - job_id = hook.run_copy( - source_project_dataset_tables=self.source_project_dataset_tables, - destination_project_dataset_table=self.destination_project_dataset_table, - write_disposition=self.write_disposition, - create_disposition=self.create_disposition, - labels=self.labels, - encryption_configuration=self.encryption_configuration, - ) + if not hook.project_id: + raise ValueError("The project_id should be set") - job = hook.get_job(job_id=job_id, location=self.location).to_api_repr() - conf = job["configuration"]["copy"]["destinationTable"] - BigQueryTableLink.persist( - context=context, - task_instance=self, - dataset_id=conf["datasetId"], - project_id=conf["projectId"], - table_id=conf["tableId"], - ) + configuration = self._prepare_job_configuration() + job_id = self._submit_job(hook=hook, configuration=configuration) + + job = hook.get_job(job_id=job_id, location=self.location).to_api_repr() + conf = job["configuration"]["copy"]["destinationTable"] + BigQueryTableLink.persist( + context=context, + task_instance=self, + dataset_id=conf["datasetId"], + project_id=conf["projectId"], + table_id=conf["tableId"], + ) diff --git a/tests/providers/google/cloud/transfers/test_bigquery_to_bigquery.py b/tests/providers/google/cloud/transfers/test_bigquery_to_bigquery.py index 5c0f8f560caba..ed06928c2ccff 100644 --- a/tests/providers/google/cloud/transfers/test_bigquery_to_bigquery.py +++ b/tests/providers/google/cloud/transfers/test_bigquery_to_bigquery.py @@ -23,64 +23,90 @@ BQ_HOOK_PATH = "airflow.providers.google.cloud.transfers.bigquery_to_bigquery.BigQueryHook" TASK_ID = "test-bq-create-table-operator" +TEST_GCP_PROJECT_ID = "test-project" TEST_DATASET = "test-dataset" TEST_TABLE_ID = "test-table-id" +SOURCE_PROJECT_DATASET_TABLES = f"{TEST_GCP_PROJECT_ID}.{TEST_DATASET}.{TEST_TABLE_ID}" +DESTINATION_PROJECT_DATASET_TABLE = f"{TEST_GCP_PROJECT_ID}.{TEST_DATASET + '_new'}.{TEST_TABLE_ID}" +WRITE_DISPOSITION = "WRITE_EMPTY" +CREATE_DISPOSITION = "CREATE_IF_NEEDED" +LABELS = {"k1": "v1"} +ENCRYPTION_CONFIGURATION = {"key": "kk"} + + +def split_tablename_side_effect(*args, **kwargs): + if kwargs["table_input"] == SOURCE_PROJECT_DATASET_TABLES: + return ( + TEST_GCP_PROJECT_ID, + TEST_DATASET, + TEST_TABLE_ID, + ) + elif kwargs["table_input"] == DESTINATION_PROJECT_DATASET_TABLE: + return ( + TEST_GCP_PROJECT_ID, + TEST_DATASET + "_new", + TEST_TABLE_ID, + ) + class TestBigQueryToBigQueryOperator: @mock.patch(BQ_HOOK_PATH) def test_execute_without_location_should_execute_successfully(self, mock_hook): - source_project_dataset_tables = f"{TEST_DATASET}.{TEST_TABLE_ID}" - destination_project_dataset_table = f"{TEST_DATASET + '_new'}.{TEST_TABLE_ID}" - write_disposition = "WRITE_EMPTY" - create_disposition = "CREATE_IF_NEEDED" - labels = {"k1": "v1"} - encryption_configuration = {"key": "kk"} - operator = BigQueryToBigQueryOperator( task_id=TASK_ID, - source_project_dataset_tables=source_project_dataset_tables, - destination_project_dataset_table=destination_project_dataset_table, - write_disposition=write_disposition, - create_disposition=create_disposition, - labels=labels, - encryption_configuration=encryption_configuration, + source_project_dataset_tables=SOURCE_PROJECT_DATASET_TABLES, + destination_project_dataset_table=DESTINATION_PROJECT_DATASET_TABLE, + write_disposition=WRITE_DISPOSITION, + create_disposition=CREATE_DISPOSITION, + labels=LABELS, + encryption_configuration=ENCRYPTION_CONFIGURATION, ) + mock_hook.return_value.split_tablename.side_effect = split_tablename_side_effect operator.execute(context=mock.MagicMock()) - mock_hook.return_value.run_copy.assert_called_once_with( - source_project_dataset_tables=source_project_dataset_tables, - destination_project_dataset_table=destination_project_dataset_table, - write_disposition=write_disposition, - create_disposition=create_disposition, - labels=labels, - encryption_configuration=encryption_configuration, + mock_hook.return_value.insert_job.assert_called_once_with( + configuration={ + "copy": { + "createDisposition": CREATE_DISPOSITION, + "destinationEncryptionConfiguration": ENCRYPTION_CONFIGURATION, + "destinationTable": { + "datasetId": TEST_DATASET + "_new", + "projectId": TEST_GCP_PROJECT_ID, + "tableId": TEST_TABLE_ID, + }, + "sourceTables": [ + { + "datasetId": TEST_DATASET, + "projectId": TEST_GCP_PROJECT_ID, + "tableId": TEST_TABLE_ID, + }, + ], + "writeDisposition": WRITE_DISPOSITION, + }, + "labels": LABELS, + }, + project_id=mock_hook.return_value.project_id, ) @mock.patch(BQ_HOOK_PATH) def test_execute_single_regional_location_should_execute_successfully(self, mock_hook): - source_project_dataset_tables = f"{TEST_DATASET}.{TEST_TABLE_ID}" - destination_project_dataset_table = f"{TEST_DATASET + '_new'}.{TEST_TABLE_ID}" - write_disposition = "WRITE_EMPTY" - create_disposition = "CREATE_IF_NEEDED" - labels = {"k1": "v1"} location = "us-central1" - encryption_configuration = {"key": "kk"} - mock_hook.return_value.run_copy.return_value = "job-id" operator = BigQueryToBigQueryOperator( task_id=TASK_ID, - source_project_dataset_tables=source_project_dataset_tables, - destination_project_dataset_table=destination_project_dataset_table, - write_disposition=write_disposition, - create_disposition=create_disposition, - labels=labels, - encryption_configuration=encryption_configuration, + source_project_dataset_tables=SOURCE_PROJECT_DATASET_TABLES, + destination_project_dataset_table=DESTINATION_PROJECT_DATASET_TABLE, + write_disposition=WRITE_DISPOSITION, + create_disposition=CREATE_DISPOSITION, + labels=LABELS, + encryption_configuration=ENCRYPTION_CONFIGURATION, location=location, ) + mock_hook.return_value.split_tablename.side_effect = split_tablename_side_effect operator.execute(context=mock.MagicMock()) mock_hook.return_value.get_job.assert_called_once_with( - job_id="job-id", + job_id=mock_hook.return_value.insert_job.return_value.job_id, location=location, )