-
Notifications
You must be signed in to change notification settings - Fork 0
/
proxy.go
121 lines (98 loc) · 2.99 KB
/
proxy.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
package tcp_to_ws
import (
"github.com/Roman-Mitusov/middleware/proxy"
"github.com/fasthttp/websocket"
"github.com/gofiber/fiber/v2"
"github.com/savsgio/gotils"
"github.com/sirupsen/logrus"
"github.com/valyala/fasthttp"
"io"
"net"
)
var (
//Default TCP dialer
DefaultTcpDialer = &fasthttp.TCPDialer{Concurrency: 1000}
// Default web socket upgrader to upgrade the incoming connections
DefaultWebSocketUpgrader = &websocket.FastHTTPUpgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
}
)
type TcpToWSProxy struct {
PrepareRequest func(ctx *fiber.Ctx) error
TcpDialer *fasthttp.TCPDialer
Upgrader *websocket.FastHTTPUpgrader
}
func (p *TcpToWSProxy) ProxyTcpToWS(ctx *fiber.Ctx) (err error) {
logger := proxy.DefaultLogger()
if b := websocket.FastHTTPIsWebSocketUpgrade(ctx.Context()); b {
logger.Infof("Request is upgraded %v", b)
}
var (
dialer = DefaultTcpDialer
upgrader = DefaultWebSocketUpgrader
)
if p.TcpDialer != nil {
dialer = p.TcpDialer
}
if p.Upgrader != nil {
upgrader = p.Upgrader
}
if p.PrepareRequest != nil {
if err = p.PrepareRequest(ctx); err != nil {
return err
}
}
upgrader.CheckOrigin = func(ctx *fasthttp.RequestCtx) bool {
return true
}
targetHost := gotils.B2S(ctx.Request().URI().Host())
tcpConn, err := dialer.Dial(targetHost)
if err != nil {
logger.Errorf("Unable to establish tcp connection to host=%s with error=%v", targetHost, err)
return err
}
err = upgrader.Upgrade(ctx.Context(), func(clientConn *websocket.Conn) {
defer clientConn.Close()
var (
errClient = make(chan error, 1)
message string
)
if ctx.Context().IsTLS() {
ctx.Response().Header.Set("Sec-WebSocket-Protocol", "https")
} else {
ctx.Response().Header.Set("Sec-WebSocket-Protocol", "http")
}
logger.Info("Upgrade handler working")
go copyTcpResponseToWebSocketConnection(clientConn, tcpConn, errClient, logger)
for {
select {
case err = <-errClient:
message = "tcptowsproxy: Error when copying response from tcp to ws: %v"
}
// log error except '*websocket.CloseError'
if _, ok := err.(*websocket.CloseError); !ok {
logger.Errorf(message, err)
}
}
})
if err != nil {
logger.Errorf("tcptowsproxy: couldn't upgrade %s", err)
return err
}
return nil
}
func copyTcpResponseToWebSocketConnection(dst *websocket.Conn, src net.Conn, errChan chan error, logger *logrus.Logger) {
logger.Info("tcptowsproxy: start copying tcp response bytes to client WebSocket connection")
logger.Info("tcptowsproxy: obtaining WebSocket connection writer....")
writer, err := dst.NextWriter(websocket.BinaryMessage)
if err != nil {
errChan <- err
logrus.Errorf("tcptowsproxy: error obtaining writer from WebSocket client connection err=%v", err)
}
logger.Info("tcptowsproxy: start copying tcp response to WebSocket client connection ...")
if _, err := io.Copy(writer, src); err != nil {
errChan <- err
logger.Errorf("tcptowsproxy: error copy tcp response to WebSocket client connection err=%v", err)
}
}