diff --git a/CMakeLists.txt b/CMakeLists.txt index 02d2f1e..5769f6b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -108,11 +108,11 @@ set(SPARROW_IPC_HEADERS ${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/chunk_memory_serializer.hpp ${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/config/config.hpp ${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/config/sparrow_ipc_version.hpp + ${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/compression.hpp ${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/deserialize_fixedsizebinary_array.hpp ${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/deserialize_primitive_array.hpp ${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/deserialize_utils.hpp ${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/deserialize_variable_size_binary_array.hpp - ${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/deserialize_variable_size_binary_array.hpp ${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/deserialize.hpp ${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/encapsulated_message.hpp ${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/flatbuffer_utils.hpp @@ -132,6 +132,7 @@ set(SPARROW_IPC_SRC ${SPARROW_IPC_SOURCE_DIR}/arrow_interface/arrow_schema.cpp ${SPARROW_IPC_SOURCE_DIR}/arrow_interface/arrow_schema/private_data.cpp ${SPARROW_IPC_SOURCE_DIR}/chunk_memory_serializer.cpp + ${SPARROW_IPC_SOURCE_DIR}/compression.cpp ${SPARROW_IPC_SOURCE_DIR}/deserialize_fixedsizebinary_array.cpp ${SPARROW_IPC_SOURCE_DIR}/deserialize_utils.cpp ${SPARROW_IPC_SOURCE_DIR}/deserialize.cpp @@ -253,6 +254,8 @@ target_link_libraries(sparrow-ipc PUBLIC sparrow::sparrow flatbuffers::flatbuffers + PRIVATE + lz4::lz4 ) # Ensure generated headers are available when building sparrow-ipc @@ -318,6 +321,25 @@ if (TARGET flatbuffers) endif() endif() +if (TARGET lz4) + get_target_property(is_imported lz4 IMPORTED) + if(NOT is_imported) + # This means `lz4` was fetched using FetchContent + # We need to export `lz4` target explicitly + list(APPEND SPARROW_IPC_EXPORTED_TARGETS lz4) + endif() +endif() + +if (TARGET lz4_static) + get_target_property(is_imported lz4_static IMPORTED) + if(NOT is_imported) + # `lz4_static` is needed as this is the actual library + # and `lz4` is an interface pointing to it. + # If `lz4_shared` is used instead for some reason, modify this accordingly + list(APPEND SPARROW_IPC_EXPORTED_TARGETS lz4_static) + endif() +endif() + install(TARGETS ${SPARROW_IPC_EXPORTED_TARGETS} EXPORT ${PROJECT_NAME}-targets) diff --git a/cmake/Findlz4.cmake b/cmake/Findlz4.cmake new file mode 100644 index 0000000..2b9e9c0 --- /dev/null +++ b/cmake/Findlz4.cmake @@ -0,0 +1,42 @@ +# Find LZ4 library and headers + +# This module defines: +# LZ4_FOUND - True if lz4 is found +# LZ4_INCLUDE_DIRS - LZ4 include directories +# LZ4_LIBRARIES - Libraries needed to use LZ4 +# LZ4_VERSION - LZ4 version number +# + +find_package(PkgConfig) +if(PKG_CONFIG_FOUND) + pkg_check_modules(LZ4 QUIET liblz4) + if(NOT LZ4_FOUND) + message(STATUS "Did not find 'liblz4.pc', trying 'lz4.pc'") + pkg_check_modules(LZ4 QUIET lz4) + endif() +endif() + +find_path(LZ4_INCLUDE_DIR lz4.h) +# HINTS ${LZ4_INCLUDEDIR} ${LZ4_INCLUDE_DIRS}) +find_library(LZ4_LIBRARY NAMES lz4 liblz4) +# HINTS ${LZ4_LIBDIR} ${LZ4_LIBRARY_DIRS}) + +include(FindPackageHandleStandardArgs) +find_package_handle_standard_args(lz4 DEFAULT_MSG + LZ4_LIBRARY LZ4_INCLUDE_DIR) +mark_as_advanced(LZ4_INCLUDE_DIR LZ4_LIBRARY) + +set(LZ4_LIBRARIES ${LZ4_LIBRARY}) +set(LZ4_INCLUDE_DIRS ${LZ4_INCLUDE_DIR}) + +if(LZ4_FOUND AND NOT TARGET lz4::lz4) + add_library(lz4::lz4 UNKNOWN IMPORTED) + set_target_properties(lz4::lz4 PROPERTIES + IMPORTED_LOCATION "${LZ4_LIBRARIES}" + INTERFACE_INCLUDE_DIRECTORIES "${LZ4_INCLUDE_DIRS}") + if (NOT TARGET LZ4::LZ4 AND TARGET lz4::lz4) + add_library(LZ4::LZ4 ALIAS lz4::lz4) + endif () +endif() + +#TODO add version? diff --git a/cmake/external_dependencies.cmake b/cmake/external_dependencies.cmake index 1d46d8f..9135817 100644 --- a/cmake/external_dependencies.cmake +++ b/cmake/external_dependencies.cmake @@ -11,8 +11,8 @@ endif() function(find_package_or_fetch) set(options) - set(oneValueArgs CONAN_PKG_NAME PACKAGE_NAME GIT_REPOSITORY TAG) - set(multiValueArgs) + set(oneValueArgs CONAN_PKG_NAME PACKAGE_NAME GIT_REPOSITORY TAG SOURCE_SUBDIR) + set(multiValueArgs CMAKE_ARGS) cmake_parse_arguments(PARSE_ARGV 0 arg "${options}" "${oneValueArgs}" "${multiValueArgs}" ) @@ -29,7 +29,14 @@ function(find_package_or_fetch) if(FETCH_DEPENDENCIES_WITH_CMAKE STREQUAL "ON" OR FETCH_DEPENDENCIES_WITH_CMAKE STREQUAL "MISSING") if(NOT ${actual_pkg_name}_FOUND) message(STATUS "📦 Fetching ${arg_PACKAGE_NAME}") - FetchContent_Declare( + # Apply CMAKE_ARGS before fetching + foreach(cmake_arg ${arg_CMAKE_ARGS}) + string(REGEX MATCH "^([^=]+)=(.*)$" _ ${cmake_arg}) + if(CMAKE_MATCH_1) + set(${CMAKE_MATCH_1} ${CMAKE_MATCH_2} CACHE BOOL "" FORCE) + endif() + endforeach() + set(fetch_args ${arg_PACKAGE_NAME} GIT_SHALLOW TRUE GIT_REPOSITORY ${arg_GIT_REPOSITORY} @@ -37,6 +44,10 @@ function(find_package_or_fetch) GIT_PROGRESS TRUE SYSTEM EXCLUDE_FROM_ALL) + if(arg_SOURCE_SUBDIR) + list(APPEND fetch_args SOURCE_SUBDIR ${arg_SOURCE_SUBDIR}) + endif() + FetchContent_Declare(${fetch_args}) FetchContent_MakeAvailable(${arg_PACKAGE_NAME}) message(STATUS "\t✅ Fetched ${arg_PACKAGE_NAME}") else() @@ -79,6 +90,25 @@ if(NOT TARGET flatbuffers::flatbuffers) endif() unset(FLATBUFFERS_BUILD_TESTS CACHE) +# Fetching lz4 +# Disable bundled mode to allow shared libraries if needed +# lz4 is built as static by default if bundled +# set(LZ4_BUNDLED_MODE OFF CACHE BOOL "" FORCE) +# set(BUILD_SHARED_LIBS ON CACHE BOOL "" FORCE) +find_package_or_fetch( + PACKAGE_NAME lz4 + GIT_REPOSITORY https://github.com/lz4/lz4.git + TAG v1.10.0 + SOURCE_SUBDIR build/cmake + CMAKE_ARGS + "LZ4_BUILD_CLI=OFF" + "LZ4_BUILD_LEGACY_LZ4C=OFF" +) + +if(NOT TARGET lz4::lz4) + add_library(lz4::lz4 ALIAS lz4) +endif() + if(SPARROW_IPC_BUILD_TESTS) find_package_or_fetch( PACKAGE_NAME doctest @@ -109,10 +139,18 @@ if(SPARROW_IPC_BUILD_TESTS) ) message(STATUS "\t✅ Fetched arrow-testing") - # Iterate over all the files in the arrow-testing-data source directiory. When it's a gz, extract in place. - file(GLOB_RECURSE arrow_testing_data_targz_files CONFIGURE_DEPENDS + # Fetch all the files in the cpp-21.0.0 directory + file(GLOB_RECURSE arrow_testing_data_targz_files_cpp_21 CONFIGURE_DEPENDS "${arrow-testing_SOURCE_DIR}/data/arrow-ipc-stream/integration/cpp-21.0.0/*.json.gz" ) + # Fetch all the files in the 2.0.0-compression directory + file(GLOB_RECURSE arrow_testing_data_targz_files_compression CONFIGURE_DEPENDS + "${arrow-testing_SOURCE_DIR}/data/arrow-ipc-stream/integration/2.0.0-compression/*.json.gz" + ) + + # Combine lists of files + list(APPEND arrow_testing_data_targz_files ${arrow_testing_data_targz_files_cpp_21} ${arrow_testing_data_targz_files_compression}) + # Iterate over all the files in the arrow-testing-data source directory. When it's a gz, extract in place. foreach(file_path IN LISTS arrow_testing_data_targz_files) cmake_path(GET file_path PARENT_PATH parent_dir) cmake_path(GET file_path STEM filename) @@ -128,5 +166,4 @@ if(SPARROW_IPC_BUILD_TESTS) endif() endif() endforeach() - endif() diff --git a/conanfile.py b/conanfile.py index 59916f8..e2f251a 100644 --- a/conanfile.py +++ b/conanfile.py @@ -45,6 +45,8 @@ def configure(self): def requirements(self): self.requires("sparrow/1.0.0") self.requires(f"flatbuffers/{self._flatbuffers_version}") + self.requires("lz4/1.9.4") + #self.requires("zstd/1.5.5") if self.options.get_safe("build_tests"): self.test_requires("doctest/2.4.12") diff --git a/environment-dev.yml b/environment-dev.yml index 7a3f086..ff84d2a 100644 --- a/environment-dev.yml +++ b/environment-dev.yml @@ -8,8 +8,10 @@ dependencies: - cxx-compiler # Libraries dependencies - flatbuffers + - lz4-c - nlohmann_json - sparrow-devel >=1.1.2 + # Testing dependencies - doctest # Documentation dependencies - doxygen diff --git a/include/sparrow_ipc/arrow_interface/arrow_array.hpp b/include/sparrow_ipc/arrow_interface/arrow_array.hpp index 2f1f72d..4faecf4 100644 --- a/include/sparrow_ipc/arrow_interface/arrow_array.hpp +++ b/include/sparrow_ipc/arrow_interface/arrow_array.hpp @@ -1,34 +1,86 @@ - #pragma once -#include +#include #include +#include #include "sparrow_ipc/config/config.hpp" +#include "sparrow_ipc/arrow_interface/arrow_array/private_data.hpp" namespace sparrow_ipc { - [[nodiscard]] SPARROW_IPC_API ArrowArray make_non_owning_arrow_array( + SPARROW_IPC_API void release_arrow_array_children_and_dictionary(ArrowArray* array); + + template + void arrow_array_release(ArrowArray* array) + { + SPARROW_ASSERT_TRUE(array != nullptr) + SPARROW_ASSERT_TRUE(array->release == std::addressof(arrow_array_release)) + + SPARROW_ASSERT_TRUE(array->private_data != nullptr); + + delete static_cast(array->private_data); + array->private_data = nullptr; + array->buffers = nullptr; // The buffers were deleted with the private data + + release_arrow_array_children_and_dictionary(array); + array->release = nullptr; + } + + template + void fill_arrow_array( + ArrowArray& array, int64_t length, int64_t null_count, int64_t offset, - std::vector&& buffers, size_t children_count, ArrowArray** children, - ArrowArray* dictionary - ); + ArrowArray* dictionary, + Arg&& private_data_arg + ) + { + SPARROW_ASSERT_TRUE(length >= 0); + SPARROW_ASSERT_TRUE(null_count >= -1); + SPARROW_ASSERT_TRUE(offset >= 0); - SPARROW_IPC_API void release_non_owning_arrow_array(ArrowArray* array); + array.length = length; + array.null_count = null_count; + array.offset = offset; + array.n_children = static_cast(children_count); + array.children = children; + array.dictionary = dictionary; - SPARROW_IPC_API void fill_non_owning_arrow_array( - ArrowArray& array, + auto private_data = new T(std::forward(private_data_arg)); + array.private_data = private_data; + array.n_buffers = private_data->n_buffers(); + array.buffers = private_data->buffers_ptrs(); + + array.release = &arrow_array_release; + } + + template + [[nodiscard]] ArrowArray make_arrow_array( int64_t length, int64_t null_count, int64_t offset, - std::vector&& buffers, size_t children_count, ArrowArray** children, - ArrowArray* dictionary - ); -} \ No newline at end of file + ArrowArray* dictionary, + Arg&& private_data_arg + ) + { + ArrowArray array{}; + fill_arrow_array( + array, + length, + null_count, + offset, + children_count, + children, + dictionary, + std::forward(private_data_arg) + ); + return array; + } +} diff --git a/include/sparrow_ipc/arrow_interface/arrow_array/private_data.hpp b/include/sparrow_ipc/arrow_interface/arrow_array/private_data.hpp index 90e633f..5ad6c90 100644 --- a/include/sparrow_ipc/arrow_interface/arrow_array/private_data.hpp +++ b/include/sparrow_ipc/arrow_interface/arrow_array/private_data.hpp @@ -1,5 +1,5 @@ #pragma once - +#include #include #include @@ -7,19 +7,40 @@ namespace sparrow_ipc { + template + concept ArrowPrivateData = requires(T& t) + { + { t.buffers_ptrs() } -> std::same_as; + { t.n_buffers() } -> std::convertible_to; + }; + + class owning_arrow_array_private_data + { + public: + + explicit owning_arrow_array_private_data(std::vector>&& buffers); + + [[nodiscard]] SPARROW_IPC_API const void** buffers_ptrs() noexcept; + [[nodiscard]] SPARROW_IPC_API std::size_t n_buffers() const noexcept; + + private: + std::vector> m_buffers; + std::vector m_buffer_pointers; + }; + class non_owning_arrow_array_private_data { public: explicit constexpr non_owning_arrow_array_private_data(std::vector&& buffers_pointers) - : m_buffers_pointers(std::move(buffers_pointers)) + : m_buffer_pointers(std::move(buffers_pointers)) { } [[nodiscard]] SPARROW_IPC_API const void** buffers_ptrs() noexcept; + [[nodiscard]] SPARROW_IPC_API std::size_t n_buffers() const noexcept; private: - - std::vector m_buffers_pointers; + std::vector m_buffer_pointers; }; } diff --git a/include/sparrow_ipc/chunk_memory_serializer.hpp b/include/sparrow_ipc/chunk_memory_serializer.hpp index 3b241f8..4868a42 100644 --- a/include/sparrow_ipc/chunk_memory_serializer.hpp +++ b/include/sparrow_ipc/chunk_memory_serializer.hpp @@ -1,7 +1,16 @@ #pragma once +#include +#include +#include +#include +#include + #include +#include "Message_generated.h" + +#include "sparrow_ipc/any_output_stream.hpp" #include "sparrow_ipc/chunk_memory_output_stream.hpp" #include "sparrow_ipc/config/config.hpp" #include "sparrow_ipc/memory_output_stream.hpp" @@ -33,8 +42,10 @@ namespace sparrow_ipc * @brief Constructs a chunk serializer with a reference to a chunked memory output stream. * * @param stream Reference to a chunked memory output stream that will receive the serialized chunks + * @param compression Optional: The compression type to use for record batch bodies. */ - chunk_serializer(chunked_memory_output_stream>>& stream); + // TODO Use enums and such to avoid including flatbuffers headers + chunk_serializer(chunked_memory_output_stream>>& stream, std::optional compression = std::nullopt); /** * @brief Writes a single record batch to the chunked stream. @@ -120,6 +131,7 @@ namespace sparrow_ipc std::vector m_dtypes; chunked_memory_output_stream>>* m_pstream; bool m_ended{false}; + std::optional m_compression; }; // Implementation @@ -148,10 +160,14 @@ namespace sparrow_ipc for (const auto& rb : record_batches) { + if (get_column_dtypes(rb) != m_dtypes) + { + throw std::invalid_argument("Record batch schema does not match serializer schema"); + } std::vector buffer; memory_output_stream stream(buffer); any_output_stream astream(stream); - serialize_record_batch(rb, astream); + serialize_record_batch(rb, astream, m_compression); m_pstream->write(std::move(buffer)); } } @@ -169,4 +185,4 @@ namespace sparrow_ipc write(record_batches); return *this; } -} \ No newline at end of file +} diff --git a/include/sparrow_ipc/compression.hpp b/include/sparrow_ipc/compression.hpp new file mode 100644 index 0000000..96b47ec --- /dev/null +++ b/include/sparrow_ipc/compression.hpp @@ -0,0 +1,28 @@ +#pragma once + +#include +#include +#include +#include + +#include "Message_generated.h" + +#include "sparrow_ipc/config/config.hpp" + +namespace sparrow_ipc +{ +// TODO use these later if needed for wrapping purposes (flatbuffers/lz4) +// enum class CompressionType +// { +// NONE, +// LZ4, +// ZSTD +// }; + +// CompressionType to_compression_type(org::apache::arrow::flatbuf::CompressionType compression_type); + + constexpr auto CompressionHeaderSize = sizeof(std::int64_t); + + [[nodiscard]] SPARROW_IPC_API std::vector compress(const org::apache::arrow::flatbuf::CompressionType compression_type, std::span data); + [[nodiscard]] SPARROW_IPC_API std::variant, std::span> decompress(const org::apache::arrow::flatbuf::CompressionType compression_type, std::span data); +} diff --git a/include/sparrow_ipc/deserialize_primitive_array.hpp b/include/sparrow_ipc/deserialize_primitive_array.hpp index a1c5dad..76f7212 100644 --- a/include/sparrow_ipc/deserialize_primitive_array.hpp +++ b/include/sparrow_ipc/deserialize_primitive_array.hpp @@ -34,28 +34,45 @@ namespace sparrow_ipc nullptr, nullptr ); - const auto [bitmap_ptr, null_count] = utils::get_bitmap_pointer_and_null_count( - record_batch, - body, - buffer_index++ - ); - const auto primitive_buffer_metadata = record_batch.buffers()->Get(buffer_index++); - if (body.size() < (primitive_buffer_metadata->offset() + primitive_buffer_metadata->length())) + + const auto compression = record_batch.compression(); + std::vector> decompressed_buffers; + + auto validity_buffer_span = utils::get_and_decompress_buffer(record_batch, body, buffer_index, compression, decompressed_buffers); + + const auto [bitmap_ptr, null_count] = utils::get_bitmap_pointer_and_null_count(validity_buffer_span, record_batch.length()); + + auto data_buffer_span = utils::get_and_decompress_buffer(record_batch, body, buffer_index, compression, decompressed_buffers); + + ArrowArray array; + if (compression) { - throw std::runtime_error("Primitive buffer exceeds body size"); + array = make_arrow_array( + record_batch.length(), + null_count, + 0, + 0, + nullptr, + nullptr, + std::move(decompressed_buffers) + ); } - auto primitives_ptr = const_cast(body.data() + primitive_buffer_metadata->offset()); - std::vector buffers = {bitmap_ptr, primitives_ptr}; - ArrowArray array = make_non_owning_arrow_array( - record_batch.length(), - null_count, - 0, - std::move(buffers), - 0, - nullptr, - nullptr - ); + else + { + auto primitives_ptr = const_cast(data_buffer_span.data()); + std::vector buffers = {bitmap_ptr, primitives_ptr}; + array = make_arrow_array( + record_batch.length(), + null_count, + 0, + 0, + nullptr, + nullptr, + std::move(buffers) + ); + } + sparrow::arrow_proxy ap{std::move(array), std::move(schema)}; return sparrow::primitive_array{std::move(ap)}; } -} \ No newline at end of file +} diff --git a/include/sparrow_ipc/deserialize_utils.hpp b/include/sparrow_ipc/deserialize_utils.hpp index fc1ca05..36f93ad 100644 --- a/include/sparrow_ipc/deserialize_utils.hpp +++ b/include/sparrow_ipc/deserialize_utils.hpp @@ -2,15 +2,35 @@ #include #include +#include #include #include #include "Message_generated.h" -#include "Schema_generated.h" namespace sparrow_ipc::utils { + /** + * @brief Extracts bitmap pointer and null count from a validity buffer span. + * + * This function calculates the number of null values represented by the bitmap. + * + * @param validity_buffer_span The validity buffer as a byte span. + * @param length The Arrow RecordBatch length (number of values in the array). + * + * @return A pair containing: + * - First: Pointer to the bitmap data (nullptr if buffer is empty) + * - Second: Count of null values in the bitmap (0 if buffer is empty) + * + * @note If the bitmap buffer is empty, returns {nullptr, 0} + * @note The returned pointer is a non-const cast of the original const data + */ + [[nodiscard]] std::pair get_bitmap_pointer_and_null_count( + std::span validity_buffer_span, + const int64_t length + ); + /** * @brief Extracts bitmap pointer and null count from a RecordBatch buffer. * @@ -28,9 +48,35 @@ namespace sparrow_ipc::utils * @note If the bitmap buffer has zero length, returns {nullptr, 0} * @note The returned pointer is a non-const cast of the original const data */ + // TODO to be removed when not used anymore (after adding compression to deserialize_fixedsizebinary_array) [[nodiscard]] std::pair get_bitmap_pointer_and_null_count( const org::apache::arrow::flatbuf::RecordBatch& record_batch, std::span body, size_t index ); -} \ No newline at end of file + + /** + * @brief Extracts a buffer from a RecordBatch and decompresses it if necessary. + * + * This function retrieves a buffer span from the specified index, increments the index, + * and applies decompression if specified. If the buffer is decompressed, the new + * data is stored in `decompressed_storage` and the returned span will point to this new data. + * + * @param record_batch The Arrow RecordBatch containing buffer metadata. + * @param body The raw buffer data as a byte span. + * @param buffer_index The index of the buffer to retrieve. This value is incremented by the function. + * @param compression The compression algorithm to use. If nullptr, no decompression is performed. + * @param decompressed_storage A vector that will be used to store the data of any decompressed buffers. + * + * @return A span viewing the resulting buffer data. This will be a view of the original + * `body` if no decompression occurs, or a view of the newly added buffer in + * `decompressed_storage` if decompression occurs. + */ + [[nodiscard]] std::span get_and_decompress_buffer( + const org::apache::arrow::flatbuf::RecordBatch& record_batch, + std::span body, + size_t& buffer_index, + const org::apache::arrow::flatbuf::BodyCompression* compression, + std::vector>& decompressed_storage + ); +} diff --git a/include/sparrow_ipc/deserialize_variable_size_binary_array.hpp b/include/sparrow_ipc/deserialize_variable_size_binary_array.hpp index f6a5729..623776d 100644 --- a/include/sparrow_ipc/deserialize_variable_size_binary_array.hpp +++ b/include/sparrow_ipc/deserialize_variable_size_binary_array.hpp @@ -31,35 +31,47 @@ namespace sparrow_ipc nullptr, nullptr ); - const auto [bitmap_ptr, null_count] = utils::get_bitmap_pointer_and_null_count( - record_batch, - body, - buffer_index++ - ); - const auto offset_metadata = record_batch.buffers()->Get(buffer_index++); - if ((offset_metadata->offset() + offset_metadata->length()) > body.size()) + const auto compression = record_batch.compression(); + std::vector> decompressed_buffers; + + auto validity_buffer_span = utils::get_and_decompress_buffer(record_batch, body, buffer_index, compression, decompressed_buffers); + + const auto [bitmap_ptr, null_count] = utils::get_bitmap_pointer_and_null_count(validity_buffer_span, record_batch.length()); + + auto offset_buffer_span = utils::get_and_decompress_buffer(record_batch, body, buffer_index, compression, decompressed_buffers); + auto data_buffer_span = utils::get_and_decompress_buffer(record_batch, body, buffer_index, compression, decompressed_buffers); + + ArrowArray array; + if (compression) { - throw std::runtime_error("Offset buffer exceeds body size"); + array = make_arrow_array( + record_batch.length(), + null_count, + 0, + 0, + nullptr, + nullptr, + std::move(decompressed_buffers) + ); } - auto offset_ptr = const_cast(body.data() + offset_metadata->offset()); - const auto buffer_metadata = record_batch.buffers()->Get(buffer_index++); - if ((buffer_metadata->offset() + buffer_metadata->length()) > body.size()) + else { - throw std::runtime_error("Data buffer exceeds body size"); + auto offset_ptr = const_cast(offset_buffer_span.data()); + auto buffer_ptr = const_cast(data_buffer_span.data()); + std::vector buffers = {bitmap_ptr, offset_ptr, buffer_ptr}; + array = make_arrow_array( + record_batch.length(), + null_count, + 0, + 0, + nullptr, + nullptr, + std::move(buffers) + ); } - auto buffer_ptr = const_cast(body.data() + buffer_metadata->offset()); - std::vector buffers = {bitmap_ptr, offset_ptr, buffer_ptr}; - ArrowArray array = make_non_owning_arrow_array( - record_batch.length(), - null_count, - 0, - std::move(buffers), - 0, - nullptr, - nullptr - ); + sparrow::arrow_proxy ap{std::move(array), std::move(schema)}; return T{std::move(ap)}; } -} \ No newline at end of file +} diff --git a/include/sparrow_ipc/flatbuffer_utils.hpp b/include/sparrow_ipc/flatbuffer_utils.hpp index 87c322d..b26f0ed 100644 --- a/include/sparrow_ipc/flatbuffer_utils.hpp +++ b/include/sparrow_ipc/flatbuffer_utils.hpp @@ -213,15 +213,20 @@ namespace sparrow_ipc * format that conforms to the Arrow IPC specification. * * @param record_batch The source record batch containing the data to be serialized - * + * @param compression Optional: The compression algorithm to be used for the message body + * @param body_size Optional: An override for the total size of the message body + * If not provided, the size is calculated from the uncompressed buffers + * This is required when using compression + * @param compressed_buffers Optional: A pointer to a vector of buffer metadata. + * If provided, this metadata is used instead of generating it from the + * uncompressed record batch. This is required when using compression. * @return A FlatBufferBuilder containing the complete serialized message ready for * transmission or storage. The builder is finished and ready to be accessed * via GetBufferPointer() and GetSize(). * * @note The returned message uses Arrow IPC format version V5 - * @note Compression and variadic buffer counts are not currently implemented (set to 0) - * @note The body size is automatically calculated based on the record batch contents + * @note Variadic buffer counts is not currently implemented (set to 0) */ [[nodiscard]] flatbuffers::FlatBufferBuilder - get_record_batch_message_builder(const sparrow::record_batch& record_batch); -} \ No newline at end of file + get_record_batch_message_builder(const sparrow::record_batch& record_batch, std::optional compression = std::nullopt, std::optional body_size = std::nullopt, const std::vector* compressed_buffers = nullptr); +} diff --git a/include/sparrow_ipc/serialize.hpp b/include/sparrow_ipc/serialize.hpp index 4a18e57..ab47646 100644 --- a/include/sparrow_ipc/serialize.hpp +++ b/include/sparrow_ipc/serialize.hpp @@ -26,6 +26,7 @@ namespace sparrow_ipc * @param record_batches Collection of record batches to serialize. All batches must have identical * schemas. * @param stream The output stream where the serialized data will be written. + * @param compression The compression type to use when serializing. * * @throws std::invalid_argument If record batches have inconsistent schemas or if the collection * contains batches that cannot be serialized together. @@ -35,7 +36,7 @@ namespace sparrow_ipc */ template requires std::same_as, sparrow::record_batch> - void serialize_record_batches_to_ipc_stream(const R& record_batches, any_output_stream& stream) + void serialize_record_batches_to_ipc_stream(const R& record_batches, any_output_stream& stream, std::optional compression) { if (record_batches.empty()) { @@ -51,7 +52,7 @@ namespace sparrow_ipc serialize_schema_message(record_batches[0], stream); for (const auto& rb : record_batches) { - serialize_record_batch(rb, stream); + serialize_record_batch(rb, stream, compression); } stream.write(end_of_stream); } @@ -68,13 +69,14 @@ namespace sparrow_ipc * * @param record_batch The sparrow record batch to serialize * @param stream The output stream where the serialized record batch will be written + * @param compression The compression type to use when serializing. * * @note The output follows Arrow IPC message format with proper alignment and * includes both metadata and data portions of the record batch */ SPARROW_IPC_API void - serialize_record_batch(const sparrow::record_batch& record_batch, any_output_stream& stream); + serialize_record_batch(const sparrow::record_batch& record_batch, any_output_stream& stream, std::optional compression); /** * @brief Serializes a schema message for a record batch into a byte buffer. diff --git a/include/sparrow_ipc/serialize_utils.hpp b/include/sparrow_ipc/serialize_utils.hpp index ae881a5..5129b73 100644 --- a/include/sparrow_ipc/serialize_utils.hpp +++ b/include/sparrow_ipc/serialize_utils.hpp @@ -40,9 +40,10 @@ namespace sparrow_ipc * * @param record_batch The sparrow record batch to be serialized * @param stream The output stream where the serialized record batch will be written + * @param compression The compression type to use when serializing */ SPARROW_IPC_API void - serialize_record_batch(const sparrow::record_batch& record_batch, any_output_stream& stream); + serialize_record_batch(const sparrow::record_batch& record_batch, any_output_stream& stream, std::optional compression); /** * @brief Calculates the total serialized size of a schema message. @@ -72,10 +73,11 @@ namespace sparrow_ipc * - Body data with 8-byte alignment between buffers * * @param record_batch The record batch to be measured + * @param compression The compression type to use when serializing * @return The total size in bytes that the serialized record batch would occupy */ [[nodiscard]] SPARROW_IPC_API std::size_t - calculate_record_batch_message_size(const sparrow::record_batch& record_batch); + calculate_record_batch_message_size(const sparrow::record_batch& record_batch, std::optional compression = std::nullopt); /** * @brief Calculates the total serialized size for a collection of record batches. @@ -85,12 +87,13 @@ namespace sparrow_ipc * * @tparam R Range type containing sparrow::record_batch objects * @param record_batches Collection of record batches to be measured + * @param compression The compression type to use when serializing * @return The total size in bytes for the complete serialized output * @throws std::invalid_argument if record batches have inconsistent schemas */ template requires std::same_as, sparrow::record_batch> - [[nodiscard]] std::size_t calculate_total_serialized_size(const R& record_batches) + [[nodiscard]] std::size_t calculate_total_serialized_size(const R& record_batches, std::optional compression = std::nullopt) { if (record_batches.empty()) { @@ -109,12 +112,30 @@ namespace sparrow_ipc // Calculate record batch message sizes for (const auto& record_batch : record_batches) { - total_size += calculate_record_batch_message_size(record_batch); + total_size += calculate_record_batch_message_size(record_batch, compression); } return total_size; } + /** + * @brief Generates the compressed message body and buffer metadata for a record batch. + * + * This function traverses the record batch, compresses each buffer using the specified + * compression algorithm, and constructs the message body. For each compressed buffer, + * it is prefixed by its 8-byte uncompressed size. Padding is added after each + * compressed buffer to ensure 8-byte alignment. + * + * @param record_batch The record batch to serialize. + * @param compression_type The compression algorithm to use (e.g., LZ4_FRAME, ZSTD). + * @return A std::pair containing: + * - first: A vector of bytes representing the complete compressed message body. + * - second: A vector of FlatBuffer Buffer objects describing the offset and + * size of each buffer within the compressed body. + */ + [[nodiscard]] SPARROW_IPC_API std::pair, std::vector> + generate_compressed_body_and_buffers(const sparrow::record_batch& record_batch, const org::apache::arrow::flatbuf::CompressionType compression_type); + /** * @brief Fills the body vector with serialized data from an arrow proxy and its children. * diff --git a/include/sparrow_ipc/serializer.hpp b/include/sparrow_ipc/serializer.hpp index 9a8c1e0..f0ebcb9 100644 --- a/include/sparrow_ipc/serializer.hpp +++ b/include/sparrow_ipc/serializer.hpp @@ -41,8 +41,8 @@ namespace sparrow_ipc * The serializer stores a pointer to this stream for later use. */ template - serializer(TStream& stream) - : m_stream(stream) + serializer(TStream& stream, std::optional compression = std::nullopt) + : m_stream(stream), m_compression(compression) { } @@ -94,7 +94,7 @@ namespace sparrow_ipc m_stream.size(), [this](size_t acc, const sparrow::record_batch& rb) { - return acc + calculate_record_batch_message_size(rb); + return acc + calculate_record_batch_message_size(rb, m_compression); } ) + (m_schema_received ? 0 : calculate_schema_message_size(*record_batches.begin())); @@ -115,7 +115,7 @@ namespace sparrow_ipc { throw std::invalid_argument("Record batch schema does not match serializer schema"); } - serialize_record_batch(rb, m_stream); + serialize_record_batch(rb, m_stream, m_compression); } } @@ -206,6 +206,7 @@ namespace sparrow_ipc std::vector m_dtypes; any_output_stream m_stream; bool m_ended{false}; + std::optional m_compression; }; inline serializer& end_stream(serializer& serializer) diff --git a/src/arrow_interface/arrow_array.cpp b/src/arrow_interface/arrow_array.cpp index ed0a0f2..a01006b 100644 --- a/src/arrow_interface/arrow_array.cpp +++ b/src/arrow_interface/arrow_array.cpp @@ -1,73 +1,40 @@ #include "sparrow_ipc/arrow_interface/arrow_array.hpp" -#include - #include -#include - -#include "sparrow_ipc/arrow_interface/arrow_array/private_data.hpp" -#include "sparrow_ipc/arrow_interface/arrow_array_schema_common_release.hpp" namespace sparrow_ipc { - void release_non_owning_arrow_array(ArrowArray* array) - { - SPARROW_ASSERT_FALSE(array == nullptr) - SPARROW_ASSERT_TRUE(array->release == std::addressof(release_non_owning_arrow_array)) - - release_common_non_owning_arrow(*array); - array->buffers = nullptr; // The buffers were deleted with the private data - } - - void fill_non_owning_arrow_array( - ArrowArray& array, - int64_t length, - int64_t null_count, - int64_t offset, - std::vector&& buffers, - size_t children_count, - ArrowArray** children, - ArrowArray* dictionary - ) - { - SPARROW_ASSERT_TRUE(length >= 0); - SPARROW_ASSERT_TRUE(null_count >= -1); - SPARROW_ASSERT_TRUE(offset >= 0); - - array.length = length; - array.null_count = null_count; - array.offset = offset; - array.n_buffers = static_cast(buffers.size()); - array.private_data = new non_owning_arrow_array_private_data(std::move(buffers)); - const auto private_data = static_cast(array.private_data); - array.buffers = private_data->buffers_ptrs(); - array.n_children = static_cast(children_count); - array.children = children; - array.dictionary = dictionary; - array.release = release_non_owning_arrow_array; - } - - ArrowArray make_non_owning_arrow_array( - int64_t length, - int64_t null_count, - int64_t offset, - std::vector&& buffers, - size_t children_count, - ArrowArray** children, - ArrowArray* dictionary - ) + void release_arrow_array_children_and_dictionary(ArrowArray* array) { - ArrowArray array{}; - fill_non_owning_arrow_array( - array, - length, - null_count, - offset, - std::move(buffers), - children_count, - children, - dictionary - ); - return array; + SPARROW_ASSERT_TRUE(array != nullptr) + + if (array->children) + { + for (int64_t i = 0; i < array->n_children; ++i) + { + ArrowArray* child = array->children[i]; + if (child) + { + if (child->release) + { + child->release(child); + } + delete child; + child = nullptr; + } + } + delete[] array->children; + array->children = nullptr; + } + + if (array->dictionary) + { + if (array->dictionary->release) + { + array->dictionary->release(array->dictionary); + } + delete array->dictionary; + array->dictionary = nullptr; + } } } diff --git a/src/arrow_interface/arrow_array/private_data.cpp b/src/arrow_interface/arrow_array/private_data.cpp index b133c8e..9c3738b 100644 --- a/src/arrow_interface/arrow_array/private_data.cpp +++ b/src/arrow_interface/arrow_array/private_data.cpp @@ -2,8 +2,33 @@ namespace sparrow_ipc { + owning_arrow_array_private_data::owning_arrow_array_private_data(std::vector>&& buffers) + : m_buffers(std::move(buffers)) + { + m_buffer_pointers.reserve(m_buffers.size()); + for (const auto& buffer : m_buffers) + { + m_buffer_pointers.push_back(buffer.data()); + } + } + + const void** owning_arrow_array_private_data::buffers_ptrs() noexcept + { + return m_buffer_pointers.data(); + } + + std::size_t owning_arrow_array_private_data::n_buffers() const noexcept + { + return m_buffers.size(); + } + const void** non_owning_arrow_array_private_data::buffers_ptrs() noexcept { - return const_cast(reinterpret_cast(m_buffers_pointers.data())); + return const_cast(reinterpret_cast(m_buffer_pointers.data())); + } + + std::size_t non_owning_arrow_array_private_data::n_buffers() const noexcept + { + return m_buffer_pointers.size(); } -} \ No newline at end of file +} diff --git a/src/chunk_memory_serializer.cpp b/src/chunk_memory_serializer.cpp index cbdfb4a..db2c8a2 100644 --- a/src/chunk_memory_serializer.cpp +++ b/src/chunk_memory_serializer.cpp @@ -6,8 +6,8 @@ namespace sparrow_ipc { - chunk_serializer::chunk_serializer(chunked_memory_output_stream>>& stream) - : m_pstream(&stream) + chunk_serializer::chunk_serializer(chunked_memory_output_stream>>& stream, std::optional compression) + : m_pstream(&stream), m_compression(compression) { } diff --git a/src/compression.cpp b/src/compression.cpp new file mode 100644 index 0000000..c8ad598 --- /dev/null +++ b/src/compression.cpp @@ -0,0 +1,154 @@ +#include + +#include + +#include "sparrow_ipc/compression.hpp" + +namespace sparrow_ipc +{ +// CompressionType to_compression_type(org::apache::arrow::flatbuf::CompressionType compression_type) +// { +// switch (compression_type) +// { +// case org::apache::arrow::flatbuf::CompressionType::LZ4_FRAME: +// return CompressionType::LZ4; +// // case org::apache::arrow::flatbuf::CompressionType::ZSTD: +// // // TODO: Add ZSTD support +// // break; +// default: +// return CompressionType::NONE; +// } +// } + + namespace + { + std::vector lz4_compress(std::span data) + { + const std::int64_t uncompressed_size = data.size(); + const size_t max_compressed_size = LZ4F_compressFrameBound(uncompressed_size, nullptr); + std::vector compressed_data(max_compressed_size); + const size_t compressed_size = LZ4F_compressFrame(compressed_data.data(), max_compressed_size, data.data(), uncompressed_size, nullptr); + if (LZ4F_isError(compressed_size)) + { + throw std::runtime_error("Failed to compress data with LZ4 frame format"); + } + compressed_data.resize(compressed_size); + return compressed_data; + } + + std::vector lz4_decompress(std::span data, const std::int64_t decompressed_size) + { + std::vector decompressed_data(decompressed_size); + LZ4F_dctx* dctx = nullptr; + LZ4F_createDecompressionContext(&dctx, LZ4F_VERSION); + size_t compressed_size_in_out = data.size(); + size_t decompressed_size_in_out = decompressed_size; + size_t result = LZ4F_decompress(dctx, decompressed_data.data(), &decompressed_size_in_out, data.data(), &compressed_size_in_out, nullptr); + if (LZ4F_isError(result) || (decompressed_size_in_out != (size_t)decompressed_size)) + { + throw std::runtime_error("Failed to decompress data with LZ4 frame format"); + } + LZ4F_freeDecompressionContext(dctx); + return decompressed_data; + } + + // TODO These functions could be moved to serialize_utils and deserialize_utils if preferred + // as they are handling the header size + std::vector uncompressed_data_with_header(std::span data) + { + std::vector result; + result.reserve(CompressionHeaderSize + data.size()); + const std::int64_t header = -1; + result.insert(result.end(), reinterpret_cast(&header), reinterpret_cast(&header) + sizeof(header)); + result.insert(result.end(), data.begin(), data.end()); + return result; + } + + std::vector lz4_compress_with_header(std::span data) + { + const std::int64_t original_size = data.size(); + auto compressed_body = lz4_compress(data); + + if (compressed_body.size() >= static_cast(original_size)) + { + return uncompressed_data_with_header(data); + } + + std::vector result; + result.reserve(CompressionHeaderSize + compressed_body.size()); + result.insert(result.end(), reinterpret_cast(&original_size), reinterpret_cast(&original_size) + sizeof(original_size)); + result.insert(result.end(), compressed_body.begin(), compressed_body.end()); + return result; + } + + std::variant, std::span> lz4_decompress_with_header(std::span data) + { + if (data.size() < CompressionHeaderSize) + { + throw std::runtime_error("Invalid compressed data: missing decompressed size"); + } + const std::int64_t decompressed_size = *reinterpret_cast(data.data()); + const auto compressed_data = data.subspan(CompressionHeaderSize); + + if (decompressed_size == -1) + { + return compressed_data; + } + + return lz4_decompress(compressed_data, decompressed_size); + } + + std::span get_body_from_uncompressed_data(std::span data) + { + if (data.size() < CompressionHeaderSize) + { + throw std::runtime_error("Invalid data: missing header"); + } + return data.subspan(CompressionHeaderSize); + } + } + + std::vector compress(const org::apache::arrow::flatbuf::CompressionType compression_type, std::span data) + { + switch (compression_type) + { + case org::apache::arrow::flatbuf::CompressionType::LZ4_FRAME: + { + return lz4_compress_with_header(data); + } + case org::apache::arrow::flatbuf::CompressionType::ZSTD: + { + throw std::runtime_error("Compression using zstd is not supported yet."); + } + default: + return uncompressed_data_with_header(data); + } + } + + std::variant, std::span> decompress(const org::apache::arrow::flatbuf::CompressionType compression_type, std::span data) + { + // Handle empty input: an empty span is a valid representation for an empty buffer + // (e.g., a validity bitmap for a column with no nulls) and should decompress to an empty output. + // TODO if we don't call this fct anymore on validity buffers, remove this empty data handling + if (data.empty()) + { + return {}; + } + + switch (compression_type) + { + case org::apache::arrow::flatbuf::CompressionType::LZ4_FRAME: + { + return lz4_decompress_with_header(data); + } + case org::apache::arrow::flatbuf::CompressionType::ZSTD: + { + throw std::runtime_error("Decompression using zstd is not supported yet."); + } + default: + { + return get_body_from_uncompressed_data(data); + } + } + } +} diff --git a/src/deserialize.cpp b/src/deserialize.cpp index 5779ca9..55f863d 100644 --- a/src/deserialize.cpp +++ b/src/deserialize.cpp @@ -49,7 +49,6 @@ namespace sparrow_ipc const std::vector>>& field_metadata ) { - const size_t length = static_cast(record_batch.length()); size_t buffer_index = 0; std::vector arrays; @@ -277,4 +276,4 @@ namespace sparrow_ipc } while (!data.empty()); return record_batches; } -} \ No newline at end of file +} diff --git a/src/deserialize_fixedsizebinary_array.cpp b/src/deserialize_fixedsizebinary_array.cpp index 63ea213..427f600 100644 --- a/src/deserialize_fixedsizebinary_array.cpp +++ b/src/deserialize_fixedsizebinary_array.cpp @@ -2,6 +2,7 @@ namespace sparrow_ipc { + // TODO add compression here and tests (not available for this type in apache arrow integration tests files) sparrow::fixed_width_binary_array deserialize_non_owning_fixedwidthbinary( const org::apache::arrow::flatbuf::RecordBatch& record_batch, std::span body, @@ -33,14 +34,14 @@ namespace sparrow_ipc } auto buffer_ptr = const_cast(body.data() + buffer_metadata->offset()); std::vector buffers = {bitmap_ptr, buffer_ptr}; - ArrowArray array = make_non_owning_arrow_array( + ArrowArray array = make_arrow_array( record_batch.length(), null_count, 0, - std::move(buffers), 0, nullptr, - nullptr + nullptr, + std::move(buffers) ); sparrow::arrow_proxy ap{std::move(array), std::move(schema)}; return sparrow::fixed_width_binary_array{std::move(ap)}; diff --git a/src/deserialize_utils.cpp b/src/deserialize_utils.cpp index d89be6c..a3ea3a0 100644 --- a/src/deserialize_utils.cpp +++ b/src/deserialize_utils.cpp @@ -1,7 +1,26 @@ #include "sparrow_ipc/deserialize_utils.hpp" +#include "sparrow_ipc/compression.hpp" + namespace sparrow_ipc::utils { + std::pair get_bitmap_pointer_and_null_count( + std::span validity_buffer_span, + const int64_t length + ) + { + if (validity_buffer_span.empty()) + { + return {nullptr, 0}; + } + auto ptr = const_cast(validity_buffer_span.data()); + const sparrow::dynamic_bitset_view bitmap_view{ + ptr, + static_cast(length) + }; + return {ptr, bitmap_view.null_count()}; + } + std::pair get_bitmap_pointer_and_null_count( const org::apache::arrow::flatbuf::RecordBatch& record_batch, std::span body, @@ -24,4 +43,47 @@ namespace sparrow_ipc::utils }; return {ptr, bitmap_view.null_count()}; } -} \ No newline at end of file + + std::span get_and_decompress_buffer( + const org::apache::arrow::flatbuf::RecordBatch& record_batch, + std::span body, + size_t& buffer_index, + const org::apache::arrow::flatbuf::BodyCompression* compression, + std::vector>& decompressed_storage + ) + { + const auto buffer_metadata = record_batch.buffers()->Get(buffer_index++); + if (body.size() < (buffer_metadata->offset() + buffer_metadata->length())) + { + throw std::runtime_error("Buffer metadata exceeds body size"); + } + auto buffer_span = body.subspan(buffer_metadata->offset(), buffer_metadata->length()); + + if (compression) + { + auto decompressed_result = decompress(compression->codec(), buffer_span); + return std::visit( + [&decompressed_storage](auto&& arg) -> std::span + { + using T = std::decay_t; + if constexpr (std::is_same_v>) + { + // Decompression happened + decompressed_storage.emplace_back(std::move(arg)); + return decompressed_storage.back(); + } + else // It's a std::span + { + // No decompression, but we are in a compression context, so we must copy the buffer + // to ensure the owning ArrowArray has access to all its data. + // TODO think about other strategies + decompressed_storage.emplace_back(arg.begin(), arg.end()); + return decompressed_storage.back(); + } + }, + decompressed_result + ); + } + return buffer_span; + } +} diff --git a/src/flatbuffer_utils.cpp b/src/flatbuffer_utils.cpp index 9f510b7..1e580f3 100644 --- a/src/flatbuffer_utils.cpp +++ b/src/flatbuffer_utils.cpp @@ -562,23 +562,28 @@ namespace sparrow_ipc return buffers; } - flatbuffers::FlatBufferBuilder get_record_batch_message_builder(const sparrow::record_batch& record_batch) + flatbuffers::FlatBufferBuilder get_record_batch_message_builder(const sparrow::record_batch& record_batch, std::optional compression, std::optional body_size_override, const std::vector* compressed_buffers) { const std::vector nodes = create_fieldnodes(record_batch); - const std::vector buffers = get_buffers(record_batch); + const std::vector& buffers = compressed_buffers ? *compressed_buffers : get_buffers(record_batch); flatbuffers::FlatBufferBuilder record_batch_builder; auto nodes_offset = record_batch_builder.CreateVectorOfStructs(nodes); auto buffers_offset = record_batch_builder.CreateVectorOfStructs(buffers); + flatbuffers::Offset compression_offset = 0; + if (compression) + { + compression_offset = org::apache::arrow::flatbuf::CreateBodyCompression(record_batch_builder, compression.value(), org::apache::arrow::flatbuf::BodyCompressionMethod::BUFFER); + } const auto record_batch_offset = org::apache::arrow::flatbuf::CreateRecordBatch( record_batch_builder, static_cast(record_batch.nb_rows()), nodes_offset, buffers_offset, - 0, // TODO: Compression + compression_offset, 0 // TODO :variadic buffer Counts ); - const int64_t body_size = calculate_body_size(record_batch); + const int64_t body_size = body_size_override.value_or(calculate_body_size(record_batch)); const auto record_batch_message_offset = org::apache::arrow::flatbuf::CreateMessage( record_batch_builder, org::apache::arrow::flatbuf::MetadataVersion::V5, diff --git a/src/serialize.cpp b/src/serialize.cpp index a4e797d..644acf3 100644 --- a/src/serialize.cpp +++ b/src/serialize.cpp @@ -1,11 +1,11 @@ -#include "sparrow_ipc/serialize.hpp" +#include +#include "sparrow_ipc/serialize.hpp" #include "sparrow_ipc/flatbuffer_utils.hpp" namespace sparrow_ipc { void common_serialize( - const sparrow::record_batch& record_batch, const flatbuffers::FlatBufferBuilder& builder, any_output_stream& stream ) @@ -20,12 +20,23 @@ namespace sparrow_ipc void serialize_schema_message(const sparrow::record_batch& record_batch, any_output_stream& stream) { - common_serialize(record_batch, get_schema_message_builder(record_batch), stream); + common_serialize(get_schema_message_builder(record_batch), stream); } - void serialize_record_batch(const sparrow::record_batch& record_batch, any_output_stream& stream) + void serialize_record_batch(const sparrow::record_batch& record_batch, any_output_stream& stream, std::optional compression) { - common_serialize(record_batch, get_record_batch_message_builder(record_batch), stream); - generate_body(record_batch, stream); + if (compression.has_value()) + { + // TODO Handle this inside get_record_batch_message_builder + auto [compressed_body, compressed_buffers] = generate_compressed_body_and_buffers(record_batch, compression.value()); + common_serialize(get_record_batch_message_builder(record_batch, compression, compressed_body.size(), &compressed_buffers), stream); + // TODO Use something equivalent to generate_body (stream wise, handling children etc) + stream.write(std::span(compressed_body.data(), compressed_body.size())); + } + else + { + common_serialize(get_record_batch_message_builder(record_batch, compression), stream); + generate_body(record_batch, stream); + } } -} \ No newline at end of file +} diff --git a/src/serialize_utils.cpp b/src/serialize_utils.cpp index 8545927..cbab29f 100644 --- a/src/serialize_utils.cpp +++ b/src/serialize_utils.cpp @@ -1,7 +1,9 @@ +#include "sparrow_ipc/compression.hpp" #include "sparrow_ipc/flatbuffer_utils.hpp" #include "sparrow_ipc/magic_values.hpp" #include "sparrow_ipc/serialize.hpp" -#include "sparrow_ipc/utils.hpp" +#include "sparrow_ipc/serialize_utils.hpp" + namespace sparrow_ipc { @@ -70,14 +72,24 @@ namespace sparrow_ipc return utils::align_to_8(total_size); } - std::size_t calculate_record_batch_message_size(const sparrow::record_batch& record_batch) + std::size_t calculate_record_batch_message_size(const sparrow::record_batch& record_batch, std::optional compression) { // Build the record batch message to get its exact metadata size - flatbuffers::FlatBufferBuilder record_batch_builder = get_record_batch_message_builder(record_batch); + flatbuffers::FlatBufferBuilder record_batch_builder = get_record_batch_message_builder(record_batch, compression); const flatbuffers::uoffset_t record_batch_len = record_batch_builder.GetSize(); - // Calculate body size (already includes 8-byte alignment for each buffer) - const int64_t body_size = calculate_body_size(record_batch); + std::size_t actual_body_size = 0; + if (compression.has_value()) + { + // If compressed, the body size is the sum of compressed buffer sizes + original size prefixes + padding + auto [compressed_body, compressed_buffers] = generate_compressed_body_and_buffers(record_batch, compression.value()); + actual_body_size = compressed_body.size(); + } + else + { + // If not compressed, the body size is the sum of uncompressed buffer sizes with padding + actual_body_size = static_cast(calculate_body_size(record_batch)); + } // Calculate total size: // - Continuation bytes (4) @@ -88,7 +100,39 @@ namespace sparrow_ipc std::size_t metadata_size = continuation.size() + sizeof(uint32_t) + record_batch_len; metadata_size = utils::align_to_8(metadata_size); - return metadata_size + static_cast(body_size); + return metadata_size + actual_body_size; + } + + std::pair, std::vector> + generate_compressed_body_and_buffers(const sparrow::record_batch& record_batch, const org::apache::arrow::flatbuf::CompressionType compression_type) + { + std::vector compressed_body; + std::vector compressed_buffers; + int64_t current_offset = 0; + + for (const auto& column : record_batch.columns()) + { + const auto& arrow_proxy = sparrow::detail::array_access::get_arrow_proxy(column); + for (const auto& buffer : arrow_proxy.buffers()) + { + // Compress the buffer. The returned buffer already has the correct size header. + std::vector compressed_buffer_with_header = compress(compression_type, std::span(buffer.data(), buffer.size())); + + const size_t aligned_chunk_size = utils::align_to_8(compressed_buffer_with_header.size()); + const size_t padding_needed = aligned_chunk_size - compressed_buffer_with_header.size(); + + // Write compressed data with header + compressed_body.insert(compressed_body.end(), compressed_buffer_with_header.begin(), compressed_buffer_with_header.end()); + + // Add padding + compressed_body.insert(compressed_body.end(), padding_needed, 0); + + // Update compressed buffer metadata + compressed_buffers.emplace_back(current_offset, aligned_chunk_size); + current_offset += aligned_chunk_size; + } + } + return {compressed_body, compressed_buffers}; } std::vector get_column_dtypes(const sparrow::record_batch& rb) diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 11c2f9f..d1de914 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -10,6 +10,7 @@ set(SPARROW_IPC_TESTS_SRC test_arrow_schema.cpp test_chunk_memory_output_stream.cpp test_chunk_memory_serializer.cpp + test_compression.cpp test_de_serialization_with_files.cpp $<$>:test_flatbuffer_utils.cpp> test_memory_output_streams.cpp diff --git a/tests/test_compression.cpp b/tests/test_compression.cpp new file mode 100644 index 0000000..4c5946e --- /dev/null +++ b/tests/test_compression.cpp @@ -0,0 +1,92 @@ +#include +#include +#include + +#include + +#include + +namespace sparrow_ipc +{ + TEST_SUITE("De/Compression") + { + TEST_CASE("Unsupported ZSTD de/compression") + { + std::string original_string = "some data to compress"; + std::vector original_data(original_string.begin(), original_string.end()); + const auto compression_type = org::apache::arrow::flatbuf::CompressionType::ZSTD; + + // Test compression with ZSTD + CHECK_THROWS_WITH_AS(compress(compression_type, original_data), "Compression using zstd is not supported yet.", std::runtime_error); + + // Test decompression with ZSTD + CHECK_THROWS_WITH_AS(decompress(compression_type, original_data), "Decompression using zstd is not supported yet.", std::runtime_error); + } + + TEST_CASE("Empty data") + { + const std::vector empty_data; + const auto compression_type = org::apache::arrow::flatbuf::CompressionType::LZ4_FRAME; + + // Test compression of empty data + auto compressed = compress(compression_type, empty_data); + CHECK_EQ(compressed.size(), CompressionHeaderSize); + const std::int64_t header = *reinterpret_cast(compressed.data()); + CHECK_EQ(header, -1); + + // Test decompression of empty data + auto decompressed = decompress(compression_type, compressed); + std::visit([](const auto& value) { CHECK(value.empty()); }, decompressed); + } + + TEST_CASE("Data compression and decompression round-trip") + { + std::string original_string = "Hello world, this is a test of compression and decompression. But we need more words to make this compression worth it!"; + std::vector original_data(original_string.begin(), original_string.end()); + + // Compress data + auto compression_type = org::apache::arrow::flatbuf::CompressionType::LZ4_FRAME; + std::vector compressed_data = compress(compression_type, original_data); + + // Decompress + auto decompressed_result = decompress(compression_type, compressed_data); + std::visit( + [&original_data](const auto& decompressed_data) + { + CHECK_EQ(decompressed_data.size(), original_data.size()); + const std::vector vec(decompressed_data.begin(), decompressed_data.end()); + CHECK_EQ(vec, original_data); + }, + decompressed_result + ); + } + + TEST_CASE("Data compression with incompressible data") + { + std::string original_string = "abc"; + std::vector original_data(original_string.begin(), original_string.end()); + + // Compress data + auto compression_type = org::apache::arrow::flatbuf::CompressionType::LZ4_FRAME; + std::vector compressed_data = compress(compression_type, original_data); + + // Decompress + auto decompressed_result = decompress(compression_type, compressed_data); + std::visit( + [&original_data](const auto& decompressed_data) + { + CHECK_EQ(decompressed_data.size(), original_data.size()); + const std::vector vec(decompressed_data.begin(), decompressed_data.end()); + CHECK_EQ(vec, original_data); + }, + decompressed_result + ); + + // Check that the compressed data is just the original data with a -1 header + const std::int64_t header = *reinterpret_cast(compressed_data.data()); + CHECK_EQ(header, -1); + std::vector body(compressed_data.begin() + sizeof(header), compressed_data.end()); + CHECK_EQ(body, original_data); + } + } +} diff --git a/tests/test_de_serialization_with_files.cpp b/tests/test_de_serialization_with_files.cpp index 3bcb79c..7b7d236 100644 --- a/tests/test_de_serialization_with_files.cpp +++ b/tests/test_de_serialization_with_files.cpp @@ -22,6 +22,9 @@ const std::filesystem::path arrow_testing_data_dir = ARROW_TESTING_DATA_DIR; const std::filesystem::path tests_resources_files_path = arrow_testing_data_dir / "data" / "arrow-ipc-stream" / "integration" / "cpp-21.0.0"; +const std::filesystem::path tests_resources_files_path_with_compression = arrow_testing_data_dir / "data" / "arrow-ipc-stream" + / "integration" / "2.0.0-compression"; + const std::vector files_paths_to_test = { tests_resources_files_path / "generated_primitive", // tests_resources_files_path / "generated_primitive_large_offsets", @@ -29,6 +32,14 @@ const std::vector files_paths_to_test = { // tests_resources_files_path / "generated_primitive_no_batches" }; +const std::vector files_paths_to_test_with_compression = { + tests_resources_files_path_with_compression / "generated_lz4", + tests_resources_files_path_with_compression/ "generated_uncompressible_lz4" +// tests_resources_files_path_with_compression / "generated_zstd" +// tests_resources_files_path_with_compression/ "generated_uncompressible_zstd" +}; + + size_t get_number_of_batches(const std::filesystem::path& json_path) { std::ifstream json_file(json_path); @@ -174,4 +185,55 @@ TEST_SUITE("Integration tests") } } } + + TEST_CASE("Compare record_batch serialization with stream file using LZ4 compression") + { + for (const auto& file_path : files_paths_to_test_with_compression) + { + std::filesystem::path json_path = file_path; + json_path.replace_extension(".json"); + const std::string test_name = "Testing LZ4 compression with " + file_path.filename().string(); + SUBCASE(test_name.c_str()) + { + // Load the JSON file + auto json_data = load_json_file(json_path); + CHECK(json_data != nullptr); + + const size_t num_batches = get_number_of_batches(json_path); + std::vector record_batches_from_json; + for (size_t batch_idx = 0; batch_idx < num_batches; ++batch_idx) + { + INFO("Processing batch " << batch_idx << " of " << num_batches); + record_batches_from_json.emplace_back( + sparrow::json_reader::build_record_batch_from_json(json_data, batch_idx) + ); + } + + // Load stream file + std::filesystem::path stream_file_path = file_path; + stream_file_path.replace_extension(".stream"); + std::ifstream stream_file(stream_file_path, std::ios::in | std::ios::binary); + REQUIRE(stream_file.is_open()); + const std::vector stream_data( + (std::istreambuf_iterator(stream_file)), + (std::istreambuf_iterator()) + ); + stream_file.close(); + + // Process the stream file + const auto record_batches_from_stream = sparrow_ipc::deserialize_stream( + std::span(stream_data) + ); + + std::vector serialized_data; + sparrow_ipc::memory_output_stream stream(serialized_data); + sparrow_ipc::serializer serializer(stream, org::apache::arrow::flatbuf::CompressionType::LZ4_FRAME); + serializer << record_batches_from_json << sparrow_ipc::end_stream; + const auto deserialized_serialized_data = sparrow_ipc::deserialize_stream( + std::span(serialized_data) + ); + compare_record_batches(record_batches_from_stream, deserialized_serialized_data); + } + } + } } diff --git a/tests/test_serialize_utils.cpp b/tests/test_serialize_utils.cpp index ea4011e..3aebbbb 100644 --- a/tests/test_serialize_utils.cpp +++ b/tests/test_serialize_utils.cpp @@ -111,7 +111,7 @@ namespace sparrow_ipc std::vector serialized; memory_output_stream stream(serialized); any_output_stream astream(stream); - serialize_schema_message(record_batch, astream ); + serialize_schema_message(record_batch, astream); CHECK_EQ(estimated_size, serialized.size()); } @@ -147,7 +147,7 @@ namespace sparrow_ipc std::vector serialized; memory_output_stream stream(serialized); any_output_stream astream(stream); - serialize_record_batch(record_batch, astream); + serialize_record_batch(record_batch, astream, std::nullopt); CHECK_EQ(estimated_size, serialized.size()); } @@ -164,7 +164,7 @@ namespace sparrow_ipc std::vector serialized; memory_output_stream stream(serialized); any_output_stream astream(stream); - serialize_record_batch(record_batch, astream); + serialize_record_batch(record_batch, astream, std::nullopt); CHECK_EQ(estimated_size, serialized.size()); } @@ -243,8 +243,8 @@ namespace sparrow_ipc std::vector serialized; memory_output_stream stream(serialized); any_output_stream astream(stream); - serialize_record_batch(record_batch, astream); - CHECK_GT(serialized.size(), 0); + serialize_record_batch(record_batch, astream, std::nullopt); + CHECK_GT(serialized.size(), 0); // Check that it starts with continuation bytes CHECK_GE(serialized.size(), continuation.size()); @@ -278,10 +278,10 @@ namespace sparrow_ipc std::vector serialized; memory_output_stream stream(serialized); any_output_stream astream(stream); - serialize_record_batch(empty_batch, astream); + serialize_record_batch(empty_batch, astream, std::nullopt); CHECK_GT(serialized.size(), 0); CHECK_GE(serialized.size(), continuation.size()); } } } -} \ No newline at end of file +}