Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 1 addition & 25 deletions TablePro/Core/AI/AIProvider.swift
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,9 @@
// AIProvider.swift
// TablePro
//
// Protocol defining AI provider interface for streaming chat and model discovery.
//

import Foundation

/// Protocol for AI provider implementations
protocol AIProvider: AnyObject {
/// Stream chat completions as an async sequence of events (text tokens and usage)
func streamChat(
messages: [AIChatMessage],
model: String,
systemPrompt: String?
) -> AsyncThrowingStream<AIStreamEvent, Error>

/// Fetch available models from the provider
func fetchAvailableModels() async throws -> [String]

/// Test connection to verify API key and endpoint
func testConnection() async throws -> Bool
}

/// Errors that can occur during AI provider operations
enum AIProviderError: Error, LocalizedError {
case invalidEndpoint(String)
case authenticationFailed(String)
Expand Down Expand Up @@ -55,7 +36,6 @@ enum AIProviderError: Error, LocalizedError {
}
}

/// Base HTTP error mapping — providers can override for custom status codes
static func mapHTTPError(statusCode: Int, body: String) -> AIProviderError {
let message = parseErrorMessage(from: body) ?? body
switch statusCode {
Expand All @@ -70,8 +50,6 @@ enum AIProviderError: Error, LocalizedError {
}
}

/// Extract human-readable message from provider JSON error responses.
/// Supports Anthropic (`{"error":{"message":"..."}}`), OpenAI, and Gemini formats.
static func parseErrorMessage(from body: String) -> String? {
guard let data = body.data(using: .utf8),
let json = try? JSONSerialization.jsonObject(with: data) as? [String: Any],
Expand All @@ -84,9 +62,7 @@ enum AIProviderError: Error, LocalizedError {
}
}

// MARK: - Shared Helpers

extension AIProvider {
extension ChatTransport {
func collectErrorBody(from bytes: URLSession.AsyncBytes) async throws -> String {
var body = ""
for try await line in bytes.lines {
Expand Down
11 changes: 4 additions & 7 deletions TablePro/Core/AI/AIProviderFactory.swift
Original file line number Diff line number Diff line change
Expand Up @@ -2,30 +2,27 @@
// AIProviderFactory.swift
// TablePro
//
// Factory for creating AI provider instances. Resolves the active provider
// from settings (no per-feature routing).
//

import Foundation
import os

enum AIProviderFactory {
struct ResolvedProvider: Sendable {
let provider: AIProvider
let provider: ChatTransport
let model: String
let config: AIProviderConfig
}

private static let cacheLock = OSAllocatedUnfairLock(
initialState: [UUID: (config: AIProviderConfig, apiKey: String?, provider: AIProvider)]()
initialState: [UUID: (config: AIProviderConfig, apiKey: String?, provider: ChatTransport)]()
)

static func createProvider(for config: AIProviderConfig, apiKey: String?) -> AIProvider {
static func createProvider(for config: AIProviderConfig, apiKey: String?) -> ChatTransport {
cacheLock.withLock { cache in
if let cached = cache[config.id], cached.apiKey == apiKey, cached.config == config {
return cached.provider
}
let provider: AIProvider
let provider: ChatTransport
if let descriptor = AIProviderRegistry.shared.descriptor(for: config.type.rawValue) {
provider = descriptor.makeProvider(config, apiKey)
} else {
Expand Down
60 changes: 19 additions & 41 deletions TablePro/Core/AI/AnthropicProvider.swift
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,11 @@
// AnthropicProvider.swift
// TablePro
//
// Anthropic Claude API provider using the Messages API with SSE streaming.
//

import Foundation
import os

/// AI provider for Anthropic's Claude models
final class AnthropicProvider: AIProvider {
final class AnthropicProvider: ChatTransport {
private static let logger = Logger(subsystem: "com.TablePro", category: "AnthropicProvider")

private let endpoint: String
Expand All @@ -24,22 +21,14 @@ final class AnthropicProvider: AIProvider {
self.session = URLSession(configuration: .ephemeral)
}

// MARK: - AIProvider

func streamChat(
messages: [AIChatMessage],
model: String,
systemPrompt: String?
) -> AsyncThrowingStream<AIStreamEvent, Error> {
turns: [ChatTurn],
options: ChatTransportOptions
) -> AsyncThrowingStream<ChatStreamEvent, Error> {
AsyncThrowingStream { continuation in
let task = Task {
do {
let request = try buildMessagesRequest(
messages: messages,
model: model,
systemPrompt: systemPrompt
)

let request = try buildMessagesRequest(turns: turns, options: options)
let (bytes, response) = try await session.bytes(for: request)

guard let httpResponse = response as? HTTPURLResponse else {
Expand Down Expand Up @@ -72,7 +61,7 @@ final class AnthropicProvider: AIProvider {
case "content_block_delta":
if let delta = json["delta"] as? [String: Any],
let text = delta["text"] as? String {
continuation.yield(.text(text))
continuation.yield(.textDelta(text))
}
case "message_start":
if let message = json["message"] as? [String: Any],
Expand All @@ -95,7 +84,6 @@ final class AnthropicProvider: AIProvider {
}
}

// Yield usage if we got any token data
if inputTokens > 0 || outputTokens > 0 {
continuation.yield(.usage(AITokenUsage(
inputTokens: inputTokens,
Expand Down Expand Up @@ -148,14 +136,9 @@ final class AnthropicProvider: AIProvider {
]

func testConnection() async throws -> Bool {
let testMessage = AIChatMessage(role: .user, content: "Hi")
let request = try buildMessagesRequest(
messages: [testMessage],
model: "claude-haiku-4-5-20251001",
systemPrompt: nil,
maxTokens: 1,
stream: false
)
let testTurn = ChatTurn(role: .user, blocks: [.text("Hi")])
let testOptions = ChatTransportOptions(model: "claude-haiku-4-5-20251001", maxOutputTokens: 1)
let request = try buildMessagesRequest(turns: [testTurn], options: testOptions, stream: false)

let (data, response) = try await session.data(for: request)

Expand All @@ -165,7 +148,6 @@ final class AnthropicProvider: AIProvider {

let statusCode = httpResponse.statusCode

// 200 = full success, 400 = key is valid but request was rejected (e.g. billing)
if statusCode == 200 || statusCode == 400 {
return true
}
Expand All @@ -178,16 +160,11 @@ final class AnthropicProvider: AIProvider {
throw AIProviderError.mapHTTPError(statusCode: statusCode, body: body)
}

// MARK: - Private

private func buildMessagesRequest(
messages: [AIChatMessage],
model: String,
systemPrompt: String?,
maxTokens: Int? = nil,
turns: [ChatTurn],
options: ChatTransportOptions,
stream: Bool = true
) throws -> URLRequest {
let maxTokens = maxTokens ?? maxOutputTokens
guard let url = URL(string: "\(endpoint)/v1/messages") else {
throw AIProviderError.invalidEndpoint(endpoint)
}
Expand All @@ -199,20 +176,21 @@ final class AnthropicProvider: AIProvider {
request.setValue("2023-06-01", forHTTPHeaderField: "anthropic-version")

var body: [String: Any] = [
"model": model,
"max_tokens": maxTokens,
"model": options.model,
"max_tokens": options.maxOutputTokens ?? maxOutputTokens,
"stream": stream
]

if let systemPrompt {
if let systemPrompt = options.systemPrompt {
body["system"] = systemPrompt
}

// Convert messages (skip system role — handled via system parameter)
let apiMessages = messages
let apiMessages = turns
.filter { $0.role != .system }
.map { message -> [String: String] in
["role": message.role.rawValue, "content": message.content]
.compactMap { turn -> [String: String]? in
let text = turn.plainText
guard !text.isEmpty else { return nil }
return ["role": turn.role.rawValue, "content": text]
}
body["messages"] = apiMessages

Expand Down
53 changes: 53 additions & 0 deletions TablePro/Core/AI/Chat/ChatTransport.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
//
// ChatTransport.swift
// TablePro
//

import Foundation

protocol ChatTransport: AnyObject, Sendable {
func streamChat(
turns: [ChatTurn],
options: ChatTransportOptions
) -> AsyncThrowingStream<ChatStreamEvent, Error>

func fetchAvailableModels() async throws -> [String]

func testConnection() async throws -> Bool
}

struct ChatTransportOptions: Sendable {
var model: String
var systemPrompt: String?
var maxOutputTokens: Int?
var temperature: Double?
var tools: [ChatToolSpec]

init(
model: String,
systemPrompt: String? = nil,
maxOutputTokens: Int? = nil,
temperature: Double? = nil,
tools: [ChatToolSpec] = []
) {
self.model = model
self.systemPrompt = systemPrompt
self.maxOutputTokens = maxOutputTokens
self.temperature = temperature
self.tools = tools
}
}

struct ChatToolSpec: Codable, Equatable, Sendable {
let name: String
let description: String
let inputSchema: JSONValue
}

enum ChatStreamEvent: Sendable {
case textDelta(String)
case toolUseStart(id: String, name: String)
case toolUseDelta(id: String, inputJSONDelta: String)
case toolUseEnd(id: String)
case usage(AITokenUsage)
}
Loading
Loading