diff --git a/ext/zstdruby/streaming_compress.c b/ext/zstdruby/streaming_compress.c index 8f491eb..0bcfd7e 100644 --- a/ext/zstdruby/streaming_compress.c +++ b/ext/zstdruby/streaming_compress.c @@ -105,13 +105,13 @@ static VALUE no_compress(struct streaming_compress_t* sc, ZSTD_EndDirective endOp) { ZSTD_inBuffer input = { NULL, 0, 0 }; - const char* output_data = RSTRING_PTR(sc->buf); VALUE result = rb_str_new(0, 0); size_t ret; do { + const char* output_data = RSTRING_PTR(sc->buf); ZSTD_outBuffer output = { (void*)output_data, sc->buf_size, 0 }; - size_t const ret = zstd_stream_compress(sc->ctx, &output, &input, endOp, false); + ret = zstd_stream_compress(sc->ctx, &output, &input, endOp, false); if (ZSTD_isError(ret)) { rb_raise(rb_eRuntimeError, "flush error error code: %s", ZSTD_getErrorName(ret)); } @@ -131,9 +131,9 @@ rb_streaming_compress_compress(VALUE obj, VALUE src) struct streaming_compress_t* sc; TypedData_Get_Struct(obj, struct streaming_compress_t, &streaming_compress_type, sc); - const char* output_data = RSTRING_PTR(sc->buf); VALUE result = rb_str_new(0, 0); while (input.pos < input.size) { + const char* output_data = RSTRING_PTR(sc->buf); ZSTD_outBuffer output = { (void*)output_data, sc->buf_size, 0 }; size_t const ret = zstd_stream_compress(sc->ctx, &output, &input, ZSTD_e_continue, false); if (ZSTD_isError(ret)) { @@ -150,7 +150,6 @@ rb_streaming_compress_write(int argc, VALUE *argv, VALUE obj) size_t total = 0; struct streaming_compress_t* sc; TypedData_Get_Struct(obj, struct streaming_compress_t, &streaming_compress_type, sc); - const char* output_data = RSTRING_PTR(sc->buf); while (argc-- > 0) { VALUE str = *argv++; @@ -160,18 +159,20 @@ rb_streaming_compress_write(int argc, VALUE *argv, VALUE obj) ZSTD_inBuffer input = { input_data, input_size, 0 }; while (input.pos < input.size) { + const char* output_data = RSTRING_PTR(sc->buf); ZSTD_outBuffer output = { (void*)output_data, sc->buf_size, 0 }; size_t const ret = zstd_stream_compress(sc->ctx, &output, &input, ZSTD_e_continue, false); if (ZSTD_isError(ret)) { rb_raise(rb_eRuntimeError, "compress error error code: %s", ZSTD_getErrorName(ret)); } - /* collect produced bytes */ + /* Directly append to the pending buffer */ if (output.pos > 0) { rb_str_cat(sc->pending, output.dst, output.pos); } - total += RSTRING_LEN(str); } + total += RSTRING_LEN(str); } + return SIZET2NUM(total); } @@ -202,9 +203,9 @@ rb_streaming_compress_flush(VALUE obj) struct streaming_compress_t* sc; TypedData_Get_Struct(obj, struct streaming_compress_t, &streaming_compress_type, sc); VALUE drained = no_compress(sc, ZSTD_e_flush); - rb_str_cat(sc->pending, RSTRING_PTR(drained), RSTRING_LEN(drained)); - VALUE out = sc->pending; - sc->pending = rb_str_new(0, 0); + VALUE out = rb_str_dup(sc->pending); + rb_str_cat(out, RSTRING_PTR(drained), RSTRING_LEN(drained)); + rb_str_resize(sc->pending, 0); return out; } @@ -214,9 +215,9 @@ rb_streaming_compress_finish(VALUE obj) struct streaming_compress_t* sc; TypedData_Get_Struct(obj, struct streaming_compress_t, &streaming_compress_type, sc); VALUE drained = no_compress(sc, ZSTD_e_end); - rb_str_cat(sc->pending, RSTRING_PTR(drained), RSTRING_LEN(drained)); - VALUE out = sc->pending; - sc->pending = rb_str_new(0, 0); + VALUE out = rb_str_dup(sc->pending); + rb_str_cat(out, RSTRING_PTR(drained), RSTRING_LEN(drained)); + rb_str_resize(sc->pending, 0); return out; } diff --git a/ext/zstdruby/streaming_decompress.c b/ext/zstdruby/streaming_decompress.c index 0336967..ac65e97 100644 --- a/ext/zstdruby/streaming_decompress.c +++ b/ext/zstdruby/streaming_decompress.c @@ -100,15 +100,22 @@ rb_streaming_decompress_decompress(VALUE obj, VALUE src) struct streaming_decompress_t* sd; TypedData_Get_Struct(obj, struct streaming_decompress_t, &streaming_decompress_type, sd); - const char* output_data = RSTRING_PTR(sd->buf); VALUE result = rb_str_new(0, 0); + while (input.pos < input.size) { + const char* output_data = RSTRING_PTR(sd->buf); ZSTD_outBuffer output = { (void*)output_data, sd->buf_size, 0 }; size_t const ret = zstd_stream_decompress(sd->dctx, &output, &input, false); + if (ZSTD_isError(ret)) { rb_raise(rb_eRuntimeError, "decompress error error code: %s", ZSTD_getErrorName(ret)); } - rb_str_cat(result, output.dst, output.pos); + if (output.pos > 0) { + rb_str_cat(result, output.dst, output.pos); + } + if (ret == 0 && output.pos == 0) { + break; + } } return result; } diff --git a/ext/zstdruby/zstdruby.c b/ext/zstdruby/zstdruby.c index e63b95c..1d41522 100644 --- a/ext/zstdruby/zstdruby.c +++ b/ext/zstdruby/zstdruby.c @@ -39,61 +39,91 @@ static VALUE rb_compress(int argc, VALUE *argv, VALUE self) return output; } -static VALUE decompress_buffered(ZSTD_DCtx* dctx, const char* input_data, size_t input_size) -{ - ZSTD_inBuffer input = { input_data, input_size, 0 }; - VALUE result = rb_str_new(0, 0); +static VALUE decode_one_frame(ZSTD_DCtx* dctx, const unsigned char* src, size_t size, VALUE kwargs) { + VALUE out = rb_str_buf_new(0); + size_t cap = ZSTD_DStreamOutSize(); + char *buf = ALLOC_N(char, cap); + ZSTD_inBuffer in = (ZSTD_inBuffer){ src, size, 0 }; - while (input.pos < input.size) { - ZSTD_outBuffer output = { NULL, 0, 0 }; - output.size += ZSTD_DStreamOutSize(); - VALUE output_string = rb_str_new(NULL, output.size); - output.dst = RSTRING_PTR(output_string); + ZSTD_DCtx_reset(dctx, ZSTD_reset_session_only); + set_decompress_params(dctx, kwargs); - size_t ret = zstd_stream_decompress(dctx, &output, &input, false); + for (;;) { + ZSTD_outBuffer o = (ZSTD_outBuffer){ buf, cap, 0 }; + size_t ret = ZSTD_decompressStream(dctx, &o, &in); if (ZSTD_isError(ret)) { - ZSTD_freeDCtx(dctx); - rb_raise(rb_eRuntimeError, "%s: %s", "ZSTD_decompressStream failed", ZSTD_getErrorName(ret)); + xfree(buf); + rb_raise(rb_eRuntimeError, "ZSTD_decompressStream failed: %s", ZSTD_getErrorName(ret)); + } + if (o.pos) { + rb_str_cat(out, buf, o.pos); + } + if (ret == 0) { + break; } - rb_str_cat(result, output.dst, output.pos); } - ZSTD_freeDCtx(dctx); - return result; + xfree(buf); + return out; +} + +static VALUE decompress_buffered(ZSTD_DCtx* dctx, const char* data, size_t len) { + return decode_one_frame(dctx, (const unsigned char*)data, len, Qnil); } static VALUE rb_decompress(int argc, VALUE *argv, VALUE self) { - VALUE input_value; - VALUE kwargs; + VALUE input_value, kwargs; rb_scan_args(argc, argv, "10:", &input_value, &kwargs); StringValue(input_value); - char* input_data = RSTRING_PTR(input_value); - size_t input_size = RSTRING_LEN(input_value); - ZSTD_DCtx* const dctx = ZSTD_createDCtx(); - if (dctx == NULL) { - rb_raise(rb_eRuntimeError, "%s", "ZSTD_createDCtx failed"); - } - set_decompress_params(dctx, kwargs); - unsigned long long const uncompressed_size = ZSTD_getFrameContentSize(input_data, input_size); - if (uncompressed_size == ZSTD_CONTENTSIZE_ERROR) { - rb_raise(rb_eRuntimeError, "%s: %s", "not compressed by zstd", ZSTD_getErrorName(uncompressed_size)); - } - // ZSTD_decompressStream may be called multiple times when ZSTD_CONTENTSIZE_UNKNOWN, causing slowness. - // Therefore, we will not standardize on ZSTD_decompressStream - if (uncompressed_size == ZSTD_CONTENTSIZE_UNKNOWN) { - return decompress_buffered(dctx, input_data, input_size); - } + size_t in_size = RSTRING_LEN(input_value); + const unsigned char *in_r = (const unsigned char *)RSTRING_PTR(input_value); + unsigned char *in = ALLOC_N(unsigned char, in_size); + memcpy(in, in_r, in_size); + + size_t off = 0; + const uint32_t ZSTD_MAGIC = 0xFD2FB528U; + const uint32_t SKIP_LO = 0x184D2A50U; /* ...5F */ + + while (off + 4 <= in_size) { + uint32_t magic = (uint32_t)in[off] + | ((uint32_t)in[off+1] << 8) + | ((uint32_t)in[off+2] << 16) + | ((uint32_t)in[off+3] << 24); + + if ((magic & 0xFFFFFFF0U) == (SKIP_LO & 0xFFFFFFF0U)) { + if (off + 8 > in_size) break; + uint32_t skipLen = (uint32_t)in[off+4] + | ((uint32_t)in[off+5] << 8) + | ((uint32_t)in[off+6] << 16) + | ((uint32_t)in[off+7] << 24); + size_t adv = (size_t)8 + (size_t)skipLen; + if (off + adv > in_size) break; + off += adv; + continue; + } - VALUE output = rb_str_new(NULL, uncompressed_size); - char* output_data = RSTRING_PTR(output); + if (magic == ZSTD_MAGIC) { + ZSTD_DCtx *dctx = ZSTD_createDCtx(); + if (!dctx) { + xfree(in); + rb_raise(rb_eRuntimeError, "ZSTD_createDCtx failed"); + } + + VALUE out = decode_one_frame(dctx, in + off, in_size - off, kwargs); - size_t const decompress_size = zstd_decompress(dctx, output_data, uncompressed_size, input_data, input_size, false); - if (ZSTD_isError(decompress_size)) { - rb_raise(rb_eRuntimeError, "%s: %s", "decompress error", ZSTD_getErrorName(decompress_size)); + ZSTD_freeDCtx(dctx); + xfree(in); + RB_GC_GUARD(input_value); + return out; + } + + off += 1; } - ZSTD_freeDCtx(dctx); - return output; + + xfree(in); + RB_GC_GUARD(input_value); + rb_raise(rb_eRuntimeError, "not a zstd frame (magic not found)"); } static void free_cdict(void *dict)