/
client.go
154 lines (139 loc) · 3.7 KB
/
client.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
package codec
import (
"bufio"
"github.com/alilestera/tinyrpc/pkg/compressor"
"github.com/alilestera/tinyrpc/pkg/header"
"github.com/alilestera/tinyrpc/pkg/serializer"
"hash/crc32"
"io"
"net/rpc"
"sync"
)
type clientCodec struct {
r io.Reader
w io.Writer
c io.Closer
compressor compressor.CompressType // rpc compress type(raw,gzip,snappy,zlib)
serializer serializer.Serializer
response header.ResponseHeader // rpc response header
mutex sync.Mutex // protect pending map
pending map[uint64]string
}
// WriteRequest Write the rpc request header and body to the io stream
func (c *clientCodec) WriteRequest(r *rpc.Request, param any) error {
c.mutex.Lock()
c.pending[r.Seq] = r.ServiceMethod
c.mutex.Unlock()
// check compressor whether exist
if _, ok := compressor.Compressors[c.compressor]; !ok {
return NotFoundCompressorError
}
// call serializer
reqBody, err := c.serializer.Marshal(param)
if err != nil {
return err
}
// zip request body
compressedReqBody, err := compressor.Compressors[c.compressor].Zip(reqBody)
if err != nil {
return err
}
return c.buildSendRequest(r, compressedReqBody)
}
// buildSendRequest build request and send frame
// reqBody is compressed
func (c *clientCodec) buildSendRequest(r *rpc.Request, reqBody []byte) error {
// get header from pool
h := header.RequestPool.Get().(*header.RequestHeader)
defer func() {
h.ResetHeader()
header.RequestPool.Put(h)
}()
// fill header
h.ID = r.Seq
h.Method = r.ServiceMethod
h.RequestLen = uint32(len(reqBody))
h.CompressType = c.compressor
h.Checksum = crc32.ChecksumIEEE(reqBody)
// send header
if err := sendFrame(c.w, h.Marshal()); err != nil {
return err
}
// send body
if err := write(c.w, reqBody); err != nil {
return err
}
c.w.(*bufio.Writer).Flush()
return nil
}
// ReadResponseHeader read the rpc response header from the io stream
func (c *clientCodec) ReadResponseHeader(r *rpc.Response) error {
// reset clientCodec header
c.response.ResetHeader()
// read response header
data, err := recvFrame(c.r)
if err != nil {
return err
}
// unmarshal header
err = c.response.Unmarshal(data)
if err != nil {
return err
}
// fill response
c.mutex.Lock()
r.Seq = c.response.ID
r.Error = c.response.Error
r.ServiceMethod = c.pending[r.Seq]
delete(c.pending, r.Seq)
c.mutex.Unlock()
return nil
}
// ReadResponseBody read the rpc response body from the io stream
func (c *clientCodec) ReadResponseBody(param any) error {
if param == nil {
if c.response.ResponseLen != 0 { // discard excess
if err := read(c.r, make([]byte, c.response.ResponseLen)); err != nil {
return err
}
}
return nil
}
// read ResponseLen length bytes
respBody := make([]byte, c.response.ResponseLen)
err := read(c.r, respBody)
if err != nil {
return err
}
// check
if c.response.Checksum != 0 {
if crc32.ChecksumIEEE(respBody) != c.response.Checksum {
return UnexpectedChecksumError
}
}
// check compressor whether is match
if c.response.GetCompressType() != c.compressor {
return CompressorTypeMismatchError
}
// unzip response body
resp, err := compressor.Compressors[c.response.GetCompressType()].Unzip(respBody)
if err != nil {
return err
}
// unmarshal unzip response body
return c.serializer.Unmarshal(resp, param)
}
func (c *clientCodec) Close() error {
return c.c.Close()
}
// NewClientCodec Create a new client codec
func NewClientCodec(conn io.ReadWriteCloser, compressType compressor.CompressType, serializer serializer.Serializer) rpc.ClientCodec {
return &clientCodec{
r: bufio.NewReader(conn),
w: bufio.NewWriter(conn),
c: conn,
compressor: compressType,
serializer: serializer,
pending: make(map[uint64]string),
}
}