Skip to content

Commit

Permalink
Make Start and Stop SageMaker Pipelines operators deferrable (#32683)
Browse files Browse the repository at this point in the history
  • Loading branch information
vandonr-amz committed Jul 25, 2023
1 parent 9f3af9c commit 9570cb1
Show file tree
Hide file tree
Showing 6 changed files with 269 additions and 31 deletions.
38 changes: 24 additions & 14 deletions airflow/providers/amazon/aws/hooks/sagemaker.py
Expand Up @@ -23,14 +23,15 @@
import tarfile
import tempfile
import time
import warnings
from collections import Counter
from datetime import datetime
from functools import partial
from typing import Any, Callable, Generator, cast

from botocore.exceptions import ClientError

from airflow.exceptions import AirflowException
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
Expand Down Expand Up @@ -1061,7 +1062,7 @@ def start_pipeline(
display_name: str = "airflow-triggered-execution",
pipeline_params: dict | None = None,
wait_for_completion: bool = False,
check_interval: int = 30,
check_interval: int | None = None,
verbose: bool = True,
) -> str:
"""Start a new execution for a SageMaker pipeline.
Expand All @@ -1073,14 +1074,19 @@ def start_pipeline(
:param display_name: The name this pipeline execution will have in the UI. Doesn't need to be unique.
:param pipeline_params: Optional parameters for the pipeline.
All parameters supplied need to already be present in the pipeline definition.
:param wait_for_completion: Will only return once the pipeline is complete if true.
:param check_interval: How long to wait between checks for pipeline status when waiting for
completion.
:param verbose: Whether to print steps details when waiting for completion.
Defaults to true, consider turning off for pipelines that have thousands of steps.
:return: the ARN of the pipeline execution launched.
"""
if wait_for_completion or check_interval is not None:
warnings.warn(
"parameter `wait_for_completion` and `check_interval` are deprecated, "
"remove them and call check_status yourself if you want to wait for completion",
AirflowProviderDeprecationWarning,
stacklevel=2,
)
if check_interval is None:
check_interval = 30

formatted_params = format_tags(pipeline_params, key_label="Name")

try:
Expand Down Expand Up @@ -1108,7 +1114,7 @@ def stop_pipeline(
self,
pipeline_exec_arn: str,
wait_for_completion: bool = False,
check_interval: int = 10,
check_interval: int | None = None,
verbose: bool = True,
fail_if_not_running: bool = False,
) -> str:
Expand All @@ -1119,12 +1125,6 @@ def stop_pipeline(
:param pipeline_exec_arn: Amazon Resource Name (ARN) of the pipeline execution.
It's the ARN of the pipeline itself followed by "/execution/" and an id.
:param wait_for_completion: Whether to wait for the pipeline to reach a final state.
(i.e. either 'Stopped' or 'Failed')
:param check_interval: How long to wait between checks for pipeline status when waiting for
completion.
:param verbose: Whether to print steps details when waiting for completion.
Defaults to true, consider turning off for pipelines that have thousands of steps.
:param fail_if_not_running: This method will raise an exception if the pipeline we're trying to stop
is not in an "Executing" state when the call is sent (which would mean that the pipeline is
already either stopping or stopped).
Expand All @@ -1133,6 +1133,16 @@ def stop_pipeline(
:return: Status of the pipeline execution after the operation.
One of 'Executing'|'Stopping'|'Stopped'|'Failed'|'Succeeded'.
"""
if wait_for_completion or check_interval is not None:
warnings.warn(
"parameter `wait_for_completion` and `check_interval` are deprecated, "
"remove them and call check_status yourself if you want to wait for completion",
AirflowProviderDeprecationWarning,
stacklevel=2,
)
if check_interval is None:
check_interval = 10

retries = 2 # i.e. 3 calls max, 1 initial + 2 retries
while True:
try:
Expand Down
82 changes: 75 additions & 7 deletions airflow/providers/amazon/aws/operators/sagemaker.py
Expand Up @@ -30,7 +30,10 @@
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
from airflow.providers.amazon.aws.triggers.sagemaker import SageMakerTrigger
from airflow.providers.amazon.aws.triggers.sagemaker import (
SageMakerPipelineTrigger,
SageMakerTrigger,
)
from airflow.providers.amazon.aws.utils import trim_none_values
from airflow.providers.amazon.aws.utils.sagemaker import ApprovalStatus
from airflow.providers.amazon.aws.utils.tags import format_tags
Expand Down Expand Up @@ -998,8 +1001,10 @@ class SageMakerStartPipelineOperator(SageMakerBaseOperator):
All parameters supplied need to already be present in the pipeline definition.
:param wait_for_completion: If true, this operator will only complete once the pipeline is complete.
:param check_interval: How long to wait between checks for pipeline status when waiting for completion.
:param waiter_max_attempts: How many times to check the status before failing.
:param verbose: Whether to print steps details when waiting for completion.
Defaults to true, consider turning off for pipelines that have thousands of steps.
:param deferrable: Run operator in the deferrable mode.
:return str: Returns The ARN of the pipeline execution created in Amazon SageMaker.
"""
Expand All @@ -1015,7 +1020,9 @@ def __init__(
pipeline_params: dict | None = None,
wait_for_completion: bool = False,
check_interval: int = CHECK_INTERVAL_SECOND,
waiter_max_attempts: int = 9999,
verbose: bool = True,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
**kwargs,
):
super().__init__(config={}, aws_conn_id=aws_conn_id, **kwargs)
Expand All @@ -1024,22 +1031,46 @@ def __init__(
self.pipeline_params = pipeline_params
self.wait_for_completion = wait_for_completion
self.check_interval = check_interval
self.waiter_max_attempts = waiter_max_attempts
self.verbose = verbose
self.deferrable = deferrable

def execute(self, context: Context) -> str:
arn = self.hook.start_pipeline(
pipeline_name=self.pipeline_name,
display_name=self.display_name,
pipeline_params=self.pipeline_params,
wait_for_completion=self.wait_for_completion,
check_interval=self.check_interval,
verbose=self.verbose,
)
self.log.info(
"Starting a new execution for pipeline %s, running with ARN %s", self.pipeline_name, arn
)
if self.deferrable:
self.defer(
trigger=SageMakerPipelineTrigger(
waiter_type=SageMakerPipelineTrigger.Type.COMPLETE,
pipeline_execution_arn=arn,
waiter_delay=self.check_interval,
waiter_max_attempts=self.waiter_max_attempts,
aws_conn_id=self.aws_conn_id,
),
method_name="execute_complete",
)
elif self.wait_for_completion:
self.hook.check_status(
arn,
"PipelineExecutionStatus",
lambda p: self.hook.describe_pipeline_exec(p, self.verbose),
self.check_interval,
non_terminal_states=self.hook.pipeline_non_terminal_states,
max_ingestion_time=self.waiter_max_attempts * self.check_interval,
)
return arn

def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> str:
if event is None or event["status"] != "success":
raise AirflowException(f"Failure during pipeline execution: {event}")
return event["value"]


class SageMakerStopPipelineOperator(SageMakerBaseOperator):
"""
Expand All @@ -1057,6 +1088,7 @@ class SageMakerStopPipelineOperator(SageMakerBaseOperator):
:param verbose: Whether to print steps details when waiting for completion.
Defaults to true, consider turning off for pipelines that have thousands of steps.
:param fail_if_not_running: raises an exception if the pipeline stopped or succeeded before this was run
:param deferrable: Run operator in the deferrable mode.
:return str: Returns the status of the pipeline execution after the operation has been done.
"""
Expand All @@ -1073,32 +1105,68 @@ def __init__(
pipeline_exec_arn: str,
wait_for_completion: bool = False,
check_interval: int = CHECK_INTERVAL_SECOND,
waiter_max_attempts: int = 9999,
verbose: bool = True,
fail_if_not_running: bool = False,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
**kwargs,
):
super().__init__(config={}, aws_conn_id=aws_conn_id, **kwargs)
self.pipeline_exec_arn = pipeline_exec_arn
self.wait_for_completion = wait_for_completion
self.check_interval = check_interval
self.waiter_max_attempts = waiter_max_attempts
self.verbose = verbose
self.fail_if_not_running = fail_if_not_running
self.deferrable = deferrable

def execute(self, context: Context) -> str:
status = self.hook.stop_pipeline(
pipeline_exec_arn=self.pipeline_exec_arn,
wait_for_completion=self.wait_for_completion,
check_interval=self.check_interval,
verbose=self.verbose,
fail_if_not_running=self.fail_if_not_running,
)
self.log.info(
"Stop requested for pipeline execution with ARN %s. Status is now %s",
self.pipeline_exec_arn,
status,
)

if status not in self.hook.pipeline_non_terminal_states:
# pipeline already stopped
return status

# else, eventually wait for completion
if self.deferrable:
self.defer(
trigger=SageMakerPipelineTrigger(
waiter_type=SageMakerPipelineTrigger.Type.STOPPED,
pipeline_execution_arn=self.pipeline_exec_arn,
waiter_delay=self.check_interval,
waiter_max_attempts=self.waiter_max_attempts,
aws_conn_id=self.aws_conn_id,
),
method_name="execute_complete",
)
elif self.wait_for_completion:
status = self.hook.check_status(
self.pipeline_exec_arn,
"PipelineExecutionStatus",
lambda p: self.hook.describe_pipeline_exec(p, self.verbose),
self.check_interval,
non_terminal_states=self.hook.pipeline_non_terminal_states,
max_ingestion_time=self.waiter_max_attempts * self.check_interval,
)["PipelineExecutionStatus"]

return status

def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> str:
if event is None or event["status"] != "success":
raise AirflowException(f"Failure during pipeline execution: {event}")
else:
# theoretically we should do a `describe` call to know this,
# but if we reach this point, this is the only possible status
return "Stopped"


class SageMakerRegisterModelVersionOperator(SageMakerBaseOperator):
"""
Expand Down
1 change: 1 addition & 0 deletions airflow/providers/amazon/aws/triggers/ecs.py
Expand Up @@ -173,6 +173,7 @@ async def run(self) -> AsyncIterator[TriggerEvent]:
)
# we reach this point only if the waiter met a success criteria
yield TriggerEvent({"status": "success", "task_arn": self.task_arn})
return
except WaiterError as error:
if "terminal failure" in str(error):
raise
Expand Down
85 changes: 84 additions & 1 deletion airflow/providers/amazon/aws/triggers/sagemaker.py
Expand Up @@ -17,9 +17,15 @@

from __future__ import annotations

import asyncio
from collections import Counter
from enum import IntEnum
from functools import cached_property
from typing import Any
from typing import Any, AsyncIterator

from botocore.exceptions import WaiterError

from airflow import AirflowException
from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
from airflow.providers.amazon.aws.utils.waiter_with_logging import async_wait
from airflow.triggers.base import BaseTrigger, TriggerEvent
Expand Down Expand Up @@ -115,3 +121,80 @@ async def run(self):
status_args=[self._get_response_status_key(self.job_type)],
)
yield TriggerEvent({"status": "success", "message": "Job completed."})


class SageMakerPipelineTrigger(BaseTrigger):
"""Trigger to wait for a sagemaker pipeline execution to finish."""

class Type(IntEnum):
"""Type of waiter to use."""

COMPLETE = 1
STOPPED = 2

def __init__(
self,
waiter_type: Type,
pipeline_execution_arn: str,
waiter_delay: int,
waiter_max_attempts: int,
aws_conn_id: str,
):
self.waiter_type = waiter_type
self.pipeline_execution_arn = pipeline_execution_arn
self.waiter_delay = waiter_delay
self.waiter_max_attempts = waiter_max_attempts
self.aws_conn_id = aws_conn_id

def serialize(self) -> tuple[str, dict[str, Any]]:
return (
self.__class__.__module__ + "." + self.__class__.__qualname__,
{
"waiter_type": self.waiter_type.value, # saving the int value here
"pipeline_execution_arn": self.pipeline_execution_arn,
"waiter_delay": self.waiter_delay,
"waiter_max_attempts": self.waiter_max_attempts,
"aws_conn_id": self.aws_conn_id,
},
)

_waiter_name = {
Type.COMPLETE: "PipelineExecutionComplete",
Type.STOPPED: "PipelineExecutionStopped",
}

async def run(self) -> AsyncIterator[TriggerEvent]:
attempts = 0
hook = SageMakerHook(aws_conn_id=self.aws_conn_id)
async with hook.async_conn as conn:
waiter = hook.get_waiter(self._waiter_name[self.waiter_type], deferrable=True, client=conn)
while attempts < self.waiter_max_attempts:
attempts = attempts + 1
try:
await waiter.wait(
PipelineExecutionArn=self.pipeline_execution_arn, WaiterConfig={"MaxAttempts": 1}
)
# we reach this point only if the waiter met a success criteria
yield TriggerEvent({"status": "success", "value": self.pipeline_execution_arn})
return
except WaiterError as error:
if "terminal failure" in str(error):
raise

self.log.info(
"Status of the pipeline execution: %s", error.last_response["PipelineExecutionStatus"]
)

res = await conn.list_pipeline_execution_steps(
PipelineExecutionArn=self.pipeline_execution_arn
)
count_by_state = Counter(s["StepStatus"] for s in res["PipelineExecutionSteps"])
running_steps = [
s["StepName"] for s in res["PipelineExecutionSteps"] if s["StepStatus"] == "Executing"
]
self.log.info("State of the pipeline steps: %s", count_by_state)
self.log.info("Steps currently in progress: %s", running_steps)

await asyncio.sleep(int(self.waiter_delay))

raise AirflowException("Waiter error: max attempts reached")

0 comments on commit 9570cb1

Please sign in to comment.