/
header.go
65 lines (52 loc) · 1.5 KB
/
header.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
package packet
func headerLen(rl int) int {
// add packet type and flag byte + remaining length
return 1 + varintLen(uint64(rl))
}
func encodeHeader(dst []byte, flags byte, rl int, tl int, t Type) (int, error) {
// check buffer length
if len(dst) < headerLen(rl) || len(dst) < tl {
return 0, insufficientBufferSize(t)
}
// write type and flags
typeAndFlags := byte(t)<<4 | (t.defaultFlags() & 0xf)
typeAndFlags |= flags
dst[0] = typeAndFlags
// write remaining length
n, err := writeVarint(dst[1:], uint64(rl), t)
if err != nil {
return 0, err
}
return 1 + n, nil
}
func decodeHeader(src []byte, t Type) (int, byte, int, error) {
// check buffer size
if len(src) < 2 {
return 0, 0, 0, insufficientBufferSize(t)
}
// read type and flags
decodedType := Type(src[0] >> 4)
flags := src[0] & 0x0f
total := 1
// check against static type
if decodedType != t {
return total, 0, 0, makeError(t, "invalid type %d", decodedType)
}
// check flags except for publish packets
if t != PUBLISH && flags != t.defaultFlags() {
return total, 0, 0, makeError(t, "invalid flags, expected %d, got %d", t.defaultFlags(), flags)
}
// read remaining length
_rl, n, err := readVarint(src[total:], t)
total += n
if err != nil {
return total, 0, 0, err
}
// get remaining length
rl := int(_rl)
// check remaining buffer
if rl > len(src[total:]) {
return total, 0, 0, makeError(t, "remaining length (%d) is greater than remaining buffer (%d)", rl, len(src[total:]))
}
return total, flags, rl, nil
}