Skip to content

Commit

Permalink
Made function call streamable (#98)
Browse files Browse the repository at this point in the history
* made function call streamable

* Revert `FunctionCall` and Introduce `FunctionCallStream`
  • Loading branch information
buraktabn committed Aug 17, 2023
1 parent 263eb70 commit ede4114
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 60 deletions.
7 changes: 6 additions & 1 deletion async-openai/src/types/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -869,12 +869,17 @@ pub type ChatCompletionResponseStream =
Pin<Box<dyn Stream<Item = Result<CreateChatCompletionStreamResponse, OpenAIError>> + Send>>;

// For reason (not documented by OpenAI) the response from stream is different
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
pub struct FunctionCallStream {
pub name: Option<String>,
pub arguments: Option<String>,
}

#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
pub struct ChatCompletionStreamResponseDelta {
pub role: Option<Role>,
pub content: Option<String>,
pub function_call: Option<FunctionCall>,
pub function_call: Option<FunctionCallStream>,
}

#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
Expand Down
143 changes: 84 additions & 59 deletions examples/function-call-stream/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use async_openai::{

use futures::StreamExt;
use serde_json::json;
use async_openai::config::OpenAIConfig;

#[tokio::main]
async fn main() -> Result<(), Box<dyn Error>> {
Expand Down Expand Up @@ -42,71 +43,95 @@ async fn main() -> Result<(), Box<dyn Error>> {
.function_call("auto")
.build()?;

// the first response from GPT is just the json response containing the function that was called
// and the model-generated arguments for that function (don't stream this)
let response = client
.chat()
.create(request)
.await?
.choices
.get(0)
.unwrap()
.message
.clone();

if let Some(function_call) = response.function_call {
let mut available_functions: HashMap<&str, fn(&str, &str) -> serde_json::Value> =
HashMap::new();
available_functions.insert("get_current_weather", get_current_weather);

let function_name = function_call.name;
let function_args: serde_json::Value = function_call.arguments.parse().unwrap();

let location = function_args["location"].as_str().unwrap();
let unit = "fahrenheit"; // why doesn't the model return a unit argument?
let function = available_functions.get(function_name.as_str()).unwrap();
let function_response = function(location, unit); // call the function

let message = vec![
ChatCompletionRequestMessageArgs::default()
.role(Role::User)
.content("What's the weather like in Boston?")
.build()?,
ChatCompletionRequestMessageArgs::default()
.role(Role::Function)
.content(function_response.to_string())
.name(function_name)
.build()?,
];

let request = CreateChatCompletionRequestArgs::default()
.max_tokens(512u16)
.model("gpt-3.5-turbo-0613")
.messages(message)
.build()?;

// Now stream received response from model, which essentially formats the function response
let mut stream = client.chat().create_stream(request).await?;

let mut lock = stdout().lock();
while let Some(result) = stream.next().await {
match result {
Ok(response) => {
response.choices.iter().for_each(|chat_choice| {
if let Some(ref content) = chat_choice.delta.content {
write!(lock, "{}", content).unwrap();
let mut stream = client.chat().create_stream(request).await?;

let mut fn_name = String::new();
let mut fn_args = String::new();

let mut lock = stdout().lock();
while let Some(result) = stream.next().await {
match result {
Ok(response) => {
for chat_choice in response.choices {
if let Some(fn_call) = &chat_choice.delta.function_call {
writeln!(lock, "function_call: {:?}", fn_call).unwrap();
if let Some(name) = &fn_call.name {
fn_name = name.clone();
}
});
}
Err(err) => {
writeln!(lock, "error: {err}").unwrap();
if let Some(args) = &fn_call.arguments {
fn_args.push_str(args);
}
}
if let Some(finish_reason) = &chat_choice.finish_reason {
if finish_reason == "function_call" {
call_fn(&client, &fn_name, &fn_args).await?;
}
} else if let Some(content) = &chat_choice.delta.content {
write!(lock, "{}", content).unwrap();
}
}
}
stdout().flush()?;
Err(err) => {
writeln!(lock, "error: {err}").unwrap();
}
}
println!("{}", "\n");
stdout().flush()?;
}


Ok(())
}

async fn call_fn(client: &Client<OpenAIConfig>, name: &str, args: &str) -> Result<(), Box<dyn Error>> {
let mut available_functions: HashMap<&str, fn(&str, &str) -> serde_json::Value> =
HashMap::new();
available_functions.insert("get_current_weather", get_current_weather);

let function_args: serde_json::Value = args.parse().unwrap();

let location = function_args["location"].as_str().unwrap();
let unit = function_args["unit"].as_str().unwrap_or("fahrenheit");
let function = available_functions.get(name).unwrap();
let function_response = function(location, unit); // call the function

let message = vec![
ChatCompletionRequestMessageArgs::default()
.role(Role::User)
.content("What's the weather like in Boston?")
.build()?,
ChatCompletionRequestMessageArgs::default()
.role(Role::Function)
.content(function_response.to_string())
.name(name.clone())
.build()?,
];

let request = CreateChatCompletionRequestArgs::default()
.max_tokens(512u16)
.model("gpt-3.5-turbo-0613")
.messages(message)
.build()?;

// Now stream received response from model, which essentially formats the function response
let mut stream = client.chat().create_stream(request).await?;

let mut lock = stdout().lock();
while let Some(result) = stream.next().await {
match result {
Ok(response) => {
response.choices.iter().for_each(|chat_choice| {
if let Some(ref content) = chat_choice.delta.content {
write!(lock, "{}", content).unwrap();
}
});
}
Err(err) => {
writeln!(lock, "error: {err}").unwrap();
}
}
stdout().flush()?;
}
println!("{}", "\n");
Ok(())
}

Expand Down

0 comments on commit ede4114

Please sign in to comment.