|
| 1 | +package dns |
| 2 | + |
| 3 | +import ( |
| 4 | + "context" |
| 5 | + "fmt" |
| 6 | + "net" |
| 7 | + "sync" |
| 8 | + "testing" |
| 9 | + |
| 10 | + mdns "github.com/miekg/dns" |
| 11 | + "github.com/stretchr/testify/assert" |
| 12 | + "github.com/stretchr/testify/require" |
| 13 | +) |
| 14 | + |
| 15 | +// queryLog records the FQDNs queried in order. |
| 16 | +type queryLog struct { |
| 17 | + mu sync.Mutex |
| 18 | + queries []string |
| 19 | +} |
| 20 | + |
| 21 | +func (l *queryLog) append(name string) { |
| 22 | + l.mu.Lock() |
| 23 | + defer l.mu.Unlock() |
| 24 | + l.queries = append(l.queries, name) |
| 25 | +} |
| 26 | + |
| 27 | +func (l *queryLog) list() []string { |
| 28 | + l.mu.Lock() |
| 29 | + defer l.mu.Unlock() |
| 30 | + out := make([]string, len(l.queries)) |
| 31 | + copy(out, l.queries) |
| 32 | + return out |
| 33 | +} |
| 34 | + |
| 35 | +// fakeUpstreamServer starts a local DNS server that answers queries based on |
| 36 | +// a set of known names. Unknown names get NXDOMAIN. |
| 37 | +type fakeUpstreamServer struct { |
| 38 | + addr string |
| 39 | + known map[string]string // FQDN -> IP answer (A record) |
| 40 | + log *queryLog |
| 41 | + server *mdns.Server |
| 42 | + udpConn net.PacketConn |
| 43 | +} |
| 44 | + |
| 45 | +func newFakeUpstream(t *testing.T, known map[string]string, log *queryLog) *fakeUpstreamServer { |
| 46 | + t.Helper() |
| 47 | + pc, err := net.ListenPacket("udp", "127.0.0.1:0") |
| 48 | + require.NoError(t, err) |
| 49 | + |
| 50 | + f := &fakeUpstreamServer{ |
| 51 | + addr: pc.LocalAddr().String(), |
| 52 | + known: known, |
| 53 | + log: log, |
| 54 | + udpConn: pc, |
| 55 | + } |
| 56 | + mux := mdns.NewServeMux() |
| 57 | + mux.HandleFunc(".", f.handler) |
| 58 | + f.server = &mdns.Server{ |
| 59 | + PacketConn: pc, |
| 60 | + Handler: mux, |
| 61 | + } |
| 62 | + go func() { _ = f.server.ActivateAndServe() }() |
| 63 | + return f |
| 64 | +} |
| 65 | + |
| 66 | +func (f *fakeUpstreamServer) handler(w mdns.ResponseWriter, r *mdns.Msg) { |
| 67 | + q := r.Question[0] |
| 68 | + f.log.append(q.Name) |
| 69 | + |
| 70 | + resp := new(mdns.Msg) |
| 71 | + resp.SetReply(r) |
| 72 | + |
| 73 | + if ip, ok := f.known[q.Name]; ok && q.Qtype == mdns.TypeA { |
| 74 | + resp.Answer = append(resp.Answer, &mdns.A{ |
| 75 | + Hdr: mdns.RR_Header{ |
| 76 | + Name: q.Name, |
| 77 | + Rrtype: mdns.TypeA, |
| 78 | + Class: mdns.ClassINET, |
| 79 | + Ttl: 60, |
| 80 | + }, |
| 81 | + A: net.ParseIP(ip), |
| 82 | + }) |
| 83 | + } else { |
| 84 | + resp.Rcode = mdns.RcodeNameError |
| 85 | + } |
| 86 | + w.WriteMsg(resp) |
| 87 | +} |
| 88 | + |
| 89 | +func (f *fakeUpstreamServer) close() { |
| 90 | + f.server.Shutdown() |
| 91 | + f.udpConn.Close() |
| 92 | +} |
| 93 | + |
| 94 | +// upstreamPort returns just the port number so we can override the upstream |
| 95 | +// client to hit it. |
| 96 | +func (f *fakeUpstreamServer) port() string { |
| 97 | + _, port, _ := net.SplitHostPort(f.addr) |
| 98 | + return port |
| 99 | +} |
| 100 | + |
| 101 | +// serveDNSHelper builds an upstream, sends a query, and returns the response. |
| 102 | +func serveDNSHelper(t *testing.T, u *upstream, name string, qtype uint16) *mdns.Msg { |
| 103 | + t.Helper() |
| 104 | + r := new(mdns.Msg) |
| 105 | + r.SetQuestion(mdns.Fqdn(name), qtype) |
| 106 | + |
| 107 | + rec := &dnsRecorder{} |
| 108 | + _, err := u.ServeDNS(context.Background(), rec, r) |
| 109 | + require.NoError(t, err) |
| 110 | + require.NotNil(t, rec.msg, "expected a DNS response") |
| 111 | + return rec.msg |
| 112 | +} |
| 113 | + |
| 114 | +// dnsRecorder captures the response written via WriteMsg. |
| 115 | +type dnsRecorder struct { |
| 116 | + msg *mdns.Msg |
| 117 | +} |
| 118 | + |
| 119 | +func (r *dnsRecorder) WriteMsg(m *mdns.Msg) error { r.msg = m; return nil } |
| 120 | +func (r *dnsRecorder) LocalAddr() net.Addr { return nil } |
| 121 | +func (r *dnsRecorder) RemoteAddr() net.Addr { return nil } |
| 122 | +func (r *dnsRecorder) Write([]byte) (int, error) { return 0, fmt.Errorf("not implemented") } |
| 123 | +func (r *dnsRecorder) Close() error { return nil } |
| 124 | +func (r *dnsRecorder) TsigStatus() error { return nil } |
| 125 | +func (r *dnsRecorder) TsigTimersOnly(bool) {} |
| 126 | +func (r *dnsRecorder) Hijack() {} |
| 127 | + |
| 128 | +func TestServeDNS_NdotsSearchDomainPriority(t *testing.T) { |
| 129 | + tests := []struct { |
| 130 | + name string |
| 131 | + queryName string // domain to query |
| 132 | + ndots int |
| 133 | + searchDomains []string |
| 134 | + knownNames map[string]string // FQDN->IP that the fake upstream knows |
| 135 | + wantRcode int |
| 136 | + wantAnswer string // expected IP in answer, "" if NXDOMAIN expected |
| 137 | + wantFirstQuery string // first FQDN the upstream should have received |
| 138 | + }{ |
| 139 | + { |
| 140 | + name: "dotCount < ndots: search domains tried first", |
| 141 | + queryName: "apache.default", |
| 142 | + ndots: 5, |
| 143 | + searchDomains: []string{"svc.cluster.local", "cluster.local"}, |
| 144 | + knownNames: map[string]string{ |
| 145 | + "apache.default.svc.cluster.local.": "10.0.0.1", |
| 146 | + }, |
| 147 | + wantRcode: mdns.RcodeSuccess, |
| 148 | + wantAnswer: "10.0.0.1", |
| 149 | + wantFirstQuery: "apache.default.svc.cluster.local.", |
| 150 | + }, |
| 151 | + { |
| 152 | + name: "dotCount >= ndots: bare name tried first", |
| 153 | + queryName: "api.example.com", |
| 154 | + ndots: 2, |
| 155 | + searchDomains: []string{"svc.cluster.local"}, |
| 156 | + knownNames: map[string]string{ |
| 157 | + "api.example.com.": "1.2.3.4", |
| 158 | + }, |
| 159 | + wantRcode: mdns.RcodeSuccess, |
| 160 | + wantAnswer: "1.2.3.4", |
| 161 | + wantFirstQuery: "api.example.com.", |
| 162 | + }, |
| 163 | + { |
| 164 | + name: "dotCount >= ndots: bare fails, falls back to search domain", |
| 165 | + queryName: "api.example.com", |
| 166 | + ndots: 2, |
| 167 | + searchDomains: []string{"svc.cluster.local"}, |
| 168 | + knownNames: map[string]string{ |
| 169 | + "api.example.com.svc.cluster.local.": "10.0.0.2", |
| 170 | + }, |
| 171 | + wantRcode: mdns.RcodeSuccess, |
| 172 | + wantAnswer: "10.0.0.2", |
| 173 | + wantFirstQuery: "api.example.com.", |
| 174 | + }, |
| 175 | + { |
| 176 | + name: "no search domains: direct upstream query", |
| 177 | + queryName: "apache.default", |
| 178 | + ndots: 5, |
| 179 | + searchDomains: nil, |
| 180 | + knownNames: map[string]string{ |
| 181 | + "apache.default.": "10.0.0.3", |
| 182 | + }, |
| 183 | + wantRcode: mdns.RcodeSuccess, |
| 184 | + wantAnswer: "10.0.0.3", |
| 185 | + wantFirstQuery: "apache.default.", |
| 186 | + }, |
| 187 | + { |
| 188 | + name: "dotCount < ndots: all search domains fail, falls back to bare name", |
| 189 | + queryName: "myservice.ns", |
| 190 | + ndots: 5, |
| 191 | + searchDomains: []string{"svc.cluster.local", "cluster.local"}, |
| 192 | + knownNames: map[string]string{ |
| 193 | + "myservice.ns.": "10.0.0.4", |
| 194 | + }, |
| 195 | + wantRcode: mdns.RcodeSuccess, |
| 196 | + wantAnswer: "10.0.0.4", |
| 197 | + wantFirstQuery: "myservice.ns.svc.cluster.local.", |
| 198 | + }, |
| 199 | + { |
| 200 | + name: "dotCount < ndots: everything fails returns NXDOMAIN", |
| 201 | + queryName: "nonexistent.svc", |
| 202 | + ndots: 5, |
| 203 | + searchDomains: []string{"svc.cluster.local"}, |
| 204 | + knownNames: map[string]string{}, |
| 205 | + wantRcode: mdns.RcodeNameError, |
| 206 | + wantAnswer: "", |
| 207 | + }, |
| 208 | + } |
| 209 | + |
| 210 | + for _, tt := range tests { |
| 211 | + t.Run(tt.name, func(t *testing.T) { |
| 212 | + log := &queryLog{} |
| 213 | + fake := newFakeUpstream(t, tt.knownNames, log) |
| 214 | + defer fake.close() |
| 215 | + |
| 216 | + host, port, err := net.SplitHostPort(fake.addr) |
| 217 | + require.NoError(t, err) |
| 218 | + |
| 219 | + u := &upstream{ |
| 220 | + Upstreams: []string{host}, |
| 221 | + SearchDomains: tt.searchDomains, |
| 222 | + Ndots: tt.ndots, |
| 223 | + } |
| 224 | + // Override the upstream port for testing. |
| 225 | + origPort := upstreamPort |
| 226 | + defer func() { setUpstreamPort(origPort) }() |
| 227 | + setUpstreamPort(mustAtoi(port)) |
| 228 | + |
| 229 | + resp := serveDNSHelper(t, u, tt.queryName, mdns.TypeA) |
| 230 | + |
| 231 | + assert.Equal(t, tt.wantRcode, resp.Rcode, "unexpected rcode") |
| 232 | + |
| 233 | + if tt.wantAnswer != "" { |
| 234 | + require.Len(t, resp.Answer, 1, "expected one answer") |
| 235 | + a, ok := resp.Answer[0].(*mdns.A) |
| 236 | + require.True(t, ok) |
| 237 | + assert.Equal(t, tt.wantAnswer, a.A.String()) |
| 238 | + } |
| 239 | + |
| 240 | + if tt.wantFirstQuery != "" { |
| 241 | + queries := log.list() |
| 242 | + require.NotEmpty(t, queries, "expected at least one upstream query") |
| 243 | + assert.Equal(t, tt.wantFirstQuery, queries[0], "first query mismatch") |
| 244 | + } |
| 245 | + }) |
| 246 | + } |
| 247 | +} |
| 248 | + |
| 249 | +func mustAtoi(s string) int { |
| 250 | + n := 0 |
| 251 | + for _, c := range s { |
| 252 | + n = n*10 + int(c-'0') |
| 253 | + } |
| 254 | + return n |
| 255 | +} |
0 commit comments