From 2800ca1979b5b415ba6056d0d775b1d5feede679 Mon Sep 17 00:00:00 2001 From: Daniel Toye Date: Fri, 21 Jul 2023 13:01:33 -0600 Subject: [PATCH 1/4] refactors, merge strategy, working on database methods --- clone.go | 53 +------ clone_test.go | 20 ++- integrations/utils.go | 4 +- interface.go | 53 ++++++- merge.go | 236 ++++++++++++++++++++++++++++ merge_test.go | 347 ++++++++++++++++++++++++++++++++++++++++++ postgres.go | 52 ++++++- postgres_test.go | 118 -------------- pulley_tests.go | 247 ++++++++++++++++++++++++++++++ 9 files changed, 943 insertions(+), 187 deletions(-) create mode 100644 merge.go create mode 100644 merge_test.go delete mode 100644 postgres_test.go create mode 100644 pulley_tests.go diff --git a/clone.go b/clone.go index 96c42e8..01a3d35 100644 --- a/clone.go +++ b/clone.go @@ -3,7 +3,6 @@ package datapasta import ( "context" "fmt" - "log" "strings" ) @@ -153,55 +152,5 @@ func Download(ctx context.Context, db Database, startTable, startColumn string, // Upload uploads, in naive order, every record in a dump. // It mutates the elements of `dump`, so you can track changes (for example new primary keys). func Upload(ctx context.Context, db Database, dump DatabaseDump) error { - fkm := NewForeignKeyMapper(db) - return db.Insert(fkm, dump...) -} - -type ForeignKeyMapper func(row map[string]any) func() - -// NewForeignKeyMapper returns a function that will update foreign key references in a row to their new values. -// each update returns a function that must be called after the row has been updated with new primary keys. -func NewForeignKeyMapper(db Database) ForeignKeyMapper { - changes := make(map[string]map[any]any) - - for _, fk := range db.ForeignKeys() { - changes[fk.BaseTable+"."+fk.BaseCol] = map[any]any{} - } - - return func(row map[string]any) func() { - table := row[DumpTableKey].(string) - for k, v := range row { - for _, fk := range db.ForeignKeys() { - if fk.ReferencingTable != table || fk.ReferencingCol != k || v == nil || changes[fk.BaseTable+`.`+fk.BaseCol] == nil { - continue - } - - newID, ok := changes[fk.BaseTable+`.`+fk.BaseCol][v] - if !ok { - log.Printf("unable to find mapped id for %s[%s]=%v in %s", table, k, v, fk.BaseTable) - } else { - row[k] = newID - } - } - } - - copy := make(map[string]any, len(row)) - for k, v := range row { - // does anyone care about this value? - if changes[table+`.`+k] == nil { - continue - } - copy[k] = v - } - - return func() { - table := row[DumpTableKey].(string) - for k, v := range row { - if changes[table+"."+k] == nil { - continue - } - changes[table+"."+k][copy[k]] = v - } - } - } + return db.Insert(dump...) } diff --git a/clone_test.go b/clone_test.go index ee53bf4..5916d1e 100644 --- a/clone_test.go +++ b/clone_test.go @@ -65,10 +65,23 @@ func (d testDB) SelectMatchingRows(tname string, conds map[string][]any) ([]map[ return nil, fmt.Errorf("no mock for %s where %#v", tname, conds) } +func (d testDB) PrimaryKeys() map[string]string { + return nil +} + +func (d testDB) InsertRecord(map[string]any) (any, error) { return nil, nil } + +// apply the updates from the cols to the row +func (d testDB) Update(id datapasta.RecordID, cols map[string]any) error { return nil } + +// delete the row +func (d testDB) Delete(id datapasta.RecordID) error { return nil } + +func (d testDB) Mapping() ([]datapasta.Mapping, error) { return nil, nil } + // upload a batch of records -func (d testDB) Insert(fkm datapasta.ForeignKeyMapper, records ...map[string]any) error { +func (d testDB) Insert(records ...map[string]any) error { for _, m := range records { - finish := fkm(m) d.Logf("inserting %#v", m) if m[datapasta.DumpTableKey] == "company" && m["id"] == 10 { @@ -76,17 +89,14 @@ func (d testDB) Insert(fkm datapasta.ForeignKeyMapper, records ...map[string]any d.Errorf("didn't obfuscated company 9's api key, got %s", m["api_key"]) } m["id"] = 11 - finish() continue } if m[datapasta.DumpTableKey] == "factory" && m["id"] == 23 { m["id"] = 12 - finish() continue } if m[datapasta.DumpTableKey] == "product" && m["id"] == 5 { m["id"] = 13 - finish() continue } return fmt.Errorf("unexpected insert: %#v", m) diff --git a/integrations/utils.go b/integrations/utils.go index 7e9f380..aed9f58 100644 --- a/integrations/utils.go +++ b/integrations/utils.go @@ -33,8 +33,7 @@ func TestDatabaseImplementation(t *testing.T, db datapasta.Database, startTable, old[k] = v } - fkm := datapasta.NewForeignKeyMapper(db) - if err := db.Insert(fkm, found[0]); err != nil { + if err := db.Insert(found[0]); err != nil { t.Fatalf("error inserting row: %s", err.Error()) return } @@ -62,4 +61,3 @@ func TestDatabaseImplementation(t *testing.T, db datapasta.Database, startTable, return } } - diff --git a/interface.go b/interface.go index 098365d..7fdf3e5 100644 --- a/interface.go +++ b/interface.go @@ -1,30 +1,69 @@ package datapasta +import "fmt" + // Database is the abstraction between the cloning tool and the database. // The NewPostgres.NewClient method gives you an implementation for Postgres. type Database interface { - // SelectMatchingRows must return unseen records. // a Database can't be reused between clones, because it must do internal deduping. // `conds` will be a map of columns and the values they can have. SelectMatchingRows(tname string, conds map[string][]any) ([]map[string]any, error) - + + // insert one record, returning the new id + InsertRecord(record map[string]any) (any, error) + + // apply the updates from the cols to the row + Update(id RecordID, cols map[string]any) error + + // delete the row + Delete(id RecordID) error + // Insert uploads a batch of records. - // any changes to the records (such as newly generated primary keys) should mutate the record map directly. // a Destination can't generally be reused between clones, as it may be inside a transaction. // it's recommended that callers use a Database that wraps a transaction. - Insert(mapper ForeignKeyMapper, records ...map[string]any) error - + // + // the records will have primary keys which must be handled. + // the Database is responsible for exposing the resulting primary key mapping in some manner. + Insert(records ...map[string]any) error + + // Mapping must return whatever mapping has been created by prior Inserts. + // the implementation may internally choose to track this in the database or in memory. + Mapping() ([]Mapping, error) + // get foriegn key mapping ForeignKeys() []ForeignKey + + // get primary key mapping + PrimaryKeys() map[string]string } // ForeignKey contains every RERENCING column and the BASE column it refers to. -// This is used to recurse the database as a graph. +// This is used to recurse the database as a graph. // Database implementations must provide a complete list of references. type ForeignKey struct { BaseTable string `json:"base_table"` BaseCol string `json:"base_col"` ReferencingTable string `json:"referencing_table"` ReferencingCol string `json:"referencing_col"` -} \ No newline at end of file +} + +type RecordID struct { + Table string + PrimaryKey any +} + +func (r RecordID) String() string { + return fmt.Sprintf(`%s(%v)`, r.Table, r.PrimaryKey) +} + +func GetRowIdentifier(pks map[string]string, row map[string]any) RecordID { + table := row[DumpTableKey].(string) + pk := row[pks[table]] + return RecordID{Table: table, PrimaryKey: pk} +} + +type Mapping struct { + TableName string + OriginalID, NewID any +} diff --git a/merge.go b/merge.go new file mode 100644 index 0000000..380eda5 --- /dev/null +++ b/merge.go @@ -0,0 +1,236 @@ +package datapasta + +import "fmt" + +type MergeAction struct { + ID RecordID + Action string + Data map[string]any +} + +func (ma MergeAction) String() string { + return fmt.Sprintf(`%s %s %#v`, ma.Action, ma.ID, ma.Data) +} + +func FindRow(table, pk string, id any, dump DatabaseDump) map[string]any { + if id == nil { + return nil + } + for _, d := range dump { + if d[DumpTableKey] != table { + continue + } + if d[pk] == id { + return d + } + } + return nil +} + +func FindMapping(table string, id any, mapp []Mapping) Mapping { + for _, m := range mapp { + if m.TableName != table { + continue + } + if m.NewID == id { + return m + } + } + return Mapping{TableName: table, OriginalID: id, NewID: id} +} + +// reverse all the primary keys of a dump +func ReversePrimaryKeyMapping(pks map[string]string, mapp []Mapping, dump DatabaseDump) { + for _, row := range dump { + table := row[DumpTableKey].(string) + pk, hasPk := pks[table] + if !hasPk { + continue + } + m := FindMapping(table, row[pk], mapp) + row[pk] = m.OriginalID + } +} + +// reverse all the foreign keys of an indivdual row +func ReverseForeignKeyMappingRow(fks []ForeignKey, mapp []Mapping, row map[string]any) { + update := func(row map[string]any, col, otherTable string) { + for _, m := range mapp { + if m.TableName != otherTable { + continue + } + if m.NewID != row[col] { + continue + } + row[col] = m.OriginalID + } + } + + table := row[DumpTableKey].(string) + for _, fk := range fks { + if fk.ReferencingTable != table { + continue + } + update(row, fk.ReferencingCol, fk.BaseTable) + } +} + +// reverse all the foreign keys of a dump +func ReverseForeignKeyMapping(fks []ForeignKey, mapp []Mapping, rows DatabaseDump) { + for _, row := range rows { + ReverseForeignKeyMappingRow(fks, mapp, row) + } +} + +// find rows in "from" that are missing in "in" +func FindMissingRows(pks map[string]string, from, in DatabaseDump) DatabaseDump { + out := make(DatabaseDump, 0) + for _, row := range from { + table := row[DumpTableKey].(string) + pk, hasPk := pks[table] + if !hasPk { + continue + } + match := FindRow(table, pk, row[pk], in) + if match != nil { + continue + } + + out = append(out, row) + } + return out +} + +// return a map of updates or deletes that would make "in" equal "from" +// the map key is the table and column that changed +// and the value is the new value +func FindModifiedRows(pks map[string]string, from, in DatabaseDump) map[RecordID]map[string]any { + all := make(map[RecordID]map[string]any) + for _, row := range from { + table := row[DumpTableKey].(string) + pk, hasPk := pks[table] + if !hasPk { + continue + } + match := FindRow(table, pk, row[pk], in) + if match == nil { + continue + } + + changes := make(map[string]any) + for k, v := range match { + if v != row[k] { + changes[k] = row[k] + } + } + + if len(changes) == 0 { + continue + } + all[RecordID{Table: table, PrimaryKey: row[pk]}] = changes + } + return all +} + +func ApplyMergeStrategy(db Database, mapp []Mapping, dump DatabaseDump, mas []MergeAction) error { + fks := db.ForeignKeys() + pks := db.PrimaryKeys() + + for _, ma := range mas { + if ma.Action != "create" { + continue + } + ReverseForeignKeyMappingRow(fks, mapp, ma.Data) + origID := ma.Data[pks[ma.ID.Table]] + delete(ma.Data, pks[ma.ID.Table]) + id, err := db.InsertRecord(ma.Data) + if err != nil { + return err + } + mapp = append(mapp, Mapping{TableName: ma.ID.Table, OriginalID: origID, NewID: id}) + } + + // do all the creates *while* updating the mapping + // do all the updates + // do the all deletes + return nil +} + +// GenerateMergeStrategy returns every update or delete needed to merge branch into main +// note that conflicts will be intermingled in updates and deletes +func GenerateMergeStrategy(pks map[string]string, base, main, branch DatabaseDump) []MergeAction { + out := make([]MergeAction, 0) + + deletedInMain := make(map[RecordID]bool) + for _, deleted := range FindMissingRows(pks, base, main) { + deletedInMain[GetRowIdentifier(pks, deleted)] = true + } + editedInMain := make(map[RecordID]bool) + for id := range FindModifiedRows(pks, main, base) { + editedInMain[id] = true + } + + created := FindMissingRows(pks, branch, base) + for _, m := range created { + id := GetRowIdentifier(pks, m) + out = append(out, MergeAction{id, "create", m}) + } + + changes := FindModifiedRows(pks, branch, base) + for id, c := range changes { + if editedInMain[id] || deletedInMain[id] { + out = append(out, MergeAction{id, "conflict", c}) + continue + } + out = append(out, MergeAction{id, "update", c}) + } + + deleted := FindMissingRows(pks, base, branch) + for _, m := range deleted { + id := GetRowIdentifier(pks, m) + if editedInMain[id] { + out = append(out, MergeAction{id, "conflict", m}) + continue + } + out = append(out, MergeAction{id, "delete", m}) + } + + return out +} + +// // ThreeWayMerge applies a three way merge to a Diffing interface +// func ThreeWayMerge(db Diffing, mapp []Mapping, root, main, branch DatabaseDump) (actions MergeAction) { +// // existing definitions: +// // type DatabaseDump []map[string]any +// // type Mapping struct { +// // TableName string +// // OriginalID, NewID any +// // } +// // type Diffing interface { +// // Insert(table string, record map[string]any) (pk any, err error) +// // Update(table string, record map[string]any) error +// // } + +// // all DatabaseDump rows have an "id" field of type any that is used for finding the old id (unmapping) + +// // define slice for merge actions + +// // find new items in branch not in root +// // for each new item: +// // apply id unmapping to this item +// // insert it into the db +// // append new id to mapping +// // append this change as a nonconflicting change +// // apply id unmapping to everything in branch + +// // find all modified or deleted items in main +// // find all modified or deleted items in branch +// // for each diff item in branch: +// // if it exists in the main diff +// // append this as a conflicting merge actions +// // otherwise +// // append this to the merge actions +// // update the record in the db + +// // return conflicting and nonconflicting changes +// } diff --git a/merge_test.go b/merge_test.go new file mode 100644 index 0000000..def0e95 --- /dev/null +++ b/merge_test.go @@ -0,0 +1,347 @@ +package datapasta + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestFindNewRows(t *testing.T) { + main := DatabaseDump{ + { + DumpTableKey: "person", + "id": 10, + }, + } + branch := DatabaseDump{ + { + DumpTableKey: "person", + "id": 10, + }, + { + DumpTableKey: "person", + "id": 11, + }, + } + pks := map[string]string{"person": "id"} + + newRows := FindMissingRows(pks, branch, main) + + ok := assert.New(t) + ok.Len(newRows, 1) + ok.Equal(11, newRows[0]["id"]) +} + +func TestReversePrimaryKeyMapping(t *testing.T) { + branch := DatabaseDump{ + { + DumpTableKey: "person", + "id": 10, + }, + { + DumpTableKey: "person", + "id": 11, + }, + } + mapp := []Mapping{ + {TableName: "person", OriginalID: 8, NewID: 11}, + } + pks := map[string]string{"person": "id"} + + ReversePrimaryKeyMapping(pks, mapp, branch) + + ok := assert.New(t) + ok.Equal(10, branch[0]["id"]) + ok.Equal(8, branch[1]["id"]) +} + +func TestFindModifiedRows(t *testing.T) { + pks := map[string]string{"person": "id"} + + main := DatabaseDump{ + { + DumpTableKey: "person", + "id": 10, + "name": "alice", + }, + } + branch := DatabaseDump{ + { + DumpTableKey: "person", + "id": 10, + "name": "alicia", + }, + } + + mods := FindModifiedRows(pks, branch, main) + + change := RecordID{Table: "person", PrimaryKey: 10} + + ok := assert.New(t) + ok.Len(mods, 1) + ok.Len(mods[change], 1) + ok.Equal("alicia", mods[change]["name"]) +} + +func TestReverseForeignKeyMapping(t *testing.T) { + main := DatabaseDump{ + { + DumpTableKey: "person", + "country": 20, + }, + { + DumpTableKey: "country", + "id": 10, + }, + } + + fks := []ForeignKey{{ReferencingTable: "person", ReferencingCol: "country", BaseTable: "country", BaseCol: "id"}} + mapp := []Mapping{{TableName: "country", OriginalID: 15, NewID: 20}} + + ReverseForeignKeyMapping(fks, mapp, main) + + ok := assert.New(t) + ok.Equal(15, main[0]["country"]) +} + +func TestGenerateMergeStrategy(t *testing.T) { + base := DatabaseDump{ + { + DumpTableKey: "person", + "id": 10, + "name": "left_alone", + }, + { + DumpTableKey: "person", + "id": 11, + "name": "name_changed_in_main", + }, + { + DumpTableKey: "person", + "id": 12, + "name": "name_changed_in_branch", + }, + { + DumpTableKey: "person", + "id": 13, + "name": "deleted_in_main", + }, + { + DumpTableKey: "person", + "id": 14, + "name": "deleted_in_branch", + }, + + // conflicts + { + DumpTableKey: "person", + "id": 17, + "name": "deleted_main_updated_branch", + }, + { + DumpTableKey: "person", + "id": 18, + "name": "deleted_branch_updated_main", + }, + { + DumpTableKey: "person", + "id": 19, + "name": "deleted_both", // not a conflict + }, + { + DumpTableKey: "person", + "id": 20, + "name": "updated_both", + }, + } + main := DatabaseDump{ + { + DumpTableKey: "person", + "id": 10, + "name": "left_alone", + }, + { + DumpTableKey: "person", + "id": 11, + "name": "name_changed_in_main_completed", + }, + { + DumpTableKey: "person", + "id": 12, + "name": "name_changed_in_branch", + }, + { + DumpTableKey: "person", + "id": 14, + "name": "deleted_in_branch", + }, + { + DumpTableKey: "person", + "id": 15, + "name": "created_in_main", + }, + + // conflicts + { + DumpTableKey: "person", + "id": 18, + "name": "deleted_branch_updated_main_complete", + }, + + { + DumpTableKey: "person", + "id": 20, + "name": "updated_both_complete_main", + }, + } + branch := DatabaseDump{ + { + DumpTableKey: "person", + "id": 10, + "name": "left_alone", + }, + { + DumpTableKey: "person", + "id": 11, + "name": "name_changed_in_main", + }, + { + DumpTableKey: "person", + "id": 12, + "name": "name_changed_in_branch_completed", + }, + { + DumpTableKey: "person", + "id": 13, + "name": "deleted_in_main", + }, + { + DumpTableKey: "person", + "id": 16, + "name": "created_in_branch", + }, + + // conflicts + { + DumpTableKey: "person", + "id": 17, + "name": "deleted_main_updated_branch_complete", + }, + { + DumpTableKey: "person", + "id": 20, + "name": "updated_both_complete_branch", + }, + } + pks := map[string]string{"person": "id"} + + actions := GenerateMergeStrategy(pks, base, main, branch) + + for _, ma := range actions { + t.Logf("%#v", ma) + } + + ok := assert.New(t) + ok.Len(actions, 7) + + // creation is not included in the merge + ok.Equal("create", actions[0].Action) + ok.Equal(16, actions[0].ID.PrimaryKey) + + ok.Equal("update", actions[1].Action) + ok.Equal(12, actions[1].ID.PrimaryKey) + + ok.Contains(actions, MergeAction{ID: RecordID{Table: "person", PrimaryKey: 20}, Action: "conflict", Data: map[string]interface{}{"name": "updated_both_complete_branch"}}) + ok.Contains(actions, MergeAction{ID: RecordID{Table: "person", PrimaryKey: 12}, Action: "update", Data: map[string]interface{}{"name": "name_changed_in_branch_completed"}}) + ok.Contains(actions, MergeAction{ID: RecordID{Table: "person", PrimaryKey: 20}, Action: "conflict", Data: map[string]interface{}{"name": "updated_both_complete_branch"}}) + + ok.Equal("conflict", actions[5].Action) + ok.Equal(18, actions[5].ID.PrimaryKey) + + ok.Equal("delete", actions[6].Action) + ok.Equal(19, actions[6].ID.PrimaryKey) +} + +func TestGenerateMergeStrategyWithMapping(t *testing.T) { + base := DatabaseDump{ + { + DumpTableKey: "person", + "id": 10, + "friend": 9, + }, + { + DumpTableKey: "person", + "id": 11, + "friend": 10, + }, + } + main := DatabaseDump{ + { + DumpTableKey: "person", + "id": 10, + "friend": 9, + }, + { + DumpTableKey: "person", + "id": 11, + "friend": 10, + }, + } + branch := DatabaseDump{ + { + DumpTableKey: "person", + "id": 20, + "friend": 19, + }, + { + DumpTableKey: "person", + "id": 22, + }, + { + DumpTableKey: "person", + "id": 21, + "friend": 22, + }, + } + pks := map[string]string{"person": "id"} + fks := []ForeignKey{ + { + BaseTable: "person", + BaseCol: "id", + ReferencingTable: "person", + ReferencingCol: "friend", + }, + } + mapping := []Mapping{ + { + TableName: "person", + OriginalID: 9, + NewID: 19, + }, + { + TableName: "person", + OriginalID: 10, + NewID: 20, + }, + { + TableName: "person", + OriginalID: 11, + NewID: 21, + }, + } + + ReversePrimaryKeyMapping(pks, mapping, branch) + ReverseForeignKeyMapping(fks, mapping, branch) + mas := GenerateMergeStrategy(pks, base, main, branch) + + ok := assert.New(t) + ok.Len(mas, 2) + + ok.Equal("create", mas[0].Action) + ok.Equal(22, mas[0].ID.PrimaryKey) + + ok.Equal("update", mas[1].Action) + ok.Equal(11, mas[1].ID.PrimaryKey) + ok.Len(mas[1].Data, 1) + ok.Equal(22, mas[1].Data["friend"]) +} diff --git a/postgres.go b/postgres.go index fc9ec2c..c3acbc2 100644 --- a/postgres.go +++ b/postgres.go @@ -33,12 +33,12 @@ func NewPostgres(ctx context.Context, c Postgreser) (pgdb, error) { pkGroups[pk.TableName] = pk } - builder := squirrel.StatementBuilder.PlaceholderFormat(squirrel.Dollar) fks := make([]ForeignKey, 0, len(sqlcFKs)) for _, fk := range sqlcFKs { fks = append(fks, ForeignKey(fk)) } + builder := squirrel.StatementBuilder.PlaceholderFormat(squirrel.Dollar) return pgdb{ fks: fks, pkGroups: pkGroups, @@ -59,6 +59,14 @@ func (db pgdb) ForeignKeys() []ForeignKey { return db.fks } +func (db pgdb) PrimaryKeys() map[string]string { + out := make(map[string]string) + for _, r := range db.pkGroups { + out[r.TableName] = r.ColumnName + } + return out +} + type pgtx struct { pgdb ctx context.Context @@ -152,7 +160,47 @@ func (db pgbatchtx) SelectMatchingRows(tname string, conds map[string][]any) ([] return foundInThisScan, nil } -func (db pgbatchtx) Insert(fkm ForeignKeyMapper, rows ...map[string]any) error { +func (db pgbatchtx) InsertRecord(row map[string]any) (any, error) { + keys := make([]string, 0, len(row)) + vals := make([]any, 0, len(row)) + table := row[DumpTableKey].(string) + builder := db.builder.Insert(`"` + table + `"`) + for k, v := range row { + if v == nil { + continue + } + if k == DumpTableKey { + continue + } + keys = append(keys, fmt.Sprintf(`"%s"`, k)) + vals = append(vals, v) + } + + builder = builder.Columns(keys...).Values(vals...) + sql, args, err := builder.ToSql() + if err != nil { + return nil, err + } + var id any + if err := db.tx.db.QueryRow(db.ctx, sql, args).Scan(&id); err != nil { + return nil, err + } + return id, nil +} + +func (db pgbatchtx) Update(id RecordID, cols map[string]any) error { + return nil +} + +func (db pgbatchtx) Delete(id RecordID) error { + return nil +} + +func (db pgbatchtx) Mapping() ([]Mapping, error) { + return nil, nil +} + +func (db pgbatchtx) Insert(rows ...map[string]any) error { if _, err := db.tx.db.Exec(db.ctx, "CREATE TEMPORARY TABLE IF NOT EXISTS datapasta_clone(table_name text, original_id integer, clone_id integer) ON COMMIT DROP"); err != nil { return err } diff --git a/postgres_test.go b/postgres_test.go deleted file mode 100644 index c677f3f..0000000 --- a/postgres_test.go +++ /dev/null @@ -1,118 +0,0 @@ -package datapasta - -import ( - "context" - "encoding/json" - "log" - "strings" - "testing" - "time" - - "github.com/jackc/pgx/v4" - "github.com/jackc/pgx/v4/pgxpool" - "github.com/stretchr/testify/assert" -) - -const largeCompany = 2127511261 -const testCompany = 1296515245 - -func TestWithLocalPostgres(t *testing.T) { - t.Skipf("test is used for development against real pulley schema") - - company := testCompany - log.Println("starting to clone company", company) - - ok := assert.New(t) - conn, err := pgxpool.Connect(context.Background(), `postgresql://postgres:postgres@localhost:5432/postgres`) - ok.NoError(err) - db, err := NewPostgres(context.Background(), conn) - ok.NoError(err) - - tx, err := conn.BeginTx(context.Background(), pgx.TxOptions{}) - ok.NoError(err) - defer tx.Rollback(context.Background()) - - cli, err := db.NewBatchClient(context.Background(), tx) - ok.NoError(err) - - exportOpts := []Opt{ - DontInclude("user"), - DontInclude("firm"), - DontRecurse("stakeholder"), - DontInclude("sandbox_clone"), - DontInclude("sandbox"), - } - - startDL := time.Now() - res, _, err := Download(context.Background(), cli, "company", "id", company, exportOpts...) - ok.NoError(err) - ok.NotEmpty(res) - download := time.Since(startDL) - - for _, row := range res { - CleanupRow(row) - log.Println("cloning", row[DumpTableKey], row["id"]) - } - - in, _ := json.Marshal(res) - out := make([]map[string]any, 0, len(res)) - json.Unmarshal(in, &out) - - fkm := NewForeignKeyMapper(cli) - start := time.Now() - - log.Println("starting to insert company", company) - ok.NoError(cli.Insert(fkm, out...)) - upload := time.Since(start) - - var newID int64 - switch any(cli).(type) { - case pgbatchtx: - ok.NoError(tx.QueryRow(context.Background(), "SELECT clone_id FROM datapasta_clone WHERE original_id = $1 AND table_name = 'company'", company).Scan(&newID)) - case pgtx: - newID = int64(out[0]["id"].(int32)) - } - - t.Logf("new id: %d", newID) - - log.Println("starting to download company", newID) - newRes, deb, err := Download(context.Background(), cli, "company", "id", newID, exportOpts...) - ok.NoError(err) - - for _, l := range deb { - if !strings.HasSuffix(l, " 0 rows") { - t.Logf("debug: %s ... %s", l[:20], l[len(l)-20:]) - } - } - - for _, out := range newRes { - if out[DumpTableKey] == "company" { - t.Logf("found cloned company %v", out["id"]) - } - } - - ok.Equalf(len(res), len(newRes), "expected clone to have the same size export") - - t.Logf("durations: download(%s), upload(%s)", download, upload) -} - -// postgres rows need some pulley-specific cleanup -func CleanupRow(obj map[string]any) { - if obj[DumpTableKey] == "security" { - obj["change_email_token"] = nil - } - if obj[DumpTableKey] == "task" { - obj["access_code"] = nil - } - if obj[DumpTableKey] == "company" { - obj["stripe_customer_id"] = nil - } - for col, raw := range obj { - switch val := raw.(type) { - case time.Time: - if val.IsZero() || val.Year() <= 1 { - obj[col] = time.Now() // there's a few invalid timestamps - } - } - } -} diff --git a/pulley_tests.go b/pulley_tests.go new file mode 100644 index 0000000..42bc1e8 --- /dev/null +++ b/pulley_tests.go @@ -0,0 +1,247 @@ +package datapasta + +import ( + "context" + "encoding/json" + "log" + "strings" + "testing" + "time" + + "github.com/jackc/pgx/v4" + "github.com/jackc/pgx/v4/pgxpool" + "github.com/stretchr/testify/assert" +) + +const runPulleyTests = true + +// This file houses tests that run against Pulley schema. +// TODO: have a docker database to run tests in a schema in datapasta. + +const largeCompany = 2127511261 +const testCompany = 1296515245 + +func TestWithLocalPostgres(t *testing.T) { + if !runPulleyTests { + t.Skipf("test is used for development against real pulley schema") + } + + company := testCompany + log.Println("starting to clone company", company) + + ok := assert.New(t) + conn, err := pgxpool.Connect(context.Background(), `postgresql://postgres:postgres@localhost:5432/postgres`) + ok.NoError(err) + db, err := NewPostgres(context.Background(), conn) + ok.NoError(err) + + tx, err := conn.BeginTx(context.Background(), pgx.TxOptions{}) + ok.NoError(err) + defer tx.Rollback(context.Background()) + + cli, err := db.NewBatchClient(context.Background(), tx) + ok.NoError(err) + + exportOpts := []Opt{ + DontInclude("user"), + DontInclude("firm"), + DontRecurse("stakeholder"), + DontInclude("sandbox_clone"), + DontInclude("sandbox"), + } + + startDL := time.Now() + res, _, err := Download(context.Background(), cli, "company", "id", company, exportOpts...) + ok.NoError(err) + ok.NotEmpty(res) + download := time.Since(startDL) + + for _, row := range res { + CleanupRow(row) + log.Println("cloning", row[DumpTableKey], row["id"]) + } + + in, _ := json.Marshal(res) + out := make([]map[string]any, 0, len(res)) + json.Unmarshal(in, &out) + + start := time.Now() + + log.Println("starting to insert company", company) + ok.NoError(cli.Insert(out...)) + upload := time.Since(start) + + var newID int64 + switch any(cli).(type) { + case pgbatchtx: + ok.NoError(tx.QueryRow(context.Background(), "SELECT clone_id FROM datapasta_clone WHERE original_id = $1 AND table_name = 'company'", company).Scan(&newID)) + case pgtx: + newID = int64(out[0]["id"].(int32)) + } + + t.Logf("new id: %d", newID) + + log.Println("starting to download company", newID) + newRes, deb, err := Download(context.Background(), cli, "company", "id", newID, exportOpts...) + ok.NoError(err) + + for _, l := range deb { + if !strings.HasSuffix(l, " 0 rows") { + t.Logf("debug: %s ... %s", l[:20], l[len(l)-20:]) + } + } + + for _, out := range newRes { + if out[DumpTableKey] == "company" { + t.Logf("found cloned company %v", out["id"]) + } + } + + ok.Equalf(len(res), len(newRes), "expected clone to have the same size export") + + t.Logf("durations: download(%s), upload(%s)", download, upload) +} + +// postgres rows need some pulley-specific cleanup +func CleanupRow(obj map[string]any) { + if obj[DumpTableKey] == "security" { + obj["change_email_token"] = nil + } + if obj[DumpTableKey] == "task" { + obj["access_code"] = nil + } + if obj[DumpTableKey] == "company" { + obj["stripe_customer_id"] = nil + } + for col, raw := range obj { + switch val := raw.(type) { + case time.Time: + if val.IsZero() || val.Year() <= 1 { + obj[col] = time.Now() // there's a few invalid timestamps + } + } + } +} + +// func TestDiffTwoCompaniesWithLocalPostgres(t *testing.T) { +// if !runPulleyTests { +// t.Skipf("test is used for development against real pulley schema") +// } + +// // declare and initialize both company ids +// company1 := largeCompany // To be replaced with actual ID +// company2 := testCompany // To be replaced with actual ID + +// ok := assert.New(t) +// conn, err := pgxpool.Connect(context.Background(), `postgresql://postgres:postgres@localhost:5432/postgres`) +// ok.NoError(err) +// db, err := NewPostgres(context.Background(), conn) +// ok.NoError(err) + +// cli, err := db.NewBatchClient(context.Background(), conn) +// ok.NoError(err) + +// exportOpts := []Opt{ +// DontInclude("user"), +// DontInclude("firm"), +// DontRecurse("stakeholder"), +// DontInclude("sandbox_clone"), +// DontInclude("sandbox"), +// } + +// // Downloading data for company1 +// res1, _, err := Download(context.Background(), cli, "company", "id", company1, exportOpts...) +// ok.NoError(err) + +// // Downloading data for company2 +// res2, _, err := Download(context.Background(), cli, "company", "id", company2, exportOpts...) +// ok.NoError(err) + +// // Perform diff between the two companies +// nonConflicting, conflicting := ThreeWayDiff(cli, res1, res1, res2) + +// // Number of conflicting and non-conflicting differences should be printed or used in assert +// t.Logf("Number of conflicting diffs: %d", len(conflicting)) +// t.Logf("Number of nonconflicting diffs: %d", len(nonConflicting)) +// } + +// func TestFullProcessWithRealDatabase(t *testing.T) { +// if !runPulleyTests { +// t.Skipf("test is used for development against real pulley schema") +// } + +// company := testCompany // replace with actual company ID + +// log.Println("starting to clone company", company) + +// ok := assert.New(t) +// conn, err := pgxpool.Connect(context.Background(), `postgresql://postgres:postgres@localhost:5432/postgres`) +// ok.NoError(err) + +// db, err := NewPostgres(context.Background(), conn) +// ok.NoError(err) + +// // Begin a transaction specifically for cloning and reverse mapping +// tx, err := conn.BeginTx(context.Background(), pgx.TxOptions{}) +// ok.NoError(err) +// defer tx.Rollback(context.Background()) + +// cli, err := db.NewBatchClient(context.Background(), tx) // This client is used for cloning and reverse mapping +// ok.NoError(err) + +// exportOpts := []Opt{ +// DontInclude("user"), +// DontInclude("firm"), +// DontRecurse("stakeholder"), +// DontInclude("sandbox_clone"), +// DontInclude("sandbox"), +// } + +// // Downloading data for initial company +// connCli1, err := db.NewBatchClient(context.Background(), conn) // New client for this download +// ok.NoError(err) +// initial, _, err := Download(context.Background(), connCli1, "company", "id", company, exportOpts...) +// ok.NoError(err) + +// for _, row := range initial { +// CleanupRow(row) +// } + +// // Copy 'initial' to 'out' to clone the initial company +// in, _ := json.Marshal(initial) +// out := make([]map[string]any, 0, len(initial)) +// json.Unmarshal(in, &out) + +// log.Println("starting to insert company", company) +// ok.NoError(cli.Insert(out...)) // Insert cloned company + +// mapping, err := cli.GetMapping() // Acquire the mapping between old and new Ids +// ok.NoError(err) + +// for _, m := range mapping { +// log.Println("mapped", m.TableName, m.OriginalID, "to", m.NewID) +// } + +// // Re-download both initial and cloned company +// connCli2, err := db.NewBatchClient(context.Background(), conn) // New client for this download +// ok.NoError(err) +// initialClone, _, err := Download(context.Background(), connCli2, "company", "id", company, exportOpts...) +// ok.NoError(err) +// ok.NotEmpty(initialClone) + +// connCli3, err := db.NewBatchClient(context.Background(), conn) // New client for this download +// ok.NoError(err) +// clonedCompany, _, err := Download(context.Background(), connCli3, "company", "id", mapping[0].NewID, exportOpts...) +// ok.NoError(err) +// ok.NotEmpty(clonedCompany) + +// // Reverse map the downloaded clone +// ok.NoError(ReverseMapping(cli, clonedCompany, mapping)) + +// log.Println("sample initial clone", initialClone[:5]) +// log.Println("sample reverse mapped clone ", clonedCompany[:5]) + +// newDiff := FindDiffs(cli, initialClone, clonedCompany) // Perform diff between initial and cloned data after reverse mapping + +// ok.Empty(newDiff, "Expect no diffs, but got some") // Ensure there is no diff +// } From 3f4010dcf92a41a5c5a08508cfad4bca3e35abdc Mon Sep 17 00:00:00 2001 From: Daniel Toye Date: Tue, 1 Aug 2023 16:10:12 -0600 Subject: [PATCH 2/4] so close, just need to figure out types getting munged during transport, email is bigint --- clone.go | 1 + interface.go | 9 +- merge.go | 82 +++++++--------- merge_test.go | 96 +++++++++++++++++++ postgres.go | 71 +++++++++++++- pulley_test.go | 220 ++++++++++++++++++++++++++++++++++++++++++ pulley_tests.go | 247 ------------------------------------------------ 7 files changed, 430 insertions(+), 296 deletions(-) create mode 100644 pulley_test.go delete mode 100644 pulley_tests.go diff --git a/clone.go b/clone.go index 01a3d35..69bc499 100644 --- a/clone.go +++ b/clone.go @@ -146,6 +146,7 @@ func Download(ctx context.Context, db Database, startTable, startColumn string, return nil, debugging, err } } + return cloneInOrder, debugging, nil } diff --git a/interface.go b/interface.go index 7fdf3e5..6c0af48 100644 --- a/interface.go +++ b/interface.go @@ -1,6 +1,8 @@ package datapasta -import "fmt" +import ( + "fmt" +) // Database is the abstraction between the cloning tool and the database. // The NewPostgres.NewClient method gives you an implementation for Postgres. @@ -59,7 +61,10 @@ func (r RecordID) String() string { func GetRowIdentifier(pks map[string]string, row map[string]any) RecordID { table := row[DumpTableKey].(string) - pk := row[pks[table]] + pk, ok := row[pks[table]] + if !ok { + panic("unable to get row identifier") + } return RecordID{Table: table, PrimaryKey: pk} } diff --git a/merge.go b/merge.go index 380eda5..ba82b70 100644 --- a/merge.go +++ b/merge.go @@ -1,6 +1,10 @@ package datapasta -import "fmt" +import ( + "fmt" + "log" + "reflect" +) type MergeAction struct { ID RecordID @@ -33,9 +37,11 @@ func FindMapping(table string, id any, mapp []Mapping) Mapping { continue } if m.NewID == id { + log.Printf(`%s: %T %#v == %T %#v`, table, m.NewID, m.NewID, id, id) return m } } + log.Printf("no mapping found for %s (%T %v)", table, id, id) return Mapping{TableName: table, OriginalID: id, NewID: id} } @@ -45,6 +51,7 @@ func ReversePrimaryKeyMapping(pks map[string]string, mapp []Mapping, dump Databa table := row[DumpTableKey].(string) pk, hasPk := pks[table] if !hasPk { + log.Println("no pk for", table) continue } m := FindMapping(table, row[pk], mapp) @@ -119,7 +126,7 @@ func FindModifiedRows(pks map[string]string, from, in DatabaseDump) map[RecordID changes := make(map[string]any) for k, v := range match { - if v != row[k] { + if !reflect.DeepEqual(v, row[k]) { changes[k] = row[k] } } @@ -132,27 +139,46 @@ func FindModifiedRows(pks map[string]string, from, in DatabaseDump) map[RecordID return all } -func ApplyMergeStrategy(db Database, mapp []Mapping, dump DatabaseDump, mas []MergeAction) error { +func ApplyMergeStrategy(db Database, mapp []Mapping, mas []MergeAction) error { fks := db.ForeignKeys() - pks := db.PrimaryKeys() for _, ma := range mas { if ma.Action != "create" { continue } + ma.Data[DumpTableKey] = ma.ID.Table ReverseForeignKeyMappingRow(fks, mapp, ma.Data) - origID := ma.Data[pks[ma.ID.Table]] - delete(ma.Data, pks[ma.ID.Table]) id, err := db.InsertRecord(ma.Data) if err != nil { - return err + return fmt.Errorf(`creating %s: %s`, ma.ID, err.Error()) } - mapp = append(mapp, Mapping{TableName: ma.ID.Table, OriginalID: origID, NewID: id}) + mapp = append(mapp, Mapping{TableName: ma.ID.Table, NewID: ma.ID.PrimaryKey, OriginalID: id}) } // do all the creates *while* updating the mapping // do all the updates + for _, ma := range mas { + if ma.Action != "update" { + continue + } + ma.Data[DumpTableKey] = ma.ID.Table + ReverseForeignKeyMappingRow(fks, mapp, ma.Data) + delete(ma.Data, DumpTableKey) + if err := db.Update(ma.ID, ma.Data); err != nil { + return fmt.Errorf(`updating %s: %s`, ma.ID, err.Error()) + } + } + // do the all deletes + for _, ma := range mas { + if ma.Action != "delete" { + continue + } + if err := db.Delete(ma.ID); err != nil { + return fmt.Errorf(`deleting %s: %s`, ma.ID, err.Error()) + } + } + return nil } @@ -173,6 +199,7 @@ func GenerateMergeStrategy(pks map[string]string, base, main, branch DatabaseDum created := FindMissingRows(pks, branch, base) for _, m := range created { id := GetRowIdentifier(pks, m) + delete(m, pks[id.Table]) out = append(out, MergeAction{id, "create", m}) } @@ -192,45 +219,8 @@ func GenerateMergeStrategy(pks map[string]string, base, main, branch DatabaseDum out = append(out, MergeAction{id, "conflict", m}) continue } - out = append(out, MergeAction{id, "delete", m}) + out = append(out, MergeAction{id, "delete", nil}) } return out } - -// // ThreeWayMerge applies a three way merge to a Diffing interface -// func ThreeWayMerge(db Diffing, mapp []Mapping, root, main, branch DatabaseDump) (actions MergeAction) { -// // existing definitions: -// // type DatabaseDump []map[string]any -// // type Mapping struct { -// // TableName string -// // OriginalID, NewID any -// // } -// // type Diffing interface { -// // Insert(table string, record map[string]any) (pk any, err error) -// // Update(table string, record map[string]any) error -// // } - -// // all DatabaseDump rows have an "id" field of type any that is used for finding the old id (unmapping) - -// // define slice for merge actions - -// // find new items in branch not in root -// // for each new item: -// // apply id unmapping to this item -// // insert it into the db -// // append new id to mapping -// // append this change as a nonconflicting change -// // apply id unmapping to everything in branch - -// // find all modified or deleted items in main -// // find all modified or deleted items in branch -// // for each diff item in branch: -// // if it exists in the main diff -// // append this as a conflicting merge actions -// // otherwise -// // append this to the merge actions -// // update the record in the db - -// // return conflicting and nonconflicting changes -// } diff --git a/merge_test.go b/merge_test.go index def0e95..b1d2429 100644 --- a/merge_test.go +++ b/merge_test.go @@ -1,6 +1,7 @@ package datapasta import ( + "fmt" "testing" "github.com/stretchr/testify/assert" @@ -345,3 +346,98 @@ func TestGenerateMergeStrategyWithMapping(t *testing.T) { ok.Len(mas[1].Data, 1) ok.Equal(22, mas[1].Data["friend"]) } + +func TestApplyMergeStrategy(t *testing.T) { + ok := assert.New(t) + + mapp := []Mapping{ + {"user", 1, 3}, + {"user", 2, 4}, + } + db := &mergeDB{T: t, id: 5, data: map[any]map[string]any{}} + db.data[1] = map[string]any{"name": "alica", "friend": 2} + db.data[2] = map[string]any{"name": "bob"} + + // "alice" (1) is friends with bob (2) + // "alica" is cloned to 3 + // "bob" is cloned to 4 + // "alica" renames to "alicia" and becomes friends with "jeff" (5) + // "bob" is deleted + + mas := []MergeAction{} + mas = append(mas, MergeAction{ + ID: RecordID{"user", 1}, + Action: "update", + Data: map[string]any{"name": "alicia", "friend": 5}, + }) + mas = append(mas, MergeAction{ + ID: RecordID{"user", 5}, + Action: "create", + Data: map[string]any{"name": "jeff"}, + }) + mas = append(mas, MergeAction{ + ID: RecordID{"user", 2}, + Action: "delete", + }) + + ok.NoError(ApplyMergeStrategy(db, mapp, mas)) + + ok.Equal("alicia", db.data[1]["name"]) + ok.Equal("jeff", db.data[db.data[1]["friend"]]["name"]) + ok.NotContains(db.data, "bob") +} + +type mergeDB struct { + *testing.T + id int + data map[any]map[string]any +} + +// get foriegn key mapping +func (d *mergeDB) ForeignKeys() []ForeignKey { + d.Log(`ForeignKeys`) + return []ForeignKey{ + { + BaseTable: "user", BaseCol: "id", + ReferencingTable: "user", ReferencingCol: "friend", + }, + } +} + +func (d *mergeDB) InsertRecord(i map[string]any) (any, error) { + d.Log(`InsertRecord`, i) + d.id++ + d.data[d.id] = i + return d.id, nil +} + +func (d *mergeDB) Update(id RecordID, cols map[string]any) error { + if _, ok := d.data[id.PrimaryKey]; !ok { + return fmt.Errorf(`cant update nonexistant row %s`, id) + } + for k, v := range cols { + d.data[id.PrimaryKey][k] = v + } + d.Log(`Update`, id, cols) + return nil +} + +// delete the row +func (d *mergeDB) Delete(id RecordID) error { + if _, ok := d.data[id.PrimaryKey]; !ok { + return fmt.Errorf(`cant delete nonexistant row %s`, id) + } + d.Log(`Delete`, id) + return nil +} + +// stubbed for interface + +func (d mergeDB) Mapping() ([]Mapping, error) { return nil, nil } +func (d mergeDB) PrimaryKeys() map[string]string { return nil } +func (d mergeDB) Insert(records ...map[string]any) error { return nil } +func (d mergeDB) SelectMatchingRows(tname string, conds map[string][]any) ([]map[string]any, error) { + return nil, nil +} + +var _ Database = new(mergeDB) diff --git a/postgres.go b/postgres.go index c3acbc2..00d4b55 100644 --- a/postgres.go +++ b/postgres.go @@ -189,15 +189,51 @@ func (db pgbatchtx) InsertRecord(row map[string]any) (any, error) { } func (db pgbatchtx) Update(id RecordID, cols map[string]any) error { + table := id.Table + builder := db.builder.Update(`"` + table + `"`) + builder = builder.SetMap(cols).Where(squirrel.Eq{"id": id.PrimaryKey}) + sql, args, err := builder.ToSql() + if err != nil { + return err + } + cmd, err := db.tx.db.Exec(db.ctx, sql, args) + if err != nil { + return err + } + if cmd.RowsAffected() != 0 { + return fmt.Errorf("delete affected %d rows, expected 1", cmd.RowsAffected()) + } return nil } func (db pgbatchtx) Delete(id RecordID) error { + table := id.Table + builder := db.builder.Delete(`"` + table + `"`).Where(squirrel.Eq{"id": id.PrimaryKey}) + builder = builder.Limit(1) + sql, args, err := builder.ToSql() + if err != nil { + return err + } + cmd, err := db.tx.db.Exec(db.ctx, sql, args) + if err != nil { + return err + } + if cmd.RowsAffected() != 0 { + return fmt.Errorf("delete affected %d rows, expected 1", cmd.RowsAffected()) + } return nil } func (db pgbatchtx) Mapping() ([]Mapping, error) { - return nil, nil + rows, err := db.tx.GetMapping(db.ctx) + if err != nil { + return nil, err + } + mapps := make([]Mapping, 0, len(rows)) + for _, r := range rows { + mapps = append(mapps, Mapping{TableName: r.TableName, OriginalID: r.OriginalID, NewID: r.CloneID}) + } + return mapps, nil } func (db pgbatchtx) Insert(rows ...map[string]any) error { @@ -322,6 +358,39 @@ type postgresQueries struct { db Postgreser } +const getMapping = ` + SELECT table_name, original_id, clone_id FROM datapasta_clone +` + +type getMappingRow struct { + TableName string + OriginalID, CloneID int32 +} + +func (q *postgresQueries) GetMapping(ctx context.Context) ([]getMappingRow, error) { + rows, err := q.db.Query(ctx, getMapping) + if err != nil { + return nil, err + } + defer rows.Close() + var items []getMappingRow + for rows.Next() { + var i getMappingRow + if err := rows.Scan( + &i.TableName, + &i.OriginalID, + &i.CloneID, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const getForeignKeys = `-- name: GetForeignKeys :many SELECT (select r.relname from pg_catalog.pg_class r where r.oid = c.confrelid)::text as base_table, diff --git a/pulley_test.go b/pulley_test.go new file mode 100644 index 0000000..734781a --- /dev/null +++ b/pulley_test.go @@ -0,0 +1,220 @@ +package datapasta + +import ( + "context" + "encoding/json" + "log" + "strings" + "testing" + "time" + + "github.com/jackc/pgx/v4" + "github.com/jackc/pgx/v4/pgxpool" + "github.com/stretchr/testify/assert" +) + +const runPulleyTests = true + +// This file houses tests that run against Pulley schema. +// TODO: have a docker database to run tests in a schema in datapasta. + +const largeCompany = 2127511261 +const testCompany = 1296515245 + +func TestWithLocalPostgres(t *testing.T) { + if !runPulleyTests { + t.Skipf("test is used for development against real pulley schema") + } + + company := testCompany + log.Println("starting to clone company", company) + + ok := assert.New(t) + conn, err := pgxpool.Connect(context.Background(), `postgresql://postgres:postgres@localhost:5432/postgres`) + ok.NoError(err) + db, err := NewPostgres(context.Background(), conn) + ok.NoError(err) + + tx, err := conn.BeginTx(context.Background(), pgx.TxOptions{}) + ok.NoError(err) + defer tx.Rollback(context.Background()) + + cli, err := db.NewBatchClient(context.Background(), tx) + ok.NoError(err) + + exportOpts := []Opt{ + DontInclude("user"), + DontInclude("firm"), + DontRecurse("stakeholder"), + DontInclude("sandbox_clone"), + DontInclude("sandbox"), + } + + startDL := time.Now() + res, _, err := Download(context.Background(), cli, "company", "id", company, exportOpts...) + ok.NoError(err) + ok.NotEmpty(res) + download := time.Since(startDL) + + for _, row := range res { + CleanupRow(row) + log.Println("cloning", row[DumpTableKey], row["id"]) + } + + in, _ := json.Marshal(res) + out := make([]map[string]any, 0, len(res)) + json.Unmarshal(in, &out) + + start := time.Now() + + log.Println("starting to insert company", company) + ok.NoError(cli.Insert(out...)) + upload := time.Since(start) + + var newID int64 + switch any(cli).(type) { + case pgbatchtx: + ok.NoError(tx.QueryRow(context.Background(), "SELECT clone_id FROM datapasta_clone WHERE original_id = $1 AND table_name = 'company'", company).Scan(&newID)) + case pgtx: + newID = int64(out[0]["id"].(int32)) + } + + t.Logf("new id: %d", newID) + + log.Println("starting to download company", newID) + newRes, deb, err := Download(context.Background(), cli, "company", "id", newID, exportOpts...) + ok.NoError(err) + + for _, l := range deb { + if !strings.HasSuffix(l, " 0 rows") { + t.Logf("debug: %s ... %s", l[:20], l[len(l)-20:]) + } + } + + for _, out := range newRes { + if out[DumpTableKey] == "company" { + t.Logf("found cloned company %v", out["id"]) + } + } + + ok.Equalf(len(res), len(newRes), "expected clone to have the same size export") + + t.Logf("durations: download(%s), upload(%s)", download, upload) +} + +// postgres rows need some pulley-specific cleanup +func CleanupRow(obj map[string]any) { + if obj[DumpTableKey] == "security" { + obj["change_email_token"] = nil + } + if obj[DumpTableKey] == "task" { + obj["access_code"] = nil + } + if obj[DumpTableKey] == "company" { + obj["stripe_customer_id"] = nil + } + for col, raw := range obj { + switch val := raw.(type) { + case time.Time: + if val.IsZero() || val.Year() <= 1 { + obj[col] = time.Now() // there's a few invalid timestamps + } + } + } +} + +func TestFullProcessWithRealDatabase(t *testing.T) { + if !runPulleyTests { + t.Skipf("test is used for development against real pulley schema") + } + + company := testCompany // replace with actual company ID + + log.Println("starting to clone company", company) + + ok := assert.New(t) + conn, err := pgxpool.Connect(context.Background(), `postgresql://postgres:postgres@localhost:5432/postgres`) + ok.NoError(err) + + db, err := NewPostgres(context.Background(), conn) + ok.NoError(err) + + // Begin a transaction specifically for cloning and reverse mapping + tx, err := conn.BeginTx(context.Background(), pgx.TxOptions{}) + ok.NoError(err) + defer tx.Rollback(context.Background()) + + exportOpts := []Opt{ + DontInclude("user"), + DontInclude("firm_company"), + DontRecurse("stakeholder"), + DontInclude("sandbox_clone"), + DontInclude("sandbox"), + } + + // Downloading data for initial company + connCli1, err := db.NewBatchClient(context.Background(), conn) // New client for this download + ok.NoError(err) + initial, _, err := Download(context.Background(), connCli1, "company", "id", company, exportOpts...) + ok.NoError(err) + ok.NotEmpty(initial) + + for _, row := range initial { + CleanupRow(row) + } + + cli, err := db.NewBatchClient(context.Background(), tx) // This client is used for cloning and reverse mapping + in, _ := json.Marshal(initial) + { + out := make([]map[string]any, 0, len(initial)) + json.Unmarshal(in, &out) + + log.Println("starting to insert company", company) + ok.NoError(err) + ok.NoError(cli.Insert(out...)) // Insert cloned company + } + mapping, err := cli.Mapping() // Acquire the mapping between old and new Ids + ok.NoError(err) + + for _, m := range mapping { + log.Printf("mapped %s (%T %v to %T %v)", m.TableName, m.OriginalID, m.OriginalID, m.NewID, m.NewID) + } + + // we can mutate company and clone here to generate merge actions + + // Re-download both initial and cloned company + connCli2, err := db.NewBatchClient(context.Background(), conn) // New client for this download + ok.NoError(err) + currentCompany, _, err := Download(context.Background(), connCli2, "company", "id", company, exportOpts...) + ok.NoError(err) + ok.NotEmpty(currentCompany) + + // get the clone using the tx + connCli3, err := db.NewBatchClient(context.Background(), tx) // New client for this download + ok.NoError(err) + clonedCompany, _, err := Download(context.Background(), connCli3, "company", "id", mapping[0].NewID, exportOpts...) + ok.NoError(err) + ok.NotEmpty(clonedCompany) + + t.Logf(`before mapping: %#v`, clonedCompany[0]) + + ReverseForeignKeyMapping(db.ForeignKeys(), mapping, clonedCompany) + ReversePrimaryKeyMapping(db.PrimaryKeys(), mapping, clonedCompany) + t.Logf(`after mapping: %#v`, clonedCompany[0]) + t.Logf(`previously: %#v`, initial[0]) + t.Logf(`currently: %#v`, currentCompany[0]) + + t.Logf(`after mapping: %T`, clonedCompany[0]["id"]) + t.Logf(`previously: %T`, initial[0]["id"]) + t.Logf(`currently: %T`, currentCompany[0]["id"]) + + mas := GenerateMergeStrategy(db.PrimaryKeys(), initial, currentCompany, clonedCompany) + ok.Empty(mas) + + if len(mas) > 5 { + for _, ma := range mas[:5] { + t.Log("merge action:", ma) + } + + } +} diff --git a/pulley_tests.go b/pulley_tests.go deleted file mode 100644 index 42bc1e8..0000000 --- a/pulley_tests.go +++ /dev/null @@ -1,247 +0,0 @@ -package datapasta - -import ( - "context" - "encoding/json" - "log" - "strings" - "testing" - "time" - - "github.com/jackc/pgx/v4" - "github.com/jackc/pgx/v4/pgxpool" - "github.com/stretchr/testify/assert" -) - -const runPulleyTests = true - -// This file houses tests that run against Pulley schema. -// TODO: have a docker database to run tests in a schema in datapasta. - -const largeCompany = 2127511261 -const testCompany = 1296515245 - -func TestWithLocalPostgres(t *testing.T) { - if !runPulleyTests { - t.Skipf("test is used for development against real pulley schema") - } - - company := testCompany - log.Println("starting to clone company", company) - - ok := assert.New(t) - conn, err := pgxpool.Connect(context.Background(), `postgresql://postgres:postgres@localhost:5432/postgres`) - ok.NoError(err) - db, err := NewPostgres(context.Background(), conn) - ok.NoError(err) - - tx, err := conn.BeginTx(context.Background(), pgx.TxOptions{}) - ok.NoError(err) - defer tx.Rollback(context.Background()) - - cli, err := db.NewBatchClient(context.Background(), tx) - ok.NoError(err) - - exportOpts := []Opt{ - DontInclude("user"), - DontInclude("firm"), - DontRecurse("stakeholder"), - DontInclude("sandbox_clone"), - DontInclude("sandbox"), - } - - startDL := time.Now() - res, _, err := Download(context.Background(), cli, "company", "id", company, exportOpts...) - ok.NoError(err) - ok.NotEmpty(res) - download := time.Since(startDL) - - for _, row := range res { - CleanupRow(row) - log.Println("cloning", row[DumpTableKey], row["id"]) - } - - in, _ := json.Marshal(res) - out := make([]map[string]any, 0, len(res)) - json.Unmarshal(in, &out) - - start := time.Now() - - log.Println("starting to insert company", company) - ok.NoError(cli.Insert(out...)) - upload := time.Since(start) - - var newID int64 - switch any(cli).(type) { - case pgbatchtx: - ok.NoError(tx.QueryRow(context.Background(), "SELECT clone_id FROM datapasta_clone WHERE original_id = $1 AND table_name = 'company'", company).Scan(&newID)) - case pgtx: - newID = int64(out[0]["id"].(int32)) - } - - t.Logf("new id: %d", newID) - - log.Println("starting to download company", newID) - newRes, deb, err := Download(context.Background(), cli, "company", "id", newID, exportOpts...) - ok.NoError(err) - - for _, l := range deb { - if !strings.HasSuffix(l, " 0 rows") { - t.Logf("debug: %s ... %s", l[:20], l[len(l)-20:]) - } - } - - for _, out := range newRes { - if out[DumpTableKey] == "company" { - t.Logf("found cloned company %v", out["id"]) - } - } - - ok.Equalf(len(res), len(newRes), "expected clone to have the same size export") - - t.Logf("durations: download(%s), upload(%s)", download, upload) -} - -// postgres rows need some pulley-specific cleanup -func CleanupRow(obj map[string]any) { - if obj[DumpTableKey] == "security" { - obj["change_email_token"] = nil - } - if obj[DumpTableKey] == "task" { - obj["access_code"] = nil - } - if obj[DumpTableKey] == "company" { - obj["stripe_customer_id"] = nil - } - for col, raw := range obj { - switch val := raw.(type) { - case time.Time: - if val.IsZero() || val.Year() <= 1 { - obj[col] = time.Now() // there's a few invalid timestamps - } - } - } -} - -// func TestDiffTwoCompaniesWithLocalPostgres(t *testing.T) { -// if !runPulleyTests { -// t.Skipf("test is used for development against real pulley schema") -// } - -// // declare and initialize both company ids -// company1 := largeCompany // To be replaced with actual ID -// company2 := testCompany // To be replaced with actual ID - -// ok := assert.New(t) -// conn, err := pgxpool.Connect(context.Background(), `postgresql://postgres:postgres@localhost:5432/postgres`) -// ok.NoError(err) -// db, err := NewPostgres(context.Background(), conn) -// ok.NoError(err) - -// cli, err := db.NewBatchClient(context.Background(), conn) -// ok.NoError(err) - -// exportOpts := []Opt{ -// DontInclude("user"), -// DontInclude("firm"), -// DontRecurse("stakeholder"), -// DontInclude("sandbox_clone"), -// DontInclude("sandbox"), -// } - -// // Downloading data for company1 -// res1, _, err := Download(context.Background(), cli, "company", "id", company1, exportOpts...) -// ok.NoError(err) - -// // Downloading data for company2 -// res2, _, err := Download(context.Background(), cli, "company", "id", company2, exportOpts...) -// ok.NoError(err) - -// // Perform diff between the two companies -// nonConflicting, conflicting := ThreeWayDiff(cli, res1, res1, res2) - -// // Number of conflicting and non-conflicting differences should be printed or used in assert -// t.Logf("Number of conflicting diffs: %d", len(conflicting)) -// t.Logf("Number of nonconflicting diffs: %d", len(nonConflicting)) -// } - -// func TestFullProcessWithRealDatabase(t *testing.T) { -// if !runPulleyTests { -// t.Skipf("test is used for development against real pulley schema") -// } - -// company := testCompany // replace with actual company ID - -// log.Println("starting to clone company", company) - -// ok := assert.New(t) -// conn, err := pgxpool.Connect(context.Background(), `postgresql://postgres:postgres@localhost:5432/postgres`) -// ok.NoError(err) - -// db, err := NewPostgres(context.Background(), conn) -// ok.NoError(err) - -// // Begin a transaction specifically for cloning and reverse mapping -// tx, err := conn.BeginTx(context.Background(), pgx.TxOptions{}) -// ok.NoError(err) -// defer tx.Rollback(context.Background()) - -// cli, err := db.NewBatchClient(context.Background(), tx) // This client is used for cloning and reverse mapping -// ok.NoError(err) - -// exportOpts := []Opt{ -// DontInclude("user"), -// DontInclude("firm"), -// DontRecurse("stakeholder"), -// DontInclude("sandbox_clone"), -// DontInclude("sandbox"), -// } - -// // Downloading data for initial company -// connCli1, err := db.NewBatchClient(context.Background(), conn) // New client for this download -// ok.NoError(err) -// initial, _, err := Download(context.Background(), connCli1, "company", "id", company, exportOpts...) -// ok.NoError(err) - -// for _, row := range initial { -// CleanupRow(row) -// } - -// // Copy 'initial' to 'out' to clone the initial company -// in, _ := json.Marshal(initial) -// out := make([]map[string]any, 0, len(initial)) -// json.Unmarshal(in, &out) - -// log.Println("starting to insert company", company) -// ok.NoError(cli.Insert(out...)) // Insert cloned company - -// mapping, err := cli.GetMapping() // Acquire the mapping between old and new Ids -// ok.NoError(err) - -// for _, m := range mapping { -// log.Println("mapped", m.TableName, m.OriginalID, "to", m.NewID) -// } - -// // Re-download both initial and cloned company -// connCli2, err := db.NewBatchClient(context.Background(), conn) // New client for this download -// ok.NoError(err) -// initialClone, _, err := Download(context.Background(), connCli2, "company", "id", company, exportOpts...) -// ok.NoError(err) -// ok.NotEmpty(initialClone) - -// connCli3, err := db.NewBatchClient(context.Background(), conn) // New client for this download -// ok.NoError(err) -// clonedCompany, _, err := Download(context.Background(), connCli3, "company", "id", mapping[0].NewID, exportOpts...) -// ok.NoError(err) -// ok.NotEmpty(clonedCompany) - -// // Reverse map the downloaded clone -// ok.NoError(ReverseMapping(cli, clonedCompany, mapping)) - -// log.Println("sample initial clone", initialClone[:5]) -// log.Println("sample reverse mapped clone ", clonedCompany[:5]) - -// newDiff := FindDiffs(cli, initialClone, clonedCompany) // Perform diff between initial and cloned data after reverse mapping - -// ok.Empty(newDiff, "Expect no diffs, but got some") // Ensure there is no diff -// } From 5d999562e3295863ca65a3367b2f0e261d5ab241 Mon Sep 17 00:00:00 2001 From: Daniel Toye Date: Wed, 2 Aug 2023 13:33:08 -0600 Subject: [PATCH 3/4] it appears to be working --- interface.go | 7 +++-- merge.go | 69 ++++++++++++++++++++++++-------------------------- merge_test.go | 26 +++++++++++-------- postgres.go | 28 ++++++++------------ pulley_test.go | 45 ++++++++++++++++++++------------ 5 files changed, 94 insertions(+), 81 deletions(-) diff --git a/interface.go b/interface.go index 6c0af48..16f6e3e 100644 --- a/interface.go +++ b/interface.go @@ -2,6 +2,7 @@ package datapasta import ( "fmt" + "log" ) // Database is the abstraction between the cloning tool and the database. @@ -69,6 +70,8 @@ func GetRowIdentifier(pks map[string]string, row map[string]any) RecordID { } type Mapping struct { - TableName string - OriginalID, NewID any + RecordID + OriginalID any } + +var LogFunc = log.Printf diff --git a/merge.go b/merge.go index ba82b70..0ab0ada 100644 --- a/merge.go +++ b/merge.go @@ -2,8 +2,6 @@ package datapasta import ( "fmt" - "log" - "reflect" ) type MergeAction struct { @@ -13,36 +11,37 @@ type MergeAction struct { } func (ma MergeAction) String() string { - return fmt.Sprintf(`%s %s %#v`, ma.Action, ma.ID, ma.Data) + if ma.Action == "delete" { + return fmt.Sprintf(`%s %s`, ma.Action, ma.ID) + } + return fmt.Sprintf(`%s %s %d columns`, ma.Action, ma.ID, len(ma.Data)) } func FindRow(table, pk string, id any, dump DatabaseDump) map[string]any { if id == nil { return nil } + needle := RecordID{Table: table, PrimaryKey: id} for _, d := range dump { - if d[DumpTableKey] != table { - continue - } - if d[pk] == id { + test := RecordID{Table: d[DumpTableKey].(string), PrimaryKey: d[pk]} + if test.String() == needle.String() { return d } } return nil } -func FindMapping(table string, id any, mapp []Mapping) Mapping { +func FindMapping(id RecordID, mapp []Mapping) Mapping { + if id.PrimaryKey == nil { + return Mapping{RecordID: id, OriginalID: id.PrimaryKey} + } for _, m := range mapp { - if m.TableName != table { - continue - } - if m.NewID == id { - log.Printf(`%s: %T %#v == %T %#v`, table, m.NewID, m.NewID, id, id) + if m.RecordID.String() == id.String() { return m } } - log.Printf("no mapping found for %s (%T %v)", table, id, id) - return Mapping{TableName: table, OriginalID: id, NewID: id} + LogFunc("no mapping found for %s (%T %v)", id.Table, id.PrimaryKey, id.PrimaryKey) + return Mapping{RecordID: id, OriginalID: id.PrimaryKey} } // reverse all the primary keys of a dump @@ -51,10 +50,10 @@ func ReversePrimaryKeyMapping(pks map[string]string, mapp []Mapping, dump Databa table := row[DumpTableKey].(string) pk, hasPk := pks[table] if !hasPk { - log.Println("no pk for", table) + LogFunc("no pk for %s", table) continue } - m := FindMapping(table, row[pk], mapp) + m := FindMapping(RecordID{Table: table, PrimaryKey: row[pk]}, mapp) row[pk] = m.OriginalID } } @@ -62,15 +61,9 @@ func ReversePrimaryKeyMapping(pks map[string]string, mapp []Mapping, dump Databa // reverse all the foreign keys of an indivdual row func ReverseForeignKeyMappingRow(fks []ForeignKey, mapp []Mapping, row map[string]any) { update := func(row map[string]any, col, otherTable string) { - for _, m := range mapp { - if m.TableName != otherTable { - continue - } - if m.NewID != row[col] { - continue - } - row[col] = m.OriginalID - } + target := RecordID{Table: otherTable, PrimaryKey: row[col]} + m := FindMapping(target, mapp) + row[col] = m.OriginalID } table := row[DumpTableKey].(string) @@ -126,7 +119,7 @@ func FindModifiedRows(pks map[string]string, from, in DatabaseDump) map[RecordID changes := make(map[string]any) for k, v := range match { - if !reflect.DeepEqual(v, row[k]) { + if fmt.Sprintf(`%v`, v) != fmt.Sprintf(`%v`, row[k]) { changes[k] = row[k] } } @@ -152,7 +145,7 @@ func ApplyMergeStrategy(db Database, mapp []Mapping, mas []MergeAction) error { if err != nil { return fmt.Errorf(`creating %s: %s`, ma.ID, err.Error()) } - mapp = append(mapp, Mapping{TableName: ma.ID.Table, NewID: ma.ID.PrimaryKey, OriginalID: id}) + mapp = append(mapp, Mapping{RecordID: ma.ID, OriginalID: id}) } // do all the creates *while* updating the mapping @@ -187,13 +180,13 @@ func ApplyMergeStrategy(db Database, mapp []Mapping, mas []MergeAction) error { func GenerateMergeStrategy(pks map[string]string, base, main, branch DatabaseDump) []MergeAction { out := make([]MergeAction, 0) - deletedInMain := make(map[RecordID]bool) + deletedInMain := make(map[string]bool) for _, deleted := range FindMissingRows(pks, base, main) { - deletedInMain[GetRowIdentifier(pks, deleted)] = true + deletedInMain[GetRowIdentifier(pks, deleted).String()] = true } - editedInMain := make(map[RecordID]bool) + editedInMain := make(map[string]bool) for id := range FindModifiedRows(pks, main, base) { - editedInMain[id] = true + editedInMain[id.String()] = true } created := FindMissingRows(pks, branch, base) @@ -205,8 +198,12 @@ func GenerateMergeStrategy(pks map[string]string, base, main, branch DatabaseDum changes := FindModifiedRows(pks, branch, base) for id, c := range changes { - if editedInMain[id] || deletedInMain[id] { - out = append(out, MergeAction{id, "conflict", c}) + if editedInMain[id.String()] { + out = append(out, MergeAction{id, "conflicting_double_update", c}) + continue + } + if deletedInMain[id.String()] { + out = append(out, MergeAction{id, "conflicting_update_deleted", c}) continue } out = append(out, MergeAction{id, "update", c}) @@ -215,8 +212,8 @@ func GenerateMergeStrategy(pks map[string]string, base, main, branch DatabaseDum deleted := FindMissingRows(pks, base, branch) for _, m := range deleted { id := GetRowIdentifier(pks, m) - if editedInMain[id] { - out = append(out, MergeAction{id, "conflict", m}) + if editedInMain[id.String()] { + out = append(out, MergeAction{id, "conflict_delete_updated", m}) continue } out = append(out, MergeAction{id, "delete", nil}) diff --git a/merge_test.go b/merge_test.go index b1d2429..37f69c4 100644 --- a/merge_test.go +++ b/merge_test.go @@ -45,7 +45,7 @@ func TestReversePrimaryKeyMapping(t *testing.T) { }, } mapp := []Mapping{ - {TableName: "person", OriginalID: 8, NewID: 11}, + {RecordID: RecordID{Table: "person", PrimaryKey: 11}, OriginalID: 8}, } pks := map[string]string{"person": "id"} @@ -97,7 +97,7 @@ func TestReverseForeignKeyMapping(t *testing.T) { } fks := []ForeignKey{{ReferencingTable: "person", ReferencingCol: "country", BaseTable: "country", BaseCol: "id"}} - mapp := []Mapping{{TableName: "country", OriginalID: 15, NewID: 20}} + mapp := []Mapping{{RecordID: RecordID{Table: "country", PrimaryKey: 20}, OriginalID: 15}} ReverseForeignKeyMapping(fks, mapp, main) @@ -315,19 +315,25 @@ func TestGenerateMergeStrategyWithMapping(t *testing.T) { } mapping := []Mapping{ { - TableName: "person", + RecordID: RecordID{ + Table: "person", + PrimaryKey: 19, + }, OriginalID: 9, - NewID: 19, }, { - TableName: "person", + RecordID: RecordID{ + Table: "person", + PrimaryKey: 20, + }, OriginalID: 10, - NewID: 20, }, { - TableName: "person", + RecordID: RecordID{ + Table: "person", + PrimaryKey: 21, + }, OriginalID: 11, - NewID: 21, }, } @@ -351,8 +357,8 @@ func TestApplyMergeStrategy(t *testing.T) { ok := assert.New(t) mapp := []Mapping{ - {"user", 1, 3}, - {"user", 2, 4}, + {RecordID{"user", 3}, 1}, + {RecordID{"user", 4}, 2}, } db := &mergeDB{T: t, id: 5, data: map[any]map[string]any{}} db.data[1] = map[string]any{"name": "alica", "friend": 2} diff --git a/postgres.go b/postgres.go index 00d4b55..1bd9b8a 100644 --- a/postgres.go +++ b/postgres.go @@ -4,7 +4,6 @@ import ( "context" "encoding/json" "fmt" - "log" "time" "github.com/Masterminds/squirrel" @@ -13,8 +12,6 @@ import ( "github.com/jackc/pgx/v4" ) -var OptLogPostgres = true - // NewPostgres returns a pgdb that can generate a Database for datapasta Upload and Download functions. func NewPostgres(ctx context.Context, c Postgreser) (pgdb, error) { client := postgresQueries{db: c} @@ -164,7 +161,7 @@ func (db pgbatchtx) InsertRecord(row map[string]any) (any, error) { keys := make([]string, 0, len(row)) vals := make([]any, 0, len(row)) table := row[DumpTableKey].(string) - builder := db.builder.Insert(`"` + table + `"`) + builder := db.builder.Insert(`"` + table + `"`).Suffix("RETURNING id") for k, v := range row { if v == nil { continue @@ -182,7 +179,7 @@ func (db pgbatchtx) InsertRecord(row map[string]any) (any, error) { return nil, err } var id any - if err := db.tx.db.QueryRow(db.ctx, sql, args).Scan(&id); err != nil { + if err := db.tx.db.QueryRow(db.ctx, sql, args...).Scan(&id); err != nil { return nil, err } return id, nil @@ -196,12 +193,12 @@ func (db pgbatchtx) Update(id RecordID, cols map[string]any) error { if err != nil { return err } - cmd, err := db.tx.db.Exec(db.ctx, sql, args) + cmd, err := db.tx.db.Exec(db.ctx, sql, args...) if err != nil { return err } - if cmd.RowsAffected() != 0 { - return fmt.Errorf("delete affected %d rows, expected 1", cmd.RowsAffected()) + if cmd.RowsAffected() != 1 { + return fmt.Errorf("update affected %d rows, expected 1", cmd.RowsAffected()) } return nil } @@ -209,17 +206,16 @@ func (db pgbatchtx) Update(id RecordID, cols map[string]any) error { func (db pgbatchtx) Delete(id RecordID) error { table := id.Table builder := db.builder.Delete(`"` + table + `"`).Where(squirrel.Eq{"id": id.PrimaryKey}) - builder = builder.Limit(1) sql, args, err := builder.ToSql() if err != nil { return err } - cmd, err := db.tx.db.Exec(db.ctx, sql, args) + cmd, err := db.tx.db.Exec(db.ctx, sql, args...) if err != nil { return err } - if cmd.RowsAffected() != 0 { - return fmt.Errorf("delete affected %d rows, expected 1", cmd.RowsAffected()) + if cmd.RowsAffected() > 1 { + return fmt.Errorf("delete affected %d rows, expected 0 or 1", cmd.RowsAffected()) } return nil } @@ -231,7 +227,7 @@ func (db pgbatchtx) Mapping() ([]Mapping, error) { } mapps := make([]Mapping, 0, len(rows)) for _, r := range rows { - mapps = append(mapps, Mapping{TableName: r.TableName, OriginalID: r.OriginalID, NewID: r.CloneID}) + mapps = append(mapps, Mapping{RecordID: RecordID{Table: r.TableName, PrimaryKey: r.CloneID}, OriginalID: r.OriginalID}) } return mapps, nil } @@ -312,7 +308,7 @@ func (db pgbatchtx) Insert(rows ...map[string]any) error { } prepped := time.Now() - log.Printf("batchrows:%d, followups:%d", batch.Len(), followup.Len()) + LogFunc("batchrows:%d, followups:%d", batch.Len(), followup.Len()) res := db.tx.db.SendBatch(db.ctx, batch) for i := 0; i < batch.Len(); i++ { @@ -334,9 +330,7 @@ func (db pgbatchtx) Insert(rows ...map[string]any) error { } fks.Close() - if OptLogPostgres { - log.Printf("prepping: %s, batching: %s", prepped.Sub(start), time.Since(prepped)) - } + LogFunc("prepping: %s, batching: %s", prepped.Sub(start), time.Since(prepped)) if err := res.Close(); err != nil { return fmt.Errorf("failed to execute batch followup queries: %w", err) diff --git a/pulley_test.go b/pulley_test.go index 734781a..d80d208 100644 --- a/pulley_test.go +++ b/pulley_test.go @@ -13,7 +13,7 @@ import ( "github.com/stretchr/testify/assert" ) -const runPulleyTests = true +const runPulleyTests = false // This file houses tests that run against Pulley schema. // TODO: have a docker database to run tests in a schema in datapasta. @@ -175,12 +175,25 @@ func TestFullProcessWithRealDatabase(t *testing.T) { } mapping, err := cli.Mapping() // Acquire the mapping between old and new Ids ok.NoError(err) + cloneID := mapping[0].PrimaryKey for _, m := range mapping { - log.Printf("mapped %s (%T %v to %T %v)", m.TableName, m.OriginalID, m.OriginalID, m.NewID, m.NewID) + log.Printf("mapped %s (%T %v to %T %v)", m.Table, m.OriginalID, m.OriginalID, m.PrimaryKey, m.PrimaryKey) } // we can mutate company and clone here to generate merge actions + deleteStuff := `DELETE FROM "security" WHERE id IN (SELECT id FROM "security" WHERE company_id=$1 LIMIT 1)` + _, deleteErr := tx.Exec(context.Background(), deleteStuff, cloneID) + ok.NoError(deleteErr) + changeStuff := `UPDATE "security" SET issue_date=NOW() WHERE id IN (SELECT id FROM "security" WHERE company_id=$1 LIMIT 1)` + _, changeErr := tx.Exec(context.Background(), changeStuff, cloneID) + ok.NoError(changeErr) + addStuff := `INSERT INTO "security" (company_id) VALUES ($1)` + _, addErr := tx.Exec(context.Background(), addStuff, cloneID) + ok.NoError(addErr) + + // here's the juice: given the "initial" snapshot, we can export both the main and sandbox dumps + // then use them to generate and apply a 3-way merge changeset // Re-download both initial and cloned company connCli2, err := db.NewBatchClient(context.Background(), conn) // New client for this download @@ -189,10 +202,15 @@ func TestFullProcessWithRealDatabase(t *testing.T) { ok.NoError(err) ok.NotEmpty(currentCompany) + // cleanup the current company so that the prior cleanup doesnt count as conflicts + for _, row := range currentCompany { + CleanupRow(row) + } + // get the clone using the tx connCli3, err := db.NewBatchClient(context.Background(), tx) // New client for this download ok.NoError(err) - clonedCompany, _, err := Download(context.Background(), connCli3, "company", "id", mapping[0].NewID, exportOpts...) + clonedCompany, _, err := Download(context.Background(), connCli3, "company", "id", cloneID, exportOpts...) ok.NoError(err) ok.NotEmpty(clonedCompany) @@ -200,21 +218,16 @@ func TestFullProcessWithRealDatabase(t *testing.T) { ReverseForeignKeyMapping(db.ForeignKeys(), mapping, clonedCompany) ReversePrimaryKeyMapping(db.PrimaryKeys(), mapping, clonedCompany) - t.Logf(`after mapping: %#v`, clonedCompany[0]) - t.Logf(`previously: %#v`, initial[0]) - t.Logf(`currently: %#v`, currentCompany[0]) - - t.Logf(`after mapping: %T`, clonedCompany[0]["id"]) - t.Logf(`previously: %T`, initial[0]["id"]) - t.Logf(`currently: %T`, currentCompany[0]["id"]) mas := GenerateMergeStrategy(db.PrimaryKeys(), initial, currentCompany, clonedCompany) - ok.Empty(mas) - - if len(mas) > 5 { - for _, ma := range mas[:5] { - t.Log("merge action:", ma) - } + actions := map[string]bool{} + for _, ma := range mas { + actions[ma.Action] = true + t.Log("merge action:", ma) } + ok.Equal(map[string]bool{"create": true, "delete": true, "update": true}, actions) + + mergeErr := ApplyMergeStrategy(connCli3, mapping, mas) + ok.NoError(mergeErr) } From 589cdbc77eac5f6e7990995ef16140d8d5ce7838 Mon Sep 17 00:00:00 2001 From: Daniel Toye Date: Thu, 3 Aug 2023 10:50:22 -0600 Subject: [PATCH 4/4] fix tests --- merge_test.go | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/merge_test.go b/merge_test.go index 37f69c4..43b2880 100644 --- a/merge_test.go +++ b/merge_test.go @@ -249,14 +249,11 @@ func TestGenerateMergeStrategy(t *testing.T) { ok.Equal("create", actions[0].Action) ok.Equal(16, actions[0].ID.PrimaryKey) - ok.Equal("update", actions[1].Action) - ok.Equal(12, actions[1].ID.PrimaryKey) - - ok.Contains(actions, MergeAction{ID: RecordID{Table: "person", PrimaryKey: 20}, Action: "conflict", Data: map[string]interface{}{"name": "updated_both_complete_branch"}}) + ok.Contains(actions, MergeAction{ID: RecordID{Table: "person", PrimaryKey: 20}, Action: "conflicting_double_update", Data: map[string]interface{}{"name": "updated_both_complete_branch"}}) ok.Contains(actions, MergeAction{ID: RecordID{Table: "person", PrimaryKey: 12}, Action: "update", Data: map[string]interface{}{"name": "name_changed_in_branch_completed"}}) - ok.Contains(actions, MergeAction{ID: RecordID{Table: "person", PrimaryKey: 20}, Action: "conflict", Data: map[string]interface{}{"name": "updated_both_complete_branch"}}) + ok.Contains(actions, MergeAction{ID: RecordID{Table: "person", PrimaryKey: 20}, Action: "conflicting_double_update", Data: map[string]interface{}{"name": "updated_both_complete_branch"}}) - ok.Equal("conflict", actions[5].Action) + ok.Equal("conflict_delete_updated", actions[5].Action) ok.Equal(18, actions[5].ID.PrimaryKey) ok.Equal("delete", actions[6].Action)