Skip to content

Commit

Permalink
[zstd_stream] Don't block in reader.Read if a zstd block is available
Browse files Browse the repository at this point in the history
reader.Read used to try to fully read an internal buffer until EOF or
the buffer was filled. That was buffer was set to ZSTD_DStreamInSize,
which is larger than any zstd block.

This means that reader.Read could try to buffer much more data than
what was needed to process and return a single block from the Read
method.

This was an issue because we could miss an urgent Flush from a
corresponding Writer by blocking. (A typical use case is instant
messaging.) It was also against the general convention of io.Read that a
single call should return any immediately available data without
blocking, if any.

Interestingly enough, the test case should have caught this up, but
because we used a bytes.Buffer, the Read returned io.EOF after reading
the entirety of the buffer, even if we appended to the buffer later on.
The test case is also fixed by this commit.

Fixes: #95
  • Loading branch information
delthas committed Jul 9, 2021
1 parent e292af4 commit 42c5dcb
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 49 deletions.
108 changes: 64 additions & 44 deletions zstd_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -422,29 +422,66 @@ func (r *reader) Read(p []byte) (int, error) {
return 0, r.firstError
}

// If we already have enough bytes, return
if r.decompSize-r.decompOff >= len(p) {
copy(p, r.decompressionBuffer[r.decompOff:])
r.decompOff += len(p)
return len(p), nil
if len(p) == 0 {
return 0, nil
}

// If we already have some uncompressed bytes, return without blocking
if r.decompSize > r.decompOff {
if r.decompSize-r.decompOff > len(p) {
copy(p, r.decompressionBuffer[r.decompOff:])
r.decompOff += len(p)
return len(p), nil
}
// From https://golang.org/pkg/io/#Reader
// > Read conventionally returns what is available instead of waiting for more.
copy(p, r.decompressionBuffer[r.decompOff:r.decompSize])
got := r.decompSize - r.decompOff
r.decompOff = r.decompSize
return got, nil
}

copy(p, r.decompressionBuffer[r.decompOff:r.decompSize])
got := r.decompSize - r.decompOff
r.decompSize = 0
r.decompOff = 0

for got < len(p) {
// Populate src
src := r.compressionBuffer
reader := r.underlyingReader
n, err := TryReadFull(reader, src[r.compressionLeft:])
if err != nil && err != errShortRead { // Handle underlying reader errors first
return 0, fmt.Errorf("failed to read from underlying reader: %s", err)
} else if n == 0 && r.compressionLeft == 0 {
return got, io.EOF
// Repeatedly read from the underlying reader until we get
// at least one zstd block, so that we don't block if the
// other end has flushed a block.
for {
// - If the last decompression didn't entirely fill the decompression buffer,
// zstd flushed all it could, and needs new data. In that case, do 1 Read.
// - If the last decompression did entirely fill the decompression buffer,
// it might have needed more room to decompress the input. In that case,
// don't do any unnecessary Read that might block.
needsData := r.decompSize < len(r.decompressionBuffer)

var src []byte
if !needsData {
src = r.compressionBuffer[:r.compressionLeft]
} else {
src = r.compressionBuffer
var n int
var err error
// Read until data arrives or an error occurs.
for n == 0 && err == nil {
n, err = r.underlyingReader.Read(src[r.compressionLeft:])
}
if err != nil && err != io.EOF { // Handle underlying reader errors first
return 0, fmt.Errorf("failed to read from underlying reader: %s", err)
}
if n == 0 {
// Ideally, we'd return with ErrUnexpectedEOF in all cases where the stream was unexpectedly EOF'd
// during a block or frame, i.e. when there are incomplete, pending compression data.
// However, it's hard to detect those cases with zstd. Namely, there is no way to know the size of
// the current buffered compression data in the zstd stream internal buffers.
// Best effort: throw ErrUnexpectedEOF if we still have some pending buffered compression data that
// zstd doesn't want to accept.
// If we don't have any buffered compression data but zstd still has some in its internal buffers,
// we will return with EOF instead.
if r.compressionLeft > 0 {
return 0, io.ErrUnexpectedEOF
}
return 0, io.EOF
}
src = src[:r.compressionLeft+n]
}
src = src[:r.compressionLeft+n]

// C code
var srcPtr *byte // Do not point anywhere, if src is empty
Expand All @@ -462,9 +499,9 @@ func (r *reader) Read(p []byte) (int, error) {
)
retCode := int(r.resultBuffer.return_code)

// Keep src here eventhough we reuse later, the code might be deleted at some point
// Keep src here even though we reuse later, the code might be deleted at some point
runtime.KeepAlive(src)
if err = getError(retCode); err != nil {
if err := getError(retCode); err != nil {
return 0, fmt.Errorf("failed to decompress: %s", err)
}

Expand All @@ -474,10 +511,9 @@ func (r *reader) Read(p []byte) (int, error) {
left := src[bytesConsumed:]
copy(r.compressionBuffer, left)
}
r.compressionLeft = len(src) - int(bytesConsumed)
r.compressionLeft = len(src) - bytesConsumed
r.decompSize = int(r.resultBuffer.bytes_written)
r.decompOff = copy(p[got:], r.decompressionBuffer[:r.decompSize])
got += r.decompOff
r.decompOff = copy(p, r.decompressionBuffer[:r.decompSize])

// Resize buffers
nsize := retCode // Hint for next src buffer size
Expand All @@ -489,25 +525,9 @@ func (r *reader) Read(p []byte) (int, error) {
nsize = r.compressionLeft
}
r.compressionBuffer = resize(r.compressionBuffer, nsize)
}
return got, nil
}

// TryReadFull reads buffer just as ReadFull does
// Here we expect that buffer may end and we do not return ErrUnexpectedEOF as ReadAtLeast does.
// We return errShortRead instead to distinguish short reads and failures.
// We cannot use ReadFull/ReadAtLeast because it masks Reader errors, such as network failures
// and causes panic instead of error.
func TryReadFull(r io.Reader, buf []byte) (n int, err error) {
for n < len(buf) && err == nil {
var nn int
nn, err = r.Read(buf[n:])
n += nn
}
if n == len(buf) && err == io.EOF {
err = nil // EOF at the end is somewhat expected
} else if err == io.EOF {
err = errShortRead
if r.decompOff > 0 {
return r.decompOff, nil
}
}
return
}
17 changes: 12 additions & 5 deletions zstd_stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ func testCompressionDecompression(t *testing.T, dict []byte, payload []byte) {
// Decompress
r := NewReaderDict(rr, dict)
dst := make([]byte, len(payload))
n, err := r.Read(dst)
n, err := io.ReadFull(r, dst)
if err != nil {
failOnError(t, "Failed to read for decompression", err)
}
Expand Down Expand Up @@ -211,9 +211,16 @@ func TestStreamEmptyPayload(t *testing.T) {
}

func TestStreamFlush(t *testing.T) {
var w bytes.Buffer
writer := NewWriter(&w)
reader := NewReader(&w)
// use an actual os pipe so that
// - it's buffered and we don't get a 1-read = 1-write behaviour (io.Pipe)
// - reading doesn't send EOF when we're done reading the buffer (bytes.Buffer)
pr, pw, err := os.Pipe()
failOnError(t, "Failed creating pipe", err)
defer pw.Close()
defer pr.Close()

writer := NewWriter(pw)
reader := NewReader(pr)

payload := "cc" // keep the payload short to make sure it will not be automatically flushed by zstd
buf := make([]byte, len(payload))
Expand Down Expand Up @@ -429,7 +436,7 @@ func BenchmarkStreamDecompression(b *testing.B) {
for i := 0; i < b.N; i++ {
rr := bytes.NewReader(compressed)
r := NewReader(rr)
_, err := r.Read(dst)
_, err := io.ReadFull(r, dst)
if err != nil {
b.Fatalf("Failed to decompress: %s", err)
}
Expand Down

0 comments on commit 42c5dcb

Please sign in to comment.