Skip to content

Commit

Permalink
fix: use different TTL of multiple records in answer
Browse files Browse the repository at this point in the history
  • Loading branch information
0xERR0R committed Sep 25, 2023
1 parent 69f6ae4 commit f988593
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 14 deletions.
1 change: 1 addition & 0 deletions helpertest/helper.go
Expand Up @@ -19,6 +19,7 @@ import (
const (
A = dns.Type(dns.TypeA)
AAAA = dns.Type(dns.TypeAAAA)
CNAME = dns.Type(dns.TypeCNAME)
HTTPS = dns.Type(dns.TypeHTTPS)
MX = dns.Type(dns.TypeMX)
PTR = dns.Type(dns.TypePTR)
Expand Down
7 changes: 3 additions & 4 deletions redis/redis.go
Expand Up @@ -5,6 +5,7 @@ import (
"context"
"encoding/json"
"fmt"
"math"
"strings"
"time"

Expand Down Expand Up @@ -315,11 +316,9 @@ func convertMessage(message *redisMessage, ttl time.Duration) (*CacheMessage, er

// getTTL of dns message or return defaultCacheTime if 0
func (c *Client) getTTL(dns *dns.Msg) time.Duration {
ttl := uint32(0)
ttl := uint32(math.MaxInt32)
for _, a := range dns.Answer {
if a.Header().Ttl > ttl {
ttl = a.Header().Ttl
}
ttl = min(ttl, a.Header().Ttl)
}

if ttl == 0 {
Expand Down
30 changes: 20 additions & 10 deletions resolver/caching_resolver.go
Expand Up @@ -2,6 +2,7 @@ package resolver

import (
"fmt"
"math"
"sync/atomic"
"time"

Expand Down Expand Up @@ -173,9 +174,7 @@ func (r *CachingResolver) Resolve(request *model.Request) (response *model.Respo
resp.Rcode = val.resultMsg.Rcode

// Adjust TTL
for _, rr := range resp.Answer {
rr.Header().Ttl = uint32(ttl.Seconds())
}
setTTLInCachedResponse(resp, ttl)

if resp.Rcode == dns.RcodeSuccess {
return &model.Response{Res: resp, RType: model.ResponseTypeCACHED, Reason: "CACHED"}, nil
Expand All @@ -198,6 +197,18 @@ func (r *CachingResolver) Resolve(request *model.Request) (response *model.Respo
return response, err
}

func setTTLInCachedResponse(resp *dns.Msg, ttl time.Duration) {
minTTL := uint32(math.MaxInt32)
// find smallest TTL first
for _, rr := range resp.Answer {
minTTL = min(minTTL, rr.Header().Ttl)
}

for _, rr := range resp.Answer {
rr.Header().Ttl = rr.Header().Ttl - minTTL + uint32(ttl.Seconds())
}
}

func (r *CachingResolver) trackQueryDomainNameCount(domain, cacheKey string, logger *logrus.Entry) {
if r.prefetchingNameCache != nil {
var domainCount int
Expand Down Expand Up @@ -256,16 +267,15 @@ func (r *CachingResolver) putInCache(cacheKey string, response *model.Response,

if publish && r.redisClient != nil {
res := *respCopy
res.Answer = response.Res.Answer
r.redisClient.PublishCache(cacheKey, &res)
}
}

// adjustTTLs calculates and returns the max TTL (considers also the min and max cache time)
// adjustTTLs calculates and returns the min TTL (considers also the min and max cache time)
// for all records from answer or a negative cache time for empty answer
// adjust the TTL in the answer header accordingly
func (r *CachingResolver) adjustTTLs(answer []dns.RR) (maxTTL time.Duration) {
var max uint32
func (r *CachingResolver) adjustTTLs(answer []dns.RR) (ttl time.Duration) {
minTTL := uint32(math.MaxInt32)

if len(answer) == 0 {
return r.cfg.CacheTimeNegative.ToDuration()
Expand All @@ -286,12 +296,12 @@ func (r *CachingResolver) adjustTTLs(answer []dns.RR) (maxTTL time.Duration) {
}

headerTTL := atomic.LoadUint32(&a.Header().Ttl)
if max < headerTTL {
max = headerTTL
if minTTL > headerTTL {
minTTL = headerTTL
}
}

return time.Duration(max) * time.Second
return time.Duration(minTTL) * time.Second
}

func (r *CachingResolver) publishMetricsIfEnabled(event string, val interface{}) {
Expand Down
55 changes: 55 additions & 0 deletions resolver/caching_resolver_test.go
@@ -1,6 +1,7 @@
package resolver

import (
"fmt"
"time"

"github.com/0xERR0R/blocky/cache/expirationcache"
Expand Down Expand Up @@ -136,6 +137,60 @@ var _ = Describe("CachingResolver", func() {
})
})
})
When("caching with default values is enabled", func() {
BeforeEach(func() {
rr1, err := dns.NewRR(fmt.Sprintf("%s\t%d\tIN\t%s\t%s", "example.com.", 600, A, "1.2.3.4"))
Expect(err).Should(Succeed())

rr2, err := dns.NewRR(fmt.Sprintf("%s\t%d\tIN\t%s\t%s", "example.com.", 950, CNAME, "cname.example.com"))
Expect(err).Should(Succeed())

msg := new(dns.Msg)
msg.Answer = []dns.RR{rr1, rr2}
mockAnswer = msg
})
It("should cache response and use response's TTL for multiple records", func() {
By("first request", func() {
result, err := sut.Resolve(newRequest("example.com.", A))
Expect(err).Should(Succeed())
Expect(result).
Should(
SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
WithTransform(ToAnswer, SatisfyAll(
HaveLen(2),
)),
))

Expect(result.Res.Answer[0]).Should(HaveTTL(BeNumerically("==", 600)))
Expect(result.Res.Answer[1]).Should(HaveTTL(BeNumerically("==", 950)))

Expect(m.Calls).Should(HaveLen(1))
})

By("second request", func() {
Eventually(func(g Gomega) {
result, err := sut.Resolve(newRequest("example.com.", A))
g.Expect(err).Should(Succeed())
g.Expect(result).
Should(
SatisfyAll(
HaveResponseType(ResponseTypeCACHED),
HaveReturnCode(dns.RcodeSuccess),
WithTransform(ToAnswer, SatisfyAll(
HaveLen(2),
))))

g.Expect(result.Res.Answer[0]).Should(HaveTTL(BeNumerically("<=", 599)))
g.Expect(result.Res.Answer[1]).Should(HaveTTL(BeNumerically("<=", 949)))

// still one call to upstream
g.Expect(m.Calls).Should(HaveLen(1))
}, "1s").Should(Succeed())
})
})
})
When("min caching time is defined", func() {
BeforeEach(func() {
sutConfig = config.CachingConfig{
Expand Down

0 comments on commit f988593

Please sign in to comment.