Skip to content

Commit

Permalink
clean up pgxcommon
Browse files Browse the repository at this point in the history
we used a "transaction factory" everywhere but never used it to create a transaction (ever since we switched to implicit transactions).

this removes the extra abstraction. it also gives the DBReader interface
a better name (since it can be used for more than just reading)
  • Loading branch information
ecordell committed May 12, 2023
1 parent 0087ebf commit c574f6e
Show file tree
Hide file tree
Showing 11 changed files with 43 additions and 140 deletions.
16 changes: 2 additions & 14 deletions internal/datastore/crdb/caveat.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,10 @@ func (cr *crdbReader) ReadCaveatByName(ctx context.Context, name string) (*core.
return nil, datastore.NoRevision, fmt.Errorf(errReadCaveat, name, err)
}

tx, txCleanup, err := cr.txSource(ctx)
if err != nil {
return nil, datastore.NoRevision, fmt.Errorf(errReadCaveat, name, err)
}
defer txCleanup(ctx)

var definitionBytes []byte
var timestamp time.Time

err = tx.QueryRowFunc(ctx, func(ctx context.Context, row pgx.Row) error {
err = cr.txSource.QueryRowFunc(ctx, func(ctx context.Context, row pgx.Row) error {
return row.Scan(&definitionBytes, &timestamp)
}, sql, args...)
if err != nil {
Expand Down Expand Up @@ -94,15 +88,9 @@ func (cr *crdbReader) lookupCaveats(ctx context.Context, caveatNames []string) (
return nil, fmt.Errorf(errListCaveats, err)
}

tx, txCleanup, err := cr.txSource(ctx)
if err != nil {
return nil, fmt.Errorf(errListCaveats, err)
}
defer txCleanup(ctx)

var allDefinitionBytes []bytesAndTimestamp

err = tx.QueryFunc(ctx, func(ctx context.Context, rows pgx.Rows) error {
err = cr.txSource.QueryFunc(ctx, func(ctx context.Context, rows pgx.Rows) error {
for rows.Next() {
var defBytes []byte
var name string
Expand Down
22 changes: 6 additions & 16 deletions internal/datastore/crdb/crdb.go
Original file line number Diff line number Diff line change
Expand Up @@ -245,43 +245,33 @@ type crdbDatastore struct {
}

func (cds *crdbDatastore) SnapshotReader(rev datastore.Revision) datastore.Reader {
useImplicitTxFunc := func(ctx context.Context) (pgxcommon.DBReader, common.TxCleanupFunc, error) {
return cds.readPool, func(context.Context) {}, nil
}

querySplitter := common.TupleQuerySplitter{
Executor: pgxcommon.NewPGXExecutor(useImplicitTxFunc),
Executor: pgxcommon.NewPGXExecutor(cds.readPool),
UsersetBatchSize: cds.usersetBatchSize,
}

fromBuilder := func(query sq.SelectBuilder, fromStr string) sq.SelectBuilder {
return query.From(fromStr + " AS OF SYSTEM TIME " + rev.String())
}

return &crdbReader{useImplicitTxFunc, querySplitter, noOverlapKeyer, nil, fromBuilder}
return &crdbReader{cds.readPool, querySplitter, noOverlapKeyer, nil, fromBuilder}
}

func noCleanup(context.Context) {}

func (cds *crdbDatastore) ReadWriteTx(
ctx context.Context,
f datastore.TxUserFunc,
) (datastore.Revision, error) {
var commitTimestamp revision.Decimal

err := cds.writePool.BeginFunc(ctx, func(tx pgx.Tx) error {
longLivedTx := func(context.Context) (pgxcommon.DBReader, common.TxCleanupFunc, error) {
return pgxcommon.DBReaderFor(tx), noCleanup, nil
}

querySplitter := common.TupleQuerySplitter{
Executor: pgxcommon.NewPGXExecutor(longLivedTx),
Executor: pgxcommon.NewPGXExecutor(pgxcommon.QuerierFuncsFor(tx)),
UsersetBatchSize: cds.usersetBatchSize,
}

rwt := &crdbReadWriteTXN{
&crdbReader{
longLivedTx,
pgxcommon.QuerierFuncsFor(tx),
querySplitter,
cds.writeOverlapKeyer,
make(keySet),
Expand All @@ -307,7 +297,7 @@ func (cds *crdbDatastore) ReadWriteTx(

if cds.disableStats {
var err error
commitTimestamp, err = readCRDBNow(ctx, pgxcommon.DBReaderFor(tx))
commitTimestamp, err = readCRDBNow(ctx, pgxcommon.QuerierFuncsFor(tx))
if err != nil {
return fmt.Errorf("error getting commit timestamp: %w", err)
}
Expand Down Expand Up @@ -413,7 +403,7 @@ func (cds *crdbDatastore) Features(ctx context.Context) (*datastore.Features, er
return &features, nil
}

func readCRDBNow(ctx context.Context, reader pgxcommon.DBReader) (revision.Decimal, error) {
func readCRDBNow(ctx context.Context, reader pgxcommon.DBFuncQuerier) (revision.Decimal, error) {
ctx, span := tracer.Start(ctx, "readCRDBNow")
defer span.End()

Expand Down
32 changes: 7 additions & 25 deletions internal/datastore/crdb/reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ var (
)

type crdbReader struct {
txSource pgxcommon.TxFactory
txSource pgxcommon.DBFuncQuerier
querySplitter common.TupleQuerySplitter
keyer overlapKeyer
overlapKeySet keySet
Expand All @@ -59,13 +59,7 @@ func (cr *crdbReader) ReadNamespaceByName(
ctx context.Context,
nsName string,
) (*core.NamespaceDefinition, datastore.Revision, error) {
tx, txCleanup, err := cr.txSource(ctx)
if err != nil {
return nil, datastore.NoRevision, fmt.Errorf(errUnableToReadConfig, err)
}
defer txCleanup(ctx)

config, timestamp, err := cr.loadNamespace(ctx, tx, nsName)
config, timestamp, err := cr.loadNamespace(ctx, cr.txSource, nsName)
if err != nil {
if errors.As(err, &datastore.ErrNamespaceNotFound{}) {
return nil, datastore.NoRevision, err
Expand All @@ -77,13 +71,7 @@ func (cr *crdbReader) ReadNamespaceByName(
}

func (cr *crdbReader) ListAllNamespaces(ctx context.Context) ([]datastore.RevisionedNamespace, error) {
tx, txCleanup, err := cr.txSource(ctx)
if err != nil {
return nil, fmt.Errorf(errUnableToListNamespaces, err)
}
defer txCleanup(ctx)

nsDefs, err := loadAllNamespaces(ctx, tx, cr.fromBuilder)
nsDefs, err := loadAllNamespaces(ctx, cr.txSource, cr.fromBuilder)
if err != nil {
return nil, fmt.Errorf(errUnableToListNamespaces, err)
}
Expand All @@ -94,13 +82,7 @@ func (cr *crdbReader) LookupNamespacesWithNames(ctx context.Context, nsNames []s
if len(nsNames) == 0 {
return nil, nil
}
tx, txCleanup, err := cr.txSource(ctx)
if err != nil {
return nil, fmt.Errorf(errUnableToListNamespaces, err)
}
defer txCleanup(ctx)

nsDefs, err := cr.lookupNamespaces(ctx, tx, nsNames)
nsDefs, err := cr.lookupNamespaces(ctx, cr.txSource, nsNames)
if err != nil {
return nil, fmt.Errorf(errUnableToListNamespaces, err)
}
Expand Down Expand Up @@ -149,7 +131,7 @@ func (cr *crdbReader) ReverseQueryRelationships(
options.WithSort(queryOpts.SortForReverse))
}

func (cr crdbReader) loadNamespace(ctx context.Context, tx pgxcommon.DBReader, nsName string) (*core.NamespaceDefinition, time.Time, error) {
func (cr crdbReader) loadNamespace(ctx context.Context, tx pgxcommon.DBFuncQuerier, nsName string) (*core.NamespaceDefinition, time.Time, error) {
query := cr.fromBuilder(queryReadNamespace, tableNamespace).Where(sq.Eq{colNamespace: nsName})

sql, args, err := query.ToSql()
Expand Down Expand Up @@ -178,7 +160,7 @@ func (cr crdbReader) loadNamespace(ctx context.Context, tx pgxcommon.DBReader, n
return loaded, timestamp, nil
}

func (cr crdbReader) lookupNamespaces(ctx context.Context, tx pgxcommon.DBReader, nsNames []string) ([]datastore.RevisionedNamespace, error) {
func (cr crdbReader) lookupNamespaces(ctx context.Context, tx pgxcommon.DBFuncQuerier, nsNames []string) ([]datastore.RevisionedNamespace, error) {
clause := sq.Or{}
for _, nsName := range nsNames {
clause = append(clause, sq.Eq{colNamespace: nsName})
Expand Down Expand Up @@ -224,7 +206,7 @@ func (cr crdbReader) lookupNamespaces(ctx context.Context, tx pgxcommon.DBReader
return nsDefs, nil
}

func loadAllNamespaces(ctx context.Context, tx pgxcommon.DBReader, fromBuilder func(sq.SelectBuilder, string) sq.SelectBuilder) ([]datastore.RevisionedNamespace, error) {
func loadAllNamespaces(ctx context.Context, tx pgxcommon.DBFuncQuerier, fromBuilder func(sq.SelectBuilder, string) sq.SelectBuilder) ([]datastore.RevisionedNamespace, error) {
query := fromBuilder(queryReadNamespace, tableNamespace)

sql, args, err := query.ToSql()
Expand Down
2 changes: 1 addition & 1 deletion internal/datastore/crdb/readwrite.go
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ func (rwt *crdbReadWriteTXN) DeleteNamespaces(ctx context.Context, nsNames ...st
nsClauses := make([]sq.Sqlizer, 0, len(nsNames))
tplClauses := make([]sq.Sqlizer, 0, len(nsNames))
for _, nsName := range nsNames {
_, timestamp, err := rwt.loadNamespace(ctx, pgxcommon.DBReaderFor(rwt.tx), nsName)
_, timestamp, err := rwt.loadNamespace(ctx, pgxcommon.QuerierFuncsFor(rwt.tx), nsName)
if err != nil {
if errors.As(err, &datastore.ErrNamespaceNotFound{}) {
return err
Expand Down
2 changes: 1 addition & 1 deletion internal/datastore/crdb/stats.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ func (cds *crdbDatastore) Statistics(ctx context.Context) (datastore.Stats, erro
return fmt.Errorf("unable to read relationship count: %w", err)
}

nsDefs, err = loadAllNamespaces(ctx, pgxcommon.DBReaderFor(tx), func(sb squirrel.SelectBuilder, fromStr string) squirrel.SelectBuilder {
nsDefs, err = loadAllNamespaces(ctx, pgxcommon.QuerierFuncsFor(tx), func(sb squirrel.SelectBuilder, fromStr string) squirrel.SelectBuilder {
return sb.From(fromStr)
})
if err != nil {
Expand Down
16 changes: 2 additions & 14 deletions internal/datastore/postgres/caveat.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,9 @@ func (r *pgReader) ReadCaveatByName(ctx context.Context, name string) (*core.Cav
return nil, datastore.NoRevision, fmt.Errorf(errReadCaveat, err)
}

tx, txCleanup, err := r.txSource(ctx)
if err != nil {
return nil, datastore.NoRevision, fmt.Errorf(errReadCaveat, err)
}
defer txCleanup(ctx)

var txID xid8
var serializedDef []byte
err = tx.QueryRowFunc(ctx, func(ctx context.Context, row pgx.Row) error {
err = r.txSource.QueryRowFunc(ctx, func(ctx context.Context, row pgx.Row) error {
return row.Scan(&serializedDef, &txID)
}, sql, args...)
if err != nil {
Expand Down Expand Up @@ -90,14 +84,8 @@ func (r *pgReader) lookupCaveats(ctx context.Context, caveatNames []string) ([]d
return nil, fmt.Errorf(errListCaveats, err)
}

tx, txCleanup, err := r.txSource(ctx)
if err != nil {
return nil, fmt.Errorf(errListCaveats, err)
}
defer txCleanup(ctx)

var caveats []datastore.RevisionedCaveat
err = tx.QueryFunc(ctx, func(ctx context.Context, rows pgx.Rows) error {
err = r.txSource.QueryFunc(ctx, func(ctx context.Context, rows pgx.Rows) error {
for rows.Next() {
var version xid8
var defBytes []byte
Expand Down
41 changes: 12 additions & 29 deletions internal/datastore/postgres/common/pgx.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,21 +23,15 @@ import (
const errUnableToQueryTuples = "unable to query tuples: %w"

// NewPGXExecutor creates an executor that uses the pgx library to make the specified queries.
func NewPGXExecutor(txSource TxFactory) common.ExecuteQueryFunc {
func NewPGXExecutor(querier DBFuncQuerier) common.ExecuteQueryFunc {
return func(ctx context.Context, sql string, args []any) ([]*corev1.RelationTuple, error) {
span := trace.SpanFromContext(ctx)

tx, txCleanup, err := txSource(ctx)
if err != nil {
return nil, fmt.Errorf("error getting tx from source: %w", fmt.Errorf(errUnableToQueryTuples, err))
}
defer txCleanup(ctx)
return queryTuples(ctx, sql, args, span, tx)
return queryTuples(ctx, sql, args, span, querier)
}
}

// queryTuples queries tuples for the given query and transaction.
func queryTuples(ctx context.Context, sqlStatement string, args []any, span trace.Span, tx DBReader) ([]*corev1.RelationTuple, error) {
func queryTuples(ctx context.Context, sqlStatement string, args []any, span trace.Span, tx DBFuncQuerier) ([]*corev1.RelationTuple, error) {
// TODO: this event name is misleading
span.AddEvent("DB transaction established")
var tuples []*corev1.RelationTuple
Expand Down Expand Up @@ -109,17 +103,13 @@ func ConfigurePGXLogger(connConfig *pgx.ConnConfig) {
connConfig.Tracer = &tracelog.TraceLog{Logger: levelMappingFn(l), LogLevel: tracelog.LogLevelInfo}
}

// DBReader copies enough of the common interface between pgxpool and tx to be useful
type DBReader interface {
// DBFuncQuerier is satisfied by RetryPool and QuerierFuncs (which can wrap a pgxpool or transaction)
type DBFuncQuerier interface {
ExecFunc(ctx context.Context, tagFunc func(ctx context.Context, tag pgconn.CommandTag, err error) error, sql string, arguments ...any) error
QueryFunc(ctx context.Context, rowsFunc func(ctx context.Context, rows pgx.Rows) error, sql string, optionsAndArgs ...any) error
QueryRowFunc(ctx context.Context, rowFunc func(ctx context.Context, row pgx.Row) error, sql string, optionsAndArgs ...any) error
}

// TxFactory returns a transaction, cleanup function, and any errors that may have
// occurred when building the transaction.
type TxFactory func(context.Context) (DBReader, common.TxCleanupFunc, error)

// PoolOptions is the set of configuration used for a pgx connection pool.
type PoolOptions struct {
ConnMaxIdleTime *time.Duration
Expand Down Expand Up @@ -166,23 +156,16 @@ func (opts PoolOptions) ConfigurePgx(pgxConfig *pgxpool.Config) {
ConfigurePGXLogger(pgxConfig.ConnConfig)
}

// DirectReader is satisfied by pgx.Tx and ConnPooler
type DirectReader interface {
Exec(ctx context.Context, sql string, arguments ...any) (pgconn.CommandTag, error)
Query(ctx context.Context, sql string, optionsAndArgs ...any) (pgx.Rows, error)
QueryRow(ctx context.Context, sql string, optionsAndArgs ...any) pgx.Row
}

type Reader struct {
d DirectReader
type QuerierFuncs struct {
d Querier
}

func (t *Reader) ExecFunc(ctx context.Context, tagFunc func(ctx context.Context, tag pgconn.CommandTag, err error) error, sql string, arguments ...any) error {
func (t *QuerierFuncs) ExecFunc(ctx context.Context, tagFunc func(ctx context.Context, tag pgconn.CommandTag, err error) error, sql string, arguments ...any) error {
tag, err := t.d.Exec(ctx, sql, arguments...)
return tagFunc(ctx, tag, err)
}

func (t *Reader) QueryFunc(ctx context.Context, rowsFunc func(ctx context.Context, rows pgx.Rows) error, sql string, optionsAndArgs ...any) error {
func (t *QuerierFuncs) QueryFunc(ctx context.Context, rowsFunc func(ctx context.Context, rows pgx.Rows) error, sql string, optionsAndArgs ...any) error {
rows, err := t.d.Query(ctx, sql, optionsAndArgs...)
if err != nil {
return err
Expand All @@ -191,10 +174,10 @@ func (t *Reader) QueryFunc(ctx context.Context, rowsFunc func(ctx context.Contex
return rowsFunc(ctx, rows)
}

func (t *Reader) QueryRowFunc(ctx context.Context, rowFunc func(ctx context.Context, row pgx.Row) error, sql string, optionsAndArgs ...any) error {
func (t *QuerierFuncs) QueryRowFunc(ctx context.Context, rowFunc func(ctx context.Context, row pgx.Row) error, sql string, optionsAndArgs ...any) error {
return rowFunc(ctx, t.d.QueryRow(ctx, sql, optionsAndArgs...))
}

func DBReaderFor(d DirectReader) DBReader {
return &Reader{d: d}
func QuerierFuncsFor(d Querier) DBFuncQuerier {
return &QuerierFuncs{d: d}
}
18 changes: 4 additions & 14 deletions internal/datastore/postgres/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -296,24 +296,18 @@ type pgDatastore struct {
func (pgd *pgDatastore) SnapshotReader(revRaw datastore.Revision) datastore.Reader {
rev := revRaw.(postgresRevision)

createTxFunc := func(ctx context.Context) (pgxcommon.DBReader, common.TxCleanupFunc, error) {
return pgxcommon.DBReaderFor(pgd.readPool), func(ctx context.Context) {}, nil
}

querySplitter := common.TupleQuerySplitter{
Executor: pgxcommon.NewPGXExecutor(createTxFunc),
Executor: pgxcommon.NewPGXExecutor(pgxcommon.QuerierFuncsFor(pgd.readPool)),
UsersetBatchSize: pgd.usersetBatchSize,
}

return &pgReader{
createTxFunc,
pgxcommon.QuerierFuncsFor(pgd.readPool),
querySplitter,
buildLivingObjectFilterForRevision(rev),
}
}

func noCleanup(context.Context) {}

// ReadWriteTx starts a read/write transaction, which will be committed if no error is
// returned and rolled back if an error is returned.
func (pgd *pgDatastore) ReadWriteTx(
Expand All @@ -331,18 +325,14 @@ func (pgd *pgDatastore) ReadWriteTx(
return err
}

longLivedTx := func(context.Context) (pgxcommon.DBReader, common.TxCleanupFunc, error) {
return pgxcommon.DBReaderFor(tx), noCleanup, nil
}

querySplitter := common.TupleQuerySplitter{
Executor: pgxcommon.NewPGXExecutor(longLivedTx),
Executor: pgxcommon.NewPGXExecutor(pgxcommon.QuerierFuncsFor(pgd.readPool)),
UsersetBatchSize: pgd.usersetBatchSize,
}

rwt := &pgReadWriteTXN{
&pgReader{
longLivedTx,
pgxcommon.QuerierFuncsFor(pgd.readPool),
querySplitter,
currentlyLivingObjects,
},
Expand Down

0 comments on commit c574f6e

Please sign in to comment.