forked from jackc/pgtype
/
pgxtype.go
145 lines (119 loc) · 3.72 KB
/
pgxtype.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
package pgxtype
import (
"context"
"errors"
"github.com/jackc/pgconn"
"github.com/jackc/pgtype"
"github.com/jackc/pgx/v4"
)
type Querier interface {
Exec(ctx context.Context, sql string, arguments ...interface{}) (pgconn.CommandTag, error)
Query(ctx context.Context, sql string, optionsAndArgs ...interface{}) (pgx.Rows, error)
QueryRow(ctx context.Context, sql string, optionsAndArgs ...interface{}) pgx.Row
}
// LoadDataType uses conn to inspect the database for typeName and produces a pgtype.DataType suitable for
// registration on ci.
func LoadDataType(ctx context.Context, conn Querier, ci *pgtype.ConnInfo, typeName string) (pgtype.DataType, error) {
var oid uint32
err := conn.QueryRow(ctx, "select $1::text::regtype::oid;", typeName).Scan(&oid)
if err != nil {
return pgtype.DataType{}, err
}
var typtype string
err = conn.QueryRow(ctx, "select typtype::text from pg_type where oid=$1", oid).Scan(&typtype)
if err != nil {
return pgtype.DataType{}, err
}
switch typtype {
case "b": // array
elementOID, err := GetArrayElementOID(ctx, conn, oid)
if err != nil {
return pgtype.DataType{}, err
}
var element pgtype.ValueTranscoder
if dt, ok := ci.DataTypeForOID(elementOID); ok {
if element, ok = dt.Value.(pgtype.ValueTranscoder); !ok {
return pgtype.DataType{}, errors.New("array element OID not registered as ValueTranscoder")
}
}
newElement := func() pgtype.ValueTranscoder {
return pgtype.NewValue(element).(pgtype.ValueTranscoder)
}
at := pgtype.NewArrayType(typeName, elementOID, newElement)
return pgtype.DataType{Value: at, Name: typeName, OID: oid}, nil
case "c": // composite
fields, err := GetCompositeFields(ctx, conn, oid)
if err != nil {
return pgtype.DataType{}, err
}
ct, err := pgtype.NewCompositeType(typeName, fields, ci)
if err != nil {
return pgtype.DataType{}, err
}
return pgtype.DataType{Value: ct, Name: typeName, OID: oid}, nil
case "e": // enum
members, err := GetEnumMembers(ctx, conn, oid)
if err != nil {
return pgtype.DataType{}, err
}
return pgtype.DataType{Value: pgtype.NewEnumType(typeName, members), Name: typeName, OID: oid}, nil
default:
return pgtype.DataType{}, errors.New("unknown typtype")
}
}
func GetArrayElementOID(ctx context.Context, conn Querier, oid uint32) (uint32, error) {
var typelem uint32
err := conn.QueryRow(ctx, "select typelem from pg_type where oid=$1", oid).Scan(&typelem)
if err != nil {
return 0, err
}
return typelem, nil
}
// GetCompositeFields gets the fields of a composite type.
func GetCompositeFields(ctx context.Context, conn Querier, oid uint32) ([]pgtype.CompositeTypeField, error) {
var typrelid uint32
err := conn.QueryRow(ctx, "select typrelid from pg_type where oid=$1", oid).Scan(&typrelid)
if err != nil {
return nil, err
}
var fields []pgtype.CompositeTypeField
rows, err := conn.Query(ctx, `select attname, atttypid
from pg_attribute
where attrelid=$1
order by attnum`, typrelid)
if err != nil {
return nil, err
}
for rows.Next() {
var f pgtype.CompositeTypeField
err := rows.Scan(&f.Name, &f.OID)
if err != nil {
return nil, err
}
fields = append(fields, f)
}
if rows.Err() != nil {
return nil, rows.Err()
}
return fields, nil
}
// GetEnumMembers gets the possible values of the enum by oid.
func GetEnumMembers(ctx context.Context, conn Querier, oid uint32) ([]string, error) {
members := []string{}
rows, err := conn.Query(ctx, "select enumlabel from pg_enum where enumtypid=$1 order by enumsortorder", oid)
if err != nil {
return nil, err
}
for rows.Next() {
var m string
err := rows.Scan(&m)
if err != nil {
return nil, err
}
members = append(members, m)
}
if rows.Err() != nil {
return nil, rows.Err()
}
return members, nil
}