/
websocket.go
118 lines (101 loc) · 2.77 KB
/
websocket.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
package websocket
import (
"context"
"crypto/tls"
"fmt"
"log/slog"
"net/http"
"net/url"
"time"
"github.com/absmach/mproxy/pkg/session"
mptls "github.com/absmach/mproxy/pkg/tls"
"github.com/gorilla/websocket"
)
// Proxy represents WS Proxy.
type Proxy struct {
target string
path string
scheme string
handler session.Handler
interceptor session.Interceptor
logger *slog.Logger
}
// New - creates new WS proxy
func New(target, path, scheme string, handler session.Handler, interceptor session.Interceptor, logger *slog.Logger) *Proxy {
return &Proxy{
target: target,
path: path,
scheme: scheme,
handler: handler,
interceptor: interceptor,
logger: logger,
}
}
var upgrader = websocket.Upgrader{
// Timeout for WS upgrade request handshake
HandshakeTimeout: 10 * time.Second,
// Paho JS client expecting header Sec-WebSocket-Protocol:mqtt in Upgrade response during handshake.
Subprotocols: []string{"mqttv3.1", "mqtt"},
// Allow CORS
CheckOrigin: func(r *http.Request) bool {
return true
},
}
// Handler - proxies WS traffic
func (p Proxy) Handler() http.Handler {
return p.handle()
}
func (p Proxy) handle() http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
cconn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
p.logger.Error("Error upgrading connection", slog.Any("error", err))
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
go p.pass(r.Context(), cconn)
})
}
func (p Proxy) pass(ctx context.Context, in *websocket.Conn) {
defer in.Close()
url := url.URL{
Scheme: p.scheme,
Host: p.target,
Path: p.path,
}
dialer := &websocket.Dialer{
Subprotocols: []string{"mqtt"},
}
srv, _, err := dialer.Dial(url.String(), nil)
if err != nil {
p.logger.Error("Unable to connect to broker", slog.Any("error", err))
return
}
errc := make(chan error, 1)
inboundConn := newConn(in)
outboundConn := newConn(srv)
defer inboundConn.Close()
defer outboundConn.Close()
clientCert, err := mptls.ClientCert(in.UnderlyingConn())
if err != nil {
p.logger.Error("Failed to get client certificate", slog.Any("error", err))
return
}
err = session.Stream(ctx, inboundConn, outboundConn, p.handler, p.interceptor, clientCert)
errc <- err
p.logger.Warn("Broken connection for client", slog.Any("error", err))
}
// Listen of the server
func (p Proxy) Listen(wsPort string) error {
port := fmt.Sprintf(":%s", wsPort)
return http.ListenAndServe(port, nil)
}
// ListenTLS - version of Listen with TLS encryption
func (p Proxy) ListenTLS(tlsCfg *tls.Config, crt, key, wssPort string) error {
port := fmt.Sprintf(":%s", wssPort)
server := &http.Server{
Addr: port,
TLSConfig: tlsCfg,
}
return server.ListenAndServeTLS(crt, key)
}