Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remoteconfig: fine grained locking #2458

Merged
merged 9 commits into from
Jan 8, 2024
87 changes: 63 additions & 24 deletions internal/remoteconfig/remoteconfig.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,15 @@ type Client struct {
repository *rc.Repository
stop chan struct{}

callbacks []Callback
products map[string]struct{}
productsWithCallbacks map[string]ProductCallback
capabilities map[Capability]struct{}
// When acquiring several locks and using defer to release them, make sure to acquire the locks in the following order:
callbacks []Callback
_callbacksMu sync.RWMutex
products map[string]struct{}
_productsMu sync.RWMutex
productsWithCallbacks map[string]ProductCallback
_productsWithCallbacksMu sync.RWMutex
capabilities map[Capability]struct{}
_capabilitiesMu sync.RWMutex

lastError error
}
Expand Down Expand Up @@ -243,12 +248,18 @@ func Subscribe(product string, callback ProductCallback, capabilities ...Capabil
if client == nil {
return ErrClientNotStarted
}
client.Lock()
defer client.Unlock()
client._productsMu.RLock()
defer client._productsMu.RUnlock()
if _, found := client.products[product]; found {
return fmt.Errorf("product %s already registered via RegisterProduct", product)
}

client._productsWithCallbacksMu.Lock()
defer client._productsWithCallbacksMu.Unlock()
client.productsWithCallbacks[product] = callback

client._capabilitiesMu.Lock()
defer client._capabilitiesMu.Unlock()
for _, cap := range capabilities {
client.capabilities[cap] = struct{}{}
Hellzy marked this conversation as resolved.
Show resolved Hide resolved
}
Expand All @@ -262,8 +273,8 @@ func RegisterCallback(f Callback) error {
if client == nil {
return ErrClientNotStarted
}
client.Lock()
defer client.Unlock()
client._callbacksMu.Lock()
defer client._callbacksMu.Unlock()
client.callbacks = append(client.callbacks, f)
return nil
}
Expand All @@ -274,12 +285,13 @@ func UnregisterCallback(f Callback) error {
if client == nil {
return ErrClientNotStarted
}
client.Lock()
defer client.Unlock()
client._callbacksMu.Lock()
defer client._callbacksMu.Unlock()
fValue := reflect.ValueOf(f)
for i, callback := range client.callbacks {
if reflect.ValueOf(callback) == fValue {
client.callbacks = append(client.callbacks[:i], client.callbacks[i+1:]...)
break
}
}
return nil
Expand All @@ -290,8 +302,10 @@ func RegisterProduct(p string) error {
if client == nil {
return ErrClientNotStarted
}
client.Lock()
defer client.Unlock()
client._productsMu.Lock()
defer client._productsMu.Unlock()
client._productsWithCallbacksMu.RLock()
defer client._productsWithCallbacksMu.RUnlock()
if _, found := client.productsWithCallbacks[p]; found {
return fmt.Errorf("product %s already registered via Subscribe", p)
}
Expand All @@ -304,8 +318,8 @@ func UnregisterProduct(p string) error {
if client == nil {
return ErrClientNotStarted
}
client.Lock()
defer client.Unlock()
client._productsMu.Lock()
defer client._productsMu.Unlock()
delete(client.products, p)
return nil
}
Expand All @@ -315,8 +329,10 @@ func HasProduct(p string) (bool, error) {
if client == nil {
return false, ErrClientNotStarted
}
client.RLock()
defer client.RUnlock()
client._productsMu.RLock()
defer client._productsMu.RUnlock()
client._productsWithCallbacksMu.RLock()
defer client._productsWithCallbacksMu.RUnlock()
_, found := client.products[p]
_, foundWithCallback := client.productsWithCallbacks[p]
return found || foundWithCallback, nil
Expand All @@ -328,8 +344,8 @@ func RegisterCapability(cap Capability) error {
if client == nil {
return ErrClientNotStarted
}
client.Lock()
defer client.Unlock()
client._capabilitiesMu.Lock()
defer client._capabilitiesMu.Unlock()
client.capabilities[cap] = struct{}{}
return nil
}
Expand All @@ -340,8 +356,8 @@ func UnregisterCapability(cap Capability) error {
if client == nil {
return ErrClientNotStarted
}
client.Lock()
defer client.Unlock()
client._capabilitiesMu.Lock()
defer client._capabilitiesMu.Unlock()
delete(client.capabilities, cap)
return nil
}
Expand All @@ -351,13 +367,35 @@ func HasCapability(cap Capability) (bool, error) {
if client == nil {
return false, ErrClientNotStarted
}
client.RLock()
defer client.RUnlock()
client._capabilitiesMu.RLock()
defer client._capabilitiesMu.RUnlock()
_, found := client.capabilities[cap]
return found, nil
}

func (c *Client) globalCallbacks() []Callback {
c._callbacksMu.RLock()
defer c._callbacksMu.RUnlock()
callbacks := make([]Callback, len(c.callbacks))
copy(callbacks, c.callbacks)
return callbacks
}

func (c *Client) productCallbacks() map[string]ProductCallback {
c._productsWithCallbacksMu.RLock()
defer c._productsWithCallbacksMu.RUnlock()
callbacks := make(map[string]ProductCallback, len(c.productsWithCallbacks))
for k, v := range c.productsWithCallbacks {
callbacks[k] = v
}
return callbacks
}

func (c *Client) allProducts() []string {
c._productsMu.RLock()
defer c._productsMu.RUnlock()
c._productsWithCallbacksMu.RLock()
defer c._productsWithCallbacksMu.RUnlock()
products := make([]string, 0, len(c.products)+len(c.productsWithCallbacks))
for p := range c.products {
products = append(products, p)
Expand Down Expand Up @@ -447,7 +485,7 @@ func (c *Client) applyUpdate(pbUpdate *clientGetConfigsResponse) error {
// 3 - ApplyStateAcknowledged
// This makes sure that any product that would need to re-receive the config in a subsequent update will be allowed to
statuses := make(map[string]rc.ApplyStatus)
for _, fn := range c.callbacks {
for _, fn := range c.globalCallbacks() {
for path, status := range fn(productUpdates) {
if s, ok := statuses[path]; !ok || status.State == rc.ApplyStateError ||
s.State == rc.ApplyStateAcknowledged && status.State == rc.ApplyStateUnacknowledged {
Expand All @@ -456,8 +494,9 @@ func (c *Client) applyUpdate(pbUpdate *clientGetConfigsResponse) error {
}
}
// Call the product-specific callbacks registered via Subscribe
productCallbacks := c.productCallbacks()
for product, update := range productUpdates {
if fn, ok := c.productsWithCallbacks[product]; ok {
if fn, ok := productCallbacks[product]; ok {
for path, status := range fn(update) {
statuses[path] = status
}
Expand Down
72 changes: 72 additions & 0 deletions internal/remoteconfig/remoteconfig_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@ import (
"crypto/sha256"
"encoding/json"
"fmt"
"math/rand"
"reflect"
"strings"
"sync"
"testing"
"time"

Expand Down Expand Up @@ -377,3 +379,73 @@ func TestNewUpdateRequest(t *testing.T) {
require.Equal(t, "app-version", req.Client.ClientTracer.AppVersion)
require.True(t, req.Client.IsTracer)
}

// TestAsync starts many goroutines that use the exported client API to make sure no deadlocks occur
func TestAsync(t *testing.T) {
require.NoError(t, Start(DefaultClientConfig()))
defer Stop()
const iterations = 10000
var wg sync.WaitGroup

// Subscriptions
for i := 0; i < iterations; i++ {
product := fmt.Sprintf("%d", rand.Int()%10)
capability := Capability(rand.Uint32() % 10)
wg.Add(1)
go func() {
callback := func(update ProductUpdate) map[string]rc.ApplyStatus { return nil }
Subscribe(product, callback, capability)
wg.Done()
}()
}

// Products
for i := 0; i < iterations; i++ {
wg.Add(1)
go func() {
defer wg.Done()
RegisterProduct(fmt.Sprintf("%d", rand.Int()%10))
}()
}
for i := 0; i < iterations; i++ {
wg.Add(1)
go func() {
defer wg.Done()
UnregisterProduct(fmt.Sprintf("%d", rand.Int()%10))
}()
}

// Capabilities
for i := 0; i < iterations; i++ {
wg.Add(1)
go func() {
defer wg.Done()
RegisterCapability(Capability(rand.Uint32() % 10))
}()
}
for i := 0; i < iterations; i++ {
wg.Add(1)
go func() {
defer wg.Done()
UnregisterCapability(Capability(rand.Uint32() % 10))
}()
}

// Callbacks
callback := func(updates map[string]ProductUpdate) map[string]rc.ApplyStatus { return nil }
for i := 0; i < iterations; i++ {
wg.Add(1)
go func() {
defer wg.Done()
RegisterCallback(callback)
}()
}
for i := 0; i < iterations; i++ {
wg.Add(1)
go func() {
defer wg.Done()
UnregisterCallback(callback)
}()
}
wg.Wait()
}