diff --git a/internal/runtime/events.go b/internal/runtime/events.go index ff1fa660..62f3f202 100644 --- a/internal/runtime/events.go +++ b/internal/runtime/events.go @@ -227,6 +227,29 @@ type RepositoryContextUnavailablePayload struct { Reason string `json:"reason"` } +// HookEventPayload 描述 hook 生命周期事件负载。 +type HookEventPayload struct { + HookID string `json:"hook_id"` + Point string `json:"point"` + Scope string `json:"scope"` + Kind string `json:"kind"` + Mode string `json:"mode"` + Status string `json:"status,omitempty"` + StartedAt time.Time `json:"started_at,omitempty"` + DurationMS int64 `json:"duration_ms,omitempty"` + Error string `json:"error,omitempty"` +} + +// HookBlockedPayload 描述 hook 阻断事件负载。 +type HookBlockedPayload struct { + HookID string `json:"hook_id"` + Point string `json:"point"` + ToolCallID string `json:"tool_call_id,omitempty"` + ToolName string `json:"tool_name,omitempty"` + Reason string `json:"reason,omitempty"` + Enforced bool `json:"enforced"` +} + const ( // EventUserMessage 表示用户消息已写入会话。 EventUserMessage EventType = "user_message" @@ -302,6 +325,14 @@ const ( EventAssetSaveFailed EventType = "asset_save_failed" // EventRepositoryContextUnavailable 表示本轮 repository 事实本应获取但失败,已降级为空上下文。 EventRepositoryContextUnavailable EventType = "repository_context_unavailable" + // EventHookStarted 表示 hook 执行开始。 + EventHookStarted EventType = "hook_started" + // EventHookFinished 表示 hook 执行结束。 + EventHookFinished EventType = "hook_finished" + // EventHookFailed 表示 hook 执行失败。 + EventHookFailed EventType = "hook_failed" + // EventHookBlocked 表示某个 hook 返回 block(是否生效由 payload.enforced 决定)。 + EventHookBlocked EventType = "hook_blocked" ) // TokenUsagePayload 承载单轮 token 用量统计。 diff --git a/internal/runtime/hooks_integration.go b/internal/runtime/hooks_integration.go new file mode 100644 index 00000000..ee235291 --- /dev/null +++ b/internal/runtime/hooks_integration.go @@ -0,0 +1,170 @@ +package runtime + +import ( + "context" + "strings" + + runtimehooks "neo-code/internal/runtime/hooks" +) + +const ( + // hookErrorClassBlocked 标识由 before_tool_call hook 拦截产生的工具错误分类。 + hookErrorClassBlocked = "hook_blocked" +) + +type hookContextKey string + +const hookRuntimeEnvelopeKey hookContextKey = "runtime_hook_envelope" + +type hookRuntimeEnvelope struct { + RunID string + SessionID string + Turn int + Phase string +} + +// HookExecutor 定义 runtime 调用 hook 的最小执行契约。 +type HookExecutor interface { + Run(ctx context.Context, point runtimehooks.HookPoint, input runtimehooks.HookContext) runtimehooks.RunOutput +} + +type hookRuntimeEventEmitter struct { + service *Service +} + +func newHookRuntimeEventEmitter(service *Service) *hookRuntimeEventEmitter { + return &hookRuntimeEventEmitter{service: service} +} + +// EmitHookEvent 将 hooks 包内事件桥接为 runtime 事件,供 TUI 与日志统一消费。 +func (e *hookRuntimeEventEmitter) EmitHookEvent(ctx context.Context, event runtimehooks.HookEvent) error { + if e == nil || e.service == nil { + return nil + } + envelope, _ := runtimeHookEnvelopeFromContext(ctx) + kind := EventType(strings.TrimSpace(string(event.Type))) + if kind == "" { + return nil + } + return e.service.emitWithEnvelope(ctx, RuntimeEvent{ + Type: kind, + RunID: envelope.RunID, + SessionID: envelope.SessionID, + Turn: envelope.Turn, + Phase: envelope.Phase, + PayloadVersion: 0, + Payload: HookEventPayload{ + HookID: event.HookID, + Point: string(event.Point), + Scope: string(event.Scope), + Kind: string(event.Kind), + Mode: string(event.Mode), + Status: string(event.Status), + StartedAt: event.StartedAt, + DurationMS: event.DurationMS, + Error: event.Error, + }, + }) +} + +// runHookPoint 在指定运行态上下文执行一个 hook 点,并自动注入 run/session 元数据。 +func (s *Service) runHookPoint( + ctx context.Context, + state *runState, + point runtimehooks.HookPoint, + input runtimehooks.HookContext, +) runtimehooks.RunOutput { + if s == nil || s.hookExecutor == nil { + return runtimehooks.RunOutput{} + } + input.RunID = firstNonBlank(input.RunID, hookRunIDFromState(state)) + input.SessionID = firstNonBlank(input.SessionID, hookSessionIDFromState(state)) + scopedCtx := withRuntimeHookEnvelope(ctx, hookRuntimeEnvelope{ + RunID: hookRunIDFromState(state), + SessionID: hookSessionIDFromState(state), + Turn: hookTurnFromState(state), + Phase: hookPhaseFromState(state), + }) + return s.hookExecutor.Run(scopedCtx, point, input) +} + +func withRuntimeHookEnvelope(ctx context.Context, envelope hookRuntimeEnvelope) context.Context { + if ctx == nil { + ctx = context.Background() + } + return context.WithValue(ctx, hookRuntimeEnvelopeKey, envelope) +} + +func runtimeHookEnvelopeFromContext(ctx context.Context) (hookRuntimeEnvelope, bool) { + if ctx == nil { + return hookRuntimeEnvelope{}, false + } + raw := ctx.Value(hookRuntimeEnvelopeKey) + envelope, ok := raw.(hookRuntimeEnvelope) + return envelope, ok +} + +func hookRunIDFromState(state *runState) string { + if state == nil { + return "" + } + return strings.TrimSpace(state.runID) +} + +func hookSessionIDFromState(state *runState) string { + if state == nil { + return "" + } + return strings.TrimSpace(state.session.ID) +} + +func hookTurnFromState(state *runState) int { + if state == nil { + return turnUnspecified + } + return state.turn +} + +func hookPhaseFromState(state *runState) string { + if state == nil { + return "" + } + if state.lifecycle == "" { + return "" + } + return string(state.lifecycle) +} + +func firstNonBlank(values ...string) string { + for _, value := range values { + trimmed := strings.TrimSpace(value) + if trimmed != "" { + return trimmed + } + } + return "" +} + +func findHookBlockMessage(output runtimehooks.RunOutput) string { + if !output.Blocked { + return "" + } + for _, result := range output.Results { + if !strings.EqualFold(strings.TrimSpace(result.HookID), strings.TrimSpace(output.BlockedBy)) { + continue + } + message := strings.TrimSpace(result.Message) + if message != "" { + return message + } + errText := strings.TrimSpace(result.Error) + if errText != "" { + return errText + } + break + } + if blockedBy := strings.TrimSpace(output.BlockedBy); blockedBy != "" { + return "hook blocked by " + blockedBy + } + return "hook blocked" +} diff --git a/internal/runtime/hooks_integration_test.go b/internal/runtime/hooks_integration_test.go new file mode 100644 index 00000000..48e6ec5f --- /dev/null +++ b/internal/runtime/hooks_integration_test.go @@ -0,0 +1,422 @@ +package runtime + +import ( + "context" + "errors" + "strings" + "sync" + "testing" + "time" + + providertypes "neo-code/internal/provider/types" + approvalflow "neo-code/internal/runtime/approval" + "neo-code/internal/runtime/controlplane" + runtimehooks "neo-code/internal/runtime/hooks" + "neo-code/internal/tools" +) + +func TestExecuteOneToolCallBlocksWhenBeforeToolHookReturnsBlock(t *testing.T) { + t.Parallel() + + store := newMemoryStore() + session := newRuntimeSession("session-hook-before-tool-block") + store.sessions[session.ID] = cloneSession(session) + + toolManager := &stubToolManager{ + result: tools.ToolResult{Name: "filesystem_read_file", Content: "should not execute"}, + } + service := &Service{ + sessionStore: store, + toolManager: toolManager, + approvalBroker: approvalflow.NewBroker(), + events: make(chan RuntimeEvent, 32), + } + state := newRunState("run-hook-before-tool-block", session) + + registry := runtimehooks.NewRegistry() + if err := registry.Register(runtimehooks.HookSpec{ + ID: "block-before-tool", + Point: runtimehooks.HookPointBeforeToolCall, + Handler: func(ctx context.Context, input runtimehooks.HookContext) runtimehooks.HookResult { + return runtimehooks.HookResult{Status: runtimehooks.HookResultBlock, Message: "blocked by test hook"} + }, + }); err != nil { + t.Fatalf("register hook: %v", err) + } + service.SetHookExecutor(runtimehooks.NewExecutor(registry, newHookRuntimeEventEmitter(service), time.Second)) + + result, wrote, err := service.executeOneToolCall( + context.Background(), + &state, + TurnBudgetSnapshot{Workdir: t.TempDir(), ToolTimeout: time.Second}, + providertypes.ToolCall{ID: "call-1", Name: "filesystem_read_file", Arguments: `{"path":"README.md"}`}, + &sync.Mutex{}, + func() bool { return false }, + ) + if err != nil { + t.Fatalf("executeOneToolCall() error = %v", err) + } + if wrote { + t.Fatalf("executeOneToolCall() wrote = true, want false") + } + if !result.IsError { + t.Fatalf("tool result should be error when blocked by hook") + } + if result.ErrorClass != hookErrorClassBlocked { + t.Fatalf("result.ErrorClass = %q, want %q", result.ErrorClass, hookErrorClassBlocked) + } + + toolManager.mu.Lock() + executeCalls := toolManager.executeCalls + toolManager.mu.Unlock() + if executeCalls != 0 { + t.Fatalf("tool manager execute calls = %d, want 0", executeCalls) + } + + events := collectRuntimeEvents(service.Events()) + assertEventContains(t, events, EventHookStarted) + assertEventContains(t, events, EventHookFinished) + assertEventContains(t, events, EventHookBlocked) + assertEventContains(t, events, EventToolResult) + assertNoEventType(t, events, EventToolStart) + if eventIndex(events, EventHookBlocked) > eventIndex(events, EventToolResult) { + t.Fatalf("hook_blocked should be emitted before tool_result") + } + + hookStartedIndex := eventIndex(events, EventHookStarted) + if hookStartedIndex >= 0 { + started := events[hookStartedIndex] + if started.RunID != state.runID { + t.Fatalf("hook_started run id = %q, want %q", started.RunID, state.runID) + } + if started.SessionID != state.session.ID { + t.Fatalf("hook_started session id = %q, want %q", started.SessionID, state.session.ID) + } + } +} + +func TestExecuteOneToolCallTriggersAfterToolResultHookWithoutMutatingResult(t *testing.T) { + t.Parallel() + + store := newMemoryStore() + session := newRuntimeSession("session-hook-after-tool-result") + store.sessions[session.ID] = cloneSession(session) + + toolManager := &stubToolManager{ + result: tools.ToolResult{Name: "filesystem_read_file", Content: "ok"}, + } + service := &Service{ + sessionStore: store, + toolManager: toolManager, + approvalBroker: approvalflow.NewBroker(), + events: make(chan RuntimeEvent, 32), + } + state := newRunState("run-hook-after-tool-result", session) + + var ( + called bool + metadata map[string]any + ) + registry := runtimehooks.NewRegistry() + if err := registry.Register(runtimehooks.HookSpec{ + ID: "observe-after-tool", + Point: runtimehooks.HookPointAfterToolResult, + Handler: func(ctx context.Context, input runtimehooks.HookContext) runtimehooks.HookResult { + called = true + metadata = input.Metadata + return runtimehooks.HookResult{Status: runtimehooks.HookResultPass} + }, + }); err != nil { + t.Fatalf("register hook: %v", err) + } + service.SetHookExecutor(runtimehooks.NewExecutor(registry, newHookRuntimeEventEmitter(service), time.Second)) + + result, _, err := service.executeOneToolCall( + context.Background(), + &state, + TurnBudgetSnapshot{Workdir: t.TempDir(), ToolTimeout: time.Second}, + providertypes.ToolCall{ID: "call-2", Name: "filesystem_read_file", Arguments: `{"path":"README.md"}`}, + &sync.Mutex{}, + func() bool { return false }, + ) + if err != nil { + t.Fatalf("executeOneToolCall() error = %v", err) + } + if !called { + t.Fatalf("after_tool_result hook should be called") + } + if got := result.Content; got != "ok" { + t.Fatalf("tool result content = %q, want %q", got, "ok") + } + if got := metadata["result_content_preview"]; got != "ok" { + t.Fatalf("result_content_preview = %#v, want %q", got, "ok") + } +} + +func TestExecuteOneToolCallCanceledStillTriggersAfterToolResultHook(t *testing.T) { + t.Parallel() + + store := newMemoryStore() + session := newRuntimeSession("session-hook-after-tool-result-canceled") + store.sessions[session.ID] = cloneSession(session) + + toolManager := &stubToolManager{ + executeFn: func(ctx context.Context, input tools.ToolCallInput) (tools.ToolResult, error) { + return tools.ToolResult{Name: input.Name}, context.Canceled + }, + } + service := &Service{ + sessionStore: store, + toolManager: toolManager, + approvalBroker: approvalflow.NewBroker(), + events: make(chan RuntimeEvent, 32), + } + state := newRunState("run-hook-after-tool-result-canceled", session) + + var ( + called bool + errMsg string + ) + registry := runtimehooks.NewRegistry() + if err := registry.Register(runtimehooks.HookSpec{ + ID: "observe-after-tool-canceled", + Point: runtimehooks.HookPointAfterToolResult, + Handler: func(ctx context.Context, input runtimehooks.HookContext) runtimehooks.HookResult { + called = true + if raw, ok := input.Metadata["execution_error"]; ok { + if text, ok := raw.(string); ok { + errMsg = text + } + } + return runtimehooks.HookResult{Status: runtimehooks.HookResultPass} + }, + }); err != nil { + t.Fatalf("register hook: %v", err) + } + service.SetHookExecutor(runtimehooks.NewExecutor(registry, newHookRuntimeEventEmitter(service), time.Second)) + + _, _, err := service.executeOneToolCall( + context.Background(), + &state, + TurnBudgetSnapshot{Workdir: t.TempDir(), ToolTimeout: time.Second}, + providertypes.ToolCall{ID: "call-3", Name: "filesystem_read_file", Arguments: `{"path":"README.md"}`}, + &sync.Mutex{}, + func() bool { return false }, + ) + if !errors.Is(err, context.Canceled) { + t.Fatalf("executeOneToolCall() error = %v, want context.Canceled", err) + } + if !called { + t.Fatalf("after_tool_result hook should be called when tool execution is canceled") + } + if errMsg == "" { + t.Fatalf("expected execution_error metadata for canceled execution") + } +} + +func TestRunBeforeCompletionDecisionHookBlockIsObservedOnly(t *testing.T) { + t.Parallel() + + manager := newRuntimeConfigManager(t) + store := newMemoryStore() + scripted := &scriptedProvider{ + streams: [][]providertypes.StreamEvent{ + { + providertypes.NewTextDeltaStreamEvent("final answer"), + providertypes.NewMessageDoneStreamEvent("", nil), + }, + }, + } + service := NewWithFactory(manager, &stubToolManager{}, store, &scriptedProviderFactory{provider: scripted}, &stubContextBuilder{}) + + registry := runtimehooks.NewRegistry() + if err := registry.Register(runtimehooks.HookSpec{ + ID: "block-before-completion", + Point: runtimehooks.HookPointBeforeCompletionDecision, + Handler: func(ctx context.Context, input runtimehooks.HookContext) runtimehooks.HookResult { + return runtimehooks.HookResult{Status: runtimehooks.HookResultBlock, Message: "blocked but non-authoritative"} + }, + }); err != nil { + t.Fatalf("register hook: %v", err) + } + service.SetHookExecutor(runtimehooks.NewExecutor(registry, newHookRuntimeEventEmitter(service), time.Second)) + + if err := service.Run(context.Background(), UserInput{ + RunID: "run-hook-before-completion", + Parts: []providertypes.ContentPart{providertypes.NewTextPart("hello")}, + }); err != nil { + t.Fatalf("Run() error = %v", err) + } + + events := collectRuntimeEvents(service.Events()) + assertEventContains(t, events, EventHookBlocked) + assertEventContains(t, events, EventAgentDone) + if eventIndex(events, EventHookBlocked) > eventIndex(events, EventVerificationStarted) { + t.Fatalf("before_completion_decision hook_blocked should be emitted before verification_started") + } + + blockedIndex := eventIndex(events, EventHookBlocked) + if blockedIndex >= 0 { + payload, ok := events[blockedIndex].Payload.(HookBlockedPayload) + if !ok { + t.Fatalf("hook_blocked payload type = %T, want HookBlockedPayload", events[blockedIndex].Payload) + } + if payload.Enforced { + t.Fatalf("before_completion_decision block should be observed only, got enforced=true") + } + if payload.Point != string(runtimehooks.HookPointBeforeCompletionDecision) { + t.Fatalf("payload.Point = %q, want %q", payload.Point, runtimehooks.HookPointBeforeCompletionDecision) + } + } +} + +func TestHookIntegrationHelpersBranches(t *testing.T) { + t.Parallel() + + if got := firstNonBlank(" ", "\n", "value", "ignored"); got != "value" { + t.Fatalf("firstNonBlank() = %q, want value", got) + } + if got := firstNonBlank(" ", "\n"); got != "" { + t.Fatalf("firstNonBlank() = %q, want empty", got) + } + + if got := findHookBlockMessage(runtimehooks.RunOutput{}); got != "" { + t.Fatalf("findHookBlockMessage() for non-blocked output = %q, want empty", got) + } + if got := findHookBlockMessage(runtimehooks.RunOutput{ + Blocked: true, + BlockedBy: "hook-1", + Results: []runtimehooks.HookResult{{HookID: "hook-1", Message: " denied "}}, + }); got != "denied" { + t.Fatalf("findHookBlockMessage() from message = %q, want denied", got) + } + if got := findHookBlockMessage(runtimehooks.RunOutput{ + Blocked: true, + BlockedBy: "hook-2", + Results: []runtimehooks.HookResult{{HookID: "hook-2", Error: " failed "}}, + }); got != "failed" { + t.Fatalf("findHookBlockMessage() from error = %q, want failed", got) + } + if got := findHookBlockMessage(runtimehooks.RunOutput{ + Blocked: true, + BlockedBy: "hook-3", + Results: []runtimehooks.HookResult{{HookID: "other", Message: "ignored"}}, + }); got != "hook blocked by hook-3" { + t.Fatalf("findHookBlockMessage() fallback by hook id = %q", got) + } + if got := findHookBlockMessage(runtimehooks.RunOutput{ + Blocked: true, + Results: []runtimehooks.HookResult{{HookID: "other", Message: "ignored"}}, + }); got != "hook blocked" { + t.Fatalf("findHookBlockMessage() default fallback = %q", got) + } + + wrapped := withRuntimeHookEnvelope(nil, hookRuntimeEnvelope{RunID: "run-1"}) + envelope, ok := runtimeHookEnvelopeFromContext(wrapped) + if !ok || envelope.RunID != "run-1" { + t.Fatalf("runtimeHookEnvelopeFromContext() = (%+v,%v), want run-1", envelope, ok) + } + if _, ok := runtimeHookEnvelopeFromContext(nil); ok { + t.Fatalf("runtimeHookEnvelopeFromContext(nil) should return ok=false") + } + if _, ok := runtimeHookEnvelopeFromContext(context.Background()); ok { + t.Fatalf("runtimeHookEnvelopeFromContext(background) should return ok=false") + } + + state := newRunState(" run-id ", newRuntimeSession("session-x")) + state.turn = 3 + if got := hookRunIDFromState(&state); got != "run-id" { + t.Fatalf("hookRunIDFromState() = %q", got) + } + if got := hookSessionIDFromState(&state); got != "session-x" { + t.Fatalf("hookSessionIDFromState() = %q", got) + } + if got := hookTurnFromState(&state); got != 3 { + t.Fatalf("hookTurnFromState() = %d", got) + } + if got := hookPhaseFromState(&state); got != "" { + t.Fatalf("hookPhaseFromState() without lifecycle = %q, want empty", got) + } + state.lifecycle = controlplane.RunStateExecute + if got := hookPhaseFromState(&state); got != string(controlplane.RunStateExecute) { + t.Fatalf("hookPhaseFromState() with lifecycle = %q", got) + } + if got := hookRunIDFromState(nil); got != "" { + t.Fatalf("hookRunIDFromState(nil) = %q, want empty", got) + } + if got := hookSessionIDFromState(nil); got != "" { + t.Fatalf("hookSessionIDFromState(nil) = %q, want empty", got) + } + if got := hookTurnFromState(nil); got != turnUnspecified { + t.Fatalf("hookTurnFromState(nil) = %d, want %d", got, turnUnspecified) + } +} + +func TestSummarizeHookResultContentTruncatesLongContent(t *testing.T) { + t.Parallel() + + if got := summarizeHookResultContent(" short "); got != "short" { + t.Fatalf("summarizeHookResultContent() short = %q", got) + } + long := strings.Repeat("x", 300) + got := summarizeHookResultContent(long) + if len(got) != 256 { + t.Fatalf("summarizeHookResultContent() len = %d, want 256", len(got)) + } +} + +func TestHookRuntimeEventEmitterBranches(t *testing.T) { + t.Parallel() + + if err := (&hookRuntimeEventEmitter{}).EmitHookEvent(context.Background(), runtimehooks.HookEvent{ + Type: runtimehooks.HookEventStarted, + }); err != nil { + t.Fatalf("EmitHookEvent() with nil service error = %v", err) + } + + service := &Service{events: make(chan RuntimeEvent, 8)} + emitter := newHookRuntimeEventEmitter(service) + if err := emitter.EmitHookEvent(context.Background(), runtimehooks.HookEvent{}); err != nil { + t.Fatalf("EmitHookEvent() blank type error = %v", err) + } + if got := len(collectRuntimeEvents(service.Events())); got != 0 { + t.Fatalf("expected blank event type to be ignored, got %d events", got) + } + + startedAt := time.Date(2026, 4, 20, 10, 30, 0, 0, time.UTC) + ctx := withRuntimeHookEnvelope(context.Background(), hookRuntimeEnvelope{ + RunID: "run-evt", + SessionID: "session-evt", + Turn: 2, + Phase: "execute", + }) + if err := emitter.EmitHookEvent(ctx, runtimehooks.HookEvent{ + Type: runtimehooks.HookEventFinished, + HookID: "hook-evt", + Point: runtimehooks.HookPointAfterToolResult, + Scope: runtimehooks.HookScopeInternal, + Kind: runtimehooks.HookKindFunction, + Mode: runtimehooks.HookModeSync, + Status: runtimehooks.HookResultPass, + StartedAt: startedAt, + DurationMS: 12, + Error: "", + }); err != nil { + t.Fatalf("EmitHookEvent() finished error = %v", err) + } + events := collectRuntimeEvents(service.Events()) + if len(events) != 1 { + t.Fatalf("events len = %d, want 1", len(events)) + } + evt := events[0] + if evt.Type != EventHookFinished || evt.RunID != "run-evt" || evt.SessionID != "session-evt" || evt.Turn != 2 || evt.Phase != "execute" { + t.Fatalf("unexpected runtime event envelope: %+v", evt) + } + payload, ok := evt.Payload.(HookEventPayload) + if !ok { + t.Fatalf("payload type = %T, want HookEventPayload", evt.Payload) + } + if payload.HookID != "hook-evt" || payload.Point != string(runtimehooks.HookPointAfterToolResult) || payload.DurationMS != 12 { + t.Fatalf("unexpected payload: %+v", payload) + } +} diff --git a/internal/runtime/run.go b/internal/runtime/run.go index 389ada87..d1e61ced 100644 --- a/internal/runtime/run.go +++ b/internal/runtime/run.go @@ -18,6 +18,7 @@ import ( providertypes "neo-code/internal/provider/types" "neo-code/internal/runtime/acceptance" "neo-code/internal/runtime/controlplane" + runtimehooks "neo-code/internal/runtime/hooks" "neo-code/internal/runtime/streaming" agentsession "neo-code/internal/session" "neo-code/internal/tools" @@ -253,6 +254,26 @@ func (s *Service) Run(ctx context.Context, input UserInput) (err error) { if err := s.setBaseRunState(ctx, &state, controlplane.RunStateVerify); err != nil { return s.handleRunError(err) } + completionHookOutput := s.runHookPoint( + ctx, + &state, + runtimehooks.HookPointBeforeCompletionDecision, + runtimehooks.HookContext{ + Metadata: map[string]any{ + "completion_passed": completed, + "has_tool_calls": hasToolCalls, + "assistant_role": strings.TrimSpace(turnOutput.assistant.Role), + }, + }, + ) + if completionHookOutput.Blocked { + s.emitRunScoped(ctx, EventHookBlocked, &state, HookBlockedPayload{ + HookID: strings.TrimSpace(completionHookOutput.BlockedBy), + Point: string(runtimehooks.HookPointBeforeCompletionDecision), + Reason: findHookBlockMessage(completionHookOutput), + Enforced: false, + }) + } s.emitRunScoped(ctx, EventVerificationStarted, &state, VerificationStartedPayload{ CompletionPassed: completed, diff --git a/internal/runtime/runtime.go b/internal/runtime/runtime.go index 28bde9f4..35f246a2 100644 --- a/internal/runtime/runtime.go +++ b/internal/runtime/runtime.go @@ -17,6 +17,7 @@ import ( "neo-code/internal/provider/builtin" providertypes "neo-code/internal/provider/types" "neo-code/internal/runtime/approval" + runtimehooks "neo-code/internal/runtime/hooks" "neo-code/internal/security" agentsession "neo-code/internal/session" "neo-code/internal/skills" @@ -130,6 +131,7 @@ type Service struct { memoExtractor MemoExtractor skillsRegistry skills.Registry budgetResolver BudgetResolver + hookExecutor HookExecutor events chan RuntimeEvent sessionMu sync.Mutex @@ -181,7 +183,7 @@ func NewWithFactory( }) } - return &Service{ + service := &Service{ configManager: configManager, sessionStore: sessionStore, toolManager: toolManager, @@ -196,6 +198,8 @@ func NewWithFactory( activeRunByID: make(map[string]uint64), activeRunTokenIDs: make(map[uint64]string), } + service.hookExecutor = runtimehooks.NewExecutor(runtimehooks.NewRegistry(), newHookRuntimeEventEmitter(service), runtimehooks.DefaultHookTimeout) + return service } // SetMemoExtractor 设置可选记忆提取钩子,由 Run 在结束时异步触发。 @@ -349,3 +353,8 @@ func isRuntimeSessionAlreadyExistsError(err error) bool { func (s *Service) SetBudgetResolver(resolver BudgetResolver) { s.budgetResolver = resolver } + +// SetHookExecutor 设置 runtime 生命周期 hook 执行器;传入 nil 可禁用 hook 执行。 +func (s *Service) SetHookExecutor(executor HookExecutor) { + s.hookExecutor = executor +} diff --git a/internal/runtime/toolexec.go b/internal/runtime/toolexec.go index 757d4af7..5778762f 100644 --- a/internal/runtime/toolexec.go +++ b/internal/runtime/toolexec.go @@ -7,6 +7,7 @@ import ( "sync" providertypes "neo-code/internal/provider/types" + runtimehooks "neo-code/internal/runtime/hooks" "neo-code/internal/tools" ) @@ -110,6 +111,37 @@ func (s *Service) executeOneToolCall( toolLock.Lock() defer toolLock.Unlock() + beforeToolHookOutput := s.runHookPoint(ctx, state, runtimehooks.HookPointBeforeToolCall, runtimehooks.HookContext{ + Metadata: map[string]any{ + "tool_call_id": strings.TrimSpace(call.ID), + "tool_name": strings.TrimSpace(call.Name), + "tool_arguments": strings.TrimSpace(call.Arguments), + "workdir": strings.TrimSpace(snapshot.Workdir), + }, + }) + if beforeToolHookOutput.Blocked { + reason := findHookBlockMessage(beforeToolHookOutput) + result := tools.NewErrorResult(call.Name, hookErrorClassBlocked, reason, map[string]any{ + "hook_id": beforeToolHookOutput.BlockedBy, + "point": string(runtimehooks.HookPointBeforeToolCall), + }) + result.ToolCallID = call.ID + result.ErrorClass = hookErrorClassBlocked + s.emitRunScoped(ctx, EventHookBlocked, state, HookBlockedPayload{ + HookID: strings.TrimSpace(beforeToolHookOutput.BlockedBy), + Point: string(runtimehooks.HookPointBeforeToolCall), + ToolCallID: strings.TrimSpace(call.ID), + ToolName: strings.TrimSpace(call.Name), + Reason: reason, + Enforced: true, + }) + if err := s.appendToolMessageAndSave(ctx, state, call, result); err != nil { + return result, false, err + } + s.emitRunScoped(ctx, EventToolResult, state, result) + return result, false, nil + } + s.emitRunScoped(ctx, EventToolStart, state, call) result, execErr := s.executeToolCallWithPermission(ctx, permissionExecutionInput{ @@ -125,11 +157,13 @@ func (s *Service) executeOneToolCall( }) if errors.Is(execErr, context.Canceled) { + s.emitAfterToolResultHook(ctx, state, call, result, execErr) return result, false, execErr } if execErr != nil && strings.TrimSpace(result.Content) == "" { result.Content = execErr.Error() } + s.emitAfterToolResultHook(ctx, state, call, result, execErr) if err := s.appendToolMessageAndSave(ctx, state, call, result); err != nil { if execErr != nil && errors.Is(err, context.Canceled) { @@ -246,3 +280,35 @@ func hasSuccessfulWorkspaceWriteFact(result tools.ToolResult, execErr error) boo } return result.Facts.WorkspaceWrite } + +func summarizeHookResultContent(content string) string { + trimmed := strings.TrimSpace(content) + if len(trimmed) <= 256 { + return trimmed + } + return trimmed[:256] +} + +// emitAfterToolResultHook 在工具结果确定后触发 after_tool_result 挂点,仅提供只读摘要元信息。 +func (s *Service) emitAfterToolResultHook( + ctx context.Context, + state *runState, + call providertypes.ToolCall, + result tools.ToolResult, + execErr error, +) { + afterToolHookMetadata := map[string]any{ + "tool_call_id": strings.TrimSpace(call.ID), + "tool_name": strings.TrimSpace(call.Name), + "is_error": result.IsError, + "error_class": strings.TrimSpace(result.ErrorClass), + "result_content_preview": summarizeHookResultContent(result.Content), + "result_metadata_present": len(result.Metadata) > 0, + } + if execErr != nil { + afterToolHookMetadata["execution_error"] = strings.TrimSpace(execErr.Error()) + } + _ = s.runHookPoint(ctx, state, runtimehooks.HookPointAfterToolResult, runtimehooks.HookContext{ + Metadata: afterToolHookMetadata, + }) +} diff --git a/internal/runtime/turn_control.go b/internal/runtime/turn_control.go index bca9b757..09aff177 100644 --- a/internal/runtime/turn_control.go +++ b/internal/runtime/turn_control.go @@ -241,6 +241,9 @@ func normalizeToolResultContent(content string) string { // classifyToolError 为错误结果生成轻量分类,避免直接依赖完整错误文案。 func classifyToolError(result tools.ToolResult) string { + if explicit := strings.TrimSpace(result.ErrorClass); explicit != "" { + return explicit + } trimmed := strings.ToLower(strings.TrimSpace(result.Content)) switch { case strings.Contains(trimmed, "timeout"): diff --git a/internal/runtime/turn_control_test.go b/internal/runtime/turn_control_test.go index 96e85732..ba4b44db 100644 --- a/internal/runtime/turn_control_test.go +++ b/internal/runtime/turn_control_test.go @@ -137,3 +137,16 @@ func TestHasSuccessfulVerificationResultRequiresStructuredFacts(t *testing.T) { t.Fatalf("expected incomplete verification facts to be ignored") } } + +func TestClassifyToolErrorPrefersExplicitErrorClass(t *testing.T) { + t.Parallel() + + got := classifyToolError(tools.ToolResult{ + IsError: true, + ErrorClass: " hook_blocked ", + Content: "permission denied", + }) + if got != "hook_blocked" { + t.Fatalf("classifyToolError() = %q, want hook_blocked", got) + } +} diff --git a/internal/tools/types.go b/internal/tools/types.go index 68b30038..786a0a26 100644 --- a/internal/tools/types.go +++ b/internal/tools/types.go @@ -94,6 +94,8 @@ type ToolResult struct { Name string Content string IsError bool + // ErrorClass 表示机器可读的错误分类(例如 hook_blocked/permission_denied)。 + ErrorClass string Metadata map[string]any Facts ToolExecutionFacts } diff --git a/internal/tui/core/app/update.go b/internal/tui/core/app/update.go index 76e43ba4..4fb5283c 100644 --- a/internal/tui/core/app/update.go +++ b/internal/tui/core/app/update.go @@ -1659,6 +1659,87 @@ var runtimeEventHandlerRegistry = map[tuiservices.EventType]func(*App, tuiservic tuiservices.EventSkillActivated: runtimeEventSkillActivatedHandler, tuiservices.EventSkillDeactivated: runtimeEventSkillDeactivatedHandler, tuiservices.EventSkillMissing: runtimeEventSkillMissingHandler, + tuiservices.EventHookStarted: runtimeEventHookStartedHandler, + tuiservices.EventHookFinished: runtimeEventHookFinishedHandler, + tuiservices.EventHookFailed: runtimeEventHookFailedHandler, + tuiservices.EventHookBlocked: runtimeEventHookBlockedHandler, +} + +func runtimeEventHookStartedHandler(a *App, event tuiservices.RuntimeEvent) bool { + payload, ok := event.Payload.(tuiservices.HookEventPayload) + if !ok { + return false + } + hookID := strings.TrimSpace(payload.HookID) + if hookID == "" { + hookID = "unknown_hook" + } + point := strings.TrimSpace(payload.Point) + if point == "" { + point = "unknown_point" + } + a.appendActivity("hook", "Hook started", hookID+" @ "+point, false) + return false +} + +func runtimeEventHookFinishedHandler(a *App, event tuiservices.RuntimeEvent) bool { + payload, ok := event.Payload.(tuiservices.HookEventPayload) + if !ok { + return false + } + hookID := strings.TrimSpace(payload.HookID) + if hookID == "" { + hookID = "unknown_hook" + } + status := strings.TrimSpace(payload.Status) + if status == "" { + status = "pass" + } + detail := fmt.Sprintf("%s (%dms)", status, payload.DurationMS) + a.appendActivity("hook", "Hook finished: "+hookID, detail, false) + return false +} + +func runtimeEventHookFailedHandler(a *App, event tuiservices.RuntimeEvent) bool { + payload, ok := event.Payload.(tuiservices.HookEventPayload) + if !ok { + return false + } + hookID := strings.TrimSpace(payload.HookID) + if hookID == "" { + hookID = "unknown_hook" + } + detail := strings.TrimSpace(payload.Error) + if detail == "" { + detail = "hook execution failed" + } + a.appendActivity("hook", "Hook failed: "+hookID, detail, true) + return false +} + +func runtimeEventHookBlockedHandler(a *App, event tuiservices.RuntimeEvent) bool { + payload, ok := event.Payload.(tuiservices.HookBlockedPayload) + if !ok { + return false + } + hookID := strings.TrimSpace(payload.HookID) + if hookID == "" { + hookID = "unknown_hook" + } + point := strings.TrimSpace(payload.Point) + if point == "" { + point = "unknown_point" + } + reason := strings.TrimSpace(payload.Reason) + if reason == "" { + reason = "hook returned block" + } + title := "Hook blocked: " + hookID + " @ " + point + if !payload.Enforced { + title = "Hook block observed: " + hookID + " @ " + point + } + a.appendActivity("hook", title, reason, payload.Enforced) + return false } func runtimeEventPhaseChangedHandler(a *App, event tuiservices.RuntimeEvent) bool { diff --git a/internal/tui/core/app/update_runtime_events_test.go b/internal/tui/core/app/update_runtime_events_test.go index 55bd6aef..745e7d22 100644 --- a/internal/tui/core/app/update_runtime_events_test.go +++ b/internal/tui/core/app/update_runtime_events_test.go @@ -189,6 +189,75 @@ func TestRuntimeEventHandlerRegistryContainsRenamedEvents(t *testing.T) { if _, ok := runtimeEventHandlerRegistry[agentruntime.EventAcceptanceDecided]; !ok { t.Fatalf("expected acceptance_decided handler to be registered") } + if _, ok := runtimeEventHandlerRegistry[agentruntime.EventHookStarted]; !ok { + t.Fatalf("expected hook_started handler to be registered") + } + if _, ok := runtimeEventHandlerRegistry[agentruntime.EventHookFinished]; !ok { + t.Fatalf("expected hook_finished handler to be registered") + } + if _, ok := runtimeEventHandlerRegistry[agentruntime.EventHookFailed]; !ok { + t.Fatalf("expected hook_failed handler to be registered") + } + if _, ok := runtimeEventHandlerRegistry[agentruntime.EventHookBlocked]; !ok { + t.Fatalf("expected hook_blocked handler to be registered") + } +} + +func TestRuntimeHookEventHandlers(t *testing.T) { + t.Parallel() + + app, _ := newTestApp(t) + if handled := runtimeEventHookStartedHandler(&app, agentruntime.RuntimeEvent{Payload: 1}); handled { + t.Fatalf("expected invalid hook_started payload to return false") + } + runtimeEventHookStartedHandler(&app, agentruntime.RuntimeEvent{ + Payload: agentruntime.HookEventPayload{HookID: " ", Point: ""}, + }) + last := app.activities[len(app.activities)-1] + if last.Title != "Hook started" || !strings.Contains(last.Detail, "unknown_hook @ unknown_point") { + t.Fatalf("unexpected hook started activity: %+v", last) + } + + if handled := runtimeEventHookFinishedHandler(&app, agentruntime.RuntimeEvent{Payload: "bad"}); handled { + t.Fatalf("expected invalid hook_finished payload to return false") + } + runtimeEventHookFinishedHandler(&app, agentruntime.RuntimeEvent{ + Payload: agentruntime.HookEventPayload{HookID: "hook-1", DurationMS: 7}, + }) + last = app.activities[len(app.activities)-1] + if last.Title != "Hook finished: hook-1" || !strings.Contains(last.Detail, "pass (7ms)") { + t.Fatalf("unexpected hook finished activity: %+v", last) + } + + if handled := runtimeEventHookFailedHandler(&app, agentruntime.RuntimeEvent{Payload: true}); handled { + t.Fatalf("expected invalid hook_failed payload to return false") + } + runtimeEventHookFailedHandler(&app, agentruntime.RuntimeEvent{ + Payload: agentruntime.HookEventPayload{HookID: "", Error: ""}, + }) + last = app.activities[len(app.activities)-1] + if !last.IsError || last.Title != "Hook failed: unknown_hook" || last.Detail != "hook execution failed" { + t.Fatalf("unexpected hook failed activity: %+v", last) + } + + if handled := runtimeEventHookBlockedHandler(&app, agentruntime.RuntimeEvent{Payload: map[string]any{}}); handled { + t.Fatalf("expected invalid hook_blocked payload to return false") + } + runtimeEventHookBlockedHandler(&app, agentruntime.RuntimeEvent{ + Payload: agentruntime.HookBlockedPayload{HookID: "", Point: "", Reason: "", Enforced: false}, + }) + last = app.activities[len(app.activities)-1] + if last.IsError || last.Title != "Hook block observed: unknown_hook @ unknown_point" || last.Detail != "hook returned block" { + t.Fatalf("unexpected hook blocked observed activity: %+v", last) + } + + runtimeEventHookBlockedHandler(&app, agentruntime.RuntimeEvent{ + Payload: agentruntime.HookBlockedPayload{HookID: "hook-2", Point: "before_tool_call", Reason: "deny", Enforced: true}, + }) + last = app.activities[len(app.activities)-1] + if !last.IsError || last.Title != "Hook blocked: hook-2 @ before_tool_call" || last.Detail != "deny" { + t.Fatalf("unexpected hook blocked enforced activity: %+v", last) + } } func TestShouldHandleRuntimeEventFiltersBySessionAndRun(t *testing.T) { diff --git a/internal/tui/services/gateway_stream_client.go b/internal/tui/services/gateway_stream_client.go index c97dbfd9..9146a473 100644 --- a/internal/tui/services/gateway_stream_client.go +++ b/internal/tui/services/gateway_stream_client.go @@ -233,6 +233,10 @@ func restoreRuntimePayload(eventType EventType, payload any) (any, error) { return decodeRuntimePayload[AssetSavedPayload](payload) case EventAssetSaveFailed: return decodeRuntimePayload[AssetSaveFailedPayload](payload) + case EventHookStarted, EventHookFinished, EventHookFailed: + return decodeRuntimePayload[HookEventPayload](payload) + case EventHookBlocked: + return decodeRuntimePayload[HookBlockedPayload](payload) case EventTodoUpdated, EventTodoConflict: return decodeRuntimePayload[TodoEventPayload](payload) case EventType(RuntimeEventRunContext): diff --git a/internal/tui/services/gateway_stream_client_test.go b/internal/tui/services/gateway_stream_client_test.go index bb37697f..19f20e00 100644 --- a/internal/tui/services/gateway_stream_client_test.go +++ b/internal/tui/services/gateway_stream_client_test.go @@ -63,6 +63,7 @@ func TestDecodeRuntimeEventFromGatewayNotificationRestoresToolResultPayload(t *t "Name": "bash", "Content": "ok", "IsError": false, + "ErrorClass": "hook_blocked", }, }, }) @@ -78,6 +79,78 @@ func TestDecodeRuntimeEventFromGatewayNotificationRestoresToolResultPayload(t *t if toolResult.ToolCallID != "call-1" || toolResult.Name != "bash" || toolResult.Content != "ok" || toolResult.IsError { t.Fatalf("unexpected tool result payload: %#v", toolResult) } + if toolResult.ErrorClass != "hook_blocked" { + t.Fatalf("toolResult.ErrorClass = %q, want %q", toolResult.ErrorClass, "hook_blocked") + } +} + +func TestDecodeRuntimeEventFromGatewayNotificationRestoresHookBlockedPayload(t *testing.T) { + notification := buildGatewayEventNotification(t, gateway.MessageFrame{ + Type: gateway.FrameTypeEvent, + Action: gateway.FrameActionRun, + SessionID: "session-hook", + RunID: "run-hook", + Payload: map[string]any{ + "runtime_event_type": string(EventHookBlocked), + "payload_version": runtimeEventPayloadVersion, + "payload": map[string]any{ + "hook_id": "block-before-tool", + "point": "before_tool_call", + "tool_call_id": "call-2", + "tool_name": "bash", + "reason": "blocked by policy", + "enforced": true, + }, + }, + }) + + event, err := decodeRuntimeEventFromGatewayNotification(notification) + if err != nil { + t.Fatalf("decodeRuntimeEventFromGatewayNotification() error = %v", err) + } + payload, ok := event.Payload.(HookBlockedPayload) + if !ok { + t.Fatalf("event.Payload type = %T, want HookBlockedPayload", event.Payload) + } + if payload.HookID != "block-before-tool" || payload.Point != "before_tool_call" || payload.ToolName != "bash" || !payload.Enforced { + t.Fatalf("unexpected hook blocked payload: %#v", payload) + } +} + +func TestDecodeRuntimeEventFromGatewayNotificationRestoresHookLifecyclePayload(t *testing.T) { + notification := buildGatewayEventNotification(t, gateway.MessageFrame{ + Type: gateway.FrameTypeEvent, + Action: gateway.FrameActionRun, + SessionID: "session-hook", + RunID: "run-hook", + Payload: map[string]any{ + "runtime_event_type": string(EventHookStarted), + "payload_version": runtimeEventPayloadVersion, + "payload": map[string]any{ + "hook_id": "observe-after-tool", + "point": "after_tool_result", + "status": "pass", + "duration_ms": 9, + "scope": "internal", + "kind": "function", + "mode": "sync", + "started_at": "2026-04-20T10:30:00Z", + "error": "", + }, + }, + }) + + event, err := decodeRuntimeEventFromGatewayNotification(notification) + if err != nil { + t.Fatalf("decodeRuntimeEventFromGatewayNotification() error = %v", err) + } + payload, ok := event.Payload.(HookEventPayload) + if !ok { + t.Fatalf("event.Payload type = %T, want HookEventPayload", event.Payload) + } + if payload.HookID != "observe-after-tool" || payload.Point != "after_tool_result" || payload.Status != "pass" || payload.DurationMS != 9 { + t.Fatalf("unexpected hook lifecycle payload: %#v", payload) + } } func TestDecodeRuntimeEventFromGatewayNotificationSupportsNestedEnvelope(t *testing.T) { diff --git a/internal/tui/services/runtime_contract.go b/internal/tui/services/runtime_contract.go index 2898d3dc..67d129b1 100644 --- a/internal/tui/services/runtime_contract.go +++ b/internal/tui/services/runtime_contract.go @@ -302,6 +302,29 @@ type AssetSaveFailedPayload struct { Message string `json:"message"` } +// HookEventPayload 描述 hook 生命周期事件。 +type HookEventPayload struct { + HookID string `json:"hook_id"` + Point string `json:"point"` + Scope string `json:"scope"` + Kind string `json:"kind"` + Mode string `json:"mode"` + Status string `json:"status,omitempty"` + StartedAt time.Time `json:"started_at,omitempty"` + DurationMS int64 `json:"duration_ms,omitempty"` + Error string `json:"error,omitempty"` +} + +// HookBlockedPayload 描述 hook 阻断事件。 +type HookBlockedPayload struct { + HookID string `json:"hook_id"` + Point string `json:"point"` + ToolCallID string `json:"tool_call_id,omitempty"` + ToolName string `json:"tool_name,omitempty"` + Reason string `json:"reason,omitempty"` + Enforced bool `json:"enforced"` +} + const ( EventUserMessage EventType = "user_message" EventAgentChunk EventType = "agent_chunk" @@ -336,4 +359,8 @@ const ( EventInputNormalized EventType = "input_normalized" EventAssetSaved EventType = "asset_saved" EventAssetSaveFailed EventType = "asset_save_failed" + EventHookStarted EventType = "hook_started" + EventHookFinished EventType = "hook_finished" + EventHookFailed EventType = "hook_failed" + EventHookBlocked EventType = "hook_blocked" )