Skip to content

Commit

Permalink
upstream: fix goroutines leak
Browse files Browse the repository at this point in the history
  • Loading branch information
EugeneOne1 committed Nov 23, 2023
1 parent 65b5293 commit dc6b77c
Showing 1 changed file with 47 additions and 42 deletions.
89 changes: 47 additions & 42 deletions upstream/parallel.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@ package upstream

import (
"context"
"fmt"
"net/netip"
"time"

"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/miekg/dns"
"golang.org/x/exp/slices"
)

const (
Expand All @@ -21,37 +23,40 @@ const (

// ExchangeParallel returns the dirst successful response from one of u. It
// returns an error if all upstreams failed to exchange the request.
func ExchangeParallel(u []Upstream, req *dns.Msg) (reply *dns.Msg, resolved Upstream, err error) {
upsNum := len(u)
func ExchangeParallel(ups []Upstream, req *dns.Msg) (reply *dns.Msg, resolved Upstream, err error) {
upsNum := len(ups)
switch upsNum {
case 0:
return nil, nil, ErrNoUpstreams
case 1:
reply, err = exchangeAndLog(u[0], req)
reply, err = exchangeAndLog(ups[0], req)

return reply, u[0], err
return reply, ups[0], err
default:
// Go on.
}

resCh := make(chan *ExchangeAllResult)
errCh := make(chan error)
for _, f := range u {
go exchangeAsync(f, req, resCh, errCh)
resCh := make(chan any, upsNum)
for _, f := range ups {
go exchangeAsync(f, req, resCh)
}

errs := []error{}
for range u {
select {
case excErr := <-errCh:
errs = append(errs, excErr)
case rep := <-resCh:
if rep.Resp != nil {
return rep.Resp, rep.Upstream, nil
for range ups {
var r *ExchangeAllResult
r, err = receiveAsyncResult(resCh)
if err != nil {
if !errors.Is(err, ErrNoReply) {
errs = append(errs, err)
}
} else {
return r.Resp, r.Upstream, nil
}
}

// TODO(e.burkov): Probably it's better to return the joined error from
// each upstream that returned no response, and get rid of multiple
// [errors.Is] calls. This will change the behavior though.
if len(errs) == 0 {
return nil, nil, errors.Error("none of upstream servers responded")
}
Expand All @@ -72,8 +77,8 @@ type ExchangeAllResult struct {
// ExchangeAll returns the responses from all of u. It returns an error only if
// all upstreams failed to exchange the request.
func ExchangeAll(ups []Upstream, req *dns.Msg) (res []ExchangeAllResult, err error) {
upsl := len(ups)
switch upsl {
upsNum := len(ups)
switch upsNum {
case 0:
return nil, ErrNoUpstreams
case 1:
Expand All @@ -90,62 +95,60 @@ func ExchangeAll(ups []Upstream, req *dns.Msg) (res []ExchangeAllResult, err err
// Go on.
}

res = make([]ExchangeAllResult, 0, upsl)
res = make([]ExchangeAllResult, 0, upsNum)
var errs []error

resCh := make(chan *ExchangeAllResult)
errCh := make(chan error)
resCh := make(chan any, upsNum)

// Start exchanging concurrently.
for _, u := range ups {
go exchangeAsync(u, req, resCh, errCh)
go exchangeAsync(u, req, resCh)
}

// Wait for all exchanges to finish.
for range ups {
var r *ExchangeAllResult
r, err = receiveAsyncResult(resCh, errCh)
r, err = receiveAsyncResult(resCh)
if err != nil {
errs = append(errs, err)
} else {
res = append(res, *r)
}
}

if len(errs) == upsl {
if len(errs) == upsNum {
// TODO(e.burkov): Use [errors.Join] in Go 1.20.
return res, errors.List("all upstreams failed to exchange", errs...)
}

return res, nil
return slices.Clip(res), nil
}

// receiveAsyncResult receives a single result from resCh or an error from
// errCh. It returns either a non-nil result or an error.
func receiveAsyncResult(
resCh chan *ExchangeAllResult,
errCh chan error,
) (res *ExchangeAllResult, err error) {
select {
case err = <-errCh:
return nil, err
case rep := <-resCh:
if rep.Resp == nil {
func receiveAsyncResult(resCh chan any) (res *ExchangeAllResult, err error) {
switch res := (<-resCh).(type) {
case error:
return nil, res
case *ExchangeAllResult:
if res.Resp == nil {
return nil, ErrNoReply
}

return rep, nil
return res, nil
default:
return nil, fmt.Errorf("unexpected type %T of result", res)
}
}

// exchangeAsync tries to resolve DNS request with one upstream and sends the
// result to respCh.
func exchangeAsync(u Upstream, req *dns.Msg, respCh chan *ExchangeAllResult, errCh chan error) {
func exchangeAsync(u Upstream, req *dns.Msg, resCh chan any) {
reply, err := exchangeAndLog(u, req)
if err != nil {
errCh <- err
resCh <- err
} else {
respCh <- &ExchangeAllResult{Resp: reply, Upstream: u}
resCh <- &ExchangeAllResult{Resp: reply, Upstream: u}
}
}

Expand All @@ -156,12 +159,14 @@ func exchangeAndLog(u Upstream, req *dns.Msg) (resp *dns.Msg, err error) {

start := time.Now()
reply, err := u.Exchange(req)
elapsed := time.Since(start)
dur := time.Since(start)

if q := &req.Question[0]; err == nil {
log.Debug("dnsproxy: upstream %s exchanged %s successfully in %s", addr, q, elapsed)
} else {
log.Debug("dnsproxy: upstream %s failed to exchange %s in %s: %s", addr, q, elapsed, err)
if len(req.Question) > 0 {
if q := &req.Question[0]; err == nil {
log.Debug("dnsproxy: upstream %s exchanged %s successfully in %s", addr, q, dur)
} else {
log.Debug("dnsproxy: upstream %s failed to exchange %s in %s: %s", addr, q, dur, err)
}
}

return reply, err
Expand Down

0 comments on commit dc6b77c

Please sign in to comment.