Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add bedrock contract tests. #227

Merged
merged 9 commits into from
Aug 16, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 120 additions & 0 deletions contract-tests/images/applications/botocore/botocore_server.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
import atexit
import json
import os
import tempfile
from collections import namedtuple
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
from threading import Thread

Expand Down Expand Up @@ -41,6 +43,8 @@ def do_GET(self):
self._handle_sqs_request()
if self.in_path("kinesis"):
self._handle_kinesis_request()
if self.in_path("bedrock"):
self._handle_bedrock_request()

self._end_request(self.main_status)

Expand Down Expand Up @@ -203,6 +207,100 @@ def _handle_kinesis_request(self) -> None:
else:
set_main_status(404)

def _handle_bedrock_request(self) -> None:
# Localstack does not support Bedrock related services.
# we inject inject_200_success directly into the API call
# to make sure we receive http response with expected status code and attributes.
bedrock_client: BaseClient = boto3.client("bedrock", endpoint_url=_AWS_SDK_ENDPOINT, region_name=_AWS_REGION)
bedrock_agent_client: BaseClient = boto3.client(
"bedrock-agent", endpoint_url=_AWS_SDK_ENDPOINT, region_name=_AWS_REGION
)
bedrock_runtime_client: BaseClient = boto3.client(
"bedrock-runtime", endpoint_url=_AWS_SDK_ENDPOINT, region_name=_AWS_REGION
)
bedrock_agent_runtime_client: BaseClient = boto3.client(
"bedrock-agent-runtime", endpoint_url=_AWS_SDK_ENDPOINT, region_name=_AWS_REGION
)
if self.in_path("getknowledgebase/get_knowledge_base"):
set_main_status(200)
bedrock_agent_client.meta.events.register(
"before-call.bedrock-agent.GetKnowledgeBase",
inject_200_success,
)
bedrock_agent_client.get_knowledge_base(knowledgeBaseId="invalid-knowledge-base-id")
elif self.in_path("getdatasource/get_data_source"):
set_main_status(200)
bedrock_agent_client.meta.events.register(
"before-call.bedrock-agent.GetDataSource",
inject_200_success,
)
bedrock_agent_client.get_data_source(knowledgeBaseId="TESTKBSEID", dataSourceId="DATASURCID")
elif self.in_path("getagent/get-agent"):
set_main_status(200)
bedrock_agent_client.meta.events.register(
"before-call.bedrock-agent.GetAgent",
inject_200_success,
)
bedrock_agent_client.get_agent(agentId="TESTAGENTID")
elif self.in_path("getguardrail/get-guardrail"):
set_main_status(200)
bedrock_client.meta.events.register(
"before-call.bedrock.GetGuardrail",
lambda **kwargs: inject_200_success(guardrailId="bt4o77i015cu", **kwargs),
)
bedrock_client.get_guardrail(
guardrailIdentifier="arn:aws:bedrock:us-east-1:000000000000:guardrail/bt4o77i015cu"
)
elif self.in_path("invokeagent/invoke_agent"):
set_main_status(200)
bedrock_agent_runtime_client.meta.events.register(
"before-call.bedrock-agent-runtime.InvokeAgent",
inject_200_success,
)
bedrock_agent_runtime_client.invoke_agent(
agentId="Q08WFRPHVL",
agentAliasId="testAlias",
sessionId="testSessionId",
inputText="Invoke agent sample input text",
)
elif self.in_path("retrieve/retrieve"):
set_main_status(200)
bedrock_agent_runtime_client.meta.events.register(
"before-call.bedrock-agent-runtime.Retrieve",
inject_200_success,
)
bedrock_agent_runtime_client.retrieve(
knowledgeBaseId="test-knowledge-base-id",
retrievalQuery={
"text": "an example of retrieve query",
},
)
elif self.in_path("invokemodel/invoke-model"):
set_main_status(200)
bedrock_runtime_client.meta.events.register(
"before-call.bedrock-runtime.InvokeModel",
inject_200_success,
)
model_id = "amazon.titan-text-premier-v1:0"
user_message = "Describe the purpose of a 'hello world' program in one line."
prompt = f"<s>[INST] {user_message} [/INST]"
body = json.dumps(
{
"inputText": prompt,
"textGenerationConfig": {
"maxTokenCount": 3072,
"stopSequences": [],
"temperature": 0.7,
"topP": 0.9,
},
}
)
accept = "application/json"
content_type = "application/json"
bedrock_runtime_client.invoke_model(body=body, modelId=model_id, accept=accept, contentType=content_type)
else:
set_main_status(404)

def _end_request(self, status_code: int):
self.send_response_only(status_code)
self.end_headers()
Expand Down Expand Up @@ -251,6 +349,28 @@ def prepare_aws_server() -> None:
print("Unexpected exception occurred", exception)


def inject_200_success(**kwargs):
response_metadata = {
"HTTPStatusCode": 200,
"RequestId": "mock-request-id",
}

response_body = {
"Message": "Request succeeded",
"ResponseMetadata": response_metadata,
}

guardrail_id = kwargs.get("guardrailId")
if guardrail_id is not None:
response_body["guardrailId"] = guardrail_id

HTTPResponse = namedtuple("HTTPResponse", ["status_code", "headers", "body"])
headers = kwargs.get("headers", {})
body = kwargs.get("body", "")
http_response = HTTPResponse(200, headers=headers, body=body)
return http_response, response_body


def main() -> None:
prepare_aws_server()
server_address: Tuple[str, int] = ("0.0.0.0", _PORT)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
opentelemetry-distro==0.46b0
opentelemetry-exporter-otlp-proto-grpc==1.25.0
typing-extensions==4.9.0
botocore==1.34.26
boto3==1.34.26
botocore==1.34.143
boto3==1.34.143
135 changes: 133 additions & 2 deletions contract-tests/tests/test/amazon/botocore/botocore_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@
_AWS_SQS_QUEUE_URL: str = "aws.sqs.queue.url"
_AWS_SQS_QUEUE_NAME: str = "aws.sqs.queue.name"
_AWS_KINESIS_STREAM_NAME: str = "aws.kinesis.stream.name"
_AWS_BEDROCK_AGENT_ID: str = "aws.bedrock.agent.id"
_AWS_BEDROCK_GUARDRAIL_ID: str = "aws.bedrock.guardrail.id"
_AWS_BEDROCK_KNOWLEDGE_BASE_ID: str = "aws.bedrock.knowledge_base.id"
_AWS_BEDROCK_DATA_SOURCE_ID: str = "aws.bedrock.data_source.id"
_GEN_AI_REQUEST_MODEL: str = "gen_ai.request.model"


# pylint: disable=too-many-public-methods
Expand Down Expand Up @@ -66,7 +71,7 @@ def set_up_dependency_container(cls):
)
}
cls._local_stack: LocalStackContainer = (
LocalStackContainer(image="localstack/localstack:2.0.1")
LocalStackContainer(image="localstack/localstack:3.5.0")
.with_name("localstack")
.with_services("s3", "sqs", "dynamodb", "kinesis")
.with_env("DEFAULT_REGION", "us-west-2")
Expand Down Expand Up @@ -372,6 +377,132 @@ def test_kinesis_fault(self):
span_name="Kinesis.PutRecord",
)

def test_bedrock_runtime_invoke_model(self):
self.do_test_requests(
"bedrock/invokemodel/invoke-model",
"GET",
200,
0,
0,
rpc_service="Bedrock Runtime",
zzhlogin marked this conversation as resolved.
Show resolved Hide resolved
remote_service="AWS::BedrockRuntime",
remote_operation="InvokeModel",
remote_resource_type="AWS::Bedrock::Model",
remote_resource_identifier="amazon.titan-text-premier-v1:0",
request_specific_attributes={
_GEN_AI_REQUEST_MODEL: "amazon.titan-text-premier-v1:0",
},
span_name="Bedrock Runtime.InvokeModel",
)

def test_bedrock_get_guardrail(self):
self.do_test_requests(
"bedrock/getguardrail/get-guardrail",
"GET",
200,
0,
0,
rpc_service="Bedrock",
remote_service="AWS::Bedrock",
remote_operation="GetGuardrail",
remote_resource_type="AWS::Bedrock::Guardrail",
remote_resource_identifier="bt4o77i015cu",
request_specific_attributes={
_AWS_BEDROCK_GUARDRAIL_ID: "bt4o77i015cu",
},
span_name="Bedrock.GetGuardrail",
)

def test_bedrock_agent_runtime_invoke_agent(self):
self.do_test_requests(
"bedrock/invokeagent/invoke_agent",
"GET",
200,
0,
0,
rpc_service="Bedrock Agent Runtime",
remote_service="AWS::Bedrock",
remote_operation="InvokeAgent",
remote_resource_type="AWS::Bedrock::Agent",
remote_resource_identifier="Q08WFRPHVL",
request_specific_attributes={
_AWS_BEDROCK_AGENT_ID: "Q08WFRPHVL",
},
span_name="Bedrock Agent Runtime.InvokeAgent",
)

def test_bedrock_agent_runtime_retrieve(self):
self.do_test_requests(
"bedrock/retrieve/retrieve",
"GET",
200,
0,
0,
rpc_service="Bedrock Agent Runtime",
remote_service="AWS::Bedrock",
remote_operation="Retrieve",
remote_resource_type="AWS::Bedrock::KnowledgeBase",
remote_resource_identifier="test-knowledge-base-id",
request_specific_attributes={
_AWS_BEDROCK_KNOWLEDGE_BASE_ID: "test-knowledge-base-id",
},
span_name="Bedrock Agent Runtime.Retrieve",
)

def test_bedrock_agent_get_agent(self):
self.do_test_requests(
"bedrock/getagent/get-agent",
"GET",
200,
0,
0,
rpc_service="Bedrock Agent",
remote_service="AWS::Bedrock",
remote_operation="GetAgent",
remote_resource_type="AWS::Bedrock::Agent",
remote_resource_identifier="TESTAGENTID",
request_specific_attributes={
_AWS_BEDROCK_AGENT_ID: "TESTAGENTID",
},
span_name="Bedrock Agent.GetAgent",
)

def test_bedrock_agent_get_knowledge_base(self):
self.do_test_requests(
"bedrock/getknowledgebase/get_knowledge_base",
"GET",
200,
0,
0,
rpc_service="Bedrock Agent",
remote_service="AWS::Bedrock",
remote_operation="GetKnowledgeBase",
remote_resource_type="AWS::Bedrock::KnowledgeBase",
remote_resource_identifier="invalid-knowledge-base-id",
request_specific_attributes={
_AWS_BEDROCK_KNOWLEDGE_BASE_ID: "invalid-knowledge-base-id",
},
span_name="Bedrock Agent.GetKnowledgeBase",
)

def test_bedrock_agent_get_data_source(self):
self.do_test_requests(
"bedrock/getdatasource/get_data_source",
"GET",
200,
0,
0,
rpc_service="Bedrock Agent",
remote_service="AWS::Bedrock",
remote_operation="GetDataSource",
remote_resource_type="AWS::Bedrock::DataSource",
remote_resource_identifier="DATASURCID",
request_specific_attributes={
_AWS_BEDROCK_DATA_SOURCE_ID: "DATASURCID",
},
span_name="Bedrock Agent.GetDataSource",
)

@override
def _assert_aws_span_attributes(self, resource_scope_spans: List[ResourceScopeSpan], path: str, **kwargs) -> None:
target_spans: List[Span] = []
Expand Down Expand Up @@ -427,7 +558,7 @@ def _assert_semantic_conventions_span_attributes(
self.assertEqual(target_spans[0].name, kwargs.get("span_name"))
self._assert_semantic_conventions_attributes(
target_spans[0].attributes,
kwargs.get("remote_service"),
kwargs.get("rpc_service") if "rpc_service" in kwargs else kwargs.get("remote_service").split("::")[-1],
kwargs.get("remote_operation"),
status_code,
kwargs.get("request_specific_attributes", {}),
Expand Down
Loading