Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
595 changes: 467 additions & 128 deletions engine/Cargo.lock

Large diffs are not rendered by default.

60 changes: 54 additions & 6 deletions engine/baml-lib/llm-client/src/clients/aws_bedrock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,16 @@ use crate::{
UnresolvedAllowedRoleMetadata, UnresolvedFinishReasonFilter, UnresolvedRolesSelection,
};
use anyhow::Result;
use indexmap::IndexMap;
use secrecy::SecretString;

use baml_types::{ApiKeyWithProvenance, EvaluationContext, GetEnvVar, StringOr};
use baml_types::{ApiKeyWithProvenance, EvaluationContext, GetEnvVar, StringOr, UnresolvedValue};
use serde_json::Value;

use super::helpers::{Error, PropertyHandler};

#[derive(Debug, Clone)]
pub struct UnresolvedAwsBedrock {
pub struct UnresolvedAwsBedrock<Meta> {
model: Option<StringOr>,
region: Option<StringOr>,
access_key_id: Option<StringOr>,
Expand All @@ -24,6 +26,7 @@ pub struct UnresolvedAwsBedrock {
supported_request_modes: SupportedRequestModes,
inference_config: Option<UnresolvedInferenceConfiguration>,
finish_reason_filter: UnresolvedFinishReasonFilter,
additional_model_request_fields: Option<IndexMap<String, (Meta, UnresolvedValue<Meta>)>>,
}

#[derive(Debug, Clone)]
Expand Down Expand Up @@ -76,6 +79,7 @@ pub struct ResolvedAwsBedrock {
pub allowed_role_metadata: AllowedRoleMetadata,
pub supported_request_modes: SupportedRequestModes,
pub finish_reason_filter: FinishReasonFilter,
pub additional_model_request_fields: Option<IndexMap<String, Value>>,
}

impl std::fmt::Debug for ResolvedAwsBedrock {
Expand All @@ -92,6 +96,10 @@ impl std::fmt::Debug for ResolvedAwsBedrock {
.field("allowed_role_metadata", &self.allowed_role_metadata)
.field("supported_request_modes", &self.supported_request_modes)
.field("finish_reason_filter", &self.finish_reason_filter)
.field(
"additional_model_request_fields",
&self.additional_model_request_fields,
)
.finish()
}
}
Expand Down Expand Up @@ -122,7 +130,33 @@ impl ResolvedAwsBedrock {
}
}

impl UnresolvedAwsBedrock {
impl<Meta: Clone> UnresolvedAwsBedrock<Meta> {
pub fn without_meta(&self) -> UnresolvedAwsBedrock<()> {
UnresolvedAwsBedrock {
model: self.model.clone(),
region: self.region.clone(),
access_key_id: self.access_key_id.clone(),
secret_access_key: self.secret_access_key.clone(),
session_token: self.session_token.clone(),
profile: self.profile.clone(),
role_selection: self.role_selection.clone(),
allowed_role_metadata: self.allowed_role_metadata.clone(),
supported_request_modes: self.supported_request_modes.clone(),
inference_config: self.inference_config.clone(),
finish_reason_filter: self.finish_reason_filter.clone(),
additional_model_request_fields: self.additional_model_request_fields.as_ref().map(
|fields| {
fields
.iter()
.map(|(k, (_, v))| (k.clone(), ((), v.without_meta())))
.collect::<IndexMap<_, _>>()
},
),
}
}
}

impl<Meta: Clone> UnresolvedAwsBedrock<Meta> {
pub fn required_env_vars(&self) -> HashSet<String> {
let mut env_vars = HashSet::new();
if let Some(m) = self.model.as_ref() {
Expand Down Expand Up @@ -309,6 +343,17 @@ impl UnresolvedAwsBedrock {
}
}

let additional_model_request_fields = self
.additional_model_request_fields
.as_ref()
.map(|fields| {
fields
.iter()
.map(|(k, (_, v))| Ok((k.clone(), v.resolve_serde::<serde_json::Value>(ctx)?)))
.collect::<Result<IndexMap<_, _>>>()
})
.transpose()?;

Ok(ResolvedAwsBedrock {
model: model.resolve(ctx)?,
region,
Expand All @@ -325,12 +370,11 @@ impl UnresolvedAwsBedrock {
.map(|c| c.resolve(ctx))
.transpose()?,
finish_reason_filter: self.finish_reason_filter.resolve(ctx)?,
additional_model_request_fields,
})
}

pub fn create_from<Meta: Clone>(
mut properties: PropertyHandler<Meta>,
) -> Result<Self, Vec<Error<Meta>>> {
pub fn create_from(mut properties: PropertyHandler<Meta>) -> Result<Self, Vec<Error<Meta>>> {
let model = {
// Add AWS Bedrock-specific validation logic here
let model_id = properties.ensure_string("model_id", false);
Expand Down Expand Up @@ -374,6 +418,9 @@ impl UnresolvedAwsBedrock {
let role_selection = properties.ensure_roles_selection();
let allowed_metadata = properties.ensure_allowed_metadata();
let supported_request_modes = properties.ensure_supported_request_modes();
let additional_model_request_fields = properties
.ensure_map("additional_model_request_fields", false)
.map(|(_, map, _)| map);

let inference_config = {
let mut inference_config = UnresolvedInferenceConfiguration {
Expand Down Expand Up @@ -452,6 +499,7 @@ impl UnresolvedAwsBedrock {
supported_request_modes,
inference_config,
finish_reason_filter,
additional_model_request_fields,
})
}
}
4 changes: 2 additions & 2 deletions engine/baml-lib/llm-client/src/clients/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ pub mod vertex;
pub enum UnresolvedClientProperty<Meta> {
OpenAI(openai::UnresolvedOpenAI<Meta>),
Anthropic(anthropic::UnresolvedAnthropic<Meta>),
AWSBedrock(aws_bedrock::UnresolvedAwsBedrock),
AWSBedrock(aws_bedrock::UnresolvedAwsBedrock<Meta>),
Vertex(vertex::UnresolvedVertex<Meta>),
GoogleAI(google_ai::UnresolvedGoogleAI<Meta>),
RoundRobin(round_robin::UnresolvedRoundRobin<Meta>),
Expand Down Expand Up @@ -106,7 +106,7 @@ impl<Meta: Clone> UnresolvedClientProperty<Meta> {
UnresolvedClientProperty::Anthropic(a.without_meta())
}
UnresolvedClientProperty::AWSBedrock(a) => {
UnresolvedClientProperty::AWSBedrock(a.clone())
UnresolvedClientProperty::AWSBedrock(a.without_meta())
}
UnresolvedClientProperty::Vertex(v) => {
UnresolvedClientProperty::Vertex(v.without_meta())
Expand Down
8 changes: 4 additions & 4 deletions engine/baml-runtime/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,8 @@ tempfile = "3.19.0"


[target.'cfg(target_arch = "wasm32")'.dependencies]
aws-config = { version = "1.5.3", default-features = false, features = [] }
aws-sdk-bedrockruntime = { version = "1.37.0", default-features = false, features = [
aws-config = { version = "1.6.2", default-features = false, features = [] }
aws-sdk-bedrockruntime = { version = "1.85.0", default-features = false, features = [
] }
colored = { version = "2.1.0", default-features = false, features = [
"no-color",
Expand Down Expand Up @@ -145,8 +145,8 @@ web-sys = { version = "0.3.69", features = [
wasmtimer = "0.4.1"

[target.'cfg(not(target_arch = "wasm32"))'.dependencies]
aws-config = "1.5.3"
aws-sdk-bedrockruntime = "1.37.0"
aws-config = "1.6.2"
aws-sdk-bedrockruntime = "1.85.0"
axum = "0.7.5"
axum-extra = { version = "0.9.3", features = ["erased-json", "typed-header"] }
criterion = "0.5.1"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;

use aws_config::Region;
use aws_config::{identity::IdentityCache, retry::RetryConfig, BehaviorVersion, ConfigLoader};
Expand All @@ -10,7 +11,7 @@ use aws_credential_types::{
},
Credentials,
};
use aws_sdk_bedrockruntime::config::Intercept;
use aws_sdk_bedrockruntime::config::{Intercept, StalledStreamProtectionConfig};
use aws_sdk_bedrockruntime::Client as BedrockRuntimeClient;
use aws_sdk_bedrockruntime::{self as bedrock, operation::converse::ConverseOutput};

Expand All @@ -19,6 +20,7 @@ use aws_smithy_json::serialize::JsonObjectWriter;
use aws_smithy_runtime_api::client::result::SdkError;
use aws_smithy_runtime_api::http::Headers;
use aws_smithy_types::Blob;
use aws_smithy_types::Document;
use baml_types::tracing::events::{
ContentId, FunctionId, HTTPBody, HTTPRequest, HTTPResponse, HttpRequestId, TraceData,
TraceEvent, TraceLevel,
Expand Down Expand Up @@ -92,6 +94,37 @@ fn resolve_properties(
Ok(props)
}

// Helper function to convert serde_json::Value to aws_smithy_types::Document
fn serde_json_to_aws_document(value: serde_json::Value) -> Document {
match value {
serde_json::Value::Null => Document::Null,
serde_json::Value::Bool(b) => Document::Bool(b),
serde_json::Value::Number(n) => {
if n.is_i64() {
Document::Number(aws_smithy_types::Number::NegInt(n.as_i64().unwrap()))
} else if n.is_u64() {
Document::Number(aws_smithy_types::Number::PosInt(n.as_u64().unwrap()))
} else {
// Fallback to f64
Document::Number(aws_smithy_types::Number::Float(
n.as_f64().unwrap_or(f64::NAN),
))
}
}
serde_json::Value::String(s) => Document::String(s),
serde_json::Value::Array(arr) => {
Document::Array(arr.into_iter().map(serde_json_to_aws_document).collect())
}
serde_json::Value::Object(map) => {
let converted_map: HashMap<String, Document> = map
.into_iter()
.map(|(k, v)| (k, serde_json_to_aws_document(v)))
.collect();
Document::Object(converted_map)
}
}
}

#[derive(Debug)]
struct CollectorInterceptor {
span_id: Option<Uuid>,
Expand Down Expand Up @@ -392,37 +425,57 @@ impl AwsClient {
let bedrock_config = aws_sdk_bedrockruntime::config::Builder::from(&config)
// To support HTTPS_PROXY https://github.com/awslabs/aws-sdk-rust/issues/169
.http_client(http_client)
// Adding a custom http client (above) breaks the stalled stream protection for some reason. If a bedrock request takes longer than 5s (the default grace period, it makes it error out), so we disable it.
.stalled_stream_protection(StalledStreamProtectionConfig::disabled())
.interceptor(CollectorInterceptor::new(span_id, http_request_id.clone()))
.build();
Ok(BedrockRuntimeClient::from_conf(bedrock_config))
}

async fn chat_anyhow<'r>(&self, response: &'r ConverseOutput) -> Result<&'r String> {
async fn chat_anyhow<'r>(&self, response: &'r ConverseOutput) -> Result<String> {
let Some(bedrock::types::ConverseOutput::Message(ref message)) = response.output else {
anyhow::bail!(
"Expected message output in response, but is type {}",
"unknown"
);
};
let content = message
.content
.first()
.context("Expected message output to have content")?;
let bedrock::types::ContentBlock::Text(ref content) = content else {
anyhow::bail!(
"Expected message output to be text, got {}",
match content {
bedrock::types::ContentBlock::Image(_) => "image",
bedrock::types::ContentBlock::GuardContent(_) => "guardContent",
bedrock::types::ContentBlock::ToolResult(_) => "toolResult",
bedrock::types::ContentBlock::ToolUse(_) => "toolUse",
bedrock::types::ContentBlock::Text(_) => "text",
_ => "unknown",
}
);
};
// Try to extract text from all content blocks
let mut extracted_text = String::new();
let mut has_text = false;

Ok(content)
if message.content.is_empty() {
anyhow::bail!("Expected message output to have content, but content is empty");
}

for content_block in &message.content {
if let bedrock::types::ContentBlock::Text(text) = content_block {
has_text = true;
extracted_text.push_str(text);
}
}

// If we found at least one text block, return the concatenated text
if has_text {
let content = extracted_text;
return Ok(content);
}

// If we didn't find any text blocks, return an error with details about the content
anyhow::bail!(
"Expected message output to contain at least one text block, but found none. Content: {:?}",
message.content.iter().map(|block| match block {
bedrock::types::ContentBlock::Image(_) => "image",
bedrock::types::ContentBlock::GuardContent(_) => "guardContent",
bedrock::types::ContentBlock::ToolResult(_) => "toolResult",
bedrock::types::ContentBlock::ToolUse(_) => "toolUse",
bedrock::types::ContentBlock::Text(_) => "text",
bedrock::types::ContentBlock::ReasoningContent(_) => "reasoningContent",
bedrock::types::ContentBlock::CachePoint(_) => "cachePoint",
bedrock::types::ContentBlock::Document(_) => "document",
bedrock::types::ContentBlock::Video(_) => "video",
_ => "unknown",
}).collect::<Vec<_>>()
);
}

fn build_request(
Expand Down Expand Up @@ -460,8 +513,23 @@ impl AwsClient {
.build()
});

let additional_fields_doc = self
.properties
.additional_model_request_fields
.as_ref()
.map(|map| {
// Convert IndexMap<String, serde_json::Value> to serde_json::Value::Object
let json_map: serde_json::Map<String, serde_json::Value> =
map.iter().map(|(k, v)| (k.clone(), v.clone())).collect();
let json_value = serde_json::Value::Object(json_map);
// Convert serde_json::Value to aws_smithy_types::Document
serde_json_to_aws_document(json_value)
})
.unwrap_or_else(|| Document::Object(HashMap::new())); // Default to empty object

bedrock::operation::converse::ConverseInput::builder()
.set_inference_config(inference_config)
.set_additional_model_request_fields(Some(additional_fields_doc))
.set_model_id(Some(self.properties.model.clone()))
.set_system(system_message)
.set_messages(Some(converse_messages))
Expand Down Expand Up @@ -596,12 +664,15 @@ impl WithStreamChat for AwsClient {
}
};

let additional_model_request_fields = request.additional_model_request_fields;

let request = aws_client
.converse_stream()
.set_model_id(request.model_id)
.set_inference_config(request.inference_config)
.set_system(request.system)
.set_messages(request.messages);
.set_messages(request.messages)
.set_additional_model_request_fields(additional_model_request_fields);

let system_start = SystemTime::now();
let instant_start = Instant::now();
Expand Down Expand Up @@ -894,6 +965,7 @@ impl WithChat for AwsClient {
let request = aws_client
.converse()
.set_model_id(request.model_id)
.set_additional_model_request_fields(request.additional_model_request_fields)
.set_inference_config(request.inference_config)
.set_system(request.system)
.set_messages(request.messages);
Expand Down
22 changes: 22 additions & 0 deletions integ-tests/baml_src/test-files/providers/aws.baml
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,14 @@ function TestAwsInferenceProfile(input: string) -> string {
"#
}

// slow on purpose to try and trigger the stalled stream protection (which should be disabled)
function TestAwsClaude37(input: string) -> string {
client AwsBedrockClaude37Client
prompt #"
Write 12 haikus. Number them.
"#
}

test TestName {
functions [TestAwsInferenceProfile]
args {
Expand All @@ -75,6 +83,20 @@ test TestName {
}
}

client<llm> AwsBedrockClaude37Client {
provider "aws-bedrock"
options {
model "arn:aws:bedrock:us-east-1:404337120808:inference-profile/us.anthropic.claude-3-7-sonnet-20250219-v1:0"
additional_model_request_fields {
thinking {
type "enabled"
budget_tokens 1030
}
}
}
}



client<llm> AwsBedrockInferenceProfileClient {
provider "aws-bedrock"
Expand Down
Loading
Loading