From 8f4bef9010bfb7518be1d137ab38f5759eb71906 Mon Sep 17 00:00:00 2001 From: pionxe Date: Mon, 13 Apr 2026 21:26:39 +0800 Subject: [PATCH 1/9] =?UTF-8?q?feat(gateway):=20=E6=96=B0=E5=A2=9E?= =?UTF-8?q?=E7=8B=AC=E7=AB=8B=E8=BF=9B=E7=A8=8B=E9=AA=A8=E6=9E=B6=E4=B8=8E?= =?UTF-8?q?=E6=9C=AC=E5=9C=B0IPC=E6=8E=A2=E6=B4=BB=E5=A5=91=E7=BA=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 6 + cmd/neocode-gateway/main.go | 87 +++++++ go.mod | 1 + go.sum | 2 + internal/gateway/server.go | 241 ++++++++++++++++++ internal/gateway/server_test.go | 110 ++++++++ internal/gateway/transport/address_unix.go | 20 ++ internal/gateway/transport/address_windows.go | 10 + internal/gateway/transport/listen.go | 28 ++ internal/gateway/transport/listen_unix.go | 54 ++++ .../gateway/transport/listen_unix_test.go | 58 +++++ internal/gateway/transport/listen_windows.go | 19 ++ .../gateway/transport/listen_windows_test.go | 51 ++++ internal/gateway/types.go | 2 + internal/gateway/validate.go | 5 +- internal/gateway/validate_test.go | 9 + 16 files changed, 702 insertions(+), 1 deletion(-) create mode 100644 cmd/neocode-gateway/main.go create mode 100644 internal/gateway/server.go create mode 100644 internal/gateway/server_test.go create mode 100644 internal/gateway/transport/address_unix.go create mode 100644 internal/gateway/transport/address_windows.go create mode 100644 internal/gateway/transport/listen.go create mode 100644 internal/gateway/transport/listen_unix.go create mode 100644 internal/gateway/transport/listen_unix_test.go create mode 100644 internal/gateway/transport/listen_windows.go create mode 100644 internal/gateway/transport/listen_windows_test.go diff --git a/README.md b/README.md index 7b58fc96..43c5a21b 100644 --- a/README.md +++ b/README.md @@ -49,6 +49,12 @@ cd neo-code go run ./cmd/neocode ``` +Gateway 独立进程(Step 1 骨架): + +```bash +go run ./cmd/neocode-gateway +``` + 设置 API Key 示例(按你使用的 provider 选择): ```bash diff --git a/cmd/neocode-gateway/main.go b/cmd/neocode-gateway/main.go new file mode 100644 index 00000000..cb231449 --- /dev/null +++ b/cmd/neocode-gateway/main.go @@ -0,0 +1,87 @@ +package main + +import ( + "context" + "errors" + "flag" + "fmt" + "log" + "os" + "os/signal" + "strings" + "syscall" + + "neo-code/internal/gateway" +) + +const ( + defaultLogLevel = "info" +) + +var errHelpRequested = errors.New("help requested") + +// main 负责启动 Gateway 独立进程,并在收到系统信号时优雅退出。 +func main() { + if err := run(); err != nil { + fmt.Fprintf(os.Stderr, "neocode-gateway: %v\n", err) + os.Exit(1) + } +} + +// run 解析启动参数并驱动网关服务生命周期。 +func run() error { + listenAddress, logLevel, err := parseFlags() + if err != nil { + if errors.Is(err, errHelpRequested) { + return nil + } + return err + } + + logger := log.New(os.Stderr, "neocode-gateway: ", log.LstdFlags) + logger.Printf("starting gateway (log-level=%s)", logLevel) + + ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) + defer stop() + + server, err := gateway.NewServer(gateway.ServerOptions{ + ListenAddress: listenAddress, + Logger: logger, + }) + if err != nil { + return err + } + defer func() { + _ = server.Close(context.Background()) + }() + + logger.Printf("gateway listen address: %s", server.ListenAddress()) + return server.Serve(ctx, nil) +} + +// parseFlags 解析命令行参数并执行基础校验。 +func parseFlags() (listenAddress string, logLevel string, err error) { + fs := flag.NewFlagSet(os.Args[0], flag.ContinueOnError) + fs.SetOutput(os.Stdout) + + var listen string + var level string + fs.StringVar(&listen, "listen", "", "gateway listen address (optional override)") + fs.StringVar(&level, "log-level", defaultLogLevel, "gateway log level: debug|info|warn|error") + + if parseErr := fs.Parse(os.Args[1:]); parseErr != nil { + if errors.Is(parseErr, flag.ErrHelp) { + return "", "", errHelpRequested + } + return "", "", parseErr + } + + normalizedLevel := strings.ToLower(strings.TrimSpace(level)) + switch normalizedLevel { + case "debug", "info", "warn", "error": + default: + return "", "", fmt.Errorf("invalid --log-level %q: must be debug|info|warn|error", level) + } + + return strings.TrimSpace(listen), normalizedLevel, nil +} diff --git a/go.mod b/go.mod index 7d1726d8..b5b9bda2 100644 --- a/go.mod +++ b/go.mod @@ -16,6 +16,7 @@ require ( ) require ( + github.com/Microsoft/go-winio v0.6.2 // indirect github.com/alecthomas/chroma/v2 v2.20.0 // indirect github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect github.com/aymerick/douceur v0.2.0 // indirect diff --git a/go.sum b/go.sum index 56abd244..59439be3 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,7 @@ github.com/MakeNowJust/heredoc v1.0.0 h1:cXCdzVdstXyiTqTvfqk9SDHpKNjxuom+DOlyEeQ4pzQ= github.com/MakeNowJust/heredoc v1.0.0/go.mod h1:mG5amYoWBHf8vpLOuehzbGGw0EHxpZZ6lCpQ4fNJ8LE= +github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY= +github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU= github.com/alecthomas/assert/v2 v2.11.0 h1:2Q9r3ki8+JYXvGsDyBXwH3LcJ+WK5D0gc5E8vS6K3D0= github.com/alecthomas/assert/v2 v2.11.0/go.mod h1:Bze95FyfUr7x34QZrjL+XP+0qgp/zg8yS+TtBj1WA3k= github.com/alecthomas/chroma/v2 v2.20.0 h1:sfIHpxPyR07/Oylvmcai3X/exDlE8+FA820NTz+9sGw= diff --git a/internal/gateway/server.go b/internal/gateway/server.go new file mode 100644 index 00000000..0542ed46 --- /dev/null +++ b/internal/gateway/server.go @@ -0,0 +1,241 @@ +package gateway + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "log" + "net" + "os" + "strings" + "sync" + + "neo-code/internal/gateway/transport" +) + +// ServerOptions 描述网关服务启动所需的可选配置。 +type ServerOptions struct { + ListenAddress string + Logger *log.Logger +} + +// Server 提供基于本地 IPC 的网关服务骨架实现。 +type Server struct { + listenAddress string + logger *log.Logger + + mu sync.Mutex + listener net.Listener + conns map[net.Conn]struct{} + wg sync.WaitGroup +} + +// NewServer 创建网关服务实例,并解析默认监听地址。 +func NewServer(options ServerOptions) (*Server, error) { + listenAddress := strings.TrimSpace(options.ListenAddress) + if listenAddress == "" { + resolved, err := transport.DefaultListenAddress() + if err != nil { + return nil, err + } + listenAddress = resolved + } + + logger := options.Logger + if logger == nil { + logger = log.New(os.Stderr, "gateway: ", log.LstdFlags) + } + + return &Server{ + listenAddress: listenAddress, + logger: logger, + conns: make(map[net.Conn]struct{}), + }, nil +} + +// ListenAddress 返回当前服务绑定的监听地址。 +func (s *Server) ListenAddress() string { + return s.listenAddress +} + +// Serve 启动 IPC 监听并处理客户端请求。 +func (s *Server) Serve(ctx context.Context, runtimePort RuntimePort) error { + listener, err := transport.Listen(s.listenAddress) + if err != nil { + return err + } + + s.mu.Lock() + if s.listener != nil { + s.mu.Unlock() + _ = listener.Close() + return fmt.Errorf("gateway: server is already serving") + } + s.listener = listener + s.mu.Unlock() + + s.logger.Printf("listening on %s", s.listenAddress) + + go func() { + <-ctx.Done() + _ = s.Close(context.Background()) + }() + + for { + conn, acceptErr := listener.Accept() + if acceptErr != nil { + if errors.Is(acceptErr, net.ErrClosed) || ctx.Err() != nil || s.isClosed() { + return nil + } + return fmt.Errorf("gateway: accept connection: %w", acceptErr) + } + + s.wg.Add(1) + go func() { + defer s.wg.Done() + s.trackConnection(conn) + defer s.untrackConnection(conn) + s.handleConnection(ctx, conn, runtimePort) + }() + } +} + +// Close 关闭监听器并等待所有连接处理协程退出。 +func (s *Server) Close(ctx context.Context) error { + s.mu.Lock() + listener := s.listener + s.listener = nil + s.mu.Unlock() + + var closeErr error + if listener != nil { + closeErr = listener.Close() + } + + for conn := range s.snapshotConnections() { + closeErr = errors.Join(closeErr, conn.Close()) + } + + waitDone := make(chan struct{}) + go func() { + s.wg.Wait() + close(waitDone) + }() + + select { + case <-ctx.Done(): + closeErr = errors.Join(closeErr, ctx.Err()) + case <-waitDone: + } + + return closeErr +} + +// isClosed 判断监听器是否已经关闭。 +func (s *Server) isClosed() bool { + s.mu.Lock() + defer s.mu.Unlock() + return s.listener == nil +} + +// snapshotConnections 返回当前连接集合的拷贝,用于关闭流程安全遍历。 +func (s *Server) snapshotConnections() map[net.Conn]struct{} { + s.mu.Lock() + defer s.mu.Unlock() + + copied := make(map[net.Conn]struct{}, len(s.conns)) + for conn := range s.conns { + copied[conn] = struct{}{} + } + return copied +} + +// trackConnection 记录活跃连接,便于关闭时统一清理。 +func (s *Server) trackConnection(conn net.Conn) { + s.mu.Lock() + defer s.mu.Unlock() + s.conns[conn] = struct{}{} +} + +// untrackConnection 移除已结束连接,避免连接集合持续增长。 +func (s *Server) untrackConnection(conn net.Conn) { + s.mu.Lock() + defer s.mu.Unlock() + delete(s.conns, conn) +} + +// handleConnection 在单连接上循环处理消息帧并返回响应帧。 +func (s *Server) handleConnection(ctx context.Context, conn net.Conn, runtimePort RuntimePort) { + defer func() { + _ = conn.Close() + }() + + decoder := json.NewDecoder(conn) + encoder := json.NewEncoder(conn) + + for { + select { + case <-ctx.Done(): + return + default: + } + + var frame MessageFrame + if err := decoder.Decode(&frame); err != nil { + if errors.Is(err, io.EOF) { + return + } + + s.logger.Printf("decode frame failed: %v", err) + _ = encoder.Encode(errorFrame(MessageFrame{}, NewFrameError(ErrorCodeInvalidFrame, "invalid json frame"))) + return + } + + response := s.dispatchFrame(ctx, frame, runtimePort) + if err := encoder.Encode(response); err != nil { + s.logger.Printf("write frame failed: %v", err) + return + } + } +} + +// dispatchFrame 根据请求动作生成响应帧。 +func (s *Server) dispatchFrame(_ context.Context, frame MessageFrame, runtimePort RuntimePort) MessageFrame { + _ = runtimePort + + if validationErr := ValidateFrame(frame); validationErr != nil { + return errorFrame(frame, validationErr) + } + + if frame.Type != FrameTypeRequest { + return errorFrame(frame, NewFrameError(ErrorCodeInvalidFrame, "only request frames are supported")) + } + + switch frame.Action { + case FrameActionPing: + return MessageFrame{ + Type: FrameTypeAck, + Action: FrameActionPing, + RequestID: frame.RequestID, + Payload: map[string]string{ + "message": "pong", + }, + } + default: + return errorFrame(frame, NewFrameError(ErrorCodeUnsupportedAction, "action is not implemented in gateway step 1")) + } +} + +// errorFrame 构建统一错误响应帧。 +func errorFrame(frame MessageFrame, frameErr *FrameError) MessageFrame { + return MessageFrame{ + Type: FrameTypeError, + Action: frame.Action, + RequestID: frame.RequestID, + Error: frameErr, + } +} + +var _ Gateway = (*Server)(nil) diff --git a/internal/gateway/server_test.go b/internal/gateway/server_test.go new file mode 100644 index 00000000..a7e6d362 --- /dev/null +++ b/internal/gateway/server_test.go @@ -0,0 +1,110 @@ +package gateway + +import ( + "context" + "encoding/json" + "net" + "testing" + "time" +) + +func TestServerHandleConnectionPing(t *testing.T) { + t.Parallel() + + server := &Server{} + serverConn, clientConn := net.Pipe() + done := make(chan struct{}) + + go func() { + defer close(done) + server.handleConnection(context.Background(), serverConn, nil) + }() + + encoder := json.NewEncoder(clientConn) + decoder := json.NewDecoder(clientConn) + + if err := encoder.Encode(MessageFrame{ + Type: FrameTypeRequest, + Action: FrameActionPing, + RequestID: "req-1", + }); err != nil { + t.Fatalf("encode request: %v", err) + } + + var response MessageFrame + if err := decoder.Decode(&response); err != nil { + t.Fatalf("decode response: %v", err) + } + + if response.Type != FrameTypeAck { + t.Fatalf("response type = %q, want %q", response.Type, FrameTypeAck) + } + if response.Action != FrameActionPing { + t.Fatalf("response action = %q, want %q", response.Action, FrameActionPing) + } + if response.RequestID != "req-1" { + t.Fatalf("response request_id = %q, want %q", response.RequestID, "req-1") + } + + payloadMap, ok := response.Payload.(map[string]any) + if !ok { + t.Fatalf("response payload type = %T, want map[string]any", response.Payload) + } + if got, _ := payloadMap["message"].(string); got != "pong" { + t.Fatalf("response payload message = %q, want %q", got, "pong") + } + + _ = clientConn.Close() + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("handleConnection did not exit") + } +} + +func TestServerHandleConnectionUnsupportedAction(t *testing.T) { + t.Parallel() + + server := &Server{} + serverConn, clientConn := net.Pipe() + done := make(chan struct{}) + + go func() { + defer close(done) + server.handleConnection(context.Background(), serverConn, nil) + }() + + encoder := json.NewEncoder(clientConn) + decoder := json.NewDecoder(clientConn) + + if err := encoder.Encode(MessageFrame{ + Type: FrameTypeRequest, + Action: FrameActionRun, + RequestID: "req-2", + InputText: "hello", + }); err != nil { + t.Fatalf("encode request: %v", err) + } + + var response MessageFrame + if err := decoder.Decode(&response); err != nil { + t.Fatalf("decode response: %v", err) + } + + if response.Type != FrameTypeError { + t.Fatalf("response type = %q, want %q", response.Type, FrameTypeError) + } + if response.Error == nil { + t.Fatal("response error is nil") + } + if response.Error.Code != ErrorCodeUnsupportedAction.String() { + t.Fatalf("error code = %q, want %q", response.Error.Code, ErrorCodeUnsupportedAction.String()) + } + + _ = clientConn.Close() + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("handleConnection did not exit") + } +} diff --git a/internal/gateway/transport/address_unix.go b/internal/gateway/transport/address_unix.go new file mode 100644 index 00000000..ed6b0c89 --- /dev/null +++ b/internal/gateway/transport/address_unix.go @@ -0,0 +1,20 @@ +//go:build !windows + +package transport + +import ( + "fmt" + "os" + "path/filepath" +) + +const defaultUnixSocketRelativePath = ".neocode/run/gateway.sock" + +// DefaultListenAddress 返回 Unix 系统默认监听地址。 +func DefaultListenAddress() (string, error) { + homeDir, err := os.UserHomeDir() + if err != nil { + return "", fmt.Errorf("gateway: resolve user home dir: %w", err) + } + return filepath.Join(homeDir, defaultUnixSocketRelativePath), nil +} diff --git a/internal/gateway/transport/address_windows.go b/internal/gateway/transport/address_windows.go new file mode 100644 index 00000000..f6b00306 --- /dev/null +++ b/internal/gateway/transport/address_windows.go @@ -0,0 +1,10 @@ +//go:build windows + +package transport + +const defaultWindowsNamedPipePath = `\\.\pipe\neocode-gateway` + +// DefaultListenAddress 返回 Windows 系统默认监听地址。 +func DefaultListenAddress() (string, error) { + return defaultWindowsNamedPipePath, nil +} diff --git a/internal/gateway/transport/listen.go b/internal/gateway/transport/listen.go new file mode 100644 index 00000000..91514564 --- /dev/null +++ b/internal/gateway/transport/listen.go @@ -0,0 +1,28 @@ +package transport + +import ( + "errors" + "net" +) + +// cleanupListener 在关闭底层监听器后执行额外清理逻辑。 +type cleanupListener struct { + net.Listener + cleanup func() error +} + +// newCleanupListener 包装监听器并注入清理钩子。 +func newCleanupListener(listener net.Listener, cleanup func() error) net.Listener { + if cleanup == nil { + return listener + } + return &cleanupListener{ + Listener: listener, + cleanup: cleanup, + } +} + +// Close 关闭监听器并执行额外清理。 +func (l *cleanupListener) Close() error { + return errors.Join(l.Listener.Close(), l.cleanup()) +} diff --git a/internal/gateway/transport/listen_unix.go b/internal/gateway/transport/listen_unix.go new file mode 100644 index 00000000..3bc9c82b --- /dev/null +++ b/internal/gateway/transport/listen_unix.go @@ -0,0 +1,54 @@ +//go:build !windows + +package transport + +import ( + "fmt" + "net" + "os" + "path/filepath" +) + +// Listen 在 Unix 系统上启动 UDS 监听并在关闭时清理 socket 文件。 +func Listen(address string) (net.Listener, error) { + if err := os.MkdirAll(filepath.Dir(address), 0o755); err != nil { + return nil, fmt.Errorf("gateway: create socket dir: %w", err) + } + + if err := removeStaleUnixSocket(address); err != nil { + return nil, err + } + + listener, err := net.Listen("unix", address) + if err != nil { + return nil, fmt.Errorf("gateway: listen unix socket: %w", err) + } + + return newCleanupListener(listener, func() error { + if err := os.Remove(address); err != nil && !os.IsNotExist(err) { + return fmt.Errorf("gateway: remove unix socket: %w", err) + } + return nil + }), nil +} + +// removeStaleUnixSocket 清理历史残留的 socket 文件,避免监听失败。 +func removeStaleUnixSocket(address string) error { + info, err := os.Lstat(address) + if err != nil { + if os.IsNotExist(err) { + return nil + } + return fmt.Errorf("gateway: stat unix socket path: %w", err) + } + + if info.Mode()&os.ModeSocket == 0 { + return fmt.Errorf("gateway: unix socket path exists and is not socket: %s", address) + } + + if err := os.Remove(address); err != nil { + return fmt.Errorf("gateway: remove stale unix socket: %w", err) + } + + return nil +} diff --git a/internal/gateway/transport/listen_unix_test.go b/internal/gateway/transport/listen_unix_test.go new file mode 100644 index 00000000..f0844d46 --- /dev/null +++ b/internal/gateway/transport/listen_unix_test.go @@ -0,0 +1,58 @@ +//go:build !windows + +package transport + +import ( + "net" + "os" + "path/filepath" + "testing" + "time" +) + +func TestListenUnixAcceptsConnectionAndCleansSocket(t *testing.T) { + t.Parallel() + + socketPath := filepath.Join(t.TempDir(), "gateway.sock") + listener, err := Listen(socketPath) + if err != nil { + t.Fatalf("listen unix socket: %v", err) + } + defer func() { + _ = listener.Close() + }() + + acceptDone := make(chan error, 1) + go func() { + conn, acceptErr := listener.Accept() + if acceptErr != nil { + acceptDone <- acceptErr + return + } + _ = conn.Close() + acceptDone <- nil + }() + + conn, err := net.Dial("unix", socketPath) + if err != nil { + t.Fatalf("dial unix socket: %v", err) + } + _ = conn.Close() + + select { + case acceptErr := <-acceptDone: + if acceptErr != nil { + t.Fatalf("accept connection: %v", acceptErr) + } + case <-time.After(2 * time.Second): + t.Fatal("accept timed out") + } + + if err := listener.Close(); err != nil { + t.Fatalf("close listener: %v", err) + } + + if _, err := os.Stat(socketPath); !os.IsNotExist(err) { + t.Fatalf("socket file should be removed on close, stat err: %v", err) + } +} diff --git a/internal/gateway/transport/listen_windows.go b/internal/gateway/transport/listen_windows.go new file mode 100644 index 00000000..eebb6ccc --- /dev/null +++ b/internal/gateway/transport/listen_windows.go @@ -0,0 +1,19 @@ +//go:build windows + +package transport + +import ( + "fmt" + "net" + + "github.com/Microsoft/go-winio" +) + +// Listen 在 Windows 系统上启动 Named Pipe 监听。 +func Listen(address string) (net.Listener, error) { + listener, err := winio.ListenPipe(address, nil) + if err != nil { + return nil, fmt.Errorf("gateway: listen named pipe: %w", err) + } + return newCleanupListener(listener, nil), nil +} diff --git a/internal/gateway/transport/listen_windows_test.go b/internal/gateway/transport/listen_windows_test.go new file mode 100644 index 00000000..7c331c10 --- /dev/null +++ b/internal/gateway/transport/listen_windows_test.go @@ -0,0 +1,51 @@ +//go:build windows + +package transport + +import ( + "fmt" + "testing" + "time" + + "github.com/Microsoft/go-winio" +) + +func TestListenNamedPipeAcceptsConnection(t *testing.T) { + t.Parallel() + + pipePath := fmt.Sprintf(`\\.\pipe\neocode-gateway-test-%d`, time.Now().UnixNano()) + listener, err := Listen(pipePath) + if err != nil { + t.Fatalf("listen named pipe: %v", err) + } + defer func() { + _ = listener.Close() + }() + + acceptDone := make(chan error, 1) + go func() { + conn, acceptErr := listener.Accept() + if acceptErr != nil { + acceptDone <- acceptErr + return + } + _ = conn.Close() + acceptDone <- nil + }() + + timeout := 2 * time.Second + conn, err := winio.DialPipe(pipePath, &timeout) + if err != nil { + t.Fatalf("dial named pipe: %v", err) + } + _ = conn.Close() + + select { + case acceptErr := <-acceptDone: + if acceptErr != nil { + t.Fatalf("accept connection: %v", acceptErr) + } + case <-time.After(3 * time.Second): + t.Fatal("accept timed out") + } +} diff --git a/internal/gateway/types.go b/internal/gateway/types.go index 34807ba8..851fffc4 100644 --- a/internal/gateway/types.go +++ b/internal/gateway/types.go @@ -18,6 +18,8 @@ const ( type FrameAction string const ( + // FrameActionPing 表示探活动作,用于验证网关可用性。 + FrameActionPing FrameAction = "ping" // FrameActionRun 表示发起一次运行。 FrameActionRun FrameAction = "run" // FrameActionCompact 表示触发一次手动压缩。 diff --git a/internal/gateway/validate.go b/internal/gateway/validate.go index d83d1a15..aae8283f 100644 --- a/internal/gateway/validate.go +++ b/internal/gateway/validate.go @@ -30,6 +30,8 @@ func validateRequestFrame(frame MessageFrame) *FrameError { } switch frame.Action { + case FrameActionPing: + return nil case FrameActionRun: return validateRunFrame(frame) case FrameActionCompact, FrameActionLoadSession: @@ -168,7 +170,8 @@ func isValidFrameType(frameType FrameType) bool { // isValidFrameAction 判断动作是否属于协议定义集合。 func isValidFrameAction(action FrameAction) bool { switch action { - case FrameActionRun, + case FrameActionPing, + FrameActionRun, FrameActionCompact, FrameActionCancel, FrameActionListSessions, diff --git a/internal/gateway/validate_test.go b/internal/gateway/validate_test.go index 48267aa0..85f2dad1 100644 --- a/internal/gateway/validate_test.go +++ b/internal/gateway/validate_test.go @@ -13,6 +13,15 @@ func TestValidateFrame_BasicRules(t *testing.T) { wantCode string wantField string }{ + { + name: "valid ping request", + frame: MessageFrame{ + Type: FrameTypeRequest, + Action: FrameActionPing, + RequestID: "req-ping", + }, + wantNil: true, + }, { name: "valid run with input_text", frame: MessageFrame{ From f01183b9d76efd70f1d76f1a94980ca73d7818cf Mon Sep 17 00:00:00 2001 From: pionxe Date: Mon, 13 Apr 2026 22:02:57 +0800 Subject: [PATCH 2/9] =?UTF-8?q?fix(gateway):=20=E4=BF=AE=E5=A4=8D=20Accept?= =?UTF-8?q?=20=E4=B8=8E=20Close=20=E4=B9=8B=E9=97=B4=E7=9A=84=E8=BF=9E?= =?UTF-8?q?=E6=8E=A5=E7=99=BB=E8=AE=B0=E7=AB=9E=E6=80=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/gateway/server.go | 26 ++++- internal/gateway/server_race_test.go | 136 +++++++++++++++++++++++++++ 2 files changed, 160 insertions(+), 2 deletions(-) create mode 100644 internal/gateway/server_race_test.go diff --git a/internal/gateway/server.go b/internal/gateway/server.go index 0542ed46..2f2795cf 100644 --- a/internal/gateway/server.go +++ b/internal/gateway/server.go @@ -19,12 +19,14 @@ import ( type ServerOptions struct { ListenAddress string Logger *log.Logger + listenFn func(address string) (net.Listener, error) } // Server 提供基于本地 IPC 的网关服务骨架实现。 type Server struct { listenAddress string logger *log.Logger + listenFn func(address string) (net.Listener, error) mu sync.Mutex listener net.Listener @@ -47,10 +49,15 @@ func NewServer(options ServerOptions) (*Server, error) { if logger == nil { logger = log.New(os.Stderr, "gateway: ", log.LstdFlags) } + listenFn := options.listenFn + if listenFn == nil { + listenFn = transport.Listen + } return &Server{ listenAddress: listenAddress, logger: logger, + listenFn: listenFn, conns: make(map[net.Conn]struct{}), }, nil } @@ -62,7 +69,7 @@ func (s *Server) ListenAddress() string { // Serve 启动 IPC 监听并处理客户端请求。 func (s *Server) Serve(ctx context.Context, runtimePort RuntimePort) error { - listener, err := transport.Listen(s.listenAddress) + listener, err := s.listenFn(s.listenAddress) if err != nil { return err } @@ -92,10 +99,14 @@ func (s *Server) Serve(ctx context.Context, runtimePort RuntimePort) error { return fmt.Errorf("gateway: accept connection: %w", acceptErr) } + if !s.registerConnection(conn) { + _ = conn.Close() + continue + } + s.wg.Add(1) go func() { defer s.wg.Done() - s.trackConnection(conn) defer s.untrackConnection(conn) s.handleConnection(ctx, conn, runtimePort) }() @@ -159,6 +170,17 @@ func (s *Server) trackConnection(conn net.Conn) { s.conns[conn] = struct{}{} } +// registerConnection 在服务可用时登记连接,若网关已关闭则拒绝登记。 +func (s *Server) registerConnection(conn net.Conn) bool { + s.mu.Lock() + defer s.mu.Unlock() + if s.listener == nil { + return false + } + s.conns[conn] = struct{}{} + return true +} + // untrackConnection 移除已结束连接,避免连接集合持续增长。 func (s *Server) untrackConnection(conn net.Conn) { s.mu.Lock() diff --git a/internal/gateway/server_race_test.go b/internal/gateway/server_race_test.go new file mode 100644 index 00000000..5798b14f --- /dev/null +++ b/internal/gateway/server_race_test.go @@ -0,0 +1,136 @@ +package gateway + +import ( + "context" + "errors" + "io" + "log" + "net" + "os" + "strings" + "sync" + "testing" + "time" +) + +func TestServeCloseDuringAcceptDoesNotLeakConnection(t *testing.T) { + t.Parallel() + + listener := newStubListener() + server, err := NewServer(ServerOptions{ + ListenAddress: "stub://gateway", + Logger: log.New(io.Discard, "", 0), + listenFn: func(string) (net.Listener, error) { + return listener, nil + }, + }) + if err != nil { + t.Fatalf("new server: %v", err) + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + serveDone := make(chan error, 1) + go func() { + serveDone <- server.Serve(ctx, nil) + }() + + serverConn, clientConn := net.Pipe() + defer clientConn.Close() + + closeDone := make(chan error, 1) + listener.onAccept = func() { + go func() { + closeDone <- server.Close(context.Background()) + }() + } + + listener.acceptCh <- serverConn + + select { + case closeErr := <-closeDone: + if closeErr != nil { + t.Fatalf("close server: %v", closeErr) + } + case <-time.After(2 * time.Second): + t.Fatal("close timed out") + } + + select { + case serveErr := <-serveDone: + if serveErr != nil { + t.Fatalf("serve returned error: %v", serveErr) + } + case <-time.After(2 * time.Second): + t.Fatal("serve did not exit") + } + + readDone := make(chan error, 1) + go func() { + var buf [1]byte + _, readErr := clientConn.Read(buf[:]) + readDone <- readErr + }() + + select { + case readErr := <-readDone: + if errors.Is(readErr, io.EOF) || errors.Is(readErr, net.ErrClosed) || errors.Is(readErr, os.ErrClosed) { + return + } + if readErr != nil && strings.Contains(readErr.Error(), "closed pipe") { + return + } + t.Fatalf("expected closed connection after server close, got %v", readErr) + case <-time.After(300 * time.Millisecond): + t.Fatal("connection was not closed by server") + } +} + +type stubListener struct { + acceptCh chan net.Conn + closeCh chan struct{} + + onAccept func() + closeOnce sync.Once +} + +func newStubListener() *stubListener { + return &stubListener{ + acceptCh: make(chan net.Conn, 1), + closeCh: make(chan struct{}), + } +} + +func (l *stubListener) Accept() (net.Conn, error) { + select { + case <-l.closeCh: + return nil, net.ErrClosed + case conn := <-l.acceptCh: + if l.onAccept != nil { + l.onAccept() + } + return conn, nil + } +} + +func (l *stubListener) Close() error { + l.closeOnce.Do(func() { + close(l.closeCh) + }) + return nil +} + +func (l *stubListener) Addr() net.Addr { + return stubAddr("stub://gateway") +} + +type stubAddr string + +func (a stubAddr) Network() string { + return "stub" +} + +func (a stubAddr) String() string { + return string(a) +} From 2435651efc1682cb472a5c221a32d775ad74ddeb Mon Sep 17 00:00:00 2001 From: pionxe Date: Mon, 13 Apr 2026 22:03:51 +0800 Subject: [PATCH 3/9] =?UTF-8?q?fix(gateway):=20=E4=BF=AE=E5=A4=8D=20Accept?= =?UTF-8?q?=20=E4=B8=8E=20Close=20=E4=B9=8B=E9=97=B4=E7=9A=84=E8=BF=9E?= =?UTF-8?q?=E6=8E=A5=E7=99=BB=E8=AE=B0=E7=AB=9E=E6=80=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/gateway/server.go | 7 ------- 1 file changed, 7 deletions(-) diff --git a/internal/gateway/server.go b/internal/gateway/server.go index 2f2795cf..6558da1a 100644 --- a/internal/gateway/server.go +++ b/internal/gateway/server.go @@ -163,13 +163,6 @@ func (s *Server) snapshotConnections() map[net.Conn]struct{} { return copied } -// trackConnection 记录活跃连接,便于关闭时统一清理。 -func (s *Server) trackConnection(conn net.Conn) { - s.mu.Lock() - defer s.mu.Unlock() - s.conns[conn] = struct{}{} -} - // registerConnection 在服务可用时登记连接,若网关已关闭则拒绝登记。 func (s *Server) registerConnection(conn net.Conn) bool { s.mu.Lock() From 79d08ddfa3b130af8c822123db846d17ab111292 Mon Sep 17 00:00:00 2001 From: pionxe Date: Mon, 13 Apr 2026 22:22:49 +0800 Subject: [PATCH 4/9] =?UTF-8?q?fix(gateway):=20=E4=B8=BA=20JSON=20?= =?UTF-8?q?=E8=A7=A3=E7=A0=81=E5=A2=9E=E5=8A=A0=E6=9C=80=E5=A4=A7=E5=B8=A7?= =?UTF-8?q?=E9=95=BF=E5=BA=A6=E9=99=90=E5=88=B6=EF=BC=8C=E9=98=B2=E6=AD=A2?= =?UTF-8?q?=E5=86=85=E5=AD=98=E6=94=BE=E5=A4=A7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/gateway/server.go | 86 +++++++++++++++++++++++++++++++-- internal/gateway/server_test.go | 85 +++++++++++++++++++++++++++++++- 2 files changed, 166 insertions(+), 5 deletions(-) diff --git a/internal/gateway/server.go b/internal/gateway/server.go index 6558da1a..22c8c780 100644 --- a/internal/gateway/server.go +++ b/internal/gateway/server.go @@ -1,6 +1,8 @@ package gateway import ( + "bufio" + "bytes" "context" "encoding/json" "errors" @@ -15,6 +17,16 @@ import ( "neo-code/internal/gateway/transport" ) +const ( + // MaxFrameSize 定义单条 JSON 帧允许的最大字节数,避免异常输入导致内存放大。 + MaxFrameSize int64 = 1 << 20 // 1 MiB +) + +var ( + errFrameTooLarge = errors.New("frame exceeds max size") + errFrameEmpty = errors.New("empty frame") +) + // ServerOptions 描述网关服务启动所需的可选配置。 type ServerOptions struct { ListenAddress string @@ -49,6 +61,7 @@ func NewServer(options ServerOptions) (*Server, error) { if logger == nil { logger = log.New(os.Stderr, "gateway: ", log.LstdFlags) } + listenFn := options.listenFn if listenFn == nil { listenFn = transport.Listen @@ -187,7 +200,7 @@ func (s *Server) handleConnection(ctx context.Context, conn net.Conn, runtimePor _ = conn.Close() }() - decoder := json.NewDecoder(conn) + reader := bufio.NewReader(conn) encoder := json.NewEncoder(conn) for { @@ -197,11 +210,22 @@ func (s *Server) handleConnection(ctx context.Context, conn net.Conn, runtimePor default: } - var frame MessageFrame - if err := decoder.Decode(&frame); err != nil { + frame, err := decodeFrame(reader) + if err != nil { if errors.Is(err, io.EOF) { return } + if errors.Is(err, errFrameEmpty) { + continue + } + if errors.Is(err, errFrameTooLarge) { + s.logger.Printf("decode frame failed: %v", err) + _ = encoder.Encode(errorFrame(MessageFrame{}, NewFrameError( + ErrorCodeInvalidFrame, + fmt.Sprintf("frame exceeds max size %d bytes", MaxFrameSize), + ))) + return + } s.logger.Printf("decode frame failed: %v", err) _ = encoder.Encode(errorFrame(MessageFrame{}, NewFrameError(ErrorCodeInvalidFrame, "invalid json frame"))) @@ -216,6 +240,62 @@ func (s *Server) handleConnection(ctx context.Context, conn net.Conn, runtimePor } } +// decodeFrame 从连接读取一条 JSON 帧并执行长度与格式校验。 +func decodeFrame(reader *bufio.Reader) (MessageFrame, error) { + payload, err := readFramePayload(reader, MaxFrameSize) + if err != nil { + return MessageFrame{}, err + } + + limitedReader := &io.LimitedReader{R: bytes.NewReader(payload), N: MaxFrameSize} + decoder := json.NewDecoder(limitedReader) + + var frame MessageFrame + if err := decoder.Decode(&frame); err != nil { + return MessageFrame{}, err + } + + var trailing any + if err := decoder.Decode(&trailing); !errors.Is(err, io.EOF) { + return MessageFrame{}, fmt.Errorf("frame contains trailing json values") + } + + return frame, nil +} + +// readFramePayload 按换行边界读取单条帧,并限制单帧最大字节数。 +func readFramePayload(reader *bufio.Reader, maxSize int64) ([]byte, error) { + var payload []byte + + for { + chunk, err := reader.ReadSlice('\n') + if int64(len(payload)+len(chunk)) > maxSize { + return nil, errFrameTooLarge + } + payload = append(payload, chunk...) + + if err == nil { + break + } + if errors.Is(err, bufio.ErrBufferFull) { + continue + } + if errors.Is(err, io.EOF) { + if len(payload) == 0 { + return nil, io.EOF + } + break + } + return nil, err + } + + payload = bytes.TrimSpace(payload) + if len(payload) == 0 { + return nil, errFrameEmpty + } + return payload, nil +} + // dispatchFrame 根据请求动作生成响应帧。 func (s *Server) dispatchFrame(_ context.Context, frame MessageFrame, runtimePort RuntimePort) MessageFrame { _ = runtimePort diff --git a/internal/gateway/server_test.go b/internal/gateway/server_test.go index a7e6d362..aa7cae26 100644 --- a/internal/gateway/server_test.go +++ b/internal/gateway/server_test.go @@ -3,7 +3,12 @@ package gateway import ( "context" "encoding/json" + "errors" + "fmt" + "io" + "log" "net" + "strings" "testing" "time" ) @@ -11,7 +16,7 @@ import ( func TestServerHandleConnectionPing(t *testing.T) { t.Parallel() - server := &Server{} + server := &Server{logger: log.New(io.Discard, "", 0)} serverConn, clientConn := net.Pipe() done := make(chan struct{}) @@ -65,7 +70,7 @@ func TestServerHandleConnectionPing(t *testing.T) { func TestServerHandleConnectionUnsupportedAction(t *testing.T) { t.Parallel() - server := &Server{} + server := &Server{logger: log.New(io.Discard, "", 0)} serverConn, clientConn := net.Pipe() done := make(chan struct{}) @@ -108,3 +113,79 @@ func TestServerHandleConnectionUnsupportedAction(t *testing.T) { t.Fatal("handleConnection did not exit") } } + +func TestServerHandleConnectionRejectsOversizedFrame(t *testing.T) { + t.Parallel() + + server := &Server{logger: log.New(io.Discard, "", 0)} + serverConn, clientConn := net.Pipe() + done := make(chan struct{}) + + go func() { + defer close(done) + server.handleConnection(context.Background(), serverConn, nil) + }() + + decoder := json.NewDecoder(clientConn) + oversizedPayload := strings.Repeat("a", int(MaxFrameSize)+128) + requestFrame := fmt.Sprintf( + `{"type":"request","action":"ping","request_id":"req-oversize","input_text":"%s"}`+"\n", + oversizedPayload, + ) + + writeDone := make(chan error, 1) + go func() { + _, err := io.WriteString(clientConn, requestFrame) + writeDone <- err + }() + + var response MessageFrame + if err := decoder.Decode(&response); err != nil { + t.Fatalf("decode oversized response: %v", err) + } + if response.Type != FrameTypeError { + t.Fatalf("response type = %q, want %q", response.Type, FrameTypeError) + } + if response.Error == nil { + t.Fatal("response error is nil") + } + if response.Error.Code != ErrorCodeInvalidFrame.String() { + t.Fatalf("error code = %q, want %q", response.Error.Code, ErrorCodeInvalidFrame.String()) + } + if !strings.Contains(response.Error.Message, "frame exceeds max size") { + t.Fatalf("error message = %q, want contains %q", response.Error.Message, "frame exceeds max size") + } + + select { + case <-writeDone: + case <-time.After(2 * time.Second): + t.Fatal("write oversized frame timed out") + } + + readDone := make(chan error, 1) + go func() { + var buf [1]byte + _, err := clientConn.Read(buf[:]) + readDone <- err + }() + + select { + case err := <-readDone: + if errors.Is(err, io.EOF) { + break + } + if err != nil && strings.Contains(err.Error(), "closed pipe") { + break + } + t.Fatalf("expected connection close after oversized frame, got %v", err) + case <-time.After(500 * time.Millisecond): + t.Fatal("connection was not closed after oversized frame") + } + + _ = clientConn.Close() + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("handleConnection did not exit") + } +} From d5a33fa76ec58c9f23a86dc9397baea0fcabebaa Mon Sep 17 00:00:00 2001 From: pionxe Date: Mon, 13 Apr 2026 22:25:14 +0800 Subject: [PATCH 5/9] =?UTF-8?q?fix(gateway):=20=E6=94=B6=E7=B4=A7=20Unix?= =?UTF-8?q?=20socket=20=E7=9B=AE=E5=BD=95=E4=B8=8E=E6=96=87=E4=BB=B6?= =?UTF-8?q?=E6=9D=83=E9=99=90=E8=87=B3=200700/0600?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/gateway/transport/listen_unix.go | 17 ++++++++++++++++- internal/gateway/transport/listen_unix_test.go | 17 +++++++++++++++++ 2 files changed, 33 insertions(+), 1 deletion(-) diff --git a/internal/gateway/transport/listen_unix.go b/internal/gateway/transport/listen_unix.go index 3bc9c82b..d5a3ad20 100644 --- a/internal/gateway/transport/listen_unix.go +++ b/internal/gateway/transport/listen_unix.go @@ -9,11 +9,22 @@ import ( "path/filepath" ) +const ( + // unixSocketDirPerm 定义 Unix socket 父目录权限(仅当前用户可访问)。 + unixSocketDirPerm os.FileMode = 0o700 + // unixSocketFilePerm 定义 Unix socket 文件权限(仅当前用户可读写)。 + unixSocketFilePerm os.FileMode = 0o600 +) + // Listen 在 Unix 系统上启动 UDS 监听并在关闭时清理 socket 文件。 func Listen(address string) (net.Listener, error) { - if err := os.MkdirAll(filepath.Dir(address), 0o755); err != nil { + socketDir := filepath.Dir(address) + if err := os.MkdirAll(socketDir, unixSocketDirPerm); err != nil { return nil, fmt.Errorf("gateway: create socket dir: %w", err) } + if err := os.Chmod(socketDir, unixSocketDirPerm); err != nil { + return nil, fmt.Errorf("gateway: set socket dir permission: %w", err) + } if err := removeStaleUnixSocket(address); err != nil { return nil, err @@ -23,6 +34,10 @@ func Listen(address string) (net.Listener, error) { if err != nil { return nil, fmt.Errorf("gateway: listen unix socket: %w", err) } + if err := os.Chmod(address, unixSocketFilePerm); err != nil { + _ = listener.Close() + return nil, fmt.Errorf("gateway: set socket file permission: %w", err) + } return newCleanupListener(listener, func() error { if err := os.Remove(address); err != nil && !os.IsNotExist(err) { diff --git a/internal/gateway/transport/listen_unix_test.go b/internal/gateway/transport/listen_unix_test.go index f0844d46..79dcdf38 100644 --- a/internal/gateway/transport/listen_unix_test.go +++ b/internal/gateway/transport/listen_unix_test.go @@ -14,6 +14,7 @@ func TestListenUnixAcceptsConnectionAndCleansSocket(t *testing.T) { t.Parallel() socketPath := filepath.Join(t.TempDir(), "gateway.sock") + socketDir := filepath.Dir(socketPath) listener, err := Listen(socketPath) if err != nil { t.Fatalf("listen unix socket: %v", err) @@ -39,6 +40,22 @@ func TestListenUnixAcceptsConnectionAndCleansSocket(t *testing.T) { } _ = conn.Close() + socketInfo, err := os.Stat(socketPath) + if err != nil { + t.Fatalf("stat socket file: %v", err) + } + if got := socketInfo.Mode() & os.ModePerm; got != unixSocketFilePerm { + t.Fatalf("socket file perm = %#o, want %#o", got, unixSocketFilePerm) + } + + dirInfo, err := os.Stat(socketDir) + if err != nil { + t.Fatalf("stat socket dir: %v", err) + } + if got := dirInfo.Mode() & os.ModePerm; got != unixSocketDirPerm { + t.Fatalf("socket dir perm = %#o, want %#o", got, unixSocketDirPerm) + } + select { case acceptErr := <-acceptDone: if acceptErr != nil { From 74459ff8224e2f3a2f46735cca196f0ecd8edb40 Mon Sep 17 00:00:00 2001 From: pionxe Date: Mon, 13 Apr 2026 22:30:23 +0800 Subject: [PATCH 6/9] =?UTF-8?q?fix(gateway):=20=E4=B8=BA=20Windows=20Named?= =?UTF-8?q?=20Pipe=20=E8=AE=BE=E7=BD=AE=E6=9C=80=E5=B0=8F=E5=8C=96?= =?UTF-8?q?=E8=AE=BF=E9=97=AE=E6=8E=A7=E5=88=B6=20ACL?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/gateway/transport/listen_windows.go | 75 ++++++++++++++++++- .../gateway/transport/listen_windows_test.go | 41 ++++++++++ 2 files changed, 114 insertions(+), 2 deletions(-) diff --git a/internal/gateway/transport/listen_windows.go b/internal/gateway/transport/listen_windows.go index eebb6ccc..fb6993ae 100644 --- a/internal/gateway/transport/listen_windows.go +++ b/internal/gateway/transport/listen_windows.go @@ -7,13 +7,84 @@ import ( "net" "github.com/Microsoft/go-winio" + "golang.org/x/sys/windows" ) -// Listen 在 Windows 系统上启动 Named Pipe 监听。 +const ( + pipeSDDLDiscretionaryACL = "D:P" +) + +// Listen 在 Windows 系统上启动 Named Pipe 监听,并显式收敛访问控制。 func Listen(address string) (net.Listener, error) { - listener, err := winio.ListenPipe(address, nil) + config, err := newRestrictedPipeConfig() + if err != nil { + return nil, err + } + + listener, err := winio.ListenPipe(address, config) if err != nil { return nil, fmt.Errorf("gateway: listen named pipe: %w", err) } return newCleanupListener(listener, nil), nil } + +// newRestrictedPipeConfig 构建最小权限 PipeConfig,仅允许 SYSTEM、管理员组与当前用户访问。 +func newRestrictedPipeConfig() (*winio.PipeConfig, error) { + securityDescriptor, err := buildRestrictedPipeSecurityDescriptor() + if err != nil { + return nil, err + } + return &winio.PipeConfig{SecurityDescriptor: securityDescriptor}, nil +} + +// buildRestrictedPipeSecurityDescriptor 生成管道 ACL 的 SDDL 表达式。 +func buildRestrictedPipeSecurityDescriptor() (string, error) { + currentUserSID, err := currentProcessUserSID() + if err != nil { + return "", err + } + + systemSID, err := wellKnownSIDString(windows.WinLocalSystemSid) + if err != nil { + return "", fmt.Errorf("gateway: resolve local-system sid: %w", err) + } + + administratorsSID, err := wellKnownSIDString(windows.WinBuiltinAdministratorsSid) + if err != nil { + return "", fmt.Errorf("gateway: resolve administrators sid: %w", err) + } + + return fmt.Sprintf( + "%s(%s)(%s)(%s)", + pipeSDDLDiscretionaryACL, + allowGenericAllAce(systemSID), + allowGenericAllAce(administratorsSID), + allowGenericAllAce(currentUserSID), + ), nil +} + +// currentProcessUserSID 返回当前进程用户的 SID 字符串。 +func currentProcessUserSID() (string, error) { + tokenUser, err := windows.GetCurrentProcessToken().GetTokenUser() + if err != nil { + return "", fmt.Errorf("gateway: query current token user: %w", err) + } + if tokenUser == nil || tokenUser.User.Sid == nil { + return "", fmt.Errorf("gateway: current token user sid is empty") + } + return tokenUser.User.Sid.String(), nil +} + +// wellKnownSIDString 将系统内置 SID 类型转换为 SID 字符串。 +func wellKnownSIDString(sidType windows.WELL_KNOWN_SID_TYPE) (string, error) { + sid, err := windows.CreateWellKnownSid(sidType) + if err != nil { + return "", err + } + return sid.String(), nil +} + +// allowGenericAllAce 为指定 SID 生成“完全控制”ACE。 +func allowGenericAllAce(sid string) string { + return fmt.Sprintf("A;;GA;;;%s", sid) +} diff --git a/internal/gateway/transport/listen_windows_test.go b/internal/gateway/transport/listen_windows_test.go index 7c331c10..338f0b6b 100644 --- a/internal/gateway/transport/listen_windows_test.go +++ b/internal/gateway/transport/listen_windows_test.go @@ -4,10 +4,12 @@ package transport import ( "fmt" + "strings" "testing" "time" "github.com/Microsoft/go-winio" + "golang.org/x/sys/windows" ) func TestListenNamedPipeAcceptsConnection(t *testing.T) { @@ -49,3 +51,42 @@ func TestListenNamedPipeAcceptsConnection(t *testing.T) { t.Fatal("accept timed out") } } + +func TestNewRestrictedPipeConfigContainsExpectedSIDs(t *testing.T) { + t.Parallel() + + config, err := newRestrictedPipeConfig() + if err != nil { + t.Fatalf("new restricted pipe config: %v", err) + } + if config == nil { + t.Fatal("pipe config is nil") + } + if config.SecurityDescriptor == "" { + t.Fatal("security descriptor is empty") + } + + currentUserSID, err := currentProcessUserSID() + if err != nil { + t.Fatalf("current user sid: %v", err) + } + if !strings.Contains(config.SecurityDescriptor, currentUserSID) { + t.Fatalf("security descriptor does not contain current user sid %q", currentUserSID) + } + + systemSID, err := wellKnownSIDString(windows.WinLocalSystemSid) + if err != nil { + t.Fatalf("system sid: %v", err) + } + if !strings.Contains(config.SecurityDescriptor, systemSID) { + t.Fatalf("security descriptor does not contain system sid %q", systemSID) + } + + adminSID, err := wellKnownSIDString(windows.WinBuiltinAdministratorsSid) + if err != nil { + t.Fatalf("administrators sid: %v", err) + } + if !strings.Contains(config.SecurityDescriptor, adminSID) { + t.Fatalf("security descriptor does not contain administrators sid %q", adminSID) + } +} From 7b04ef51adac280ebf49e84202d4bea593cd4ec7 Mon Sep 17 00:00:00 2001 From: pionxe Date: Mon, 13 Apr 2026 22:51:33 +0800 Subject: [PATCH 7/9] =?UTF-8?q?feat:=E8=A1=A5=E9=BD=90=E8=A6=86=E7=9B=96?= =?UTF-8?q?=E7=8E=87?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/gateway/server.go | 4 +- internal/gateway/server_additional_test.go | 391 ++++++++++++++++++ internal/gateway/transport/listen_windows.go | 20 +- .../transport/listen_windows_acl_test.go | 215 ++++++++++ internal/gateway/validate_additional_test.go | 89 ++++ 5 files changed, 712 insertions(+), 7 deletions(-) create mode 100644 internal/gateway/server_additional_test.go create mode 100644 internal/gateway/transport/listen_windows_acl_test.go create mode 100644 internal/gateway/validate_additional_test.go diff --git a/internal/gateway/server.go b/internal/gateway/server.go index 22c8c780..b253c26d 100644 --- a/internal/gateway/server.go +++ b/internal/gateway/server.go @@ -25,6 +25,8 @@ const ( var ( errFrameTooLarge = errors.New("frame exceeds max size") errFrameEmpty = errors.New("empty frame") + + defaultListenAddressFn = transport.DefaultListenAddress ) // ServerOptions 描述网关服务启动所需的可选配置。 @@ -50,7 +52,7 @@ type Server struct { func NewServer(options ServerOptions) (*Server, error) { listenAddress := strings.TrimSpace(options.ListenAddress) if listenAddress == "" { - resolved, err := transport.DefaultListenAddress() + resolved, err := defaultListenAddressFn() if err != nil { return nil, err } diff --git a/internal/gateway/server_additional_test.go b/internal/gateway/server_additional_test.go new file mode 100644 index 00000000..5689ddec --- /dev/null +++ b/internal/gateway/server_additional_test.go @@ -0,0 +1,391 @@ +package gateway + +import ( + "bufio" + "context" + "encoding/json" + "errors" + "io" + "log" + "net" + "strings" + "sync" + "testing" + "time" +) + +func TestNewServerUsesDefaultsAndOverrides(t *testing.T) { + originalDefaultListenAddress := defaultListenAddressFn + defaultListenAddressFn = func() (string, error) { + return "default-address", nil + } + t.Cleanup(func() { + defaultListenAddressFn = originalDefaultListenAddress + }) + + server, err := NewServer(ServerOptions{}) + if err != nil { + t.Fatalf("new server with defaults: %v", err) + } + if server.ListenAddress() != "default-address" { + t.Fatalf("default listen address = %q, want %q", server.ListenAddress(), "default-address") + } + if server.logger == nil { + t.Fatal("default logger should not be nil") + } + if server.listenFn == nil { + t.Fatal("default listen function should not be nil") + } + + customLogger := log.New(io.Discard, "custom", 0) + customServer, err := NewServer(ServerOptions{ + ListenAddress: " custom-address ", + Logger: customLogger, + listenFn: func(string) (net.Listener, error) { + return nil, nil + }, + }) + if err != nil { + t.Fatalf("new server with custom options: %v", err) + } + if customServer.ListenAddress() != "custom-address" { + t.Fatalf("custom listen address = %q, want %q", customServer.ListenAddress(), "custom-address") + } + if customServer.logger != customLogger { + t.Fatal("custom logger was not used") + } +} + +func TestNewServerReturnsDefaultAddressError(t *testing.T) { + originalDefaultListenAddress := defaultListenAddressFn + defaultListenAddressFn = func() (string, error) { + return "", errors.New("default address failed") + } + t.Cleanup(func() { + defaultListenAddressFn = originalDefaultListenAddress + }) + + _, err := NewServer(ServerOptions{}) + if err == nil { + t.Fatal("expected error when default listen address fails") + } + if !strings.Contains(err.Error(), "default address failed") { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestServerIsClosedState(t *testing.T) { + server := &Server{} + if !server.isClosed() { + t.Fatal("expected server to be closed when listener is nil") + } + + server.listener = &simpleListener{} + if server.isClosed() { + t.Fatal("expected server to be open when listener exists") + } +} + +func TestServeReturnsListenError(t *testing.T) { + server, err := NewServer(ServerOptions{ + ListenAddress: "listen-error", + Logger: log.New(io.Discard, "", 0), + listenFn: func(string) (net.Listener, error) { + return nil, errors.New("listen failed") + }, + }) + if err != nil { + t.Fatalf("new server: %v", err) + } + + serveErr := server.Serve(context.Background(), nil) + if serveErr == nil || !strings.Contains(serveErr.Error(), "listen failed") { + t.Fatalf("expected listen failure, got %v", serveErr) + } +} + +func TestServeRejectsAlreadyServing(t *testing.T) { + created := &simpleListener{} + server, err := NewServer(ServerOptions{ + ListenAddress: "already-serving", + Logger: log.New(io.Discard, "", 0), + listenFn: func(string) (net.Listener, error) { + return created, nil + }, + }) + if err != nil { + t.Fatalf("new server: %v", err) + } + server.listener = &simpleListener{} + + serveErr := server.Serve(context.Background(), nil) + if serveErr == nil || !strings.Contains(serveErr.Error(), "already serving") { + t.Fatalf("expected already serving error, got %v", serveErr) + } + if !created.closed { + t.Fatal("newly created listener should be closed when server is already serving") + } +} + +func TestServeReturnsAcceptError(t *testing.T) { + listener := &scriptedListener{results: []acceptResult{{err: errors.New("accept failed")}}} + server, err := NewServer(ServerOptions{ + ListenAddress: "accept-error", + Logger: log.New(io.Discard, "", 0), + listenFn: func(string) (net.Listener, error) { + return listener, nil + }, + }) + if err != nil { + t.Fatalf("new server: %v", err) + } + + serveErr := server.Serve(context.Background(), nil) + if serveErr == nil || !strings.Contains(serveErr.Error(), "accept connection") { + t.Fatalf("expected accept error, got %v", serveErr) + } +} + +func TestServeSkipsConnectionWhenRegisterRejected(t *testing.T) { + serverConn, clientConn := net.Pipe() + defer clientConn.Close() + + listener := &scriptedListener{ + results: []acceptResult{ + { + conn: serverConn, + }, + {err: net.ErrClosed}, + }, + } + + server, err := NewServer(ServerOptions{ + ListenAddress: "register-reject", + Logger: log.New(io.Discard, "", 0), + listenFn: func(string) (net.Listener, error) { + return listener, nil + }, + }) + if err != nil { + t.Fatalf("new server: %v", err) + } + + listener.results[0].beforeReturn = func() { + server.mu.Lock() + server.listener = nil + server.mu.Unlock() + } + + serveErr := server.Serve(context.Background(), nil) + if serveErr != nil { + t.Fatalf("serve should exit cleanly when listener closed, got %v", serveErr) + } + + readDone := make(chan error, 1) + go func() { + var buf [1]byte + _, err := clientConn.Read(buf[:]) + readDone <- err + }() + + select { + case err := <-readDone: + if !errors.Is(err, io.EOF) && (err == nil || !strings.Contains(err.Error(), "closed pipe")) { + t.Fatalf("expected rejected connection to be closed, got %v", err) + } + case <-time.After(500 * time.Millisecond): + t.Fatal("rejected connection was not closed") + } +} + +func TestCloseReturnsContextErrorWhenWaitCanceled(t *testing.T) { + server := &Server{conns: make(map[net.Conn]struct{})} + server.wg.Add(1) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + err := server.Close(ctx) + if !errors.Is(err, context.Canceled) { + t.Fatalf("close error = %v, want context canceled", err) + } + + server.wg.Done() +} + +func TestDecodeFrameTrailingJSON(t *testing.T) { + reader := bufio.NewReader(strings.NewReader(`{"type":"request","action":"ping"} {"extra":1}` + "\n")) + _, err := decodeFrame(reader) + if err == nil || !strings.Contains(err.Error(), "trailing") { + t.Fatalf("expected trailing json error, got %v", err) + } +} + +func TestReadFramePayloadBranches(t *testing.T) { + if _, err := readFramePayload(bufio.NewReader(strings.NewReader("")), MaxFrameSize); !errors.Is(err, io.EOF) { + t.Fatalf("empty payload error = %v, want io.EOF", err) + } + + payload, err := readFramePayload(bufio.NewReader(strings.NewReader("{\"type\":\"request\"}")), MaxFrameSize) + if err != nil { + t.Fatalf("payload without newline should decode at EOF: %v", err) + } + if string(payload) != `{"type":"request"}` { + t.Fatalf("payload mismatch: %q", string(payload)) + } + + tooLarge := strings.Repeat("a", 5000) + if _, err := readFramePayload(bufio.NewReaderSize(strings.NewReader(tooLarge), 64), 1024); !errors.Is(err, errFrameTooLarge) { + t.Fatalf("oversized payload error = %v, want errFrameTooLarge", err) + } + + if _, err := readFramePayload(bufio.NewReader(&failingReader{}), MaxFrameSize); err == nil || err.Error() != "read failed" { + t.Fatalf("expected read failure, got %v", err) + } +} + +func TestDispatchFrameNonRequest(t *testing.T) { + server := &Server{} + response := server.dispatchFrame(context.Background(), MessageFrame{Type: FrameTypeEvent, Action: FrameActionPing}, nil) + if response.Type != FrameTypeError { + t.Fatalf("response type = %q, want %q", response.Type, FrameTypeError) + } + if response.Error == nil || response.Error.Code != ErrorCodeInvalidFrame.String() { + t.Fatalf("response error = %#v, want invalid frame", response.Error) + } +} + +func TestDispatchFrameValidationError(t *testing.T) { + server := &Server{} + response := server.dispatchFrame(context.Background(), MessageFrame{Type: FrameType("invalid")}, nil) + if response.Type != FrameTypeError { + t.Fatalf("response type = %q, want %q", response.Type, FrameTypeError) + } + if response.Error == nil || response.Error.Code != ErrorCodeInvalidFrame.String() { + t.Fatalf("response error = %#v, want invalid frame", response.Error) + } +} + +func TestServerHandleConnectionSkipsEmptyFrame(t *testing.T) { + server := &Server{logger: log.New(io.Discard, "", 0)} + serverConn, clientConn := net.Pipe() + done := make(chan struct{}) + go func() { + defer close(done) + server.handleConnection(context.Background(), serverConn, nil) + }() + + _, _ = io.WriteString(clientConn, "\n") + _, _ = io.WriteString(clientConn, `{"type":"request","action":"ping","request_id":"empty-then-ping"}`+"\n") + + decoder := json.NewDecoder(clientConn) + var response MessageFrame + if err := decoder.Decode(&response); err != nil { + t.Fatalf("decode response: %v", err) + } + if response.Type != FrameTypeAck || response.Action != FrameActionPing { + t.Fatalf("unexpected response after empty frame: %#v", response) + } + + _ = clientConn.Close() + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("handleConnection did not exit") + } +} + +func TestServerHandleConnectionInvalidJSONFrame(t *testing.T) { + server := &Server{logger: log.New(io.Discard, "", 0)} + serverConn, clientConn := net.Pipe() + done := make(chan struct{}) + go func() { + defer close(done) + server.handleConnection(context.Background(), serverConn, nil) + }() + + _, _ = io.WriteString(clientConn, "{invalid-json}\n") + decoder := json.NewDecoder(clientConn) + var response MessageFrame + if err := decoder.Decode(&response); err != nil { + t.Fatalf("decode response: %v", err) + } + if response.Type != FrameTypeError { + t.Fatalf("response type = %q, want %q", response.Type, FrameTypeError) + } + if response.Error == nil || response.Error.Code != ErrorCodeInvalidFrame.String() { + t.Fatalf("response error = %#v, want invalid frame", response.Error) + } + + _ = clientConn.Close() + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("handleConnection did not exit") + } +} + +type failingReader struct{} + +func (r *failingReader) Read(_ []byte) (int, error) { + return 0, errors.New("read failed") +} + +type simpleListener struct { + closed bool +} + +func (l *simpleListener) Accept() (net.Conn, error) { + return nil, net.ErrClosed +} + +func (l *simpleListener) Close() error { + l.closed = true + return nil +} + +func (l *simpleListener) Addr() net.Addr { + return stubAddr("simple") +} + +type acceptResult struct { + conn net.Conn + err error + beforeReturn func() +} + +type scriptedListener struct { + mu sync.Mutex + results []acceptResult + closed bool +} + +func (l *scriptedListener) Accept() (net.Conn, error) { + l.mu.Lock() + defer l.mu.Unlock() + if len(l.results) == 0 { + return nil, net.ErrClosed + } + result := l.results[0] + l.results = l.results[1:] + if result.beforeReturn != nil { + result.beforeReturn() + } + if result.err != nil { + return nil, result.err + } + if result.conn == nil { + return nil, net.ErrClosed + } + return result.conn, nil +} + +func (l *scriptedListener) Close() error { + l.mu.Lock() + defer l.mu.Unlock() + l.closed = true + return nil +} + +func (l *scriptedListener) Addr() net.Addr { + return stubAddr("scripted") +} diff --git a/internal/gateway/transport/listen_windows.go b/internal/gateway/transport/listen_windows.go index fb6993ae..4ed05073 100644 --- a/internal/gateway/transport/listen_windows.go +++ b/internal/gateway/transport/listen_windows.go @@ -14,6 +14,14 @@ const ( pipeSDDLDiscretionaryACL = "D:P" ) +var ( + listenPipeFn = winio.ListenPipe + currentProcessUserSIDFn = currentProcessUserSID + wellKnownSIDStringFn = wellKnownSIDString + getCurrentProcessTokenFn = windows.GetCurrentProcessToken + createWellKnownSIDFn = windows.CreateWellKnownSid +) + // Listen 在 Windows 系统上启动 Named Pipe 监听,并显式收敛访问控制。 func Listen(address string) (net.Listener, error) { config, err := newRestrictedPipeConfig() @@ -21,7 +29,7 @@ func Listen(address string) (net.Listener, error) { return nil, err } - listener, err := winio.ListenPipe(address, config) + listener, err := listenPipeFn(address, config) if err != nil { return nil, fmt.Errorf("gateway: listen named pipe: %w", err) } @@ -39,17 +47,17 @@ func newRestrictedPipeConfig() (*winio.PipeConfig, error) { // buildRestrictedPipeSecurityDescriptor 生成管道 ACL 的 SDDL 表达式。 func buildRestrictedPipeSecurityDescriptor() (string, error) { - currentUserSID, err := currentProcessUserSID() + currentUserSID, err := currentProcessUserSIDFn() if err != nil { return "", err } - systemSID, err := wellKnownSIDString(windows.WinLocalSystemSid) + systemSID, err := wellKnownSIDStringFn(windows.WinLocalSystemSid) if err != nil { return "", fmt.Errorf("gateway: resolve local-system sid: %w", err) } - administratorsSID, err := wellKnownSIDString(windows.WinBuiltinAdministratorsSid) + administratorsSID, err := wellKnownSIDStringFn(windows.WinBuiltinAdministratorsSid) if err != nil { return "", fmt.Errorf("gateway: resolve administrators sid: %w", err) } @@ -65,7 +73,7 @@ func buildRestrictedPipeSecurityDescriptor() (string, error) { // currentProcessUserSID 返回当前进程用户的 SID 字符串。 func currentProcessUserSID() (string, error) { - tokenUser, err := windows.GetCurrentProcessToken().GetTokenUser() + tokenUser, err := getCurrentProcessTokenFn().GetTokenUser() if err != nil { return "", fmt.Errorf("gateway: query current token user: %w", err) } @@ -77,7 +85,7 @@ func currentProcessUserSID() (string, error) { // wellKnownSIDString 将系统内置 SID 类型转换为 SID 字符串。 func wellKnownSIDString(sidType windows.WELL_KNOWN_SID_TYPE) (string, error) { - sid, err := windows.CreateWellKnownSid(sidType) + sid, err := createWellKnownSIDFn(sidType) if err != nil { return "", err } diff --git a/internal/gateway/transport/listen_windows_acl_test.go b/internal/gateway/transport/listen_windows_acl_test.go new file mode 100644 index 00000000..c92b8ae9 --- /dev/null +++ b/internal/gateway/transport/listen_windows_acl_test.go @@ -0,0 +1,215 @@ +//go:build windows + +package transport + +import ( + "errors" + "net" + "strings" + "testing" + + "github.com/Microsoft/go-winio" + "golang.org/x/sys/windows" +) + +func TestDefaultListenAddressWindows(t *testing.T) { + t.Parallel() + + address, err := DefaultListenAddress() + if err != nil { + t.Fatalf("default listen address: %v", err) + } + if address != defaultWindowsNamedPipePath { + t.Fatalf("default address = %q, want %q", address, defaultWindowsNamedPipePath) + } +} + +func TestNewCleanupListenerBranches(t *testing.T) { + t.Parallel() + + base := &stubNetListener{} + if got := newCleanupListener(base, nil); got != base { + t.Fatal("expected original listener when cleanup is nil") + } + + closeErr := errors.New("close failed") + cleanupErr := errors.New("cleanup failed") + wrapped := newCleanupListener(&stubNetListener{closeErr: closeErr}, func() error { return cleanupErr }) + if err := wrapped.Close(); err == nil { + t.Fatal("expected joined error") + } else { + if !errors.Is(err, closeErr) { + t.Fatalf("joined error should include close error, got %v", err) + } + if !errors.Is(err, cleanupErr) { + t.Fatalf("joined error should include cleanup error, got %v", err) + } + } +} + +func TestBuildRestrictedPipeSecurityDescriptorContainsExpectedACEs(t *testing.T) { + t.Parallel() + + sddl, err := buildRestrictedPipeSecurityDescriptor() + if err != nil { + t.Fatalf("build restricted descriptor: %v", err) + } + if !strings.HasPrefix(sddl, pipeSDDLDiscretionaryACL) { + t.Fatalf("sddl prefix = %q, want starts with %q", sddl, pipeSDDLDiscretionaryACL) + } + if strings.Count(sddl, "A;;GA;;;") != 3 { + t.Fatalf("sddl should contain 3 allow full-control ACE entries, got %q", sddl) + } +} + +func TestNewRestrictedPipeConfigErrorBranch(t *testing.T) { + originalCurrent := currentProcessUserSIDFn + currentProcessUserSIDFn = func() (string, error) { + return "", errors.New("current user failed") + } + t.Cleanup(func() { + currentProcessUserSIDFn = originalCurrent + }) + + _, err := newRestrictedPipeConfig() + if err == nil || !strings.Contains(err.Error(), "current user failed") { + t.Fatalf("expected current user error, got %v", err) + } +} + +func TestListenReturnsConfigError(t *testing.T) { + originalCurrent := currentProcessUserSIDFn + currentProcessUserSIDFn = func() (string, error) { + return "", errors.New("restricted config failed") + } + t.Cleanup(func() { + currentProcessUserSIDFn = originalCurrent + }) + + _, err := Listen(`\\.\pipe\neocode-gateway-config-error-test`) + if err == nil || !strings.Contains(err.Error(), "restricted config failed") { + t.Fatalf("expected config build failure, got %v", err) + } +} + +func TestBuildRestrictedPipeSecurityDescriptorSystemErrorBranch(t *testing.T) { + originalCurrent := currentProcessUserSIDFn + originalWellKnown := wellKnownSIDStringFn + currentProcessUserSIDFn = func() (string, error) { return "S-1-5-21-1", nil } + wellKnownSIDStringFn = func(sidType windows.WELL_KNOWN_SID_TYPE) (string, error) { + if sidType == windows.WinLocalSystemSid { + return "", errors.New("system sid failed") + } + return "S-1-5-32-544", nil + } + t.Cleanup(func() { + currentProcessUserSIDFn = originalCurrent + wellKnownSIDStringFn = originalWellKnown + }) + + _, err := buildRestrictedPipeSecurityDescriptor() + if err == nil || !strings.Contains(err.Error(), "system sid failed") { + t.Fatalf("expected system sid error, got %v", err) + } +} + +func TestBuildRestrictedPipeSecurityDescriptorAdminErrorBranch(t *testing.T) { + originalCurrent := currentProcessUserSIDFn + originalWellKnown := wellKnownSIDStringFn + currentProcessUserSIDFn = func() (string, error) { return "S-1-5-21-1", nil } + wellKnownSIDStringFn = func(sidType windows.WELL_KNOWN_SID_TYPE) (string, error) { + if sidType == windows.WinBuiltinAdministratorsSid { + return "", errors.New("admin sid failed") + } + return "S-1-5-18", nil + } + t.Cleanup(func() { + currentProcessUserSIDFn = originalCurrent + wellKnownSIDStringFn = originalWellKnown + }) + + _, err := buildRestrictedPipeSecurityDescriptor() + if err == nil || !strings.Contains(err.Error(), "admin sid failed") { + t.Fatalf("expected admin sid error, got %v", err) + } +} + +func TestListenReturnsListenPipeError(t *testing.T) { + originalCurrent := currentProcessUserSIDFn + originalWellKnown := wellKnownSIDStringFn + originalListenPipe := listenPipeFn + currentProcessUserSIDFn = func() (string, error) { return "S-1-5-21-1", nil } + wellKnownSIDStringFn = func(sidType windows.WELL_KNOWN_SID_TYPE) (string, error) { + if sidType == windows.WinLocalSystemSid { + return "S-1-5-18", nil + } + if sidType == windows.WinBuiltinAdministratorsSid { + return "S-1-5-32-544", nil + } + return "", errors.New("unexpected sid type") + } + listenPipeFn = func(_ string, _ *winio.PipeConfig) (net.Listener, error) { + return nil, errors.New("listen pipe failed") + } + t.Cleanup(func() { + currentProcessUserSIDFn = originalCurrent + wellKnownSIDStringFn = originalWellKnown + listenPipeFn = originalListenPipe + }) + + _, err := Listen(`\\.\pipe\neocode-gateway-error-test`) + if err == nil || !strings.Contains(err.Error(), "listen pipe failed") { + t.Fatalf("expected listen pipe failure, got %v", err) + } +} + +func TestCurrentProcessUserSIDErrorBranch(t *testing.T) { + originalTokenFn := getCurrentProcessTokenFn + getCurrentProcessTokenFn = func() windows.Token { + return windows.Token(0) + } + t.Cleanup(func() { + getCurrentProcessTokenFn = originalTokenFn + }) + + _, err := currentProcessUserSID() + if err == nil { + t.Fatal("expected current process token user error") + } +} + +func TestWellKnownSIDStringErrorBranch(t *testing.T) { + originalCreateWellKnownSID := createWellKnownSIDFn + createWellKnownSIDFn = func(_ windows.WELL_KNOWN_SID_TYPE) (*windows.SID, error) { + return nil, errors.New("create sid failed") + } + t.Cleanup(func() { + createWellKnownSIDFn = originalCreateWellKnownSID + }) + + _, err := wellKnownSIDString(windows.WinLocalSystemSid) + if err == nil || !strings.Contains(err.Error(), "create sid failed") { + t.Fatalf("expected create sid failure, got %v", err) + } +} + +type stubNetListener struct { + closeErr error +} + +func (l *stubNetListener) Accept() (net.Conn, error) { + return nil, net.ErrClosed +} + +func (l *stubNetListener) Close() error { + return l.closeErr +} + +func (l *stubNetListener) Addr() net.Addr { + return pipeAddr("stub") +} + +type pipeAddr string + +func (a pipeAddr) Network() string { return "pipe" } +func (a pipeAddr) String() string { return string(a) } diff --git a/internal/gateway/validate_additional_test.go b/internal/gateway/validate_additional_test.go new file mode 100644 index 00000000..e8611842 --- /dev/null +++ b/internal/gateway/validate_additional_test.go @@ -0,0 +1,89 @@ +package gateway + +import ( + "strings" + "testing" +) + +func TestDecodePermissionResolutionInputAdditionalBranches(t *testing.T) { + t.Parallel() + + t.Run("nil permission pointer", func(t *testing.T) { + var input *PermissionResolutionInput + _, err := decodePermissionResolutionInput(input) + if err == nil || !strings.Contains(err.Error(), "is nil") { + t.Fatalf("expected nil pointer error, got %v", err) + } + }) + + t.Run("marshal error", func(t *testing.T) { + payload := map[string]any{"bad": func() {}} + _, err := decodePermissionResolutionInput(payload) + if err == nil { + t.Fatal("expected marshal error") + } + }) + + t.Run("unmarshal error", func(t *testing.T) { + _, err := decodePermissionResolutionInput([]byte("not-json-object")) + if err == nil { + t.Fatal("expected unmarshal error") + } + }) +} + +func TestValidateRequestFrameRunsInputPartsValidationForCompact(t *testing.T) { + t.Parallel() + + err := ValidateFrame(MessageFrame{ + Type: FrameTypeRequest, + Action: FrameActionCompact, + SessionID: "sess-1", + InputParts: []InputPart{{ + Type: InputPartTypeText, + Text: " ", + }}, + }) + if err == nil { + t.Fatal("expected input_parts validation error") + } + if err.Code != ErrorCodeInvalidMultimodalPayload.String() { + t.Fatalf("error code = %q, want %q", err.Code, ErrorCodeInvalidMultimodalPayload.String()) + } +} + +func TestValidateFrameCancelAndListSessions(t *testing.T) { + t.Parallel() + + cancelErr := ValidateFrame(MessageFrame{ + Type: FrameTypeRequest, + Action: FrameActionCancel, + }) + if cancelErr != nil { + t.Fatalf("cancel request should be valid, got %v", cancelErr) + } + + listErr := ValidateFrame(MessageFrame{ + Type: FrameTypeRequest, + Action: FrameActionListSessions, + }) + if listErr != nil { + t.Fatalf("list_sessions request should be valid, got %v", listErr) + } +} + +func TestValidateResolvePermissionInvalidPayloadType(t *testing.T) { + t.Parallel() + + err := ValidateFrame(MessageFrame{ + Type: FrameTypeRequest, + Action: FrameActionResolvePermission, + Payload: make(chan int), + }) + if err == nil { + t.Fatal("expected invalid resolve_permission payload error") + } + if err.Code != ErrorCodeInvalidAction.String() { + t.Fatalf("error code = %q, want %q", err.Code, ErrorCodeInvalidAction.String()) + } +} From 542ba74d7b8c3c7b3364872419a4073b884eea87 Mon Sep 17 00:00:00 2001 From: xgopilot Date: Mon, 13 Apr 2026 15:00:32 +0000 Subject: [PATCH 8/9] fix(gateway): avoid chmod on pre-existing unix socket directory Generated with [codeagent](https://github.com/qbox/codeagent) Co-authored-by: pionxe <148670367+pionxe@users.noreply.github.com> --- cmd/neocode-gateway/main_test.go | 89 ++++++++++++ .../gateway/transport/address_unix_test.go | 23 +++ internal/gateway/transport/listen_unix.go | 30 +++- .../gateway/transport/listen_unix_test.go | 132 +++++++++++++++++- 4 files changed, 269 insertions(+), 5 deletions(-) create mode 100644 cmd/neocode-gateway/main_test.go create mode 100644 internal/gateway/transport/address_unix_test.go diff --git a/cmd/neocode-gateway/main_test.go b/cmd/neocode-gateway/main_test.go new file mode 100644 index 00000000..6f706e3d --- /dev/null +++ b/cmd/neocode-gateway/main_test.go @@ -0,0 +1,89 @@ +package main + +import ( + "errors" + "flag" + "os" + "strings" + "testing" +) + +func TestParseFlagsValid(t *testing.T) { + withArgs(t, []string{"neocode-gateway", "--listen", " /tmp/gateway.sock ", "--log-level", " WARN "}, func() { + listen, level, err := parseFlags() + if err != nil { + t.Fatalf("parse flags: %v", err) + } + if listen != "/tmp/gateway.sock" { + t.Fatalf("listen = %q, want %q", listen, "/tmp/gateway.sock") + } + if level != "warn" { + t.Fatalf("log level = %q, want %q", level, "warn") + } + }) +} + +func TestParseFlagsHelp(t *testing.T) { + withArgs(t, []string{"neocode-gateway", "--help"}, func() { + _, _, err := parseFlags() + if !errors.Is(err, errHelpRequested) { + t.Fatalf("parse flags error = %v, want %v", err, errHelpRequested) + } + }) +} + +func TestParseFlagsInvalidLogLevel(t *testing.T) { + withArgs(t, []string{"neocode-gateway", "--log-level", "trace"}, func() { + _, _, err := parseFlags() + if err == nil { + t.Fatal("expected invalid log level error") + } + if !strings.Contains(err.Error(), "invalid --log-level") { + t.Fatalf("error = %v, want contains %q", err, "invalid --log-level") + } + }) +} + +func TestParseFlagsUnknownFlag(t *testing.T) { + withArgs(t, []string{"neocode-gateway", "--unknown"}, func() { + _, _, err := parseFlags() + if err == nil { + t.Fatal("expected parse error") + } + if errors.Is(err, flag.ErrHelp) { + t.Fatalf("error = %v, should not be help error", err) + } + }) +} + +func TestRunHelp(t *testing.T) { + withArgs(t, []string{"neocode-gateway", "--help"}, func() { + if err := run(); err != nil { + t.Fatalf("run help: %v", err) + } + }) +} + +func TestRunInvalidLogLevel(t *testing.T) { + withArgs(t, []string{"neocode-gateway", "--log-level", "trace"}, func() { + err := run() + if err == nil { + t.Fatal("expected run error") + } + if !strings.Contains(err.Error(), "invalid --log-level") { + t.Fatalf("error = %v, want contains %q", err, "invalid --log-level") + } + }) +} + +func withArgs(t *testing.T, args []string, fn func()) { + t.Helper() + + originalArgs := os.Args + os.Args = args + defer func() { + os.Args = originalArgs + }() + + fn() +} diff --git a/internal/gateway/transport/address_unix_test.go b/internal/gateway/transport/address_unix_test.go new file mode 100644 index 00000000..568d3f43 --- /dev/null +++ b/internal/gateway/transport/address_unix_test.go @@ -0,0 +1,23 @@ +//go:build !windows + +package transport + +import ( + "path/filepath" + "testing" +) + +func TestDefaultListenAddress(t *testing.T) { + home := t.TempDir() + t.Setenv("HOME", home) + + address, err := DefaultListenAddress() + if err != nil { + t.Fatalf("default listen address: %v", err) + } + + want := filepath.Join(home, defaultUnixSocketRelativePath) + if address != want { + t.Fatalf("default listen address = %q, want %q", address, want) + } +} diff --git a/internal/gateway/transport/listen_unix.go b/internal/gateway/transport/listen_unix.go index d5a3ad20..de2c3430 100644 --- a/internal/gateway/transport/listen_unix.go +++ b/internal/gateway/transport/listen_unix.go @@ -19,11 +19,14 @@ const ( // Listen 在 Unix 系统上启动 UDS 监听并在关闭时清理 socket 文件。 func Listen(address string) (net.Listener, error) { socketDir := filepath.Dir(address) - if err := os.MkdirAll(socketDir, unixSocketDirPerm); err != nil { - return nil, fmt.Errorf("gateway: create socket dir: %w", err) + created, err := ensureSocketDir(socketDir) + if err != nil { + return nil, err } - if err := os.Chmod(socketDir, unixSocketDirPerm); err != nil { - return nil, fmt.Errorf("gateway: set socket dir permission: %w", err) + if created { + if err := os.Chmod(socketDir, unixSocketDirPerm); err != nil { + return nil, fmt.Errorf("gateway: set socket dir permission: %w", err) + } } if err := removeStaleUnixSocket(address); err != nil { @@ -47,6 +50,25 @@ func Listen(address string) (net.Listener, error) { }), nil } +// ensureSocketDir 确保 socket 父目录可用,并返回该目录是否由当前流程创建。 +func ensureSocketDir(socketDir string) (bool, error) { + info, err := os.Stat(socketDir) + if err == nil { + if !info.IsDir() { + return false, fmt.Errorf("gateway: socket dir path exists and is not directory: %s", socketDir) + } + return false, nil + } + if !os.IsNotExist(err) { + return false, fmt.Errorf("gateway: stat socket dir: %w", err) + } + + if err := os.MkdirAll(socketDir, unixSocketDirPerm); err != nil { + return false, fmt.Errorf("gateway: create socket dir: %w", err) + } + return true, nil +} + // removeStaleUnixSocket 清理历史残留的 socket 文件,避免监听失败。 func removeStaleUnixSocket(address string) error { info, err := os.Lstat(address) diff --git a/internal/gateway/transport/listen_unix_test.go b/internal/gateway/transport/listen_unix_test.go index 79dcdf38..4e103a5d 100644 --- a/internal/gateway/transport/listen_unix_test.go +++ b/internal/gateway/transport/listen_unix_test.go @@ -3,9 +3,11 @@ package transport import ( + "errors" "net" "os" "path/filepath" + "strings" "testing" "time" ) @@ -13,7 +15,7 @@ import ( func TestListenUnixAcceptsConnectionAndCleansSocket(t *testing.T) { t.Parallel() - socketPath := filepath.Join(t.TempDir(), "gateway.sock") + socketPath := filepath.Join(t.TempDir(), "run", "gateway.sock") socketDir := filepath.Dir(socketPath) listener, err := Listen(socketPath) if err != nil { @@ -73,3 +75,131 @@ func TestListenUnixAcceptsConnectionAndCleansSocket(t *testing.T) { t.Fatalf("socket file should be removed on close, stat err: %v", err) } } + +func TestListenUnixDoesNotChmodExistingDir(t *testing.T) { + t.Parallel() + + parentDir := filepath.Join(t.TempDir(), "existing") + if err := os.MkdirAll(parentDir, 0o755); err != nil { + t.Fatalf("create parent dir: %v", err) + } + + socketPath := filepath.Join(parentDir, "gateway.sock") + listener, err := Listen(socketPath) + if err != nil { + t.Fatalf("listen unix socket: %v", err) + } + defer func() { + _ = listener.Close() + }() + + dirInfo, err := os.Stat(parentDir) + if err != nil { + t.Fatalf("stat parent dir: %v", err) + } + if got := dirInfo.Mode() & os.ModePerm; got != 0o755 { + t.Fatalf("existing dir perm = %#o, want %#o", got, 0o755) + } +} + +func TestListenUnixSocketDirPathIsFile(t *testing.T) { + t.Parallel() + + baseDir := t.TempDir() + filePath := filepath.Join(baseDir, "not-dir") + if err := os.WriteFile(filePath, []byte("x"), 0o600); err != nil { + t.Fatalf("write marker file: %v", err) + } + + socketPath := filepath.Join(filePath, "gateway.sock") + _, err := Listen(socketPath) + if err == nil { + t.Fatal("expected error when socket dir path is file") + } + if !strings.Contains(err.Error(), "is not directory") { + t.Fatalf("error = %v, want contains %q", err, "is not directory") + } +} + +func TestRemoveStaleUnixSocket(t *testing.T) { + t.Parallel() + + socketPath := filepath.Join(t.TempDir(), "gateway.sock") + listener, err := net.Listen("unix", socketPath) + if err != nil { + t.Fatalf("listen unix socket: %v", err) + } + _ = listener.Close() + + if err := removeStaleUnixSocket(socketPath); err != nil { + t.Fatalf("remove stale socket: %v", err) + } + if _, err := os.Stat(socketPath); !os.IsNotExist(err) { + t.Fatalf("socket should be removed, stat err: %v", err) + } +} + +func TestRemoveStaleUnixSocketNonSocketPath(t *testing.T) { + t.Parallel() + + filePath := filepath.Join(t.TempDir(), "plain-file") + if err := os.WriteFile(filePath, []byte("x"), 0o600); err != nil { + t.Fatalf("write marker file: %v", err) + } + + err := removeStaleUnixSocket(filePath) + if err == nil { + t.Fatal("expected error when stale path is non-socket") + } + if !strings.Contains(err.Error(), "is not socket") { + t.Fatalf("error = %v, want contains %q", err, "is not socket") + } +} + +func TestRemoveStaleUnixSocketNotExist(t *testing.T) { + t.Parallel() + + err := removeStaleUnixSocket(filepath.Join(t.TempDir(), "missing.sock")) + if err != nil { + t.Fatalf("remove missing stale socket: %v", err) + } +} + +func TestNewCleanupListenerWithoutCleanup(t *testing.T) { + t.Parallel() + + baseListener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen tcp: %v", err) + } + defer func() { + _ = baseListener.Close() + }() + + wrapped := newCleanupListener(baseListener, nil) + if wrapped != baseListener { + t.Fatal("expected original listener when cleanup is nil") + } +} + +func TestCleanupListenerCloseReturnsJoinedError(t *testing.T) { + t.Parallel() + + baseListener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen tcp: %v", err) + } + + cleanupErr := errors.New("cleanup failed") + wrapped := newCleanupListener(baseListener, func() error { + return cleanupErr + }) + + closeErr := wrapped.Close() + if closeErr == nil { + t.Fatal("expected close error") + } + if !errors.Is(closeErr, cleanupErr) { + t.Fatalf("close error = %v, want contains cleanup err %v", closeErr, cleanupErr) + } +} From 1bd05cfb8c82acf9a4e3a4662ddbe78801384fb3 Mon Sep 17 00:00:00 2001 From: xgopilot Date: Tue, 14 Apr 2026 00:34:18 +0000 Subject: [PATCH 9/9] fix(gateway): resolve review issues for server concurrency and ipc hardening Generated with [codeagent](https://github.com/qbox/codeagent) Co-authored-by: pionxe <148670367+pionxe@users.noreply.github.com> --- internal/gateway/server.go | 132 +++++++++++++++--- internal/gateway/server_additional_test.go | 103 +++++++++++++- .../transport/listen_windows_acl_test.go | 6 - .../gateway/transport/listen_windows_test.go | 4 - 4 files changed, 213 insertions(+), 32 deletions(-) diff --git a/internal/gateway/server.go b/internal/gateway/server.go index b253c26d..9078d082 100644 --- a/internal/gateway/server.go +++ b/internal/gateway/server.go @@ -13,6 +13,7 @@ import ( "os" "strings" "sync" + "time" "neo-code/internal/gateway/transport" ) @@ -20,6 +21,13 @@ import ( const ( // MaxFrameSize 定义单条 JSON 帧允许的最大字节数,避免异常输入导致内存放大。 MaxFrameSize int64 = 1 << 20 // 1 MiB + + // DefaultMaxConnections 定义服务允许的最大并发连接数,超过上限的连接会被快速拒绝。 + DefaultMaxConnections = 128 + // DefaultReadTimeout 定义单次读帧的最大等待时间,避免慢连接长期占用资源。 + DefaultReadTimeout = 30 * time.Second + // DefaultWriteTimeout 定义单次写帧的最大等待时间,避免写阻塞占用处理协程。 + DefaultWriteTimeout = 30 * time.Second ) var ( @@ -31,16 +39,22 @@ var ( // ServerOptions 描述网关服务启动所需的可选配置。 type ServerOptions struct { - ListenAddress string - Logger *log.Logger - listenFn func(address string) (net.Listener, error) + ListenAddress string + Logger *log.Logger + MaxConnections int + ReadTimeout time.Duration + WriteTimeout time.Duration + listenFn func(address string) (net.Listener, error) } // Server 提供基于本地 IPC 的网关服务骨架实现。 type Server struct { - listenAddress string - logger *log.Logger - listenFn func(address string) (net.Listener, error) + listenAddress string + logger *log.Logger + listenFn func(address string) (net.Listener, error) + maxConnections int + readTimeout time.Duration + writeTimeout time.Duration mu sync.Mutex listener net.Listener @@ -48,6 +62,14 @@ type Server struct { wg sync.WaitGroup } +type registerConnectionResult int + +const ( + registerConnectionAccepted registerConnectionResult = iota + registerConnectionServerClosed + registerConnectionLimitExceeded +) + // NewServer 创建网关服务实例,并解析默认监听地址。 func NewServer(options ServerOptions) (*Server, error) { listenAddress := strings.TrimSpace(options.ListenAddress) @@ -69,11 +91,29 @@ func NewServer(options ServerOptions) (*Server, error) { listenFn = transport.Listen } + maxConnections := options.MaxConnections + if maxConnections <= 0 { + maxConnections = DefaultMaxConnections + } + + readTimeout := options.ReadTimeout + if readTimeout <= 0 { + readTimeout = DefaultReadTimeout + } + + writeTimeout := options.WriteTimeout + if writeTimeout <= 0 { + writeTimeout = DefaultWriteTimeout + } + return &Server{ - listenAddress: listenAddress, - logger: logger, - listenFn: listenFn, - conns: make(map[net.Conn]struct{}), + listenAddress: listenAddress, + logger: logger, + listenFn: listenFn, + maxConnections: maxConnections, + readTimeout: readTimeout, + writeTimeout: writeTimeout, + conns: make(map[net.Conn]struct{}), }, nil } @@ -114,12 +154,17 @@ func (s *Server) Serve(ctx context.Context, runtimePort RuntimePort) error { return fmt.Errorf("gateway: accept connection: %w", acceptErr) } - if !s.registerConnection(conn) { + switch s.registerConnection(conn) { + case registerConnectionAccepted: + case registerConnectionServerClosed: + _ = conn.Close() + continue + case registerConnectionLimitExceeded: + s.logger.Printf("reject connection: max connections %d reached", s.maxConnections) _ = conn.Close() continue } - s.wg.Add(1) go func() { defer s.wg.Done() defer s.untrackConnection(conn) @@ -178,15 +223,19 @@ func (s *Server) snapshotConnections() map[net.Conn]struct{} { return copied } -// registerConnection 在服务可用时登记连接,若网关已关闭则拒绝登记。 -func (s *Server) registerConnection(conn net.Conn) bool { +// registerConnection 在服务可用且未超限时登记连接,并原子增加连接处理 WaitGroup 计数。 +func (s *Server) registerConnection(conn net.Conn) registerConnectionResult { s.mu.Lock() defer s.mu.Unlock() if s.listener == nil { - return false + return registerConnectionServerClosed + } + if len(s.conns) >= s.maxConnections { + return registerConnectionLimitExceeded } s.conns[conn] = struct{}{} - return true + s.wg.Add(1) + return registerConnectionAccepted } // untrackConnection 移除已结束连接,避免连接集合持续增长。 @@ -212,6 +261,11 @@ func (s *Server) handleConnection(ctx context.Context, conn net.Conn, runtimePor default: } + if err := s.applyReadDeadline(conn); err != nil { + s.logger.Printf("set read deadline failed: %v", err) + return + } + frame, err := decodeFrame(reader) if err != nil { if errors.Is(err, io.EOF) { @@ -220,9 +274,13 @@ func (s *Server) handleConnection(ctx context.Context, conn net.Conn, runtimePor if errors.Is(err, errFrameEmpty) { continue } + if isTimeoutError(err) { + s.logger.Printf("read frame timeout: %v", err) + return + } if errors.Is(err, errFrameTooLarge) { s.logger.Printf("decode frame failed: %v", err) - _ = encoder.Encode(errorFrame(MessageFrame{}, NewFrameError( + _ = s.writeFrame(conn, encoder, errorFrame(MessageFrame{}, NewFrameError( ErrorCodeInvalidFrame, fmt.Sprintf("frame exceeds max size %d bytes", MaxFrameSize), ))) @@ -230,18 +288,52 @@ func (s *Server) handleConnection(ctx context.Context, conn net.Conn, runtimePor } s.logger.Printf("decode frame failed: %v", err) - _ = encoder.Encode(errorFrame(MessageFrame{}, NewFrameError(ErrorCodeInvalidFrame, "invalid json frame"))) + _ = s.writeFrame(conn, encoder, errorFrame(MessageFrame{}, NewFrameError(ErrorCodeInvalidFrame, "invalid json frame"))) return } response := s.dispatchFrame(ctx, frame, runtimePort) - if err := encoder.Encode(response); err != nil { - s.logger.Printf("write frame failed: %v", err) + if !s.writeFrame(conn, encoder, response) { return } } } +// applyReadDeadline 为当前连接设置下一次读操作超时,避免慢读连接长期占用协程。 +func (s *Server) applyReadDeadline(conn net.Conn) error { + if s.readTimeout <= 0 { + return nil + } + return conn.SetReadDeadline(time.Now().Add(s.readTimeout)) +} + +// applyWriteDeadline 为当前连接设置下一次写操作超时,避免写阻塞导致协程泄漏。 +func (s *Server) applyWriteDeadline(conn net.Conn) error { + if s.writeTimeout <= 0 { + return nil + } + return conn.SetWriteDeadline(time.Now().Add(s.writeTimeout)) +} + +// writeFrame 统一处理响应写回及写超时设置,失败时返回 false 供上层快速终止连接循环。 +func (s *Server) writeFrame(conn net.Conn, encoder *json.Encoder, frame MessageFrame) bool { + if err := s.applyWriteDeadline(conn); err != nil { + s.logger.Printf("set write deadline failed: %v", err) + return false + } + if err := encoder.Encode(frame); err != nil { + s.logger.Printf("write frame failed: %v", err) + return false + } + return true +} + +// isTimeoutError 判断错误是否为网络超时,用于区分慢连接超时与协议错误。 +func isTimeoutError(err error) bool { + var netErr net.Error + return errors.As(err, &netErr) && netErr.Timeout() +} + // decodeFrame 从连接读取一条 JSON 帧并执行长度与格式校验。 func decodeFrame(reader *bufio.Reader) (MessageFrame, error) { payload, err := readFramePayload(reader, MaxFrameSize) diff --git a/internal/gateway/server_additional_test.go b/internal/gateway/server_additional_test.go index 5689ddec..3fcc7118 100644 --- a/internal/gateway/server_additional_test.go +++ b/internal/gateway/server_additional_test.go @@ -36,11 +36,23 @@ func TestNewServerUsesDefaultsAndOverrides(t *testing.T) { if server.listenFn == nil { t.Fatal("default listen function should not be nil") } + if server.maxConnections != DefaultMaxConnections { + t.Fatalf("default max connections = %d, want %d", server.maxConnections, DefaultMaxConnections) + } + if server.readTimeout != DefaultReadTimeout { + t.Fatalf("default read timeout = %v, want %v", server.readTimeout, DefaultReadTimeout) + } + if server.writeTimeout != DefaultWriteTimeout { + t.Fatalf("default write timeout = %v, want %v", server.writeTimeout, DefaultWriteTimeout) + } customLogger := log.New(io.Discard, "custom", 0) customServer, err := NewServer(ServerOptions{ - ListenAddress: " custom-address ", - Logger: customLogger, + ListenAddress: " custom-address ", + Logger: customLogger, + MaxConnections: 7, + ReadTimeout: 150 * time.Millisecond, + WriteTimeout: 250 * time.Millisecond, listenFn: func(string) (net.Listener, error) { return nil, nil }, @@ -54,6 +66,15 @@ func TestNewServerUsesDefaultsAndOverrides(t *testing.T) { if customServer.logger != customLogger { t.Fatal("custom logger was not used") } + if customServer.maxConnections != 7 { + t.Fatalf("custom max connections = %d, want %d", customServer.maxConnections, 7) + } + if customServer.readTimeout != 150*time.Millisecond { + t.Fatalf("custom read timeout = %v, want %v", customServer.readTimeout, 150*time.Millisecond) + } + if customServer.writeTimeout != 250*time.Millisecond { + t.Fatalf("custom write timeout = %v, want %v", customServer.writeTimeout, 250*time.Millisecond) + } } func TestNewServerReturnsDefaultAddressError(t *testing.T) { @@ -324,6 +345,84 @@ func TestServerHandleConnectionInvalidJSONFrame(t *testing.T) { } } +func TestRegisterConnectionRejectsWhenLimitExceeded(t *testing.T) { + server := &Server{ + listener: &simpleListener{}, + maxConnections: 1, + conns: make(map[net.Conn]struct{}), + } + + conn1Server, conn1Client := net.Pipe() + defer conn1Client.Close() + defer conn1Server.Close() + if got := server.registerConnection(conn1Server); got != registerConnectionAccepted { + t.Fatalf("first register result = %v, want accepted", got) + } + + conn2Server, conn2Client := net.Pipe() + defer conn2Client.Close() + defer conn2Server.Close() + if got := server.registerConnection(conn2Server); got != registerConnectionLimitExceeded { + t.Fatalf("second register result = %v, want limit exceeded", got) + } + + server.untrackConnection(conn1Server) + server.wg.Done() +} + +func TestServerHandleConnectionReadTimeoutClosesConnection(t *testing.T) { + server := &Server{ + logger: log.New(io.Discard, "", 0), + readTimeout: 20 * time.Millisecond, + } + serverConn, clientConn := net.Pipe() + done := make(chan struct{}) + go func() { + defer close(done) + server.handleConnection(context.Background(), serverConn, nil) + }() + + select { + case <-done: + case <-time.After(500 * time.Millisecond): + t.Fatal("handleConnection should exit after read timeout") + } + + var buf [1]byte + _, err := clientConn.Read(buf[:]) + if !errors.Is(err, io.EOF) && (err == nil || !strings.Contains(err.Error(), "closed pipe")) { + t.Fatalf("expected closed connection after timeout, got %v", err) + } + _ = clientConn.Close() +} + +func TestServerHandleConnectionWriteTimeoutClosesConnection(t *testing.T) { + server := &Server{ + logger: log.New(io.Discard, "", 0), + readTimeout: time.Second, + writeTimeout: 20 * time.Millisecond, + } + serverConn, clientConn := net.Pipe() + done := make(chan struct{}) + go func() { + defer close(done) + server.handleConnection(context.Background(), serverConn, nil) + }() + + _, err := io.WriteString(clientConn, `{"type":"request","action":"ping","request_id":"write-timeout"}`+"\n") + if err != nil { + t.Fatalf("write request: %v", err) + } + + select { + case <-done: + case <-time.After(500 * time.Millisecond): + t.Fatal("handleConnection should exit after write timeout") + } + + _ = clientConn.Close() +} + type failingReader struct{} func (r *failingReader) Read(_ []byte) (int, error) { diff --git a/internal/gateway/transport/listen_windows_acl_test.go b/internal/gateway/transport/listen_windows_acl_test.go index c92b8ae9..6d10e7ab 100644 --- a/internal/gateway/transport/listen_windows_acl_test.go +++ b/internal/gateway/transport/listen_windows_acl_test.go @@ -13,8 +13,6 @@ import ( ) func TestDefaultListenAddressWindows(t *testing.T) { - t.Parallel() - address, err := DefaultListenAddress() if err != nil { t.Fatalf("default listen address: %v", err) @@ -25,8 +23,6 @@ func TestDefaultListenAddressWindows(t *testing.T) { } func TestNewCleanupListenerBranches(t *testing.T) { - t.Parallel() - base := &stubNetListener{} if got := newCleanupListener(base, nil); got != base { t.Fatal("expected original listener when cleanup is nil") @@ -48,8 +44,6 @@ func TestNewCleanupListenerBranches(t *testing.T) { } func TestBuildRestrictedPipeSecurityDescriptorContainsExpectedACEs(t *testing.T) { - t.Parallel() - sddl, err := buildRestrictedPipeSecurityDescriptor() if err != nil { t.Fatalf("build restricted descriptor: %v", err) diff --git a/internal/gateway/transport/listen_windows_test.go b/internal/gateway/transport/listen_windows_test.go index 338f0b6b..826734ec 100644 --- a/internal/gateway/transport/listen_windows_test.go +++ b/internal/gateway/transport/listen_windows_test.go @@ -13,8 +13,6 @@ import ( ) func TestListenNamedPipeAcceptsConnection(t *testing.T) { - t.Parallel() - pipePath := fmt.Sprintf(`\\.\pipe\neocode-gateway-test-%d`, time.Now().UnixNano()) listener, err := Listen(pipePath) if err != nil { @@ -53,8 +51,6 @@ func TestListenNamedPipeAcceptsConnection(t *testing.T) { } func TestNewRestrictedPipeConfigContainsExpectedSIDs(t *testing.T) { - t.Parallel() - config, err := newRestrictedPipeConfig() if err != nil { t.Fatalf("new restricted pipe config: %v", err)