-
Notifications
You must be signed in to change notification settings - Fork 134
/
upsert.go
194 lines (181 loc) · 5.94 KB
/
upsert.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
package database
import (
"context"
"fmt"
"reflect"
"strings"
"github.com/jackc/pgx/v4/pgxpool"
"github.com/google/uuid"
"github.com/jackc/pgx/v4"
"github.com/pkg/errors"
)
func UpsertWithTransaction[T any](ctx context.Context, db *pgxpool.Pool, tableName string, records []T) error {
if len(records) == 0 {
return nil
}
return db.BeginTxFunc(ctx, pgx.TxOptions{
IsoLevel: pgx.ReadCommitted,
AccessMode: pgx.ReadWrite,
DeferrableMode: pgx.Deferrable,
}, func(tx pgx.Tx) error {
return Upsert(ctx, tx, tableName, records)
})
}
// Upsert is an optimised SQL call for bulk upserts.
//
// For efficiency, this function:
// 1. Creates an empty temporary SQL table.
// 2. Inserts all records into the temporary table using the postgres-specific COPY wire protocol.
// 3. Upserts all records from the temporary table into the target table (as specified by tableName).
//
// The COPY protocol can be faster than repeated inserts for as little as 5 rows; see
// https://www.postgresql.org/docs/current/populate.html
// https://pkg.go.dev/github.com/jackc/pgx/v4#hdr-Copy_Protocol
//
// The records to write should be structs with fields marked with "db" tags.
// Field names and values are extracted using the NamesValuesFromRecord function;
// see its definition for details. The first field is used as the primary key in SQL.
//
// The temporary table is created with the provided schema, which should be of the form
// (
//
// id UUID PRIMARY KEY,
// width int NOT NULL,
// height int NOT NULL
//
// )
// I.e., it should omit everything before and after the "(" and ")", respectively.
func Upsert[T any](ctx context.Context, tx pgx.Tx, tableName string, records []T) error {
if len(records) < 1 {
return nil
}
// Write records into postgres.
// First, create a temporary table for loading data in bulk using the copy protocol.
// TODO: don't use select * here but rather just select the cols we care about
tempTableName := uniqueTableName(tableName)
_, err := tx.Exec(ctx, fmt.Sprintf("CREATE TEMPORARY TABLE %s ON COMMIT DROP AS SELECT * FROM %s LIMIT 0;", tempTableName, tableName))
if err != nil {
return errors.WithStack(err)
}
// Use the postgres-specific COPY wire protocol to load data into the new table in a single operation.
// The COPY protocol can be faster than repeated inserts for as little as 5 rows; see
// https://www.postgresql.org/docs/current/populate.html
// https://pkg.go.dev/github.com/jackc/pgx/v4#hdr-Copy_Protocol
//
// We're guaranteed there is at least one record.
names, _ := NamesValuesFromRecord(records[0])
if len(names) < 2 {
return errors.Errorf("Names() must return at least 2 elements, but got %v", names)
}
n, err := tx.CopyFrom(ctx,
pgx.Identifier{tempTableName},
names,
pgx.CopyFromSlice(len(records), func(i int) ([]interface{}, error) {
// TODO: Are we guaranteed that values always come in the order listed in the record? Otherwise we need to control the order.
_, values := NamesValuesFromRecord(records[i])
return values, nil
}),
)
if err != nil {
return errors.WithStack(err)
}
if n != int64(len(records)) {
return errors.Errorf("only %d out of %d rows were inserted", n, len(records))
}
// Move those rows into the main table, using ON CONFLICT rules to over-write existing rows.
var b strings.Builder
fmt.Fprintf(&b, "INSERT INTO %s SELECT * from %s ", tableName, tempTableName)
fmt.Fprintf(&b, "ON CONFLICT (%s) DO UPDATE SET ", names[0])
for i, name := range names {
fmt.Fprintf(&b, "%s = EXCLUDED.%s", name, name)
if i != len(names)-1 {
fmt.Fprintf(&b, ", ")
}
}
fmt.Fprint(&b, ";")
_, err = tx.Exec(ctx, b.String())
if err != nil {
return errors.WithStack(err)
}
return nil
}
// NamesFromRecord returns a slice composed of the field names in a struct marked with "db" tags.
//
// For example, if x is an instance of a struct with definition
//
// type Rectangle struct {
// Width int `db:"width"`
// Height int `db:"height"`
// },
//
// it returns ["width", "height"].
func NamesFromRecord(x interface{}) []string {
t := reflect.TypeOf(x)
names := make([]string, 0, t.NumField())
for i := 0; i < t.NumField(); i++ {
name := t.Field(i).Tag.Get("db")
if name != "" {
names = append(names, name)
}
}
return names
}
// ValuesFromRecord returns a slice composed of the values of the fields in a struct marked with "db" tags.
//
// For example, if x is an instance of a struct with definition
//
// type Rectangle struct {
// Name string,
// Width int `db:"width"`
// Height int `db:"height"`
// },
//
// where Width = 5 and Height = 10, it returns [5, 10].
func ValuesFromRecord(x interface{}) []interface{} {
t := reflect.TypeOf(x)
v := reflect.ValueOf(x)
values := make([]interface{}, 0, v.NumField())
for i := 0; i < t.NumField(); i++ {
name := t.Field(i).Tag.Get("db")
if name != "" {
value := v.Field(i).Interface()
values = append(values, value)
}
}
return values
}
// NamesValuesFromRecord returns a slice composed of the field names
// and another composed of the corresponding values
// for fields of a struct marked with "db" tags.
//
// For example, if x is an instance of a struct with definition
//
// type Rectangle struct {
// Width int `db:"width"`
// Height int `db:"height"`
// },
//
// where Width = 10 and Height = 5,
// it returns ["width", "height"], [10, 5].
//
// This function does not handle pointers to structs,
// i.e., x must be Rectangle{} and not &Rectangle{}.
func NamesValuesFromRecord(x interface{}) ([]string, []interface{}) {
t := reflect.TypeOf(x)
v := reflect.ValueOf(x)
names := make([]string, 0, t.NumField())
values := make([]interface{}, 0, v.NumField())
for i := 0; i < t.NumField(); i++ {
name := t.Field(i).Tag.Get("db")
if name != "" {
names = append(names, name)
value := v.Field(i).Interface()
values = append(values, value)
}
}
return names, values
}
func uniqueTableName(table string) string {
suffix := strings.ReplaceAll(uuid.New().String(), "-", "")
return fmt.Sprintf("%s_tmp_%s", table, suffix)
}