diff --git a/src/api/providers/ollama.ts b/src/api/providers/ollama.ts
index 8c8e9ed8da..afb6117b54 100644
--- a/src/api/providers/ollama.ts
+++ b/src/api/providers/ollama.ts
@@ -6,6 +6,7 @@ import { convertToOpenAiMessages } from "../transform/openai-format"
import { convertToR1Format } from "../transform/r1-format"
import { ApiStream } from "../transform/stream"
import { DEEP_SEEK_DEFAULT_TEMPERATURE } from "./openai"
+import { XmlMatcher } from "../../utils/xml-matcher"
const OLLAMA_DEFAULT_TEMPERATURE = 0
@@ -35,15 +36,26 @@ export class OllamaHandler implements ApiHandler, SingleCompletionHandler {
temperature: this.options.modelTemperature ?? OLLAMA_DEFAULT_TEMPERATURE,
stream: true,
})
+ const matcher = new XmlMatcher(
+ "think",
+ (chunk) =>
+ ({
+ type: chunk.matched ? "reasoning" : "text",
+ text: chunk.data,
+ }) as const,
+ )
for await (const chunk of stream) {
const delta = chunk.choices[0]?.delta
+
if (delta?.content) {
- yield {
- type: "text",
- text: delta.content,
+ for (const chunk of matcher.update(delta.content)) {
+ yield chunk
}
}
}
+ for (const chunk of matcher.final()) {
+ yield chunk
+ }
}
getModel(): { id: string; info: ModelInfo } {
diff --git a/src/utils/__tests__/xml-matcher.test.ts b/src/utils/__tests__/xml-matcher.test.ts
new file mode 100644
index 0000000000..033084ee47
--- /dev/null
+++ b/src/utils/__tests__/xml-matcher.test.ts
@@ -0,0 +1,124 @@
+import { XmlMatcher } from "../xml-matcher"
+
+describe("XmlMatcher", () => {
+ it("only match at position 0", () => {
+ const matcher = new XmlMatcher("think")
+ const chunks = [...matcher.update("data"), ...matcher.final()]
+ expect(chunks).toHaveLength(1)
+ expect(chunks).toEqual([
+ {
+ matched: true,
+ data: "data",
+ },
+ ])
+ })
+ it("tag with space", () => {
+ const matcher = new XmlMatcher("think")
+ const chunks = [...matcher.update("< think >data think >"), ...matcher.final()]
+ expect(chunks).toHaveLength(1)
+ expect(chunks).toEqual([
+ {
+ matched: true,
+ data: "data",
+ },
+ ])
+ })
+
+ it("invalid tag", () => {
+ const matcher = new XmlMatcher("think")
+ const chunks = [...matcher.update("< think 1>data think >"), ...matcher.final()]
+ expect(chunks).toHaveLength(1)
+ expect(chunks).toEqual([
+ {
+ matched: false,
+ data: "< think 1>data think >",
+ },
+ ])
+ })
+
+ it("anonymous tag", () => {
+ const matcher = new XmlMatcher("think")
+ const chunks = [...matcher.update("<>data>"), ...matcher.final()]
+ expect(chunks).toHaveLength(1)
+ expect(chunks).toEqual([
+ {
+ matched: false,
+ data: "<>data>",
+ },
+ ])
+ })
+
+ it("streaming push", () => {
+ const matcher = new XmlMatcher("think")
+ const chunks = [
+ ...matcher.update("dat"),
+ ...matcher.update("a"),
+ ...matcher.update("think>"),
+ ]
+ expect(chunks).toHaveLength(2)
+ expect(chunks).toEqual([
+ {
+ matched: true,
+ data: "dat",
+ },
+ {
+ matched: true,
+ data: "a",
+ },
+ ])
+ })
+
+ it("nested tag", () => {
+ const matcher = new XmlMatcher("think")
+ const chunks = [...matcher.update("XYZ"), ...matcher.final()]
+ expect(chunks).toHaveLength(1)
+ expect(chunks).toEqual([
+ {
+ matched: true,
+ data: "XYZ",
+ },
+ ])
+ })
+
+ it("nested invalid tag", () => {
+ const matcher = new XmlMatcher("think")
+ const chunks = [...matcher.update("XYZ"), ...matcher.final()]
+ expect(chunks).toHaveLength(2)
+ expect(chunks).toEqual([
+ {
+ matched: true,
+ data: "XYZ",
+ },
+ {
+ matched: true,
+ data: "",
+ },
+ ])
+ })
+
+ it("Wrong matching position", () => {
+ const matcher = new XmlMatcher("think")
+ const chunks = [...matcher.update("1data"), ...matcher.final()]
+ expect(chunks).toHaveLength(1)
+ expect(chunks).toEqual([
+ {
+ matched: false,
+ data: "1data",
+ },
+ ])
+ })
+
+ it("Unclosed tag", () => {
+ const matcher = new XmlMatcher("think")
+ const chunks = [...matcher.update("data"), ...matcher.final()]
+ expect(chunks).toHaveLength(1)
+ expect(chunks).toEqual([
+ {
+ matched: true,
+ data: "data",
+ },
+ ])
+ })
+})
diff --git a/src/utils/xml-matcher.ts b/src/utils/xml-matcher.ts
new file mode 100644
index 0000000000..49ed93aa6b
--- /dev/null
+++ b/src/utils/xml-matcher.ts
@@ -0,0 +1,105 @@
+export interface XmlMatcherResult {
+ matched: boolean
+ data: string
+}
+export class XmlMatcher {
+ index = 0
+ chunks: XmlMatcherResult[] = []
+ cached: string[] = []
+ matched: boolean = false
+ state: "TEXT" | "TAG_OPEN" | "TAG_CLOSE" = "TEXT"
+ depth = 0
+ pointer = 0
+ constructor(
+ readonly tagName: string,
+ readonly transform?: (chunks: XmlMatcherResult) => Result,
+ readonly position = 0,
+ ) {}
+ private collect() {
+ if (!this.cached.length) {
+ return
+ }
+ const last = this.chunks.at(-1)
+ const data = this.cached.join("")
+ const matched = this.matched
+ if (last?.matched === matched) {
+ last.data += data
+ } else {
+ this.chunks.push({
+ data,
+ matched,
+ })
+ }
+ this.cached = []
+ }
+ private pop() {
+ const chunks = this.chunks
+ this.chunks = []
+ if (!this.transform) {
+ return chunks as Result[]
+ }
+ return chunks.map(this.transform)
+ }
+
+ private _update(chunk: string) {
+ for (let i = 0; i < chunk.length; i++) {
+ const char = chunk[i]
+ this.cached.push(char)
+ this.pointer++
+
+ if (this.state === "TEXT") {
+ if (char === "<" && (this.pointer <= this.position + 1 || this.matched)) {
+ this.state = "TAG_OPEN"
+ this.index = 0
+ } else {
+ this.collect()
+ }
+ } else if (this.state === "TAG_OPEN") {
+ if (char === ">" && this.index === this.tagName.length) {
+ this.state = "TEXT"
+ if (!this.matched) {
+ this.cached = []
+ }
+ this.depth++
+ this.matched = true
+ } else if (this.index === 0 && char === "/") {
+ this.state = "TAG_CLOSE"
+ } else if (char === " " && (this.index === 0 || this.index === this.tagName.length)) {
+ continue
+ } else if (this.tagName[this.index] === char) {
+ this.index++
+ } else {
+ this.state = "TEXT"
+ this.collect()
+ }
+ } else if (this.state === "TAG_CLOSE") {
+ if (char === ">" && this.index === this.tagName.length) {
+ this.state = "TEXT"
+ this.depth--
+ this.matched = this.depth > 0
+ if (!this.matched) {
+ this.cached = []
+ }
+ } else if (char === " " && (this.index === 0 || this.index === this.tagName.length)) {
+ continue
+ } else if (this.tagName[this.index] === char) {
+ this.index++
+ } else {
+ this.state = "TEXT"
+ this.collect()
+ }
+ }
+ }
+ }
+ final(chunk?: string) {
+ if (chunk) {
+ this._update(chunk)
+ }
+ this.collect()
+ return this.pop()
+ }
+ update(chunk: string) {
+ this._update(chunk)
+ return this.pop()
+ }
+}