diff --git a/async-openai/src/types/types.rs b/async-openai/src/types/types.rs index b99c4134..03672a00 100644 --- a/async-openai/src/types/types.rs +++ b/async-openai/src/types/types.rs @@ -869,12 +869,17 @@ pub type ChatCompletionResponseStream = Pin> + Send>>; // For reason (not documented by OpenAI) the response from stream is different +#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)] +pub struct FunctionCallStream { + pub name: Option, + pub arguments: Option, +} #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)] pub struct ChatCompletionStreamResponseDelta { pub role: Option, pub content: Option, - pub function_call: Option, + pub function_call: Option, } #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)] diff --git a/examples/function-call-stream/src/main.rs b/examples/function-call-stream/src/main.rs index ffe8cce6..0570e31e 100644 --- a/examples/function-call-stream/src/main.rs +++ b/examples/function-call-stream/src/main.rs @@ -12,6 +12,7 @@ use async_openai::{ use futures::StreamExt; use serde_json::json; +use async_openai::config::OpenAIConfig; #[tokio::main] async fn main() -> Result<(), Box> { @@ -42,71 +43,95 @@ async fn main() -> Result<(), Box> { .function_call("auto") .build()?; - // the first response from GPT is just the json response containing the function that was called - // and the model-generated arguments for that function (don't stream this) - let response = client - .chat() - .create(request) - .await? - .choices - .get(0) - .unwrap() - .message - .clone(); - - if let Some(function_call) = response.function_call { - let mut available_functions: HashMap<&str, fn(&str, &str) -> serde_json::Value> = - HashMap::new(); - available_functions.insert("get_current_weather", get_current_weather); - - let function_name = function_call.name; - let function_args: serde_json::Value = function_call.arguments.parse().unwrap(); - - let location = function_args["location"].as_str().unwrap(); - let unit = "fahrenheit"; // why doesn't the model return a unit argument? - let function = available_functions.get(function_name.as_str()).unwrap(); - let function_response = function(location, unit); // call the function - - let message = vec![ - ChatCompletionRequestMessageArgs::default() - .role(Role::User) - .content("What's the weather like in Boston?") - .build()?, - ChatCompletionRequestMessageArgs::default() - .role(Role::Function) - .content(function_response.to_string()) - .name(function_name) - .build()?, - ]; - - let request = CreateChatCompletionRequestArgs::default() - .max_tokens(512u16) - .model("gpt-3.5-turbo-0613") - .messages(message) - .build()?; - - // Now stream received response from model, which essentially formats the function response - let mut stream = client.chat().create_stream(request).await?; - - let mut lock = stdout().lock(); - while let Some(result) = stream.next().await { - match result { - Ok(response) => { - response.choices.iter().for_each(|chat_choice| { - if let Some(ref content) = chat_choice.delta.content { - write!(lock, "{}", content).unwrap(); + let mut stream = client.chat().create_stream(request).await?; + + let mut fn_name = String::new(); + let mut fn_args = String::new(); + + let mut lock = stdout().lock(); + while let Some(result) = stream.next().await { + match result { + Ok(response) => { + for chat_choice in response.choices { + if let Some(fn_call) = &chat_choice.delta.function_call { + writeln!(lock, "function_call: {:?}", fn_call).unwrap(); + if let Some(name) = &fn_call.name { + fn_name = name.clone(); } - }); - } - Err(err) => { - writeln!(lock, "error: {err}").unwrap(); + if let Some(args) = &fn_call.arguments { + fn_args.push_str(args); + } + } + if let Some(finish_reason) = &chat_choice.finish_reason { + if finish_reason == "function_call" { + call_fn(&client, &fn_name, &fn_args).await?; + } + } else if let Some(content) = &chat_choice.delta.content { + write!(lock, "{}", content).unwrap(); + } } } - stdout().flush()?; + Err(err) => { + writeln!(lock, "error: {err}").unwrap(); + } } - println!("{}", "\n"); + stdout().flush()?; } + + Ok(()) +} + +async fn call_fn(client: &Client, name: &str, args: &str) -> Result<(), Box> { + let mut available_functions: HashMap<&str, fn(&str, &str) -> serde_json::Value> = + HashMap::new(); + available_functions.insert("get_current_weather", get_current_weather); + + let function_args: serde_json::Value = args.parse().unwrap(); + + let location = function_args["location"].as_str().unwrap(); + let unit = function_args["unit"].as_str().unwrap_or("fahrenheit"); + let function = available_functions.get(name).unwrap(); + let function_response = function(location, unit); // call the function + + let message = vec![ + ChatCompletionRequestMessageArgs::default() + .role(Role::User) + .content("What's the weather like in Boston?") + .build()?, + ChatCompletionRequestMessageArgs::default() + .role(Role::Function) + .content(function_response.to_string()) + .name(name.clone()) + .build()?, + ]; + + let request = CreateChatCompletionRequestArgs::default() + .max_tokens(512u16) + .model("gpt-3.5-turbo-0613") + .messages(message) + .build()?; + + // Now stream received response from model, which essentially formats the function response + let mut stream = client.chat().create_stream(request).await?; + + let mut lock = stdout().lock(); + while let Some(result) = stream.next().await { + match result { + Ok(response) => { + response.choices.iter().for_each(|chat_choice| { + if let Some(ref content) = chat_choice.delta.content { + write!(lock, "{}", content).unwrap(); + } + }); + } + Err(err) => { + writeln!(lock, "error: {err}").unwrap(); + } + } + stdout().flush()?; + } + println!("{}", "\n"); Ok(()) }