/
manager.go
245 lines (206 loc) · 5.95 KB
/
manager.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
package router
import (
"database/sql"
_ "embed"
"github.com/1f349/violet/proxy"
"github.com/1f349/violet/target"
"github.com/MrMelon54/rescheduler"
"log"
"net/http"
"strings"
"sync"
)
// Manager is a database and mutex wrap around router allowing it to be
// dynamically regenerated after updating the database of routes.
type Manager struct {
db *sql.DB
s *sync.RWMutex
r *Router
p *proxy.HybridTransport
z *rescheduler.Rescheduler
}
var (
//go:embed create-tables.sql
createTables string
)
// NewManager create a new manager, initialises the routes and redirects tables
// in the database and runs a first time compile.
func NewManager(db *sql.DB, proxy *proxy.HybridTransport) *Manager {
m := &Manager{
db: db,
s: &sync.RWMutex{},
r: New(proxy),
p: proxy,
}
m.z = rescheduler.NewRescheduler(m.threadCompile)
// init routes table
_, err := m.db.Exec(createTables)
if err != nil {
log.Printf("[WARN] Failed to generate tables\n")
return nil
}
return m
}
func (m *Manager) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
m.s.RLock()
r := m.r
m.s.RUnlock()
r.ServeHTTP(rw, req)
}
func (m *Manager) Compile() {
m.z.Run()
}
func (m *Manager) threadCompile() {
// new router
router := New(m.p)
// compile router and check errors
err := m.internalCompile(router)
if err != nil {
log.Printf("[Manager] Compile failed: %s\n", err)
return
}
// lock while replacing router
m.s.Lock()
m.r = router
m.s.Unlock()
}
// internalCompile is a hidden internal method for querying the database during
// the Compile() method.
func (m *Manager) internalCompile(router *Router) error {
log.Println("[Manager] Updating routes from database")
// sql or something?
rows, err := m.db.Query(`SELECT source, destination, flags FROM routes WHERE active = 1`)
if err != nil {
return err
}
defer rows.Close()
// loop through rows and scan the options
for rows.Next() {
var (
src, dst string
flags target.Flags
)
err := rows.Scan(&src, &dst, &flags)
if err != nil {
return err
}
router.AddRoute(target.Route{
Src: src,
Dst: dst,
Flags: flags.NormaliseRouteFlags(),
})
}
// check for errors
if err := rows.Err(); err != nil {
return err
}
// sql or something?
rows, err = m.db.Query(`SELECT source,destination,flags,code FROM redirects WHERE active = 1`)
if err != nil {
return err
}
defer rows.Close()
// loop through rows and scan the options
for rows.Next() {
var (
src, dst string
flags target.Flags
code int
)
err := rows.Scan(&src, &dst, &flags, &code)
if err != nil {
return err
}
router.AddRedirect(target.Redirect{
Src: src,
Dst: dst,
Flags: flags.NormaliseRedirectFlags(),
Code: code,
})
}
// check for errors
return rows.Err()
}
func (m *Manager) GetAllRoutes(hosts []string) ([]target.RouteWithActive, error) {
if len(hosts) < 1 {
return []target.RouteWithActive{}, nil
}
s := make([]target.RouteWithActive, 0)
query, err := m.db.Query(`SELECT source, destination, description, flags, active FROM routes`)
if err != nil {
return nil, err
}
for query.Next() {
var a target.RouteWithActive
if err := query.Scan(&a.Src, &a.Dst, &a.Desc, &a.Flags, &a.Active); err != nil {
return nil, err
}
for _, i := range hosts {
// if this is never true then the domain was mistakenly grabbed from the database
if a.OnDomain(i) {
s = append(s, a)
break
}
}
}
return s, nil
}
func (m *Manager) InsertRoute(route target.RouteWithActive) error {
_, err := m.db.Exec(`INSERT INTO routes (source, destination, description, flags, active) VALUES (?, ?, ?, ?, ?) ON CONFLICT(source) DO UPDATE SET destination = excluded.destination, description = excluded.description, flags = excluded.flags, active = excluded.active`, route.Src, route.Dst, route.Desc, route.Flags, route.Active)
return err
}
func (m *Manager) DeleteRoute(source string) error {
_, err := m.db.Exec(`DELETE FROM routes WHERE source = ?`, source)
return err
}
func (m *Manager) GetAllRedirects(hosts []string) ([]target.RedirectWithActive, error) {
if len(hosts) < 1 {
return []target.RedirectWithActive{}, nil
}
s := make([]target.RedirectWithActive, 0)
query, err := m.db.Query(`SELECT source, destination, description, flags, code, active FROM redirects`)
if err != nil {
return nil, err
}
for query.Next() {
var a target.RedirectWithActive
if err := query.Scan(&a.Src, &a.Dst, &a.Desc, &a.Flags, &a.Code, &a.Active); err != nil {
return nil, err
}
for _, i := range hosts {
// if this is never true then the domain was mistakenly grabbed from the database
if a.OnDomain(i) {
s = append(s, a)
break
}
}
}
return s, nil
}
func (m *Manager) InsertRedirect(redirect target.RedirectWithActive) error {
_, err := m.db.Exec(`INSERT INTO redirects (source, destination, description, flags, code, active) VALUES (?, ?, ?, ?, ?, ?) ON CONFLICT(source) DO UPDATE SET destination = excluded.destination, description = excluded.description, flags = excluded.flags, code = excluded.code, active = excluded.active`, redirect.Src, redirect.Dst, redirect.Desc, redirect.Flags, redirect.Code, redirect.Active)
return err
}
func (m *Manager) DeleteRedirect(source string) error {
_, err := m.db.Exec(`DELETE FROM redirects WHERE source = ?`, source)
return err
}
// GenerateHostSearch this should help improve performance
// TODO(Melon) discover how to implement this correctly
func GenerateHostSearch(hosts []string) (string, []string) {
var searchString strings.Builder
searchString.WriteString("WHERE ")
hostArgs := make([]string, len(hosts)*2)
for i := range hosts {
if i != 0 {
searchString.WriteString(" OR ")
}
// these like checks are not perfect but do reduce load on the database
searchString.WriteString("source LIKE '%' + ? + '/%'")
searchString.WriteString(" OR source LIKE '%' + ?")
// loads the hostname into even and odd args
hostArgs[i*2] = hosts[i]
hostArgs[i*2+1] = hosts[i]
}
return searchString.String(), hostArgs
}