diff --git a/README.md b/README.md index 8b2e5ff..3a81639 100644 --- a/README.md +++ b/README.md @@ -113,6 +113,28 @@ AIScript excels in these scenarios: Check out the [examples](./examples) directory for more sample code. +## Supported AI Models + +AIScript supports the following AI models: + +- [x] OpenAI ((uses `OPENAI_API_KEY` environment variable by default)) +- [x] DeepSeek +- [ ] Anthropic + +Configuration by `project.toml`: + +```toml +# use OpenAI +[ai.openai] +api_key = "YOUR_API_KEY" +model = "gpt-3.5-turbo" + +# or use DeepSeek +[ai.deepseek] +api_key = "YOUR_API_KEY" +model = "deepseek-chat" +``` + ## Roadmap See our [roadmap](https://aiscript.dev/guide/contribution/roadmap) for upcoming features and improvements. diff --git a/aiscript-vm/src/ai/agent.rs b/aiscript-vm/src/ai/agent.rs index 8b6baf8..90ca33d 100644 --- a/aiscript-vm/src/ai/agent.rs +++ b/aiscript-vm/src/ai/agent.rs @@ -8,7 +8,6 @@ use openai_api_rs::v1::{ ChatCompletionMessage, ChatCompletionMessageForResponse, ChatCompletionRequest, Content, MessageRole, Tool, ToolCall, ToolChoiceType, ToolType, }, - common::GPT3_5_TURBO, types::{self, FunctionParameters, JSONSchemaDefine}, }; use tokio::runtime::Handle; @@ -278,6 +277,8 @@ pub async fn _run_agent<'gc>( mut agent: Gc<'gc, Agent<'gc>>, args: Vec>, ) -> Value<'gc> { + use super::default_model; + let message = args[0]; let debug = args[1].as_boolean(); let mut history = Vec::new(); @@ -288,11 +289,11 @@ pub async fn _run_agent<'gc>( tool_calls: None, tool_call_id: None, }); - let mut client = super::openai_client(); + let mut client = super::openai_client(state.ai_config.as_ref()); loop { let mut messages = vec![agent.get_instruction_message()]; messages.extend(history.clone()); - let mut req = ChatCompletionRequest::new(GPT3_5_TURBO.to_string(), messages); + let mut req = ChatCompletionRequest::new(default_model(state.ai_config.as_ref()), messages); let tools = agent.get_tools(); if !tools.is_empty() { req = req diff --git a/aiscript-vm/src/ai/mod.rs b/aiscript-vm/src/ai/mod.rs index 787715a..ba82193 100644 --- a/aiscript-vm/src/ai/mod.rs +++ b/aiscript-vm/src/ai/mod.rs @@ -4,16 +4,22 @@ mod prompt; use std::env; pub use agent::{Agent, run_agent}; -use openai_api_rs::v1::api::OpenAIClient; +use openai_api_rs::v1::{api::OpenAIClient, common::GPT3_5_TURBO}; pub use prompt::{PromptConfig, prompt_with_config}; use serde::Deserialize; -#[derive(Debug, Clone, Deserialize, Default)] -pub struct AiConfig { - pub openai: Option, - pub anthropic: Option, - pub deepseek: Option, +const DEEPSEEK_API_ENDPOINT: &str = "https://api.deepseek.com/v1"; +const DEEPSEEK_V3: &str = "deepseek-chat"; + +#[derive(Debug, Clone, Deserialize)] +pub enum AiConfig { + #[serde(rename = "openai")] + OpenAI(ModelConfig), + #[serde(rename = "anthropic")] + Anthropic(ModelConfig), + #[serde(rename = "deepseek")] + DeepSeek(ModelConfig), } #[derive(Debug, Clone, Deserialize)] @@ -22,10 +28,60 @@ pub struct ModelConfig { pub model: Option, } +impl AiConfig { + pub(crate) fn take_model(&mut self) -> Option { + match self { + Self::OpenAI(ModelConfig { model, .. }) => model.take(), + Self::Anthropic(ModelConfig { model, .. }) => model.take(), + Self::DeepSeek(ModelConfig { model, .. }) => model.take(), + } + } + + pub(crate) fn set_model(&mut self, m: String) { + match self { + Self::OpenAI(ModelConfig { model, .. }) => model.replace(m), + Self::Anthropic(ModelConfig { model, .. }) => model.replace(m), + Self::DeepSeek(ModelConfig { model, .. }) => model.replace(m), + }; + } +} + #[allow(unused)] -pub(crate) fn openai_client() -> OpenAIClient { - OpenAIClient::builder() - .with_api_key(env::var("OPENAI_API_KEY").unwrap().to_string()) - .build() - .unwrap() +pub(crate) fn openai_client(config: Option<&AiConfig>) -> OpenAIClient { + match config { + None => OpenAIClient::builder() + .with_api_key(env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set")) + .build() + .unwrap(), + Some(AiConfig::OpenAI(model_config)) => { + let api_key = if model_config.api_key.is_empty() { + env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set") + } else { + model_config.api_key.clone() + }; + OpenAIClient::builder() + .with_api_key(api_key) + .build() + .unwrap() + } + Some(AiConfig::DeepSeek(ModelConfig { api_key, .. })) => OpenAIClient::builder() + .with_endpoint(DEEPSEEK_API_ENDPOINT) + .with_api_key(api_key) + .build() + .unwrap(), + Some(AiConfig::Anthropic(_)) => unimplemented!("Anthropic API not yet supported"), + } +} + +pub(crate) fn default_model(config: Option<&AiConfig>) -> String { + match config { + None => GPT3_5_TURBO.to_string(), + Some(AiConfig::OpenAI(ModelConfig { model, .. })) => { + model.clone().unwrap_or(GPT3_5_TURBO.to_string()) + } + Some(AiConfig::DeepSeek(ModelConfig { model, .. })) => { + model.clone().unwrap_or(DEEPSEEK_V3.to_string()) + } + Some(AiConfig::Anthropic(_)) => unimplemented!("Anthropic API not yet supported"), + } } diff --git a/aiscript-vm/src/ai/prompt.rs b/aiscript-vm/src/ai/prompt.rs index 9e9f919..68a5eb4 100644 --- a/aiscript-vm/src/ai/prompt.rs +++ b/aiscript-vm/src/ai/prompt.rs @@ -1,9 +1,11 @@ use openai_api_rs::v1::common::GPT3_5_TURBO; use tokio::runtime::Handle; +use super::{AiConfig, ModelConfig, default_model}; + pub struct PromptConfig { pub input: String, - pub model: Option, + pub ai_config: Option, pub max_tokens: Option, pub temperature: Option, pub system_prompt: Option, @@ -13,7 +15,10 @@ impl Default for PromptConfig { fn default() -> Self { Self { input: String::new(), - model: Some(GPT3_5_TURBO.to_string()), + ai_config: Some(AiConfig::OpenAI(ModelConfig { + api_key: Default::default(), + model: Some(GPT3_5_TURBO.to_string()), + })), max_tokens: Default::default(), temperature: Default::default(), system_prompt: Default::default(), @@ -21,6 +26,21 @@ impl Default for PromptConfig { } } +impl PromptConfig { + fn take_model(&mut self) -> String { + self.ai_config + .as_mut() + .and_then(|config| config.take_model()) + .unwrap_or_else(|| default_model(self.ai_config.as_ref())) + } + + pub(crate) fn set_model(&mut self, model: String) { + if let Some(config) = self.ai_config.as_mut() { + config.set_model(model); + } + } +} + #[cfg(feature = "ai_test")] async fn _prompt_with_config(config: PromptConfig) -> String { return format!("AI: {}", config.input); @@ -28,12 +48,9 @@ async fn _prompt_with_config(config: PromptConfig) -> String { #[cfg(not(feature = "ai_test"))] async fn _prompt_with_config(mut config: PromptConfig) -> String { - use openai_api_rs::v1::{ - chat_completion::{self, ChatCompletionRequest}, - common::GPT3_5_TURBO, - }; - - let mut client = super::openai_client(); + use openai_api_rs::v1::chat_completion::{self, ChatCompletionRequest}; + let mut client = super::openai_client(config.ai_config.as_ref()); + let model = config.take_model(); // Create system message if provided let mut messages = Vec::new(); @@ -57,13 +74,7 @@ async fn _prompt_with_config(mut config: PromptConfig) -> String { }); // Build the request - let mut req = ChatCompletionRequest::new( - config - .model - .take() - .unwrap_or_else(|| GPT3_5_TURBO.to_string()), - messages, - ); + let mut req = ChatCompletionRequest::new(model, messages); if let Some(max_tokens) = config.max_tokens { req.max_tokens = Some(max_tokens); diff --git a/aiscript-vm/src/vm/state.rs b/aiscript-vm/src/vm/state.rs index ad2f398..6f9205a 100644 --- a/aiscript-vm/src/vm/state.rs +++ b/aiscript-vm/src/vm/state.rs @@ -1013,15 +1013,21 @@ impl<'gc> State<'gc> { let result = match value { // Simple string case Value::String(s) => { - let input = s.to_str().unwrap().to_string(); - ai::prompt_with_config(PromptConfig { - input, + let mut config = PromptConfig { + input: s.to_str().unwrap().to_string(), ..Default::default() - }) + }; + if let Some(ai_cfg) = &self.ai_config { + config.ai_config = Some(ai_cfg.clone()); + } + ai::prompt_with_config(config) } // Object config case Value::Object(obj) => { let mut config = PromptConfig::default(); + if let Some(ai_cfg) = &self.ai_config { + config.ai_config = Some(ai_cfg.clone()); + } let obj_ref = obj.borrow(); // Extract input (required) @@ -1039,7 +1045,7 @@ impl<'gc> State<'gc> { if let Some(Value::String(model)) = obj_ref.fields.get(&self.intern(b"model")) { - config.model = Some(model.to_str().unwrap().to_string()); + config.set_model(model.to_str().unwrap().to_string()); } // Extract max_tokens (optional)