Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,28 @@ AIScript excels in these scenarios:

Check out the [examples](./examples) directory for more sample code.

## Supported AI Models

AIScript supports the following AI models:

- [x] OpenAI ((uses `OPENAI_API_KEY` environment variable by default))
- [x] DeepSeek
- [ ] Anthropic

Configuration by `project.toml`:

```toml
# use OpenAI
[ai.openai]
api_key = "YOUR_API_KEY"
model = "gpt-3.5-turbo"

# or use DeepSeek
[ai.deepseek]
api_key = "YOUR_API_KEY"
model = "deepseek-chat"
```

## Roadmap

See our [roadmap](https://aiscript.dev/guide/contribution/roadmap) for upcoming features and improvements.
Expand Down
7 changes: 4 additions & 3 deletions aiscript-vm/src/ai/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ use openai_api_rs::v1::{
ChatCompletionMessage, ChatCompletionMessageForResponse, ChatCompletionRequest, Content,
MessageRole, Tool, ToolCall, ToolChoiceType, ToolType,
},
common::GPT3_5_TURBO,
types::{self, FunctionParameters, JSONSchemaDefine},
};
use tokio::runtime::Handle;
Expand Down Expand Up @@ -278,6 +277,8 @@ pub async fn _run_agent<'gc>(
mut agent: Gc<'gc, Agent<'gc>>,
args: Vec<Value<'gc>>,
) -> Value<'gc> {
use super::default_model;

let message = args[0];
let debug = args[1].as_boolean();
let mut history = Vec::new();
Expand All @@ -288,11 +289,11 @@ pub async fn _run_agent<'gc>(
tool_calls: None,
tool_call_id: None,
});
let mut client = super::openai_client();
let mut client = super::openai_client(state.ai_config.as_ref());
loop {
let mut messages = vec![agent.get_instruction_message()];
messages.extend(history.clone());
let mut req = ChatCompletionRequest::new(GPT3_5_TURBO.to_string(), messages);
let mut req = ChatCompletionRequest::new(default_model(state.ai_config.as_ref()), messages);
let tools = agent.get_tools();
if !tools.is_empty() {
req = req
Expand Down
78 changes: 67 additions & 11 deletions aiscript-vm/src/ai/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,22 @@ mod prompt;
use std::env;

pub use agent::{Agent, run_agent};
use openai_api_rs::v1::api::OpenAIClient;
use openai_api_rs::v1::{api::OpenAIClient, common::GPT3_5_TURBO};
pub use prompt::{PromptConfig, prompt_with_config};

use serde::Deserialize;

#[derive(Debug, Clone, Deserialize, Default)]
pub struct AiConfig {
pub openai: Option<ModelConfig>,
pub anthropic: Option<ModelConfig>,
pub deepseek: Option<ModelConfig>,
const DEEPSEEK_API_ENDPOINT: &str = "https://api.deepseek.com/v1";
const DEEPSEEK_V3: &str = "deepseek-chat";

#[derive(Debug, Clone, Deserialize)]
pub enum AiConfig {
#[serde(rename = "openai")]
OpenAI(ModelConfig),
#[serde(rename = "anthropic")]
Anthropic(ModelConfig),
#[serde(rename = "deepseek")]
DeepSeek(ModelConfig),
}

#[derive(Debug, Clone, Deserialize)]
Expand All @@ -22,10 +28,60 @@ pub struct ModelConfig {
pub model: Option<String>,
}

impl AiConfig {
pub(crate) fn take_model(&mut self) -> Option<String> {
match self {
Self::OpenAI(ModelConfig { model, .. }) => model.take(),
Self::Anthropic(ModelConfig { model, .. }) => model.take(),
Self::DeepSeek(ModelConfig { model, .. }) => model.take(),
}
}

pub(crate) fn set_model(&mut self, m: String) {
match self {
Self::OpenAI(ModelConfig { model, .. }) => model.replace(m),
Self::Anthropic(ModelConfig { model, .. }) => model.replace(m),
Self::DeepSeek(ModelConfig { model, .. }) => model.replace(m),
};
}
}

#[allow(unused)]
pub(crate) fn openai_client() -> OpenAIClient {
OpenAIClient::builder()
.with_api_key(env::var("OPENAI_API_KEY").unwrap().to_string())
.build()
.unwrap()
pub(crate) fn openai_client(config: Option<&AiConfig>) -> OpenAIClient {
match config {
None => OpenAIClient::builder()
.with_api_key(env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set"))
.build()
.unwrap(),
Some(AiConfig::OpenAI(model_config)) => {
let api_key = if model_config.api_key.is_empty() {
env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set")
} else {
model_config.api_key.clone()
};
OpenAIClient::builder()
.with_api_key(api_key)
.build()
.unwrap()
}
Some(AiConfig::DeepSeek(ModelConfig { api_key, .. })) => OpenAIClient::builder()
.with_endpoint(DEEPSEEK_API_ENDPOINT)
.with_api_key(api_key)
.build()
.unwrap(),
Some(AiConfig::Anthropic(_)) => unimplemented!("Anthropic API not yet supported"),
}
}

pub(crate) fn default_model(config: Option<&AiConfig>) -> String {
match config {
None => GPT3_5_TURBO.to_string(),
Some(AiConfig::OpenAI(ModelConfig { model, .. })) => {
model.clone().unwrap_or(GPT3_5_TURBO.to_string())
}
Some(AiConfig::DeepSeek(ModelConfig { model, .. })) => {
model.clone().unwrap_or(DEEPSEEK_V3.to_string())
}
Some(AiConfig::Anthropic(_)) => unimplemented!("Anthropic API not yet supported"),
}
}
41 changes: 26 additions & 15 deletions aiscript-vm/src/ai/prompt.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
use openai_api_rs::v1::common::GPT3_5_TURBO;
use tokio::runtime::Handle;

use super::{AiConfig, ModelConfig, default_model};

pub struct PromptConfig {
pub input: String,
pub model: Option<String>,
pub ai_config: Option<AiConfig>,
pub max_tokens: Option<i64>,
pub temperature: Option<f64>,
pub system_prompt: Option<String>,
Expand All @@ -13,27 +15,42 @@ impl Default for PromptConfig {
fn default() -> Self {
Self {
input: String::new(),
model: Some(GPT3_5_TURBO.to_string()),
ai_config: Some(AiConfig::OpenAI(ModelConfig {
api_key: Default::default(),
model: Some(GPT3_5_TURBO.to_string()),
})),
max_tokens: Default::default(),
temperature: Default::default(),
system_prompt: Default::default(),
}
}
}

impl PromptConfig {
fn take_model(&mut self) -> String {
self.ai_config
.as_mut()
.and_then(|config| config.take_model())
.unwrap_or_else(|| default_model(self.ai_config.as_ref()))
}

pub(crate) fn set_model(&mut self, model: String) {
if let Some(config) = self.ai_config.as_mut() {
config.set_model(model);
}
}
}

#[cfg(feature = "ai_test")]
async fn _prompt_with_config(config: PromptConfig) -> String {
return format!("AI: {}", config.input);
}

#[cfg(not(feature = "ai_test"))]
async fn _prompt_with_config(mut config: PromptConfig) -> String {
use openai_api_rs::v1::{
chat_completion::{self, ChatCompletionRequest},
common::GPT3_5_TURBO,
};

let mut client = super::openai_client();
use openai_api_rs::v1::chat_completion::{self, ChatCompletionRequest};
let mut client = super::openai_client(config.ai_config.as_ref());
let model = config.take_model();

// Create system message if provided
let mut messages = Vec::new();
Expand All @@ -57,13 +74,7 @@ async fn _prompt_with_config(mut config: PromptConfig) -> String {
});

// Build the request
let mut req = ChatCompletionRequest::new(
config
.model
.take()
.unwrap_or_else(|| GPT3_5_TURBO.to_string()),
messages,
);
let mut req = ChatCompletionRequest::new(model, messages);

if let Some(max_tokens) = config.max_tokens {
req.max_tokens = Some(max_tokens);
Expand Down
16 changes: 11 additions & 5 deletions aiscript-vm/src/vm/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1013,15 +1013,21 @@ impl<'gc> State<'gc> {
let result = match value {
// Simple string case
Value::String(s) => {
let input = s.to_str().unwrap().to_string();
ai::prompt_with_config(PromptConfig {
input,
let mut config = PromptConfig {
input: s.to_str().unwrap().to_string(),
..Default::default()
})
};
if let Some(ai_cfg) = &self.ai_config {
config.ai_config = Some(ai_cfg.clone());
}
ai::prompt_with_config(config)
}
// Object config case
Value::Object(obj) => {
let mut config = PromptConfig::default();
if let Some(ai_cfg) = &self.ai_config {
config.ai_config = Some(ai_cfg.clone());
}
let obj_ref = obj.borrow();

// Extract input (required)
Expand All @@ -1039,7 +1045,7 @@ impl<'gc> State<'gc> {
if let Some(Value::String(model)) =
obj_ref.fields.get(&self.intern(b"model"))
{
config.model = Some(model.to_str().unwrap().to_string());
config.set_model(model.to_str().unwrap().to_string());
}

// Extract max_tokens (optional)
Expand Down
Loading