/
store.go
159 lines (128 loc) · 3.83 KB
/
store.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
package store
import (
"context"
"database/sql/driver"
"encoding/json"
"fmt"
"time"
"github.com/pkg/errors"
"github.com/jmoiron/sqlx"
_ "github.com/lib/pq" // Register lib/pq PostreSQL driver
"github.com/sirupsen/logrus"
)
// ErrNoResults indicates that no data matching the query was found.
type ErrNoResults struct {
error
}
// ErrExists is returned if a unique record already exists.
type ErrExists struct {
error
}
// ErrFKeyViolation is returned if inserting a record causes a foreign key violation.
type ErrFKeyViolation struct {
error
}
const pgExists = "23505"
const pgFKeyViolation = "23503"
// Service provides methods for storing data in a PostgreSQL database.
type Service struct {
db *sqlx.DB
logger *logrus.Logger
}
// New creates a new store service from a dataSourceName. The logger is used to
// log errors that would not otherwise be returned such as issues rolling back
// transactions. The context passed is used for pinging the database.
func New(ctx context.Context, dsn string, logger *logrus.Logger) (Service, error) {
db, err := sqlx.Open("postgres", dsn)
if err != nil {
return Service{}, err
}
s := Service{db: db, logger: logger}
return s, s.Ping(ctx)
}
// Ping pings the underlying postgresql database. You would think we would call
// db.Ping() here, but that doesn't actually Ping the database because reasons.
func (s *Service) Ping(ctx context.Context) error {
if s.db != nil {
return s.db.QueryRowContext(ctx, "SELECT 1").Scan(new(bool))
}
return errors.New("not connected to postgresql")
}
// Close closes the underlying postgresql database.
func (s *Service) Close() error {
if s.db != nil {
return s.db.Close()
}
return nil
}
func (s *Service) logErr(err error) {
if err != nil {
s.logger.Error(err)
}
}
func (s *Service) doTransaction(ctx context.Context, txWrapper func(*sqlx.Tx) error) error {
tx, err := s.db.BeginTxx(ctx, nil)
if err != nil {
return errors.Wrap(err, "unable to begin transaction")
}
if err := txWrapper(tx); err != nil {
if rbErr := tx.Rollback(); rbErr != nil {
s.logErr(errors.Wrap(err, "unable to rollback transaction"))
}
return errors.Wrap(err, "error in transaction wrapper")
}
return errors.Wrap(tx.Commit(), "unable to commit transaction")
}
// UnixTime exists so that we can have times that look like time.Time's to
// database drivers and JSON marshallers/unmarshallers but are internally
// represented as unix timestamps for easier comparison.
type UnixTime struct {
Unix int64
}
// NewUnixFromTime creates a new UnixTime timestamp from a time.Time.
func NewUnixFromTime(time time.Time) UnixTime {
return UnixTime{
Unix: time.Unix(),
}
}
// NewUnixFromInt creates a new UnixTime timestamp from an int64.
func NewUnixFromInt(time int64) UnixTime {
return UnixTime{
Unix: time,
}
}
// Scan accepts either a time.Time or an int64 for scanning from a database into
// a unix timestamp.
func (ut *UnixTime) Scan(src interface{}) error {
if ut == nil {
return errors.New("cannot scan into nil unix time")
}
switch v := src.(type) {
case time.Time:
ut.Unix = v.Unix()
case int64:
ut.Unix = v
default:
return fmt.Errorf("got invalid type for time: %T", src)
}
return nil
}
// Value returns a driver.Value that is always a time.Time that represents the
// internally stored unix time.
func (ut UnixTime) Value() (driver.Value, error) {
return time.Unix(ut.Unix, 0), nil
}
// MarshalJSON returns a []byte that represents this UnixTime in RFC 3339 format.
func (ut *UnixTime) MarshalJSON() ([]byte, error) {
return time.Unix(ut.Unix, 0).MarshalJSON()
}
// UnmarshalJSON accepts a []byte representing a time.Time value, and unmarshals
// it into a unix timestamp.
func (ut *UnixTime) UnmarshalJSON(data []byte) error {
var time time.Time
if err := json.Unmarshal(data, &time); err != nil {
return err
}
ut.Unix = time.Unix()
return nil
}