Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions include/sparrow_ipc/serializer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,11 @@ namespace sparrow_ipc
requires std::same_as<std::ranges::range_value_t<R>, sparrow::record_batch>
void write(const R& record_batches)
{
if (std::ranges::empty(record_batches))
{
return;
}

if (m_ended)
{
throw std::runtime_error("Cannot append to a serializer that has been ended");
Expand Down
17 changes: 13 additions & 4 deletions src/deserialize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ namespace sparrow_ipc
const std::optional<std::vector<sparrow::metadata_pair>>& metadata = field_metadata[field_idx++];
const std::string name = field->name() == nullptr ? "" : field->name()->str();
const auto field_type = field->type_type();
// TODO rename all the deserialize_non_owning... fcts since this is not correct anymore
const auto deserialize_non_owning_primitive_array_lambda = [&]<typename T>()
{
return deserialize_non_owning_primitive_array<T>(
Expand Down Expand Up @@ -207,8 +208,20 @@ namespace sparrow_ipc
std::vector<std::optional<std::vector<sparrow::metadata_pair>>> fields_metadata;
do
{
// Check for end-of-stream marker here as data could contain only that (if no record batches present/written)
if (data.size() >= 8 && is_end_of_stream(data.subspan(0, 8)))
{
break;
}

const auto [encapsulated_message, rest] = extract_encapsulated_message(data);
const org::apache::arrow::flatbuf::Message* message = encapsulated_message.flat_buffer_message();

if (message == nullptr)
{
throw std::invalid_argument("Extracted flatbuffers message is null.");
}

switch (message->header_type())
{
case org::apache::arrow::flatbuf::MessageHeader::Schema:
Expand Down Expand Up @@ -269,10 +282,6 @@ namespace sparrow_ipc
throw std::runtime_error("Unknown message header type.");
}
data = rest;
if (is_end_of_stream(data.subspan(0, 8)))
{
break;
}
} while (!data.empty());
return record_batches;
}
Expand Down
17 changes: 14 additions & 3 deletions src/deserialize_fixedsizebinary_array.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

namespace sparrow_ipc
{
// TODO add compression here and tests (not available for this type in apache arrow integration tests files)
sparrow::fixed_width_binary_array deserialize_non_owning_fixedwidthbinary(
const org::apache::arrow::flatbuf::RecordBatch& record_batch,
std::span<const uint8_t> body,
Expand All @@ -23,10 +22,22 @@ namespace sparrow_ipc
nullptr
);

const auto compression = record_batch.compression();
std::vector<arrow_array_private_data::optionally_owned_buffer> buffers;

auto validity_buffer_span = utils::get_buffer(record_batch, body, buffer_index);
buffers.push_back(validity_buffer_span);
buffers.push_back(utils::get_buffer(record_batch, body, buffer_index));
auto data_buffer_span = utils::get_buffer(record_batch, body, buffer_index);

if (compression)
{
buffers.push_back(utils::get_decompressed_buffer(validity_buffer_span, compression));
buffers.push_back(utils::get_decompressed_buffer(data_buffer_span, compression));
}
else
{
buffers.push_back(validity_buffer_span);
buffers.push_back(data_buffer_span);
}

// TODO bitmap_ptr is not used anymore... Leave it for now, and remove later if no need confirmed
const auto [bitmap_ptr, null_count] = utils::get_bitmap_pointer_and_null_count(validity_buffer_span, record_batch.length());
Expand Down
4 changes: 2 additions & 2 deletions src/encapsulated_message.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,10 @@ namespace sparrow_ipc
const std::span<const uint8_t> continuation_span = data.subspan(0, 4);
if (!is_continuation(continuation_span))
{
throw std::runtime_error("Buffer starts with continuation bytes, expected a valid message.");
throw std::runtime_error("Buffer should start with continuation bytes, expected a valid message.");
}
encapsulated_message message(data);
std::span<const uint8_t> rest = data.subspan(message.total_length());
return {std::move(message), std::move(rest)};
}
}
}
66 changes: 61 additions & 5 deletions tests/test_de_serialization_with_files.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,19 @@ const std::filesystem::path tests_resources_files_path_with_compression = arrow_

const std::vector<std::filesystem::path> files_paths_to_test = {
tests_resources_files_path / "generated_primitive",
// tests_resources_files_path / "generated_primitive_large_offsets",
tests_resources_files_path / "generated_primitive_zerolength",
// tests_resources_files_path / "generated_primitive_no_batches"
tests_resources_files_path / "generated_primitive_no_batches",
tests_resources_files_path / "generated_binary",
tests_resources_files_path / "generated_large_binary",
tests_resources_files_path / "generated_binary_zerolength",
tests_resources_files_path / "generated_binary_no_batches",
};

const std::vector<std::filesystem::path> files_paths_to_test_with_compression = {
tests_resources_files_path_with_compression / "generated_lz4",
tests_resources_files_path_with_compression/ "generated_uncompressible_lz4"
// tests_resources_files_path_with_compression / "generated_zstd"
// tests_resources_files_path_with_compression/ "generated_uncompressible_zstd"
tests_resources_files_path_with_compression/ "generated_uncompressible_lz4",
// tests_resources_files_path_with_compression / "generated_zstd",
// tests_resources_files_path_with_compression/ "generated_uncompressible_zstd",
};


Expand Down Expand Up @@ -236,4 +239,57 @@ TEST_SUITE("Integration tests")
}
}
}

TEST_CASE("Round trip of classic test files serialization/deserialization using LZ4 compression")
{
for (const auto& file_path : files_paths_to_test)
{
std::filesystem::path json_path = file_path;
json_path.replace_extension(".json");

// Load the JSON file
auto json_data = load_json_file(json_path);
CHECK(json_data != nullptr);

const size_t num_batches = get_number_of_batches(json_path);
std::vector<sparrow::record_batch> record_batches_from_json;
for (size_t batch_idx = 0; batch_idx < num_batches; ++batch_idx)
{
INFO("Processing batch " << batch_idx << " of " << num_batches);
record_batches_from_json.emplace_back(
sparrow::json_reader::build_record_batch_from_json(json_data, batch_idx)
);
}

// Load stream file
std::filesystem::path stream_file_path = file_path;
stream_file_path.replace_extension(".stream");
std::ifstream stream_file(stream_file_path, std::ios::in | std::ios::binary);
REQUIRE(stream_file.is_open());
const std::vector<uint8_t> stream_data(
(std::istreambuf_iterator<char>(stream_file)),
(std::istreambuf_iterator<char>())
);
stream_file.close();

// Process the stream file
const auto record_batches_from_stream = sparrow_ipc::deserialize_stream(
std::span<const uint8_t>(stream_data)
);

// Serialize from json with LZ4 compression
std::vector<uint8_t> serialized_data;
sparrow_ipc::memory_output_stream stream(serialized_data);
sparrow_ipc::serializer serializer(stream, sparrow_ipc::CompressionType::LZ4_FRAME);
serializer << record_batches_from_json << sparrow_ipc::end_stream;

// Deserialize
const auto deserialized_serialized_data = sparrow_ipc::deserialize_stream(
std::span<const uint8_t>(serialized_data)
);

// Compare
compare_record_batches(record_batches_from_stream, deserialized_serialized_data);
}
}
}
Loading