From 36e3824433a234fad4049fe3054f074a7241f68a Mon Sep 17 00:00:00 2001 From: Damian Peckett Date: Wed, 7 May 2025 13:26:30 +0200 Subject: [PATCH] [connect-ip] client can also run in kernel mode now via TUN --- go.mod | 15 +- go.sum | 5 + pkg/cmd/tunnel/tunnelnode.go | 19 +- pkg/netstack/tun_device.go | 14 ++ pkg/socksproxy/server.go | 9 + pkg/tunnel/client.go | 181 +++++++++++------- pkg/tunnel/dns/resolver.go | 5 +- pkg/tunnel/net/pcap.go | 113 ++++++++++++ pkg/tunnel/router/netlink.go | 9 + pkg/tunnel/router/netlink_linux.go | 94 ++++++---- pkg/tunnel/router/netstack.go | 59 +----- pkg/tunnel/router/options.go | 72 ++++++++ pkg/tunnel/router/router.go | 3 + pkg/tunnel/server.go | 22 +-- pkg/tunnel/token/jwks.go | 3 +- pkg/tunnel/tunnel_test.go | 285 ++++++++++++++++++++++++++--- pkg/utils/vm/vm.go | 23 ++- 17 files changed, 721 insertions(+), 210 deletions(-) create mode 100644 pkg/tunnel/net/pcap.go create mode 100644 pkg/tunnel/router/netlink.go create mode 100644 pkg/tunnel/router/options.go diff --git a/go.mod b/go.mod index c4f8e48c..75886586 100644 --- a/go.mod +++ b/go.mod @@ -8,7 +8,11 @@ require ( github.com/ClickHouse/clickhouse-go/v2 v2.23.2 github.com/MicahParks/jwkset v0.9.5 github.com/MicahParks/keyfunc/v3 v3.3.10 + github.com/adrg/xdg v0.5.3 github.com/alphadose/haxmap v1.4.1 + github.com/anatol/vmtest v0.0.0-20250318022921-2f32244e2f0f + github.com/avast/retry-go/v4 v4.6.1 + github.com/bramvdbogaerde/go-scp v1.5.0 github.com/buraksezer/olric v0.5.6 github.com/cncf/xds/go v0.0.0-20250121191232-2f005788dc42 github.com/coder/websocket v1.8.12 @@ -38,10 +42,13 @@ require ( github.com/google/go-cmp v0.7.0 github.com/google/go-containerregistry v0.19.1 github.com/google/go-github/v61 v61.0.0 + github.com/google/gopacket v1.1.19 github.com/google/uuid v1.6.0 github.com/hashicorp/go-discover v0.0.0-20240726212017-342faf50e5d4 github.com/jedib0t/go-pretty/v6 v6.4.9 github.com/k3s-io/kine v0.13.2 + github.com/kdomanski/iso9660 v0.4.0 + github.com/klauspost/cpuid/v2 v2.2.10 github.com/metal-stack/go-ipam v1.14.7 github.com/miekg/dns v1.1.63 github.com/mitchellh/mapstructure v1.5.0 @@ -65,6 +72,7 @@ require ( go.opentelemetry.io/proto/otlp v1.3.1 go.temporal.io/api v1.29.2 go.temporal.io/sdk v1.26.0 + golang.org/x/crypto v0.37.0 golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56 golang.org/x/net v0.39.0 golang.org/x/sync v0.13.0 @@ -116,21 +124,17 @@ require ( github.com/NYTimes/gziphandler v1.1.1 // indirect github.com/Rican7/retry v0.1.0 // indirect github.com/RoaringBitmap/roaring v1.2.1 // indirect - github.com/adrg/xdg v0.5.3 // indirect - github.com/anatol/vmtest v0.0.0-20250318022921-2f32244e2f0f // indirect github.com/andybalholm/brotli v1.1.0 // indirect github.com/antlr4-go/antlr/v4 v4.13.0 // indirect github.com/apache/thrift v0.20.0 // indirect github.com/apparentlymart/go-cidr v1.1.0 // indirect github.com/armon/go-metrics v0.0.0-20180917152333-f0300d1749da // indirect github.com/asaskevich/govalidator v0.0.0-20200428143746-21a406dcc535 // indirect - github.com/avast/retry-go/v4 v4.6.1 // indirect github.com/aws/aws-sdk-go v1.55.5 // indirect github.com/benbjohnson/clock v1.3.5 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/bits-and-blooms/bitset v1.2.0 // indirect github.com/blang/semver/v4 v4.0.0 // indirect - github.com/bramvdbogaerde/go-scp v1.5.0 // indirect github.com/buraksezer/consistent v0.10.0 // indirect github.com/cactus/go-statsd-client/statsd v0.0.0-20200423205355-cb0885a1018c // indirect github.com/cactus/go-statsd-client/v5 v5.1.0 // indirect @@ -234,10 +238,8 @@ require ( github.com/joyent/triton-go v0.0.0-20180628001255-830d2b111e62 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 // indirect - github.com/kdomanski/iso9660 v0.4.0 // indirect github.com/kelseyhightower/envconfig v1.4.0 // indirect github.com/klauspost/compress v1.18.0 // indirect - github.com/klauspost/cpuid/v2 v2.2.10 // indirect github.com/kylelemons/godebug v1.1.0 // indirect github.com/labstack/echo/v4 v4.10.0 // indirect github.com/labstack/gommon v0.4.0 // indirect @@ -366,7 +368,6 @@ require ( go.uber.org/multierr v1.11.0 // indirect go.uber.org/zap v1.27.0 // indirect go4.org/netipx v0.0.0-20231129151722-fdeea329fbba // indirect - golang.org/x/crypto v0.37.0 // indirect golang.org/x/mod v0.21.0 // indirect golang.org/x/oauth2 v0.29.0 // indirect golang.org/x/term v0.31.0 // indirect diff --git a/go.sum b/go.sum index 9ff0e4bb..abc20321 100644 --- a/go.sum +++ b/go.sum @@ -528,6 +528,8 @@ github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/ github.com/google/gofuzz v1.1.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0= github.com/google/gofuzz v1.2.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8= +github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo= github.com/google/martian v2.1.0+incompatible h1:/CP5g8u/VJHijgedC/Legn3BAbAaWPgecwXBIDzw5no= github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs= github.com/google/martian/v3 v3.3.3 h1:DIhPTQrbPkgs2yJYdXU/eNACCG5DVQjySNRNlflZ9Fc= @@ -1258,9 +1260,11 @@ golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvx golang.org/x/lint v0.0.0-20190301231843-5614ed5bae6f/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= +golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= golang.org/x/mobile v0.0.0-20190719004257-d2bd2a29d028/go.mod h1:E/iHnbuqvinMTCcRqshq8CkpyQDoeVncDDYHnLhea+o= golang.org/x/mod v0.0.0-20190513183733-4bf6d317e70e/go.mod h1:mXi4GBBbnImb6dmsKGUJ2LatrhH/nqhxcFungHvyanc= golang.org/x/mod v0.1.0/go.mod h1:0QHyrYULN0/3qlju5TqG8bIK38QM8yzMo5ekMj3DlcY= +golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= @@ -1444,6 +1448,7 @@ golang.org/x/tools v0.0.0-20191029041327-9cc4af7d6b2c/go.mod h1:b+2E5dAYhXwXZwtn golang.org/x/tools v0.0.0-20191029190741-b9c20aec41a5/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191108193012-7d206e10da11/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= golang.org/x/tools v0.0.0-20201224043029-2b0845dc783e/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= diff --git a/pkg/cmd/tunnel/tunnelnode.go b/pkg/cmd/tunnel/tunnelnode.go index 31b6572c..55698c35 100644 --- a/pkg/cmd/tunnel/tunnelnode.go +++ b/pkg/cmd/tunnel/tunnelnode.go @@ -222,17 +222,26 @@ func (t *tunnelNodeReconciler) Reconcile(ctx context.Context, req ctrl.Request) cOpts = append(cOpts, tunnel.WithInsecureSkipVerify(true)) } + if t.tunC != nil { + log.Info("Closing existing tunnel client") + if err := t.tunC.Close(); err != nil { + log.Error(err, "Failed to close existing tunnel client") + } + t.tunC = nil + } + if t.tunC, err = tunnel.NewTunnelClient(cOpts...); err != nil { log.Error(err, "Failed to create tunnel client") t.doneCh <- fmt.Errorf("failed to create tunnel client: %w", err) return ctrl.Result{}, nil // Unrecoverable error. } - if err := t.tunC.Start(ctx); err != nil { - log.Error(err, "Failed to start tunnel client") - t.doneCh <- fmt.Errorf("failed to start tunnel client: %w", err) - return ctrl.Result{}, nil // Unrecoverable error. - } + go func() { + if err := t.tunC.Start(ctx); err != nil { + log.Error(err, "Failed to start tunnel client") + t.doneCh <- fmt.Errorf("failed to start tunnel client: %w", err) + } + }() return ctrl.Result{}, nil } diff --git a/pkg/netstack/tun_device.go b/pkg/netstack/tun_device.go index e0a75d19..475d241f 100644 --- a/pkg/netstack/tun_device.go +++ b/pkg/netstack/tun_device.go @@ -5,6 +5,7 @@ import ( "fmt" "net/netip" "os" + "sync/atomic" "syscall" "github.com/dpeckett/network" @@ -35,6 +36,7 @@ type TunDevice struct { events chan tun.Event incomingPacket chan *buffer.View mtu int + closed atomic.Bool } func NewTunDevice(localAddresses []netip.Prefix, pcapPath string) (*TunDevice, error) { @@ -145,6 +147,10 @@ func (tun *TunDevice) MTU() (int, error) { return tun.mtu, nil } func (tun *TunDevice) BatchSize() int { return 1 } func (tun *TunDevice) Read(buf [][]byte, sizes []int, offset int) (int, error) { + if tun.closed.Load() { + return 0, os.ErrClosed + } + view, ok := <-tun.incomingPacket if !ok { return 0, os.ErrClosed @@ -159,6 +165,10 @@ func (tun *TunDevice) Read(buf [][]byte, sizes []int, offset int) (int, error) { } func (tun *TunDevice) Write(buf [][]byte, offset int) (int, error) { + if tun.closed.Load() { + return 0, os.ErrClosed + } + for _, buf := range buf { packet := buf[offset:] if len(packet) == 0 { @@ -191,6 +201,10 @@ func (tun *TunDevice) WriteNotify() { } func (tun *TunDevice) Close() error { + if tun.closed.Swap(true) { + return nil + } + tun.stack.RemoveNIC(tun.nicID) if tun.events != nil { diff --git a/pkg/socksproxy/server.go b/pkg/socksproxy/server.go index 333be33e..4f231c8e 100644 --- a/pkg/socksproxy/server.go +++ b/pkg/socksproxy/server.go @@ -29,6 +29,7 @@ func NewServer(addr string, upstream network.Network, fallback network.Network) socks5.WithDial((&dialer{upstream: upstream, fallback: fallback}).DialContext), socks5.WithResolver(&resolver{net: upstream}), socks5.WithBufferPool(bufferpool.NewPool(256 * 1024)), + socks5.WithLogger(&logger{}), // No auth as we'll be binding exclusively to a local interface. socks5.WithAuthMethods([]socks5.Authenticator{socks5.NoAuthAuthenticator{}}), } @@ -113,6 +114,8 @@ func (d *dialer) DialContext(ctx context.Context, network, address string) (net. return d.fallback.DialContext(ctx, network, address) } + slog.Debug("Address is private - dialing upstream", slog.String("address", addr.String())) + return d.upstream.DialContext(ctx, network, address) } @@ -140,3 +143,9 @@ func (r *resolver) Resolve(ctx context.Context, name string) (context.Context, n return ctx, ip, nil } + +type logger struct{} + +func (l *logger) Errorf(format string, arg ...any) { + slog.Error(fmt.Sprintf(format, arg...)) +} diff --git a/pkg/tunnel/client.go b/pkg/tunnel/client.go index 96bcb629..f7756056 100644 --- a/pkg/tunnel/client.go +++ b/pkg/tunnel/client.go @@ -10,7 +10,7 @@ import ( "net" "net/http" "net/netip" - "strconv" + "reflect" "strings" "sync" "time" @@ -22,27 +22,40 @@ import ( "github.com/quic-go/quic-go/http3" "github.com/yosida95/uritemplate/v3" - "github.com/apoxy-dev/apoxy-cli/pkg/netstack" - "github.com/apoxy-dev/apoxy-cli/pkg/socksproxy" - "github.com/apoxy-dev/apoxy-cli/pkg/tunnel/connection" + "github.com/apoxy-dev/apoxy-cli/pkg/tunnel/router" ) type TunnelClientOption func(*tunnelClientOptions) +type TunnelClientMode string + +const ( + // TunnelClientModeKernel indicates that the tunnel client will use the kernel mode router. + // This mode requires root privileges and is more efficient for routing traffic. + TunnelClientModeKernel TunnelClientMode = "kernel" + // TunnelClientModeUser indicates that the tunnel client will use the user mode router. + TunnelClientModeUser TunnelClientMode = "user" +) + type tunnelClientOptions struct { serverAddr string - insecureSkipVerify bool uuid uuid.UUID authToken string - pcapPath string + mode TunnelClientMode + insecureSkipVerify bool rootCAs *x509.CertPool - socksListenAddr string + pcapPath string + // Kernel mode options + extIfaceName string + tunIfaceName string + // Userspace options + socksListenAddr string } func defaultClientOptions() *tunnelClientOptions { return &tunnelClientOptions{ - serverAddr: "localhost:9443", - socksListenAddr: "localhost:1080", + serverAddr: "localhost:9443", + mode: TunnelClientModeUser, } } @@ -54,13 +67,6 @@ func WithServerAddr(addr string) TunnelClientOption { } } -// WithInsecureSkipVerify skips TLS certificate verification of the server. -func WithInsecureSkipVerify(skip bool) TunnelClientOption { - return func(o *tunnelClientOptions) { - o.insecureSkipVerify = skip - } -} - // WithUUID sets the UUID for the tunnel client. func WithUUID(uuid uuid.UUID) TunnelClientOption { return func(o *tunnelClientOptions) { @@ -75,10 +81,17 @@ func WithAuthToken(token string) TunnelClientOption { } } -// WithPcapPath sets the optional path to a packet capture file for the tunnel client. -func WithPcapPath(path string) TunnelClientOption { +// WithMode sets the mode of the tunnel client (kernel or userspace). +func WithMode(mode TunnelClientMode) TunnelClientOption { return func(o *tunnelClientOptions) { - o.pcapPath = path + o.mode = mode + } +} + +// WithInsecureSkipVerify skips TLS certificate verification of the server. +func WithInsecureSkipVerify(skip bool) TunnelClientOption { + return func(o *tunnelClientOptions) { + o.insecureSkipVerify = skip } } @@ -89,7 +102,31 @@ func WithRootCAs(caCerts *x509.CertPool) TunnelClientOption { } } +// WithPcapPath sets the optional path to a packet capture file for the tunnel client. +func WithPcapPath(path string) TunnelClientOption { + return func(o *tunnelClientOptions) { + o.pcapPath = path + } +} + +// WithExternalInterface sets the external interface name. +// This is only valid in kernel mode. +func WithExternalInterface(name string) TunnelClientOption { + return func(o *tunnelClientOptions) { + o.extIfaceName = name + } +} + +// WithTunnelInterface sets the tunnel interface name. +// This is only valid in kernel mode. +func WithTunnelInterface(name string) TunnelClientOption { + return func(o *tunnelClientOptions) { + o.tunIfaceName = name + } +} + // WithSocksListenAddr sets the listen address for the local SOCKS5 proxy server. +// Only valid in user mode. func WithSocksListenAddr(addr string) TunnelClientOption { return func(o *tunnelClientOptions) { o.socksListenAddr = addr @@ -103,19 +140,15 @@ type TunnelClient struct { insecureSkipVerify bool uuid uuid.UUID authToken string - pcapPath string rootCAs *x509.CertPool + router router.Router + + hConn *http3.ClientConn + conn *connectip.Conn - hConn *http3.ClientConn - conn *connectip.Conn - tun *netstack.TunDevice - netstack *network.NetstackNetwork - proxy *socksproxy.ProxyServer closeOnce sync.Once } -// NewTunnelClient creates a new SOCKS5 proxy and loopback reverse proxy, -// that forwards and receives traffic via QUIC tunnels. func NewTunnelClient(opts ...TunnelClientOption) (*TunnelClient, error) { options := defaultClientOptions() for _, opt := range opts { @@ -133,7 +166,6 @@ func NewTunnelClient(opts ...TunnelClientOption) (*TunnelClient, error) { options: options, uuid: options.uuid, authToken: options.authToken, - pcapPath: options.pcapPath, rootCAs: options.rootCAs, insecureSkipVerify: options.insecureSkipVerify, } @@ -141,9 +173,6 @@ func NewTunnelClient(opts ...TunnelClientOption) (*TunnelClient, error) { return client, nil } -// Start establishes a connection to the server and begins forwarding traffic. -// TODO: this is non blocking and does not match the behavior of the router.Start() -// method, we should probably change it. func (c *TunnelClient) Start(ctx context.Context) error { tlsConfig := &tls.Config{ ServerName: "proxy", @@ -227,47 +256,64 @@ func (c *TunnelClient) Start(ctx context.Context) error { slog.Any("searchDomains", resolveConf.SearchDomains), slog.Any("nDots", resolveConf.NDots)) - c.tun, err = netstack.NewTunDevice(filteredLocalPrefixes, c.pcapPath) - if err != nil { - return fmt.Errorf("failed to create virtual TUN device: %w", err) + routerOpts := []router.Option{ + router.WithLocalAddresses(filteredLocalPrefixes), + router.WithResolveConfig(resolveConf), } - c.netstack = c.tun.Network(resolveConf) + if c.options.pcapPath != "" { + routerOpts = append(routerOpts, router.WithPcapPath(c.options.pcapPath)) + } - go connection.Splice(c.tun, c.conn) + if c.options.extIfaceName != "" { + routerOpts = append(routerOpts, router.WithExternalInterface(c.options.extIfaceName)) + } - _, socksListenPortStr, err := net.SplitHostPort(c.options.socksListenAddr) - if err != nil { - return fmt.Errorf("failed to parse SOCKS listen address: %w", err) + if c.options.tunIfaceName != "" { + routerOpts = append(routerOpts, router.WithTunnelInterface(c.options.tunIfaceName)) } - socksListenPort, err := strconv.Atoi(socksListenPortStr) - if err != nil { - return fmt.Errorf("failed to parse SOCKS listen port: %w", err) + if c.options.socksListenAddr != "" { + routerOpts = append(routerOpts, router.WithSocksListenAddr(c.options.socksListenAddr)) } - slog.Info("Forwarding all inbound traffic to loopback interface") + if c.options.mode == TunnelClientModeKernel { + c.router, err = router.NewNetlinkRouter(routerOpts...) + if err != nil { + return fmt.Errorf("failed to create kernel router: %w", err) + } + } else if c.options.mode == TunnelClientModeUser { + c.router, err = router.NewNetstackRouter(routerOpts...) + if err != nil { + return fmt.Errorf("failed to create user mode router: %w", err) + } + } - if err := c.tun.ForwardTo(ctx, network.Filtered(&network.FilteredNetworkConfig{ - DeniedPorts: []uint16{uint16(socksListenPort)}, - Upstream: network.Loopback(), - })); err != nil { - return fmt.Errorf("failed to forward to loopback: %w", err) + routes, err := c.conn.Routes(ctx) + if err != nil { + return fmt.Errorf("failed to get routes: %w", err) } - slog.Info("Starting SOCKS5 proxy", slog.String("listenAddr", c.options.socksListenAddr)) + for _, route := range routes { + for _, prefix := range route.Prefixes() { + slog.Info("Adding route", slog.String("prefix", prefix.String())) - c.proxy = socksproxy.NewServer(c.options.socksListenAddr, c.netstack, network.Host()) - go func() { - if err := c.proxy.ListenAndServe(ctx); err != nil { - slog.Error("SOCKS proxy error", slog.String("error", err.Error())) + _, _, err := c.router.AddPeer(prefix, c.conn) + if err != nil { + return fmt.Errorf("failed to add peer route %s: %w", prefix.String(), err) + } } - }() + } + + slog.Info("Starting router") + + if err := c.router.Start(ctx); err != nil { + return fmt.Errorf("failed to start router: %w", err) + } return nil } -// Stop closes the tunnel client and stops forwarding traffic. func (c *TunnelClient) Close() error { var firstErr error c.closeOnce.Do(func() { @@ -280,15 +326,6 @@ func (c *TunnelClient) Close() error { } } - if c.tun != nil { - if err := c.tun.Close(); err != nil { - slog.Error("Failed to close TUN device", slog.Any("error", err)) - if firstErr == nil { - firstErr = fmt.Errorf("failed to close TUN device: %w", err) - } - } - } - if c.hConn != nil { if err := c.hConn.CloseWithError(ApplicationCodeOK, ""); err != nil { slog.Error("Failed to close HTTP/3 connection", slog.Any("error", err)) @@ -298,12 +335,10 @@ func (c *TunnelClient) Close() error { } } - if c.proxy != nil { - if err := c.proxy.Close(); err != nil { - slog.Error("Failed to close SOCKS proxy", slog.Any("error", err)) - if firstErr == nil { - firstErr = fmt.Errorf("failed to close SOCKS proxy: %w", err) - } + if err := c.router.Close(); err != nil { + slog.Error("Failed to close router", slog.Any("error", err)) + if firstErr == nil { + firstErr = fmt.Errorf("failed to close router: %w", err) } } }) @@ -311,5 +346,9 @@ func (c *TunnelClient) Close() error { } func (c *TunnelClient) LocalAddresses() ([]netip.Prefix, error) { - return c.tun.LocalAddresses() + if c.router == nil || reflect.ValueOf(c.router).IsNil() { + return nil, nil + } + + return c.router.LocalAddresses() } diff --git a/pkg/tunnel/dns/resolver.go b/pkg/tunnel/dns/resolver.go index 16802eab..c648b7bb 100644 --- a/pkg/tunnel/dns/resolver.go +++ b/pkg/tunnel/dns/resolver.go @@ -166,7 +166,10 @@ func (r *TunnelNodeDNSReconciler) serveDNS(ctx context.Context, next plugin.Hand return dns.RcodeServerFailure, nil } - w.WriteMsg(msg) + if err := w.WriteMsg(msg); err != nil { + log.Error("Failed to write response", slog.Any("error", err)) + return dns.RcodeServerFailure, err + } return dns.RcodeSuccess, nil } diff --git a/pkg/tunnel/net/pcap.go b/pkg/tunnel/net/pcap.go new file mode 100644 index 00000000..57ab9be8 --- /dev/null +++ b/pkg/tunnel/net/pcap.go @@ -0,0 +1,113 @@ +package net + +import ( + "fmt" + "log/slog" + "os" + "time" + + "github.com/google/gopacket" + "github.com/google/gopacket/layers" + "github.com/google/gopacket/pcapgo" + "golang.zx2c4.com/wireguard/tun" +) + +var _ tun.Device = (*PcapDevice)(nil) + +type PcapDevice struct { + dev tun.Device + w *pcapgo.Writer +} + +func NewPcapDevice(dev tun.Device, pcapPath string) (*PcapDevice, error) { + f, err := os.Create(pcapPath) + if err != nil { + return nil, err + } + + w := pcapgo.NewWriter(f) + if err := w.WriteFileHeader(65535, layers.LinkTypeIPv6); err != nil { + return nil, err + } + + return &PcapDevice{ + dev: dev, + w: w, + }, nil +} + +func (d *PcapDevice) Write(bufs [][]byte, offset int) (int, error) { + for _, buf := range bufs { + if len(buf) <= offset { + slog.Warn("PcapDevice.Write: skipping short buffer", + slog.Int("len", len(buf)), slog.Int("offset", offset)) + continue + } + packetData := buf[offset:] + ci := gopacket.CaptureInfo{ + Timestamp: time.Now(), + CaptureLength: len(packetData), + Length: len(packetData), + } + if err := d.w.WritePacket(ci, packetData); err != nil { + return 0, fmt.Errorf("failed to write packet: %w", err) + } + } + + n, err := d.dev.Write(bufs, offset) + if err != nil { + return n, err + } + + return n, nil +} + +func (d *PcapDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) { + n, err = d.dev.Read(bufs, sizes, offset) + if err != nil { + return n, err + } + + for i := 0; i < n; i++ { + if len(bufs[i]) < offset+sizes[i] { + slog.Warn("PcapDevice.Read: skipping short buffer", + slog.Int("len", len(bufs[i])), slog.Int("offset", offset), slog.Int("size", sizes[i])) + continue + } + packetData := bufs[i][offset : offset+sizes[i]] + ci := gopacket.CaptureInfo{ + Timestamp: time.Now(), + CaptureLength: len(packetData), + Length: len(packetData), + } + if err := d.w.WritePacket(ci, packetData); err != nil { + return 0, fmt.Errorf("failed to write packet: %w", err) + } + } + + return n, nil +} + +func (d *PcapDevice) BatchSize() int { + return d.dev.BatchSize() +} + +func (d *PcapDevice) Close() error { + return d.dev.Close() +} + +func (d *PcapDevice) Events() <-chan tun.Event { + return d.dev.Events() +} + +func (d *PcapDevice) File() *os.File { + return d.dev.File() +} + +func (d *PcapDevice) MTU() (int, error) { + return d.dev.MTU() +} + +func (d *PcapDevice) Name() (string, error) { + return d.dev.Name() +} diff --git a/pkg/tunnel/router/netlink.go b/pkg/tunnel/router/netlink.go new file mode 100644 index 00000000..9d2daf91 --- /dev/null +++ b/pkg/tunnel/router/netlink.go @@ -0,0 +1,9 @@ +//go:build !linux + +package router + +import "fmt" + +func NewNetlinkRouter(_ ...Option) (Router, error) { + return nil, fmt.Errorf("netlink router is not supported on this platform") +} diff --git a/pkg/tunnel/router/netlink_linux.go b/pkg/tunnel/router/netlink_linux.go index 36ec5391..b7fb3704 100644 --- a/pkg/tunnel/router/netlink_linux.go +++ b/pkg/tunnel/router/netlink_linux.go @@ -13,7 +13,6 @@ import ( "slices" "strings" "sync" - "sync/atomic" "github.com/vishvananda/netlink" "golang.org/x/sync/errgroup" @@ -24,6 +23,7 @@ import ( "github.com/apoxy-dev/apoxy-cli/pkg/netstack" "github.com/apoxy-dev/apoxy-cli/pkg/tunnel/connection" + tunnet "github.com/apoxy-dev/apoxy-cli/pkg/tunnel/net" ) var ( @@ -43,7 +43,6 @@ type NetlinkRouter struct { mux *connection.MuxedConnection closeOnce sync.Once - closed atomic.Bool } func extPrefixes(link netlink.Link) (netip.Addr, []netip.Prefix, error) { @@ -87,38 +86,8 @@ func extPrefixes(link netlink.Link) (netip.Addr, []netip.Prefix, error) { } // NewNetlinkRouter creates a new netlink-based tunnel router. -// NetlinkRouterOption represents a router configuration option. -type NetlinkRouterOption func(*netlinkRouterOptions) - -type netlinkRouterOptions struct { - extIfaceName string - tunIfaceName string -} - -func defaultNetlinkOptions() *netlinkRouterOptions { - return &netlinkRouterOptions{ - extIfaceName: "eth0", - tunIfaceName: "tun0", - } -} - -// WithExternalInterface sets the external interface name. -func WithExternalInterface(name string) NetlinkRouterOption { - return func(o *netlinkRouterOptions) { - o.extIfaceName = name - } -} - -// WithTunnelInterface sets the tunnel interface name. -func WithTunnelInterface(name string) NetlinkRouterOption { - return func(o *netlinkRouterOptions) { - o.tunIfaceName = name - } -} - -// NewNetlinkRouter creates a new netlink-based tunnel router. -func NewNetlinkRouter(opts ...NetlinkRouterOption) (*NetlinkRouter, error) { - options := defaultNetlinkOptions() +func NewNetlinkRouter(opts ...Option) (*NetlinkRouter, error) { + options := defaultOptions() for _, opt := range opts { opt(options) } @@ -137,6 +106,14 @@ func NewNetlinkRouter(opts ...NetlinkRouterOption) (*NetlinkRouter, error) { return nil, fmt.Errorf("failed to create TUN interface: %w", err) } + if options.pcapPath != "" { + tunDev, err = tunnet.NewPcapDevice(tunDev, options.pcapPath) + if err != nil { + tunDev.Close() + return nil, fmt.Errorf("failed to create pcap device: %w", err) + } + } + // Get the actual tun name (may differ from requested name). actualTunName, err := tunDev.Name() if err != nil { @@ -150,6 +127,25 @@ func NewNetlinkRouter(opts ...NetlinkRouterOption) (*NetlinkRouter, error) { return nil, fmt.Errorf("failed to get TUN interface: %w", err) } + for _, addr := range options.localAddresses { + ip := addr.Addr() + mask := net.CIDRMask(addr.Bits(), 32) + if ip.Is6() { + mask = net.CIDRMask(addr.Bits(), 128) + } + + if err := netlink.AddrAdd(tunLink, &netlink.Addr{ + IPNet: &net.IPNet{ + IP: ip.AsSlice(), + Mask: mask, + }, + }); err != nil { + tunDev.Close() + return nil, fmt.Errorf("failed to add address to TUN interface: %w", err) + } + slog.Info("Added address to TUN interface", slog.String("addr", addr.String())) + } + if err := netlink.LinkSetUp(tunLink); err != nil { tunDev.Close() return nil, fmt.Errorf("failed to bring up TUN interface: %w", err) @@ -364,3 +360,33 @@ func (r *NetlinkRouter) Close() error { }) return firstErr } + +// LocalAddresses returns the list of local addresses that are assigned to the router. +func (r *NetlinkRouter) LocalAddresses() ([]netip.Prefix, error) { + if r.tunLink == nil { + return nil, nil + } + + addrs, err := netlink.AddrList(r.tunLink, netlink.FAMILY_V6) + if err != nil { + return nil, fmt.Errorf("failed to get addresses for link: %w", err) + } + + var prefixes []netip.Prefix + for _, addr := range addrs { + ip, ok := netip.AddrFromSlice(addr.IP) + if !ok { + slog.Warn("Failed to convert IP address", slog.String("ip", addr.IP.String())) + continue + } + if !ip.IsGlobalUnicast() { // Skip non-global unicast addresses. + slog.Debug("Skipping non-global unicast address", slog.String("ip", addr.IP.String())) + continue + } + + bits, _ := addr.Mask.Size() + prefixes = append(prefixes, netip.PrefixFrom(ip, bits)) + } + + return prefixes, nil +} diff --git a/pkg/tunnel/router/netstack.go b/pkg/tunnel/router/netstack.go index 39be389e..ef650d21 100644 --- a/pkg/tunnel/router/netstack.go +++ b/pkg/tunnel/router/netstack.go @@ -21,52 +21,6 @@ var ( _ Router = (*NetstackRouter)(nil) ) -type NetstackRouterOption func(*netstackRouterOptions) - -type netstackRouterOptions struct { - localAddresses []netip.Prefix - socksListenAddr string - resolveConf *network.ResolveConfig // If not set system default resolver is used - pcapPath string -} - -func defaultClientOptions() *netstackRouterOptions { - return &netstackRouterOptions{ - localAddresses: []netip.Prefix{ - netip.MustParsePrefix("fd00::/64"), - }, - socksListenAddr: "localhost:1080", - } -} - -// WithLocalAddresses sets the local addresses for the netstack router. -func WithLocalAddresses(localAddresses []netip.Prefix) NetstackRouterOption { - return func(o *netstackRouterOptions) { - o.localAddresses = localAddresses - } -} - -// WithSocksListenAddr sets the SOCKS listen address for the netstack router. -func WithSocksListenAddr(addr string) NetstackRouterOption { - return func(o *netstackRouterOptions) { - o.socksListenAddr = addr - } -} - -// WithResolveConfig sets the DNS configuration for the netstack router. -func WithResolveConfig(conf *network.ResolveConfig) NetstackRouterOption { - return func(o *netstackRouterOptions) { - o.resolveConf = conf - } -} - -// WithPcapPath sets the optional path to a packet capture file for the netstack router. -func WithPcapPath(path string) NetstackRouterOption { - return func(o *netstackRouterOptions) { - o.pcapPath = path - } -} - // NetstackRouter implements Router using a userspace network stack. type NetstackRouter struct { tunDev *netstack.TunDevice @@ -79,8 +33,8 @@ type NetstackRouter struct { } // NewNetstackRouter creates a new netstack-based tunnel router. -func NewNetstackRouter(opts ...NetstackRouterOption) (*NetstackRouter, error) { - options := defaultClientOptions() +func NewNetstackRouter(opts ...Option) (*NetstackRouter, error) { + options := defaultOptions() for _, opt := range opts { opt(options) } @@ -153,16 +107,12 @@ func (r *NetstackRouter) Start(ctx context.Context) error { // AddPeer adds a peer route to the tunnel. func (r *NetstackRouter) AddPeer(peer netip.Prefix, conn connection.Connection) (netip.Addr, []netip.Prefix, error) { - slog.Debug("Adding route in netstack", slog.String("prefix", peer.String())) - r.mux.AddConnection(peer, conn) return peer.Addr(), r.localAddresses, nil } // RemovePeer removes a peer route from the tunnel. func (r *NetstackRouter) RemovePeer(peer netip.Prefix) error { - slog.Debug("Removing route in netstack", slog.String("prefix", peer.String())) - if err := r.mux.RemoveConnection(peer); err != nil { slog.Error("failed to remove connection", slog.Any("error", err)) } @@ -202,3 +152,8 @@ func (r *NetstackRouter) Close() error { }) return firstErr } + +// LocalAddresses returns the list of local addresses that are assigned to the router. +func (r *NetstackRouter) LocalAddresses() ([]netip.Prefix, error) { + return r.tunDev.LocalAddresses() +} diff --git a/pkg/tunnel/router/options.go b/pkg/tunnel/router/options.go new file mode 100644 index 00000000..993d3463 --- /dev/null +++ b/pkg/tunnel/router/options.go @@ -0,0 +1,72 @@ +package router + +import ( + "net/netip" + + "github.com/dpeckett/network" +) + +// Option represents a router configuration option. +type Option func(*routerOptions) + +type routerOptions struct { + localAddresses []netip.Prefix + resolveConf *network.ResolveConfig // If not set system default resolver is used + pcapPath string + extIfaceName string + tunIfaceName string + socksListenAddr string +} + +func defaultOptions() *routerOptions { + return &routerOptions{ + extIfaceName: "eth0", + tunIfaceName: "tun0", + socksListenAddr: "localhost:1080", + } +} + +// WithLocalAddresses sets the local addresses for the router. +func WithLocalAddresses(localAddresses []netip.Prefix) Option { + return func(o *routerOptions) { + o.localAddresses = localAddresses + } +} + +// WithPcapPath sets the optional path to a packet capture file for the netstack router. +func WithPcapPath(path string) Option { + return func(o *routerOptions) { + o.pcapPath = path + } +} + +// WithResolveConfig sets the DNS configuration for the netstack router. +func WithResolveConfig(conf *network.ResolveConfig) Option { + return func(o *routerOptions) { + o.resolveConf = conf + } +} + +// WithExternalInterface sets the external interface name. +// Only valid for netlink routers. +func WithExternalInterface(name string) Option { + return func(o *routerOptions) { + o.extIfaceName = name + } +} + +// WithTunnelInterface sets the tunnel interface name. +// Only valid for netlink routers. +func WithTunnelInterface(name string) Option { + return func(o *routerOptions) { + o.tunIfaceName = name + } +} + +// WithSocksListenAddr sets the SOCKS listen address for the netstack router. +// Only valid for netstack routers. +func WithSocksListenAddr(addr string) Option { + return func(o *routerOptions) { + o.socksListenAddr = addr + } +} diff --git a/pkg/tunnel/router/router.go b/pkg/tunnel/router/router.go index 85940c65..e50998ab 100644 --- a/pkg/tunnel/router/router.go +++ b/pkg/tunnel/router/router.go @@ -22,4 +22,7 @@ type Router interface { // RemovePeer removes a peer route from the tunnel identified by the given prefix. RemovePeer(peer netip.Prefix) error + + // LocalAddresses returns the list of local addresses that are assigned to the router. + LocalAddresses() ([]netip.Prefix, error) } diff --git a/pkg/tunnel/server.go b/pkg/tunnel/server.go index 5eca488b..136995b9 100644 --- a/pkg/tunnel/server.go +++ b/pkg/tunnel/server.go @@ -10,6 +10,7 @@ import ( "net/http" "net/netip" "strings" + "time" "github.com/alphadose/haxmap" "github.com/google/uuid" @@ -108,9 +109,6 @@ type TunnelServer struct { mux *connection.MuxedConnection // Maps tunnelNodes *haxmap.Map[string, *corev1alpha.TunnelNode] - - tunnelCtx context.Context - tunnelCtxCancel context.CancelFunc } // NewTunnelServer creates a new server proxy that routes traffic via @@ -154,8 +152,6 @@ func (t *TunnelServer) SetupWithManager(mgr ctrl.Manager) error { } func (t *TunnelServer) Start(ctx context.Context) error { - t.tunnelCtx, t.tunnelCtxCancel = context.WithCancel(ctx) - bindTo, err := netip.ParseAddrPort(t.options.proxyAddr) if err != nil { return fmt.Errorf("failed to parse bind address: %w", err) @@ -185,7 +181,7 @@ func (t *TunnelServer) Start(ctx context.Context) error { return fmt.Errorf("failed to create QUIC listener: %w", err) } - g, ctx := errgroup.WithContext(t.tunnelCtx) + g, ctx := errgroup.WithContext(ctx) g.Go(func() error { <-ctx.Done() @@ -222,17 +218,15 @@ func upsertAgentStatus(s *corev1alpha.TunnelNodeStatus, agent *corev1alpha.Agent } func (t *TunnelServer) Stop() error { - if err := t.Shutdown(context.Background()); err != nil { - slog.Error("Failed to shutdown server", slog.Any("error", err)) + if err := t.router.Close(); err != nil { + slog.Error("Failed to close router", slog.Any("error", err)) } - // Stop any background tasks if they are running. - if t.tunnelCtxCancel != nil { - t.tunnelCtxCancel() - } + shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() - if err := t.router.Close(); err != nil { - slog.Error("Failed to close router", slog.Any("error", err)) + if err := t.Shutdown(shutdownCtx); err != nil { + slog.Error("Failed to shutdown server", slog.Any("error", err)) } return t.Server.Close() diff --git a/pkg/tunnel/token/jwks.go b/pkg/tunnel/token/jwks.go index 1e15b821..e04fff0e 100644 --- a/pkg/tunnel/token/jwks.go +++ b/pkg/tunnel/token/jwks.go @@ -1,6 +1,7 @@ package token import ( + "context" "crypto/ecdsa" "crypto/sha256" "fmt" @@ -37,7 +38,7 @@ func NewJWKSHandler(publicKeyPEM []byte) (http.HandlerFunc, error) { } jwkSet := jwkset.NewMemoryStorage() - if err := jwkSet.KeyWrite(nil, jwk); err != nil { + if err := jwkSet.KeyWrite(context.Background(), jwk); err != nil { return nil, fmt.Errorf("failed to write JWK: %w", err) } diff --git a/pkg/tunnel/tunnel_test.go b/pkg/tunnel/tunnel_test.go index ca3fdf29..f65d3be9 100644 --- a/pkg/tunnel/tunnel_test.go +++ b/pkg/tunnel/tunnel_test.go @@ -10,10 +10,12 @@ import ( "log/slog" "net" "net/http" + "net/netip" "path/filepath" "testing" "time" + "github.com/avast/retry-go/v4" "github.com/golang-jwt/jwt" "github.com/google/uuid" "github.com/stretchr/testify/assert" @@ -30,9 +32,15 @@ import ( "github.com/apoxy-dev/apoxy-cli/pkg/tunnel" "github.com/apoxy-dev/apoxy-cli/pkg/tunnel/router" "github.com/apoxy-dev/apoxy-cli/pkg/tunnel/token" + "github.com/apoxy-dev/apoxy-cli/pkg/utils/vm" ) -func TestTunnelEndToEnd(t *testing.T) { +func TestTunnelEndToEnd_UserModeClient(t *testing.T) { + child := vm.RunTestInVM(t) + if !child { + return + } + if testing.Verbose() { slog.SetLogLoggerLevel(slog.LevelDebug) } @@ -88,16 +96,13 @@ func TestTunnelEndToEnd(t *testing.T) { jwtValidator, err := token.NewInMemoryValidator(jwtPublicKeyPEM) require.NoError(t, err) - netstackRouter, err := router.NewNetstackRouter( - router.WithSocksListenAddr("localhost:1080"), - router.WithPcapPath("server.pcap"), - ) + serverRouter, err := router.NewNetlinkRouter() require.NoError(t, err) server := tunnel.NewTunnelServer( kubeClient, jwtValidator, - netstackRouter, + serverRouter, tunnel.WithCertPath(filepath.Join(certsDir, "server.crt")), tunnel.WithKeyPath(filepath.Join(certsDir, "server.key")), ) @@ -115,11 +120,9 @@ func TestTunnelEndToEnd(t *testing.T) { ) require.NoError(t, err) - gCtx, gCancel := context.WithCancel(ctx) - t.Cleanup(gCancel) - g, gctx := errgroup.WithContext(gCtx) + g, ctx := errgroup.WithContext(ctx) - // Start a little http server listening on localhost (to test the tunnel) + // Start a little http server listening on the client side. httpListener, err := net.Listen("tcp", "localhost:0") require.NoError(t, err) @@ -133,7 +136,7 @@ func TestTunnelEndToEnd(t *testing.T) { g.Go(func() error { g.Go(func() error { - <-gctx.Done() + <-ctx.Done() t.Log("Closing HTTP test server") return httpServer.Close() }) @@ -154,7 +157,7 @@ func TestTunnelEndToEnd(t *testing.T) { t.Log("Starting tunnel server") - if err := server.Start(gctx); err != nil && !errors.Is(err, http.ErrServerClosed) { + if err := server.Start(ctx); err != nil && !errors.Is(err, http.ErrServerClosed) { return fmt.Errorf("unable to start server: %v", err) } @@ -165,28 +168,269 @@ func TestTunnelEndToEnd(t *testing.T) { g.Go(func() error { defer t.Log("Tunnel client closed") - defer gCancel() // Abort everything when the client is done - // Wait for the server to start time.Sleep(1 * time.Second) t.Log("Starting tunnel client") - if err := client.Start(gCtx); err != nil { + if err := client.Start(ctx); err != nil { return fmt.Errorf("unable to connect to server: %v", err) } defer func() { _ = client.Close() }() - clientAddresses, err := client.LocalAddresses() + return nil + }) + + // Run the test + g.Go(func() error { + // Cancel the context when the test is done + defer cancel() + + var clientAddresses []netip.Prefix + err := retry.Do( + func() error { + var err error + clientAddresses, err = client.LocalAddresses() + if err != nil { + return err + } + if len(clientAddresses) == 0 { + return fmt.Errorf("no addresses yet") + } + return nil + }, + retry.Context(ctx), + retry.Attempts(10), + retry.Delay(time.Second), + ) if err != nil { - return fmt.Errorf("unable to get local addresses: %v", err) + return fmt.Errorf("failed to get client addresses: %w", err) } t.Logf("Assigned client addresses: %v", clientAddresses) - // Connect to the netstack routers / servers socks5 proxy + t.Log("Connecting to HTTP server running on client via the tunnel") + + httpPort := httpListener.Addr().(*net.TCPAddr).Port + resp, err := http.Get("http://" + net.JoinHostPort(clientAddresses[0].Addr().String(), fmt.Sprintf("%d", httpPort))) + require.NoError(t, err) + defer resp.Body.Close() + + // Read the response body + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.Equal(t, "Hello, world!\n", string(body)) + + t.Log("Connection successful") + + return nil + }) + + require.NoError(t, g.Wait()) +} + +func TestTunnelEndToEnd_KernelModeClient(t *testing.T) { + child := vm.RunTestInVM(t) + if !child { + return + } + + if testing.Verbose() { + slog.SetLogLoggerLevel(slog.LevelDebug) + } + + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + caCert, serverCert, err := cryptoutils.GenerateSelfSignedTLSCert("localhost") + require.NoError(t, err) + + certsDir := t.TempDir() + + // Save the server certificate and private key to the temporary directory as PEM files + err = cryptoutils.SaveCertificatePEM(serverCert, certsDir, "server", false) + require.NoError(t, err) + + // Create a client UUID and JWT token + // This UUID is used to identify the client in the server's tunnel node list. + // The JWT token is used for authentication and contains the client's UUID as the subject. + clientUUID := uuid.New() + + jwtPrivateKeyPEM, jwtPublicKeyPEM, err := cryptoutils.GenerateEllipticKeyPair() + require.NoError(t, err) + + jwtPrivateKey, err := cryptoutils.ParseEllipticPrivateKeyPEM(jwtPrivateKeyPEM) + require.NoError(t, err) + + clientAuthToken, err := jwt.NewWithClaims(jwt.SigningMethodES256, jwt.MapClaims{ + "sub": clientUUID.String(), + "iat": time.Now().Unix(), + "exp": time.Now().Add(time.Minute * 5).Unix(), + }).SignedString(jwtPrivateKey) + require.NoError(t, err) + + scheme := runtime.NewScheme() + require.NoError(t, corev1alpha.Install(scheme)) + + clientTunnelNode := &corev1alpha.TunnelNode{ + ObjectMeta: metav1.ObjectMeta{ + Name: "client", + UID: apimachinerytypes.UID(clientUUID.String()), + }, + Status: corev1alpha.TunnelNodeStatus{ + Credentials: &corev1alpha.TunnelNodeCredentials{ + Token: clientAuthToken, + }, + }, + } + + kubeClient := fake.NewClientBuilder().WithScheme(scheme). + WithObjects(clientTunnelNode).WithStatusSubresource(clientTunnelNode).Build() + + jwtValidator, err := token.NewInMemoryValidator(jwtPublicKeyPEM) + require.NoError(t, err) + + serverRouter, err := router.NewNetstackRouter( + // We need to assign atleast one local address to the server for netstack to work. + router.WithLocalAddresses([]netip.Prefix{ + netip.MustParsePrefix("fd00::/64"), + }), + router.WithPcapPath("server.pcap"), + ) + require.NoError(t, err) + + server := tunnel.NewTunnelServer( + kubeClient, + jwtValidator, + serverRouter, + tunnel.WithCertPath(filepath.Join(certsDir, "server.crt")), + tunnel.WithKeyPath(filepath.Join(certsDir, "server.key")), + ) + + // Register the client with the server + server.AddTunnelNode(clientTunnelNode) + + // Create a new tunnel client + client, err := tunnel.NewTunnelClient( + tunnel.WithUUID(clientUUID), + tunnel.WithAuthToken(clientAuthToken), + tunnel.WithRootCAs(cryptoutils.CertPoolForCertificate(caCert)), + tunnel.WithMode(tunnel.TunnelClientModeKernel), + tunnel.WithPcapPath("client.pcap"), + ) + require.NoError(t, err) + + g, ctx := errgroup.WithContext(ctx) + + // Start the server + g.Go(func() error { + defer t.Log("Tunnel server closed") + + t.Log("Starting tunnel server") + + if err := server.Start(ctx); err != nil && !errors.Is(err, http.ErrServerClosed) { + return fmt.Errorf("unable to start server: %v", err) + } + + return nil + }) + + // Start the client + g.Go(func() error { + defer t.Log("Tunnel client closed") + + // Wait for the server to start + time.Sleep(1 * time.Second) + + t.Log("Starting tunnel client") + + if err := client.Start(ctx); err != nil { + return fmt.Errorf("unable to connect to server: %v", err) + } + defer func() { + _ = client.Close() + }() + + return nil + }) + + var httpListener net.Listener + g.Go(func() error { + var clientAddresses []netip.Prefix + err := retry.Do( + func() error { + var err error + clientAddresses, err = client.LocalAddresses() + if err != nil { + return err + } + if len(clientAddresses) == 0 { + return fmt.Errorf("no addresses yet") + } + return nil + }, + retry.Context(ctx), + retry.Attempts(10), + retry.Delay(time.Second), + ) + if err != nil { + return fmt.Errorf("failed to get client addresses: %w", err) + } + + // Start a little http server listening on the client side. + httpListener, err = net.Listen("tcp", net.JoinHostPort(clientAddresses[0].Addr().String(), "0")) + require.NoError(t, err) + + httpServer := &http.Server{ + Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/plain") + w.WriteHeader(http.StatusOK) + fmt.Fprintln(w, "Hello, world!") + }), + } + + g.Go(func() error { + <-ctx.Done() + t.Log("Closing HTTP test server") + return httpServer.Close() + }) + + defer t.Log("HTTP test server closed") + t.Log("Starting HTTP test server") + + if err := httpServer.Serve(httpListener); err != nil && !errors.Is(err, http.ErrServerClosed) { + return fmt.Errorf("unable to start HTTP server: %v", err) + } + + return nil + }) + + g.Go(func() error { + // Cancel the context when the test is done + defer cancel() + + err := retry.Do( + func() error { + if httpListener == nil { + return fmt.Errorf("http listener not ready") + } + return nil + }, + retry.Context(ctx), + retry.Attempts(10), + retry.Delay(1*time.Second), + ) + if err != nil { + return fmt.Errorf("listener never became ready: %w", err) + } + + t.Logf("Connecting to HTTP server running on client via the tunnel: %s", + httpListener.Addr().(*net.TCPAddr).String()) + dialer, err := proxyclient.SOCKS5("tcp", "localhost:1080", nil, proxyclient.Direct) require.NoError(t, err) @@ -196,10 +440,7 @@ func TestTunnelEndToEnd(t *testing.T) { }, } - t.Log("Connecting to HTTP server through server SOCKS5 proxy") - - httpPort := httpListener.Addr().(*net.TCPAddr).Port - resp, err := client.Get("http://" + net.JoinHostPort(clientAddresses[0].Addr().String(), fmt.Sprintf("%d", httpPort))) + resp, err := client.Get("http://" + httpListener.Addr().(*net.TCPAddr).String()) require.NoError(t, err) defer resp.Body.Close() diff --git a/pkg/utils/vm/vm.go b/pkg/utils/vm/vm.go index 5a9f781f..b0a56b3e 100644 --- a/pkg/utils/vm/vm.go +++ b/pkg/utils/vm/vm.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "io" + "net" "net/http" "os" "os/exec" @@ -51,7 +52,7 @@ func RunTestInVM(t *testing.T) bool { // Download the image if not already present if _, err := os.Stat(imagePath); os.IsNotExist(err) { - imageURL := fmt.Sprintf("https://cdimage.debian.org/images/cloud/bookworm/20250428-2096/debian-12-generic-amd64-20250428-2096.qcow2") + imageURL := fmt.Sprintf("https://cdimage.debian.org/images/cloud/bookworm/latest/debian-12-generic-amd64.qcow2") t.Logf("Downloading image from %s...\n", imageURL) @@ -104,9 +105,16 @@ func RunTestInVM(t *testing.T) bool { return false } + sshPort, err := getFreePort() + if err != nil { + t.Fatalf("failed to find free SSH port: %v", err) + return false + } + t.Logf("Using random SSH host port: %d", sshPort) + qemuParams := []string{ "-m", "1024M", - "-netdev", "user,id=net0,hostfwd=tcp::10022-:22", + "-netdev", fmt.Sprintf("user,id=net0,hostfwd=tcp::%d-:22", sshPort), "-device", "e1000,netdev=net0,mac=52:54:00:12:34:56", "-snapshot", } @@ -175,7 +183,7 @@ func RunTestInVM(t *testing.T) bool { // Wait for SSH to become available var conn *ssh.Client for i := 0; i < 10; i++ { - conn, err = ssh.Dial("tcp", "localhost:10022", config) + conn, err = ssh.Dial("tcp", fmt.Sprintf("localhost:%d", sshPort), config) if err == nil { break } @@ -271,3 +279,12 @@ func createCloudInitISO(w io.Writer, userData, networkConfig, metaData string) e return nil } + +func getFreePort() (int, error) { + l, err := net.Listen("tcp", ":0") + if err != nil { + return 0, err + } + defer l.Close() + return l.Addr().(*net.TCPAddr).Port, nil +}