Skip to content

Commit

Permalink
refactor(config): ensure upstreams.timeout is always valid
Browse files Browse the repository at this point in the history
  • Loading branch information
ThinkChaos committed Dec 6, 2023
1 parent 0f69630 commit ef29cdc
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 15 deletions.
1 change: 1 addition & 0 deletions config/config.go
Expand Up @@ -578,6 +578,7 @@ func (cfg *Config) migrate(logger *logrus.Entry) bool {

func (cfg *Config) validate(logger *logrus.Entry) {
cfg.MinTLSServeVer.validate(logger)
cfg.Upstreams.validate(logger)
}

// ConvertPort converts string representation into a valid port (0 - 65535)
Expand Down
11 changes: 10 additions & 1 deletion config/upstreams.go
Expand Up @@ -10,14 +10,23 @@ const UpstreamDefaultCfgName = "default"
// Upstreams upstream servers configuration
type Upstreams struct {
Init Init `yaml:"init"`
Timeout Duration `yaml:"timeout" default:"2s"`
Timeout Duration `yaml:"timeout" default:"2s"` // always > 0
Groups UpstreamGroups `yaml:"groups"`
Strategy UpstreamStrategy `yaml:"strategy" default:"parallel_best"`
UserAgent string `yaml:"userAgent"`
}

type UpstreamGroups map[string][]Upstream

func (c *Upstreams) validate(logger *logrus.Entry) {
defaults := mustDefault[Upstreams]()

if !c.Timeout.IsAboveZero() {
logger.Warnf("upstreams.timeout <= 0, setting to %s", defaults.Timeout)
c.Timeout = defaults.Timeout
}
}

// IsEnabled implements `config.Configurable`.
func (c *Upstreams) IsEnabled() bool {
return len(c.Groups) != 0
Expand Down
19 changes: 19 additions & 0 deletions config/upstreams_test.go
Expand Up @@ -61,6 +61,25 @@ var _ = Describe("ParallelBestConfig", func() {
))
})
})

Describe("validate", func() {
It("should compute defaults", func() {
cfg.Timeout = -1

cfg.validate(logger)

Expect(cfg.Timeout).Should(BeNumerically(">", 0))

Expect(hook.Calls).ShouldNot(BeEmpty())
Expect(hook.Messages).Should(ContainElement(ContainSubstring("timeout")))
})

It("should not override valid user values", func() {
cfg.validate(logger)

Expect(hook.Messages).ShouldNot(ContainElement(ContainSubstring("timeout")))
})
})
})

Context("UpstreamGroupConfig", func() {
Expand Down
8 changes: 2 additions & 6 deletions resolver/bootstrap.go
Expand Up @@ -143,12 +143,8 @@ func (b *Bootstrap) resolveUpstream(ctx context.Context, r Resolver, host string
return ips, nil
}

if b.cfg.timeout.IsAboveZero() {
var cancel context.CancelFunc

ctx, cancel = context.WithTimeout(ctx, b.cfg.timeout.ToDuration())
defer cancel()
}
ctx, cancel := context.WithTimeout(ctx, b.cfg.timeout.ToDuration())
defer cancel()

// Use system resolver if no bootstrap is configured
if b.resolver == nil {
Expand Down
10 changes: 2 additions & 8 deletions resolver/upstream_resolver.go
Expand Up @@ -274,14 +274,8 @@ func (r *UpstreamResolver) Resolve(ctx context.Context, request *model.Request)
ip = ips.Current()
upstreamURL := r.upstreamClient.fmtURL(ip, r.cfg.Port, r.cfg.Path)

ctx := ctx // make sure we don't overwrite the outer function's context

if r.cfg.Timeout.IsAboveZero() {
var cancel context.CancelFunc

ctx, cancel = context.WithTimeout(ctx, r.cfg.Timeout.ToDuration())
defer cancel()
}
ctx, cancel := context.WithTimeout(ctx, r.cfg.Timeout.ToDuration())
defer cancel()

response, rtt, err := r.upstreamClient.callExternal(ctx, request.Req, upstreamURL, request.Protocol)
if err != nil {
Expand Down

0 comments on commit ef29cdc

Please sign in to comment.