Skip to content

Commit

Permalink
Add logic for on_kill for operator
Browse files Browse the repository at this point in the history
  • Loading branch information
sunank200 committed Apr 12, 2024
1 parent 4a3caa2 commit a5cd68d
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 0 deletions.
25 changes: 25 additions & 0 deletions airflow/providers/google/cloud/hooks/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -3389,6 +3389,31 @@ async def create_job_for_partition_get(
job_query_resp = await job_client.query(query_request, cast(Session, session))
return job_query_resp["jobReference"]["jobId"]

async def cancel_job(self, job_id: str, project_id: str | None, location: str | None) -> None:
"""
Asynchronously cancel a BigQuery job.
:param job_id: ID of the job to cancel.
:param project_id: Google Cloud Project where the job was running.
:param location: Location where the job was running.
"""
async with ClientSession() as session:
token = await self.get_token(session=session)
job = Job(job_id=job_id, project=project_id, location=location, token=token, session=session) # type: ignore

self.log.info(
"Attempting to cancel BigQuery job: %s in project: %s, location: %s",
job_id,
project_id,
location,
)
try:
await job.cancel()
self.log.info("Job %s cancellation requested.", job_id)
except Exception as e:
self.log.error("Failed to cancel BigQuery job %s: %s", job_id, str(e))
raise

def get_records(self, query_results: dict[str, Any], as_dict: bool = False) -> list[Any]:
"""Convert a response from BigQuery to records.
Expand Down
1 change: 1 addition & 0 deletions airflow/providers/google/cloud/operators/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -2903,6 +2903,7 @@ def execute(self, context: Any):
location=self.location or hook.location,
poll_interval=self.poll_interval,
impersonation_chain=self.impersonation_chain,
cancel_on_kill=self.cancel_on_kill,
),
method_name="execute_complete",
)
Expand Down
11 changes: 11 additions & 0 deletions airflow/providers/google/cloud/triggers/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def __init__(
table_id: str | None = None,
poll_interval: float = 4.0,
impersonation_chain: str | Sequence[str] | None = None,
cancel_on_kill: bool = True,
):
super().__init__()
self.log.info("Using the connection %s .", conn_id)
Expand All @@ -69,6 +70,7 @@ def __init__(
self.table_id = table_id
self.poll_interval = poll_interval
self.impersonation_chain = impersonation_chain
self.cancel_on_kill = cancel_on_kill

def serialize(self) -> tuple[str, dict[str, Any]]:
"""Serialize BigQueryInsertJobTrigger arguments and classpath."""
Expand All @@ -83,6 +85,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
"table_id": self.table_id,
"poll_interval": self.poll_interval,
"impersonation_chain": self.impersonation_chain,
"cancel_on_kill": self.cancel_on_kill,
},
)

Expand Down Expand Up @@ -113,6 +116,14 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override]
self.poll_interval,
)
await asyncio.sleep(self.poll_interval)
except asyncio.CancelledError:
self.log.info("Task was killed.")
if self.job_id and self.cancel_on_kill:
await hook.cancel_job( # type: ignore[union-attr]
job_id=self.job_id, project_id=self.project_id, location=self.location
)
else:
self.log.info("Skipping to cancel job: %s:%s.%s", self.project_id, self.location, self.job_id)
except Exception as e:
self.log.exception("Exception occurred while checking for query completion")
yield TriggerEvent({"status": "error", "message": str(e)})
Expand Down

0 comments on commit a5cd68d

Please sign in to comment.