Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 54 additions & 52 deletions src/puter-js/src/modules/AI.js
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,10 @@ class AI{
// if string, it's treated as the prompt which is a shorthand for { messages: [{ content: prompt }] }
// if object, it's treated as the full argument object that the API expects
chat = async (...args) => {
let options = {};
let settings = {};
// requestParams: parameters that will be sent to the backend driver
let requestParams = {};
// userParams: parameters provided by the user in the function call
let userParams = {};
let testMode = false;

// default driver is openai-completion
Expand All @@ -185,12 +187,12 @@ class AI{

// ai.chat(prompt)
if(typeof args[0] === 'string'){
options = { messages: [{ content: args[0] }] };
requestParams = { messages: [{ content: args[0] }] };
}

// ai.chat(prompt, testMode)
if (typeof args[0] === 'string' && (!args[1] || typeof args[1] === 'boolean')) {
options = { messages: [{ content: args[0] }] };
requestParams = { messages: [{ content: args[0] }] };
}

// ai.chat(prompt, imageURL/File)
Expand All @@ -202,7 +204,7 @@ class AI{
}

// parse args[1] as an image_url object
options = {
requestParams = {
vision: true,
messages: [
{
Expand All @@ -224,7 +226,7 @@ class AI{
for (let i = 0; i < args[1].length; i++) {
args[1][i] = { image_url: { url: args[1][i] } };
}
options = {
requestParams = {
vision: true,
messages: [
{
Expand All @@ -238,7 +240,7 @@ class AI{
}
// chat([messages])
else if (Array.isArray(args[0])) {
options = { messages: args[0] };
requestParams = { messages: args[0] };
}

// determine if testMode is enabled
Expand All @@ -248,62 +250,62 @@ class AI{
testMode = true;
}

// if any of the args is an object, assume it's the settings object
// if any of the args is an object, assume it's the user parameters object
const is_object = v => {
return typeof v === 'object' &&
!Array.isArray(v) &&
v !== null;
};
for (let i = 0; i < args.length; i++) {
if (is_object(args[i])) {
settings = args[i];
userParams = args[i];
break;
}
}


// does settings contain `model`? add it to options
if (settings.model) {
options.model = settings.model;
// Copy relevant parameters from userParams to requestParams
if (userParams.model) {
requestParams.model = userParams.model;
}
if (settings.temperature) {
options.temperature = settings.temperature;
if (userParams.temperature) {
requestParams.temperature = userParams.temperature;
}
if (settings.max_tokens) {
options.max_tokens = settings.max_tokens;
if (userParams.max_tokens) {
requestParams.max_tokens = userParams.max_tokens;
}

// convert to the correct model name if necessary
if( options.model === 'claude-3-5-sonnet'){
options.model = 'claude-3-5-sonnet-latest';
if( requestParams.model === 'claude-3-5-sonnet'){
requestParams.model = 'claude-3-5-sonnet-latest';
}
if( options.model === 'claude-3-7-sonnet' || options.model === 'claude'){
options.model = 'claude-3-7-sonnet-latest';
if( requestParams.model === 'claude-3-7-sonnet' || requestParams.model === 'claude'){
requestParams.model = 'claude-3-7-sonnet-latest';
}
if ( options.model === 'mistral' ) {
options.model = 'mistral-large-latest';
if ( requestParams.model === 'mistral' ) {
requestParams.model = 'mistral-large-latest';
}
if ( options.model === 'groq' ) {
options.model = 'llama3-8b-8192';
if ( requestParams.model === 'groq' ) {
requestParams.model = 'llama3-8b-8192';
}
if ( options.model === 'deepseek' ) {
options.model = 'deepseek-chat';
if ( requestParams.model === 'deepseek' ) {
requestParams.model = 'deepseek-chat';
}

// map model to the appropriate driver
if (!options.model || options.model === 'gpt-4o' || options.model === 'gpt-4o-mini') {
if (!requestParams.model || requestParams.model === 'gpt-4o' || requestParams.model === 'gpt-4o-mini') {
driver = 'openai-completion';
}else if(
options.model === 'claude-3-haiku-20240307' ||
options.model === 'claude-3-5-sonnet-20240620' ||
options.model === 'claude-3-5-sonnet-20241022' ||
options.model === 'claude-3-5-sonnet-latest' ||
options.model === 'claude-3-7-sonnet-latest'
requestParams.model === 'claude-3-haiku-20240307' ||
requestParams.model === 'claude-3-5-sonnet-20240620' ||
requestParams.model === 'claude-3-5-sonnet-20241022' ||
requestParams.model === 'claude-3-5-sonnet-latest' ||
requestParams.model === 'claude-3-7-sonnet-latest'
){
driver = 'claude';
}else if(options.model === 'meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo' || options.model === 'meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo' || options.model === 'meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo' || options.model === `google/gemma-2-27b-it`){
}else if(requestParams.model === 'meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo' || requestParams.model === 'meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo' || requestParams.model === 'meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo' || requestParams.model === `google/gemma-2-27b-it`){
driver = 'together-ai';
}else if(options.model === 'mistral-large-latest' || options.model === 'codestral-latest'){
}else if(requestParams.model === 'mistral-large-latest' || requestParams.model === 'codestral-latest'){
driver = 'mistral';
}else if([
"distil-whisper-large-v3-en",
Expand All @@ -318,41 +320,41 @@ class AI{
"llama-guard-3-8b",
"mixtral-8x7b-32768",
"whisper-large-v3"
].includes(options.model)) {
].includes(requestParams.model)) {
driver = 'groq';
}else if(options.model === 'grok-beta') {
}else if(requestParams.model === 'grok-beta') {
driver = 'xai';
}
else if(
options.model === 'deepseek-chat' ||
options.model === 'deepseek-reasoner'
requestParams.model === 'deepseek-chat' ||
requestParams.model === 'deepseek-reasoner'
){
driver = 'deepseek';
}
else if(
options.model === 'gemini-1.5-flash' ||
options.model === 'gemini-2.0-flash'
requestParams.model === 'gemini-1.5-flash' ||
requestParams.model === 'gemini-2.0-flash'
){
driver = 'gemini';
}
else if ( options.model.startsWith('openrouter:') ) {
else if ( requestParams.model.startsWith('openrouter:') ) {
driver = 'openrouter';
}

// stream flag from settings
if(settings.stream !== undefined && typeof settings.stream === 'boolean'){
options.stream = settings.stream;
// stream flag from userParams
if(userParams.stream !== undefined && typeof userParams.stream === 'boolean'){
requestParams.stream = userParams.stream;
}

if ( settings.driver ) {
driver = settings.driver;
if ( userParams.driver ) {
driver = userParams.driver;
}

// settings to pass
const SETTINGS_TO_PASS = ['tools', 'response'];
for ( const name of SETTINGS_TO_PASS ) {
if ( settings[name] ) {
options[name] = settings[name];
// Additional parameters to pass from userParams to requestParams
const PARAMS_TO_PASS = ['tools', 'response'];
for ( const name of PARAMS_TO_PASS ) {
if ( userParams[name] ) {
requestParams[name] = userParams[name];
}
}

Expand All @@ -370,7 +372,7 @@ class AI{

return result;
}
}).call(this, options);
}).call(this, requestParams);
}

txt2img = async (...args) => {
Expand Down