Skip to content

Commit

Permalink
Remove usage of deprecated method from BigQueryToBigQueryOperator (#3…
Browse files Browse the repository at this point in the history
  • Loading branch information
MaksYermak committed Nov 20, 2023
1 parent f8dd192 commit 9207e7d
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 55 deletions.
88 changes: 67 additions & 21 deletions airflow/providers/google/cloud/transfers/bigquery_to_bigquery.py
Expand Up @@ -18,10 +18,8 @@
"""This module contains Google BigQuery to BigQuery operator."""
from __future__ import annotations

import warnings
from typing import TYPE_CHECKING, Sequence

from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.models import BaseOperator
from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook
from airflow.providers.google.cloud.links.bigquery import BigQueryTableLink
Expand Down Expand Up @@ -110,6 +108,58 @@ def __init__(
self.encryption_configuration = encryption_configuration
self.location = location
self.impersonation_chain = impersonation_chain
self.hook: BigQueryHook | None = None

def _prepare_job_configuration(self):
self.source_project_dataset_tables = (
[self.source_project_dataset_tables]
if not isinstance(self.source_project_dataset_tables, list)
else self.source_project_dataset_tables
)

source_project_dataset_tables_fixup = []
for source_project_dataset_table in self.source_project_dataset_tables:
source_project, source_dataset, source_table = self.hook.split_tablename(
table_input=source_project_dataset_table,
default_project_id=self.hook.project_id,
var_name="source_project_dataset_table",
)
source_project_dataset_tables_fixup.append(
{"projectId": source_project, "datasetId": source_dataset, "tableId": source_table}
)

destination_project, destination_dataset, destination_table = self.hook.split_tablename(
table_input=self.destination_project_dataset_table,
default_project_id=self.hook.project_id,
)
configuration = {
"copy": {
"createDisposition": self.create_disposition,
"writeDisposition": self.write_disposition,
"sourceTables": source_project_dataset_tables_fixup,
"destinationTable": {
"projectId": destination_project,
"datasetId": destination_dataset,
"tableId": destination_table,
},
}
}

if self.labels:
configuration["labels"] = self.labels

if self.encryption_configuration:
configuration["copy"]["destinationEncryptionConfiguration"] = self.encryption_configuration

return configuration

def _submit_job(
self,
hook: BigQueryHook,
configuration: dict,
) -> str:
job = hook.insert_job(configuration=configuration, project_id=hook.project_id)
return job.job_id

def execute(self, context: Context) -> None:
self.log.info(
Expand All @@ -122,24 +172,20 @@ def execute(self, context: Context) -> None:
location=self.location,
impersonation_chain=self.impersonation_chain,
)
self.hook = hook

with warnings.catch_warnings():
warnings.simplefilter("ignore", AirflowProviderDeprecationWarning)
job_id = hook.run_copy(
source_project_dataset_tables=self.source_project_dataset_tables,
destination_project_dataset_table=self.destination_project_dataset_table,
write_disposition=self.write_disposition,
create_disposition=self.create_disposition,
labels=self.labels,
encryption_configuration=self.encryption_configuration,
)
if not hook.project_id:
raise ValueError("The project_id should be set")

job = hook.get_job(job_id=job_id, location=self.location).to_api_repr()
conf = job["configuration"]["copy"]["destinationTable"]
BigQueryTableLink.persist(
context=context,
task_instance=self,
dataset_id=conf["datasetId"],
project_id=conf["projectId"],
table_id=conf["tableId"],
)
configuration = self._prepare_job_configuration()
job_id = self._submit_job(hook=hook, configuration=configuration)

job = hook.get_job(job_id=job_id, location=self.location).to_api_repr()
conf = job["configuration"]["copy"]["destinationTable"]
BigQueryTableLink.persist(
context=context,
task_instance=self,
dataset_id=conf["datasetId"],
project_id=conf["projectId"],
table_id=conf["tableId"],
)
94 changes: 60 additions & 34 deletions tests/providers/google/cloud/transfers/test_bigquery_to_bigquery.py
Expand Up @@ -23,64 +23,90 @@

BQ_HOOK_PATH = "airflow.providers.google.cloud.transfers.bigquery_to_bigquery.BigQueryHook"
TASK_ID = "test-bq-create-table-operator"
TEST_GCP_PROJECT_ID = "test-project"
TEST_DATASET = "test-dataset"
TEST_TABLE_ID = "test-table-id"

SOURCE_PROJECT_DATASET_TABLES = f"{TEST_GCP_PROJECT_ID}.{TEST_DATASET}.{TEST_TABLE_ID}"
DESTINATION_PROJECT_DATASET_TABLE = f"{TEST_GCP_PROJECT_ID}.{TEST_DATASET + '_new'}.{TEST_TABLE_ID}"
WRITE_DISPOSITION = "WRITE_EMPTY"
CREATE_DISPOSITION = "CREATE_IF_NEEDED"
LABELS = {"k1": "v1"}
ENCRYPTION_CONFIGURATION = {"key": "kk"}


def split_tablename_side_effect(*args, **kwargs):
if kwargs["table_input"] == SOURCE_PROJECT_DATASET_TABLES:
return (
TEST_GCP_PROJECT_ID,
TEST_DATASET,
TEST_TABLE_ID,
)
elif kwargs["table_input"] == DESTINATION_PROJECT_DATASET_TABLE:
return (
TEST_GCP_PROJECT_ID,
TEST_DATASET + "_new",
TEST_TABLE_ID,
)


class TestBigQueryToBigQueryOperator:
@mock.patch(BQ_HOOK_PATH)
def test_execute_without_location_should_execute_successfully(self, mock_hook):
source_project_dataset_tables = f"{TEST_DATASET}.{TEST_TABLE_ID}"
destination_project_dataset_table = f"{TEST_DATASET + '_new'}.{TEST_TABLE_ID}"
write_disposition = "WRITE_EMPTY"
create_disposition = "CREATE_IF_NEEDED"
labels = {"k1": "v1"}
encryption_configuration = {"key": "kk"}

operator = BigQueryToBigQueryOperator(
task_id=TASK_ID,
source_project_dataset_tables=source_project_dataset_tables,
destination_project_dataset_table=destination_project_dataset_table,
write_disposition=write_disposition,
create_disposition=create_disposition,
labels=labels,
encryption_configuration=encryption_configuration,
source_project_dataset_tables=SOURCE_PROJECT_DATASET_TABLES,
destination_project_dataset_table=DESTINATION_PROJECT_DATASET_TABLE,
write_disposition=WRITE_DISPOSITION,
create_disposition=CREATE_DISPOSITION,
labels=LABELS,
encryption_configuration=ENCRYPTION_CONFIGURATION,
)

mock_hook.return_value.split_tablename.side_effect = split_tablename_side_effect
operator.execute(context=mock.MagicMock())
mock_hook.return_value.run_copy.assert_called_once_with(
source_project_dataset_tables=source_project_dataset_tables,
destination_project_dataset_table=destination_project_dataset_table,
write_disposition=write_disposition,
create_disposition=create_disposition,
labels=labels,
encryption_configuration=encryption_configuration,
mock_hook.return_value.insert_job.assert_called_once_with(
configuration={
"copy": {
"createDisposition": CREATE_DISPOSITION,
"destinationEncryptionConfiguration": ENCRYPTION_CONFIGURATION,
"destinationTable": {
"datasetId": TEST_DATASET + "_new",
"projectId": TEST_GCP_PROJECT_ID,
"tableId": TEST_TABLE_ID,
},
"sourceTables": [
{
"datasetId": TEST_DATASET,
"projectId": TEST_GCP_PROJECT_ID,
"tableId": TEST_TABLE_ID,
},
],
"writeDisposition": WRITE_DISPOSITION,
},
"labels": LABELS,
},
project_id=mock_hook.return_value.project_id,
)

@mock.patch(BQ_HOOK_PATH)
def test_execute_single_regional_location_should_execute_successfully(self, mock_hook):
source_project_dataset_tables = f"{TEST_DATASET}.{TEST_TABLE_ID}"
destination_project_dataset_table = f"{TEST_DATASET + '_new'}.{TEST_TABLE_ID}"
write_disposition = "WRITE_EMPTY"
create_disposition = "CREATE_IF_NEEDED"
labels = {"k1": "v1"}
location = "us-central1"
encryption_configuration = {"key": "kk"}
mock_hook.return_value.run_copy.return_value = "job-id"

operator = BigQueryToBigQueryOperator(
task_id=TASK_ID,
source_project_dataset_tables=source_project_dataset_tables,
destination_project_dataset_table=destination_project_dataset_table,
write_disposition=write_disposition,
create_disposition=create_disposition,
labels=labels,
encryption_configuration=encryption_configuration,
source_project_dataset_tables=SOURCE_PROJECT_DATASET_TABLES,
destination_project_dataset_table=DESTINATION_PROJECT_DATASET_TABLE,
write_disposition=WRITE_DISPOSITION,
create_disposition=CREATE_DISPOSITION,
labels=LABELS,
encryption_configuration=ENCRYPTION_CONFIGURATION,
location=location,
)

mock_hook.return_value.split_tablename.side_effect = split_tablename_side_effect
operator.execute(context=mock.MagicMock())
mock_hook.return_value.get_job.assert_called_once_with(
job_id="job-id",
job_id=mock_hook.return_value.insert_job.return_value.job_id,
location=location,
)

0 comments on commit 9207e7d

Please sign in to comment.