Skip to content

Commit

Permalink
Use extensions package
Browse files Browse the repository at this point in the history
  • Loading branch information
abursavich committed Sep 9, 2020
1 parent 87a8132 commit 8844d63
Show file tree
Hide file tree
Showing 4 changed files with 291 additions and 91 deletions.
88 changes: 40 additions & 48 deletions accept.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@ import (
"net/textproto"
"net/url"
"path/filepath"
"strconv"
"strings"

"nhooyr.io/websocket/internal/errd"
"nhooyr.io/websocket/internal/extensions"
)

// AcceptOptions represents Accept's options.
Expand Down Expand Up @@ -118,9 +120,10 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con
w.Header().Set("Sec-WebSocket-Protocol", subproto)
}

copts, ok := selectDeflate(websocketExtensions(r.Header), opts.CompressionMode)
exts, _ := extensions.Parse(r.Header.Values(extensions.Header))
copts, ok := selectDeflate(opts.CompressionMode, exts)
if ok {
w.Header().Set("Sec-WebSocket-Extensions", copts.String())
w.Header().Set(extensions.Header, copts.String())
}

w.WriteHeader(http.StatusSwitchingProtocols)
Expand Down Expand Up @@ -230,44 +233,61 @@ func selectSubprotocol(r *http.Request, subprotocols []string) string {
return ""
}

func selectDeflate(extensions []websocketExtension, mode CompressionMode) (*compressionOptions, bool) {
func selectDeflate(mode CompressionMode, exts extensions.Extensions) (*compressionOptions, bool) {
if mode == CompressionDisabled {
return nil, false
}
for _, ext := range extensions {
switch ext.name {
for _, ext := range exts {
switch ext.Name {
case "permessage-deflate":
if copts, ok := acceptDeflate(ext, mode); ok {
if copts, ok := acceptDeflate(mode, ext.Params); ok {
return copts, true
}
}
}
return nil, false
}

func acceptDeflate(ext websocketExtension, mode CompressionMode) (*compressionOptions, bool) {
func acceptDeflate(mode CompressionMode, params extensions.Params) (*compressionOptions, bool) {
copts := mode.opts()
for _, p := range ext.params {
switch p {
seen := make(map[string]bool)
for _, p := range params {
if seen[p.Name] {
return nil, false
}
seen[p.Name] = true

switch p.Name {
case "client_no_context_takeover":
copts.clientNoContextTakeover = true
continue
if p.Value == "" {
copts.clientNoContextTakeover = true
continue
}
case "server_no_context_takeover":
copts.serverNoContextTakeover = true
continue
case "client_max_window_bits",
"server_max_window_bits=15":
continue
}
if strings.HasPrefix(p, "client_max_window_bits=") {
// We can't adjust the deflate window, but decoding with a larger window is acceptable.
continue
if p.Value == "" {
copts.serverNoContextTakeover = true
continue
}
case "client_max_window_bits":
if p.Value == "" || isValidWindowBits(p.Value) {
// We can't adjust the deflate window, but decoding with a larger window is acceptable.
continue
}
case "server_max_window_bits":
if p.Value == "15" {
continue
}
}
return nil, false
}
return copts, true
}

func isValidWindowBits(s string) bool {
i, err := strconv.Atoi(s)
return err == nil && i >= 8 && i <= 15
}

func headerContainsToken(h http.Header, key, token string) bool {
token = strings.ToLower(token)

Expand All @@ -279,34 +299,6 @@ func headerContainsToken(h http.Header, key, token string) bool {
return false
}

type websocketExtension struct {
name string
params []string
}

func websocketExtensions(h http.Header) []websocketExtension {
var exts []websocketExtension
extStrs := headerTokens(h, "Sec-WebSocket-Extensions")
for _, extStr := range extStrs {
if extStr == "" {
continue
}

vals := strings.Split(extStr, ";")
for i := range vals {
vals[i] = strings.TrimSpace(vals[i])
}

e := websocketExtension{
name: vals[0],
params: vals[1:],
}

exts = append(exts, e)
}
return exts
}

func headerTokens(h http.Header, key string) []string {
key = textproto.CanonicalMIMEHeaderKey(key)
var tokens []string
Expand Down
34 changes: 30 additions & 4 deletions accept_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"strings"
"testing"

"nhooyr.io/websocket/internal/extensions"
"nhooyr.io/websocket/internal/test/assert"
)

Expand Down Expand Up @@ -380,14 +381,40 @@ func Test_selectDeflate(t *testing.T) {
},
expOK: true,
},
{
name: "permessage-deflate/first",
mode: CompressionContextTakeover,
header: "permessage-deflate; server_no_context_takeover; client_no_context_takeover, permessage-deflate",
expCopts: &compressionOptions{
clientNoContextTakeover: true,
serverNoContextTakeover: true,
},
expOK: true,
},
{
name: "permessage-deflate/duplicate-parameter",
mode: CompressionContextTakeover,
header: "permessage-deflate; server_no_context_takeover; server_no_context_takeover",
expOK: false,
},
{
name: "permessage-deflate/duplicate-parameter/with-fallback",
mode: CompressionContextTakeover,
header: "permessage-deflate; server_no_context_takeover; server_no_context_takeover, permessage-deflate; server_no_context_takeover",
expCopts: &compressionOptions{
clientNoContextTakeover: false,
serverNoContextTakeover: true,
},
expOK: true,
},
{
name: "permessage-deflate/unknown-parameter",
mode: CompressionNoContextTakeover,
header: "permessage-deflate; meow",
expOK: false,
},
{
name: "permessage-deflate/unknown-parameter",
name: "permessage-deflate/unknown-parameter/with-fallback",
mode: CompressionNoContextTakeover,
header: "permessage-deflate; meow, permessage-deflate; client_max_window_bits",
expCopts: &compressionOptions{
Expand All @@ -403,9 +430,8 @@ func Test_selectDeflate(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()

h := http.Header{}
h.Set("Sec-WebSocket-Extensions", tc.header)
copts, ok := selectDeflate(websocketExtensions(h), tc.mode)
exts, _ := extensions.Parse([]string{tc.header})
copts, ok := selectDeflate(tc.mode, exts)
assert.Equal(t, "selected options", tc.expOK, ok)
assert.Equal(t, "compression options", tc.expCopts, copts)
})
Expand Down
66 changes: 37 additions & 29 deletions dial.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"time"

"nhooyr.io/websocket/internal/errd"
"nhooyr.io/websocket/internal/extensions"
)

// DialOptions represents Dial's options.
Expand Down Expand Up @@ -84,12 +85,7 @@ func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) (
return nil, nil, fmt.Errorf("failed to generate Sec-WebSocket-Key: %w", err)
}

var copts *compressionOptions
if opts.CompressionMode != CompressionDisabled {
copts = opts.CompressionMode.opts()
}

resp, err := handshakeRequest(ctx, urls, opts, copts, secWebSocketKey)
resp, err := handshakeRequest(ctx, urls, opts, secWebSocketKey)
if err != nil {
return nil, resp, err
}
Expand All @@ -111,7 +107,7 @@ func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) (
}
}()

copts, err = verifyServerResponse(opts, copts, secWebSocketKey, resp)
copts, err := verifyServerResponse(opts, secWebSocketKey, resp)
if err != nil {
return nil, resp, err
}
Expand All @@ -132,7 +128,7 @@ func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) (
}), resp, nil
}

func handshakeRequest(ctx context.Context, urls string, opts *DialOptions, copts *compressionOptions, secWebSocketKey string) (*http.Response, error) {
func handshakeRequest(ctx context.Context, urls string, opts *DialOptions, secWebSocketKey string) (*http.Response, error) {
if opts.HTTPClient.Timeout > 0 {
return nil, errors.New("use context for cancellation instead of http.Client.Timeout; see https://github.com/nhooyr/websocket/issues/67")
}
Expand Down Expand Up @@ -161,8 +157,8 @@ func handshakeRequest(ctx context.Context, urls string, opts *DialOptions, copts
if len(opts.Subprotocols) > 0 {
req.Header.Set("Sec-WebSocket-Protocol", strings.Join(opts.Subprotocols, ","))
}
if copts != nil {
req.Header.Set("Sec-WebSocket-Extensions", copts.String())
if opts.CompressionMode != CompressionDisabled {
req.Header.Set("Sec-WebSocket-Extensions", opts.CompressionMode.opts().String())
}

resp, err := opts.HTTPClient.Do(req)
Expand All @@ -184,7 +180,7 @@ func secWebSocketKey(rr io.Reader) (string, error) {
return base64.StdEncoding.EncodeToString(b), nil
}

func verifyServerResponse(opts *DialOptions, copts *compressionOptions, secWebSocketKey string, resp *http.Response) (*compressionOptions, error) {
func verifyServerResponse(opts *DialOptions, secWebSocketKey string, resp *http.Response) (*compressionOptions, error) {
if resp.StatusCode != http.StatusSwitchingProtocols {
return nil, fmt.Errorf("expected handshake response status code %v but got %v", http.StatusSwitchingProtocols, resp.StatusCode)
}
Expand All @@ -209,7 +205,11 @@ func verifyServerResponse(opts *DialOptions, copts *compressionOptions, secWebSo
return nil, err
}

return verifyServerExtensions(copts, resp.Header)
exts, err := extensions.Parse(resp.Header.Values(extensions.Header))
if err != nil {
return nil, fmt.Errorf("WebSocket protcol violation: invalid %s header: %v", extensions.Header, err)
}
return verifyServerExtensions(opts.CompressionMode, exts)
}

func verifySubprotocol(subprotos []string, resp *http.Response) error {
Expand All @@ -227,34 +227,42 @@ func verifySubprotocol(subprotos []string, resp *http.Response) error {
return fmt.Errorf("WebSocket protocol violation: unexpected Sec-WebSocket-Protocol from server: %q", proto)
}

func verifyServerExtensions(copts *compressionOptions, h http.Header) (*compressionOptions, error) {
exts := websocketExtensions(h)
func verifyServerExtensions(mode CompressionMode, exts extensions.Extensions) (*compressionOptions, error) {
if len(exts) == 0 {
return nil, nil
}

ext := exts[0]
if ext.name != "permessage-deflate" || len(exts) > 1 || copts == nil {
return nil, fmt.Errorf("WebSocket protcol violation: unsupported extensions from server: %+v", exts[1:])
if len(exts) > 1 || mode == CompressionDisabled || ext.Name != "permessage-deflate" {
return nil, fmt.Errorf("WebSocket protcol violation: unsupported extensions from server: %q", exts)
}

copts = &*copts
copts := mode.opts()
seen := make(map[string]bool)
for _, p := range ext.Params {
if seen[p.Name] {
return nil, fmt.Errorf("WebSocket protcol violation: duplicate permessage-deflate extension parameter %q from server", p.Name)
}
seen[p.Name] = true
switch p.Name {

for _, p := range ext.params {
switch p {
case "client_no_context_takeover":
copts.clientNoContextTakeover = true
continue
if p.Value == "" {
copts.clientNoContextTakeover = true
continue
}
case "server_no_context_takeover":
copts.serverNoContextTakeover = true
continue
if p.Value == "" {
copts.serverNoContextTakeover = true
continue
}
case "server_max_window_bits":
if p.Value == "" || isValidWindowBits(p.Value) {
// We can't adjust the deflate window, but decoding with a larger window is acceptable.
continue
}
}
if strings.HasPrefix(p, "server_max_window_bits=") {
// We can't adjust the deflate window, but decoding with a larger window is acceptable.
continue
}

return nil, fmt.Errorf("unsupported permessage-deflate parameter: %q", p)
return nil, fmt.Errorf("WebSocket protcol violation: unsupported permessage-deflate extension parameter from server: %v=%q", p.Name, p.Value)
}

return copts, nil
Expand Down
Loading

0 comments on commit 8844d63

Please sign in to comment.