diff --git a/go.mod b/go.mod index 48202a8..90e0cbf 100644 --- a/go.mod +++ b/go.mod @@ -33,6 +33,7 @@ require ( github.com/getsavvyinc/upgrade-cli v0.3.0 github.com/getsentry/sentry-go v0.26.0 github.com/go-logr/logr v1.4.3 + github.com/go-logr/stdr v1.2.2 github.com/goccy/go-json v0.9.11 github.com/golang-jwt/jwt v3.2.2+incompatible github.com/golang-jwt/jwt/v5 v5.3.0 @@ -60,6 +61,7 @@ require ( github.com/quic-go/quic-go v0.50.1 github.com/shirou/gopsutil v3.21.11+incompatible github.com/sirupsen/logrus v1.9.3 + github.com/slavc/xdp v0.3.4 github.com/spf13/cobra v1.9.1 github.com/stretchr/testify v1.10.0 github.com/telepresenceio/watchable v0.0.0-20220726211108-9bb86f92afa7 @@ -180,7 +182,6 @@ require ( github.com/ghodss/yaml v1.0.0 // indirect github.com/go-faster/city v1.0.1 // indirect github.com/go-faster/errors v0.7.1 // indirect - github.com/go-logr/stdr v1.2.2 // indirect github.com/go-logr/zapr v1.3.0 // indirect github.com/go-ole/go-ole v1.3.0 // indirect github.com/go-openapi/jsonpointer v0.21.1 // indirect @@ -305,6 +306,7 @@ require ( github.com/robfig/cron v1.2.0 // indirect github.com/robfig/cron/v3 v3.0.1 // indirect github.com/russross/blackfriday/v2 v2.1.0 // indirect + github.com/safchain/ethtool v0.6.1 // indirect github.com/sean-/seed v0.0.0-20170313163322-e2103e2c3529 // indirect github.com/seccomp/libseccomp-golang v0.10.0 // indirect github.com/segmentio/asm v1.2.0 // indirect diff --git a/go.sum b/go.sum index fc73222..c360280 100644 --- a/go.sum +++ b/go.sum @@ -121,8 +121,6 @@ github.com/apoxy-dev/apiserver-runtime v0.0.0-20250420214109-979c605051d1 h1:sAS github.com/apoxy-dev/apiserver-runtime v0.0.0-20250420214109-979c605051d1/go.mod h1:zOVeivsnCWenmbgr6kiefIExoqlbuv2xyg9SXXfbs5U= github.com/apoxy-dev/connect-ip-go v0.0.0-20250530062404-603929a73f45 h1:SwPk1n/oSVX7YwlNpC9KNH9YaYkcL/k6OfqSGVnxyyI= github.com/apoxy-dev/connect-ip-go v0.0.0-20250530062404-603929a73f45/go.mod h1:z5rtgIizc+/K27UtB0occwZgqg/mz3IqgyUJW8aubbI= -github.com/apoxy-dev/icx v0.7.1 h1:1uEvkyc2+IYHTvn8FMv/sbCOr8poJiQNmWLudGBuguY= -github.com/apoxy-dev/icx v0.7.1/go.mod h1:Muuk3bRXTp3YB5Xj+xHOGQ/T1xVxIKJuvmMfLBXhIN4= github.com/apoxy-dev/icx v0.7.2 h1:6GqlqxkjwyEwaQBAJJ40+iM6D6w46IKmKWtE/43bCUk= github.com/apoxy-dev/icx v0.7.2/go.mod h1:Muuk3bRXTp3YB5Xj+xHOGQ/T1xVxIKJuvmMfLBXhIN4= github.com/apoxy-dev/quic-go v0.0.0-20250530165952-53cca597715e h1:10GIpiVyKoRgCyr0J2TvJtdn17bsFHN+ROWkeVJpcOU= @@ -200,6 +198,7 @@ github.com/checkpoint-restore/go-criu/v6 v6.3.0/go.mod h1:rrRTN/uSwY2X+BPRl/gkul github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI= github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI= github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= +github.com/cilium/ebpf v0.4.0/go.mod h1:4tRaxcgiL706VnOzHOdBlY8IEAIdxINsQBcU4xJJXRs= github.com/cilium/ebpf v0.18.0 h1:OsSwqS4y+gQHxaKgg2U/+Fev834kdnsQbtzRnbVC6Gs= github.com/cilium/ebpf v0.18.0/go.mod h1:vmsAT73y4lW2b4peE+qcOqw6MxvWQdC+LiU5gd/xyo4= github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= @@ -368,6 +367,7 @@ github.com/fortytw2/leaktest v1.3.0 h1:u8491cBMTQ8ft8aeV+adlcytMZylmA5nnwwkRZjI8 github.com/fortytw2/leaktest v1.3.0/go.mod h1:jDsjWgpAGjm2CA7WthBh/CdZYEPF31XHquHwclZch5g= github.com/francoispqt/gojay v1.2.13 h1:d2m3sFjloqoIUQU3TsHBgj6qg/BVGlTBeHDUmyJnXKk= github.com/francoispqt/gojay v1.2.13/go.mod h1:ehT5mTG4ua4581f1++1WLG0vPdaA9HaiDsoyrBGkyDY= +github.com/frankban/quicktest v1.11.3/go.mod h1:wRf/ReqHper53s+kmmSZizM8NamnL3IM0I9ntUbOk+k= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= github.com/fsnotify/fsnotify v1.6.0/go.mod h1:sl3t1tCWJFWoRz9R8WJCbQihKKwmorjAbSClcnxKAGw= @@ -732,6 +732,7 @@ github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxv github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= @@ -800,6 +801,7 @@ github.com/metal-stack/go-ipam v1.14.12/go.mod h1:B6R3ADxm1r5C1DJafhI90oB3+DRRby github.com/microcosm-cc/bluemonday v1.0.1/go.mod h1:hsXNsILzKxV+sX77C5b8FSuKF00vh2OMYv+xgHpAMF4= github.com/miekg/dns v1.0.14/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg= github.com/miekg/dns v1.1.26/go.mod h1:bPDLeHnStXmXAq1m/Ch/hvfNHr14JKNPMBo3VZKjuso= +github.com/miekg/dns v1.1.35/go.mod h1:KNUDUusw/aVsxyTYZM1oqvCicbwhgbNgztCETuNZ7xM= github.com/miekg/dns v1.1.45/go.mod h1:e3IlAVfNqAllflbibAZEWOXOQ+Ynzk/dDozDxY7XnME= github.com/miekg/dns v1.1.63 h1:8M5aAw6OMZfFXTT7K5V0Eu5YiiL8l7nUAkyN6C9YwaY= github.com/miekg/dns v1.1.63/go.mod h1:6NGHfjhpmr5lt3XPLuyfDJi5AXbNIPM9PY6H6sF1Nfs= @@ -1004,6 +1006,8 @@ github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQD github.com/russross/blackfriday/v2 v2.1.0 h1:JIOH55/0cWyOuilr9/qlrm0BSXldqnqwMsf35Ld67mk= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/ryanuber/columnize v2.1.0+incompatible/go.mod h1:sm1tb6uqfes/u+d4ooFouqFdy9/2g9QGwK3SQygK0Ts= +github.com/safchain/ethtool v0.6.1 h1:mhRnXE1H8fV8TTXh/HdqE4tXtb57r//BQh5pPYMuM5k= +github.com/safchain/ethtool v0.6.1/go.mod h1:JzoNbG8xeg/BeVeVoMCtCb3UPWoppZZbFpA+1WFh+M0= github.com/samuel/go-thrift v0.0.0-20190219015601-e8b6b52668fe/go.mod h1:Vrkh1pnjV9Bl8c3P9zH0/D4NlOHWP5d4/hF4YTULaec= github.com/sean-/seed v0.0.0-20170313163322-e2103e2c3529 h1:nn5Wsu0esKSJiIVhscUtVbo7ada43DJhG55ua/hjS5I= github.com/sean-/seed v0.0.0-20170313163322-e2103e2c3529/go.mod h1:DxrIzT+xaE7yg65j358z/aeFdxmN0P9QXhEzd20vsDc= @@ -1050,6 +1054,8 @@ github.com/sirupsen/logrus v1.8.1/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic github.com/sirupsen/logrus v1.9.0/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= +github.com/slavc/xdp v0.3.4 h1:UFvt36LBz0pIqeRairAo4IP/sDWQ7mgT8LJRuF3MS8M= +github.com/slavc/xdp v0.3.4/go.mod h1:+pr19iCFDgI8wBXh5MUQftx5UuLlKaF7UvDAFWZq0zw= github.com/softlayer/softlayer-go v0.0.0-20180806151055-260589d94c7d h1:bVQRCxQvfjNUeRqaY/uT0tFuvuFY0ulgnczuR684Xic= github.com/softlayer/softlayer-go v0.0.0-20180806151055-260589d94c7d/go.mod h1:Cw4GTlQccdRGSEf6KiMju767x0NEHE0YIVPJSaXjlsw= github.com/soheilhy/cmux v0.1.4/go.mod h1:IM3LyeVVIOuxMH7sFAkER9+bJ4dT7Ms6E4xg4kGIyLM= @@ -1168,8 +1174,10 @@ github.com/vbatts/tar-split v0.11.3/go.mod h1:9QlHN18E+fEH7RdG+QAJJcuya3rqT7eXST github.com/vektah/gqlparser v1.1.2/go.mod h1:1ycwN7Ij5njmMkPPAOaRFY4rET2Enx7IkVv3vaXspKw= github.com/viant/assertly v0.4.8/go.mod h1:aGifi++jvCrUaklKEKT0BU95igDNaqkvz+49uaYMPRU= github.com/viant/toolbox v0.24.0/go.mod h1:OxMCG57V0PXuIP2HNQrtJf2CjqdmbrOx5EkMILuUhzM= +github.com/vishvananda/netlink v1.1.0/go.mod h1:cTgwzPIzzgDAYoQrMm0EdrjRUBkTqKYppBueQtXaqoE= github.com/vishvananda/netlink v1.3.1 h1:3AEMt62VKqz90r0tmNhog0r/PpWKmrEShJU0wJW6bV0= github.com/vishvananda/netlink v1.3.1/go.mod h1:ARtKouGSTGchR8aMwmkzC0qiNPrrWO5JS/XMVl45+b4= +github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df/go.mod h1:JP3t17pCcGlemwknint6hfoeCVQrEMVwxRLRjXpq+BU= github.com/vishvananda/netns v0.0.5 h1:DfiHV+j8bA32MFM7bfEunvT8IAqQ/NzSJHtcmW5zdEY= github.com/vishvananda/netns v0.0.5/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM= github.com/vmihailenco/msgpack v4.0.4+incompatible h1:dSLoQfGFAo3F6OoNhwUmLwVgaUXK79GlxNBwueZn0xI= @@ -1439,6 +1447,7 @@ golang.org/x/sys v0.0.0-20190316082340-a2f829d7f35f/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20190321052220-f7bb7a8bee54/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190606203320-7fc4e5ec1444/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190616124812-15dcb6c0061f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190826190057-c7b8b68b1456/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190904154756-749cb33beabd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -1541,6 +1550,7 @@ golang.org/x/tools v0.0.0-20191029041327-9cc4af7d6b2c/go.mod h1:b+2E5dAYhXwXZwtn golang.org/x/tools v0.0.0-20191029190741-b9c20aec41a5/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191108193012-7d206e10da11/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20191216052735-49a3e744a425/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= golang.org/x/tools v0.0.0-20201224043029-2b0845dc783e/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= diff --git a/pkg/cmd/alpha/alpha.go b/pkg/cmd/alpha/alpha.go index b80719c..d0ae12d 100644 --- a/pkg/cmd/alpha/alpha.go +++ b/pkg/cmd/alpha/alpha.go @@ -19,4 +19,5 @@ func Cmd() *cobra.Command { func init() { alphaCmd.AddCommand(rateLimitCmd) + alphaCmd.AddCommand(tunnelCmd) } diff --git a/pkg/cmd/alpha/tunnel.go b/pkg/cmd/alpha/tunnel.go new file mode 100644 index 0000000..01b4f95 --- /dev/null +++ b/pkg/cmd/alpha/tunnel.go @@ -0,0 +1,215 @@ +package alpha + +import ( + "context" + "crypto/tls" + "fmt" + "log/slog" + "net" + "net/netip" + "net/url" + "time" + + "github.com/dpeckett/network" + "github.com/spf13/cobra" + "golang.org/x/sync/errgroup" + + "github.com/apoxy-dev/apoxy/pkg/netstack" + "github.com/apoxy-dev/apoxy/pkg/tunnel/api" + "github.com/apoxy-dev/apoxy/pkg/tunnel/bifurcate" + "github.com/apoxy-dev/apoxy/pkg/tunnel/router" +) + +var ( + agentName string + tunnelName string + relayAddr string + token string + insecureSkipVerify bool + socksListenAddr string + pcapPath string +) + +var tunnelCmd = &cobra.Command{ + Use: "tunnel", + Short: "Manage tunnels", + Long: "Manage icx tunnels and connect to the remote Apoxy Edge fabric.", +} + +var tunnelRunCmd = &cobra.Command{ + Use: "run", + Short: "Run a tunnel", + Long: "Create a secure tunnel to the remote Apoxy Edge fabric.", + RunE: func(cmd *cobra.Command, args []string) error { + pc, err := net.ListenPacket("udp", ":0") + if err != nil { + return fmt.Errorf("failed to create UDP socket: %w", err) + } + defer pc.Close() + + pcGeneve, pcQuic := bifurcate.Bifurcate(pc) + defer pcGeneve.Close() + defer pcQuic.Close() + + apiURL := url.URL{ + Scheme: "https", + Host: relayAddr, + } + if insecureSkipVerify { + apiURL.Scheme = "http" + } + + tlsConf := &tls.Config{ + InsecureSkipVerify: insecureSkipVerify, + } + + client, err := api.NewClient(api.ClientOptions{ + BaseURL: apiURL.String(), + Agent: agentName, + TunnelName: tunnelName, + Token: token, + TLSConfig: tlsConf, + }) + if err != nil { + return fmt.Errorf("failed to create tunnel API client: %w", err) + } + defer client.Close() + + connectResp, err := client.Connect(cmd.Context()) + if err != nil { + return fmt.Errorf("failed to connect to tunnel relay: %w", err) + } + + // Ensure we disconnect when the command exits + defer func() { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + + err := client.Disconnect(ctx, connectResp.ID) + cancel() + if err != nil { + slog.Error("Failed to disconnect from tunnel", slog.Any("error", err)) + } + }() + + slog.Info("Connected to tunnel relay", slog.String("id", connectResp.ID), + slog.Int("vni", int(connectResp.VNI)), slog.Int("mtu", connectResp.MTU)) + + var routerOpts []router.Option + + if connectResp.DNS != nil { + resolveConf := &network.ResolveConfig{ + Nameservers: connectResp.DNS.Servers, + SearchDomains: connectResp.DNS.SearchDomains, + NDots: connectResp.DNS.NDots, + } + routerOpts = append(routerOpts, router.WithResolveConfig(resolveConf)) + } + + if socksListenAddr != "" { + routerOpts = append(routerOpts, router.WithSocksListenAddr(socksListenAddr)) + } + + if pcapPath != "" { + routerOpts = append(routerOpts, router.WithPcapPath(pcapPath)) + } + + r, err := router.NewICXNetstackRouter(pcGeneve, connectResp.MTU, routerOpts...) + if err != nil { + return fmt.Errorf("failed to create ICX netstack router: %w", err) + } + defer r.Close() + + remoteAddr, err := netip.ParseAddrPort(relayAddr) + if err != nil { + return fmt.Errorf("failed to parse relay address: %w", err) + } + + overlayAddrs, err := stringsToPrefixes(connectResp.Addresses) + if err != nil { + return fmt.Errorf("failed to parse assigned addresses: %w", err) + } + + if err := r.Handler.AddVirtualNetwork(connectResp.VNI, netstack.ToFullAddress(remoteAddr), overlayAddrs); err != nil { + return fmt.Errorf("failed to add virtual network to ICX handler: %w", err) + } + + g, ctx := errgroup.WithContext(cmd.Context()) + + g.Go(func() error { + // Rotate keys at half-life; retry with a short backoff on failure. + apply := func(k api.Keys) time.Duration { + // Apply new keys to the ICX handler. + if err := r.Handler.UpdateVirtualNetworkKeys(connectResp.VNI, k.Epoch, k.Recv, k.Send, k.ExpiresAt); err != nil { + slog.Error("Failed to apply new keys to router", slog.Any("error", err)) + } + + // Compute next refresh: half of remaining lifetime. + remaining := time.Until(k.ExpiresAt) + next := remaining / 2 + // Clamp to a sensible minimum to avoid tight loops. + if next < 10*time.Second { + next = 10 * time.Second + } + return next + } + + // Seed initial schedule from the keys we got on Connect. + next := apply(connectResp.Keys) + + timer := time.NewTimer(next) + defer timer.Stop() + + for { + select { + case <-ctx.Done(): + return ctx.Err() + case <-timer.C: + // Try to rotate keys. + upd, err := client.UpdateKeys(ctx, connectResp.ID) + if err != nil { + slog.Warn("Key update failed; retrying soon", slog.Any("error", err)) + timer.Reset(5 * time.Second) + continue + } + + slog.Info("Rotated tunnel keys", slog.Uint64("epoch", uint64(upd.Keys.Epoch))) + timer.Reset(apply(upd.Keys)) + } + } + }) + + g.Go(func() error { + return r.Start(ctx) + }) + + return g.Wait() + }, +} + +func init() { + tunnelRunCmd.Flags().StringVarP(&agentName, "agent", "a", "", "The name of this agent.") + tunnelRunCmd.Flags().StringVarP(&tunnelName, "name", "n", "", "The name of the tunnel to connect to.") + tunnelRunCmd.Flags().StringVarP(&relayAddr, "relay-addr", "r", "", "The address of the tunnel relay to connect to.") + tunnelRunCmd.Flags().StringVarP(&token, "token", "k", "", "The token to use for authenticating with the tunnel relay.") + tunnelRunCmd.Flags().BoolVar(&insecureSkipVerify, "insecure-skip-verify", false, "Skip TLS certificate verification.") + tunnelRunCmd.Flags().StringVarP(&pcapPath, "pcap", "p", "", "Path to an optional packet capture file to write.") + tunnelRunCmd.Flags().StringVar(&socksListenAddr, "socks-addr", "localhost:1080", "Listen address for SOCKS proxy.") + cobra.CheckErr(tunnelRunCmd.MarkFlagRequired("agent")) + cobra.CheckErr(tunnelRunCmd.MarkFlagRequired("name")) + cobra.CheckErr(tunnelRunCmd.MarkFlagRequired("relay-addr")) + cobra.CheckErr(tunnelRunCmd.MarkFlagRequired("token")) + + tunnelCmd.AddCommand(tunnelRunCmd) +} + +func stringsToPrefixes(addrs []string) ([]netip.Prefix, error) { + prefixes := make([]netip.Prefix, 0, len(addrs)) + for _, addr := range addrs { + p, err := netip.ParsePrefix(addr) + if err != nil { + return nil, fmt.Errorf("failed to parse address %q: %w", addr, err) + } + prefixes = append(prefixes, p) + } + return prefixes, nil +} diff --git a/pkg/netstack/icx_network.go b/pkg/netstack/icx_network.go index edfb1d5..081b5df 100644 --- a/pkg/netstack/icx_network.go +++ b/pkg/netstack/icx_network.go @@ -47,7 +47,7 @@ type ICXNetwork struct { // NewICXNetwork creates a new ICXNetwork instance with the given handler, physical connection, MTU, and resolve configuration. // If pcapPath is provided, it will create a packet sniffer that writes to the specified file. // The handler must be configured in layer3 mode. -func NewICXNetwork(handler *icx.Handler, phy *l2pc.L2PacketConn, pathMTU int, resolveConf *network.ResolveConfig, pcapPath string) (*ICXNetwork, error) { +func NewICXNetwork(handler *icx.Handler, phy *l2pc.L2PacketConn, mtu int, resolveConf *network.ResolveConfig, pcapPath string) (*ICXNetwork, error) { ipt := newIPTables() opts := stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ @@ -79,7 +79,7 @@ func NewICXNetwork(handler *icx.Handler, phy *l2pc.L2PacketConn, pathMTU int, re } nicID := ipstack.NextNICID() - linkEP := channel.New(4096, uint32(icx.MTU(pathMTU)), "") + linkEP := channel.New(4096, uint32(mtu), "") var nicEP stack.LinkEndpoint = linkEP var pcapFile *os.File @@ -116,7 +116,7 @@ func NewICXNetwork(handler *icx.Handler, phy *l2pc.L2PacketConn, pathMTU int, re incomingPacket: make(chan *buffer.View), pktPool: sync.Pool{ New: func() any { - b := make([]byte, 0, pathMTU) + b := make([]byte, 0, 65535) return &b }, }, @@ -291,6 +291,21 @@ func (net *ICXNetwork) DelAddr(addr netip.Prefix) error { return nil } +// LocalAddresses returns the list of local addresses assigned to the network. +func (net *ICXNetwork) LocalAddresses() ([]netip.Prefix, error) { + nic := net.stack.NICInfo()[net.nicID] + + var addrs []netip.Prefix + for _, assignedAddr := range nic.ProtocolAddresses { + addrs = append(addrs, netip.PrefixFrom( + addrFromNetstackIP(assignedAddr.AddressWithPrefix.Address), + assignedAddr.AddressWithPrefix.PrefixLen, + )) + } + + return addrs, nil +} + // ForwardTo forwards all inbound TCP traffic to the upstream network. func (net *ICXNetwork) ForwardTo(ctx context.Context, upstream network.Network) error { // Allow outgoing packets to have a source address different from the NIC. @@ -305,5 +320,9 @@ func (net *ICXNetwork) ForwardTo(ctx context.Context, upstream network.Network) tcpForwarder := TCPForwarder(ctx, net.stack, upstream) net.stack.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder) + + udpForwarder := UDPForwarder(ctx, net.stack, upstream) + net.stack.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder) + return nil } diff --git a/pkg/netstack/icx_network_test.go b/pkg/netstack/icx_network_test.go index 7dc9cb5..6071aa4 100644 --- a/pkg/netstack/icx_network_test.go +++ b/pkg/netstack/icx_network_test.go @@ -54,12 +54,15 @@ func TestICXNetwork_Speed(t *testing.T) { uaA := pcA.LocalAddr().(*net.UDPAddr) uaB := pcB.LocalAddr().(*net.UDPAddr) + aA := netip.AddrPortFrom(netip.MustParseAddr(uaA.IP.String()), uint16(uaA.Port)) + aB := netip.AddrPortFrom(netip.MustParseAddr(uaB.IP.String()), uint16(uaB.Port)) + // Build ICX handlers in L3 mode and link them together - hA, err := icx.NewHandler(icx.WithLocalAddr(toFullAddress(uaA)), + hA, err := icx.NewHandler(icx.WithLocalAddr(netstack.ToFullAddress(aA)), icx.WithVirtMAC(tcpip.GetRandMacAddr()), icx.WithLayer3VirtFrames()) require.NoError(t, err) - hB, err := icx.NewHandler(icx.WithLocalAddr(toFullAddress(uaB)), + hB, err := icx.NewHandler(icx.WithLocalAddr(netstack.ToFullAddress(aB)), icx.WithVirtMAC(tcpip.GetRandMacAddr()), icx.WithLayer3VirtFrames()) require.NoError(t, err) @@ -68,10 +71,10 @@ func TestICXNetwork_Speed(t *testing.T) { // Advertise a shared /24 so each side routes via the tunnel. route := netip.MustParsePrefix("10.1.0.0/24") - err = hA.AddVirtualNetwork(vni, toFullAddress(uaB), []netip.Prefix{route}) + err = hA.AddVirtualNetwork(vni, netstack.ToFullAddress(aB), []netip.Prefix{route}) require.NoError(t, err) - err = hB.AddVirtualNetwork(vni, toFullAddress(uaA), []netip.Prefix{route}) + err = hB.AddVirtualNetwork(vni, netstack.ToFullAddress(aA), []netip.Prefix{route}) require.NoError(t, err) var key [16]byte @@ -89,11 +92,12 @@ func TestICXNetwork_Speed(t *testing.T) { }) // Create two networks on top of the handlers - netA, err := netstack.NewICXNetwork(hA, l2A, 1500, nil, "") + mtu := icx.MTU(1500) // compute the inner MTU based on the path MTU + netA, err := netstack.NewICXNetwork(hA, l2A, mtu, nil, "") require.NoError(t, err) t.Cleanup(func() { require.NoError(t, netA.Close()) }) - netB, err := netstack.NewICXNetwork(hB, l2B, 1500, nil, "") + netB, err := netstack.NewICXNetwork(hB, l2B, mtu, nil, "") require.NoError(t, err) t.Cleanup(func() { require.NoError(t, netB.Close()) }) @@ -236,16 +240,3 @@ func TestICXNetwork_Speed(t *testing.T) { numStreams, totalBytes, totalRead, elapsed, mbpsBytes, mbps, gbps) }) } - -func toFullAddress(addr *net.UDPAddr) *tcpip.FullAddress { - if addr.IP.To4() != nil { - return &tcpip.FullAddress{ - Addr: tcpip.AddrFrom4Slice(addr.IP.To4()[:]), - Port: uint16(addr.Port), - } - } - return &tcpip.FullAddress{ - Addr: tcpip.AddrFrom16Slice(addr.IP.To16()[:]), - Port: uint16(addr.Port), - } -} diff --git a/pkg/netstack/utils.go b/pkg/netstack/utils.go new file mode 100644 index 0000000..0868e93 --- /dev/null +++ b/pkg/netstack/utils.go @@ -0,0 +1,23 @@ +package netstack + +import ( + "net/netip" + + "gvisor.dev/gvisor/pkg/tcpip" +) + +func ToFullAddress(addrPort netip.AddrPort) *tcpip.FullAddress { + if addrPort.Addr().Is4() { + addrv4 := addrPort.Addr().As4() + return &tcpip.FullAddress{ + Addr: tcpip.AddrFrom4Slice(addrv4[:]), + Port: uint16(addrPort.Port()), + } + } else { + addrv6 := addrPort.Addr().As16() + return &tcpip.FullAddress{ + Addr: tcpip.AddrFrom16Slice(addrv6[:]), + Port: uint16(addrPort.Port()), + } + } +} diff --git a/pkg/tunnel/adapter/connection.go b/pkg/tunnel/adapter/connection.go index 091c4c0..79cdac9 100644 --- a/pkg/tunnel/adapter/connection.go +++ b/pkg/tunnel/adapter/connection.go @@ -6,8 +6,8 @@ import ( "sync" "sync/atomic" + "github.com/apoxy-dev/apoxy/pkg/netstack" "github.com/apoxy-dev/icx" - "gvisor.dev/gvisor/pkg/tcpip" ) // Connection is a connection like abstraction over an icx virtual network. @@ -81,7 +81,7 @@ func (c *Connection) SetVNI(vni uint) error { addrs = []netip.Prefix{*c.overlayAddr} } - if err := c.handler.AddVirtualNetwork(vni, toFullAddress(c.remoteAddr), addrs); err != nil { + if err := c.handler.AddVirtualNetwork(vni, netstack.ToFullAddress(c.remoteAddr), addrs); err != nil { return fmt.Errorf("failed to add virtual network %d: %w", vni, err) } c.vni = &vni @@ -130,19 +130,3 @@ func (c *Connection) SetOverlayAddress(addr string) error { func (c *Connection) IncrementKeyEpoch() uint32 { return c.keyEpoch.Add(1) } - -func toFullAddress(addrPort netip.AddrPort) *tcpip.FullAddress { - if addrPort.Addr().Is4() { - addrv4 := addrPort.Addr().As4() - return &tcpip.FullAddress{ - Addr: tcpip.AddrFrom4Slice(addrv4[:]), - Port: uint16(addrPort.Port()), - } - } else { - addrv6 := addrPort.Addr().As16() - return &tcpip.FullAddress{ - Addr: tcpip.AddrFrom16Slice(addrv6[:]), - Port: uint16(addrPort.Port()), - } - } -} diff --git a/pkg/tunnel/connection/icx_conn.go b/pkg/tunnel/connection/icx_conn.go deleted file mode 100644 index a416d0a..0000000 --- a/pkg/tunnel/connection/icx_conn.go +++ /dev/null @@ -1,166 +0,0 @@ -package connection - -import ( - "errors" - "fmt" - "log/slog" - "net" - "sync" - - "github.com/apoxy-dev/icx" - "github.com/apoxy-dev/icx/udp" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/header" -) - -var ( - _ Connection = (*ICXConn)(nil) -) - -var ErrInvalidFrame = errors.New("invalid frame") - -type ICXConn struct { - pc net.PacketConn - handler *icx.Handler - localAddr *tcpip.FullAddress - pktPool sync.Pool -} - -// NewICXConn creates a new ICXConn instance from a PacketConn and an ICX handler. -// The ICX handler should be configured in layer 3 mode. -func NewICXConn(pc net.PacketConn, handler *icx.Handler) (*ICXConn, error) { - localAddr := pc.LocalAddr().(*net.UDPAddr) - if localAddr == nil { - return nil, fmt.Errorf("failed to get local address from PacketConn") - } - - return &ICXConn{ - pc: pc, - handler: handler, - localAddr: toFullAddress(localAddr), - pktPool: sync.Pool{ - New: func() any { - b := make([]byte, 0, 65535) - return &b - }, - }, - }, nil -} - -func (c *ICXConn) Close() error { - return c.pc.Close() -} - -func (c *ICXConn) ReadPacket(pkt []byte) (int, error) { - phyFrame := c.pktPool.Get().(*[]byte) - defer c.pktPool.Put(phyFrame) - *phyFrame = (*phyFrame)[:cap(*phyFrame)] - - // Temporarily read into the start of the buffer - n, raddr, err := c.pc.ReadFrom((*phyFrame)[:]) - if err != nil { - return 0, err - } - remoteAddr := raddr.(*net.UDPAddr) - - // Determine payload offset based on remoteAddr family - payloadOffset := udp.PayloadOffsetIPv4 - if remoteAddr.IP.To4() == nil { - payloadOffset = udp.PayloadOffsetIPv6 - } - - // Ensure there's enough space for move - if payloadOffset+n > cap(*phyFrame) { - return 0, errors.New("packet too large to fit in buffer with offset") - } - - // Shift the received data to payloadOffset - copy((*phyFrame)[payloadOffset:], (*phyFrame)[:n]) - - // Encode the UDP frame in-place starting from offset - phyFrameLen, err := udp.Encode(*phyFrame, toFullAddress(remoteAddr), c.localAddr, n, true) - if err != nil { - return 0, err - } - - pktLen := c.handler.PhyToVirt((*phyFrame)[:phyFrameLen], pkt) - if pktLen <= 0 { - slog.Warn("Invalid frame received", slog.String("remote", remoteAddr.String()), slog.Int("len", n)) - return 0, nil - } - - return pktLen, nil -} - -func (c *ICXConn) WritePacket(pkt []byte) ([]byte, error) { - phyFrame := c.pktPool.Get().(*[]byte) - defer c.pktPool.Put(phyFrame) - *phyFrame = (*phyFrame)[:cap(*phyFrame)] - - phyFrameLen, loopback := c.handler.VirtToPhy(pkt, *phyFrame) - if phyFrameLen <= 0 { - return nil, ErrInvalidFrame - } - *phyFrame = (*phyFrame)[:phyFrameLen] - - if loopback { - // If the frame is a loopback, we don't send it out. - // TODO: we should be clear if this is L2 or L3 loopback. - // Either way for now there will be no loopbacks expected in L3 mode. - loopbackPacket := make([]byte, phyFrameLen) - copy(loopbackPacket, *phyFrame) - return loopbackPacket, nil - } - - eth := header.Ethernet((*phyFrame)[:header.EthernetMinimumSize]) - ethType := eth.Type() - - // Extract the destination address and payload offset from the frame - var payloadOffset int - var raddr *net.UDPAddr - if ethType == header.IPv6ProtocolNumber { - payloadOffset = udp.PayloadOffsetIPv6 - - ip := header.IPv6((*phyFrame)[header.EthernetMinimumSize:]) - udp := header.UDP(ip.Payload()) - - raddr = &net.UDPAddr{ - IP: net.IP(ip.DestinationAddressSlice()), - Port: int(udp.DestinationPort()), - } - } else if ethType == header.IPv4ProtocolNumber { - payloadOffset = udp.PayloadOffsetIPv4 - - ip := header.IPv4((*phyFrame)[header.EthernetMinimumSize:]) - udp := header.UDP(ip.Payload()) - - raddr = &net.UDPAddr{ - IP: net.IP(ip.DestinationAddressSlice()), - Port: int(udp.DestinationPort()), - } - } else { - return nil, fmt.Errorf("unsupported ethertype: %d", ethType) - } - - // Send the packet out - _, err := c.pc.WriteTo((*phyFrame)[payloadOffset:], raddr) - if err != nil { - return nil, fmt.Errorf("failed to write packet: %w", err) - } - - return nil, nil -} - -func toFullAddress(addr *net.UDPAddr) *tcpip.FullAddress { - if addr.IP.To4() != nil { - return &tcpip.FullAddress{ - Addr: tcpip.AddrFrom4Slice(addr.IP.To4()[:]), - Port: uint16(addr.Port), - } - } else { - return &tcpip.FullAddress{ - Addr: tcpip.AddrFrom16Slice(addr.IP.To16()[:]), - Port: uint16(addr.Port), - } - } -} diff --git a/pkg/tunnel/connection/icx_conn_test.go b/pkg/tunnel/connection/icx_conn_test.go deleted file mode 100644 index 741c033..0000000 --- a/pkg/tunnel/connection/icx_conn_test.go +++ /dev/null @@ -1,148 +0,0 @@ -package connection_test - -import ( - "log/slog" - "net" - "net/netip" - "testing" - "time" - - "github.com/apoxy-dev/icx" - "github.com/stretchr/testify/require" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/header" - - "github.com/apoxy-dev/apoxy/pkg/tunnel/connection" -) - -func TestICXConn(t *testing.T) { - if testing.Verbose() { - slog.SetLogLoggerLevel(slog.LevelDebug) - } - - // Create two local packet connections - laddr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0") - require.NoError(t, err) - - pc1, err := net.ListenUDP("udp", laddr) - require.NoError(t, err) - - fa1 := mustNewFullAddress(pc1.LocalAddr().String()) - - laddr, err = net.ResolveUDPAddr("udp", "127.0.0.1:0") - require.NoError(t, err) - - pc2, err := net.ListenUDP("udp", laddr) - require.NoError(t, err) - - fa2 := mustNewFullAddress(pc2.LocalAddr().String()) - - vni := uint(0x12345) - - var key [16]byte - copy(key[:], []byte("0123456789abcdef")) - - // Setup ICX handlers - handler1, err := icx.NewHandler(icx.WithLocalAddr(fa1), - icx.WithVirtMAC(tcpip.GetRandMacAddr()), icx.WithLayer3VirtFrames()) - require.NoError(t, err) - - err = handler1.AddVirtualNetwork( - vni, - fa2, - []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")}, - ) - require.NoError(t, err) - - err = handler1.UpdateVirtualNetworkKeys( - vni, - 1, key, key, - time.Now().Add(10*time.Minute), - ) - require.NoError(t, err) - - handler2, err := icx.NewHandler(icx.WithLocalAddr(fa2), - icx.WithVirtMAC(tcpip.GetRandMacAddr()), icx.WithLayer3VirtFrames()) - require.NoError(t, err) - - err = handler2.AddVirtualNetwork( - vni, - fa1, - []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")}, - ) - require.NoError(t, err) - - err = handler2.UpdateVirtualNetworkKeys( - vni, - 1, key, key, - time.Now().Add(10*time.Minute), - ) - require.NoError(t, err) - - // Build ICX connections - conn1, err := connection.NewICXConn(pc1, handler1) - require.NoError(t, err) - t.Cleanup(func() { - require.NoError(t, conn1.Close()) - }) - - conn2, err := connection.NewICXConn(pc2, handler2) - require.NoError(t, err) - t.Cleanup(func() { - require.NoError(t, conn2.Close()) - }) - - // Send a packet from conn1 to conn2 - ipPacket := makeIPv4UDPPacket() - _, err = conn1.WritePacket(ipPacket) - require.NoError(t, err) - - // Read the packet on conn2 - buf := make([]byte, 1500) - n, err := conn2.ReadPacket(buf) - require.NoError(t, err) - require.Greater(t, n, 0) - require.Equal(t, ipPacket, buf[:n]) -} - -func makeIPv4UDPPacket() []byte { - ipPacket := make([]byte, header.IPv4MinimumSize+header.UDPMinimumSize) - - ip := header.IPv4(ipPacket) - ip.Encode(&header.IPv4Fields{ - TotalLength: uint16(len(ipPacket)), - TTL: 64, - Protocol: uint8(header.UDPProtocolNumber), - SrcAddr: tcpip.AddrFrom4Slice(net.IPv4(192, 168, 1, 1).To4()), - DstAddr: tcpip.AddrFrom4Slice(net.IPv4(192, 168, 1, 2).To4()), - }) - ip.SetChecksum(^ip.CalculateChecksum()) - - udp := header.UDP(ipPacket[header.IPv4MinimumSize:]) - udp.Encode(&header.UDPFields{ - SrcPort: 1234, - DstPort: 5678, - Length: header.UDPMinimumSize, - }) - - return ipPacket -} - -func mustNewFullAddress(addrPortStr string) *tcpip.FullAddress { - addrPort := netip.MustParseAddrPort(addrPortStr) - - switch addrPort.Addr().BitLen() { - case 32: - return &tcpip.FullAddress{ - Addr: tcpip.AddrFrom4Slice(addrPort.Addr().AsSlice()), - Port: addrPort.Port(), - } - case 128: - return &tcpip.FullAddress{ - Addr: tcpip.AddrFrom16Slice(addrPort.Addr().AsSlice()), - Port: addrPort.Port(), - } - default: - panic("Unsupported IP address length") - } -} diff --git a/pkg/tunnel/relay.go b/pkg/tunnel/relay.go index 4c911af..ba018dc 100644 --- a/pkg/tunnel/relay.go +++ b/pkg/tunnel/relay.go @@ -27,6 +27,7 @@ import ( "github.com/apoxy-dev/apoxy/pkg/tunnel/api" "github.com/apoxy-dev/apoxy/pkg/tunnel/controllers" "github.com/apoxy-dev/apoxy/pkg/tunnel/hasher" + "github.com/apoxy-dev/apoxy/pkg/tunnel/router" ) const ( @@ -40,19 +41,21 @@ type Relay struct { cert tls.Certificate handler *icx.Handler idHasher *hasher.Hasher + router router.Router tokens *haxmap.Map[string, string] // map[tunnelName]token conns *haxmap.Map[string, *adapter.Connection] // map[connectionID]Connection onConnect func(ctx context.Context, agentName string, conn controllers.Connection) error onDisconnect func(ctx context.Context, agentName, id string) error } -func NewRelay(name string, pc net.PacketConn, cert tls.Certificate, handler *icx.Handler, idHasher *hasher.Hasher) *Relay { +func NewRelay(name string, pc net.PacketConn, cert tls.Certificate, handler *icx.Handler, idHasher *hasher.Hasher, router router.Router) *Relay { return &Relay{ name: name, pc: pc, cert: cert, handler: handler, idHasher: idHasher, + router: router, tokens: haxmap.New[string, string](), conns: haxmap.New[string, *adapter.Connection](), } @@ -117,10 +120,9 @@ func (r *Relay) Start(ctx context.Context) error { slog.Info("Stopping relay", slog.String("addr", ln.Addr().String())) - // TODO (dpeckett): implement ICX router - /*if err := r.router.Close(); err != nil { + if err := r.router.Close(); err != nil { slog.Error("Failed to close router", slog.Any("error", err)) - }*/ + } shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() @@ -132,18 +134,18 @@ func (r *Relay) Start(ctx context.Context) error { return srv.Close() }) + // Start the router to handle network traffic. g.Go(func() error { - slog.Info("Starting relay", slog.String("addr", ln.Addr().String())) - return srv.ServeListener(ln) + return r.router.Start(ctx) }) - // TODO (dpeckett): implement ICX router - /* - // Start the router to handle network traffic. - g.Go(func() error { - return r.router.Start(ctx) - }) - */ + g.Go(func() error { + slog.Info("Starting relay", slog.String("addr", ln.Addr().String())) + if err := srv.ServeListener(ln); err != nil && !errors.Is(err, http.ErrServerClosed) { + return err + } + return nil + }) return g.Wait() } diff --git a/pkg/tunnel/relay_test.go b/pkg/tunnel/relay_test.go index d517f22..d26e876 100644 --- a/pkg/tunnel/relay_test.go +++ b/pkg/tunnel/relay_test.go @@ -15,14 +15,18 @@ import ( "github.com/apoxy-dev/icx" "github.com/quic-go/quic-go" "github.com/quic-go/quic-go/http3" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "gvisor.dev/gvisor/pkg/tcpip" "github.com/apoxy-dev/apoxy/pkg/cryptoutils" + "github.com/apoxy-dev/apoxy/pkg/netstack" "github.com/apoxy-dev/apoxy/pkg/tunnel" "github.com/apoxy-dev/apoxy/pkg/tunnel/api" + "github.com/apoxy-dev/apoxy/pkg/tunnel/connection" "github.com/apoxy-dev/apoxy/pkg/tunnel/controllers" "github.com/apoxy-dev/apoxy/pkg/tunnel/hasher" + "github.com/apoxy-dev/apoxy/pkg/tunnel/router" ) func TestRelay_Connect_UpdateKeys_Disconnect(t *testing.T) { @@ -93,7 +97,7 @@ func startRelay(t *testing.T, token string, onConnect func(context.Context, stri caCert, serverCert, err := cryptoutils.GenerateSelfSignedTLSCert("localhost") require.NoError(t, err) - h, err := icx.NewHandler(icx.WithLocalAddr(toFullAddress(netip.MustParseAddrPort("127.0.0.1:6081"))), + h, err := icx.NewHandler(icx.WithLocalAddr(netstack.ToFullAddress(netip.MustParseAddrPort("127.0.0.1:6081"))), icx.WithVirtMAC(tcpip.GetRandMacAddr())) require.NoError(t, err) @@ -103,7 +107,12 @@ func startRelay(t *testing.T, token string, onConnect func(context.Context, stri idHasher := hasher.NewHasher(idKey) - r := tunnel.NewRelay("relay-it", pc, serverCert, h, idHasher) + rtr := &mockRouter{} + + rtr.On("Start", mock.Anything).Return(nil) + rtr.On("Close").Return(nil) + + r := tunnel.NewRelay("relay-it", pc, serverCert, h, idHasher, rtr) r.SetCredentials("test-tunnel", token) r.SetOnConnect(onConnect) r.SetOnDisconnect(onDisconnect) @@ -112,7 +121,9 @@ func startRelay(t *testing.T, token string, onConnect func(context.Context, stri done := make(chan struct{}) go func() { - _ = r.Start(ctx) // on shutdown we don't assert the error path + if err := r.Start(ctx); err != nil { + t.Errorf("Relay stopped with error: %v", err) + } close(done) }() @@ -211,18 +222,63 @@ func clientForRelay(t *testing.T, r *tunnel.Relay, caCert tls.Certificate, token return c } -func toFullAddress(addrPort netip.AddrPort) *tcpip.FullAddress { - if addrPort.Addr().Is4() { - addrv4 := addrPort.Addr().As4() - return &tcpip.FullAddress{ - Addr: tcpip.AddrFrom4Slice(addrv4[:]), - Port: uint16(addrPort.Port()), - } - } else { - addrv6 := addrPort.Addr().As16() - return &tcpip.FullAddress{ - Addr: tcpip.AddrFrom16Slice(addrv6[:]), - Port: uint16(addrPort.Port()), - } +type mockRouter struct { + mock.Mock +} + +func (m *mockRouter) Start(ctx context.Context) error { + args := m.Called(ctx) + return args.Error(0) +} + +func (m *mockRouter) AddAddr(addr netip.Prefix, tun connection.Connection) error { + args := m.Called(addr, tun) + return args.Error(0) +} + +func (m *mockRouter) ListAddrs() ([]netip.Prefix, error) { + args := m.Called() + var addrs []netip.Prefix + if v := args.Get(0); v != nil { + addrs = v.([]netip.Prefix) } + return addrs, args.Error(1) +} + +func (m *mockRouter) DelAddr(addr netip.Prefix) error { + args := m.Called(addr) + return args.Error(0) +} + +func (m *mockRouter) AddRoute(dst netip.Prefix) error { + args := m.Called(dst) + return args.Error(0) +} + +func (m *mockRouter) DelRoute(dst netip.Prefix) error { + args := m.Called(dst) + return args.Error(0) +} + +func (m *mockRouter) ListRoutes() ([]router.TunnelRoute, error) { + args := m.Called() + var routes []router.TunnelRoute + if v := args.Get(0); v != nil { + routes = v.([]router.TunnelRoute) + } + return routes, args.Error(1) +} + +func (m *mockRouter) LocalAddresses() ([]netip.Prefix, error) { + args := m.Called() + var addrs []netip.Prefix + if v := args.Get(0); v != nil { + addrs = v.([]netip.Prefix) + } + return addrs, args.Error(1) +} + +func (m *mockRouter) Close() error { + args := m.Called() + return args.Error(0) } diff --git a/pkg/tunnel/router/client_icx_netstack.go b/pkg/tunnel/router/client_icx_netstack.go new file mode 100644 index 0000000..c922ac3 --- /dev/null +++ b/pkg/tunnel/router/client_icx_netstack.go @@ -0,0 +1,211 @@ +package router + +import ( + "context" + "errors" + "fmt" + "log/slog" + "net" + "net/netip" + "strconv" + "sync" + + "github.com/apoxy-dev/icx" + "github.com/dpeckett/network" + "golang.org/x/sync/errgroup" + "gvisor.dev/gvisor/pkg/tcpip" + + "github.com/apoxy-dev/apoxy/pkg/netstack" + "github.com/apoxy-dev/apoxy/pkg/socksproxy" + "github.com/apoxy-dev/apoxy/pkg/tunnel/connection" + "github.com/apoxy-dev/apoxy/pkg/tunnel/l2pc" +) + +var ( + _ Router = (*ICXNetstackRouter)(nil) +) + +type ICXNetstackRouter struct { + Handler *icx.Handler + phy *l2pc.L2PacketConn + net *netstack.ICXNetwork + proxy *socksproxy.ProxyServer + closeOnce sync.Once +} + +func NewICXNetstackRouter(pc net.PacketConn, mtu int, opts ...Option) (*ICXNetstackRouter, error) { + options := defaultOptions() + for _, opt := range opts { + opt(options) + } + + phy, err := l2pc.NewL2PacketConn(pc) + if err != nil { + return nil, fmt.Errorf("failed to create L2 packet connection phy: %w", err) + } + + localUDPAddr := pc.LocalAddr().(*net.UDPAddr) + + localAddr := netip.AddrPortFrom(netip.MustParseAddr(localUDPAddr.IP.String()), + uint16(localUDPAddr.Port)) + + handler, err := icx.NewHandler( + icx.WithLocalAddr(netstack.ToFullAddress(localAddr)), + icx.WithVirtMAC(tcpip.GetRandMacAddr()), icx.WithLayer3VirtFrames()) + if err != nil { + return nil, fmt.Errorf("failed to create ICX handler: %w", err) + } + + net, err := netstack.NewICXNetwork(handler, phy, mtu, options.resolveConf, options.pcapPath) + if err != nil { + _ = phy.Close() + return nil, fmt.Errorf("failed to create ICX network: %w", err) + } + + proxy := socksproxy.NewServer( + options.socksListenAddr, + net.Network, + network.Host(), + ) + + return &ICXNetstackRouter{ + Handler: handler, + phy: phy, + net: net, + proxy: proxy, + }, nil +} + +func (r *ICXNetstackRouter) Close() error { + var err error + r.closeOnce.Do(func() { + if err = r.proxy.Close(); err != nil { + err = fmt.Errorf("failed to close SOCKS proxy: %w", err) + return + } + + if err = r.net.Close(); err != nil { + err = fmt.Errorf("failed to close ICX network: %w", err) + return + } + + err = r.phy.Close() + }) + return err +} + +// Start initializes the router and starts forwarding traffic. +// It's a blocking call that should be run in a separate goroutine. +func (r *ICXNetstackRouter) Start(ctx context.Context) error { + _, socksListenPortStr, err := net.SplitHostPort(r.proxy.Addr) + if err != nil { + return fmt.Errorf("failed to parse SOCKS listen address: %w", err) + } + + socksListenPort, err := strconv.Atoi(socksListenPortStr) + if err != nil { + return fmt.Errorf("failed to parse SOCKS listen port: %w", err) + } + + slog.Info("Forwarding all inbound traffic to loopback interface") + + if err := r.net.ForwardTo(ctx, network.Filtered(&network.FilteredNetworkConfig{ + DeniedPorts: []uint16{uint16(socksListenPort)}, + Upstream: network.Host(), + })); err != nil { + return fmt.Errorf("failed to forward to loopback: %w", err) + } + + g, ctx := errgroup.WithContext(ctx) + + g.Go(func() error { + <-ctx.Done() + slog.Debug("Closing router") + return r.Close() + }) + + g.Go(func() error { + slog.Info("Splicing packets between netstack and ICX") + + // This will be terminated when the router is closed. + if err := r.net.Start(); err != nil && !errors.Is(err, net.ErrClosed) { + return fmt.Errorf("failed to splice packets: %w", err) + } + + return nil + }) + + g.Go(func() error { + slog.Info("Starting SOCKS5 proxy", slog.String("listenAddr", r.proxy.Addr)) + + if err := r.proxy.ListenAndServe(ctx); err != nil { + slog.Error("SOCKS proxy error", slog.String("error", err.Error())) + } + + return nil + }) + + return g.Wait() +} + +// AddAddr adds a tun with an associated address to the router. +func (r *ICXNetstackRouter) AddAddr(addr netip.Prefix, tun connection.Connection) error { + if err := r.net.AddAddr(addr); err != nil { + return fmt.Errorf("failed to add address to ICX network: %w", err) + } + + return nil +} + +// ListAddrs returns a list of all addresses currently managed by the router. +func (r *ICXNetstackRouter) ListAddrs() ([]netip.Prefix, error) { + return r.net.LocalAddresses() +} + +// DelAddr removes a tun by its addr from the router. +func (r *ICXNetstackRouter) DelAddr(addr netip.Prefix) error { + if err := r.net.DelAddr(addr); err != nil { + return fmt.Errorf("failed to remove address from ICX network: %w", err) + } + + return nil +} + +// AddRoute adds a dst prefix to be routed through the given tunnel connection. +// If multiple tunnels are provided, the router will distribute traffic across them +// uniformly. +func (r *ICXNetstackRouter) AddRoute(dst netip.Prefix) error { + return nil +} + +// Del removes a routing associations for a given destination prefix and Connection name. +// New matching flows will stop being routed through the tunnel immediately while +// existing flows may continue to use the tunnel for some draining period before +// getting re-routed via a different tunnel or dropped (if no tunnel is available for +// the given dst). +func (r *ICXNetstackRouter) DelRoute(dst netip.Prefix) error { + return nil +} + +// ListRoutes returns a list of all routes currently managed by the router. +func (r *ICXNetstackRouter) ListRoutes() ([]TunnelRoute, error) { + localAddrs, err := r.net.LocalAddresses() + if err != nil { + return nil, fmt.Errorf("failed to list local addresses: %w", err) + } + + var routes []TunnelRoute + for _, addr := range localAddrs { + routes = append(routes, TunnelRoute{ + Dst: addr, + State: TunnelRouteStateActive, + }) + } + + return routes, nil +} + +// LocalAddresses returns the list of local addresses that are assigned to the router. +func (r *ICXNetstackRouter) LocalAddresses() ([]netip.Prefix, error) { + return r.net.LocalAddresses() +} diff --git a/pkg/tunnel/router/options.go b/pkg/tunnel/router/options.go index 2c970f6..ed739b9 100644 --- a/pkg/tunnel/router/options.go +++ b/pkg/tunnel/router/options.go @@ -19,6 +19,7 @@ type routerOptions struct { socksListenAddr string cksumRecalc bool preserveDefaultGwDsts []netip.Prefix + sourcePortHashing bool } func defaultOptions() *routerOptions { @@ -97,3 +98,11 @@ func WithPreserveDefaultGwDsts(dsts []netip.Prefix) Option { o.preserveDefaultGwDsts = dsts } } + +// WithSourcePortHashing enables or disables source port hashing for routing decisions. +// Only valid for ICX routers. +func WithSourcePortHashing(enable bool) Option { + return func(o *routerOptions) { + o.sourcePortHashing = enable + } +} diff --git a/pkg/tunnel/router/server_icx_netlink.go b/pkg/tunnel/router/server_icx_netlink.go new file mode 100644 index 0000000..7366e12 --- /dev/null +++ b/pkg/tunnel/router/server_icx_netlink.go @@ -0,0 +1,9 @@ +//go:build !linux + +package router + +import "fmt" + +func NewICXNetlinkRouter(_ ...Option) (Router, error) { + return nil, fmt.Errorf("netlink router is not supported on this platform") +} diff --git a/pkg/tunnel/router/server_icx_netlink_linux.go b/pkg/tunnel/router/server_icx_netlink_linux.go new file mode 100644 index 0000000..f6d79aa --- /dev/null +++ b/pkg/tunnel/router/server_icx_netlink_linux.go @@ -0,0 +1,256 @@ +package router + +import ( + "context" + "errors" + "fmt" + "math" + "net" + "net/netip" + "os" + "sync" + + "github.com/apoxy-dev/icx" + "github.com/apoxy-dev/icx/filter" + "github.com/apoxy-dev/icx/tunnel" + "github.com/apoxy-dev/icx/veth" + "github.com/google/gopacket/layers" + "github.com/google/gopacket/pcapgo" + "github.com/slavc/xdp" + "github.com/vishvananda/netlink" + "gvisor.dev/gvisor/pkg/tcpip" + + "github.com/apoxy-dev/apoxy/pkg/netstack" + "github.com/apoxy-dev/apoxy/pkg/tunnel/connection" +) + +const ( + icxDefaultPort = 6081 + extPathMTU = 1500 +) + +var ( + _ Router = (*ICXNetlinkRouter)(nil) +) + +type ICXNetlinkRouter struct { + Handler *icx.Handler + vethDev *veth.Handle + ingressFilter *xdp.Program + pcapFile *os.File + tun *tunnel.Tunnel + closeOnce sync.Once +} + +func NewICXNetlinkRouter(opts ...Option) (*ICXNetlinkRouter, error) { + options := defaultOptions() + for _, opt := range opts { + opt(options) + } + + phy, err := netlink.LinkByName(options.extIfaceName) + if err != nil { + return nil, fmt.Errorf("failed to find interface %s: %w", options.extIfaceName, err) + } + + addrs, err := addrsForInterface(phy, icxDefaultPort) + if err != nil { + return nil, fmt.Errorf("failed to get addresses for interface %s: %w", options.extIfaceName, err) + } + + localAddr, err := selectSourceAddr(addrs) + if err != nil { + return nil, fmt.Errorf("failed to select source address: %w", err) + } + + numQueues, err := tunnel.NumQueues(phy) + if err != nil { + return nil, fmt.Errorf("failed to get number of TX queues for interface %s: %w", options.extIfaceName, err) + } + + vethDev, err := veth.Create(options.tunIfaceName, numQueues, icx.MTU(extPathMTU)) + if err != nil { + return nil, fmt.Errorf("failed to create veth device: %w", err) + } + + virtMAC := tcpip.LinkAddress(vethDev.Link.Attrs().HardwareAddr) + + handlerOpts := []icx.HandlerOption{ + icx.WithLocalAddr(localAddr), + icx.WithVirtMAC(virtMAC), + } + if options.sourcePortHashing { + handlerOpts = append(handlerOpts, icx.WithSourcePortHashing()) + } + + handler, err := icx.NewHandler(handlerOpts...) + if err != nil { + _ = vethDev.Close() + return nil, fmt.Errorf("failed to create handler: %w", err) + } + + ingressFilter, err := filter.Bind(addrs...) + if err != nil { + _ = vethDev.Close() + return nil, fmt.Errorf("failed to create ingress filter: %w", err) + } + + var pcapFile *os.File + var pcapWriter *pcapgo.Writer + if options.pcapPath != "" { + pcapFile, err = os.Create(options.pcapPath) + if err != nil { + _ = vethDev.Close() + _ = ingressFilter.Close() + return nil, fmt.Errorf("failed to create pcap file: %w", err) + } + + pcapWriter = pcapgo.NewWriter(pcapFile) + if err := pcapWriter.WriteFileHeader(uint32(math.MaxUint16), layers.LinkTypeEthernet); err != nil { + return nil, fmt.Errorf("failed to write PCAP header: %w", err) + } + } + + tun, err := tunnel.NewTunnel(options.extIfaceName, vethDev.Peer.Attrs().Name, ingressFilter, handler, pcapWriter) + if err != nil { + _ = vethDev.Close() + _ = ingressFilter.Close() + return nil, fmt.Errorf("failed to create tunnel: %w", err) + } + + return &ICXNetlinkRouter{ + Handler: handler, + vethDev: vethDev, + ingressFilter: ingressFilter, + pcapFile: pcapFile, + tun: tun, + }, nil +} + +func (r *ICXNetlinkRouter) Close() error { + var firstErr error + r.closeOnce.Do(func() { + if err := r.tun.Close(); err != nil && firstErr == nil { + firstErr = err + } + if err := r.vethDev.Close(); err != nil && firstErr == nil { + firstErr = err + } + if err := r.ingressFilter.Close(); err != nil && firstErr == nil { + firstErr = err + } + if err := r.pcapFile.Close(); err != nil && firstErr == nil { + firstErr = err + } + }) + return firstErr +} + +// Start initializes the router and starts forwarding traffic. +// It's a blocking call that should be run in a separate goroutine. +func (r *ICXNetlinkRouter) Start(ctx context.Context) error { + if err := r.tun.Start(ctx); err != nil && !errors.Is(err, context.Canceled) { + return fmt.Errorf("failed to start tunnel: %w", err) + } + + return nil +} + +// AddAddr adds a tun with an associated address to the router. +func (r *ICXNetlinkRouter) AddAddr(addr netip.Prefix, tun connection.Connection) error { + // TODO (dpeckett): implement + return nil +} + +// ListAddrs returns a list of all addresses currently managed by the router. +func (r *ICXNetlinkRouter) ListAddrs() ([]netip.Prefix, error) { + // TODO (dpeckett): implement + return nil, nil +} + +// DelAddr removes a tun by its addr from the router. +func (r *ICXNetlinkRouter) DelAddr(addr netip.Prefix) error { + // TODO (dpeckett): implement + return nil +} + +// AddRoute adds a dst prefix to be routed through the given tunnel connection. +// If multiple tunnels are provided, the router will distribute traffic across them +// uniformly. +func (r *ICXNetlinkRouter) AddRoute(dst netip.Prefix) error { + // TODO (dpeckett): implement + return nil +} + +// Del removes a routing associations for a given destination prefix and Connection name. +// New matching flows will stop being routed through the tunnel immediately while +// existing flows may continue to use the tunnel for some draining period before +// getting re-routed via a different tunnel or dropped (if no tunnel is available for +// the given dst). +func (r *ICXNetlinkRouter) DelRoute(dst netip.Prefix) error { + // TODO (dpeckett): implement + return nil +} + +// ListRoutes returns a list of all routes currently managed by the router. +func (r *ICXNetlinkRouter) ListRoutes() ([]TunnelRoute, error) { + // TODO (dpeckett): implement + return nil, nil +} + +// LocalAddresses returns the list of local addresses that are assigned to the router. +func (r *ICXNetlinkRouter) LocalAddresses() ([]netip.Prefix, error) { + // TODO (dpeckett): implement + return nil, nil +} + +func addrsForInterface(link netlink.Link, port int) ([]net.Addr, error) { + nlAddrs, err := netlink.AddrList(link, netlink.FAMILY_ALL) + if err != nil { + return nil, fmt.Errorf("failed to get addresses for interface: %w", err) + } + + var addrs []net.Addr + for _, addr := range nlAddrs { + if addr.IP == nil { + continue + } + addrs = append(addrs, &net.UDPAddr{ + IP: addr.IP, + Port: port, + }) + } + + return addrs, nil +} + +func selectSourceAddr(addrs []net.Addr) (*tcpip.FullAddress, error) { + var localUDP *net.UDPAddr + bestScore := -1 + for _, a := range addrs { + if ua, ok := a.(*net.UDPAddr); ok && ua.IP != nil { + score := 0 + if ua.IP.IsGlobalUnicast() { + // Prefer IPv4 over IPv6 for most underlays unless otherwise configured. + if ua.IP.To4() != nil { + score = 3 + } else { + score = 2 + } + } else { + // Still consider non-global addresses as a last resort. + score = 1 + } + if score > bestScore { + bestScore = score + localUDP = ua + } + } + } + + if localUDP == nil { + return nil, fmt.Errorf("no valid UDP address found") + } + + return netstack.ToFullAddress(netip.MustParseAddrPort(localUDP.String())), nil +}