Skip to content

Commit

Permalink
Merge pull request #11 from bessiambre/main
Browse files Browse the repository at this point in the history
Support for non id primary keys
  • Loading branch information
dan-pulley committed Dec 1, 2023
2 parents b234d53 + 94ff925 commit b650ae1
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 5 deletions.
14 changes: 14 additions & 0 deletions clone_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ func TestDownloadUpload(t *testing.T) {
assert.Equal(10, res[0]["id"])
assert.Equal("produces socks", res[1]["desc"])
assert.Equal("socks", res[2]["name"])
assert.Equal("socks are cool", res[3]["detail"])

// users are expected to do some cleanup, so test that it works
for _, row := range res {
Expand All @@ -29,6 +30,7 @@ func TestDownloadUpload(t *testing.T) {
assert.Equal(11, res[0]["id"])
assert.Equal(12, res[1]["id"])
assert.Equal(13, res[2]["id"])
assert.Equal(11, res[3]["company_id"])
}

func cleanup(row map[string]any) {
Expand Down Expand Up @@ -60,6 +62,10 @@ func (d testDB) SelectMatchingRows(tname string, conds map[string][]any) ([]map[
if conds["id"][0] == 23 {
return []map[string]any{{"id": 23, "desc": "produces socks"}}, nil
}
case "company_details":
if conds["company_id"][0] == 10 {
return []map[string]any{{"company_id": 10, "detail": "socks are cool"}}, nil
}
}

return nil, fmt.Errorf("no mock for %s where %#v", tname, conds)
Expand Down Expand Up @@ -99,6 +105,10 @@ func (d testDB) Insert(records ...map[string]any) error {
m["id"] = 13
continue
}
if m[datapasta.DumpTableKey] == "company_details" && m["company_id"] == 10 {
m["company_id"] = 11
continue
}
return fmt.Errorf("unexpected insert: %#v", m)
}
return nil
Expand All @@ -115,6 +125,10 @@ func (d testDB) ForeignKeys() []datapasta.ForeignKey {
BaseTable: "factory", BaseCol: "id",
ReferencingTable: "product", ReferencingCol: "factory_id",
},
{
BaseTable: "company", BaseCol: "id",
ReferencingTable: "company_details", ReferencingCol: "company_id",
},
}
}

Expand Down
15 changes: 10 additions & 5 deletions postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -252,10 +252,10 @@ func (db pgbatchtx) Insert(rows ...map[string]any) error {
builder := db.builder.Insert(`"` + table + `"`)
oldPK := row[pk]
if pk != "" {
builder = builder.Suffix("RETURNING " + pk)
builder = builder.Suffix("RETURNING " + pk + " as id")
builder = builder.Prefix("WITH inserted_row AS (")
builder = builder.Suffix(") INSERT INTO datapasta_clone (table_name, original_id, clone_id) SELECT ?, ?, id FROM inserted_row", table, oldPK)
delete(row, pk)
//delete(row, pk)
}

keys := make([]string, 0, len(row))
Expand All @@ -268,8 +268,11 @@ func (db pgbatchtx) Insert(rows ...map[string]any) error {
continue
}
deferred := false
foundForeign := false
for _, fk := range db.fks {

if fk.ReferencingCol == k && fk.ReferencingTable == table {
foundForeign = true
findInMap := squirrel.Expr("COALESCE((SELECT clone_id FROM datapasta_clone WHERE original_id = ? AND table_name = ?::text), ?)", v, fk.BaseTable, v)

if fk.BaseTable == table {
Expand All @@ -278,7 +281,7 @@ func (db pgbatchtx) Insert(rows ...map[string]any) error {
return fmt.Errorf("can't have self-referencing tables without primary key")
}
deferred = true
builder := db.builder.Update(`"`+table+`"`).Set(k, findInMap).Where("id=(SELECT clone_id FROM datapasta_clone WHERE original_id = ? AND table_name = ?::text)", oldPK, fk.BaseTable)
builder := db.builder.Update(`"`+table+`"`).Set(k, findInMap).Where(pk+"=(SELECT clone_id FROM datapasta_clone WHERE original_id = ? AND table_name = ?::text)", oldPK, fk.BaseTable)
sql, args, err := builder.ToSql()
if err != nil {
return fmt.Errorf(`build: %w, args: %s, sql: %s`, err, args, sql)
Expand All @@ -294,8 +297,10 @@ func (db pgbatchtx) Insert(rows ...map[string]any) error {
if deferred {
continue
}
keys = append(keys, fmt.Sprintf(`"%s"`, k))
vals = append(vals, v)
if foundForeign || k != pk {
keys = append(keys, fmt.Sprintf(`"%s"`, k))
vals = append(vals, v)
}
}

builder = builder.Columns(keys...).Values(vals...)
Expand Down

0 comments on commit b650ae1

Please sign in to comment.