diff --git a/CMakeLists.txt b/CMakeLists.txt index 2d0d0d9..52387a8 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -29,13 +29,12 @@ set(SPARROW_IPC_INCLUDE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/include) set(SPARROW_IPC_SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/src) set(SPARROW_IPC_HEADERS - # config ${SPARROW_IPC_INCLUDE_DIR}/config/config.hpp - ${SPARROW_IPC_INCLUDE_DIR}/sparrow-ipc.hpp + ${SPARROW_IPC_INCLUDE_DIR}/serialize.hpp ) set(SPARROW_IPC_SRC - ${SPARROW_IPC_SOURCE_DIR}/sparrow-ipc.cpp + ${SPARROW_IPC_SOURCE_DIR}/serialize.cpp ) set(SCHEMA_DIR ${CMAKE_BINARY_DIR}/format) @@ -102,6 +101,13 @@ find_package(sparrow CONFIG REQUIRED) add_library(sparrow-ipc ${SPARROW_IPC_LIBRARY_TYPE} ${SPARROW_IPC_SRC} ${SPARROW_IPC_HEADERS}) target_compile_definitions(sparrow-ipc PUBLIC ${SPARROW_IPC_COMPILE_DEFINITIONS}) + +if(UNIX) + target_compile_options(sparrow-ipc PRIVATE "-fvisibility=hidden") +else() + target_compile_definitions(sparrow-ipc PRIVATE SPARROW_IPC_EXPORTS) +endif() + target_include_directories(sparrow-ipc PUBLIC ${SPARROW_IPC_INCLUDE_DIR} PRIVATE ${SPARROW_IPC_SOURCE_DIR} ) target_link_libraries(sparrow-ipc PRIVATE flatbuffers_interface) target_link_libraries(sparrow-ipc PUBLIC flatbuffers::flatbuffers sparrow::sparrow) diff --git a/include/serialize.hpp b/include/serialize.hpp new file mode 100644 index 0000000..6444530 --- /dev/null +++ b/include/serialize.hpp @@ -0,0 +1,13 @@ +#pragma once + +#include +#include "sparrow.hpp" + +#include "config/config.hpp" + +//TODO split serialize/deserialize fcts in two different files or just rename the current one? +template +SPARROW_IPC_API std::vector serialize_primitive_array(const sparrow::primitive_array& arr); + +template +SPARROW_IPC_API sparrow::primitive_array deserialize_primitive_array(const std::vector& buffer); diff --git a/include/sparrow-ipc.hpp b/include/sparrow-ipc.hpp deleted file mode 100644 index 76f269e..0000000 --- a/include/sparrow-ipc.hpp +++ /dev/null @@ -1,3 +0,0 @@ -#include "../include/config/config.hpp" - -SPARROW_IPC_API void fake_func_for_now(); diff --git a/src/serialize.cpp b/src/serialize.cpp new file mode 100644 index 0000000..be5e84a --- /dev/null +++ b/src/serialize.cpp @@ -0,0 +1,308 @@ +#include +#include +#include +#include +#include +#include + +#include "Message_generated.h" +#include "Schema_generated.h" + +#include "serialize.hpp" + +namespace +{ + // Aligns a value to the next multiple of 8, as required by the Arrow IPC format for message bodies. + int64_t align_to_8(int64_t n) + { + return (n + 7) & -8; + } + + // TODO Complete this with all possible formats? + std::pair> + get_flatbuffer_type(flatbuffers::FlatBufferBuilder& builder, const char* format_str) + { + if (format_str == sparrow::data_type_to_format(sparrow::data_type::INT32)) + { + auto int_type = org::apache::arrow::flatbuf::CreateInt(builder, 32, true); + return {org::apache::arrow::flatbuf::Type::Int, int_type.Union()}; + } + else if (format_str == sparrow::data_type_to_format(sparrow::data_type::FLOAT)) + { + auto fp_type = org::apache::arrow::flatbuf::CreateFloatingPoint( + builder, org::apache::arrow::flatbuf::Precision::SINGLE); + return {org::apache::arrow::flatbuf::Type::FloatingPoint, fp_type.Union()}; + } + else if (format_str == sparrow::data_type_to_format(sparrow::data_type::DOUBLE)) + { + auto fp_type = org::apache::arrow::flatbuf::CreateFloatingPoint( + builder, org::apache::arrow::flatbuf::Precision::DOUBLE); + return {org::apache::arrow::flatbuf::Type::FloatingPoint, fp_type.Union()}; + } + else + { + throw std::runtime_error("Unsupported data type for serialization"); + } + } +} + +template +std::vector serialize_primitive_array(const sparrow::primitive_array& arr) +{ + // This function serializes a sparrow::primitive_array into a byte vector that is compliant + // with the Apache Arrow IPC Streaming Format. It constructs a stream containing two messages: + // 1. A Schema message: Describes the data's metadata (field name, type, nullability). + // 2. A RecordBatch message: Contains the actual array data (null count, length, and raw buffers). + // This two-part structure makes the data self-describing and readable by other Arrow-native tools. + // The implementation adheres to the specification by correctly handling: + // - Message order (Schema first, then RecordBatch). + // - The encapsulated message format (4-byte metadata length prefix). + // - 8-byte padding and alignment for the message body. + // - Correctly populating the Flatbuffer-defined metadata for both messages. + + // Create a mutable copy of the input array to allow moving its internal structures + sparrow::primitive_array mutable_arr = arr; + auto [arrow_arr, arrow_schema] = sparrow::extract_arrow_structures(std::move(mutable_arr)); + + // This will be the final buffer holding the complete IPC stream. + std::vector final_buffer; + + // I - Serialize the Schema message + // An Arrow IPC stream must start with a Schema message + { + // Create a new builder for the Schema message's metadata + flatbuffers::FlatBufferBuilder schema_builder; + + // Create the Field metadata, which describes a single column (or array) + flatbuffers::Offset fb_name_offset = 0; + if (arrow_schema.name) + { + fb_name_offset = schema_builder.CreateString(arrow_schema.name); + } + + // Determine the Flatbuffer type information from the C schema's format string + auto [type_enum, type_offset] = get_flatbuffer_type(schema_builder, arrow_schema.format); + + // Handle metadata + flatbuffers::Offset>> + fb_metadata_offset = 0; + + if (arr.metadata()) + { + sparrow::key_value_view metadata_view = *(arr.metadata()); + std::vector> kv_offsets; + + auto mv_it = metadata_view.cbegin(); + for (auto i = 0; i < metadata_view.size(); ++i, ++mv_it) + { + auto key_offset = schema_builder.CreateString(std::string((*mv_it).first)); + auto value_offset = schema_builder.CreateString(std::string((*mv_it).second)); + kv_offsets.push_back( + org::apache::arrow::flatbuf::CreateKeyValue(schema_builder, key_offset, value_offset)); + } + fb_metadata_offset = schema_builder.CreateVector(kv_offsets); + } + + // Build the Field object + auto fb_field = org::apache::arrow::flatbuf::CreateField( + schema_builder, + fb_name_offset, + (arrow_schema.flags & static_cast(sparrow::ArrowFlag::NULLABLE)) != 0, + type_enum, + type_offset, + 0, // dictionary + 0, // children + fb_metadata_offset); + + // A Schema contains a vector of fields. For this primitive array, there is only one + std::vector> fields_vec = {fb_field}; + auto fb_fields = schema_builder.CreateVector(fields_vec); + + // Build the Schema object from the vector of fields + auto schema_offset = org::apache::arrow::flatbuf::CreateSchema(schema_builder, org::apache::arrow::flatbuf::Endianness::Little, fb_fields); + + // Wrap the Schema in a top-level Message, which is the standard IPC envelope + auto schema_message_offset = org::apache::arrow::flatbuf::CreateMessage( + schema_builder, + org::apache::arrow::flatbuf::MetadataVersion::V5, + org::apache::arrow::flatbuf::MessageHeader::Schema, + schema_offset.Union(), + 0 + ); + schema_builder.Finish(schema_message_offset); + + // Assemble the Schema message bytes + uint32_t schema_len = schema_builder.GetSize(); // Get the size of the serialized metadata + final_buffer.resize(sizeof(uint32_t) + schema_len); // Resize the buffer to hold the message + // Copy the metadata into the buffer, after the 4-byte length prefix + memcpy(final_buffer.data() + sizeof(uint32_t), schema_builder.GetBufferPointer(), schema_len); + // Write the 4-byte metadata length at the beginning of the message + *(reinterpret_cast(final_buffer.data())) = schema_len; + } + + // II - Serialize the RecordBatch message + // After the Schema, we send the RecordBatch containing the actual data + { + // Create a new builder for the RecordBatch message's metadata + flatbuffers::FlatBufferBuilder batch_builder; + + // arrow_arr.buffers[0] is the validity bitmap + // arrow_arr.buffers[1] is the data buffer + const uint8_t* validity_bitmap = reinterpret_cast(arrow_arr.buffers[0]); + const uint8_t* data_buffer = reinterpret_cast(arrow_arr.buffers[1]); + + // Calculate the size of the validity and data buffers + int64_t validity_size = (arrow_arr.length + 7) / 8; + int64_t data_size = arrow_arr.length * sizeof(T); + int64_t body_len = validity_size + data_size; // The total size of the message body + + // Create Flatbuffer descriptions for the data buffers + org::apache::arrow::flatbuf::Buffer validity_buffer_struct(0, validity_size); + org::apache::arrow::flatbuf::Buffer data_buffer_struct(validity_size, data_size); + // Create the FieldNode, which describes the layout of the array data + org::apache::arrow::flatbuf::FieldNode field_node_struct(arrow_arr.length, arrow_arr.null_count); + + // A RecordBatch contains a vector of nodes and a vector of buffers + auto fb_nodes_vector = batch_builder.CreateVectorOfStructs(&field_node_struct, 1); + std::vector buffers_vec = {validity_buffer_struct, data_buffer_struct}; + auto fb_buffers_vector = batch_builder.CreateVectorOfStructs(buffers_vec); + + // Build the RecordBatch metadata object + auto record_batch_offset = org::apache::arrow::flatbuf::CreateRecordBatch(batch_builder, arrow_arr.length, fb_nodes_vector, fb_buffers_vector); + + // Wrap the RecordBatch in a top-level Message + auto batch_message_offset = org::apache::arrow::flatbuf::CreateMessage( + batch_builder, + org::apache::arrow::flatbuf::MetadataVersion::V5, + org::apache::arrow::flatbuf::MessageHeader::RecordBatch, + record_batch_offset.Union(), + body_len + ); + batch_builder.Finish(batch_message_offset); + + // III - Append the RecordBatch message to the final buffer + uint32_t batch_meta_len = batch_builder.GetSize(); // Get the size of the batch metadata + int64_t aligned_batch_meta_len = align_to_8(batch_meta_len); // Calculate the padded length + + size_t current_size = final_buffer.size(); // Get the current size (which is the end of the Schema message) + // Resize the buffer to append the new message + final_buffer.resize(current_size + sizeof(uint32_t) + aligned_batch_meta_len + body_len); + uint8_t* dst = final_buffer.data() + current_size; // Get a pointer to where the new message will start + + // Write the 4-byte metadata length for the RecordBatch message + *(reinterpret_cast(dst)) = batch_meta_len; + dst += sizeof(uint32_t); + // Copy the RecordBatch metadata into the buffer + memcpy(dst, batch_builder.GetBufferPointer(), batch_meta_len); + // Add padding to align the body to an 8-byte boundary + memset(dst + batch_meta_len, 0, aligned_batch_meta_len - batch_meta_len); + dst += aligned_batch_meta_len; + // Copy the actual data buffers (the message body) into the buffer + if (validity_bitmap) + { + memcpy(dst, validity_bitmap, validity_size); + } + else + { + // If validity_bitmap is null, it means there are no nulls + memset(dst, 0xFF, validity_size); + } + dst += validity_size; + if (data_buffer) + { + memcpy(dst, data_buffer, data_size); + } + } + + // Release the memory managed by the C structures + arrow_arr.release(&arrow_arr); + arrow_schema.release(&arrow_schema); + + // Return the final buffer containing the complete IPC stream + return final_buffer; +} + +template +sparrow::primitive_array deserialize_primitive_array(const std::vector& buffer) { + const uint8_t* buf_ptr = buffer.data(); + size_t current_offset = 0; + + // I - Deserialize the Schema message + uint32_t schema_meta_len = *(reinterpret_cast(buf_ptr + current_offset)); + current_offset += sizeof(uint32_t); + auto schema_message = org::apache::arrow::flatbuf::GetMessage(buf_ptr + current_offset); + if (schema_message->header_type() != org::apache::arrow::flatbuf::MessageHeader::Schema) + { + throw std::runtime_error("Expected Schema message at the start of the buffer."); + } + auto flatbuffer_schema = static_cast(schema_message->header()); + auto fields = flatbuffer_schema->fields(); + if (fields->size() != 1) + { + throw std::runtime_error("Expected schema with exactly one field for primitive_array."); + } + bool is_nullable = fields->Get(0)->nullable(); + current_offset += schema_meta_len; + + // II - Deserialize the RecordBatch message + uint32_t batch_meta_len = *(reinterpret_cast(buf_ptr + current_offset)); + current_offset += sizeof(uint32_t); + auto batch_message = org::apache::arrow::flatbuf::GetMessage(buf_ptr + current_offset); + if (batch_message->header_type() != org::apache::arrow::flatbuf::MessageHeader::RecordBatch) + { + throw std::runtime_error("Expected RecordBatch message, but got a different type."); + } + auto record_batch = static_cast(batch_message->header()); + current_offset += align_to_8(batch_meta_len); + const uint8_t* body_ptr = buf_ptr + current_offset; + + // Extract metadata from the RecordBatch + auto buffers_meta = record_batch->buffers(); + auto nodes_meta = record_batch->nodes(); + auto node_meta = nodes_meta->Get(0); + + // The body contains the validity bitmap and the data buffer concatenated + // We need to copy this data into memory owned by the new ArrowArray + int64_t validity_len = buffers_meta->Get(0)->length(); + int64_t data_len = buffers_meta->Get(1)->length(); + + uint8_t* validity_buffer_copy = new uint8_t[validity_len]; + memcpy(validity_buffer_copy, body_ptr + buffers_meta->Get(0)->offset(), validity_len); + + uint8_t* data_buffer_copy = new uint8_t[data_len]; + memcpy(data_buffer_copy, body_ptr + buffers_meta->Get(1)->offset(), data_len); + + // Get name + std::optional name; + const flatbuffers::String* fb_name_flatbuffer = fields->Get(0)->name(); + if (fb_name_flatbuffer) + { + name = std::string_view(fb_name_flatbuffer->c_str(), fb_name_flatbuffer->size()); + } + + // Handle metadata + std::optional> metadata; + auto fb_metadata = fields->Get(0)->custom_metadata(); + if (fb_metadata && !fb_metadata->empty()) + { + metadata = std::vector(); + metadata->reserve(fb_metadata->size()); + for (const auto& kv : *fb_metadata) + { + metadata->emplace_back(kv->key()->c_str(), kv->value()->c_str()); + } + } + + auto data = sparrow::u8_buffer(reinterpret_cast(data_buffer_copy), node_meta->length()); + auto bitmap = sparrow::validity_bitmap(validity_buffer_copy, node_meta->length()); + + return sparrow::primitive_array(std::move(data), node_meta->length(), std::move(bitmap), name, metadata); +} + +// Explicit template instantiation +template SPARROW_IPC_API std::vector serialize_primitive_array(const sparrow::primitive_array& arr); +template SPARROW_IPC_API sparrow::primitive_array deserialize_primitive_array(const std::vector& buffer); +template SPARROW_IPC_API std::vector serialize_primitive_array(const sparrow::primitive_array& arr); +template SPARROW_IPC_API sparrow::primitive_array deserialize_primitive_array(const std::vector& buffer); +template SPARROW_IPC_API std::vector serialize_primitive_array(const sparrow::primitive_array& arr); +template SPARROW_IPC_API sparrow::primitive_array deserialize_primitive_array(const std::vector& buffer); diff --git a/src/sparrow-ipc.cpp b/src/sparrow-ipc.cpp deleted file mode 100644 index 172824c..0000000 --- a/src/sparrow-ipc.cpp +++ /dev/null @@ -1,7 +0,0 @@ -#include "sparrow/sparrow.hpp" - -#include "../include/sparrow-ipc.hpp" -#include "../generated/Schema_generated.h" - -void fake_func_for_now() -{} diff --git a/tests/test.cpp b/tests/test.cpp index 853cc45..d81c9aa 100644 --- a/tests/test.cpp +++ b/tests/test.cpp @@ -1,25 +1,158 @@ #define DOCTEST_CONFIG_IMPLEMENT_WITH_MAIN -#include "sparrow/sparrow.hpp" +#include +#include +#include +#include + #include "doctest/doctest.h" +#include "sparrow.hpp" + +#include "../include/serialize.hpp" + +using testing_types = std::tuple< + int, + float, + double>; + +template +void compare_bitmap(sparrow::primitive_array& pa1, sparrow::primitive_array& pa2) +{ + const auto pa1_bitmap = pa1.bitmap(); + const auto pa2_bitmap = pa2.bitmap(); + + CHECK_EQ(pa1_bitmap.size(), pa2_bitmap.size()); + auto pa1_it = pa1_bitmap.begin(); + auto pa2_it = pa2_bitmap.begin(); + for (size_t i = 0; i < pa1_bitmap.size(); ++i) + { + CHECK_EQ(*pa1_it, *pa2_it); + ++pa1_it; + ++pa2_it; + } +} + +template +void compare_metadata(sparrow::primitive_array& pa1, sparrow::primitive_array& pa2) +{ + if (!pa1.metadata().has_value()) + { + CHECK(!pa2.metadata().has_value()); + return; + } + + CHECK(pa2.metadata().has_value()); + sparrow::key_value_view kvs1_view = *(pa1.metadata()); + sparrow::key_value_view kvs2_view = *(pa2.metadata()); + + CHECK_EQ(kvs1_view.size(), kvs2_view.size()); + std::vector> kvs1, kvs2; + auto kvs1_it = kvs1_view.cbegin(); + auto kvs2_it = kvs2_view.cbegin(); + for (auto i = 0; i < kvs1_view.size(); ++i) + { + CHECK_EQ(*kvs1_it, *kvs2_it); + ++kvs1_it; + ++kvs2_it; + } +} + +TEST_CASE_TEMPLATE_DEFINE("Serialize and Deserialize primitive_array", T, primitive_array_types) +{ + namespace sp = sparrow; + + auto create_primitive_array = []() -> sp::primitive_array { + if constexpr (std::is_same_v) + { + return {10, 20, 30, 40, 50}; + } + else if constexpr (std::is_same_v) + { + return {10.5f, 20.5f, 30.5f, 40.5f, 50.5f}; + } + else if constexpr (std::is_same_v) + { + return {10.1, 20.2, 30.3, 40.4, 50.5}; + } + else + { + FAIL("Unsupported type for templated test case"); + } + }; + + sp::primitive_array ar = create_primitive_array(); + + std::vector serialized_data = serialize_primitive_array(ar); + + CHECK(serialized_data.size() > 0); + + sp::primitive_array deserialized_ar = deserialize_primitive_array(serialized_data); -#include "../generated/Schema_generated.h" + CHECK_EQ(ar, deserialized_ar); -// NOTE this is just testing sparrow internals usability, -// for now we are not testing anything with serialization/deserialization -TEST_CASE("Use sparrow primitive_array") + compare_bitmap(ar, deserialized_ar); + compare_metadata(ar, deserialized_ar); +} + +TEST_CASE_TEMPLATE_APPLY(primitive_array_types, testing_types); + +TEST_CASE("Serialize and Deserialize primitive_array - int with nulls") +{ + namespace sp = sparrow; + + // Data buffer + sp::u8_buffer data_buffer = {100, 200, 300, 400, 500}; + + // Validity bitmap: 100 (valid), 200 (valid), 300 (null), 400 (valid), 500 (null) + sp::validity_bitmap validity(5, true); // All valid initially + validity.set(2, false); // Set index 2 to null + validity.set(4, false); // Set index 4 to null + + sp::primitive_array ar(std::move(data_buffer), std::move(validity)); + + std::vector serialized_data = serialize_primitive_array(ar); + + CHECK(serialized_data.size() > 0); + + sp::primitive_array deserialized_ar = deserialize_primitive_array(serialized_data); + + CHECK_EQ(ar, deserialized_ar); + + compare_bitmap(ar, deserialized_ar); + compare_metadata(ar, deserialized_ar); +} + +TEST_CASE("Serialize and Deserialize primitive_array - with name and metadata") { namespace sp = sparrow; - sp::primitive_array ar = { 1, 3, 5, 7, 9 }; - CHECK_EQ(ar.size(), 5); + // Data buffer + sp::u8_buffer data_buffer = {1, 2, 3}; + + // Validity bitmap: All valid + sp::validity_bitmap validity(3, true); + + // Custom metadata + std::vector metadata = { + {"key1", "value1"}, + {"key2", "value2"} + }; + + sp::primitive_array ar( + std::move(data_buffer), + std::move(validity), + "my_named_array", // name + std::make_optional(std::vector{{"key1", "value1"}, {"key2", "value2"}}) + ); + + std::vector serialized_data = serialize_primitive_array(ar); + + CHECK(serialized_data.size() > 0); - auto [arrow_array, arrow_schema] = sp::extract_arrow_structures(std::move(ar)); - CHECK_EQ(arrow_array.length, 5); + sp::primitive_array deserialized_ar = deserialize_primitive_array(serialized_data); - // Serialize - // Deserialize + CHECK_EQ(ar, deserialized_ar); - arrow_array.release(&arrow_array); - arrow_schema.release(&arrow_schema); + compare_bitmap(ar, deserialized_ar); + compare_metadata(ar, deserialized_ar); }