diff --git a/Cotabby.xcodeproj/project.pbxproj b/Cotabby.xcodeproj/project.pbxproj index 1004e61..8f398e3 100644 --- a/Cotabby.xcodeproj/project.pbxproj +++ b/Cotabby.xcodeproj/project.pbxproj @@ -39,6 +39,7 @@ 15FA56CEF6FB5FF54C2FBA6F /* PermissionAndContextModelTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = E7F42112F14026E6253BB865 /* PermissionAndContextModelTests.swift */; }; 19CB55B62977376E9AE8D428 /* VisualContextStartCoalescer.swift in Sources */ = {isa = PBXBuildFile; fileRef = 2F01FAC4F57EB08471521196 /* VisualContextStartCoalescer.swift */; }; 1B3FFCB9A979F49BF86EAAD4 /* ScreenshotContextGeneratorTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = B2BFD19A159680A495EE02FD /* ScreenshotContextGeneratorTests.swift */; }; + 1C00CE3D553B58723EAE5F92 /* ConstrainedSamplerTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 6B0200AA5C4A644A3B52A3EC /* ConstrainedSamplerTests.swift */; }; 1D1C6FF0B8F50AC14A1000F4 /* SentenceBoundaryClassifierTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 2D7360A6D4261989A66658ED /* SentenceBoundaryClassifierTests.swift */; }; 1F8CC88AFFE67C08944CF506 /* WindowScreenshotService.swift in Sources */ = {isa = PBXBuildFile; fileRef = 77B0121E7BB173F8A2B0B108 /* WindowScreenshotService.swift */; }; 2197B68F1E4D0C3497DAC061 /* LlamaSuggestionEngine.swift in Sources */ = {isa = PBXBuildFile; fileRef = BE04620C905041680116BE80 /* LlamaSuggestionEngine.swift */; }; @@ -139,6 +140,7 @@ 7EB20783E0D36715D1230A5C /* PromptSectionBudgetTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = E260C4D08C786CDBD527B329 /* PromptSectionBudgetTests.swift */; }; 7FC103944F4EF39DB965F469 /* InMemoryLogging in Frameworks */ = {isa = PBXBuildFile; productRef = 88921938DC814625ED57D605 /* InMemoryLogging */; }; 814E348C663B697537594F0C /* EmojiRecentsTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 671689F289D45A124639C9C6 /* EmojiRecentsTests.swift */; }; + 81FC391E375B7EEFF965FF1B /* ConstrainedSampler.swift in Sources */ = {isa = PBXBuildFile; fileRef = 47CF010A66DA31989524FCD0 /* ConstrainedSampler.swift */; }; 82D4ADEAF05337ABDE4C586C /* RuntimeBootstrapModel.swift in Sources */ = {isa = PBXBuildFile; fileRef = 60629DFE309C1A4BD8A7FB3B /* RuntimeBootstrapModel.swift */; }; 83EC3543DC45B1601F119BF9 /* InsertionSafetyGateTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 43D627C4A55359EAF4676FF7 /* InsertionSafetyGateTests.swift */; }; 8441299082E6B68F7F88911B /* ShortcutConflictTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 0850B07CCDBA67C756C6EC59 /* ShortcutConflictTests.swift */; }; @@ -148,6 +150,7 @@ 88BCD795A14E1C9308F7BB31 /* SuggestionAvailabilityEvaluatorTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = C05B0439348261163B37C508 /* SuggestionAvailabilityEvaluatorTests.swift */; }; 8B2DFC860803C0A7C4D34A36 /* ContextBuffer.swift in Sources */ = {isa = PBXBuildFile; fileRef = 54EF3C7F5D9D6F3FA50FD51C /* ContextBuffer.swift */; }; 8DA36F1521B6A59D8C20AC59 /* Logging in Frameworks */ = {isa = PBXBuildFile; productRef = 5A60D1467BBFECB3DFEB39C2 /* Logging */; }; + 8EED2B55999A119AE3B67359 /* TokenProfile.swift in Sources */ = {isa = PBXBuildFile; fileRef = F3CEFE8C321E17BB3873C893 /* TokenProfile.swift */; }; 902B83CCB82E286FBEB9DAAD /* EmojiPickerPanelLayout.swift in Sources */ = {isa = PBXBuildFile; fileRef = 62EDF1199CC5E18BD7651661 /* EmojiPickerPanelLayout.swift */; }; 907A0BF56C3BB0CBAF2649AB /* SettingsCategory.swift in Sources */ = {isa = PBXBuildFile; fileRef = 5D0AEFF86F8210CBE7CFCBAD /* SettingsCategory.swift */; }; 909EBE545CE644C6C57F1B5D /* SuggestionCoordinator.swift in Sources */ = {isa = PBXBuildFile; fileRef = 8F961F5DF2A392F6F5F94F8A /* SuggestionCoordinator.swift */; }; @@ -197,6 +200,7 @@ C71B594433F3B411CAE5DE7E /* FocusCapabilityResolverTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = D4F6D5F94B238F7B4BE7C247 /* FocusCapabilityResolverTests.swift */; }; C9B815652CED38966C53A5E8 /* EmojiVariantResolverTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = EE8BB19D8EC9A75CD3458A6B /* EmojiVariantResolverTests.swift */; }; CA5B2D226FBAA5419E78F14F /* SuggestionSessionReconciler.swift in Sources */ = {isa = PBXBuildFile; fileRef = CE0AA0503128B0FC3951D700 /* SuggestionSessionReconciler.swift */; }; + CA8F453AA4AD02FAA8C961F7 /* TokenProfileTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = E7D0BF193110927BEB865748 /* TokenProfileTests.swift */; }; CB65A79F164269991FABC32E /* SuggestionStateHelperTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 19DB9558F4D3AFB108D71649 /* SuggestionStateHelperTests.swift */; }; CC98B842D10574C5206BEFA7 /* FocusCapabilityResolver.swift in Sources */ = {isa = PBXBuildFile; fileRef = 70367FCC1E0F08EE3B8EB26F /* FocusCapabilityResolver.swift */; }; CCC83DC5AE51C17F153D5A6A /* PermissionModels.swift in Sources */ = {isa = PBXBuildFile; fileRef = 9D82FFC568527700EC17C07D /* PermissionModels.swift */; }; @@ -327,6 +331,7 @@ 44595B534DD7323F0AD60825 /* MenuBarPopoverDismisser.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = MenuBarPopoverDismisser.swift; sourceTree = ""; }; 4696A84D17890B154533A08F /* PromptPolicyTests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = PromptPolicyTests.swift; sourceTree = ""; }; 4793D4EA5D36D7E5CC216C27 /* LanguageSupportTests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = LanguageSupportTests.swift; sourceTree = ""; }; + 47CF010A66DA31989524FCD0 /* ConstrainedSampler.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ConstrainedSampler.swift; sourceTree = ""; }; 51020F8CD58338BD643FBF63 /* ModelDownloadManager.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ModelDownloadManager.swift; sourceTree = ""; }; 52BAFA2F989C3C4F7FB892B5 /* MarkerSelectionSynthesizerTests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = MarkerSelectionSynthesizerTests.swift; sourceTree = ""; }; 53CF416511099C6818110F01 /* CompletionRenderModePolicy.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = CompletionRenderModePolicy.swift; sourceTree = ""; }; @@ -356,6 +361,7 @@ 671689F289D45A124639C9C6 /* EmojiRecentsTests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = EmojiRecentsTests.swift; sourceTree = ""; }; 67586807ACE8EB13C9014535 /* TickMarkSlider.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = TickMarkSlider.swift; sourceTree = ""; }; 6A44BEC8C23FF227731DD0CD /* FocusCapabilityFlickerGate.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = FocusCapabilityFlickerGate.swift; sourceTree = ""; }; + 6B0200AA5C4A644A3B52A3EC /* ConstrainedSamplerTests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ConstrainedSamplerTests.swift; sourceTree = ""; }; 6B2D97BAA3618A7D0357AC44 /* SuggestionWorkController.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = SuggestionWorkController.swift; sourceTree = ""; }; 6DC693E00430F46E41CB56E6 /* RequestID.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = RequestID.swift; sourceTree = ""; }; 70367FCC1E0F08EE3B8EB26F /* FocusCapabilityResolver.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = FocusCapabilityResolver.swift; sourceTree = ""; }; @@ -479,6 +485,7 @@ E43E587E421AF544A8300CE4 /* CustomRulesCatalog.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = CustomRulesCatalog.swift; sourceTree = ""; }; E5DAF68AEBFE334F68A65E82 /* AcceptanceModePickerView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = AcceptanceModePickerView.swift; sourceTree = ""; }; E6423D6CC8CC371D2DA899DE /* PermissionOverlayTracker.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = PermissionOverlayTracker.swift; sourceTree = ""; }; + E7D0BF193110927BEB865748 /* TokenProfileTests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = TokenProfileTests.swift; sourceTree = ""; }; E7F42112F14026E6253BB865 /* PermissionAndContextModelTests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = PermissionAndContextModelTests.swift; sourceTree = ""; }; EAAE6B395FAB604DF059280A /* KeyCodeLabels.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = KeyCodeLabels.swift; sourceTree = ""; }; EB630F9814388203DD1CA2EC /* ShortcutsPaneView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ShortcutsPaneView.swift; sourceTree = ""; }; @@ -487,6 +494,7 @@ EE94342B888A5A2CCF66BC93 /* SuggestionRequestFactoryTests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = SuggestionRequestFactoryTests.swift; sourceTree = ""; }; EFD89799BB82AF7A92559AEB /* ClipboardContentDistillerTests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ClipboardContentDistillerTests.swift; sourceTree = ""; }; F308F6E274CC645E27CB651F /* OverlayController.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = OverlayController.swift; sourceTree = ""; }; + F3CEFE8C321E17BB3873C893 /* TokenProfile.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = TokenProfile.swift; sourceTree = ""; }; FA4B45B91D4DEAC979C3113E /* PromptContextSanitizer.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = PromptContextSanitizer.swift; sourceTree = ""; }; FA878B447441BB4F3E327CC8 /* OnboardingTemplateRecommender.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = OnboardingTemplateRecommender.swift; sourceTree = ""; }; FB317C82CE2CBC69056BA4B8 /* TagChip.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = TagChip.swift; sourceTree = ""; }; @@ -725,6 +733,7 @@ 90B0D133AB77A2503FB08827 /* ClipboardRelevanceFilterTests.swift */, D504BEB224E0C176F5FCFF6E /* CompletionRenderModePolicyTests.swift */, 06FF2B0A3094A952A8EBA9B5 /* ConfidenceSuppressionPolicyTests.swift */, + 6B0200AA5C4A644A3B52A3EC /* ConstrainedSamplerTests.swift */, AF1E065C7FFB697FCEB2FA5C /* CotabbyTestFixtures.swift */, AD752451330486FE270018B0 /* CustomRulesTests.swift */, C1C5DE0F3FF63545000E2453 /* DisplayCoordinateConverterTests.swift */, @@ -780,6 +789,7 @@ C71031E8DB171047318B92FC /* SyntheticReplacePlannerTests.swift */, 43E37A7E835D3BDE6265843C /* TerminalAppDetectorTests.swift */, FC24FD54860CE6737E65EF65 /* TextDirectionDetectorTests.swift */, + E7D0BF193110927BEB865748 /* TokenProfileTests.swift */, E19A5B462891263BDFB56607 /* TrailingDuplicationFilterTests.swift */, 050D929E13BE52E6282B64D2 /* VisualContextStartCoalescerTests.swift */, 1E0513E3B23937B099A3CFF2 /* WordCountFormatterTests.swift */, @@ -884,6 +894,7 @@ D3A2AC525DC664DB540D4F19 /* ClipboardRelevanceFilter.swift */, 53CF416511099C6818110F01 /* CompletionRenderModePolicy.swift */, 1BD71ECC2AE4821B643E0935 /* ConfidenceSuppressionPolicy.swift */, + 47CF010A66DA31989524FCD0 /* ConstrainedSampler.swift */, C7B2D34A6F3AC9DFD61350F7 /* CotabbyDebugOptions.swift */, 29ED42C4BDD0C521101AF95E /* DeviceInfo.swift */, 74BD1D4DB27D5D96D1E06096 /* DisplayCoordinateConverter.swift */, @@ -928,6 +939,7 @@ B424E2AC97C99D335B0D5751 /* SuggestionTextNormalizer.swift */, 7F4C4A7EAF886E0CC945BFEF /* TerminalAppDetector.swift */, 328847A0F494360033366791 /* TextDirectionDetector.swift */, + F3CEFE8C321E17BB3873C893 /* TokenProfile.swift */, D408D647412C59F3E692C42B /* TrailingDuplicationFilter.swift */, 2F01FAC4F57EB08471521196 /* VisualContextStartCoalescer.swift */, 815F2ABAF6AB75DA3AFBBCEF /* WordCountFormatter.swift */, @@ -1075,6 +1087,7 @@ 7C94725B4837DEC9ECF1BC54 /* CompletionRenderMode.swift in Sources */, 3985F0F2B3178DBB945B1064 /* CompletionRenderModePolicy.swift in Sources */, 429CE592897D8A952F2916C3 /* ConfidenceSuppressionPolicy.swift in Sources */, + 81FC391E375B7EEFF965FF1B /* ConstrainedSampler.swift in Sources */, 8B2DFC860803C0A7C4D34A36 /* ContextBuffer.swift in Sources */, AA2E09FF7E430D66ECA8ECD5 /* CotabbyApp.swift in Sources */, FCC571EC239846F06007BFCA /* CotabbyAppEnvironment.swift in Sources */, @@ -1208,6 +1221,7 @@ AB9C9C001F97F9D14F8B192A /* TerminalAppDetector.swift in Sources */, 96782E57CA26A16409368B69 /* TextDirectionDetector.swift in Sources */, 6014B31E2570EFFE45557E33 /* TickMarkSlider.swift in Sources */, + 8EED2B55999A119AE3B67359 /* TokenProfile.swift in Sources */, D3B43622E5A41B11E7AF527E /* TrailingDuplicationFilter.swift in Sources */, E9E4CC657771DF9F4C56183C /* VisualContextCoordinator.swift in Sources */, 4190F8A76196B16ED94D0A55 /* VisualContextModels.swift in Sources */, @@ -1235,6 +1249,7 @@ BFCA7FAFDAEBF586AB615567 /* ClipboardRelevanceFilterTests.swift in Sources */, 25F91CEF38400FD1ADB6B1AF /* CompletionRenderModePolicyTests.swift in Sources */, 91D8189EFCD1BA992EA6F038 /* ConfidenceSuppressionPolicyTests.swift in Sources */, + 1C00CE3D553B58723EAE5F92 /* ConstrainedSamplerTests.swift in Sources */, 5E10EFC426217CB7218A5847 /* CotabbyTestFixtures.swift in Sources */, 91D1F16B8C5DA281D4B7F699 /* CustomRulesTests.swift in Sources */, 56611BA0087710277140E9E6 /* DisplayCoordinateConverterTests.swift in Sources */, @@ -1290,6 +1305,7 @@ EF5BAB96DDADABB86F9E02D9 /* SyntheticReplacePlannerTests.swift in Sources */, DE236C9285635C686D66A2F6 /* TerminalAppDetectorTests.swift in Sources */, 5A441797D71A880A7482077D /* TextDirectionDetectorTests.swift in Sources */, + CA8F453AA4AD02FAA8C961F7 /* TokenProfileTests.swift in Sources */, DB1310FF3576ACA6472C4DB1 /* TrailingDuplicationFilterTests.swift in Sources */, D5CAF3B590E5EC2AFC72E57A /* VisualContextStartCoalescerTests.swift in Sources */, 6AE0B46FB52D189D94E1F79A /* WordCountFormatterTests.swift in Sources */, @@ -1603,7 +1619,7 @@ isa = XCRemoteSwiftPackageReference; repositoryURL = "https://github.com/FuJacob/cotabbyinference.git"; requirement = { - branch = "feat/generation-quality-controls"; + branch = main; kind = branch; }; }; diff --git a/Cotabby/Models/LlamaRuntimeModels.swift b/Cotabby/Models/LlamaRuntimeModels.swift index d39f8ed..f725b2c 100644 --- a/Cotabby/Models/LlamaRuntimeModels.swift +++ b/Cotabby/Models/LlamaRuntimeModels.swift @@ -201,6 +201,13 @@ struct LlamaGenerationOptions: Equatable, Sendable { /// Average per-token log-probability below which a completion is suppressed as low-confidence. /// Defaults to -infinity, which disables suppression entirely. var confidenceFloor: Double = -.infinity + + /// Routes generation through the deterministic constrained decoder (logit read + admissibility + /// mask + argmax + manual token commit) instead of the engine's built-in stochastic sampler. + /// Default off so the shipping sampleNext path is unaffected until the constrained decoder is + /// validated on device. Changing it does not affect KV reuse, so it is intentionally excluded + /// from `SamplingFingerprint`. + var useConstrainedDecoder: Bool = false } /// The concrete runtime assets selected during bootstrap after checking available model files. diff --git a/Cotabby/Services/Runtime/LlamaRuntimeCore.swift b/Cotabby/Services/Runtime/LlamaRuntimeCore.swift index f230f16..a113a68 100644 --- a/Cotabby/Services/Runtime/LlamaRuntimeCore.swift +++ b/Cotabby/Services/Runtime/LlamaRuntimeCore.swift @@ -34,6 +34,13 @@ nonisolated final class LlamaRuntimeCore: @unchecked Sendable { private var autocompletePromptTokens: [Int32] = [] private var autocompleteSamplingFingerprint: SamplingFingerprint? + /// Per-model constrained-decoding token table, built lazily on the first constrained request and + /// reused across requests. Keyed by model URL so loading a different model rebuilds it. Read and + /// written only inside `runConstrainedDecode`, which runs under `autocompleteLock`, so it needs + /// no extra synchronization. Cleared on `shutdown()` to release the table when the model unloads. + private var cachedTokenProfile: TokenProfile? + private var cachedTokenProfileModelURL: URL? + /// Coordinates model lifecycle with in-flight operations. `generate()` and `summarize()` /// increment the active count on entry and decrement on exit. `shutdown()` sets the /// shutting-down flag and blocks until all active operations finish before unloading. @@ -188,6 +195,19 @@ nonisolated final class LlamaRuntimeCore: @unchecked Sendable { autocompleteSamplingFingerprint = fingerprint } + // The KV-trim defer above runs after whichever decoder returns. Both decoders share the + // prepared sequence and the same confidence-suppression contract; they differ only in how + // they pick each token (engine sampler vs. deterministic constrained selection). + return options.useConstrainedDecoder + ? try runConstrainedDecode(sequenceID: sequenceID, options: options) + : runEngineSampledDecode(sequenceID: sequenceID, options: options) + } + + // MARK: - Decoders + + /// The shipping decoder: delegates token selection to the engine's built-in sampler + /// (`sampleNext`), which applies temperature / top-k / top-p / min-p and commits each token. + private func runEngineSampledDecode(sequenceID: Int32, options: LlamaGenerationOptions) -> String { var generatedText = "" var tokensGenerated = 0 var sumLogprob = 0.0 @@ -230,24 +250,131 @@ nonisolated final class LlamaRuntimeCore: @unchecked Sendable { ] ) - // Confidence suppression: drop completions the model itself was unsure about. Disabled by - // default (confidenceFloor == -infinity); the KV-trim defer above still runs on early return. - if tokensGenerated > 0, - ConfidenceSuppressionPolicy.shouldSuppress( - averageLogprob: sumLogprob / Double(tokensGenerated), - floor: options.confidenceFloor - ) { + if Self.shouldSuppress(sumLogprob: sumLogprob, tokensGenerated: tokensGenerated, options: options) { + return "" + } + return generatedText + } + + /// The constrained decoder: reads the raw next-token logits, masks structural / excluded tokens + /// via the token profile, deterministically selects the highest-logit admissible token, and + /// commits it manually with `acceptToken`. This trades the sampler's randomness for reproducible, + /// leak-free continuations (no chat/control markers can surface as visible text). It honors the + /// same cancellation, single-line, and confidence-suppression contracts as the sampled path. + /// Mid-word word-continuation is already applied to the seed logits by `decodePrompt` (the engine + /// masks new-word-start tokens for the first step), so the first `getNextTokenLogits` row this + /// reads is already constrained when `forceWordContinuation` was set. + private func runConstrainedDecode(sequenceID: Int32, options: LlamaGenerationOptions) throws -> String { + let profile = try autocompleteTokenProfile() + let vocabSize = profile.vocabSize + guard vocabSize > 0 else { + throw LlamaRuntimeError.generationFailed("Vocabulary unavailable for constrained decoding.") + } + // `topK` bounds the candidate pool the selector ranks; clamp to a sane positive value so a + // zero/negative request still yields a full-vocab argmax rather than an empty pool. + let topK = options.topK > 0 ? options.topK : vocabSize + + var generatedBytes: [UInt8] = [] + var tokensGenerated = 0 + var sumLogprob = 0.0 + var stopReason = "budget_exhausted" + var logits = [Float](repeating: 0, count: vocabSize) + + for _ in 0 ..< options.maxPredictionTokens { + if Task.isCancelled { + stopReason = "cancelled" + break + } + + let written = logits.withUnsafeMutableBufferPointer { buffer in + Int(engine.getNextTokenLogits(sequenceID, buffer.baseAddress, Int32(buffer.count))) + } + guard written == vocabSize else { + stopReason = "no_logits" + break + } + + guard let tokenID = ConstrainedSampler.selectToken( + logits: logits, + profile: profile, + admissibleTokenIDs: nil, + topK: topK + ) else { + stopReason = "no_admissible_token" + break + } + + if profile.isEndOfGeneration(tokenID) { + stopReason = "eos" + break + } + // Single-line fields must never receive a line break; stop before emitting one so the + // partial completion so far is preserved (mirrors the sampler path's single_line mask). + if options.singleLine, profile.isNewline(tokenID) { + stopReason = "single_line" + break + } + + // Accumulate raw bytes and decode once at the end: a single token may carry only part of + // a multi-byte UTF-8 scalar, so per-token String decoding would corrupt CJK / emoji. + if let logProb = ConstrainedSampler.logProb(ofTokenAt: tokenID, in: logits) { + sumLogprob += logProb + } + generatedBytes.append(contentsOf: profile.bytes(for: tokenID)) + tokensGenerated += 1 + + if engine.acceptToken(sequenceID, Int32(tokenID)) != .ok { + stopReason = "accept_failed" + break + } + } + + // Lossy decode is deliberate: the accumulated bytes are valid UTF-8 except for a possible + // partial trailing scalar (the final token may carry only part of a multi-byte character). + // The failable `String(bytes:encoding:)` would discard the entire completion in that case; + // `String(decoding:)` keeps every complete scalar and renders only the fragment as U+FFFD. + // swiftlint:disable:next optional_data_string_conversion + let generatedText = String(decoding: generatedBytes, as: UTF8.self) + CotabbyLogger.runtime.debug( + "Decode end", + metadata: [ + "kind": .string("generate_constrained"), + "tokens_generated": .stringConvertible(tokensGenerated), + "chars_generated": .stringConvertible(generatedText.count), + "stop_reason": .string(stopReason) + ] + ) + + if Self.shouldSuppress(sumLogprob: sumLogprob, tokensGenerated: tokensGenerated, options: options) { + return "" + } + return generatedText + } + + /// Shared low-confidence gate for both decoders: drop completions the model itself was unsure + /// about. Disabled by default (confidenceFloor == -infinity). The KV-trim defer in `generate` + /// still runs because the caller returns "" rather than throwing. + private static func shouldSuppress( + sumLogprob: Double, + tokensGenerated: Int, + options: LlamaGenerationOptions + ) -> Bool { + guard tokensGenerated > 0 else { return false } + let averageLogprob = sumLogprob / Double(tokensGenerated) + let suppress = ConfidenceSuppressionPolicy.shouldSuppress( + averageLogprob: averageLogprob, + floor: options.confidenceFloor + ) + if suppress { CotabbyLogger.runtime.debug( "Suppressed low-confidence completion", metadata: [ "tokens_generated": .stringConvertible(tokensGenerated), - "avg_logprob": .stringConvertible(sumLogprob / Double(tokensGenerated)) + "avg_logprob": .stringConvertible(averageLogprob) ] ) - return "" } - - return generatedText + return suppress } // MARK: - Cache and lifecycle @@ -302,6 +429,8 @@ nonisolated final class LlamaRuntimeCore: @unchecked Sendable { resetPromptCache() engine.unloadModel() preparedRuntime = nil + cachedTokenProfile = nil + cachedTokenProfileModelURL = nil CotabbyLogger.runtime.info("Runtime shutdown complete") lifecycleCondition.lock() @@ -405,6 +534,68 @@ nonisolated final class LlamaRuntimeCore: @unchecked Sendable { return Array(vec) } + /// Lazily builds and caches the constrained-decoding token profile for the loaded model. The + /// profile records each token's bytes and structural flags so the constrained decoder can mask + /// excluded tokens and detect stops without calling back into the engine per step. Building scans + /// the whole vocabulary once (one detokenize per token), so the result is cached and reused until + /// the model changes. Must be called while holding `autocompleteLock`. + private func autocompleteTokenProfile() throws -> TokenProfile { + let modelURL = preparedRuntime?.resolvedRuntime.modelFileURL + if let cachedTokenProfile, cachedTokenProfileModelURL == modelURL { + return cachedTokenProfile + } + + let vocabSize = Int(engine.getVocabSize()) + guard vocabSize > 0 else { + throw LlamaRuntimeError.generationFailed("Vocabulary unavailable for constrained decoding.") + } + + // Detokenize every token once up front; the build closures index this snapshot so each + // token's bytes are computed a single time and its control flag derives from the same bytes. + var tokenBytes: [[UInt8]] = [] + tokenBytes.reserveCapacity(vocabSize) + for id in 0 ..< vocabSize { + tokenBytes.append(detokenizeBytes(Int32(id))) + } + + let profile = TokenProfile.build( + vocabSize: vocabSize, + bytesFor: { tokenBytes[$0] }, + // A token that detokenizes to no visible bytes is a structural / special / control token + // (llama renders those empty when special rendering is off); never emit it as text. + isControl: { tokenBytes[$0].isEmpty }, + isEndOfGeneration: { self.engine.isEndOfGenerationToken(Int32($0)) } + ) + cachedTokenProfile = profile + cachedTokenProfileModelURL = modelURL + CotabbyLogger.runtime.debug( + "Built constrained-decode token profile", + metadata: ["vocab_size": .stringConvertible(vocabSize)] + ) + return profile + } + + /// The raw UTF-8 bytes a token detokenizes to, or empty for a structural token that renders to + /// nothing. `detokenize` returns the byte count, or a negative `-(required)` when the fixed + /// buffer is too small; the rare large-piece case retries once at the requested size. + private func detokenizeBytes(_ token: Int32) -> [UInt8] { + var buffer = [CChar](repeating: 0, count: 256) + let written = buffer.withUnsafeMutableBufferPointer { ptr in + Int(engine.detokenize(token, ptr.baseAddress, Int32(ptr.count))) + } + if written > 0 { + return buffer.prefix(written).map { UInt8(bitPattern: $0) } + } + if written < 0 { + var large = [CChar](repeating: 0, count: -written) + let writtenLarge = large.withUnsafeMutableBufferPointer { ptr in + Int(engine.detokenize(token, ptr.baseAddress, Int32(ptr.count))) + } + return writtenLarge > 0 ? large.prefix(writtenLarge).map { UInt8(bitPattern: $0) } : [] + } + return [] + } + private static func extractPiece(_ result: SampleResult) -> String { guard let piece = result.piece, result.piece_length > 0 else { return "" } let buffer = UnsafeBufferPointer( diff --git a/Cotabby/Services/Runtime/LlamaSuggestionEngine.swift b/Cotabby/Services/Runtime/LlamaSuggestionEngine.swift index f793531..21b0c0f 100644 --- a/Cotabby/Services/Runtime/LlamaSuggestionEngine.swift +++ b/Cotabby/Services/Runtime/LlamaSuggestionEngine.swift @@ -13,6 +13,15 @@ final class LlamaSuggestionEngine { private let runtimeManager: LlamaRuntimeManager private var promptCacheHintTracker = LlamaPromptCacheHintTracker() + /// UserDefaults key (no UI) that routes llama generation through the deterministic constrained + /// decoder instead of the engine's stochastic sampler. Default-off: decode quality can only be + /// judged with a real model in a real field, so this stays a hidden developer/dogfood toggle + /// until it is validated on device and promoted to the default. + private static let constrainedDecoderDefaultsKey = "cotabbyConstrainedDecoderEnabled" + private static var isConstrainedDecoderEnabled: Bool { + UserDefaults.standard.bool(forKey: constrainedDecoderDefaultsKey) + } + init(runtimeManager: LlamaRuntimeManager) { self.runtimeManager = runtimeManager } @@ -50,7 +59,8 @@ final class LlamaSuggestionEngine { forceWordContinuation: MidWordContinuationPolicy.shouldForceContinuation( precedingText: request.context.precedingText, trailingText: request.context.trailingText - ) + ), + useConstrainedDecoder: Self.isConstrainedDecoderEnabled ) ) try Task.checkCancellation() diff --git a/Cotabby/Support/ConstrainedSampler.swift b/Cotabby/Support/ConstrainedSampler.swift new file mode 100644 index 0000000..0fe5031 --- /dev/null +++ b/Cotabby/Support/ConstrainedSampler.swift @@ -0,0 +1,129 @@ +import Foundation + +/// File overview: +/// Pure, deterministic token selection over a single step's logits, plus a confidence helper that +/// averages per-step log-probabilities. Selection skips excluded (control) tokens, optionally +/// restricts to a set of admissible ids, and returns the surviving token with the highest logit. +/// +/// Why this file exists: +/// Constrained decoding needs a selection step that is fully reproducible: the same logits and the +/// same constraints must always yield the same token, so behavior is testable and a suggestion can +/// be explained after the fact. This sampler is therefore deterministic argmax with no RNG and no +/// temperature. `topK` only bounds how large the candidate pool is before the argmax — it is a cost +/// guard, not a source of randomness, so a smaller `topK` can only ever change the result by +/// excluding lower-logit tokens that would not have won anyway among the unexcluded candidates. The +/// admissibility set (when present) is the byte-prefix constraint computed elsewhere; passing nil +/// means "no prefix constraint". Keeping this logic pure keeps the engine integration thin: the +/// runtime supplies logits and the precomputed constraints, and this returns an id (or nil when +/// nothing survives). +enum ConstrainedSampler { + /// Selects the highest-logit token that survives the constraints, or nil when none survive. + /// + /// Survivors are tokens that are in-range, not `profile.isExcluded`, and — when + /// `admissibleTokenIDs` is non-nil — members of that set. `topK` bounds the candidate pool by + /// pre-ranking on logit before filtering: only the `topK` highest-logit token ids are considered. + /// Because selection is a plain argmax, bounding the pool cannot change which token wins unless + /// the winner sat outside the top `topK` by raw logit, so callers trade recall for cost by + /// lowering `topK`. A `topK` of zero or negative considers no candidates and returns nil. + /// + /// Determinism note: ties on logit are broken by the lower token id, so equal-logit inputs still + /// produce a single stable result. + static func selectToken( + logits: [Float], + profile: TokenProfile, + admissibleTokenIDs: Set?, + topK: Int + ) -> Int? { + guard topK > 0, !logits.isEmpty else { + return nil + } + if let admissible = admissibleTokenIDs, admissible.isEmpty { + // An explicit empty admissible set means the prefix constraint admits nothing this step. + return nil + } + + let candidates = candidatePool(count: logits.count, logits: logits, limit: topK) + + var best: Int? + var bestLogit: Float = -.infinity + for id in candidates { + if profile.isExcluded(id) { + continue + } + if let admissible = admissibleTokenIDs, !admissible.contains(id) { + continue + } + let logit = logits[id] + // Strict greater-than keeps the first-seen (lower-id, because the pool is id-ordered after + // the top-k cut) token on ties, which makes the result independent of iteration quirks. + if best == nil || logit > bestLogit { + best = id + bestLogit = logit + } + } + return best + } + + /// Average per-step log-probability of a sequence of chosen tokens, a confidence summary suitable + /// for the existing low-confidence suppression policy. + /// + /// `fullRows[i]` is the full logits vector at step `i` and `chosenLogits[i]` is the logit of the + /// token actually committed at step `i` (the caller already knows which id it picked, so it passes + /// the scalar rather than the id). For each step this computes the softmax log-probability of the + /// chosen token, `chosenLogit - logSumExp(row)`, and returns the mean across steps. Returns nil + /// when there are no steps or the two inputs disagree in length, since an average is undefined + /// then. Pure and deterministic: a numerically stable log-sum-exp (shifted by the row maximum) + /// makes the result independent of constant offsets in the logits. + static func averageLogProb(of chosenLogits: [Float], over fullRows: [[Float]]) -> Double? { + guard !chosenLogits.isEmpty, chosenLogits.count == fullRows.count else { + return nil + } + var total = 0.0 + for (chosen, row) in zip(chosenLogits, fullRows) { + guard !row.isEmpty else { + return nil + } + total += Double(chosen) - logSumExp(row) + } + return total / Double(chosenLogits.count) + } + + /// The softmax log-probability of the token at `index` in `row`: `row[index] - logSumExp(row)`. + /// This is the single-step form of `averageLogProb`, for decoders that score each chosen token as + /// they go instead of retaining every logits row (retaining full rows for a whole completion would + /// cost vocab-size floats per step). Returns nil for an empty row or an out-of-range index. + static func logProb(ofTokenAt index: Int, in row: [Float]) -> Double? { + guard !row.isEmpty, index >= 0, index < row.count else { + return nil + } + return Double(row[index]) - logSumExp(row) + } + + /// The token ids to consider this step, ordered by id. When `limit` is at least `count` every id + /// is returned (still id-ordered). Otherwise the ids are ranked by descending logit, the top + /// `limit` are kept, and that subset is re-sorted by id so downstream tie-breaking stays stable. + private static func candidatePool(count: Int, logits: [Float], limit: Int) -> [Int] { + guard limit < count else { + return Array(0.. logits[rhs] + } + // Stable id ordering for equal logits so the kept set is deterministic at the cut line. + return lhs < rhs + } + return ranked.prefix(limit).sorted() + } + + /// Numerically stable log(sum(exp(row))): subtract the max before exponentiating so large logits + /// do not overflow. The caller guarantees `row` is non-empty. + private static func logSumExp(_ row: [Float]) -> Double { + let maxLogit = Double(row.max() ?? 0) + var sumExp = 0.0 + for value in row { + sumExp += exp(Double(value) - maxLogit) + } + return maxLogit + log(sumExp) + } +} diff --git a/Cotabby/Support/TokenProfile.swift b/Cotabby/Support/TokenProfile.swift new file mode 100644 index 0000000..57815c1 --- /dev/null +++ b/Cotabby/Support/TokenProfile.swift @@ -0,0 +1,136 @@ +import Foundation + +/// File overview: +/// Per-token metadata used by constrained decoding: for every token id in a model's vocabulary it +/// records the token's raw UTF-8 bytes plus a few classification flags (control, end-of-generation, +/// whitespace-only, newline). Constrained sampling and prefix admissibility both read from this +/// table instead of querying a live engine. +/// +/// Why this file exists: +/// Constrained decoding needs the *byte* shape of every candidate token (to test whether it can +/// continue a required prefix) and a way to drop tokens that should never be sampled as visible text +/// (control / structural tokens). A live tokenizer cannot be called from pure decision code without +/// dragging the runtime in, and it would not be deterministically testable. Building this profile +/// once from injected vocabulary data keeps the decoding rules pure: tests supply stub closures for +/// the bytes and flags, and the same inputs always yield the same verdicts. The builder takes +/// closures rather than concrete engine objects precisely so no runtime dependency leaks in. +struct TokenProfile { + /// Classification flags for a single token, kept compact so the per-token table stays small even + /// for large vocabularies. + struct Entry { + /// The token's raw UTF-8 bytes. Byte-level (not String) because admissibility against a + /// partial-word prefix is a byte-prefix relationship: a token may encode only part of a + /// multi-byte scalar, and Strings cannot represent that fragment. + let bytes: [UInt8] + /// A structural / control token (for example a chat or special marker) that must never be + /// emitted as visible completion text. + let isControl: Bool + /// A token the engine treats as a stop / end-of-generation signal. + let isEndOfGeneration: Bool + /// The decoded bytes are non-empty and contain only whitespace. + let isWhitespaceOnly: Bool + /// The decoded bytes contain a line feed (`\n`). + let isNewline: Bool + } + + /// Indexed by token id; `entries[id]` is the metadata for token `id`. + let entries: [Entry] + + /// The number of tokens described by this profile. + var vocabSize: Int { entries.count } + + /// Builds a profile for `vocabSize` tokens by pulling each token's bytes and flags from the + /// supplied closures. The closures are the only source of engine data, which is what keeps the + /// type pure and testable: a runtime passes detokenize / control / EOG lookups, and a test passes + /// stubs. Whitespace-only and newline are derived from the bytes here so callers cannot supply an + /// inconsistent classification. + static func build( + vocabSize: Int, + bytesFor: (Int) -> [UInt8], + isControl: (Int) -> Bool, + isEndOfGeneration: (Int) -> Bool + ) -> TokenProfile { + guard vocabSize > 0 else { + return TokenProfile(entries: []) + } + var entries: [Entry] = [] + entries.reserveCapacity(vocabSize) + for id in 0.. [UInt8] { + guard let entry = entry(for: id) else { + return [] + } + return entry.bytes + } + + /// Whether `id` must be excluded from ordinary sampling. Control tokens are excluded so structural + /// markers never surface as visible completion text. An out-of-range id is treated as excluded so + /// it can never be selected by accident. + func isExcluded(_ id: Int) -> Bool { + guard let entry = entry(for: id) else { + return true + } + return entry.isControl + } + + /// Whether `id` is an end-of-generation token. False for an out-of-range id. + func isEndOfGeneration(_ id: Int) -> Bool { + entry(for: id)?.isEndOfGeneration ?? false + } + + /// Whether `id` decodes to bytes containing a newline. False for an out-of-range id. + func isNewline(_ id: Int) -> Bool { + entry(for: id)?.isNewline ?? false + } + + /// Whether `id` decodes to non-empty, whitespace-only bytes. False for an out-of-range id. + func isWhitespaceOnly(_ id: Int) -> Bool { + entry(for: id)?.isWhitespaceOnly ?? false + } + + private func entry(for id: Int) -> Entry? { + guard id >= 0, id < entries.count else { + return nil + } + return entries[id] + } + + private static let lineFeed: UInt8 = 0x0A + + /// Non-empty and every byte is an ASCII whitespace character. Constrained to ASCII whitespace on + /// purpose: classifying arbitrary multi-byte Unicode whitespace would require decoding partial + /// scalars, which a single token's bytes may not form. Empty bytes are not whitespace-only. + private static func isWhitespaceOnly(_ bytes: [UInt8]) -> Bool { + guard !bytes.isEmpty else { + return false + } + return bytes.allSatisfy(Self.isASCIIWhitespace) + } + + private static func isASCIIWhitespace(_ byte: UInt8) -> Bool { + switch byte { + case 0x20, 0x09, 0x0A, 0x0B, 0x0C, 0x0D: + // space, tab, line feed, vertical tab, form feed, carriage return + return true + default: + return false + } + } +} diff --git a/CotabbyTests/ConstrainedSamplerTests.swift b/CotabbyTests/ConstrainedSamplerTests.swift new file mode 100644 index 0000000..08eb25c --- /dev/null +++ b/CotabbyTests/ConstrainedSamplerTests.swift @@ -0,0 +1,242 @@ +import XCTest +@testable import Cotabby + +/// Pure-function tests for deterministic constrained selection and the confidence helper. No RNG is +/// involved, so every selection is an exact, repeatable argmax under the given constraints. +final class ConstrainedSamplerTests: XCTestCase { + + /// Profile of plain non-control, non-EOG tokens with single-letter bytes, one per logit slot. + private func plainProfile(count: Int, control: Set = []) -> TokenProfile { + TokenProfile.build( + vocabSize: count, + bytesFor: { [UInt8(65 + ($0 % 26))] }, + isControl: { control.contains($0) }, + isEndOfGeneration: { _ in false } + ) + } + + // MARK: - selectToken + + func test_select_returnsHighestLogit() { + let logits: [Float] = [0.1, 2.5, 1.0, -3.0] + let id = ConstrainedSampler.selectToken( + logits: logits, + profile: plainProfile(count: 4), + admissibleTokenIDs: nil, + topK: 4 + ) + XCTAssertEqual(id, 1) + } + + func test_select_skipsExcludedControlTokens() { + // Token 1 has the highest logit but is control, so it must be skipped in favor of token 2. + let logits: [Float] = [0.1, 5.0, 2.0, 1.0] + let id = ConstrainedSampler.selectToken( + logits: logits, + profile: plainProfile(count: 4, control: [1]), + admissibleTokenIDs: nil, + topK: 4 + ) + XCTAssertEqual(id, 2) + } + + func test_select_honorsAdmissibleSet() { + // Highest logit is token 0, but only {2, 3} are admissible, so token 2 (higher of the two) wins. + let logits: [Float] = [9.0, 8.0, 3.0, 2.0] + let id = ConstrainedSampler.selectToken( + logits: logits, + profile: plainProfile(count: 4), + admissibleTokenIDs: [2, 3], + topK: 4 + ) + XCTAssertEqual(id, 2) + } + + func test_select_emptyAdmissibleSet_returnsNil() { + // An explicit empty constraint admits nothing this step. + let logits: [Float] = [1.0, 2.0, 3.0] + let id = ConstrainedSampler.selectToken( + logits: logits, + profile: plainProfile(count: 3), + admissibleTokenIDs: [], + topK: 3 + ) + XCTAssertNil(id) + } + + func test_select_allExcluded_returnsNil() { + let logits: [Float] = [1.0, 2.0, 3.0] + let id = ConstrainedSampler.selectToken( + logits: logits, + profile: plainProfile(count: 3, control: [0, 1, 2]), + admissibleTokenIDs: nil, + topK: 3 + ) + XCTAssertNil(id) + } + + func test_select_admissibleIDOutOfRange_isIgnored() { + // An admissible id with no logit slot must not crash or be selected; the in-range admissible + // token wins instead. + let logits: [Float] = [1.0, 4.0] + let id = ConstrainedSampler.selectToken( + logits: logits, + profile: plainProfile(count: 2), + admissibleTokenIDs: [1, 99], + topK: 2 + ) + XCTAssertEqual(id, 1) + } + + func test_select_tieBrokenByLowerID() { + // Equal logits must resolve to the lower id so the result is stable. + let logits: [Float] = [5.0, 5.0, 5.0] + let id = ConstrainedSampler.selectToken( + logits: logits, + profile: plainProfile(count: 3), + admissibleTokenIDs: nil, + topK: 3 + ) + XCTAssertEqual(id, 0) + } + + func test_select_topKBoundsCandidatePool() { + // topK=2 keeps only the two highest-logit ids {0, 3}; the lower-logit ids 1 and 2 are never + // considered. Token 0 is excluded (control), so token 3 wins from within the bounded pool. + let logits: [Float] = [9.0, 1.0, 2.0, 8.0] + let id = ConstrainedSampler.selectToken( + logits: logits, + profile: plainProfile(count: 4, control: [0]), + admissibleTokenIDs: nil, + topK: 2 + ) + XCTAssertEqual(id, 3) + } + + func test_select_topKTooSmallToReachAdmissible_returnsNil() { + // Admissible token 2 has a low logit; topK=1 keeps only the global max (token 0), which is not + // admissible, so nothing survives. Demonstrates that topK trades recall for cost. + let logits: [Float] = [9.0, 8.0, 0.5] + let id = ConstrainedSampler.selectToken( + logits: logits, + profile: plainProfile(count: 3), + admissibleTokenIDs: [2], + topK: 1 + ) + XCTAssertNil(id) + } + + func test_select_topKZero_returnsNil() { + let id = ConstrainedSampler.selectToken( + logits: [1.0, 2.0], + profile: plainProfile(count: 2), + admissibleTokenIDs: nil, + topK: 0 + ) + XCTAssertNil(id) + } + + func test_select_emptyLogits_returnsNil() { + let id = ConstrainedSampler.selectToken( + logits: [], + profile: plainProfile(count: 0), + admissibleTokenIDs: nil, + topK: 4 + ) + XCTAssertNil(id) + } + + func test_select_isDeterministicAcrossRepeatedCalls() { + let logits: [Float] = [0.3, 0.31, 0.305, 0.31] + let profile = plainProfile(count: 4) + let first = ConstrainedSampler.selectToken( + logits: logits, profile: profile, admissibleTokenIDs: nil, topK: 4 + ) + for _ in 0..<20 { + let again = ConstrainedSampler.selectToken( + logits: logits, profile: profile, admissibleTokenIDs: nil, topK: 4 + ) + XCTAssertEqual(again, first) + } + // Tie between ids 1 and 3 resolves to the lower id. + XCTAssertEqual(first, 1) + } + + // MARK: - averageLogProb + + func test_averageLogProb_uniformRow_matchesNegativeLogVocab() { + // Every logit equal -> each token's probability is 1/N, so log-prob is -ln(N) every step. + let row: [Float] = [0, 0, 0, 0] + let value = ConstrainedSampler.averageLogProb(of: [0, 0], over: [row, row]) + XCTAssertNotNil(value) + XCTAssertEqual(value ?? .nan, -log(4.0), accuracy: 1e-9) + } + + func test_averageLogProb_isInvariantToConstantOffset() { + // Adding a constant to every logit in a row leaves softmax (and thus the log-prob) unchanged. + let base: [Float] = [1.0, 2.0, 0.5] + let shifted: [Float] = base.map { $0 + 100.0 } + let plain = ConstrainedSampler.averageLogProb(of: [2.0], over: [base]) + let offset = ConstrainedSampler.averageLogProb(of: [102.0], over: [shifted]) + XCTAssertNotNil(plain) + XCTAssertNotNil(offset) + XCTAssertEqual(plain ?? .nan, offset ?? .nan, accuracy: 1e-6) + } + + func test_averageLogProb_averagesAcrossSteps() { + // Two steps with known per-step log-probs; the result is their mean. + let rowA: [Float] = [0, 0] // chosen logit 0 -> log(0.5) + let rowB: [Float] = [0, 0, 0, 0] // chosen logit 0 -> log(0.25) + let value = ConstrainedSampler.averageLogProb(of: [0, 0], over: [rowA, rowB]) + let expected = (log(0.5) + log(0.25)) / 2.0 + XCTAssertEqual(value ?? .nan, expected, accuracy: 1e-9) + } + + func test_averageLogProb_emptyInput_returnsNil() { + XCTAssertNil(ConstrainedSampler.averageLogProb(of: [], over: [])) + } + + func test_averageLogProb_lengthMismatch_returnsNil() { + XCTAssertNil(ConstrainedSampler.averageLogProb(of: [1.0], over: [[1.0], [2.0]])) + } + + func test_averageLogProb_emptyRow_returnsNil() { + XCTAssertNil(ConstrainedSampler.averageLogProb(of: [1.0], over: [[]])) + } + + // MARK: - logProb (single-step) + + func test_logProb_uniformRow_matchesNegativeLogVocab() { + // Uniform logits -> every token has probability 1/N -> log-prob is -ln(N). + let row: [Float] = [0, 0, 0, 0] + let value = ConstrainedSampler.logProb(ofTokenAt: 2, in: row) + XCTAssertEqual(value ?? .nan, -log(4.0), accuracy: 1e-9) + } + + func test_logProb_matchesAverageLogProbForSingleStep() { + // The single-step helper must agree with averaging one row, since the decoder accumulates the + // former where the offline confidence helper uses the latter. + let row: [Float] = [1.0, 2.0, 0.5, -1.0] + let single = ConstrainedSampler.logProb(ofTokenAt: 1, in: row) + let averaged = ConstrainedSampler.averageLogProb(of: [row[1]], over: [row]) + XCTAssertNotNil(single) + XCTAssertEqual(single ?? .nan, averaged ?? .nan, accuracy: 1e-9) + } + + func test_logProb_isInvariantToConstantOffset() { + let base: [Float] = [1.0, 2.0, 0.5] + let shifted: [Float] = base.map { $0 + 50.0 } + let plain = ConstrainedSampler.logProb(ofTokenAt: 0, in: base) + let offset = ConstrainedSampler.logProb(ofTokenAt: 0, in: shifted) + XCTAssertEqual(plain ?? .nan, offset ?? .nan, accuracy: 1e-6) + } + + func test_logProb_outOfRangeIndex_returnsNil() { + XCTAssertNil(ConstrainedSampler.logProb(ofTokenAt: 5, in: [1.0, 2.0])) + XCTAssertNil(ConstrainedSampler.logProb(ofTokenAt: -1, in: [1.0, 2.0])) + } + + func test_logProb_emptyRow_returnsNil() { + XCTAssertNil(ConstrainedSampler.logProb(ofTokenAt: 0, in: [])) + } +} diff --git a/CotabbyTests/TokenProfileTests.swift b/CotabbyTests/TokenProfileTests.swift new file mode 100644 index 0000000..dc763ea --- /dev/null +++ b/CotabbyTests/TokenProfileTests.swift @@ -0,0 +1,109 @@ +import XCTest +@testable import Cotabby + +/// Pure-function tests for per-token constrained-decoding metadata. Engine data is supplied through +/// stub closures, so every assertion is deterministic and no runtime is involved. +final class TokenProfileTests: XCTestCase { + + /// One token's stub data; a small struct rather than a tuple keeps the table readable and avoids + /// a large-tuple lint warning. + private struct Stub { + let bytes: [UInt8] + let control: Bool + let eog: Bool + } + + /// Builds a profile from a literal table of stub entries indexed by token id. + private func makeProfile(_ table: [Stub]) -> TokenProfile { + TokenProfile.build( + vocabSize: table.count, + bytesFor: { table[$0].bytes }, + isControl: { table[$0].control }, + isEndOfGeneration: { table[$0].eog } + ) + } + + private func bytes(_ string: String) -> [UInt8] { + Array(string.utf8) + } + + func test_build_recordsVocabSizeAndBytes() { + let profile = makeProfile([ + Stub(bytes: bytes("the"), control: false, eog: false), + Stub(bytes: bytes(" dog"), control: false, eog: false) + ]) + XCTAssertEqual(profile.vocabSize, 2) + XCTAssertEqual(profile.bytes(for: 0), bytes("the")) + XCTAssertEqual(profile.bytes(for: 1), bytes(" dog")) + } + + func test_build_emptyVocab_producesEmptyProfile() { + let profile = TokenProfile.build( + vocabSize: 0, + bytesFor: { _ in [] }, + isControl: { _ in false }, + isEndOfGeneration: { _ in false } + ) + XCTAssertEqual(profile.vocabSize, 0) + } + + func test_controlToken_isExcluded() { + let profile = makeProfile([ + Stub(bytes: bytes("hi"), control: false, eog: false), + Stub(bytes: bytes("<|end|>"), control: true, eog: false) + ]) + XCTAssertFalse(profile.isExcluded(0)) + XCTAssertTrue(profile.isExcluded(1)) + } + + func test_endOfGenerationFlag_isReported() { + let profile = makeProfile([ + Stub(bytes: bytes("word"), control: false, eog: false), + Stub(bytes: bytes(""), control: true, eog: true) + ]) + XCTAssertFalse(profile.isEndOfGeneration(0)) + XCTAssertTrue(profile.isEndOfGeneration(1)) + } + + func test_whitespaceOnly_classification() { + let profile = makeProfile([ + Stub(bytes: bytes(" "), control: false, eog: false), + Stub(bytes: bytes("\t \n"), control: false, eog: false), + Stub(bytes: bytes(" x"), control: false, eog: false), + Stub(bytes: [], control: false, eog: false) + ]) + XCTAssertTrue(profile.isWhitespaceOnly(0)) + XCTAssertTrue(profile.isWhitespaceOnly(1)) + // A space followed by a letter is not whitespace-only. + XCTAssertFalse(profile.isWhitespaceOnly(2)) + // Empty bytes are not whitespace-only. + XCTAssertFalse(profile.isWhitespaceOnly(3)) + } + + func test_newline_classification() { + let profile = makeProfile([ + Stub(bytes: bytes("\n"), control: false, eog: false), + Stub(bytes: bytes("a\nb"), control: false, eog: false), + Stub(bytes: bytes("plain"), control: false, eog: false) + ]) + XCTAssertTrue(profile.isNewline(0)) + // Newline embedded among other bytes still counts. + XCTAssertTrue(profile.isNewline(1)) + XCTAssertFalse(profile.isNewline(2)) + } + + func test_outOfRangeID_isDefensive() { + let profile = makeProfile([ + Stub(bytes: bytes("only"), control: false, eog: false) + ]) + // Out-of-range queries must never crash and must be treated as excluded / negative so a stray + // id can never be selected or misclassified. + XCTAssertEqual(profile.bytes(for: 5), []) + XCTAssertEqual(profile.bytes(for: -1), []) + XCTAssertTrue(profile.isExcluded(5)) + XCTAssertTrue(profile.isExcluded(-1)) + XCTAssertFalse(profile.isEndOfGeneration(5)) + XCTAssertFalse(profile.isNewline(5)) + XCTAssertFalse(profile.isWhitespaceOnly(5)) + } +} diff --git a/project.yml b/project.yml index b1727e8..273be9a 100644 --- a/project.yml +++ b/project.yml @@ -13,7 +13,7 @@ packages: exactVersion: 2.9.1 CotabbyInference: url: https://github.com/FuJacob/cotabbyinference.git - branch: feat/generation-quality-controls + branch: main swift-log: url: https://github.com/apple/swift-log.git from: 1.12.1