@@ -21,15 +21,18 @@ const (
2121 DefaultProbeTimeout = 3 * time .Second
2222 // DefaultMaxConcurrent is the default maximum number of concurrent probes.
2323 DefaultMaxConcurrent = 10
24+ // DefaultPingsPerEndpoint is the default number of ping requests per endpoint.
25+ DefaultPingsPerEndpoint = 3
2426)
2527
2628// Option configures a LatencySelector.
2729type Option func (* latencyOptions )
2830
2931type latencyOptions struct {
30- probeTimeout time.Duration
31- maxConcurrent int
32- insecureSkip bool
32+ probeTimeout time.Duration
33+ maxConcurrent int
34+ insecureSkip bool
35+ pingsPerEndpoint int
3336}
3437
3538// WithProbeTimeout sets the timeout for each endpoint probe.
@@ -53,6 +56,17 @@ func WithInsecureSkipVerify(skip bool) Option {
5356 }
5457}
5558
59+ // WithPingsPerEndpoint sets the number of ping requests per endpoint.
60+ // The latencies are aggregated using a trimmed mean (removing outliers).
61+ func WithPingsPerEndpoint (n int ) Option {
62+ return func (o * latencyOptions ) {
63+ if n < 1 {
64+ n = 1
65+ }
66+ o .pingsPerEndpoint = n
67+ }
68+ }
69+
5670// LatencySelector selects endpoints based on QUIC handshake latency.
5771type LatencySelector struct {
5872 opts latencyOptions
@@ -61,8 +75,9 @@ type LatencySelector struct {
6175// NewLatencySelector creates a new LatencySelector.
6276func NewLatencySelector (opts ... Option ) * LatencySelector {
6377 options := latencyOptions {
64- probeTimeout : DefaultProbeTimeout ,
65- maxConcurrent : DefaultMaxConcurrent ,
78+ probeTimeout : DefaultProbeTimeout ,
79+ maxConcurrent : DefaultMaxConcurrent ,
80+ pingsPerEndpoint : DefaultPingsPerEndpoint ,
6681 }
6782 for _ , opt := range opts {
6883 opt (& options )
@@ -153,17 +168,9 @@ func (s *LatencySelector) probeAll(ctx context.Context, endpoints []string) []Pr
153168 return results
154169}
155170
156- // probe measures the round-trip latency to a single endpoint by making
157- // an HTTP/3 request to the /ping endpoint.
158- func (s * LatencySelector ) probe (ctx context.Context , addr string ) ProbeResult {
159- result := ProbeResult {
160- Addr : addr ,
161- ProbedAt : time .Now (),
162- }
163-
164- probeCtx , cancel := context .WithTimeout (ctx , s .opts .probeTimeout )
165- defer cancel ()
166-
171+ // dialEndpoint establishes a QUIC connection and creates an HTTP/3 client connection.
172+ // Returns the QUIC connection, HTTP/3 client connection, and any error.
173+ func (s * LatencySelector ) dialEndpoint (ctx context.Context , addr string ) (quic.Connection , * http3.ClientConn , error ) {
167174 // Extract hostname from address for TLS ServerName.
168175 serverName := "proxy"
169176 if host , _ , err := net .SplitHostPort (addr ); err == nil && net .ParseIP (host ) == nil {
@@ -181,55 +188,130 @@ func (s *LatencySelector) probe(ctx context.Context, addr string) ProbeResult {
181188 InitialPacketSize : 1350 ,
182189 }
183190
184- start := time .Now ()
185-
186191 // Dial QUIC connection.
187- qConn , err := quic .DialAddr (probeCtx , addr , tlsConfig , quicConfig )
192+ qConn , err := quic .DialAddr (ctx , addr , tlsConfig , quicConfig )
188193 if err != nil {
189- result .Error = err
190- slog .Debug ("Endpoint probe failed (QUIC dial)" ,
191- slog .String ("addr" , addr ),
192- slog .Any ("error" , err ))
193- return result
194+ return nil , nil , err
194195 }
195- defer qConn .CloseWithError (0 , "probe complete" )
196196
197- // Make HTTP/3 request to /ping endpoint .
197+ // Create HTTP/3 client connection .
198198 tr := & http3.Transport {EnableDatagrams : true }
199199 hConn := tr .NewClientConn (qConn )
200200
201- req , err := http .NewRequestWithContext (probeCtx , "GET" , "https://proxy/ping" , nil )
201+ return qConn , hConn , nil
202+ }
203+
204+ // pingSingle performs a single HTTP/3 GET /ping request over an existing connection.
205+ // Returns the round-trip latency or an error.
206+ func (s * LatencySelector ) pingSingle (ctx context.Context , hConn * http3.ClientConn ) (time.Duration , error ) {
207+ req , err := http .NewRequestWithContext (ctx , "GET" , "https://proxy/ping" , nil )
202208 if err != nil {
203- result .Error = err
204- slog .Debug ("Endpoint probe failed (request creation)" ,
205- slog .String ("addr" , addr ),
206- slog .Any ("error" , err ))
207- return result
209+ return 0 , err
208210 }
209211
212+ start := time .Now ()
210213 resp , err := hConn .RoundTrip (req )
214+ if err != nil {
215+ return 0 , err
216+ }
217+ defer resp .Body .Close ()
218+
219+ if resp .StatusCode != http .StatusOK {
220+ return 0 , fmt .Errorf ("ping returned status %d" , resp .StatusCode )
221+ }
222+
223+ return time .Since (start ), nil
224+ }
225+
226+ // aggregateLatencies computes the trimmed mean of latencies by discarding
227+ // the highest and lowest values (if there are enough samples) and averaging the rest.
228+ // For 1-2 samples, returns the median. For 3+ samples, discards extremes.
229+ func aggregateLatencies (pings []time.Duration ) time.Duration {
230+ if len (pings ) == 0 {
231+ return 0
232+ }
233+ if len (pings ) == 1 {
234+ return pings [0 ]
235+ }
236+
237+ // Sort the pings.
238+ sorted := make ([]time.Duration , len (pings ))
239+ copy (sorted , pings )
240+ sort .Slice (sorted , func (i , j int ) bool {
241+ return sorted [i ] < sorted [j ]
242+ })
243+
244+ if len (sorted ) == 2 {
245+ // Return median (average of two).
246+ return (sorted [0 ] + sorted [1 ]) / 2
247+ }
248+
249+ // Discard highest and lowest, average the rest.
250+ trimmed := sorted [1 : len (sorted )- 1 ]
251+ var sum time.Duration
252+ for _ , d := range trimmed {
253+ sum += d
254+ }
255+ return sum / time .Duration (len (trimmed ))
256+ }
257+
258+ // probe measures the round-trip latency to a single endpoint by making
259+ // multiple HTTP/3 requests to the /ping endpoint and aggregating the results.
260+ func (s * LatencySelector ) probe (ctx context.Context , addr string ) ProbeResult {
261+ result := ProbeResult {
262+ Addr : addr ,
263+ ProbedAt : time .Now (),
264+ }
265+
266+ probeCtx , cancel := context .WithTimeout (ctx , s .opts .probeTimeout )
267+ defer cancel ()
268+
269+ // Establish connection.
270+ qConn , hConn , err := s .dialEndpoint (probeCtx , addr )
211271 if err != nil {
212272 result .Error = err
213- slog .Debug ("Endpoint probe failed (HTTP/3 request )" ,
273+ slog .Debug ("Endpoint probe failed (QUIC dial )" ,
214274 slog .String ("addr" , addr ),
215275 slog .Any ("error" , err ))
216276 return result
217277 }
218- defer resp . Body . Close ( )
278+ defer qConn . CloseWithError ( 0 , "probe complete" )
219279
220- if resp .StatusCode != http .StatusOK {
221- result .Error = fmt .Errorf ("ping returned status %d" , resp .StatusCode )
222- slog .Debug ("Endpoint probe failed (bad status)" ,
280+ // Perform multiple pings and collect latencies.
281+ var pings []time.Duration
282+ for i := 0 ; i < s .opts .pingsPerEndpoint ; i ++ {
283+ latency , err := s .pingSingle (probeCtx , hConn )
284+ if err != nil {
285+ slog .Debug ("Endpoint ping failed" ,
286+ slog .String ("addr" , addr ),
287+ slog .Int ("ping" , i + 1 ),
288+ slog .Any ("error" , err ))
289+ // Continue to collect as many pings as possible.
290+ continue
291+ }
292+ pings = append (pings , latency )
293+ slog .Debug ("Endpoint ping succeeded" ,
223294 slog .String ("addr" , addr ),
224- slog .Int ("status" , resp .StatusCode ))
295+ slog .Int ("ping" , i + 1 ),
296+ slog .Duration ("latency" , latency ))
297+ }
298+
299+ // If no pings succeeded, return an error.
300+ if len (pings ) == 0 {
301+ result .Error = fmt .Errorf ("all %d pings failed" , s .opts .pingsPerEndpoint )
302+ slog .Debug ("Endpoint probe failed (all pings failed)" ,
303+ slog .String ("addr" , addr ))
225304 return result
226305 }
227306
228- result .Latency = time .Since (start )
307+ // Aggregate latencies using trimmed mean.
308+ result .Latency = aggregateLatencies (pings )
229309
230310 slog .Debug ("Endpoint probe succeeded" ,
231311 slog .String ("addr" , addr ),
232- slog .Duration ("latency" , result .Latency ))
312+ slog .Int ("successful_pings" , len (pings )),
313+ slog .Int ("total_pings" , s .opts .pingsPerEndpoint ),
314+ slog .Duration ("aggregated_latency" , result .Latency ))
233315
234316 return result
235317}
0 commit comments