Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Introduce Amazon Bedrock service (#38602)
* Introduce Amazon Bedrock service
- Loading branch information
Showing
8 changed files
with
378 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
from __future__ import annotations | ||
|
||
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook | ||
|
||
|
||
class BedrockRuntimeHook(AwsBaseHook): | ||
""" | ||
Interact with the Amazon Bedrock Runtime. | ||
Provide thin wrapper around :external+boto3:py:class:`boto3.client("bedrock-runtime") <BedrockRuntime.Client>`. | ||
Additional arguments (such as ``aws_conn_id``) may be specified and | ||
are passed down to the underlying AwsBaseHook. | ||
.. seealso:: | ||
- :class:`airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook` | ||
""" | ||
|
||
client_type = "bedrock-runtime" | ||
|
||
def __init__(self, *args, **kwargs) -> None: | ||
kwargs["client_type"] = self.client_type | ||
super().__init__(*args, **kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
from __future__ import annotations | ||
|
||
import json | ||
from typing import TYPE_CHECKING, Any, Sequence | ||
|
||
from airflow.providers.amazon.aws.hooks.bedrock import BedrockRuntimeHook | ||
from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator | ||
from airflow.providers.amazon.aws.utils.mixins import aws_template_fields | ||
from airflow.utils.helpers import prune_dict | ||
|
||
if TYPE_CHECKING: | ||
from airflow.utils.context import Context | ||
|
||
|
||
class BedrockInvokeModelOperator(AwsBaseOperator[BedrockRuntimeHook]): | ||
""" | ||
Invoke the specified Bedrock model to run inference using the input provided. | ||
Use InvokeModel to run inference for text models, image models, and embedding models. | ||
To see the format and content of the input_data field for different models, refer to | ||
`Inference parameters docs <https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html>`_. | ||
.. seealso:: | ||
For more information on how to use this operator, take a look at the guide: | ||
:ref:`howto/operator:BedrockInvokeModelOperator` | ||
:param model_id: The ID of the Bedrock model. (templated) | ||
:param input_data: Input data in the format specified in the content-type request header. (templated) | ||
:param content_type: The MIME type of the input data in the request. (templated) Default: application/json | ||
:param accept: The desired MIME type of the inference body in the response. | ||
(templated) Default: application/json | ||
:param aws_conn_id: The Airflow connection used for AWS credentials. | ||
If this is ``None`` or empty then the default boto3 behaviour is used. If | ||
running Airflow in a distributed manner and aws_conn_id is None or | ||
empty, then default boto3 configuration would be used (and must be | ||
maintained on each worker node). | ||
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. | ||
:param verify: Whether or not to verify SSL certificates. See: | ||
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html | ||
:param botocore_config: Configuration dictionary (key-values) for botocore client. See: | ||
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html | ||
""" | ||
|
||
aws_hook_class = BedrockRuntimeHook | ||
template_fields: Sequence[str] = aws_template_fields( | ||
"model_id", "input_data", "content_type", "accept_type" | ||
) | ||
|
||
def __init__( | ||
self, | ||
model_id: str, | ||
input_data: dict[str, Any], | ||
content_type: str | None = None, | ||
accept_type: str | None = None, | ||
**kwargs, | ||
): | ||
super().__init__(**kwargs) | ||
self.model_id = model_id | ||
self.input_data = input_data | ||
self.content_type = content_type | ||
self.accept_type = accept_type | ||
|
||
def execute(self, context: Context) -> dict[str, str | int]: | ||
# These are optional values which the API defaults to "application/json" if not provided here. | ||
invoke_kwargs = prune_dict({"contentType": self.content_type, "accept": self.accept_type}) | ||
|
||
response = self.hook.conn.invoke_model( | ||
body=json.dumps(self.input_data), | ||
modelId=self.model_id, | ||
**invoke_kwargs, | ||
) | ||
|
||
response_body = json.loads(response["body"].read()) | ||
self.log.info("Bedrock %s prompt: %s", self.model_id, self.input_data) | ||
self.log.info("Bedrock model response: %s", response_body) | ||
return response_body |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
72 changes: 72 additions & 0 deletions
72
docs/apache-airflow-providers-amazon/operators/bedrock.rst
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
.. Licensed to the Apache Software Foundation (ASF) under one | ||
or more contributor license agreements. See the NOTICE file | ||
distributed with this work for additional information | ||
regarding copyright ownership. The ASF licenses this file | ||
to you under the Apache License, Version 2.0 (the | ||
"License"); you may not use this file except in compliance | ||
with the License. You may obtain a copy of the License at | ||
.. http://www.apache.org/licenses/LICENSE-2.0 | ||
.. Unless required by applicable law or agreed to in writing, | ||
software distributed under the License is distributed on an | ||
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
KIND, either express or implied. See the License for the | ||
specific language governing permissions and limitations | ||
under the License. | ||
============== | ||
Amazon Bedrock | ||
============== | ||
|
||
`Amazon Bedrock <https://aws.amazon.com/bedrock/>`__ is a fully managed service that | ||
offers a choice of high-performing foundation models (FMs) from leading AI companies | ||
like AI21 Labs, Anthropic, Cohere, Meta, Mistral AI, Stability AI, and Amazon via a | ||
single API, along with a broad set of capabilities you need to build generative AI | ||
applications with security, privacy, and responsible AI. | ||
|
||
Prerequisite Tasks | ||
------------------ | ||
|
||
.. include:: ../_partials/prerequisite_tasks.rst | ||
|
||
Generic Parameters | ||
------------------ | ||
|
||
.. include:: ../_partials/generic_parameters.rst | ||
|
||
Operators | ||
--------- | ||
|
||
.. _howto/operator:BedrockInvokeModelOperator: | ||
|
||
Invoke an existing Amazon Bedrock Model | ||
======================================= | ||
|
||
To invoke an existing Amazon Bedrock model, you can use | ||
:class:`~airflow.providers.amazon.aws.operators.bedrock.BedrockInvokeModelOperator`. | ||
|
||
Note that every model family has different input and output formats. | ||
For example, to invoke a Meta Llama model you would use: | ||
|
||
.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_bedrock.py | ||
:language: python | ||
:dedent: 4 | ||
:start-after: [START howto_operator_invoke_llama_model] | ||
:end-before: [END howto_operator_invoke_llama_model] | ||
|
||
To invoke an Amazon Titan model you would use: | ||
|
||
.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_bedrock.py | ||
:language: python | ||
:dedent: 4 | ||
:start-after: [START howto_operator_invoke_titan_model] | ||
:end-before: [END howto_operator_invoke_titan_model] | ||
|
||
For details on the different formats, see `Inference parameters for foundation models <https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html>`__ | ||
|
||
|
||
Reference | ||
--------- | ||
|
||
* `AWS boto3 library documentation for Amazon Bedrock <https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/bedrock.html>`__ |
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
from __future__ import annotations | ||
|
||
from airflow.providers.amazon.aws.hooks.bedrock import BedrockRuntimeHook | ||
|
||
|
||
class TestBedrockRuntimeHook: | ||
def test_conn_returns_a_boto3_connection(self): | ||
hook = BedrockRuntimeHook() | ||
|
||
assert hook.conn is not None | ||
assert hook.conn.meta.service_model.service_name == "bedrock-runtime" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
|
||
from __future__ import annotations | ||
|
||
import json | ||
from typing import Generator | ||
from unittest import mock | ||
|
||
import pytest | ||
from moto import mock_aws | ||
|
||
from airflow.providers.amazon.aws.hooks.bedrock import BedrockRuntimeHook | ||
from airflow.providers.amazon.aws.operators.bedrock import BedrockInvokeModelOperator | ||
|
||
MODEL_ID = "meta.llama2-13b-chat-v1" | ||
PROMPT = "A very important question." | ||
GENERATED_RESPONSE = "An important answer." | ||
MOCK_RESPONSE = json.dumps( | ||
{ | ||
"generation": GENERATED_RESPONSE, | ||
"prompt_token_count": len(PROMPT), | ||
"generation_token_count": len(GENERATED_RESPONSE), | ||
"stop_reason": "stop", | ||
} | ||
) | ||
|
||
|
||
@pytest.fixture | ||
def runtime_hook() -> Generator[BedrockRuntimeHook, None, None]: | ||
with mock_aws(): | ||
yield BedrockRuntimeHook(aws_conn_id="aws_default") | ||
|
||
|
||
class TestBedrockInvokeModelOperator: | ||
@mock.patch.object(BedrockRuntimeHook, "conn") | ||
def test_invoke_model_prompt_good_combinations(self, mock_conn): | ||
mock_conn.invoke_model.return_value["body"].read.return_value = MOCK_RESPONSE | ||
operator = BedrockInvokeModelOperator( | ||
task_id="test_task", model_id=MODEL_ID, input_data={"input_data": {"prompt": PROMPT}} | ||
) | ||
|
||
response = operator.execute({}) | ||
|
||
assert response["generation"] == GENERATED_RESPONSE |
Oops, something went wrong.