From 911009a021752262842c7372cf584295a9af9832 Mon Sep 17 00:00:00 2001 From: Mithilesh Singh Date: Sun, 19 Apr 2026 08:04:01 +0000 Subject: [PATCH 1/3] feat/refactor tool-calling, add ToolCallResult, PromptConfig, and sanitization(#234) --- sdk/go/agent/agent.go | 7 + sdk/go/ai/tool_calling.go | 287 ++++++++++++++++++++++----------- sdk/go/ai/tool_calling_test.go | 248 ++++++++++++++++++++++++++++ 3 files changed, 446 insertions(+), 96 deletions(-) diff --git a/sdk/go/agent/agent.go b/sdk/go/agent/agent.go index 59831bcbe..7f584804c 100644 --- a/sdk/go/agent/agent.go +++ b/sdk/go/agent/agent.go @@ -2001,6 +2001,13 @@ func (a *Agent) AIWithTools(ctx context.Context, prompt string, config ai.ToolCa } callFn := func(ctx context.Context, target string, input map[string]interface{}) (map[string]interface{}, error) { + if strings.Contains(target, ":skill:") { + parts := strings.SplitN(target, ":skill:", 2) + target = parts[0] + "." + parts[1] + } else if strings.Contains(target, ":") { + parts := strings.SplitN(target, ":", 2) + target = parts[0] + "." + parts[1] + } return a.Call(ctx, target, input) } diff --git a/sdk/go/ai/tool_calling.go b/sdk/go/ai/tool_calling.go index 68e43fc21..e92d2b156 100644 --- a/sdk/go/ai/tool_calling.go +++ b/sdk/go/ai/tool_calling.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "fmt" + "strings" "time" "github.com/Agent-Field/agentfield/sdk/go/types" @@ -13,13 +14,17 @@ import ( type ToolCallConfig struct { MaxTurns int MaxToolCalls int + SystemPrompt string + PromptConfig *PromptConfig } // DefaultToolCallConfig returns default configuration for the tool-call loop. func DefaultToolCallConfig() ToolCallConfig { + pc := DefaultPromptConfig() return ToolCallConfig{ MaxTurns: 10, MaxToolCalls: 25, + PromptConfig: &pc, } } @@ -33,6 +38,19 @@ type ToolCallRecord struct { Turn int } +// PromptConfig customizes the tool-call loop's tool-facing prompt content. +type PromptConfig struct { + // ToolCallLimitReached is sent back to the model when it asks for more tool + // calls than MaxToolCalls allows. + ToolCallLimitReached string + // ToolErrorFormatter formats tool execution failures before they are sent + // back to the model. It may return a raw string or any JSON-marshable value. + ToolErrorFormatter func(toolName string, err error) interface{} + // ToolResultFormatter formats successful tool results before they are sent + // back to the model. It may return a raw string or any JSON-marshable value. + ToolResultFormatter func(toolName string, result map[string]interface{}) interface{} +} + // ToolCallTrace records the full trace of a tool-call loop. type ToolCallTrace struct { Calls []ToolCallRecord @@ -41,59 +59,48 @@ type ToolCallTrace struct { FinalResponse string } +// ToolCallResult wraps a tool-call response and its execution trace. +type ToolCallResult struct { + Response *Response + Trace *ToolCallTrace +} + +// Text returns the final text response from the tool-call loop. +func (r *ToolCallResult) Text() string { + if r == nil || r.Trace == nil { + return "" + } + return r.Trace.FinalResponse +} + // CallFunc is the function signature for dispatching tool calls. // It maps to agent.Call(ctx, target, input). type CallFunc func(ctx context.Context, target string, input map[string]interface{}) (map[string]interface{}, error) -// CapabilityToToolDefinition converts a ReasonerCapability to a ToolDefinition. +// DefaultPromptConfig returns the default prompt content for tool execution. +func DefaultPromptConfig() PromptConfig { + return PromptConfig{ + ToolCallLimitReached: "Tool call limit reached. Please provide a final response.", + ToolErrorFormatter: func(toolName string, err error) interface{} { + return map[string]string{ + "error": err.Error(), + "tool": toolName, + } + }, + ToolResultFormatter: func(_ string, result map[string]interface{}) interface{} { + return result + }, + } +} + +// CapabilityToToolDefinition converts a ReasonerCapability or SkillCapability +// to a ToolDefinition. func CapabilityToToolDefinition(cap interface{}) ToolDefinition { switch c := cap.(type) { case types.ReasonerCapability: - desc := "" - if c.Description != nil { - desc = *c.Description - } - if desc == "" { - desc = "Call " + c.InvocationTarget - } - params := c.InputSchema - if params == nil { - params = map[string]interface{}{"type": "object", "properties": map[string]interface{}{}} - } - if _, ok := params["type"]; !ok { - params = map[string]interface{}{"type": "object", "properties": params} - } - return ToolDefinition{ - Type: "function", - Function: ToolFunction{ - Name: c.InvocationTarget, - Description: desc, - Parameters: params, - }, - } + return capabilityToTool(c.InvocationTarget, c.Description, c.InputSchema) case types.SkillCapability: - desc := "" - if c.Description != nil { - desc = *c.Description - } - if desc == "" { - desc = "Call " + c.InvocationTarget - } - params := c.InputSchema - if params == nil { - params = map[string]interface{}{"type": "object", "properties": map[string]interface{}{}} - } - if _, ok := params["type"]; !ok { - params = map[string]interface{}{"type": "object", "properties": params} - } - return ToolDefinition{ - Type: "function", - Function: ToolFunction{ - Name: c.InvocationTarget, - Description: desc, - Parameters: params, - }, - } + return capabilityToTool(c.InvocationTarget, c.Description, c.InputSchema) default: return ToolDefinition{} } @@ -124,47 +131,61 @@ func (c *Client) ExecuteToolCallLoop( callFn CallFunc, opts ...Option, ) (*Response, *ToolCallTrace, error) { + result, err := c.ExecuteToolCallLoopResult(ctx, messages, tools, config, callFn, opts...) + if result == nil { + return nil, nil, err + } + return result.Response, result.Trace, err +} + +// ExecuteToolCallLoopResult runs the tool-call loop and returns a wrapped +// response plus the full execution trace. +func (c *Client) ExecuteToolCallLoopResult( + ctx context.Context, + messages []Message, + tools []ToolDefinition, + config ToolCallConfig, + callFn CallFunc, + opts ...Option, +) (*ToolCallResult, error) { trace := &ToolCallTrace{} + result := &ToolCallResult{Trace: trace} totalCalls := 0 + promptConfig := resolvePromptConfig(config.PromptConfig) + effectiveOpts := opts + if strings.TrimSpace(config.SystemPrompt) != "" { + effectiveOpts = append([]Option{WithSystem(config.SystemPrompt)}, opts...) + } + loopMessages := append([]Message(nil), messages...) for turn := 0; turn < config.MaxTurns; turn++ { trace.TotalTurns = turn + 1 - // Build request - req := &Request{ - Messages: messages, - Model: c.config.Model, - Temperature: &c.config.Temperature, - MaxTokens: &c.config.MaxTokens, - Tools: tools, - ToolChoice: "auto", - } - - for _, opt := range opts { - if err := opt(req); err != nil { - return nil, trace, fmt.Errorf("apply option: %w", err) - } + req, err := c.buildToolCallRequest(loopMessages, tools, true, effectiveOpts) + if err != nil { + return result, err } resp, err := c.doRequest(ctx, req) if err != nil { - return nil, trace, fmt.Errorf("LLM call failed: %w", err) + return result, fmt.Errorf("LLM call failed: %w", err) } if !resp.HasToolCalls() { + result.Response = resp trace.FinalResponse = resp.Text() - return resp, trace, nil + return result, nil } // Append assistant message with tool calls - messages = append(messages, resp.Choices[0].Message) + loopMessages = append(loopMessages, resp.Choices[0].Message) // Execute each tool call for _, tc := range resp.ToolCalls() { if totalCalls >= config.MaxToolCalls { - messages = append(messages, Message{ + loopMessages = append(loopMessages, Message{ Role: "tool", - Content: []ContentPart{{Type: "text", Text: `{"error": "Tool call limit reached. Please provide a final response."}`}}, + Content: []ContentPart{{Type: "text", Text: encodeToolContent(map[string]string{"error": promptConfig.ToolCallLimitReached})}}, ToolCallID: tc.ID, }) continue @@ -177,34 +198,30 @@ func (c *Client) ExecuteToolCallLoop( if err := json.Unmarshal([]byte(tc.Function.Arguments), &args); err != nil { args = map[string]interface{}{} } + toolName := unsanitizeToolName(tc.Function.Name) record := ToolCallRecord{ - ToolName: tc.Function.Name, + ToolName: toolName, Arguments: args, Turn: turn, } start := time.Now() - result, err := callFn(ctx, tc.Function.Name, args) + toolResult, err := callFn(ctx, toolName, args) record.LatencyMs = float64(time.Since(start).Milliseconds()) if err != nil { record.Error = err.Error() - errJSON, _ := json.Marshal(map[string]string{ - "error": err.Error(), - "tool": tc.Function.Name, - }) - messages = append(messages, Message{ + loopMessages = append(loopMessages, Message{ Role: "tool", - Content: []ContentPart{{Type: "text", Text: string(errJSON)}}, + Content: []ContentPart{{Type: "text", Text: encodeToolContent(promptConfig.ToolErrorFormatter(toolName, err))}}, ToolCallID: tc.ID, }) } else { - record.Result = result - resultJSON, _ := json.Marshal(result) - messages = append(messages, Message{ + record.Result = toolResult + loopMessages = append(loopMessages, Message{ Role: "tool", - Content: []ContentPart{{Type: "text", Text: string(resultJSON)}}, + Content: []ContentPart{{Type: "text", Text: encodeToolContent(promptConfig.ToolResultFormatter(toolName, toolResult))}}, ToolCallID: tc.ID, }) } @@ -214,43 +231,121 @@ func (c *Client) ExecuteToolCallLoop( // If tool call limit reached, make one final call without tools if totalCalls >= config.MaxToolCalls { - req := &Request{ - Messages: messages, - Model: c.config.Model, - Temperature: &c.config.Temperature, - MaxTokens: &c.config.MaxTokens, - } - for _, opt := range opts { - if err := opt(req); err != nil { - return nil, trace, fmt.Errorf("apply option: %w", err) - } + req, err := c.buildToolCallRequest(loopMessages, nil, false, effectiveOpts) + if err != nil { + return result, err } resp, err := c.doRequest(ctx, req) if err != nil { - return nil, trace, fmt.Errorf("final LLM call failed: %w", err) + return result, fmt.Errorf("final LLM call failed: %w", err) } + result.Response = resp trace.FinalResponse = resp.Text() - return resp, trace, nil + return result, nil } } // Max turns reached - make final call without tools + req, err := c.buildToolCallRequest(loopMessages, nil, false, effectiveOpts) + if err != nil { + return result, err + } + resp, err := c.doRequest(ctx, req) + if err != nil { + return result, fmt.Errorf("final LLM call failed: %w", err) + } + result.Response = resp + trace.FinalResponse = resp.Text() + trace.TotalTurns = config.MaxTurns + return result, nil +} + +func capabilityToTool(invocationTarget string, description *string, inputSchema map[string]interface{}) ToolDefinition { + desc := "" + if description != nil { + desc = *description + } + if desc == "" { + desc = "Call " + invocationTarget + } + + return ToolDefinition{ + Type: "function", + Function: ToolFunction{ + Name: sanitizeToolName(invocationTarget), + Description: desc, + Parameters: normalizeToolParameters(inputSchema), + }, + } +} + +func normalizeToolParameters(inputSchema map[string]interface{}) map[string]interface{} { + if inputSchema == nil { + return map[string]interface{}{"type": "object", "properties": map[string]interface{}{}} + } + if _, ok := inputSchema["type"]; ok { + return inputSchema + } + return map[string]interface{}{"type": "object", "properties": inputSchema} +} + +func sanitizeToolName(name string) string { + return strings.ReplaceAll(name, ":", "__") +} + +func unsanitizeToolName(name string) string { + return strings.ReplaceAll(name, "__", ":") +} + +func resolvePromptConfig(config *PromptConfig) PromptConfig { + resolved := DefaultPromptConfig() + if config == nil { + return resolved + } + if strings.TrimSpace(config.ToolCallLimitReached) != "" { + resolved.ToolCallLimitReached = config.ToolCallLimitReached + } + if config.ToolErrorFormatter != nil { + resolved.ToolErrorFormatter = config.ToolErrorFormatter + } + if config.ToolResultFormatter != nil { + resolved.ToolResultFormatter = config.ToolResultFormatter + } + return resolved +} + +func encodeToolContent(content interface{}) string { + switch v := content.(type) { + case string: + return v + case []byte: + return string(v) + default: + data, err := json.Marshal(v) + if err != nil { + return "{}" + } + return string(data) + } +} + +func (c *Client) buildToolCallRequest(messages []Message, tools []ToolDefinition, includeTools bool, opts []Option) (*Request, error) { req := &Request{ Messages: messages, Model: c.config.Model, Temperature: &c.config.Temperature, MaxTokens: &c.config.MaxTokens, } + if includeTools { + req.Tools = tools + req.ToolChoice = "auto" + } + for _, opt := range opts { if err := opt(req); err != nil { - return nil, trace, fmt.Errorf("apply option: %w", err) + return nil, fmt.Errorf("apply option: %w", err) } } - resp, err := c.doRequest(ctx, req) - if err != nil { - return nil, trace, fmt.Errorf("final LLM call failed: %w", err) - } - trace.FinalResponse = resp.Text() - trace.TotalTurns = config.MaxTurns - return resp, trace, nil + + return req, nil } diff --git a/sdk/go/ai/tool_calling_test.go b/sdk/go/ai/tool_calling_test.go index db0c710fb..831793c9d 100644 --- a/sdk/go/ai/tool_calling_test.go +++ b/sdk/go/ai/tool_calling_test.go @@ -1,10 +1,15 @@ package ai import ( + "context" "encoding/json" + "net/http" + "sync/atomic" "testing" "github.com/Agent-Field/agentfield/sdk/go/types" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestCapabilityToToolDefinition_Reasoner(t *testing.T) { @@ -50,6 +55,20 @@ func TestCapabilityToToolDefinition_Skill(t *testing.T) { } } +func TestCapabilityToToolDefinition_SanitizesColonTargets(t *testing.T) { + desc := "Lookup weather" + s := types.SkillCapability{ + ID: "lookup_weather", + Description: &desc, + InvocationTarget: "weather-agent:skill:lookup_weather", + } + + tool := CapabilityToToolDefinition(s) + if tool.Function.Name != "weather-agent__skill__lookup_weather" { + t.Errorf("expected sanitized name, got %q", tool.Function.Name) + } +} + func TestCapabilityToToolDefinition_NilSchema(t *testing.T) { r := types.ReasonerCapability{ ID: "test", @@ -104,6 +123,12 @@ func TestToolCallConfig_Defaults(t *testing.T) { if cfg.MaxToolCalls != 25 { t.Errorf("expected MaxToolCalls 25, got %d", cfg.MaxToolCalls) } + if cfg.PromptConfig == nil { + t.Fatal("expected PromptConfig to be initialized") + } + if cfg.PromptConfig.ToolCallLimitReached != "Tool call limit reached. Please provide a final response." { + t.Errorf("expected default prompt config, got %q", cfg.PromptConfig.ToolCallLimitReached) + } } func TestToolDefinition_JSONRoundTrip(t *testing.T) { @@ -151,3 +176,226 @@ func TestToolCallTrace(t *testing.T) { t.Errorf("expected no error, got %q", trace.Calls[0].Error) } } + +func TestToolCallResult(t *testing.T) { + resp := &Response{ + Choices: []Choice{{ + Message: Message{ + Role: "assistant", + Content: []ContentPart{{Type: "text", Text: "from response"}}, + }, + }}, + } + trace := &ToolCallTrace{ + FinalResponse: "from trace", + } + + result := &ToolCallResult{Response: resp, Trace: trace} + if result.Text() != "from trace" { + t.Fatalf("expected trace text, got %q", result.Text()) + } +} + +func TestSanitizeToolNameRoundTrip(t *testing.T) { + name := "worker:skill:lookup" + if got := sanitizeToolName(name); got != "worker__skill__lookup" { + t.Fatalf("expected sanitized name, got %q", got) + } + if got := unsanitizeToolName("worker__skill__lookup"); got != name { + t.Fatalf("expected unsanitized name %q, got %q", name, got) + } +} + +func TestExecuteToolCallLoopResult_UsesSystemPromptAndUnsanitizesToolNames(t *testing.T) { + var requestCount atomic.Int32 + client := newToolLoopClient(t, func(w http.ResponseWriter, r *http.Request) { + count := requestCount.Add(1) + var req Request + require.NoError(t, json.NewDecoder(r.Body).Decode(&req)) + + switch count { + case 1: + require.Len(t, req.Messages, 2) + assert.Equal(t, "system", req.Messages[0].Role) + assert.Equal(t, "Use tools carefully.", req.Messages[0].Content[0].Text) + require.Len(t, req.Tools, 1) + assert.Equal(t, "worker__skill__lookup", req.Tools[0].Function.Name) + require.NoError(t, json.NewEncoder(w).Encode(Response{ + Choices: []Choice{{ + Message: Message{ + Role: "assistant", + ToolCalls: []ToolCall{{ + ID: "call-1", + Type: "function", + Function: ToolCallFunction{ + Name: "worker__skill__lookup", + Arguments: `{"ticket":"123"}`, + }, + }}, + }, + FinishReason: "tool_calls", + }}, + })) + case 2: + require.Len(t, req.Messages, 4) + assert.Equal(t, "system", req.Messages[0].Role) + assert.Equal(t, "tool", req.Messages[3].Role) + require.NoError(t, json.NewEncoder(w).Encode(Response{ + Choices: []Choice{{ + Message: Message{ + Role: "assistant", + Content: []ContentPart{{Type: "text", Text: "resolved"}}, + }, + FinishReason: "stop", + }}, + })) + default: + t.Fatalf("unexpected request %d", count) + } + }) + + result, err := client.ExecuteToolCallLoopResult( + context.Background(), + []Message{{Role: "user", Content: []ContentPart{{Type: "text", Text: "Find ticket 123"}}}}, + []ToolDefinition{{Type: "function", Function: ToolFunction{Name: "worker__skill__lookup", Parameters: map[string]interface{}{"type": "object"}}}}, + ToolCallConfig{MaxTurns: 3, MaxToolCalls: 2, SystemPrompt: "Use tools carefully."}, + func(_ context.Context, target string, input map[string]interface{}) (map[string]interface{}, error) { + assert.Equal(t, "worker:skill:lookup", target) + assert.Equal(t, map[string]interface{}{"ticket": "123"}, input) + return map[string]interface{}{"status": "open"}, nil + }, + ) + + require.NoError(t, err) + require.NotNil(t, result) + assert.Equal(t, "resolved", result.Text()) + require.NotNil(t, result.Response) + require.NotNil(t, result.Trace) + assert.Equal(t, "worker:skill:lookup", result.Trace.Calls[0].ToolName) +} + +func TestExecuteToolCallLoopResult_AppliesPromptConfig(t *testing.T) { + t.Run("custom limit message", func(t *testing.T) { + var requestCount atomic.Int32 + client := newToolLoopClient(t, func(w http.ResponseWriter, r *http.Request) { + count := requestCount.Add(1) + var req Request + require.NoError(t, json.NewDecoder(r.Body).Decode(&req)) + + switch count { + case 1: + require.Len(t, req.Tools, 1) + require.NoError(t, json.NewEncoder(w).Encode(Response{ + Choices: []Choice{{ + Message: Message{ + Role: "assistant", + ToolCalls: []ToolCall{ + { + ID: "call-1", + Type: "function", + Function: ToolCallFunction{Name: "lookup", Arguments: `{"id":"1"}`}, + }, + { + ID: "call-2", + Type: "function", + Function: ToolCallFunction{Name: "lookup", Arguments: `{"id":"2"}`}, + }, + }, + }, + }}, + })) + case 2: + require.Len(t, req.Messages, 4) + var toolMessage map[string]string + require.NoError(t, json.Unmarshal([]byte(req.Messages[3].Content[0].Text), &toolMessage)) + assert.Equal(t, "custom limit reached", toolMessage["error"]) + require.NoError(t, json.NewEncoder(w).Encode(Response{ + Choices: []Choice{{Message: Message{ + Role: "assistant", + Content: []ContentPart{{Type: "text", Text: "final after limit"}}, + }}}, + })) + default: + t.Fatalf("unexpected request %d", count) + } + }) + + resp, trace, err := client.ExecuteToolCallLoop( + context.Background(), + []Message{{Role: "user", Content: []ContentPart{{Type: "text", Text: "lookup"}}}}, + []ToolDefinition{{Type: "function", Function: ToolFunction{Name: "lookup"}}}, + ToolCallConfig{ + MaxTurns: 3, + MaxToolCalls: 1, + PromptConfig: &PromptConfig{ToolCallLimitReached: "custom limit reached"}, + }, + func(_ context.Context, target string, input map[string]interface{}) (map[string]interface{}, error) { + assert.Equal(t, "lookup", target) + return map[string]interface{}{"ok": true}, nil + }, + ) + + require.NoError(t, err) + assert.Equal(t, "final after limit", resp.Text()) + assert.Equal(t, "final after limit", trace.FinalResponse) + }) + + t.Run("custom result formatter", func(t *testing.T) { + var requestCount atomic.Int32 + client := newToolLoopClient(t, func(w http.ResponseWriter, r *http.Request) { + count := requestCount.Add(1) + if count == 1 { + require.NoError(t, json.NewEncoder(w).Encode(Response{ + Choices: []Choice{{ + Message: Message{ + Role: "assistant", + ToolCalls: []ToolCall{{ + ID: "call-1", + Type: "function", + Function: ToolCallFunction{ + Name: "worker__skill__lookup", + Arguments: `{"id":"1"}`, + }, + }}, + }, + }}, + })) + return + } + + var req Request + require.NoError(t, json.NewDecoder(r.Body).Decode(&req)) + require.Len(t, req.Messages, 3) + assert.Equal(t, "tool=worker:skill:lookup status=open", req.Messages[2].Content[0].Text) + require.NoError(t, json.NewEncoder(w).Encode(Response{ + Choices: []Choice{{Message: Message{ + Role: "assistant", + Content: []ContentPart{{Type: "text", Text: "done"}}, + }}}, + })) + }) + + result, err := client.ExecuteToolCallLoopResult( + context.Background(), + []Message{{Role: "user", Content: []ContentPart{{Type: "text", Text: "lookup"}}}}, + []ToolDefinition{{Type: "function", Function: ToolFunction{Name: "worker__skill__lookup"}}}, + ToolCallConfig{ + MaxTurns: 2, + MaxToolCalls: 2, + PromptConfig: &PromptConfig{ + ToolResultFormatter: func(toolName string, result map[string]interface{}) interface{} { + return "tool=" + toolName + " status=" + result["status"].(string) + }, + }, + }, + func(_ context.Context, target string, input map[string]interface{}) (map[string]interface{}, error) { + assert.Equal(t, "worker:skill:lookup", target) + assert.Equal(t, map[string]interface{}{"id": "1"}, input) + return map[string]interface{}{"status": "open"}, nil + }, + ) + + require.NoError(t, err) + assert.Equal(t, "done", result.Text()) + }) +} From b597071172f57b3505d31d139ae6771c33f9d49f Mon Sep 17 00:00:00 2001 From: Mithilesh Singh Date: Sun, 19 Apr 2026 11:22:54 +0000 Subject: [PATCH 2/3] Added more unit tests for coverage(#234) --- sdk/go/agent/agent_test.go | 27 ++++++ sdk/go/ai/tool_calling_test.go | 150 +++++++++++++++++++++++++++++++++ 2 files changed, 177 insertions(+) diff --git a/sdk/go/agent/agent_test.go b/sdk/go/agent/agent_test.go index a038d66dc..412dac41e 100644 --- a/sdk/go/agent/agent_test.go +++ b/sdk/go/agent/agent_test.go @@ -1444,3 +1444,30 @@ func TestCallLocalUnknownReasoner(t *testing.T) { assert.Error(t, err) assert.Contains(t, err.Error(), "unknown reasoner") } + +func TestCall_TargetPrefixing(t *testing.T) { + var capturedPath string + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + capturedPath = r.URL.Path + + resp := map[string]any{ + "status": "succeeded", + "result": map[string]any{"ok": true}, + } + _ = json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + agent, _ := New(Config{ + NodeID: "node-1", + Version: "1.0.0", + AgentFieldURL: server.URL, + Logger: log.New(io.Discard, "", 0), + }) + + _, err := agent.Call(context.Background(), "lookup", nil) + require.NoError(t, err) + + assert.Contains(t, capturedPath, "/execute/node-1.lookup") +} \ No newline at end of file diff --git a/sdk/go/ai/tool_calling_test.go b/sdk/go/ai/tool_calling_test.go index 831793c9d..4bff420ac 100644 --- a/sdk/go/ai/tool_calling_test.go +++ b/sdk/go/ai/tool_calling_test.go @@ -399,3 +399,153 @@ func TestExecuteToolCallLoopResult_AppliesPromptConfig(t *testing.T) { assert.Equal(t, "done", result.Text()) }) } + +func TestEncodeToolContent_AllCases(t *testing.T) { + // string + assert.Equal(t, "hello", encodeToolContent("hello")) + + // []byte + assert.Equal(t, "hi", encodeToolContent([]byte("hi"))) + + // JSON + out := encodeToolContent(map[string]string{"a": "b"}) + assert.Contains(t, out, `"a":"b"`) + + // invalid (marshal error) + ch := make(chan int) + assert.Equal(t, "{}", encodeToolContent(ch)) +} + +func TestNormalizeToolParameters_AllCases(t *testing.T) { + // nil + out := normalizeToolParameters(nil) + assert.Equal(t, "object", out["type"]) + + // already valid + in := map[string]interface{}{"type": "object"} + out = normalizeToolParameters(in) + assert.Equal(t, in, out) + + // missing type + in = map[string]interface{}{"field": "value"} + out = normalizeToolParameters(in) + assert.Equal(t, "object", out["type"]) +} + +func TestExecuteToolCallLoopResult_ErrorFormatter(t *testing.T) { + var requestCount atomic.Int32 + + client := newToolLoopClient(t, func(w http.ResponseWriter, r *http.Request) { + count := requestCount.Add(1) + + if count == 1 { + require.NoError(t, json.NewEncoder(w).Encode(Response{ + Choices: []Choice{{ + Message: Message{ + Role: "assistant", + ToolCalls: []ToolCall{{ + ID: "call-1", + Type: "function", + Function: ToolCallFunction{ + Name: "lookup", + Arguments: `{"id":"1"}`, + }, + }}, + }, + }}, + })) + return + } + + var req Request + require.NoError(t, json.NewDecoder(r.Body).Decode(&req)) + + assert.Contains(t, req.Messages[2].Content[0].Text, "custom-error") + + require.NoError(t, json.NewEncoder(w).Encode(Response{ + Choices: []Choice{{Message: Message{ + Role: "assistant", + Content: []ContentPart{{Type: "text", Text: "done"}}, + }}}, + })) + }) + + _, err := client.ExecuteToolCallLoopResult( + context.Background(), + []Message{{Role: "user", Content: []ContentPart{{Type: "text", Text: "lookup"}}}}, + []ToolDefinition{{Type: "function", Function: ToolFunction{Name: "lookup"}}}, + ToolCallConfig{ + MaxTurns: 2, + MaxToolCalls: 2, + PromptConfig: &PromptConfig{ + ToolErrorFormatter: func(tool string, err error) interface{} { + return map[string]string{"error": "custom-error"} + }, + }, + }, + func(context.Context, string, map[string]interface{}) (map[string]interface{}, error) { + return nil, assert.AnError + }, + ) + + require.NoError(t, err) +} + +func TestExecuteToolCallLoopResult_NoToolCalls(t *testing.T) { + client := newToolLoopClient(t, func(w http.ResponseWriter, r *http.Request) { + require.NoError(t, json.NewEncoder(w).Encode(Response{ + Choices: []Choice{{ + Message: Message{ + Role: "assistant", + Content: []ContentPart{{Type: "text", Text: "direct"}}, + }, + FinishReason: "stop", + }}, + })) + }) + + result, err := client.ExecuteToolCallLoopResult( + context.Background(), + []Message{{Role: "user", Content: []ContentPart{{Type: "text", Text: "hi"}}}}, + nil, + DefaultToolCallConfig(), + nil, + ) + + require.NoError(t, err) + assert.Equal(t, "direct", result.Text()) +} + +func TestSanitizeToolName_NoChange(t *testing.T) { + name := "simpletool" + assert.Equal(t, name, sanitizeToolName(name)) +} + +func TestUnsanitizeToolName_Standalone(t *testing.T) { + assert.Equal(t, "worker:skill:lookup", unsanitizeToolName("worker__skill__lookup")) +} +func TestSanitizeToolName_MultipleColons(t *testing.T) { + name := "a:b:c:d" + sanitized := sanitizeToolName(name) + assert.Equal(t, "a__b__c__d", sanitized) + assert.Equal(t, name, unsanitizeToolName(sanitized)) +} + +func TestResolvePromptConfig_Nil(t *testing.T) { + cfg := resolvePromptConfig(nil) + assert.NotNil(t, cfg) + assert.NotEmpty(t, cfg.ToolCallLimitReached) +} + +func TestToolCallResult_FallbackToResponse(t *testing.T) { + resp := &Response{ + Choices: []Choice{{ + Message: Message{ + Content: []ContentPart{{Type: "text", Text: "from response"}}, + }, + }}, + } + + result := &ToolCallResult{Response: resp, Trace: nil} + assert.Equal(t, "from response", result.Text()) +} \ No newline at end of file From a88d3f451b491756789790b6f731424824112303 Mon Sep 17 00:00:00 2001 From: Mithilesh Singh Date: Sun, 19 Apr 2026 11:28:24 +0000 Subject: [PATCH 3/3] Added more unit tests for coverage(#234) --- sdk/go/ai/tool_calling_test.go | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/sdk/go/ai/tool_calling_test.go b/sdk/go/ai/tool_calling_test.go index 4bff420ac..b01d1fba3 100644 --- a/sdk/go/ai/tool_calling_test.go +++ b/sdk/go/ai/tool_calling_test.go @@ -535,17 +535,4 @@ func TestResolvePromptConfig_Nil(t *testing.T) { cfg := resolvePromptConfig(nil) assert.NotNil(t, cfg) assert.NotEmpty(t, cfg.ToolCallLimitReached) -} - -func TestToolCallResult_FallbackToResponse(t *testing.T) { - resp := &Response{ - Choices: []Choice{{ - Message: Message{ - Content: []ContentPart{{Type: "text", Text: "from response"}}, - }, - }}, - } - - result := &ToolCallResult{Response: resp, Trace: nil} - assert.Equal(t, "from response", result.Text()) } \ No newline at end of file