Skip to content

Commit

Permalink
Add logic to handle on_kill for BigQueryInsertJobOperator when deferr…
Browse files Browse the repository at this point in the history
…able=True (#38912)
  • Loading branch information
sunank200 committed Apr 15, 2024
1 parent 456ec48 commit e237041
Show file tree
Hide file tree
Showing 5 changed files with 122 additions and 0 deletions.
25 changes: 25 additions & 0 deletions airflow/providers/google/cloud/hooks/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -3388,6 +3388,31 @@ async def create_job_for_partition_get(
job_query_resp = await job_client.query(query_request, cast(Session, session))
return job_query_resp["jobReference"]["jobId"]

async def cancel_job(self, job_id: str, project_id: str | None, location: str | None) -> None:
"""
Cancel a BigQuery job.
:param job_id: ID of the job to cancel.
:param project_id: Google Cloud Project where the job was running.
:param location: Location where the job was running.
"""
async with ClientSession() as session:
token = await self.get_token(session=session)
job = Job(job_id=job_id, project=project_id, location=location, token=token, session=session) # type: ignore[arg-type]

self.log.info(
"Attempting to cancel BigQuery job: %s in project: %s, location: %s",
job_id,
project_id,
location,
)
try:
await job.cancel()
self.log.info("Job %s cancellation requested.", job_id)
except Exception as e:
self.log.error("Failed to cancel BigQuery job %s: %s", job_id, str(e))
raise

def get_records(self, query_results: dict[str, Any], as_dict: bool = False) -> list[Any]:
"""Convert a response from BigQuery to records.
Expand Down
1 change: 1 addition & 0 deletions airflow/providers/google/cloud/operators/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -2903,6 +2903,7 @@ def execute(self, context: Any):
location=self.location or hook.location,
poll_interval=self.poll_interval,
impersonation_chain=self.impersonation_chain,
cancel_on_kill=self.cancel_on_kill,
),
method_name="execute_complete",
)
Expand Down
11 changes: 11 additions & 0 deletions airflow/providers/google/cloud/triggers/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def __init__(
table_id: str | None = None,
poll_interval: float = 4.0,
impersonation_chain: str | Sequence[str] | None = None,
cancel_on_kill: bool = True,
):
super().__init__()
self.log.info("Using the connection %s .", conn_id)
Expand All @@ -69,6 +70,7 @@ def __init__(
self.table_id = table_id
self.poll_interval = poll_interval
self.impersonation_chain = impersonation_chain
self.cancel_on_kill = cancel_on_kill

def serialize(self) -> tuple[str, dict[str, Any]]:
"""Serialize BigQueryInsertJobTrigger arguments and classpath."""
Expand All @@ -83,6 +85,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
"table_id": self.table_id,
"poll_interval": self.poll_interval,
"impersonation_chain": self.impersonation_chain,
"cancel_on_kill": self.cancel_on_kill,
},
)

Expand Down Expand Up @@ -113,6 +116,14 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override]
self.poll_interval,
)
await asyncio.sleep(self.poll_interval)
except asyncio.CancelledError:
self.log.info("Task was killed.")
if self.job_id and self.cancel_on_kill:
await hook.cancel_job( # type: ignore[union-attr]
job_id=self.job_id, project_id=self.project_id, location=self.location
)
else:
self.log.info("Skipping to cancel job: %s:%s.%s", self.project_id, self.location, self.job_id)
except Exception as e:
self.log.exception("Exception occurred while checking for query completion")
yield TriggerEvent({"status": "error", "message": str(e)})
Expand Down
49 changes: 49 additions & 0 deletions tests/providers/google/cloud/hooks/test_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -2190,6 +2190,55 @@ async def test_get_job_output_assert_once_with(self, mock_job_instance):
resp = await hook.get_job_output(job_id=JOB_ID, project_id=PROJECT_ID)
assert resp == response

@pytest.mark.asyncio
@pytest.mark.db_test
@mock.patch("google.auth.default")
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.Job")
async def test_cancel_job_success(self, mock_job, mock_auth_default):
mock_credentials = mock.MagicMock(spec=google.auth.compute_engine.Credentials)
mock_credentials.token = "ACCESS_TOKEN"
mock_auth_default.return_value = (mock_credentials, PROJECT_ID)
job_id = "test_job_id"
project_id = "test_project"
location = "US"

mock_job_instance = AsyncMock()
mock_job_instance.cancel.return_value = None
mock_job.return_value = mock_job_instance

await self.hook.cancel_job(job_id=job_id, project_id=project_id, location=location)

mock_job_instance.cancel.assert_called_once()

@pytest.mark.asyncio
@pytest.mark.db_test
@mock.patch("google.auth.default")
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.Job")
async def test_cancel_job_failure(self, mock_job, mock_auth_default):
"""
Test that BigQueryAsyncHook handles exceptions during job cancellation correctly.
"""
mock_credentials = mock.MagicMock(spec=google.auth.compute_engine.Credentials)
mock_credentials.token = "ACCESS_TOKEN"
mock_auth_default.return_value = (mock_credentials, PROJECT_ID)

mock_job_instance = AsyncMock()
mock_job_instance.cancel.side_effect = Exception("Cancellation failed")
mock_job.return_value = mock_job_instance

hook = BigQueryAsyncHook()

job_id = "test_job_id"
project_id = "test_project"
location = "US"

with pytest.raises(Exception) as excinfo:
await hook.cancel_job(job_id=job_id, project_id=project_id, location=location)

assert "Cancellation failed" in str(excinfo.value), "Exception message not passed correctly"

mock_job_instance.cancel.assert_called_once()

@pytest.mark.asyncio
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.ClientSession")
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_job_instance")
Expand Down
36 changes: 36 additions & 0 deletions tests/providers/google/cloud/triggers/test_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ def test_serialization(self, insert_job_trigger):
classpath, kwargs = insert_job_trigger.serialize()
assert classpath == "airflow.providers.google.cloud.triggers.bigquery.BigQueryInsertJobTrigger"
assert kwargs == {
"cancel_on_kill": True,
"conn_id": TEST_CONN_ID,
"job_id": TEST_JOB_ID,
"project_id": TEST_GCP_PROJECT_ID,
Expand Down Expand Up @@ -233,6 +234,41 @@ async def test_bigquery_op_trigger_exception(self, mock_job_status, caplog, inse
actual = await generator.asend(None)
assert TriggerEvent({"status": "error", "message": "Test exception"}) == actual

@pytest.mark.asyncio
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.cancel_job")
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_job_status")
async def test_bigquery_insert_job_trigger_cancellation(
self, mock_get_job_status, mock_cancel_job, caplog, insert_job_trigger
):
"""
Test that BigQueryInsertJobTrigger handles cancellation correctly, logs the appropriate message,
and conditionally cancels the job based on the `cancel_on_kill` attribute.
"""
insert_job_trigger.cancel_on_kill = True
insert_job_trigger.job_id = "1234"

mock_get_job_status.side_effect = [
{"status": "running", "message": "Job is still running"},
asyncio.CancelledError(),
]

mock_cancel_job.return_value = asyncio.Future()
mock_cancel_job.return_value.set_result(None)

caplog.set_level(logging.INFO)

try:
async for _ in insert_job_trigger.run():
pass
except asyncio.CancelledError:
pass

assert (
"Task was killed" in caplog.text
or "Bigquery job status is running. Sleeping for 4.0 seconds." in caplog.text
), "Expected messages about task status or cancellation not found in log."
mock_cancel_job.assert_awaited_once()


class TestBigQueryGetDataTrigger:
def test_bigquery_get_data_trigger_serialization(self, get_data_trigger):
Expand Down

0 comments on commit e237041

Please sign in to comment.