Skip to content

Commit

Permalink
handling websocket connection upgrade fixed
Browse files Browse the repository at this point in the history
  • Loading branch information
dencoded authored and buger committed Jan 6, 2018
1 parent ae6bf5f commit 8f320ad
Show file tree
Hide file tree
Showing 4 changed files with 132 additions and 28 deletions.
1 change: 1 addition & 0 deletions TESTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,7 @@ There is also few special URLs with specific behavior:
- `/get` accepts only `GET` requests
- `/post` accepts only `POST` requests
- `/jwk.json` used for cases when JWK token downloaded from upsteram
- `/ws` used for testing WebSockets
- `/bundles` built in plugin bundle web server, more details below

### Coprocess plugin testing
Expand Down
127 changes: 99 additions & 28 deletions gateway_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -660,32 +660,7 @@ func TestWithCacheAllSafeRequests(t *testing.T) {
}...)
}

func TestWebsocketsUpstream(t *testing.T) {
// setup and run web socket upstream
var upgrader = websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
}

wsHandler := func(w http.ResponseWriter, req *http.Request) {
conn, err := upgrader.Upgrade(w, req, nil)
if err != nil {
t.Error("cannot upgrade:", err)
http.Error(w, fmt.Sprintf("cannot upgrade: %v", err), http.StatusInternalServerError)
}
mt, p, err := conn.ReadMessage()
if err != nil {
t.Error("cannot read message:", err)
return
}
conn.WriteMessage(mt, []byte("reply to message:"+string(p)))
}
wsServer := httptest.NewServer(http.HandlerFunc(wsHandler))
defer wsServer.Close()
u, _ := url.Parse(wsServer.URL)
u.Scheme = "ws"
targetUrl := u.String()

func TestWebsocketsUpstreamUpgradeRequest(t *testing.T) {
// setup spec and do test HTTP upgrade-request
config.Global.HttpServerOptions.EnableWebSockets = true
defer resetTestConfig()
Expand All @@ -695,16 +670,112 @@ func TestWebsocketsUpstream(t *testing.T) {

buildAndLoadAPI(func(spec *APISpec) {
spec.Proxy.ListenPath = "/"
spec.Proxy.TargetURL = targetUrl
})

ts.Run(t, test.TestCase{
Code: http.StatusSwitchingProtocols,
Path: "/ws",
Headers: map[string]string{
"Connection": "Upgrade",
"Upgrade": "websocket",
"Sec-Websocket-Version": "13",
"Sec-Websocket-Key": "abc",
},
Code: http.StatusSwitchingProtocols,
})
}

func TestWebsocketsSeveralOpenClose(t *testing.T) {
config.Global.HttpServerOptions.EnableWebSockets = true
defer resetTestConfig()

ts := newTykTestServer()
defer ts.Close()

buildAndLoadAPI(func(spec *APISpec) {
spec.Proxy.ListenPath = "/"
})

baseURL := strings.Replace(ts.URL, "http://", "ws://", -1)

// connect 1st time, send and read message, close connection
conn1, _, err := websocket.DefaultDialer.Dial(baseURL+"/ws", nil)
if err != nil {
t.Fatalf("cannot make websocket connection: %v", err)
}
err = conn1.WriteMessage(websocket.BinaryMessage, []byte("test message 1"))
if err != nil {
t.Fatalf("cannot write message: %v", err)
}
_, p, err := conn1.ReadMessage()
if err != nil {
t.Fatalf("cannot read message: %v", err)
}
if string(p) != "reply to message: test message 1" {
t.Error("Unexpected reply:", string(p))
}
conn1.Close()

// connect 2nd time, send and read message, but don't close yet
conn2, _, err := websocket.DefaultDialer.Dial(baseURL+"/ws", nil)
if err != nil {
t.Fatalf("cannot make websocket connection: %v", err)
}
err = conn2.WriteMessage(websocket.BinaryMessage, []byte("test message 2"))
if err != nil {
t.Fatalf("cannot write message: %v", err)
}
_, p, err = conn2.ReadMessage()
if err != nil {
t.Fatalf("cannot read message: %v", err)
}
if string(p) != "reply to message: test message 2" {
t.Error("Unexpected reply:", string(p))
}

// connect 3d time having one connection already open before, send and read message
conn3, _, err := websocket.DefaultDialer.Dial(baseURL+"/ws", nil)
if err != nil {
t.Fatalf("cannot make websocket connection: %v", err)
}
err = conn3.WriteMessage(websocket.BinaryMessage, []byte("test message 3"))
if err != nil {
t.Fatalf("cannot write message: %v", err)
}
_, p, err = conn3.ReadMessage()
if err != nil {
t.Fatalf("cannot read message: %v", err)
}
if string(p) != "reply to message: test message 3" {
t.Error("Unexpected reply:", string(p))
}

// check that we still can interact via 2nd connection we did before
err = conn2.WriteMessage(websocket.BinaryMessage, []byte("new test message 2"))
if err != nil {
t.Fatalf("cannot write message: %v", err)
}
_, p, err = conn2.ReadMessage()
if err != nil {
t.Fatalf("cannot read message: %v", err)
}
if string(p) != "reply to message: new test message 2" {
t.Error("Unexpected reply:", string(p))
}

// check that we still can interact via 3d connection we did before
err = conn3.WriteMessage(websocket.BinaryMessage, []byte("new test message 3"))
if err != nil {
t.Fatalf("cannot write message: %v", err)
}
_, p, err = conn3.ReadMessage()
if err != nil {
t.Fatalf("cannot read message: %v", err)
}
if string(p) != "reply to message: new test message 3" {
t.Error("Unexpected reply:", string(p))
}

// clean up connections
conn2.Close()
conn3.Close()
}
28 changes: 28 additions & 0 deletions helpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"testing"
"time"

"github.com/gorilla/websocket"
"github.com/miekg/dns"
"github.com/satori/go.uuid"

Expand Down Expand Up @@ -90,6 +91,29 @@ const (
)

func testHttpHandler() http.Handler {
var upgrader = websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
}

wsHandler := func(w http.ResponseWriter, req *http.Request) {
conn, err := upgrader.Upgrade(w, req, nil)
if err != nil {
http.Error(w, fmt.Sprintf("cannot upgrade: %v", err), http.StatusInternalServerError)
}

// start simple reader/writer per connection
go func() {
for {
mt, p, err := conn.ReadMessage()
if err != nil {
return
}
conn.WriteMessage(mt, []byte("reply to message: "+string(p)))
}
}()
}

httpError := func(w http.ResponseWriter, status int) {
http.Error(w, http.StatusText(status), status)
}
Expand Down Expand Up @@ -121,6 +145,7 @@ func testHttpHandler() http.Handler {
mux.HandleFunc("/", handleMethod(""))
mux.HandleFunc("/get", handleMethod("GET"))
mux.HandleFunc("/post", handleMethod("POST"))
mux.HandleFunc("/ws", wsHandler)
mux.HandleFunc("/jwk.json", func(w http.ResponseWriter, r *http.Request) {
io.WriteString(w, jwkTestJson)
})
Expand Down Expand Up @@ -182,6 +207,7 @@ type tykTestServerConfig struct {
type tykTestServer struct {
ln net.Listener
cln net.Listener
URL string

globalConfig config.Config
config tykTestServerConfig
Expand Down Expand Up @@ -211,6 +237,8 @@ func (s *tykTestServer) Start() {
} else {
listen(s.ln, s.cln, fmt.Errorf("Without goagain"))
}

s.URL = "http://" + s.ln.Addr().String()
}

func (s *tykTestServer) Close() {
Expand Down
4 changes: 4 additions & 0 deletions reverse_proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,10 @@ func (p *ReverseProxy) WrappedServeHTTP(rw http.ResponseWriter, req *http.Reques
if p.TykAPISpec.HTTPTransport == nil {
_, timeout := p.CheckHardTimeoutEnforced(p.TykAPISpec, req)
p.TykAPISpec.HTTPTransport = httpTransport(timeout, rw, req, p)
} else if IsWebsocket(req) { // check if it is an upgrade request to NEW WS-connection
// overwrite transport's ResponseWriter from previous upgrade request
// as it was already hijacked and now is being used for other connection
p.TykAPISpec.HTTPTransport.(*WSDialer).RW = rw
}

ctx := req.Context()
Expand Down

0 comments on commit 8f320ad

Please sign in to comment.