Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Cotabby/Services/Runtime/LlamaRuntimeCore.swift
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,8 @@ nonisolated final class LlamaRuntimeCore: @unchecked Sendable {
topK: topK,
noRepeatNgramSize: Self.noRepeatNgramSize
),
isSingleLine: options.singleLine
isSingleLine: options.singleLine,
isMidWord: options.forceWordContinuation
)
let best = candidates.first
CotabbyLogger.runtime.debug(
Expand Down
15 changes: 12 additions & 3 deletions Cotabby/Support/ConstrainedBeamSearch.swift
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,15 @@ enum ConstrainedBeamSearch {
nextLogits: @escaping BeamLogitsProvider,
profile: TokenProfile,
configuration: BeamSearchConfiguration,
isSingleLine: Bool
isSingleLine: Bool,
isMidWord: Bool = false
) -> [BeamCandidate] {
Engine(
nextLogits: nextLogits,
profile: profile,
configuration: configuration,
isSingleLine: isSingleLine
isSingleLine: isSingleLine,
isMidWord: isMidWord
).run()
}
}
Expand All @@ -85,6 +87,7 @@ private struct Engine {
let profile: TokenProfile
let configuration: BeamSearchConfiguration
let isSingleLine: Bool
let isMidWord: Bool

func run() -> [BeamCandidate] {
var frontier: [BeamCandidate] = [BeamCandidate(tokenIDs: [], bytes: [], cumulativeLogprob: 0)]
Expand Down Expand Up @@ -121,13 +124,19 @@ private struct Engine {
history: branch.tokenIDs,
ngramSize: configuration.noRepeatNgramSize
)
let candidates = ConstrainedSampler.rankedAdmissibleTokens(
var candidates = ConstrainedSampler.rankedAdmissibleTokens(
logits: logits,
profile: profile,
admissibleTokenIDs: nil,
topK: configuration.topK,
blockedTokenIDs: blocked
)
// Mid-word: the first generated token must finish the current word, not start a new token with
// punctuation / whitespace / a symbol. Applies only to the first step; later tokens generate
// freely once the word is being continued.
if isMidWord, branch.tokenIDs.isEmpty {
candidates = candidates.filter { profile.continuesWordMidStream($0) }
}
for tokenID in candidates {
if profile.isEndOfGeneration(tokenID) {
completed.append(branch)
Expand Down
22 changes: 22 additions & 0 deletions Cotabby/Support/TokenProfile.swift
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,27 @@ struct TokenProfile {
entry(for: id)?.isWhitespaceOnly ?? false
}

/// Whether `id` can continue the current word mid-stream: its first byte is an ASCII letter or
/// digit, a common within-word mark (apostrophe or hyphen), or a non-ASCII lead byte (which starts
/// a multi-byte letter or ideograph). Tokens that begin with whitespace, breaking punctuation, or a
/// symbol are rejected, so a mid-word completion finishes the word instead of starting a new token.
/// False for an out-of-range or empty (control) token.
Comment on lines +108 to +112
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 The doc comment says "False for an out-of-range or empty (control) token", but isControl is never consulted — only bytes.isEmpty is checked. A non-control token whose bytes happen to be empty would also return false, and a control token with non-empty bytes starting with a letter would return true (though rankedAdmissibleTokens would already have excluded it). The parenthetical is misleading; tightening the wording removes the ambiguity.

Suggested change
/// Whether `id` can continue the current word mid-stream: its first byte is an ASCII letter or
/// digit, a common within-word mark (apostrophe or hyphen), or a non-ASCII lead byte (which starts
/// a multi-byte letter or ideograph). Tokens that begin with whitespace, breaking punctuation, or a
/// symbol are rejected, so a mid-word completion finishes the word instead of starting a new token.
/// False for an out-of-range or empty (control) token.
/// Whether `id` can continue the current word mid-stream: its first byte is an ASCII letter or
/// digit, a common within-word mark (apostrophe or hyphen), or a non-ASCII lead byte (which starts
/// a multi-byte letter or ideograph). Tokens that begin with whitespace, breaking punctuation, or a
/// symbol are rejected, so a mid-word completion finishes the word instead of starting a new token.
/// False for an out-of-range id or a token whose byte sequence is empty.

Fix in Codex Fix in Claude Code

func continuesWordMidStream(_ id: Int) -> Bool {
guard let bytes = entry(for: id)?.bytes, !bytes.isEmpty else {
return false
}
// Inspect the first character with Unicode-aware classification: letters (including CJK and
// other scripts) and digits continue a word, as do the two common within-word marks; whitespace,
// punctuation, and symbols (ASCII or not, e.g. an em dash or arrow) do not. The lossy decode is
// fine because only the first scalar is examined and a malformed lead decodes to U+FFFD, which
// is not a letter, so it is rejected.
// swiftlint:disable:next optional_data_string_conversion
guard let first = String(decoding: bytes, as: UTF8.self).first else {
return false
}
return first.isLetter || first.isNumber || first == "'" || first == "-"
}

private func entry(for id: Int) -> Entry? {
guard id >= 0, id < entries.count else {
return nil
Expand Down Expand Up @@ -133,4 +154,5 @@ struct TokenProfile {
return false
}
}

}
18 changes: 18 additions & 0 deletions CotabbyTests/ConstrainedBeamSearchTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,24 @@ final class ConstrainedBeamSearchTests: XCTestCase {
XCTAssertFalse(recorder.paths.contains([0, 1]), "search must stop at the sentence and not step past it")
}

func test_search_midWord_firstTokenMustContinueTheWord() {
// token 0 breaks the word (leading punctuation) but has the higher logit; token 1 continues it.
// Mid-word, only a word-continuing token may start the completion.
let profile = makeProfile(byteStrings: [", and", "ing"])
let rows: [[Int]: [Float]] = [[]: row([0: 9, 1: 1], vocabSize: 2)]
let normal = ConstrainedBeamSearch.search(
nextLogits: provider(vocabSize: 2, rows: rows), profile: profile,
configuration: BeamSearchConfiguration(beamWidth: 1, maxTokens: 1, topK: 5),
isSingleLine: false, isMidWord: false)
let midWord = ConstrainedBeamSearch.search(
nextLogits: provider(vocabSize: 2, rows: rows), profile: profile,
configuration: BeamSearchConfiguration(beamWidth: 1, maxTokens: 1, topK: 5),
isSingleLine: false, isMidWord: true)

XCTAssertEqual(normal.first?.tokenIDs, [0], "without mid-word, the highest-logit token wins")
XCTAssertEqual(midWord.first?.tokenIDs, [1], "mid-word, the word-breaking token is filtered out")
}

func test_search_respectsMaxTokenBudget() {
// No EOG / sentence end: every token keeps generating, so the budget bounds the length.
let profile = makeProfile(byteStrings: ["a", "b"])
Comment on lines 150 to 173
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 No test for all-candidates-filtered path

The new test covers the normal filtered case, but there's no assertion for when isMidWord: true and every admissible token is word-breaking (e.g. the only tokens in the vocabulary are " x", ".y"). In that scenario candidates is empty after filtering, the branch is silently dropped from both live and completed, and search returns []. The caller in LlamaRuntimeCore handles this correctly with guard let best else { return "" }, but adding a single test case would pin that contract and prevent a future refactor from introducing a crash or incorrect fallback.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Fix in Codex Fix in Claude Code

Expand Down
22 changes: 22 additions & 0 deletions CotabbyTests/TokenProfileTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -106,4 +106,26 @@ final class TokenProfileTests: XCTestCase {
XCTAssertFalse(profile.isNewline(5))
XCTAssertFalse(profile.isWhitespaceOnly(5))
}

func test_continuesWordMidStream_acceptsWordCharactersAndRejectsBreakers() {
let profile = makeProfile([
Stub(bytes: bytes("rrow"), control: false, eog: false), // 0: letters
Stub(bytes: bytes("3rd"), control: false, eog: false), // 1: leading digit
Stub(bytes: bytes("'t"), control: false, eog: false), // 2: apostrophe (don't)
Stub(bytes: bytes("-op"), control: false, eog: false), // 3: hyphen (co-op)
Stub(bytes: bytes("中文"), control: false, eog: false), // 4: CJK letter
Stub(bytes: bytes(" word"), control: false, eog: false), // 5: leading space
Stub(bytes: bytes(".rrow"), control: false, eog: false), // 6: leading period
Stub(bytes: bytes("!stop"), control: false, eog: false), // 7: leading punctuation
Stub(bytes: bytes("→x"), control: false, eog: false), // 8: non-ASCII symbol
Stub(bytes: [], control: true, eog: false) // 9: empty / control
])
for id in [0, 1, 2, 3, 4] {
XCTAssertTrue(profile.continuesWordMidStream(id), "id \(id) should continue a word")
}
for id in [5, 6, 7, 8, 9] {
XCTAssertFalse(profile.continuesWordMidStream(id), "id \(id) should not continue a word")
}
XCTAssertFalse(profile.continuesWordMidStream(-1))
}
}