Skip to content

Commit

Permalink
allow api spec to keep both transports to http and ws upstreams
Browse files Browse the repository at this point in the history
  • Loading branch information
dencoded committed Jul 13, 2018
1 parent 863fccd commit 7c149d9
Show file tree
Hide file tree
Showing 4 changed files with 131 additions and 19 deletions.
4 changes: 2 additions & 2 deletions .travis.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
sudo: false
sudo: required
language: go

notifications:
Expand All @@ -14,7 +14,7 @@ addons:
sources:
- sourceline: 'ppa:opencpu/jq'
packages:
- python3-dev
- python3.5
- python3-pip
- libluajit-5.1-dev
- libjq-dev
Expand Down
2 changes: 2 additions & 0 deletions api_definition.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,8 @@ type APISpec struct {
ServiceRefreshInProgress bool
HTTPTransport http.RoundTripper
HTTPTransportCreated time.Time
WSTransport http.RoundTripper
WSTransportCreated time.Time
GlobalConfig config.Config
}

Expand Down
88 changes: 88 additions & 0 deletions gateway_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1108,6 +1108,94 @@ func TestWebsocketsSeveralOpenClose(t *testing.T) {
conn3.Close()
}

func TestWebsocketsAndHTTPEndpointMatch(t *testing.T) {
globalConf := config.Global()
globalConf.HttpServerOptions.EnableWebSockets = true
config.SetGlobal(globalConf)
defer resetTestConfig()

ts := newTykTestServer()
defer ts.Close()

buildAndLoadAPI(func(spec *APISpec) {
spec.Proxy.ListenPath = "/"
})

baseURL := strings.Replace(ts.URL, "http://", "ws://", -1)

// connect to ws, send 1st message and check reply
wsConn, _, err := websocket.DefaultDialer.Dial(baseURL+"/ws", nil)
if err != nil {
t.Fatalf("cannot make websocket connection: %v", err)
}
err = wsConn.WriteMessage(websocket.BinaryMessage, []byte("test message 1"))
if err != nil {
t.Fatalf("cannot write message: %v", err)
}
_, p, err := wsConn.ReadMessage()
if err != nil {
t.Fatalf("cannot read message: %v", err)
}
if string(p) != "reply to message: test message 1" {
t.Error("Unexpected reply:", string(p))
}

// make 1st http request
ts.Run(t, test.TestCase{
Method: "GET",
Path: "/abc",
Code: http.StatusOK,
})

// send second WS connection upgrade request
// connect to ws, send 1st message and check reply
wsConn2, _, err := websocket.DefaultDialer.Dial(baseURL+"/ws", nil)
if err != nil {
t.Fatalf("cannot make websocket connection: %v", err)
}
err = wsConn2.WriteMessage(websocket.BinaryMessage, []byte("test message 1 to ws 2"))
if err != nil {
t.Fatalf("cannot write message: %v", err)
}
_, p, err = wsConn2.ReadMessage()
if err != nil {
t.Fatalf("cannot read message: %v", err)
}
if string(p) != "reply to message: test message 1 to ws 2" {
t.Error("Unexpected reply:", string(p))
}
wsConn2.Close()

// send second message to WS and check reply
err = wsConn.WriteMessage(websocket.BinaryMessage, []byte("test message 2"))
if err != nil {
t.Fatalf("cannot write message: %v", err)
}
_, p, err = wsConn.ReadMessage()
if err != nil {
t.Fatalf("cannot read message: %v", err)
}
if string(p) != "reply to message: test message 2" {
t.Error("Unexpected reply:", string(p))
}

// make 2nd http request
ts.Run(t, test.TestCase{
Method: "GET",
Path: "/abc",
Code: http.StatusOK,
})

wsConn.Close()

// make 3d http request after closing WS connection
ts.Run(t, test.TestCase{
Method: "GET",
Path: "/abc",
Code: http.StatusOK,
})
}

func createTestUptream(t *testing.T, allowedConns int, readsPerConn int) net.Listener {
l, _ := net.Listen("tcp", "127.0.0.1:0")
go func() {
Expand Down
56 changes: 39 additions & 17 deletions reverse_proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -483,24 +483,47 @@ func httpTransport(timeOut int, rw http.ResponseWriter, req *http.Request, p *Re
}

func (p *ReverseProxy) WrappedServeHTTP(rw http.ResponseWriter, req *http.Request, withCache bool) *http.Response {
// 1. Check if timeouts are set for this endpoint
outReqIsWebsocket := IsWebsocket(req)

var roundTripper http.RoundTripper

p.TykAPISpec.Lock()
if !outReqIsWebsocket { // check if it is a regular HTTP request
// create HTTP transport
createTransport := p.TykAPISpec.HTTPTransport == nil

// Check if timeouts are set for this endpoint
if !createTransport && config.Global().MaxConnTime != 0 {
createTransport = time.Since(p.TykAPISpec.HTTPTransportCreated) > time.Duration(config.Global().MaxConnTime)*time.Second
}

createTransport := p.TykAPISpec.HTTPTransport == nil
if createTransport {
_, timeout := p.CheckHardTimeoutEnforced(p.TykAPISpec, req)
p.TykAPISpec.HTTPTransport = httpTransport(timeout, rw, req, p)
p.TykAPISpec.HTTPTransportCreated = time.Now()
}

if !createTransport && config.Global().MaxConnTime != 0 {
createTransport = time.Since(p.TykAPISpec.HTTPTransportCreated) > time.Duration(config.Global().MaxConnTime)*time.Second
}
roundTripper = p.TykAPISpec.HTTPTransport
} else { // this is NEW WS-connection upgrade request
// create WS transport
createTransport := p.TykAPISpec.WSTransport == nil

// Check if timeouts are set for this endpoint
if !createTransport && config.Global().MaxConnTime != 0 {
createTransport = time.Since(p.TykAPISpec.WSTransportCreated) > time.Duration(config.Global().MaxConnTime)*time.Second
}

if createTransport {
_, timeout := p.CheckHardTimeoutEnforced(p.TykAPISpec, req)
p.TykAPISpec.WSTransport = httpTransport(timeout, rw, req, p)
p.TykAPISpec.WSTransportCreated = time.Now()
}

if createTransport {
_, timeout := p.CheckHardTimeoutEnforced(p.TykAPISpec, req)
p.TykAPISpec.HTTPTransport = httpTransport(timeout, rw, req, p)
p.TykAPISpec.HTTPTransportCreated = time.Now()
} else if IsWebsocket(req) { // check if it is an upgrade request to NEW WS-connection
// overwrite transport's ResponseWriter from previous upgrade request
// as it was already hijacked and now is being used for other connection
p.TykAPISpec.HTTPTransportCreated = time.Now()
p.TykAPISpec.HTTPTransport.(*WSDialer).RW = rw
p.TykAPISpec.WSTransport.(*WSDialer).RW = rw

roundTripper = p.TykAPISpec.WSTransport
}
p.TykAPISpec.Unlock()

Expand Down Expand Up @@ -550,7 +573,6 @@ func (p *ReverseProxy) WrappedServeHTTP(rw http.ResponseWriter, req *http.Reques
outreq.Close = false

log.Debug("Outbound Request: ", outreq.URL.String())
outReqIsWebsocket := IsWebsocket(outreq)

// Do not modify outbound request headers if they are WS
if !outReqIsWebsocket {
Expand Down Expand Up @@ -590,9 +612,9 @@ func (p *ReverseProxy) WrappedServeHTTP(rw http.ResponseWriter, req *http.Reques

p.TykAPISpec.Lock()
if outReqIsWebsocket {
p.TykAPISpec.HTTPTransport.(*WSDialer).TLSClientConfig.Certificates = tlsCertificates
roundTripper.(*WSDialer).TLSClientConfig.Certificates = tlsCertificates
} else {
p.TykAPISpec.HTTPTransport.(*http.Transport).TLSClientConfig.Certificates = tlsCertificates
roundTripper.(*http.Transport).TLSClientConfig.Certificates = tlsCertificates
}
p.TykAPISpec.Unlock()

Expand All @@ -605,14 +627,14 @@ func (p *ReverseProxy) WrappedServeHTTP(rw http.ResponseWriter, req *http.Reques
p.ErrorHandler.HandleError(rw, logreq, "Service temporarily unnavailable.", 503)
return nil
}
res, err = p.TykAPISpec.HTTPTransport.RoundTrip(outreq)
res, err = roundTripper.RoundTrip(outreq)
if err != nil || res.StatusCode == http.StatusInternalServerError {
breakerConf.CB.Fail()
} else {
breakerConf.CB.Success()
}
} else {
res, err = p.TykAPISpec.HTTPTransport.RoundTrip(outreq)
res, err = roundTripper.RoundTrip(outreq)
}

if err != nil {
Expand Down

0 comments on commit 7c149d9

Please sign in to comment.