Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 79 additions & 17 deletions airflow/providers/amazon/aws/hooks/redshift_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,26 @@
# under the License.
from __future__ import annotations

import asyncio
import time
from pprint import pformat
from typing import TYPE_CHECKING, Any, Iterable

import botocore.exceptions
from asgiref.sync import sync_to_async

from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
from airflow.providers.amazon.aws.utils import trim_none_values

if TYPE_CHECKING:
from mypy_boto3_redshift_data import RedshiftDataAPIServiceClient # noqa

FINISHED_STATE = "FINISHED"
FAILED_STATE = "FAILED"
ABORTED_STATE = "ABORTED"
FAILURE_STATES = {FAILED_STATE, ABORTED_STATE}
RUNNING_STATES = {"PICKED", "STARTED", "SUBMITTED"}


class RedshiftDataHook(AwsGenericHook["RedshiftDataAPIServiceClient"]):
"""
Expand Down Expand Up @@ -108,27 +118,33 @@ def execute_query(

return statement_id

def wait_for_results(self, statement_id, poll_interval):
def wait_for_results(self, statement_id: str, poll_interval: int) -> str:
while True:
self.log.info("Polling statement %s", statement_id)
resp = self.conn.describe_statement(
Id=statement_id,
)
status = resp["Status"]
if status == "FINISHED":
num_rows = resp.get("ResultRows")
if num_rows is not None:
self.log.info("Processed %s rows", num_rows)
return status
elif status in ("FAILED", "ABORTED"):
raise ValueError(
f"Statement {statement_id!r} terminated with status {status}. "
f"Response details: {pformat(resp)}"
)
else:
self.log.info("Query %s", status)
is_finised = self.check_query_is_finised(statement_id)
if is_finised:
return FINISHED_STATE

time.sleep(poll_interval)

def check_query_is_finised(self, statement_id: str) -> bool:
"""Check whether query finished, raise exception is failed."""
resp = self.conn.describe_statement(Id=statement_id)
status = resp["Status"]
if status == FINISHED_STATE:
num_rows = resp.get("ResultRows")
if num_rows is not None:
self.log.info("Processed %s rows", num_rows)
return True
elif status in FAILURE_STATES:
raise ValueError(
f"Statement {statement_id!r} terminated with status {status}. "
f"Response details: {pformat(resp)}"
)

self.log.info("Query %s", status)
return False

def get_table_primary_key(
self,
table: str,
Expand Down Expand Up @@ -201,3 +217,49 @@ def get_table_primary_key(
break

return pk_columns or None

async def check_query_is_finised_async(
self, statement_id: str, poll_interval: int = 10
) -> dict[str, str | list[str]]:
"""Async function to check statement is finised.

It takes statement_id, makes async connection to redshift data to get the query status
by statement_id and returns the query status.

:param statement_id: the UUID of the statement
:param poll_interval: how often in seconds to check the query status
"""
try:
client = await sync_to_async(self.get_conn)()
while await self.is_still_running(statement_id):
await asyncio.sleep(self.poll_interval)

resp = client.describe_statement(Id=statement_id)
status = resp["Status"]
if status == FINISHED_STATE:
return {"status": "success", "statement_id": statement_id}
elif status == FAILED_STATE:
return {
"status": "error",
"message": f"Error: {resp['QueryString']} query Failed due to, {resp["Error"]}",
"statement_id": statement_id,
"type": status,
}
elif status == ABORTED_STATE:
return {
"status": "error",
"message": "The query run was stopped by the user.",
"statement_id": statement_id,
"type": status,
}
except botocore.exceptions.ClientError as error:
return {"status": "error", "message": str(error), "type": "ERROR"}

async def is_still_running(self, statement_id: str) -> bool:
"""Async function to check whether the query is still running.

:param statement_id: the UUID of the statement
"""
client = await sync_to_async(self.get_conn)()
desc = client.describe_statement(Id=statement_id)
return desc["Status"] in RUNNING_STATES
49 changes: 47 additions & 2 deletions airflow/providers/amazon/aws/operators/redshift_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,13 @@
from __future__ import annotations

from functools import cached_property
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any

from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.redshift_data import RedshiftDataHook
from airflow.providers.amazon.aws.triggers.redshift_data import RedshiftDataTrigger

if TYPE_CHECKING:
from mypy_boto3_redshift_data.type_defs import GetStatementResultResponseTypeDef
Expand Down Expand Up @@ -87,6 +90,7 @@ def __init__(
aws_conn_id: str = "aws_default",
region: str | None = None,
workgroup_name: str | None = None,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand Down Expand Up @@ -121,6 +125,11 @@ def execute(self, context: Context) -> GetStatementResultResponseTypeDef | str:
"""Execute a statement against Amazon Redshift."""
self.log.info("Executing statement: %s", self.sql)

# Set wait_for_completion to False so that it waits for the status in the deferred task.
wait_for_completion = self.wait_for_completion
if self.deferrable and self.wait_for_completion:
self.wait_for_completion = False

self.statement_id = self.hook.execute_query(
database=self.database,
sql=self.sql,
Expand All @@ -131,17 +140,53 @@ def execute(self, context: Context) -> GetStatementResultResponseTypeDef | str:
secret_arn=self.secret_arn,
statement_name=self.statement_name,
with_event=self.with_event,
wait_for_completion=self.wait_for_completion,
wait_for_completion=wait_for_completion,
poll_interval=self.poll_interval,
)

if self.deferrable:
is_finished = self.hook.check_query_is_finised(self.statement_id)
if not is_finished:
self.defer(
timeout=self.execution_timeout,
trigger=RedshiftDataTrigger(
task_id=self.task_id,
poll_interval=self.poll_interval,
aws_conn_id=self.aws_conn_id,
region_name=self.region,
statement_id=self.statement_id,
),
method_name="execute_complete",
)

if self.return_sql_result:
result = self.hook.conn.get_statement_result(Id=self.statement_id)
self.log.debug("Statement result: %s", result)
return result
else:
return self.statement_id

def execute_complete(
self, context: Context, event: dict[str, Any] | None = None
) -> GetStatementResultResponseTypeDef | str:
if event is None:
err_msg = "Trigger error: event is None"
self.log.info(err_msg)
raise AirflowException(err_msg)

if event["status"] == "error":
msg = f"context: {context}, error message: {event["message"]}"
raise AirflowException(msg)
elif event["status"] == "success":
self.log.info("%s completed successfully.", self.task_id)

if self.return_sql_result:
result = self.hook.conn.get_statement_result(Id=self.statement_id)
self.log.debug("Statement result: %s", result)
return result
else:
return self.statement_id

def on_kill(self) -> None:
"""Cancel the submitted redshift query."""
if self.statement_id:
Expand Down
75 changes: 75 additions & 0 deletions airflow/providers/amazon/aws/triggers/redshit_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

from typing import Any, AsyncIterator

from astronomer.providers.amazon.aws.hooks.redshift_data import RedshiftDataHook

from airflow.triggers.base import BaseTrigger, TriggerEvent


class RedshiftDataTrigger(BaseTrigger):
"""
RedshiftDataTrigger is fired as deferred class with params to run the task in triggerer.

:param statement_id: the UUID of the statement
:param task_id: task ID of the Dag
:param poll_interval: polling period in seconds to check for the status
:param aws_conn_id: AWS connection ID for redshift
:param region: aws region to use
"""

def __init__(
self,
statement_id: str,
task_id: str,
poll_interval: int,
aws_conn_id: str = "aws_default",
region: str | None = None,
):
super().__init__()
self.statement_id = statement_id
self.task_id = task_id
self.aws_conn_id = aws_conn_id
self.poll_interval = poll_interval
self.region = region

def serialize(self) -> tuple[str, dict[str, Any]]:
"""Serializes RedshiftDataTrigger arguments and classpath."""
return (
"airflow.providers.amazon.aws.triggers.redshift_data.RedshiftDataTrigger",
{
"statement_id": self.statement_id,
"task_id": self.task_id,
"aws_conn_id": self.aws_conn_id,
"poll_interval": self.poll_interval,
"region": self.region,
},
)

async def run(self) -> AsyncIterator[TriggerEvent]:
# hook = RedshiftDataHook(aws_conn_id=self.aws_conn_id, poll_interval=self.poll_interval)
hook = RedshiftDataHook(aws_conn_id=self.aws_conn_id, region_name=self.region)
try:
response = await hook.check_query_is_finised_async(self.statement_id)
if not response:
response = {"status": "error", "message": f"{self.task_id} failed"}
yield TriggerEvent(response)
except Exception as e:
yield TriggerEvent({"status": "error", "message": str(e)})