Skip to content

Commit

Permalink
support vertex ai
Browse files Browse the repository at this point in the history
  • Loading branch information
flexchar committed Mar 29, 2024
1 parent 4878580 commit 5d30f4c
Show file tree
Hide file tree
Showing 8 changed files with 484 additions and 4 deletions.
2 changes: 2 additions & 0 deletions src/globals.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ export const ANYSCALE: string = 'anyscale';
export const PALM: string = 'palm';
export const TOGETHER_AI: string = 'together-ai';
export const GOOGLE: string = 'google';
export const GOOGLE_VERTEX_AI: string = 'vertex-ai';
export const PERPLEXITY_AI: string = 'perplexity-ai';
export const MISTRAL_AI: string = 'mistral-ai';
export const DEEPINFRA: string = 'deepinfra';
Expand All @@ -47,6 +48,7 @@ export const VALID_PROVIDERS = [
AZURE_OPEN_AI,
COHERE,
GOOGLE,
GOOGLE_VERTEX_AI,
MISTRAL_AI,
OPEN_AI,
PALM,
Expand Down
16 changes: 14 additions & 2 deletions src/handlers/handlerUtils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,10 @@ export const fetchProviderOptionsFromConfig = (
providerOptions[0].deploymentId = camelCaseConfig.deploymentId;
if (camelCaseConfig.apiVersion)
providerOptions[0].apiVersion = camelCaseConfig.apiVersion;
if (camelCaseConfig.apiVersion)
providerOptions[0].vertexProjectId = camelCaseConfig.vertexProjectId;
if (camelCaseConfig.apiVersion)
providerOptions[0].vertexRegion = camelCaseConfig.vertexRegion;
mode = 'single';
} else {
if (camelCaseConfig.strategy && camelCaseConfig.strategy.mode) {
Expand Down Expand Up @@ -391,7 +395,11 @@ export async function tryPostProxy(
c.set('requestOptions', [
...requestOptions,
{
providerOptions: { ...providerOption, requestURL: url, rubeusURL: fn },
providerOptions: {
...providerOption,
requestURL: url,
rubeusURL: fn,
},
requestParams: params,
response: mappedResponse.clone(),
cacheStatus: cacheStatus,
Expand Down Expand Up @@ -585,7 +593,11 @@ export async function tryPost(
c.set('requestOptions', [
...requestOptions,
{
providerOptions: { ...providerOption, requestURL: url, rubeusURL: fn },
providerOptions: {
...providerOption,
requestURL: url,
rubeusURL: fn,
},
requestParams: transformedRequestBody,
response: mappedResponse.clone(),
cacheStatus: cacheStatus,
Expand Down
21 changes: 19 additions & 2 deletions src/middlewares/requestValidator/schema/config.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { z } from 'zod';
import { OLLAMA, VALID_PROVIDERS } from '../../../globals';
import { OLLAMA, VALID_PROVIDERS, GOOGLE_VERTEX_AI } from '../../../globals';

export const configSchema: any = z
.object({
Expand All @@ -20,7 +20,9 @@ export const configSchema: any = z
provider: z
.string()
.refine((value) => VALID_PROVIDERS.includes(value), {
message: `Invalid 'provider' value. Must be one of: ${VALID_PROVIDERS.join(', ')}`,
message: `Invalid 'provider' value. Must be one of: ${VALID_PROVIDERS.join(
', '
)}`,
})
.optional(),
api_key: z.string().optional(),
Expand Down Expand Up @@ -57,6 +59,9 @@ export const configSchema: any = z
request_timeout: z.number().optional(),
custom_host: z.string().optional(),
forward_headers: z.array(z.string()).optional(),
// Google Vertex AI specific
vertex_project_id: z.string().optional(),
vertex_region: z.string().optional(),
})
.refine(
(value) => {
Expand Down Expand Up @@ -94,4 +99,16 @@ export const configSchema: any = z
{
message: 'Invalid custom host',
}
)
// Validate Google Vertex AI specific fields
.refine(
(value) => {
const isGoogleVertexAIProvider = value.provider === GOOGLE_VERTEX_AI;
const hasGoogleVertexAIFields =
value.vertex_project_id && value.vertex_region;
return !(isGoogleVertexAIProvider && !hasGoogleVertexAIFields);
},
{
message: `Invalid configuration. 'vertex_project_id' and 'vertex_region' are required for '${GOOGLE_VERTEX_AI}' provider. Example: { 'provider': 'vertex-ai', 'vertex_project_id': 'my-project-id', 'vertex_region': 'us-central1', api_key: 'ya29...' }`,
}
);
42 changes: 42 additions & 0 deletions src/providers/google-vertex-ai/api.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import { ProviderAPIConfig } from '../types';

// Good reference for using REST: https://cloud.google.com/vertex-ai/generative-ai/docs/start/quickstarts/quickstart-multimodal#gemini-beginner-samples-drest
// Difference versus Studio AI: https://cloud.google.com/vertex-ai/docs/start/ai-platform-users
export const GoogleApiConfig: ProviderAPIConfig = {
getBaseURL: ({ providerOptions }) => {
const { vertexProjectId, vertexRegion } = providerOptions;

return `https://${vertexRegion}-aiplatform.googleapis.com/v1/projects/${vertexProjectId}/locations/${vertexRegion}/publishers/google`;
},
headers: ({ providerOptions }) => {
const { apiKey } = providerOptions;

return {
'Content-Type': 'application/json',
Authorization: `Bearer ${apiKey}`,
};
},
getEndpoint: ({ fn, gatewayRequestBody }) => {
let mappedFn = fn;
const { model, stream } = gatewayRequestBody;
if (stream) {
mappedFn = `stream-${fn}`;
}
switch (mappedFn) {
case 'chatComplete': {
return `/models/${model}:generateContent`;
}
case 'stream-chatComplete': {
return `/models/${model}:streamGenerateContent`;
}

// Embed API is not yet implemented in the gateway
// This may be as easy as copy-paste from Google provider, but needs to be tested

default:
return '';
}
},
};

export default GoogleApiConfig;
Loading

0 comments on commit 5d30f4c

Please sign in to comment.