Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 45 additions & 34 deletions packages/gateway/src/worker-model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ const MODELS_DEV_API = "https://models.dev/api.json";
/** Cached models.dev data: model entries for all supported providers. */
let cachedModelData: Map<string, ModelsDevEntry> | null = null;
let cachedModelDataAt = 0;
let inflightFetch: Promise<Map<string, ModelsDevEntry>> | null = null;
const CACHE_TTL_MS = 60 * 60 * 1000; // 1 hour

/** Providers to fetch pricing data for from models.dev. */
Expand Down Expand Up @@ -99,53 +100,62 @@ function fallbackEntry(modelID: string): ModelsDevEntry {
* Single HTTP request, cached for 1 hour. Returns a map of
* modelID → entry with cost and limit data across all supported providers.
*/
export async function fetchModelData(): Promise<Map<string, ModelsDevEntry>> {
export function fetchModelData(): Promise<Map<string, ModelsDevEntry>> {
// Return cache if fresh
if (cachedModelData && Date.now() - cachedModelDataAt < CACHE_TTL_MS) {
return cachedModelData;
return Promise.resolve(cachedModelData);
}

try {
const controller = new AbortController();
const timeout = setTimeout(() => controller.abort(), 10_000);
// Deduplicate concurrent calls: return the in-flight promise if one exists
if (inflightFetch) return inflightFetch;

const response = await fetch(MODELS_DEV_API, { signal: controller.signal });
clearTimeout(timeout);
inflightFetch = (async () => {
try {
const controller = new AbortController();
const timeout = setTimeout(() => controller.abort(), 10_000);

if (!response.ok) {
log.warn(`models.dev API failed: ${response.status} ${response.statusText}`);
return cachedModelData ?? new Map();
}

const data = (await response.json()) as ModelsDevResponse;
const modelData = new Map<string, ModelsDevEntry>();
const response = await fetch(MODELS_DEV_API, { signal: controller.signal });
clearTimeout(timeout);

for (const providerName of SUPPORTED_PROVIDERS) {
const providerModels = data[providerName]?.models;
if (!providerModels) {
log.warn(`models.dev API: no ${providerName} provider found`);
continue;
if (!response.ok) {
log.warn(`models.dev API failed: ${response.status} ${response.statusText}`);
return cachedModelData ?? new Map();
}

for (const [modelId, entry] of Object.entries(providerModels)) {
const e: ModelsDevEntry = { ...entry, id: modelId };
// Compute cache_write cost if not provided (typically 1.25× input price)
if (e.cost && e.cost.cache_write == null && e.cost.input != null) {
e.cost.cache_write = e.cost.input * 1.25;
const data = (await response.json()) as ModelsDevResponse;
const modelData = new Map<string, ModelsDevEntry>();

for (const providerName of SUPPORTED_PROVIDERS) {
const providerModels = data[providerName]?.models;
if (!providerModels) {
log.warn(`models.dev API: no ${providerName} provider found`);
continue;
}

for (const [modelId, entry] of Object.entries(providerModels)) {
const e: ModelsDevEntry = { ...entry, id: modelId };
// Compute cache_write cost if not provided (typically 1.25× input price)
if (e.cost && e.cost.cache_write == null && e.cost.input != null) {
e.cost.cache_write = e.cost.input * 1.25;
}
modelData.set(modelId, e);
}
modelData.set(modelId, e);
}
}

cachedModelData = modelData;
cachedModelDataAt = Date.now();
cachedModelData = modelData;
cachedModelDataAt = Date.now();

log.info(`models.dev: loaded data for ${modelData.size} models across ${SUPPORTED_PROVIDERS.join(", ")}`);
return modelData;
} catch (e) {
log.warn("models.dev API error:", e);
return cachedModelData ?? new Map();
}
log.info(`models.dev: loaded data for ${modelData.size} models across ${SUPPORTED_PROVIDERS.join(", ")}`);
return modelData;
} catch (e) {
log.warn("models.dev API error:", e);
return cachedModelData ?? new Map();
} finally {
inflightFetch = null;
}
})();

return inflightFetch;
}

/**
Expand Down Expand Up @@ -206,6 +216,7 @@ export function getModelEntrySync(modelID: string): ModelsDevEntry {
export function clearModelDataCache(): void {
cachedModelData = null;
cachedModelDataAt = 0;
inflightFetch = null;
}

// ---------------------------------------------------------------------------
Expand Down
34 changes: 34 additions & 0 deletions packages/gateway/test/worker-model.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,40 @@ describe("fetchModelData", () => {
expect(data.size).toBe(0);
});

test("deduplicates concurrent in-flight requests", async () => {
let callCount = 0;
globalThis.fetch = mock(() => {
callCount++;
return Promise.resolve(
new Response(JSON.stringify(buildModelsDevResponse(DEFAULT_MODELS)), { status: 200 }),
);
}) as unknown as typeof fetch;

const [a, b, c] = await Promise.all([
fetchModelData(),
fetchModelData(),
fetchModelData(),
]);

expect(callCount).toBe(1);
expect(a).toBe(b);
expect(b).toBe(c);
});

test("deduplicates concurrent calls even on network error", async () => {
let callCount = 0;
globalThis.fetch = mock(() => {
callCount++;
return Promise.reject(new Error("Network error"));
}) as unknown as typeof fetch;

const [a, b] = await Promise.all([fetchModelData(), fetchModelData()]);

expect(callCount).toBe(1);
expect(a).toBe(b);
expect(a.size).toBe(0);
});

test("handles missing providers gracefully", async () => {
globalThis.fetch = mock(() =>
Promise.resolve(
Expand Down
Loading