From 207dd8eac60fd5143186d7fefca5afcec76eaa4c Mon Sep 17 00:00:00 2001 From: Rafael Dantas Justo Date: Mon, 29 Aug 2022 21:45:57 +0100 Subject: [PATCH] contrib/database/sql: Fix race condition when registering drivers Library failing when registering drivers in parallel on different goroutines. Internal `driverRegistry` wasn't goroutine safe, causing race conditions with multiple writes in the internal maps. ``` $ go test -race -run TestRegister WARNING: DATA RACE Write at 0x00c000254000 by goroutine 9: runtime.mapaccessK() /opt/homebrew/Cellar/go/1.19/libexec/src/runtime/map.go:518 +0x1ec gopkg.in/DataDog/dd-trace-go.v1/contrib/database/sql.(*driverRegistry).add() /Users/rafael.justo/dev/go/src/github.com/rafaeljusto/dd-trace-go/contrib/database/sql/sql.go:60 +0x144 gopkg.in/DataDog/dd-trace-go.v1/contrib/database/sql.TestRegister.func1() /Users/rafael.justo/dev/go/src/github.com/rafaeljusto/dd-trace-go/contrib/database/sql/sql_test.go:282 +0x54 gopkg.in/DataDog/dd-trace-go.v1/contrib/database/sql.TestRegister.func2() /Users/rafael.justo/dev/go/src/github.com/rafaeljusto/dd-trace-go/contrib/database/sql/sql_test.go:285 +0x44 Previous write at 0x00c000254000 by goroutine 15: runtime.mapaccessK() /opt/homebrew/Cellar/go/1.19/libexec/src/runtime/map.go:518 +0x1ec gopkg.in/DataDog/dd-trace-go.v1/contrib/database/sql.(*driverRegistry).add() /Users/rafael.justo/dev/go/src/github.com/rafaeljusto/dd-trace-go/contrib/database/sql/sql.go:60 +0x144 gopkg.in/DataDog/dd-trace-go.v1/contrib/database/sql.TestRegister.func1() /Users/rafael.justo/dev/go/src/github.com/rafaeljusto/dd-trace-go/contrib/database/sql/sql_test.go:282 +0x54 gopkg.in/DataDog/dd-trace-go.v1/contrib/database/sql.TestRegister.func2() /Users/rafael.justo/dev/go/src/github.com/rafaeljusto/dd-trace-go/contrib/database/sql/sql_test.go:285 +0x44 ``` --- contrib/database/sql/sql.go | 18 ++++++++++++++++-- contrib/database/sql/sql_test.go | 16 ++++++++++++++++ 2 files changed, 32 insertions(+), 2 deletions(-) diff --git a/contrib/database/sql/sql.go b/contrib/database/sql/sql.go index d855f70cc9..2f198deea9 100644 --- a/contrib/database/sql/sql.go +++ b/contrib/database/sql/sql.go @@ -9,11 +9,10 @@ // We start by telling the package which driver we will be using. For example, if we are using "github.com/lib/pq", // we would do as follows: // -// sqltrace.Register("pq", pq.Driver{}) +// sqltrace.Register("pq", pq.Driver{}) // db, err := sqltrace.Open("pq", "postgres://pqgotest:password@localhost...") // // The rest of our application would continue as usual, but with tracing enabled. -// package sql import ( @@ -23,6 +22,7 @@ import ( "errors" "math" "reflect" + "sync" "time" "gopkg.in/DataDog/dd-trace-go.v1/contrib/database/sql/internal" @@ -44,11 +44,15 @@ type driverRegistry struct { drivers map[string]driver.Driver // configs maps keys to their registered configuration. configs map[string]*config + // mu protects the above maps. + mu sync.RWMutex } // isRegistered reports whether the name matches an existing entry // in the driver registry. func (d *driverRegistry) isRegistered(name string) bool { + d.mu.RLock() + defer d.mu.RUnlock() _, ok := d.configs[name] return ok } @@ -58,6 +62,8 @@ func (d *driverRegistry) add(name string, driver driver.Driver, cfg *config) { if d.isRegistered(name) { return } + d.mu.Lock() + defer d.mu.Unlock() d.keys[reflect.TypeOf(driver)] = name d.drivers[name] = driver d.configs[name] = cfg @@ -65,24 +71,32 @@ func (d *driverRegistry) add(name string, driver driver.Driver, cfg *config) { // name returns the name of the driver stored in the registry. func (d *driverRegistry) name(driver driver.Driver) (string, bool) { + d.mu.RLock() + defer d.mu.RUnlock() name, ok := d.keys[reflect.TypeOf(driver)] return name, ok } // driver returns the driver stored in the registry with the provided name. func (d *driverRegistry) driver(name string) (driver.Driver, bool) { + d.mu.RLock() + defer d.mu.RUnlock() driver, ok := d.drivers[name] return driver, ok } // config returns the config stored in the registry with the provided name. func (d *driverRegistry) config(name string) (*config, bool) { + d.mu.RLock() + defer d.mu.RUnlock() config, ok := d.configs[name] return config, ok } // unregister is used to make tests idempotent. func (d *driverRegistry) unregister(name string) { + d.mu.Lock() + defer d.mu.Unlock() driver := d.drivers[name] delete(d.keys, reflect.TypeOf(driver)) delete(d.configs, name) diff --git a/contrib/database/sql/sql_test.go b/contrib/database/sql/sql_test.go index f4dc52d3b0..b02c00b2de 100644 --- a/contrib/database/sql/sql_test.go +++ b/contrib/database/sql/sql_test.go @@ -13,6 +13,8 @@ import ( "log" "math" "os" + "strconv" + "sync" "testing" "time" @@ -261,3 +263,17 @@ func TestConnectCancelledCtx(t *testing.T) { assert.Equal("hangingConnector.query", s.OperationName()) assert.Equal("Connect", s.Tag("sql.query_type")) } + +func TestRegister(t *testing.T) { + var wg sync.WaitGroup + + for i := 1; i < 10; i++ { + wg.Add(1) + go func(i int64) { + Register("test"+strconv.FormatInt(i, 10), &mysql.MySQLDriver{}) + wg.Done() + }(int64(i)) + } + + wg.Wait() +}