-
Notifications
You must be signed in to change notification settings - Fork 0
/
wsstomp.go
93 lines (83 loc) · 2.48 KB
/
wsstomp.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
package wsstomp
import (
"context"
"fmt"
"net/http"
"time"
"nhooyr.io/websocket"
)
type WebsocketSTOMP struct {
connection *websocket.Conn
readerBuffer []byte
writeBuffer []byte
}
const (
NullByte = 0x00
LineFeedByte = 0x0a
)
// Read messages from the websocket connection until the provided array is full.
// Any surplus data is preserved for the next Read call
func (w *WebsocketSTOMP) Read(p []byte) (int, error) {
// if we have no more data, read the next message from the websocket
if len(w.readerBuffer) == 0 {
_, msg, err := w.connection.Read(context.Background())
if err != nil {
return 0, err
}
w.readerBuffer = msg
}
n := copy(p, w.readerBuffer)
w.readerBuffer = w.readerBuffer[n:]
return n, nil
}
// Write to the websocket.
//
// The written data is held back until a full STOMP frame has been written,
// then a WS message is sent.
func (w *WebsocketSTOMP) Write(p []byte) (int, error) {
var err error
w.writeBuffer = append(w.writeBuffer, p...)
// if we reach a null byte or the entire message is a linefeed (heartbeat), send the message
if p[len(p)-1] == NullByte || (len(w.writeBuffer) == 1 && len(p) == 1 && p[0] == LineFeedByte) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
err = w.connection.Write(ctx, websocket.MessageText, w.writeBuffer)
cancel()
// TODO: preserve write buffer if write fails?
w.writeBuffer = []byte{}
}
return len(p), err
}
func (w *WebsocketSTOMP) Close() error {
return w.connection.Close(websocket.StatusNormalClosure, "terminating connection")
}
// Establish a websocket connection with the provided URL.
// The context parameter will only be used for the connection handshake,
// and not for the full lifetime of the connection.
func Connect(ctx context.Context, url string, options *websocket.DialOptions) (*WebsocketSTOMP, error) {
if options == nil {
options = &websocket.DialOptions{}
}
if options.HTTPClient == nil {
options.HTTPClient = &http.Client{
// fix for https://github.com/nhooyr/websocket/issues/333
CheckRedirect: func(req *http.Request, via []*http.Request) error {
switch req.URL.Scheme {
case "ws":
req.URL.Scheme = "http"
case "wss":
req.URL.Scheme = "https"
case "http", "https":
default:
return fmt.Errorf("unexpected url scheme: %q", req.URL.Scheme)
}
return nil
},
// sane timeout
Timeout: time.Second * 30,
}
}
con, _, err := websocket.Dial(ctx, url, options)
return &WebsocketSTOMP{
connection: con,
}, err
}