/
server.go
132 lines (117 loc) · 2.49 KB
/
server.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
package http
import (
"net"
"net/http"
"sync"
)
// StatsListener implements a net.Listener for http.Server
type StatsListener struct {
listener net.Listener
// stats
*ConnStats
}
func NewStatsListener(ln net.Listener) *StatsListener {
return &StatsListener{
listener: ln,
ConnStats: &ConnStats{},
}
}
func (sl *StatsListener) Accept() (net.Conn, error) {
conn, err := sl.listener.Accept()
if err != nil {
return conn, err
}
sl.ConnStats.ActiveConnection()
statsConn := &StatsConn{
Conn: conn,
cb: func() {
sl.ConnStats.CloseConnection()
},
}
return statsConn, nil
}
func (sl *StatsListener) Close() error {
return sl.listener.Close()
}
func (sl *StatsListener) Addr() net.Addr {
return sl.listener.Addr()
}
type StatsConn struct {
net.Conn
cb func()
}
func (conn *StatsConn) Close() error {
err := conn.Conn.Close()
if err == nil {
conn.cb()
}
return err
}
type MockServer struct {
*http.Server
*ServerStats
Addr string
Mux map[string]func(http.ResponseWriter, *http.Request)
listener *StatsListener
lock sync.Mutex
}
func NewMockServer(addr string, f ServeFunc) *MockServer {
srv := &MockServer{
Server: &http.Server{},
Addr: addr,
Mux: make(map[string]func(http.ResponseWriter, *http.Request)),
lock: sync.Mutex{},
// Stats and listener in init in Start()
}
// Register Mux
if f == nil {
f = DefaultHTTPServe.Serve
}
f(srv)
return srv
}
func (srv *MockServer) HandleFunc(pattern string, handler func(http.ResponseWriter, *http.Request)) {
srv.Mux[pattern] = handler
}
type ResponseWriterWrapper struct {
http.ResponseWriter
stats *ServerStats
}
func (w *ResponseWriterWrapper) WriteHeader(code int) {
w.stats.Response(int16(code))
w.ResponseWriter.WriteHeader(code)
}
// A wrapper of ServerHTTP
func (srv *MockServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
srv.ServerStats.Request()
ww := &ResponseWriterWrapper{
ResponseWriter: w,
stats: srv.ServerStats,
}
// path
handler, ok := srv.Mux[r.URL.Path]
if !ok {
ww.WriteHeader(http.StatusNotFound)
return
}
handler(ww, r)
}
func (srv *MockServer) Start() {
srv.lock.Lock()
if srv.listener != nil {
srv.lock.Unlock()
return
}
ln, err := net.Listen("tcp", srv.Addr)
if err != nil {
panic(err)
}
srv.listener = NewStatsListener(ln)
srv.ServerStats = NewServerStats(srv.listener.ConnStats)
srv.lock.Unlock()
// register http server
mux := &http.ServeMux{}
mux.HandleFunc("/", srv.ServeHTTP)
srv.Handler = mux
srv.Serve(srv.listener)
}