From c14f1b65c6b3df248e38961409349ee22e1434dc Mon Sep 17 00:00:00 2001 From: Folyd Date: Mon, 17 Mar 2025 23:40:32 +0800 Subject: [PATCH] Add ai config --- aiscript-runtime/src/config/mod.rs | 7 +++++-- aiscript-runtime/src/endpoint.rs | 9 +++++++-- aiscript-vm/src/ai/mod.rs | 15 +++++++++++++++ aiscript-vm/src/lib.rs | 1 + aiscript-vm/src/vm/mod.rs | 5 ++++- aiscript-vm/src/vm/state.rs | 4 +++- aiscript/src/main.rs | 9 +++++++-- 7 files changed, 42 insertions(+), 8 deletions(-) diff --git a/aiscript-runtime/src/config/mod.rs b/aiscript-runtime/src/config/mod.rs index 72ffae6..cc53c5d 100644 --- a/aiscript-runtime/src/config/mod.rs +++ b/aiscript-runtime/src/config/mod.rs @@ -3,6 +3,7 @@ use std::{env, fmt::Display, fs, ops::Deref, path::Path, sync::OnceLock}; use auth::AuthConfig; use serde::Deserialize; +use aiscript_vm::AiConfig; use db::DatabaseConfig; pub use sso::{SsoConfig, get_sso_fields}; @@ -64,6 +65,8 @@ impl AsRef for EnvString { #[derive(Debug, Deserialize, Default)] pub struct Config { + #[serde(default)] + pub ai: Option, #[serde(default)] pub database: DatabaseConfig, #[serde(default)] @@ -116,9 +119,9 @@ impl Config { } } - pub fn load(path: &str) -> &Config { + pub fn load() -> &'static Config { CONFIG.get_or_init(|| { - Config::new(path).unwrap_or_else(|e| { + Config::new("project.toml").unwrap_or_else(|e| { eprintln!("Error loading config file: {}", e); Config::default() }) diff --git a/aiscript-runtime/src/endpoint.rs b/aiscript-runtime/src/endpoint.rs index 42163ea..5520e30 100644 --- a/aiscript-runtime/src/endpoint.rs +++ b/aiscript-runtime/src/endpoint.rs @@ -499,8 +499,13 @@ impl Future for RequestProcessor { let redis_connection = self.endpoint.redis_connection.clone(); let handle: JoinHandle> = task::spawn_blocking(move || { - let mut vm = - Vm::new(pg_connection, sqlite_connection, redis_connection); + let ai_config = Config::load().ai.clone(); + let mut vm = Vm::new( + pg_connection, + sqlite_connection, + redis_connection, + ai_config, + ); if let Some(fields) = sso_fields { vm.inject_sso_instance(fields); } diff --git a/aiscript-vm/src/ai/mod.rs b/aiscript-vm/src/ai/mod.rs index 65a5345..787715a 100644 --- a/aiscript-vm/src/ai/mod.rs +++ b/aiscript-vm/src/ai/mod.rs @@ -7,6 +7,21 @@ pub use agent::{Agent, run_agent}; use openai_api_rs::v1::api::OpenAIClient; pub use prompt::{PromptConfig, prompt_with_config}; +use serde::Deserialize; + +#[derive(Debug, Clone, Deserialize, Default)] +pub struct AiConfig { + pub openai: Option, + pub anthropic: Option, + pub deepseek: Option, +} + +#[derive(Debug, Clone, Deserialize)] +pub struct ModelConfig { + pub api_key: String, + pub model: Option, +} + #[allow(unused)] pub(crate) fn openai_client() -> OpenAIClient { OpenAIClient::builder() diff --git a/aiscript-vm/src/lib.rs b/aiscript-vm/src/lib.rs index 1a412ee..8f54c04 100644 --- a/aiscript-vm/src/lib.rs +++ b/aiscript-vm/src/lib.rs @@ -16,6 +16,7 @@ use std::collections::HashMap; use std::fmt::Display; use std::ops::Deref; +pub use ai::AiConfig; use aiscript_arena::Collect; use aiscript_arena::Mutation; pub(crate) use aiscript_lexer as lexer; diff --git a/aiscript-vm/src/vm/mod.rs b/aiscript-vm/src/vm/mod.rs index 556e975..fe22a87 100644 --- a/aiscript-vm/src/vm/mod.rs +++ b/aiscript-vm/src/vm/mod.rs @@ -6,6 +6,7 @@ pub use state::State; use crate::{ ReturnValue, Value, + ai::AiConfig, ast::ChunkId, builtins, stdlib, string::{InternedString, InternedStringSet}, @@ -35,7 +36,7 @@ impl Display for VmError { impl Default for Vm { fn default() -> Self { - Self::new(None, None, None) + Self::new(None, None, None, None) } } @@ -48,6 +49,7 @@ impl Vm { pg_connection: Option, sqlite_connection: Option, redis_connection: Option, + ai_config: Option, ) -> Self { let mut vm = Vm { arena: Arena::]>::new(|mc| { @@ -55,6 +57,7 @@ impl Vm { state.pg_connection = pg_connection; state.sqlite_connection = sqlite_connection; state.redis_connection = redis_connection; + state.ai_config = ai_config; state }), }; diff --git a/aiscript-vm/src/vm/state.rs b/aiscript-vm/src/vm/state.rs index 0706cc1..ad2f398 100644 --- a/aiscript-vm/src/vm/state.rs +++ b/aiscript-vm/src/vm/state.rs @@ -15,7 +15,7 @@ use sqlx::{PgPool, SqlitePool}; use crate::{ NativeFn, OpCode, ReturnValue, Value, - ai::{self, PromptConfig}, + ai::{self, AiConfig, PromptConfig}, ast::{ChunkId, Visibility}, builtins::BuiltinMethods, module::{ModuleKind, ModuleManager, ModuleSource}, @@ -110,6 +110,7 @@ pub struct State<'gc> { pub pg_connection: Option, pub sqlite_connection: Option, pub redis_connection: Option, + pub ai_config: Option, } unsafe impl Collect for State<'_> { @@ -152,6 +153,7 @@ impl<'gc> State<'gc> { pg_connection: None, sqlite_connection: None, redis_connection: None, + ai_config: None, } } diff --git a/aiscript/src/main.rs b/aiscript/src/main.rs index bbc8f3a..1700ee3 100644 --- a/aiscript/src/main.rs +++ b/aiscript/src/main.rs @@ -48,7 +48,7 @@ enum Commands { #[tokio::main] async fn main() { dotenv::dotenv().ok(); - Config::load("project.toml"); + let config = Config::load(); let cli = AIScriptCli::parse(); match cli.command { @@ -69,7 +69,12 @@ async fn main() { let sqlite_connection = aiscript_runtime::get_sqlite_connection().await; let redis_connection = aiscript_runtime::get_redis_connection().await; task::spawn_blocking(move || { - let mut vm = Vm::new(pg_connection, sqlite_connection, redis_connection); + let mut vm = Vm::new( + pg_connection, + sqlite_connection, + redis_connection, + config.ai.clone(), + ); vm.run_file(path); }) .await // must use await to wait for the thread to finish