From 383e53a1a979bf68aca2a306df8b48d12597cb04 Mon Sep 17 00:00:00 2001 From: Tomasz Mielech Date: Tue, 1 Jun 2021 15:51:32 +0200 Subject: [PATCH] test --- cluster/cluster.go | 40 +++++++++++++++++++++------------------- 1 file changed, 21 insertions(+), 19 deletions(-) diff --git a/cluster/cluster.go b/cluster/cluster.go index fb40d60e..1eb5dff4 100644 --- a/cluster/cluster.go +++ b/cluster/cluster.go @@ -126,28 +126,32 @@ func (c *clusterConnection) Do(ctx context.Context, req driver.Request) (driver. timeout = c.defaultTimeout } - serverCount := len(c.servers) - var specificServer driver.Connection + var server driver.Connection + var serverCount int + var durationPerRequest time.Duration + if v := ctx.Value(keyEndpoint); v != nil { if endpoint, ok := v.(string); ok { // Override pool to only specific server if it is found if s, ok := c.getSpecificServer(endpoint); ok { + server = s + durationPerRequest = timeout serverCount = 1 - specificServer = s } } } - timeoutDivider := math.Max(1.0, math.Min(3.0, float64(serverCount))) - attempt := 1 - s := specificServer - if s == nil { - s = c.getCurrentServer() + if server == nil { + server, serverCount = c.getCurrentServer() + timeoutDivider := math.Max(1.0, math.Min(3.0, float64(serverCount))) + durationPerRequest = time.Duration(float64(timeout) / timeoutDivider) } + + attempt := 1 for { // Send request to specific endpoint with a 1/3 timeout (so we get 3 attempts) - serverCtx, cancel := context.WithTimeout(ctx, time.Duration(float64(timeout)/timeoutDivider)) - resp, err := s.Do(serverCtx, req) + serverCtx, cancel := context.WithTimeout(ctx, durationPerRequest) + resp, err := server.Do(serverCtx, req) cancel() isNoLeaderResponse := false @@ -162,8 +166,8 @@ func (c *clusterConnection) Do(ctx context.Context, req driver.Request) (driver. err = aerr } } - } + if !isNoLeaderResponse || !followLeaderRedirect { if err == nil { // We're done @@ -189,15 +193,13 @@ func (c *clusterConnection) Do(ctx context.Context, req driver.Request) (driver. // Failed, try next server attempt++ - if specificServer != nil { + if attempt > serverCount { // A specific server was specified, no failover. - return nil, driver.WithStack(err) - } - if attempt > len(c.servers) { + // or // We've tried all servers. Giving up. return nil, driver.WithStack(err) } - s = c.getNextServer() + server = c.getNextServer() } } @@ -321,11 +323,11 @@ func (c *clusterConnection) Protocols() driver.ProtocolSet { return result } -// getCurrentServer returns the currently used server. -func (c *clusterConnection) getCurrentServer() driver.Connection { +// getCurrentServer returns the currently used server and number of servers. +func (c *clusterConnection) getCurrentServer() (driver.Connection, int) { c.mutex.RLock() defer c.mutex.RUnlock() - return c.servers[c.current] + return c.servers[c.current], len(c.servers) } // getSpecificServer returns the server with the given endpoint.