From 96b3ae4f74df8d00a98ca1e57357419425fed7d5 Mon Sep 17 00:00:00 2001 From: Damian Peckett Date: Wed, 14 May 2025 19:01:49 +0200 Subject: [PATCH] [connect-ip] further enhance packet splice and add high performance connection.Pipe implementation --- go.mod | 1 + go.sum | 2 + pkg/tunnel/connection/pipe.go | 129 ++++++++++++++++++ pkg/tunnel/connection/pipe_test.go | 47 +++++++ pkg/tunnel/connection/splice.go | 203 ++++++++++++++++++++++++++--- 5 files changed, 364 insertions(+), 18 deletions(-) create mode 100644 pkg/tunnel/connection/pipe.go create mode 100644 pkg/tunnel/connection/pipe_test.go diff --git a/go.mod b/go.mod index 75886586..5f009d6b 100644 --- a/go.mod +++ b/go.mod @@ -222,6 +222,7 @@ require ( github.com/hashicorp/mdns v1.0.1 // indirect github.com/hashicorp/memberlist v0.5.0 // indirect github.com/hashicorp/vic v1.5.1-0.20190403131502-bbfe86ec9443 // indirect + github.com/hedzr/go-ringbuf/v2 v2.2.1 // indirect github.com/iancoleman/strcase v0.3.0 // indirect github.com/imdario/mergo v0.3.16 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect diff --git a/go.sum b/go.sum index abc20321..215340f1 100644 --- a/go.sum +++ b/go.sum @@ -612,6 +612,8 @@ github.com/hashicorp/memberlist v0.5.0 h1:EtYPN8DpAURiapus508I4n9CzHs2W+8NZGbmmR github.com/hashicorp/memberlist v0.5.0/go.mod h1:yvyXLpo0QaGE59Y7hDTsTzDD25JYBZ4mHgHUZ8lrOI0= github.com/hashicorp/vic v1.5.1-0.20190403131502-bbfe86ec9443 h1:O/pT5C1Q3mVXMyuqg7yuAWUg/jMZR1/0QTzTRdNR6Uw= github.com/hashicorp/vic v1.5.1-0.20190403131502-bbfe86ec9443/go.mod h1:bEpDU35nTu0ey1EXjwNwPjI9xErAsoOCmcMb9GKvyxo= +github.com/hedzr/go-ringbuf/v2 v2.2.1 h1:bnIRxSCWYt4vs5UCDCOYf+r1C8cQC7tkcOdjOTaVzNk= +github.com/hedzr/go-ringbuf/v2 v2.2.1/go.mod h1:N3HsRpbHvPkX9GsykpkPoR2vD6WRR6GbU7tx/9GLE4M= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/iancoleman/strcase v0.3.0 h1:nTXanmYxhfFAMjZL34Ov6gkzEsSJZ5DbhxWjvSASxEI= github.com/iancoleman/strcase v0.3.0/go.mod h1:iwCmte+B7n89clKwxIoIXy/HfoL7AsD47ZCWhYzw7ho= diff --git a/pkg/tunnel/connection/pipe.go b/pkg/tunnel/connection/pipe.go new file mode 100644 index 00000000..aff31796 --- /dev/null +++ b/pkg/tunnel/connection/pipe.go @@ -0,0 +1,129 @@ +package connection + +import ( + "context" + "errors" + "runtime" + "sync" + + "github.com/hedzr/go-ringbuf/v2" + "github.com/hedzr/go-ringbuf/v2/mpmc" +) + +var _ Connection = (*Pipe)(nil) + +type Pipe struct { + readRing mpmc.RingBuffer[[]byte] + writeRing mpmc.RingBuffer[[]byte] + ctx context.Context + cancel context.CancelFunc +} + +var bufPool = sync.Pool{ + New: func() interface{} { + b := make([]byte, 2048) // Adjust size if needed + return &b + }, +} + +// NewPipe creates a pair of connected Pipe instances for bidirectional communication. +func NewPipe(ctx context.Context) (*Pipe, *Pipe) { + ringAtoB := ringbuf.New[[]byte](1024) + ringBtoA := ringbuf.New[[]byte](1024) + + ctx, cancel := context.WithCancel(ctx) + + pipeA := &Pipe{ + readRing: ringBtoA, + writeRing: ringAtoB, + ctx: ctx, + cancel: cancel, + } + pipeB := &Pipe{ + readRing: ringAtoB, + writeRing: ringBtoA, + ctx: ctx, + cancel: cancel, + } + + return pipeA, pipeB +} + +// ReadPacket reads a packet into the provided buffer. +func (p *Pipe) ReadPacket(buf []byte) (int, error) { + select { + case <-p.ctx.Done(): + return 0, errors.New("pipe closed") + default: + var item []byte + var err error + for { + item, err = p.readRing.Dequeue() + if err != nil { + if errors.Is(err, mpmc.ErrQueueEmpty) { + runtime.Gosched() + + // Has the context been cancelled? + select { + case <-p.ctx.Done(): + bufPool.Put(&item) + return 0, errors.New("pipe closed") + default: + // Continue to try to dequeue + continue + } + } + return 0, err + } + break + } + n := copy(buf, item) + bufPool.Put(&item) + return n, nil + } +} + +// WritePacket writes a packet from the provided buffer. +func (p *Pipe) WritePacket(b []byte) ([]byte, error) { + select { + case <-p.ctx.Done(): + return nil, errors.New("pipe closed") + default: + bufPtr := bufPool.Get().(*[]byte) + buf := *bufPtr + if cap(buf) < len(b) { + buf = make([]byte, len(b)) + } + buf = buf[:len(b)] + copy(buf, b) + + for { + err := p.writeRing.Enqueue(buf) + if err != nil { + if errors.Is(err, mpmc.ErrQueueFull) { + runtime.Gosched() + + // Has the context been cancelled? + select { + case <-p.ctx.Done(): + bufPool.Put(&buf) + return nil, errors.New("pipe closed") + default: + // Continue to try to enqueue + continue + } + } + return nil, err + } + break + } + + return nil, nil + } +} + +// Close terminates the pipe. +func (p *Pipe) Close() error { + p.cancel() + return nil +} diff --git a/pkg/tunnel/connection/pipe_test.go b/pkg/tunnel/connection/pipe_test.go new file mode 100644 index 00000000..8e90f2d5 --- /dev/null +++ b/pkg/tunnel/connection/pipe_test.go @@ -0,0 +1,47 @@ +package connection_test + +import ( + "bytes" + "testing" + "time" + + "github.com/apoxy-dev/apoxy-cli/pkg/tunnel/connection" +) + +func TestPipeThroughput(t *testing.T) { + const ( + packetSize = 1024 // 1 KB per packet + numPackets = 1_000_000 // Total packets to send + ) + + p1, p2 := connection.NewPipe(t.Context()) + + payload := bytes.Repeat([]byte("X"), packetSize) + buf := make([]byte, packetSize) + + done := make(chan struct{}) + + go func() { + for i := 0; i < numPackets; i++ { + if _, err := p2.ReadPacket(buf); err != nil { + t.Fatal(err) + } + } + close(done) + }() + + start := time.Now() + + for i := 0; i < numPackets; i++ { + if _, err := p1.WritePacket(payload); err != nil { + t.Fatal(err) + } + } + + <-done + duration := time.Since(start) + + throughputGbps := (float64(packetSize*numPackets*8) / 1e9) / duration.Seconds() + t.Logf("Sent %d packets of %d bytes in %s", numPackets, packetSize, duration) + t.Logf("Throughput: %.2f Gbps", throughputGbps) +} diff --git a/pkg/tunnel/connection/splice.go b/pkg/tunnel/connection/splice.go index f274042f..d220fa23 100644 --- a/pkg/tunnel/connection/splice.go +++ b/pkg/tunnel/connection/splice.go @@ -6,10 +6,12 @@ import ( "log/slog" "net" "strings" + "sync" "golang.org/x/sync/errgroup" "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/tun" + "k8s.io/utils/ptr" "github.com/apoxy-dev/apoxy-cli/pkg/netstack" ) @@ -21,9 +23,15 @@ const ( func Splice(tunDev tun.Device, conn Connection) error { var g errgroup.Group + stats := newSpliceStats() + batchSize := tunDev.BatchSize() g.Go(func() error { + defer func() { + slog.Debug("Stopped reading from TUN") + }() + defer conn.Close() sizes := make([]int, batchSize) @@ -35,29 +43,44 @@ func Splice(tunDev tun.Device, conn Connection) error { for { n, err := tunDev.Read(pkts, sizes, 0) if err != nil { + if strings.Contains(err.Error(), "closed") { + slog.Debug("TUN device closed") + return net.ErrClosed + } + if errors.Is(err, tun.ErrTooManySegments) { slog.Warn("Dropped packets from multi-segment TUN read", slog.Any("error", err)) continue } - if strings.Contains(err.Error(), "closed") { - slog.Debug("TUN device closed") - return nil - } + return fmt.Errorf("failed to read from TUN: %w", err) } + stats.recordReadBatch(n) + for i := 0; i < n; i++ { slog.Debug("Read packet from TUN", slog.Int("len", sizes[i])) icmp, err := conn.WritePacket(pkts[i][:sizes[i]]) if err != nil { + if strings.Contains(err.Error(), "closed") { + slog.Debug("Connection closed") + return net.ErrClosed + } + slog.Error("Failed to write to connection", slog.Any("error", err)) - continue + return fmt.Errorf("failed to write to connection: %w", err) } if len(icmp) > 0 { slog.Debug("Sending ICMP packet") if _, err := tunDev.Write([][]byte{icmp}, 0); err != nil { + if strings.Contains(err.Error(), "closed") { + slog.Debug("TUN device closed") + return net.ErrClosed + } + slog.Error("Failed to write ICMP packet", slog.Any("error", err)) + return fmt.Errorf("failed to write ICMP packet: %w", err) } } } @@ -65,25 +88,85 @@ func Splice(tunDev tun.Device, conn Connection) error { }) g.Go(func() error { - pkts := make([][]byte, batchSize) - for i := range pkts { - pkts[i] = make([]byte, netstack.IPv6MinMTU+tunOffset) + defer func() { + slog.Debug("Stopped reading from connection") + }() + + var pktPool = sync.Pool{ + New: func() any { + return ptr.To(make([]byte, netstack.IPv6MinMTU+tunOffset)) + }, } - // TODO: batched write to TUN device, unfortunately ReadPacket() is blocking - // and not batched which makes this tricky. + pktCh := make(chan *[]byte, batchSize) - for { - n, err := conn.ReadPacket(pkts[0][tunOffset:]) - if err != nil { - return fmt.Errorf("failed to read from connection: %w", err) + g.Go(func() error { + defer close(pktCh) + + for { + pkt := pktPool.Get().(*[]byte) + n, err := conn.ReadPacket((*pkt)[tunOffset:]) + if err != nil { + if strings.Contains(err.Error(), "closed") { + slog.Debug("Connection closed") + return net.ErrClosed + } + + slog.Error("Failed to read from connection", slog.Any("error", err)) + return fmt.Errorf("failed to read from connection: %w", err) + } + + slog.Debug("Read packet from connection", slog.Int("len", n)) + + *pkt = (*pkt)[:n+tunOffset] + pktCh <- pkt } + }) + + pkts := make([][]byte, batchSize) + + for { + select { + case pkt, ok := <-pktCh: + if !ok { + return nil + } + + pkts[0] = *pkt + batchCount := 1 - slog.Debug("Read from connection", slog.Int("bytes", n)) + closed := false + gatherBatch: + for batchCount < batchSize && !closed { + select { + case pkt, ok := <-pktCh: + if !ok { + closed = true + break + } + pkts[batchCount] = *pkt + batchCount++ + default: + break gatherBatch + } + } + + stats.recordWriteBatch(batchCount) + + if _, err := tunDev.Write(pkts[:batchCount], tunOffset); err != nil { + if strings.Contains(err.Error(), "closed") { + slog.Debug("TUN device closed") + return net.ErrClosed + } - if _, err := tunDev.Write([][]byte{pkts[0][:n+tunOffset]}, tunOffset); err != nil { - slog.Error("Failed to write to TUN", slog.Any("error", err)) - continue + slog.Error("Failed to write to TUN", slog.Any("error", err)) + return fmt.Errorf("failed to write to TUN: %w", err) + } + + for i := 0; i < batchCount; i++ { + pkt := pkts[i][:cap(pkts[i])] + pktPool.Put(&pkt) + } } } }) @@ -92,5 +175,89 @@ func Splice(tunDev tun.Device, conn Connection) error { return fmt.Errorf("failed to splice: %w", err) } + name, _ := tunDev.Name() + + slog.Debug("Splice completed", + slog.String("name", name), + slog.Int("batch_size", batchSize), + slog.Any("read_summary", stats.readSummary()), + slog.Any("write_summary", stats.writeSummary()), + ) + return nil } + +type spliceStats struct { + mu sync.Mutex + readBatchSizes map[int]int + writeBatchSizes map[int]int +} + +func newSpliceStats() *spliceStats { + return &spliceStats{ + readBatchSizes: make(map[int]int), + writeBatchSizes: make(map[int]int), + } +} + +func (s *spliceStats) recordReadBatch(n int) { + s.mu.Lock() + defer s.mu.Unlock() + s.readBatchSizes[n]++ +} + +func (s *spliceStats) recordWriteBatch(n int) { + s.mu.Lock() + defer s.mu.Unlock() + s.writeBatchSizes[n]++ +} + +func (s *spliceStats) readSummary() batchSummary { + s.mu.Lock() + defer s.mu.Unlock() + return computeSummary(s.readBatchSizes) +} + +func (s *spliceStats) writeSummary() batchSummary { + s.mu.Lock() + defer s.mu.Unlock() + return computeSummary(s.writeBatchSizes) +} + +type batchSummary struct { + TotalBatches int + MinSize int + MaxSize int + AvgSize float64 +} + +func computeSummary(hist map[int]int) batchSummary { + if len(hist) == 0 { + return batchSummary{} + } + + var ( + totalCount int + totalSize int + minSize = int(^uint(0) >> 1) // Max int + maxSize int + ) + + for size, count := range hist { + if size < minSize { + minSize = size + } + if size > maxSize { + maxSize = size + } + totalCount += count + totalSize += size * count + } + + return batchSummary{ + TotalBatches: totalCount, + MinSize: minSize, + MaxSize: maxSize, + AvgSize: float64(totalSize) / float64(totalCount), + } +}