diff --git a/integrations/chat-models/ollama/pom.xml b/integrations/chat-models/ollama/pom.xml index 96850c44..04779623 100644 --- a/integrations/chat-models/ollama/pom.xml +++ b/integrations/chat-models/ollama/pom.xml @@ -46,7 +46,7 @@ under the License. io.github.ollama4j ollama4j - 1.1.0 + 1.1.2 diff --git a/integrations/chat-models/ollama/src/main/java/org/apache/flink/agents/integrations/chatmodels/ollama/OllamaChatModelConnection.java b/integrations/chat-models/ollama/src/main/java/org/apache/flink/agents/integrations/chatmodels/ollama/OllamaChatModelConnection.java index 72b291c9..633514c7 100644 --- a/integrations/chat-models/ollama/src/main/java/org/apache/flink/agents/integrations/chatmodels/ollama/OllamaChatModelConnection.java +++ b/integrations/chat-models/ollama/src/main/java/org/apache/flink/agents/integrations/chatmodels/ollama/OllamaChatModelConnection.java @@ -20,11 +20,9 @@ import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; -import io.github.ollama4j.OllamaAPI; import io.github.ollama4j.exceptions.RoleNotFoundException; -import io.github.ollama4j.models.chat.OllamaChatMessage; -import io.github.ollama4j.models.chat.OllamaChatMessageRole; -import io.github.ollama4j.models.chat.OllamaChatResult; +import io.github.ollama4j.models.chat.*; +import io.github.ollama4j.models.request.OllamaChatEndpointCaller; import io.github.ollama4j.tools.Tools; import org.apache.flink.agents.api.chat.messages.ChatMessage; import org.apache.flink.agents.api.chat.messages.MessageRole; @@ -33,14 +31,9 @@ import org.apache.flink.agents.api.resource.ResourceDescriptor; import org.apache.flink.agents.api.resource.ResourceType; import org.apache.flink.agents.api.tools.Tool; -import org.apache.flink.agents.api.tools.ToolParameters; -import java.util.HashMap; -import java.util.List; -import java.util.Map; +import java.util.*; import java.util.function.BiFunction; -import java.util.regex.Matcher; -import java.util.regex.Pattern; import java.util.stream.Collectors; /** @@ -66,8 +59,8 @@ * } */ public class OllamaChatModelConnection extends BaseChatModelConnection { - private final OllamaAPI client; - private final Pattern pattern; + + private final OllamaChatEndpointCaller caller; /** * Creates a new ollama chat model connection. @@ -83,13 +76,10 @@ public OllamaChatModelConnection( if (endpoint == null || endpoint.isEmpty()) { throw new IllegalArgumentException("endpoint should not be null or empty."); } - this.client = new OllamaAPI(endpoint); - Integer maxChatToolCallRetries = descriptor.getArgument("maxChatToolCallRetries"); - this.client.setMaxChatToolCallRetries( - maxChatToolCallRetries != null ? maxChatToolCallRetries : 10); Integer requestTimeout = descriptor.getArgument("requestTimeout"); - this.client.setRequestTimeoutSeconds(requestTimeout != null ? requestTimeout : 10); - this.pattern = Pattern.compile("(.*?)", Pattern.DOTALL); + this.caller = + new OllamaChatEndpointCaller( + endpoint, null, requestTimeout != null ? requestTimeout : 10); } /** @@ -108,18 +98,20 @@ public OllamaChatModelConnection( } /** - * Registers tools with the Ollama client based on tool resource names. + * Converts Flink Agent tools to Ollama compatible tool specifications. * *

Each tool's input schema is expected to be a JSON schema containing "properties" and * "required" keys. The schema is converted into the function/tool specification that Ollama - * understands, and a callable is wired to invoke the underlying BaseTool with ToolParameters. + * understands, and each tool is properly formatted for Ollama API integration. * - * @param tools tools to be registered to the client - * @throws RuntimeException if schema parsing or registration fails + * @param tools List of Flink Agent tools to be converted to Ollama tools + * @return List of Ollama compatible tool specifications + * @throws RuntimeException if schema parsing or conversion fails */ @SuppressWarnings("unchecked") - private void registerTools(List tools) { + private List convertToOllamaTools(List tools) { final ObjectMapper mapper = new ObjectMapper(); + final List ollamaTools = new ArrayList<>(); try { for (Tool tool : tools) { final Map schema = @@ -130,7 +122,7 @@ private void registerTools(List tools) { (Map>) schema.get("properties"); final List required = (List) schema.get("required"); - Map propertiesMap = new HashMap<>(); + Map propertiesMap = new HashMap<>(); for (Map.Entry> entry : properties.entrySet()) { final String paramName = entry.getKey(); @@ -140,40 +132,26 @@ private void registerTools(List tools) { propertiesMap.put( paramName, - Tools.PromptFuncDefinition.Property.builder() + Tools.Property.builder() .type(type) .description(description) .required(required.contains(paramName)) .build()); } - final Tools.ToolSpecification toolSpec = - Tools.ToolSpecification.builder() - .functionName(tool.getName()) - .functionDescription(tool.getDescription()) - .toolPrompt( - Tools.PromptFuncDefinition.builder() - .type("prompt") - .function( - Tools.PromptFuncDefinition.PromptFuncSpec - .builder() - .name(tool.getName()) - .description(tool.getDescription()) - .parameters( - Tools.PromptFuncDefinition - .Parameters - .builder() - .type("object") - .properties( - propertiesMap) - .build()) - .build()) + final Tools.Tool toolSpec = + Tools.Tool.builder() + .toolSpec( + Tools.ToolSpec.builder() + .name(tool.getName()) + .description(tool.getDescription()) + .parameters(Tools.Parameters.of(propertiesMap)) .build()) - .toolFunction(arguments -> tool.call(new ToolParameters(arguments))) .build(); - - this.client.registerTool(toolSpec); + ollamaTools.add(toolSpec); } + + return ollamaTools; } catch (Exception e) { throw new RuntimeException(e); } @@ -201,32 +179,78 @@ private OllamaChatMessage convertToOllamaChatMessages(ChatMessage message) { public ChatMessage chat( List messages, List tools, Map arguments) { try { - registerTools(tools); + final boolean extractReasoning = + (boolean) arguments.getOrDefault("extract_reasoning", false); + + final List ollamaTools = this.convertToOllamaTools(tools); final List ollamaChatMessages = messages.stream() .map(this::convertToOllamaChatMessages) .collect(Collectors.toList()); - final OllamaChatResult ollamaChatResult = - this.client.chat((String) arguments.get("model"), ollamaChatMessages); + final OllamaChatRequest chatRequest = + OllamaChatRequest.builder() + .withMessages(ollamaChatMessages) + .withModel((String) arguments.get("model")) + .withThinking(extractReasoning) + .withUseTools(false) + .build(); + + chatRequest.setTools(ollamaTools); + final OllamaChatResult ollamaChatResult = this.caller.callSync(chatRequest); + final OllamaChatResponseModel ollamaChatResponse = ollamaChatResult.getResponseModel(); + final OllamaChatMessage ollamaChatMessage = ollamaChatResponse.getMessage(); + + Map extraArgs = new HashMap<>(); + if (extractReasoning) { + extraArgs.put("reasoning", ollamaChatMessage.getThinking()); + } + + final List ollamaToolCalls = ollamaChatMessage.getToolCalls(); + final ChatMessage chatMessage = ChatMessage.assistant(ollamaChatMessage.getResponse()); + chatMessage.setExtraArgs(extraArgs); + + if (ollamaToolCalls != null) { + final List> toolCalls = convertToAgentsTools(ollamaToolCalls); + chatMessage.setToolCalls(toolCalls); + } - return extraReasoning(ollamaChatResult.getResponse()); + return chatMessage; } catch (Exception e) { throw new RuntimeException(e); } } - private ChatMessage extraReasoning(String response) { - Matcher matcher = pattern.matcher(response); - StringBuilder reasoning = new StringBuilder(); - while (matcher.find()) { - reasoning.append(matcher.group(1)); + /** + * Converts Ollama tool calls to the format expected by the Flink Agents framework. + * + *

This method transforms Ollama-specific tool call representations into a generic format + * that can be used by the Flink Agents framework. Each tool call is assigned a unique ID and + * structured with the appropriate function name and arguments. + * + * @param ollamaToolCalls the list of tool calls returned from Ollama API + * @return a list of tool calls formatted for Flink Agents, where each tool call is represented + * as a map containing id, type, and function details + */ + private List> convertToAgentsTools( + List ollamaToolCalls) { + final List> toolCalls = new ArrayList<>(ollamaToolCalls.size()); + for (OllamaChatToolCalls ollamaToolCall : ollamaToolCalls) { + final UUID id = UUID.randomUUID(); + final Map toolCall = + Map.of( + "id", + id, + "type", + "function", + "function", + Map.of( + "name", + ollamaToolCall.getFunction().getName(), + "arguments", + ollamaToolCall.getFunction().getArguments())); + toolCalls.add(toolCall); } - response = matcher.replaceAll("").strip(); - ChatMessage responseMessage = ChatMessage.assistant(response); - Map extraArgs = new HashMap<>(); - extraArgs.put("reasoning", reasoning.toString().strip()); - responseMessage.setExtraArgs(extraArgs); - return responseMessage; + return toolCalls; } } diff --git a/integrations/chat-models/ollama/src/main/java/org/apache/flink/agents/integrations/chatmodels/ollama/OllamaChatModelSetup.java b/integrations/chat-models/ollama/src/main/java/org/apache/flink/agents/integrations/chatmodels/ollama/OllamaChatModelSetup.java index 80bdeb33..8b78f8f9 100644 --- a/integrations/chat-models/ollama/src/main/java/org/apache/flink/agents/integrations/chatmodels/ollama/OllamaChatModelSetup.java +++ b/integrations/chat-models/ollama/src/main/java/org/apache/flink/agents/integrations/chatmodels/ollama/OllamaChatModelSetup.java @@ -56,11 +56,13 @@ public class OllamaChatModelSetup extends BaseChatModelSetup { private final String model; + private final boolean extractReasoning; public OllamaChatModelSetup( ResourceDescriptor descriptor, BiFunction getResource) { super(descriptor, getResource); this.model = descriptor.getArgument("model"); + this.extractReasoning = Boolean.parseBoolean(descriptor.getArgument("extract_reasoning")); } /** @@ -88,6 +90,7 @@ public OllamaChatModelSetup( public Map getParameters() { Map params = new HashMap<>(); params.put("model", model); + params.put("extract_reasoning", extractReasoning); return params; } } diff --git a/plan/src/main/java/org/apache/flink/agents/plan/actions/ChatModelAction.java b/plan/src/main/java/org/apache/flink/agents/plan/actions/ChatModelAction.java index f20fdc17..6fdbd922 100644 --- a/plan/src/main/java/org/apache/flink/agents/plan/actions/ChatModelAction.java +++ b/plan/src/main/java/org/apache/flink/agents/plan/actions/ChatModelAction.java @@ -31,10 +31,7 @@ import org.apache.flink.agents.api.tools.ToolResponse; import org.apache.flink.agents.plan.JavaFunction; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.UUID; +import java.util.*; /** Built-in action for processing chat request and tool call result. */ public class ChatModelAction { @@ -83,7 +80,8 @@ public static void chat( toolCallContext.put(initialRequestId, messages); } List messageContext = - (List) toolCallContext.get(initialRequestId); + new ArrayList<>((List) toolCallContext.get(initialRequestId)); + messageContext.add(response); stm.set(TOOL_CALL_CONTEXT, toolCallContext); @@ -159,7 +157,9 @@ public static void processChatRequestOrToolResponse(Event event, RunnerContext c Map toolCallContext = (Map) stm.get(TOOL_CALL_CONTEXT).getValue(); // update tool call context - List messages = (List) toolCallContext.get(initialRequestId); + List messages = + new ArrayList<>((List) toolCallContext.get(initialRequestId)); + for (Map.Entry entry : responses.entrySet()) { Map extraArgs = new HashMap<>(); String toolCallId = entry.getKey();