/
db.go
104 lines (91 loc) · 2.8 KB
/
db.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
package db
import (
"database/sql"
"fmt"
"github.com/aclindsa/gorp"
"github.com/aclindsa/moneygo/internal/config"
"github.com/aclindsa/moneygo/internal/models"
"github.com/aclindsa/moneygo/internal/store"
_ "github.com/go-sql-driver/mysql"
_ "github.com/lib/pq"
_ "github.com/mattn/go-sqlite3"
"log"
"strings"
)
// luaMaxLengthBuffer is intended to be enough bytes such that a given string
// no longer than models.LuaMaxLength is sure to fit within a database
// implementation's string type specified by the same.
const luaMaxLengthBuffer int = 4096
func getDbMap(db *sql.DB, dbtype config.DbType) (*gorp.DbMap, error) {
var dialect gorp.Dialect
if dbtype == config.SQLite {
dialect = gorp.SqliteDialect{}
} else if dbtype == config.MySQL {
dialect = gorp.MySQLDialect{
Engine: "InnoDB",
Encoding: "UTF8",
}
} else if dbtype == config.Postgres {
dialect = gorp.PostgresDialect{
LowercaseFields: true,
}
} else {
return nil, fmt.Errorf("Don't know gorp dialect to go with '%s' DB type", dbtype.String())
}
dbmap := &gorp.DbMap{Db: db, Dialect: dialect}
dbmap.AddTableWithName(models.User{}, "users").SetKeys(true, "UserId")
dbmap.AddTableWithName(models.Session{}, "sessions").SetKeys(true, "SessionId")
dbmap.AddTableWithName(models.Security{}, "securities").SetKeys(true, "SecurityId")
dbmap.AddTableWithName(Price{}, "prices").SetKeys(true, "PriceId")
dbmap.AddTableWithName(models.Account{}, "accounts").SetKeys(true, "AccountId")
dbmap.AddTableWithName(models.Transaction{}, "transactions").SetKeys(true, "TransactionId")
dbmap.AddTableWithName(Split{}, "splits").SetKeys(true, "SplitId")
rtable := dbmap.AddTableWithName(models.Report{}, "reports").SetKeys(true, "ReportId")
rtable.ColMap("Lua").SetMaxSize(models.LuaMaxLength + luaMaxLengthBuffer)
err := dbmap.CreateTablesIfNotExists()
if err != nil {
return nil, err
}
return dbmap, nil
}
func getDSN(dbtype config.DbType, dsn string) string {
if dbtype == config.MySQL && !strings.Contains(dsn, "parseTime=true") {
log.Fatalf("The DSN for MySQL MUST contain 'parseTime=True' but does not!")
}
return dsn
}
type DbStore struct {
dbMap *gorp.DbMap
}
func (db *DbStore) Empty() error {
return db.dbMap.TruncateTables()
}
func (db *DbStore) Begin() (store.Tx, error) {
tx, err := db.dbMap.Begin()
if err != nil {
return nil, err
}
return &Tx{db.dbMap.Dialect, tx}, nil
}
func (db *DbStore) Close() error {
err := db.dbMap.Db.Close()
db.dbMap = nil
return err
}
func GetStore(dbtype config.DbType, dsn string) (store store.Store, err error) {
dsn = getDSN(dbtype, dsn)
database, err := sql.Open(dbtype.String(), dsn)
if err != nil {
return nil, err
}
defer func() {
if err != nil {
database.Close()
}
}()
dbmap, err := getDbMap(database, dbtype)
if err != nil {
return nil, err
}
return &DbStore{dbmap}, nil
}