diff --git a/TablePro/Core/AI/AIProvider.swift b/TablePro/Core/AI/AIProvider.swift index b904fd8a9..862f0a8f6 100644 --- a/TablePro/Core/AI/AIProvider.swift +++ b/TablePro/Core/AI/AIProvider.swift @@ -2,28 +2,9 @@ // AIProvider.swift // TablePro // -// Protocol defining AI provider interface for streaming chat and model discovery. -// import Foundation -/// Protocol for AI provider implementations -protocol AIProvider: AnyObject { - /// Stream chat completions as an async sequence of events (text tokens and usage) - func streamChat( - messages: [AIChatMessage], - model: String, - systemPrompt: String? - ) -> AsyncThrowingStream - - /// Fetch available models from the provider - func fetchAvailableModels() async throws -> [String] - - /// Test connection to verify API key and endpoint - func testConnection() async throws -> Bool -} - -/// Errors that can occur during AI provider operations enum AIProviderError: Error, LocalizedError { case invalidEndpoint(String) case authenticationFailed(String) @@ -55,7 +36,6 @@ enum AIProviderError: Error, LocalizedError { } } - /// Base HTTP error mapping — providers can override for custom status codes static func mapHTTPError(statusCode: Int, body: String) -> AIProviderError { let message = parseErrorMessage(from: body) ?? body switch statusCode { @@ -70,8 +50,6 @@ enum AIProviderError: Error, LocalizedError { } } - /// Extract human-readable message from provider JSON error responses. - /// Supports Anthropic (`{"error":{"message":"..."}}`), OpenAI, and Gemini formats. static func parseErrorMessage(from body: String) -> String? { guard let data = body.data(using: .utf8), let json = try? JSONSerialization.jsonObject(with: data) as? [String: Any], @@ -84,9 +62,7 @@ enum AIProviderError: Error, LocalizedError { } } -// MARK: - Shared Helpers - -extension AIProvider { +extension ChatTransport { func collectErrorBody(from bytes: URLSession.AsyncBytes) async throws -> String { var body = "" for try await line in bytes.lines { diff --git a/TablePro/Core/AI/AIProviderFactory.swift b/TablePro/Core/AI/AIProviderFactory.swift index 75f1f2995..f576b8e49 100644 --- a/TablePro/Core/AI/AIProviderFactory.swift +++ b/TablePro/Core/AI/AIProviderFactory.swift @@ -2,30 +2,27 @@ // AIProviderFactory.swift // TablePro // -// Factory for creating AI provider instances. Resolves the active provider -// from settings (no per-feature routing). -// import Foundation import os enum AIProviderFactory { struct ResolvedProvider: Sendable { - let provider: AIProvider + let provider: ChatTransport let model: String let config: AIProviderConfig } private static let cacheLock = OSAllocatedUnfairLock( - initialState: [UUID: (config: AIProviderConfig, apiKey: String?, provider: AIProvider)]() + initialState: [UUID: (config: AIProviderConfig, apiKey: String?, provider: ChatTransport)]() ) - static func createProvider(for config: AIProviderConfig, apiKey: String?) -> AIProvider { + static func createProvider(for config: AIProviderConfig, apiKey: String?) -> ChatTransport { cacheLock.withLock { cache in if let cached = cache[config.id], cached.apiKey == apiKey, cached.config == config { return cached.provider } - let provider: AIProvider + let provider: ChatTransport if let descriptor = AIProviderRegistry.shared.descriptor(for: config.type.rawValue) { provider = descriptor.makeProvider(config, apiKey) } else { diff --git a/TablePro/Core/AI/AnthropicProvider.swift b/TablePro/Core/AI/AnthropicProvider.swift index 3c6cb01d1..54ddad4dd 100644 --- a/TablePro/Core/AI/AnthropicProvider.swift +++ b/TablePro/Core/AI/AnthropicProvider.swift @@ -2,14 +2,11 @@ // AnthropicProvider.swift // TablePro // -// Anthropic Claude API provider using the Messages API with SSE streaming. -// import Foundation import os -/// AI provider for Anthropic's Claude models -final class AnthropicProvider: AIProvider { +final class AnthropicProvider: ChatTransport { private static let logger = Logger(subsystem: "com.TablePro", category: "AnthropicProvider") private let endpoint: String @@ -24,22 +21,14 @@ final class AnthropicProvider: AIProvider { self.session = URLSession(configuration: .ephemeral) } - // MARK: - AIProvider - func streamChat( - messages: [AIChatMessage], - model: String, - systemPrompt: String? - ) -> AsyncThrowingStream { + turns: [ChatTurn], + options: ChatTransportOptions + ) -> AsyncThrowingStream { AsyncThrowingStream { continuation in let task = Task { do { - let request = try buildMessagesRequest( - messages: messages, - model: model, - systemPrompt: systemPrompt - ) - + let request = try buildMessagesRequest(turns: turns, options: options) let (bytes, response) = try await session.bytes(for: request) guard let httpResponse = response as? HTTPURLResponse else { @@ -72,7 +61,7 @@ final class AnthropicProvider: AIProvider { case "content_block_delta": if let delta = json["delta"] as? [String: Any], let text = delta["text"] as? String { - continuation.yield(.text(text)) + continuation.yield(.textDelta(text)) } case "message_start": if let message = json["message"] as? [String: Any], @@ -95,7 +84,6 @@ final class AnthropicProvider: AIProvider { } } - // Yield usage if we got any token data if inputTokens > 0 || outputTokens > 0 { continuation.yield(.usage(AITokenUsage( inputTokens: inputTokens, @@ -148,14 +136,9 @@ final class AnthropicProvider: AIProvider { ] func testConnection() async throws -> Bool { - let testMessage = AIChatMessage(role: .user, content: "Hi") - let request = try buildMessagesRequest( - messages: [testMessage], - model: "claude-haiku-4-5-20251001", - systemPrompt: nil, - maxTokens: 1, - stream: false - ) + let testTurn = ChatTurn(role: .user, blocks: [.text("Hi")]) + let testOptions = ChatTransportOptions(model: "claude-haiku-4-5-20251001", maxOutputTokens: 1) + let request = try buildMessagesRequest(turns: [testTurn], options: testOptions, stream: false) let (data, response) = try await session.data(for: request) @@ -165,7 +148,6 @@ final class AnthropicProvider: AIProvider { let statusCode = httpResponse.statusCode - // 200 = full success, 400 = key is valid but request was rejected (e.g. billing) if statusCode == 200 || statusCode == 400 { return true } @@ -178,16 +160,11 @@ final class AnthropicProvider: AIProvider { throw AIProviderError.mapHTTPError(statusCode: statusCode, body: body) } - // MARK: - Private - private func buildMessagesRequest( - messages: [AIChatMessage], - model: String, - systemPrompt: String?, - maxTokens: Int? = nil, + turns: [ChatTurn], + options: ChatTransportOptions, stream: Bool = true ) throws -> URLRequest { - let maxTokens = maxTokens ?? maxOutputTokens guard let url = URL(string: "\(endpoint)/v1/messages") else { throw AIProviderError.invalidEndpoint(endpoint) } @@ -199,20 +176,21 @@ final class AnthropicProvider: AIProvider { request.setValue("2023-06-01", forHTTPHeaderField: "anthropic-version") var body: [String: Any] = [ - "model": model, - "max_tokens": maxTokens, + "model": options.model, + "max_tokens": options.maxOutputTokens ?? maxOutputTokens, "stream": stream ] - if let systemPrompt { + if let systemPrompt = options.systemPrompt { body["system"] = systemPrompt } - // Convert messages (skip system role — handled via system parameter) - let apiMessages = messages + let apiMessages = turns .filter { $0.role != .system } - .map { message -> [String: String] in - ["role": message.role.rawValue, "content": message.content] + .compactMap { turn -> [String: String]? in + let text = turn.plainText + guard !text.isEmpty else { return nil } + return ["role": turn.role.rawValue, "content": text] } body["messages"] = apiMessages diff --git a/TablePro/Core/AI/Chat/ChatTransport.swift b/TablePro/Core/AI/Chat/ChatTransport.swift new file mode 100644 index 000000000..963233343 --- /dev/null +++ b/TablePro/Core/AI/Chat/ChatTransport.swift @@ -0,0 +1,53 @@ +// +// ChatTransport.swift +// TablePro +// + +import Foundation + +protocol ChatTransport: AnyObject, Sendable { + func streamChat( + turns: [ChatTurn], + options: ChatTransportOptions + ) -> AsyncThrowingStream + + func fetchAvailableModels() async throws -> [String] + + func testConnection() async throws -> Bool +} + +struct ChatTransportOptions: Sendable { + var model: String + var systemPrompt: String? + var maxOutputTokens: Int? + var temperature: Double? + var tools: [ChatToolSpec] + + init( + model: String, + systemPrompt: String? = nil, + maxOutputTokens: Int? = nil, + temperature: Double? = nil, + tools: [ChatToolSpec] = [] + ) { + self.model = model + self.systemPrompt = systemPrompt + self.maxOutputTokens = maxOutputTokens + self.temperature = temperature + self.tools = tools + } +} + +struct ChatToolSpec: Codable, Equatable, Sendable { + let name: String + let description: String + let inputSchema: JSONValue +} + +enum ChatStreamEvent: Sendable { + case textDelta(String) + case toolUseStart(id: String, name: String) + case toolUseDelta(id: String, inputJSONDelta: String) + case toolUseEnd(id: String) + case usage(AITokenUsage) +} diff --git a/TablePro/Core/AI/Chat/ChatTurn.swift b/TablePro/Core/AI/Chat/ChatTurn.swift new file mode 100644 index 000000000..27d7d19b4 --- /dev/null +++ b/TablePro/Core/AI/Chat/ChatTurn.swift @@ -0,0 +1,155 @@ +// +// ChatTurn.swift +// TablePro +// + +import Foundation + +enum ChatRole: String, Codable, Sendable { + case user + case assistant + case system +} + +struct ChatTurn: Codable, Equatable, Identifiable, Sendable { + let id: UUID + var role: ChatRole + var blocks: [ChatContentBlock] + let timestamp: Date + var usage: AITokenUsage? + var modelId: String? + var providerId: String? + + init( + id: UUID = UUID(), + role: ChatRole, + blocks: [ChatContentBlock], + timestamp: Date = Date(), + usage: AITokenUsage? = nil, + modelId: String? = nil, + providerId: String? = nil + ) { + self.id = id + self.role = role + self.blocks = blocks + self.timestamp = timestamp + self.usage = usage + self.modelId = modelId + self.providerId = providerId + } + + init(from decoder: Decoder) throws { + let container = try decoder.container(keyedBy: CodingKeys.self) + id = try container.decode(UUID.self, forKey: .id) + role = try container.decode(ChatRole.self, forKey: .role) + timestamp = try container.decode(Date.self, forKey: .timestamp) + usage = try container.decodeIfPresent(AITokenUsage.self, forKey: .usage) + modelId = try container.decodeIfPresent(String.self, forKey: .modelId) + providerId = try container.decodeIfPresent(String.self, forKey: .providerId) + + if let decodedBlocks = try container.decodeIfPresent([ChatContentBlock].self, forKey: .blocks) { + blocks = decodedBlocks + } else if let legacyText = try container.decodeIfPresent(String.self, forKey: .content) { + blocks = [.text(legacyText)] + } else { + blocks = [] + } + } + + func encode(to encoder: Encoder) throws { + var container = encoder.container(keyedBy: CodingKeys.self) + try container.encode(id, forKey: .id) + try container.encode(role, forKey: .role) + try container.encode(blocks, forKey: .blocks) + try container.encode(timestamp, forKey: .timestamp) + try container.encodeIfPresent(usage, forKey: .usage) + try container.encodeIfPresent(modelId, forKey: .modelId) + try container.encodeIfPresent(providerId, forKey: .providerId) + } + + private enum CodingKeys: String, CodingKey { + case id, role, blocks, content, timestamp, usage, modelId, providerId + } + + var plainText: String { + blocks.compactMap { block in + if case .text(let text) = block { return text } + return nil + }.joined() + } + + mutating func appendText(_ text: String) { + guard !text.isEmpty else { return } + if case .text(let existing) = blocks.last { + blocks[blocks.count - 1] = .text(existing + text) + } else { + blocks.append(.text(text)) + } + } +} + +enum ChatContentBlock: Codable, Equatable, Sendable { + case text(String) + case toolUse(ToolUseBlock) + case toolResult(ToolResultBlock) + case attachment(ContextItem) + + private enum CodingKeys: String, CodingKey { + case kind, text, toolUse, toolResult, attachment + } + + private enum Kind: String, Codable { + case text, toolUse, toolResult, attachment + } + + init(from decoder: Decoder) throws { + let container = try decoder.container(keyedBy: CodingKeys.self) + let kind = try container.decode(Kind.self, forKey: .kind) + switch kind { + case .text: + self = .text(try container.decode(String.self, forKey: .text)) + case .toolUse: + self = .toolUse(try container.decode(ToolUseBlock.self, forKey: .toolUse)) + case .toolResult: + self = .toolResult(try container.decode(ToolResultBlock.self, forKey: .toolResult)) + case .attachment: + self = .attachment(try container.decode(ContextItem.self, forKey: .attachment)) + } + } + + func encode(to encoder: Encoder) throws { + var container = encoder.container(keyedBy: CodingKeys.self) + switch self { + case .text(let text): + try container.encode(Kind.text, forKey: .kind) + try container.encode(text, forKey: .text) + case .toolUse(let block): + try container.encode(Kind.toolUse, forKey: .kind) + try container.encode(block, forKey: .toolUse) + case .toolResult(let block): + try container.encode(Kind.toolResult, forKey: .kind) + try container.encode(block, forKey: .toolResult) + case .attachment(let item): + try container.encode(Kind.attachment, forKey: .kind) + try container.encode(item, forKey: .attachment) + } + } +} + +struct ToolUseBlock: Codable, Equatable, Sendable { + let id: String + let name: String + let input: JSONValue +} + +struct ToolResultBlock: Codable, Equatable, Sendable { + let toolUseId: String + let content: String + let isError: Bool + + init(toolUseId: String, content: String, isError: Bool = false) { + self.toolUseId = toolUseId + self.content = content + self.isError = isError + } +} diff --git a/TablePro/Core/AI/Chat/ContextItem.swift b/TablePro/Core/AI/Chat/ContextItem.swift new file mode 100644 index 000000000..6f248aaaa --- /dev/null +++ b/TablePro/Core/AI/Chat/ContextItem.swift @@ -0,0 +1,70 @@ +// +// ContextItem.swift +// TablePro +// + +import Foundation + +enum ContextItem: Codable, Equatable, Sendable { + case schema(connectionId: UUID) + case table(connectionId: UUID, name: String) + case currentQuery(text: String) + case queryResult(summary: String) + case savedQuery(id: UUID) + case file(url: URL) + + private enum CodingKeys: String, CodingKey { + case kind, connectionId, name, text, summary, id, url + } + + private enum Kind: String, Codable { + case schema, table, currentQuery, queryResult, savedQuery, file + } + + init(from decoder: Decoder) throws { + let container = try decoder.container(keyedBy: CodingKeys.self) + let kind = try container.decode(Kind.self, forKey: .kind) + switch kind { + case .schema: + let connectionId = try container.decode(UUID.self, forKey: .connectionId) + self = .schema(connectionId: connectionId) + case .table: + let connectionId = try container.decode(UUID.self, forKey: .connectionId) + let name = try container.decode(String.self, forKey: .name) + self = .table(connectionId: connectionId, name: name) + case .currentQuery: + self = .currentQuery(text: try container.decode(String.self, forKey: .text)) + case .queryResult: + self = .queryResult(summary: try container.decode(String.self, forKey: .summary)) + case .savedQuery: + self = .savedQuery(id: try container.decode(UUID.self, forKey: .id)) + case .file: + self = .file(url: try container.decode(URL.self, forKey: .url)) + } + } + + func encode(to encoder: Encoder) throws { + var container = encoder.container(keyedBy: CodingKeys.self) + switch self { + case .schema(let connectionId): + try container.encode(Kind.schema, forKey: .kind) + try container.encode(connectionId, forKey: .connectionId) + case .table(let connectionId, let name): + try container.encode(Kind.table, forKey: .kind) + try container.encode(connectionId, forKey: .connectionId) + try container.encode(name, forKey: .name) + case .currentQuery(let text): + try container.encode(Kind.currentQuery, forKey: .kind) + try container.encode(text, forKey: .text) + case .queryResult(let summary): + try container.encode(Kind.queryResult, forKey: .kind) + try container.encode(summary, forKey: .summary) + case .savedQuery(let id): + try container.encode(Kind.savedQuery, forKey: .kind) + try container.encode(id, forKey: .id) + case .file(let url): + try container.encode(Kind.file, forKey: .kind) + try container.encode(url, forKey: .url) + } + } +} diff --git a/TablePro/Core/AI/Chat/JSONValue.swift b/TablePro/Core/AI/Chat/JSONValue.swift new file mode 100644 index 000000000..aaadf3cc7 --- /dev/null +++ b/TablePro/Core/AI/Chat/JSONValue.swift @@ -0,0 +1,75 @@ +// +// JSONValue.swift +// TablePro +// + +import Foundation + +enum JSONValue: Codable, Equatable, Sendable, Hashable { + case null + case bool(Bool) + case number(Double) + case integer(Int64) + case string(String) + case array([JSONValue]) + case object([String: JSONValue]) + + init(from decoder: Decoder) throws { + let container = try decoder.singleValueContainer() + if container.decodeNil() { + self = .null + return + } + if let value = try? container.decode(Bool.self) { + self = .bool(value) + return + } + if let value = try? container.decode(Int64.self) { + self = .integer(value) + return + } + if let value = try? container.decode(Double.self) { + self = .number(value) + return + } + if let value = try? container.decode(String.self) { + self = .string(value) + return + } + if let value = try? container.decode([JSONValue].self) { + self = .array(value) + return + } + if let value = try? container.decode([String: JSONValue].self) { + self = .object(value) + return + } + throw DecodingError.dataCorruptedError( + in: container, + debugDescription: "Unsupported JSON value" + ) + } + + func encode(to encoder: Encoder) throws { + var container = encoder.singleValueContainer() + switch self { + case .null: try container.encodeNil() + case .bool(let value): try container.encode(value) + case .integer(let value): try container.encode(value) + case .number(let value): try container.encode(value) + case .string(let value): try container.encode(value) + case .array(let value): try container.encode(value) + case .object(let value): try container.encode(value) + } + } + + func decoded(as type: T.Type) throws -> T { + let data = try JSONEncoder().encode(self) + return try JSONDecoder().decode(type, from: data) + } + + static func encoded(_ value: some Encodable) throws -> JSONValue { + let data = try JSONEncoder().encode(value) + return try JSONDecoder().decode(JSONValue.self, from: data) + } +} diff --git a/TablePro/Core/AI/Copilot/CopilotChatProvider.swift b/TablePro/Core/AI/Copilot/CopilotChatProvider.swift index e13c295e1..6c181f8a3 100644 --- a/TablePro/Core/AI/Copilot/CopilotChatProvider.swift +++ b/TablePro/Core/AI/Copilot/CopilotChatProvider.swift @@ -6,23 +6,20 @@ import Foundation import os -final class CopilotChatProvider: AIProvider { +final class CopilotChatProvider: ChatTransport { private static let logger = Logger(subsystem: "com.TablePro", category: "CopilotChatProvider") private var conversationId: String? private var turnIds: [String] = [] private let progressHandlers = OSAllocatedUnfairLock( - initialState: [String: AsyncThrowingStream.Continuation]() + initialState: [String: AsyncThrowingStream.Continuation]() ) private var isProgressHandlerRegistered = false - // MARK: - AIProvider - func streamChat( - messages: [AIChatMessage], - model: String, - systemPrompt: String? - ) -> AsyncThrowingStream { + turns: [ChatTurn], + options: ChatTransportOptions + ) -> AsyncThrowingStream { AsyncThrowingStream { continuation in let task = Task { @MainActor [weak self] in guard let self else { @@ -44,19 +41,19 @@ final class CopilotChatProvider: AIProvider { self.progressHandlers.withLock { $0[token] = continuation } - let userMessage = messages.last(where: { $0.role == .user })?.content ?? "" - let effectiveModel: String? = model.isEmpty ? nil : model + let userMessage = turns.last(where: { $0.role == .user })?.plainText ?? "" + let effectiveModel: String? = options.model.isEmpty ? nil : options.model if self.conversationId == nil { - let systemPrefix = systemPrompt.map { $0 + "\n\n" } ?? "" - let turns = [CopilotConversationTurn( + let systemPrefix = options.systemPrompt.map { $0 + "\n\n" } ?? "" + let conversationTurns = [CopilotConversationTurn( request: systemPrefix + userMessage, response: "", turnId: "" )] let params = CopilotConversationCreateParams( workDoneToken: token, - turns: turns, + turns: conversationTurns, capabilities: CopilotConversationCapabilities( skills: ["current-editor"], allSkills: true @@ -107,8 +104,6 @@ final class CopilotChatProvider: AIProvider { await CopilotService.shared.isAuthenticated } - // MARK: - Conversation Lifecycle - func resetConversation() { isProgressHandlerRegistered = false let id = conversationId @@ -130,8 +125,6 @@ final class CopilotChatProvider: AIProvider { } } - // MARK: - Progress Handler - @MainActor private func ensureProgressHandler() async { guard !isProgressHandlerRegistered else { return } @@ -160,7 +153,7 @@ final class CopilotChatProvider: AIProvider { reply = last["reply"] as? String } if let reply, !reply.isEmpty { - continuation.yield(.text(reply)) + continuation.yield(.textDelta(reply)) } if let usage = value["tokenUsage"] as? [String: Any], diff --git a/TablePro/Core/AI/GeminiProvider.swift b/TablePro/Core/AI/GeminiProvider.swift index e8367b0df..217ee2892 100644 --- a/TablePro/Core/AI/GeminiProvider.swift +++ b/TablePro/Core/AI/GeminiProvider.swift @@ -2,14 +2,11 @@ // GeminiProvider.swift // TablePro // -// Google Gemini API provider using the Generative Language API with SSE streaming. -// import Foundation import os -/// AI provider for Google's Gemini models -final class GeminiProvider: AIProvider { +final class GeminiProvider: ChatTransport { private static let logger = Logger(subsystem: "com.TablePro", category: "GeminiProvider") private let endpoint: String @@ -24,22 +21,14 @@ final class GeminiProvider: AIProvider { self.session = URLSession(configuration: .ephemeral) } - // MARK: - AIProvider - func streamChat( - messages: [AIChatMessage], - model: String, - systemPrompt: String? - ) -> AsyncThrowingStream { + turns: [ChatTurn], + options: ChatTransportOptions + ) -> AsyncThrowingStream { AsyncThrowingStream { continuation in let task = Task { do { - let request = try buildStreamRequest( - messages: messages, - model: model, - systemPrompt: systemPrompt - ) - + let request = try buildStreamRequest(turns: turns, options: options) let (bytes, response) = try await session.bytes(for: request) guard let httpResponse = response as? HTTPURLResponse else { @@ -67,17 +56,15 @@ final class GeminiProvider: AIProvider { let json = try? JSONSerialization.jsonObject(with: data) as? [String: Any] else { continue } - // Extract text from candidates[0].content.parts[0].text if let candidates = json["candidates"] as? [[String: Any]], let firstCandidate = candidates.first, let content = firstCandidate["content"] as? [String: Any], let parts = content["parts"] as? [[String: Any]], let firstPart = parts.first, let text = firstPart["text"] as? String { - continuation.yield(.text(text)) + continuation.yield(.textDelta(text)) } - // Extract usage from usageMetadata if let usageMetadata = json["usageMetadata"] as? [String: Any] { if let prompt = usageMetadata["promptTokenCount"] as? Int { inputTokens = prompt @@ -88,7 +75,6 @@ final class GeminiProvider: AIProvider { } } - // Yield usage if we got any token data if inputTokens > 0 || outputTokens > 0 { continuation.yield(.usage(AITokenUsage( inputTokens: inputTokens, @@ -152,7 +138,6 @@ final class GeminiProvider: AIProvider { let methods = model["supportedGenerationMethods"] as? [String], methods.contains("generateContent") else { return nil } - // Strip "models/" prefix: "models/gemini-2.0-flash" -> "gemini-2.0-flash" if name.hasPrefix("models/") { return String(name.dropFirst(7)) } @@ -191,14 +176,11 @@ final class GeminiProvider: AIProvider { return true } - // MARK: - Private - private func buildStreamRequest( - messages: [AIChatMessage], - model: String, - systemPrompt: String? + turns: [ChatTurn], + options: ChatTransportOptions ) throws -> URLRequest { - guard let encodedModel = model.addingPercentEncoding(withAllowedCharacters: .urlPathAllowed), + guard let encodedModel = options.model.addingPercentEncoding(withAllowedCharacters: .urlPathAllowed), let url = URL( string: "\(endpoint)/v1beta/models/\(encodedModel):streamGenerateContent?alt=sse" ) else { @@ -211,21 +193,22 @@ final class GeminiProvider: AIProvider { request.setValue(apiKey, forHTTPHeaderField: "x-goog-api-key") var body: [String: Any] = [ - "generationConfig": ["maxOutputTokens": maxOutputTokens] + "generationConfig": ["maxOutputTokens": options.maxOutputTokens ?? maxOutputTokens] ] - if let systemPrompt, !systemPrompt.isEmpty { + if let systemPrompt = options.systemPrompt, !systemPrompt.isEmpty { body["systemInstruction"] = ["parts": [["text": systemPrompt]]] } - // Convert messages — Gemini uses "user" and "model" roles (not "assistant") - let contents = messages + let contents = turns .filter { $0.role != .system } - .map { message -> [String: Any] in - let role = message.role == .assistant ? "model" : "user" + .compactMap { turn -> [String: Any]? in + let text = turn.plainText + guard !text.isEmpty else { return nil } + let role = turn.role == .assistant ? "model" : "user" return [ "role": role, - "parts": [["text": message.content]] + "parts": [["text": text]] ] } body["contents"] = contents diff --git a/TablePro/Core/AI/InlineSuggestion/AIChatInlineSource.swift b/TablePro/Core/AI/InlineSuggestion/AIChatInlineSource.swift index 899883021..8d29d15ad 100644 --- a/TablePro/Core/AI/InlineSuggestion/AIChatInlineSource.swift +++ b/TablePro/Core/AI/InlineSuggestion/AIChatInlineSource.swift @@ -33,21 +33,20 @@ final class AIChatInlineSource: InlineSuggestionSource { } let userMessage = AIPromptTemplates.inlineSuggest(textBefore: context.textBefore, fullQuery: context.fullText) - let messages = [ - AIChatMessage(role: .user, content: userMessage) + let turns = [ + ChatTurn(role: .user, blocks: [.text(userMessage)]) ] let systemPrompt = await buildSystemPrompt() var accumulated = "" let stream = resolved.provider.streamChat( - messages: messages, - model: resolved.model, - systemPrompt: systemPrompt + turns: turns, + options: ChatTransportOptions(model: resolved.model, systemPrompt: systemPrompt) ) for try await event in stream { - if case .text(let token) = event { + if case .textDelta(let token) = event { accumulated += token } } diff --git a/TablePro/Core/AI/OpenAICompatibleProvider.swift b/TablePro/Core/AI/OpenAICompatibleProvider.swift index f9464f2b6..b139eda87 100644 --- a/TablePro/Core/AI/OpenAICompatibleProvider.swift +++ b/TablePro/Core/AI/OpenAICompatibleProvider.swift @@ -2,14 +2,11 @@ // OpenAICompatibleProvider.swift // TablePro // -// OpenAI-compatible API provider supporting OpenAI, OpenRouter, Ollama, and custom endpoints. -// import Foundation import os -/// AI provider for OpenAI-compatible APIs (OpenAI, OpenRouter, Ollama, custom) -final class OpenAICompatibleProvider: AIProvider { +final class OpenAICompatibleProvider: ChatTransport { private static let logger = Logger( subsystem: "com.TablePro", category: "OpenAICompatibleProvider" @@ -41,22 +38,14 @@ final class OpenAICompatibleProvider: AIProvider { self.session = session } - // MARK: - AIProvider - func streamChat( - messages: [AIChatMessage], - model: String, - systemPrompt: String? - ) -> AsyncThrowingStream { + turns: [ChatTurn], + options: ChatTransportOptions + ) -> AsyncThrowingStream { AsyncThrowingStream { continuation in let task = Task { do { - let request = try buildChatCompletionRequest( - messages: messages, - model: model, - systemPrompt: systemPrompt - ) - + let request = try buildChatCompletionRequest(turns: turns, options: options) let (bytes, response) = try await session.bytes(for: request) guard let httpResponse = response as? HTTPURLResponse else { @@ -79,34 +68,29 @@ final class OpenAICompatibleProvider: AIProvider { let jsonString: String if self.providerType == .ollama { - // Ollama: raw newline-delimited JSON (no SSE "data: " prefix) guard !line.isEmpty else { continue } jsonString = line } else { - // OpenAI/OpenRouter/Custom: SSE with "data: " prefix guard line.hasPrefix("data: ") else { continue } let payload = String(line.dropFirst(6)) guard payload != "[DONE]" else { break } jsonString = payload } - // Single JSON parse per SSE line guard let data = jsonString.data(using: .utf8), let json = try? JSONSerialization.jsonObject(with: data) as? [String: Any] else { continue } - // Text extraction if let choices = json["choices"] as? [[String: Any]], let delta = choices.first?["delta"] as? [String: Any], let content = delta["content"] as? String { - continuation.yield(.text(content)) + continuation.yield(.textDelta(content)) } else if let message = json["message"] as? [String: Any], let content = message["content"] as? String, !content.isEmpty { - continuation.yield(.text(content)) + continuation.yield(.textDelta(content)) } - // Usage extraction if let usage = json["usage"] as? [String: Any], let promptTokens = usage["prompt_tokens"] as? Int, let completionTokens = usage["completion_tokens"] as? Int { @@ -119,7 +103,6 @@ final class OpenAICompatibleProvider: AIProvider { outputTokens = evalCount } - // Ollama signals completion with "done":true if json["done"] as? Bool == true { break } @@ -156,7 +139,6 @@ final class OpenAICompatibleProvider: AIProvider { func testConnection() async throws -> Bool { switch providerType { case .ollama: - // Ollama is local — verify reachability and model availability do { let models = try await fetchAvailableModels() if models.isEmpty { @@ -177,7 +159,6 @@ final class OpenAICompatibleProvider: AIProvider { ) } default: - // Send a minimal non-streaming chat request to verify auth let chatPath = "/v1/chat/completions" guard let url = URL(string: "\(endpoint)\(chatPath)") else { throw AIProviderError.invalidEndpoint(endpoint) @@ -208,7 +189,6 @@ final class OpenAICompatibleProvider: AIProvider { return false } - // Check response is JSON (confirms we reached an API, not a random web page) let contentType = httpResponse.value(forHTTPHeaderField: "Content-Type") ?? "" let isJSON = contentType.contains("application/json") || (try? JSONSerialization.jsonObject(with: data)) != nil @@ -217,7 +197,6 @@ final class OpenAICompatibleProvider: AIProvider { throw AIProviderError.authenticationFailed("") } - // Non-JSON response means wrong endpoint (e.g., HTML 404 page) if !isJSON { return false } @@ -226,12 +205,9 @@ final class OpenAICompatibleProvider: AIProvider { } } - // MARK: - Request Building - private func buildChatCompletionRequest( - messages: [AIChatMessage], - model: String, - systemPrompt: String? + turns: [ChatTurn], + options: ChatTransportOptions ) throws -> URLRequest { let chatPath = providerType == .ollama ? "/api/chat" @@ -251,29 +227,30 @@ final class OpenAICompatibleProvider: AIProvider { ) } - // Build messages array var apiMessages: [[String: String]] = [] - if let systemPrompt { + if let systemPrompt = options.systemPrompt { apiMessages.append(["role": "system", "content": systemPrompt]) } - for message in messages where message.role != .system { + for turn in turns where turn.role != .system { + let text = turn.plainText + guard !text.isEmpty else { continue } apiMessages.append([ - "role": message.role.rawValue, - "content": message.content + "role": turn.role.rawValue, + "content": text ]) } var body: [String: Any] = [ - "model": model, + "model": options.model, "messages": apiMessages, "stream": true ] - if let maxOutputTokens { - body["max_tokens"] = maxOutputTokens + let resolvedMaxTokens = options.maxOutputTokens ?? maxOutputTokens + if let resolvedMaxTokens { + body["max_tokens"] = resolvedMaxTokens } - // Request usage stats in stream (OpenAI/OpenRouter support this) if providerType != .ollama { body["stream_options"] = ["include_usage": true] } @@ -282,8 +259,6 @@ final class OpenAICompatibleProvider: AIProvider { return request } - // MARK: - Model Fetching - private func fetchOpenAIModels() async throws -> [String] { guard let url = URL(string: "\(endpoint)/v1/models") else { throw AIProviderError.invalidEndpoint(endpoint) diff --git a/TablePro/Core/AI/Registry/AIProviderDescriptor.swift b/TablePro/Core/AI/Registry/AIProviderDescriptor.swift index ef952c08a..99f01786f 100644 --- a/TablePro/Core/AI/Registry/AIProviderDescriptor.swift +++ b/TablePro/Core/AI/Registry/AIProviderDescriptor.swift @@ -24,5 +24,5 @@ struct AIProviderDescriptor: Sendable { let requiresAPIKey: Bool let capabilities: AIProviderCapabilities let symbolName: String - let makeProvider: @Sendable (AIProviderConfig, String?) -> AIProvider + let makeProvider: @Sendable (AIProviderConfig, String?) -> ChatTransport } diff --git a/TablePro/Models/AI/AIConversation.swift b/TablePro/Models/AI/AIConversation.swift index 8a88dedaf..0a844034d 100644 --- a/TablePro/Models/AI/AIConversation.swift +++ b/TablePro/Models/AI/AIConversation.swift @@ -11,7 +11,7 @@ import Foundation struct AIConversation: Codable, Equatable, Identifiable { let id: UUID var title: String - var messages: [AIChatMessage] + var messages: [ChatTurn] let createdAt: Date var updatedAt: Date var connectionName: String? @@ -19,7 +19,7 @@ struct AIConversation: Codable, Equatable, Identifiable { init( id: UUID = UUID(), title: String = "", - messages: [AIChatMessage] = [], + messages: [ChatTurn] = [], createdAt: Date = Date(), updatedAt: Date = Date(), connectionName: String? = nil @@ -38,7 +38,7 @@ struct AIConversation: Codable, Equatable, Identifiable { let firstUserMessage = messages.first(where: { $0.role == .user }) else { return } - let text = firstUserMessage.content.trimmingCharacters(in: .whitespacesAndNewlines) + let text = firstUserMessage.plainText.trimmingCharacters(in: .whitespacesAndNewlines) if (text as NSString).length > 50 { title = String(text.prefix(47)) + "..." } else { diff --git a/TablePro/Models/AI/AIModels.swift b/TablePro/Models/AI/AIModels.swift index 4b8eba0a7..1d06264d5 100644 --- a/TablePro/Models/AI/AIModels.swift +++ b/TablePro/Models/AI/AIModels.swift @@ -205,43 +205,8 @@ struct AISettings: Codable, Equatable, Sendable { } } -// MARK: - AI Chat Message - -struct AIChatMessage: Codable, Equatable, Identifiable, Sendable { - let id: UUID - var role: AIChatRole - var content: String - let timestamp: Date - var usage: AITokenUsage? - - init( - id: UUID = UUID(), - role: AIChatRole, - content: String, - timestamp: Date = Date(), - usage: AITokenUsage? = nil - ) { - self.id = id - self.role = role - self.content = content - self.timestamp = timestamp - self.usage = usage - } -} - -enum AIChatRole: String, Codable, Sendable { - case user - case assistant - case system -} - struct AITokenUsage: Codable, Equatable, Sendable { var inputTokens: Int var outputTokens: Int var totalTokens: Int { inputTokens + outputTokens } } - -enum AIStreamEvent: Sendable { - case text(String) - case usage(AITokenUsage) -} diff --git a/TablePro/ViewModels/AIChatViewModel.swift b/TablePro/ViewModels/AIChatViewModel.swift index 429def0c3..d56db838b 100644 --- a/TablePro/ViewModels/AIChatViewModel.swift +++ b/TablePro/ViewModels/AIChatViewModel.swift @@ -17,7 +17,7 @@ final class AIChatViewModel { // MARK: - Published State - var messages: [AIChatMessage] = [] + var messages: [ChatTurn] = [] var inputText: String = "" var isStreaming: Bool = false var errorMessage: String? @@ -74,11 +74,11 @@ final class AIChatViewModel { sendWithContext(prompt: prompt) } - func editMessage(_ message: AIChatMessage) { + func editMessage(_ message: ChatTurn) { guard message.role == .user, !isStreaming else { return } guard let idx = messages.firstIndex(where: { $0.id == message.id }) else { return } - inputText = message.content + inputText = message.plainText messages.removeSubrange(idx...) persistCurrentConversation() } @@ -117,7 +117,7 @@ final class AIChatViewModel { let text = inputText.trimmingCharacters(in: .whitespacesAndNewlines) guard !text.isEmpty else { return } - let userMessage = AIChatMessage(role: .user, content: text) + let userMessage = ChatTurn(role: .user, blocks: [.text(text)]) messages.append(userMessage) trimMessagesIfNeeded() inputText = "" @@ -128,7 +128,7 @@ final class AIChatViewModel { /// Send a pre-filled prompt func sendWithContext(prompt: String) { - let userMessage = AIChatMessage(role: .user, content: prompt) + let userMessage = ChatTurn(role: .user, blocks: [.text(prompt)]) messages.append(userMessage) trimMessagesIfNeeded() errorMessage = nil @@ -145,7 +145,7 @@ final class AIChatViewModel { // Remove empty assistant placeholder left by cancelled stream if let assistantID = streamingAssistantID, let idx = messages.firstIndex(where: { $0.id == assistantID }), - messages[idx].content.isEmpty { + messages[idx].plainText.isEmpty { messages.remove(at: idx) } streamingAssistantID = nil @@ -333,7 +333,7 @@ final class AIChatViewModel { streamingTask = nil if let id = streamingAssistantID, let idx = messages.firstIndex(where: { $0.id == id }), - messages[idx].content.isEmpty { + messages[idx].plainText.isEmpty { messages.remove(at: idx) } streamingAssistantID = nil @@ -369,7 +369,7 @@ final class AIChatViewModel { let promptContext = capturePromptContext(settings: settings) // Create assistant message placeholder - let assistantMessage = AIChatMessage(role: .assistant, content: "") + let assistantMessage = ChatTurn(role: .assistant, blocks: []) messages.append(assistantMessage) trimMessagesIfNeeded() let assistantID = assistantMessage.id @@ -401,7 +401,7 @@ final class AIChatViewModel { // Pre-send size check let totalSize = ((systemPrompt ?? "") as NSString).length - + chatMessages.reduce(0) { $0 + ($1.content as NSString).length } + + chatMessages.reduce(0) { $0 + ($1.plainText as NSString).length } if totalSize > 100_000 { await MainActor.run { [weak self] in guard let self else { return } @@ -418,9 +418,11 @@ final class AIChatViewModel { } let stream = resolved.provider.streamChat( - messages: chatMessages, - model: resolved.model, - systemPrompt: systemPrompt + turns: chatMessages, + options: ChatTransportOptions( + model: resolved.model, + systemPrompt: systemPrompt + ) ) // Batch tokens off the main actor, flush on interval @@ -432,10 +434,12 @@ final class AIChatViewModel { for try await event in stream { guard !Task.isCancelled else { break } switch event { - case .text(let token): + case .textDelta(let token): pendingContent += token case .usage(let usage): pendingUsage = usage + case .toolUseStart, .toolUseDelta, .toolUseEnd: + break } if ContinuousClock.now - lastFlushTime >= flushInterval { @@ -448,7 +452,7 @@ final class AIChatViewModel { let idx = self.messages.firstIndex(where: { $0.id == assistantID }) else { return } if !content.isEmpty { - self.messages[idx].content += content + self.messages[idx].appendText(content) } if let usage { self.messages[idx].usage = usage @@ -467,7 +471,7 @@ final class AIChatViewModel { let idx = self.messages.firstIndex(where: { $0.id == assistantID }) else { return } if !content.isEmpty { - self.messages[idx].content += content + self.messages[idx].appendText(content) } if let usage { self.messages[idx].usage = usage @@ -493,7 +497,7 @@ final class AIChatViewModel { // Remove empty assistant message on error if let idx = self.messages.firstIndex(where: { $0.id == assistantID }), - self.messages[idx].content.isEmpty { + self.messages[idx].plainText.isEmpty { self.messages.remove(at: idx) } } diff --git a/TablePro/Views/AIChat/AIChatMessageView.swift b/TablePro/Views/AIChat/AIChatMessageView.swift index 2f074876f..1a113098f 100644 --- a/TablePro/Views/AIChat/AIChatMessageView.swift +++ b/TablePro/Views/AIChat/AIChatMessageView.swift @@ -11,7 +11,7 @@ import SwiftUI /// Displays a single AI chat message with appropriate styling struct AIChatMessageView: View { - let message: AIChatMessage + let message: ChatTurn var onRetry: (() -> Void)? var onRegenerate: (() -> Void)? var onEdit: (() -> Void)? @@ -31,7 +31,7 @@ struct AIChatMessageView: View { .font(.caption2) .foregroundStyle(.secondary) - Markdown(message.content) + Markdown(message.plainText) .markdownTheme(.tableProChat) .textSelection(.enabled) .frame(maxWidth: .infinity, alignment: .leading) @@ -116,12 +116,12 @@ struct AIChatMessageView: View { @ViewBuilder private var messageContent: some View { - if message.content.isEmpty { + if message.plainText.isEmpty { TypingIndicatorView() .padding(.horizontal, 8) .padding(.vertical, 6) } else { - Markdown(message.content) + Markdown(message.plainText) .markdownTheme(.tableProChat) .textSelection(.enabled) .padding(.horizontal, 8) diff --git a/TablePro/Views/AIChat/AIChatPanelView.swift b/TablePro/Views/AIChat/AIChatPanelView.swift index 99c1ef7ed..95e7a5c31 100644 --- a/TablePro/Views/AIChat/AIChatPanelView.swift +++ b/TablePro/Views/AIChat/AIChatPanelView.swift @@ -229,7 +229,7 @@ struct AIChatPanelView: View { isUserScrolledUp = false scrollToBottom(proxy: proxy, animated: true) } - .onChange(of: viewModel.messages.last?.content) { + .onChange(of: viewModel.messages.last?.plainText) { guard !isUserScrolledUp else { return } let now = Date() guard now.timeIntervalSince(lastAutoScrollTime) >= 0.1 else { return } @@ -353,16 +353,16 @@ struct AIChatPanelView: View { viewModel.queryResults = queryResults } - private func shouldShowRetry(for message: AIChatMessage) -> Bool { + private func shouldShowRetry(for message: ChatTurn) -> Bool { message.role == .user && message.id == viewModel.messages.last?.id && viewModel.lastMessageFailed } - private func shouldShowRegenerate(for message: AIChatMessage) -> Bool { + private func shouldShowRegenerate(for message: ChatTurn) -> Bool { message.role == .assistant && message.id == viewModel.messages.last?.id && !viewModel.isStreaming - && !message.content.isEmpty + && !message.plainText.isEmpty } }