From 3b156aa19d05a9f14effe1c5baeed099da4ce1eb Mon Sep 17 00:00:00 2001 From: Daniel Sogl Date: Sun, 21 Sep 2025 21:58:11 +0200 Subject: [PATCH] feat: add grok provider --- README.md | 1 + runner/codegen/genkit/models.ts | 2 + runner/codegen/genkit/providers/grok.ts | 113 ++++++++++++++++++++++++ 3 files changed, 116 insertions(+) create mode 100644 runner/codegen/genkit/providers/grok.ts diff --git a/README.md b/README.md index 8e88691..1d9f41c 100644 --- a/README.md +++ b/README.md @@ -38,6 +38,7 @@ npm install -g web-codegen-scorer export GEMINI_API_KEY="YOUR_API_KEY_HERE" # If you're using Gemini models export OPENAI_API_KEY="YOUR_API_KEY_HERE" # If you're using OpenAI models export ANTHROPIC_API_KEY="YOUR_API_KEY_HERE" # If you're using Anthropic models +export XAI_API_KEY="YOUR_API_KEY_HERE" # If you're using xAI Grok models ``` 3. **Run an eval:** diff --git a/runner/codegen/genkit/models.ts b/runner/codegen/genkit/models.ts index 61d7de9..1c8f113 100644 --- a/runner/codegen/genkit/models.ts +++ b/runner/codegen/genkit/models.ts @@ -1,9 +1,11 @@ import { GeminiModelProvider } from './providers/gemini.js'; import { ClaudeModelProvider } from './providers/claude.js'; import { OpenAiModelProvider } from './providers/open-ai.js'; +import { GrokModelProvider } from './providers/grok.js'; export const MODEL_PROVIDERS = [ new GeminiModelProvider(), new ClaudeModelProvider(), new OpenAiModelProvider(), + new GrokModelProvider(), ]; diff --git a/runner/codegen/genkit/providers/grok.ts b/runner/codegen/genkit/providers/grok.ts new file mode 100644 index 0000000..cbbd4c3 --- /dev/null +++ b/runner/codegen/genkit/providers/grok.ts @@ -0,0 +1,113 @@ +import { xAI } from '@genkit-ai/compat-oai/xai'; +import { GenkitPlugin, GenkitPluginV2 } from 'genkit/plugin'; +import { RateLimiter } from 'limiter'; +import fetch from 'node-fetch'; +import { + GenkitModelProvider, + PromptDataForCounting, + RateLimitConfig, +} from '../model-provider.js'; + +export class GrokModelProvider extends GenkitModelProvider { + readonly apiKeyVariableName = 'XAI_API_KEY'; + + protected readonly models = { + 'grok-4': () => xAI.model('grok-4'), + 'grok-code-fast-1': () => xAI.model('grok-code-fast-1'), + }; + + private async countTokensWithXaiApi( + prompt: PromptDataForCounting + ): Promise { + const apiKey = this.getApiKey(); + if (!apiKey) { + return null; + } + + try { + // Use xAI's tokenize API for accurate token counting + const messages = this.genkitPromptToXaiFormat(prompt); + const text = messages.map((m) => `${m.role}: ${m.content}`).join('\n'); + + const response = await fetch('https://api.x.ai/v1/tokenize', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${apiKey}`, + }, + body: JSON.stringify({ text }), + }); + + if (response.ok) { + const data = (await response.json()) as { tokens: unknown[] }; + return data.tokens?.length || 0; + } + return null; + } catch (error) { + console.warn('Failed to count tokens using xAI API', error); + return null; + } + } + + private async countTokensForModel( + _modelName: string, + prompt: PromptDataForCounting + ): Promise { + const xaiTokenCount = await this.countTokensWithXaiApi(prompt); + if (xaiTokenCount !== null) { + return xaiTokenCount; + } + return 0; + } + + protected rateLimitConfig: Record = { + // XAI Grok rate limits https://docs.x.ai/docs/models + 'xai/grok-4': { + requestPerMinute: new RateLimiter({ + tokensPerInterval: 480, + interval: 1000 * 60 * 1.5, // Refresh tokens after 1.5 minutes to be on the safe side + }), + tokensPerMinute: new RateLimiter({ + tokensPerInterval: 2_000_000 * 0.75, + interval: 1000 * 60 * 1.5, // Refresh tokens after 1.5 minutes to be on the safe side + }), + countTokens: (prompt) => this.countTokensForModel('grok-4', prompt), + }, + 'xai/grok-code-fast-1': { + requestPerMinute: new RateLimiter({ + tokensPerInterval: 480, + interval: 1000 * 60 * 1.5, // Refresh tokens after 1.5 minutes to be on the safe side + }), + tokensPerMinute: new RateLimiter({ + tokensPerInterval: 2_000_000 * 0.75, + interval: 1000 * 60 * 1.5, // Refresh tokens after 1.5 minutes to be on the safe side + }), + countTokens: (prompt) => + this.countTokensForModel('grok-code-fast-1', prompt), + }, + }; + + protected pluginFactory(apiKey: string): GenkitPlugin | GenkitPluginV2 { + return xAI({ apiKey }); + } + + getModelSpecificConfig(): object { + // Grok doesn't require special configuration at this time + return {}; + } + + private genkitPromptToXaiFormat( + prompt: PromptDataForCounting + ): Array<{ role: string; content: string }> { + const xaiPrompt: Array<{ role: string; content: string }> = []; + for (const part of prompt.messages) { + for (const c of part.content) { + xaiPrompt.push({ + role: part.role, + content: 'media' in c ? c.media.url : c.text, + }); + } + } + return [...xaiPrompt, { role: 'user', content: prompt.prompt }]; + } +}