diff --git a/README.md b/README.md index dd1ba23..7ab46a0 100644 --- a/README.md +++ b/README.md @@ -35,7 +35,7 @@ cli, err := pg.NewClient(ctx, c) assert.NoError(err) // download user id 50 - it will recursively find everything related to the user -dl, trace, err := datapasta.DownloadWith(ctx, cli, "user", "id", 50) +dl, trace, err := datapasta.Download(ctx, cli, "user", "id", 50) assert.NoError(err) ``` `import.go` @@ -44,7 +44,7 @@ assert.NoError(err) cli, err := pg.NewClient(ctx, db) assert.NoError(err) -datapasta.UploadWith(ctx, cli, dump) +datapasta.Upload(ctx, cli, dump) // return the new id of the user (as postgres provided a new id) return dump[0]["id"].(int32), nil @@ -65,9 +65,9 @@ purchase ( If we export a `user`, the export will recurse into `purchase`, and then recurse into other `user` records that have made purchases, which will likely clone your entire database! -This can be solved by telling DownloadWith not to recurse out of the `purchase` table, with `datapasta.DontRecurse("purchase")`. +This can be solved by telling Download not to recurse out of the `purchase` table, with `datapasta.DontRecurse("purchase")`. -This can also be solved by telling DownloadWith not to include the `user` table at all, with `datapasta.DontInclude("purchase")`. +This can also be solved by telling Download not to include the `user` table at all, with `datapasta.DontInclude("purchase")`. ### Import Tips diff --git a/clone.go b/clone.go index bcdf9bb..9c5ff7f 100644 --- a/clone.go +++ b/clone.go @@ -7,63 +7,60 @@ import ( "strings" ) -// Database is the abstraction between the cloning tool and the database. -// The NewPostgres.NewClient method gives you an implementation for Postgres. -type Database interface { - // get foriegn key mapping - ForeignKeys() []ForeignKey - - // SelectMatchingRows must return unseen records. - // a Database can't be reused between clones, because it must do internal deduping. - // `conds` will be a map of columns and the values they can have. - SelectMatchingRows(tname string, conds map[string][]any) ([]map[string]any, error) - - // Insert uploads a batch of records. - // any changes to the records (such as newly generated primary keys) should mutate the record map directly. - // a Destination can't generally be reused between clones, as it may be inside a transaction. - // it's recommended that callers use a Database that wraps a transaction. - Insert(records ...map[string]any) error -} - type ( + // DatabaseDump is the output of a Download call, containing every record that was downloaded. + // It is safe to transport as JSON. DatabaseDump []map[string]any - Opt func(map[string]bool) -) -type ForeignKey struct { - BaseTable string `json:"base_table"` - BaseCol string `json:"base_col"` - ReferencingTable string `json:"referencing_table"` - ReferencingCol string `json:"referencing_col"` -} + // Opt is a functional option that can be passed to Download. + Opt func(*downloadOpts) +) +// DontRecurse includes records from `table`, but does not recurse into references to it. func DontRecurse(table string) Opt { - return func(m map[string]bool) { - m["dontrecurse."+table] = true + return func(m *downloadOpts) { + m.dontRecurse[table] = true } } +// DontInclude does not recurse into records from `table`, but still includes referenced records. func DontInclude(table string) Opt { - return func(m map[string]bool) { - m["dontinclude."+table] = true + return func(m *downloadOpts) { + m.dontInclude[table] = true + } +} + +// LimitSize causes the clone to fail if more than `limit` records have been collected. +// You should use an estimate of a higher bound for how many records you expect to be exported. +// The default limit is 0, and 0 is treated as having no limit. +func LimitSize(limit int) Opt { + return func(m *downloadOpts) { + m.limit = limit } } const ( - // we store table and primary key names in the dump, using these keys - // because it makes it much easier to transport and clone. - // we *could* stop tracking primary key, but it saves some repeated work on the upload. + // DumpTableKey is a special field present in every row of an export. + // It can be used to determine which table the row is from. + // Note that the export may have rows from a table interleaved with rows from other tables. DumpTableKey = "%_tablename" ) -const MAX_LEN = 50000 +type downloadOpts struct { + dontInclude map[string]bool + dontRecurse map[string]bool + limit int +} -// DownloadWith recursively downloads a dump of the database from a given starting point. +// Download recursively downloads a dump of the database from a given starting point. // the 2nd return is a trace that can help debug or understand what happened. -func DownloadWith(ctx context.Context, db Database, startTable, startColumn string, startId any, opts ...Opt) (DatabaseDump, []string, error) { - flags := map[string]bool{} +func Download(ctx context.Context, db Database, startTable, startColumn string, startId any, opts ...Opt) (DatabaseDump, []string, error) { + options := downloadOpts{ + dontInclude: map[string]bool{}, + dontRecurse: map[string]bool{}, + } for _, o := range opts { - o(flags) + o(&options) } type searchParams struct { @@ -80,6 +77,11 @@ func DownloadWith(ctx context.Context, db Database, startTable, startColumn stri var recurse func(int) error recurse = func(i int) error { + if options.limit != 0 && len(cloneInOrder) >= options.limit { + debugging = append(debugging, "hit maximum recursion") + return fmt.Errorf("%d export limit exceeded", options.limit) + } + if lookupStatus[lookupQueue[i]] { return nil } @@ -107,7 +109,7 @@ func DownloadWith(ctx context.Context, db Database, startTable, startColumn stri res[DumpTableKey] = tname for _, fk := range fks { - if fk.BaseTable != tname || flags["dontrecurse."+fk.BaseTable] { + if fk.BaseTable != tname || options.dontRecurse[fk.BaseTable] || options.dontInclude[fk.ReferencingTable] { continue } // foreign keys pointing to this record can come later @@ -118,7 +120,7 @@ func DownloadWith(ctx context.Context, db Database, startTable, startColumn stri } } for _, fk := range fks { - if fk.ReferencingTable != tname || res[fk.ReferencingCol] == nil || flags["dontinclude."+fk.BaseTable] { + if fk.ReferencingTable != tname || res[fk.ReferencingCol] == nil || options.dontInclude[fk.BaseTable] { continue } // foreign keys referenced by this record must be grabbed before this record @@ -144,17 +146,13 @@ func DownloadWith(ctx context.Context, db Database, startTable, startColumn stri if err := recurse(i); err != nil { return nil, debugging, err } - if len(lookupQueue) >= MAX_LEN { - debugging = append(debugging, "hit maximum recursion") - return nil, debugging, nil - } } return cloneInOrder, debugging, nil } -// UploadWith uploads, in naive order, every record in a dump. +// Upload uploads, in naive order, every record in a dump. // It mutates the elements of `dump`, so you can track changes (for example new primary keys). -func UploadWith(ctx context.Context, db Database, dump DatabaseDump) error { +func Upload(ctx context.Context, db Database, dump DatabaseDump) error { // keep track of old columns and their new values changes := map[string]map[any]any{} diff --git a/clone_test.go b/clone_test.go index b92af11..7094756 100644 --- a/clone_test.go +++ b/clone_test.go @@ -11,7 +11,7 @@ import ( func TestDownloadUpload(t *testing.T) { db, assert := testDB{T: t}, assert.New(t) - res, _, err := datapasta.DownloadWith(context.Background(), db, "company", "id", 10) + res, _, err := datapasta.Download(context.Background(), db, "company", "id", 10) assert.NoError(err) t.Log(res) @@ -19,7 +19,12 @@ func TestDownloadUpload(t *testing.T) { assert.Equal("produces socks", res[1]["desc"]) assert.Equal("socks", res[2]["name"]) - assert.NoError(datapasta.UploadWith(context.Background(), db, res)) + // users are expected to do some cleanup, so test that it works + for _, row := range res { + cleanup(row) + } + + assert.NoError(datapasta.Upload(context.Background(), db, res)) assert.Equal(11, res[0]["id"]) assert.Equal(12, res[1]["id"]) @@ -43,9 +48,6 @@ func (d testDB) SelectMatchingRows(tname string, conds map[string][]any) ([]map[ if conds["id"][0] == 10 { return []map[string]any{{"id": 10, "api_key": "secret_api_key"}}, nil } - if conds["id"][0] == 9 { - return []map[string]any{{"id": 9, "api_key": "secret_api_key"}}, nil - } case "product": if conds["factory_id"] != nil { // we revisit this table because its a dependency of factory as well @@ -54,9 +56,6 @@ func (d testDB) SelectMatchingRows(tname string, conds map[string][]any) ([]map[ if conds["company_id"][0] == 10 { return []map[string]any{{"id": 5, "name": "socks", "company_id": 10, "factory_id": 23}}, nil } - if conds["company_id"][0] == 9 { - return []map[string]any{}, nil - } case "factory": if conds["id"][0] == 23 { return []map[string]any{{"id": 23, "desc": "produces socks"}}, nil @@ -70,11 +69,10 @@ func (d testDB) SelectMatchingRows(tname string, conds map[string][]any) ([]map[ func (d testDB) Insert(records ...map[string]any) error { // test db only handles 1 insert at a time m := records[0] + + d.Logf("inserting %#v", m) + if m[datapasta.DumpTableKey] == "company" && m["id"] == 10 { - m["id"] = 11 - return nil - } - if m[datapasta.DumpTableKey] == "company" && m["id"] == 9 { if m["api_key"] != "obfuscated" { d.Errorf("didn't obfuscated company 9's api key, got %s", m["api_key"]) } diff --git a/utils.go b/integrations/utils.go similarity index 83% rename from utils.go rename to integrations/utils.go index e663e06..aed9f58 100644 --- a/utils.go +++ b/integrations/utils.go @@ -1,11 +1,16 @@ -package datapasta +// integrations package houses some utility functions for making or testing database integrations. +package integrations -import "testing" +import ( + "testing" + + "github.com/ProlificLabs/datapasta" +) // TestDatabaseImplementation is a utility to test an implemention. it just makes sure that 1 row gets cloned correctly. // the row you choose as the starting point must have be referenced as a foreign key by some other table. // ci is not set up, but this function is one of many tests that run against local postgres for development. -func TestDatabaseImplementation(t *testing.T, db Database, startTable, startCol string, startVal any) { +func TestDatabaseImplementation(t *testing.T, db datapasta.Database, startTable, startCol string, startVal any) { // find some interesting columns cols := make(map[string]bool, 0) for _, fk := range db.ForeignKeys() { diff --git a/interface.go b/interface.go new file mode 100644 index 0000000..c5d628c --- /dev/null +++ b/interface.go @@ -0,0 +1,30 @@ +package datapasta + +// Database is the abstraction between the cloning tool and the database. +// The NewPostgres.NewClient method gives you an implementation for Postgres. +type Database interface { + + // SelectMatchingRows must return unseen records. + // a Database can't be reused between clones, because it must do internal deduping. + // `conds` will be a map of columns and the values they can have. + SelectMatchingRows(tname string, conds map[string][]any) ([]map[string]any, error) + + // Insert uploads a batch of records. + // any changes to the records (such as newly generated primary keys) should mutate the record map directly. + // a Destination can't generally be reused between clones, as it may be inside a transaction. + // it's recommended that callers use a Database that wraps a transaction. + Insert(records ...map[string]any) error + + // get foriegn key mapping + ForeignKeys() []ForeignKey +} + +// ForeignKey contains every RERENCING column and the BASE column it refers to. +// This is used to recurse the database as a graph. +// Database implementations must provide a complete list of references. +type ForeignKey struct { + BaseTable string `json:"base_table"` + BaseCol string `json:"base_col"` + ReferencingTable string `json:"referencing_table"` + ReferencingCol string `json:"referencing_col"` +} \ No newline at end of file diff --git a/postgres.go b/postgres.go index cde6851..f2c358b 100644 --- a/postgres.go +++ b/postgres.go @@ -13,8 +13,8 @@ import ( ) // NewPostgres returns a pgdb that can generate a Database for datapasta Upload and Download functions. -func NewPostgres(ctx context.Context, c DBTX) (pgdb, error) { - client := Queries{db: c} +func NewPostgres(ctx context.Context, c Postgreser) (pgdb, error) { + client := postgresQueries{db: c} sqlcPKs, err := client.GetPrimaryKeys(ctx) if err != nil { return pgdb{}, err @@ -25,7 +25,7 @@ func NewPostgres(ctx context.Context, c DBTX) (pgdb, error) { return pgdb{}, err } - pkGroups := make(map[string]GetPrimaryKeysRow, len(sqlcPKs)) + pkGroups := make(map[string]getPrimaryKeysRow, len(sqlcPKs)) for _, pk := range sqlcPKs { pkGroups[pk.TableName] = pk } @@ -44,10 +44,8 @@ func NewPostgres(ctx context.Context, c DBTX) (pgdb, error) { } type pgdb struct { - // client *pgxpool.Pool - // figured out from schema - pkGroups map[string]GetPrimaryKeysRow + pkGroups map[string]getPrimaryKeysRow fks []ForeignKey // squirrel instance to help with stuff @@ -56,10 +54,10 @@ type pgdb struct { // NewClient creates a pgtx that can be used as a Database for Upload and Download. // it is recommended you pass an open transaction, so you can control committing or rolling it back. -func (db pgdb) NewClient(ctx context.Context, tx DBTX) (pgtx, error) { +func (db pgdb) NewClient(ctx context.Context, tx Postgreser) (pgtx, error) { return pgtx{ pgdb: db, - tx: Queries{tx}, + tx: postgresQueries{tx}, ctx: ctx, found: map[string][]any{}, foundWithoutPK: map[any]bool{}, @@ -75,7 +73,7 @@ type pgtx struct { ctx context.Context // as a destination, we need a tx - tx Queries + tx postgresQueries // as a source, we must not return already-found objects found map[string][]any @@ -196,7 +194,9 @@ func (db pgtx) SelectMatchingRows(tname string, conds map[string][]any) ([]map[s return foundInThisScan, nil } -type DBTX interface { +// Postgreser does postgres things. +// github.com/jackc/pgx/v4/pgxpool.Pool is one such implementation of postgres. +type Postgreser interface { Exec(context.Context, string, ...interface{}) (pgconn.CommandTag, error) Query(context.Context, string, ...interface{}) (pgx.Rows, error) QueryRow(context.Context, string, ...interface{}) pgx.Row @@ -204,8 +204,8 @@ type DBTX interface { SendBatch(context.Context, *pgx.Batch) pgx.BatchResults } -type Queries struct { - db DBTX +type postgresQueries struct { + db Postgreser } const getForeignKeys = `-- name: GetForeignKeys :many @@ -217,22 +217,22 @@ SELECT FROM pg_catalog.pg_constraint c join pg_catalog.pg_attribute a on c.confrelid=a.attrelid and a.attnum = ANY(confkey) ` -type GetForeignKeysRow struct { +type getForeignKeysRow struct { BaseTable string `json:"base_table"` BaseCol string `json:"base_col"` ReferencingTable string `json:"referencing_table"` ReferencingCol string `json:"referencing_col"` } -func (q *Queries) GetForeignKeys(ctx context.Context) ([]GetForeignKeysRow, error) { +func (q *postgresQueries) GetForeignKeys(ctx context.Context) ([]getForeignKeysRow, error) { rows, err := q.db.Query(ctx, getForeignKeys) if err != nil { return nil, err } defer rows.Close() - var items []GetForeignKeysRow + var items []getForeignKeysRow for rows.Next() { - var i GetForeignKeysRow + var i getForeignKeysRow if err := rows.Scan( &i.BaseTable, &i.BaseCol, @@ -267,20 +267,20 @@ GROUP BY t.relname HAVING COUNT(*) = 1 ` -type GetPrimaryKeysRow struct { +type getPrimaryKeysRow struct { TableName string `json:"table_name"` ColumnName string `json:"column_name"` } -func (q *Queries) GetPrimaryKeys(ctx context.Context) ([]GetPrimaryKeysRow, error) { +func (q *postgresQueries) GetPrimaryKeys(ctx context.Context) ([]getPrimaryKeysRow, error) { rows, err := q.db.Query(ctx, getPrimaryKeys) if err != nil { return nil, err } defer rows.Close() - var items []GetPrimaryKeysRow + var items []getPrimaryKeysRow for rows.Next() { - var i GetPrimaryKeysRow + var i getPrimaryKeysRow if err := rows.Scan(&i.TableName, &i.ColumnName); err != nil { return nil, err }