-
Notifications
You must be signed in to change notification settings - Fork 670
/
gzip_compressor.go
90 lines (77 loc) · 2.46 KB
/
gzip_compressor.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
// Copyright (C) 2019-2023, Ava Labs, Inc. All rights reserved.
// See the file LICENSE for licensing terms.
package compression
import (
"bytes"
"compress/gzip"
"errors"
"fmt"
"io"
"math"
"sync"
)
var (
_ Compressor = (*gzipCompressor)(nil)
ErrInvalidMaxSizeCompressor = errors.New("invalid gzip compressor max size")
ErrDecompressedMsgTooLarge = errors.New("decompressed msg too large")
ErrMsgTooLarge = errors.New("msg too large to be compressed")
)
type gzipCompressor struct {
maxSize int64
gzipWriterPool sync.Pool
}
// Compress [msg] and returns the compressed bytes.
func (g *gzipCompressor) Compress(msg []byte) ([]byte, error) {
if int64(len(msg)) > g.maxSize {
return nil, fmt.Errorf("%w: (%d) > (%d)", ErrMsgTooLarge, len(msg), g.maxSize)
}
var writeBuffer bytes.Buffer
gzipWriter := g.gzipWriterPool.Get().(*gzip.Writer)
gzipWriter.Reset(&writeBuffer)
defer g.gzipWriterPool.Put(gzipWriter)
if _, err := gzipWriter.Write(msg); err != nil {
return nil, err
}
if err := gzipWriter.Close(); err != nil {
return nil, err
}
return writeBuffer.Bytes(), nil
}
// Decompress decompresses [msg].
func (g *gzipCompressor) Decompress(msg []byte) ([]byte, error) {
bytesReader := bytes.NewReader(msg)
gzipReader, err := gzip.NewReader(bytesReader)
if err != nil {
return nil, err
}
// We allow [io.LimitReader] to read up to [g.maxSize + 1] bytes, so that if
// the decompressed payload is greater than the maximum size, this function
// will return the appropriate error instead of an incomplete byte slice.
limitedReader := io.LimitReader(gzipReader, g.maxSize+1)
decompressed, err := io.ReadAll(limitedReader)
if err != nil {
return nil, err
}
if int64(len(decompressed)) > g.maxSize {
return nil, fmt.Errorf("%w: (%d) > (%d)", ErrDecompressedMsgTooLarge, len(decompressed), g.maxSize)
}
return decompressed, gzipReader.Close()
}
// NewGzipCompressor returns a new gzip Compressor that compresses
func NewGzipCompressor(maxSize int64) (Compressor, error) {
if maxSize == math.MaxInt64 {
// "Decompress" creates "io.LimitReader" with max size + 1:
// if the max size + 1 overflows, "io.LimitReader" reads nothing
// returning 0 byte for the decompress call
// require max size <math.MaxInt64 to prevent int64 overflows
return nil, ErrInvalidMaxSizeCompressor
}
return &gzipCompressor{
maxSize: maxSize,
gzipWriterPool: sync.Pool{
New: func() interface{} {
return gzip.NewWriter(nil)
},
},
}, nil
}