-
Notifications
You must be signed in to change notification settings - Fork 649
/
writer_server.go
122 lines (101 loc) · 3.23 KB
/
writer_server.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
// Copyright (C) 2019-2023, Ava Labs, Inc. All rights reserved.
// See the file LICENSE for licensing terms.
package gresponsewriter
import (
"context"
"errors"
"net/http"
"golang.org/x/exp/maps"
"google.golang.org/protobuf/types/known/emptypb"
"github.com/ava-labs/avalanchego/vms/rpcchainvm/ghttp/gconn"
"github.com/ava-labs/avalanchego/vms/rpcchainvm/ghttp/greader"
"github.com/ava-labs/avalanchego/vms/rpcchainvm/ghttp/gwriter"
"github.com/ava-labs/avalanchego/vms/rpcchainvm/grpcutils"
responsewriterpb "github.com/ava-labs/avalanchego/proto/pb/http/responsewriter"
readerpb "github.com/ava-labs/avalanchego/proto/pb/io/reader"
writerpb "github.com/ava-labs/avalanchego/proto/pb/io/writer"
connpb "github.com/ava-labs/avalanchego/proto/pb/net/conn"
)
var (
errUnsupportedFlushing = errors.New("response writer doesn't support flushing")
errUnsupportedHijacking = errors.New("response writer doesn't support hijacking")
_ responsewriterpb.WriterServer = (*Server)(nil)
)
// Server is an http.ResponseWriter that is managed over RPC.
type Server struct {
responsewriterpb.UnsafeWriterServer
writer http.ResponseWriter
}
// NewServer returns an http.ResponseWriter instance managed remotely
func NewServer(writer http.ResponseWriter) *Server {
return &Server{
writer: writer,
}
}
func (s *Server) Write(
_ context.Context,
req *responsewriterpb.WriteRequest,
) (*responsewriterpb.WriteResponse, error) {
headers := s.writer.Header()
maps.Clear(headers)
for _, header := range req.Headers {
headers[header.Key] = header.Values
}
n, err := s.writer.Write(req.Payload)
if err != nil {
return nil, err
}
return &responsewriterpb.WriteResponse{
Written: int32(n),
}, nil
}
func (s *Server) WriteHeader(
_ context.Context,
req *responsewriterpb.WriteHeaderRequest,
) (*emptypb.Empty, error) {
headers := s.writer.Header()
maps.Clear(headers)
for _, header := range req.Headers {
headers[header.Key] = header.Values
}
s.writer.WriteHeader(grpcutils.EnsureValidResponseCode(int(req.StatusCode)))
return &emptypb.Empty{}, nil
}
func (s *Server) Flush(context.Context, *emptypb.Empty) (*emptypb.Empty, error) {
flusher, ok := s.writer.(http.Flusher)
if !ok {
return nil, errUnsupportedFlushing
}
flusher.Flush()
return &emptypb.Empty{}, nil
}
func (s *Server) Hijack(context.Context, *emptypb.Empty) (*responsewriterpb.HijackResponse, error) {
hijacker, ok := s.writer.(http.Hijacker)
if !ok {
return nil, errUnsupportedHijacking
}
conn, readWriter, err := hijacker.Hijack()
if err != nil {
return nil, err
}
serverListener, err := grpcutils.NewListener()
if err != nil {
return nil, err
}
server := grpcutils.NewServer()
closer := grpcutils.ServerCloser{}
closer.Add(server)
connpb.RegisterConnServer(server, gconn.NewServer(conn, &closer))
readerpb.RegisterReaderServer(server, greader.NewServer(readWriter))
writerpb.RegisterWriterServer(server, gwriter.NewServer(readWriter))
go grpcutils.Serve(serverListener, server)
local := conn.LocalAddr()
remote := conn.RemoteAddr()
return &responsewriterpb.HijackResponse{
LocalNetwork: local.Network(),
LocalString: local.String(),
RemoteNetwork: remote.Network(),
RemoteString: remote.String(),
ServerAddr: serverListener.Addr().String(),
}, nil
}