Skip to content

Commit

Permalink
Introduce Amazon Bedrock service (#38602)
Browse files Browse the repository at this point in the history
* Introduce Amazon Bedrock service
  • Loading branch information
ferruzzi committed Mar 30, 2024
1 parent 0f51347 commit 0723a8f
Show file tree
Hide file tree
Showing 8 changed files with 378 additions and 0 deletions.
39 changes: 39 additions & 0 deletions airflow/providers/amazon/aws/hooks/bedrock.py
@@ -0,0 +1,39 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook


class BedrockRuntimeHook(AwsBaseHook):
"""
Interact with the Amazon Bedrock Runtime.
Provide thin wrapper around :external+boto3:py:class:`boto3.client("bedrock-runtime") <BedrockRuntime.Client>`.
Additional arguments (such as ``aws_conn_id``) may be specified and
are passed down to the underlying AwsBaseHook.
.. seealso::
- :class:`airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook`
"""

client_type = "bedrock-runtime"

def __init__(self, *args, **kwargs) -> None:
kwargs["client_type"] = self.client_type
super().__init__(*args, **kwargs)
93 changes: 93 additions & 0 deletions airflow/providers/amazon/aws/operators/bedrock.py
@@ -0,0 +1,93 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import json
from typing import TYPE_CHECKING, Any, Sequence

from airflow.providers.amazon.aws.hooks.bedrock import BedrockRuntimeHook
from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator
from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
from airflow.utils.helpers import prune_dict

if TYPE_CHECKING:
from airflow.utils.context import Context


class BedrockInvokeModelOperator(AwsBaseOperator[BedrockRuntimeHook]):
"""
Invoke the specified Bedrock model to run inference using the input provided.
Use InvokeModel to run inference for text models, image models, and embedding models.
To see the format and content of the input_data field for different models, refer to
`Inference parameters docs <https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html>`_.
.. seealso::
For more information on how to use this operator, take a look at the guide:
:ref:`howto/operator:BedrockInvokeModelOperator`
:param model_id: The ID of the Bedrock model. (templated)
:param input_data: Input data in the format specified in the content-type request header. (templated)
:param content_type: The MIME type of the input data in the request. (templated) Default: application/json
:param accept: The desired MIME type of the inference body in the response.
(templated) Default: application/json
:param aws_conn_id: The Airflow connection used for AWS credentials.
If this is ``None`` or empty then the default boto3 behaviour is used. If
running Airflow in a distributed manner and aws_conn_id is None or
empty, then default boto3 configuration would be used (and must be
maintained on each worker node).
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
:param verify: Whether or not to verify SSL certificates. See:
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
"""

aws_hook_class = BedrockRuntimeHook
template_fields: Sequence[str] = aws_template_fields(
"model_id", "input_data", "content_type", "accept_type"
)

def __init__(
self,
model_id: str,
input_data: dict[str, Any],
content_type: str | None = None,
accept_type: str | None = None,
**kwargs,
):
super().__init__(**kwargs)
self.model_id = model_id
self.input_data = input_data
self.content_type = content_type
self.accept_type = accept_type

def execute(self, context: Context) -> dict[str, str | int]:
# These are optional values which the API defaults to "application/json" if not provided here.
invoke_kwargs = prune_dict({"contentType": self.content_type, "accept": self.accept_type})

response = self.hook.conn.invoke_model(
body=json.dumps(self.input_data),
modelId=self.model_id,
**invoke_kwargs,
)

response_body = json.loads(response["body"].read())
self.log.info("Bedrock %s prompt: %s", self.model_id, self.input_data)
self.log.info("Bedrock model response: %s", response_body)
return response_body
12 changes: 12 additions & 0 deletions airflow/providers/amazon/provider.yaml
Expand Up @@ -142,6 +142,12 @@ integrations:
- /docs/apache-airflow-providers-amazon/operators/athena/athena_boto.rst
- /docs/apache-airflow-providers-amazon/operators/athena/athena_sql.rst
tags: [aws]
- integration-name: Amazon Bedrock
external-doc-url: https://aws.amazon.com/bedrock/
logo: /integration-logos/aws/Amazon-Bedrock_light-bg@4x.png
how-to-guide:
- /docs/apache-airflow-providers-amazon/operators/bedrock.rst
tags: [aws]
- integration-name: Amazon Chime
external-doc-url: https://aws.amazon.com/chime/
logo: /integration-logos/aws/Amazon-Chime-light-bg.png
Expand Down Expand Up @@ -363,6 +369,9 @@ operators:
- integration-name: AWS Batch
python-modules:
- airflow.providers.amazon.aws.operators.batch
- integration-name: Amazon Bedrock
python-modules:
- airflow.providers.amazon.aws.operators.bedrock
- integration-name: Amazon CloudFormation
python-modules:
- airflow.providers.amazon.aws.operators.cloud_formation
Expand Down Expand Up @@ -514,6 +523,9 @@ hooks:
python-modules:
- airflow.providers.amazon.aws.hooks.athena
- airflow.providers.amazon.aws.hooks.athena_sql
- integration-name: Amazon Bedrock
python-modules:
- airflow.providers.amazon.aws.hooks.bedrock
- integration-name: Amazon Chime
python-modules:
- airflow.providers.amazon.aws.hooks.chime
Expand Down
72 changes: 72 additions & 0 deletions docs/apache-airflow-providers-amazon/operators/bedrock.rst
@@ -0,0 +1,72 @@
.. Licensed to the Apache Software Foundation (ASF) under one
or more contributor license agreements. See the NOTICE file
distributed with this work for additional information
regarding copyright ownership. The ASF licenses this file
to you under the Apache License, Version 2.0 (the
"License"); you may not use this file except in compliance
with the License. You may obtain a copy of the License at
.. http://www.apache.org/licenses/LICENSE-2.0
.. Unless required by applicable law or agreed to in writing,
software distributed under the License is distributed on an
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, either express or implied. See the License for the
specific language governing permissions and limitations
under the License.
==============
Amazon Bedrock
==============

`Amazon Bedrock <https://aws.amazon.com/bedrock/>`__ is a fully managed service that
offers a choice of high-performing foundation models (FMs) from leading AI companies
like AI21 Labs, Anthropic, Cohere, Meta, Mistral AI, Stability AI, and Amazon via a
single API, along with a broad set of capabilities you need to build generative AI
applications with security, privacy, and responsible AI.

Prerequisite Tasks
------------------

.. include:: ../_partials/prerequisite_tasks.rst

Generic Parameters
------------------

.. include:: ../_partials/generic_parameters.rst

Operators
---------

.. _howto/operator:BedrockInvokeModelOperator:

Invoke an existing Amazon Bedrock Model
=======================================

To invoke an existing Amazon Bedrock model, you can use
:class:`~airflow.providers.amazon.aws.operators.bedrock.BedrockInvokeModelOperator`.

Note that every model family has different input and output formats.
For example, to invoke a Meta Llama model you would use:

.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_bedrock.py
:language: python
:dedent: 4
:start-after: [START howto_operator_invoke_llama_model]
:end-before: [END howto_operator_invoke_llama_model]

To invoke an Amazon Titan model you would use:

.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_bedrock.py
:language: python
:dedent: 4
:start-after: [START howto_operator_invoke_titan_model]
:end-before: [END howto_operator_invoke_titan_model]

For details on the different formats, see `Inference parameters for foundation models <https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html>`__


Reference
---------

* `AWS boto3 library documentation for Amazon Bedrock <https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/bedrock.html>`__
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
27 changes: 27 additions & 0 deletions tests/providers/amazon/aws/hooks/test_bedrock.py
@@ -0,0 +1,27 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

from airflow.providers.amazon.aws.hooks.bedrock import BedrockRuntimeHook


class TestBedrockRuntimeHook:
def test_conn_returns_a_boto3_connection(self):
hook = BedrockRuntimeHook()

assert hook.conn is not None
assert hook.conn.meta.service_model.service_name == "bedrock-runtime"
59 changes: 59 additions & 0 deletions tests/providers/amazon/aws/operators/test_bedrock.py
@@ -0,0 +1,59 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

import json
from typing import Generator
from unittest import mock

import pytest
from moto import mock_aws

from airflow.providers.amazon.aws.hooks.bedrock import BedrockRuntimeHook
from airflow.providers.amazon.aws.operators.bedrock import BedrockInvokeModelOperator

MODEL_ID = "meta.llama2-13b-chat-v1"
PROMPT = "A very important question."
GENERATED_RESPONSE = "An important answer."
MOCK_RESPONSE = json.dumps(
{
"generation": GENERATED_RESPONSE,
"prompt_token_count": len(PROMPT),
"generation_token_count": len(GENERATED_RESPONSE),
"stop_reason": "stop",
}
)


@pytest.fixture
def runtime_hook() -> Generator[BedrockRuntimeHook, None, None]:
with mock_aws():
yield BedrockRuntimeHook(aws_conn_id="aws_default")


class TestBedrockInvokeModelOperator:
@mock.patch.object(BedrockRuntimeHook, "conn")
def test_invoke_model_prompt_good_combinations(self, mock_conn):
mock_conn.invoke_model.return_value["body"].read.return_value = MOCK_RESPONSE
operator = BedrockInvokeModelOperator(
task_id="test_task", model_id=MODEL_ID, input_data={"input_data": {"prompt": PROMPT}}
)

response = operator.execute({})

assert response["generation"] == GENERATED_RESPONSE

0 comments on commit 0723a8f

Please sign in to comment.