-
Notifications
You must be signed in to change notification settings - Fork 947
/
httpstream.go
246 lines (213 loc) · 8.14 KB
/
httpstream.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
package portforward
import (
"context"
"fmt"
"net/http"
"strconv"
"sync"
"time"
"github.com/alibaba/pouch/cri/stream/constant"
"github.com/alibaba/pouch/cri/stream/httpstream"
"github.com/alibaba/pouch/cri/stream/httpstream/spdy"
"github.com/alibaba/pouch/pkg/collect"
"github.com/alibaba/pouch/pkg/log"
)
// httpStreamReceived is the httpstream.NewStreamHandler for port
// forward streams. It checks each stream's port and stream type headers,
// rejecting any streams that with missing or invalid values. Each valid
// stream is sent to the streams channel.
func httpStreamReceived(streams chan httpstream.Stream) func(httpstream.Stream, <-chan struct{}) error {
return func(stream httpstream.Stream, replySent <-chan struct{}) error {
// Make sure it has a valid port header.
portString := stream.Headers().Get(constant.PortHeader)
if len(portString) == 0 {
return fmt.Errorf("%q header is required", constant.PortHeader)
}
port, err := strconv.ParseUint(portString, 10, 16)
if err != nil {
return fmt.Errorf("unable to parse %q as a port: %v", portString, err)
}
if port < 1 {
return fmt.Errorf("port %q must be > 0", portString)
}
// Make sure it has a valid stream type header.
streamType := stream.Headers().Get(constant.StreamType)
if len(streamType) == 0 {
return fmt.Errorf("%q header is required", constant.StreamType)
}
if streamType != constant.StreamTypeError && streamType != constant.StreamTypeData {
return fmt.Errorf("invalid stream type %q", streamType)
}
streams <- stream
return nil
}
}
func handleHTTPStreams(ctx context.Context, w http.ResponseWriter, req *http.Request, portForwarder PortForwarder, podName string, idleTimeout, streamCreationTimeout time.Duration, supportedPortForwardProtocols []string) error {
_, err := httpstream.Handshake(w, req, supportedPortForwardProtocols)
// Negotiated protocol isn't currently used server side, but could be in the future.
if err != nil {
// Handshake writes the error to the client
return err
}
streamChan := make(chan httpstream.Stream, 1)
log.With(ctx).Infof("upgrading port forward response")
upgrader := spdy.NewResponseUpgrader()
conn := upgrader.UpgradeResponse(w, req, httpStreamReceived(streamChan))
if conn == nil {
return fmt.Errorf("unable to upgrade connection")
}
defer conn.Close()
log.With(ctx).Infof("setting forwarding streaming connection idle timeout to %v", idleTimeout)
conn.SetIdleTimeout(idleTimeout)
h := &httpStreamHandler{
conn: conn,
streamChan: streamChan,
streamPairs: collect.NewSafeMap(),
streamCreationTimeout: streamCreationTimeout,
pod: podName,
forwarder: portForwarder,
}
h.run(ctx)
return nil
}
// httpStreamHandler is capable of processing multiple port forward
// requests over a single httpstream.Connection.
type httpStreamHandler struct {
conn httpstream.Connection
streamChan chan httpstream.Stream
streamPairs *collect.SafeMap
streamCreationTimeout time.Duration
pod string
forwarder PortForwarder
}
// getStreamPair returns a httpStreamPair for requestID. This creates a
// new pair if one does not yet exist for the requestID. The returned bool is
// true if the pair was created.
func (h *httpStreamHandler) getStreamPair(requestID string) (*httpStreamPair, bool) {
p, ok := h.streamPairs.Get(requestID).Result()
if ok {
log.With(nil).Infof("portforward of cri: found existing stream pair for request %s", requestID)
return p.(*httpStreamPair), false
}
log.With(nil).Infof("portforward of cri: creating new stream pair for request %s", requestID)
pair := newPortForwardPair(requestID)
h.streamPairs.Put(requestID, pair)
return pair, true
}
// removeStreamPair removes the stream pair identified by requestID from streamPairs.
func (h *httpStreamHandler) removeStreamPair(requestID string) {
h.streamPairs.Remove(requestID)
}
// monitorStreamPair waits for the pair to receive both its error and data
// streams, or for the timeout to expire (whichever happens first), and then
// removes the pair.
func (h *httpStreamHandler) monitorStreamPair(p *httpStreamPair, timeout <-chan time.Time) {
select {
case <-timeout:
msg := fmt.Sprintf("portforward of cri: timed out waiting for streams of request %s", p.requestID)
p.printError(msg)
case <-p.complete:
log.With(nil).Infof("portforward of cri: successfully received error and data streams of request %s", p.requestID)
}
h.removeStreamPair(p.requestID)
}
// requestID returns the request id for stream.
func (h *httpStreamHandler) requestID(stream httpstream.Stream) string {
requestID := stream.Headers().Get(constant.PortForwardRequestIDHeader)
if len(requestID) == 0 {
// TODO: support the connection come from the older client
// that isn't generating the request id header.
}
return requestID
}
// run is the main loop for the httpStreamHandler. It process new streams,
// invoking portForward for each complete stream pair. The loop exits
// when the httpstream.Connection is closed.
func (h *httpStreamHandler) run(ctx context.Context) {
log.With(ctx).Infof("portforward of cri: waiting for streams")
for {
select {
case <-h.conn.CloseChan():
log.With(ctx).Infof("portforward of cri: upgraded connection closed")
return
case stream := <-h.streamChan:
requestID := h.requestID(stream)
streamType := stream.Headers().Get(constant.StreamType)
log.With(ctx).Infof("portForward of cri: received new stream of type %s, request %s", streamType, requestID)
p, created := h.getStreamPair(requestID)
if created {
go h.monitorStreamPair(p, time.After(h.streamCreationTimeout))
}
if complete, err := p.add(stream); err != nil {
msg := fmt.Sprintf("portforward of cri: error processing stream for request %s: %v", requestID, err)
p.printError(msg)
} else if complete {
go h.portForward(ctx, p)
}
}
}
}
// portForward invokes the httpStreamHandler's forwarder.PortForward
// function for the given stream pair.
func (h *httpStreamHandler) portForward(ctx context.Context, p *httpStreamPair) {
defer p.dataStream.Close()
defer p.errorStream.Close()
portString := p.dataStream.Headers().Get(constant.PortHeader)
port, _ := strconv.ParseInt(portString, 10, 32)
log.With(ctx).Infof("portforward of cri: invoking forwarder.PortForward for port %s of request %s", portString, p.requestID)
err := h.forwarder.PortForward(ctx, h.pod, int32(port), p.dataStream)
log.With(ctx).Infof("portforward of cri: done invoking forwarder.PortForward for port %s of request %s", portString, p.requestID)
if err != nil {
msg := fmt.Sprintf("portforward of cri: error forwarding port %d to pod %s: %v", port, h.pod, err)
p.printError(msg)
}
}
// httpStreamPair represents the error and data streams for a port
// forwarding request.
type httpStreamPair struct {
lock sync.RWMutex
requestID string
dataStream httpstream.Stream
errorStream httpstream.Stream
complete chan struct{}
}
// newPortForwardPair creates a new httpStreamPair.
func newPortForwardPair(requestID string) *httpStreamPair {
return &httpStreamPair{
requestID: requestID,
complete: make(chan struct{}),
}
}
// add adds the stream to the httpStreamPair. If the pair already
// contains a stream for the new stream's type, an error is returned. add
// returns true if both the data and error streams for this pair have been
// received.
func (p *httpStreamPair) add(stream httpstream.Stream) (bool, error) {
p.lock.Lock()
defer p.lock.Unlock()
switch stream.Headers().Get(constant.StreamType) {
case constant.StreamTypeError:
if p.errorStream != nil {
return false, fmt.Errorf("error stream already assigned")
}
p.errorStream = stream
case constant.StreamTypeData:
if p.dataStream != nil {
return false, fmt.Errorf("data stream already assigned")
}
p.dataStream = stream
}
complete := p.errorStream != nil && p.dataStream != nil
if complete {
close(p.complete)
}
return complete, nil
}
// printError writes s to p.errorStream if p.errorStream has been set.
func (p *httpStreamPair) printError(s string) {
p.lock.RLock()
defer p.lock.RUnlock()
if p.errorStream != nil {
fmt.Fprint(p.errorStream, s)
}
}