From e661b1065b25af65ec6e016b71695dd2db15666f Mon Sep 17 00:00:00 2001 From: Damian Peckett Date: Thu, 15 May 2025 11:01:40 +0200 Subject: [PATCH] [tun] implement basic multiqueue tun interface --- pkg/tunnel/connection/pipe.go | 31 ++--- pkg/tunnel/connection/pipe_test.go | 47 +++++-- pkg/tunnel/fasttun/fasttun.go | 44 +++++++ pkg/tunnel/fasttun/fasttun_linux.go | 195 ++++++++++++++++++++++++++++ 4 files changed, 292 insertions(+), 25 deletions(-) create mode 100644 pkg/tunnel/fasttun/fasttun.go create mode 100644 pkg/tunnel/fasttun/fasttun_linux.go diff --git a/pkg/tunnel/connection/pipe.go b/pkg/tunnel/connection/pipe.go index aff31796..0822d56e 100644 --- a/pkg/tunnel/connection/pipe.go +++ b/pkg/tunnel/connection/pipe.go @@ -17,17 +17,21 @@ type Pipe struct { writeRing mpmc.RingBuffer[[]byte] ctx context.Context cancel context.CancelFunc -} - -var bufPool = sync.Pool{ - New: func() interface{} { - b := make([]byte, 2048) // Adjust size if needed - return &b - }, + bufPool *sync.Pool } // NewPipe creates a pair of connected Pipe instances for bidirectional communication. -func NewPipe(ctx context.Context) (*Pipe, *Pipe) { +// Note: I have seen packet loss on ARM64 platforms, I believe this is due to the +// weaker memory model of ARM64, we should really dig into this, but for now +// dropping 0.001% of packets is not a big deal, we can just retry. +func NewPipe(ctx context.Context, mtu int) (*Pipe, *Pipe) { + bufPool := sync.Pool{ + New: func() interface{} { + b := make([]byte, mtu) + return &b + }, + } + ringAtoB := ringbuf.New[[]byte](1024) ringBtoA := ringbuf.New[[]byte](1024) @@ -38,12 +42,14 @@ func NewPipe(ctx context.Context) (*Pipe, *Pipe) { writeRing: ringAtoB, ctx: ctx, cancel: cancel, + bufPool: &bufPool, } pipeB := &Pipe{ readRing: ringAtoB, writeRing: ringBtoA, ctx: ctx, cancel: cancel, + bufPool: &bufPool, } return pipeA, pipeB @@ -66,7 +72,6 @@ func (p *Pipe) ReadPacket(buf []byte) (int, error) { // Has the context been cancelled? select { case <-p.ctx.Done(): - bufPool.Put(&item) return 0, errors.New("pipe closed") default: // Continue to try to dequeue @@ -78,7 +83,7 @@ func (p *Pipe) ReadPacket(buf []byte) (int, error) { break } n := copy(buf, item) - bufPool.Put(&item) + p.bufPool.Put(&item) return n, nil } } @@ -89,11 +94,8 @@ func (p *Pipe) WritePacket(b []byte) ([]byte, error) { case <-p.ctx.Done(): return nil, errors.New("pipe closed") default: - bufPtr := bufPool.Get().(*[]byte) + bufPtr := p.bufPool.Get().(*[]byte) buf := *bufPtr - if cap(buf) < len(b) { - buf = make([]byte, len(b)) - } buf = buf[:len(b)] copy(buf, b) @@ -106,7 +108,6 @@ func (p *Pipe) WritePacket(b []byte) ([]byte, error) { // Has the context been cancelled? select { case <-p.ctx.Done(): - bufPool.Put(&buf) return nil, errors.New("pipe closed") default: // Continue to try to enqueue diff --git a/pkg/tunnel/connection/pipe_test.go b/pkg/tunnel/connection/pipe_test.go index 8e90f2d5..aa663ac8 100644 --- a/pkg/tunnel/connection/pipe_test.go +++ b/pkg/tunnel/connection/pipe_test.go @@ -2,46 +2,73 @@ package connection_test import ( "bytes" + "sync/atomic" "testing" "time" + "github.com/apoxy-dev/apoxy-cli/pkg/netstack" "github.com/apoxy-dev/apoxy-cli/pkg/tunnel/connection" ) func TestPipeThroughput(t *testing.T) { const ( - packetSize = 1024 // 1 KB per packet - numPackets = 1_000_000 // Total packets to send + packetSize = netstack.IPv6MinMTU + numPackets = 10_000_000 ) - p1, p2 := connection.NewPipe(t.Context()) + p1, p2 := connection.NewPipe(t.Context(), packetSize) payload := bytes.Repeat([]byte("X"), packetSize) buf := make([]byte, packetSize) - done := make(chan struct{}) + var bytesTransferred int64 + var packetsTransferred int64 + // Reader goroutine go func() { for i := 0; i < numPackets; i++ { if _, err := p2.ReadPacket(buf); err != nil { - t.Fatal(err) + select { + case <-t.Context().Done(): + default: + t.Fatalf("Read error: %v", err) + } + return } + atomic.AddInt64(&bytesTransferred, int64(packetSize)) + atomic.AddInt64(&packetsTransferred, 1) } - close(done) }() + // Reporter goroutine + ticker := time.NewTicker(1 * time.Second) + defer ticker.Stop() + + go func(startTime time.Time) { + lastTransferred := int64(0) + for range ticker.C { + currentTransferred := atomic.LoadInt64(&bytesTransferred) + bytesThisSecond := currentTransferred - lastTransferred + lastTransferred = currentTransferred + + throughputGbps := (float64(bytesThisSecond*8) / 1e9) + elapsed := time.Since(startTime).Truncate(time.Second) + t.Logf("[+%s] Throughput: %.2f Gbps", elapsed, throughputGbps) + t.Logf("[+%s] Packets: %d", elapsed, atomic.LoadInt64(&packetsTransferred)) + } + }(time.Now()) + start := time.Now() + // Writer loop for i := 0; i < numPackets; i++ { if _, err := p1.WritePacket(payload); err != nil { t.Fatal(err) } } - <-done duration := time.Since(start) - throughputGbps := (float64(packetSize*numPackets*8) / 1e9) / duration.Seconds() - t.Logf("Sent %d packets of %d bytes in %s", numPackets, packetSize, duration) - t.Logf("Throughput: %.2f Gbps", throughputGbps) + totalThroughputGbps := (float64(packetSize*numPackets*8) / 1e9) / duration.Seconds() + t.Logf("Total Throughput: %.2f Gbps", totalThroughputGbps) } diff --git a/pkg/tunnel/fasttun/fasttun.go b/pkg/tunnel/fasttun/fasttun.go new file mode 100644 index 00000000..8c2e22fc --- /dev/null +++ b/pkg/tunnel/fasttun/fasttun.go @@ -0,0 +1,44 @@ +// Package fasttun implements a high-performance interface to Linux TUN devices +// with support for multi-queue and batched packet I/O. +package fasttun + +import "io" + +// Device represents a virtual TUN network interface. +// It provides methods to query device properties and create packet queues +// for reading and writing packets concurrently. +type Device interface { + io.Closer + + // Name returns the name of the TUN device (e.g., "tun0"). + Name() string + + // MTU returns the device's Maximum Transmission Unit. + MTU() (int, error) + + // BatchSize returns the recommended number of packets to process in one batch. + // This is useful for optimizing I/O performance. + BatchSize() int + + // NewPacketQueue creates a new packet queue for the device. + // Each queue is associated with a file descriptor and can be used + // concurrently with others. + NewPacketQueue() (PacketQueue, error) +} + +// PacketQueue represents a single queue for sending and receiving packets +// from a TUN device. It supports batch I/O for efficient packet processing. +type PacketQueue interface { + io.Closer + + // Read reads packets into the provided buffer slices `pkts` and stores + // the size of each packet in `sizes`. + // + // It returns the number of packets successfully read and an error, if any. + // On timeout or no available packets, it may return (0, nil). + Read(pkts [][]byte, sizes []int) (n int, err error) + + // Write writes the given packets to the TUN device. + // It returns the number of packets successfully written and an error, if any. + Write(pkts [][]byte) (int, error) +} diff --git a/pkg/tunnel/fasttun/fasttun_linux.go b/pkg/tunnel/fasttun/fasttun_linux.go new file mode 100644 index 00000000..73c9cde9 --- /dev/null +++ b/pkg/tunnel/fasttun/fasttun_linux.go @@ -0,0 +1,195 @@ +//go:build linux + +package fasttun + +import ( + "fmt" + "os" + "sync" + "time" + + "github.com/vishvananda/netlink" + "golang.org/x/sys/unix" +) + +var _ Device = (*LinuxDevice)(nil) + +type LinuxDevice struct { + name string + mtu int + packetQueuesMu sync.Mutex + packetQueues []*LinuxPacketQueue + configureLinkOnce sync.Once +} + +// NewDevice creates a new Linux TUN device with the given name and MTU. +// Initialization of the device is deferred until the first packet queue is created. +func NewDevice(name string, mtu int) *LinuxDevice { + return &LinuxDevice{ + name: name, + mtu: mtu, + } +} + +func (d *LinuxDevice) Close() error { + d.packetQueuesMu.Lock() + defer d.packetQueuesMu.Unlock() + + var closeErr error + for _, q := range d.packetQueues { + if err := q.Close(); err != nil && closeErr == nil { + closeErr = err // capture the first error + } + } + d.packetQueues = nil + + return fmt.Errorf("failed to close packet queues: %w", closeErr) +} + +func (d *LinuxDevice) Name() string { + return d.name +} + +func (d *LinuxDevice) MTU() (int, error) { + return d.mtu, nil +} + +func (d *LinuxDevice) BatchSize() int { + return 64 +} + +// NewPacketQueue creates a new packet queue for the device. +func (d *LinuxDevice) NewPacketQueue() (PacketQueue, error) { + fd, err := unix.Open("/dev/net/tun", unix.O_RDWR|unix.O_CLOEXEC, 0) + if err != nil { + return nil, fmt.Errorf("failed to open /dev/net/tun: %w", err) + } + + ifr, err := unix.NewIfreq(d.name) + if err != nil { + return nil, err + } + + ifr.SetUint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_MULTI_QUEUE) + if err := unix.IoctlIfreq(fd, unix.TUNSETIFF, ifr); err != nil { + return nil, err + } + + if err := unix.SetNonblock(fd, true); err != nil { + unix.Close(fd) + return nil, err + } + + tunFile := os.NewFile(uintptr(fd), "/dev/net/tun") + + q := &LinuxPacketQueue{ + tunFile: tunFile, + } + + // Store a reference to the packet queue. + d.packetQueuesMu.Lock() + d.packetQueues = append(d.packetQueues, q) + d.packetQueuesMu.Unlock() + + d.configureLinkOnce.Do(func() { + link, err := netlink.LinkByName(d.name) + if err != nil { + err = fmt.Errorf("failed to get link by name: %w", err) + } + + if err := netlink.LinkSetMTU(link, d.mtu); err != nil { + err = fmt.Errorf("failed to set MTU: %w", err) + } + + if err := netlink.LinkSetUp(link); err != nil { + err = fmt.Errorf("failed to set link up: %w", err) + } + }) + if err != nil { + _ = q.Close() + return nil, fmt.Errorf("failed to configure link: %w", err) + } + + return q, nil +} + +type LinuxPacketQueue struct { + tunFile *os.File +} + +func (q *LinuxPacketQueue) Close() error { + return q.tunFile.Close() +} + +func (q *LinuxPacketQueue) Read(pkts [][]byte, sizes []int) (int, error) { + fd := int(q.tunFile.Fd()) + timeout := 50 * time.Millisecond + + pollFds := []unix.PollFd{ + { + Fd: int32(fd), + Events: unix.POLLIN, + }, + } + + n := 0 + for i := 0; i < len(pkts); i++ { + if i == 0 { + // Wait for initial packet or timeout + nReady, err := pollWithRetry(pollFds, int(timeout.Milliseconds())) + if err != nil { + return 0, fmt.Errorf("poll error: %w", err) + } + if nReady == 0 { + return 0, nil // timeout, no packets available + } + } else { + // Check if more data is immediately ready + pollFds[0].Events = unix.POLLIN + pollFds[0].Revents = 0 + nReady, err := pollWithRetry(pollFds, 0) + if err != nil { + return n, fmt.Errorf("poll error during batching: %w", err) + } + if nReady == 0 { + break // no more packets ready + } + } + + buf := pkts[i] + nRead, err := q.tunFile.Read(buf) + if err != nil { + if n == 0 { + return 0, err + } + return n, nil // return packets read so far + } + sizes[i] = nRead + n++ + } + + return n, nil +} + +func (q *LinuxPacketQueue) Write(pkts [][]byte) (int, error) { + for i, pkt := range pkts { + _, err := q.tunFile.Write(pkt) + if err != nil { + if i == 0 { + return 0, err + } + return i, nil + } + } + return len(pkts), nil +} + +func pollWithRetry(pollFds []unix.PollFd, timeout int) (int, error) { + for { + n, err := unix.Poll(pollFds, timeout) + if err == unix.EINTR { + continue // retry on EINTR + } + return n, err + } +}