diff --git a/pkg/cmd/alpha/tunnel_relay.go b/pkg/cmd/alpha/tunnel_relay.go index 7ca3689..27b0e52 100644 --- a/pkg/cmd/alpha/tunnel_relay.go +++ b/pkg/cmd/alpha/tunnel_relay.go @@ -12,6 +12,8 @@ import ( "github.com/alphadose/haxmap" "github.com/spf13/cobra" + "github.com/apoxy-dev/icx" + "github.com/apoxy-dev/apoxy/pkg/cryptoutils" "github.com/apoxy-dev/apoxy/pkg/tunnel" "github.com/apoxy-dev/apoxy/pkg/tunnel/batchpc" @@ -21,7 +23,6 @@ import ( tunnet "github.com/apoxy-dev/apoxy/pkg/tunnel/net" "github.com/apoxy-dev/apoxy/pkg/tunnel/router" "github.com/apoxy-dev/apoxy/pkg/tunnel/vni" - "github.com/apoxy-dev/icx" ) var ( @@ -151,7 +152,7 @@ var tunnelRelayCmd = &cobra.Command{ relay.SetOnDisconnect(func(_ context.Context, agentName, id string) error { if cm, ok := connections.Get(id); ok { if err := agentIPAM.Release(cm.prefix); err != nil { - slog.Error("Failed to release prefix", err, + slog.Error("Failed to release prefix", slog.Any("error", err), slog.String("agent", agentName), slog.String("connID", id), slog.String("prefix", cm.prefix.String())) } diff --git a/pkg/cmd/alpha/tunnel_run.go b/pkg/cmd/alpha/tunnel_run.go index 21181b2..a951282 100644 --- a/pkg/cmd/alpha/tunnel_run.go +++ b/pkg/cmd/alpha/tunnel_run.go @@ -3,14 +3,15 @@ package alpha import ( "context" "crypto/tls" + "errors" "fmt" "log/slog" - "math/rand" "net" + "net/http" "net/netip" "net/url" "strings" - "sync" + "sync/atomic" "time" "github.com/apoxy-dev/icx" @@ -18,15 +19,24 @@ import ( "github.com/dpeckett/network" "github.com/spf13/cobra" "golang.org/x/sync/errgroup" + "k8s.io/apimachinery/pkg/util/sets" "github.com/apoxy-dev/apoxy/pkg/netstack" "github.com/apoxy-dev/apoxy/pkg/tunnel/api" + "github.com/apoxy-dev/apoxy/pkg/tunnel/randalloc" + "github.com/apoxy-dev/apoxy/pkg/tunnel/batchpc" "github.com/apoxy-dev/apoxy/pkg/tunnel/bifurcate" "github.com/apoxy-dev/apoxy/pkg/tunnel/conntrackpc" "github.com/apoxy-dev/apoxy/pkg/tunnel/router" ) +// Watchdog tuning knobs +const ( + watchdogMaxSilence = 120 * time.Second + watchdogInterval = 5 * time.Second +) + var ( agentName string // agent identifier tunnelName string // tunnel identifier @@ -36,8 +46,12 @@ var ( insecureSkipVerify bool // skip TLS verification (testing only) socksListenAddr string // SOCKS listen address pcapPath string // optional pcap path + healthAddr string // listen address for health endpoint (e.g. ":8080"); empty disables ) +// connectionHealthCounter tracks how many relay sessions are currently live. +var connectionHealthCounter atomic.Int32 + var tunnelRunCmd = &cobra.Command{ Use: "run", Short: "Run a tunnel", @@ -47,386 +61,615 @@ var tunnelRunCmd = &cobra.Command{ return fmt.Errorf("--min-conns must be at least 1") } - // One UDP socket shared between Geneve (data) and QUIC (control). - lis, err := net.ListenPacket("udp", ":0") + g, ctx := errgroup.WithContext(cmd.Context()) + + // Start health endpoint server if configured. + if strings.TrimSpace(healthAddr) != "" { + mux := http.NewServeMux() + mux.HandleFunc("/healthz", healthHandler) + + healthServer := &http.Server{ + Addr: healthAddr, + Handler: mux, + } + + g.Go(func() error { + slog.Info("Starting health endpoint server", slog.String("address", healthAddr)) + if err := healthServer.ListenAndServe(); err != nil && err != http.ErrServerClosed { + slog.Error("Health server failed", slog.Any("error", err)) + return err + } + return nil + }) + + g.Go(func() error { + <-ctx.Done() + shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + return healthServer.Shutdown(shutdownCtx) + }) + } + + packetPlane, err := newPacketPlane() if err != nil { - return fmt.Errorf("failed to create UDP socket: %w", err) + return err } + defer packetPlane.Close() - pc, err := batchpc.New("udp", lis) + tlsConf := &tls.Config{InsecureSkipVerify: insecureSkipVerify} + + // Bootstrap against the seed relay to learn MTU/DNS/routes, keys, VNI, and the relay address pool. + boot, err := bootstrapSession(ctx, seedRelayAddr, packetPlane.QuicMux, tlsConf) if err != nil { - return fmt.Errorf("failed to create batch packet conn: %w", err) + return err } - pcGeneve, pcQuic := bifurcate.Bifurcate(pc) - defer pcGeneve.Close() - defer pcQuic.Close() + // Initialize and start the router. + r, handler, err := initRouter( + ctx, + g, + boot.Connect, + routerInitOpts{ + pcGeneve: packetPlane.Geneve, + socksListenAddr: socksListenAddr, + pcapPath: pcapPath, + }, + ) + if err != nil { + return err + } + defer r.Close() + + // Create an relay address pool that ensures we never connect to the same + // relay from multiple slots at once. + relayAddressPool := randalloc.NewRandAllocator(boot.RelayAddresses) + + // Spawn minConns independent connection slots. + // Each slot: + // - acquires a unique relay from the allocator + // - connects & manages that session + // - when the session ends, releases the relay + for i := 0; i < minConns; i++ { + g.Go(func() error { + return manageConnectionSlot(ctx, packetPlane.QuicMux, handler, relayAddressPool, tlsConf) + }) + } - // Share a single QUIC socket across multiple relays. - pcQuicMultiplexed := conntrackpc.New(pcQuic, conntrackpc.Options{}) - defer pcQuicMultiplexed.Close() + return g.Wait() + }, +} - g, ctx := errgroup.WithContext(cmd.Context()) +func init() { + tunnelRunCmd.Flags().StringVarP(&agentName, "agent", "a", "", "The name of this agent.") + tunnelRunCmd.Flags().StringVarP(&tunnelName, "name", "n", "", "The logical name of the tunnel to connect to.") + tunnelRunCmd.Flags().StringVarP(&seedRelayAddr, "relay-addr", "r", "", "Seed relay address (host:port). The client bootstraps here, then uses the returned relay list.") + tunnelRunCmd.Flags().IntVar(&minConns, "min-conns", 1, "Minimum number of relays to maintain connections to (randomly selected from the server-provided list).") + tunnelRunCmd.Flags().StringVarP(&token, "token", "k", "", "The token to use for authenticating with the tunnel relays.") + tunnelRunCmd.Flags().BoolVar(&insecureSkipVerify, "insecure-skip-verify", false, "Skip TLS certificate verification for relay connections.") + tunnelRunCmd.Flags().StringVarP(&pcapPath, "pcap", "p", "", "Path to an optional packet capture file to write.") + tunnelRunCmd.Flags().StringVar(&socksListenAddr, "socks-addr", "localhost:1080", "Listen address for SOCKS proxy.") + tunnelRunCmd.Flags().StringVar(&healthAddr, "health-addr", "localhost:8080", "Listen address for health endpoint (e.g. \":8080\"). Empty disables.") - var ( - routerOnce sync.Once - routerErr error - r *router.ICXNetstackRouter - handler *icx.Handler - ) + cobra.CheckErr(tunnelRunCmd.MarkFlagRequired("agent")) + cobra.CheckErr(tunnelRunCmd.MarkFlagRequired("name")) + cobra.CheckErr(tunnelRunCmd.MarkFlagRequired("relay-addr")) + cobra.CheckErr(tunnelRunCmd.MarkFlagRequired("token")) - // Lazily create router/handler on first successful Connect. - getHandler := func(connectResp *api.ConnectResponse) (*icx.Handler, error) { - routerOnce.Do(func() { - routerOpts := []router.Option{ - router.WithPacketConn(pcGeneve), - router.WithTunnelMTU(connectResp.MTU), - } + tunnelCmd.AddCommand(tunnelRunCmd) +} - if socksListenAddr != "" { - routerOpts = append(routerOpts, router.WithSocksListenAddr(socksListenAddr)) - } - if pcapPath != "" { - routerOpts = append(routerOpts, router.WithPcapPath(pcapPath)) - } - if connectResp.DNS != nil { - resolveConf := &network.ResolveConfig{ - Nameservers: connectResp.DNS.Servers, - SearchDomains: connectResp.DNS.SearchDomains, - NDots: connectResp.DNS.NDots, - } - routerOpts = append(routerOpts, router.WithResolveConfig(resolveConf)) - } +// packetPlane bundles the shared UDP socket and its derived logical planes: +// - Geneve/data plane (pcGeneve) +// - QUIC/control plane mux (pcQuicMux) +type packetPlane struct { + Geneve batchpc.BatchPacketConn + QuicMux *conntrackpc.ConntrackPacketConn + closers []func() +} - r, routerErr = router.NewICXNetstackRouter(routerOpts...) - if routerErr != nil { - return - } - handler = r.Handler +// newPacketPlane: +// - creates a UDP socket bound to :0 +// - wraps it in a BatchPacketConn +// - bifurcates into Geneve (data plane) and QUIC (control) +// - wraps QUIC side in a conntrack multiplexer +func newPacketPlane() (*packetPlane, error) { + lis, err := net.ListenPacket("udp", ":0") + if err != nil { + return nil, fmt.Errorf("failed to create UDP socket: %w", err) + } - for _, addrStr := range connectResp.Addresses { - slog.Info("Adding address", slog.String("address", addrStr)) + bpc, err := batchpc.New("udp", lis) + if err != nil { + lis.Close() + return nil, fmt.Errorf("failed to create batch packet conn: %w", err) + } - addr, err := netip.ParsePrefix(addrStr) - if err != nil { - slog.Warn("Failed to parse address", slog.String("address", addrStr), slog.Any("error", err)) - continue - } + pcGeneveInner, pcQuicInner := bifurcate.Bifurcate(bpc) + pcQuicMuxInner := conntrackpc.New(pcQuicInner, conntrackpc.Options{}) + + return &packetPlane{ + Geneve: pcGeneveInner, + QuicMux: pcQuicMuxInner, + closers: []func(){ + func() { pcGeneveInner.Close() }, + func() { pcQuicMuxInner.Close() }, + func() { pcQuicInner.Close() }, + }, + }, nil +} - if err := r.AddAddr(addr, nil); err != nil { - slog.Warn("Failed to add address", slog.String("address", addrStr), slog.Any("error", err)) - } - } +func (pp *packetPlane) Close() { + for _, c := range pp.closers { + c() + } +} - for _, route := range connectResp.Routes { - slog.Info("Adding route", slog.String("destination", route.Destination)) +type bootstrapInfo struct { + Connect *api.ConnectResponse + RelayAddresses sets.Set[string] +} - dst, err := netip.ParsePrefix(route.Destination) - if err != nil { - slog.Warn("Failed to parse route prefix", slog.String("prefix", route.Destination), slog.Any("error", err)) - continue - } - if err := r.AddRoute(dst); err != nil { - slog.Warn("Failed to add route", slog.String("prefix", route.Destination), slog.Any("error", err)) - } - } +// bootstrapSession connects to the seed relay, retrieves tunnel config and +// the relay address pool, disconnects, and returns that bootstrap data. +func bootstrapSession( + ctx context.Context, + seedRelayAddr string, + pcQuicMux *conntrackpc.ConntrackPacketConn, + tlsConf *tls.Config, +) (*bootstrapInfo, error) { + seedAddr := strings.TrimSpace(seedRelayAddr) - g.Go(func() error { return r.Start(ctx) }) - }) - return handler, routerErr - } + seedResolved, err := resolveAddrPort(ctx, seedAddr) + if err != nil { + return nil, fmt.Errorf("failed to resolve seed relay addr %q: %w", seedAddr, err) + } - defer func() { - if r != nil { - _ = r.Close() - } - }() + seedPcQuic, err := pcQuicMux.Open(&net.UDPAddr{ + IP: seedResolved.Addr().AsSlice(), + Port: int(seedResolved.Port()), + }) + if err != nil { + return nil, fmt.Errorf("failed to create multiplexed packet conn for seed relay %q: %w", seedAddr, err) + } + defer seedPcQuic.Close() + + client, err := api.NewClient(api.ClientOptions{ + BaseURL: (&url.URL{Scheme: "https", Host: seedAddr}).String(), + Agent: agentName, + TunnelName: tunnelName, + Token: token, + TLSConfig: tlsConf, + PacketConn: seedPcQuic, + }) + if err != nil { + return nil, fmt.Errorf("create seed API client: %w", err) + } + defer client.Close() - tlsConf := &tls.Config{InsecureSkipVerify: insecureSkipVerify} + slog.Info("Bootstrapping against seed relay", slog.String("relay", seedAddr)) - // Bootstrap via seed relay to fetch MTU/DNS/routes and the relay pool. - seedAddr := strings.TrimSpace(seedRelayAddr) - seedResolved, err := resolveAddrPort(ctx, seedAddr) - if err != nil { - return fmt.Errorf("failed to resolve seed relay addr %q: %w", seedAddr, err) - } + connectResp, err := client.Connect(ctx) + if err != nil { + return nil, fmt.Errorf("bootstrap connect to seed relay %q: %w", seedAddr, err) + } - seedPcQuic, err := pcQuicMultiplexed.Open(&net.UDPAddr{ - IP: seedResolved.Addr().AsSlice(), - Port: int(seedResolved.Port()), - }) - if err != nil { - return fmt.Errorf("failed to create multiplexed packet conn for seed relay %q: %w", seedAddr, err) - } - defer seedPcQuic.Close() - - seedBaseURL := url.URL{Scheme: "https", Host: seedAddr} - seedClient, err := api.NewClient(api.ClientOptions{ - BaseURL: seedBaseURL.String(), - Agent: agentName, - TunnelName: tunnelName, - Token: token, - TLSConfig: tlsConf, - PacketConn: seedPcQuic, - }) - if err != nil { - return fmt.Errorf("create seed API client: %w", err) + // We're only using this connection for discovery. Close it gracefully. + if err := client.Disconnect(ctx, connectResp.ID); err != nil { + slog.Warn("Failed to disconnect bootstrap session", + slog.String("id", connectResp.ID), + slog.Any("error", err)) + } + + // Build a deduped set of relay addresses (seed + server-provided). + addrSet := sets.New[string]() + + trimmedSeed := strings.TrimSpace(seedAddr) + if trimmedSeed != "" { + addrSet.Insert(trimmedSeed) + } + for _, a := range connectResp.RelayAddresses { + a = strings.TrimSpace(a) + if a != "" { + addrSet.Insert(a) } + } + + return &bootstrapInfo{ + Connect: connectResp, + RelayAddresses: addrSet, + }, nil +} + +type routerInitOpts struct { + pcGeneve batchpc.BatchPacketConn + socksListenAddr string + pcapPath string +} - slog.Info("Bootstrapping against seed relay", slog.String("relay", seedAddr)) +// initRouter creates and starts the ICXNetstackRouter / icx.Handler using the +// bootstrap response. +func initRouter( + ctx context.Context, + g *errgroup.Group, + connectResp *api.ConnectResponse, + opts routerInitOpts, +) (*router.ICXNetstackRouter, *icx.Handler, error) { + routerOpts := []router.Option{ + router.WithPacketConn(opts.pcGeneve), + router.WithTunnelMTU(connectResp.MTU), + } - seedResp, err := seedClient.Connect(ctx) + if opts.socksListenAddr != "" { + routerOpts = append(routerOpts, router.WithSocksListenAddr(opts.socksListenAddr)) + } + if opts.pcapPath != "" { + routerOpts = append(routerOpts, router.WithPcapPath(opts.pcapPath)) + } + if connectResp.DNS != nil { + routerOpts = append(routerOpts, router.WithResolveConfig(&network.ResolveConfig{ + Nameservers: connectResp.DNS.Servers, + SearchDomains: connectResp.DNS.SearchDomains, + NDots: connectResp.DNS.NDots, + })) + } + + r, err := router.NewICXNetstackRouter(routerOpts...) + if err != nil { + return nil, nil, err + } + + h := r.Handler + + // Add assigned addresses. + for _, addrStr := range connectResp.Addresses { + slog.Info("Adding address", slog.String("address", addrStr)) + + addr, err := netip.ParsePrefix(addrStr) if err != nil { - _ = seedClient.Close() - return fmt.Errorf("bootstrap connect to seed relay %q: %w", seedAddr, err) + slog.Warn("Failed to parse address", + slog.String("address", addrStr), + slog.Any("error", err)) + continue } - // Initialize router (MTU, DNS, routes) based on bootstrap response. - if _, err := getHandler(seedResp); err != nil { - _ = seedClient.Close() - return fmt.Errorf("init router: %w", err) + if err := r.AddAddr(addr, nil); err != nil { + slog.Warn("Failed to add address", + slog.String("address", addrStr), + slog.Any("error", err)) } + } - // Close bootstrap session; steady-state connections are created below. - if err := seedClient.Disconnect(ctx, seedResp.ID); err != nil { - slog.Warn("Failed to disconnect bootstrap session", slog.String("id", seedResp.ID), slog.Any("error", err)) - } - _ = seedClient.Close() - - // Build unique relay pool (ensure seed included once). - pool := make([]string, 0, len(seedResp.RelayAddresses)+1) - seen := map[string]struct{}{} - add := func(a string) { - a = strings.TrimSpace(a) - if a == "" { - return - } - if _, ok := seen[a]; ok { - return - } - seen[a] = struct{}{} - pool = append(pool, a) + // Add routes. + for _, rt := range connectResp.Routes { + slog.Info("Adding route", slog.String("destination", rt.Destination)) + + dst, err := netip.ParsePrefix(rt.Destination) + if err != nil { + slog.Warn("Failed to parse route prefix", + slog.String("prefix", rt.Destination), + slog.Any("error", err)) + continue } - for _, a := range seedResp.RelayAddresses { - add(a) + if err := r.AddRoute(dst); err != nil { + slog.Warn("Failed to add route", + slog.String("prefix", rt.Destination), + slog.Any("error", err)) } - add(seedAddr) + } - if len(pool) == 0 { - return fmt.Errorf("server did not return any relay addresses and seed was empty") - } + // Start the router. + g.Go(func() error { return r.Start(ctx) }) - // Randomly pick up to minConns relays. - rand.Shuffle(len(pool), func(i, j int) { pool[i], pool[j] = pool[j], pool[i] }) - n := minConns - if n > len(pool) { - n = len(pool) - } - selected := pool[:n] + return r, h, nil +} - slog.Info("Selected relays for steady-state connections", - slog.Int("minConns", minConns), - slog.Int("selected", len(selected)), - slog.Any("relays", selected), - ) +// connectAndInitSession dials the relay, runs Connect, and returns the live +// api.Client, the ConnectResponse, and the handler. It also wires the relay +// into the handler via AddVirtualNetwork. +func connectAndInitSession( + ctx context.Context, + pcQuic net.PacketConn, + handler *icx.Handler, + relayAddr string, + tlsConf *tls.Config, +) (*api.Client, *api.ConnectResponse, *icx.Handler, error) { + client, err := api.NewClient(api.ClientOptions{ + BaseURL: (&url.URL{Scheme: "https", Host: relayAddr}).String(), + Agent: agentName, + TunnelName: tunnelName, + Token: token, + TLSConfig: tlsConf, + PacketConn: pcQuic, + }) + if err != nil { + return nil, nil, nil, fmt.Errorf("create API client: %w", err) + } - // One connection manager per relay. - for _, relay := range selected { - relay := relay - g.Go(func() error { - relayAddr, err := resolveAddrPort(ctx, relay) - if err != nil { - return fmt.Errorf("failed to resolve relay addr %q: %w", relay, err) - } + cleanupOnErr := func(e error) (*api.Client, *api.ConnectResponse, *icx.Handler, error) { + _ = client.Close() + return nil, nil, nil, e + } - pcQuic, err := pcQuicMultiplexed.Open(&net.UDPAddr{ - IP: relayAddr.Addr().AsSlice(), - Port: int(relayAddr.Port()), - }) - if err != nil { - return fmt.Errorf("failed to create multiplexed packet conn for relay %q: %w", relay, err) - } - defer pcQuic.Close() + slog.Info("Connecting to relay", slog.String("relay", relayAddr)) - return manageRelayConnection(ctx, pcQuic, getHandler, relay, tlsConf) - }) + connectResp, err := client.Connect(ctx) + if err != nil { + return cleanupOnErr(fmt.Errorf("connect to relay: %w", err)) + } + + remoteAddr, err := resolveAddrPort(ctx, relayAddr) + if err != nil { + return cleanupOnErr(fmt.Errorf("resolve relay addr %q: %w", relayAddr, err)) + } + + overlayAddrs, err := parsePrefixes(connectResp.Addresses) + if err != nil { + return cleanupOnErr(fmt.Errorf("parse assigned addresses: %w", err)) + } + + for _, route := range connectResp.Routes { + dst, err := netip.ParsePrefix(route.Destination) + if err != nil { + slog.Warn("Failed to parse route prefix", + slog.String("prefix", route.Destination), + slog.Any("error", err)) + continue } + overlayAddrs = append(overlayAddrs, dst) + } - return g.Wait() - }, -} + if err := handler.AddVirtualNetwork( + connectResp.VNI, + netstack.ToFullAddress(remoteAddr), + overlayAddrs, + ); err != nil { + return cleanupOnErr(fmt.Errorf("add virtual network: %w", err)) + } -func init() { - tunnelRunCmd.Flags().StringVarP(&agentName, "agent", "a", "", "The name of this agent.") - tunnelRunCmd.Flags().StringVarP(&tunnelName, "name", "n", "", "The logical name of the tunnel to connect to.") - tunnelRunCmd.Flags().StringVarP(&seedRelayAddr, "relay-addr", "r", "", "Seed relay address (host:port). The client bootstraps here, then uses the returned relay list.") - tunnelRunCmd.Flags().IntVar(&minConns, "min-conns", 1, "Minimum number of relays to maintain connections to (randomly selected from the server-provided list).") - tunnelRunCmd.Flags().StringVarP(&token, "token", "k", "", "The token to use for authenticating with the tunnel relays.") - tunnelRunCmd.Flags().BoolVar(&insecureSkipVerify, "insecure-skip-verify", false, "Skip TLS certificate verification for relay connections.") - tunnelRunCmd.Flags().StringVarP(&pcapPath, "pcap", "p", "", "Path to an optional packet capture file to write.") - tunnelRunCmd.Flags().StringVar(&socksListenAddr, "socks-addr", "localhost:1080", "Listen address for SOCKS proxy.") + slog.Info("Connected to relay", + slog.String("relay", relayAddr), + slog.String("id", connectResp.ID), + slog.Int("vni", int(connectResp.VNI)), + slog.Int("mtu", connectResp.MTU), + ) - cobra.CheckErr(tunnelRunCmd.MarkFlagRequired("agent")) - cobra.CheckErr(tunnelRunCmd.MarkFlagRequired("name")) - cobra.CheckErr(tunnelRunCmd.MarkFlagRequired("relay-addr")) - cobra.CheckErr(tunnelRunCmd.MarkFlagRequired("token")) + return client, connectResp, handler, nil +} - tunnelCmd.AddCommand(tunnelRunCmd) +// closeSession best-effort disconnect + close of an active session. +func closeSession(client *api.Client, connID string) { + if client == nil || connID == "" { + return + } + disconnectCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if err := client.Disconnect(disconnectCtx, connID); err != nil { + slog.Error("Failed to disconnect from tunnel", + slog.String("id", connID), + slog.Any("error", err)) + } + slog.Info("Disconnected from tunnel", slog.String("id", connID)) + _ = client.Close() } -// manageRelayConnection keeps a single relay session alive (connect → rotate-keys → reconnect). -func manageRelayConnection( +// manageRelayConnectionOnce establishes and maintains a single relay session +// to the specified relayAddr over pcQuic. It will: +// +// - retry Connect() until it succeeds or ctx is canceled +// - once connected, run key rotation and watchdog concurrently +// - whichever fails first ends the session +func manageRelayConnectionOnce( ctx context.Context, pcQuic net.PacketConn, - getHandler func(*api.ConnectResponse) (*icx.Handler, error), + handler *icx.Handler, relayAddr string, tlsConf *tls.Config, + onConnected func(*api.ConnectResponse), ) error { - baseURL := url.URL{Scheme: "https", Host: relayAddr} - var ( currentClient *api.Client currentConnID string + connectResp *api.ConnectResponse ) - // Best-effort disconnect/close of the active session. - disconnectClient := func() { - if currentClient == nil || currentConnID == "" { - return - } - disconnectCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - if err := currentClient.Disconnect(disconnectCtx, currentConnID); err != nil { - slog.Error("Failed to disconnect from tunnel", slog.String("id", currentConnID), slog.Any("error", err)) + // When this function returns, that relay session is down, so decrement + // if we had actually marked it active. + defer func() { + if currentConnID != "" { + connectionHealthCounter.Add(-1) } - slog.Info("Disconnected from tunnel", slog.String("id", currentConnID)) - _ = currentClient.Close() - currentClient = nil - currentConnID = "" - } - defer disconnectClient() + closeSession(currentClient, currentConnID) + }() - // Session lifecycle loop. - for { - select { - case <-ctx.Done(): + // Keep retrying connect until context canceled. + err := retry.Do( + func() error { + c, cr, _, err := connectAndInitSession(ctx, pcQuic, handler, relayAddr, tlsConf) + if err != nil { + return err + } + currentClient = c + currentConnID = cr.ID + connectResp = cr + + // Publish latest info to the caller (e.g., update relay pool). + if onConnected != nil { + onConnected(cr) + } + + return nil + }, + retry.Context(ctx), + retry.OnRetry(func(n uint, err error) { + slog.Warn("Reconnect attempt failed; backing off", + slog.String("relay", relayAddr), + slog.Uint64("attempt", uint64(n+1)), + slog.Any("error", err)) + }), + retry.LastErrorOnly(true), + ) + if err != nil { + if ctx.Err() != nil { return ctx.Err() - default: } + slog.Error("Failed to (re)connect to relay", + slog.String("relay", relayAddr), + slog.Any("error", err)) + return fmt.Errorf("failed to connect to relay %q: %w", relayAddr, err) + } - var ( - connectResp *api.ConnectResponse - handler *icx.Handler + // Successful session establishment: mark this connection active. + connectionHealthCounter.Add(1) + + // Once connected, run key rotation and watchdog concurrently. + sessionCtx, sessionCancel := context.WithCancel(ctx) + defer sessionCancel() + + g, gctx := errgroup.WithContext(sessionCtx) + g.Go(func() error { + return manageKeyRotation( + gctx, + handler, + currentClient, + currentConnID, + connectResp.VNI, + connectResp.Keys, + ) + }) + g.Go(func() error { + return relayWatchdog( + gctx, + handler, + connectResp.VNI, + watchdogMaxSilence, + watchdogInterval, ) + }) - // Connect with exponential backoff. - err := retry.Do( - func() error { - client, err := api.NewClient(api.ClientOptions{ - BaseURL: baseURL.String(), - Agent: agentName, - TunnelName: tunnelName, - Token: token, - TLSConfig: tlsConf, - PacketConn: pcQuic, - }) - if err != nil { - return fmt.Errorf("create API client: %w", err) - } + // Wait for either goroutine to return an error. + waitErr := g.Wait() - cleanupOnErr := func(e error) error { - _ = client.Close() - return e - } + if ctx.Err() != nil { + return ctx.Err() + } - slog.Info("Connecting to relay", slog.String("relay", relayAddr)) + if waitErr != nil && waitErr != context.Canceled { + slog.Warn("Connection ended", + slog.String("relay", relayAddr), + slog.Any("error", waitErr)) + } + return waitErr +} - connectResp, err = client.Connect(ctx) - if err != nil { - return cleanupOnErr(fmt.Errorf("connect to relay: %w", err)) - } +// manageConnectionSlot owns one "connection slot" that we promised to keep +// active. It repeatedly: +// +// - asks the relay address pool for an exclusive relay address +// - opens a PacketConn to that relay +// - runs manageRelayConnectionOnce +// - when that session ends, releases the relay back to the pool +// +// If minConns > number of relays, extra goroutines will block in Acquire() +// until another slot releases a relay. This enforces "no two sessions to the +// same relay address" at any instant. +func manageConnectionSlot( + ctx context.Context, + pcQuicMux *conntrackpc.ConntrackPacketConn, + handler *icx.Handler, + relayAddressPool *randalloc.RandAllocator[string], + tlsConf *tls.Config, +) error { + for { + // Block here until we get exclusive rights to a relay, + // or until ctx is canceled. + relayAddr, err := relayAddressPool.Acquire(ctx) + if err != nil { + return err // ctx canceled, etc. + } - handler, err = getHandler(connectResp) - if err != nil { - return cleanupOnErr(fmt.Errorf("init router: %w", err)) - } + slog.Info("Acquired relay slot", + slog.String("relay", relayAddr)) - remoteAddr, err := resolveAddrPort(ctx, relayAddr) - if err != nil { - return cleanupOnErr(fmt.Errorf("resolve relay addr %q: %w", relayAddr, err)) - } + // We'll run the session in an inner func so we can defer cleanup + // (pcQuic.Close) per-session but still always Release() after. + err = func() error { + // Resolve relay -> concrete IP:port. + relayAddrParsed, err := resolveAddrPort(ctx, relayAddr) + if err != nil { + slog.Warn("failed to resolve relay, will pick a new relay", + slog.String("relay", relayAddr), + slog.Any("error", err)) + return nil // we'll just loop and Acquire again + } - overlayAddrs, err := stringsToPrefixes(connectResp.Addresses) - if err != nil { - return cleanupOnErr(fmt.Errorf("parse assigned addresses: %w", err)) - } + // Open per-relay PacketConn off the shared mux. + pcQuic, err := pcQuicMux.Open(&net.UDPAddr{ + IP: relayAddrParsed.Addr().AsSlice(), + Port: int(relayAddrParsed.Port()), + }) + if err != nil { + slog.Warn("failed to create multiplexed packet conn for relay, will pick a new relay", + slog.String("relay", relayAddr), + slog.Any("error", err)) + return nil // loop again + } - for _, route := range connectResp.Routes { - dst, err := netip.ParsePrefix(route.Destination) - if err != nil { - slog.Warn("Failed to parse route prefix", slog.String("prefix", route.Destination), slog.Any("error", err)) - continue - } + // Make sure we close the PacketConn when the session ends. + defer pcQuic.Close() - overlayAddrs = append(overlayAddrs, dst) + // Updater that refreshes the allocator from the server's latest view. + onConnected := func(cr *api.ConnectResponse) { + if cr == nil { + return } - - if err := handler.AddVirtualNetwork(connectResp.VNI, netstack.ToFullAddress(remoteAddr), overlayAddrs); err != nil { - return cleanupOnErr(fmt.Errorf("add virtual network: %w", err)) + newSet := sets.New[string]() + for _, a := range cr.RelayAddresses { + a = strings.TrimSpace(a) + if a != "" { + newSet.Insert(a) + } + } + // Ensure the currently-connected relay remains in the pool so + // running sessions aren't stranded. It won't be handed out to + // another slot until we Release() this one anyway. + if relayAddr != "" { + newSet.Insert(strings.TrimSpace(relayAddr)) } - currentClient = client - currentConnID = connectResp.ID + relayAddressPool.Replace(newSet) - slog.Info("Connected to relay", - slog.String("relay", relayAddr), - slog.String("id", connectResp.ID), - slog.Int("vni", int(connectResp.VNI)), - slog.Int("mtu", connectResp.MTU), - ) + slog.Info("Updated relay address pool from connect response", + slog.Int("count", newSet.Len())) + } - return nil - }, - retry.Context(ctx), - retry.Attempts(0), // until ctx canceled - retry.OnRetry(func(n uint, err error) { - slog.Warn("Reconnect attempt failed; backing off", - slog.String("relay", relayAddr), - slog.Uint64("attempt", uint64(n+1)), - slog.Any("error", err)) - }), - retry.LastErrorOnly(true), - ) + // Run the actual session lifecycle (watchdog, key rotation, etc). + sessErr := manageRelayConnectionOnce(ctx, pcQuic, handler, relayAddr, tlsConf, onConnected) - if err != nil { if ctx.Err() != nil { return ctx.Err() } - slog.Error("Failed to (re)connect to relay", slog.String("relay", relayAddr), slog.Any("error", err)) - select { - case <-ctx.Done(): - return ctx.Err() - case <-time.After(2 * time.Second): - } - continue - } - // Live connection: rotate keys until failure or shutdown. - waitErr := manageKeyRotation(ctx, handler, currentClient, currentConnID, connectResp.VNI, connectResp.Keys) + if sessErr != nil && !errors.Is(sessErr, context.Canceled) { + slog.Warn("Connection to relay ended; rotating to a new relay", + slog.String("relay", relayAddr), + slog.Any("error", sessErr)) + } + return nil + }() - disconnectClient() + // Release the relay for other slots before the next loop iteration. + relayAddressPool.Release(relayAddr) - if ctx.Err() != nil { - return ctx.Err() - } - if waitErr != nil && waitErr != context.Canceled { - slog.Warn("Key rotation ended; will attempt to reconnect", - slog.String("relay", relayAddr), slog.Any("error", waitErr)) + if err != nil { + return err } + + // loop: grab a (maybe different) relay next time } } -// manageKeyRotation applies initial keys and refreshes at half-life with retry on failures. +// manageKeyRotation applies initial keys and refreshes at half-life with retry +// on failures. func manageKeyRotation( ctx context.Context, handler *icx.Handler, @@ -455,6 +698,7 @@ func manageKeyRotation( select { case <-ctx.Done(): return ctx.Err() + case <-timer.C: var upd *api.UpdateKeysResponse err := retry.Do( @@ -464,7 +708,7 @@ func manageKeyRotation( return err }, retry.Context(ctx), - retry.Attempts(0), // until ctx canceled + retry.Attempts(0), // keep trying until ctx canceled retry.OnRetry(func(n uint, err error) { slog.Warn("Key update failed; backing off", slog.Uint64("attempt", uint64(n+1)), @@ -473,58 +717,95 @@ func manageKeyRotation( retry.LastErrorOnly(true), ) if err != nil { - return err // includes context cancellation + return err } - slog.Info("Rotated tunnel keys", slog.Uint64("epoch", uint64(upd.Keys.Epoch))) + + slog.Info("Rotated tunnel keys", + slog.Uint64("epoch", uint64(upd.Keys.Epoch))) + timer.Reset(applyAndSchedule(upd.Keys)) } } } -// resolveAddrPort resolves "host:port" (IPv4/IPv6/hostname) to a concrete AddrPort, preferring IPv4. -func resolveAddrPort(ctx context.Context, hostport string) (netip.AddrPort, error) { - host, portStr, err := net.SplitHostPort(hostport) - if err != nil { - return netip.AddrPort{}, fmt.Errorf("split host/port: %w", err) - } - pn, err := net.LookupPort("udp", portStr) - if err != nil { - return netip.AddrPort{}, fmt.Errorf("lookup port %q: %w", portStr, err) - } - port := uint16(pn) +// relayWatchdog monitors RX silence for a specific VNI and returns an error if +// we haven't received any packet from the remote in maxSilence. +// It polls at checkInterval and exits if ctx is canceled. +func relayWatchdog( + ctx context.Context, + handler *icx.Handler, + vni uint, + maxSilence time.Duration, + checkInterval time.Duration, +) error { + ticker := time.NewTicker(checkInterval) + defer ticker.Stop() - // Fast-path for literal IPs. - if ip, err := netip.ParseAddr(host); err == nil { - return netip.AddrPortFrom(ip, port), nil - } + for { + select { + case <-ctx.Done(): + return ctx.Err() - addrs, err := net.DefaultResolver.LookupIPAddr(ctx, host) - if err != nil { - return netip.AddrPort{}, fmt.Errorf("lookup %q: %w", host, err) - } - var v4, v6 *netip.Addr - for _, a := range addrs { - if ip, ok := netip.AddrFromSlice(a.IP); ok { - if ip.Is4() && v4 == nil { - ipCopy := ip - v4 = &ipCopy - } else if ip.Is6() && v6 == nil { - ipCopy := ip - v6 = &ipCopy + case <-ticker.C: + vnet, ok := handler.GetVirtualNetwork(vni) + if !ok { + // The VNI disappeared out from under us; treat as dead. + return fmt.Errorf("relayWatchdog: VNI %d no longer present", vni) + } + + lastRxNs := vnet.Stats.LastRXUnixNano.Load() + now := time.Now() + + // If we've never received anything (0 == not set), this is suspicious, + // but we don't want to instantly kill a brand new session. + // We'll treat "never received" as "lastRx == connect time == now", + // so it only trips after maxSilence has actually elapsed. + var lastRx time.Time + if lastRxNs == 0 { + lastRx = now + } else { + lastRx = time.Unix(0, lastRxNs) + } + + silence := now.Sub(lastRx) + if silence > maxSilence { + slog.Warn("relayWatchdog: RX silence threshold exceeded; declaring tunnel dead", + slog.Uint64("vni", uint64(vni)), + slog.Duration("silence", silence), + slog.Duration("maxSilence", maxSilence), + slog.Time("lastRx", lastRx), + ) + return fmt.Errorf("rx silence (%s) exceeded max (%s)", silence, maxSilence) } } } - switch { - case v4 != nil: - return netip.AddrPortFrom(*v4, port), nil - case v6 != nil: - return netip.AddrPortFrom(*v6, port), nil - default: - return netip.AddrPort{}, fmt.Errorf("no usable A/AAAA records for %q", host) +} + +// resolveAddrPort resolves a host:port string into a netip.AddrPort by doing a +// short-lived UDP dial. This both resolves DNS and also captures the concrete +// remote address the OS actually chose. +func resolveAddrPort(ctx context.Context, relayAddr string) (netip.AddrPort, error) { + // Create a short-lived UDP connection to the host:port. + // This triggers the OS resolver and routing logic. + dialer := net.Dialer{} + conn, err := dialer.DialContext(ctx, "udp", relayAddr) + if err != nil { + return netip.AddrPort{}, fmt.Errorf("probe dial failed for %q: %w", relayAddr, err) + } + defer conn.Close() + + // Extract the resolved remote address that the OS actually chose. + ra := conn.RemoteAddr() + udpAddr, ok := ra.(*net.UDPAddr) + if !ok { + return netip.AddrPort{}, fmt.Errorf("unexpected remote addr type: %T", ra) } + + return netip.AddrPortFrom(netip.MustParseAddr(udpAddr.IP.String()), uint16(udpAddr.Port)), nil } -func stringsToPrefixes(addrs []string) ([]netip.Prefix, error) { +// parsePrefixes parses a list of string addresses into netip.Prefixes. +func parsePrefixes(addrs []string) ([]netip.Prefix, error) { prefixes := make([]netip.Prefix, 0, len(addrs)) for _, addr := range addrs { p, err := netip.ParsePrefix(addr) @@ -535,3 +816,24 @@ func stringsToPrefixes(addrs []string) ([]netip.Prefix, error) { } return prefixes, nil } + +// healthHandler returns 200 OK when at least one tunnel connection is active, +// 503 otherwise. This is used by external health checks. +// +// Response codes: +// - 200 OK: At least one tunnel connection is active +// - 503 Service Unavailable: No active tunnel connections +// +// Body is plain text with a short summary. +func healthHandler(w http.ResponseWriter, r *http.Request) { + active := connectionHealthCounter.Load() + + if active > 0 { + w.WriteHeader(http.StatusOK) + fmt.Fprintf(w, "OK - %d active connection(s)\n", active) + return + } + + w.WriteHeader(http.StatusServiceUnavailable) + fmt.Fprintf(w, "UNHEALTHY - no active connections\n") +} diff --git a/pkg/tunnel/randalloc/random_allocator.go b/pkg/tunnel/randalloc/random_allocator.go new file mode 100644 index 0000000..a6b04f3 --- /dev/null +++ b/pkg/tunnel/randalloc/random_allocator.go @@ -0,0 +1,127 @@ +package randalloc + +import ( + "context" + "math/rand" + "sync" + + "k8s.io/apimachinery/pkg/util/sets" +) + +// RandAllocator hands out items such that, at any moment, +// no two callers hold the same item concurrently. +// Call Acquire() to get exclusive use of an item, and Release() when done. +// +// If all items are busy, Acquire() will block until one is released +// or the context is canceled. +type RandAllocator[T comparable] struct { + mu sync.Mutex + items []T + inUse map[T]bool + + // waitCh is used to notify waiters that something changed + // (i.e. an item was released). It's always a non-nil channel. + // On every Release(), we close the old channel and make a new one. + waitCh chan struct{} +} + +// NewRandAllocator constructs a RandAllocator[T] from a set of items. +func NewRandAllocator[T comparable](vals sets.Set[T]) *RandAllocator[T] { + list := vals.UnsortedList() + + ra := &RandAllocator[T]{ + items: list, + inUse: make(map[T]bool, len(list)), + waitCh: make(chan struct{}), // open, will be closed to wake waiters + } + return ra +} + +// Acquire returns an item that is not currently in use by any other caller. +// It randomizes selection among the currently-free items. +// If none are free, it waits until one is released or ctx is canceled. +func (ra *RandAllocator[T]) Acquire(ctx context.Context) (T, error) { + var zero T + + for { + ra.mu.Lock() + + // Try to grab a free item immediately. + if item, ok := ra.pickFreeLocked(); ok { + ra.inUse[item] = true + ra.mu.Unlock() + return item, nil + } + + // No item free right now. + // If the caller's context is already done, abort. + if err := ctx.Err(); err != nil { + ra.mu.Unlock() + return zero, err + } + + // Take a snapshot of the current waitCh so we can wait + // without holding the mutex. + ch := ra.waitCh + + ra.mu.Unlock() + + // Wait until either: + // - context is canceled, or + // - someone calls Release() and closes ch. + select { + case <-ctx.Done(): + return zero, ctx.Err() + case <-ch: + // An item was released; loop and retry. + } + } +} + +// Release marks an item as free again and wakes any waiters. +func (ra *RandAllocator[T]) Release(item T) { + ra.mu.Lock() + defer ra.mu.Unlock() + + if ra.inUse[item] { + delete(ra.inUse, item) + } + + // Wake all current waiters by closing waitCh, + // then create a fresh channel for future waiters. + close(ra.waitCh) + ra.waitCh = make(chan struct{}) +} + +// Replace atomically swaps the candidate item set and wakes all waiters. +// Any items currently in use may remain absent from the new item set; they +// simply won't be handed out again once released. +func (ra *RandAllocator[T]) Replace(vals sets.Set[T]) { + ra.mu.Lock() + defer ra.mu.Unlock() + + ra.items = vals.UnsortedList() + + // Wake all current waiters so they can observe the new set. + close(ra.waitCh) + ra.waitCh = make(chan struct{}) +} + +// pickFreeLocked picks a currently-free item at random. +// caller must hold ra.mu. +func (ra *RandAllocator[T]) pickFreeLocked() (T, bool) { + var zero T + + n := len(ra.items) + if n == 0 { + return zero, false + } + + for _, idx := range rand.Perm(n) { + item := ra.items[idx] + if !ra.inUse[item] { + return item, true + } + } + return zero, false +} diff --git a/pkg/tunnel/randalloc/random_allocator_test.go b/pkg/tunnel/randalloc/random_allocator_test.go new file mode 100644 index 0000000..37d2ddd --- /dev/null +++ b/pkg/tunnel/randalloc/random_allocator_test.go @@ -0,0 +1,317 @@ +package randalloc_test + +import ( + "context" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "k8s.io/apimachinery/pkg/util/sets" + + "github.com/apoxy-dev/apoxy/pkg/tunnel/randalloc" +) + +// helper to build a RandAllocator[string] with deterministic items +func newAllocator(addrs ...string) *randalloc.RandAllocator[string] { + return randalloc.NewRandAllocator(sets.New[string](addrs...)) +} + +func TestAcquireAndReleaseSingle(t *testing.T) { + ra := newAllocator("r1", "r2") + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + addr, err := ra.Acquire(ctx) + require.NoError(t, err, "first acquire should succeed") + assert.Contains(t, []string{"r1", "r2"}, addr, "acquired addr must be from pool") + + // If we release it, we should be able to Acquire it again. + ra.Release(addr) + + addr2, err := ra.Acquire(ctx) + require.NoError(t, err, "second acquire after release should succeed") + assert.Contains(t, []string{"r1", "r2"}, addr2) +} + +func TestConcurrentUniqueAcquires(t *testing.T) { + ra := newAllocator("a", "b", "c") + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + var wg sync.WaitGroup + var mu sync.Mutex + got := make([]string, 0, 3) + + acquireOnce := func() { + defer wg.Done() + addr, err := ra.Acquire(ctx) + require.NoError(t, err) + + mu.Lock() + got = append(got, addr) + mu.Unlock() + } + + // Grab 3 relays in parallel; allocator size is 3 + wg.Add(3) + for i := 0; i < 3; i++ { + go acquireOnce() + } + wg.Wait() + + require.Len(t, got, 3) + + // All acquired addrs must be unique + seen := sets.New[string]() + for _, addr := range got { + assert.False(t, seen.Has(addr), "duplicate addr acquired concurrently: %s", addr) + seen.Insert(addr) + } +} + +func TestAcquireBlocksUntilRelease(t *testing.T) { + ra := newAllocator("only-one") + + ctx1, cancel1 := context.WithTimeout(context.Background(), time.Second) + defer cancel1() + + // First acquire should grab the only relay. + addr1, err := ra.Acquire(ctx1) + require.NoError(t, err) + assert.Equal(t, "only-one", addr1) + + // Second acquire should block until we release. + startCh := make(chan struct{}) + gotCh := make(chan string) + errCh := make(chan error) + + go func() { + close(startCh) // signal goroutine started + ctx2, cancel2 := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel2() + + addr2, err2 := ra.Acquire(ctx2) + if err2 != nil { + errCh <- err2 + return + } + gotCh <- addr2 + }() + + // make sure goroutine is actually running before we release + <-startCh + + // Briefly sleep to convince ourselves that goroutine would still be blocked + time.Sleep(50 * time.Millisecond) + + select { + case <-gotCh: + t.Fatalf("Acquire should still be blocked before Release") + case <-errCh: + t.Fatalf("Acquire errored before Release") + default: + // good, still blocked + } + + // Now Release the relay so the goroutine can proceed + ra.Release("only-one") + + // Now we expect the goroutine to finish successfully with same addr + select { + case addr2 := <-gotCh: + assert.Equal(t, "only-one", addr2, "after release, waiter should get freed relay") + case err2 := <-errCh: + t.Fatalf("blocked Acquire unexpectedly errored: %v", err2) + case <-time.After(time.Second): + t.Fatalf("Acquire did not unblock after Release") + } +} + +func TestAcquireContextCancel(t *testing.T) { + ra := newAllocator("busy") + + // Take the only relay so future Acquire will block. + ctxFirst, cancelFirst := context.WithTimeout(context.Background(), time.Second) + defer cancelFirst() + + addr, err := ra.Acquire(ctxFirst) + require.NoError(t, err) + assert.Equal(t, "busy", addr) + + // Now try to Acquire again, but with a short-lived context. + ctxBlocked, cancelBlocked := context.WithTimeout(context.Background(), 50*time.Millisecond) + defer cancelBlocked() + + start := time.Now() + addr2, err2 := ra.Acquire(ctxBlocked) + elapsed := time.Since(start) + + // We expect an error due to context timeout. + require.Error(t, err2, "Acquire should fail due to context timeout while no relay is available") + assert.Empty(t, addr2, "no addr should be returned on context cancel") + + // sanity check: it should have actually waited (i.e. not return instantly) + assert.GreaterOrEqual(t, elapsed, 40*time.Millisecond) +} + +func TestReleaseBroadcastsEvenIfNotInUse(t *testing.T) { + ra := newAllocator("x") + + // Take it so allocator marks it in-use. + ctxMain, cancelMain := context.WithTimeout(context.Background(), time.Second) + defer cancelMain() + + addr, err := ra.Acquire(ctxMain) + require.NoError(t, err) + assert.Equal(t, "x", addr) + + // Now start a waiter that will block. + gotCh := make(chan string) + errCh := make(chan error) + + go func() { + ctxWait, cancelWait := context.WithTimeout(context.Background(), time.Second) + defer cancelWait() + a, err := ra.Acquire(ctxWait) + if err != nil { + errCh <- err + return + } + gotCh <- a + }() + + // Release once (normal path) to free it. + ra.Release("x") + + // Releasing again should no-op but still broadcast. + // This mainly exercises the "if ra.inUse[item]" branch not being taken. + // It should still wake waiters. + ra.Release("x") + + select { + case got := <-gotCh: + assert.Equal(t, "x", got, "waiter should eventually acquire x") + case err := <-errCh: + t.Fatalf("waiter got unexpected error: %v", err) + case <-time.After(time.Second): + t.Fatalf("waiter did not wake after Release broadcasts") + } +} + +func TestReplaceWakesWaitersAndUsesNewSet(t *testing.T) { + ra := newAllocator("old") + + // Take the only old item so future Acquire blocks. + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + gotOld, err := ra.Acquire(ctx) + require.NoError(t, err) + assert.Equal(t, "old", gotOld) + + // Start a waiter that will block until Replace or Release happens. + gotCh := make(chan string) + errCh := make(chan error) + go func() { + ctxW, cancelW := context.WithTimeout(context.Background(), time.Second) + defer cancelW() + addr, err := ra.Acquire(ctxW) + if err != nil { + errCh <- err + return + } + gotCh <- addr + }() + + // Replace the set while "old" is still in use. + ra.Replace(sets.New[string]("new1", "new2")) + + // The waiter should wake and get one of the *new* items. + select { + case got := <-gotCh: + assert.Contains(t, []string{"new1", "new2"}, got, "waiter should receive an item from the NEW set after Replace()") + case err := <-errCh: + t.Fatalf("waiter got unexpected error: %v", err) + case <-time.After(time.Second): + t.Fatalf("waiter did not wake after Replace() broadcast") + } +} + +func TestReplaceRemovesOldItemsOnceReleased(t *testing.T) { + ra := newAllocator("keep", "drop") + + // Acquire both items so allocator marks them in use. + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + a1, err := ra.Acquire(ctx) + require.NoError(t, err) + a2, err := ra.Acquire(ctx) + require.NoError(t, err) + got := sets.New[string](a1, a2) + require.True(t, got.Has("keep") && got.Has("drop"), "sanity: acquired both keep and drop") + + // Replace the item set: keep "keep", drop "drop", add "new". + ra.Replace(sets.New[string]("keep", "new")) + + // Release both old items. + ra.Release("drop") + ra.Release("keep") + + // Now future acquires should *never* return "drop". + ctx2, cancel2 := context.WithTimeout(context.Background(), time.Second) + defer cancel2() + + seen := sets.New[string]() + for i := 0; i < 2; i++ { + addr, err := ra.Acquire(ctx2) + require.NoError(t, err) + seen.Insert(addr) + } + + // We only expect {"keep","new"} to be available. + assert.False(t, seen.Has("drop"), `"drop" should not be reissued after being removed by Replace`) + assert.Equal(t, sets.New[string]("keep", "new"), seen, "post-Replace pool should be exactly keep+new") +} + +func TestReplaceSubsetAndReacquireBehavior(t *testing.T) { + ra := newAllocator("a", "b", "c") + + // Acquire two so they're marked in-use. + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + first, err := ra.Acquire(ctx) + require.NoError(t, err) + second, err := ra.Acquire(ctx) + require.NoError(t, err) + + inUse := sets.New[string](first, second) + require.Len(t, inUse, 2) + + // Replace with a subset that keeps only "b" plus add "d". + ra.Replace(sets.New[string]("b", "d")) + + // Release both in-use items; only "b" should be eligible again (if it was one of the in-use), + // and "d" should be available. Any item not in the new set (like "a" or "c") must not reappear. + for _, it := range inUse.UnsortedList() { + ra.Release(it) + } + + // Collect the next two acquires. + ctx2, cancel2 := context.WithTimeout(context.Background(), time.Second) + defer cancel2() + + got := sets.New[string]() + for i := 0; i < 2; i++ { + addr, err := ra.Acquire(ctx2) + require.NoError(t, err) + got.Insert(addr) + } + + // Expect only from {"b","d"}; never "a" or "c". + assert.Subset(t, []string{"b", "d"}, got.UnsortedList(), "acquires should come from the new set only") + assert.False(t, got.Has("a"), `"a" was removed by Replace and should not be returned`) + assert.False(t, got.Has("c"), `"c" was removed by Replace and should not be returned`) +}