/
db.go
102 lines (93 loc) · 1.93 KB
/
db.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
package main
import (
"database/sql"
"math"
_ "github.com/mattn/go-sqlite3"
)
const dbName = "data.db"
const initSql = `
CREATE TABLE IF NOT EXISTS prices (
symbol TEXT NOT NULL PRIMARY KEY,
date DATETIME NOT NULL,
price REAL NOT NULL
);`
type Repository struct {
db *sql.DB
}
func newRepository() (*Repository, error) {
db, err := sql.Open("sqlite3", dbName)
if err != nil {
return nil, err
}
r := Repository{db: db}
return &r, nil
}
func (r *Repository) initDB() error {
if _, err := r.db.Exec(initSql); err != nil {
return err
}
return nil
}
func (r *Repository) getDBSymbols() ([]string, error) {
rows, err := r.db.Query("SELECT symbol FROM prices")
if err != nil {
return nil, err
}
var res []string
var symbol string
for rows.Next() {
err := rows.Scan(&symbol)
if err != nil {
return nil, err
}
res = append(res, symbol)
}
return res, nil
}
func (r *Repository) insertSymbols(symbols map[string]float64) error {
rows, err := r.db.Query("SELECT symbol, price FROM prices")
if err != nil {
return err
}
var symbol string
var price float64
for rows.Next() {
err := rows.Scan(&symbol, &price)
if err != nil {
return err
}
if v, ok := symbols[symbol]; ok {
if math.Abs(v-price) > 0.01 {
_, err := r.db.Exec("UPDATE prices SET price=?, date=CURRENT_TIMESTAMP WHERE symbol=?", v, symbol)
if err != nil {
return err
}
}
delete(symbols, symbol)
}
}
for s, p := range symbols {
_, err := r.db.Exec("INSERT INTO prices VALUES(?, CURRENT_TIMESTAMP, ?)", s, p)
if err != nil {
return err
}
}
return nil
}
func (r *Repository) getPrices() (map[string]float64, error) {
rows, err := r.db.Query("SELECT symbol, price FROM prices")
if err != nil {
return nil, err
}
res := make(map[string]float64)
var symbol string
var price float64
for rows.Next() {
err := rows.Scan(&symbol, &price)
if err != nil {
return nil, err
}
res[symbol] = price
}
return res, nil
}