From 7da4724c0315e71fe19070a18b5a861bde4ee297 Mon Sep 17 00:00:00 2001 From: Dimitri Herzog Date: Fri, 3 Mar 2023 21:39:44 +0100 Subject: [PATCH] feat: add cache entire DNS response (#833) (#909) --- resolver/caching_resolver.go | 47 +++++++++++++++--------------------- 1 file changed, 20 insertions(+), 27 deletions(-) diff --git a/resolver/caching_resolver.go b/resolver/caching_resolver.go index 7575b13db..5e945923d 100644 --- a/resolver/caching_resolver.go +++ b/resolver/caching_resolver.go @@ -37,8 +37,8 @@ type CachingResolver struct { // cacheValue includes query answer and prefetch flag type cacheValue struct { - answer []dns.RR - prefetch bool + resultMsg *dns.Msg + prefetch bool } // NewCachingResolver creates a new resolver instance @@ -124,7 +124,7 @@ func (r *CachingResolver) onExpired(cacheKey string) (val interface{}, ttl time. if response.Res.Rcode == dns.RcodeSuccess { evt.Bus().Publish(evt.CachingDomainPrefetched, domainName) - return cacheValue{response.Res.Answer, true}, r.adjustTTLs(response.Res.Answer) + return cacheValue{response.Res, true}, r.adjustTTLs(response.Res.Answer) } } else { util.LogOnError(fmt.Sprintf("can't prefetch '%s' ", domainName), err) @@ -170,9 +170,6 @@ func (r *CachingResolver) Resolve(request *model.Request) (response *model.Respo return r.next.Resolve(request) } - resp := new(dns.Msg) - resp.SetReply(request.Req) - for _, question := range request.Req.Question { domain := util.ExtractDomain(question) cacheKey := util.GenerateCacheKey(dns.Type(question.Qtype), domain) @@ -187,26 +184,24 @@ func (r *CachingResolver) Resolve(request *model.Request) (response *model.Respo evt.Bus().Publish(evt.CachingResultCacheHit, domain) - v, ok := val.(cacheValue) - if ok { - if v.prefetch { - // Hit from prefetch cache - evt.Bus().Publish(evt.CachingPrefetchCacheHit, domain) - } + v := val.(cacheValue) + if v.prefetch { + // Hit from prefetch cache + evt.Bus().Publish(evt.CachingPrefetchCacheHit, domain) + } - // Answer from successful request - for _, rr := range v.answer { - // make copy here since entries in cache can be modified by other goroutines (e.g. redis cache) - cp := dns.Copy(rr) - cp.Header().Ttl = uint32(ttl.Seconds()) + resp := v.resultMsg.Copy() + resp.SetReply(request.Req) + resp.Rcode = v.resultMsg.Rcode - resp.Answer = append(resp.Answer, cp) - } + // Adjust TTL + for _, rr := range resp.Answer { + rr.Header().Ttl = uint32(ttl.Seconds()) + } + if resp.Rcode == dns.RcodeSuccess { return &model.Response{Res: resp, RType: model.ResponseTypeCACHED, Reason: "CACHED"}, nil } - // Answer with response code != OK - resp.Rcode = val.(int) return &model.Response{Res: resp, RType: model.ResponseTypeCACHED, Reason: "CACHED NEGATIVE"}, nil } @@ -241,15 +236,13 @@ func (r *CachingResolver) trackQueryDomainNameCount(domain, cacheKey string, log } func (r *CachingResolver) putInCache(cacheKey string, response *model.Response, prefetch, publish bool) { - answer := response.Res.Answer - if response.Res.Rcode == dns.RcodeSuccess { // put value into cache - r.resultCache.Put(cacheKey, cacheValue{answer, prefetch}, r.adjustTTLs(answer)) + r.resultCache.Put(cacheKey, cacheValue{response.Res, prefetch}, r.adjustTTLs(response.Res.Answer)) } else if response.Res.Rcode == dns.RcodeNameError { if r.cacheTimeNegative > 0 { - // put return code if NXDOMAIN - r.resultCache.Put(cacheKey, response.Res.Rcode, r.cacheTimeNegative) + // put negative cache if result code is NXDOMAIN + r.resultCache.Put(cacheKey, cacheValue{response.Res, prefetch}, r.cacheTimeNegative) } } @@ -257,7 +250,7 @@ func (r *CachingResolver) putInCache(cacheKey string, response *model.Response, if publish && r.redisClient != nil { res := *response.Res - res.Answer = answer + res.Answer = response.Res.Answer r.redisClient.PublishCache(cacheKey, &res) } }