Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 49 additions & 2 deletions python/flink_agents/integrations/chat_models/ollama_chat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#################################################################################
import re
import uuid
from typing import Any, Dict, List, Mapping, Optional, Sequence, Union
from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union

from ollama import Client, Message
from pydantic import Field
Expand Down Expand Up @@ -67,6 +68,10 @@ class OllamaChatModel(BaseChatModel):
default="5m",
description="controls how long the model will stay loaded into memory following the request(default: 5m)",
)
extract_reasoning: bool = Field(
default=False,
description="If True, extracts content within <think></think> tags from the response and stores it in additional_kwargs.",
)

__client: Client = None
__tools: Sequence[Mapping[str, Any]] = []
Expand All @@ -80,6 +85,7 @@ def __init__(
request_timeout: Optional[float] = DEFAULT_REQUEST_TIMEOUT,
additional_kwargs: Optional[Dict[str, Any]] = None,
keep_alive: Optional[Union[float, str]] = None,
extract_reasoning: Optional[bool] = False,
**kwargs: Any,
) -> None:
"""Init method."""
Expand All @@ -93,6 +99,7 @@ def __init__(
request_timeout=request_timeout,
additional_kwargs=additional_kwargs,
keep_alive=keep_alive,
extract_reasoning=extract_reasoning,
**kwargs,
)
# bind tools
Expand Down Expand Up @@ -125,6 +132,34 @@ def model_kwargs(self) -> Dict[str, Any]:
**self.additional_kwargs,
}

@staticmethod
def __extract_think_tags(content: str) -> Tuple[str, Optional[str]]:
"""Extract content within <think></think> tags and clean the remaining content.

Args:
content: Original content text

Returns:
Tuple containing (cleaned_content, reasoning_content)
"""
think_pattern = r"<think>(.*?)</think>"
reasoning = None

# Find all <think> tag content
think_matches = re.findall(think_pattern, content, re.DOTALL)
if think_matches:
reasoning = "\n".join(think_matches)

# Remove <think> tags and their content from the original text
cleaned_content = re.sub(think_pattern, "", content, flags=re.DOTALL)

# Clean up any extra whitespace that might have been created
cleaned_content = re.sub(r'\n{3,}', '\n\n', cleaned_content)
cleaned_content = re.sub(r' {2,}', ' ', cleaned_content)
cleaned_content = cleaned_content.strip()

return cleaned_content, reasoning

def chat(self, messages: Sequence[ChatMessage]) -> ChatMessage:
"""Process a sequence of messages, and return a response."""
if self.prompt is not None:
Expand Down Expand Up @@ -156,10 +191,22 @@ def chat(self, messages: Sequence[ChatMessage]) -> ChatMessage:
},
}
tool_calls.append(tool_call)

content = response.message.content
extra_args = {}

# Process reasoning if extract_reasoning is enabled
if self.extract_reasoning and content:
cleaned_content, reasoning = self.__extract_think_tags(content)
content = cleaned_content
if reasoning:
extra_args["reasoning"] = reasoning

return ChatMessage(
role=MessageRole(response.message.role),
content=response.message.content,
content=content,
tool_calls=tool_calls,
extra_args=extra_args,
)

@staticmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import subprocess
import sys
from pathlib import Path
from unittest.mock import MagicMock

import pytest
from ollama import Client
Expand Down Expand Up @@ -101,3 +102,64 @@ def test_ollama_chat_with_tools() -> None: # noqa :D103
assert len(tool_calls) == 1
tool_call = tool_calls[0]
assert add(**tool_call["function"]["arguments"]) == 3


def test_extract_think_tags() -> None:
"""Test the static method that extracts content from <think></think> tags."""
# Test with a think tag at the beginning (most common case)
content = "<think>First, I need to understand the question.\nThen I need to formulate an answer.</think>The answer is 42."
cleaned, reasoning = OllamaChatModel._OllamaChatModel__extract_think_tags(content)
assert cleaned == "The answer is 42."
assert reasoning == "First, I need to understand the question.\nThen I need to formulate an answer."
# Test with a think tag only
content = "<think>This is just my thought process.</think>"
cleaned, reasoning = OllamaChatModel._OllamaChatModel__extract_think_tags(content)
assert cleaned == ""
assert reasoning == "This is just my thought process."

# Test with no think tags
content = "This is a regular response without any thinking tags."
cleaned, reasoning = OllamaChatModel._OllamaChatModel__extract_think_tags(content)
assert cleaned == content
assert reasoning is None


def test_ollama_chat_with_extract_reasoning() -> None:
"""Test that extract_reasoning functionality works correctly."""
# Create mock objects for client and response
mock_client = MagicMock()
mock_response = MagicMock()
# Use a more realistic reasoning pattern at the beginning
mock_response.message.content = "<think>To answer what the meaning of life is, I should consider philosophical perspectives. The question is often associated with the number 42 from Hitchhiker's Guide to the Galaxy.</think>The meaning of life is often considered to be 42, according to the Hitchhiker's Guide to the Galaxy."
mock_response.message.role = "assistant"
mock_response.message.tool_calls = None

# Configure mock client to return our mock response
mock_client.chat.return_value = mock_response
# Create model with mocked client
llm = OllamaChatModel(
name="ollama",
model=test_model,
extract_reasoning=True
)

# Replace the real client with our mock client
llm._OllamaChatModel__client = mock_client

# Call the chat method
response = llm.chat([
ChatMessage(
role=MessageRole.USER,
content="What's the meaning of life?",
)
])

# Verify our mock was called correctly
mock_client.chat.assert_called_once()

# Check that the response content has been cleaned
assert response.content == "The meaning of life is often considered to be 42, according to the Hitchhiker's Guide to the Galaxy."
# Check that the reasoning has been extracted and stored
assert "reasoning" in response.extra_args
assert "philosophical perspectives" in response.extra_args["reasoning"]
assert "Hitchhiker's Guide to the Galaxy" in response.extra_args["reasoning"]
Loading