/
protocol.go
97 lines (78 loc) · 2.56 KB
/
protocol.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
// Package transportws implements the subscriptions-transport-ws protocol.
package transportws
import (
"context"
"net/http"
"time"
"github.com/99designs/gqlgen/graphql"
"github.com/Desuuuu/gqlgenws/internal/util"
"github.com/Desuuuu/gqlgenws/wsutil"
"nhooyr.io/websocket"
)
const ProtocolName = "graphql-ws"
const defaultInitTimeout = 3 * time.Second
// Protocol implements the subscriptions-transport-ws protocol described here:
// https://github.com/apollographql/subscriptions-transport-ws.
//
// Protocol can be used as a gqlgenws.Protocol or directly as a gqlgen transport.
type Protocol struct {
// InitFunc is called after receiving the "connection_init" message with the
// WebSocket handshake HTTP request and the message payload.
//
// The returned Context, if not nil, is provided to GraphQL resolvers. When
// the Context is done, the connection is also closed.
//
// The returned ObjectPayload, if not nil, is used as the payload for the
// "connection_ack" message.
//
// If a non-nil error is returned, the connection is closed.
//
// If InitFunc is nil, all connections are accepted.
InitFunc func(*http.Request, wsutil.ObjectPayload) (context.Context, wsutil.ObjectPayload, error)
// InitTimeout is the duration to wait for a "connection_init" message before
// closing the connection.
//
// Defaults to 3 seconds.
InitTimeout time.Duration
// If KeepAliveInterval is set, a "ka" message is sent if no message is
// received for the specified duration.
KeepAliveInterval time.Duration
// AcceptOptions defines options used during the WebSocket handshake.
AcceptOptions websocket.AcceptOptions
}
var _ graphql.Transport = &Protocol{}
func (*Protocol) Supports(r *http.Request) (res bool) {
if !wsutil.IsUpgrade(r) {
return false
}
if !util.HasHeader(r.Header, "Sec-WebSocket-Protocol") {
return true
}
return util.HeaderContains(r.Header, "Sec-WebSocket-Protocol", ProtocolName)
}
func (p *Protocol) Do(w http.ResponseWriter, r *http.Request, exec graphql.GraphExecutor) {
if len(p.AcceptOptions.Subprotocols) == 0 {
p.AcceptOptions.Subprotocols = []string{ProtocolName}
}
c, err := websocket.Accept(w, r, &p.AcceptOptions)
if err != nil {
return
}
p.Run(r, c, exec)
}
func (*Protocol) Name() string {
return ProtocolName
}
func (p *Protocol) Run(r *http.Request, c *websocket.Conn, exec graphql.GraphExecutor) {
if p.InitTimeout.Nanoseconds() <= 0 {
p.InitTimeout = defaultInitTimeout
}
conn := &connection{
protocol: p,
conn: c,
req: r,
ctx: r.Context(),
exec: exec,
}
conn.close(conn.run())
}