diff --git a/include/sparrow_ipc/serializer.hpp b/include/sparrow_ipc/serializer.hpp index e867f76..989f735 100644 --- a/include/sparrow_ipc/serializer.hpp +++ b/include/sparrow_ipc/serializer.hpp @@ -82,6 +82,11 @@ namespace sparrow_ipc requires std::same_as, sparrow::record_batch> void write(const R& record_batches) { + if (std::ranges::empty(record_batches)) + { + return; + } + if (m_ended) { throw std::runtime_error("Cannot append to a serializer that has been ended"); diff --git a/src/deserialize.cpp b/src/deserialize.cpp index 55f863d..d4b98d1 100644 --- a/src/deserialize.cpp +++ b/src/deserialize.cpp @@ -61,6 +61,7 @@ namespace sparrow_ipc const std::optional>& metadata = field_metadata[field_idx++]; const std::string name = field->name() == nullptr ? "" : field->name()->str(); const auto field_type = field->type_type(); + // TODO rename all the deserialize_non_owning... fcts since this is not correct anymore const auto deserialize_non_owning_primitive_array_lambda = [&]() { return deserialize_non_owning_primitive_array( @@ -207,8 +208,20 @@ namespace sparrow_ipc std::vector>> fields_metadata; do { + // Check for end-of-stream marker here as data could contain only that (if no record batches present/written) + if (data.size() >= 8 && is_end_of_stream(data.subspan(0, 8))) + { + break; + } + const auto [encapsulated_message, rest] = extract_encapsulated_message(data); const org::apache::arrow::flatbuf::Message* message = encapsulated_message.flat_buffer_message(); + + if (message == nullptr) + { + throw std::invalid_argument("Extracted flatbuffers message is null."); + } + switch (message->header_type()) { case org::apache::arrow::flatbuf::MessageHeader::Schema: @@ -269,10 +282,6 @@ namespace sparrow_ipc throw std::runtime_error("Unknown message header type."); } data = rest; - if (is_end_of_stream(data.subspan(0, 8))) - { - break; - } } while (!data.empty()); return record_batches; } diff --git a/src/deserialize_fixedsizebinary_array.cpp b/src/deserialize_fixedsizebinary_array.cpp index e711e6f..279afa5 100644 --- a/src/deserialize_fixedsizebinary_array.cpp +++ b/src/deserialize_fixedsizebinary_array.cpp @@ -2,7 +2,6 @@ namespace sparrow_ipc { - // TODO add compression here and tests (not available for this type in apache arrow integration tests files) sparrow::fixed_width_binary_array deserialize_non_owning_fixedwidthbinary( const org::apache::arrow::flatbuf::RecordBatch& record_batch, std::span body, @@ -23,10 +22,22 @@ namespace sparrow_ipc nullptr ); + const auto compression = record_batch.compression(); std::vector buffers; + auto validity_buffer_span = utils::get_buffer(record_batch, body, buffer_index); - buffers.push_back(validity_buffer_span); - buffers.push_back(utils::get_buffer(record_batch, body, buffer_index)); + auto data_buffer_span = utils::get_buffer(record_batch, body, buffer_index); + + if (compression) + { + buffers.push_back(utils::get_decompressed_buffer(validity_buffer_span, compression)); + buffers.push_back(utils::get_decompressed_buffer(data_buffer_span, compression)); + } + else + { + buffers.push_back(validity_buffer_span); + buffers.push_back(data_buffer_span); + } // TODO bitmap_ptr is not used anymore... Leave it for now, and remove later if no need confirmed const auto [bitmap_ptr, null_count] = utils::get_bitmap_pointer_and_null_count(validity_buffer_span, record_batch.length()); diff --git a/src/encapsulated_message.cpp b/src/encapsulated_message.cpp index 548d878..c0f87b7 100644 --- a/src/encapsulated_message.cpp +++ b/src/encapsulated_message.cpp @@ -106,10 +106,10 @@ namespace sparrow_ipc const std::span continuation_span = data.subspan(0, 4); if (!is_continuation(continuation_span)) { - throw std::runtime_error("Buffer starts with continuation bytes, expected a valid message."); + throw std::runtime_error("Buffer should start with continuation bytes, expected a valid message."); } encapsulated_message message(data); std::span rest = data.subspan(message.total_length()); return {std::move(message), std::move(rest)}; } -} \ No newline at end of file +} diff --git a/tests/test_de_serialization_with_files.cpp b/tests/test_de_serialization_with_files.cpp index 360d67c..b0243a1 100644 --- a/tests/test_de_serialization_with_files.cpp +++ b/tests/test_de_serialization_with_files.cpp @@ -27,16 +27,19 @@ const std::filesystem::path tests_resources_files_path_with_compression = arrow_ const std::vector files_paths_to_test = { tests_resources_files_path / "generated_primitive", - // tests_resources_files_path / "generated_primitive_large_offsets", tests_resources_files_path / "generated_primitive_zerolength", - // tests_resources_files_path / "generated_primitive_no_batches" + tests_resources_files_path / "generated_primitive_no_batches", + tests_resources_files_path / "generated_binary", + tests_resources_files_path / "generated_large_binary", + tests_resources_files_path / "generated_binary_zerolength", + tests_resources_files_path / "generated_binary_no_batches", }; const std::vector files_paths_to_test_with_compression = { tests_resources_files_path_with_compression / "generated_lz4", - tests_resources_files_path_with_compression/ "generated_uncompressible_lz4" -// tests_resources_files_path_with_compression / "generated_zstd" -// tests_resources_files_path_with_compression/ "generated_uncompressible_zstd" + tests_resources_files_path_with_compression/ "generated_uncompressible_lz4", +// tests_resources_files_path_with_compression / "generated_zstd", +// tests_resources_files_path_with_compression/ "generated_uncompressible_zstd", }; @@ -236,4 +239,57 @@ TEST_SUITE("Integration tests") } } } + + TEST_CASE("Round trip of classic test files serialization/deserialization using LZ4 compression") + { + for (const auto& file_path : files_paths_to_test) + { + std::filesystem::path json_path = file_path; + json_path.replace_extension(".json"); + + // Load the JSON file + auto json_data = load_json_file(json_path); + CHECK(json_data != nullptr); + + const size_t num_batches = get_number_of_batches(json_path); + std::vector record_batches_from_json; + for (size_t batch_idx = 0; batch_idx < num_batches; ++batch_idx) + { + INFO("Processing batch " << batch_idx << " of " << num_batches); + record_batches_from_json.emplace_back( + sparrow::json_reader::build_record_batch_from_json(json_data, batch_idx) + ); + } + + // Load stream file + std::filesystem::path stream_file_path = file_path; + stream_file_path.replace_extension(".stream"); + std::ifstream stream_file(stream_file_path, std::ios::in | std::ios::binary); + REQUIRE(stream_file.is_open()); + const std::vector stream_data( + (std::istreambuf_iterator(stream_file)), + (std::istreambuf_iterator()) + ); + stream_file.close(); + + // Process the stream file + const auto record_batches_from_stream = sparrow_ipc::deserialize_stream( + std::span(stream_data) + ); + + // Serialize from json with LZ4 compression + std::vector serialized_data; + sparrow_ipc::memory_output_stream stream(serialized_data); + sparrow_ipc::serializer serializer(stream, sparrow_ipc::CompressionType::LZ4_FRAME); + serializer << record_batches_from_json << sparrow_ipc::end_stream; + + // Deserialize + const auto deserialized_serialized_data = sparrow_ipc::deserialize_stream( + std::span(serialized_data) + ); + + // Compare + compare_record_batches(record_batches_from_stream, deserialized_serialized_data); + } + } }