Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: gemini system instruction mapping #371

Merged
merged 2 commits into from
May 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 107 additions & 62 deletions src/providers/google-vertex-ai/chatComplete.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,85 +25,130 @@ export const GoogleChatCompleteConfig: ProviderConfig = {
required: true,
default: 'gemini-1.0-pro',
},
messages: {
param: 'contents',
default: '',
transform: (params: Params) => {
let lastRole: 'user' | 'model' | undefined;
const messages: { role: string; parts: { text: string }[] }[] = [];
messages: [
{
param: 'contents',
default: '',
transform: (params: Params) => {
let lastRole: 'user' | 'model' | undefined;
const messages: { role: string; parts: { text: string }[] }[] = [];

params.messages?.forEach((message: Message) => {
const role = message.role === 'assistant' ? 'model' : 'user';
let parts = [];
if (typeof message.content === 'string') {
parts.push({
text: message.content,
});
}
params.messages?.forEach((message: Message) => {
if (message.role === 'system') return;

if (message.content && typeof message.content === 'object') {
message.content.forEach((c: ContentType) => {
if (c.type === 'text') {
parts.push({
text: c.text,
});
}
if (c.type === 'image_url') {
const { url } = c.image_url || {};
const role = message.role === 'assistant' ? 'model' : 'user';

let parts = [];
if (typeof message.content === 'string') {
parts.push({
text: message.content,
});
}

if (!url) {
// Shouldn't throw error?
return;
if (message.content && typeof message.content === 'object') {
message.content.forEach((c: ContentType) => {
if (c.type === 'text') {
parts.push({
text: c.text,
});
}
if (c.type === 'image_url') {
const { url } = c.image_url || {};

if (!url) {
// Shouldn't throw error?
return;
}

// Example: data:image/png;base64,abcdefg...
if (url.startsWith('data:')) {
const [mimeTypeWithPrefix, base64Image] =
url.split(';base64,');
const mimeType = mimeTypeWithPrefix.split(':')[1];

// Example: data:image/png;base64,abcdefg...
if (url.startsWith('data:')) {
const [mimeTypeWithPrefix, base64Image] = url.split(';base64,');
const mimeType = mimeTypeWithPrefix.split(':')[1];
parts.push({
inlineData: {
mimeType: mimeType,
data: base64Image,
},
});

return;
}

// This part is problematic because URLs are not supported in the current implementation.
// Two problems exist:
// 1. Only Google Cloud Storage URLs are supported.
// 2. MimeType is not supported in OpenAI API, but it is required in Google Vertex AI API.
// Google will return an error here if any other URL is provided.
parts.push({
inlineData: {
mimeType: mimeType,
data: base64Image,
fileData: {
mimeType: 'image/jpeg',
fileUri: url,
},
});

return;
}
});
}

// This part is problematic because URLs are not supported in the current implementation.
// Two problems exist:
// 1. Only Google Cloud Storage URLs are supported.
// 2. MimeType is not supported in OpenAI API, but it is required in Google Vertex AI API.
// Google will return an error here if any other URL is provided.
parts.push({
fileData: {
mimeType: 'image/jpeg',
fileUri: url,
},
});
}
});
}
// @NOTE: This takes care of the "Please ensure that multiturn requests alternate between user and model."
// error that occurs when we have multiple user messages in a row.
const shouldAppendEmptyModeChat =
lastRole === 'user' &&
role === 'user' &&
!params.model?.includes('vision');

if (shouldAppendEmptyModeChat) {
messages.push({ role: 'model', parts: [{ text: '' }] });
}

// @NOTE: This takes care of the "Please ensure that multiturn requests alternate between user and model."
// error that occurs when we have multiple user messages in a row.
const shouldAppendEmptyModeChat =
lastRole === 'user' &&
role === 'user' &&
!params.model?.includes('vision');
messages.push({ role, parts });
lastRole = role;
});

if (shouldAppendEmptyModeChat) {
messages.push({ role: 'model', parts: [{ text: '' }] });
return messages;
},
},
{
param: 'systemInstruction',
default: '',
transform: (params: Params) => {
const firstMessage = params.messages?.[0] || null;
if (!firstMessage) return;

if (
firstMessage.role === 'system' &&
typeof firstMessage.content === 'string'
) {
return {
parts: [
{
text: firstMessage.content,
},
],
role: 'system',
};
}

messages.push({ role, parts });
lastRole = role;
});
if (
firstMessage.role === 'system' &&
typeof firstMessage.content === 'object' &&
firstMessage.content?.[0]?.text
) {
return {
parts: [
{
text: firstMessage.content?.[0].text,
},
],
role: 'system',
};
}

return messages;
return;
},
},
},
],
temperature: {
param: 'generationConfig',
transform: (params: Params) => transformGenerationConfig(params),
Expand Down
153 changes: 109 additions & 44 deletions src/providers/google/chatComplete.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,16 @@ const transformGenerationConfig = (params: Params) => {
return generationConfig;
};

// models for which systemInstruction is not supported
const SYSTEM_INSTRUCTION_DISABLED_MODELS = [
'gemini-1.0-pro',
'gemini-1.0-pro-001',
'gemini-1.0-pro-latest',
'gemini-1.0-pro-vision-latest',
'gemini-pro',
'gemini-pro-vision',
];

// TODOS: this configuration does not enforce the maximum token limit for the input parameter. If you want to enforce this, you might need to add a custom validation function or a max property to the ParameterConfig interface, and then use it in the input configuration. However, this might be complex because the token count is not a simple length check, but depends on the specific tokenization method used by the model.

export const GoogleChatCompleteConfig: ProviderConfig = {
Expand All @@ -38,57 +48,112 @@ export const GoogleChatCompleteConfig: ProviderConfig = {
required: true,
default: 'gemini-pro',
},
messages: {
param: 'contents',
default: '',
transform: (params: Params) => {
let lastRole: 'user' | 'model' | undefined;
const messages: { role: string; parts: { text: string }[] }[] = [];
messages: [
{
param: 'contents',
default: '',
transform: (params: Params) => {
let lastRole: 'user' | 'model' | 'system' | undefined;
const messages: { role: string; parts: { text: string }[] }[] = [];

params.messages?.forEach((message: Message) => {
const role = message.role === 'assistant' ? 'model' : 'user';
let parts = [];
if (typeof message.content === 'string') {
parts.push({
text: message.content,
});
}
params.messages?.forEach((message: Message) => {
// From gemini-1.5 onwards, systemInstruction is supported
// Skipping system message and sending it in systemInstruction for gemini 1.5 models
if (
message.role === 'system' &&
!SYSTEM_INSTRUCTION_DISABLED_MODELS.includes(params.model as string)
)
return;

if (message.content && typeof message.content === 'object') {
message.content.forEach((c: ContentType) => {
if (c.type === 'text') {
parts.push({
text: c.text,
});
}
if (c.type === 'image_url') {
parts.push({
inlineData: {
mimeType: 'image/jpeg',
data: c.image_url?.url,
},
});
}
});
}
const role = message.role === 'assistant' ? 'model' : 'user';
let parts = [];
if (typeof message.content === 'string') {
parts.push({
text: message.content,
});
}

// @NOTE: This takes care of the "Please ensure that multiturn requests alternate between user and model."
// error that occurs when we have multiple user messages in a row.
const shouldAppendEmptyModeChat =
lastRole === 'user' &&
role === 'user' &&
!params.model?.includes('vision');
if (message.content && typeof message.content === 'object') {
message.content.forEach((c: ContentType) => {
if (c.type === 'text') {
parts.push({
text: c.text,
});
}
if (c.type === 'image_url') {
parts.push({
inlineData: {
mimeType: 'image/jpeg',
data: c.image_url?.url,
},
});
}
});
}

if (shouldAppendEmptyModeChat) {
messages.push({ role: 'model', parts: [{ text: '' }] });
// @NOTE: This takes care of the "Please ensure that multiturn requests alternate between user and model."
// error that occurs when we have multiple user messages in a row.
const shouldAppendEmptyModeChat =
lastRole === 'user' &&
role === 'user' &&
!params.model?.includes('vision');

if (shouldAppendEmptyModeChat) {
messages.push({ role: 'model', parts: [{ text: '' }] });
}

messages.push({ role, parts });
lastRole = role;
});

return messages;
},
},
{
param: 'systemInstruction',
default: '',
transform: (params: Params) => {
// systemInstruction is only supported from gemini 1.5 models
if (SYSTEM_INSTRUCTION_DISABLED_MODELS.includes(params.model as string))
return;

const firstMessage = params.messages?.[0] || null;

if (!firstMessage) return;

if (
firstMessage.role === 'system' &&
typeof firstMessage.content === 'string'
) {
return {
parts: [
{
text: firstMessage.content,
},
],
role: 'system',
};
}

messages.push({ role, parts });
lastRole = role;
});
return messages;
if (
firstMessage.role === 'system' &&
typeof firstMessage.content === 'object' &&
firstMessage.content?.[0]?.text
) {
return {
parts: [
{
text: firstMessage.content?.[0].text,
},
],
role: 'system',
};
}

return;
},
},
},
],
temperature: {
param: 'generationConfig',
transform: (params: Params) => transformGenerationConfig(params),
Expand Down
Loading