Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
wsheaders: Add package
  • Loading branch information
abursavich committed Sep 12, 2020
1 parent 2da2886 commit 99a67e7
Show file tree
Hide file tree
Showing 6 changed files with 589 additions and 97 deletions.
50 changes: 21 additions & 29 deletions accept.go
Expand Up @@ -4,8 +4,6 @@ package websocket

import (
"bytes"
"crypto/sha1"
"encoding/base64"
"errors"
"fmt"
"io"
Expand All @@ -17,6 +15,7 @@ import (
"strings"

"nhooyr.io/websocket/internal/errd"
"nhooyr.io/websocket/internal/wsheaders"
)

// AcceptOptions represents Accept's options.
Expand Down Expand Up @@ -107,11 +106,11 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con
return nil, err
}

w.Header().Set("Upgrade", "websocket")
w.Header().Set("Connection", "Upgrade")
wsheaders.SetUpgrade(w.Header())
wsheaders.SetConnection(w.Header())

key := r.Header.Get("Sec-WebSocket-Key")
w.Header().Set("Sec-WebSocket-Accept", secWebSocketAccept(key))
challenge, _ := wsheaders.GetChallenge(r.Header)
wsheaders.SetAccept(w.Header(), challenge)

subproto := selectSubprotocol(r, opts.Subprotocols)
if subproto != "" {
Expand Down Expand Up @@ -159,29 +158,32 @@ func verifyClientRequest(w http.ResponseWriter, r *http.Request) (errCode int, _
return http.StatusUpgradeRequired, fmt.Errorf("WebSocket protocol violation: handshake request must be at least HTTP/1.1: %q", r.Proto)
}

if !headerContainsToken(r.Header, "Connection", "Upgrade") {
w.Header().Set("Connection", "Upgrade")
w.Header().Set("Upgrade", "websocket")
return http.StatusUpgradeRequired, fmt.Errorf("WebSocket protocol violation: Connection header %q does not contain Upgrade", r.Header.Get("Connection"))
if err := wsheaders.VerifyConnection(r.Header); err != nil {
wsheaders.SetConnection(w.Header())
wsheaders.SetUpgrade(w.Header())
return http.StatusUpgradeRequired, fmt.Errorf("WebSocket protocol violation: %v", err)
}

if !headerContainsToken(r.Header, "Upgrade", "websocket") {
w.Header().Set("Connection", "Upgrade")
w.Header().Set("Upgrade", "websocket")
return http.StatusUpgradeRequired, fmt.Errorf("WebSocket protocol violation: Upgrade header %q does not contain websocket", r.Header.Get("Upgrade"))
if err := wsheaders.VerifyClientUpgrade(r.Header); err != nil {
wsheaders.SetConnection(w.Header())
wsheaders.SetUpgrade(w.Header())
return http.StatusUpgradeRequired, fmt.Errorf("WebSocket protocol violation: %v", err)
}

if r.Method != "GET" {
return http.StatusMethodNotAllowed, fmt.Errorf("WebSocket protocol violation: handshake request method is not GET but %q", r.Method)
}

if r.Header.Get("Sec-WebSocket-Version") != "13" {
w.Header().Set("Sec-WebSocket-Version", "13")
return http.StatusBadRequest, fmt.Errorf("unsupported WebSocket protocol version (only 13 is supported): %q", r.Header.Get("Sec-WebSocket-Version"))
if v, err := wsheaders.GetVersion(r.Header); v != 13 {
wsheaders.SetVersion(w.Header(), 13)
if err != nil {
return http.StatusUpgradeRequired, fmt.Errorf("WebSocket protocol violation: %v", err)
}
return http.StatusBadRequest, fmt.Errorf("unsupported WebSocket protocol version (only 13 is supported): %d", v)
}

if r.Header.Get("Sec-WebSocket-Key") == "" {
return http.StatusBadRequest, errors.New("WebSocket protocol violation: missing Sec-WebSocket-Key")
if _, err := wsheaders.GetChallenge(r.Header); err != nil {
return http.StatusBadRequest, fmt.Errorf("WebSocket protocol violation: %v", err)
}

return 0, nil
Expand Down Expand Up @@ -320,13 +322,3 @@ func headerTokens(h http.Header, key string) []string {
}
return tokens
}

var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")

func secWebSocketAccept(secWebSocketKey string) string {
h := sha1.New()
h.Write([]byte(secWebSocketKey))
h.Write(keyGUID)

return base64.StdEncoding.EncodeToString(h.Sum(nil))
}
65 changes: 34 additions & 31 deletions accept_test.go
Expand Up @@ -12,8 +12,11 @@ import (
"testing"

"nhooyr.io/websocket/internal/test/assert"
"nhooyr.io/websocket/internal/wsheaders"
)

const validChallenge = "dGhlIHNhbXBsZSBub25jZQ=="

func TestAccept(t *testing.T) {
t.Parallel()

Expand All @@ -32,10 +35,10 @@ func TestAccept(t *testing.T) {

w := httptest.NewRecorder()
r := httptest.NewRequest("GET", "/", nil)
r.Header.Set("Connection", "Upgrade")
r.Header.Set("Upgrade", "websocket")
r.Header.Set("Sec-WebSocket-Version", "13")
r.Header.Set("Sec-WebSocket-Key", "meow123")
wsheaders.SetConnection(r.Header)
wsheaders.SetUpgrade(r.Header)
wsheaders.SetVersion(r.Header, 13)
wsheaders.SetChallenge(r.Header, validChallenge)
r.Header.Set("Origin", "harhar.com")

_, err := Accept(w, r, nil)
Expand All @@ -47,10 +50,10 @@ func TestAccept(t *testing.T) {

newRequest := func(extensions string) *http.Request {
r := httptest.NewRequest("GET", "/", nil)
r.Header.Set("Connection", "Upgrade")
r.Header.Set("Upgrade", "websocket")
r.Header.Set("Sec-WebSocket-Version", "13")
r.Header.Set("Sec-WebSocket-Key", "meow123")
wsheaders.SetConnection(r.Header)
wsheaders.SetUpgrade(r.Header)
wsheaders.SetVersion(r.Header, 13)
wsheaders.SetChallenge(r.Header, validChallenge)
r.Header.Set("Sec-WebSocket-Extensions", extensions)
return r
}
Expand Down Expand Up @@ -93,10 +96,10 @@ func TestAccept(t *testing.T) {

w := httptest.NewRecorder()
r := httptest.NewRequest("GET", "/", nil)
r.Header.Set("Connection", "Upgrade")
r.Header.Set("Upgrade", "websocket")
r.Header.Set("Sec-WebSocket-Version", "13")
r.Header.Set("Sec-WebSocket-Key", "meow123")
wsheaders.SetConnection(r.Header)
wsheaders.SetUpgrade(r.Header)
wsheaders.SetVersion(r.Header, 13)
wsheaders.SetChallenge(r.Header, validChallenge)

_, err := Accept(w, r, nil)
assert.Contains(t, err, `http.ResponseWriter does not implement http.Hijacker`)
Expand All @@ -113,10 +116,10 @@ func TestAccept(t *testing.T) {
}

r := httptest.NewRequest("GET", "/", nil)
r.Header.Set("Connection", "Upgrade")
r.Header.Set("Upgrade", "websocket")
r.Header.Set("Sec-WebSocket-Version", "13")
r.Header.Set("Sec-WebSocket-Key", "meow123")
wsheaders.SetConnection(r.Header)
wsheaders.SetUpgrade(r.Header)
wsheaders.SetVersion(r.Header, 13)
wsheaders.SetChallenge(r.Header, validChallenge)

_, err := Accept(w, r, nil)
assert.Contains(t, err, `failed to hijack connection`)
Expand Down Expand Up @@ -157,37 +160,37 @@ func Test_verifyClientHandshake(t *testing.T) {
{
name: "badWebSocketVersion",
h: map[string]string{
"Connection": "Upgrade",
"Upgrade": "websocket",
"Sec-WebSocket-Version": "14",
"Connection": "Upgrade",
"Upgrade": "websocket",
wsheaders.VersionKey: "14",
},
},
{
name: "badWebSocketKey",
h: map[string]string{
"Connection": "Upgrade",
"Upgrade": "websocket",
"Sec-WebSocket-Version": "13",
"Sec-WebSocket-Key": "",
"Connection": "Upgrade",
"Upgrade": "websocket",
wsheaders.VersionKey: "13",
wsheaders.ChallengeKey: "",
},
},
{
name: "badHTTPVersion",
h: map[string]string{
"Connection": "Upgrade",
"Upgrade": "websocket",
"Sec-WebSocket-Version": "13",
"Sec-WebSocket-Key": "meow123",
"Connection": "Upgrade",
"Upgrade": "websocket",
wsheaders.VersionKey: "13",
wsheaders.ChallengeKey: validChallenge,
},
http1: true,
},
{
name: "success",
h: map[string]string{
"Connection": "keep-alive, Upgrade",
"Upgrade": "websocket",
"Sec-WebSocket-Version": "13",
"Sec-WebSocket-Key": "meow123",
"Connection": "keep-alive, Upgrade",
"Upgrade": "websocket",
wsheaders.VersionKey: "13",
wsheaders.ChallengeKey: validChallenge,
},
success: true,
},
Expand Down
24 changes: 11 additions & 13 deletions dial.go
Expand Up @@ -19,6 +19,7 @@ import (
"time"

"nhooyr.io/websocket/internal/errd"
"nhooyr.io/websocket/internal/wsheaders"
)

// DialOptions represents Dial's options.
Expand Down Expand Up @@ -154,10 +155,10 @@ func handshakeRequest(ctx context.Context, urls string, opts *DialOptions, copts

req, _ := http.NewRequestWithContext(ctx, "GET", u.String(), nil)
req.Header = opts.HTTPHeader.Clone()
req.Header.Set("Connection", "Upgrade")
req.Header.Set("Upgrade", "websocket")
req.Header.Set("Sec-WebSocket-Version", "13")
req.Header.Set("Sec-WebSocket-Key", secWebSocketKey)
wsheaders.SetConnection(req.Header)
wsheaders.SetUpgrade(req.Header)
wsheaders.SetVersion(req.Header, 13)
wsheaders.SetChallenge(req.Header, secWebSocketKey)
if len(opts.Subprotocols) > 0 {
req.Header.Set("Sec-WebSocket-Protocol", strings.Join(opts.Subprotocols, ","))
}
Expand Down Expand Up @@ -189,19 +190,16 @@ func verifyServerResponse(opts *DialOptions, copts *compressionOptions, secWebSo
return nil, fmt.Errorf("expected handshake response status code %v but got %v", http.StatusSwitchingProtocols, resp.StatusCode)
}

if !headerContainsToken(resp.Header, "Connection", "Upgrade") {
return nil, fmt.Errorf("WebSocket protocol violation: Connection header %q does not contain Upgrade", resp.Header.Get("Connection"))
if err := wsheaders.VerifyConnection(resp.Header); err != nil {
return nil, fmt.Errorf("WebSocket protocol violation: %v", err)
}

if !headerContainsToken(resp.Header, "Upgrade", "WebSocket") {
return nil, fmt.Errorf("WebSocket protocol violation: Upgrade header %q does not contain websocket", resp.Header.Get("Upgrade"))
if err := wsheaders.VerifyServerUpgrade(resp.Header); err != nil {
return nil, fmt.Errorf("WebSocket protocol violation: %v", err)
}

if resp.Header.Get("Sec-WebSocket-Accept") != secWebSocketAccept(secWebSocketKey) {
return nil, fmt.Errorf("WebSocket protocol violation: invalid Sec-WebSocket-Accept %q, key %q",
resp.Header.Get("Sec-WebSocket-Accept"),
secWebSocketKey,
)
if err := wsheaders.VerifyAccept(resp.Header, secWebSocketKey); err != nil {
return nil, fmt.Errorf("WebSocket protocol violation: %v", err)
}

err := verifySubprotocol(opts.Subprotocols, resp)
Expand Down
53 changes: 29 additions & 24 deletions dial_test.go
Expand Up @@ -14,6 +14,7 @@ import (
"time"

"nhooyr.io/websocket/internal/test/assert"
"nhooyr.io/websocket/internal/wsheaders"
)

func TestBadDials(t *testing.T) {
Expand Down Expand Up @@ -98,10 +99,15 @@ func TestBadDials(t *testing.T) {
defer cancel()

rt := func(r *http.Request) (*http.Response, error) {
challenge, err := wsheaders.GetChallenge(r.Header)
if err != nil {
return nil, err
}

h := http.Header{}
h.Set("Connection", "Upgrade")
h.Set("Upgrade", "websocket")
h.Set("Sec-WebSocket-Accept", secWebSocketAccept(r.Header.Get("Sec-WebSocket-Key")))
wsheaders.SetConnection(h)
wsheaders.SetUpgrade(h)
wsheaders.SetAccept(h, challenge)

return &http.Response{
StatusCode: http.StatusSwitchingProtocols,
Expand Down Expand Up @@ -143,7 +149,7 @@ func Test_verifyServerHandshake(t *testing.T) {
{
name: "badUpgrade",
response: func(w http.ResponseWriter) {
w.Header().Set("Connection", "Upgrade")
wsheaders.SetConnection(w.Header())
w.Header().Set("Upgrade", "???")
w.WriteHeader(http.StatusSwitchingProtocols)
},
Expand All @@ -152,18 +158,18 @@ func Test_verifyServerHandshake(t *testing.T) {
{
name: "badSecWebSocketAccept",
response: func(w http.ResponseWriter) {
w.Header().Set("Connection", "Upgrade")
w.Header().Set("Upgrade", "websocket")
w.Header().Set("Sec-WebSocket-Accept", "xd")
wsheaders.SetConnection(w.Header())
wsheaders.SetUpgrade(w.Header())
w.Header().Set(wsheaders.AcceptKey, "xd")
w.WriteHeader(http.StatusSwitchingProtocols)
},
success: false,
},
{
name: "badSecWebSocketProtocol",
response: func(w http.ResponseWriter) {
w.Header().Set("Connection", "Upgrade")
w.Header().Set("Upgrade", "websocket")
wsheaders.SetConnection(w.Header())
wsheaders.SetUpgrade(w.Header())
w.Header().Set("Sec-WebSocket-Protocol", "xd")
w.WriteHeader(http.StatusSwitchingProtocols)
},
Expand All @@ -172,8 +178,8 @@ func Test_verifyServerHandshake(t *testing.T) {
{
name: "unsupportedExtension",
response: func(w http.ResponseWriter) {
w.Header().Set("Connection", "Upgrade")
w.Header().Set("Upgrade", "websocket")
wsheaders.SetConnection(w.Header())
wsheaders.SetUpgrade(w.Header())
w.Header().Set("Sec-WebSocket-Extensions", "meow")
w.WriteHeader(http.StatusSwitchingProtocols)
},
Expand All @@ -182,8 +188,8 @@ func Test_verifyServerHandshake(t *testing.T) {
{
name: "unsupportedDeflateParam",
response: func(w http.ResponseWriter) {
w.Header().Set("Connection", "Upgrade")
w.Header().Set("Upgrade", "websocket")
wsheaders.SetConnection(w.Header())
wsheaders.SetUpgrade(w.Header())
w.Header().Set("Sec-WebSocket-Extensions", "permessage-deflate; meow")
w.WriteHeader(http.StatusSwitchingProtocols)
},
Expand All @@ -192,8 +198,8 @@ func Test_verifyServerHandshake(t *testing.T) {
{
name: "success",
response: func(w http.ResponseWriter) {
w.Header().Set("Connection", "Upgrade")
w.Header().Set("Upgrade", "websocket")
wsheaders.SetConnection(w.Header())
wsheaders.SetUpgrade(w.Header())
w.WriteHeader(http.StatusSwitchingProtocols)
},
success: true,
Expand All @@ -205,21 +211,20 @@ func Test_verifyServerHandshake(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()

w := httptest.NewRecorder()
tc.response(w)
resp := w.Result()

r := httptest.NewRequest("GET", "/", nil)
req := httptest.NewRequest("GET", "/", nil)
key, err := secWebSocketKey(rand.Reader)
assert.Success(t, err)
r.Header.Set("Sec-WebSocket-Key", key)
wsheaders.SetChallenge(req.Header, key)

if resp.Header.Get("Sec-WebSocket-Accept") == "" {
resp.Header.Set("Sec-WebSocket-Accept", secWebSocketAccept(key))
w := httptest.NewRecorder()
tc.response(w)
resp := w.Result()
if resp.Header.Get(wsheaders.AcceptKey) == "" {
wsheaders.SetAccept(resp.Header, key)
}

opts := &DialOptions{
Subprotocols: strings.Split(r.Header.Get("Sec-WebSocket-Protocol"), ","),
Subprotocols: strings.Split(req.Header.Get("Sec-WebSocket-Protocol"), ","),
}
_, err = verifyServerResponse(opts, opts.CompressionMode.opts(), key, resp)
if tc.success {
Expand Down

0 comments on commit 99a67e7

Please sign in to comment.