/
websocket.go
161 lines (138 loc) · 3.9 KB
/
websocket.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
package websocket
import (
"fmt"
"log"
"net"
"net/http"
"strconv"
"strings"
"time"
"github.com/awdng/triebwerk/model"
"github.com/gorilla/websocket"
)
// Transport represents the websocket context
type Transport struct {
upgrader websocket.Upgrader
register func(conn model.Connection)
unregister func(conn model.Connection)
port int
address string
}
// NewTransport creates the websocket context
func NewTransport(address string, port int) *Transport {
return &Transport{
address: address,
port: port,
upgrader: websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
CheckOrigin: func(r *http.Request) bool {
return true
},
},
}
}
// GetAddress ...
func (t *Transport) GetAddress() string {
return strings.Join([]string{t.address, strconv.Itoa(t.port)}, ":")
}
// RegisterNewConnHandler is a callback for new connections
func (t *Transport) RegisterNewConnHandler(register func(conn model.Connection)) {
t.register = register
}
// UnregisterConnHandler is a callback for closed connections
func (t *Transport) UnregisterConnHandler(unregister func(conn model.Connection)) {
t.unregister = unregister
}
// Unregister callback
func (t *Transport) Unregister(conn model.Connection) {
t.unregister(conn)
}
// Init ...
func (t *Transport) Init() {
http.HandleFunc("/echo", func(w http.ResponseWriter, r *http.Request) {
ws, err := t.upgrader.Upgrade(w, r, nil)
if err != nil {
log.Println(err)
return
}
conn := NewConnection(ws)
t.register(conn)
})
}
// Run ...
func (t *Transport) Run() error {
listener, err := net.Listen("tcp", fmt.Sprintf(":%d", t.port))
if err != nil {
panic(err)
}
t.port = listener.Addr().(*net.TCPAddr).Port
log.Printf("Starting Triebwerk Websocket Server on %s...", t.GetAddress())
return http.Serve(listener, nil)
}
// Connection represents a websocket connection
type Connection struct {
conn *websocket.Conn
}
// NewConnection creates a new connection
func NewConnection(conn *websocket.Conn) *Connection {
return &Connection{
conn: conn,
}
}
// Identifier of the connection
func (c *Connection) Identifier() string {
return fmt.Sprintf("%s - %s", c.conn.RemoteAddr().Network(), c.conn.RemoteAddr().String())
}
// Close sends the websocket CloseMessage
// https://tools.ietf.org/html/rfc6455#section-5.5.1
// graceful == false closes immediatly
func (c *Connection) Close(writeWait time.Duration, graceful bool) {
if !graceful {
c.conn.Close()
return
}
c.conn.SetWriteDeadline(time.Now().Add(writeWait))
c.conn.WriteMessage(websocket.CloseMessage, []byte{})
}
// PrepareRead prepares the websocket connection for reading
func (c *Connection) PrepareRead(maxMessageSize int64, pongWait time.Duration) {
c.conn.SetReadLimit(maxMessageSize)
c.conn.SetReadDeadline(time.Now().Add(pongWait))
c.conn.SetPongHandler(func(string) error { c.conn.SetReadDeadline(time.Now().Add(pongWait)); return nil })
}
// Read from the network connection
func (c *Connection) Read() ([]byte, error) {
_, message, err := c.conn.ReadMessage()
if err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
// TODO: wrap in another error
}
return nil, err
}
return message, nil
}
// PrepareWrite prepares the websocket connection for writing
func (c *Connection) PrepareWrite(writeWait time.Duration) {
c.conn.SetWriteDeadline(time.Now().Add(writeWait))
}
// Write to the network connection
func (c *Connection) Write(data []byte) error {
writer, err := c.conn.NextWriter(websocket.BinaryMessage)
if err != nil {
return err
}
writer.Write(data)
// Flush data to the network
if err := writer.Close(); err != nil {
return err
}
return nil
}
// Ping sends a ping message to the client
func (c *Connection) Ping(writeWait time.Duration) {
c.conn.SetWriteDeadline(time.Now().Add(writeWait))
if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil {
return
}
}