diff --git a/go.mod b/go.mod index 90e0cbf..0e44cb7 100644 --- a/go.mod +++ b/go.mod @@ -9,7 +9,7 @@ require ( github.com/adrg/xdg v0.5.3 github.com/alphadose/haxmap v1.4.1 github.com/anatol/vmtest v0.0.0-20250318022921-2f32244e2f0f - github.com/apoxy-dev/icx v0.7.2 + github.com/apoxy-dev/icx v0.8.0 github.com/avast/retry-go/v4 v4.6.1 github.com/bramvdbogaerde/go-scp v1.5.0 github.com/buraksezer/olric v0.5.6 @@ -45,6 +45,7 @@ require ( github.com/google/gopacket v1.1.19 github.com/google/uuid v1.6.0 github.com/hashicorp/go-discover v0.0.0-20240726212017-342faf50e5d4 + github.com/hashicorp/golang-lru/v2 v2.0.7 github.com/jedib0t/go-pretty/v6 v6.4.9 github.com/julienschmidt/httprouter v1.3.0 github.com/k3s-io/kine v0.13.2 diff --git a/go.sum b/go.sum index c360280..c64f2ed 100644 --- a/go.sum +++ b/go.sum @@ -121,8 +121,8 @@ github.com/apoxy-dev/apiserver-runtime v0.0.0-20250420214109-979c605051d1 h1:sAS github.com/apoxy-dev/apiserver-runtime v0.0.0-20250420214109-979c605051d1/go.mod h1:zOVeivsnCWenmbgr6kiefIExoqlbuv2xyg9SXXfbs5U= github.com/apoxy-dev/connect-ip-go v0.0.0-20250530062404-603929a73f45 h1:SwPk1n/oSVX7YwlNpC9KNH9YaYkcL/k6OfqSGVnxyyI= github.com/apoxy-dev/connect-ip-go v0.0.0-20250530062404-603929a73f45/go.mod h1:z5rtgIizc+/K27UtB0occwZgqg/mz3IqgyUJW8aubbI= -github.com/apoxy-dev/icx v0.7.2 h1:6GqlqxkjwyEwaQBAJJ40+iM6D6w46IKmKWtE/43bCUk= -github.com/apoxy-dev/icx v0.7.2/go.mod h1:Muuk3bRXTp3YB5Xj+xHOGQ/T1xVxIKJuvmMfLBXhIN4= +github.com/apoxy-dev/icx v0.8.0 h1:Aj/LWtFokyBYYFuISFqgbiWBQJpMdIN6vCMa21MIROc= +github.com/apoxy-dev/icx v0.8.0/go.mod h1:Muuk3bRXTp3YB5Xj+xHOGQ/T1xVxIKJuvmMfLBXhIN4= github.com/apoxy-dev/quic-go v0.0.0-20250530165952-53cca597715e h1:10GIpiVyKoRgCyr0J2TvJtdn17bsFHN+ROWkeVJpcOU= github.com/apoxy-dev/quic-go v0.0.0-20250530165952-53cca597715e/go.mod h1:MFlGGpcpJqRAfmYi6NC2cptDPSxRWTOGNuP4wqrWmzQ= github.com/apoxy-dev/upgrade-cli v0.0.0-20240213232412-a56c3a52fa0e h1:FBNxMQD93z2ththupB/BYKLEaMWaEr+G+sJWJqU2wC4= @@ -643,6 +643,8 @@ github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= github.com/hashicorp/golang-lru v0.5.4 h1:YDjusn29QI/Das2iO9M0BHnIbxPeyuCHsjMW+lJfyTc= github.com/hashicorp/golang-lru v0.5.4/go.mod h1:iADmTwqILo4mZ8BN3D2Q6+9jd8WM5uGBxy+E8yxSoD4= +github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k= +github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM= github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= github.com/hashicorp/logutils v1.0.0 h1:dLEQVugN8vlakKOUE3ihGLTZJRB4j+M2cdTm/ORI65Y= github.com/hashicorp/logutils v1.0.0/go.mod h1:QIAnNjmIWmVIIkWDTG1z5v++HQmx9WQRO+LraFDTW64= diff --git a/pkg/cmd/alpha/tunnel.go b/pkg/cmd/alpha/tunnel.go index d14f22d..2ad8cfa 100644 --- a/pkg/cmd/alpha/tunnel.go +++ b/pkg/cmd/alpha/tunnel.go @@ -21,6 +21,7 @@ import ( "github.com/apoxy-dev/apoxy/pkg/netstack" "github.com/apoxy-dev/apoxy/pkg/tunnel/api" "github.com/apoxy-dev/apoxy/pkg/tunnel/bifurcate" + "github.com/apoxy-dev/apoxy/pkg/tunnel/conntrackpc" "github.com/apoxy-dev/apoxy/pkg/tunnel/router" ) @@ -60,6 +61,9 @@ var tunnelRunCmd = &cobra.Command{ defer pcGeneve.Close() defer pcQuic.Close() + pcQuicMultiplexed := conntrackpc.New(pcQuic, conntrackpc.Options{}) + defer pcQuicMultiplexed.Close() + // Context and goroutines. g, ctx := errgroup.WithContext(cmd.Context()) @@ -119,9 +123,20 @@ var tunnelRunCmd = &cobra.Command{ } relay := addr g.Go(func() error { - // TODO (dpeckett): we will need to create a kind of multiplexed packetconn - // so that each QUIC client gets its own virtual private connection from pcQuic. - // This will be based on the remote ip presumably. + relayAddr, err := resolveAddrPort(ctx, relay) + if err != nil { + return fmt.Errorf("failed to resolve relay addr %q: %w", relay, err) + } + + pcQuic, err := pcQuicMultiplexed.Open(&net.UDPAddr{ + IP: relayAddr.Addr().AsSlice(), + Port: int(relayAddr.Port()), + }) + if err != nil { + return fmt.Errorf("failed to create multiplexed packet conn for relay %q: %w", relay, err) + } + defer pcQuic.Close() + return manageRelayConnection(ctx, pcQuic, getHandler, relay, tlsConf) }) } @@ -237,10 +252,10 @@ func manageRelayConnection( return cleanupOnErr(fmt.Errorf("init router: %w", err)) } - // Parse relay addr - remoteAddr, err := netip.ParseAddrPort(relayAddr) + // Resolve relay addr (supports hostname:port and ip:port) + remoteAddr, err := resolveAddrPort(ctx, relayAddr) if err != nil { - return cleanupOnErr(fmt.Errorf("parse relay addr %q: %w", relayAddr, err)) + return cleanupOnErr(fmt.Errorf("resolve relay addr %q: %w", relayAddr, err)) } overlayAddrs, err := stringsToPrefixes(connectResp.Addresses) @@ -390,6 +405,51 @@ func manageKeyRotation( } } +// resolveAddrPort accepts "host:port" where host may be a hostname or IP +// (IPv4/IPv6, with or without brackets) and returns a concrete netip.AddrPort. +func resolveAddrPort(ctx context.Context, hostport string) (netip.AddrPort, error) { + host, portStr, err := net.SplitHostPort(hostport) + if err != nil { + return netip.AddrPort{}, fmt.Errorf("split host/port: %w", err) + } + pn, err := net.LookupPort("udp", portStr) + if err != nil { + return netip.AddrPort{}, fmt.Errorf("lookup port %q: %w", portStr, err) + } + port := uint16(pn) + + // If host is already an IP, use it. + if ip, err := netip.ParseAddr(host); err == nil { + return netip.AddrPortFrom(ip, port), nil + } + + // Resolve hostname. Prefer IPv4, then IPv6. + addrs, err := net.DefaultResolver.LookupIPAddr(ctx, host) + if err != nil { + return netip.AddrPort{}, fmt.Errorf("lookup %q: %w", host, err) + } + var v4, v6 *netip.Addr + for _, a := range addrs { + if ip, ok := netip.AddrFromSlice(a.IP); ok { + if ip.Is4() && v4 == nil { + ipCopy := ip + v4 = &ipCopy + } else if ip.Is6() && v6 == nil { + ipCopy := ip + v6 = &ipCopy + } + } + } + switch { + case v4 != nil: + return netip.AddrPortFrom(*v4, port), nil + case v6 != nil: + return netip.AddrPortFrom(*v6, port), nil + default: + return netip.AddrPort{}, fmt.Errorf("no usable A/AAAA records for %q", host) + } +} + func stringsToPrefixes(addrs []string) ([]netip.Prefix, error) { prefixes := make([]netip.Prefix, 0, len(addrs)) for _, addr := range addrs { diff --git a/pkg/cmd/alpha/tunnel_test.go b/pkg/cmd/alpha/tunnel_test.go index 4fe30c8..e44bba3 100644 --- a/pkg/cmd/alpha/tunnel_test.go +++ b/pkg/cmd/alpha/tunnel_test.go @@ -23,7 +23,6 @@ import ( "github.com/apoxy-dev/apoxy/pkg/tunnel/connection" "github.com/apoxy-dev/apoxy/pkg/tunnel/controllers" "github.com/apoxy-dev/apoxy/pkg/tunnel/hasher" - "github.com/apoxy-dev/apoxy/pkg/tunnel/router" ) func TestTunnelRun(t *testing.T) { @@ -100,6 +99,10 @@ func startRelay(t *testing.T, token string, onConnect func(context.Context, stri rtr.On("Start", mock.Anything).Return(nil) rtr.On("Close").Return(nil) + rtr.On("AddAddr", mock.Anything, mock.Anything).Return(nil) + rtr.On("DelAddr", mock.Anything).Return(nil) + rtr.On("AddRoute", mock.Anything).Return(nil) + rtr.On("DelRoute", mock.Anything).Return(nil) r := tunnel.NewRelay("relay-it", pc, serverCert, h, idHasher, rtr) r.SetCredentials("test-tunnel", token) @@ -146,15 +149,6 @@ func (m *mockRouter) AddAddr(addr netip.Prefix, tun connection.Connection) error return args.Error(0) } -func (m *mockRouter) ListAddrs() ([]netip.Prefix, error) { - args := m.Called() - var addrs []netip.Prefix - if v := args.Get(0); v != nil { - addrs = v.([]netip.Prefix) - } - return addrs, args.Error(1) -} - func (m *mockRouter) DelAddr(addr netip.Prefix) error { args := m.Called(addr) return args.Error(0) @@ -170,24 +164,6 @@ func (m *mockRouter) DelRoute(dst netip.Prefix) error { return args.Error(0) } -func (m *mockRouter) ListRoutes() ([]router.TunnelRoute, error) { - args := m.Called() - var routes []router.TunnelRoute - if v := args.Get(0); v != nil { - routes = v.([]router.TunnelRoute) - } - return routes, args.Error(1) -} - -func (m *mockRouter) LocalAddresses() ([]netip.Prefix, error) { - args := m.Called() - var addrs []netip.Prefix - if v := args.Get(0); v != nil { - addrs = v.([]netip.Prefix) - } - return addrs, args.Error(1) -} - func (m *mockRouter) Close() error { args := m.Called() return args.Error(0) diff --git a/pkg/tunnel/adapter/connection.go b/pkg/tunnel/adapter/connection.go deleted file mode 100644 index 79cdac9..0000000 --- a/pkg/tunnel/adapter/connection.go +++ /dev/null @@ -1,132 +0,0 @@ -package adapter - -import ( - "fmt" - "net/netip" - "sync" - "sync/atomic" - - "github.com/apoxy-dev/apoxy/pkg/netstack" - "github.com/apoxy-dev/icx" -) - -// Connection is a connection like abstraction over an icx virtual network. -type Connection struct { - mu sync.Mutex - id string - handler *icx.Handler - localAddr netip.AddrPort - remoteAddr netip.AddrPort - vni *uint - overlayAddr *netip.Prefix - keyEpoch atomic.Uint32 -} - -// NewConnection creates a new Connection instance. -func NewConnection(id string, handler *icx.Handler, localAddr, remoteAddr netip.AddrPort) *Connection { - return &Connection{ - id: id, - handler: handler, - localAddr: localAddr, - remoteAddr: remoteAddr, - } -} - -func (c *Connection) Close() error { - c.mu.Lock() - defer c.mu.Unlock() - - if c.vni != nil { - if err := c.handler.RemoveVirtualNetwork(*c.vni); err != nil { - return err - } - c.vni = nil - c.overlayAddr = nil - } - return nil -} - -func (c *Connection) ID() string { - return c.id -} - -func (c *Connection) VNI() *uint { - c.mu.Lock() - defer c.mu.Unlock() - - return c.vni -} - -// Set the VNI assigned to this connection. -func (c *Connection) SetVNI(vni uint) error { - c.mu.Lock() - defer c.mu.Unlock() - - // No change - if c.vni != nil && *c.vni == vni { - return nil - } - - // Remove existing VNI if set - if c.vni != nil { - if err := c.handler.RemoveVirtualNetwork(*c.vni); err != nil { - return err - } - c.vni = nil - } - - // Add new VNI - var addrs []netip.Prefix - if c.overlayAddr != nil { - addrs = []netip.Prefix{*c.overlayAddr} - } - - if err := c.handler.AddVirtualNetwork(vni, netstack.ToFullAddress(c.remoteAddr), addrs); err != nil { - return fmt.Errorf("failed to add virtual network %d: %w", vni, err) - } - c.vni = &vni - - return nil -} - -// OverlayAddress returns the overlay address/cidr assigned to this connection. -func (c *Connection) OverlayAddress() string { - c.mu.Lock() - defer c.mu.Unlock() - - if c.overlayAddr != nil { - return c.overlayAddr.String() - } - return "" -} - -// Set the overlay address/cidr assigned to this connection. -func (c *Connection) SetOverlayAddress(addr string) error { - c.mu.Lock() - defer c.mu.Unlock() - - p, err := netip.ParsePrefix(addr) - if err != nil { - return fmt.Errorf("failed to parse overlay address %q: %w", addr, err) - } - - // No change - if c.overlayAddr != nil && (*c.overlayAddr).String() == p.String() { - return nil - } - c.overlayAddr = &p - - // If a VNI is active, update its allowed prefixes in-place. - if c.vni != nil { - if err := c.handler.UpdateVirtualNetworkAddrs(*c.vni, []netip.Prefix{p}); err != nil { - return fmt.Errorf("failed to update virtual network %d with address %q: %w", *c.vni, addr, err) - } - } - - return nil -} - -// IncrementKeyEpoch increments and returns the current key epoch for this connection. -func (c *Connection) IncrementKeyEpoch() uint32 { - return c.keyEpoch.Add(1) -} diff --git a/pkg/tunnel/connection.go b/pkg/tunnel/connection.go new file mode 100644 index 0000000..a324c41 --- /dev/null +++ b/pkg/tunnel/connection.go @@ -0,0 +1,186 @@ +package tunnel + +import ( + "fmt" + "net/netip" + "sync" + "sync/atomic" + + "github.com/apoxy-dev/apoxy/pkg/netstack" + "github.com/apoxy-dev/apoxy/pkg/tunnel/controllers" + "github.com/apoxy-dev/apoxy/pkg/tunnel/router" + "github.com/apoxy-dev/icx" +) + +var _ controllers.Connection = (*connection)(nil) + +// connection is a connection like abstraction over an icx virtual network. +type connection struct { + mu sync.Mutex + id string + handler *icx.Handler + router router.Router + localAddr netip.AddrPort + remoteAddr netip.AddrPort + vni *uint + overlayAddr *netip.Prefix + keyEpoch atomic.Uint32 +} + +// Close tears down the VNI and removes any router state. +func (c *connection) Close() error { + c.mu.Lock() + defer c.mu.Unlock() + + // Remove router addr first so traffic stops before tearing down the VNI. + if c.overlayAddr != nil && c.router != nil { + if err := c.router.DelAddr(*c.overlayAddr); err != nil { + return fmt.Errorf("failed to remove router addr %q: %w", c.overlayAddr.String(), err) + } + if err := c.router.DelRoute(*c.overlayAddr); err != nil { + return fmt.Errorf("failed to remove router route %q: %w", c.overlayAddr.String(), err) + } + } + + if c.vni != nil { + if err := c.handler.RemoveVirtualNetwork(*c.vni); err != nil { + return err + } + c.vni = nil + c.overlayAddr = nil + } + + return nil +} + +func (c *connection) ID() string { + return c.id +} + +func (c *connection) VNI() *uint { + c.mu.Lock() + defer c.mu.Unlock() + + return c.vni +} + +// Set the VNI assigned to this connection. +func (c *connection) SetVNI(vni uint) error { + c.mu.Lock() + defer c.mu.Unlock() + + // No change + if c.vni != nil && *c.vni == vni { + return nil + } + + // Remove existing VNI if set + if c.vni != nil { + if err := c.handler.RemoveVirtualNetwork(*c.vni); err != nil { + return err + } + c.vni = nil + } + + // Add new VNI + var addrs []netip.Prefix + if c.overlayAddr != nil { + addrs = []netip.Prefix{*c.overlayAddr} + } + + if err := c.handler.AddVirtualNetwork(vni, netstack.ToFullAddress(c.remoteAddr), addrs); err != nil { + return fmt.Errorf("failed to add virtual network %d: %w", vni, err) + } + c.vni = &vni + + return nil +} + +// OverlayAddress returns the overlay address/cidr assigned to this connection. +func (c *connection) OverlayAddress() string { + c.mu.Lock() + defer c.mu.Unlock() + + if c.overlayAddr != nil { + return c.overlayAddr.String() + } + return "" +} + +// SetOverlayAddress sets the overlay address/cidr and updates router + VNI. +func (c *connection) SetOverlayAddress(addr string) error { + c.mu.Lock() + defer c.mu.Unlock() + + p, err := netip.ParsePrefix(addr) + if err != nil { + return fmt.Errorf("failed to parse overlay address %q: %w", addr, err) + } + + // No change + if c.overlayAddr != nil && c.overlayAddr.String() == p.String() { + return nil + } + + // Keep the old value for router rollback if needed. + var old *netip.Prefix + if c.overlayAddr != nil { + tmp := *c.overlayAddr + old = &tmp + } + + // Program router: add new, then delete old (to avoid a gap). + if c.router != nil { + if err := c.router.AddAddr(p, nil); err != nil { + return fmt.Errorf("router.AddAddr(%s) failed: %w", p.String(), err) + } + if err := c.router.AddRoute(p); err != nil { + // Try to roll back: remove the new addr to avoid duplicates. + _ = c.router.DelAddr(p) + _ = c.router.DelRoute(p) + return fmt.Errorf("router.AddRoute(%s) failed: %w", p.String(), err) + } + if old != nil { + if err := c.router.DelAddr(*old); err != nil { + // Try to roll back: remove the new addr to avoid duplicates. + _ = c.router.DelAddr(p) + return fmt.Errorf("router.DelAddr(%s) failed: %w", old.String(), err) + } + if err := c.router.DelRoute(*old); err != nil { + return fmt.Errorf("router.DelRoute(%s) failed: %w", old.String(), err) + } + } + } + + // Update in-memory state. + c.overlayAddr = &p + + // 2) If a VNI is active, update its allowed prefixes in-place. + if c.vni != nil { + if err := c.handler.UpdateVirtualNetworkAddrs(*c.vni, []netip.Prefix{p}); err != nil { + // Attempt to roll back router state to old addr on failure. + if c.router != nil { + _ = c.router.DelAddr(p) + _ = c.router.DelRoute(p) + if old != nil { + _ = c.router.AddAddr(*old, nil) + _ = c.router.AddRoute(*old) + } + } + // Restore in-memory value. + c.overlayAddr = old + if old == nil { + // If there was no old addr, also clear it. + c.overlayAddr = nil + } + return fmt.Errorf("failed to update virtual network %d with address %q: %w", *c.vni, addr, err) + } + } + + return nil +} + +// IncrementKeyEpoch increments and returns the current key epoch for this connection. +func (c *connection) IncrementKeyEpoch() uint32 { + return c.keyEpoch.Add(1) +} diff --git a/pkg/tunnel/conntrackpc/conntrackpc.go b/pkg/tunnel/conntrackpc/conntrackpc.go new file mode 100644 index 0000000..1d9f731 --- /dev/null +++ b/pkg/tunnel/conntrackpc/conntrackpc.go @@ -0,0 +1,418 @@ +// Package conntrackpc provides a conntrack-style multiplexer for net.PacketConn, +// suitable for QUIC clients that want multiple "virtual" PacketConns over one UDP socket. +package conntrackpc + +import ( + "errors" + "net" + "sync" + "time" + + "github.com/hashicorp/golang-lru/v2/expirable" +) + +type Options struct { + // If true, a new VirtualPacketConn is auto-created on the first inbound packet + // seen from a remote address not yet in the table. + AutoCreate bool + + // TTL is the idle timeout for a flow. If no traffic touches the flow for this duration, + // it is evicted and closed by the cache. + TTL time.Duration + + // MaxFlows bounds memory usage; oldest/expired flows are evicted first. + MaxFlows int + + // Size of each per-flow inbound queue (non-blocking fanout). + RxBufSize int + + // If true, vconn.WriteTo's addr parameter can change the remote and re-key the flow. + AllowAddrOverrideOnWrite bool +} + +func (o Options) withDefaults() Options { + if o.TTL <= 0 { + o.TTL = 2 * time.Minute + } + if o.MaxFlows <= 0 { + o.MaxFlows = 1024 + } + if o.RxBufSize <= 0 { + o.RxBufSize = 64 + } + return o +} + +type ConntrackPacketConn struct { + underlying net.PacketConn + localAddr net.Addr + opts Options + + mu sync.RWMutex + flows *expirable.LRU[string, *VirtualPacketConn] + closed bool + readErr error + + wg sync.WaitGroup + stopRead chan struct{} +} + +func New(underlying net.PacketConn, opts Options) *ConntrackPacketConn { + opts = opts.withDefaults() + + ct := &ConntrackPacketConn{ + underlying: underlying, + localAddr: underlying.LocalAddr(), + opts: opts, + stopRead: make(chan struct{}), + } + + // Only close the flow if the evicted key is still the vconn's current key. + // This makes removals during re-keying (oldKey removal) a no-op. + onEvicted := func(k string, v *VirtualPacketConn) { + if v != nil { + // If v.key changed (due to re-key), skip closing. + if v.key == k { + _ = v.closeLocked(errFlowExpired) + } + } + } + ct.flows = expirable.NewLRU[string, *VirtualPacketConn](opts.MaxFlows, onEvicted, opts.TTL) + + ct.wg.Add(1) + go ct.readLoop() + + return ct +} + +func (c *ConntrackPacketConn) LocalAddr() net.Addr { return c.localAddr } +func (c *ConntrackPacketConn) SetDeadline(t time.Time) error { return c.underlying.SetDeadline(t) } +func (c *ConntrackPacketConn) SetReadDeadline(t time.Time) error { + return c.underlying.SetReadDeadline(t) +} +func (c *ConntrackPacketConn) SetWriteDeadline(t time.Time) error { + return c.underlying.SetWriteDeadline(t) +} +func (c *ConntrackPacketConn) ReadFrom(b []byte) (int, net.Addr, error) { + return c.underlying.ReadFrom(b) +} +func (c *ConntrackPacketConn) WriteTo(b []byte, a net.Addr) (int, error) { + return c.underlying.WriteTo(b, a) +} + +func (c *ConntrackPacketConn) Close() error { + c.mu.Lock() + if c.closed { + c.mu.Unlock() + return nil + } + c.closed = true + close(c.stopRead) + + // drain/close all flows by removing keys (triggers OnEvicted) + keys := c.flows.Keys() + c.mu.Unlock() + + for _, k := range keys { + c.flows.Remove(k) + } + err := c.underlying.Close() + c.wg.Wait() + return err +} + +// Open returns (or creates) a VirtualPacketConn bound to the provided remote. +func (c *ConntrackPacketConn) Open(remote *net.UDPAddr) (*VirtualPacketConn, error) { + if remote == nil { + return nil, errors.New("remote addr required") + } + key := remote.String() + + c.mu.Lock() + defer c.mu.Unlock() + if c.closed { + return nil, errConntrackClosed + } + + if v, ok := c.flows.Get(key); ok && !v.isClosed() { + // Refresh TTL by re-adding. + c.flows.Add(key, v) + return v, nil + } + + v := newVirtual(c, key, remote, c.opts.RxBufSize) + c.flows.Add(key, v) + return v, nil +} + +var ( + errConntrackClosed = errors.New("conntrack: closed") + errFlowExpired = errors.New("conntrack: flow expired") +) + +func (c *ConntrackPacketConn) readLoop() { + defer c.wg.Done() + buf := make([]byte, 64*1024) + + for { + select { + case <-c.stopRead: + return + default: + } + + n, from, err := c.underlying.ReadFrom(buf) + if err != nil { + c.mu.Lock() + c.readErr = err + // Close all flows + for _, k := range c.flows.Keys() { + if v, ok := c.flows.Peek(k); ok && v != nil { + _ = v.closeLocked(err) + } + } + c.mu.Unlock() + return + } + if n == 0 { + continue + } + + key := from.String() + + c.mu.Lock() + v, ok := c.flows.Get(key) + if !ok { + if !c.opts.AutoCreate { + c.mu.Unlock() + continue + } + udpFrom, _ := from.(*net.UDPAddr) + v = newVirtual(c, key, udpFrom, c.opts.RxBufSize) + c.flows.Add(key, v) // registers & sets TTL + } else { + // refresh TTL on activity + c.flows.Add(key, v) + } + + // non-blocking deliver; drop if back-pressured + select { + case v.inbound <- append([]byte(nil), buf[:n]...): + v.touch() + default: + // drop to avoid HOL blocking + } + c.mu.Unlock() + } +} + +type VirtualPacketConn struct { + parent *ConntrackPacketConn + key string + remote *net.UDPAddr + inbound chan []byte + closedCh chan struct{} + + rdMu sync.Mutex + rdDeadline time.Time + rdDeadlineSet bool + + wrMu sync.Mutex + wrDeadline time.Time + wrDeadlineSet bool +} + +func newVirtual(parent *ConntrackPacketConn, key string, remote *net.UDPAddr, rx int) *VirtualPacketConn { + return &VirtualPacketConn{ + parent: parent, + key: key, + remote: cloneUDPAddr(remote), + inbound: make(chan []byte, rx), + closedCh: make(chan struct{}), + } +} + +func (v *VirtualPacketConn) isClosed() bool { + select { + case <-v.closedCh: + return true + default: + return false + } +} + +func (v *VirtualPacketConn) closeLocked(_ error) error { + select { + case <-v.closedCh: + return nil + default: + close(v.closedCh) + // drain inbound + for { + select { + case <-v.inbound: + default: + return nil + } + } + } +} + +func (v *VirtualPacketConn) touch() { + // Refresh TTL by re-adding into the LRU. + v.parent.flows.Add(v.key, v) +} + +func (v *VirtualPacketConn) ReadFrom(b []byte) (int, net.Addr, error) { + // Handle deadline + timer := v.nextReadTimer() + if timer != nil { + defer timer.Stop() + } + + select { + case <-v.closedCh: + return 0, nil, net.ErrClosed + case <-timerC(timer): + return 0, nil, timeoutErr("read") + case pkt := <-v.inbound: + v.touch() + n := copy(b, pkt) + return n, cloneUDPAddr(v.remote), nil + } +} + +func (v *VirtualPacketConn) WriteTo(b []byte, addr net.Addr) (int, error) { + if v.isClosed() { + return 0, net.ErrClosed + } + remote := v.remote + + // Optional remote override + re-keying + if v.parent.opts.AllowAddrOverrideOnWrite && addr != nil { + if ua, ok := addr.(*net.UDPAddr); ok { + newKey := ua.String() + if newKey != v.key { + // Re-key safely: update fields, add new key, then remove old key. + v.parent.mu.Lock() + oldKey := v.key + v.key = newKey + v.remote = cloneUDPAddr(ua) + v.parent.flows.Add(newKey, v) + v.parent.flows.Remove(oldKey) // onEvicted won't close us now + v.parent.mu.Unlock() + } else { + v.remote = cloneUDPAddr(ua) + } + remote = ua + } + } + + // Respect write deadline by temporarily setting it on the shared socket. + timer := v.nextWriteTimer() + if timer != nil { + defer timer.Stop() + } + if deadline, ok := v.getWriteDeadline(); ok { + _ = v.parent.underlying.SetWriteDeadline(deadline) + defer v.parent.underlying.SetWriteDeadline(time.Time{}) + } + + n, err := v.parent.underlying.WriteTo(b, remote) + if err == nil { + v.touch() + } + return n, err +} + +func (v *VirtualPacketConn) Close() error { + // Remove from LRU (will trigger OnEvicted -> closeLocked) + v.parent.flows.Remove(v.key) + return nil +} + +func (v *VirtualPacketConn) LocalAddr() net.Addr { return v.parent.localAddr } + +func (v *VirtualPacketConn) SetDeadline(t time.Time) error { + _ = v.SetReadDeadline(t) + _ = v.SetWriteDeadline(t) + return nil +} +func (v *VirtualPacketConn) SetReadDeadline(t time.Time) error { + v.rdMu.Lock() + v.rdDeadline = t + v.rdDeadlineSet = !t.IsZero() + v.rdMu.Unlock() + return nil +} +func (v *VirtualPacketConn) SetWriteDeadline(t time.Time) error { + v.wrMu.Lock() + v.wrDeadline = t + v.wrDeadlineSet = !t.IsZero() + v.wrMu.Unlock() + return nil +} + +func (v *VirtualPacketConn) nextReadTimer() *time.Timer { + v.rdMu.Lock() + defer v.rdMu.Unlock() + if !v.rdDeadlineSet { + return nil + } + d := time.Until(v.rdDeadline) + if d <= 0 { + d = time.Nanosecond + } + return time.NewTimer(d) +} +func (v *VirtualPacketConn) getWriteDeadline() (time.Time, bool) { + v.wrMu.Lock() + defer v.wrMu.Unlock() + return v.wrDeadline, v.wrDeadlineSet +} +func (v *VirtualPacketConn) nextWriteTimer() *time.Timer { + v.wrMu.Lock() + defer v.wrMu.Unlock() + if !v.wrDeadlineSet { + return nil + } + d := time.Until(v.wrDeadline) + if d <= 0 { + d = time.Nanosecond + } + return time.NewTimer(d) +} + +func timerC(t *time.Timer) <-chan time.Time { + if t == nil { + return nil + } + return t.C +} + +func cloneUDPAddr(a *net.UDPAddr) *net.UDPAddr { + if a == nil { + return nil + } + out := *a + if ip := a.IP; ip != nil { + cp := make([]byte, len(ip)) + copy(cp, ip) + out.IP = cp + } + return &out +} + +func timeoutErr(op string) error { + type t interface { + Timeout() bool + Error() string + } + return &net.OpError{Op: op, Err: errTimeout{}} +} + +type errTimeout struct{} + +func (errTimeout) Error() string { return "i/o timeout" } +func (errTimeout) Timeout() bool { return true } +func (errTimeout) Temporary() bool { return true } diff --git a/pkg/tunnel/conntrackpc/conntrackpc_test.go b/pkg/tunnel/conntrackpc/conntrackpc_test.go new file mode 100644 index 0000000..923754f --- /dev/null +++ b/pkg/tunnel/conntrackpc/conntrackpc_test.go @@ -0,0 +1,270 @@ +package conntrackpc_test + +import ( + "errors" + "net" + "testing" + "time" + + "github.com/apoxy-dev/apoxy/pkg/tunnel/conntrackpc" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestOpenSendReceive(t *testing.T) { + under, local := makeUDP(t) + ct := conntrackpc.New(under, conntrackpc.Options{ + AutoCreate: false, + TTL: time.Minute, + MaxFlows: 32, + RxBufSize: 8, + }) + t.Cleanup(func() { _ = ct.Close() }) + + peerPC, peerAddr := makeUDP(t) + + // Open a virtual connection bound to peerAddr and exchange packets. + v, err := ct.Open(peerAddr) + require.NoError(t, err) + + // Send from peer -> ct -> v + payload1 := []byte("hello from peer") + sendTo(t, peerPC, local, payload1) + + buf := make([]byte, 1500) + // Virtual read should deliver exactly what peer sent, and report peer's addr. + require.NoError(t, v.SetReadDeadline(time.Now().Add(2*time.Second))) + n, addr, err := v.ReadFrom(buf) + require.NoError(t, err) + assert.Equal(t, peerAddr.String(), addr.String()) + assert.Equal(t, payload1, buf[:n]) + + // Send from v -> peer + payload2 := []byte("hi back from vconn") + nw, err := v.WriteTo(payload2, nil) // nil addr => use bound remote + require.NoError(t, err) + assert.Equal(t, len(payload2), nw) + + got, from := recvFrom(t, peerPC, 2*time.Second) + assert.Equal(t, local.String(), from.String()) + assert.Equal(t, payload2, got) +} + +func TestAutoCreateOnInbound(t *testing.T) { + under, local := makeUDP(t) + ct := conntrackpc.New(under, conntrackpc.Options{ + AutoCreate: true, + TTL: time.Minute, + MaxFlows: 32, + RxBufSize: 8, + }) + t.Cleanup(func() { _ = ct.Close() }) + + peerPC, peerAddr := makeUDP(t) + + // Deliver a packet before calling Open: should auto-create the flow and queue it. + payload := []byte("first contact") + sendTo(t, peerPC, local, payload) + + // Now Open should return the existing virtual conn, with the first packet waiting. + v, err := ct.Open(peerAddr) + require.NoError(t, err) + + buf := make([]byte, 1500) + require.NoError(t, v.SetReadDeadline(time.Now().Add(2*time.Second))) + n, addr, err := v.ReadFrom(buf) + require.NoError(t, err) + assert.Equal(t, peerAddr.String(), addr.String()) + assert.Equal(t, payload, buf[:n]) +} + +func TestReadDeadlineTimeout(t *testing.T) { + under, _ := makeUDP(t) + ct := conntrackpc.New(under, conntrackpc.Options{ + AutoCreate: false, + TTL: time.Minute, + MaxFlows: 32, + RxBufSize: 8, + }) + t.Cleanup(func() { _ = ct.Close() }) + + _, peerAddr := makeUDP(t) + v, err := ct.Open(peerAddr) + require.NoError(t, err) + + // No packets inbound; a near-term read deadline should time out. + deadline := time.Now().Add(50 * time.Millisecond) + require.NoError(t, v.SetReadDeadline(deadline)) + + buf := make([]byte, 1500) + _, _, err = v.ReadFrom(buf) + var nerr net.Error + require.ErrorAs(t, err, &nerr) + assert.True(t, nerr.Timeout(), "expected timeout error") +} + +func TestTTLExpiryEvictsAndClosesFlow(t *testing.T) { + under, _ := makeUDP(t) + ct := conntrackpc.New(under, conntrackpc.Options{ + AutoCreate: false, + TTL: 80 * time.Millisecond, + MaxFlows: 32, + RxBufSize: 8, + }) + t.Cleanup(func() { _ = ct.Close() }) + + _, peerAddr := makeUDP(t) + v, err := ct.Open(peerAddr) + require.NoError(t, err) + + // Wait past TTL plus a little. The LRU eviction happens when TTL expires, + // driven by cache access/ops; NewLRU with expirable TTL evicts lazily on Ops. + // We trigger an op by opening another key to ensure eviction occurs. + time.Sleep(120 * time.Millisecond) + + // Touch the cache to provoke TTL cleanup; use a different dummy remote. + _, other := makeUDP(t) + _, _ = ct.Open(other) // triggers internal add + housekeeping + + // The old vconn should now be closed; a read should return net.ErrClosed quickly. + require.NoError(t, v.SetReadDeadline(time.Now().Add(50*time.Millisecond))) + _, _, err = v.ReadFrom(make([]byte, 1)) + require.Error(t, err) + assert.True(t, errors.Is(err, net.ErrClosed), "expected net.ErrClosed after TTL eviction") +} + +func TestMaxFlowsEvictsOldestAndCloses(t *testing.T) { + under, _ := makeUDP(t) + ct := conntrackpc.New(under, conntrackpc.Options{ + AutoCreate: false, + TTL: time.Minute, + MaxFlows: 1, // only one flow allowed + RxBufSize: 8, + }) + t.Cleanup(func() { _ = ct.Close() }) + + _, a := makeUDP(t) + va, err := ct.Open(a) + require.NoError(t, err) + + // Open a second flow; LRU should evict 'a' and close it. + _, b := makeUDP(t) + vb, err := ct.Open(b) + require.NoError(t, err) + require.NotNil(t, vb) + + // Reading from the evicted first flow should yield net.ErrClosed. + require.NoError(t, va.SetReadDeadline(time.Now().Add(50*time.Millisecond))) + _, _, err = va.ReadFrom(make([]byte, 1)) + require.Error(t, err) + assert.True(t, errors.Is(err, net.ErrClosed)) +} + +func TestAllowAddrOverrideOnWriteRekeysFlow(t *testing.T) { + under, local := makeUDP(t) + ct := conntrackpc.New(under, conntrackpc.Options{ + AutoCreate: true, + TTL: time.Minute, + MaxFlows: 32, + RxBufSize: 8, + AllowAddrOverrideOnWrite: true, + }) + t.Cleanup(func() { _ = ct.Close() }) + + // Two peers + peer1PC, peer1 := makeUDP(t) + peer2PC, peer2 := makeUDP(t) + + // Establish flow with peer1 + v, err := ct.Open(peer1) + require.NoError(t, err) + + // Write to peer2 using override; this should re-key the virtual flow to peer2. + msg := []byte("rekey to peer2") + nw, err := v.WriteTo(msg, peer2) + require.NoError(t, err) + assert.Equal(t, len(msg), nw) + + // Peer2 should receive it. + got, from := recvFrom(t, peer2PC, 2*time.Second) + assert.Equal(t, local.String(), from.String()) + assert.Equal(t, msg, got) + + // Now send from peer2 back to ct; v should receive it (flow has re-keyed). + reply := []byte("ack from peer2") + sendTo(t, peer2PC, local, reply) + + buf := make([]byte, 1500) + require.NoError(t, v.SetReadDeadline(time.Now().Add(2*time.Second))) + n, addr, err := v.ReadFrom(buf) + require.NoError(t, err) + assert.Equal(t, peer2.String(), addr.String()) + assert.Equal(t, reply, buf[:n]) + + // And a packet from peer1 should auto-create a *new* flow (since v moved). + sendTo(t, peer1PC, local, []byte("peer1 still here")) + v1, err := ct.Open(peer1) // should be a different handle than v + require.NoError(t, err) + require.NoError(t, v1.SetReadDeadline(time.Now().Add(2*time.Second))) + n, _, err = v1.ReadFrom(buf) + require.NoError(t, err) + assert.Equal(t, []byte("peer1 still here"), buf[:n]) +} + +func TestClosePropagatesToFlows(t *testing.T) { + under, _ := makeUDP(t) + ct := conntrackpc.New(under, conntrackpc.Options{ + AutoCreate: true, + TTL: time.Minute, + MaxFlows: 32, + RxBufSize: 8, + }) + _, peerAddr := makeUDP(t) + + v, err := ct.Open(peerAddr) + require.NoError(t, err) + + // Close conntrack; read on v should promptly return net.ErrClosed. + require.NoError(t, ct.Close()) + + require.NoError(t, v.SetReadDeadline(time.Now().Add(100*time.Millisecond))) + _, _, err = v.ReadFrom(make([]byte, 1)) + require.Error(t, err) + assert.True(t, errors.Is(err, net.ErrClosed)) + + // Writes from peer should now fail at the underlying since it's closed. + err = v.SetWriteDeadline(time.Now().Add(100 * time.Millisecond)) + require.NoError(t, err) + _, err = v.WriteTo([]byte("x"), ct.LocalAddr()) + assert.Error(t, err) +} + +// makeUDP binds a UDP socket on loopback and returns it plus its *net.UDPAddr. +func makeUDP(t *testing.T) (net.PacketConn, *net.UDPAddr) { + t.Helper() + pc, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + t.Cleanup(func() { _ = pc.Close() }) + + ua, ok := pc.LocalAddr().(*net.UDPAddr) + require.True(t, ok) + return pc, ua +} + +// recvFrom reads one datagram with a short deadline. +func recvFrom(t *testing.T, pc net.PacketConn, d time.Duration) ([]byte, net.Addr) { + t.Helper() + require.NoError(t, pc.SetReadDeadline(time.Now().Add(d))) + buf := make([]byte, 64*1024) + n, from, err := pc.ReadFrom(buf) + require.NoError(t, err) + return append([]byte(nil), buf[:n]...), from +} + +// sendTo writes one datagram with a short deadline. +func sendTo(t *testing.T, pc net.PacketConn, to net.Addr, payload []byte) { + t.Helper() + require.NoError(t, pc.SetWriteDeadline(time.Now().Add(2*time.Second))) + _, err := pc.WriteTo(payload, to) + require.NoError(t, err) +} diff --git a/pkg/tunnel/controllers/tunnel_agent_reconciler.go b/pkg/tunnel/controllers/tunnel_agent_reconciler.go index 82cc560..543fc81 100644 --- a/pkg/tunnel/controllers/tunnel_agent_reconciler.go +++ b/pkg/tunnel/controllers/tunnel_agent_reconciler.go @@ -3,6 +3,7 @@ package controllers import ( "context" "fmt" + "log/slog" "github.com/alphadose/haxmap" "k8s.io/apimachinery/pkg/api/equality" @@ -192,7 +193,12 @@ func (r *TunnelAgentReconciler) AddConnection(ctx context.Context, agentName str // RemoveConnection deregisters a connection from the given agent by its ID. func (r *TunnelAgentReconciler) RemoveConnection(ctx context.Context, agentName, id string) error { // Drop from in-memory map. - r.conns.Del(id) + conn, ok := r.conns.GetAndDel(id) + if ok { + if err := conn.Close(); err != nil { + slog.Warn("Failed to close connection", slog.String("id", id), slog.Any("error", err)) + } + } // Remove from status.connections (by ID) if err := retry.RetryOnConflict(retry.DefaultBackoff, func() error { diff --git a/pkg/tunnel/relay.go b/pkg/tunnel/relay.go index 0cd8b67..e665953 100644 --- a/pkg/tunnel/relay.go +++ b/pkg/tunnel/relay.go @@ -23,7 +23,6 @@ import ( "github.com/quic-go/quic-go/http3" "golang.org/x/sync/errgroup" - "github.com/apoxy-dev/apoxy/pkg/tunnel/adapter" "github.com/apoxy-dev/apoxy/pkg/tunnel/api" "github.com/apoxy-dev/apoxy/pkg/tunnel/controllers" "github.com/apoxy-dev/apoxy/pkg/tunnel/hasher" @@ -42,8 +41,8 @@ type Relay struct { handler *icx.Handler idHasher *hasher.Hasher router router.Router - tokens *haxmap.Map[string, string] // map[tunnelName]token - conns *haxmap.Map[string, *adapter.Connection] // map[connectionID]Connection + tokens *haxmap.Map[string, string] // map[tunnelName]token + conns *haxmap.Map[string, *connection] // map[connectionID]Connection onConnect func(ctx context.Context, agentName string, conn controllers.Connection) error onDisconnect func(ctx context.Context, agentName, id string) error } @@ -57,7 +56,7 @@ func NewRelay(name string, pc net.PacketConn, cert tls.Certificate, handler *icx idHasher: idHasher, router: router, tokens: haxmap.New[string, string](), - conns: haxmap.New[string, *adapter.Connection](), + conns: haxmap.New[string, *connection](), } } @@ -172,7 +171,14 @@ func (r *Relay) handleConnect(w http.ResponseWriter, req *http.Request, ps httpr } id := r.idHasher.Hash(localAddr, remoteAddr) - conn := adapter.NewConnection(id, r.handler, localAddr, remoteAddr) + + conn := &connection{ + id: id, + handler: r.handler, + router: r.router, + localAddr: localAddr, + remoteAddr: remoteAddr, + } r.conns.Set(conn.ID(), conn) diff --git a/pkg/tunnel/relay_test.go b/pkg/tunnel/relay_test.go index b36b4a5..e1bc6fb 100644 --- a/pkg/tunnel/relay_test.go +++ b/pkg/tunnel/relay_test.go @@ -236,15 +236,6 @@ func (m *mockRouter) AddAddr(addr netip.Prefix, tun connection.Connection) error return args.Error(0) } -func (m *mockRouter) ListAddrs() ([]netip.Prefix, error) { - args := m.Called() - var addrs []netip.Prefix - if v := args.Get(0); v != nil { - addrs = v.([]netip.Prefix) - } - return addrs, args.Error(1) -} - func (m *mockRouter) DelAddr(addr netip.Prefix) error { args := m.Called(addr) return args.Error(0) diff --git a/pkg/tunnel/router/client_icx_netstack.go b/pkg/tunnel/router/client_icx_netstack.go index c922ac3..c1bec25 100644 --- a/pkg/tunnel/router/client_icx_netstack.go +++ b/pkg/tunnel/router/client_icx_netstack.go @@ -157,11 +157,6 @@ func (r *ICXNetstackRouter) AddAddr(addr netip.Prefix, tun connection.Connection return nil } -// ListAddrs returns a list of all addresses currently managed by the router. -func (r *ICXNetstackRouter) ListAddrs() ([]netip.Prefix, error) { - return r.net.LocalAddresses() -} - // DelAddr removes a tun by its addr from the router. func (r *ICXNetstackRouter) DelAddr(addr netip.Prefix) error { if err := r.net.DelAddr(addr); err != nil { @@ -186,26 +181,3 @@ func (r *ICXNetstackRouter) AddRoute(dst netip.Prefix) error { func (r *ICXNetstackRouter) DelRoute(dst netip.Prefix) error { return nil } - -// ListRoutes returns a list of all routes currently managed by the router. -func (r *ICXNetstackRouter) ListRoutes() ([]TunnelRoute, error) { - localAddrs, err := r.net.LocalAddresses() - if err != nil { - return nil, fmt.Errorf("failed to list local addresses: %w", err) - } - - var routes []TunnelRoute - for _, addr := range localAddrs { - routes = append(routes, TunnelRoute{ - Dst: addr, - State: TunnelRouteStateActive, - }) - } - - return routes, nil -} - -// LocalAddresses returns the list of local addresses that are assigned to the router. -func (r *ICXNetstackRouter) LocalAddresses() ([]netip.Prefix, error) { - return r.net.LocalAddresses() -} diff --git a/pkg/tunnel/router/client_netlink_linux.go b/pkg/tunnel/router/client_netlink_linux.go index 0deee7e..d8be740 100644 --- a/pkg/tunnel/router/client_netlink_linux.go +++ b/pkg/tunnel/router/client_netlink_linux.go @@ -41,7 +41,7 @@ var ( ) // NewClientNetlinkRouter creates a new client-side netlink-based tunnel router. -func NewClientNetlinkRouter(opts ...Option) (Router, error) { +func NewClientNetlinkRouter(opts ...Option) (*ClientNetlinkRouter, error) { return newClientNetlinkRouter(opts...) } @@ -271,28 +271,6 @@ func (r *ClientNetlinkRouter) AddAddr(addr netip.Prefix, tun connection.Connecti return r.smux.Add(addr, tun) } -// ListAddrs returns a list of addresses added to the TUN interface. -func (r *ClientNetlinkRouter) ListAddrs() ([]netip.Prefix, error) { - ifcAddrs, err := netlink.AddrList(r.tunLink, netlink.FAMILY_ALL) - if err != nil { - return nil, fmt.Errorf("failed to list addresses on TUN interface: %w", err) - } - - var out []netip.Prefix - for _, ifcAddr := range ifcAddrs { - addr := netip.Addr{} - if ifcAddr.IP.To16() != nil { - addr = netip.AddrFrom16([16]byte(ifcAddr.IP.To16())) - } else if ifcAddr.IP.To4() != nil { - addr = netip.AddrFrom4([4]byte(ifcAddr.IP.To4())) - } - ones, _ := ifcAddr.Mask.Size() - out = append(out, netip.PrefixFrom(addr, ones)) - } - - return out, nil -} - // DelAddr deletes an address from a TUN interface and removes the corresponding route. func (r *ClientNetlinkRouter) DelAddr(addr netip.Prefix) error { mask := net.CIDRMask(addr.Bits(), 128) diff --git a/pkg/tunnel/router/client_netstack.go b/pkg/tunnel/router/client_netstack.go index f59a4f7..ef78a95 100644 --- a/pkg/tunnel/router/client_netstack.go +++ b/pkg/tunnel/router/client_netstack.go @@ -127,11 +127,6 @@ func (r *NetstackRouter) AddAddr(addr netip.Prefix, conn connection.Connection) return r.smux.Add(addr, conn) } -// ListAddrs lists all addresses added to the tunnel. -func (r *NetstackRouter) ListAddrs() ([]netip.Prefix, error) { - return r.tunDev.LocalAddresses() -} - // DelAddr removes a dst route from the tunnel. func (r *NetstackRouter) DelAddr(addr netip.Prefix) error { if err := r.tunDev.DelAddr(addr); err != nil { @@ -162,20 +157,6 @@ func (r *NetstackRouter) DelAll(dst netip.Prefix) error { return nil } -// ListRoutes returns a list of all routes in the tunnel. -func (r *NetstackRouter) ListRoutes() ([]TunnelRoute, error) { - ps := r.smux.Prefixes() - rts := make([]TunnelRoute, 0, len(ps)) - for _, p := range ps { - rts = append(rts, TunnelRoute{ - Dst: p, - // TODO: Add connID, - State: TunnelRouteStateActive, - }) - } - return rts, nil -} - // Close releases any resources associated with the router. func (r *NetstackRouter) Close() error { var firstErr error @@ -203,8 +184,3 @@ func (r *NetstackRouter) Close() error { }) return firstErr } - -// LocalAddresses returns the list of local addresses that are assigned to the router. -func (r *NetstackRouter) LocalAddresses() ([]netip.Prefix, error) { - return r.tunDev.LocalAddresses() -} diff --git a/pkg/tunnel/router/router.go b/pkg/tunnel/router/router.go index a8a1460..b499e25 100644 --- a/pkg/tunnel/router/router.go +++ b/pkg/tunnel/router/router.go @@ -32,9 +32,6 @@ type Router interface { // AddAddr adds a tun with an associated address to the router. AddAddr(addr netip.Prefix, tun connection.Connection) error - // ListAddrs returns a list of all addresses currently managed by the router. - ListAddrs() ([]netip.Prefix, error) - // DelAddr removes a tun by its addr from the router. DelAddr(addr netip.Prefix) error @@ -49,10 +46,4 @@ type Router interface { // getting re-routed via a different tunnel or dropped (if no tunnel is available for // the given dst). DelRoute(dst netip.Prefix) error - - // ListRoutes returns a list of all routes currently managed by the router. - ListRoutes() ([]TunnelRoute, error) - - // LocalAddresses returns the list of local addresses that are assigned to the router. - LocalAddresses() ([]netip.Prefix, error) } diff --git a/pkg/tunnel/router/server_icx_netlink_linux.go b/pkg/tunnel/router/server_icx_netlink_linux.go index f6d79aa..30a9726 100644 --- a/pkg/tunnel/router/server_icx_netlink_linux.go +++ b/pkg/tunnel/router/server_icx_netlink_linux.go @@ -1,9 +1,11 @@ package router import ( + "bytes" "context" "errors" "fmt" + "log/slog" "math" "net" "net/netip" @@ -19,6 +21,9 @@ import ( "github.com/slavc/xdp" "github.com/vishvananda/netlink" "gvisor.dev/gvisor/pkg/tcpip" + proxyutil "k8s.io/kubernetes/pkg/proxy/util" + utiliptables "k8s.io/kubernetes/pkg/util/iptables" + utilexec "k8s.io/utils/exec" "github.com/apoxy-dev/apoxy/pkg/netstack" "github.com/apoxy-dev/apoxy/pkg/tunnel/connection" @@ -35,10 +40,12 @@ var ( type ICXNetlinkRouter struct { Handler *icx.Handler - vethDev *veth.Handle + tunDev *veth.Handle + tunLink netlink.Link ingressFilter *xdp.Program pcapFile *os.File tun *tunnel.Tunnel + iptV4, iptV6 utiliptables.Interface closeOnce sync.Once } @@ -68,12 +75,41 @@ func NewICXNetlinkRouter(opts ...Option) (*ICXNetlinkRouter, error) { return nil, fmt.Errorf("failed to get number of TX queues for interface %s: %w", options.extIfaceName, err) } - vethDev, err := veth.Create(options.tunIfaceName, numQueues, icx.MTU(extPathMTU)) + tunDev, err := veth.Create(options.tunIfaceName, numQueues, icx.MTU(extPathMTU)) if err != nil { return nil, fmt.Errorf("failed to create veth device: %w", err) } - virtMAC := tcpip.LinkAddress(vethDev.Link.Attrs().HardwareAddr) + tunLink, err := netlink.LinkByName(options.tunIfaceName) + if err != nil { + _ = tunDev.Close() + return nil, fmt.Errorf("failed to get veth interface: %w", err) + } + + if !options.extIPv6Prefix.IsValid() { + slog.Warn("external IPv6 prefix is not valid - ingress is disabled") + } + + for _, addr := range options.localAddresses { + ip := addr.Addr() + mask := net.CIDRMask(addr.Bits(), 128) + if ip.Is4() { + mask = net.CIDRMask(addr.Bits(), 32) + } + + if err := netlink.AddrAdd(tunLink, &netlink.Addr{ + IPNet: &net.IPNet{ + IP: ip.AsSlice(), + Mask: mask, + }, + }); err != nil { + _ = tunDev.Close() + return nil, fmt.Errorf("failed to add address to veth interface: %w", err) + } + slog.Info("Added address to veth interface", slog.String("addr", addr.String())) + } + + virtMAC := tcpip.LinkAddress(tunDev.Link.Attrs().HardwareAddr) handlerOpts := []icx.HandlerOption{ icx.WithLocalAddr(localAddr), @@ -85,13 +121,13 @@ func NewICXNetlinkRouter(opts ...Option) (*ICXNetlinkRouter, error) { handler, err := icx.NewHandler(handlerOpts...) if err != nil { - _ = vethDev.Close() + _ = tunDev.Close() return nil, fmt.Errorf("failed to create handler: %w", err) } ingressFilter, err := filter.Bind(addrs...) if err != nil { - _ = vethDev.Close() + _ = tunDev.Close() return nil, fmt.Errorf("failed to create ingress filter: %w", err) } @@ -100,7 +136,7 @@ func NewICXNetlinkRouter(opts ...Option) (*ICXNetlinkRouter, error) { if options.pcapPath != "" { pcapFile, err = os.Create(options.pcapPath) if err != nil { - _ = vethDev.Close() + _ = tunDev.Close() _ = ingressFilter.Close() return nil, fmt.Errorf("failed to create pcap file: %w", err) } @@ -111,19 +147,22 @@ func NewICXNetlinkRouter(opts ...Option) (*ICXNetlinkRouter, error) { } } - tun, err := tunnel.NewTunnel(options.extIfaceName, vethDev.Peer.Attrs().Name, ingressFilter, handler, pcapWriter) + tun, err := tunnel.NewTunnel(options.extIfaceName, tunDev.Peer.Attrs().Name, ingressFilter, handler, pcapWriter) if err != nil { - _ = vethDev.Close() + _ = tunDev.Close() _ = ingressFilter.Close() return nil, fmt.Errorf("failed to create tunnel: %w", err) } return &ICXNetlinkRouter{ Handler: handler, - vethDev: vethDev, + tunDev: tunDev, + tunLink: tunLink, ingressFilter: ingressFilter, pcapFile: pcapFile, tun: tun, + iptV4: utiliptables.New(utilexec.New(), utiliptables.ProtocolIPv4), + iptV6: utiliptables.New(utilexec.New(), utiliptables.ProtocolIPv6), }, nil } @@ -133,7 +172,7 @@ func (r *ICXNetlinkRouter) Close() error { if err := r.tun.Close(); err != nil && firstErr == nil { firstErr = err } - if err := r.vethDev.Close(); err != nil && firstErr == nil { + if err := r.tunDev.Close(); err != nil && firstErr == nil { firstErr = err } if err := r.ingressFilter.Close(); err != nil && firstErr == nil { @@ -157,20 +196,26 @@ func (r *ICXNetlinkRouter) Start(ctx context.Context) error { } // AddAddr adds a tun with an associated address to the router. -func (r *ICXNetlinkRouter) AddAddr(addr netip.Prefix, tun connection.Connection) error { - // TODO (dpeckett): implement - return nil -} +func (r *ICXNetlinkRouter) AddAddr(_ netip.Prefix, _ connection.Connection) error { + // Virtual networks are managed externally, so we just need to + // sync the DNAT rules to include the new address. -// ListAddrs returns a list of all addresses currently managed by the router. -func (r *ICXNetlinkRouter) ListAddrs() ([]netip.Prefix, error) { - // TODO (dpeckett): implement - return nil, nil + if err := r.syncDNATChain(); err != nil { + return fmt.Errorf("failed to sync DNAT chain: %w", err) + } + + return nil } // DelAddr removes a tun by its addr from the router. -func (r *ICXNetlinkRouter) DelAddr(addr netip.Prefix) error { - // TODO (dpeckett): implement +func (r *ICXNetlinkRouter) DelAddr(_ netip.Prefix) error { + // Virtual networks are managed externally, so we just need to + // sync the DNAT rules to remove the address. + + if err := r.syncDNATChain(); err != nil { + return fmt.Errorf("failed to sync DNAT chain: %w", err) + } + return nil } @@ -178,7 +223,26 @@ func (r *ICXNetlinkRouter) DelAddr(addr netip.Prefix) error { // If multiple tunnels are provided, the router will distribute traffic across them // uniformly. func (r *ICXNetlinkRouter) AddRoute(dst netip.Prefix) error { - // TODO (dpeckett): implement + slog.Info("Adding route", slog.String("addr", dst.String())) + + mask := net.CIDRMask(dst.Bits(), 128) + if dst.Addr().Is4() { + mask = net.CIDRMask(dst.Bits(), 32) + } + route := &netlink.Route{ + LinkIndex: r.tunLink.Attrs().Index, + Dst: &net.IPNet{ + IP: dst.Addr().AsSlice(), + Mask: mask, + }, + Scope: netlink.SCOPE_LINK, + } + if err := netlink.RouteAdd(route); err != nil { + return fmt.Errorf("failed to add route: %w", err) + } + + slog.Info("Route added", slog.String("dst", dst.String())) + return nil } @@ -188,20 +252,68 @@ func (r *ICXNetlinkRouter) AddRoute(dst netip.Prefix) error { // getting re-routed via a different tunnel or dropped (if no tunnel is available for // the given dst). func (r *ICXNetlinkRouter) DelRoute(dst netip.Prefix) error { - // TODO (dpeckett): implement + slog.Debug("Removing route", slog.String("prefix", dst.String())) + + mask := net.CIDRMask(dst.Bits(), 128) + if dst.Addr().Is4() { + mask = net.CIDRMask(dst.Bits(), 32) + } + route := &netlink.Route{ + LinkIndex: r.tunLink.Attrs().Index, + Dst: &net.IPNet{ + IP: dst.Addr().AsSlice(), + Mask: mask, + }, + Scope: netlink.SCOPE_LINK, + } + if err := netlink.RouteDel(route); err != nil { + return fmt.Errorf("failed to remove route: %w", err) + } + + slog.Info("Route removed", slog.String("dst", dst.String())) return nil } -// ListRoutes returns a list of all routes currently managed by the router. -func (r *ICXNetlinkRouter) ListRoutes() ([]TunnelRoute, error) { - // TODO (dpeckett): implement - return nil, nil -} +func (r *ICXNetlinkRouter) syncDNATChain() error { + natChains := proxyutil.NewLineBuffer() + natChains.Write(utiliptables.MakeChainLine(ChainA3yTunRules)) -// LocalAddresses returns the list of local addresses that are assigned to the router. -func (r *ICXNetlinkRouter) LocalAddresses() ([]netip.Prefix, error) { - // TODO (dpeckett): implement - return nil, nil + natRules := proxyutil.NewLineBuffer() + + peers := r.Handler.ListVirtualNetworks() + + for i, peer := range peers { + for _, addr := range peer.Addrs { + if addr.Addr().Is4() { // Skipping IPv4 peers - only IPv6 tunnel ingress is supported. + continue + } + natRules.Write( + "-A", string(ChainA3yTunRules), + "-m", "statistic", + "--mode", "random", + "--probability", probability(len(peers)-i), + "-j", "DNAT", + "--to-destination", addr.Addr().String(), + ) + } + } + + iptNewData := bytes.NewBuffer(nil) + iptNewData.WriteString("*nat\n") + iptNewData.Write(natChains.Bytes()) + iptNewData.Write(natRules.Bytes()) + iptNewData.WriteString("COMMIT\n") + + if err := r.iptV6.Restore( + utiliptables.TableNAT, + iptNewData.Bytes(), + utiliptables.NoFlushTables, + utiliptables.RestoreCounters, + ); err != nil { + return fmt.Errorf("failed to execute iptables-restore: %w", err) + } + + return nil } func addrsForInterface(link netlink.Link, port int) ([]net.Addr, error) { diff --git a/pkg/tunnel/router/server_netlink_linux.go b/pkg/tunnel/router/server_netlink_linux.go index 1306221..ba271ec 100644 --- a/pkg/tunnel/router/server_netlink_linux.go +++ b/pkg/tunnel/router/server_netlink_linux.go @@ -283,11 +283,6 @@ func (r *NetlinkRouter) AddAddr(addr netip.Prefix, tun connection.Connection) er return nil } -// ListAddrs lists all addresses on the tunnel added previously via AddAddr. -func (r *NetlinkRouter) ListAddrs() ([]netip.Prefix, error) { - return r.dmux.List() -} - // DelAddr removes a tunnel connection with the given address. The addr // is used by the multiplexer to route traffic to the correct tunnel based // on the destination IP of the incoming packet. @@ -349,20 +344,6 @@ func (r *NetlinkRouter) DelRoute(dst netip.Prefix) error { return nil } -// ListRoutes returns a list of all routes in the tunnel. -func (r *NetlinkRouter) ListRoutes() ([]TunnelRoute, error) { - ps := r.dmux.Prefixes() - rts := make([]TunnelRoute, 0, len(ps)) - for _, p := range ps { - rts = append(rts, TunnelRoute{ - Dst: p, - // TODO: Add connID, - State: TunnelRouteStateActive, - }) - } - return rts, nil -} - // Close releases any resources associated with the router. func (r *NetlinkRouter) Close() error { var firstErr error @@ -383,33 +364,3 @@ func (r *NetlinkRouter) Close() error { }) return firstErr } - -// LocalAddresses returns the list of local addresses that are assigned to the router. -func (r *NetlinkRouter) LocalAddresses() ([]netip.Prefix, error) { - if r.tunLink == nil { - return nil, nil - } - - addrs, err := netlink.AddrList(r.tunLink, netlink.FAMILY_V6) - if err != nil { - return nil, fmt.Errorf("failed to get addresses for link: %w", err) - } - - var prefixes []netip.Prefix - for _, addr := range addrs { - ip, ok := netip.AddrFromSlice(addr.IP) - if !ok { - slog.Warn("Failed to convert IP address", slog.String("ip", addr.IP.String())) - continue - } - if !ip.IsGlobalUnicast() { // Skip non-global unicast addresses. - slog.Debug("Skipping non-global unicast address", slog.String("ip", addr.IP.String())) - continue - } - - bits, _ := addr.Mask.Size() - prefixes = append(prefixes, netip.PrefixFrom(ip, bits)) - } - - return prefixes, nil -}