Skip to content

Commit e28d330

Browse files
dilyevskyclaude
andcommitted
[tunnel] Add multi-ping edge selection with outlier removal
Enhance the tunnel endpoint latency selector to run multiple pings per endpoint and discard outliers before selecting the best edge. - Add pingsPerEndpoint option (default: 3) to LatencySelector - Refactor probe() to reuse QUIC connection across multiple pings - Implement trimmed mean aggregation (discards high/low values) - Add comprehensive tests for aggregateLatencies function This reduces susceptibility to noisy measurements from transient network conditions when selecting tunnel endpoints. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent 6ab2f66 commit e28d330

File tree

3 files changed

+203
-41
lines changed

3 files changed

+203
-41
lines changed

pkg/cmd/tunnel/run.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,9 @@ var tunnelRunCmd = &cobra.Command{
149149
if err != nil {
150150
return fmt.Errorf("unable to parse endpoint selection strategy: %w", err)
151151
}
152-
selectorOpts := []endpointselect.Option{}
152+
selectorOpts := []endpointselect.Option{
153+
endpointselect.WithPingsPerEndpoint(3),
154+
}
153155
if insecureSkipVerify {
154156
selectorOpts = append(selectorOpts, endpointselect.WithInsecureSkipVerify(true))
155157
}

pkg/tunnel/endpointselect/latency.go

Lines changed: 122 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,18 @@ const (
2121
DefaultProbeTimeout = 3 * time.Second
2222
// DefaultMaxConcurrent is the default maximum number of concurrent probes.
2323
DefaultMaxConcurrent = 10
24+
// DefaultPingsPerEndpoint is the default number of ping requests per endpoint.
25+
DefaultPingsPerEndpoint = 3
2426
)
2527

2628
// Option configures a LatencySelector.
2729
type Option func(*latencyOptions)
2830

2931
type latencyOptions struct {
30-
probeTimeout time.Duration
31-
maxConcurrent int
32-
insecureSkip bool
32+
probeTimeout time.Duration
33+
maxConcurrent int
34+
insecureSkip bool
35+
pingsPerEndpoint int
3336
}
3437

3538
// WithProbeTimeout sets the timeout for each endpoint probe.
@@ -53,6 +56,17 @@ func WithInsecureSkipVerify(skip bool) Option {
5356
}
5457
}
5558

59+
// WithPingsPerEndpoint sets the number of ping requests per endpoint.
60+
// The latencies are aggregated using a trimmed mean (removing outliers).
61+
func WithPingsPerEndpoint(n int) Option {
62+
return func(o *latencyOptions) {
63+
if n < 1 {
64+
n = 1
65+
}
66+
o.pingsPerEndpoint = n
67+
}
68+
}
69+
5670
// LatencySelector selects endpoints based on QUIC handshake latency.
5771
type LatencySelector struct {
5872
opts latencyOptions
@@ -61,8 +75,9 @@ type LatencySelector struct {
6175
// NewLatencySelector creates a new LatencySelector.
6276
func NewLatencySelector(opts ...Option) *LatencySelector {
6377
options := latencyOptions{
64-
probeTimeout: DefaultProbeTimeout,
65-
maxConcurrent: DefaultMaxConcurrent,
78+
probeTimeout: DefaultProbeTimeout,
79+
maxConcurrent: DefaultMaxConcurrent,
80+
pingsPerEndpoint: DefaultPingsPerEndpoint,
6681
}
6782
for _, opt := range opts {
6883
opt(&options)
@@ -153,17 +168,9 @@ func (s *LatencySelector) probeAll(ctx context.Context, endpoints []string) []Pr
153168
return results
154169
}
155170

156-
// probe measures the round-trip latency to a single endpoint by making
157-
// an HTTP/3 request to the /ping endpoint.
158-
func (s *LatencySelector) probe(ctx context.Context, addr string) ProbeResult {
159-
result := ProbeResult{
160-
Addr: addr,
161-
ProbedAt: time.Now(),
162-
}
163-
164-
probeCtx, cancel := context.WithTimeout(ctx, s.opts.probeTimeout)
165-
defer cancel()
166-
171+
// dialEndpoint establishes a QUIC connection and creates an HTTP/3 client connection.
172+
// Returns the QUIC connection, HTTP/3 client connection, and any error.
173+
func (s *LatencySelector) dialEndpoint(ctx context.Context, addr string) (quic.Connection, *http3.ClientConn, error) {
167174
// Extract hostname from address for TLS ServerName.
168175
serverName := "proxy"
169176
if host, _, err := net.SplitHostPort(addr); err == nil && net.ParseIP(host) == nil {
@@ -181,55 +188,130 @@ func (s *LatencySelector) probe(ctx context.Context, addr string) ProbeResult {
181188
InitialPacketSize: 1350,
182189
}
183190

184-
start := time.Now()
185-
186191
// Dial QUIC connection.
187-
qConn, err := quic.DialAddr(probeCtx, addr, tlsConfig, quicConfig)
192+
qConn, err := quic.DialAddr(ctx, addr, tlsConfig, quicConfig)
188193
if err != nil {
189-
result.Error = err
190-
slog.Debug("Endpoint probe failed (QUIC dial)",
191-
slog.String("addr", addr),
192-
slog.Any("error", err))
193-
return result
194+
return nil, nil, err
194195
}
195-
defer qConn.CloseWithError(0, "probe complete")
196196

197-
// Make HTTP/3 request to /ping endpoint.
197+
// Create HTTP/3 client connection.
198198
tr := &http3.Transport{EnableDatagrams: true}
199199
hConn := tr.NewClientConn(qConn)
200200

201-
req, err := http.NewRequestWithContext(probeCtx, "GET", "https://proxy/ping", nil)
201+
return qConn, hConn, nil
202+
}
203+
204+
// pingSingle performs a single HTTP/3 GET /ping request over an existing connection.
205+
// Returns the round-trip latency or an error.
206+
func (s *LatencySelector) pingSingle(ctx context.Context, hConn *http3.ClientConn) (time.Duration, error) {
207+
req, err := http.NewRequestWithContext(ctx, "GET", "https://proxy/ping", nil)
202208
if err != nil {
203-
result.Error = err
204-
slog.Debug("Endpoint probe failed (request creation)",
205-
slog.String("addr", addr),
206-
slog.Any("error", err))
207-
return result
209+
return 0, err
208210
}
209211

212+
start := time.Now()
210213
resp, err := hConn.RoundTrip(req)
214+
if err != nil {
215+
return 0, err
216+
}
217+
defer resp.Body.Close()
218+
219+
if resp.StatusCode != http.StatusOK {
220+
return 0, fmt.Errorf("ping returned status %d", resp.StatusCode)
221+
}
222+
223+
return time.Since(start), nil
224+
}
225+
226+
// aggregateLatencies computes the trimmed mean of latencies by discarding
227+
// the highest and lowest values (if there are enough samples) and averaging the rest.
228+
// For 1-2 samples, returns the median. For 3+ samples, discards extremes.
229+
func aggregateLatencies(pings []time.Duration) time.Duration {
230+
if len(pings) == 0 {
231+
return 0
232+
}
233+
if len(pings) == 1 {
234+
return pings[0]
235+
}
236+
237+
// Sort the pings.
238+
sorted := make([]time.Duration, len(pings))
239+
copy(sorted, pings)
240+
sort.Slice(sorted, func(i, j int) bool {
241+
return sorted[i] < sorted[j]
242+
})
243+
244+
if len(sorted) == 2 {
245+
// Return median (average of two).
246+
return (sorted[0] + sorted[1]) / 2
247+
}
248+
249+
// Discard highest and lowest, average the rest.
250+
trimmed := sorted[1 : len(sorted)-1]
251+
var sum time.Duration
252+
for _, d := range trimmed {
253+
sum += d
254+
}
255+
return sum / time.Duration(len(trimmed))
256+
}
257+
258+
// probe measures the round-trip latency to a single endpoint by making
259+
// multiple HTTP/3 requests to the /ping endpoint and aggregating the results.
260+
func (s *LatencySelector) probe(ctx context.Context, addr string) ProbeResult {
261+
result := ProbeResult{
262+
Addr: addr,
263+
ProbedAt: time.Now(),
264+
}
265+
266+
probeCtx, cancel := context.WithTimeout(ctx, s.opts.probeTimeout)
267+
defer cancel()
268+
269+
// Establish connection.
270+
qConn, hConn, err := s.dialEndpoint(probeCtx, addr)
211271
if err != nil {
212272
result.Error = err
213-
slog.Debug("Endpoint probe failed (HTTP/3 request)",
273+
slog.Debug("Endpoint probe failed (QUIC dial)",
214274
slog.String("addr", addr),
215275
slog.Any("error", err))
216276
return result
217277
}
218-
defer resp.Body.Close()
278+
defer qConn.CloseWithError(0, "probe complete")
219279

220-
if resp.StatusCode != http.StatusOK {
221-
result.Error = fmt.Errorf("ping returned status %d", resp.StatusCode)
222-
slog.Debug("Endpoint probe failed (bad status)",
280+
// Perform multiple pings and collect latencies.
281+
var pings []time.Duration
282+
for i := 0; i < s.opts.pingsPerEndpoint; i++ {
283+
latency, err := s.pingSingle(probeCtx, hConn)
284+
if err != nil {
285+
slog.Debug("Endpoint ping failed",
286+
slog.String("addr", addr),
287+
slog.Int("ping", i+1),
288+
slog.Any("error", err))
289+
// Continue to collect as many pings as possible.
290+
continue
291+
}
292+
pings = append(pings, latency)
293+
slog.Debug("Endpoint ping succeeded",
223294
slog.String("addr", addr),
224-
slog.Int("status", resp.StatusCode))
295+
slog.Int("ping", i+1),
296+
slog.Duration("latency", latency))
297+
}
298+
299+
// If no pings succeeded, return an error.
300+
if len(pings) == 0 {
301+
result.Error = fmt.Errorf("all %d pings failed", s.opts.pingsPerEndpoint)
302+
slog.Debug("Endpoint probe failed (all pings failed)",
303+
slog.String("addr", addr))
225304
return result
226305
}
227306

228-
result.Latency = time.Since(start)
307+
// Aggregate latencies using trimmed mean.
308+
result.Latency = aggregateLatencies(pings)
229309

230310
slog.Debug("Endpoint probe succeeded",
231311
slog.String("addr", addr),
232-
slog.Duration("latency", result.Latency))
312+
slog.Int("successful_pings", len(pings)),
313+
slog.Int("total_pings", s.opts.pingsPerEndpoint),
314+
slog.Duration("aggregated_latency", result.Latency))
233315

234316
return result
235317
}

pkg/tunnel/endpointselect/selector_test.go

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,10 +107,88 @@ func TestLatencySelector_Select(t *testing.T) {
107107
WithProbeTimeout(500*time.Millisecond),
108108
WithMaxConcurrent(5),
109109
WithInsecureSkipVerify(true),
110+
WithPingsPerEndpoint(5),
110111
)
111112
assert.Equal(t, 500*time.Millisecond, s.opts.probeTimeout)
112113
assert.Equal(t, 5, s.opts.maxConcurrent)
113114
assert.True(t, s.opts.insecureSkip)
115+
assert.Equal(t, 5, s.opts.pingsPerEndpoint)
116+
})
117+
118+
t.Run("default pings per endpoint is 3", func(t *testing.T) {
119+
s := NewLatencySelector()
120+
assert.Equal(t, 3, s.opts.pingsPerEndpoint)
121+
})
122+
123+
t.Run("pings per endpoint minimum is 1", func(t *testing.T) {
124+
s := NewLatencySelector(WithPingsPerEndpoint(0))
125+
assert.Equal(t, 1, s.opts.pingsPerEndpoint)
126+
127+
s = NewLatencySelector(WithPingsPerEndpoint(-5))
128+
assert.Equal(t, 1, s.opts.pingsPerEndpoint)
129+
})
130+
}
131+
132+
func TestAggregateLatencies(t *testing.T) {
133+
t.Run("empty slice returns zero", func(t *testing.T) {
134+
result := aggregateLatencies([]time.Duration{})
135+
assert.Equal(t, time.Duration(0), result)
136+
})
137+
138+
t.Run("single value returns that value", func(t *testing.T) {
139+
result := aggregateLatencies([]time.Duration{100 * time.Millisecond})
140+
assert.Equal(t, 100*time.Millisecond, result)
141+
})
142+
143+
t.Run("two values returns average", func(t *testing.T) {
144+
result := aggregateLatencies([]time.Duration{
145+
100 * time.Millisecond,
146+
200 * time.Millisecond,
147+
})
148+
assert.Equal(t, 150*time.Millisecond, result)
149+
})
150+
151+
t.Run("three values returns median (middle value)", func(t *testing.T) {
152+
// With 3 pings: discards high and low, returns the one remaining (median).
153+
result := aggregateLatencies([]time.Duration{
154+
100 * time.Millisecond, // low - discarded
155+
150 * time.Millisecond, // middle - kept
156+
300 * time.Millisecond, // high - discarded
157+
})
158+
assert.Equal(t, 150*time.Millisecond, result)
159+
})
160+
161+
t.Run("three values in different order returns same median", func(t *testing.T) {
162+
// Verify sorting works correctly.
163+
result := aggregateLatencies([]time.Duration{
164+
300 * time.Millisecond, // high - discarded
165+
100 * time.Millisecond, // low - discarded
166+
150 * time.Millisecond, // middle - kept
167+
})
168+
assert.Equal(t, 150*time.Millisecond, result)
169+
})
170+
171+
t.Run("five values discards outliers and averages middle three", func(t *testing.T) {
172+
result := aggregateLatencies([]time.Duration{
173+
50 * time.Millisecond, // low - discarded
174+
100 * time.Millisecond, // kept
175+
150 * time.Millisecond, // kept
176+
200 * time.Millisecond, // kept
177+
500 * time.Millisecond, // high - discarded
178+
})
179+
// Average of 100, 150, 200 = 450/3 = 150ms
180+
assert.Equal(t, 150*time.Millisecond, result)
181+
})
182+
183+
t.Run("removes outliers from noisy measurements", func(t *testing.T) {
184+
// Simulate realistic scenario: 2 normal pings and 1 outlier.
185+
result := aggregateLatencies([]time.Duration{
186+
25 * time.Millisecond, // normal
187+
30 * time.Millisecond, // normal
188+
500 * time.Millisecond, // outlier (network blip)
189+
})
190+
// Should return 30ms (middle value after sorting: 25, 30, 500).
191+
assert.Equal(t, 30*time.Millisecond, result)
114192
})
115193
}
116194

0 commit comments

Comments
 (0)