Skip to content

Commit

Permalink
clean up pgxcommon
Browse files Browse the repository at this point in the history
we used a "transaction factory" everywhere but never used it to create a transaction (ever since we switched to implicit transactions).

this removes the extra abstraction. it also gives the DBReader interface
a better name (since it can be used for more than just reading)
  • Loading branch information
ecordell committed May 18, 2023
1 parent 5746c03 commit 6b5924d
Show file tree
Hide file tree
Showing 22 changed files with 193 additions and 223 deletions.
16 changes: 2 additions & 14 deletions internal/datastore/crdb/caveat.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,10 @@ func (cr *crdbReader) ReadCaveatByName(ctx context.Context, name string) (*core.
return nil, datastore.NoRevision, fmt.Errorf(errReadCaveat, name, err)
}

tx, txCleanup, err := cr.txSource(ctx)
if err != nil {
return nil, datastore.NoRevision, fmt.Errorf(errReadCaveat, name, err)
}
defer txCleanup(ctx)

var definitionBytes []byte
var timestamp time.Time

err = tx.QueryRowFunc(ctx, func(ctx context.Context, row pgx.Row) error {
err = cr.query.QueryRowFunc(ctx, func(ctx context.Context, row pgx.Row) error {
return row.Scan(&definitionBytes, &timestamp)
}, sql, args...)
if err != nil {
Expand Down Expand Up @@ -94,15 +88,9 @@ func (cr *crdbReader) lookupCaveats(ctx context.Context, caveatNames []string) (
return nil, fmt.Errorf(errListCaveats, err)
}

tx, txCleanup, err := cr.txSource(ctx)
if err != nil {
return nil, fmt.Errorf(errListCaveats, err)
}
defer txCleanup(ctx)

var allDefinitionBytes []bytesAndTimestamp

err = tx.QueryFunc(ctx, func(ctx context.Context, rows pgx.Rows) error {
err = cr.query.QueryFunc(ctx, func(ctx context.Context, rows pgx.Rows) error {
for rows.Next() {
var defBytes []byte
var name string
Expand Down
25 changes: 8 additions & 17 deletions internal/datastore/crdb/crdb.go
Original file line number Diff line number Diff line change
Expand Up @@ -246,24 +246,18 @@ type crdbDatastore struct {
}

func (cds *crdbDatastore) SnapshotReader(rev datastore.Revision) datastore.Reader {
useImplicitTxFunc := func(ctx context.Context) (pgxcommon.DBReader, common.TxCleanupFunc, error) {
return cds.readPool, func(context.Context) {}, nil
}

querySplitter := common.TupleQuerySplitter{
Executor: pgxcommon.NewPGXExecutor(useImplicitTxFunc),
Executor: pgxcommon.NewPGXExecutor(cds.readPool),
UsersetBatchSize: cds.usersetBatchSize,
}

fromBuilder := func(query sq.SelectBuilder, fromStr string) sq.SelectBuilder {
return query.From(fromStr + " AS OF SYSTEM TIME " + rev.String())
}

return &crdbReader{useImplicitTxFunc, querySplitter, noOverlapKeyer, nil, fromBuilder}
return &crdbReader{cds.readPool, querySplitter, noOverlapKeyer, nil, fromBuilder}
}

func noCleanup(context.Context) {}

func (cds *crdbDatastore) ReadWriteTx(
ctx context.Context,
f datastore.TxUserFunc,
Expand All @@ -277,18 +271,15 @@ func (cds *crdbDatastore) ReadWriteTx(
}

err := cds.writePool.BeginFunc(ctx, func(tx pgx.Tx) error {
longLivedTx := func(context.Context) (pgxcommon.DBReader, common.TxCleanupFunc, error) {
return pgxcommon.DBReaderFor(tx), noCleanup, nil
}

querier := pgxcommon.QuerierFuncsFor(tx)
querySplitter := common.TupleQuerySplitter{
Executor: pgxcommon.NewPGXExecutor(longLivedTx),
Executor: pgxcommon.NewPGXExecutor(querier),
UsersetBatchSize: cds.usersetBatchSize,
}

rwt := &crdbReadWriteTXN{
&crdbReader{
longLivedTx,
querier,
querySplitter,
cds.writeOverlapKeyer,
make(keySet),
Expand All @@ -314,7 +305,7 @@ func (cds *crdbDatastore) ReadWriteTx(

if cds.disableStats {
var err error
commitTimestamp, err = readCRDBNow(ctx, pgxcommon.DBReaderFor(tx))
commitTimestamp, err = readCRDBNow(ctx, querier)
if err != nil {
return fmt.Errorf("error getting commit timestamp: %w", err)
}
Expand Down Expand Up @@ -420,7 +411,7 @@ func (cds *crdbDatastore) Features(ctx context.Context) (*datastore.Features, er
return &features, nil
}

func readCRDBNow(ctx context.Context, reader pgxcommon.DBReader) (revision.Decimal, error) {
func readCRDBNow(ctx context.Context, reader pgxcommon.DBFuncQuerier) (revision.Decimal, error) {
ctx, span := tracer.Start(ctx, "readCRDBNow")
defer span.End()

Expand All @@ -434,7 +425,7 @@ func readCRDBNow(ctx context.Context, reader pgxcommon.DBReader) (revision.Decim
return revision.NewFromDecimal(hlcNow), nil
}

func readClusterTTLNanos(ctx context.Context, conn *pool.RetryPool) (int64, error) {
func readClusterTTLNanos(ctx context.Context, conn pgxcommon.DBFuncQuerier) (int64, error) {
var target, configSQL string

if err := conn.QueryRowFunc(ctx, func(ctx context.Context, row pgx.Row) error {
Expand Down
4 changes: 2 additions & 2 deletions internal/datastore/crdb/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ type crdbOptions struct {
followerReadDelay time.Duration
maxRevisionStalenessPercent float64
gcWindow time.Duration
maxRetries uint32
maxRetries uint8
splitAtUsersetCount uint16
overlapStrategy string
overlapKey string
Expand Down Expand Up @@ -259,7 +259,7 @@ func GCWindow(window time.Duration) Option {
// MaxRetries is the maximum number of times a retriable transaction will be
// client-side retried.
// Default: 5
func MaxRetries(maxRetries uint32) Option {
func MaxRetries(maxRetries uint8) Option {
return func(po *crdbOptions) { po.maxRetries = maxRetries }
}

Expand Down
42 changes: 31 additions & 11 deletions internal/datastore/crdb/pool/balancer.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"context"
"hash/maphash"
"math/rand"
"runtime"
"strconv"
"time"

Expand All @@ -13,17 +12,27 @@ import (
"github.com/prometheus/client_golang/prometheus"
"golang.org/x/exp/maps"
"golang.org/x/exp/slices"
"golang.org/x/sync/semaphore"

log "github.com/authzed/spicedb/internal/logging"
)

var connectionsPerCRDBNodeCountGauge = prometheus.NewGaugeVec(prometheus.GaugeOpts{
Name: "crdb_connections_per_node",
Help: "the number of connections spicedb has to each crdb node",
}, []string{"pool", "node_id"})
var (
connectionsPerCRDBNodeCountGauge = prometheus.NewGaugeVec(prometheus.GaugeOpts{
Name: "crdb_connections_per_node",
Help: "the number of connections spicedb has to each crdb node",
}, []string{"pool", "node_id"})

pruningTimeHistogram = prometheus.NewHistogramVec(prometheus.HistogramOpts{
Name: "crdb_pruning_duration",
Help: "milliseconds spent on one iteration of pruning excess connections",
Buckets: []float64{.1, .2, .5, 1, 2, 5, 10, 20, 50, 100},
}, []string{"pool"})
)

func init() {
prometheus.MustRegister(connectionsPerCRDBNodeCountGauge)
prometheus.MustRegister(pruningTimeHistogram)
}

type balancePoolConn[C balanceConn] interface {
Expand All @@ -42,7 +51,7 @@ type balanceablePool[P balancePoolConn[C], C balanceConn] interface {
ID() string
AcquireAllIdle(ctx context.Context) []P
Node(conn C) uint32
GC(conn C) uint32
GC(conn C)
MaxConns() uint32
Range(func(conn C, nodeID uint32))
}
Expand All @@ -64,6 +73,7 @@ func NewNodeConnectionBalancer(pool *RetryPool, healthTracker *NodeHealthTracker
// testing purposes. Callers should use the exported NodeConnectionBalancer
type nodeConnectionBalancer[P balancePoolConn[C], C balanceConn] struct {
ticker *time.Ticker
sem *semaphore.Weighted
pool balanceablePool[P, C]
healthTracker *NodeHealthTracker
rnd *rand.Rand
Expand All @@ -76,6 +86,7 @@ func newNodeConnectionBalancer[P balancePoolConn[C], C balanceConn](pool balance
seed := int64(new(maphash.Hash).Sum64())
return &nodeConnectionBalancer[P, C]{
ticker: time.NewTicker(interval),
sem: semaphore.NewWeighted(1),
healthTracker: healthTracker,
pool: pool,
seed: seed,
Expand All @@ -91,8 +102,12 @@ func (p *nodeConnectionBalancer[P, C]) Prune(ctx context.Context) {
p.ticker.Stop()
return
case <-p.ticker.C:
p.pruneConnections(ctx)
runtime.GC()
if p.sem.TryAcquire(1) {
ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
p.pruneConnections(ctx)
cancel()
p.sem.Release(1)
}
}
}
}
Expand All @@ -101,6 +116,10 @@ func (p *nodeConnectionBalancer[P, C]) Prune(ctx context.Context) {
// This causes the pool to reconnect, which over time will lead to a balanced number of connections
// across each node.
func (p *nodeConnectionBalancer[P, C]) pruneConnections(ctx context.Context) {
start := time.Now()
defer func() {
pruningTimeHistogram.WithLabelValues(p.pool.ID()).Observe(float64(time.Since(start).Milliseconds()))
}()
conns := p.pool.AcquireAllIdle(ctx)
defer func() {
// release all acquired idle conns back
Expand Down Expand Up @@ -139,7 +158,7 @@ func (p *nodeConnectionBalancer[P, C]) pruneConnections(ctx context.Context) {
Msg("connections per node")

// Delete metrics for nodes we no longer have connections for
p.healthTracker.Lock()
p.healthTracker.RLock()
for node := range p.healthTracker.nodesEverSeen {
// TODO: does this handle network interruptions correctly?
if _, ok := connectionCounts[node]; !ok {
Expand All @@ -149,7 +168,7 @@ func (p *nodeConnectionBalancer[P, C]) pruneConnections(ctx context.Context) {
})
}
}
p.healthTracker.Unlock()
p.healthTracker.RUnlock()

nodes := maps.Keys(connectionCounts)
slices.Sort(nodes)
Expand All @@ -163,14 +182,15 @@ func (p *nodeConnectionBalancer[P, C]) pruneConnections(ctx context.Context) {
nodes[j], nodes[i] = nodes[i], nodes[j]
})

initialPerNodeMax := p.pool.MaxConns() / nodeCount
for i, node := range nodes {
count := connectionCounts[node]
connectionsPerCRDBNodeCountGauge.WithLabelValues(
p.pool.ID(),
strconv.FormatUint(uint64(node), 10),
).Set(float64(count))

perNodeMax := p.pool.MaxConns() / nodeCount
perNodeMax := initialPerNodeMax

// Assign MaxConns%(# of nodes) nodes an extra connection. This ensures that
// the sum of all perNodeMax values exactly equals the pool MaxConns.
Expand Down
56 changes: 56 additions & 0 deletions internal/datastore/crdb/pool/balancer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,62 @@ func TestNodeConnectionBalancerPrune(t *testing.T) {
conns: []uint32{1, 1, 1, 2, 2, 2, 3, 3, 3},
expectedGC: []uint32{},
},
{
name: "no extra, max 1",
nodes: []uint32{1, 2, 3},
maxConns: 1,
conns: []uint32{1},
expectedGC: []uint32{},
},
{
name: "prune 1, max 1",
nodes: []uint32{1, 2, 3},
maxConns: 1,
conns: []uint32{1, 2},
expectedGC: []uint32{2},
},
{
name: "no extra, max 2",
nodes: []uint32{1, 2, 3},
maxConns: 2,
conns: []uint32{1, 2},
expectedGC: []uint32{},
},
{
name: "prune 1, max 2",
nodes: []uint32{1, 2, 3},
maxConns: 2,
conns: []uint32{1, 2, 3},
expectedGC: []uint32{3},
},
{
name: "no extra, max 1 per node",
nodes: []uint32{1, 2, 3},
maxConns: 3,
conns: []uint32{1, 2, 3},
expectedGC: []uint32{},
},
{
name: "1 extra, max 1 per node",
nodes: []uint32{1, 2, 3},
maxConns: 3,
conns: []uint32{1, 2, 2, 3},
expectedGC: []uint32{2},
},
{
name: "no extra, max 2 per node",
nodes: []uint32{1, 2, 3},
maxConns: 6,
conns: []uint32{1, 1, 2, 2, 3, 3},
expectedGC: []uint32{},
},
{
name: "1 extra, max 2 per node",
nodes: []uint32{1, 2, 3},
maxConns: 6,
conns: []uint32{1, 1, 2, 2, 3, 3, 3},
expectedGC: []uint32{3},
},
{
name: "1 extra, prune 1",
nodes: []uint32{1, 2, 3},
Expand Down
3 changes: 1 addition & 2 deletions internal/datastore/crdb/pool/fake_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,12 @@ func (f *FakePool) Node(conn *FakeConn) uint32 {
return id
}

func (f *FakePool) GC(conn *FakeConn) uint32 {
func (f *FakePool) GC(conn *FakeConn) {
f.Lock()
defer f.Unlock()
id := f.nodeForConn[conn]
delete(f.nodeForConn, conn)
f.gc[conn] = id
return id
}

func (f *FakePool) MaxConns() uint32 {
Expand Down
12 changes: 6 additions & 6 deletions internal/datastore/crdb/pool/health.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ func init() {
//
// Consumers can manually mark a node healthy or unhealthy as well.
type NodeHealthTracker struct {
sync.Mutex
sync.RWMutex
connConfig *pgx.ConnConfig
healthyNodes map[uint32]struct{}
nodesEverSeen map[uint32]struct{}
Expand Down Expand Up @@ -77,7 +77,7 @@ func (t *NodeHealthTracker) tryConnect(interval time.Duration) {
if err = conn.Ping(ctx); err != nil {
return
}
log.Ctx(ctx).Info().
log.Ctx(ctx).Trace().
Uint32("nodeID", nodeID(conn)).
Msg("health check connected to node")

Expand All @@ -104,15 +104,15 @@ func (t *NodeHealthTracker) SetNodeHealth(nodeID uint32, healthy bool) {

// IsHealthy returns true if the given nodeID has been marked healthy.
func (t *NodeHealthTracker) IsHealthy(nodeID uint32) bool {
t.Lock()
t.RLock()
_, ok := t.healthyNodes[nodeID]
t.Unlock()
t.RUnlock()
return ok
}

// HealthyNodeCount returns the number of healthy nodes currently tracked.
func (t *NodeHealthTracker) HealthyNodeCount() int {
t.Lock()
defer t.Unlock()
t.RLock()
defer t.RUnlock()
return len(t.healthyNodes)
}

0 comments on commit 6b5924d

Please sign in to comment.