Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

contrib/database/sql: prevent DBM propagation full mode with incompatible dbs #2328

Merged
merged 9 commits into from
Dec 12, 2023
27 changes: 27 additions & 0 deletions contrib/database/sql/propagation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,33 @@ func TestDBMTraceContextTagging(t *testing.T) {
}
}

func TestDBMPropagation_PreventFullMode(t *testing.T) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might make sense to have a test that goes through the implicit driver registration (i.e. when a user doesn't call Register and uses Open directly?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice catch! adding it 👍

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🙇

tr := mocktracer.Start()
defer tr.Stop()

// test we prevent full mode with incompatible drivers
driverName := "sqlserver"
opts := []Option{WithDBMPropagation(tracer.DBMPropagationModeFull)}
// use the mock driver, as the real mssql driver does not implement Execer and Querier interfaces and always falls back
// to Prepare which always uses service propagation mode, so we can't test whether the DBM propagation mode gets downgraded or not.
Register(driverName, &internal.MockDriver{}, opts...)
defer unregister(driverName)

db, err := Open(driverName, "sqlserver://sa:myPassw0rd@127.0.0.1:1433?database=master", opts...)
require.NoError(t, err)
defer db.Close()

s, ctx := tracer.StartSpanFromContext(context.Background(), "test.call", tracer.WithSpanID(1))
_, err = db.ExecContext(ctx, "SELECT * FROM INFORMATION_SCHEMA.TABLES")
require.NoError(t, err)
s.Finish()

spans := tr.FinishedSpans()
for _, s := range spansOfType(spans, QueryTypeExec) {
assert.NotContains(t, s.Tags(), keyDBMTraceInjected)
}
}

func spansOfType(spans []mocktracer.Span, spanType string) (filtered []mocktracer.Span) {
filtered = make([]mocktracer.Span, 0)
for _, s := range spans {
Expand Down
34 changes: 28 additions & 6 deletions contrib/database/sql/sql.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,7 @@ func Register(driverName string, driver driver.Driver, opts ...RegisterOption) {

cfg := new(config)
defaults(cfg, driverName, nil)
for _, fn := range opts {
fn(cfg)
}
processOptions(cfg, driverName, opts...)
log.Debug("contrib/database/sql: Registering driver: %s %#v", driverName, cfg)
registeredDrivers.add(driverName, driver, cfg)
}
Expand Down Expand Up @@ -196,9 +194,7 @@ func OpenDB(c driver.Connector, opts ...Option) *sql.DB {
driverName = reflect.TypeOf(c.Driver()).String()
defaults(cfg, driverName, nil)
}
for _, fn := range opts {
fn(cfg)
}
processOptions(cfg, driverName, opts...)
tc := &tracedConnector{
connector: c,
driverName: driverName,
Expand Down Expand Up @@ -236,3 +232,29 @@ func Open(driverName, dataSourceName string, opts ...Option) (*sql.DB, error) {
}
return OpenDB(&dsnConnector{dsn: dataSourceName, driver: d}, opts...), nil
}

func processOptions(cfg *config, driverName string, opts ...Option) {
for _, fn := range opts {
fn(cfg)
}
checkDBMPropagation(cfg, driverName)
}

func checkDBMPropagation(cfg *config, driverName string) {
fullModeSupported := func() bool {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious about style: why make this an inner anonymous function rather than a named function?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it was just a matter of namespacing (this way I could use a shorter function name instead of dbmPropagationFullModeSupported). Anyway, I refactored this code, please take a look at the new version 🙏

unsupportedDrivers := []string{"sqlserver", "oracle"}
for _, dr := range unsupportedDrivers {
if dr == driverName {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this looks a bit brittle, tell me if I'm wrong, but it looks like driverName is just a key used to retrieve it later and/or make sure init is only done once, and it can take pretty much any value ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Obviously take my perspective with a grain of salt since I'm not an owner/maintainer for the go tracer but I agree that the driverName based logic is somewhat brittle. I know it's done in other places but I wonder if we could still consider doing type checking on the driver.Driver or driver.Connector.

Something like

if connector.Driver().(go_ora.OracleDriver) {
// we know this is one of the oracle drivers
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree! I've added a new way to detect the driver, following this logic:

  1. check driver package path using reflection
  2. check dsn prefix
  3. lastly check the driverName like I was doing in this version

please let me know what you think @vandonr @alexandre-normand

return false
}
}
return true
}
if cfg.dbmPropagationMode == tracer.DBMPropagationModeFull && !fullModeSupported() {
log.Warn("Using DBM_PROPAGATION_MODE in 'full' mode is not supported for %s. See "+
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it might make sense here to state that we're downgrading to service mode.

"https://docs.datadoghq.com/database_monitoring/connect_dbm_and_apm/ for more info.",
driverName,
)
cfg.dbmPropagationMode = tracer.DBMPropagationModeService
}
}