From fe69fbb432e746be2941e01e3cc536b2478ec1b6 Mon Sep 17 00:00:00 2001 From: Jacob Fu <141651335+FuJacob@users.noreply.github.com> Date: Mon, 1 Jun 2026 08:07:35 -0700 Subject: [PATCH] Add deterministic constrained decoder behind a default-off flag Introduces a logit-level constrained decoding path for the open-source llama runtime as an alternative to the engine's built-in stochastic sampler. Per step it reads the raw next-token logits, masks structural and control tokens via a per-model token profile, deterministically selects the highest-logit admissible token, and commits it manually with acceptToken. The result is reproducible, leak-free continuations: no chat or control markers can surface as visible completion text, and the same prompt always yields the same suggestion. New pure, unit-tested helpers carry the decision logic: TokenProfile (per-token bytes plus control/EOG/whitespace/newline flags, built once per model from the vocab) and ConstrainedSampler (deterministic argmax with exclusion, optional admissibility, top-k pool bound, and a stable single-step log-prob for confidence). The runtime builds the profile lazily on first use and caches it per model. Routing is gated by the hidden cotabbyConstrainedDecoderEnabled UserDefaults flag (default off), so the shipping sampleNext path is byte-for-byte unchanged until the constrained decoder is validated on device. The generate() lifecycle, KV reuse, cancellation, and the manager's task handling are untouched; only the inner decode loop branches on the flag. Bumps the CotabbyInference pin to main to consume the logits-read, token-accept, and vocab-introspection primitives. --- Cotabby.xcodeproj/project.pbxproj | 18 +- Cotabby/Models/LlamaRuntimeModels.swift | 7 + .../Services/Runtime/LlamaRuntimeCore.swift | 213 ++++++++++++++- .../Runtime/LlamaSuggestionEngine.swift | 12 +- Cotabby/Support/ConstrainedSampler.swift | 129 ++++++++++ Cotabby/Support/TokenProfile.swift | 136 ++++++++++ CotabbyTests/ConstrainedSamplerTests.swift | 242 ++++++++++++++++++ CotabbyTests/TokenProfileTests.swift | 109 ++++++++ project.yml | 2 +- 9 files changed, 854 insertions(+), 14 deletions(-) create mode 100644 Cotabby/Support/ConstrainedSampler.swift create mode 100644 Cotabby/Support/TokenProfile.swift create mode 100644 CotabbyTests/ConstrainedSamplerTests.swift create mode 100644 CotabbyTests/TokenProfileTests.swift diff --git a/Cotabby.xcodeproj/project.pbxproj b/Cotabby.xcodeproj/project.pbxproj index 1004e61..8f398e3 100644 --- a/Cotabby.xcodeproj/project.pbxproj +++ b/Cotabby.xcodeproj/project.pbxproj @@ -39,6 +39,7 @@ 15FA56CEF6FB5FF54C2FBA6F /* PermissionAndContextModelTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = E7F42112F14026E6253BB865 /* PermissionAndContextModelTests.swift */; }; 19CB55B62977376E9AE8D428 /* VisualContextStartCoalescer.swift in Sources */ = {isa = PBXBuildFile; fileRef = 2F01FAC4F57EB08471521196 /* VisualContextStartCoalescer.swift */; }; 1B3FFCB9A979F49BF86EAAD4 /* ScreenshotContextGeneratorTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = B2BFD19A159680A495EE02FD /* ScreenshotContextGeneratorTests.swift */; }; + 1C00CE3D553B58723EAE5F92 /* ConstrainedSamplerTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 6B0200AA5C4A644A3B52A3EC /* ConstrainedSamplerTests.swift */; }; 1D1C6FF0B8F50AC14A1000F4 /* SentenceBoundaryClassifierTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 2D7360A6D4261989A66658ED /* SentenceBoundaryClassifierTests.swift */; }; 1F8CC88AFFE67C08944CF506 /* WindowScreenshotService.swift in Sources */ = {isa = PBXBuildFile; fileRef = 77B0121E7BB173F8A2B0B108 /* WindowScreenshotService.swift */; }; 2197B68F1E4D0C3497DAC061 /* LlamaSuggestionEngine.swift in Sources */ = {isa = PBXBuildFile; fileRef = BE04620C905041680116BE80 /* LlamaSuggestionEngine.swift */; }; @@ -139,6 +140,7 @@ 7EB20783E0D36715D1230A5C /* PromptSectionBudgetTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = E260C4D08C786CDBD527B329 /* PromptSectionBudgetTests.swift */; }; 7FC103944F4EF39DB965F469 /* InMemoryLogging in Frameworks */ = {isa = PBXBuildFile; productRef = 88921938DC814625ED57D605 /* InMemoryLogging */; }; 814E348C663B697537594F0C /* EmojiRecentsTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 671689F289D45A124639C9C6 /* EmojiRecentsTests.swift */; }; + 81FC391E375B7EEFF965FF1B /* ConstrainedSampler.swift in Sources */ = {isa = PBXBuildFile; fileRef = 47CF010A66DA31989524FCD0 /* ConstrainedSampler.swift */; }; 82D4ADEAF05337ABDE4C586C /* RuntimeBootstrapModel.swift in Sources */ = {isa = PBXBuildFile; fileRef = 60629DFE309C1A4BD8A7FB3B /* RuntimeBootstrapModel.swift */; }; 83EC3543DC45B1601F119BF9 /* InsertionSafetyGateTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 43D627C4A55359EAF4676FF7 /* InsertionSafetyGateTests.swift */; }; 8441299082E6B68F7F88911B /* ShortcutConflictTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 0850B07CCDBA67C756C6EC59 /* ShortcutConflictTests.swift */; }; @@ -148,6 +150,7 @@ 88BCD795A14E1C9308F7BB31 /* SuggestionAvailabilityEvaluatorTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = C05B0439348261163B37C508 /* SuggestionAvailabilityEvaluatorTests.swift */; }; 8B2DFC860803C0A7C4D34A36 /* ContextBuffer.swift in Sources */ = {isa = PBXBuildFile; fileRef = 54EF3C7F5D9D6F3FA50FD51C /* ContextBuffer.swift */; }; 8DA36F1521B6A59D8C20AC59 /* Logging in Frameworks */ = {isa = PBXBuildFile; productRef = 5A60D1467BBFECB3DFEB39C2 /* Logging */; }; + 8EED2B55999A119AE3B67359 /* TokenProfile.swift in Sources */ = {isa = PBXBuildFile; fileRef = F3CEFE8C321E17BB3873C893 /* TokenProfile.swift */; }; 902B83CCB82E286FBEB9DAAD /* EmojiPickerPanelLayout.swift in Sources */ = {isa = PBXBuildFile; fileRef = 62EDF1199CC5E18BD7651661 /* EmojiPickerPanelLayout.swift */; }; 907A0BF56C3BB0CBAF2649AB /* SettingsCategory.swift in Sources */ = {isa = PBXBuildFile; fileRef = 5D0AEFF86F8210CBE7CFCBAD /* SettingsCategory.swift */; }; 909EBE545CE644C6C57F1B5D /* SuggestionCoordinator.swift in Sources */ = {isa = PBXBuildFile; fileRef = 8F961F5DF2A392F6F5F94F8A /* SuggestionCoordinator.swift */; }; @@ -197,6 +200,7 @@ C71B594433F3B411CAE5DE7E /* FocusCapabilityResolverTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = D4F6D5F94B238F7B4BE7C247 /* FocusCapabilityResolverTests.swift */; }; C9B815652CED38966C53A5E8 /* EmojiVariantResolverTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = EE8BB19D8EC9A75CD3458A6B /* EmojiVariantResolverTests.swift */; }; CA5B2D226FBAA5419E78F14F /* SuggestionSessionReconciler.swift in Sources */ = {isa = PBXBuildFile; fileRef = CE0AA0503128B0FC3951D700 /* SuggestionSessionReconciler.swift */; }; + CA8F453AA4AD02FAA8C961F7 /* TokenProfileTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = E7D0BF193110927BEB865748 /* TokenProfileTests.swift */; }; CB65A79F164269991FABC32E /* SuggestionStateHelperTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 19DB9558F4D3AFB108D71649 /* SuggestionStateHelperTests.swift */; }; CC98B842D10574C5206BEFA7 /* FocusCapabilityResolver.swift in Sources */ = {isa = PBXBuildFile; fileRef = 70367FCC1E0F08EE3B8EB26F /* FocusCapabilityResolver.swift */; }; CCC83DC5AE51C17F153D5A6A /* PermissionModels.swift in Sources */ = {isa = PBXBuildFile; fileRef = 9D82FFC568527700EC17C07D /* PermissionModels.swift */; }; @@ -327,6 +331,7 @@ 44595B534DD7323F0AD60825 /* MenuBarPopoverDismisser.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = MenuBarPopoverDismisser.swift; sourceTree = ""; }; 4696A84D17890B154533A08F /* PromptPolicyTests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = PromptPolicyTests.swift; sourceTree = ""; }; 4793D4EA5D36D7E5CC216C27 /* LanguageSupportTests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = LanguageSupportTests.swift; sourceTree = ""; }; + 47CF010A66DA31989524FCD0 /* ConstrainedSampler.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ConstrainedSampler.swift; sourceTree = ""; }; 51020F8CD58338BD643FBF63 /* ModelDownloadManager.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ModelDownloadManager.swift; sourceTree = ""; }; 52BAFA2F989C3C4F7FB892B5 /* MarkerSelectionSynthesizerTests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = MarkerSelectionSynthesizerTests.swift; sourceTree = ""; }; 53CF416511099C6818110F01 /* CompletionRenderModePolicy.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = CompletionRenderModePolicy.swift; sourceTree = ""; }; @@ -356,6 +361,7 @@ 671689F289D45A124639C9C6 /* EmojiRecentsTests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = EmojiRecentsTests.swift; sourceTree = ""; }; 67586807ACE8EB13C9014535 /* TickMarkSlider.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = TickMarkSlider.swift; sourceTree = ""; }; 6A44BEC8C23FF227731DD0CD /* FocusCapabilityFlickerGate.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = FocusCapabilityFlickerGate.swift; sourceTree = ""; }; + 6B0200AA5C4A644A3B52A3EC /* ConstrainedSamplerTests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ConstrainedSamplerTests.swift; sourceTree = ""; }; 6B2D97BAA3618A7D0357AC44 /* SuggestionWorkController.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = SuggestionWorkController.swift; sourceTree = ""; }; 6DC693E00430F46E41CB56E6 /* RequestID.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = RequestID.swift; sourceTree = ""; }; 70367FCC1E0F08EE3B8EB26F /* FocusCapabilityResolver.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = FocusCapabilityResolver.swift; sourceTree = ""; }; @@ -479,6 +485,7 @@ E43E587E421AF544A8300CE4 /* CustomRulesCatalog.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = CustomRulesCatalog.swift; sourceTree = ""; }; E5DAF68AEBFE334F68A65E82 /* AcceptanceModePickerView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = AcceptanceModePickerView.swift; sourceTree = ""; }; E6423D6CC8CC371D2DA899DE /* PermissionOverlayTracker.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = PermissionOverlayTracker.swift; sourceTree = ""; }; + E7D0BF193110927BEB865748 /* TokenProfileTests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = TokenProfileTests.swift; sourceTree = ""; }; E7F42112F14026E6253BB865 /* PermissionAndContextModelTests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = PermissionAndContextModelTests.swift; sourceTree = ""; }; EAAE6B395FAB604DF059280A /* KeyCodeLabels.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = KeyCodeLabels.swift; sourceTree = ""; }; EB630F9814388203DD1CA2EC /* ShortcutsPaneView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ShortcutsPaneView.swift; sourceTree = ""; }; @@ -487,6 +494,7 @@ EE94342B888A5A2CCF66BC93 /* SuggestionRequestFactoryTests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = SuggestionRequestFactoryTests.swift; sourceTree = ""; }; EFD89799BB82AF7A92559AEB /* ClipboardContentDistillerTests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ClipboardContentDistillerTests.swift; sourceTree = ""; }; F308F6E274CC645E27CB651F /* OverlayController.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = OverlayController.swift; sourceTree = ""; }; + F3CEFE8C321E17BB3873C893 /* TokenProfile.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = TokenProfile.swift; sourceTree = ""; }; FA4B45B91D4DEAC979C3113E /* PromptContextSanitizer.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = PromptContextSanitizer.swift; sourceTree = ""; }; FA878B447441BB4F3E327CC8 /* OnboardingTemplateRecommender.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = OnboardingTemplateRecommender.swift; sourceTree = ""; }; FB317C82CE2CBC69056BA4B8 /* TagChip.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = TagChip.swift; sourceTree = ""; }; @@ -725,6 +733,7 @@ 90B0D133AB77A2503FB08827 /* ClipboardRelevanceFilterTests.swift */, D504BEB224E0C176F5FCFF6E /* CompletionRenderModePolicyTests.swift */, 06FF2B0A3094A952A8EBA9B5 /* ConfidenceSuppressionPolicyTests.swift */, + 6B0200AA5C4A644A3B52A3EC /* ConstrainedSamplerTests.swift */, AF1E065C7FFB697FCEB2FA5C /* CotabbyTestFixtures.swift */, AD752451330486FE270018B0 /* CustomRulesTests.swift */, C1C5DE0F3FF63545000E2453 /* DisplayCoordinateConverterTests.swift */, @@ -780,6 +789,7 @@ C71031E8DB171047318B92FC /* SyntheticReplacePlannerTests.swift */, 43E37A7E835D3BDE6265843C /* TerminalAppDetectorTests.swift */, FC24FD54860CE6737E65EF65 /* TextDirectionDetectorTests.swift */, + E7D0BF193110927BEB865748 /* TokenProfileTests.swift */, E19A5B462891263BDFB56607 /* TrailingDuplicationFilterTests.swift */, 050D929E13BE52E6282B64D2 /* VisualContextStartCoalescerTests.swift */, 1E0513E3B23937B099A3CFF2 /* WordCountFormatterTests.swift */, @@ -884,6 +894,7 @@ D3A2AC525DC664DB540D4F19 /* ClipboardRelevanceFilter.swift */, 53CF416511099C6818110F01 /* CompletionRenderModePolicy.swift */, 1BD71ECC2AE4821B643E0935 /* ConfidenceSuppressionPolicy.swift */, + 47CF010A66DA31989524FCD0 /* ConstrainedSampler.swift */, C7B2D34A6F3AC9DFD61350F7 /* CotabbyDebugOptions.swift */, 29ED42C4BDD0C521101AF95E /* DeviceInfo.swift */, 74BD1D4DB27D5D96D1E06096 /* DisplayCoordinateConverter.swift */, @@ -928,6 +939,7 @@ B424E2AC97C99D335B0D5751 /* SuggestionTextNormalizer.swift */, 7F4C4A7EAF886E0CC945BFEF /* TerminalAppDetector.swift */, 328847A0F494360033366791 /* TextDirectionDetector.swift */, + F3CEFE8C321E17BB3873C893 /* TokenProfile.swift */, D408D647412C59F3E692C42B /* TrailingDuplicationFilter.swift */, 2F01FAC4F57EB08471521196 /* VisualContextStartCoalescer.swift */, 815F2ABAF6AB75DA3AFBBCEF /* WordCountFormatter.swift */, @@ -1075,6 +1087,7 @@ 7C94725B4837DEC9ECF1BC54 /* CompletionRenderMode.swift in Sources */, 3985F0F2B3178DBB945B1064 /* CompletionRenderModePolicy.swift in Sources */, 429CE592897D8A952F2916C3 /* ConfidenceSuppressionPolicy.swift in Sources */, + 81FC391E375B7EEFF965FF1B /* ConstrainedSampler.swift in Sources */, 8B2DFC860803C0A7C4D34A36 /* ContextBuffer.swift in Sources */, AA2E09FF7E430D66ECA8ECD5 /* CotabbyApp.swift in Sources */, FCC571EC239846F06007BFCA /* CotabbyAppEnvironment.swift in Sources */, @@ -1208,6 +1221,7 @@ AB9C9C001F97F9D14F8B192A /* TerminalAppDetector.swift in Sources */, 96782E57CA26A16409368B69 /* TextDirectionDetector.swift in Sources */, 6014B31E2570EFFE45557E33 /* TickMarkSlider.swift in Sources */, + 8EED2B55999A119AE3B67359 /* TokenProfile.swift in Sources */, D3B43622E5A41B11E7AF527E /* TrailingDuplicationFilter.swift in Sources */, E9E4CC657771DF9F4C56183C /* VisualContextCoordinator.swift in Sources */, 4190F8A76196B16ED94D0A55 /* VisualContextModels.swift in Sources */, @@ -1235,6 +1249,7 @@ BFCA7FAFDAEBF586AB615567 /* ClipboardRelevanceFilterTests.swift in Sources */, 25F91CEF38400FD1ADB6B1AF /* CompletionRenderModePolicyTests.swift in Sources */, 91D8189EFCD1BA992EA6F038 /* ConfidenceSuppressionPolicyTests.swift in Sources */, + 1C00CE3D553B58723EAE5F92 /* ConstrainedSamplerTests.swift in Sources */, 5E10EFC426217CB7218A5847 /* CotabbyTestFixtures.swift in Sources */, 91D1F16B8C5DA281D4B7F699 /* CustomRulesTests.swift in Sources */, 56611BA0087710277140E9E6 /* DisplayCoordinateConverterTests.swift in Sources */, @@ -1290,6 +1305,7 @@ EF5BAB96DDADABB86F9E02D9 /* SyntheticReplacePlannerTests.swift in Sources */, DE236C9285635C686D66A2F6 /* TerminalAppDetectorTests.swift in Sources */, 5A441797D71A880A7482077D /* TextDirectionDetectorTests.swift in Sources */, + CA8F453AA4AD02FAA8C961F7 /* TokenProfileTests.swift in Sources */, DB1310FF3576ACA6472C4DB1 /* TrailingDuplicationFilterTests.swift in Sources */, D5CAF3B590E5EC2AFC72E57A /* VisualContextStartCoalescerTests.swift in Sources */, 6AE0B46FB52D189D94E1F79A /* WordCountFormatterTests.swift in Sources */, @@ -1603,7 +1619,7 @@ isa = XCRemoteSwiftPackageReference; repositoryURL = "https://github.com/FuJacob/cotabbyinference.git"; requirement = { - branch = "feat/generation-quality-controls"; + branch = main; kind = branch; }; }; diff --git a/Cotabby/Models/LlamaRuntimeModels.swift b/Cotabby/Models/LlamaRuntimeModels.swift index d39f8ed..f725b2c 100644 --- a/Cotabby/Models/LlamaRuntimeModels.swift +++ b/Cotabby/Models/LlamaRuntimeModels.swift @@ -201,6 +201,13 @@ struct LlamaGenerationOptions: Equatable, Sendable { /// Average per-token log-probability below which a completion is suppressed as low-confidence. /// Defaults to -infinity, which disables suppression entirely. var confidenceFloor: Double = -.infinity + + /// Routes generation through the deterministic constrained decoder (logit read + admissibility + /// mask + argmax + manual token commit) instead of the engine's built-in stochastic sampler. + /// Default off so the shipping sampleNext path is unaffected until the constrained decoder is + /// validated on device. Changing it does not affect KV reuse, so it is intentionally excluded + /// from `SamplingFingerprint`. + var useConstrainedDecoder: Bool = false } /// The concrete runtime assets selected during bootstrap after checking available model files. diff --git a/Cotabby/Services/Runtime/LlamaRuntimeCore.swift b/Cotabby/Services/Runtime/LlamaRuntimeCore.swift index f230f16..a113a68 100644 --- a/Cotabby/Services/Runtime/LlamaRuntimeCore.swift +++ b/Cotabby/Services/Runtime/LlamaRuntimeCore.swift @@ -34,6 +34,13 @@ nonisolated final class LlamaRuntimeCore: @unchecked Sendable { private var autocompletePromptTokens: [Int32] = [] private var autocompleteSamplingFingerprint: SamplingFingerprint? + /// Per-model constrained-decoding token table, built lazily on the first constrained request and + /// reused across requests. Keyed by model URL so loading a different model rebuilds it. Read and + /// written only inside `runConstrainedDecode`, which runs under `autocompleteLock`, so it needs + /// no extra synchronization. Cleared on `shutdown()` to release the table when the model unloads. + private var cachedTokenProfile: TokenProfile? + private var cachedTokenProfileModelURL: URL? + /// Coordinates model lifecycle with in-flight operations. `generate()` and `summarize()` /// increment the active count on entry and decrement on exit. `shutdown()` sets the /// shutting-down flag and blocks until all active operations finish before unloading. @@ -188,6 +195,19 @@ nonisolated final class LlamaRuntimeCore: @unchecked Sendable { autocompleteSamplingFingerprint = fingerprint } + // The KV-trim defer above runs after whichever decoder returns. Both decoders share the + // prepared sequence and the same confidence-suppression contract; they differ only in how + // they pick each token (engine sampler vs. deterministic constrained selection). + return options.useConstrainedDecoder + ? try runConstrainedDecode(sequenceID: sequenceID, options: options) + : runEngineSampledDecode(sequenceID: sequenceID, options: options) + } + + // MARK: - Decoders + + /// The shipping decoder: delegates token selection to the engine's built-in sampler + /// (`sampleNext`), which applies temperature / top-k / top-p / min-p and commits each token. + private func runEngineSampledDecode(sequenceID: Int32, options: LlamaGenerationOptions) -> String { var generatedText = "" var tokensGenerated = 0 var sumLogprob = 0.0 @@ -230,24 +250,131 @@ nonisolated final class LlamaRuntimeCore: @unchecked Sendable { ] ) - // Confidence suppression: drop completions the model itself was unsure about. Disabled by - // default (confidenceFloor == -infinity); the KV-trim defer above still runs on early return. - if tokensGenerated > 0, - ConfidenceSuppressionPolicy.shouldSuppress( - averageLogprob: sumLogprob / Double(tokensGenerated), - floor: options.confidenceFloor - ) { + if Self.shouldSuppress(sumLogprob: sumLogprob, tokensGenerated: tokensGenerated, options: options) { + return "" + } + return generatedText + } + + /// The constrained decoder: reads the raw next-token logits, masks structural / excluded tokens + /// via the token profile, deterministically selects the highest-logit admissible token, and + /// commits it manually with `acceptToken`. This trades the sampler's randomness for reproducible, + /// leak-free continuations (no chat/control markers can surface as visible text). It honors the + /// same cancellation, single-line, and confidence-suppression contracts as the sampled path. + /// Mid-word word-continuation is already applied to the seed logits by `decodePrompt` (the engine + /// masks new-word-start tokens for the first step), so the first `getNextTokenLogits` row this + /// reads is already constrained when `forceWordContinuation` was set. + private func runConstrainedDecode(sequenceID: Int32, options: LlamaGenerationOptions) throws -> String { + let profile = try autocompleteTokenProfile() + let vocabSize = profile.vocabSize + guard vocabSize > 0 else { + throw LlamaRuntimeError.generationFailed("Vocabulary unavailable for constrained decoding.") + } + // `topK` bounds the candidate pool the selector ranks; clamp to a sane positive value so a + // zero/negative request still yields a full-vocab argmax rather than an empty pool. + let topK = options.topK > 0 ? options.topK : vocabSize + + var generatedBytes: [UInt8] = [] + var tokensGenerated = 0 + var sumLogprob = 0.0 + var stopReason = "budget_exhausted" + var logits = [Float](repeating: 0, count: vocabSize) + + for _ in 0 ..< options.maxPredictionTokens { + if Task.isCancelled { + stopReason = "cancelled" + break + } + + let written = logits.withUnsafeMutableBufferPointer { buffer in + Int(engine.getNextTokenLogits(sequenceID, buffer.baseAddress, Int32(buffer.count))) + } + guard written == vocabSize else { + stopReason = "no_logits" + break + } + + guard let tokenID = ConstrainedSampler.selectToken( + logits: logits, + profile: profile, + admissibleTokenIDs: nil, + topK: topK + ) else { + stopReason = "no_admissible_token" + break + } + + if profile.isEndOfGeneration(tokenID) { + stopReason = "eos" + break + } + // Single-line fields must never receive a line break; stop before emitting one so the + // partial completion so far is preserved (mirrors the sampler path's single_line mask). + if options.singleLine, profile.isNewline(tokenID) { + stopReason = "single_line" + break + } + + // Accumulate raw bytes and decode once at the end: a single token may carry only part of + // a multi-byte UTF-8 scalar, so per-token String decoding would corrupt CJK / emoji. + if let logProb = ConstrainedSampler.logProb(ofTokenAt: tokenID, in: logits) { + sumLogprob += logProb + } + generatedBytes.append(contentsOf: profile.bytes(for: tokenID)) + tokensGenerated += 1 + + if engine.acceptToken(sequenceID, Int32(tokenID)) != .ok { + stopReason = "accept_failed" + break + } + } + + // Lossy decode is deliberate: the accumulated bytes are valid UTF-8 except for a possible + // partial trailing scalar (the final token may carry only part of a multi-byte character). + // The failable `String(bytes:encoding:)` would discard the entire completion in that case; + // `String(decoding:)` keeps every complete scalar and renders only the fragment as U+FFFD. + // swiftlint:disable:next optional_data_string_conversion + let generatedText = String(decoding: generatedBytes, as: UTF8.self) + CotabbyLogger.runtime.debug( + "Decode end", + metadata: [ + "kind": .string("generate_constrained"), + "tokens_generated": .stringConvertible(tokensGenerated), + "chars_generated": .stringConvertible(generatedText.count), + "stop_reason": .string(stopReason) + ] + ) + + if Self.shouldSuppress(sumLogprob: sumLogprob, tokensGenerated: tokensGenerated, options: options) { + return "" + } + return generatedText + } + + /// Shared low-confidence gate for both decoders: drop completions the model itself was unsure + /// about. Disabled by default (confidenceFloor == -infinity). The KV-trim defer in `generate` + /// still runs because the caller returns "" rather than throwing. + private static func shouldSuppress( + sumLogprob: Double, + tokensGenerated: Int, + options: LlamaGenerationOptions + ) -> Bool { + guard tokensGenerated > 0 else { return false } + let averageLogprob = sumLogprob / Double(tokensGenerated) + let suppress = ConfidenceSuppressionPolicy.shouldSuppress( + averageLogprob: averageLogprob, + floor: options.confidenceFloor + ) + if suppress { CotabbyLogger.runtime.debug( "Suppressed low-confidence completion", metadata: [ "tokens_generated": .stringConvertible(tokensGenerated), - "avg_logprob": .stringConvertible(sumLogprob / Double(tokensGenerated)) + "avg_logprob": .stringConvertible(averageLogprob) ] ) - return "" } - - return generatedText + return suppress } // MARK: - Cache and lifecycle @@ -302,6 +429,8 @@ nonisolated final class LlamaRuntimeCore: @unchecked Sendable { resetPromptCache() engine.unloadModel() preparedRuntime = nil + cachedTokenProfile = nil + cachedTokenProfileModelURL = nil CotabbyLogger.runtime.info("Runtime shutdown complete") lifecycleCondition.lock() @@ -405,6 +534,68 @@ nonisolated final class LlamaRuntimeCore: @unchecked Sendable { return Array(vec) } + /// Lazily builds and caches the constrained-decoding token profile for the loaded model. The + /// profile records each token's bytes and structural flags so the constrained decoder can mask + /// excluded tokens and detect stops without calling back into the engine per step. Building scans + /// the whole vocabulary once (one detokenize per token), so the result is cached and reused until + /// the model changes. Must be called while holding `autocompleteLock`. + private func autocompleteTokenProfile() throws -> TokenProfile { + let modelURL = preparedRuntime?.resolvedRuntime.modelFileURL + if let cachedTokenProfile, cachedTokenProfileModelURL == modelURL { + return cachedTokenProfile + } + + let vocabSize = Int(engine.getVocabSize()) + guard vocabSize > 0 else { + throw LlamaRuntimeError.generationFailed("Vocabulary unavailable for constrained decoding.") + } + + // Detokenize every token once up front; the build closures index this snapshot so each + // token's bytes are computed a single time and its control flag derives from the same bytes. + var tokenBytes: [[UInt8]] = [] + tokenBytes.reserveCapacity(vocabSize) + for id in 0 ..< vocabSize { + tokenBytes.append(detokenizeBytes(Int32(id))) + } + + let profile = TokenProfile.build( + vocabSize: vocabSize, + bytesFor: { tokenBytes[$0] }, + // A token that detokenizes to no visible bytes is a structural / special / control token + // (llama renders those empty when special rendering is off); never emit it as text. + isControl: { tokenBytes[$0].isEmpty }, + isEndOfGeneration: { self.engine.isEndOfGenerationToken(Int32($0)) } + ) + cachedTokenProfile = profile + cachedTokenProfileModelURL = modelURL + CotabbyLogger.runtime.debug( + "Built constrained-decode token profile", + metadata: ["vocab_size": .stringConvertible(vocabSize)] + ) + return profile + } + + /// The raw UTF-8 bytes a token detokenizes to, or empty for a structural token that renders to + /// nothing. `detokenize` returns the byte count, or a negative `-(required)` when the fixed + /// buffer is too small; the rare large-piece case retries once at the requested size. + private func detokenizeBytes(_ token: Int32) -> [UInt8] { + var buffer = [CChar](repeating: 0, count: 256) + let written = buffer.withUnsafeMutableBufferPointer { ptr in + Int(engine.detokenize(token, ptr.baseAddress, Int32(ptr.count))) + } + if written > 0 { + return buffer.prefix(written).map { UInt8(bitPattern: $0) } + } + if written < 0 { + var large = [CChar](repeating: 0, count: -written) + let writtenLarge = large.withUnsafeMutableBufferPointer { ptr in + Int(engine.detokenize(token, ptr.baseAddress, Int32(ptr.count))) + } + return writtenLarge > 0 ? large.prefix(writtenLarge).map { UInt8(bitPattern: $0) } : [] + } + return [] + } + private static func extractPiece(_ result: SampleResult) -> String { guard let piece = result.piece, result.piece_length > 0 else { return "" } let buffer = UnsafeBufferPointer( diff --git a/Cotabby/Services/Runtime/LlamaSuggestionEngine.swift b/Cotabby/Services/Runtime/LlamaSuggestionEngine.swift index f793531..21b0c0f 100644 --- a/Cotabby/Services/Runtime/LlamaSuggestionEngine.swift +++ b/Cotabby/Services/Runtime/LlamaSuggestionEngine.swift @@ -13,6 +13,15 @@ final class LlamaSuggestionEngine { private let runtimeManager: LlamaRuntimeManager private var promptCacheHintTracker = LlamaPromptCacheHintTracker() + /// UserDefaults key (no UI) that routes llama generation through the deterministic constrained + /// decoder instead of the engine's stochastic sampler. Default-off: decode quality can only be + /// judged with a real model in a real field, so this stays a hidden developer/dogfood toggle + /// until it is validated on device and promoted to the default. + private static let constrainedDecoderDefaultsKey = "cotabbyConstrainedDecoderEnabled" + private static var isConstrainedDecoderEnabled: Bool { + UserDefaults.standard.bool(forKey: constrainedDecoderDefaultsKey) + } + init(runtimeManager: LlamaRuntimeManager) { self.runtimeManager = runtimeManager } @@ -50,7 +59,8 @@ final class LlamaSuggestionEngine { forceWordContinuation: MidWordContinuationPolicy.shouldForceContinuation( precedingText: request.context.precedingText, trailingText: request.context.trailingText - ) + ), + useConstrainedDecoder: Self.isConstrainedDecoderEnabled ) ) try Task.checkCancellation() diff --git a/Cotabby/Support/ConstrainedSampler.swift b/Cotabby/Support/ConstrainedSampler.swift new file mode 100644 index 0000000..0fe5031 --- /dev/null +++ b/Cotabby/Support/ConstrainedSampler.swift @@ -0,0 +1,129 @@ +import Foundation + +/// File overview: +/// Pure, deterministic token selection over a single step's logits, plus a confidence helper that +/// averages per-step log-probabilities. Selection skips excluded (control) tokens, optionally +/// restricts to a set of admissible ids, and returns the surviving token with the highest logit. +/// +/// Why this file exists: +/// Constrained decoding needs a selection step that is fully reproducible: the same logits and the +/// same constraints must always yield the same token, so behavior is testable and a suggestion can +/// be explained after the fact. This sampler is therefore deterministic argmax with no RNG and no +/// temperature. `topK` only bounds how large the candidate pool is before the argmax — it is a cost +/// guard, not a source of randomness, so a smaller `topK` can only ever change the result by +/// excluding lower-logit tokens that would not have won anyway among the unexcluded candidates. The +/// admissibility set (when present) is the byte-prefix constraint computed elsewhere; passing nil +/// means "no prefix constraint". Keeping this logic pure keeps the engine integration thin: the +/// runtime supplies logits and the precomputed constraints, and this returns an id (or nil when +/// nothing survives). +enum ConstrainedSampler { + /// Selects the highest-logit token that survives the constraints, or nil when none survive. + /// + /// Survivors are tokens that are in-range, not `profile.isExcluded`, and — when + /// `admissibleTokenIDs` is non-nil — members of that set. `topK` bounds the candidate pool by + /// pre-ranking on logit before filtering: only the `topK` highest-logit token ids are considered. + /// Because selection is a plain argmax, bounding the pool cannot change which token wins unless + /// the winner sat outside the top `topK` by raw logit, so callers trade recall for cost by + /// lowering `topK`. A `topK` of zero or negative considers no candidates and returns nil. + /// + /// Determinism note: ties on logit are broken by the lower token id, so equal-logit inputs still + /// produce a single stable result. + static func selectToken( + logits: [Float], + profile: TokenProfile, + admissibleTokenIDs: Set?, + topK: Int + ) -> Int? { + guard topK > 0, !logits.isEmpty else { + return nil + } + if let admissible = admissibleTokenIDs, admissible.isEmpty { + // An explicit empty admissible set means the prefix constraint admits nothing this step. + return nil + } + + let candidates = candidatePool(count: logits.count, logits: logits, limit: topK) + + var best: Int? + var bestLogit: Float = -.infinity + for id in candidates { + if profile.isExcluded(id) { + continue + } + if let admissible = admissibleTokenIDs, !admissible.contains(id) { + continue + } + let logit = logits[id] + // Strict greater-than keeps the first-seen (lower-id, because the pool is id-ordered after + // the top-k cut) token on ties, which makes the result independent of iteration quirks. + if best == nil || logit > bestLogit { + best = id + bestLogit = logit + } + } + return best + } + + /// Average per-step log-probability of a sequence of chosen tokens, a confidence summary suitable + /// for the existing low-confidence suppression policy. + /// + /// `fullRows[i]` is the full logits vector at step `i` and `chosenLogits[i]` is the logit of the + /// token actually committed at step `i` (the caller already knows which id it picked, so it passes + /// the scalar rather than the id). For each step this computes the softmax log-probability of the + /// chosen token, `chosenLogit - logSumExp(row)`, and returns the mean across steps. Returns nil + /// when there are no steps or the two inputs disagree in length, since an average is undefined + /// then. Pure and deterministic: a numerically stable log-sum-exp (shifted by the row maximum) + /// makes the result independent of constant offsets in the logits. + static func averageLogProb(of chosenLogits: [Float], over fullRows: [[Float]]) -> Double? { + guard !chosenLogits.isEmpty, chosenLogits.count == fullRows.count else { + return nil + } + var total = 0.0 + for (chosen, row) in zip(chosenLogits, fullRows) { + guard !row.isEmpty else { + return nil + } + total += Double(chosen) - logSumExp(row) + } + return total / Double(chosenLogits.count) + } + + /// The softmax log-probability of the token at `index` in `row`: `row[index] - logSumExp(row)`. + /// This is the single-step form of `averageLogProb`, for decoders that score each chosen token as + /// they go instead of retaining every logits row (retaining full rows for a whole completion would + /// cost vocab-size floats per step). Returns nil for an empty row or an out-of-range index. + static func logProb(ofTokenAt index: Int, in row: [Float]) -> Double? { + guard !row.isEmpty, index >= 0, index < row.count else { + return nil + } + return Double(row[index]) - logSumExp(row) + } + + /// The token ids to consider this step, ordered by id. When `limit` is at least `count` every id + /// is returned (still id-ordered). Otherwise the ids are ranked by descending logit, the top + /// `limit` are kept, and that subset is re-sorted by id so downstream tie-breaking stays stable. + private static func candidatePool(count: Int, logits: [Float], limit: Int) -> [Int] { + guard limit < count else { + return Array(0.. logits[rhs] + } + // Stable id ordering for equal logits so the kept set is deterministic at the cut line. + return lhs < rhs + } + return ranked.prefix(limit).sorted() + } + + /// Numerically stable log(sum(exp(row))): subtract the max before exponentiating so large logits + /// do not overflow. The caller guarantees `row` is non-empty. + private static func logSumExp(_ row: [Float]) -> Double { + let maxLogit = Double(row.max() ?? 0) + var sumExp = 0.0 + for value in row { + sumExp += exp(Double(value) - maxLogit) + } + return maxLogit + log(sumExp) + } +} diff --git a/Cotabby/Support/TokenProfile.swift b/Cotabby/Support/TokenProfile.swift new file mode 100644 index 0000000..57815c1 --- /dev/null +++ b/Cotabby/Support/TokenProfile.swift @@ -0,0 +1,136 @@ +import Foundation + +/// File overview: +/// Per-token metadata used by constrained decoding: for every token id in a model's vocabulary it +/// records the token's raw UTF-8 bytes plus a few classification flags (control, end-of-generation, +/// whitespace-only, newline). Constrained sampling and prefix admissibility both read from this +/// table instead of querying a live engine. +/// +/// Why this file exists: +/// Constrained decoding needs the *byte* shape of every candidate token (to test whether it can +/// continue a required prefix) and a way to drop tokens that should never be sampled as visible text +/// (control / structural tokens). A live tokenizer cannot be called from pure decision code without +/// dragging the runtime in, and it would not be deterministically testable. Building this profile +/// once from injected vocabulary data keeps the decoding rules pure: tests supply stub closures for +/// the bytes and flags, and the same inputs always yield the same verdicts. The builder takes +/// closures rather than concrete engine objects precisely so no runtime dependency leaks in. +struct TokenProfile { + /// Classification flags for a single token, kept compact so the per-token table stays small even + /// for large vocabularies. + struct Entry { + /// The token's raw UTF-8 bytes. Byte-level (not String) because admissibility against a + /// partial-word prefix is a byte-prefix relationship: a token may encode only part of a + /// multi-byte scalar, and Strings cannot represent that fragment. + let bytes: [UInt8] + /// A structural / control token (for example a chat or special marker) that must never be + /// emitted as visible completion text. + let isControl: Bool + /// A token the engine treats as a stop / end-of-generation signal. + let isEndOfGeneration: Bool + /// The decoded bytes are non-empty and contain only whitespace. + let isWhitespaceOnly: Bool + /// The decoded bytes contain a line feed (`\n`). + let isNewline: Bool + } + + /// Indexed by token id; `entries[id]` is the metadata for token `id`. + let entries: [Entry] + + /// The number of tokens described by this profile. + var vocabSize: Int { entries.count } + + /// Builds a profile for `vocabSize` tokens by pulling each token's bytes and flags from the + /// supplied closures. The closures are the only source of engine data, which is what keeps the + /// type pure and testable: a runtime passes detokenize / control / EOG lookups, and a test passes + /// stubs. Whitespace-only and newline are derived from the bytes here so callers cannot supply an + /// inconsistent classification. + static func build( + vocabSize: Int, + bytesFor: (Int) -> [UInt8], + isControl: (Int) -> Bool, + isEndOfGeneration: (Int) -> Bool + ) -> TokenProfile { + guard vocabSize > 0 else { + return TokenProfile(entries: []) + } + var entries: [Entry] = [] + entries.reserveCapacity(vocabSize) + for id in 0.. [UInt8] { + guard let entry = entry(for: id) else { + return [] + } + return entry.bytes + } + + /// Whether `id` must be excluded from ordinary sampling. Control tokens are excluded so structural + /// markers never surface as visible completion text. An out-of-range id is treated as excluded so + /// it can never be selected by accident. + func isExcluded(_ id: Int) -> Bool { + guard let entry = entry(for: id) else { + return true + } + return entry.isControl + } + + /// Whether `id` is an end-of-generation token. False for an out-of-range id. + func isEndOfGeneration(_ id: Int) -> Bool { + entry(for: id)?.isEndOfGeneration ?? false + } + + /// Whether `id` decodes to bytes containing a newline. False for an out-of-range id. + func isNewline(_ id: Int) -> Bool { + entry(for: id)?.isNewline ?? false + } + + /// Whether `id` decodes to non-empty, whitespace-only bytes. False for an out-of-range id. + func isWhitespaceOnly(_ id: Int) -> Bool { + entry(for: id)?.isWhitespaceOnly ?? false + } + + private func entry(for id: Int) -> Entry? { + guard id >= 0, id < entries.count else { + return nil + } + return entries[id] + } + + private static let lineFeed: UInt8 = 0x0A + + /// Non-empty and every byte is an ASCII whitespace character. Constrained to ASCII whitespace on + /// purpose: classifying arbitrary multi-byte Unicode whitespace would require decoding partial + /// scalars, which a single token's bytes may not form. Empty bytes are not whitespace-only. + private static func isWhitespaceOnly(_ bytes: [UInt8]) -> Bool { + guard !bytes.isEmpty else { + return false + } + return bytes.allSatisfy(Self.isASCIIWhitespace) + } + + private static func isASCIIWhitespace(_ byte: UInt8) -> Bool { + switch byte { + case 0x20, 0x09, 0x0A, 0x0B, 0x0C, 0x0D: + // space, tab, line feed, vertical tab, form feed, carriage return + return true + default: + return false + } + } +} diff --git a/CotabbyTests/ConstrainedSamplerTests.swift b/CotabbyTests/ConstrainedSamplerTests.swift new file mode 100644 index 0000000..08eb25c --- /dev/null +++ b/CotabbyTests/ConstrainedSamplerTests.swift @@ -0,0 +1,242 @@ +import XCTest +@testable import Cotabby + +/// Pure-function tests for deterministic constrained selection and the confidence helper. No RNG is +/// involved, so every selection is an exact, repeatable argmax under the given constraints. +final class ConstrainedSamplerTests: XCTestCase { + + /// Profile of plain non-control, non-EOG tokens with single-letter bytes, one per logit slot. + private func plainProfile(count: Int, control: Set = []) -> TokenProfile { + TokenProfile.build( + vocabSize: count, + bytesFor: { [UInt8(65 + ($0 % 26))] }, + isControl: { control.contains($0) }, + isEndOfGeneration: { _ in false } + ) + } + + // MARK: - selectToken + + func test_select_returnsHighestLogit() { + let logits: [Float] = [0.1, 2.5, 1.0, -3.0] + let id = ConstrainedSampler.selectToken( + logits: logits, + profile: plainProfile(count: 4), + admissibleTokenIDs: nil, + topK: 4 + ) + XCTAssertEqual(id, 1) + } + + func test_select_skipsExcludedControlTokens() { + // Token 1 has the highest logit but is control, so it must be skipped in favor of token 2. + let logits: [Float] = [0.1, 5.0, 2.0, 1.0] + let id = ConstrainedSampler.selectToken( + logits: logits, + profile: plainProfile(count: 4, control: [1]), + admissibleTokenIDs: nil, + topK: 4 + ) + XCTAssertEqual(id, 2) + } + + func test_select_honorsAdmissibleSet() { + // Highest logit is token 0, but only {2, 3} are admissible, so token 2 (higher of the two) wins. + let logits: [Float] = [9.0, 8.0, 3.0, 2.0] + let id = ConstrainedSampler.selectToken( + logits: logits, + profile: plainProfile(count: 4), + admissibleTokenIDs: [2, 3], + topK: 4 + ) + XCTAssertEqual(id, 2) + } + + func test_select_emptyAdmissibleSet_returnsNil() { + // An explicit empty constraint admits nothing this step. + let logits: [Float] = [1.0, 2.0, 3.0] + let id = ConstrainedSampler.selectToken( + logits: logits, + profile: plainProfile(count: 3), + admissibleTokenIDs: [], + topK: 3 + ) + XCTAssertNil(id) + } + + func test_select_allExcluded_returnsNil() { + let logits: [Float] = [1.0, 2.0, 3.0] + let id = ConstrainedSampler.selectToken( + logits: logits, + profile: plainProfile(count: 3, control: [0, 1, 2]), + admissibleTokenIDs: nil, + topK: 3 + ) + XCTAssertNil(id) + } + + func test_select_admissibleIDOutOfRange_isIgnored() { + // An admissible id with no logit slot must not crash or be selected; the in-range admissible + // token wins instead. + let logits: [Float] = [1.0, 4.0] + let id = ConstrainedSampler.selectToken( + logits: logits, + profile: plainProfile(count: 2), + admissibleTokenIDs: [1, 99], + topK: 2 + ) + XCTAssertEqual(id, 1) + } + + func test_select_tieBrokenByLowerID() { + // Equal logits must resolve to the lower id so the result is stable. + let logits: [Float] = [5.0, 5.0, 5.0] + let id = ConstrainedSampler.selectToken( + logits: logits, + profile: plainProfile(count: 3), + admissibleTokenIDs: nil, + topK: 3 + ) + XCTAssertEqual(id, 0) + } + + func test_select_topKBoundsCandidatePool() { + // topK=2 keeps only the two highest-logit ids {0, 3}; the lower-logit ids 1 and 2 are never + // considered. Token 0 is excluded (control), so token 3 wins from within the bounded pool. + let logits: [Float] = [9.0, 1.0, 2.0, 8.0] + let id = ConstrainedSampler.selectToken( + logits: logits, + profile: plainProfile(count: 4, control: [0]), + admissibleTokenIDs: nil, + topK: 2 + ) + XCTAssertEqual(id, 3) + } + + func test_select_topKTooSmallToReachAdmissible_returnsNil() { + // Admissible token 2 has a low logit; topK=1 keeps only the global max (token 0), which is not + // admissible, so nothing survives. Demonstrates that topK trades recall for cost. + let logits: [Float] = [9.0, 8.0, 0.5] + let id = ConstrainedSampler.selectToken( + logits: logits, + profile: plainProfile(count: 3), + admissibleTokenIDs: [2], + topK: 1 + ) + XCTAssertNil(id) + } + + func test_select_topKZero_returnsNil() { + let id = ConstrainedSampler.selectToken( + logits: [1.0, 2.0], + profile: plainProfile(count: 2), + admissibleTokenIDs: nil, + topK: 0 + ) + XCTAssertNil(id) + } + + func test_select_emptyLogits_returnsNil() { + let id = ConstrainedSampler.selectToken( + logits: [], + profile: plainProfile(count: 0), + admissibleTokenIDs: nil, + topK: 4 + ) + XCTAssertNil(id) + } + + func test_select_isDeterministicAcrossRepeatedCalls() { + let logits: [Float] = [0.3, 0.31, 0.305, 0.31] + let profile = plainProfile(count: 4) + let first = ConstrainedSampler.selectToken( + logits: logits, profile: profile, admissibleTokenIDs: nil, topK: 4 + ) + for _ in 0..<20 { + let again = ConstrainedSampler.selectToken( + logits: logits, profile: profile, admissibleTokenIDs: nil, topK: 4 + ) + XCTAssertEqual(again, first) + } + // Tie between ids 1 and 3 resolves to the lower id. + XCTAssertEqual(first, 1) + } + + // MARK: - averageLogProb + + func test_averageLogProb_uniformRow_matchesNegativeLogVocab() { + // Every logit equal -> each token's probability is 1/N, so log-prob is -ln(N) every step. + let row: [Float] = [0, 0, 0, 0] + let value = ConstrainedSampler.averageLogProb(of: [0, 0], over: [row, row]) + XCTAssertNotNil(value) + XCTAssertEqual(value ?? .nan, -log(4.0), accuracy: 1e-9) + } + + func test_averageLogProb_isInvariantToConstantOffset() { + // Adding a constant to every logit in a row leaves softmax (and thus the log-prob) unchanged. + let base: [Float] = [1.0, 2.0, 0.5] + let shifted: [Float] = base.map { $0 + 100.0 } + let plain = ConstrainedSampler.averageLogProb(of: [2.0], over: [base]) + let offset = ConstrainedSampler.averageLogProb(of: [102.0], over: [shifted]) + XCTAssertNotNil(plain) + XCTAssertNotNil(offset) + XCTAssertEqual(plain ?? .nan, offset ?? .nan, accuracy: 1e-6) + } + + func test_averageLogProb_averagesAcrossSteps() { + // Two steps with known per-step log-probs; the result is their mean. + let rowA: [Float] = [0, 0] // chosen logit 0 -> log(0.5) + let rowB: [Float] = [0, 0, 0, 0] // chosen logit 0 -> log(0.25) + let value = ConstrainedSampler.averageLogProb(of: [0, 0], over: [rowA, rowB]) + let expected = (log(0.5) + log(0.25)) / 2.0 + XCTAssertEqual(value ?? .nan, expected, accuracy: 1e-9) + } + + func test_averageLogProb_emptyInput_returnsNil() { + XCTAssertNil(ConstrainedSampler.averageLogProb(of: [], over: [])) + } + + func test_averageLogProb_lengthMismatch_returnsNil() { + XCTAssertNil(ConstrainedSampler.averageLogProb(of: [1.0], over: [[1.0], [2.0]])) + } + + func test_averageLogProb_emptyRow_returnsNil() { + XCTAssertNil(ConstrainedSampler.averageLogProb(of: [1.0], over: [[]])) + } + + // MARK: - logProb (single-step) + + func test_logProb_uniformRow_matchesNegativeLogVocab() { + // Uniform logits -> every token has probability 1/N -> log-prob is -ln(N). + let row: [Float] = [0, 0, 0, 0] + let value = ConstrainedSampler.logProb(ofTokenAt: 2, in: row) + XCTAssertEqual(value ?? .nan, -log(4.0), accuracy: 1e-9) + } + + func test_logProb_matchesAverageLogProbForSingleStep() { + // The single-step helper must agree with averaging one row, since the decoder accumulates the + // former where the offline confidence helper uses the latter. + let row: [Float] = [1.0, 2.0, 0.5, -1.0] + let single = ConstrainedSampler.logProb(ofTokenAt: 1, in: row) + let averaged = ConstrainedSampler.averageLogProb(of: [row[1]], over: [row]) + XCTAssertNotNil(single) + XCTAssertEqual(single ?? .nan, averaged ?? .nan, accuracy: 1e-9) + } + + func test_logProb_isInvariantToConstantOffset() { + let base: [Float] = [1.0, 2.0, 0.5] + let shifted: [Float] = base.map { $0 + 50.0 } + let plain = ConstrainedSampler.logProb(ofTokenAt: 0, in: base) + let offset = ConstrainedSampler.logProb(ofTokenAt: 0, in: shifted) + XCTAssertEqual(plain ?? .nan, offset ?? .nan, accuracy: 1e-6) + } + + func test_logProb_outOfRangeIndex_returnsNil() { + XCTAssertNil(ConstrainedSampler.logProb(ofTokenAt: 5, in: [1.0, 2.0])) + XCTAssertNil(ConstrainedSampler.logProb(ofTokenAt: -1, in: [1.0, 2.0])) + } + + func test_logProb_emptyRow_returnsNil() { + XCTAssertNil(ConstrainedSampler.logProb(ofTokenAt: 0, in: [])) + } +} diff --git a/CotabbyTests/TokenProfileTests.swift b/CotabbyTests/TokenProfileTests.swift new file mode 100644 index 0000000..dc763ea --- /dev/null +++ b/CotabbyTests/TokenProfileTests.swift @@ -0,0 +1,109 @@ +import XCTest +@testable import Cotabby + +/// Pure-function tests for per-token constrained-decoding metadata. Engine data is supplied through +/// stub closures, so every assertion is deterministic and no runtime is involved. +final class TokenProfileTests: XCTestCase { + + /// One token's stub data; a small struct rather than a tuple keeps the table readable and avoids + /// a large-tuple lint warning. + private struct Stub { + let bytes: [UInt8] + let control: Bool + let eog: Bool + } + + /// Builds a profile from a literal table of stub entries indexed by token id. + private func makeProfile(_ table: [Stub]) -> TokenProfile { + TokenProfile.build( + vocabSize: table.count, + bytesFor: { table[$0].bytes }, + isControl: { table[$0].control }, + isEndOfGeneration: { table[$0].eog } + ) + } + + private func bytes(_ string: String) -> [UInt8] { + Array(string.utf8) + } + + func test_build_recordsVocabSizeAndBytes() { + let profile = makeProfile([ + Stub(bytes: bytes("the"), control: false, eog: false), + Stub(bytes: bytes(" dog"), control: false, eog: false) + ]) + XCTAssertEqual(profile.vocabSize, 2) + XCTAssertEqual(profile.bytes(for: 0), bytes("the")) + XCTAssertEqual(profile.bytes(for: 1), bytes(" dog")) + } + + func test_build_emptyVocab_producesEmptyProfile() { + let profile = TokenProfile.build( + vocabSize: 0, + bytesFor: { _ in [] }, + isControl: { _ in false }, + isEndOfGeneration: { _ in false } + ) + XCTAssertEqual(profile.vocabSize, 0) + } + + func test_controlToken_isExcluded() { + let profile = makeProfile([ + Stub(bytes: bytes("hi"), control: false, eog: false), + Stub(bytes: bytes("<|end|>"), control: true, eog: false) + ]) + XCTAssertFalse(profile.isExcluded(0)) + XCTAssertTrue(profile.isExcluded(1)) + } + + func test_endOfGenerationFlag_isReported() { + let profile = makeProfile([ + Stub(bytes: bytes("word"), control: false, eog: false), + Stub(bytes: bytes(""), control: true, eog: true) + ]) + XCTAssertFalse(profile.isEndOfGeneration(0)) + XCTAssertTrue(profile.isEndOfGeneration(1)) + } + + func test_whitespaceOnly_classification() { + let profile = makeProfile([ + Stub(bytes: bytes(" "), control: false, eog: false), + Stub(bytes: bytes("\t \n"), control: false, eog: false), + Stub(bytes: bytes(" x"), control: false, eog: false), + Stub(bytes: [], control: false, eog: false) + ]) + XCTAssertTrue(profile.isWhitespaceOnly(0)) + XCTAssertTrue(profile.isWhitespaceOnly(1)) + // A space followed by a letter is not whitespace-only. + XCTAssertFalse(profile.isWhitespaceOnly(2)) + // Empty bytes are not whitespace-only. + XCTAssertFalse(profile.isWhitespaceOnly(3)) + } + + func test_newline_classification() { + let profile = makeProfile([ + Stub(bytes: bytes("\n"), control: false, eog: false), + Stub(bytes: bytes("a\nb"), control: false, eog: false), + Stub(bytes: bytes("plain"), control: false, eog: false) + ]) + XCTAssertTrue(profile.isNewline(0)) + // Newline embedded among other bytes still counts. + XCTAssertTrue(profile.isNewline(1)) + XCTAssertFalse(profile.isNewline(2)) + } + + func test_outOfRangeID_isDefensive() { + let profile = makeProfile([ + Stub(bytes: bytes("only"), control: false, eog: false) + ]) + // Out-of-range queries must never crash and must be treated as excluded / negative so a stray + // id can never be selected or misclassified. + XCTAssertEqual(profile.bytes(for: 5), []) + XCTAssertEqual(profile.bytes(for: -1), []) + XCTAssertTrue(profile.isExcluded(5)) + XCTAssertTrue(profile.isExcluded(-1)) + XCTAssertFalse(profile.isEndOfGeneration(5)) + XCTAssertFalse(profile.isNewline(5)) + XCTAssertFalse(profile.isWhitespaceOnly(5)) + } +} diff --git a/project.yml b/project.yml index b1727e8..273be9a 100644 --- a/project.yml +++ b/project.yml @@ -13,7 +13,7 @@ packages: exactVersion: 2.9.1 CotabbyInference: url: https://github.com/FuJacob/cotabbyinference.git - branch: feat/generation-quality-controls + branch: main swift-log: url: https://github.com/apple/swift-log.git from: 1.12.1