Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ dependencies = [
"ujson>=5.10.0",
"boto3==1.38.0",
"setuptools",
"Deprecated==1.2.18",
]

requires-python = ">=3.9"
Expand All @@ -47,7 +48,7 @@ dev = [
"qdrant-client",
"graphlit-client",
"python-dotenv",
"pinecone",
"pinecone>=3.1.0,<=6.0.2",
"langchain",
"langchain-community",
"langchain-openai",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
"METHOD": "aws_bedrock.invoke_model",
"ENDPOINT": "/invoke-model",
},
"INVOKE_MODEL_WITH_RESPONSE_STREAM": {
"METHOD": "aws_bedrock.invoke_model_with_response_stream",
"ENDPOINT": "/invoke-model-with-response-stream",
},
"CONVERSE": {
"METHOD": AWSBedrockMethods.CONVERSE.value,
"ENDPOINT": "/converse",
Expand Down
93 changes: 75 additions & 18 deletions src/langtrace_python_sdk/instrumentation/aws_bedrock/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@
"""

import json
import io

from wrapt import ObjectProxy
from itertools import tee
from .stream_body_wrapper import BufferedStreamBody
from functools import wraps
from langtrace.trace_attributes import (
Expand All @@ -43,6 +45,7 @@
set_span_attributes,
set_usage_attributes,
)
from langtrace_python_sdk.utils import set_event_prompt


def converse_stream(original_method, version, tracer):
Expand Down Expand Up @@ -104,7 +107,7 @@ def traced_method(wrapped, instance, args, kwargs):
def patch_converse_stream(original_method, tracer, version):
def traced_method(*args, **kwargs):
modelId = kwargs.get("modelId")
(vendor, _) = modelId.split(".")
vendor, _ = parse_vendor_and_model_name_from_model_id(modelId)
input_content = [
{
"role": message.get("role", "user"),
Expand All @@ -128,7 +131,9 @@ def traced_method(*args, **kwargs):
response = original_method(*args, **kwargs)

if span.is_recording():
set_span_streaming_response(span, response)
stream1, stream2 = tee(response["stream"])
set_span_streaming_response(span, stream1)
response["stream"] = stream2
return response

return traced_method
Expand All @@ -137,7 +142,7 @@ def traced_method(*args, **kwargs):
def patch_converse(original_method, tracer, version):
def traced_method(*args, **kwargs):
modelId = kwargs.get("modelId")
(vendor, _) = modelId.split(".")
vendor, _ = parse_vendor_and_model_name_from_model_id(modelId)
input_content = [
{
"role": message.get("role", "user"),
Expand Down Expand Up @@ -167,12 +172,29 @@ def traced_method(*args, **kwargs):
return traced_method


def parse_vendor_and_model_name_from_model_id(model_id):
if model_id.startswith("arn:aws:bedrock:"):
# This needs to be in one of the following forms:
# arn:aws:bedrock:region:account-id:foundation-model/vendor.model-name
# arn:aws:bedrock:region:account-id:custom-model/vendor.model-name/model-id
parts = model_id.split("/")
identifiers = parts[1].split(".")
return identifiers[0], identifiers[1]
parts = model_id.split(".")
if len(parts) == 1:
return parts[0], parts[0]
else:
return parts[-2], parts[-1]


def patch_invoke_model(original_method, tracer, version):
def traced_method(*args, **kwargs):
modelId = kwargs.get("modelId")
(vendor, _) = modelId.split(".")
vendor, _ = parse_vendor_and_model_name_from_model_id(modelId)
span_attributes = {
**get_langtrace_attributes(version, vendor, vendor_type="framework"),
SpanAttributes.LLM_PATH: APIS["INVOKE_MODEL"]["ENDPOINT"],
SpanAttributes.LLM_IS_STREAMING: False,
**get_extra_attributes(),
}
with tracer.start_as_current_span(
Expand All @@ -193,9 +215,11 @@ def patch_invoke_model_with_response_stream(original_method, tracer, version):
@wraps(original_method)
def traced_method(*args, **kwargs):
modelId = kwargs.get("modelId")
(vendor, _) = modelId.split(".")
vendor, _ = parse_vendor_and_model_name_from_model_id(modelId)
span_attributes = {
**get_langtrace_attributes(version, vendor, vendor_type="framework"),
SpanAttributes.LLM_PATH: APIS["INVOKE_MODEL_WITH_RESPONSE_STREAM"]["ENDPOINT"],
SpanAttributes.LLM_IS_STREAMING: True,
**get_extra_attributes(),
}
span = tracer.start_span(
Expand All @@ -217,7 +241,7 @@ def handle_streaming_call(span, kwargs, response):
def stream_finished(response_body):
request_body = json.loads(kwargs.get("body"))

(vendor, model) = kwargs.get("modelId").split(".")
vendor, model = parse_vendor_and_model_name_from_model_id(kwargs.get("modelId"))

set_span_attribute(span, SpanAttributes.LLM_REQUEST_MODEL, model)
set_span_attribute(span, SpanAttributes.LLM_RESPONSE_MODEL, model)
Expand All @@ -241,18 +265,22 @@ def stream_finished(response_body):

def handle_call(span, kwargs, response):
modelId = kwargs.get("modelId")
(vendor, model_name) = modelId.split(".")
vendor, model_name = parse_vendor_and_model_name_from_model_id(modelId)
read_response_body = response.get("body").read()
request_body = json.loads(kwargs.get("body"))
response_body = json.loads(read_response_body)
response["body"] = BufferedStreamBody(
response["body"]._raw_stream, response["body"]._content_length
io.BytesIO(read_response_body), len(read_response_body)
)
request_body = json.loads(kwargs.get("body"))
response_body = json.loads(response.get("body").read())

set_span_attribute(span, SpanAttributes.LLM_RESPONSE_MODEL, modelId)
set_span_attribute(span, SpanAttributes.LLM_REQUEST_MODEL, modelId)

if vendor == "amazon":
set_amazon_attributes(span, request_body, response_body)
if model_name.startswith("titan-embed-text"):
set_amazon_embedding_attributes(span, request_body, response_body)
else:
set_amazon_attributes(span, request_body, response_body)

if vendor == "anthropic":
if "prompt" in request_body:
Expand Down Expand Up @@ -356,6 +384,27 @@ def set_amazon_attributes(span, request_body, response_body):
set_event_completion(span, completions)


def set_amazon_embedding_attributes(span, request_body, response_body):
input_text = request_body.get("inputText")
set_event_prompt(span, input_text)

embeddings = response_body.get("embedding", [])
input_tokens = response_body.get("inputTextTokenCount")
set_usage_attributes(
span,
{
"input_tokens": input_tokens,
"output": len(embeddings),
},
)
set_span_attribute(
span, SpanAttributes.LLM_REQUEST_MODEL, request_body.get("modelId")
)
set_span_attribute(
span, SpanAttributes.LLM_RESPONSE_MODEL, request_body.get("modelId")
)


def set_anthropic_completions_attributes(span, request_body, response_body):
set_span_attribute(
span,
Expand Down Expand Up @@ -442,10 +491,10 @@ def _set_response_attributes(span, kwargs, result):
)


def set_span_streaming_response(span, response):
def set_span_streaming_response(span, response_stream):
streaming_response = ""
role = None
for event in response["stream"]:
for event in response_stream:
if "messageStart" in event:
role = event["messageStart"]["role"]
elif "contentBlockDelta" in event:
Expand Down Expand Up @@ -475,13 +524,15 @@ def __init__(
stream_done_callback=None,
):
super().__init__(response)

self._stream_done_callback = stream_done_callback
self._accumulating_body = {"generation": ""}
self.last_chunk = None

def __iter__(self):
for event in self.__wrapped__:
# Process the event
self._process_event(event)
# Yield the original event immediately
yield event

def _process_event(self, event):
Expand All @@ -496,7 +547,11 @@ def _process_event(self, event):
self._stream_done_callback(decoded_chunk)
return
if "generation" in decoded_chunk:
self._accumulating_body["generation"] += decoded_chunk.get("generation")
generation = decoded_chunk.get("generation")
if self.last_chunk == generation:
return
self.last_chunk = generation
self._accumulating_body["generation"] += generation

if type == "message_start":
self._accumulating_body = decoded_chunk.get("message")
Expand All @@ -505,9 +560,11 @@ def _process_event(self, event):
decoded_chunk.get("content_block")
)
elif type == "content_block_delta":
self._accumulating_body["content"][-1]["text"] += decoded_chunk.get(
"delta"
).get("text")
text = decoded_chunk.get("delta").get("text")
if self.last_chunk == text:
return
self.last_chunk = text
self._accumulating_body["content"][-1]["text"] += text

elif self.has_finished(type, decoded_chunk):
self._accumulating_body["invocation_metrics"] = decoded_chunk.get(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class PineconeInstrumentation(BaseInstrumentor):
The PineconeInstrumentation class represents the Pinecone instrumentation"""

def instrumentation_dependencies(self) -> Collection[str]:
return ["pinecone >= 3.1.0"]
return ["pinecone >= 3.1.0", "pinecone <= 6.0.2"]

def _instrument(self, **kwargs):
tracer_provider = kwargs.get("tracer_provider")
Expand Down
2 changes: 1 addition & 1 deletion src/langtrace_python_sdk/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "3.8.18"
__version__ = "3.8.21"
45 changes: 45 additions & 0 deletions src/tests/aws_bedrock/cassettes/test_chat_completion.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
interactions:
- request:
body: '{"messages": [{"role": "user", "content": "Say this is a test three times"}],
"anthropic_version": "bedrock-2023-05-31", "max_tokens": 100}'
headers:
Accept:
- !!binary |
YXBwbGljYXRpb24vanNvbg==
Content-Length:
- '139'
Content-Type:
- !!binary |
YXBwbGljYXRpb24vanNvbg==
User-Agent:
- !!binary |
Qm90bzMvMS4zOC4xOCBtZC9Cb3RvY29yZSMxLjM4LjE4IHVhLzIuMSBvcy9tYWNvcyMyNC40LjAg
bWQvYXJjaCNhcm02NCBsYW5nL3B5dGhvbiMzLjEzLjEgbWQvcHlpbXBsI0NQeXRob24gbS9aLGIg
Y2ZnL3JldHJ5LW1vZGUjc3RhbmRhcmQgQm90b2NvcmUvMS4zOC4xOA==
method: POST
uri: https://bedrock-runtime.us-east-1.amazonaws.com/model/us.anthropic.claude-3-7-sonnet-20250219-v1%3A0/invoke
response:
body:
string: '{"id":"msg_bdrk_01NJB1bDTLkFh6pgfoAD5hkb","type":"message","role":"assistant","model":"claude-3-7-sonnet-20250219","content":[{"type":"text","text":"This
is a test.\nThis is a test.\nThis is a test."}],"stop_reason":"end_turn","stop_sequence":null,"usage":{"input_tokens":14,"cache_creation_input_tokens":0,"cache_read_input_tokens":0,"output_tokens":20}}'
headers:
Connection:
- keep-alive
Content-Length:
- '355'
Content-Type:
- application/json
Date:
- Mon, 19 May 2025 16:42:05 GMT
X-Amzn-Bedrock-Input-Token-Count:
- '14'
X-Amzn-Bedrock-Invocation-Latency:
- '926'
X-Amzn-Bedrock-Output-Token-Count:
- '20'
x-amzn-RequestId:
- c0a92363-ec28-4a8b-9c09-571131d946b0
status:
code: 200
message: OK
version: 1
41 changes: 41 additions & 0 deletions src/tests/aws_bedrock/cassettes/test_generate_embedding.yaml

Large diffs are not rendered by default.

48 changes: 48 additions & 0 deletions src/tests/aws_bedrock/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
"""Unit tests configuration module."""

import pytest
import os

from boto3.session import Session
from botocore.config import Config

from langtrace_python_sdk.instrumentation.aws_bedrock.instrumentation import (
AWSBedrockInstrumentation,
)


@pytest.fixture(autouse=True)
def environment():
if not os.getenv("AWS_ACCESS_KEY_ID"):
os.environ["AWS_ACCESS_KEY_ID"] = "test_aws_access_key_id"
if not os.getenv("AWS_SECRET_ACCESS_KEY"):
os.environ["AWS_SECRET_ACCESS_KEY"] = "test_aws_secret_access_key"


@pytest.fixture
def aws_bedrock_client():
bedrock_config = Config(
region_name="us-east-1",
connect_timeout=300,
read_timeout=300,
retries={"total_max_attempts": 2, "mode": "standard"},
)
return Session().client("bedrock-runtime", config=bedrock_config)


@pytest.fixture(scope="module")
def vcr_config():
return {
"filter_headers": [
"authorization",
"X-Amz-Date",
"X-Amz-Security-Token",
"amz-sdk-invocation-id",
"amz-sdk-request",
]
}


@pytest.fixture(scope="session", autouse=True)
def instrument():
AWSBedrockInstrumentation().instrument()
Loading
Loading