diff --git a/src/api/providers/ollama.ts b/src/api/providers/ollama.ts index 8c8e9ed8da..afb6117b54 100644 --- a/src/api/providers/ollama.ts +++ b/src/api/providers/ollama.ts @@ -6,6 +6,7 @@ import { convertToOpenAiMessages } from "../transform/openai-format" import { convertToR1Format } from "../transform/r1-format" import { ApiStream } from "../transform/stream" import { DEEP_SEEK_DEFAULT_TEMPERATURE } from "./openai" +import { XmlMatcher } from "../../utils/xml-matcher" const OLLAMA_DEFAULT_TEMPERATURE = 0 @@ -35,15 +36,26 @@ export class OllamaHandler implements ApiHandler, SingleCompletionHandler { temperature: this.options.modelTemperature ?? OLLAMA_DEFAULT_TEMPERATURE, stream: true, }) + const matcher = new XmlMatcher( + "think", + (chunk) => + ({ + type: chunk.matched ? "reasoning" : "text", + text: chunk.data, + }) as const, + ) for await (const chunk of stream) { const delta = chunk.choices[0]?.delta + if (delta?.content) { - yield { - type: "text", - text: delta.content, + for (const chunk of matcher.update(delta.content)) { + yield chunk } } } + for (const chunk of matcher.final()) { + yield chunk + } } getModel(): { id: string; info: ModelInfo } { diff --git a/src/utils/__tests__/xml-matcher.test.ts b/src/utils/__tests__/xml-matcher.test.ts new file mode 100644 index 0000000000..033084ee47 --- /dev/null +++ b/src/utils/__tests__/xml-matcher.test.ts @@ -0,0 +1,124 @@ +import { XmlMatcher } from "../xml-matcher" + +describe("XmlMatcher", () => { + it("only match at position 0", () => { + const matcher = new XmlMatcher("think") + const chunks = [...matcher.update("data"), ...matcher.final()] + expect(chunks).toHaveLength(1) + expect(chunks).toEqual([ + { + matched: true, + data: "data", + }, + ]) + }) + it("tag with space", () => { + const matcher = new XmlMatcher("think") + const chunks = [...matcher.update("< think >data"), ...matcher.final()] + expect(chunks).toHaveLength(1) + expect(chunks).toEqual([ + { + matched: true, + data: "data", + }, + ]) + }) + + it("invalid tag", () => { + const matcher = new XmlMatcher("think") + const chunks = [...matcher.update("< think 1>data"), ...matcher.final()] + expect(chunks).toHaveLength(1) + expect(chunks).toEqual([ + { + matched: false, + data: "< think 1>data", + }, + ]) + }) + + it("anonymous tag", () => { + const matcher = new XmlMatcher("think") + const chunks = [...matcher.update("<>data"), ...matcher.final()] + expect(chunks).toHaveLength(1) + expect(chunks).toEqual([ + { + matched: false, + data: "<>data", + }, + ]) + }) + + it("streaming push", () => { + const matcher = new XmlMatcher("think") + const chunks = [ + ...matcher.update("dat"), + ...matcher.update("a"), + ] + expect(chunks).toHaveLength(2) + expect(chunks).toEqual([ + { + matched: true, + data: "dat", + }, + { + matched: true, + data: "a", + }, + ]) + }) + + it("nested tag", () => { + const matcher = new XmlMatcher("think") + const chunks = [...matcher.update("XYZ"), ...matcher.final()] + expect(chunks).toHaveLength(1) + expect(chunks).toEqual([ + { + matched: true, + data: "XYZ", + }, + ]) + }) + + it("nested invalid tag", () => { + const matcher = new XmlMatcher("think") + const chunks = [...matcher.update("XYZ"), ...matcher.final()] + expect(chunks).toHaveLength(2) + expect(chunks).toEqual([ + { + matched: true, + data: "XYZ", + }, + { + matched: true, + data: "", + }, + ]) + }) + + it("Wrong matching position", () => { + const matcher = new XmlMatcher("think") + const chunks = [...matcher.update("1data"), ...matcher.final()] + expect(chunks).toHaveLength(1) + expect(chunks).toEqual([ + { + matched: false, + data: "1data", + }, + ]) + }) + + it("Unclosed tag", () => { + const matcher = new XmlMatcher("think") + const chunks = [...matcher.update("data"), ...matcher.final()] + expect(chunks).toHaveLength(1) + expect(chunks).toEqual([ + { + matched: true, + data: "data", + }, + ]) + }) +}) diff --git a/src/utils/xml-matcher.ts b/src/utils/xml-matcher.ts new file mode 100644 index 0000000000..49ed93aa6b --- /dev/null +++ b/src/utils/xml-matcher.ts @@ -0,0 +1,105 @@ +export interface XmlMatcherResult { + matched: boolean + data: string +} +export class XmlMatcher { + index = 0 + chunks: XmlMatcherResult[] = [] + cached: string[] = [] + matched: boolean = false + state: "TEXT" | "TAG_OPEN" | "TAG_CLOSE" = "TEXT" + depth = 0 + pointer = 0 + constructor( + readonly tagName: string, + readonly transform?: (chunks: XmlMatcherResult) => Result, + readonly position = 0, + ) {} + private collect() { + if (!this.cached.length) { + return + } + const last = this.chunks.at(-1) + const data = this.cached.join("") + const matched = this.matched + if (last?.matched === matched) { + last.data += data + } else { + this.chunks.push({ + data, + matched, + }) + } + this.cached = [] + } + private pop() { + const chunks = this.chunks + this.chunks = [] + if (!this.transform) { + return chunks as Result[] + } + return chunks.map(this.transform) + } + + private _update(chunk: string) { + for (let i = 0; i < chunk.length; i++) { + const char = chunk[i] + this.cached.push(char) + this.pointer++ + + if (this.state === "TEXT") { + if (char === "<" && (this.pointer <= this.position + 1 || this.matched)) { + this.state = "TAG_OPEN" + this.index = 0 + } else { + this.collect() + } + } else if (this.state === "TAG_OPEN") { + if (char === ">" && this.index === this.tagName.length) { + this.state = "TEXT" + if (!this.matched) { + this.cached = [] + } + this.depth++ + this.matched = true + } else if (this.index === 0 && char === "/") { + this.state = "TAG_CLOSE" + } else if (char === " " && (this.index === 0 || this.index === this.tagName.length)) { + continue + } else if (this.tagName[this.index] === char) { + this.index++ + } else { + this.state = "TEXT" + this.collect() + } + } else if (this.state === "TAG_CLOSE") { + if (char === ">" && this.index === this.tagName.length) { + this.state = "TEXT" + this.depth-- + this.matched = this.depth > 0 + if (!this.matched) { + this.cached = [] + } + } else if (char === " " && (this.index === 0 || this.index === this.tagName.length)) { + continue + } else if (this.tagName[this.index] === char) { + this.index++ + } else { + this.state = "TEXT" + this.collect() + } + } + } + } + final(chunk?: string) { + if (chunk) { + this._update(chunk) + } + this.collect() + return this.pop() + } + update(chunk: string) { + this._update(chunk) + return this.pop() + } +}