Skip to content

Commit

Permalink
ingest_data - operator, trigger, sensor, and waiter (with unit tests)
Browse files Browse the repository at this point in the history
  • Loading branch information
ferruzzi committed Apr 29, 2024
1 parent b19e0ee commit 78df56c
Show file tree
Hide file tree
Showing 8 changed files with 405 additions and 1 deletion.
73 changes: 73 additions & 0 deletions airflow/providers/amazon/aws/operators/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,3 +565,76 @@ def execute(self, context: Context) -> str:
)

return create_ds_response["dataSource"]["dataSourceId"]


class BedrockIngestDataOperator(AwsBaseOperator[BedrockAgentHook]):
"""
Begin an ingestion job, in which an Amazon Bedrock data source is added to an Amazon Bedrock knowledge base.
.. seealso::
For more information on how to use this operator, take a look at the guide:
:ref:`howto/operator:BedrockIngestDataOperator`
:param knowledge_base_id: The unique identifier of the knowledge base to which to add the data source. (templated)
:param data_source_id: The unique identifier of the data source to ingest. (templated)
:param wait_for_completion: Whether to wait for cluster to stop. (default: True)
:param waiter_delay: Time in seconds to wait between status checks. (default: 60)
:param waiter_max_attempts: Maximum number of attempts to check for job completion. (default: 10)
:param deferrable: If True, the operator will wait asynchronously for the cluster to stop.
This implies waiting for completion. This mode requires aiobotocore module to be installed.
(default: False)
:param aws_conn_id: The Airflow connection used for AWS credentials.
If this is ``None`` or empty then the default boto3 behaviour is used. If
running Airflow in a distributed manner and aws_conn_id is None or
empty, then default boto3 configuration would be used (and must be
maintained on each worker node).
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
:param verify: Whether or not to verify SSL certificates. See:
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
"""

aws_hook_class = BedrockAgentHook

def __init__(
self,
knowledge_base_id: str,
data_source_id: str,
ingest_data_kwargs: dict[str, Any] | None = None,
wait_for_completion: bool = True,
waiter_delay: int = 60,
waiter_max_attempts: int = 10,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
**kwargs,
):
super().__init__(**kwargs)
self.knowledge_base_id = knowledge_base_id
self.data_source_id = data_source_id
self.ingest_data_kwargs = ingest_data_kwargs or {}

self.wait_for_completion = wait_for_completion
self.waiter_delay = waiter_delay
self.waiter_max_attempts = waiter_max_attempts
self.deferrable = deferrable

template_fields: Sequence[str] = aws_template_fields(
"knowledge_base_id",
"data_source_id",
)

def execute(self, context: Context) -> str:
ingestion_job_id = self.hook.conn.start_ingestion_job(
knowledgeBaseId=self.knowledge_base_id, dataSourceId=self.data_source_id
)["ingestionJob"]["ingestionJobId"]

if self.wait_for_completion:
self.log.info("Waiting for ingestion job %s", ingestion_job_id)
self.hook.get_waiter(waiter_name="ingestion_job_complete").wait(
knowledgeBaseId=self.knowledge_base_id,
dataSourceId=self.data_source_id,
ingestionJobId=ingestion_job_id,
)

return ingestion_job_id
80 changes: 80 additions & 0 deletions airflow/providers/amazon/aws/sensors/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor
from airflow.providers.amazon.aws.triggers.bedrock import (
BedrockCustomizeModelCompletedTrigger,
BedrockIngestionJobTrigger,
BedrockKnowledgeBaseActiveTrigger,
BedrockProvisionModelThroughputCompletedTrigger,
)
Expand Down Expand Up @@ -328,3 +329,82 @@ def execute(self, context: Context) -> Any:
)
else:
super().execute(context=context)


class BedrockIngestionJobSensor(BedrockAgentBaseSensor):
"""
Poll the ingestion job status until it reaches a terminal state; fails if creation fails.
.. seealso::
For more information on how to use this sensor, take a look at the guide:
:ref:`howto/sensor:BedrockIngestionJobSensor`
:param knowledge_base_id: The unique identifier of the knowledge base for which to get information. (templated)
:param data_source_id: The unique identifier of the data source in the ingestion job. (templated)
:param ingestion_job_id: The unique identifier of the ingestion job. (templated)
:param deferrable: If True, the sensor will operate in deferrable more. This mode requires aiobotocore
module to be installed.
(default: False, but can be overridden in config file by setting default_deferrable to True)
:param poke_interval: Polling period in seconds to check for the status of the job. (default: 60)
:param max_retries: Number of times before returning the current state (default: 10)
:param aws_conn_id: The Airflow connection used for AWS credentials.
If this is ``None`` or empty then the default boto3 behaviour is used. If
running Airflow in a distributed manner and aws_conn_id is None or
empty, then default boto3 configuration would be used (and must be
maintained on each worker node).
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
:param verify: Whether or not to verify SSL certificates. See:
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
"""

INTERMEDIATE_STATES: tuple[str, ...] = ("STARTING", "IN_PROGRESS")
FAILURE_STATES: tuple[str, ...] = ("FAILED",)
SUCCESS_STATES: tuple[str, ...] = ("COMPLETE",)
FAILURE_MESSAGE = "Bedrock ingestion job sensor failed."

template_fields: Sequence[str] = aws_template_fields(
"knowledge_base_id", "data_source_id", "ingestion_job_id"
)

def __init__(
self,
*,
knowledge_base_id: str,
data_source_id: str,
ingestion_job_id: str,
poke_interval: int = 60,
max_retries: int = 10,
**kwargs,
) -> None:
super().__init__(**kwargs)
self.poke_interval = poke_interval
self.max_retries = max_retries
self.knowledge_base_id = knowledge_base_id
self.data_source_id = data_source_id
self.ingestion_job_id = ingestion_job_id

def get_state(self) -> str:
return self.hook.conn.get_ingestion_job(
knowledgeBaseId=self.knowledge_base_id,
ingestionJobId=self.ingestion_job_id,
dataSourceId=self.data_source_id,
)["ingestionJob"]["status"]

def execute(self, context: Context) -> Any:
if self.deferrable:
self.defer(
trigger=BedrockIngestionJobTrigger(
knowledge_base_id=self.knowledge_base_id,
ingestion_job_id=self.ingestion_job_id,
data_source_id=self.data_source_id,
waiter_delay=int(self.poke_interval),
waiter_max_attempts=self.max_retries,
aws_conn_id=self.aws_conn_id,
),
method_name="poke",
)
else:
super().execute(context=context)
49 changes: 49 additions & 0 deletions airflow/providers/amazon/aws/triggers/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,3 +131,52 @@ def __init__(

def hook(self) -> AwsGenericHook:
return BedrockHook(aws_conn_id=self.aws_conn_id)


class BedrockIngestionJobTrigger(AwsBaseWaiterTrigger):
"""
Trigger when a Bedrock ingestion job reaches the COMPLETE state.
:param knowledge_base_id: The unique identifier of the knowledge base for which to get information.
:param data_source_id: The unique identifier of the data source in the ingestion job.
:param ingestion_job_id: The unique identifier of the ingestion job.
:param waiter_delay: The amount of time in seconds to wait between attempts. (default: 60)
:param waiter_max_attempts: The maximum number of attempts to be made. (default: 10)
:param aws_conn_id: The Airflow connection used for AWS credentials.
"""

def __init__(
self,
*,
knowledge_base_id: str,
data_source_id: str,
ingestion_job_id: str,
waiter_delay: int = 60,
waiter_max_attempts: int = 10,
aws_conn_id: str | None = None,
) -> None:
super().__init__(
serialized_fields={
"knowledge_base_id": knowledge_base_id,
"data_source_id": data_source_id,
"ingestion_job_id": ingestion_job_id,
},
waiter_name="ingestion_job_complete",
waiter_args={
"knowledge_base_id": knowledge_base_id,
"data_source_id": data_source_id,
"ingestion_job_id": ingestion_job_id,
},
failure_message="Bedrock ingestion job creation failed.",
status_message="Status of Bedrock ingestion job is",
status_queries=["status"],
return_key="ingestion_job_id",
return_value=ingestion_job_id,
waiter_delay=waiter_delay,
waiter_max_attempts=waiter_max_attempts,
aws_conn_id=aws_conn_id,
)

def hook(self) -> AwsGenericHook:
return BedrockAgentHook(aws_conn_id=self.aws_conn_id)
31 changes: 31 additions & 0 deletions airflow/providers/amazon/aws/waiters/bedrock-agent.json
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,37 @@
"state": "failure"
}
]
},
"ingestion_job_complete": {
"delay": 60,
"maxAttempts": 10,
"operation": "getIngestionJob",
"acceptors": [
{
"matcher": "path",
"argument": "ingestionJob.status",
"expected": "COMPLETE",
"state": "success"
},
{
"matcher": "path",
"argument": "ingestionJob.status",
"expected": "STARTING",
"state": "retry"
},
{
"matcher": "path",
"argument": "ingestionJob.status",
"expected": "IN_PROGRESS",
"state": "retry"
},
{
"matcher": "path",
"argument": "ingestionJob.status",
"expected": "FAILED",
"state": "failure"
}
]
}
}
}
32 changes: 32 additions & 0 deletions tests/providers/amazon/aws/operators/test_bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
BedrockCreateKnowledgeBaseOperator,
BedrockCreateProvisionedModelThroughputOperator,
BedrockCustomizeModelOperator,
BedrockIngestDataOperator,
BedrockInvokeModelOperator,
)

Expand Down Expand Up @@ -314,3 +315,34 @@ def test_id_returned(self, mock_conn):
result = self.operator.execute({})

assert result == self.DATA_SOURCE_ID


class TestBedrockIngestDataOperator:
INGESTION_JOB_ID = "ingestion_job_id"

@pytest.fixture
def mock_conn(self) -> Generator[BaseAwsConnection, None, None]:
with mock.patch.object(BedrockAgentHook, "conn") as _conn:
_conn.start_ingestion_job.return_value = {
"ingestionJob": {"ingestionJobId": self.INGESTION_JOB_ID}
}
yield _conn

@pytest.fixture
def bedrock_hook(self) -> Generator[BedrockAgentHook, None, None]:
with mock_aws():
hook = BedrockAgentHook()
yield hook

def setup_method(self):
self.operator = BedrockIngestDataOperator(
task_id="create_data_source",
data_source_id="data_source_id",
knowledge_base_id="knowledge_base_id",
wait_for_completion=False,
)

def test_id_returned(self, mock_conn):
result = self.operator.execute({})

assert result == self.INGESTION_JOB_ID
63 changes: 63 additions & 0 deletions tests/providers/amazon/aws/sensors/test_bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from airflow.providers.amazon.aws.hooks.bedrock import BedrockAgentHook, BedrockHook
from airflow.providers.amazon.aws.sensors.bedrock import (
BedrockCustomizeModelCompletedSensor,
BedrockIngestionJobSensor,
BedrockKnowledgeBaseActiveSensor,
BedrockProvisionModelThroughputCompletedSensor,
)
Expand Down Expand Up @@ -209,3 +210,65 @@ def test_poke_failure_states(self, mock_conn, state, soft_fail, expected_excepti
sensor = self.SENSOR(**self.default_op_kwargs, aws_conn_id=None, soft_fail=soft_fail)
with pytest.raises(expected_exception, match=sensor.FAILURE_MESSAGE):
sensor.poke({})


class TestBedrockIngestionJobSensor:
SENSOR = BedrockIngestionJobSensor

def setup_method(self):
self.default_op_kwargs = dict(
task_id="test_bedrock_knowledge_base_active_sensor",
knowledge_base_id="knowledge_base_id",
data_source_id="data_source_id",
ingestion_job_id="ingestion_job_id",
poke_interval=5,
max_retries=1,
)
self.sensor = self.SENSOR(**self.default_op_kwargs, aws_conn_id=None)

def test_base_aws_op_attributes(self):
op = self.SENSOR(**self.default_op_kwargs)
assert op.hook.aws_conn_id == "aws_default"
assert op.hook._region_name is None
assert op.hook._verify is None
assert op.hook._config is None

op = self.SENSOR(
**self.default_op_kwargs,
aws_conn_id="aws-test-custom-conn",
region_name="eu-west-1",
verify=False,
botocore_config={"read_timeout": 42},
)
assert op.hook.aws_conn_id == "aws-test-custom-conn"
assert op.hook._region_name == "eu-west-1"
assert op.hook._verify is False
assert op.hook._config is not None
assert op.hook._config.read_timeout == 42

@pytest.mark.parametrize("state", SENSOR.SUCCESS_STATES)
@mock.patch.object(BedrockAgentHook, "conn")
def test_poke_success_states(self, mock_conn, state):
mock_conn.get_ingestion_job.return_value = {"ingestionJob": {"status": state}}
assert self.sensor.poke({}) is True

@pytest.mark.parametrize("state", SENSOR.INTERMEDIATE_STATES)
@mock.patch.object(BedrockAgentHook, "conn")
def test_poke_intermediate_states(self, mock_conn, state):
mock_conn.get_ingestion_job.return_value = {"ingestionJob": {"status": state}}
assert self.sensor.poke({}) is False

@pytest.mark.parametrize(
"soft_fail, expected_exception",
[
pytest.param(False, AirflowException, id="not-soft-fail"),
pytest.param(True, AirflowSkipException, id="soft-fail"),
],
)
@pytest.mark.parametrize("state", SENSOR.FAILURE_STATES)
@mock.patch.object(BedrockAgentHook, "conn")
def test_poke_failure_states(self, mock_conn, state, soft_fail, expected_exception):
mock_conn.get_ingestion_job.return_value = {"ingestionJob": {"status": state}}
sensor = self.SENSOR(**self.default_op_kwargs, aws_conn_id=None, soft_fail=soft_fail)
with pytest.raises(expected_exception, match=sensor.FAILURE_MESSAGE):
sensor.poke({})

0 comments on commit 78df56c

Please sign in to comment.