Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/abliteration-provider.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"roo-cline": patch
---

Add abliteration.ai as a provider.
5 changes: 5 additions & 0 deletions apps/cli/src/lib/utils/__tests__/provider.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@ describe("getApiKeyFromEnv", () => {
expect(getApiKeyFromEnv("anthropic")).toBe("test-anthropic-key")
})

it("should return API key from environment variable for abliteration.ai", () => {
process.env.ABLIT_KEY = "test-abliteration-key"
expect(getApiKeyFromEnv("abliteration")).toBe("test-abliteration-key")
})

it("should return API key from environment variable for openrouter", () => {
process.env.OPENROUTER_API_KEY = "test-openrouter-key"
expect(getApiKeyFromEnv("openrouter")).toBe("test-openrouter-key")
Expand Down
5 changes: 5 additions & 0 deletions apps/cli/src/lib/utils/provider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import { RooCodeSettings } from "@roo-code/types"
import type { SupportedProvider } from "@/types/index.js"

const envVarMap: Record<SupportedProvider, string> = {
abliteration: "ABLIT_KEY",
anthropic: "ANTHROPIC_API_KEY",
"openai-native": "OPENAI_API_KEY",
gemini: "GOOGLE_API_KEY",
Expand All @@ -28,6 +29,10 @@ export function getProviderSettings(
const config: RooCodeSettings = { apiProvider: provider }

switch (provider) {
case "abliteration":
if (apiKey) config.abliterationApiKey = apiKey
if (model) config.apiModelId = model
break
case "anthropic":
if (apiKey) config.apiKey = apiKey
if (model) config.apiModelId = model
Expand Down
1 change: 1 addition & 0 deletions apps/cli/src/types/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import type { ProviderName, ReasoningEffortExtended } from "@roo-code/types"
import type { OutputFormat } from "./json-events.js"

export const supportedProviders = [
"abliteration",
"anthropic",
"openai-native",
"gemini",
Expand Down
1 change: 1 addition & 0 deletions packages/types/src/global-settings.ts
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@ export type RooCodeSettings = GlobalSettings & ProviderSettings
export const SECRET_STATE_KEYS = [
"apiKey",
"openRouterApiKey",
"abliterationApiKey",
"awsAccessKey",
"awsApiKey",
"awsSecretKey",
Expand Down
14 changes: 14 additions & 0 deletions packages/types/src/provider-settings.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import { z } from "zod"
import { modelInfoSchema, reasoningEffortSettingSchema, verbosityLevelsSchema, serviceTierSchema } from "./model.js"
import { codebaseIndexProviderSchema } from "./codebase-index.js"
import {
abliterationModels,
anthropicModels,
basetenModels,
bedrockModels,
Expand Down Expand Up @@ -110,6 +111,7 @@ export const providerNames = [
...internalProviders,
...customProviders,
...fauxProviders,
"abliteration",
"anthropic",
"bedrock",
"baseten",
Expand Down Expand Up @@ -209,6 +211,10 @@ const anthropicSchema = apiModelIdProviderModelSchema.extend({
anthropicBeta1MContext: z.boolean().optional(), // Enable 'context-1m-2025-08-07' beta for 1M context window.
})

const abliterationSchema = apiModelIdProviderModelSchema.extend({
abliterationApiKey: z.string().optional(),
})

const openRouterSchema = baseProviderSettingsSchema.extend({
openRouterApiKey: z.string().optional(),
openRouterModelId: z.string().optional(),
Expand Down Expand Up @@ -399,6 +405,7 @@ const defaultSchema = z.object({
})

export const providerSettingsSchemaDiscriminated = z.discriminatedUnion("apiProvider", [
abliterationSchema.merge(z.object({ apiProvider: z.literal("abliteration") })),
anthropicSchema.merge(z.object({ apiProvider: z.literal("anthropic") })),
openRouterSchema.merge(z.object({ apiProvider: z.literal("openrouter") })),
bedrockSchema.merge(z.object({ apiProvider: z.literal("bedrock") })),
Expand Down Expand Up @@ -433,6 +440,7 @@ export const providerSettingsSchemaDiscriminated = z.discriminatedUnion("apiProv

export const providerSettingsSchema = z.object({
apiProvider: providerNamesWithRetiredSchema.optional(),
...abliterationSchema.shape,
...anthropicSchema.shape,
...openRouterSchema.shape,
...bedrockSchema.shape,
Expand Down Expand Up @@ -511,6 +519,7 @@ export const isTypicalProvider = (key: unknown): key is TypicalProvider =>
isProviderName(key) && !isInternalProvider(key) && !isCustomProvider(key) && !isFauxProvider(key)

export const modelIdKeysByProvider: Record<TypicalProvider, ModelIdKey> = {
abliteration: "apiModelId",
anthropic: "apiModelId",
openrouter: "openRouterModelId",
bedrock: "apiModelId",
Expand Down Expand Up @@ -576,6 +585,11 @@ export const MODELS_BY_PROVIDER: Record<
Exclude<ProviderName, "fake-ai" | "gemini-cli" | "openai">,
{ id: ProviderName; label: string; models: string[] }
> = {
abliteration: {
id: "abliteration",
label: "abliteration.ai",
models: Object.keys(abliterationModels),
},
anthropic: {
id: "anthropic",
label: "Anthropic",
Expand Down
17 changes: 17 additions & 0 deletions packages/types/src/providers/abliteration.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import type { ModelInfo } from "../model.js"

// https://docs.abliteration.ai/models
export type AbliterationModelId = "abliterated-model"

export const abliterationDefaultModelId: AbliterationModelId = "abliterated-model"

export const abliterationModels = {
"abliterated-model": {
maxTokens: 8192,
contextWindow: 150_000,
supportsImages: true,
supportsPromptCache: false,
description:
"Default abliteration.ai model. Supports OpenAI-compatible chat completions, streaming, tool calling, and vision.",
},
} as const satisfies Record<string, ModelInfo>
4 changes: 4 additions & 0 deletions packages/types/src/providers/index.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
export * from "./abliteration.js"
export * from "./anthropic.js"
export * from "./baseten.js"
export * from "./bedrock.js"
Expand Down Expand Up @@ -26,6 +27,7 @@ export * from "./vercel-ai-gateway.js"
export * from "./zai.js"
export * from "./minimax.js"

import { abliterationDefaultModelId } from "./abliteration.js"
import { anthropicDefaultModelId } from "./anthropic.js"
import { basetenDefaultModelId } from "./baseten.js"
import { bedrockDefaultModelId } from "./bedrock.js"
Expand Down Expand Up @@ -63,6 +65,8 @@ export function getProviderDefaultModelId(
options: { isChina?: boolean } = { isChina: false },
): string {
switch (provider) {
case "abliteration":
return abliterationDefaultModelId
case "openrouter":
return openRouterDefaultModelId
case "requesty":
Expand Down
3 changes: 3 additions & 0 deletions src/api/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import { ApiStream } from "./transform/stream"

import {
AnthropicHandler,
AbliterationHandler,
AwsBedrockHandler,
OpenRouterHandler,
PoeHandler,
Expand Down Expand Up @@ -119,6 +120,8 @@ export function buildApiHandler(configuration: ProviderSettings): ApiHandler {
}

switch (apiProvider) {
case "abliteration":
return new AbliterationHandler(options)
case "anthropic":
return new AnthropicHandler(options)
case "openrouter":
Expand Down
129 changes: 129 additions & 0 deletions src/api/providers/__tests__/abliteration.spec.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
// npx vitest run src/api/providers/__tests__/abliteration.spec.ts

import { Anthropic } from "@anthropic-ai/sdk"
import OpenAI from "openai"

import { abliterationDefaultModelId, abliterationModels } from "@roo-code/types"

import { AbliterationHandler } from "../abliteration"

const mockCreate = vi.fn()

vi.mock("openai", () => ({
default: vi.fn(() => ({
chat: {
completions: {
create: mockCreate,
},
},
})),
}))

describe("AbliterationHandler", () => {
let handler: AbliterationHandler

beforeEach(() => {
vi.clearAllMocks()
handler = new AbliterationHandler({ abliterationApiKey: "test-abliteration-api-key" })
})

it("should use the correct abliteration.ai base URL", () => {
new AbliterationHandler({ abliterationApiKey: "test-abliteration-api-key" })
expect(OpenAI).toHaveBeenCalledWith(expect.objectContaining({ baseURL: "https://api.abliteration.ai/v1" }))
})

it("should use the provided API key", () => {
const abliterationApiKey = "test-abliteration-api-key"
new AbliterationHandler({ abliterationApiKey })
expect(OpenAI).toHaveBeenCalledWith(expect.objectContaining({ apiKey: abliterationApiKey }))
})

it("should throw error when API key is not provided", () => {
expect(() => new AbliterationHandler({})).toThrow("API key is required")
})

it("should return default model when no model is specified", () => {
const model = handler.getModel()
expect(model.id).toBe(abliterationDefaultModelId)
expect(model.info).toEqual(abliterationModels[abliterationDefaultModelId])
})

it("should return specified model when valid model is provided", () => {
const handlerWithModel = new AbliterationHandler({
apiModelId: "abliterated-model",
abliterationApiKey: "test-abliteration-api-key",
})
const model = handlerWithModel.getModel()
expect(model.id).toBe("abliterated-model")
expect(model.info).toEqual(abliterationModels["abliterated-model"])
})

it("completePrompt method should return text from abliteration.ai API", async () => {
const expectedResponse = "This is a test response from abliteration.ai"
mockCreate.mockResolvedValueOnce({ choices: [{ message: { content: expectedResponse } }] })
const result = await handler.completePrompt("test prompt")
expect(result).toBe(expectedResponse)
})

it("createMessage should yield text content from stream", async () => {
const testContent = "This is test content from abliteration.ai stream"

mockCreate.mockImplementationOnce(() => {
return {
[Symbol.asyncIterator]: () => ({
next: vi
.fn()
.mockResolvedValueOnce({
done: false,
value: { choices: [{ delta: { content: testContent } }] },
})
.mockResolvedValueOnce({ done: true }),
}),
}
})

const stream = handler.createMessage("system prompt", [])
const firstChunk = await stream.next()

expect(firstChunk.done).toBe(false)
expect(firstChunk.value).toEqual({ type: "text", text: testContent })
})

it("createMessage should pass correct parameters to abliteration.ai client", async () => {
const modelInfo = abliterationModels[abliterationDefaultModelId]
const handlerWithModel = new AbliterationHandler({
apiModelId: abliterationDefaultModelId,
abliterationApiKey: "test-abliteration-api-key",
})

mockCreate.mockImplementationOnce(() => {
return {
[Symbol.asyncIterator]: () => ({
async next() {
return { done: true }
},
}),
}
})

const systemPrompt = "Test system prompt for abliteration.ai"
const messages: Anthropic.Messages.MessageParam[] = [
{ role: "user", content: "Test message for abliteration.ai" },
]

const messageGenerator = handlerWithModel.createMessage(systemPrompt, messages)
await messageGenerator.next()

expect(mockCreate).toHaveBeenCalledWith(
expect.objectContaining({
model: abliterationDefaultModelId,
max_tokens: modelInfo.maxTokens,
temperature: 0,
messages: expect.arrayContaining([{ role: "system", content: systemPrompt }]),
stream: true,
stream_options: { include_usage: true },
}),
undefined,
)
})
})
18 changes: 18 additions & 0 deletions src/api/providers/abliteration.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import { type AbliterationModelId, abliterationDefaultModelId, abliterationModels } from "@roo-code/types"

import type { ApiHandlerOptions } from "../../shared/api"

import { BaseOpenAiCompatibleProvider } from "./base-openai-compatible-provider"

export class AbliterationHandler extends BaseOpenAiCompatibleProvider<AbliterationModelId> {
constructor(options: ApiHandlerOptions) {
super({
...options,
providerName: "abliteration.ai",
baseURL: "https://api.abliteration.ai/v1",
apiKey: options.abliterationApiKey,
defaultProviderModelId: abliterationDefaultModelId,
providerModels: abliterationModels,
})
}
}
1 change: 1 addition & 0 deletions src/api/providers/index.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
export { AbliterationHandler } from "./abliteration"
export { AnthropicVertexHandler } from "./anthropic-vertex"
export { AnthropicHandler } from "./anthropic"
export { AwsBedrockHandler } from "./bedrock"
Expand Down
1 change: 1 addition & 0 deletions src/shared/ProfileValidator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ export class ProfileValidator {
switch (profile.apiProvider) {
case "openai":
return profile.openAiModelId
case "abliteration":
case "anthropic":
case "openai-native":
case "bedrock":
Expand Down
1 change: 1 addition & 0 deletions src/shared/__tests__/ProfileValidator.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ describe("ProfileValidator", () => {

// Test specific providers that use apiModelId
const apiModelProviders = [
"abliteration",
"anthropic",
"openai-native",
"bedrock",
Expand Down
10 changes: 10 additions & 0 deletions webview-ui/src/components/settings/ApiOptions.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import {
type ProviderSettings,
isRetiredProvider,
DEFAULT_CONSECUTIVE_MISTAKE_LIMIT,
abliterationDefaultModelId,
openRouterDefaultModelId,
poeDefaultModelId,
requestyDefaultModelId,
Expand Down Expand Up @@ -94,6 +95,7 @@ import {
Fireworks,
VercelAiGateway,
MiniMax,
Abliteration,
} from "./providers"

import { MODELS_BY_PROVIDER, PROVIDERS } from "./constants"
Expand Down Expand Up @@ -362,6 +364,7 @@ const ApiOptions = ({
fireworks: { field: "apiModelId", default: fireworksDefaultModelId },
poe: { field: "apiModelId", default: poeDefaultModelId },
roo: { field: "apiModelId", default: rooDefaultModelId },
abliteration: { field: "apiModelId", default: abliterationDefaultModelId },
"vercel-ai-gateway": { field: "vercelAiGatewayModelId", default: vercelAiGatewayDefaultModelId },
openai: { field: "openAiModelId" },
ollama: { field: "ollamaModelId" },
Expand Down Expand Up @@ -546,6 +549,13 @@ const ApiOptions = ({
/>
)}

{selectedProvider === "abliteration" && (
<Abliteration
apiConfiguration={apiConfiguration}
setApiConfigurationField={setApiConfigurationField}
/>
)}

{selectedProvider === "openai-codex" && (
<OpenAICodex
apiConfiguration={apiConfiguration}
Expand Down
Loading
Loading