Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion integrations/chat-models/ollama/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ under the License.
<dependency>
<groupId>io.github.ollama4j</groupId>
<artifactId>ollama4j</artifactId>
<version>1.1.0</version>
<version>1.1.2</version>
</dependency>
</dependencies>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,9 @@

import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.github.ollama4j.OllamaAPI;
import io.github.ollama4j.exceptions.RoleNotFoundException;
import io.github.ollama4j.models.chat.OllamaChatMessage;
import io.github.ollama4j.models.chat.OllamaChatMessageRole;
import io.github.ollama4j.models.chat.OllamaChatResult;
import io.github.ollama4j.models.chat.*;
import io.github.ollama4j.models.request.OllamaChatEndpointCaller;
import io.github.ollama4j.tools.Tools;
import org.apache.flink.agents.api.chat.messages.ChatMessage;
import org.apache.flink.agents.api.chat.messages.MessageRole;
Expand All @@ -33,14 +31,9 @@
import org.apache.flink.agents.api.resource.ResourceDescriptor;
import org.apache.flink.agents.api.resource.ResourceType;
import org.apache.flink.agents.api.tools.Tool;
import org.apache.flink.agents.api.tools.ToolParameters;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.*;
import java.util.function.BiFunction;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;

/**
Expand All @@ -66,8 +59,8 @@
* }</pre>
*/
public class OllamaChatModelConnection extends BaseChatModelConnection {
private final OllamaAPI client;
private final Pattern pattern;

private final OllamaChatEndpointCaller caller;

/**
* Creates a new ollama chat model connection.
Expand All @@ -83,13 +76,10 @@ public OllamaChatModelConnection(
if (endpoint == null || endpoint.isEmpty()) {
throw new IllegalArgumentException("endpoint should not be null or empty.");
}
this.client = new OllamaAPI(endpoint);
Integer maxChatToolCallRetries = descriptor.getArgument("maxChatToolCallRetries");
this.client.setMaxChatToolCallRetries(
maxChatToolCallRetries != null ? maxChatToolCallRetries : 10);
Integer requestTimeout = descriptor.getArgument("requestTimeout");
this.client.setRequestTimeoutSeconds(requestTimeout != null ? requestTimeout : 10);
this.pattern = Pattern.compile("<think>(.*?)</think>", Pattern.DOTALL);
this.caller =
new OllamaChatEndpointCaller(
endpoint, null, requestTimeout != null ? requestTimeout : 10);
}

/**
Expand All @@ -108,18 +98,20 @@ public OllamaChatModelConnection(
}

/**
* Registers tools with the Ollama client based on tool resource names.
* Converts Flink Agent tools to Ollama compatible tool specifications.
*
* <p>Each tool's input schema is expected to be a JSON schema containing "properties" and
* "required" keys. The schema is converted into the function/tool specification that Ollama
* understands, and a callable is wired to invoke the underlying BaseTool with ToolParameters.
* understands, and each tool is properly formatted for Ollama API integration.
*
* @param tools tools to be registered to the client
* @throws RuntimeException if schema parsing or registration fails
* @param tools List of Flink Agent tools to be converted to Ollama tools
* @return List of Ollama compatible tool specifications
* @throws RuntimeException if schema parsing or conversion fails
*/
@SuppressWarnings("unchecked")
private void registerTools(List<Tool> tools) {
private List<Tools.Tool> convertToOllamaTools(List<Tool> tools) {
final ObjectMapper mapper = new ObjectMapper();
final List<Tools.Tool> ollamaTools = new ArrayList<>();
try {
for (Tool tool : tools) {
final Map<String, Object> schema =
Expand All @@ -130,7 +122,7 @@ private void registerTools(List<Tool> tools) {
(Map<String, Map<String, String>>) schema.get("properties");
final List<String> required = (List<String>) schema.get("required");

Map<String, Tools.PromptFuncDefinition.Property> propertiesMap = new HashMap<>();
Map<String, Tools.Property> propertiesMap = new HashMap<>();

for (Map.Entry<String, Map<String, String>> entry : properties.entrySet()) {
final String paramName = entry.getKey();
Expand All @@ -140,40 +132,26 @@ private void registerTools(List<Tool> tools) {

propertiesMap.put(
paramName,
Tools.PromptFuncDefinition.Property.builder()
Tools.Property.builder()
.type(type)
.description(description)
.required(required.contains(paramName))
.build());
}

final Tools.ToolSpecification toolSpec =
Tools.ToolSpecification.builder()
.functionName(tool.getName())
.functionDescription(tool.getDescription())
.toolPrompt(
Tools.PromptFuncDefinition.builder()
.type("prompt")
.function(
Tools.PromptFuncDefinition.PromptFuncSpec
.builder()
.name(tool.getName())
.description(tool.getDescription())
.parameters(
Tools.PromptFuncDefinition
.Parameters
.builder()
.type("object")
.properties(
propertiesMap)
.build())
.build())
final Tools.Tool toolSpec =
Tools.Tool.builder()
.toolSpec(
Tools.ToolSpec.builder()
.name(tool.getName())
.description(tool.getDescription())
.parameters(Tools.Parameters.of(propertiesMap))
.build())
.toolFunction(arguments -> tool.call(new ToolParameters(arguments)))
.build();

this.client.registerTool(toolSpec);
ollamaTools.add(toolSpec);
}

return ollamaTools;
} catch (Exception e) {
throw new RuntimeException(e);
}
Expand Down Expand Up @@ -201,32 +179,78 @@ private OllamaChatMessage convertToOllamaChatMessages(ChatMessage message) {
public ChatMessage chat(
List<ChatMessage> messages, List<Tool> tools, Map<String, Object> arguments) {
try {
registerTools(tools);
final boolean extractReasoning =
(boolean) arguments.getOrDefault("extract_reasoning", false);

final List<Tools.Tool> ollamaTools = this.convertToOllamaTools(tools);
final List<OllamaChatMessage> ollamaChatMessages =
messages.stream()
.map(this::convertToOllamaChatMessages)
.collect(Collectors.toList());

final OllamaChatResult ollamaChatResult =
this.client.chat((String) arguments.get("model"), ollamaChatMessages);
final OllamaChatRequest chatRequest =
OllamaChatRequest.builder()
.withMessages(ollamaChatMessages)
.withModel((String) arguments.get("model"))
.withThinking(extractReasoning)
.withUseTools(false)
.build();

chatRequest.setTools(ollamaTools);
final OllamaChatResult ollamaChatResult = this.caller.callSync(chatRequest);
final OllamaChatResponseModel ollamaChatResponse = ollamaChatResult.getResponseModel();
final OllamaChatMessage ollamaChatMessage = ollamaChatResponse.getMessage();

Map<String, Object> extraArgs = new HashMap<>();
if (extractReasoning) {
extraArgs.put("reasoning", ollamaChatMessage.getThinking());
}

final List<OllamaChatToolCalls> ollamaToolCalls = ollamaChatMessage.getToolCalls();
final ChatMessage chatMessage = ChatMessage.assistant(ollamaChatMessage.getResponse());
chatMessage.setExtraArgs(extraArgs);

if (ollamaToolCalls != null) {
final List<Map<String, Object>> toolCalls = convertToAgentsTools(ollamaToolCalls);
chatMessage.setToolCalls(toolCalls);
}

return extraReasoning(ollamaChatResult.getResponse());
return chatMessage;
} catch (Exception e) {
throw new RuntimeException(e);
}
}

private ChatMessage extraReasoning(String response) {
Matcher matcher = pattern.matcher(response);
StringBuilder reasoning = new StringBuilder();
while (matcher.find()) {
reasoning.append(matcher.group(1));
/**
* Converts Ollama tool calls to the format expected by the Flink Agents framework.
*
* <p>This method transforms Ollama-specific tool call representations into a generic format
* that can be used by the Flink Agents framework. Each tool call is assigned a unique ID and
* structured with the appropriate function name and arguments.
*
* @param ollamaToolCalls the list of tool calls returned from Ollama API
* @return a list of tool calls formatted for Flink Agents, where each tool call is represented
* as a map containing id, type, and function details
*/
private List<Map<String, Object>> convertToAgentsTools(
List<OllamaChatToolCalls> ollamaToolCalls) {
final List<Map<String, Object>> toolCalls = new ArrayList<>(ollamaToolCalls.size());
for (OllamaChatToolCalls ollamaToolCall : ollamaToolCalls) {
final UUID id = UUID.randomUUID();
final Map<String, Object> toolCall =
Map.of(
"id",
id,
"type",
"function",
"function",
Map.of(
"name",
ollamaToolCall.getFunction().getName(),
"arguments",
ollamaToolCall.getFunction().getArguments()));
toolCalls.add(toolCall);
}
response = matcher.replaceAll("").strip();
ChatMessage responseMessage = ChatMessage.assistant(response);
Map<String, Object> extraArgs = new HashMap<>();
extraArgs.put("reasoning", reasoning.toString().strip());
responseMessage.setExtraArgs(extraArgs);
return responseMessage;
return toolCalls;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,13 @@
public class OllamaChatModelSetup extends BaseChatModelSetup {

private final String model;
private final boolean extractReasoning;

public OllamaChatModelSetup(
ResourceDescriptor descriptor, BiFunction<String, ResourceType, Resource> getResource) {
super(descriptor, getResource);
this.model = descriptor.getArgument("model");
this.extractReasoning = Boolean.parseBoolean(descriptor.getArgument("extract_reasoning"));
}

/**
Expand Down Expand Up @@ -88,6 +90,7 @@ public OllamaChatModelSetup(
public Map<String, Object> getParameters() {
Map<String, Object> params = new HashMap<>();
params.put("model", model);
params.put("extract_reasoning", extractReasoning);
return params;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,7 @@
import org.apache.flink.agents.api.tools.ToolResponse;
import org.apache.flink.agents.plan.JavaFunction;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.*;

/** Built-in action for processing chat request and tool call result. */
public class ChatModelAction {
Expand Down Expand Up @@ -83,7 +80,8 @@ public static void chat(
toolCallContext.put(initialRequestId, messages);
}
List<ChatMessage> messageContext =
(List<ChatMessage>) toolCallContext.get(initialRequestId);
new ArrayList<>((List<ChatMessage>) toolCallContext.get(initialRequestId));

messageContext.add(response);
stm.set(TOOL_CALL_CONTEXT, toolCallContext);

Expand Down Expand Up @@ -159,7 +157,9 @@ public static void processChatRequestOrToolResponse(Event event, RunnerContext c
Map<UUID, Object> toolCallContext =
(Map<UUID, Object>) stm.get(TOOL_CALL_CONTEXT).getValue();
// update tool call context
List<ChatMessage> messages = (List<ChatMessage>) toolCallContext.get(initialRequestId);
List<ChatMessage> messages =
new ArrayList<>((List<ChatMessage>) toolCallContext.get(initialRequestId));

for (Map.Entry<String, ToolResponse> entry : responses.entrySet()) {
Map<String, Object> extraArgs = new HashMap<>();
String toolCallId = entry.getKey();
Expand Down