From 5c8c7171bfe591f34a18183501492ca271752471 Mon Sep 17 00:00:00 2001 From: Vladislav Yarmak Date: Fri, 15 May 2026 00:24:11 +0300 Subject: [PATCH 1/4] enable short DNS cache by default to make DP go easy on forwarders --- main.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/main.go b/main.go index 5f1be85..1413886 100644 --- a/main.go +++ b/main.go @@ -475,7 +475,7 @@ func parse_args() *CLIArgs { return nil }) flag.Var(&args.dnsPreferAddress, "dns-prefer-address", "address resolution preference (none/ipv4/ipv6)") - flag.DurationVar(&args.dnsCacheTTL, "dns-cache-ttl", 0, "enable DNS cache with specified fixed TTL") + flag.DurationVar(&args.dnsCacheTTL, "dns-cache-ttl", 10, "enable DNS cache with specified fixed TTL") flag.DurationVar(&args.dnsCacheNegTTL, "dns-cache-neg-ttl", time.Second, "TTL for negative responses of DNS cache") flag.DurationVar(&args.dnsCacheTimeout, "dns-cache-timeout", 5*time.Second, "timeout for shared resolves of DNS cache") flag.DurationVar(&args.reqHeaderTimeout, "req-header-timeout", 30*time.Second, "amount of time allowed to read request headers") From a84b0dd6fbaef921d39707ea178867d588c0313b Mon Sep 17 00:00:00 2001 From: Vladislav Yarmak Date: Fri, 15 May 2026 01:27:57 +0300 Subject: [PATCH 2/4] extend DNS cache to be usable post filter; move resolver construction before any dialers --- dialer/rescache.go | 28 +++++++++++++++------------- main.go | 20 +++++++++++--------- 2 files changed, 26 insertions(+), 22 deletions(-) diff --git a/dialer/rescache.go b/dialer/rescache.go index bf7f0eb..93cf6da 100644 --- a/dialer/rescache.go +++ b/dialer/rescache.go @@ -28,20 +28,20 @@ type resolverCacheValue struct { } type NameResolveCachingDialer struct { - resolver Resolver - cache secache.Cache[resolverCacheKey, *resolverCacheValue] - sf singleflight.Group - posTTL time.Duration - negTTL time.Duration - timeout time.Duration - next Dialer + resolver Resolver + preFilter bool + cache secache.Cache[resolverCacheKey, *resolverCacheValue] + sf singleflight.Group + posTTL time.Duration + negTTL time.Duration + timeout time.Duration + next Dialer } -func NewNameResolveCachingDialer(next Dialer, resolver Resolver, posTTL, negTTL, timeout time.Duration) *NameResolveCachingDialer { - // func(c *ttlcache.Cache[resolverCacheKey, resolverCacheValue], key resolverCacheKey) *ttlcache.Item[resolverCacheKey, resolverCacheValue] { - // }, +func NewNameResolveCachingDialer(next Dialer, preFilter bool, resolver Resolver, posTTL, negTTL, timeout time.Duration) *NameResolveCachingDialer { return &NameResolveCachingDialer{ - resolver: resolver, + resolver: resolver, + preFilter: preFilter, cache: *(secache.New[resolverCacheKey, *resolverCacheValue]( 3, func(key resolverCacheKey, item *resolverCacheValue) bool { @@ -56,7 +56,7 @@ func NewNameResolveCachingDialer(next Dialer, resolver Resolver, posTTL, negTTL, } func (nrcd *NameResolveCachingDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { - if WantsHostname(ctx, network, address, nrcd.next) { + if nrcd.preFilter && WantsHostname(ctx, network, address, nrcd.next) { return nrcd.next.DialContext(ctx, network, address) } @@ -116,7 +116,9 @@ func (nrcd *NameResolveCachingDialer) DialContext(ctx context.Context, network, return nil, res.err } - ctx = dto.OrigDstToContext(ctx, address) + if nrcd.preFilter { + ctx = dto.OrigDstToContext(ctx, address) + } var dialErr error var conn net.Conn diff --git a/main.go b/main.go index 1413886..afa4a39 100644 --- a/main.go +++ b/main.go @@ -660,6 +660,16 @@ func run() int { filterRoot = access.NewDstAddrFilter(args.denyDstAddr.Value(), filterRoot) } + var nameResolver dialer.Resolver = net.DefaultResolver + if len(args.dnsServers) > 0 { + nameResolver, err = resolver.FastFromURLs(args.dnsServers...) + if err != nil { + mainLogger.Critical("Failed to create name resolver: %v", err) + return 3 + } + } + nameResolver = resolver.Prefer(nameResolver, args.dnsPreferAddress.Value()) + // construct dialers var dialerRoot dialer.Dialer = dialer.NewBoundDialer(new(net.Dialer), args.sourceIPHints) if len(args.proxy) > 0 { @@ -694,18 +704,10 @@ func run() int { dialerRoot = dialer.NewFilterDialer(filterRoot.Access, dialerRoot) // must follow after resolving in chain - var nameResolver dialer.Resolver = net.DefaultResolver - if len(args.dnsServers) > 0 { - nameResolver, err = resolver.FastFromURLs(args.dnsServers...) - if err != nil { - mainLogger.Critical("Failed to create name resolver: %v", err) - return 3 - } - } - nameResolver = resolver.Prefer(nameResolver, args.dnsPreferAddress.Value()) if args.dnsCacheTTL > 0 { dialerRoot = dialer.NewNameResolveCachingDialer( dialerRoot, + true, nameResolver, args.dnsCacheTTL, args.dnsCacheNegTTL, From 7ef0b39bc25373ef7e86fb90e98199552034ea00 Mon Sep 17 00:00:00 2001 From: Vladislav Yarmak Date: Fri, 15 May 2026 01:38:34 +0300 Subject: [PATCH 3/4] cache DNS resolution for proxy connections as well --- main.go | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/main.go b/main.go index afa4a39..8598ae7 100644 --- a/main.go +++ b/main.go @@ -660,6 +660,7 @@ func run() int { filterRoot = access.NewDstAddrFilter(args.denyDstAddr.Value(), filterRoot) } + // setup name resolution var nameResolver dialer.Resolver = net.DefaultResolver if len(args.dnsServers) > 0 { nameResolver, err = resolver.FastFromURLs(args.dnsServers...) @@ -672,6 +673,16 @@ func run() int { // construct dialers var dialerRoot dialer.Dialer = dialer.NewBoundDialer(new(net.Dialer), args.sourceIPHints) + if args.dnsCacheTTL > 0 { + dialerRoot = dialer.NewNameResolveCachingDialer( + dialerRoot, + false, + nameResolver, + args.dnsCacheTTL, + args.dnsCacheNegTTL, + args.dnsCacheTimeout, + ) + } if len(args.proxy) > 0 { for _, proxy := range args.proxy { if proxy.literal { From 5954fef750d1cdacccae52e404665dab4dfb062e Mon Sep 17 00:00:00 2001 From: Vladislav Yarmak Date: Fri, 15 May 2026 02:14:34 +0300 Subject: [PATCH 4/4] refactor name caching from dialer to resolver level --- dialer/rescache.go | 103 ++++++++++----------------------------------- main.go | 32 +++++--------- 2 files changed, 33 insertions(+), 102 deletions(-) diff --git a/dialer/rescache.go b/dialer/rescache.go index 93cf6da..f6f79bc 100644 --- a/dialer/rescache.go +++ b/dialer/rescache.go @@ -2,18 +2,12 @@ package dialer import ( "context" - "errors" - "fmt" - "net" "net/netip" "strings" "time" "codeberg.org/yarmak/secache" - "github.com/hashicorp/go-multierror" "golang.org/x/sync/singleflight" - - "github.com/SenseUnit/dumbproxy/dialer/dto" ) type resolverCacheKey struct { @@ -27,21 +21,18 @@ type resolverCacheValue struct { err error } -type NameResolveCachingDialer struct { - resolver Resolver - preFilter bool - cache secache.Cache[resolverCacheKey, *resolverCacheValue] - sf singleflight.Group - posTTL time.Duration - negTTL time.Duration - timeout time.Duration - next Dialer +type CachingResolver struct { + next Resolver + cache secache.Cache[resolverCacheKey, *resolverCacheValue] + sf singleflight.Group + posTTL time.Duration + negTTL time.Duration + timeout time.Duration } -func NewNameResolveCachingDialer(next Dialer, preFilter bool, resolver Resolver, posTTL, negTTL, timeout time.Duration) *NameResolveCachingDialer { - return &NameResolveCachingDialer{ - resolver: resolver, - preFilter: preFilter, +func NewCachingResolver(next Resolver, posTTL, negTTL, timeout time.Duration) *CachingResolver { + return &CachingResolver{ + next: next, cache: *(secache.New[resolverCacheKey, *resolverCacheValue]( 3, func(key resolverCacheKey, item *resolverCacheValue) bool { @@ -51,62 +42,40 @@ func NewNameResolveCachingDialer(next Dialer, preFilter bool, resolver Resolver, posTTL: posTTL, negTTL: negTTL, timeout: timeout, - next: next, } } -func (nrcd *NameResolveCachingDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { - if nrcd.preFilter && WantsHostname(ctx, network, address, nrcd.next) { - return nrcd.next.DialContext(ctx, network, address) - } - - host, port, err := net.SplitHostPort(address) - if err != nil { - return nil, fmt.Errorf("failed to extract host and port from %s: %w", address, err) - } - +func (r *CachingResolver) LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error) { if addr, err := netip.ParseAddr(host); err == nil { // literal IP address, just do unmapping - return nrcd.next.DialContext(ctx, network, net.JoinHostPort(addr.Unmap().String(), port)) - } - - var resolveNetwork string - switch network { - case "udp4", "tcp4", "ip4": - resolveNetwork = "ip4" - case "udp6", "tcp6", "ip6": - resolveNetwork = "ip6" - case "udp", "tcp", "ip": - resolveNetwork = "ip" - default: - return nil, fmt.Errorf("resolving dial %q: unsupported network %q", address, network) + return r.next.LookupNetIP(ctx, network, addr.Unmap().String()) } host = strings.ToLower(host) key := resolverCacheKey{ - network: resolveNetwork, + network: network, host: host, } - res, ok := nrcd.cache.GetValidOrDelete(key) + res, ok := r.cache.GetValidOrDelete(key) if !ok { - v, _, _ := nrcd.sf.Do(key.network+":"+key.host, func() (any, error) { - ctx, cl := context.WithTimeout(context.Background(), nrcd.timeout) + v, _, _ := r.sf.Do(key.network+":"+key.host, func() (any, error) { + ctx, cl := context.WithTimeout(context.Background(), r.timeout) defer cl() - res, err := nrcd.resolver.LookupNetIP(ctx, key.network, key.host) + res, err := r.next.LookupNetIP(ctx, key.network, key.host) for i := range res { res[i] = res[i].Unmap() } - setTTL := nrcd.negTTL + setTTL := r.negTTL if err == nil { - setTTL = nrcd.posTTL + setTTL = r.posTTL } item := &resolverCacheValue{ expires: time.Now().Add(setTTL), addrs: res, err: err, } - nrcd.cache.Set(key, item) + r.cache.Set(key, item) return item, nil }) res = v.(*resolverCacheValue) @@ -116,35 +85,7 @@ func (nrcd *NameResolveCachingDialer) DialContext(ctx context.Context, network, return nil, res.err } - if nrcd.preFilter { - ctx = dto.OrigDstToContext(ctx, address) - } - - var dialErr error - var conn net.Conn - - for _, ip := range res.addrs { - conn, err = nrcd.next.DialContext(ctx, network, net.JoinHostPort(ip.String(), port)) - if err == nil { - return conn, nil - } - dialErr = multierror.Append(dialErr, err) - var sae dto.StopAddressIteration - if errors.As(err, &sae) { - break - } - } - - return nil, fmt.Errorf("failed to dial %s: %w", address, dialErr) -} - -func (nrcd *NameResolveCachingDialer) Dial(network, address string) (net.Conn, error) { - return nrcd.DialContext(context.Background(), network, address) -} - -func (nrcd *NameResolveCachingDialer) WantsHostname(ctx context.Context, net, address string) bool { - return WantsHostname(ctx, net, address, nrcd.next) + return res.addrs, nil } -var _ Dialer = new(NameResolveCachingDialer) -var _ HostnameWanter = new(NameResolveCachingDialer) +var _ Resolver = new(CachingResolver) diff --git a/main.go b/main.go index 8598ae7..b746df9 100644 --- a/main.go +++ b/main.go @@ -669,20 +669,20 @@ func run() int { return 3 } } - nameResolver = resolver.Prefer(nameResolver, args.dnsPreferAddress.Value()) - - // construct dialers - var dialerRoot dialer.Dialer = dialer.NewBoundDialer(new(net.Dialer), args.sourceIPHints) if args.dnsCacheTTL > 0 { - dialerRoot = dialer.NewNameResolveCachingDialer( - dialerRoot, - false, + nameResolver = dialer.NewCachingResolver( nameResolver, args.dnsCacheTTL, args.dnsCacheNegTTL, args.dnsCacheTimeout, ) } + nameResolver = resolver.Prefer(nameResolver, args.dnsPreferAddress.Value()) + + // construct dialers + var dialerRoot dialer.Dialer = dialer.NewBoundDialer(new(net.Dialer), args.sourceIPHints) + // this resolving dialer resolves dials unconditionally, for sake of cache or resolving privacy + dialerRoot = dialer.NewNameResolvingDialer(dialerRoot, nameResolver) if len(args.proxy) > 0 { for _, proxy := range args.proxy { if proxy.literal { @@ -713,20 +713,10 @@ func run() int { } } - dialerRoot = dialer.NewFilterDialer(filterRoot.Access, dialerRoot) // must follow after resolving in chain - - if args.dnsCacheTTL > 0 { - dialerRoot = dialer.NewNameResolveCachingDialer( - dialerRoot, - true, - nameResolver, - args.dnsCacheTTL, - args.dnsCacheNegTTL, - args.dnsCacheTimeout, - ) - } else { - dialerRoot = dialer.NewNameResolvingDialer(dialerRoot, nameResolver) - } + dialerRoot = dialer.NewFilterDialer(filterRoot.Access, dialerRoot) + // this resolving dialer resolves dials conditionally (unless upstream dialer tells not to) + // for sake of access filtering by destination address + dialerRoot = dialer.NewNameResolvingDialer(dialerRoot, nameResolver) // unholy plug if args.tt {