diff --git a/README.md b/README.md index a3de5da..05005a0 100644 --- a/README.md +++ b/README.md @@ -623,11 +623,12 @@ Works with any server that implements the OpenAI `/v1/embeddings` API format (ll "dimensions": 768, "apiKey": "{env:EMBED_API_KEY}", "maxTokens": 8192, - "timeoutMs": 30000 + "timeoutMs": 30000, + "maxBatchSize": 64 } } ``` -Required fields: `baseUrl`, `model`, `dimensions` (positive integer). Optional: `apiKey`, `maxTokens`, `timeoutMs` (default: 30000). `{env:VAR_NAME}` placeholders are resolved before config validation for fields that are actually used and throw if the referenced environment variable is missing or malformed. +Required fields: `baseUrl`, `model`, `dimensions` (positive integer). Optional: `apiKey`, `maxTokens`, `timeoutMs` (default: 30000), `maxBatchSize` (or `max_batch_size`) to cap inputs per `/embeddings` request for servers like text-embeddings-inference. `{env:VAR_NAME}` placeholders are resolved before config validation for fields that are actually used and throw if the referenced environment variable is missing or malformed. ## ⚠️ Tradeoffs diff --git a/src/config/schema.ts b/src/config/schema.ts index 40adbaa..bfe815f 100644 --- a/src/config/schema.ts +++ b/src/config/schema.ts @@ -66,6 +66,8 @@ export interface CustomProviderConfig { concurrency?: number; /** Minimum delay between requests in milliseconds (default: 1000). Set to 0 for local servers. */ requestIntervalMs?: number; + maxBatchSize?: number; + max_batch_size?: number; } export interface CodebaseIndexConfig { @@ -245,6 +247,11 @@ export function parseConfig(raw: unknown): ParsedCodebaseIndexConfig { timeoutMs: typeof rawCustom.timeoutMs === 'number' ? Math.max(1000, rawCustom.timeoutMs) : undefined, concurrency: typeof rawCustom.concurrency === 'number' ? Math.max(1, Math.floor(rawCustom.concurrency)) : undefined, requestIntervalMs: typeof rawCustom.requestIntervalMs === 'number' ? Math.max(0, Math.floor(rawCustom.requestIntervalMs)) : undefined, + maxBatchSize: typeof rawCustom.maxBatchSize === 'number' + ? Math.max(1, Math.floor(rawCustom.maxBatchSize)) + : typeof rawCustom.max_batch_size === 'number' + ? Math.max(1, Math.floor(rawCustom.max_batch_size)) + : undefined, }; // Warn if baseUrl doesn't end with an API version path like /v1. // Note: using console.warn here because Logger isn't initialized yet at config parse time. diff --git a/src/embeddings/detector.ts b/src/embeddings/detector.ts index 83f152f..9eb9dd3 100644 --- a/src/embeddings/detector.ts +++ b/src/embeddings/detector.ts @@ -15,6 +15,7 @@ export interface ProviderCredentials { export interface CustomModelInfo extends BaseModelInfo { provider: 'custom'; timeoutMs: number; + maxBatchSize?: number; } export type ConfiguredProviderInfo = { @@ -247,6 +248,7 @@ export function createCustomProviderInfo(config: CustomProviderConfig): Configur maxTokens: config.maxTokens ?? 8192, costPer1MTokens: 0, timeoutMs: config.timeoutMs ?? 30_000, + maxBatchSize: config.maxBatchSize, }, }; } diff --git a/src/embeddings/provider.ts b/src/embeddings/provider.ts index 69a7855..5259200 100644 --- a/src/embeddings/provider.ts +++ b/src/embeddings/provider.ts @@ -343,23 +343,28 @@ class CustomEmbeddingProvider implements EmbeddingProviderInterface { private modelInfo: CustomModelInfo ) { } - async embedQuery(query: string): Promise { - const result = await this.embedBatch([query]); - return { - embedding: result.embeddings[0], - tokensUsed: result.totalTokensUsed, - }; - } + private splitIntoRequestBatches(texts: string[]): string[][] { + const maxBatchSize = this.modelInfo.maxBatchSize; - async embedDocument(document: string): Promise { - const result = await this.embedBatch([document]); - return { - embedding: result.embeddings[0], - tokensUsed: result.totalTokensUsed, - }; + if (!maxBatchSize || texts.length <= maxBatchSize) { + return [texts]; + } + + const batches: string[][] = []; + for (let i = 0; i < texts.length; i += maxBatchSize) { + batches.push(texts.slice(i, i + maxBatchSize)); + } + return batches; } - async embedBatch(texts: string[]): Promise { + private async embedRequest(texts: string[]): Promise { + if (texts.length === 0) { + return { + embeddings: [], + totalTokensUsed: 0, + }; + } + const headers: Record = { "Content-Type": "application/json", }; @@ -444,6 +449,39 @@ class CustomEmbeddingProvider implements EmbeddingProviderInterface { throw new Error("Custom embedding API returned unexpected response format. Expected OpenAI-compatible format with data[].embedding."); } + async embedQuery(query: string): Promise { + const result = await this.embedBatch([query]); + return { + embedding: result.embeddings[0], + tokensUsed: result.totalTokensUsed, + }; + } + + async embedDocument(document: string): Promise { + const result = await this.embedBatch([document]); + return { + embedding: result.embeddings[0], + tokensUsed: result.totalTokensUsed, + }; + } + + async embedBatch(texts: string[]): Promise { + const requestBatches = this.splitIntoRequestBatches(texts); + const embeddings: number[][] = []; + let totalTokensUsed = 0; + + for (const batch of requestBatches) { + const result = await this.embedRequest(batch); + embeddings.push(...result.embeddings); + totalTokensUsed += result.totalTokensUsed; + } + + return { + embeddings, + totalTokensUsed, + }; + } + getModelInfo(): CustomModelInfo { return this.modelInfo; } diff --git a/tests/config.test.ts b/tests/config.test.ts index 3336783..daa2438 100644 --- a/tests/config.test.ts +++ b/tests/config.test.ts @@ -623,6 +623,32 @@ describe("config schema", () => { expect(config.customProvider!.requestIntervalMs).toBe(0); }); + it("should parse custom provider with maxBatchSize", () => { + const config = parseConfig({ + embeddingProvider: "custom", + customProvider: { + baseUrl: "http://localhost:11434/v1", + model: "test", + dimensions: 768, + maxBatchSize: 64, + }, + }); + expect(config.customProvider!.maxBatchSize).toBe(64); + }); + + it("should parse custom provider with max_batch_size alias", () => { + const config = parseConfig({ + embeddingProvider: "custom", + customProvider: { + baseUrl: "http://localhost:11434/v1", + model: "test", + dimensions: 768, + max_batch_size: 32, + }, + }); + expect(config.customProvider!.maxBatchSize).toBe(32); + }); + it("should clamp concurrency to minimum of 1", () => { const config = parseConfig({ embeddingProvider: "custom", @@ -636,6 +662,19 @@ describe("config schema", () => { expect(config.customProvider!.concurrency).toBe(1); }); + it("should clamp maxBatchSize to minimum of 1", () => { + const config = parseConfig({ + embeddingProvider: "custom", + customProvider: { + baseUrl: "http://localhost:11434/v1", + model: "test", + dimensions: 768, + maxBatchSize: 0, + }, + }); + expect(config.customProvider!.maxBatchSize).toBe(1); + }); + it("should leave concurrency undefined when not provided", () => { const config = parseConfig({ embeddingProvider: "custom", diff --git a/tests/custom-provider.test.ts b/tests/custom-provider.test.ts index 913f826..50a6746 100644 --- a/tests/custom-provider.test.ts +++ b/tests/custom-provider.test.ts @@ -1,6 +1,6 @@ import { describe, it, expect, vi, beforeEach, afterEach } from "vitest"; import { createEmbeddingProvider, CustomProviderNonRetryableError } from "../src/embeddings/provider.js"; -import { createCustomProviderInfo } from "../src/embeddings/detector.js"; +import { createCustomProviderInfo, type ConfiguredProviderInfo } from "../src/embeddings/detector.js"; import { Indexer } from "../src/indexer/index.js"; import { parseConfig } from "../src/config/schema.js"; import pRetry from "p-retry"; @@ -11,6 +11,30 @@ import * as path from "path"; describe("CustomEmbeddingProvider", () => { let fetchSpy: ReturnType; + function getCustomProviderInfo( + info: ConfiguredProviderInfo + ): Extract { + expect(info.provider).toBe("custom"); + if (info.provider !== "custom") { + throw new Error("Expected custom provider info"); + } + return info; + } + + function getRejectedError(promise: Promise): Promise { + return promise.then( + () => { + throw new Error("Expected promise to reject"); + }, + (error: unknown) => { + if (error instanceof Error) { + return error; + } + return new Error(String(error)); + } + ); + } + beforeEach(() => { fetchSpy = vi.spyOn(globalThis, "fetch"); }); @@ -94,6 +118,47 @@ describe("CustomEmbeddingProvider", () => { expect(result.totalTokensUsed).toBe(30); }); + it("should split custom provider requests by maxBatchSize", async () => { + fetchSpy + .mockResolvedValueOnce(new Response(JSON.stringify({ + data: [ + { embedding: new Array(768).fill(0.1) }, + { embedding: new Array(768).fill(0.2) }, + ], + usage: { total_tokens: 20 }, + }), { status: 200 })) + .mockResolvedValueOnce(new Response(JSON.stringify({ + data: [ + { embedding: new Array(768).fill(0.3) }, + { embedding: new Array(768).fill(0.4) }, + ], + usage: { total_tokens: 22 }, + }), { status: 200 })) + .mockResolvedValueOnce(new Response(JSON.stringify({ + data: [ + { embedding: new Array(768).fill(0.5) }, + ], + usage: { total_tokens: 11 }, + }), { status: 200 })); + + const info = createCustomProviderInfo({ + baseUrl: "http://localhost:11434/v1", + model: "nomic-embed-text", + dimensions: 768, + maxBatchSize: 2, + }); + const provider = createEmbeddingProvider(info); + + const result = await provider.embedBatch(["text1", "text2", "text3", "text4", "text5"]); + + expect(fetchSpy).toHaveBeenCalledTimes(3); + expect(JSON.parse((fetchSpy.mock.calls[0] as [string, RequestInit])[1].body as string).input).toEqual(["text1", "text2"]); + expect(JSON.parse((fetchSpy.mock.calls[1] as [string, RequestInit])[1].body as string).input).toEqual(["text3", "text4"]); + expect(JSON.parse((fetchSpy.mock.calls[2] as [string, RequestInit])[1].body as string).input).toEqual(["text5"]); + expect(result.embeddings).toHaveLength(5); + expect(result.totalTokensUsed).toBe(53); + }); + it("should estimate tokens when usage is not provided", async () => { fetchSpy.mockResolvedValueOnce(new Response(JSON.stringify({ data: [{ embedding: new Array(768).fill(0) }], @@ -213,28 +278,28 @@ describe("CustomEmbeddingProvider", () => { }); it("should default timeout to 30000ms", () => { - const info = createCustomProviderInfo({ + const info = getCustomProviderInfo(createCustomProviderInfo({ baseUrl: "http://localhost:11434/v1", model: "nomic-embed-text", dimensions: 768, - }); + })); expect(info.modelInfo.timeoutMs).toBe(30000); }); it("should use custom timeout value from config", () => { - const info = createCustomProviderInfo({ + const info = getCustomProviderInfo(createCustomProviderInfo({ baseUrl: "http://localhost:11434/v1", model: "nomic-embed-text", dimensions: 768, timeoutMs: 60000, - }); + })); expect(info.modelInfo.timeoutMs).toBe(60000); }); it("should throw non-retryable error on 4xx responses (except 429)", async () => { fetchSpy.mockResolvedValueOnce(new Response("Unauthorized", { status: 401 })); const provider = createProvider(); - const error = await provider.embedQuery("test").catch((e: Error) => e); + const error = await getRejectedError(provider.embedQuery("test")); expect(error).toBeInstanceOf(CustomProviderNonRetryableError); expect(error.message).toContain("non-retryable"); expect(error.message).toContain("401"); @@ -243,21 +308,21 @@ describe("CustomEmbeddingProvider", () => { it("should throw non-retryable error on 400 Bad Request", async () => { fetchSpy.mockResolvedValueOnce(new Response("Bad model name", { status: 400 })); const provider = createProvider(); - const error = await provider.embedQuery("test").catch((e: Error) => e); + const error = await getRejectedError(provider.embedQuery("test")); expect(error).toBeInstanceOf(CustomProviderNonRetryableError); }); it("should throw non-retryable error on 403 Forbidden", async () => { fetchSpy.mockResolvedValueOnce(new Response("Forbidden", { status: 403 })); const provider = createProvider(); - const error = await provider.embedQuery("test").catch((e: Error) => e); + const error = await getRejectedError(provider.embedQuery("test")); expect(error).toBeInstanceOf(CustomProviderNonRetryableError); }); it("should throw retryable error on 429 rate limit", async () => { fetchSpy.mockResolvedValueOnce(new Response("Rate limited", { status: 429 })); const provider = createProvider(); - const error = await provider.embedQuery("test").catch((e: Error) => e); + const error = await getRejectedError(provider.embedQuery("test")); expect(error).not.toBeInstanceOf(CustomProviderNonRetryableError); expect(error.message).toContain("429"); }); @@ -265,7 +330,7 @@ describe("CustomEmbeddingProvider", () => { it("should throw retryable error on 5xx server errors", async () => { fetchSpy.mockResolvedValueOnce(new Response("Internal Server Error", { status: 500 })); const provider = createProvider(); - const error = await provider.embedQuery("test").catch((e: Error) => e); + const error = await getRejectedError(provider.embedQuery("test")); expect(error).not.toBeInstanceOf(CustomProviderNonRetryableError); expect(error.message).toContain("500"); }); diff --git a/tests/embeddings.test.ts b/tests/embeddings.test.ts index d8378f9..9a76628 100644 --- a/tests/embeddings.test.ts +++ b/tests/embeddings.test.ts @@ -67,5 +67,15 @@ describe("embeddings detector", () => { }); expect(info.modelInfo.maxTokens).toBe(4096); }); + + it("should pass through optional maxBatchSize", () => { + const info = createCustomProviderInfo({ + baseUrl: "http://localhost/v1", + model: "test", + dimensions: 512, + maxBatchSize: 64, + }); + expect(info.modelInfo.maxBatchSize).toBe(64); + }); }); });