Skip to content

Commit ca4745a

Browse files
committed
[dns] fix ndots search domain resolution and start DNS proxy in tunnel runtime
The internal DNS server had identical logic in both branches of the ndots check — it always queried the bare name first, so search domains like svc.cluster.local never got priority even with ndots:5. Restructure ServeDNS so that when dotCount < ndots, search domains are tried before the bare name, matching standard resolver behavior. Also start the DNS proxy (127.0.0.1:8053) in the `apoxy run` tunnel path, which was missing compared to `apoxy tunnel run`. Add a configurable DNSAddr field to TunnelConfig for this.
1 parent 0e1f8d8 commit ca4745a

File tree

4 files changed

+310
-13
lines changed

4 files changed

+310
-13
lines changed

api/config/v1alpha1/config_types.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,10 @@ type TunnelConfig struct {
133133
// Only use for development/testing.
134134
// +optional
135135
InsecureSkipVerify bool `json:"insecureSkipVerify,omitempty"`
136+
// DNSAddr is the address to listen on for the internal DNS proxy.
137+
// Defaults to "127.0.0.1:8053". Set to empty string to disable.
138+
// +optional
139+
DNSAddr string `json:"dnsAddr,omitempty"`
136140
}
137141

138142
// TunnelMode is the mode of the tunnel.

pkg/cmd/run/tunnel.go

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ import (
3535
corev1alpha "github.com/apoxy-dev/apoxy/api/core/v1alpha"
3636
"github.com/apoxy-dev/apoxy/client/versioned"
3737
"github.com/apoxy-dev/apoxy/pkg/log"
38+
"github.com/apoxy-dev/apoxy/pkg/net/dns"
3839
"github.com/apoxy-dev/apoxy/pkg/tunnel"
3940
"github.com/apoxy-dev/apoxy/pkg/tunnel/endpointselect"
4041
)
@@ -68,6 +69,9 @@ func resolveTunnelConfig(in *configv1alpha1.TunnelConfig) *configv1alpha1.Tunnel
6869
if out.SocksPort == nil {
6970
out.SocksPort = ptr.To(1080)
7071
}
72+
if out.DNSAddr == "" {
73+
out.DNSAddr = "127.0.0.1:8053"
74+
}
7175
return out
7276
}
7377

@@ -482,6 +486,19 @@ func runTunnel(ctx context.Context, cfg *configv1alpha1.Config, tc *configv1alph
482486
return mgr.Start(gctx)
483487
})
484488

489+
// Start internal DNS proxy so search domain expansion works.
490+
if tc.DNSAddr != "" {
491+
dnsAddr := tc.DNSAddr
492+
g.Go(func() error {
493+
slog.Info("Starting DNS proxy", slog.String("address", dnsAddr))
494+
if err := dns.ListenAndServe(dnsAddr); err != nil {
495+
slog.Error("DNS server failed", slog.Any("error", err))
496+
return fmt.Errorf("dns server: %w", err)
497+
}
498+
return nil
499+
})
500+
}
501+
485502
// Start the router — each pod independently connects to the tunnel server.
486503
g.Go(func() error {
487504
if err := r.Start(gctx); err != nil {

pkg/net/dns/upstream.go

Lines changed: 34 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,15 @@ import (
1818
"github.com/apoxy-dev/apoxy/pkg/log"
1919
)
2020

21-
const (
21+
var (
2222
upstreamPort = 53
2323
)
2424

25+
// setUpstreamPort overrides the upstream port (for testing).
26+
func setUpstreamPort(port int) {
27+
upstreamPort = port
28+
}
29+
2530
// upstream is a plugin that sends queries to a random upstream.
2631
type upstream struct {
2732
Next plugin.Handler
@@ -49,29 +54,43 @@ func (u *upstream) ServeDNS(ctx context.Context, w mdns.ResponseWriter, r *mdns.
4954
return u.Next.ServeDNS(ctx, w, r)
5055
}
5156

52-
// Try the original query first
53-
response, rcode, err := u.queryUpstream(r)
54-
if err != nil {
55-
return rcode, err
56-
}
57+
var response *mdns.Msg
5758

58-
// Apply search domain logic based on ndots and search domains
5959
if len(u.SearchDomains) > 0 && len(r.Question) > 0 {
6060
originalName := r.Question[0].Name
6161
dotCount := u.countDots(originalName)
6262

63-
// Determine search strategy based on ndots
6463
if dotCount < u.Ndots {
65-
// Try search domains first, then original name if all fail
66-
if response.Rcode == mdns.RcodeNameError {
67-
response = u.trySearchDomains(r, originalName, response)
64+
// Few dots: try search domains first, bare name as fallback.
65+
response = u.trySearchDomains(r, originalName, nil)
66+
if response == nil {
67+
var rcode int
68+
var err error
69+
response, rcode, err = u.queryUpstream(r)
70+
if err != nil {
71+
return rcode, err
72+
}
6873
}
6974
} else {
70-
// Try original name first (already done), then search domains if it failed
75+
// Enough dots: try bare name first, search domains as fallback.
76+
var rcode int
77+
var err error
78+
response, rcode, err = u.queryUpstream(r)
79+
if err != nil {
80+
return rcode, err
81+
}
7182
if response.Rcode == mdns.RcodeNameError {
7283
response = u.trySearchDomains(r, originalName, response)
7384
}
7485
}
86+
} else {
87+
// No search domains: query upstream directly.
88+
var rcode int
89+
var err error
90+
response, rcode, err = u.queryUpstream(r)
91+
if err != nil {
92+
return rcode, err
93+
}
7594
}
7695

7796
// Block responses referencing non-global unicast IPs if enabled
@@ -127,6 +146,8 @@ func (u *upstream) countDots(name string) int {
127146
}
128147

129148
// trySearchDomains attempts to resolve a name using configured search domains.
149+
// fallbackResponse may be nil when called before any upstream query (ndots path).
150+
// Returns nil if no search domain succeeded and fallbackResponse was nil.
130151
func (u *upstream) trySearchDomains(originalQuery *mdns.Msg, originalName string, fallbackResponse *mdns.Msg) *mdns.Msg {
131152
for _, domain := range u.SearchDomains {
132153
// Create a new query with the search domain appended
@@ -154,7 +175,7 @@ func (u *upstream) trySearchDomains(originalQuery *mdns.Msg, originalName string
154175
return searchResponse
155176
}
156177
}
157-
// If no search domain worked, return the original response
178+
// If no search domain worked, return the fallback (may be nil).
158179
return fallbackResponse
159180
}
160181

pkg/net/dns/upstream_test.go

Lines changed: 255 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,255 @@
1+
package dns
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"net"
7+
"sync"
8+
"testing"
9+
10+
mdns "github.com/miekg/dns"
11+
"github.com/stretchr/testify/assert"
12+
"github.com/stretchr/testify/require"
13+
)
14+
15+
// queryLog records the FQDNs queried in order.
16+
type queryLog struct {
17+
mu sync.Mutex
18+
queries []string
19+
}
20+
21+
func (l *queryLog) append(name string) {
22+
l.mu.Lock()
23+
defer l.mu.Unlock()
24+
l.queries = append(l.queries, name)
25+
}
26+
27+
func (l *queryLog) list() []string {
28+
l.mu.Lock()
29+
defer l.mu.Unlock()
30+
out := make([]string, len(l.queries))
31+
copy(out, l.queries)
32+
return out
33+
}
34+
35+
// fakeUpstreamServer starts a local DNS server that answers queries based on
36+
// a set of known names. Unknown names get NXDOMAIN.
37+
type fakeUpstreamServer struct {
38+
addr string
39+
known map[string]string // FQDN -> IP answer (A record)
40+
log *queryLog
41+
server *mdns.Server
42+
udpConn net.PacketConn
43+
}
44+
45+
func newFakeUpstream(t *testing.T, known map[string]string, log *queryLog) *fakeUpstreamServer {
46+
t.Helper()
47+
pc, err := net.ListenPacket("udp", "127.0.0.1:0")
48+
require.NoError(t, err)
49+
50+
f := &fakeUpstreamServer{
51+
addr: pc.LocalAddr().String(),
52+
known: known,
53+
log: log,
54+
udpConn: pc,
55+
}
56+
mux := mdns.NewServeMux()
57+
mux.HandleFunc(".", f.handler)
58+
f.server = &mdns.Server{
59+
PacketConn: pc,
60+
Handler: mux,
61+
}
62+
go func() { _ = f.server.ActivateAndServe() }()
63+
return f
64+
}
65+
66+
func (f *fakeUpstreamServer) handler(w mdns.ResponseWriter, r *mdns.Msg) {
67+
q := r.Question[0]
68+
f.log.append(q.Name)
69+
70+
resp := new(mdns.Msg)
71+
resp.SetReply(r)
72+
73+
if ip, ok := f.known[q.Name]; ok && q.Qtype == mdns.TypeA {
74+
resp.Answer = append(resp.Answer, &mdns.A{
75+
Hdr: mdns.RR_Header{
76+
Name: q.Name,
77+
Rrtype: mdns.TypeA,
78+
Class: mdns.ClassINET,
79+
Ttl: 60,
80+
},
81+
A: net.ParseIP(ip),
82+
})
83+
} else {
84+
resp.Rcode = mdns.RcodeNameError
85+
}
86+
w.WriteMsg(resp)
87+
}
88+
89+
func (f *fakeUpstreamServer) close() {
90+
f.server.Shutdown()
91+
f.udpConn.Close()
92+
}
93+
94+
// upstreamPort returns just the port number so we can override the upstream
95+
// client to hit it.
96+
func (f *fakeUpstreamServer) port() string {
97+
_, port, _ := net.SplitHostPort(f.addr)
98+
return port
99+
}
100+
101+
// serveDNSHelper builds an upstream, sends a query, and returns the response.
102+
func serveDNSHelper(t *testing.T, u *upstream, name string, qtype uint16) *mdns.Msg {
103+
t.Helper()
104+
r := new(mdns.Msg)
105+
r.SetQuestion(mdns.Fqdn(name), qtype)
106+
107+
rec := &dnsRecorder{}
108+
_, err := u.ServeDNS(context.Background(), rec, r)
109+
require.NoError(t, err)
110+
require.NotNil(t, rec.msg, "expected a DNS response")
111+
return rec.msg
112+
}
113+
114+
// dnsRecorder captures the response written via WriteMsg.
115+
type dnsRecorder struct {
116+
msg *mdns.Msg
117+
}
118+
119+
func (r *dnsRecorder) WriteMsg(m *mdns.Msg) error { r.msg = m; return nil }
120+
func (r *dnsRecorder) LocalAddr() net.Addr { return nil }
121+
func (r *dnsRecorder) RemoteAddr() net.Addr { return nil }
122+
func (r *dnsRecorder) Write([]byte) (int, error) { return 0, fmt.Errorf("not implemented") }
123+
func (r *dnsRecorder) Close() error { return nil }
124+
func (r *dnsRecorder) TsigStatus() error { return nil }
125+
func (r *dnsRecorder) TsigTimersOnly(bool) {}
126+
func (r *dnsRecorder) Hijack() {}
127+
128+
func TestServeDNS_NdotsSearchDomainPriority(t *testing.T) {
129+
tests := []struct {
130+
name string
131+
queryName string // domain to query
132+
ndots int
133+
searchDomains []string
134+
knownNames map[string]string // FQDN->IP that the fake upstream knows
135+
wantRcode int
136+
wantAnswer string // expected IP in answer, "" if NXDOMAIN expected
137+
wantFirstQuery string // first FQDN the upstream should have received
138+
}{
139+
{
140+
name: "dotCount < ndots: search domains tried first",
141+
queryName: "apache.default",
142+
ndots: 5,
143+
searchDomains: []string{"svc.cluster.local", "cluster.local"},
144+
knownNames: map[string]string{
145+
"apache.default.svc.cluster.local.": "10.0.0.1",
146+
},
147+
wantRcode: mdns.RcodeSuccess,
148+
wantAnswer: "10.0.0.1",
149+
wantFirstQuery: "apache.default.svc.cluster.local.",
150+
},
151+
{
152+
name: "dotCount >= ndots: bare name tried first",
153+
queryName: "api.example.com",
154+
ndots: 2,
155+
searchDomains: []string{"svc.cluster.local"},
156+
knownNames: map[string]string{
157+
"api.example.com.": "1.2.3.4",
158+
},
159+
wantRcode: mdns.RcodeSuccess,
160+
wantAnswer: "1.2.3.4",
161+
wantFirstQuery: "api.example.com.",
162+
},
163+
{
164+
name: "dotCount >= ndots: bare fails, falls back to search domain",
165+
queryName: "api.example.com",
166+
ndots: 2,
167+
searchDomains: []string{"svc.cluster.local"},
168+
knownNames: map[string]string{
169+
"api.example.com.svc.cluster.local.": "10.0.0.2",
170+
},
171+
wantRcode: mdns.RcodeSuccess,
172+
wantAnswer: "10.0.0.2",
173+
wantFirstQuery: "api.example.com.",
174+
},
175+
{
176+
name: "no search domains: direct upstream query",
177+
queryName: "apache.default",
178+
ndots: 5,
179+
searchDomains: nil,
180+
knownNames: map[string]string{
181+
"apache.default.": "10.0.0.3",
182+
},
183+
wantRcode: mdns.RcodeSuccess,
184+
wantAnswer: "10.0.0.3",
185+
wantFirstQuery: "apache.default.",
186+
},
187+
{
188+
name: "dotCount < ndots: all search domains fail, falls back to bare name",
189+
queryName: "myservice.ns",
190+
ndots: 5,
191+
searchDomains: []string{"svc.cluster.local", "cluster.local"},
192+
knownNames: map[string]string{
193+
"myservice.ns.": "10.0.0.4",
194+
},
195+
wantRcode: mdns.RcodeSuccess,
196+
wantAnswer: "10.0.0.4",
197+
wantFirstQuery: "myservice.ns.svc.cluster.local.",
198+
},
199+
{
200+
name: "dotCount < ndots: everything fails returns NXDOMAIN",
201+
queryName: "nonexistent.svc",
202+
ndots: 5,
203+
searchDomains: []string{"svc.cluster.local"},
204+
knownNames: map[string]string{},
205+
wantRcode: mdns.RcodeNameError,
206+
wantAnswer: "",
207+
},
208+
}
209+
210+
for _, tt := range tests {
211+
t.Run(tt.name, func(t *testing.T) {
212+
log := &queryLog{}
213+
fake := newFakeUpstream(t, tt.knownNames, log)
214+
defer fake.close()
215+
216+
host, port, err := net.SplitHostPort(fake.addr)
217+
require.NoError(t, err)
218+
219+
u := &upstream{
220+
Upstreams: []string{host},
221+
SearchDomains: tt.searchDomains,
222+
Ndots: tt.ndots,
223+
}
224+
// Override the upstream port for testing.
225+
origPort := upstreamPort
226+
defer func() { setUpstreamPort(origPort) }()
227+
setUpstreamPort(mustAtoi(port))
228+
229+
resp := serveDNSHelper(t, u, tt.queryName, mdns.TypeA)
230+
231+
assert.Equal(t, tt.wantRcode, resp.Rcode, "unexpected rcode")
232+
233+
if tt.wantAnswer != "" {
234+
require.Len(t, resp.Answer, 1, "expected one answer")
235+
a, ok := resp.Answer[0].(*mdns.A)
236+
require.True(t, ok)
237+
assert.Equal(t, tt.wantAnswer, a.A.String())
238+
}
239+
240+
if tt.wantFirstQuery != "" {
241+
queries := log.list()
242+
require.NotEmpty(t, queries, "expected at least one upstream query")
243+
assert.Equal(t, tt.wantFirstQuery, queries[0], "first query mismatch")
244+
}
245+
})
246+
}
247+
}
248+
249+
func mustAtoi(s string) int {
250+
n := 0
251+
for _, c := range s {
252+
n = n*10 + int(c-'0')
253+
}
254+
return n
255+
}

0 commit comments

Comments
 (0)