Skip to content
Closed
6 changes: 5 additions & 1 deletion engine/baml-lib/baml-types/src/tracing/events.rs
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,11 @@ impl HTTPBody {
}

pub fn text(&self) -> anyhow::Result<&str> {
std::str::from_utf8(&self.raw).map_err(|e| anyhow::anyhow!("HTTP body is not UTF-8: {}", e))
match self.raw.len() {
0 => Ok(""),
_ => std::str::from_utf8(&self.raw)
.map_err(|e| anyhow::anyhow!("HTTP body is not UTF-8: {}", e)),
}
}

pub fn json(&self) -> anyhow::Result<serde_json::Value> {
Expand Down
2 changes: 1 addition & 1 deletion engine/language_client_cffi/build.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::{path::Path, process::Command};
use std::path::Path;

use cbindgen;
use flatc::flatc;
Expand Down
328 changes: 324 additions & 4 deletions engine/language_client_cffi/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,23 @@
/// cbindgen:ignore
mod ctypes;

use std::{collections::HashMap, ffi::CStr, ptr::null, sync::Arc};
mod raw_ptr_wrapper;

use anyhow::Result;
use baml_runtime::client_registry::ClientRegistry;
use baml_runtime::tracingv2::storage::storage::{Collector, LLMCall, LLMStreamCall, Timing, Usage};
use baml_runtime::{BamlRuntime, FunctionResult};
use baml_types::tracing::events::HTTPBody;
use once_cell::sync::{Lazy, OnceCell};
use raw_ptr_wrapper::{CollectorWrapper, FunctionLogWrapper, HTTPResponseWrapper};
use std::ops::Deref;
use std::ptr::null_mut;
use std::{
collections::HashMap,
ffi::CStr,
ptr::null,
sync::{Arc, Mutex},
};

const VERSION: &str = env!("CARGO_PKG_VERSION");

Expand Down Expand Up @@ -163,8 +174,18 @@ pub extern "C" fn call_function_from_c(
encoded_args: *const libc::c_char,
length: usize,
id: u32,
collectors: *const libc::c_void,
collectors_length: usize,
) -> *const libc::c_void {
match call_function_from_c_inner(runtime, function_name, encoded_args, length, id) {
match call_function_from_c_inner(
runtime,
function_name,
encoded_args,
length,
id,
collectors,
collectors_length,
) {
Ok(_) => null(),
Err(e) => {
Box::into_raw(Box::new(CString::new(e.to_string()).unwrap())) as *const libc::c_void
Expand All @@ -178,6 +199,8 @@ fn call_function_from_c_inner(
encoded_args: *const libc::c_char,
length: usize,
id: u32,
collectors: *const libc::c_void,
collectors_length: usize,
) -> Result<()> {
// Safety: assume that the pointers provided are valid.
let runtime = unsafe { &*(runtime as *const BamlRuntime) };
Expand All @@ -195,6 +218,20 @@ fn call_function_from_c_inner(
let function_args = ctypes::buffer_to_cffi_function_arguments(buffer)?;
let env_vars = function_args.env_vars.clone();

// let runtime = unsafe { &*(runtime as *const BamlRuntime) };
let collector_ptrs = unsafe {
std::slice::from_raw_parts(collectors as *const *const libc::c_void, collectors_length)
};
let collectors = match collectors_length {
0 => None,
_ => Some(
collector_ptrs
.iter()
.map(|c| CollectorWrapper::from_raw(*c, true))
.collect::<Vec<_>>(),
),
};

let ctx = runtime.create_ctx_manager(BamlValue::String("cffi".to_string()), None);

// Spawn an async task to await the future and call the callback when done.
Expand All @@ -208,7 +245,7 @@ fn call_function_from_c_inner(
&ctx,
None,
function_args.client_registry.as_ref(),
None,
collectors.map(|c| c.iter().map(|c| c.deref().clone()).collect()),
env_vars,
)
.await;
Expand Down Expand Up @@ -267,7 +304,8 @@ fn call_function_stream_from_c_inner(
None,
None,
None,
env_vars){
env_vars,
) {
Ok(stream) => stream,
Err(e) => {
return Err(anyhow::anyhow!("Failed to stream function: {}", e));
Expand All @@ -289,3 +327,285 @@ fn call_function_stream_from_c_inner(
fn on_event(id: u32, result: FunctionResult) {
safe_trigger_callback(id, true, Ok(result));
}

#[no_mangle]
pub extern "C" fn call_collector_function(
object: *const libc::c_void,
object_type: *const c_char,
function_name: *const c_char,
) -> *const libc::c_void {
match call_collector_function_inner(object, object_type, function_name) {
Ok(result) => result,
Err(e) => {
Box::into_raw(Box::new(CString::new(e.to_string()).unwrap())) as *const libc::c_void
}
}
}

fn call_collector_function_inner(
object: *const libc::c_void,
object_type: *const c_char,
function_name: *const c_char,
) -> Result<*const libc::c_void> {
let object_type = match unsafe { CStr::from_ptr(object_type) }.to_str() {
Ok(s) => s.to_owned(),
Err(_) => {
return Err(anyhow::anyhow!("Failed to convert object type to string"));
}
};

let function_name = match unsafe { CStr::from_ptr(function_name) }.to_str() {
Ok(s) => s.to_owned(),
Err(_) => {
return Err(anyhow::anyhow!("Failed to convert function name to string"));
}
};

if object.is_null() {
return match (object_type.as_str(), function_name.as_str()) {
("collector", "new") => {
let collector = Collector::new(None);
Ok(Arc::into_raw(Arc::new(collector)) as *const libc::c_void)
}
_ => Err(anyhow::anyhow!(
"Failed to call collector function: {}",
function_name
)),
};
}

match object_type.as_str() {
"collector" => {
let collector = CollectorWrapper::from_raw(object, true);

match function_name.as_str() {
"destroy" => {
collector.destroy();
// collector goes out of scope here
Ok(null())
}
"usage" => {
let usage = collector.usage();
Ok(Box::into_raw(Box::new(usage)) as *const libc::c_void)
}
"last" => {
let last = collector.last_function_log();
let wrapper = Arc::new(Mutex::new(last));
Ok(Arc::into_raw(wrapper) as *const libc::c_void)
}
_ => Err(anyhow::anyhow!(
"Failed to call function: {} on object type: {}",
function_name,
object_type
)),
}
}
"usage" => {
let usage = unsafe { &mut *(object as *mut Usage) };

match function_name.as_str() {
"destroy" => {
let _ = unsafe { Box::from_raw(object as *mut Usage) };
Ok(null())
}
// pretend this is an integer not a pointer, which is dirty but works for now
"input_tokens" => Ok(usage.input_tokens.unwrap_or_default() as *mut libc::c_void),
"output_tokens" => Ok(usage.output_tokens.unwrap_or_default() as *mut libc::c_void),
_ => Err(anyhow::anyhow!(
"Failed to call function: {} on object type: {}",
function_name,
object_type
)),
}
}
"function_log" => {
let function_log = FunctionLogWrapper::from_raw(object, true);
match function_name.as_str() {
"id" => {
let id = function_log.lock().unwrap().id().to_string();
let c_id = CString::new(id).unwrap();
Ok(c_id.into_raw() as *const libc::c_void)
}
"function_name" => {
let function_name = function_log.lock().unwrap().function_name();
let c_function_name = CString::new(function_name).unwrap();
Ok(c_function_name.into_raw() as *const libc::c_void)
}
"log_type" => {
let log_type = function_log.lock().unwrap().log_type();
let c_log_type = CString::new(log_type.to_string()).unwrap();
Ok(c_log_type.into_raw() as *const libc::c_void)
}
"raw_llm_response" => {
let raw_llm_response = function_log.lock().unwrap().raw_llm_response();
let c_raw_llm_response =
CString::new(raw_llm_response.unwrap_or_default()).unwrap();
Ok(c_raw_llm_response.into_raw() as *const libc::c_void)
}
"calls" => {
let calls = function_log.lock().unwrap().calls();
let c_calls = calls
.iter()
.map(|c| match c {
baml_runtime::tracingv2::storage::storage::LLMCallKind::Basic(
inner,
) => Box::into_raw(Box::new(inner.clone())) as *mut libc::c_void,
baml_runtime::tracingv2::storage::storage::LLMCallKind::Stream(
inner,
) => Box::into_raw(Box::new(inner.clone())) as *mut libc::c_void,
})
.chain(std::iter::once(null_mut()))
.collect::<Vec<_>>();
let c_calls_ptr = c_calls.as_ptr() as *const libc::c_void;
// leak this so go can have it
std::mem::forget(c_calls);
Ok(c_calls_ptr)
}
"timing" => {
let timing = function_log.lock().unwrap().timing();
Ok(Box::into_raw(Box::new(timing)) as *const libc::c_void)
}
"usage" => {
let usage = function_log.lock().unwrap().usage();
Ok(Box::into_raw(Box::new(usage)) as *const libc::c_void)
}
"destroy" => {
function_log.destroy();
Ok(null())
}
_ => Err(anyhow::anyhow!(
"Failed to call function: {} on object type: {}",
function_name,
object_type
)),
}
}
"timing" => {
let timing = unsafe { &mut *(object as *mut Timing) };
match function_name.as_str() {
"destroy" => {
let _ = unsafe { Box::from_raw(object as *mut Timing) };
Ok(null())
}
"start_time_utc_ms" => Ok(timing.start_time_utc_ms as *mut libc::c_void),
"duration_ms" => Ok(timing
.duration_ms
.map(|d| d as *mut libc::c_void)
.unwrap_or(null_mut())),
_ => Err(anyhow::anyhow!(
"Failed to call function: {} on object type: {}",
function_name,
object_type
)),
}
}
"llm_call" => {
let llm_call = unsafe { &mut *(object as *mut LLMCall) };
match function_name.as_str() {
"client_name" => {
let c_client_name = CString::new(llm_call.client_name.clone()).unwrap();
Ok(c_client_name.into_raw() as *const libc::c_void)
}
"provider" => {
let c_provider = CString::new(llm_call.provider.clone()).unwrap();
Ok(c_provider.into_raw() as *const libc::c_void)
}
"timing" => {
let timing = llm_call.timing.clone();
Ok(Box::into_raw(Box::new(timing)) as *const libc::c_void)
}
"usage" => {
let usage = llm_call.usage.clone().unwrap();
Ok(Box::into_raw(Box::new(usage)) as *const libc::c_void)
}
"destroy" => {
let _ = unsafe { Box::from_raw(object as *mut LLMCall) };
Ok(null())
}
"selected" => {
let selected = if llm_call.selected { 1 } else { 0 };
Ok(selected as *const libc::c_void)
}
"http_response" => llm_call
.response
.clone()
.map(|r| Arc::into_raw(r) as *const libc::c_void)
.ok_or(anyhow::anyhow!("No response")),
_ => Err(anyhow::anyhow!(
"Failed to call function: {} on object type: {}",
function_name,
object_type
)),
}
}
"http_response" => {
let http_response = HTTPResponseWrapper::from_raw(object, true);
match function_name.as_str() {
"destroy" => {
http_response.destroy();
Ok(null())
}
"http_body" => {
let http_body = http_response.body.clone();
Ok(Box::into_raw(Box::new(http_body)) as *const libc::c_void)
}
_ => Err(anyhow::anyhow!(
"Failed to call function: {} on object type: {}",
function_name,
object_type
)),
}
}
"http_body" => {
let http_body = unsafe { &mut *(object as *mut HTTPBody) };
match function_name.as_str() {
"destroy" => {
let _ = unsafe { Box::from_raw(object as *mut HTTPBody) };
Ok(null())
}
"text" => {
let text = http_body.text().unwrap();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Avoid using unwrap() on the result of http_body.text(). This FFI boundary should propagate errors instead of panicking on invalid UTF-8 data.

let c_text = CString::new(text).unwrap();
Ok(c_text.into_raw() as *const libc::c_void)
}
_ => Err(anyhow::anyhow!(
"Failed to call function: {} on object type: {}",
function_name,
object_type
)),
}
}
"string" => match function_name.as_str() {
"destroy" => {
let _ = unsafe { CString::from_raw(object as *mut c_char) };
Ok(null())
}
_ => Err(anyhow::anyhow!(
"Failed to call function: {} on object type: {}",
function_name,
object_type
)),
},
"list" => {
let ptrs = object as *mut *mut libc::c_void;
match function_name.as_str() {
"destroy" => {
unsafe {
drop(Box::from_raw(ptrs));
}
Ok(null())
}
_ => Err(anyhow::anyhow!(
"Failed to call function: {} on object type: {}",
function_name,
object_type
)),
}
}
_ => Err(anyhow::anyhow!(
"Failed to call function: {} on object type: {}",
function_name,
object_type
)),
}
}
Loading
Loading