Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add support for Google Gemini #115

Merged
merged 4 commits into from
Dec 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions src/providers/google/api.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
export interface GoogleFetchPayload {
apiKey: string
body: Record<string, any>
model?: string
}

export const fetchChatCompletion = async(payload: GoogleFetchPayload) => {
const { apiKey, body, model } = payload || {}
const initOptions = {
headers: { 'Content-Type': 'application/json' },
method: 'POST',
body: JSON.stringify({ ...body }),
}
return fetch(`https://generativelanguage.googleapis.com/v1/models/${model}:generateContent?key=${apiKey}`, initOptions)
}
71 changes: 71 additions & 0 deletions src/providers/google/handler.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import { fetchChatCompletion } from './api'
import { parseMessageList } from './parser'
import type { Message } from '@/types/message'
import type { HandlerPayload, Provider } from '@/types/provider'

export const handlePrompt: Provider['handlePrompt'] = async(payload, signal?: AbortSignal) => {
if (payload.botId === 'chat_continuous')
return handleChatCompletion(payload, signal)
if (payload.botId === 'chat_single')
return handleChatCompletion(payload, signal)
}

export const handleRapidPrompt: Provider['handleRapidPrompt'] = async(prompt, globalSettings) => {
const rapidPromptPayload = {
conversationId: 'temp',
conversationType: 'chat_single',
botId: 'temp',
globalSettings: {
...globalSettings,
model: 'gemini-pro',
},
botSettings: {},
prompt,
messages: { contents: [{ role: 'user', parts: [{ text: prompt }] }] },
} as unknown as HandlerPayload
const result = await handleChatCompletion(rapidPromptPayload)
if (typeof result === 'string')
return result
return ''
}

export const handleChatCompletion = async(payload: HandlerPayload, signal?: AbortSignal) => {
// An array to store the chat messages
const messages: Message[] = []

let maxTokens = payload.globalSettings.maxTokens as number
let messageHistorySize = payload.globalSettings.messageHistorySize as number

// Iterate through the message history
while (messageHistorySize > 0) {
messageHistorySize--
// Get the last message from the payload
const m = payload.messages.pop()
if (m === undefined)
break

if (maxTokens - m.content.length < 0)
break

maxTokens -= m.content.length
messages.unshift(m)
}

const response = await fetchChatCompletion({
apiKey: payload.globalSettings.apiKey as string,
body: {
contents: parseMessageList(messages),
},
model: payload.globalSettings.model as string,
})

if (response.ok) {
const json = await response.json()
// console.log('json', json)
const output = json.candidates[0].content.parts[0].text || json
return output as string
}

const text = await response.text()
throw new Error(`Failed to fetch chat completion: ${text}`)
}
70 changes: 70 additions & 0 deletions src/providers/google/index.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import {
handlePrompt,
handleRapidPrompt,
} from './handler'
import type { Provider } from '@/types/provider'

const providerGoogle = () => {
const provider: Provider = {
id: 'provider-google',
icon: 'i-simple-icons-google', // @unocss-include
name: 'Google',
globalSettings: [
{
key: 'apiKey',
name: 'API Key',
type: 'api-key',
},
{
key: 'model',
name: 'Google model',
description: 'Custom model for Google API.',
type: 'select',
options: [
{ value: 'gemini-pro', label: 'gemini-pro' },
],
default: 'gemini-pro',
},
{
key: 'maxTokens',
name: 'Max Tokens',
description: 'The maximum number of tokens to generate in the completion.',
type: 'slider',
min: 0,
max: 32768,
default: 2048,
step: 1,
},
{
key: 'messageHistorySize',
name: 'Max History Message Size',
description: 'The number of retained historical messages will be truncated if the length of the message exceeds the MaxToken parameter.',
type: 'slider',
min: 1,
max: 24,
default: 5,
step: 1,
},
],
bots: [
{
id: 'chat_continuous',
type: 'chat_continuous',
name: 'Continuous Chat',
settings: [],
},
{
id: 'chat_single',
type: 'chat_single',
name: 'Single Chat',
settings: [],
},

],
handlePrompt,
handleRapidPrompt,
}
return provider
}

export default providerGoogle
35 changes: 35 additions & 0 deletions src/providers/google/parser.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import type { Message } from '@/types/message'

export const parseMessageList = (rawList: Message[]) => {
interface GoogleGeminiMessage {
role: 'user' | 'model'
// TODO: Add support for image input
parts: [
{ text: string },
]
}

if (rawList.length === 0)
return []

const parsedList: GoogleGeminiMessage[] = []
// if first message is system message, insert an empty message after it
if (rawList[0].role === 'system') {
parsedList.push({ role: 'user', parts: [{ text: rawList[0].content }] })
parsedList.push({ role: 'model', parts: [{ text: 'OK.' }] })
}
// covert other messages
const roleDict = {
user: 'user',
assistant: 'model',
} as const
for (const message of rawList) {
if (message.role === 'system')
continue
parsedList.push({
role: roleDict[message.role],
parts: [{ text: message.content }],
})
}
return parsedList
}
2 changes: 2 additions & 0 deletions src/stores/provider.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import providerOpenAI from '@/providers/openai'
import providerAzure from '@/providers/azure'
import providerGoogle from '@/providers/google'
import providerReplicate from '@/providers/replicate'
import { allConversationTypes } from '@/types/conversation'
import type { BotMeta } from '@/types/app'
Expand All @@ -8,6 +9,7 @@ export const providerList = [
providerOpenAI(),
providerAzure(),
providerReplicate(),
providerGoogle(),
]

export const providerMetaList = providerList.map(provider => ({
Expand Down