From 1484ffb8809a12c8ca8ae7aee8e71ceb3f59deed Mon Sep 17 00:00:00 2001 From: Alexandre Vilain Date: Wed, 20 Aug 2025 14:32:02 +0200 Subject: [PATCH] feat(autocomplete): add FIM support with Mistral Codestral provider Add Fill-in-the-Middle (FIM) completion support by restructuring the autocomplete system. Introduce dedicated HoleFiller implementations for different completion strategies and add Mistral Codestral provider with native FIM capabilities. - Add ai-sdk-mistral-fim dependency for FIM support - Refactor HoleFiller interface to use prompt-based architecture - Extract DefaultHoleFiller from holeFiller.ts with existing hole-filling logic - Add MistralFimHoleFiller for native FIM completion - Create CodestralProvider with Mistral FIM integration - Add provider detection logic to select appropriate completion strategy - Consolidate autocomplete exports in index.ts - Update completion provider to dynamically choose HoleFiller based on provider type --- package-lock.json | 35 +++++++ package.json | 1 + src/autocomplete/context.ts | 7 -- src/autocomplete/defaultHoleFiller.ts | 114 +++++++++++++++++++++ src/autocomplete/holeFiller.ts | 120 +++-------------------- src/autocomplete/index.ts | 3 + src/autocomplete/mistralfimHoleFiller.ts | 14 +++ src/providers/codestral.ts | 24 +++++ src/providers/providers.ts | 11 ++- src/vscode/completionProvider.ts | 13 +-- src/vscode/profileCommandProvider.ts | 18 ++-- 11 files changed, 225 insertions(+), 135 deletions(-) delete mode 100644 src/autocomplete/context.ts create mode 100644 src/autocomplete/defaultHoleFiller.ts create mode 100644 src/autocomplete/index.ts create mode 100644 src/autocomplete/mistralfimHoleFiller.ts create mode 100644 src/providers/codestral.ts diff --git a/package-lock.json b/package-lock.json index b9bab89..7d65401 100644 --- a/package-lock.json +++ b/package-lock.json @@ -12,6 +12,7 @@ "@ai-sdk/openai-compatible": "^1.0.7", "@ai-sdk/provider": "^2.0.0", "ai": "^5.0.14", + "ai-sdk-mistral-fim": "^0.0.1", "ai-sdk-ollama": "^0.5.0", "uuid": "^11.1.0" }, @@ -2730,6 +2731,40 @@ "zod": "^3.25.76 || ^4" } }, + "node_modules/ai-sdk-mistral-fim": { + "version": "0.0.1", + "resolved": "https://registry.npmjs.org/ai-sdk-mistral-fim/-/ai-sdk-mistral-fim-0.0.1.tgz", + "integrity": "sha512-00QOOA8p2+URcJjZ747gQSmcmx+VSuK9JdkP5tMjI1iKpkbtd8pEr/Z3JUIDTwmEr5qbqfVmyXd7YQpkNzdpGA==", + "license": "Apache-2.0", + "dependencies": { + "@ai-sdk/provider": "^2.0.0", + "@ai-sdk/provider-utils": "^3.0.4" + }, + "engines": { + "node": ">=18" + }, + "peerDependencies": { + "zod": "^3.25.76 || ^4" + } + }, + "node_modules/ai-sdk-mistral-fim/node_modules/@ai-sdk/provider-utils": { + "version": "3.0.4", + "resolved": "https://registry.npmjs.org/@ai-sdk/provider-utils/-/provider-utils-3.0.4.tgz", + "integrity": "sha512-/3Z6lfUp8r+ewFd9yzHkCmPlMOJUXup2Sx3aoUyrdXLhOmAfHRl6Z4lDbIdV0uvw/QYoBcVLJnvXN7ncYeS3uQ==", + "license": "Apache-2.0", + "dependencies": { + "@ai-sdk/provider": "2.0.0", + "@standard-schema/spec": "^1.0.0", + "eventsource-parser": "^3.0.3", + "zod-to-json-schema": "^3.24.1" + }, + "engines": { + "node": ">=18" + }, + "peerDependencies": { + "zod": "^3.25.76 || ^4" + } + }, "node_modules/ai-sdk-ollama": { "version": "0.5.0", "resolved": "https://registry.npmjs.org/ai-sdk-ollama/-/ai-sdk-ollama-0.5.0.tgz", diff --git a/package.json b/package.json index 088c38b..3325e83 100644 --- a/package.json +++ b/package.json @@ -150,6 +150,7 @@ "@ai-sdk/openai-compatible": "^1.0.7", "@ai-sdk/provider": "^2.0.0", "ai": "^5.0.14", + "ai-sdk-mistral-fim": "^0.0.1", "ai-sdk-ollama": "^0.5.0", "uuid": "^11.1.0" } diff --git a/src/autocomplete/context.ts b/src/autocomplete/context.ts deleted file mode 100644 index 78e6bf1..0000000 --- a/src/autocomplete/context.ts +++ /dev/null @@ -1,7 +0,0 @@ -export type AutoCompleteContext = { - textBeforeCursor: string, - textAfterCursor: string, - currentLineText: string, - filename?: string, - language?: string, -} \ No newline at end of file diff --git a/src/autocomplete/defaultHoleFiller.ts b/src/autocomplete/defaultHoleFiller.ts new file mode 100644 index 0000000..bda032c --- /dev/null +++ b/src/autocomplete/defaultHoleFiller.ts @@ -0,0 +1,114 @@ +import { HoleFiller, PromptArgs, AutoCompleteContext } from "./holeFiller"; + +// Source: continue/core/autocomplete/templating/AutocompleteTemplate.ts (holeFillerTemplate) +export class DefaultHoleFiller implements HoleFiller { + systemPrompt(): string { + // From https://github.com/VictorTaelin/AI-scripts + return `You are a HOLE FILLER. You are provided with a file containing holes, formatted as '{{HOLE_NAME}}'. + Your TASK is to complete with a string to replace this hole with, inside a XML tag, including context-aware indentation, if needed. + All completions MUST be truthful, accurate, well-written and correct. +## EXAMPLE QUERY: + + +function sum_evens(lim) { + var sum = 0; + for (var i = 0; i < lim; ++i) { + {{FILL_HERE}} + } + return sum; +} + + +TASK: Fill the {{FILL_HERE}} hole. + +## CORRECT COMPLETION + +if (i % 2 === 0) { + sum += i; + } + +## EXAMPLE QUERY: + + +def sum_list(lst): + total = 0 + for x in lst: + {{FILL_HERE}} + return total + +print sum_list([1, 2, 3]) + + +## CORRECT COMPLETION: + + total += x + +## EXAMPLE QUERY: + + +// data Tree a = Node (Tree a) (Tree a) | Leaf a + +// sum :: Tree Int -> Int +// sum (Node lft rgt) = sum lft + sum rgt +// sum (Leaf val) = val + +// convert to TypeScript: +{{FILL_HERE}} + + +## CORRECT COMPLETION: + +type Tree + = {$:"Node", lft: Tree, rgt: Tree} + | {$:"Leaf", val: T}; + +function sum(tree: Tree): number { + switch (tree.$) { + case "Node": + return sum(tree.lft) + sum(tree.rgt); + case "Leaf": + return tree.val; + } +} + +## EXAMPLE QUERY: + +The 5th {{FILL_HERE}} is Jupiter. + +## CORRECT COMPLETION: + +planet from the Sun + +## EXAMPLE QUERY: + +function hypothenuse(a, b) { + return Math.sqrt({{FILL_HERE}}b ** 2); +} + +## CORRECT COMPLETION: + +a ** 2 + +`; + } + + userPrompt(ctx: AutoCompleteContext): string { + let context = ''; + if (ctx.filename !== '') { + context += `// Filename: "${ctx.filename}" \n`; + } + if (ctx.language !== '') { + context += `// Programming language: "${ctx.language}" \n`; + } + return `${context}\n${ctx.textBeforeCursor}{{FILL_HERE}}${ctx.textAfterCursor}\n\nTASK: Fill the {{FILL_HERE}} hole. Answer only with the CORRECT completion, and NOTHING ELSE. Do it now.\n`; + } + + prompt(params: AutoCompleteContext): PromptArgs { + return { + messages: [ + { role: "system", content: this.systemPrompt() }, + { role: "user", content: this.userPrompt(params) }, + ], + }; + } +} + diff --git a/src/autocomplete/holeFiller.ts b/src/autocomplete/holeFiller.ts index 208e84b..f62d8bc 100644 --- a/src/autocomplete/holeFiller.ts +++ b/src/autocomplete/holeFiller.ts @@ -1,110 +1,18 @@ -import { type AutoCompleteContext } from "./context"; +import { Prompt } from 'ai'; +import { ProviderOptions } from '@ai-sdk/provider-utils'; export interface HoleFiller { - systemPrompt(): string - userPrompt(params: AutoCompleteContext): string -} - -// Source: continue/core/autocomplete/templating/AutocompleteTemplate.ts (holeFillerTemplate) -export class DefaultHoleFiller implements HoleFiller { - systemPrompt(): string { - // From https://github.com/VictorTaelin/AI-scripts - return `You are a HOLE FILLER. You are provided with a file containing holes, formatted as '{{HOLE_NAME}}'. - Your TASK is to complete with a string to replace this hole with, inside a XML tag, including context-aware indentation, if needed. - All completions MUST be truthful, accurate, well-written and correct. -## EXAMPLE QUERY: - - -function sum_evens(lim) { - var sum = 0; - for (var i = 0; i < lim; ++i) { - {{FILL_HERE}} - } - return sum; -} - - -TASK: Fill the {{FILL_HERE}} hole. - -## CORRECT COMPLETION - -if (i % 2 === 0) { - sum += i; - } - -## EXAMPLE QUERY: - - -def sum_list(lst): - total = 0 - for x in lst: - {{FILL_HERE}} - return total - -print sum_list([1, 2, 3]) - - -## CORRECT COMPLETION: - - total += x - -## EXAMPLE QUERY: - - -// data Tree a = Node (Tree a) (Tree a) | Leaf a - -// sum :: Tree Int -> Int -// sum (Node lft rgt) = sum lft + sum rgt -// sum (Leaf val) = val - -// convert to TypeScript: -{{FILL_HERE}} - - -## CORRECT COMPLETION: - -type Tree - = {$:"Node", lft: Tree, rgt: Tree} - | {$:"Leaf", val: T}; - -function sum(tree: Tree): number { - switch (tree.$) { - case "Node": - return sum(tree.lft) + sum(tree.rgt); - case "Leaf": - return tree.val; - } -} - -## EXAMPLE QUERY: - -The 5th {{FILL_HERE}} is Jupiter. - -## CORRECT COMPLETION: - -planet from the Sun - -## EXAMPLE QUERY: - -function hypothenuse(a, b) { - return Math.sqrt({{FILL_HERE}}b ** 2); -} - -## CORRECT COMPLETION: - -a ** 2 + -`; - } - - userPrompt(ctx: AutoCompleteContext): string { - let context = ''; - if (ctx.filename !== '') { - context += `// Filename: "${ctx.filename}" \n`; - } - if (ctx.language !== '') { - context += `// Programming language: "${ctx.language}" \n`; - } - return `${context}\n${ctx.textBeforeCursor}{{FILL_HERE}}${ctx.textAfterCursor}\n\nTASK: Fill the {{FILL_HERE}} hole. Answer only with the CORRECT completion, and NOTHING ELSE. Do it now.\n`; - } + prompt(params: AutoCompleteContext): PromptArgs } +export type PromptArgs = Prompt & { + providerOptions?: ProviderOptions; +}; + +export type AutoCompleteContext = { + textBeforeCursor: string, + textAfterCursor: string, + currentLineText: string, + filename?: string, + language?: string, +} \ No newline at end of file diff --git a/src/autocomplete/index.ts b/src/autocomplete/index.ts new file mode 100644 index 0000000..6aea0fd --- /dev/null +++ b/src/autocomplete/index.ts @@ -0,0 +1,3 @@ +export * from './defaultHoleFiller'; +export * from './mistralfimHoleFiller'; +export * from './holeFiller'; \ No newline at end of file diff --git a/src/autocomplete/mistralfimHoleFiller.ts b/src/autocomplete/mistralfimHoleFiller.ts new file mode 100644 index 0000000..d60bde4 --- /dev/null +++ b/src/autocomplete/mistralfimHoleFiller.ts @@ -0,0 +1,14 @@ +import { HoleFiller, PromptArgs, AutoCompleteContext } from "./holeFiller"; + +export class MistralFimHoleFiller implements HoleFiller { + prompt(params: AutoCompleteContext): PromptArgs { + return { + prompt: params.textBeforeCursor, + providerOptions: { + 'mistral.fim': { + suffix: params.textAfterCursor, + } + } + }; + } +} diff --git a/src/providers/codestral.ts b/src/providers/codestral.ts new file mode 100644 index 0000000..41e585e --- /dev/null +++ b/src/providers/codestral.ts @@ -0,0 +1,24 @@ +import { ProfileWithAPIKey, ProviderConnection, Model } from "../types"; +import { type LanguageModelV2 } from "@ai-sdk/provider"; +import { LanguageModelProvider } from "./providers"; +import { createMistralFim } from 'ai-sdk-mistral-fim'; + +export class CodestralProvider implements LanguageModelProvider { + languageModel(profile: ProfileWithAPIKey): LanguageModelV2 { + return createMistralFim({ + baseURL: profile.baseURL, + apiKey: profile.apiKey, + })(profile.modelId); + } + + async listModels(_conn: ProviderConnection): Promise { + return new Promise((resolve) => { + resolve([ + { + id: 'codestral-latest', + name: 'codestral-latest' + }, + ]); + }); + } +} diff --git a/src/providers/providers.ts b/src/providers/providers.ts index a31a9c4..e3fcc20 100644 --- a/src/providers/providers.ts +++ b/src/providers/providers.ts @@ -2,6 +2,7 @@ import { ProfileWithAPIKey, Provider, ProviderConnection, ProviderID, Model } fr import { OpenAICompatibleProvider } from "./openaiCompatible"; import { OllamaProvider } from "./ollama"; import { type LanguageModelV2 } from "@ai-sdk/provider"; +import { CodestralProvider } from "./codestral"; export interface LanguageModelProvider { languageModel(profile: ProfileWithAPIKey): LanguageModelV2 @@ -17,10 +18,11 @@ function languageModelProvider(providerId: ProviderID): LanguageModelProvider { case 'groq': case 'openai-compatible': case 'mistral': - case 'mistral-codestral': // TODO: we should support FIM endpoint. - return new OpenAICompatibleProvider(); + return new OpenAICompatibleProvider(); case 'ollama': - return new OllamaProvider(); + return new OllamaProvider(); + case 'mistral-codestral': + return new CodestralProvider(); default: throw new Error(`Unsupported provider: ${providerId}`); } @@ -34,6 +36,9 @@ export function getLanguageModelFromProfile(profile: ProfileWithAPIKey): Languag return languageModelProvider(profile.provider).languageModel(profile); } +export function isFimProvider(provider: ProviderID): boolean { + return provider === 'mistral-codestral'; +} export const providers: Provider[] = [ { diff --git a/src/vscode/completionProvider.ts b/src/vscode/completionProvider.ts index 3e5b88e..c8933d2 100644 --- a/src/vscode/completionProvider.ts +++ b/src/vscode/completionProvider.ts @@ -1,8 +1,7 @@ import * as vscode from 'vscode'; -import { HoleFiller, DefaultHoleFiller } from '../autocomplete/holeFiller'; -import { AutoCompleteContext } from '../autocomplete/context'; +import { MistralFimHoleFiller, AutoCompleteContext, DefaultHoleFiller } from '../autocomplete'; import { ProfileService } from '../services/profileService'; -import { getLanguageModelFromProfile } from '../providers/providers'; +import { getLanguageModelFromProfile, isFimProvider } from '../providers/providers'; import { generateText } from 'ai'; import { TabCoderStatusBarProvider } from './statusBarProvider'; import { logger } from '../utils/logger'; @@ -10,7 +9,6 @@ import { LanguageModelV2 } from '@ai-sdk/provider'; import { ProfileWithAPIKey } from '../types'; export class TabCoderInlineCompletionProvider implements vscode.InlineCompletionItemProvider { - private holeFiller: HoleFiller = new DefaultHoleFiller(); private profileService: ProfileService; private statusBarProvider: TabCoderStatusBarProvider; private debounceTimeout: NodeJS.Timeout | undefined; @@ -257,12 +255,11 @@ export class TabCoderInlineCompletionProvider implements vscode.InlineCompletion logger.debug(`Request ${requestId}: Using cached model for profile ${profile.id}`); } + const holeFiller = isFimProvider(profile.provider) ? new MistralFimHoleFiller() : new DefaultHoleFiller(); + const { text, usage } = await generateText({ model: this.cachedModel!, - messages: [ - { role: "system", content: this.holeFiller.systemPrompt() }, - { role: "user", content: this.holeFiller.userPrompt(params) }, - ], + ...holeFiller.prompt(params), abortSignal: this.currentAbortController.signal, }); diff --git a/src/vscode/profileCommandProvider.ts b/src/vscode/profileCommandProvider.ts index ebcc638..d61b15b 100644 --- a/src/vscode/profileCommandProvider.ts +++ b/src/vscode/profileCommandProvider.ts @@ -79,16 +79,12 @@ export class ProfileCommandProvider { apiKey = inputApiKey; } - let selectedModelId = ''; - if (selectedProvider.id === 'mistral-codestral') { - selectedModelId = 'codestral-latest'; - } else { - selectedModelId = await this.askForModel({ - id: selectedProvider.id, - baseURL, - apiKey - }); - } + const selectedModelId = await this.loadAndAskForModel({ + id: selectedProvider.id, + baseURL, + apiKey + }); + if (!selectedModelId) { return; @@ -192,7 +188,7 @@ export class ProfileCommandProvider { } } - async askForModel(conn: ProviderConnection): Promise { + async loadAndAskForModel(conn: ProviderConnection): Promise { // Load models for the selected provider. const qp = vscode.window.createQuickPick(); qp.title = 'Create New AI Profile - Step 4 of 5';