diff --git a/go.mod b/go.mod index 5f009d6b..75886586 100644 --- a/go.mod +++ b/go.mod @@ -222,7 +222,6 @@ require ( github.com/hashicorp/mdns v1.0.1 // indirect github.com/hashicorp/memberlist v0.5.0 // indirect github.com/hashicorp/vic v1.5.1-0.20190403131502-bbfe86ec9443 // indirect - github.com/hedzr/go-ringbuf/v2 v2.2.1 // indirect github.com/iancoleman/strcase v0.3.0 // indirect github.com/imdario/mergo v0.3.16 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect diff --git a/go.sum b/go.sum index 215340f1..abc20321 100644 --- a/go.sum +++ b/go.sum @@ -612,8 +612,6 @@ github.com/hashicorp/memberlist v0.5.0 h1:EtYPN8DpAURiapus508I4n9CzHs2W+8NZGbmmR github.com/hashicorp/memberlist v0.5.0/go.mod h1:yvyXLpo0QaGE59Y7hDTsTzDD25JYBZ4mHgHUZ8lrOI0= github.com/hashicorp/vic v1.5.1-0.20190403131502-bbfe86ec9443 h1:O/pT5C1Q3mVXMyuqg7yuAWUg/jMZR1/0QTzTRdNR6Uw= github.com/hashicorp/vic v1.5.1-0.20190403131502-bbfe86ec9443/go.mod h1:bEpDU35nTu0ey1EXjwNwPjI9xErAsoOCmcMb9GKvyxo= -github.com/hedzr/go-ringbuf/v2 v2.2.1 h1:bnIRxSCWYt4vs5UCDCOYf+r1C8cQC7tkcOdjOTaVzNk= -github.com/hedzr/go-ringbuf/v2 v2.2.1/go.mod h1:N3HsRpbHvPkX9GsykpkPoR2vD6WRR6GbU7tx/9GLE4M= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/iancoleman/strcase v0.3.0 h1:nTXanmYxhfFAMjZL34Ov6gkzEsSJZ5DbhxWjvSASxEI= github.com/iancoleman/strcase v0.3.0/go.mod h1:iwCmte+B7n89clKwxIoIXy/HfoL7AsD47ZCWhYzw7ho= diff --git a/pkg/tunnel/connection/pipe.go b/pkg/tunnel/connection/pipe.go deleted file mode 100644 index 0822d56e..00000000 --- a/pkg/tunnel/connection/pipe.go +++ /dev/null @@ -1,130 +0,0 @@ -package connection - -import ( - "context" - "errors" - "runtime" - "sync" - - "github.com/hedzr/go-ringbuf/v2" - "github.com/hedzr/go-ringbuf/v2/mpmc" -) - -var _ Connection = (*Pipe)(nil) - -type Pipe struct { - readRing mpmc.RingBuffer[[]byte] - writeRing mpmc.RingBuffer[[]byte] - ctx context.Context - cancel context.CancelFunc - bufPool *sync.Pool -} - -// NewPipe creates a pair of connected Pipe instances for bidirectional communication. -// Note: I have seen packet loss on ARM64 platforms, I believe this is due to the -// weaker memory model of ARM64, we should really dig into this, but for now -// dropping 0.001% of packets is not a big deal, we can just retry. -func NewPipe(ctx context.Context, mtu int) (*Pipe, *Pipe) { - bufPool := sync.Pool{ - New: func() interface{} { - b := make([]byte, mtu) - return &b - }, - } - - ringAtoB := ringbuf.New[[]byte](1024) - ringBtoA := ringbuf.New[[]byte](1024) - - ctx, cancel := context.WithCancel(ctx) - - pipeA := &Pipe{ - readRing: ringBtoA, - writeRing: ringAtoB, - ctx: ctx, - cancel: cancel, - bufPool: &bufPool, - } - pipeB := &Pipe{ - readRing: ringAtoB, - writeRing: ringBtoA, - ctx: ctx, - cancel: cancel, - bufPool: &bufPool, - } - - return pipeA, pipeB -} - -// ReadPacket reads a packet into the provided buffer. -func (p *Pipe) ReadPacket(buf []byte) (int, error) { - select { - case <-p.ctx.Done(): - return 0, errors.New("pipe closed") - default: - var item []byte - var err error - for { - item, err = p.readRing.Dequeue() - if err != nil { - if errors.Is(err, mpmc.ErrQueueEmpty) { - runtime.Gosched() - - // Has the context been cancelled? - select { - case <-p.ctx.Done(): - return 0, errors.New("pipe closed") - default: - // Continue to try to dequeue - continue - } - } - return 0, err - } - break - } - n := copy(buf, item) - p.bufPool.Put(&item) - return n, nil - } -} - -// WritePacket writes a packet from the provided buffer. -func (p *Pipe) WritePacket(b []byte) ([]byte, error) { - select { - case <-p.ctx.Done(): - return nil, errors.New("pipe closed") - default: - bufPtr := p.bufPool.Get().(*[]byte) - buf := *bufPtr - buf = buf[:len(b)] - copy(buf, b) - - for { - err := p.writeRing.Enqueue(buf) - if err != nil { - if errors.Is(err, mpmc.ErrQueueFull) { - runtime.Gosched() - - // Has the context been cancelled? - select { - case <-p.ctx.Done(): - return nil, errors.New("pipe closed") - default: - // Continue to try to enqueue - continue - } - } - return nil, err - } - break - } - - return nil, nil - } -} - -// Close terminates the pipe. -func (p *Pipe) Close() error { - p.cancel() - return nil -} diff --git a/pkg/tunnel/connection/pipe_test.go b/pkg/tunnel/connection/pipe_test.go deleted file mode 100644 index aa663ac8..00000000 --- a/pkg/tunnel/connection/pipe_test.go +++ /dev/null @@ -1,74 +0,0 @@ -package connection_test - -import ( - "bytes" - "sync/atomic" - "testing" - "time" - - "github.com/apoxy-dev/apoxy-cli/pkg/netstack" - "github.com/apoxy-dev/apoxy-cli/pkg/tunnel/connection" -) - -func TestPipeThroughput(t *testing.T) { - const ( - packetSize = netstack.IPv6MinMTU - numPackets = 10_000_000 - ) - - p1, p2 := connection.NewPipe(t.Context(), packetSize) - - payload := bytes.Repeat([]byte("X"), packetSize) - buf := make([]byte, packetSize) - - var bytesTransferred int64 - var packetsTransferred int64 - - // Reader goroutine - go func() { - for i := 0; i < numPackets; i++ { - if _, err := p2.ReadPacket(buf); err != nil { - select { - case <-t.Context().Done(): - default: - t.Fatalf("Read error: %v", err) - } - return - } - atomic.AddInt64(&bytesTransferred, int64(packetSize)) - atomic.AddInt64(&packetsTransferred, 1) - } - }() - - // Reporter goroutine - ticker := time.NewTicker(1 * time.Second) - defer ticker.Stop() - - go func(startTime time.Time) { - lastTransferred := int64(0) - for range ticker.C { - currentTransferred := atomic.LoadInt64(&bytesTransferred) - bytesThisSecond := currentTransferred - lastTransferred - lastTransferred = currentTransferred - - throughputGbps := (float64(bytesThisSecond*8) / 1e9) - elapsed := time.Since(startTime).Truncate(time.Second) - t.Logf("[+%s] Throughput: %.2f Gbps", elapsed, throughputGbps) - t.Logf("[+%s] Packets: %d", elapsed, atomic.LoadInt64(&packetsTransferred)) - } - }(time.Now()) - - start := time.Now() - - // Writer loop - for i := 0; i < numPackets; i++ { - if _, err := p1.WritePacket(payload); err != nil { - t.Fatal(err) - } - } - - duration := time.Since(start) - - totalThroughputGbps := (float64(packetSize*numPackets*8) / 1e9) / duration.Seconds() - t.Logf("Total Throughput: %.2f Gbps", totalThroughputGbps) -} diff --git a/pkg/tunnel/fasttun/fasttun.go b/pkg/tunnel/fasttun/fasttun.go index 8c2e22fc..8a626caf 100644 --- a/pkg/tunnel/fasttun/fasttun.go +++ b/pkg/tunnel/fasttun/fasttun.go @@ -16,10 +16,6 @@ type Device interface { // MTU returns the device's Maximum Transmission Unit. MTU() (int, error) - // BatchSize returns the recommended number of packets to process in one batch. - // This is useful for optimizing I/O performance. - BatchSize() int - // NewPacketQueue creates a new packet queue for the device. // Each queue is associated with a file descriptor and can be used // concurrently with others. @@ -31,6 +27,10 @@ type Device interface { type PacketQueue interface { io.Closer + // BatchSize returns the recommended number of packets to process in one batch. + // This is useful for optimizing I/O performance. + BatchSize() int + // Read reads packets into the provided buffer slices `pkts` and stores // the size of each packet in `sizes`. // diff --git a/pkg/tunnel/fasttun/fasttun_linux.go b/pkg/tunnel/fasttun/fasttun_linux.go index 73c9cde9..cbee68bb 100644 --- a/pkg/tunnel/fasttun/fasttun_linux.go +++ b/pkg/tunnel/fasttun/fasttun_linux.go @@ -43,7 +43,11 @@ func (d *LinuxDevice) Close() error { } d.packetQueues = nil - return fmt.Errorf("failed to close packet queues: %w", closeErr) + if closeErr != nil { + return fmt.Errorf("failed to close packet queues: %w", closeErr) + } + + return nil } func (d *LinuxDevice) Name() string { @@ -54,10 +58,6 @@ func (d *LinuxDevice) MTU() (int, error) { return d.mtu, nil } -func (d *LinuxDevice) BatchSize() int { - return 64 -} - // NewPacketQueue creates a new packet queue for the device. func (d *LinuxDevice) NewPacketQueue() (PacketQueue, error) { fd, err := unix.Open("/dev/net/tun", unix.O_RDWR|unix.O_CLOEXEC, 0) @@ -121,6 +121,10 @@ func (q *LinuxPacketQueue) Close() error { return q.tunFile.Close() } +func (q *LinuxPacketQueue) BatchSize() int { + return 64 +} + func (q *LinuxPacketQueue) Read(pkts [][]byte, sizes []int) (int, error) { fd := int(q.tunFile.Fd()) timeout := 50 * time.Millisecond diff --git a/pkg/tunnel/fasttun/fasttun_linux_test.go b/pkg/tunnel/fasttun/fasttun_linux_test.go new file mode 100644 index 00000000..c37f2d6d --- /dev/null +++ b/pkg/tunnel/fasttun/fasttun_linux_test.go @@ -0,0 +1,225 @@ +//go:build linux + +package fasttun_test + +import ( + "context" + "errors" + "fmt" + "log/slog" + "net/netip" + "os" + "os/exec" + "runtime" + "strconv" + "testing" + "time" + + "github.com/apoxy-dev/apoxy-cli/pkg/netstack" + "github.com/apoxy-dev/apoxy-cli/pkg/tunnel/fasttun" + "github.com/apoxy-dev/apoxy-cli/pkg/utils/vm" + "github.com/stretchr/testify/require" + "github.com/vishvananda/netlink" + "github.com/vishvananda/netns" + "golang.org/x/sync/errgroup" +) + +func TestLinuxDeviceThroughput(t *testing.T) { + child := vm.RunTestInVM(t, vm.WithPackages("iperf3")) + if !child { + return + } + + if testing.Verbose() { + slog.SetLogLoggerLevel(slog.LevelDebug) + } + + // Is iperf3 installed? + if _, err := exec.LookPath("iperf3"); err != nil { + t.Skipf("skipping test: %v", err) + } + + iperf3Major, iperf3Minor, err := iperf3Version() + require.NoError(t, err) + require.GreaterOrEqual(t, iperf3Major, 3) + parallelSupport := iperf3Major > 3 || (iperf3Major == 3 && iperf3Minor >= 16) + + // Backup the host network namespace + hostns, err := netns.Get() + require.NoError(t, err) + + // Create network namespaces + ns1, err := netns.NewNamed("fasttun-ns1") + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, ns1.Close()) + require.NoError(t, netns.DeleteNamed("fasttun-ns1")) + }) + + ns2, err := netns.NewNamed("fasttun-ns2") + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, ns2.Close()) + require.NoError(t, netns.DeleteNamed("fasttun-ns2")) + }) + + tun1 := fasttun.NewDevice("tun1", netstack.IPv6MinMTU) + t.Cleanup(func() { + require.NoError(t, tun1.Close()) + }) + + // Create packet queues + nPacketQueues := runtime.NumCPU() + tun1Queues := make([]fasttun.PacketQueue, nPacketQueues) + for i := 0; i < nPacketQueues; i++ { + q, err := tun1.NewPacketQueue() + require.NoError(t, err) + tun1Queues[i] = q + } + + err = configureTun(tun1.Name(), netip.MustParsePrefix("fd00::1/64"), hostns, ns1) + require.NoError(t, err) + + tun2 := fasttun.NewDevice("tun2", netstack.IPv6MinMTU) + t.Cleanup(func() { + require.NoError(t, tun2.Close()) + }) + + // Create packet queues + tun2Queues := make([]fasttun.PacketQueue, nPacketQueues) + for i := 0; i < nPacketQueues; i++ { + q, err := tun2.NewPacketQueue() + require.NoError(t, err) + tun2Queues[i] = q + } + + err = configureTun(tun2.Name(), netip.MustParsePrefix("fd00::2/64"), hostns, ns2) + require.NoError(t, err) + + g, ctx := errgroup.WithContext(t.Context()) + + g.Go(func() error { + <-ctx.Done() + + if err := tun1.Close(); err != nil { + return fmt.Errorf("failed to close tun1: %w", err) + } + + if err := tun2.Close(); err != nil { + return fmt.Errorf("failed to close tun2: %w", err) + } + + return nil + }) + + for i := 0; i < nPacketQueues; i++ { + i := i + g.Go(func() error { + runtime.LockOSThread() + defer runtime.UnlockOSThread() + + return fasttun.Splice(ctx, tun1Queues[i], tun2Queues[i], tun1Queues[i].BatchSize(), netstack.IPv6MinMTU) + }) + } + + g.Go(func() error { + // Wait for the TUN devices to be ready + time.Sleep(1 * time.Second) + + // Start iperf3 server in ns1 + server := exec.CommandContext(ctx, "ip", "netns", "exec", "fasttun-ns1", + "iperf3", "-V", "-s") + server.Stdout = os.Stdout + server.Stderr = os.Stderr + if err := server.Start(); err != nil { + return fmt.Errorf("iperf3 server start failed: %w", err) + } + defer server.Process.Kill() + + // Give server time to start + time.Sleep(1 * time.Second) + + // Run iperf3 client in ns2 + clientArgs := []string{"-V", "-C", "cubic", "-c", "fd00::1", "-t", "10"} + if parallelSupport { + clientArgs = append(clientArgs, "-P", strconv.Itoa(nPacketQueues)) + } + + cmdArgs := append([]string{"netns", "exec", "fasttun-ns2", "iperf3"}, clientArgs...) + client := exec.CommandContext(ctx, "ip", cmdArgs...) + client.Stdout = os.Stdout + client.Stderr = os.Stderr + if err := client.Run(); err != nil { + return fmt.Errorf("iperf3 client failed: %w", err) + } + + return context.Canceled // Signal completion. + }) + + if err := g.Wait(); err != nil && !errors.Is(err, context.Canceled) { + t.Fatal(err) + } +} + +func configureTun(name string, addr netip.Prefix, hostns, ns netns.NsHandle) error { + link, err := netlink.LinkByName(name) + if err != nil { + return fmt.Errorf("failed to get TUN device link: %w", err) + } + + // Move the link to the target namespace + if err := netlink.LinkSetNsFd(link, int(ns)); err != nil { + return fmt.Errorf("failed to set TUN device namespace: %w", err) + } + + // Jump to the target namespace + runtime.LockOSThread() + defer runtime.UnlockOSThread() + + if err := netns.Set(ns); err != nil { + return fmt.Errorf("failed to set target namespace: %w", err) + } + defer netns.Set(hostns) + + nlAddr, err := netlink.ParseAddr(addr.String()) + if err != nil { + return fmt.Errorf("failed to parse TUN device address: %w", err) + } + + if err := netlink.AddrAdd(link, nlAddr); err != nil { + return fmt.Errorf("failed to add address to TUN device: %w", err) + } + + if err := netlink.LinkSetMTU(link, netstack.IPv6MinMTU); err != nil { + return fmt.Errorf("failed to set TUN device MTU: %w", err) + } + + if err := netlink.LinkSetUp(link); err != nil { + return fmt.Errorf("failed to set TUN device up: %w", err) + } + + return nil +} + +func iperf3Version() (major, minor int, err error) { + output, err := exec.Command("iperf3", "-v").Output() + if err != nil { + return 0, 0, fmt.Errorf("iperf3 not found or failed to run: %w", err) + } + + var version string + _, err = fmt.Sscanf(string(output), "iperf %s", &version) + if err != nil { + return 0, 0, fmt.Errorf("failed to parse iperf3 version: %w", err) + } + + n, err := fmt.Sscanf(version, "%d.%d", &major, &minor) + if err != nil { + return 0, 0, fmt.Errorf("failed to extract major.minor version: %w", err) + } + if n != 2 { + return 0, 0, fmt.Errorf("unexpected version format") + } + + return major, minor, nil +} diff --git a/pkg/tunnel/fasttun/fasttun_test.go b/pkg/tunnel/fasttun/fasttun_test.go new file mode 100644 index 00000000..25aa3812 --- /dev/null +++ b/pkg/tunnel/fasttun/fasttun_test.go @@ -0,0 +1,16 @@ +//go:build !linux + +package fasttun_test + +import ( + "testing" + + "github.com/apoxy-dev/apoxy-cli/pkg/utils/vm" +) + +// A stub for non-linux operating systems, when the test is compiled for the VM +// it will use the linux version of this test. +func TestLinuxDeviceThroughput(t *testing.T) { + // Run the test in a linux VM. + vm.RunTestInVM(t) +} diff --git a/pkg/tunnel/fasttun/splice.go b/pkg/tunnel/fasttun/splice.go new file mode 100644 index 00000000..540d37bf --- /dev/null +++ b/pkg/tunnel/fasttun/splice.go @@ -0,0 +1,67 @@ +package fasttun + +import ( + "context" + "errors" + "fmt" + "strings" + + "golang.org/x/sync/errgroup" +) + +// Splice copies packets bidirectionally between two PacketQueues. +func Splice(ctx context.Context, qA, qB PacketQueue, batchSize, mtu int) error { + g, ctx := errgroup.WithContext(ctx) + + copyPackets := func(src, dst PacketQueue) error { + pkts := make([][]byte, batchSize) + for i := range pkts { + pkts[i] = make([]byte, mtu) + } + sizes := make([]int, batchSize) + + for { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + + n, err := src.Read(pkts, sizes) + if err != nil { + return fmt.Errorf("read error: %w", err) + } + if n == 0 { + continue + } + + toWrite := make([][]byte, n) + for i := 0; i < n; i++ { + toWrite[i] = pkts[i][:sizes[i]] + } + + written := 0 + for written < n { + m, err := dst.Write(toWrite[written:]) + if err != nil { + return fmt.Errorf("write error: %w", err) + } + written += m + } + } + } + + g.Go(func() error { + return copyPackets(qA, qB) + }) + + g.Go(func() error { + return copyPackets(qB, qA) + }) + + if err := g.Wait(); err != nil && !(errors.Is(err, context.Canceled) || strings.Contains(err.Error(), "closed")) { + return err + } + + return nil +} diff --git a/pkg/utils/vm/cloud-config.yaml b/pkg/utils/vm/cloud-config.yaml.tmpl similarity index 67% rename from pkg/utils/vm/cloud-config.yaml rename to pkg/utils/vm/cloud-config.yaml.tmpl index 2d9e3052..b091a259 100644 --- a/pkg/utils/vm/cloud-config.yaml +++ b/pkg/utils/vm/cloud-config.yaml.tmpl @@ -6,3 +6,10 @@ users: shell: /bin/bash sudo: ALL=(ALL) NOPASSWD:ALL ssh_pwauth: true +{{- if .Packages }} +package_update: true +packages: +{{- range .Packages }} + - {{ . }} +{{- end }} +{{- end }} \ No newline at end of file diff --git a/pkg/utils/vm/vm.go b/pkg/utils/vm/vm.go index b0a56b3e..61c16cd6 100644 --- a/pkg/utils/vm/vm.go +++ b/pkg/utils/vm/vm.go @@ -12,6 +12,7 @@ import ( "path/filepath" "runtime" "testing" + "text/template" "time" _ "embed" @@ -24,8 +25,8 @@ import ( "golang.org/x/crypto/ssh" ) -//go:embed cloud-config.yaml -var userData string +//go:embed cloud-config.yaml.tmpl +var userDataTemplate string //go:embed metadata.yaml var metaData string @@ -33,10 +34,33 @@ var metaData string //go:embed network-config.yaml var networkConfig string +type Option func(*options) + +type options struct { + packages []string +} + +func defaultOptions() *options { + return &options{} +} + +// WithPackages sets the packages to install in the VM. +func WithPackages(pkgs ...string) Option { + return func(o *options) { + o.packages = append(o.packages, pkgs...) + } +} + // RunTestInVM runs the test as root inside a linux VM using QEMU. -func RunTestInVM(t *testing.T) bool { +func RunTestInVM(t *testing.T, opts ...Option) bool { t.Helper() + options := defaultOptions() + for _, o := range opts { + o(options) + } + + // Check if we are running in a VM if cpuid.CPU.VM() { // We are the child running in the VM, nothing we need to do. return true @@ -98,7 +122,23 @@ func RunTestInVM(t *testing.T) bool { t.Logf("Creating cloud-init ISO at %s...\n", cloudInitISOPath) - err = createCloudInitISO(cloudInitISOFile, userData, networkConfig, metaData) + tmpl, err := template.New("cloud-config").Parse(userDataTemplate) + if err != nil { + t.Fatalf("failed to parse cloud-config template: %v", err) + return false + } + + tmplData := map[string]any{ + "Packages": options.packages, + } + + var userData bytes.Buffer + if err := tmpl.Execute(&userData, tmplData); err != nil { + t.Fatalf("failed to execute cloud-config template: %v", err) + return false + } + + err = createCloudInitISO(cloudInitISOFile, userData.String(), networkConfig, metaData) _ = cloudInitISOFile.Close() if err != nil { t.Fatalf("failed to create cloud-init ISO: %v", err) @@ -127,7 +167,7 @@ func RunTestInVM(t *testing.T) bool { } // Launch the QEMU VM using vmtest - opts := vmtest.QemuOptions{ + qemuOpts := vmtest.QemuOptions{ Architecture: vmtest.QEMU_X86_64, OperatingSystem: vmtest.OS_LINUX, Disks: []vmtest.QemuDisk{ @@ -139,7 +179,7 @@ func RunTestInVM(t *testing.T) bool { CdRom: cloudInitISOPath, } - qemu, err := vmtest.NewQemu(&opts) + qemu, err := vmtest.NewQemu(&qemuOpts) if err != nil { t.Fatalf("failed to create QEMU instance: %v", err) } @@ -151,11 +191,13 @@ func RunTestInVM(t *testing.T) bool { return false } - t.Logf("Compiling test binary from %s...\n", testSourceFile) + testSourceDir := filepath.Dir(testSourceFile) + + t.Logf("Compiling test binary from %s...\n", testSourceDir) // Compile the test binary testBinary := filepath.Join(tempDir, "testbin") - cmd := exec.Command("go", "test", "-c", "-o", testBinary, testSourceFile) + cmd := exec.Command("go", "test", "-c", "-o", testBinary, testSourceDir) cmd.Env = append(os.Environ(), "GOOS=linux", "GOARCH=amd64") cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr