From 63bfe5735709e507f42a6e2be36bb880a8a596fb Mon Sep 17 00:00:00 2001 From: Damian Peckett Date: Fri, 4 Apr 2025 12:15:10 +0200 Subject: [PATCH 1/3] [wg] the abstraction is now TunnelTransport --- pkg/wireguard/kernel_network.go | 19 ------ pkg/wireguard/kernel_transport.go | 19 ++++++ ...ork_linux.go => kernel_transport_linux.go} | 66 +++++++++---------- ...twork_test.go => kernel_transport_test.go} | 6 +- pkg/wireguard/{network.go => transport.go} | 6 +- ...pace_network.go => userspace_transport.go} | 50 +++++++------- ...rk_test.go => userspace_transport_test.go} | 6 +- 7 files changed, 86 insertions(+), 86 deletions(-) delete mode 100644 pkg/wireguard/kernel_network.go create mode 100644 pkg/wireguard/kernel_transport.go rename pkg/wireguard/{kernel_network_linux.go => kernel_transport_linux.go} (82%) rename pkg/wireguard/{kernel_network_test.go => kernel_transport_test.go} (95%) rename pkg/wireguard/{network.go => transport.go} (83%) rename pkg/wireguard/{userspace_network.go => userspace_transport.go} (83%) rename pkg/wireguard/{userspace_network_test.go => userspace_transport_test.go} (94%) diff --git a/pkg/wireguard/kernel_network.go b/pkg/wireguard/kernel_network.go deleted file mode 100644 index 40cd9b21..00000000 --- a/pkg/wireguard/kernel_network.go +++ /dev/null @@ -1,19 +0,0 @@ -//go:build !linux -// +build !linux - -package wireguard - -import ( - "errors" -) - -type KernelModeNetwork struct { - Network -} - -// NewKernelModeNetwork returns a new kernel mode wireguard network. -func NewKernelModeNetwork( - conf *DeviceConfig, -) (*KernelModeNetwork, error) { - return &KernelModeNetwork{}, errors.New("kernel mode networks are not supported on this platform") -} diff --git a/pkg/wireguard/kernel_transport.go b/pkg/wireguard/kernel_transport.go new file mode 100644 index 00000000..53d4ea10 --- /dev/null +++ b/pkg/wireguard/kernel_transport.go @@ -0,0 +1,19 @@ +//go:build !linux +// +build !linux + +package wireguard + +import ( + "errors" +) + +type KernelModeTransport struct { + TunnelTransport +} + +// NewKernelModeNetwork returns a new kernel mode wireguard network. +func NewKernelModeTransport( + conf *DeviceConfig, +) (*KernelModeTransport, error) { + return &KernelModeTransport{}, errors.New("kernel mode networks are not supported on this platform") +} diff --git a/pkg/wireguard/kernel_network_linux.go b/pkg/wireguard/kernel_transport_linux.go similarity index 82% rename from pkg/wireguard/kernel_network_linux.go rename to pkg/wireguard/kernel_transport_linux.go index bf283e04..f341a12f 100644 --- a/pkg/wireguard/kernel_network_linux.go +++ b/pkg/wireguard/kernel_transport_linux.go @@ -20,9 +20,9 @@ import ( "github.com/apoxy-dev/apoxy-cli/pkg/utils" ) -var _ Network = (*KernelModeNetwork)(nil) +var _ TunnelTransport = (*KernelModeTransport)(nil) -type KernelModeNetwork struct { +type KernelModeTransport struct { *network.FilteredNetwork privateKey wgtypes.Key ifaceName string @@ -30,10 +30,10 @@ type KernelModeNetwork struct { wgClient *wgctrl.Client } -// NewKernelModeNetwork returns a new kernel mode wireguard network. -func NewKernelModeNetwork( +// NewKernelModeTransport returns a new kernel mode wireguard network. +func NewKernelModeTransport( conf *DeviceConfig, -) (*KernelModeNetwork, error) { +) (*KernelModeTransport, error) { if conf.PrivateKey == nil { return nil, errors.New("private key is required") } @@ -111,7 +111,7 @@ func NewKernelModeNetwork( return nil, fmt.Errorf("could not configure WireGuard device: %w", err) } - return &KernelModeNetwork{ + return &KernelModeTransport{ FilteredNetwork: network.Filtered(&network.FilteredNetworkConfig{ AllowedDestinations: localAddresses, Upstream: network.Host(), @@ -123,23 +123,23 @@ func NewKernelModeNetwork( }, nil } -func (n *KernelModeNetwork) Close() error { - defer n.wgClient.Close() +func (t *KernelModeTransport) Close() error { + defer t.wgClient.Close() - link, err := netlink.LinkByName(n.ifaceName) + link, err := netlink.LinkByName(t.ifaceName) if err != nil { - return fmt.Errorf("could not find interface %s: %w", n.ifaceName, err) + return fmt.Errorf("could not find interface %s: %w", t.ifaceName, err) } if err := netlink.LinkDel(link); err != nil { - return fmt.Errorf("could not delete interface %s: %w", n.ifaceName, err) + return fmt.Errorf("could not delete interface %s: %w", t.ifaceName, err) } return nil } -func (n *KernelModeNetwork) Peers() ([]PeerConfig, error) { - device, err := n.wgClient.Device(n.ifaceName) +func (t *KernelModeTransport) Peers() ([]PeerConfig, error) { + device, err := t.wgClient.Device(t.ifaceName) if err != nil { return nil, fmt.Errorf("could not fetch WireGuard device info: %w", err) } @@ -164,7 +164,7 @@ func (n *KernelModeNetwork) Peers() ([]PeerConfig, error) { return peers, nil } -func (n *KernelModeNetwork) AddPeer(peerConf *PeerConfig) error { +func (t *KernelModeTransport) AddPeer(peerConf *PeerConfig) error { publicKey, err := wgtypes.ParseKey(*peerConf.PublicKey) if err != nil { return fmt.Errorf("invalid public key: %w", err) @@ -208,16 +208,16 @@ func (n *KernelModeNetwork) AddPeer(peerConf *PeerConfig) error { peer.PersistentKeepaliveInterval = ptr.To(time.Duration(*peerConf.PersistentKeepaliveIntervalSec) * time.Second) } - if err := n.wgClient.ConfigureDevice(n.ifaceName, wgtypes.Config{ + if err := t.wgClient.ConfigureDevice(t.ifaceName, wgtypes.Config{ Peers: []wgtypes.PeerConfig{peer}, }); err != nil { - return fmt.Errorf("could not add peer to interface %s: %w", n.ifaceName, err) + return fmt.Errorf("could not add peer to interface %s: %w", t.ifaceName, err) } // Add route to the peer's allowed IPs - link, err := netlink.LinkByName(n.ifaceName) + link, err := netlink.LinkByName(t.ifaceName) if err != nil { - return fmt.Errorf("could not find interface %s: %w", n.ifaceName, err) + return fmt.Errorf("could not find interface %s: %w", t.ifaceName, err) } for _, allowedIP := range peerConf.AllowedIPs { @@ -238,20 +238,20 @@ func (n *KernelModeNetwork) AddPeer(peerConf *PeerConfig) error { return fmt.Errorf("could not add route to %s: %w", allowedIP, err) } - n.FilteredNetwork.AddAllowedDestination(prefix) + t.FilteredNetwork.AddAllowedDestination(prefix) } return nil } -func (n *KernelModeNetwork) RemovePeer(publicKey string) error { +func (t *KernelModeTransport) RemovePeer(publicKey string) error { parsedKey, err := wgtypes.ParseKey(publicKey) if err != nil { return fmt.Errorf("invalid public key: %w", err) } // Remove route to the peer's allowed IPs - device, err := n.wgClient.Device(n.ifaceName) + device, err := t.wgClient.Device(t.ifaceName) if err != nil { return fmt.Errorf("could not fetch WireGuard device info: %w", err) } @@ -264,9 +264,9 @@ func (n *KernelModeNetwork) RemovePeer(publicKey string) error { } } - link, err := netlink.LinkByName(n.ifaceName) + link, err := netlink.LinkByName(t.ifaceName) if err != nil { - return fmt.Errorf("could not find interface %s: %w", n.ifaceName, err) + return fmt.Errorf("could not find interface %s: %w", t.ifaceName, err) } for _, allowedIP := range peerAllowedIPs { @@ -283,7 +283,7 @@ func (n *KernelModeNetwork) RemovePeer(publicKey string) error { } addr, _ := netip.AddrFromSlice(allowedIP.IP) - n.FilteredNetwork.RemoveAllowedDestination(netip.PrefixFrom(addr, len(allowedIP.Mask)*8)) + t.FilteredNetwork.RemoveAllowedDestination(netip.PrefixFrom(addr, len(allowedIP.Mask)*8)) } peer := wgtypes.PeerConfig{ @@ -291,23 +291,23 @@ func (n *KernelModeNetwork) RemovePeer(publicKey string) error { Remove: true, } - if err := n.wgClient.ConfigureDevice(n.ifaceName, wgtypes.Config{ + if err := t.wgClient.ConfigureDevice(t.ifaceName, wgtypes.Config{ Peers: []wgtypes.PeerConfig{peer}, }); err != nil { - return fmt.Errorf("could not remove peer from interface %s: %w", n.ifaceName, err) + return fmt.Errorf("could not remove peer from interface %s: %w", t.ifaceName, err) } return nil } -func (n *KernelModeNetwork) PublicKey() string { - return n.privateKey.PublicKey().String() +func (t *KernelModeTransport) PublicKey() string { + return t.privateKey.PublicKey().String() } -func (n *KernelModeNetwork) LocalAddresses() ([]netip.Prefix, error) { - link, err := netlink.LinkByName(n.ifaceName) +func (t *KernelModeTransport) LocalAddresses() ([]netip.Prefix, error) { + link, err := netlink.LinkByName(t.ifaceName) if err != nil { - return nil, fmt.Errorf("could not find interface %s: %w", n.ifaceName, err) + return nil, fmt.Errorf("could not find interface %s: %w", t.ifaceName, err) } addrs, err := netlink.AddrList(link, netlink.FAMILY_ALL) @@ -328,8 +328,8 @@ func (n *KernelModeNetwork) LocalAddresses() ([]netip.Prefix, error) { return prefixes, nil } -func (n *KernelModeNetwork) ListenPort() (uint16, error) { - return n.listenPort, nil +func (t *KernelModeTransport) ListenPort() (uint16, error) { + return t.listenPort, nil } func findNextAvailableInterface() (string, error) { diff --git a/pkg/wireguard/kernel_network_test.go b/pkg/wireguard/kernel_transport_test.go similarity index 95% rename from pkg/wireguard/kernel_network_test.go rename to pkg/wireguard/kernel_transport_test.go index 3d533160..4459f823 100644 --- a/pkg/wireguard/kernel_network_test.go +++ b/pkg/wireguard/kernel_transport_test.go @@ -24,7 +24,7 @@ import ( "github.com/apoxy-dev/apoxy-cli/pkg/wireguard" ) -func TestKernelModeNetwork(t *testing.T) { +func TestKernelModeTransport(t *testing.T) { if testing.Verbose() { slog.SetLogLoggerLevel(slog.LevelDebug) } @@ -42,7 +42,7 @@ func TestKernelModeNetwork(t *testing.T) { projectID := uuid.New() wgAddress := tunnel.NewApoxy4To6Prefix(projectID, "kernel-node") - kernelWGNet, err := wireguard.NewKernelModeNetwork(&wireguard.DeviceConfig{ + kernelWGNet, err := wireguard.NewKernelModeTransport(&wireguard.DeviceConfig{ PrivateKey: ptr.To(kernelPrivateKey.String()), Address: []string{wgAddress.String()}, }) @@ -60,7 +60,7 @@ func TestKernelModeNetwork(t *testing.T) { wgAddress = tunnel.NewApoxy4To6Prefix(projectID, "userspace-node") - wgNet, err := wireguard.NewUserspaceNetwork(&wireguard.DeviceConfig{ + wgNet, err := wireguard.NewUserspaceTransport(&wireguard.DeviceConfig{ PrivateKey: ptr.To(privateKey.String()), ListenPort: ptr.To(listenPort), Address: []string{wgAddress.String()}, diff --git a/pkg/wireguard/network.go b/pkg/wireguard/transport.go similarity index 83% rename from pkg/wireguard/network.go rename to pkg/wireguard/transport.go index 680d2256..f026f33f 100644 --- a/pkg/wireguard/network.go +++ b/pkg/wireguard/transport.go @@ -7,10 +7,10 @@ import ( "github.com/dpeckett/network" ) -// Network is an interface that represents a WireGuard network. +// TunnelTransport is an interface that represents a WireGuard network. // It provides methods to manage peers, retrieve local addresses, and -// listen for incoming connections. -type Network interface { +// make/listen for connections. +type TunnelTransport interface { io.Closer network.Network // Peers returns known peers associated with the network. diff --git a/pkg/wireguard/userspace_network.go b/pkg/wireguard/userspace_transport.go similarity index 83% rename from pkg/wireguard/userspace_network.go rename to pkg/wireguard/userspace_transport.go index 27e870bd..d108aab3 100644 --- a/pkg/wireguard/userspace_network.go +++ b/pkg/wireguard/userspace_transport.go @@ -23,18 +23,18 @@ import ( "github.com/apoxy-dev/apoxy-cli/pkg/wireguard/uapi" ) -var _ Network = (*UserspaceNetwork)(nil) +var _ TunnelTransport = (*UserspaceTransport)(nil) -// UserspaceNetwork is a user-space network implementation that uses WireGuard. -type UserspaceNetwork struct { +// UserspaceTransport is a user-space network implementation that uses WireGuard. +type UserspaceTransport struct { *network.NetstackNetwork tun *tunDevice dev *device.Device privateKey wgtypes.Key } -// NewUserspaceNetwork returns a new userspace wireguard network. -func NewUserspaceNetwork(conf *DeviceConfig) (*UserspaceNetwork, error) { +// NewUserspaceTransport returns a new userspace wireguard network. +func NewUserspaceTransport(conf *DeviceConfig) (*UserspaceTransport, error) { if conf.PrivateKey == nil { return nil, errors.New("private key is required") } @@ -95,7 +95,7 @@ func NewUserspaceNetwork(conf *DeviceConfig) (*UserspaceNetwork, error) { Nameservers: conf.DNS, } - return &UserspaceNetwork{ + return &UserspaceTransport{ NetstackNetwork: network.Netstack(tun.stack, tun.nicID, resolveConf), tun: tun, dev: dev, @@ -103,20 +103,20 @@ func NewUserspaceNetwork(conf *DeviceConfig) (*UserspaceNetwork, error) { }, nil } -func (n *UserspaceNetwork) Close() error { - n.dev.Close() // Closes tun device internally. +func (t *UserspaceTransport) Close() error { + t.dev.Close() // Closes tun device internally. return nil } // PublicKey returns the public key for this peer on the WireGuard network. -func (n *UserspaceNetwork) PublicKey() string { - return n.privateKey.PublicKey().String() +func (t *UserspaceTransport) PublicKey() string { + return t.privateKey.PublicKey().String() } // ListenPort returns the local listen port of this end of the tunnel. -func (n *UserspaceNetwork) ListenPort() (uint16, error) { +func (t *UserspaceTransport) ListenPort() (uint16, error) { var uapiConf strings.Builder - if err := n.dev.IpcGetOperation(&uapiConf); err != nil { + if err := t.dev.IpcGetOperation(&uapiConf); err != nil { return 0, fmt.Errorf("failed to get device config: %w", err) } @@ -138,8 +138,8 @@ func (n *UserspaceNetwork) ListenPort() (uint16, error) { } // LocalAddresses returns the list of local addresses assigned to the WireGuard network. -func (n *UserspaceNetwork) LocalAddresses() ([]netip.Prefix, error) { - nic := n.tun.stack.NICInfo()[n.tun.nicID] +func (t *UserspaceTransport) LocalAddresses() ([]netip.Prefix, error) { + nic := t.tun.stack.NICInfo()[t.tun.nicID] var addrs []netip.Prefix for _, assignedAddr := range nic.ProtocolAddresses { @@ -153,28 +153,28 @@ func (n *UserspaceNetwork) LocalAddresses() ([]netip.Prefix, error) { } // FowardToLoopback forwards all inbound traffic to the loopback interface. -func (n *UserspaceNetwork) FowardToLoopback(ctx context.Context) error { +func (t *UserspaceTransport) FowardToLoopback(ctx context.Context) error { // Allow outgoing packets to have a source address different from the address // assigned to the NIC. - if tcpipErr := n.tun.stack.SetSpoofing(n.tun.nicID, true); tcpipErr != nil { + if tcpipErr := t.tun.stack.SetSpoofing(t.tun.nicID, true); tcpipErr != nil { return fmt.Errorf("failed to enable spoofing: %v", tcpipErr) } // Allow incoming packets to have a destination address different from the // address assigned to the NIC. - if tcpipErr := n.tun.stack.SetPromiscuousMode(n.tun.nicID, true); tcpipErr != nil { + if tcpipErr := t.tun.stack.SetPromiscuousMode(t.tun.nicID, true); tcpipErr != nil { return fmt.Errorf("failed to enable promiscuous mode: %v", tcpipErr) } - tcpForwarder := netstack.TCPForwarder(ctx, n.tun.stack, network.Loopback()) + tcpForwarder := netstack.TCPForwarder(ctx, t.tun.stack, network.Loopback()) - n.tun.stack.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder) + t.tun.stack.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder) return nil } // Peers returns the list of public keys for all peers on the WireGuard network. -func (n *UserspaceNetwork) Peers() ([]PeerConfig, error) { +func (n *UserspaceTransport) Peers() ([]PeerConfig, error) { var uapiConf strings.Builder if err := n.dev.IpcGetOperation(&uapiConf); err != nil { return nil, fmt.Errorf("failed to get device config: %w", err) @@ -200,7 +200,7 @@ func (n *UserspaceNetwork) Peers() ([]PeerConfig, error) { } // AddPeer adds, or updates, a peer to the WireGuard network. -func (n *UserspaceNetwork) AddPeer(peerConf *PeerConfig) error { +func (t *UserspaceTransport) AddPeer(peerConf *PeerConfig) error { if peerConf.Endpoint != nil { // If it's an address, resolve it. If it's a name pass it through unmodified. host, port, err := net.SplitHostPort(*peerConf.Endpoint) @@ -230,7 +230,7 @@ func (n *UserspaceNetwork) AddPeer(peerConf *PeerConfig) error { return fmt.Errorf("failed to marshal peer config: %w", err) } - if err := n.dev.IpcSet(uapiPeerConf); err != nil { + if err := t.dev.IpcSet(uapiPeerConf); err != nil { return fmt.Errorf("failed to add peer: %w", err) } @@ -252,7 +252,7 @@ func (n *UserspaceNetwork) AddPeer(peerConf *PeerConfig) error { slog.Warn("failed to marshal peer config", slog.Any("error", err)) } - if err := n.dev.IpcSet(uapiPeerConf); err != nil { + if err := t.dev.IpcSet(uapiPeerConf); err != nil { slog.Warn("failed to set persistent keep-alive interval", slog.Any("error", err)) } }() @@ -262,7 +262,7 @@ func (n *UserspaceNetwork) AddPeer(peerConf *PeerConfig) error { } // RemovePeer removes a peer from the WireGuard network. -func (n *UserspaceNetwork) RemovePeer(publicKey string) error { +func (t *UserspaceTransport) RemovePeer(publicKey string) error { peerConf := &PeerConfig{ PublicKey: ptr.To(publicKey), Remove: ptr.To(true), @@ -273,7 +273,7 @@ func (n *UserspaceNetwork) RemovePeer(publicKey string) error { return fmt.Errorf("failed to marshal peer config: %w", err) } - if err := n.dev.IpcSet(uapiPeerConf); err != nil { + if err := t.dev.IpcSet(uapiPeerConf); err != nil { return fmt.Errorf("failed to remove peer: %w", err) } diff --git a/pkg/wireguard/userspace_network_test.go b/pkg/wireguard/userspace_transport_test.go similarity index 94% rename from pkg/wireguard/userspace_network_test.go rename to pkg/wireguard/userspace_transport_test.go index 81ee6c6b..125cc259 100644 --- a/pkg/wireguard/userspace_network_test.go +++ b/pkg/wireguard/userspace_transport_test.go @@ -19,7 +19,7 @@ import ( "github.com/apoxy-dev/apoxy-cli/pkg/wireguard" ) -func TestUserspaceNetwork(t *testing.T) { +func TestUserspaceTransport(t *testing.T) { handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) _, err := w.Write([]byte("Hello, World!")) @@ -48,7 +48,7 @@ func TestUserspaceNetwork(t *testing.T) { clientPort, err := utils.UnusedUDP4Port() require.NoError(t, err) - serverWGNet, err := wireguard.NewUserspaceNetwork(&wireguard.DeviceConfig{ + serverWGNet, err := wireguard.NewUserspaceTransport(&wireguard.DeviceConfig{ PrivateKey: ptr.To(base64.StdEncoding.EncodeToString(serverPrivateKey[:])), ListenPort: ptr.To(serverPort), Address: []string{"10.0.0.1/32"}, @@ -67,7 +67,7 @@ func TestUserspaceNetwork(t *testing.T) { require.NoError(t, serverWGNet.FowardToLoopback(context.Background())) - clientWGNet, err := wireguard.NewUserspaceNetwork(&wireguard.DeviceConfig{ + clientWGNet, err := wireguard.NewUserspaceTransport(&wireguard.DeviceConfig{ PrivateKey: ptr.To(base64.StdEncoding.EncodeToString(clientPrivateKey[:])), ListenPort: ptr.To(clientPort), Address: []string{"10.0.0.2/32"}, From a885843bbdecb7e1c967f1735bc6bd688488740c Mon Sep 17 00:00:00 2001 From: Damian Peckett Date: Fri, 4 Apr 2025 13:22:30 +0200 Subject: [PATCH 2/3] [connectip] add basic client and skeleton of server --- pkg/connectip/client.go | 186 ++++++++++++++++++++++ pkg/connectip/server.go | 33 ++++ pkg/connectip/splice.go | 73 +++++++++ pkg/connectip/transport.go | 12 ++ pkg/{wireguard => netstack}/tun_device.go | 75 +++++++-- pkg/tunnel/server.go | 8 +- pkg/wireguard/kernel_transport_linux.go | 3 +- pkg/wireguard/userspace_transport.go | 37 +---- 8 files changed, 374 insertions(+), 53 deletions(-) create mode 100644 pkg/connectip/client.go create mode 100644 pkg/connectip/server.go create mode 100644 pkg/connectip/splice.go create mode 100644 pkg/connectip/transport.go rename pkg/{wireguard => netstack}/tun_device.go (65%) diff --git a/pkg/connectip/client.go b/pkg/connectip/client.go new file mode 100644 index 00000000..93441372 --- /dev/null +++ b/pkg/connectip/client.go @@ -0,0 +1,186 @@ +package connectip + +import ( + "context" + "crypto/tls" + "crypto/x509" + "errors" + "fmt" + "log/slog" + "net" + "net/http" + "net/netip" + "strings" + "sync" + "time" + + "github.com/apoxy-dev/apoxy-cli/pkg/netstack" + "github.com/dpeckett/network" + connectip "github.com/quic-go/connect-ip-go" + "github.com/quic-go/quic-go" + "github.com/quic-go/quic-go/http3" + "github.com/yosida95/uritemplate/v3" +) + +var _ TunnelTransport = (*ClientTransport)(nil) + +type ClientConfig struct { + // The UUID identifying the client. + UUID string + // The authentication token for the client. + AuthToken string + // The optional path to a packet capture file. + PcapPath string + // Optional root CA certificates for TLS verification. + RootCAs *x509.CertPool +} + +type ClientTransport struct { + *network.NetstackNetwork + uuid string + authToken string + pcapPath string + rootCAs *x509.CertPool + + conn *connectip.Conn + tun *netstack.TunDevice + closeOnce sync.Once +} + +func NewClientTransport(conf *ClientConfig) *ClientTransport { + return &ClientTransport{ + uuid: conf.UUID, + authToken: conf.AuthToken, + pcapPath: conf.PcapPath, + rootCAs: conf.RootCAs, + } +} + +func (t *ClientTransport) Connect(ctx context.Context, serverAddr string) error { + tlsConfig := &tls.Config{ + ServerName: "proxy", + NextProtos: []string{http3.NextProtoH3}, + RootCAs: t.rootCAs, + } + + // Use the proxy address as the server name if it is a domain. + if addr, _, err := net.SplitHostPort(serverAddr); err == nil && net.ParseIP(addr) == nil { + tlsConfig.ServerName = addr + } + + qConn, err := quic.DialAddr( + ctx, + serverAddr, + tlsConfig, + &quic.Config{ + EnableDatagrams: true, + InitialPacketSize: 1350, + KeepAlivePeriod: 5 * time.Second, + MaxIdleTimeout: 5 * time.Minute, + }, + ) + if err != nil { + return fmt.Errorf("failed to dial QUIC connection: %w", err) + } + + tr := &http3.Transport{EnableDatagrams: true} + hconn := tr.NewClientConn(qConn) + + template := uritemplate.MustNew(fmt.Sprintf("https://proxy/connect/%s?token=%s", t.uuid, t.authToken)) + + var rsp *http.Response + t.conn, rsp, err = connectip.Dial(ctx, hconn, template) + if err != nil { + return fmt.Errorf("failed to dial connect-ip connection: %w", err) + } + if rsp.StatusCode != http.StatusOK { + return fmt.Errorf("unexpected status code: %d", rsp.StatusCode) + } + + slog.Info("Connected to server", slog.String("addr", serverAddr)) + + localPrefixes, err := t.conn.LocalPrefixes(ctx) + if err != nil { + return fmt.Errorf("failed to get local IP addresses: %w", err) + } + if len(localPrefixes) == 0 { + return errors.New("no local IP addresses available") + } + + // Filter out non-IPv6 addresses. + filteredLocalPrefixes := make([]netip.Prefix, 0, len(localPrefixes)) + for _, prefix := range localPrefixes { + if !prefix.Addr().Is6() { + slog.Warn("Skipping non-IPv6 address", slog.String("address", prefix.Addr().String())) + continue + } + + slog.Info("Adding IPv6 address", slog.String("prefix", prefix.String())) + filteredLocalPrefixes = append(filteredLocalPrefixes, prefix) + } + + resolveConf := &network.ResolveConfig{ + Nameservers: rsp.Header.Values("X-Apoxy-Nameservers"), + SearchDomains: rsp.Header.Values("X-Apoxy-DNS-SearchDomains"), + } + + // Parse DNS options from response headers. + if opts := rsp.Header.Values("X-Apoxy-DNS-Options"); len(opts) > 0 { + for _, opt := range opts { + if strings.HasPrefix(opt, "ndots:") { + var ndots int + if n, err := fmt.Sscanf(opt[6:], "%d", &ndots); err != nil || n != 1 { + ndots = 1 + } + resolveConf.NDots = &ndots + } + } + } + + slog.Info("Using DNS configuration", + slog.Any("nameservers", resolveConf.Nameservers), + slog.Any("searchDomains", resolveConf.SearchDomains), + slog.Any("nDots", resolveConf.NDots)) + + t.tun, err = netstack.NewTunDevice(filteredLocalPrefixes, nil, t.pcapPath) + if err != nil { + return fmt.Errorf("failed to create virtual TUN device: %w", err) + } + + t.NetstackNetwork = t.tun.Network(resolveConf) + + // TODO: how to bubble up errors from this? + go spliceConnToTunDevice(t.conn, t.tun) + + return nil +} + +func (t *ClientTransport) Close() error { + var closeErr error + + t.closeOnce.Do(func() { + if t.conn != nil { + if err := t.conn.Close(); err != nil { + closeErr = fmt.Errorf("failed to close connect-ip connection: %w", err) + } + } + + if t.tun != nil { + if err := t.tun.Close(); err != nil { + // combine errors if both fail + if closeErr != nil { + closeErr = fmt.Errorf("%v; also failed to close TUN device: %w", closeErr, err) + } else { + closeErr = fmt.Errorf("failed to close TUN device: %w", err) + } + } + } + }) + + return closeErr +} + +// FowardToLoopback forwards all inbound traffic to the loopback interface. +func (t *ClientTransport) FowardToLoopback(ctx context.Context) error { + return t.tun.ForwardTo(ctx, network.Loopback()) +} diff --git a/pkg/connectip/server.go b/pkg/connectip/server.go new file mode 100644 index 00000000..0c322cab --- /dev/null +++ b/pkg/connectip/server.go @@ -0,0 +1,33 @@ +package connectip + +import ( + "context" + + "github.com/dpeckett/network" +) + +var _ TunnelTransport = (*ServerTransport)(nil) + +type ServerConfig struct { + // TODO: Define server configuration options +} + +type ServerTransport struct { + network.Network +} + +func NewServerTransport(conf *ServerConfig) *ServerTransport { + return &ServerTransport{} +} + +func (t *ServerTransport) ListenForConnections(ctx context.Context) error { + // TODO: Implement server listening logic + + // TODO: use splice to move packets around + return nil +} + +func (t *ServerTransport) Close() error { + // TODO: Implement server closing logic + return nil +} diff --git a/pkg/connectip/splice.go b/pkg/connectip/splice.go new file mode 100644 index 00000000..3a704633 --- /dev/null +++ b/pkg/connectip/splice.go @@ -0,0 +1,73 @@ +package connectip + +import ( + "errors" + "fmt" + "log/slog" + "net" + + "github.com/apoxy-dev/apoxy-cli/pkg/netstack" + connectip "github.com/quic-go/connect-ip-go" + "golang.org/x/sync/errgroup" + "golang.zx2c4.com/wireguard/tun" +) + +func spliceConnToTunDevice(conn *connectip.Conn, tun tun.Device) error { + var g errgroup.Group + + g.Go(func() error { + var pkt [netstack.IPv6MinMTU]byte + sizes := make([]int, 1) + + for { + _, err := tun.Read([][]byte{pkt[:]}, sizes, 0) + if err != nil { + if errors.Is(err, net.ErrClosed) { + // TUN device is closed, exit the loop. + // TODO: is this the correct error + return nil + } + + return fmt.Errorf("failed to read from TUN: %w", err) + } + + slog.Debug("Read packet from TUN", slog.Int("len", sizes[0])) + + icmp, err := conn.WritePacket(pkt[:sizes[0]]) + if err != nil { + slog.Error("Failed to write to connection", slog.Any("error", err)) + continue + } + if len(icmp) > 0 { + slog.Debug("Sending ICMP packet") + + if _, err := t.tun.Write([][]byte{icmp}, 0); err != nil { + slog.Error("Failed to write ICMP packet", slog.Any("error", err)) + } + } + } + }) + + g.Go(func() error { + var pkt [netstack.IPv6MinMTU]byte + + for { + n, err := conn.ReadPacket(pkt[:]) + if err != nil { + if errors.Is(err, net.ErrClosed) { + return nil + } + return fmt.Errorf("failed to read from connection: %w", err) + } + + slog.Debug("Read from connection", slog.Int("bytes", n)) + + if _, err := tun.Write([][]byte{pkt[:n]}, 0); err != nil { + slog.Error("Failed to write to TUN", slog.Any("error", err)) + continue + } + } + }) + + return g.Wait() +} diff --git a/pkg/connectip/transport.go b/pkg/connectip/transport.go new file mode 100644 index 00000000..d74c39b1 --- /dev/null +++ b/pkg/connectip/transport.go @@ -0,0 +1,12 @@ +package connectip + +import ( + "io" + + "github.com/dpeckett/network" +) + +type TunnelTransport interface { + io.Closer + network.Network +} diff --git a/pkg/wireguard/tun_device.go b/pkg/netstack/tun_device.go similarity index 65% rename from pkg/wireguard/tun_device.go rename to pkg/netstack/tun_device.go index 1e3fe13c..cbeb78df 100644 --- a/pkg/wireguard/tun_device.go +++ b/pkg/netstack/tun_device.go @@ -1,11 +1,13 @@ -package wireguard +package netstack import ( + "context" "fmt" "net/netip" "os" "syscall" + "github.com/dpeckett/network" "golang.zx2c4.com/wireguard/tun" "k8s.io/utils/ptr" @@ -22,11 +24,11 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/transport/udp" ) -const DefaultMTU = 1280 // IPv6 minimum MTU, required for some PPPoE links. +const IPv6MinMTU = 1280 // IPv6 minimum MTU, required for some PPPoE links. -var _ tun.Device = (*tunDevice)(nil) +var _ tun.Device = (*TunDevice)(nil) -type tunDevice struct { +type TunDevice struct { ep *channel.Endpoint stack *stack.Stack nicID tcpip.NICID @@ -36,7 +38,7 @@ type tunDevice struct { mtu int } -func newTunDevice(localAddresses []netip.Prefix, mtu *int, pcapPath string) (*tunDevice, error) { +func NewTunDevice(localAddresses []netip.Prefix, mtu *int, pcapPath string) (*TunDevice, error) { opts := stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ ipv4.NewProtocol, @@ -65,7 +67,7 @@ func newTunDevice(localAddresses []netip.Prefix, mtu *int, pcapPath string) (*tu } if mtu == nil { - mtu = ptr.To(DefaultMTU) + mtu = ptr.To(IPv6MinMTU) } nicID := ipstack.NextNICID() @@ -122,7 +124,7 @@ func newTunDevice(localAddresses []netip.Prefix, mtu *int, pcapPath string) (*tu } } - tunDev := &tunDevice{ + tunDev := &TunDevice{ ep: linkEP, stack: ipstack, nicID: nicID, @@ -137,17 +139,17 @@ func newTunDevice(localAddresses []netip.Prefix, mtu *int, pcapPath string) (*tu return tunDev, nil } -func (tun *tunDevice) Name() (string, error) { return "go", nil } +func (tun *TunDevice) Name() (string, error) { return "go", nil } -func (tun *tunDevice) File() *os.File { return nil } +func (tun *TunDevice) File() *os.File { return nil } -func (tun *tunDevice) Events() <-chan tun.Event { return tun.events } +func (tun *TunDevice) Events() <-chan tun.Event { return tun.events } -func (tun *tunDevice) MTU() (int, error) { return tun.mtu, nil } +func (tun *TunDevice) MTU() (int, error) { return tun.mtu, nil } -func (tun *tunDevice) BatchSize() int { return 1 } +func (tun *TunDevice) BatchSize() int { return 1 } -func (tun *tunDevice) Read(buf [][]byte, sizes []int, offset int) (int, error) { +func (tun *TunDevice) Read(buf [][]byte, sizes []int, offset int) (int, error) { view, ok := <-tun.incomingPacket if !ok { return 0, os.ErrClosed @@ -161,7 +163,7 @@ func (tun *tunDevice) Read(buf [][]byte, sizes []int, offset int) (int, error) { return 1, nil } -func (tun *tunDevice) Write(buf [][]byte, offset int) (int, error) { +func (tun *TunDevice) Write(buf [][]byte, offset int) (int, error) { for _, buf := range buf { packet := buf[offset:] if len(packet) == 0 { @@ -181,7 +183,7 @@ func (tun *tunDevice) Write(buf [][]byte, offset int) (int, error) { return len(buf), nil } -func (tun *tunDevice) WriteNotify() { +func (tun *TunDevice) WriteNotify() { pkt := tun.ep.Read() if pkt == nil { return @@ -193,7 +195,7 @@ func (tun *tunDevice) WriteNotify() { tun.incomingPacket <- view } -func (tun *tunDevice) Close() error { +func (tun *TunDevice) Close() error { tun.stack.RemoveNIC(tun.nicID) if tun.events != nil { @@ -212,3 +214,44 @@ func (tun *tunDevice) Close() error { return nil } + +// Network returns the network abstraction for the TUN device. +func (tun *TunDevice) Network(resolveConf *network.ResolveConfig) *network.NetstackNetwork { + return network.Netstack(tun.stack, tun.nicID, resolveConf) +} + +// LocalAddresses returns the list of local addresses assigned to the TUN device. +func (tun *TunDevice) LocalAddresses() ([]netip.Prefix, error) { + nic := tun.stack.NICInfo()[tun.nicID] + + var addrs []netip.Prefix + for _, assignedAddr := range nic.ProtocolAddresses { + addrs = append(addrs, netip.PrefixFrom( + addrFromNetstackIP(assignedAddr.AddressWithPrefix.Address), + assignedAddr.AddressWithPrefix.PrefixLen, + )) + } + + return addrs, nil +} + +// ForwardTo forwards all inbound traffic to the upstream network. +func (tun *TunDevice) ForwardTo(ctx context.Context, upstream network.Network) error { + // Allow outgoing packets to have a source address different from the address + // assigned to the NIC. + if tcpipErr := tun.stack.SetSpoofing(tun.nicID, true); tcpipErr != nil { + return fmt.Errorf("failed to enable spoofing: %v", tcpipErr) + } + + // Allow incoming packets to have a destination address different from the + // address assigned to the NIC. + if tcpipErr := tun.stack.SetPromiscuousMode(tun.nicID, true); tcpipErr != nil { + return fmt.Errorf("failed to enable promiscuous mode: %v", tcpipErr) + } + + tcpForwarder := TCPForwarder(ctx, tun.stack, upstream) + + tun.stack.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder) + + return nil +} diff --git a/pkg/tunnel/server.go b/pkg/tunnel/server.go index 402c1a6d..538aec0e 100644 --- a/pkg/tunnel/server.go +++ b/pkg/tunnel/server.go @@ -36,14 +36,14 @@ import ( "sigs.k8s.io/controller-runtime/pkg/log" "sigs.k8s.io/controller-runtime/pkg/reconcile" + "github.com/apoxy-dev/apoxy-cli/pkg/netstack" "github.com/apoxy-dev/apoxy-cli/pkg/tunnel/token" corev1alpha "github.com/apoxy-dev/apoxy-cli/api/core/v1alpha" ) const ( - IPv6MinMTU = 1280 - tunOffset = device.MessageTransportHeaderSize + tunOffset = device.MessageTransportHeaderSize ) var ( @@ -51,7 +51,7 @@ var ( bufferPool = sync.Pool{ New: func() interface{} { - b := make([]byte, IPv6MinMTU+tunOffset) + b := make([]byte, netstack.IPv6MinMTU+tunOffset) return &b }, } @@ -184,7 +184,7 @@ func (t *TunnelServer) Start(ctx context.Context, mgr ctrl.Manager) error { // 1. Setup QUIC server. var err error - t.dev, err = tun.CreateTUN(t.options.tunName, IPv6MinMTU) + t.dev, err = tun.CreateTUN(t.options.tunName, netstack.IPv6MinMTU) if err != nil { return fmt.Errorf("failed to create TUN interface: %w", err) } diff --git a/pkg/wireguard/kernel_transport_linux.go b/pkg/wireguard/kernel_transport_linux.go index f341a12f..806d3a5a 100644 --- a/pkg/wireguard/kernel_transport_linux.go +++ b/pkg/wireguard/kernel_transport_linux.go @@ -17,6 +17,7 @@ import ( "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "k8s.io/utils/ptr" + "github.com/apoxy-dev/apoxy-cli/pkg/netstack" "github.com/apoxy-dev/apoxy-cli/pkg/utils" ) @@ -67,7 +68,7 @@ func NewKernelModeTransport( link := &netlink.GenericLink{ LinkAttrs: netlink.LinkAttrs{ Name: ifaceName, - MTU: DefaultMTU, + MTU: netstack.IPv6MinMTU, }, LinkType: "wireguard", } diff --git a/pkg/wireguard/userspace_transport.go b/pkg/wireguard/userspace_transport.go index d108aab3..651669aa 100644 --- a/pkg/wireguard/userspace_transport.go +++ b/pkg/wireguard/userspace_transport.go @@ -15,7 +15,6 @@ import ( "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "k8s.io/utils/ptr" "github.com/apoxy-dev/apoxy-cli/pkg/netstack" @@ -28,7 +27,7 @@ var _ TunnelTransport = (*UserspaceTransport)(nil) // UserspaceTransport is a user-space network implementation that uses WireGuard. type UserspaceTransport struct { *network.NetstackNetwork - tun *tunDevice + tun *netstack.TunDevice dev *device.Device privateKey wgtypes.Key } @@ -49,7 +48,7 @@ func NewUserspaceTransport(conf *DeviceConfig) (*UserspaceTransport, error) { return nil, fmt.Errorf("failed to parse local addresses: %w", err) } - tun, err := newTunDevice(localAddresses, conf.MTU, conf.PacketCapturePath) + tun, err := netstack.NewTunDevice(localAddresses, conf.MTU, conf.PacketCapturePath) if err != nil { return nil, fmt.Errorf("failed to create netstack device: %w", err) } @@ -96,7 +95,7 @@ func NewUserspaceTransport(conf *DeviceConfig) (*UserspaceTransport, error) { } return &UserspaceTransport{ - NetstackNetwork: network.Netstack(tun.stack, tun.nicID, resolveConf), + NetstackNetwork: tun.Network(resolveConf), tun: tun, dev: dev, privateKey: privateKey, @@ -139,38 +138,12 @@ func (t *UserspaceTransport) ListenPort() (uint16, error) { // LocalAddresses returns the list of local addresses assigned to the WireGuard network. func (t *UserspaceTransport) LocalAddresses() ([]netip.Prefix, error) { - nic := t.tun.stack.NICInfo()[t.tun.nicID] - - var addrs []netip.Prefix - for _, assignedAddr := range nic.ProtocolAddresses { - addrs = append(addrs, netip.PrefixFrom( - addrFromNetstackIP(assignedAddr.AddressWithPrefix.Address), - assignedAddr.AddressWithPrefix.PrefixLen, - )) - } - - return addrs, nil + return t.tun.LocalAddresses() } // FowardToLoopback forwards all inbound traffic to the loopback interface. func (t *UserspaceTransport) FowardToLoopback(ctx context.Context) error { - // Allow outgoing packets to have a source address different from the address - // assigned to the NIC. - if tcpipErr := t.tun.stack.SetSpoofing(t.tun.nicID, true); tcpipErr != nil { - return fmt.Errorf("failed to enable spoofing: %v", tcpipErr) - } - - // Allow incoming packets to have a destination address different from the - // address assigned to the NIC. - if tcpipErr := t.tun.stack.SetPromiscuousMode(t.tun.nicID, true); tcpipErr != nil { - return fmt.Errorf("failed to enable promiscuous mode: %v", tcpipErr) - } - - tcpForwarder := netstack.TCPForwarder(ctx, t.tun.stack, network.Loopback()) - - t.tun.stack.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder) - - return nil + return t.tun.ForwardTo(ctx, network.Loopback()) } // Peers returns the list of public keys for all peers on the WireGuard network. From e3d6eb831db6c71a83cf923106e748fe9f4305ce Mon Sep 17 00:00:00 2001 From: Damian Peckett Date: Fri, 4 Apr 2025 18:16:02 +0200 Subject: [PATCH 3/3] [connect-ip] Add a muxed connection type that uses a triemap under-the-hood to route to different downstream prefixes --- go.mod | 12 +- go.sum | 14 +-- pkg/{connectip => connip}/client.go | 9 +- pkg/connip/muxed_connection.go | 158 +++++++++++++++++++++++++ pkg/connip/muxed_connection_test.go | 139 ++++++++++++++++++++++ pkg/{connectip => connip}/server.go | 2 +- pkg/{connectip => connip}/splice.go | 10 +- pkg/{connectip => connip}/transport.go | 2 +- pkg/tunnel/ipam.go | 2 +- pkg/tunnel/server.go | 157 ++++-------------------- pkg/tunnel/token/token.go | 2 +- pkg/wireguard/userspace_transport.go | 13 -- 12 files changed, 343 insertions(+), 177 deletions(-) rename pkg/{connectip => connip}/client.go (97%) create mode 100644 pkg/connip/muxed_connection.go create mode 100644 pkg/connip/muxed_connection_test.go rename pkg/{connectip => connip}/server.go (97%) rename pkg/{connectip => connip}/splice.go (87%) rename pkg/{connectip => connip}/transport.go (86%) diff --git a/go.mod b/go.mod index cab4f946..b34edf58 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ toolchain go1.23.2 require ( github.com/ClickHouse/clickhouse-go/v2 v2.23.2 + github.com/alphadose/haxmap v1.4.1 github.com/buraksezer/olric v0.5.6 github.com/cncf/xds/go v0.0.0-20240905190251-b4127c9b8d78 github.com/coder/websocket v1.8.12 @@ -15,6 +16,7 @@ require ( github.com/docker/docker v27.2.0+incompatible github.com/dpeckett/contextio v0.5.1 github.com/dpeckett/network v0.3.1 + github.com/dpeckett/triemap v0.3.1 github.com/envoyproxy/gateway v0.5.0-rc.1.0.20240618131507-bdff5d56b59d github.com/envoyproxy/go-control-plane v0.13.0 github.com/envoyproxy/ratelimit v1.4.1-0.20230427142404-e2a87f41d3a7 @@ -24,6 +26,7 @@ require ( github.com/getsentry/sentry-go v0.26.0 github.com/go-logr/logr v1.4.2 github.com/goccy/go-json v0.9.11 + github.com/golang-jwt/jwt/v5 v5.2.2 github.com/golang-migrate/migrate/v4 v4.17.1 github.com/golang/protobuf v1.5.4 github.com/google/go-cmp v0.6.0 @@ -39,6 +42,8 @@ require ( github.com/opencontainers/image-spec v1.1.0 github.com/opencontainers/runc v1.2.2 github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8 + github.com/quic-go/connect-ip-go v0.0.0-20241112091351-321f13c3d203 + github.com/quic-go/quic-go v0.50.1 github.com/shirou/gopsutil v3.21.11+incompatible github.com/sirupsen/logrus v1.9.3 github.com/spf13/cobra v1.8.1 @@ -50,6 +55,7 @@ require ( github.com/vishvananda/netlink v1.3.0 github.com/vishvananda/netns v0.0.4 github.com/vmihailenco/msgpack v4.0.4+incompatible + github.com/yosida95/uritemplate/v3 v3.0.2 go.opentelemetry.io/proto/otlp v1.2.0 go.temporal.io/api v1.29.2 go.temporal.io/sdk v1.26.0 @@ -105,7 +111,6 @@ require ( github.com/NYTimes/gziphandler v1.1.1 // indirect github.com/Rican7/retry v0.1.0 // indirect github.com/RoaringBitmap/roaring v1.2.1 // indirect - github.com/alphadose/haxmap v1.4.1 // indirect github.com/andybalholm/brotli v1.1.0 // indirect github.com/antlr/antlr4/runtime/Go/antlr/v4 v4.0.0-20230305170008-8188dc5388df // indirect github.com/apache/thrift v0.20.0 // indirect @@ -173,7 +178,6 @@ require ( github.com/gogo/protobuf v1.3.2 // indirect github.com/golang-jwt/jwt v3.2.2+incompatible // indirect github.com/golang-jwt/jwt/v4 v4.5.1 // indirect - github.com/golang-jwt/jwt/v5 v5.2.2 // indirect github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect github.com/golang/mock v1.7.0-rc.1 // indirect github.com/golang/snappy v0.0.4 // indirect @@ -277,9 +281,7 @@ require ( github.com/prometheus/client_model v0.6.1 // indirect github.com/prometheus/common v0.60.1 // indirect github.com/prometheus/procfs v0.15.1 // indirect - github.com/quic-go/connect-ip-go v0.0.0-20241112091351-321f13c3d203 // indirect github.com/quic-go/qpack v0.5.1 // indirect - github.com/quic-go/quic-go v0.50.1 // indirect github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475 // indirect github.com/redis/go-redis/v9 v9.6.1 // indirect github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect @@ -324,8 +326,6 @@ require ( github.com/xdg-go/scram v1.1.2 // indirect github.com/xdg-go/stringprep v1.0.4 // indirect github.com/xiang90/probing v0.0.0-20190116061207-43a291ad63a2 // indirect - github.com/yosida95/uritemplate v2.0.0+incompatible // indirect - github.com/yosida95/uritemplate/v3 v3.0.2 // indirect github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 // indirect github.com/yusufpapurcu/wmi v1.2.4 // indirect go.etcd.io/bbolt v1.3.8 // indirect diff --git a/go.sum b/go.sum index 2ad11f8a..82741500 100644 --- a/go.sum +++ b/go.sum @@ -893,12 +893,12 @@ github.com/docker/spdystream v0.0.0-20160310174837-449fdfce4d96/go.mod h1:Qh8CwZ github.com/docopt/docopt-go v0.0.0-20180111231733-ee0de3bc6815/go.mod h1:WwZ+bS3ebgob9U8Nd0kOddGdZWjyMGR8Wziv+TBNwSE= github.com/dpeckett/contextio v0.5.1 h1:w19s6EThbZuRpa2z/Lu06v6+o3rrZhbBzmkol6en/hA= github.com/dpeckett/contextio v0.5.1/go.mod h1:IY/CQ1ee6y4C5j/mU0X0M/D84s2FxNisggbNClTPndc= -github.com/dpeckett/triemap v0.2.1 h1:qn4azAsnYMXBPYIdqtA7m8eYbmBssVN0Bo3ygSfeorE= -github.com/dpeckett/triemap v0.2.1/go.mod h1:pBxNH+K6m5I4lVo+W7u6JEanxP13adD4t2XYVMxfmTo= -github.com/dunglas/httpsfv v1.0.2 h1:iERDp/YAfnojSDJ7PW3dj1AReJz4MrwbECSSE59JWL0= -github.com/dunglas/httpsfv v1.0.2/go.mod h1:zID2mqw9mFsnt7YC3vYQ9/cjq30q41W+1AnDwH8TiMg= github.com/dpeckett/network v0.3.1 h1:rMDRLc85zc3v4mGcGfbOrNA9Kx69K2Xr8bD/Hc9MERY= github.com/dpeckett/network v0.3.1/go.mod h1:83quX+FE+BdOAKFEm5Om+QdI/1ZEQVNUBZSPl7V7erk= +github.com/dpeckett/triemap v0.3.1 h1:jzxCyKs/ATw9uCdD2bd0xFTPLIP9uZwX0iZUOOOIDoc= +github.com/dpeckett/triemap v0.3.1/go.mod h1:pBxNH+K6m5I4lVo+W7u6JEanxP13adD4t2XYVMxfmTo= +github.com/dunglas/httpsfv v1.0.2 h1:iERDp/YAfnojSDJ7PW3dj1AReJz4MrwbECSSE59JWL0= +github.com/dunglas/httpsfv v1.0.2/go.mod h1:zID2mqw9mFsnt7YC3vYQ9/cjq30q41W+1AnDwH8TiMg= github.com/dustin/go-humanize v0.0.0-20171111073723-bb3d318650d4/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= @@ -1616,8 +1616,6 @@ github.com/quic-go/connect-ip-go v0.0.0-20241112091351-321f13c3d203 h1:/SLaObCHs github.com/quic-go/connect-ip-go v0.0.0-20241112091351-321f13c3d203/go.mod h1:eck9h1BsbP2ri3dIiBinxTfR6vMjsOqt3XgIsz6aKmo= github.com/quic-go/qpack v0.5.1 h1:giqksBPnT/HDtZ6VhtFKgoLOWmlyo9Ei6u9PqzIMbhI= github.com/quic-go/qpack v0.5.1/go.mod h1:+PC4XFrEskIVkcLzpEkbLqq1uCoxPhQuvK5rH1ZgaEg= -github.com/quic-go/quic-go v0.48.1 h1:y/8xmfWI9qmGTc+lBr4jKRUWLGSlSigv847ULJ4hYXA= -github.com/quic-go/quic-go v0.48.1/go.mod h1:yBgs3rWBOADpga7F+jJsb6Ybg1LSYiQvwWlLX+/6HMs= github.com/quic-go/quic-go v0.50.1 h1:unsgjFIUqW8a2oopkY7YNONpV1gYND6Nt9hnt1PN94Q= github.com/quic-go/quic-go v0.50.1/go.mod h1:Vim6OmUvlYdwBhXP9ZVrtGmCMWa3wEqhq3NgYrI8b4E= github.com/rcrowley/go-metrics v0.0.0-20141108142129-dee209f2455f/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4= @@ -1816,8 +1814,6 @@ github.com/xhit/go-str2duration/v2 v2.1.0/go.mod h1:ohY8p+0f07DiV6Em5LKB0s2YpLtX github.com/xiang90/probing v0.0.0-20190116061207-43a291ad63a2 h1:eY9dn8+vbi4tKz5Qo6v2eYzo7kUS51QINcR5jNpbZS8= github.com/xiang90/probing v0.0.0-20190116061207-43a291ad63a2/go.mod h1:UETIi67q53MR2AWcXfiuqkDkRtnGDLqkBTpCHuJHxtU= github.com/xordataexchange/crypt v0.0.3-0.20170626215501-b2862e3d0a77/go.mod h1:aYKd//L2LvnjZzWKhF00oedf4jCCReLcmhLdhm1A27Q= -github.com/yosida95/uritemplate v2.0.0+incompatible h1:j6LR/+4tiD14zc0Z0M8QilHLULqgZFD47XqgXQgCE1A= -github.com/yosida95/uritemplate v2.0.0+incompatible/go.mod h1:mksJanHNnLsh6wYgt/AbBRZ4ogsHsO2uiZlm/UURY5c= github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= github.com/youmark/pkcs8 v0.0.0-20181117223130-1be2e3e5546d/go.mod h1:rHwXgn7JulP+udvsHwJoVG1YGAP6VLg4y9I5dyZdqmA= @@ -1918,8 +1914,6 @@ go.uber.org/fx v1.21.1/go.mod h1:HT2M7d7RHo+ebKGh9NRcrsrHHfpZ60nW3QRubMRfv48= go.uber.org/goleak v1.1.10/go.mod h1:8a7PlsEVH3e/a/GLqe5IIrQx6GzcnRmZEufDUTk4A7A= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= -go.uber.org/mock v0.4.0 h1:VcM4ZOtdbR4f6VXfiOpwpVJDL6lCReaZ6mw31wqh7KU= -go.uber.org/mock v0.4.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc= go.uber.org/mock v0.5.0 h1:KAMbZvZPyBPWgD14IrIQ38QCyjwpvVVV6K/bHl1IwQU= go.uber.org/mock v0.5.0/go.mod h1:ge71pBPLYDk7QIi1LupWxdAykm7KIEFchiOqd6z7qMM= go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0= diff --git a/pkg/connectip/client.go b/pkg/connip/client.go similarity index 97% rename from pkg/connectip/client.go rename to pkg/connip/client.go index 93441372..e063e460 100644 --- a/pkg/connectip/client.go +++ b/pkg/connip/client.go @@ -1,4 +1,4 @@ -package connectip +package connip import ( "context" @@ -14,12 +14,13 @@ import ( "sync" "time" - "github.com/apoxy-dev/apoxy-cli/pkg/netstack" "github.com/dpeckett/network" connectip "github.com/quic-go/connect-ip-go" "github.com/quic-go/quic-go" "github.com/quic-go/quic-go/http3" "github.com/yosida95/uritemplate/v3" + + "github.com/apoxy-dev/apoxy-cli/pkg/netstack" ) var _ TunnelTransport = (*ClientTransport)(nil) @@ -149,8 +150,8 @@ func (t *ClientTransport) Connect(ctx context.Context, serverAddr string) error t.NetstackNetwork = t.tun.Network(resolveConf) - // TODO: how to bubble up errors from this? - go spliceConnToTunDevice(t.conn, t.tun) + // TODO (dpeckett): how to bubble up errors from this? + go splice(t.tun, t.conn) return nil } diff --git a/pkg/connip/muxed_connection.go b/pkg/connip/muxed_connection.go new file mode 100644 index 00000000..2de913c5 --- /dev/null +++ b/pkg/connip/muxed_connection.go @@ -0,0 +1,158 @@ +package connip + +import ( + "errors" + "fmt" + "io" + "log/slog" + "net" + "net/netip" + "sync" + + "github.com/dpeckett/triemap" + + "github.com/apoxy-dev/apoxy-cli/pkg/netstack" +) + +// Connection is a simple interface implemented by connect-ip-go and custom +// connection types. +type Connection interface { + io.Closer + ReadPacket([]byte) (int, error) + WritePacket([]byte) ([]byte, error) +} + +var _ Connection = (*MuxedConnection)(nil) + +// MuxedConnection is a connection that multiplexes multiple downstream +// connections over a single virtual connection. +type MuxedConnection struct { + // Maps tunnel destination address to CONNECT-IP connection. + conns *triemap.TrieMap[Connection] + incomingPackets chan *[]byte + packetBufferPool sync.Pool +} + +func NewMuxedConnection() *MuxedConnection { + return &MuxedConnection{ + conns: triemap.New[Connection](), + incomingPackets: make(chan *[]byte, 100), + packetBufferPool: sync.Pool{ + New: func() interface{} { + b := make([]byte, netstack.IPv6MinMTU) + return &b + }, + }, + } +} + +func (m *MuxedConnection) AddConnection(prefix netip.Prefix, conn Connection) { + if prefix.IsValid() && prefix.Addr().Is6() { + m.conns.Insert(prefix, conn) + go m.readPackets(conn) + } else { + slog.Warn("Invalid prefix for connection", slog.String("prefix", prefix.String())) + } +} + +func (m *MuxedConnection) RemoveConnection(prefix netip.Prefix) error { + if prefix.IsValid() && prefix.Addr().Is6() { + conn, ok := m.conns.Get(prefix.Addr()) + if !ok { + return fmt.Errorf("no connection found for prefix: %s", prefix.String()) + } + + // Close the connection and remove it from the map. + if err := conn.Close(); err != nil { + return fmt.Errorf("failed to close connection: %w", err) + } + + // Remove the connection from the map. + m.conns.Remove(prefix) + } else { + return fmt.Errorf("invalid prefix for connection: %s", prefix.String()) + } + return nil +} + +func (m *MuxedConnection) Close() error { + // Close all connections in the map. + m.conns.ForEach(func(prefix netip.Prefix, conn Connection) bool { + if err := conn.Close(); err != nil { + slog.Warn("Failed to close connection", + slog.String("prefix", prefix.String()), slog.Any("error", err)) + } + return true + }) + + // Clear the map. + m.conns.Clear() + + // Close the incoming packets channel. + close(m.incomingPackets) + + return nil +} + +func (m *MuxedConnection) ReadPacket(pkt []byte) (int, error) { + p, ok := <-m.incomingPackets + if !ok { + return 0, net.ErrClosed + } + + n := copy(pkt, *p) + + m.packetBufferPool.Put(p) + + return n, nil +} + +func (m *MuxedConnection) WritePacket(pkt []byte) ([]byte, error) { + var dstIP netip.Addr + switch pkt[0] >> 4 { + case 6: + // IPv6 packet (RFC 8200) + if len(pkt) >= 40 { + var addr [16]byte + copy(addr[:], pkt[24:40]) + dstIP = netip.AddrFrom16(addr) + } else { + return nil, fmt.Errorf("IPv6 packet too short: %d", len(pkt)) + } + default: + return nil, fmt.Errorf("unknown packet type: %d", pkt[0]>>4) + } + + if !dstIP.IsValid() || !dstIP.Is6() || !dstIP.IsGlobalUnicast() { + return nil, fmt.Errorf("invalid destination IP: %s", dstIP.String()) + } + + slog.Debug("Packet destination", slog.String("ip", dstIP.String())) + + conn, ok := m.conns.Get(dstIP) + if !ok { + return nil, fmt.Errorf("no matching tunnel found for destination IP: %s", dstIP.String()) + } + + return conn.WritePacket(pkt) +} + +func (m *MuxedConnection) readPackets(conn Connection) { + for { + pkt := m.packetBufferPool.Get().(*[]byte) + + n, err := conn.ReadPacket(*pkt) + if err != nil { + if !errors.Is(err, net.ErrClosed) { + slog.Error("Failed to read from connection", slog.Any("error", err)) + } + + break + } + + slog.Debug("Read packet from connection", slog.Int("bytes", n)) + + *pkt = (*pkt)[:n] + m.incomingPackets <- pkt + } +} diff --git a/pkg/connip/muxed_connection_test.go b/pkg/connip/muxed_connection_test.go new file mode 100644 index 00000000..8240d9f9 --- /dev/null +++ b/pkg/connip/muxed_connection_test.go @@ -0,0 +1,139 @@ +package connip_test + +import ( + "net" + "net/netip" + "testing" + "time" + + "github.com/apoxy-dev/apoxy-cli/pkg/connip" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +type MockConnection struct { + mock.Mock + closed bool +} + +func (m *MockConnection) ReadPacket(p []byte) (int, error) { + if m.closed { + return 0, net.ErrClosed + } + + args := m.Called(p) + n := args.Int(0) + copy(p, args.Get(1).([]byte)) + return n, args.Error(2) +} + +func (m *MockConnection) WritePacket(p []byte) ([]byte, error) { + args := m.Called(p) + return args.Get(0).([]byte), args.Error(1) +} + +func (m *MockConnection) Close() error { + m.closed = true + return m.Called().Error(0) +} + +func TestMuxedConnection(t *testing.T) { + t.Run("Add and Remove Connection", func(t *testing.T) { + mux := connip.NewMuxedConnection() + mockConn := new(MockConnection) + mockConn.On("ReadPacket", mock.Anything).Return(0, []byte{}, nil).Maybe() + mockConn.On("Close").Return(nil).Once() + + prefix := netip.MustParsePrefix("2001:db8::/96") + mux.AddConnection(prefix, mockConn) + err := mux.RemoveConnection(prefix) + assert.NoError(t, err) + + // Try removing again should fail + err = mux.RemoveConnection(prefix) + assert.Error(t, err) + }) + + t.Run("Remove Connection - Invalid Prefix", func(t *testing.T) { + mux := connip.NewMuxedConnection() + prefix := netip.MustParsePrefix("192.0.2.0/24") + err := mux.RemoveConnection(prefix) + assert.Error(t, err) + }) + + t.Run("WritePacket - Success", func(t *testing.T) { + mux := connip.NewMuxedConnection() + mockConn := new(MockConnection) + mockConn.On("ReadPacket", mock.Anything).Return(0, []byte{}, nil).Maybe() + + prefix := netip.MustParsePrefix("2001:db8::/96") + mux.AddConnection(prefix, mockConn) + + pkt := make([]byte, 40) + pkt[0] = 0x60 // IPv6 + copy(pkt[24:40], netip.MustParseAddr("2001:db8::1").AsSlice()) + + mockConn.On("WritePacket", pkt).Return([]byte("ok"), nil).Once() + + resp, err := mux.WritePacket(pkt) + assert.NoError(t, err) + assert.Equal(t, []byte("ok"), resp) + mockConn.AssertExpectations(t) + }) + + t.Run("WritePacket - No Connection Found", func(t *testing.T) { + mux := connip.NewMuxedConnection() + + pkt := make([]byte, 40) + pkt[0] = 0x60 + copy(pkt[24:40], netip.MustParseAddr("2001:db8::1").AsSlice()) + + resp, err := mux.WritePacket(pkt) + assert.Nil(t, resp) + assert.ErrorContains(t, err, "no matching tunnel") + }) + + t.Run("ReadPacket - Success", func(t *testing.T) { + mux := connip.NewMuxedConnection() + mockConn := new(MockConnection) + + expected := []byte("hello") + mockConn.On("ReadPacket", mock.Anything).Return(len(expected), expected, nil) + + prefix := netip.MustParsePrefix("2001:db8::/96") + mux.AddConnection(prefix, mockConn) + + time.Sleep(10 * time.Millisecond) // let goroutine read once + + buf := make([]byte, 1500) + n, err := mux.ReadPacket(buf) + assert.NoError(t, err) + assert.Equal(t, len(expected), n) + assert.Equal(t, expected, buf[:n]) + mockConn.AssertExpectations(t) + }) + + t.Run("ReadPacket - Closed Channel", func(t *testing.T) { + mux := connip.NewMuxedConnection() + _ = mux.Close() + + buf := make([]byte, 1500) + _, err := mux.ReadPacket(buf) + + assert.ErrorIs(t, err, net.ErrClosed) + }) + + t.Run("Close - All Connections", func(t *testing.T) { + mux := connip.NewMuxedConnection() + mockConn := new(MockConnection) + mockConn.On("ReadPacket", mock.Anything).Return(0, []byte{}, nil).Maybe() + mockConn.On("Close").Return(nil).Once() + + prefix := netip.MustParsePrefix("2001:db8::/96") + mux.AddConnection(prefix, mockConn) + + err := mux.Close() + assert.NoError(t, err) + mockConn.AssertExpectations(t) + }) +} diff --git a/pkg/connectip/server.go b/pkg/connip/server.go similarity index 97% rename from pkg/connectip/server.go rename to pkg/connip/server.go index 0c322cab..60775ff0 100644 --- a/pkg/connectip/server.go +++ b/pkg/connip/server.go @@ -1,4 +1,4 @@ -package connectip +package connip import ( "context" diff --git a/pkg/connectip/splice.go b/pkg/connip/splice.go similarity index 87% rename from pkg/connectip/splice.go rename to pkg/connip/splice.go index 3a704633..3c629c49 100644 --- a/pkg/connectip/splice.go +++ b/pkg/connip/splice.go @@ -1,4 +1,4 @@ -package connectip +package connip import ( "errors" @@ -6,13 +6,13 @@ import ( "log/slog" "net" - "github.com/apoxy-dev/apoxy-cli/pkg/netstack" - connectip "github.com/quic-go/connect-ip-go" "golang.org/x/sync/errgroup" "golang.zx2c4.com/wireguard/tun" + + "github.com/apoxy-dev/apoxy-cli/pkg/netstack" ) -func spliceConnToTunDevice(conn *connectip.Conn, tun tun.Device) error { +func splice(tun tun.Device, conn Connection) error { var g errgroup.Group g.Go(func() error { @@ -41,7 +41,7 @@ func spliceConnToTunDevice(conn *connectip.Conn, tun tun.Device) error { if len(icmp) > 0 { slog.Debug("Sending ICMP packet") - if _, err := t.tun.Write([][]byte{icmp}, 0); err != nil { + if _, err := tun.Write([][]byte{icmp}, 0); err != nil { slog.Error("Failed to write ICMP packet", slog.Any("error", err)) } } diff --git a/pkg/connectip/transport.go b/pkg/connip/transport.go similarity index 86% rename from pkg/connectip/transport.go rename to pkg/connip/transport.go index d74c39b1..0b82ee9d 100644 --- a/pkg/connectip/transport.go +++ b/pkg/connip/transport.go @@ -1,4 +1,4 @@ -package connectip +package connip import ( "io" diff --git a/pkg/tunnel/ipam.go b/pkg/tunnel/ipam.go index f43914e8..c29f2f08 100644 --- a/pkg/tunnel/ipam.go +++ b/pkg/tunnel/ipam.go @@ -26,7 +26,7 @@ func (r *randomULA) Allocate(_ *http.Request) netip.Prefix { addr := apoxyULAPrefix.Addr().As16() // Generate 6 random bytes (48 bits) - this will fill the bits between /48 and /96 var randomBytes [6]byte - rand.Read(randomBytes[:]) + _, _ = rand.Read(randomBytes[:]) // Insert the random bytes into positions 6-11 (after the /48 prefix, before the /96 suffix) for i := 0; i < 6; i++ { diff --git a/pkg/tunnel/server.go b/pkg/tunnel/server.go index 538aec0e..4a6d9b46 100644 --- a/pkg/tunnel/server.go +++ b/pkg/tunnel/server.go @@ -3,18 +3,14 @@ package tunnel import ( - "bytes" "context" "crypto/tls" - goerrors "errors" "fmt" - "io" "log/slog" "net" "net/http" "net/netip" "strings" - "sync" "github.com/alphadose/haxmap" "github.com/google/uuid" @@ -23,8 +19,6 @@ import ( "github.com/quic-go/quic-go/http3" "github.com/vishvananda/netlink" "github.com/yosida95/uritemplate/v3" - "golang.org/x/sync/errgroup" - "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/tun" "k8s.io/apimachinery/pkg/api/errors" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -36,31 +30,15 @@ import ( "sigs.k8s.io/controller-runtime/pkg/log" "sigs.k8s.io/controller-runtime/pkg/reconcile" + "github.com/apoxy-dev/apoxy-cli/pkg/connip" "github.com/apoxy-dev/apoxy-cli/pkg/netstack" "github.com/apoxy-dev/apoxy-cli/pkg/tunnel/token" corev1alpha "github.com/apoxy-dev/apoxy-cli/api/core/v1alpha" ) -const ( - tunOffset = device.MessageTransportHeaderSize -) - var ( connectTmpl = uritemplate.MustNew("https://proxy/connect") - - bufferPool = sync.Pool{ - New: func() interface{} { - b := make([]byte, netstack.IPv6MinMTU+tunOffset) - return &b - }, - } - - bytesBufferPool = sync.Pool{ - New: func() interface{} { - return new(bytes.Buffer) - }, - } ) type TunnelOption func(*tunnelOptions) @@ -144,8 +122,8 @@ type TunnelServer struct { dev tun.Device ln *quic.EarlyListener - // Maps tunnel destination address to CONNECT-IP connection. - tuns *haxmap.Map[string, *connectip.Conn] + // Connections + mux *connip.MuxedConnection // Maps tunnelNodes *haxmap.Map[string, *corev1alpha.TunnelNode] } @@ -163,7 +141,7 @@ func NewTunnelServer(opts ...TunnelOption) *TunnelServer { EnableDatagrams: true, }, options: options, - tuns: haxmap.New[string, *connectip.Conn](), + mux: connip.NewMuxedConnection(), tunnelNodes: haxmap.New[string, *corev1alpha.TunnelNode](), } @@ -218,28 +196,24 @@ func (t *TunnelServer) Start(ctx context.Context, mgr ctrl.Manager) error { return fmt.Errorf("failed to create QUIC listener: %w", err) } - g, ctx := errgroup.WithContext(context.Background()) - g.Go(func() error { - g.Go(func() error { - <-ctx.Done() - return t.Shutdown(context.Background()) - }) + ctx, cancel := context.WithCancel(ctx) + defer cancel() - slog.Info("Starting HTTP/3 server", slog.String("addr", t.ln.Addr().String())) + go func() { + <-ctx.Done() - return t.ServeListener(t.ln) - }) - g.Go(func() error { - g.Go(func() error { - <-ctx.Done() - return t.dev.Close() - }) + if err := t.dev.Close(); err != nil { + slog.Error("Failed to close TUN device", slog.Any("error", err)) + } - slog.Info("Starting TUN muxer") + if err := t.Shutdown(context.Background()); err != nil { + slog.Error("Failed to shutdown QUIC server", slog.Any("error", err)) + } + }() + + slog.Info("Starting HTTP/3 server", slog.String("addr", t.ln.Addr().String())) - return t.mux(ctx) - }) - return g.Wait() + return t.ServeListener(t.ln) } func (t *TunnelServer) Stop(ctx context.Context) error { @@ -332,32 +306,7 @@ func (t *TunnelServer) handleConnect(w http.ResponseWriter, r *http.Request) { return } - t.tuns.Set(peerPrefix.String(), conn) - - go func() { - b := bufferPool.Get().(*[]byte) - defer bufferPool.Put(b) - - // TODO (dpeckett): add support for writing batched packets. - for { - n, err := conn.ReadPacket((*b)[tunOffset:]) - if err != nil { - if goerrors.Is(err, net.ErrClosed) { - slog.Info("Connection closed") - return - } - slog.Error("Failed to read from connection", slog.Any("error", err)) - continue - } - - slog.Debug("Read from connection", slog.Int("bytes", n)) - - if _, err := t.dev.Write([][]byte{(*b)[:n+tunOffset]}, tunOffset); err != nil { - slog.Error("Failed to write to TUN", slog.Any("error", err)) - continue - } - } - }() + t.mux.AddConnection(peerPrefix, conn) agent := corev1alpha.AgentStatus{ Name: uuid.NewString(), @@ -389,7 +338,9 @@ func (t *TunnelServer) handleConnect(w http.ResponseWriter, r *http.Request) { slog.Error("Failed to close connection", slog.Any("error", err)) } - t.tuns.Del(peerPrefix.String()) + if err := t.mux.RemoveConnection(peerPrefix); err != nil { + slog.Error("Failed to remove connection", slog.Any("error", err)) + } if err := t.options.ipam.Release(peerPrefix); err != nil { slog.Error("Failed to deallocate IP address", slog.Any("error", err)) @@ -479,70 +430,6 @@ func (t *TunnelServer) removeTUNPeer(peer netip.Prefix) error { return nil } -func (t *TunnelServer) mux(ctx context.Context) error { - for { - b := bufferPool.Get().(*[]byte) - sizes := make([]int, 1) - _, err := t.dev.Read([][]byte{*b}, sizes, 0) - if goerrors.Is(err, io.EOF) { - bufferPool.Put(b) - return nil - } else if err != nil { - bufferPool.Put(b) - return fmt.Errorf("failed to read from TUN: %w", err) - } - slog.Debug("Read packet from TUN", slog.Int("len", sizes[0])) - - var dstIP netip.Addr - switch (*b)[0] >> 4 { - case 6: - // IPv6 packet (RFC 8200) - if sizes[0] >= 40 { - var addr [16]byte - copy(addr[:], (*b)[24:40]) - dstIP = netip.AddrFrom16(addr) - } else { - slog.Debug("IPv6 packet too short", slog.Int("length", len(*b))) - bufferPool.Put(b) - continue - } - default: - slog.Warn("Unknown packet type (expected IPv6)", slog.String("type", fmt.Sprintf("%#x", (*b)[0]>>4))) - bufferPool.Put(b) - continue - } - if !dstIP.IsValid() || !dstIP.Is6() || !dstIP.IsGlobalUnicast() { - slog.Debug("Invalid destination IP", slog.String("ip", dstIP.String())) - bufferPool.Put(b) - continue - } - - slog.Debug("Packet destination", slog.String("ip", dstIP.String())) - - dstPrefix := netip.PrefixFrom(dstIP, 96) - conn, ok := t.tuns.Get(dstPrefix.String()) - if !ok { - slog.Debug("No matching tunnel found", slog.String("ip", dstPrefix.String())) - bufferPool.Put(b) - continue - } - - icmp, err := conn.WritePacket((*b)[:sizes[0]]) - bufferPool.Put(b) - if err != nil { - slog.Error("Failed to write to connection", slog.Any("error", err)) - continue - } - if len(icmp) > 0 { - slog.Debug("Sending ICMP packet") - if _, err := t.dev.Write([][]byte{icmp}, 0); err != nil { - slog.Error("Failed to write ICMP packet", slog.Any("error", err)) - } - } - } - panic("unreachable") -} - func (t *TunnelServer) reconcile(ctx context.Context, request reconcile.Request) (reconcile.Result, error) { node := &corev1alpha.TunnelNode{} if err := t.Get(ctx, request.NamespacedName, node); errors.IsNotFound(err) { diff --git a/pkg/tunnel/token/token.go b/pkg/tunnel/token/token.go index 5d61a068..79c2972e 100644 --- a/pkg/tunnel/token/token.go +++ b/pkg/tunnel/token/token.go @@ -62,7 +62,7 @@ func (v *Validator) Validate(tokenStr, subject string) error { return errors.New("subject claim not found or invalid") } - if strings.ToLower(tokenSubject) != strings.ToLower(subject) { + if strings.EqualFold(tokenSubject, subject) { return fmt.Errorf("token subject %q does not match expected subject %q", tokenSubject, subject) } diff --git a/pkg/wireguard/userspace_transport.go b/pkg/wireguard/userspace_transport.go index 651669aa..23ca7105 100644 --- a/pkg/wireguard/userspace_transport.go +++ b/pkg/wireguard/userspace_transport.go @@ -14,7 +14,6 @@ import ( "golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" - "gvisor.dev/gvisor/pkg/tcpip" "k8s.io/utils/ptr" "github.com/apoxy-dev/apoxy-cli/pkg/netstack" @@ -272,15 +271,3 @@ func parseAddressList(addrs []string) ([]netip.Prefix, error) { return parsed, nil } - -func addrFromNetstackIP(ip tcpip.Address) netip.Addr { - switch ip.Len() { - case 4: - ip := ip.As4() - return netip.AddrFrom4([4]byte{ip[0], ip[1], ip[2], ip[3]}) - case 16: - ip := ip.As16() - return netip.AddrFrom16(ip).Unmap() - } - return netip.Addr{} -}