Skip to content

Commit

Permalink
[zstd] Add a sanity limit to decompress buffer size allocation
Browse files Browse the repository at this point in the history
Fix #60
Before we were blindly trusting the data returned by ZSTD_getDecompressedSize. This mean with a specifically crafter payload, we would try to allocate a lot of memory resulting in potential DOS.
Now set a sane limit and fall back to streaming
  • Loading branch information
Viq111 committed Apr 6, 2022
1 parent eaf4b06 commit 30c4b29
Showing 1 changed file with 33 additions and 12 deletions.
45 changes: 33 additions & 12 deletions zstd.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ var (
ErrEmptySlice = errors.New("Bytes slice is empty")
)

const (
zstdFrameHeaderSizeMax = 18 // From zstd.h. Since it's experimental API, hardcoding it
)

// CompressBound returns the worst case size needed for a destination buffer,
// which can be used to preallocate a destination buffer or select a previously
// allocated buffer from a pool.
Expand All @@ -46,6 +50,30 @@ func cCompressBound(srcSize int) int {
return int(C.ZSTD_compressBound(C.size_t(srcSize)))
}

// decompressSizeHint tries to give a hint on how much of the output buffer size we should have
// based on zstd frame descriptors. To prevent DOS from maliciously-created payloads, limit the size
func decompressSizeHint(src []byte) int {
// 1 MB or 10x input size
upperBound := 10 * len(src)
if upperBound < 1000*1000 {
upperBound = 1000 * 1000
}

hint := upperBound
if len(src) >= zstdFrameHeaderSizeMax {
hint = int(C.ZSTD_getFrameContentSize(unsafe.Pointer(&src[0]), C.size_t(len(src))))
if hint < 0 { // On error, just use upperBound
hint = upperBound
}
}

// Take the minimum of both
if hint > upperBound {
return upperBound
}
return hint
}

// Compress src into dst. If you have a buffer to use, you can pass it to
// prevent allocation. If it is too small, or if nil is passed, a new buffer
// will be allocated and returned.
Expand Down Expand Up @@ -113,18 +141,11 @@ func Decompress(dst, src []byte) ([]byte, error) {
return dst[:written], nil
}

if len(dst) == 0 {
// Attempt to use zStd to determine decompressed size (may result in error or 0)
size := int(C.ZSTD_getDecompressedSize(unsafe.Pointer(&src[0]), C.size_t(len(src))))
if err := getError(size); err != nil {
return nil, err
}

if size > 0 {
dst = make([]byte, size)
} else {
dst = make([]byte, len(src)*3) // starting guess
}
bound := decompressSizeHint(src)
if cap(dst) >= bound {
dst = dst[0:cap(dst)]
} else {
dst = make([]byte, bound)
}
for i := 0; i < 3; i++ { // 3 tries to allocate a bigger buffer
result, err := decompress(dst, src)
Expand Down

0 comments on commit 30c4b29

Please sign in to comment.