-
Notifications
You must be signed in to change notification settings - Fork 0
/
server.go
174 lines (159 loc) · 3.91 KB
/
server.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
package codec
import (
"bufio"
"github.com/alilestera/tinyrpc/pkg/compressor"
"github.com/alilestera/tinyrpc/pkg/header"
"github.com/alilestera/tinyrpc/pkg/serializer"
"hash/crc32"
"io"
"net/rpc"
"sync"
)
type requestContext struct {
requestID uint64
compareType compressor.CompressType
}
type serverCodec struct {
r io.Reader
w io.Writer
c io.Closer
request header.RequestHeader
serializer serializer.Serializer
mutex sync.Mutex
seq uint64
pending map[uint64]*requestContext
}
// ReadRequestHeader read the rpc request header from the io stream
func (s *serverCodec) ReadRequestHeader(r *rpc.Request) error {
// reset serverCodec request header
s.request.ResetHeader()
// read request header
data, err := recvFrame(s.r)
if err != nil {
return err
}
// unmarshal data
err = s.request.Unmarshal(data)
if err != nil {
return err
}
// fill field
s.mutex.Lock()
s.seq++
s.pending[s.seq] = &requestContext{
requestID: s.request.ID,
compareType: s.request.CompressType,
}
r.ServiceMethod = s.request.Method
r.Seq = s.seq
s.mutex.Unlock()
return nil
}
// ReadRequestBody read the rpc request body from the io stream
func (s *serverCodec) ReadRequestBody(param any) error {
if param == nil {
if s.request.RequestLen != 0 { // discard excess
if err := read(s.r, make([]byte, s.request.RequestLen)); err != nil {
return err
}
}
return nil
}
// read ResponseLen length bytes
reqBody := make([]byte, s.request.RequestLen)
err := read(s.r, reqBody)
if err != nil {
return err
}
// check
if s.request.Checksum != 0 {
if crc32.ChecksumIEEE(reqBody) != s.request.Checksum {
return UnexpectedChecksumError
}
}
// check compressor whether exist
if _, ok := compressor.
Compressors[s.request.GetCompressType()]; !ok {
return NotFoundCompressorError
}
// unzip request body
req, err := compressor.Compressors[s.request.GetCompressType()].Unzip(reqBody)
if err != nil {
return err
}
return s.serializer.Unmarshal(req, param)
}
// WriteResponse Write the rpc response header and body to the io stream
func (s *serverCodec) WriteResponse(r *rpc.Response, param any) error {
s.mutex.Lock()
reqCtx, ok := s.pending[r.Seq]
if !ok {
s.mutex.Unlock()
return InvalidSequenceError
}
delete(s.pending, r.Seq)
s.mutex.Unlock()
// call rpc get wrong
if r.Error != "" {
param = nil
}
// check compressor whether exist
if _, ok := compressor.Compressors[reqCtx.compareType]; !ok {
return NotFoundCompressorError
}
// marshal
var respBody []byte
var err error
if param != nil {
respBody, err = s.serializer.Marshal(param)
if err != nil {
return err
}
}
// zip response body
compressedRespBody, err := compressor.Compressors[reqCtx.compareType].Zip(respBody)
if err != nil {
return err
}
return s.buildSendResponse(reqCtx, r, compressedRespBody)
}
// buildSendResponse build response and send frame
// reqBody is compressed
func (s *serverCodec) buildSendResponse(reqCtx *requestContext, r *rpc.Response, reqBody []byte) error {
// get header from pool
h := header.ResponsePool.Get().(*header.ResponseHeader)
defer func() {
h.ResetHeader()
header.ResponsePool.Put(h)
}()
// fill header
h.ID = reqCtx.requestID
h.Error = r.Error
h.ResponseLen = uint32(len(reqBody))
h.Checksum = crc32.ChecksumIEEE(reqBody)
h.CompressType = reqCtx.compareType
var err error
// send header
if err = sendFrame(s.w, h.Marshal()); err != nil {
return err
}
// send body
if err = write(s.w, reqBody); err != nil {
return err
}
s.w.(*bufio.Writer).Flush()
return nil
}
func (s *serverCodec) Close() error {
return s.c.Close()
}
// NewServerCodec Create a new server codec
func NewServerCodec(conn io.ReadWriteCloser, serializer serializer.Serializer) rpc.ServerCodec {
return &serverCodec{
r: bufio.NewReader(conn),
w: bufio.NewWriter(conn),
c: conn,
serializer: serializer,
pending: make(map[uint64]*requestContext),
}
}