/
exchange.go
152 lines (125 loc) · 3.91 KB
/
exchange.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
package proxy
import (
"fmt"
"time"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/miekg/dns"
"gonum.org/v1/gonum/stat/sampleuv"
)
// exchangeUpstreams resolves req using the given upstreams. It returns the DNS
// response, the upstream that successfully resolved the request, and the error
// if any.
func (p *Proxy) exchangeUpstreams(
req *dns.Msg,
ups []upstream.Upstream,
) (resp *dns.Msg, u upstream.Upstream, err error) {
switch p.UpstreamMode {
case UModeParallel:
return upstream.ExchangeParallel(ups, req)
case UModeFastestAddr:
switch req.Question[0].Qtype {
case dns.TypeA, dns.TypeAAAA:
return p.fastestAddr.ExchangeFastest(req, ups)
default:
// Go on to the load-balancing mode.
}
default:
// Go on to the load-balancing mode.
}
if len(ups) == 1 {
u = ups[0]
resp, _, err = exchange(u, req, p.time)
// TODO(e.burkov): p.updateRTT(u.Address(), elapsed)
return resp, u, err
}
w := sampleuv.NewWeighted(p.calcWeights(ups), p.randSrc)
var errs []error
for i, ok := w.Take(); ok; i, ok = w.Take() {
u = ups[i]
var elapsed time.Duration
resp, elapsed, err = exchange(u, req, p.time)
if err == nil {
p.updateRTT(u.Address(), elapsed)
return resp, u, nil
}
errs = append(errs, err)
// TODO(e.burkov): Use the actual configured timeout or, perhaps, the
// actual measured elapsed time.
p.updateRTT(u.Address(), defaultTimeout)
}
err = fmt.Errorf("all upstreams failed to exchange request: %w", errors.Join(errs...))
return nil, nil, err
}
// exchange returns the result of the DNS request exchange with the given
// upstream and the elapsed time in milliseconds. It uses the given clock to
// measure the request duration.
func exchange(u upstream.Upstream, req *dns.Msg, c clock) (resp *dns.Msg, dur time.Duration, err error) {
startTime := c.Now()
reply, err := u.Exchange(req)
// Don't use [time.Since] because it uses [time.Now].
dur = c.Now().Sub(startTime)
addr := u.Address()
if err != nil {
log.Error(
"dnsproxy: upstream %s failed to exchange %s in %s: %s",
addr,
req.Question[0].String(),
dur,
err,
)
} else {
log.Debug(
"dnsproxy: upstream %s successfully finished exchange of %s; elapsed %s",
addr,
req.Question[0].String(),
dur,
)
}
return reply, dur, err
}
// upstreamRTTStats is the statistics for a single upstream's round-trip time.
type upstreamRTTStats struct {
// rttSum is the sum of all the round-trip times in microseconds. The
// float64 type is used since it's capable of representing about 285 years
// in microseconds.
rttSum float64
// reqNum is the number of requests to the upstream. The float64 type is
// used since to avoid unnecessary type conversions.
reqNum float64
}
// update returns updated stats after adding given RTT.
func (stats upstreamRTTStats) update(rtt time.Duration) (updated upstreamRTTStats) {
return upstreamRTTStats{
rttSum: stats.rttSum + float64(rtt.Microseconds()),
reqNum: stats.reqNum + 1,
}
}
// calcWeights returns the slice of weights, each corresponding to the upstream
// with the same index in the given slice.
func (p *Proxy) calcWeights(ups []upstream.Upstream) (weights []float64) {
weights = make([]float64, 0, len(ups))
p.rttLock.Lock()
defer p.rttLock.Unlock()
for _, u := range ups {
stat := p.upstreamRTTStats[u.Address()]
if stat.rttSum == 0 || stat.reqNum == 0 {
// Use 1 as the default weight.
weights = append(weights, 1)
} else {
weights = append(weights, 1/(stat.rttSum/stat.reqNum))
}
}
return weights
}
// updateRTT updates the round-trip time in [upstreamRTTStats] for given
// address.
func (p *Proxy) updateRTT(address string, rtt time.Duration) {
p.rttLock.Lock()
defer p.rttLock.Unlock()
if p.upstreamRTTStats == nil {
p.upstreamRTTStats = map[string]upstreamRTTStats{}
}
p.upstreamRTTStats[address] = p.upstreamRTTStats[address].update(rtt)
}