Skip to content

Commit

Permalink
Merge pull request #115 from LyuLumos/main
Browse files Browse the repository at this point in the history
feat: Add support for Google Gemini
  • Loading branch information
ddiu8081 committed Dec 19, 2023
2 parents 43169a4 + 29f4521 commit fbd4fb2
Show file tree
Hide file tree
Showing 5 changed files with 193 additions and 0 deletions.
15 changes: 15 additions & 0 deletions src/providers/google/api.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
export interface GoogleFetchPayload {
apiKey: string
body: Record<string, any>
model?: string
}

export const fetchChatCompletion = async(payload: GoogleFetchPayload) => {
const { apiKey, body, model } = payload || {}
const initOptions = {
headers: { 'Content-Type': 'application/json' },
method: 'POST',
body: JSON.stringify({ ...body }),
}
return fetch(`https://generativelanguage.googleapis.com/v1/models/${model}:generateContent?key=${apiKey}`, initOptions)
}
71 changes: 71 additions & 0 deletions src/providers/google/handler.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import { fetchChatCompletion } from './api'
import { parseMessageList } from './parser'
import type { Message } from '@/types/message'
import type { HandlerPayload, Provider } from '@/types/provider'

export const handlePrompt: Provider['handlePrompt'] = async(payload, signal?: AbortSignal) => {
if (payload.botId === 'chat_continuous')
return handleChatCompletion(payload, signal)
if (payload.botId === 'chat_single')
return handleChatCompletion(payload, signal)
}

export const handleRapidPrompt: Provider['handleRapidPrompt'] = async(prompt, globalSettings) => {
const rapidPromptPayload = {
conversationId: 'temp',
conversationType: 'chat_single',
botId: 'temp',
globalSettings: {
...globalSettings,
model: 'gemini-pro',
},
botSettings: {},
prompt,
messages: { contents: [{ role: 'user', parts: [{ text: prompt }] }] },
} as unknown as HandlerPayload
const result = await handleChatCompletion(rapidPromptPayload)
if (typeof result === 'string')
return result
return ''
}

export const handleChatCompletion = async(payload: HandlerPayload, signal?: AbortSignal) => {
// An array to store the chat messages
const messages: Message[] = []

let maxTokens = payload.globalSettings.maxTokens as number
let messageHistorySize = payload.globalSettings.messageHistorySize as number

// Iterate through the message history
while (messageHistorySize > 0) {
messageHistorySize--
// Get the last message from the payload
const m = payload.messages.pop()
if (m === undefined)
break

if (maxTokens - m.content.length < 0)
break

maxTokens -= m.content.length
messages.unshift(m)
}

const response = await fetchChatCompletion({
apiKey: payload.globalSettings.apiKey as string,
body: {
contents: parseMessageList(messages),
},
model: payload.globalSettings.model as string,
})

if (response.ok) {
const json = await response.json()
// console.log('json', json)
const output = json.candidates[0].content.parts[0].text || json
return output as string
}

const text = await response.text()
throw new Error(`Failed to fetch chat completion: ${text}`)
}
70 changes: 70 additions & 0 deletions src/providers/google/index.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import {
handlePrompt,
handleRapidPrompt,
} from './handler'
import type { Provider } from '@/types/provider'

const providerGoogle = () => {
const provider: Provider = {
id: 'provider-google',
icon: 'i-simple-icons-google', // @unocss-include
name: 'Google',
globalSettings: [
{
key: 'apiKey',
name: 'API Key',
type: 'api-key',
},
{
key: 'model',
name: 'Google model',
description: 'Custom model for Google API.',
type: 'select',
options: [
{ value: 'gemini-pro', label: 'gemini-pro' },
],
default: 'gemini-pro',
},
{
key: 'maxTokens',
name: 'Max Tokens',
description: 'The maximum number of tokens to generate in the completion.',
type: 'slider',
min: 0,
max: 32768,
default: 2048,
step: 1,
},
{
key: 'messageHistorySize',
name: 'Max History Message Size',
description: 'The number of retained historical messages will be truncated if the length of the message exceeds the MaxToken parameter.',
type: 'slider',
min: 1,
max: 24,
default: 5,
step: 1,
},
],
bots: [
{
id: 'chat_continuous',
type: 'chat_continuous',
name: 'Continuous Chat',
settings: [],
},
{
id: 'chat_single',
type: 'chat_single',
name: 'Single Chat',
settings: [],
},

],
handlePrompt,
handleRapidPrompt,
}
return provider
}

export default providerGoogle
35 changes: 35 additions & 0 deletions src/providers/google/parser.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import type { Message } from '@/types/message'

export const parseMessageList = (rawList: Message[]) => {
interface GoogleGeminiMessage {
role: 'user' | 'model'
// TODO: Add support for image input
parts: [
{ text: string },
]
}

if (rawList.length === 0)
return []

const parsedList: GoogleGeminiMessage[] = []
// if first message is system message, insert an empty message after it
if (rawList[0].role === 'system') {
parsedList.push({ role: 'user', parts: [{ text: rawList[0].content }] })
parsedList.push({ role: 'model', parts: [{ text: 'OK.' }] })
}
// covert other messages
const roleDict = {
user: 'user',
assistant: 'model',
} as const
for (const message of rawList) {
if (message.role === 'system')
continue
parsedList.push({
role: roleDict[message.role],
parts: [{ text: message.content }],
})
}
return parsedList
}
2 changes: 2 additions & 0 deletions src/stores/provider.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import providerOpenAI from '@/providers/openai'
import providerAzure from '@/providers/azure'
import providerGoogle from '@/providers/google'
import providerReplicate from '@/providers/replicate'
import { allConversationTypes } from '@/types/conversation'
import type { BotMeta } from '@/types/app'
Expand All @@ -8,6 +9,7 @@ export const providerList = [
providerOpenAI(),
providerAzure(),
providerReplicate(),
providerGoogle(),
]

export const providerMetaList = providerList.map(provider => ({
Expand Down

0 comments on commit fbd4fb2

Please sign in to comment.