Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix "Go pointer to Go pointer" panics #97

Merged
merged 2 commits into from
Mar 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 18 additions & 10 deletions zstd.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,18 +58,26 @@ func CompressLevel(dst, src []byte, level int) ([]byte, error) {
dst = make([]byte, bound)
}

var srcPtr *byte // Do not point anywhere, if src is empty
if len(src) > 0 {
srcPtr = &src[0]
// We need unsafe.Pointer(&src[0]) in the Cgo call to avoid "Go pointer to Go pointer" panics.
// This means we need to special case empty input. See:
// https://github.com/golang/go/issues/14210#issuecomment-346402945
var cWritten C.size_t
if len(src) == 0 {
cWritten = C.ZSTD_compress(
unsafe.Pointer(&dst[0]),
C.size_t(len(dst)),
unsafe.Pointer(nil),
C.size_t(0),
C.int(level))
} else {
cWritten = C.ZSTD_compress(
unsafe.Pointer(&dst[0]),
C.size_t(len(dst)),
unsafe.Pointer(&src[0]),
C.size_t(len(src)),
C.int(level))
}

cWritten := C.ZSTD_compress(
unsafe.Pointer(&dst[0]),
C.size_t(len(dst)),
unsafe.Pointer(srcPtr),
C.size_t(len(src)),
C.int(level))

written := int(cWritten)
// Check if the return is an Error code
if err := getError(written); err != nil {
Expand Down
31 changes: 20 additions & 11 deletions zstd_ctx.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,19 +63,28 @@ func (c *ctx) CompressLevel(dst, src []byte, level int) ([]byte, error) {
dst = make([]byte, bound)
}

var srcPtr *byte // Do not point anywhere, if src is empty
if len(src) > 0 {
srcPtr = &src[0]
// We need unsafe.Pointer(&src[0]) in the Cgo call to avoid "Go pointer to Go pointer" panics.
// This means we need to special case empty input. See:
// https://github.com/golang/go/issues/14210#issuecomment-346402945
var cWritten C.size_t
if len(src) == 0 {
cWritten = C.ZSTD_compressCCtx(
c.cctx,
unsafe.Pointer(&dst[0]),
C.size_t(len(dst)),
unsafe.Pointer(nil),
C.size_t(0),
C.int(level))
} else {
cWritten = C.ZSTD_compressCCtx(
c.cctx,
unsafe.Pointer(&dst[0]),
C.size_t(len(dst)),
unsafe.Pointer(&src[0]),
C.size_t(len(src)),
C.int(level))
}

cWritten := C.ZSTD_compressCCtx(
c.cctx,
unsafe.Pointer(&dst[0]),
C.size_t(len(dst)),
unsafe.Pointer(srcPtr),
C.size_t(len(src)),
C.int(level))

written := int(cWritten)
// Check if the return is an Error code
if err := getError(written); err != nil {
Expand Down
7 changes: 7 additions & 0 deletions zstd_ctx_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,13 @@ func TestCtxCompressLevel(t *testing.T) {
}
}

func TestCtxCompressLevelNoGoPointers(t *testing.T) {
testCompressNoGoPointers(t, func(input []byte) ([]byte, error) {
cctx := NewCtx()
return cctx.CompressLevel(nil, input, BestSpeed)
})
}

func TestCtxEmptySliceCompress(t *testing.T) {
ctx := NewCtx()

Expand Down
10 changes: 5 additions & 5 deletions zstd_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,17 +168,17 @@ func (w *Writer) Write(p []byte) (int, error) {
srcData = w.srcBuffer
}

var srcPtr *byte // Do not point anywhere, if src is empty
if len(srcData) > 0 {
srcPtr = &srcData[0]
if len(srcData) == 0 {
// this is technically unnecessary: srcData is p or w.srcBuffer, and len() > 0 checked above
// but this ensures the code can change without dereferencing an srcData[0]
return 0, nil
}

C.ZSTD_compressStream2_wrapper(
w.resultBuffer,
w.ctx,
unsafe.Pointer(&w.dstBuffer[0]),
C.size_t(len(w.dstBuffer)),
unsafe.Pointer(srcPtr),
unsafe.Pointer(&srcData[0]),
C.size_t(len(srcData)),
)
ret := int(w.resultBuffer.return_code)
Expand Down
16 changes: 16 additions & 0 deletions zstd_stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,22 @@ func TestStreamDecompressionChunks(t *testing.T) {
}
}

func TestStreamWriteNoGoPointers(t *testing.T) {
testCompressNoGoPointers(t, func(input []byte) ([]byte, error) {
buf := &bytes.Buffer{}
zw := NewWriter(buf)
_, err := zw.Write(input)
if err != nil {
return nil, err
}
err = zw.Close()
if err != nil {
return nil, err
}
return buf.Bytes(), nil
})
}

func BenchmarkStreamCompression(b *testing.B) {
if raw == nil {
b.Fatal(ErrNoPayloadEnv)
Expand Down
37 changes: 37 additions & 0 deletions zstd_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,43 @@ func TestCompressLevel(t *testing.T) {
}
}

// structWithGoPointers contains a byte buffer and a pointer to Go objects (slice). This means
// Cgo checks can fail when passing a pointer to buffer:
// "panic: runtime error: cgo argument has Go pointer to Go pointer"
// https://github.com/golang/go/issues/14210#issuecomment-346402945
type structWithGoPointers struct {
buffer [1]byte
slice []byte
}

// testCompressDecompressByte ensures that functions use the correct unsafe.Pointer assignment
// to avoid "Go pointer to Go pointer" panics.
func testCompressNoGoPointers(t *testing.T, compressFunc func(input []byte) ([]byte, error)) {
t.Helper()

s := structWithGoPointers{}
s.buffer[0] = 0x42
s.slice = s.buffer[:1]

compressed, err := compressFunc(s.slice)
if err != nil {
t.Fatal(err)
}
decompressed, err := Decompress(nil, compressed)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(decompressed, s.slice) {
t.Errorf("decompressed=%#v input=%#v", decompressed, s.slice)
}
}

func TestCompressLevelNoGoPointers(t *testing.T) {
testCompressNoGoPointers(t, func(input []byte) ([]byte, error) {
return CompressLevel(nil, input, BestSpeed)
})
}

func doCompressLevel(payload []byte, out []byte) error {
out, err := CompressLevel(out, payload, DefaultCompression)
if err != nil {
Expand Down