Skip to content

Commit

Permalink
Add bedrock contract tests. (#227)
Browse files Browse the repository at this point in the history
The PR is a follow up on bedrock service support PR:

#209

We add contract tests for following Bedrock services that covers all
resource attributes we newly support:
1. Bedrock API: `GetGuardrail` 
2. BedrockAgent APIs: `GetAgent`, `GetDataSource`, `GetKnowledgeBase`
3. BedrockRuntime API: `InvokeModel`
4. BedrockAgentRuntime API: `InvokeAgent`

Upgrade `botocore` and `boto3` to the latest version `1.34.143` so that
to support Bedrock services API calls.
Upgrade `localstack/localstack` image to the latest version `3.5.0` so
resolve the SQS API call issue using `localstack/localstack:2.0.1` with
new version of `boto3`:
localstack/localstack#9610

**Contract test limitation:**
The contract tests in current repo is using
[LocalStackContainer](https://github.com/aws-observability/aws-otel-python-instrumentation/blob/912dd93ff19b7a4594bb9ed1a7d8cde4907a735d/contract-tests/tests/test/amazon/botocore/botocore_test.py#L68)
to serve AWS SDK service calls. But it doesn’t has bedrock related
service support (This is [the full service
list](https://docs.localstack.cloud/references/coverage/) it support.).
In this case, no matter which bedrock API we call in contract test, the
response will always be 4XX.

As a workaround, we inject `inject_200_success` and `inject_500_error`
directly into the API call to make sure we receive http response with
expected status code and attributes.

**`_assert_semantic_conventions_span_attributes` function change:**
In [_assert_semantic_conventions_span_attributes
function](https://github.com/aws-observability/aws-otel-python-instrumentation/blob/1753bbf2c3cd41778abc358a7f1e3199e48c7fa9/contract-tests/tests/test/amazon/botocore/botocore_test.py#L448)
it is checking the input `service` equals to `rpc.service`, however, we
pass the input service with
["remote_service"](https://github.com/aws-observability/aws-otel-python-instrumentation/blob/1753bbf2c3cd41778abc358a7f1e3199e48c7fa9/contract-tests/tests/test/amazon/botocore/botocore_test.py#L430),
where there is mismatch for example for Bedrock Agent Runtime: we have
`rpc_service="Bedrock Agent Runtime", remote_service="AWS::Bedrock". `
Thus we change to use `rpc_service` if it is provided by:
```
kwargs.get("rpc_service") if "rpc_service" in kwargs else kwargs.get("remote_service").split("::")[-1],
```



By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice.
  • Loading branch information
zzhlogin authored Aug 16, 2024
1 parent 1d0aedd commit 0d05ca7
Show file tree
Hide file tree
Showing 3 changed files with 255 additions and 4 deletions.
120 changes: 120 additions & 0 deletions contract-tests/images/applications/botocore/botocore_server.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
import atexit
import json
import os
import tempfile
from collections import namedtuple
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
from threading import Thread

Expand Down Expand Up @@ -41,6 +43,8 @@ def do_GET(self):
self._handle_sqs_request()
if self.in_path("kinesis"):
self._handle_kinesis_request()
if self.in_path("bedrock"):
self._handle_bedrock_request()

self._end_request(self.main_status)

Expand Down Expand Up @@ -203,6 +207,100 @@ def _handle_kinesis_request(self) -> None:
else:
set_main_status(404)

def _handle_bedrock_request(self) -> None:
# Localstack does not support Bedrock related services.
# we inject inject_200_success directly into the API call
# to make sure we receive http response with expected status code and attributes.
bedrock_client: BaseClient = boto3.client("bedrock", endpoint_url=_AWS_SDK_ENDPOINT, region_name=_AWS_REGION)
bedrock_agent_client: BaseClient = boto3.client(
"bedrock-agent", endpoint_url=_AWS_SDK_ENDPOINT, region_name=_AWS_REGION
)
bedrock_runtime_client: BaseClient = boto3.client(
"bedrock-runtime", endpoint_url=_AWS_SDK_ENDPOINT, region_name=_AWS_REGION
)
bedrock_agent_runtime_client: BaseClient = boto3.client(
"bedrock-agent-runtime", endpoint_url=_AWS_SDK_ENDPOINT, region_name=_AWS_REGION
)
if self.in_path("getknowledgebase/get_knowledge_base"):
set_main_status(200)
bedrock_agent_client.meta.events.register(
"before-call.bedrock-agent.GetKnowledgeBase",
inject_200_success,
)
bedrock_agent_client.get_knowledge_base(knowledgeBaseId="invalid-knowledge-base-id")
elif self.in_path("getdatasource/get_data_source"):
set_main_status(200)
bedrock_agent_client.meta.events.register(
"before-call.bedrock-agent.GetDataSource",
inject_200_success,
)
bedrock_agent_client.get_data_source(knowledgeBaseId="TESTKBSEID", dataSourceId="DATASURCID")
elif self.in_path("getagent/get-agent"):
set_main_status(200)
bedrock_agent_client.meta.events.register(
"before-call.bedrock-agent.GetAgent",
inject_200_success,
)
bedrock_agent_client.get_agent(agentId="TESTAGENTID")
elif self.in_path("getguardrail/get-guardrail"):
set_main_status(200)
bedrock_client.meta.events.register(
"before-call.bedrock.GetGuardrail",
lambda **kwargs: inject_200_success(guardrailId="bt4o77i015cu", **kwargs),
)
bedrock_client.get_guardrail(
guardrailIdentifier="arn:aws:bedrock:us-east-1:000000000000:guardrail/bt4o77i015cu"
)
elif self.in_path("invokeagent/invoke_agent"):
set_main_status(200)
bedrock_agent_runtime_client.meta.events.register(
"before-call.bedrock-agent-runtime.InvokeAgent",
inject_200_success,
)
bedrock_agent_runtime_client.invoke_agent(
agentId="Q08WFRPHVL",
agentAliasId="testAlias",
sessionId="testSessionId",
inputText="Invoke agent sample input text",
)
elif self.in_path("retrieve/retrieve"):
set_main_status(200)
bedrock_agent_runtime_client.meta.events.register(
"before-call.bedrock-agent-runtime.Retrieve",
inject_200_success,
)
bedrock_agent_runtime_client.retrieve(
knowledgeBaseId="test-knowledge-base-id",
retrievalQuery={
"text": "an example of retrieve query",
},
)
elif self.in_path("invokemodel/invoke-model"):
set_main_status(200)
bedrock_runtime_client.meta.events.register(
"before-call.bedrock-runtime.InvokeModel",
inject_200_success,
)
model_id = "amazon.titan-text-premier-v1:0"
user_message = "Describe the purpose of a 'hello world' program in one line."
prompt = f"<s>[INST] {user_message} [/INST]"
body = json.dumps(
{
"inputText": prompt,
"textGenerationConfig": {
"maxTokenCount": 3072,
"stopSequences": [],
"temperature": 0.7,
"topP": 0.9,
},
}
)
accept = "application/json"
content_type = "application/json"
bedrock_runtime_client.invoke_model(body=body, modelId=model_id, accept=accept, contentType=content_type)
else:
set_main_status(404)

def _end_request(self, status_code: int):
self.send_response_only(status_code)
self.end_headers()
Expand Down Expand Up @@ -251,6 +349,28 @@ def prepare_aws_server() -> None:
print("Unexpected exception occurred", exception)


def inject_200_success(**kwargs):
response_metadata = {
"HTTPStatusCode": 200,
"RequestId": "mock-request-id",
}

response_body = {
"Message": "Request succeeded",
"ResponseMetadata": response_metadata,
}

guardrail_id = kwargs.get("guardrailId")
if guardrail_id is not None:
response_body["guardrailId"] = guardrail_id

HTTPResponse = namedtuple("HTTPResponse", ["status_code", "headers", "body"])
headers = kwargs.get("headers", {})
body = kwargs.get("body", "")
http_response = HTTPResponse(200, headers=headers, body=body)
return http_response, response_body


def main() -> None:
prepare_aws_server()
server_address: Tuple[str, int] = ("0.0.0.0", _PORT)
Expand Down
4 changes: 2 additions & 2 deletions contract-tests/images/applications/botocore/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
opentelemetry-distro==0.46b0
opentelemetry-exporter-otlp-proto-grpc==1.25.0
typing-extensions==4.9.0
botocore==1.34.26
boto3==1.34.26
botocore==1.34.143
boto3==1.34.143
135 changes: 133 additions & 2 deletions contract-tests/tests/test/amazon/botocore/botocore_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@
_AWS_SQS_QUEUE_URL: str = "aws.sqs.queue.url"
_AWS_SQS_QUEUE_NAME: str = "aws.sqs.queue.name"
_AWS_KINESIS_STREAM_NAME: str = "aws.kinesis.stream.name"
_AWS_BEDROCK_AGENT_ID: str = "aws.bedrock.agent.id"
_AWS_BEDROCK_GUARDRAIL_ID: str = "aws.bedrock.guardrail.id"
_AWS_BEDROCK_KNOWLEDGE_BASE_ID: str = "aws.bedrock.knowledge_base.id"
_AWS_BEDROCK_DATA_SOURCE_ID: str = "aws.bedrock.data_source.id"
_GEN_AI_REQUEST_MODEL: str = "gen_ai.request.model"


# pylint: disable=too-many-public-methods
Expand Down Expand Up @@ -66,7 +71,7 @@ def set_up_dependency_container(cls):
)
}
cls._local_stack: LocalStackContainer = (
LocalStackContainer(image="localstack/localstack:2.0.1")
LocalStackContainer(image="localstack/localstack:3.5.0")
.with_name("localstack")
.with_services("s3", "sqs", "dynamodb", "kinesis")
.with_env("DEFAULT_REGION", "us-west-2")
Expand Down Expand Up @@ -372,6 +377,132 @@ def test_kinesis_fault(self):
span_name="Kinesis.PutRecord",
)

def test_bedrock_runtime_invoke_model(self):
self.do_test_requests(
"bedrock/invokemodel/invoke-model",
"GET",
200,
0,
0,
rpc_service="Bedrock Runtime",
remote_service="AWS::BedrockRuntime",
remote_operation="InvokeModel",
remote_resource_type="AWS::Bedrock::Model",
remote_resource_identifier="amazon.titan-text-premier-v1:0",
request_specific_attributes={
_GEN_AI_REQUEST_MODEL: "amazon.titan-text-premier-v1:0",
},
span_name="Bedrock Runtime.InvokeModel",
)

def test_bedrock_get_guardrail(self):
self.do_test_requests(
"bedrock/getguardrail/get-guardrail",
"GET",
200,
0,
0,
rpc_service="Bedrock",
remote_service="AWS::Bedrock",
remote_operation="GetGuardrail",
remote_resource_type="AWS::Bedrock::Guardrail",
remote_resource_identifier="bt4o77i015cu",
request_specific_attributes={
_AWS_BEDROCK_GUARDRAIL_ID: "bt4o77i015cu",
},
span_name="Bedrock.GetGuardrail",
)

def test_bedrock_agent_runtime_invoke_agent(self):
self.do_test_requests(
"bedrock/invokeagent/invoke_agent",
"GET",
200,
0,
0,
rpc_service="Bedrock Agent Runtime",
remote_service="AWS::Bedrock",
remote_operation="InvokeAgent",
remote_resource_type="AWS::Bedrock::Agent",
remote_resource_identifier="Q08WFRPHVL",
request_specific_attributes={
_AWS_BEDROCK_AGENT_ID: "Q08WFRPHVL",
},
span_name="Bedrock Agent Runtime.InvokeAgent",
)

def test_bedrock_agent_runtime_retrieve(self):
self.do_test_requests(
"bedrock/retrieve/retrieve",
"GET",
200,
0,
0,
rpc_service="Bedrock Agent Runtime",
remote_service="AWS::Bedrock",
remote_operation="Retrieve",
remote_resource_type="AWS::Bedrock::KnowledgeBase",
remote_resource_identifier="test-knowledge-base-id",
request_specific_attributes={
_AWS_BEDROCK_KNOWLEDGE_BASE_ID: "test-knowledge-base-id",
},
span_name="Bedrock Agent Runtime.Retrieve",
)

def test_bedrock_agent_get_agent(self):
self.do_test_requests(
"bedrock/getagent/get-agent",
"GET",
200,
0,
0,
rpc_service="Bedrock Agent",
remote_service="AWS::Bedrock",
remote_operation="GetAgent",
remote_resource_type="AWS::Bedrock::Agent",
remote_resource_identifier="TESTAGENTID",
request_specific_attributes={
_AWS_BEDROCK_AGENT_ID: "TESTAGENTID",
},
span_name="Bedrock Agent.GetAgent",
)

def test_bedrock_agent_get_knowledge_base(self):
self.do_test_requests(
"bedrock/getknowledgebase/get_knowledge_base",
"GET",
200,
0,
0,
rpc_service="Bedrock Agent",
remote_service="AWS::Bedrock",
remote_operation="GetKnowledgeBase",
remote_resource_type="AWS::Bedrock::KnowledgeBase",
remote_resource_identifier="invalid-knowledge-base-id",
request_specific_attributes={
_AWS_BEDROCK_KNOWLEDGE_BASE_ID: "invalid-knowledge-base-id",
},
span_name="Bedrock Agent.GetKnowledgeBase",
)

def test_bedrock_agent_get_data_source(self):
self.do_test_requests(
"bedrock/getdatasource/get_data_source",
"GET",
200,
0,
0,
rpc_service="Bedrock Agent",
remote_service="AWS::Bedrock",
remote_operation="GetDataSource",
remote_resource_type="AWS::Bedrock::DataSource",
remote_resource_identifier="DATASURCID",
request_specific_attributes={
_AWS_BEDROCK_DATA_SOURCE_ID: "DATASURCID",
},
span_name="Bedrock Agent.GetDataSource",
)

@override
def _assert_aws_span_attributes(self, resource_scope_spans: List[ResourceScopeSpan], path: str, **kwargs) -> None:
target_spans: List[Span] = []
Expand Down Expand Up @@ -427,7 +558,7 @@ def _assert_semantic_conventions_span_attributes(
self.assertEqual(target_spans[0].name, kwargs.get("span_name"))
self._assert_semantic_conventions_attributes(
target_spans[0].attributes,
kwargs.get("remote_service"),
kwargs.get("rpc_service") if "rpc_service" in kwargs else kwargs.get("remote_service").split("::")[-1],
kwargs.get("remote_operation"),
status_code,
kwargs.get("request_specific_attributes", {}),
Expand Down

0 comments on commit 0d05ca7

Please sign in to comment.