Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 13 additions & 12 deletions ext/zstdruby/streaming_compress.c
Original file line number Diff line number Diff line change
Expand Up @@ -105,13 +105,13 @@ static VALUE
no_compress(struct streaming_compress_t* sc, ZSTD_EndDirective endOp)
{
ZSTD_inBuffer input = { NULL, 0, 0 };
const char* output_data = RSTRING_PTR(sc->buf);
VALUE result = rb_str_new(0, 0);
size_t ret;
do {
const char* output_data = RSTRING_PTR(sc->buf);
ZSTD_outBuffer output = { (void*)output_data, sc->buf_size, 0 };

size_t const ret = zstd_stream_compress(sc->ctx, &output, &input, endOp, false);
ret = zstd_stream_compress(sc->ctx, &output, &input, endOp, false);
if (ZSTD_isError(ret)) {
rb_raise(rb_eRuntimeError, "flush error error code: %s", ZSTD_getErrorName(ret));
}
Expand All @@ -131,9 +131,9 @@ rb_streaming_compress_compress(VALUE obj, VALUE src)
struct streaming_compress_t* sc;
TypedData_Get_Struct(obj, struct streaming_compress_t, &streaming_compress_type, sc);

const char* output_data = RSTRING_PTR(sc->buf);
VALUE result = rb_str_new(0, 0);
while (input.pos < input.size) {
const char* output_data = RSTRING_PTR(sc->buf);
ZSTD_outBuffer output = { (void*)output_data, sc->buf_size, 0 };
size_t const ret = zstd_stream_compress(sc->ctx, &output, &input, ZSTD_e_continue, false);
if (ZSTD_isError(ret)) {
Expand All @@ -150,7 +150,6 @@ rb_streaming_compress_write(int argc, VALUE *argv, VALUE obj)
size_t total = 0;
struct streaming_compress_t* sc;
TypedData_Get_Struct(obj, struct streaming_compress_t, &streaming_compress_type, sc);
const char* output_data = RSTRING_PTR(sc->buf);

while (argc-- > 0) {
VALUE str = *argv++;
Expand All @@ -160,18 +159,20 @@ rb_streaming_compress_write(int argc, VALUE *argv, VALUE obj)
ZSTD_inBuffer input = { input_data, input_size, 0 };

while (input.pos < input.size) {
const char* output_data = RSTRING_PTR(sc->buf);
ZSTD_outBuffer output = { (void*)output_data, sc->buf_size, 0 };
size_t const ret = zstd_stream_compress(sc->ctx, &output, &input, ZSTD_e_continue, false);
if (ZSTD_isError(ret)) {
rb_raise(rb_eRuntimeError, "compress error error code: %s", ZSTD_getErrorName(ret));
}
/* collect produced bytes */
/* Directly append to the pending buffer */
if (output.pos > 0) {
rb_str_cat(sc->pending, output.dst, output.pos);
}
total += RSTRING_LEN(str);
}
total += RSTRING_LEN(str);
}

return SIZET2NUM(total);
}

Expand Down Expand Up @@ -202,9 +203,9 @@ rb_streaming_compress_flush(VALUE obj)
struct streaming_compress_t* sc;
TypedData_Get_Struct(obj, struct streaming_compress_t, &streaming_compress_type, sc);
VALUE drained = no_compress(sc, ZSTD_e_flush);
rb_str_cat(sc->pending, RSTRING_PTR(drained), RSTRING_LEN(drained));
VALUE out = sc->pending;
sc->pending = rb_str_new(0, 0);
VALUE out = rb_str_dup(sc->pending);
rb_str_cat(out, RSTRING_PTR(drained), RSTRING_LEN(drained));
rb_str_resize(sc->pending, 0);
return out;
}

Expand All @@ -214,9 +215,9 @@ rb_streaming_compress_finish(VALUE obj)
struct streaming_compress_t* sc;
TypedData_Get_Struct(obj, struct streaming_compress_t, &streaming_compress_type, sc);
VALUE drained = no_compress(sc, ZSTD_e_end);
rb_str_cat(sc->pending, RSTRING_PTR(drained), RSTRING_LEN(drained));
VALUE out = sc->pending;
sc->pending = rb_str_new(0, 0);
VALUE out = rb_str_dup(sc->pending);
rb_str_cat(out, RSTRING_PTR(drained), RSTRING_LEN(drained));
rb_str_resize(sc->pending, 0);
return out;
}

Expand Down
11 changes: 9 additions & 2 deletions ext/zstdruby/streaming_decompress.c
Original file line number Diff line number Diff line change
Expand Up @@ -100,15 +100,22 @@ rb_streaming_decompress_decompress(VALUE obj, VALUE src)

struct streaming_decompress_t* sd;
TypedData_Get_Struct(obj, struct streaming_decompress_t, &streaming_decompress_type, sd);
const char* output_data = RSTRING_PTR(sd->buf);
VALUE result = rb_str_new(0, 0);

while (input.pos < input.size) {
const char* output_data = RSTRING_PTR(sd->buf);
ZSTD_outBuffer output = { (void*)output_data, sd->buf_size, 0 };
size_t const ret = zstd_stream_decompress(sd->dctx, &output, &input, false);

if (ZSTD_isError(ret)) {
rb_raise(rb_eRuntimeError, "decompress error error code: %s", ZSTD_getErrorName(ret));
}
rb_str_cat(result, output.dst, output.pos);
if (output.pos > 0) {
rb_str_cat(result, output.dst, output.pos);
}
if (ret == 0 && output.pos == 0) {
break;
}
}
return result;
}
Expand Down
110 changes: 70 additions & 40 deletions ext/zstdruby/zstdruby.c
Original file line number Diff line number Diff line change
Expand Up @@ -39,61 +39,91 @@ static VALUE rb_compress(int argc, VALUE *argv, VALUE self)
return output;
}

static VALUE decompress_buffered(ZSTD_DCtx* dctx, const char* input_data, size_t input_size)
{
ZSTD_inBuffer input = { input_data, input_size, 0 };
VALUE result = rb_str_new(0, 0);
static VALUE decode_one_frame(ZSTD_DCtx* dctx, const unsigned char* src, size_t size, VALUE kwargs) {
VALUE out = rb_str_buf_new(0);
size_t cap = ZSTD_DStreamOutSize();
char *buf = ALLOC_N(char, cap);
ZSTD_inBuffer in = (ZSTD_inBuffer){ src, size, 0 };

while (input.pos < input.size) {
ZSTD_outBuffer output = { NULL, 0, 0 };
output.size += ZSTD_DStreamOutSize();
VALUE output_string = rb_str_new(NULL, output.size);
output.dst = RSTRING_PTR(output_string);
ZSTD_DCtx_reset(dctx, ZSTD_reset_session_only);
set_decompress_params(dctx, kwargs);

size_t ret = zstd_stream_decompress(dctx, &output, &input, false);
for (;;) {
ZSTD_outBuffer o = (ZSTD_outBuffer){ buf, cap, 0 };
size_t ret = ZSTD_decompressStream(dctx, &o, &in);
if (ZSTD_isError(ret)) {
ZSTD_freeDCtx(dctx);
rb_raise(rb_eRuntimeError, "%s: %s", "ZSTD_decompressStream failed", ZSTD_getErrorName(ret));
xfree(buf);
rb_raise(rb_eRuntimeError, "ZSTD_decompressStream failed: %s", ZSTD_getErrorName(ret));
}
if (o.pos) {
rb_str_cat(out, buf, o.pos);
}
if (ret == 0) {
break;
}
rb_str_cat(result, output.dst, output.pos);
}
ZSTD_freeDCtx(dctx);
return result;
xfree(buf);
return out;
}

static VALUE decompress_buffered(ZSTD_DCtx* dctx, const char* data, size_t len) {
return decode_one_frame(dctx, (const unsigned char*)data, len, Qnil);
}

static VALUE rb_decompress(int argc, VALUE *argv, VALUE self)
{
VALUE input_value;
VALUE kwargs;
VALUE input_value, kwargs;
rb_scan_args(argc, argv, "10:", &input_value, &kwargs);
StringValue(input_value);
char* input_data = RSTRING_PTR(input_value);
size_t input_size = RSTRING_LEN(input_value);
ZSTD_DCtx* const dctx = ZSTD_createDCtx();
if (dctx == NULL) {
rb_raise(rb_eRuntimeError, "%s", "ZSTD_createDCtx failed");
}
set_decompress_params(dctx, kwargs);

unsigned long long const uncompressed_size = ZSTD_getFrameContentSize(input_data, input_size);
if (uncompressed_size == ZSTD_CONTENTSIZE_ERROR) {
rb_raise(rb_eRuntimeError, "%s: %s", "not compressed by zstd", ZSTD_getErrorName(uncompressed_size));
}
// ZSTD_decompressStream may be called multiple times when ZSTD_CONTENTSIZE_UNKNOWN, causing slowness.
// Therefore, we will not standardize on ZSTD_decompressStream
if (uncompressed_size == ZSTD_CONTENTSIZE_UNKNOWN) {
return decompress_buffered(dctx, input_data, input_size);
}
size_t in_size = RSTRING_LEN(input_value);
const unsigned char *in_r = (const unsigned char *)RSTRING_PTR(input_value);
unsigned char *in = ALLOC_N(unsigned char, in_size);
memcpy(in, in_r, in_size);

size_t off = 0;
const uint32_t ZSTD_MAGIC = 0xFD2FB528U;
const uint32_t SKIP_LO = 0x184D2A50U; /* ...5F */

while (off + 4 <= in_size) {
uint32_t magic = (uint32_t)in[off]
| ((uint32_t)in[off+1] << 8)
| ((uint32_t)in[off+2] << 16)
| ((uint32_t)in[off+3] << 24);

if ((magic & 0xFFFFFFF0U) == (SKIP_LO & 0xFFFFFFF0U)) {
if (off + 8 > in_size) break;
uint32_t skipLen = (uint32_t)in[off+4]
| ((uint32_t)in[off+5] << 8)
| ((uint32_t)in[off+6] << 16)
| ((uint32_t)in[off+7] << 24);
size_t adv = (size_t)8 + (size_t)skipLen;
if (off + adv > in_size) break;
off += adv;
continue;
}

VALUE output = rb_str_new(NULL, uncompressed_size);
char* output_data = RSTRING_PTR(output);
if (magic == ZSTD_MAGIC) {
ZSTD_DCtx *dctx = ZSTD_createDCtx();
if (!dctx) {
xfree(in);
rb_raise(rb_eRuntimeError, "ZSTD_createDCtx failed");
}

VALUE out = decode_one_frame(dctx, in + off, in_size - off, kwargs);

size_t const decompress_size = zstd_decompress(dctx, output_data, uncompressed_size, input_data, input_size, false);
if (ZSTD_isError(decompress_size)) {
rb_raise(rb_eRuntimeError, "%s: %s", "decompress error", ZSTD_getErrorName(decompress_size));
ZSTD_freeDCtx(dctx);
xfree(in);
RB_GC_GUARD(input_value);
return out;
}

off += 1;
}
ZSTD_freeDCtx(dctx);
return output;

xfree(in);
RB_GC_GUARD(input_value);
rb_raise(rb_eRuntimeError, "not a zstd frame (magic not found)");
}

static void free_cdict(void *dict)
Expand Down