Skip to content

Commit

Permalink
it appears to be working
Browse files Browse the repository at this point in the history
  • Loading branch information
dan-pulley committed Aug 2, 2023
1 parent 3f4010d commit 5d99956
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 81 deletions.
7 changes: 5 additions & 2 deletions interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package datapasta

import (
"fmt"
"log"
)

// Database is the abstraction between the cloning tool and the database.
Expand Down Expand Up @@ -69,6 +70,8 @@ func GetRowIdentifier(pks map[string]string, row map[string]any) RecordID {
}

type Mapping struct {
TableName string
OriginalID, NewID any
RecordID
OriginalID any
}

var LogFunc = log.Printf
69 changes: 33 additions & 36 deletions merge.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@ package datapasta

import (
"fmt"
"log"
"reflect"
)

type MergeAction struct {
Expand All @@ -13,36 +11,37 @@ type MergeAction struct {
}

func (ma MergeAction) String() string {
return fmt.Sprintf(`%s %s %#v`, ma.Action, ma.ID, ma.Data)
if ma.Action == "delete" {
return fmt.Sprintf(`%s %s`, ma.Action, ma.ID)
}
return fmt.Sprintf(`%s %s %d columns`, ma.Action, ma.ID, len(ma.Data))
}

func FindRow(table, pk string, id any, dump DatabaseDump) map[string]any {
if id == nil {
return nil
}
needle := RecordID{Table: table, PrimaryKey: id}
for _, d := range dump {
if d[DumpTableKey] != table {
continue
}
if d[pk] == id {
test := RecordID{Table: d[DumpTableKey].(string), PrimaryKey: d[pk]}
if test.String() == needle.String() {
return d
}
}
return nil
}

func FindMapping(table string, id any, mapp []Mapping) Mapping {
func FindMapping(id RecordID, mapp []Mapping) Mapping {
if id.PrimaryKey == nil {
return Mapping{RecordID: id, OriginalID: id.PrimaryKey}
}
for _, m := range mapp {
if m.TableName != table {
continue
}
if m.NewID == id {
log.Printf(`%s: %T %#v == %T %#v`, table, m.NewID, m.NewID, id, id)
if m.RecordID.String() == id.String() {
return m
}
}
log.Printf("no mapping found for %s (%T %v)", table, id, id)
return Mapping{TableName: table, OriginalID: id, NewID: id}
LogFunc("no mapping found for %s (%T %v)", id.Table, id.PrimaryKey, id.PrimaryKey)
return Mapping{RecordID: id, OriginalID: id.PrimaryKey}
}

// reverse all the primary keys of a dump
Expand All @@ -51,26 +50,20 @@ func ReversePrimaryKeyMapping(pks map[string]string, mapp []Mapping, dump Databa
table := row[DumpTableKey].(string)
pk, hasPk := pks[table]
if !hasPk {
log.Println("no pk for", table)
LogFunc("no pk for %s", table)
continue
}
m := FindMapping(table, row[pk], mapp)
m := FindMapping(RecordID{Table: table, PrimaryKey: row[pk]}, mapp)
row[pk] = m.OriginalID
}
}

// reverse all the foreign keys of an indivdual row
func ReverseForeignKeyMappingRow(fks []ForeignKey, mapp []Mapping, row map[string]any) {
update := func(row map[string]any, col, otherTable string) {
for _, m := range mapp {
if m.TableName != otherTable {
continue
}
if m.NewID != row[col] {
continue
}
row[col] = m.OriginalID
}
target := RecordID{Table: otherTable, PrimaryKey: row[col]}
m := FindMapping(target, mapp)
row[col] = m.OriginalID
}

table := row[DumpTableKey].(string)
Expand Down Expand Up @@ -126,7 +119,7 @@ func FindModifiedRows(pks map[string]string, from, in DatabaseDump) map[RecordID

changes := make(map[string]any)
for k, v := range match {
if !reflect.DeepEqual(v, row[k]) {
if fmt.Sprintf(`%v`, v) != fmt.Sprintf(`%v`, row[k]) {
changes[k] = row[k]
}
}
Expand All @@ -152,7 +145,7 @@ func ApplyMergeStrategy(db Database, mapp []Mapping, mas []MergeAction) error {
if err != nil {
return fmt.Errorf(`creating %s: %s`, ma.ID, err.Error())
}
mapp = append(mapp, Mapping{TableName: ma.ID.Table, NewID: ma.ID.PrimaryKey, OriginalID: id})
mapp = append(mapp, Mapping{RecordID: ma.ID, OriginalID: id})
}

// do all the creates *while* updating the mapping
Expand Down Expand Up @@ -187,13 +180,13 @@ func ApplyMergeStrategy(db Database, mapp []Mapping, mas []MergeAction) error {
func GenerateMergeStrategy(pks map[string]string, base, main, branch DatabaseDump) []MergeAction {
out := make([]MergeAction, 0)

deletedInMain := make(map[RecordID]bool)
deletedInMain := make(map[string]bool)
for _, deleted := range FindMissingRows(pks, base, main) {
deletedInMain[GetRowIdentifier(pks, deleted)] = true
deletedInMain[GetRowIdentifier(pks, deleted).String()] = true
}
editedInMain := make(map[RecordID]bool)
editedInMain := make(map[string]bool)
for id := range FindModifiedRows(pks, main, base) {
editedInMain[id] = true
editedInMain[id.String()] = true
}

created := FindMissingRows(pks, branch, base)
Expand All @@ -205,8 +198,12 @@ func GenerateMergeStrategy(pks map[string]string, base, main, branch DatabaseDum

changes := FindModifiedRows(pks, branch, base)
for id, c := range changes {
if editedInMain[id] || deletedInMain[id] {
out = append(out, MergeAction{id, "conflict", c})
if editedInMain[id.String()] {
out = append(out, MergeAction{id, "conflicting_double_update", c})
continue
}
if deletedInMain[id.String()] {
out = append(out, MergeAction{id, "conflicting_update_deleted", c})
continue
}
out = append(out, MergeAction{id, "update", c})
Expand All @@ -215,8 +212,8 @@ func GenerateMergeStrategy(pks map[string]string, base, main, branch DatabaseDum
deleted := FindMissingRows(pks, base, branch)
for _, m := range deleted {
id := GetRowIdentifier(pks, m)
if editedInMain[id] {
out = append(out, MergeAction{id, "conflict", m})
if editedInMain[id.String()] {
out = append(out, MergeAction{id, "conflict_delete_updated", m})
continue
}
out = append(out, MergeAction{id, "delete", nil})
Expand Down
26 changes: 16 additions & 10 deletions merge_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ func TestReversePrimaryKeyMapping(t *testing.T) {
},
}
mapp := []Mapping{
{TableName: "person", OriginalID: 8, NewID: 11},
{RecordID: RecordID{Table: "person", PrimaryKey: 11}, OriginalID: 8},
}
pks := map[string]string{"person": "id"}

Expand Down Expand Up @@ -97,7 +97,7 @@ func TestReverseForeignKeyMapping(t *testing.T) {
}

fks := []ForeignKey{{ReferencingTable: "person", ReferencingCol: "country", BaseTable: "country", BaseCol: "id"}}
mapp := []Mapping{{TableName: "country", OriginalID: 15, NewID: 20}}
mapp := []Mapping{{RecordID: RecordID{Table: "country", PrimaryKey: 20}, OriginalID: 15}}

ReverseForeignKeyMapping(fks, mapp, main)

Expand Down Expand Up @@ -315,19 +315,25 @@ func TestGenerateMergeStrategyWithMapping(t *testing.T) {
}
mapping := []Mapping{
{
TableName: "person",
RecordID: RecordID{
Table: "person",
PrimaryKey: 19,
},
OriginalID: 9,
NewID: 19,
},
{
TableName: "person",
RecordID: RecordID{
Table: "person",
PrimaryKey: 20,
},
OriginalID: 10,
NewID: 20,
},
{
TableName: "person",
RecordID: RecordID{
Table: "person",
PrimaryKey: 21,
},
OriginalID: 11,
NewID: 21,
},
}

Expand All @@ -351,8 +357,8 @@ func TestApplyMergeStrategy(t *testing.T) {
ok := assert.New(t)

mapp := []Mapping{
{"user", 1, 3},
{"user", 2, 4},
{RecordID{"user", 3}, 1},
{RecordID{"user", 4}, 2},
}
db := &mergeDB{T: t, id: 5, data: map[any]map[string]any{}}
db.data[1] = map[string]any{"name": "alica", "friend": 2}
Expand Down
28 changes: 11 additions & 17 deletions postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"context"
"encoding/json"
"fmt"
"log"
"time"

"github.com/Masterminds/squirrel"
Expand All @@ -13,8 +12,6 @@ import (
"github.com/jackc/pgx/v4"
)

var OptLogPostgres = true

// NewPostgres returns a pgdb that can generate a Database for datapasta Upload and Download functions.
func NewPostgres(ctx context.Context, c Postgreser) (pgdb, error) {
client := postgresQueries{db: c}
Expand Down Expand Up @@ -164,7 +161,7 @@ func (db pgbatchtx) InsertRecord(row map[string]any) (any, error) {
keys := make([]string, 0, len(row))
vals := make([]any, 0, len(row))
table := row[DumpTableKey].(string)
builder := db.builder.Insert(`"` + table + `"`)
builder := db.builder.Insert(`"` + table + `"`).Suffix("RETURNING id")
for k, v := range row {
if v == nil {
continue
Expand All @@ -182,7 +179,7 @@ func (db pgbatchtx) InsertRecord(row map[string]any) (any, error) {
return nil, err
}
var id any
if err := db.tx.db.QueryRow(db.ctx, sql, args).Scan(&id); err != nil {
if err := db.tx.db.QueryRow(db.ctx, sql, args...).Scan(&id); err != nil {
return nil, err
}
return id, nil
Expand All @@ -196,30 +193,29 @@ func (db pgbatchtx) Update(id RecordID, cols map[string]any) error {
if err != nil {
return err
}
cmd, err := db.tx.db.Exec(db.ctx, sql, args)
cmd, err := db.tx.db.Exec(db.ctx, sql, args...)
if err != nil {
return err
}
if cmd.RowsAffected() != 0 {
return fmt.Errorf("delete affected %d rows, expected 1", cmd.RowsAffected())
if cmd.RowsAffected() != 1 {
return fmt.Errorf("update affected %d rows, expected 1", cmd.RowsAffected())
}
return nil
}

func (db pgbatchtx) Delete(id RecordID) error {
table := id.Table
builder := db.builder.Delete(`"` + table + `"`).Where(squirrel.Eq{"id": id.PrimaryKey})
builder = builder.Limit(1)
sql, args, err := builder.ToSql()
if err != nil {
return err
}
cmd, err := db.tx.db.Exec(db.ctx, sql, args)
cmd, err := db.tx.db.Exec(db.ctx, sql, args...)
if err != nil {
return err
}
if cmd.RowsAffected() != 0 {
return fmt.Errorf("delete affected %d rows, expected 1", cmd.RowsAffected())
if cmd.RowsAffected() > 1 {
return fmt.Errorf("delete affected %d rows, expected 0 or 1", cmd.RowsAffected())
}
return nil
}
Expand All @@ -231,7 +227,7 @@ func (db pgbatchtx) Mapping() ([]Mapping, error) {
}
mapps := make([]Mapping, 0, len(rows))
for _, r := range rows {
mapps = append(mapps, Mapping{TableName: r.TableName, OriginalID: r.OriginalID, NewID: r.CloneID})
mapps = append(mapps, Mapping{RecordID: RecordID{Table: r.TableName, PrimaryKey: r.CloneID}, OriginalID: r.OriginalID})
}
return mapps, nil
}
Expand Down Expand Up @@ -312,7 +308,7 @@ func (db pgbatchtx) Insert(rows ...map[string]any) error {
}

prepped := time.Now()
log.Printf("batchrows:%d, followups:%d", batch.Len(), followup.Len())
LogFunc("batchrows:%d, followups:%d", batch.Len(), followup.Len())

res := db.tx.db.SendBatch(db.ctx, batch)
for i := 0; i < batch.Len(); i++ {
Expand All @@ -334,9 +330,7 @@ func (db pgbatchtx) Insert(rows ...map[string]any) error {
}
fks.Close()

if OptLogPostgres {
log.Printf("prepping: %s, batching: %s", prepped.Sub(start), time.Since(prepped))
}
LogFunc("prepping: %s, batching: %s", prepped.Sub(start), time.Since(prepped))

if err := res.Close(); err != nil {
return fmt.Errorf("failed to execute batch followup queries: %w", err)
Expand Down
Loading

0 comments on commit 5d99956

Please sign in to comment.