/
proxy-invoker.go
126 lines (112 loc) · 3.77 KB
/
proxy-invoker.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
package bifrost_rpc_access
import (
"errors"
"io"
bifrost_rpc "github.com/aperturerobotics/bifrost/rpc"
"github.com/aperturerobotics/starpc/rpcstream"
"github.com/aperturerobotics/starpc/srpc"
)
// ProxyInvoker is an srpc.Invoker that invokes via the proxy client.
type ProxyInvoker struct {
client SRPCAccessRpcServiceClient
req *LookupRpcServiceRequest
waitAck bool
}
// NewProxyInvoker constructs a new srpc.Invoker with a client and request.
//
// if waitAck is set, waits for ack from the remote before starting the proxied rpc.
// note: usually you do not need waitAck set to true.
func NewProxyInvoker(client SRPCAccessRpcServiceClient, req *LookupRpcServiceRequest, waitAck bool) *ProxyInvoker {
return &ProxyInvoker{client: client, req: req, waitAck: waitAck}
}
// InvokeMethod invokes the method matching the service & method ID.
// Returns false, nil if not found.
// If service string is empty, ignore it.
func (r *ProxyInvoker) InvokeMethod(serviceID, methodID string, strm srpc.Stream) (bool, error) {
req := r.req
if serviceID != "" && serviceID != req.GetServiceId() {
req = req.CloneVT()
req.ServiceId = serviceID
}
componentID, err := req.MarshalComponentID()
if err != nil {
return false, err
}
// Remote will lookup the service, then return either an error or ack.
rpcStream, err := rpcstream.OpenRpcStream(strm.Context(), r.client.CallRpcService, componentID, r.waitAck)
if err != nil {
return false, err
}
defer rpcStream.Close()
// each packet in rpcStream is now either an Ack or a Body packet.
// each Body packet contains a *srpc.Packet from the remote service.
// Start the RPC with the remote
startPkt := srpc.NewCallStartPacket(serviceID, methodID, nil, false)
packetWriter := rpcstream.NewRpcStreamWriter(rpcStream)
if err := packetWriter.WritePacket(startPkt); err != nil {
return false, err
}
errCh := make(chan error, 3)
// Read messages from prw -> write to invoker stream.
go func() {
proxyMsg := srpc.NewRawMessage(nil, false) // zero-copy mode
// We have to handle the Packet here because srpc.Stream MsgSend will be
// encoded and wrapped in a Body packet.
handler := srpc.NewPacketDataHandler(func(pkt *srpc.Packet) error {
switch body := pkt.GetBody().(type) {
case *srpc.Packet_CallCancel:
// unexpected from server -> client but handle anyway
return errors.New("rpc canceled by the remote")
case *srpc.Packet_CallData:
data, dataIsZero := body.CallData.GetData(), body.CallData.GetDataIsZero()
complete, errStr := body.CallData.GetComplete(), body.CallData.GetError()
if len(data) != 0 || dataIsZero {
proxyMsg.SetData(data)
if err := strm.MsgSend(proxyMsg); err != nil {
return err
}
}
if errStr != "" {
return errors.New(errStr)
}
if complete {
errCh <- nil
return io.EOF
}
}
return nil
})
errCh <- rpcstream.ReadToHandler(rpcStream, handler)
}()
// Write messages from invoker stream -> rpc client.
go func() {
readMsg := srpc.NewRawMessage(nil, false) // zero-copy mode
for {
err := strm.MsgRecv(readMsg)
if err == io.EOF {
// EOF = normal exit
err = packetWriter.WritePacket(srpc.NewCallDataPacket(nil, false, true, nil))
errCh <- err
return
}
if err == nil {
callData := readMsg.GetData()
err = packetWriter.WritePacket(srpc.NewCallDataPacket(callData, len(callData) == 0, false, nil))
}
if err != nil {
// attempt to write the error back to the client rpc
_ = packetWriter.WritePacket(srpc.NewCallDataPacket(nil, false, true, err))
errCh <- err
return
}
}
}()
// Wait for an error
resErr := <-errCh
return true, resErr
}
// _ is a type assertion
var (
_ srpc.Invoker = ((*ProxyInvoker)(nil))
_ bifrost_rpc.LookupRpcServiceValue = ((*ProxyInvoker)(nil))
)