diff --git a/internal/core/src/CMakeLists.txt b/internal/core/src/CMakeLists.txt index 28480eb8a1a8..50e34e99b5e2 100644 --- a/internal/core/src/CMakeLists.txt +++ b/internal/core/src/CMakeLists.txt @@ -33,6 +33,4 @@ add_subdirectory( query ) add_subdirectory( segcore ) add_subdirectory( indexbuilder ) add_subdirectory( exec ) -if(USE_DYNAMIC_SIMD) - add_subdirectory( simd ) -endif() +add_subdirectory( bitset ) diff --git a/internal/core/src/bitset/CMakeLists.txt b/internal/core/src/bitset/CMakeLists.txt new file mode 100644 index 000000000000..d6217efb7d7d --- /dev/null +++ b/internal/core/src/bitset/CMakeLists.txt @@ -0,0 +1,33 @@ +set(BITSET_SRCS + detail/platform/dynamic.cpp +) + +if (${CMAKE_SYSTEM_PROCESSOR} STREQUAL "x86_64") + list(APPEND BITSET_SRCS + detail/platform/x86/avx2-inst.cpp + detail/platform/x86/avx512-inst.cpp + detail/platform/x86/instruction_set.cpp + ) + + set_source_files_properties(detail/platform/x86/avx512-inst.cpp PROPERTIES COMPILE_FLAGS "-mavx512f -mavx512bw -mavx512vl -mavx512dq") + set_source_files_properties(detail/platform/x86/avx2-inst.cpp PROPERTIES COMPILE_FLAGS "-mavx2 -mavx -mfma") + + # set_source_files_properties(detail/platform/dynamic.cpp PROPERTIES COMPILE_FLAGS "-mavx512f -mavx512bw -mavx512vl -mavx512dq") + # set_source_files_properties(detail/platform/dynamic.cpp PROPERTIES COMPILE_FLAGS "-mavx2 -mavx -mfma") +elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "arm*") + list(APPEND BITSET_SRCS + detail/platform/arm/neon-inst.cpp + detail/platform/arm/sve-inst.cpp + ) + + # targeting AWS graviton, + # https://github.com/aws/aws-graviton-getting-started/blob/main/c-c%2B%2B.md + + # let dynamic.cpp know that SVE is available + # comment it out for now + # set_source_files_properties(detail/platform/dynamic.cpp PROPERTIES COMPILE_FLAGS "-mcpu=neoverse-v1") + + set_source_files_properties(detail/platform/arm/sve-inst.cpp PROPERTIES COMPILE_FLAGS "-mcpu=neoverse-v1") +endif() + +add_library(milvus_bitset ${BITSET_SRCS}) diff --git a/internal/core/src/bitset/bitset.h b/internal/core/src/bitset/bitset.h new file mode 100644 index 000000000000..27a659ae1456 --- /dev/null +++ b/internal/core/src/bitset/bitset.h @@ -0,0 +1,1081 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include + +#include "common.h" + +namespace milvus { +namespace bitset { + +namespace { + +// A supporting facility for checking out of range. +// It is needed to add a capability to verify that we won't go out of +// range even for the Release build. +template +struct RangeChecker {}; + +// disabled. +template <> +struct RangeChecker { + // Check if a < max + template + static inline void + lt(const SizeT a, const SizeT max) { + } + + // Check if a <= max + template + static inline void + le(const SizeT a, const SizeT max) { + } + + // Check if a == b + template + static inline void + eq(const SizeT a, const SizeT b) { + } +}; + +// enabled. +template <> +struct RangeChecker { + // Check if a < max + template + static inline void + lt(const SizeT a, const SizeT max) { + // todo: replace + assert(a < max); + } + + // Check if a <= max + template + static inline void + le(const SizeT a, const SizeT max) { + // todo: replace + assert(a <= max); + } + + // Check if a == b + template + static inline void + eq(const SizeT a, const SizeT b) { + // todo: replace + assert(a == b); + } +}; + +} // namespace + +// CRTP + +// Bitset view, which does not own the data. +template +class BitsetView; + +// Bitset, which owns the data. +template +class Bitset; + +// This is the base CRTP class. +template +class BitsetBase { + template + friend class BitsetView; + + template + friend class Bitset; + + public: + using policy_type = PolicyT; + using data_type = typename policy_type::data_type; + using size_type = typename policy_type::size_type; + using proxy_type = typename policy_type::proxy_type; + using const_proxy_type = typename policy_type::const_proxy_type; + + using range_checker = RangeChecker; + + // + inline data_type* + data() { + return as_derived().data_impl(); + } + + // + inline const data_type* + data() const { + return as_derived().data_impl(); + } + + // Return the number of bits we're working with. + inline size_type + size() const { + return as_derived().size_impl(); + } + + // Return the number of bytes which is needed to + // contain all our bits. + inline size_type + size_in_bytes() const { + return policy_type::get_required_size_in_bytes(this->size()); + } + + // Return the number of elements which is needed to + // contain all our bits. + inline size_type + size_in_elements() const { + return policy_type::get_required_size_in_elements(this->size()); + } + + // + inline bool + empty() const { + return (this->size() == 0); + } + + // + inline proxy_type + operator[](const size_type bit_idx) { + range_checker::lt(bit_idx, this->size()); + + const size_type idx_v = bit_idx + this->offset(); + return policy_type::get_proxy(this->data(), idx_v); + } + + // + inline bool + operator[](const size_type bit_idx) const { + range_checker::lt(bit_idx, this->size()); + + const size_type idx_v = bit_idx + this->offset(); + const auto proxy = policy_type::get_proxy(this->data(), idx_v); + return proxy.operator bool(); + } + + // Set all bits to true. + inline void + set() { + policy_type::op_set(this->data(), this->offset(), this->size()); + } + + // Set a given bit to a given value. + inline void + set(const size_type bit_idx, const bool value = true) { + this->operator[](bit_idx) = value; + } + + // Set all bits to false. + inline void + reset() { + policy_type::op_reset(this->data(), this->offset(), this->size()); + } + + // Set a given bit to false. + inline void + reset(const size_type bit_idx) { + this->operator[](bit_idx) = false; + } + + // Return whether all bits are set to true. + inline bool + all() const { + return policy_type::op_all(this->data(), this->offset(), this->size()); + } + + // Return whether any of the bits is set to true. + inline bool + any() const { + return (!this->none()); + } + + // Return whether all bits are set to false. + inline bool + none() const { + return policy_type::op_none(this->data(), this->offset(), this->size()); + } + + // Inplace and. + template + inline void + inplace_and(const BitsetBase& other, const size_type size) { + range_checker::le(size, this->size()); + range_checker::le(size, other.size()); + + policy_type::op_and( + this->data(), other.data(), this->offset(), other.offset(), size); + } + + // Inplace and. A given bitset / bitset view is expected to have the same size. + template + inline ImplT& + operator&=(const BitsetBase& other) { + range_checker::eq(other.size(), this->size()); + + this->inplace_and(other, this->size()); + return as_derived(); + } + + // Inplace or. + template + inline void + inplace_or(const BitsetBase& other, const size_type size) { + range_checker::le(size, this->size()); + range_checker::le(size, other.size()); + + policy_type::op_or( + this->data(), other.data(), this->offset(), other.offset(), size); + } + + // Inplace or. A given bitset / bitset view is expected to have the same size. + template + inline ImplT& + operator|=(const BitsetBase& other) { + range_checker::eq(other.size(), this->size()); + + this->inplace_or(other, this->size()); + return as_derived(); + } + + // Revert all bits. + inline void + flip() { + policy_type::op_flip(this->data(), this->offset(), this->size()); + } + + // + inline BitsetView + operator+(const size_type offset) { + return this->view(offset); + } + + // Create a view of a given size from the given position. + inline BitsetView + view(const size_type offset, const size_type size) { + range_checker::le(offset, this->size()); + range_checker::le(offset + size, this->size()); + + return BitsetView( + this->data(), this->offset() + offset, size); + } + + // Create a const view of a given size from the given position. + inline BitsetView + view(const size_type offset, const size_type size) const { + range_checker::le(offset, this->size()); + range_checker::le(offset + size, this->size()); + + return BitsetView( + const_cast(this->data()), + this->offset() + offset, + size); + } + + // Create a view from the given position, which uses all available size. + inline BitsetView + view(const size_type offset) { + range_checker::le(offset, this->size()); + + return BitsetView( + this->data(), this->offset() + offset, this->size() - offset); + } + + // Create a const view from the given position, which uses all available size. + inline const BitsetView + view(const size_type offset) const { + range_checker::le(offset, this->size()); + + return BitsetView( + const_cast(this->data()), + this->offset() + offset, + this->size() - offset); + } + + // Create a view. + inline BitsetView + view() { + return this->view(0); + } + + // Create a const view. + inline const BitsetView + view() const { + return this->view(0); + } + + // Return the number of bits which are set to true. + inline size_type + count() const { + return policy_type::op_count( + this->data(), this->offset(), this->size()); + } + + // Compare the current bitset with another bitset / bitset view. + template + inline bool + operator==(const BitsetBase& other) { + if (this->size() != other.size()) { + return false; + } + + return policy_type::op_eq(this->data(), + other.data(), + this->offset(), + other.offset(), + this->size()); + } + + // Compare the current bitset with another bitset / bitset view. + template + inline bool + operator!=(const BitsetBase& other) { + return (!(*this == other)); + } + + // Inplace xor. + template + inline void + inplace_xor(const BitsetBase& other, const size_type size) { + range_checker::le(size, this->size()); + range_checker::le(size, other.size()); + + policy_type::op_xor( + this->data(), other.data(), this->offset(), other.offset(), size); + } + + // Inplace xor. A given bitset / bitset view is expected to have the same size. + template + inline ImplT& + operator^=(const BitsetBase& other) { + range_checker::eq(other.size(), this->size()); + + this->inplace_xor(other, this->size()); + return as_derived(); + } + + // Inplace sub. + template + inline void + inplace_sub(const BitsetBase& other, const size_type size) { + range_checker::le(size, this->size()); + range_checker::le(size, other.size()); + + policy_type::op_sub( + this->data(), other.data(), this->offset(), other.offset(), size); + } + + // Inplace sub. A given bitset / bitset view is expected to have the same size. + template + inline ImplT& + operator-=(const BitsetBase& other) { + range_checker::eq(other.size(), this->size()); + + this->inplace_sub(other, this->size()); + return as_derived(); + } + + // Find the index of the first bit set to true. + inline std::optional + find_first() const { + return policy_type::op_find( + this->data(), this->offset(), this->size(), 0); + } + + // Find the index of the first bit set to true, starting from a given bit index. + inline std::optional + find_next(const size_type starting_bit_idx) const { + const size_type size_v = this->size(); + if (starting_bit_idx + 1 >= size_v) { + return std::nullopt; + } + + return policy_type::op_find( + this->data(), this->offset(), this->size(), starting_bit_idx + 1); + } + + // Read multiple bits starting from a given bit index. + inline data_type + read(const size_type starting_bit_idx, const size_type nbits) { + range_checker::le(nbits, sizeof(data_type)); + + return policy_type::op_read( + this->data(), this->offset() + starting_bit_idx, nbits); + } + + // Write multiple bits starting from a given bit index. + inline void + write(const size_type starting_bit_idx, + const data_type value, + const size_type nbits) { + range_checker::le(nbits, sizeof(data_type)); + + policy_type::op_write( + this->data(), this->offset() + starting_bit_idx, nbits, value); + } + + // Compare two arrays element-wise + template + void + inplace_compare_column(const T* const __restrict t, + const U* const __restrict u, + const size_type size, + CompareOpType op) { + if (op == CompareOpType::EQ) { + this->inplace_compare_column(t, u, size); + } else if (op == CompareOpType::GE) { + this->inplace_compare_column(t, u, size); + } else if (op == CompareOpType::GT) { + this->inplace_compare_column(t, u, size); + } else if (op == CompareOpType::LE) { + this->inplace_compare_column(t, u, size); + } else if (op == CompareOpType::LT) { + this->inplace_compare_column(t, u, size); + } else if (op == CompareOpType::NE) { + this->inplace_compare_column(t, u, size); + } else { + // unimplemented + } + } + + template + void + inplace_compare_column(const T* const __restrict t, + const U* const __restrict u, + const size_type size) { + range_checker::le(size, this->size()); + + policy_type::template op_compare_column( + this->data(), this->offset(), t, u, size); + } + + // Compare elements of an given array with a given value + template + void + inplace_compare_val(const T* const __restrict t, + const size_type size, + const T& value, + CompareOpType op) { + if (op == CompareOpType::EQ) { + this->inplace_compare_val(t, size, value); + } else if (op == CompareOpType::GE) { + this->inplace_compare_val(t, size, value); + } else if (op == CompareOpType::GT) { + this->inplace_compare_val(t, size, value); + } else if (op == CompareOpType::LE) { + this->inplace_compare_val(t, size, value); + } else if (op == CompareOpType::LT) { + this->inplace_compare_val(t, size, value); + } else if (op == CompareOpType::NE) { + this->inplace_compare_val(t, size, value); + } else { + // unimplemented + } + } + + template + void + inplace_compare_val(const T* const __restrict t, + const size_type size, + const T& value) { + range_checker::le(size, this->size()); + + policy_type::template op_compare_val( + this->data(), this->offset(), t, size, value); + } + + // + template + void + inplace_within_range_column(const T* const __restrict lower, + const T* const __restrict upper, + const T* const __restrict values, + const size_type size, + const RangeType op) { + if (op == RangeType::IncInc) { + this->inplace_within_range_column( + lower, upper, values, size); + } else if (op == RangeType::IncExc) { + this->inplace_within_range_column( + lower, upper, values, size); + } else if (op == RangeType::ExcInc) { + this->inplace_within_range_column( + lower, upper, values, size); + } else if (op == RangeType::ExcExc) { + this->inplace_within_range_column( + lower, upper, values, size); + } else { + // unimplemented + } + } + + template + void + inplace_within_range_column(const T* const __restrict lower, + const T* const __restrict upper, + const T* const __restrict values, + const size_type size) { + range_checker::le(size, this->size()); + + policy_type::template op_within_range_column( + this->data(), this->offset(), lower, upper, values, size); + } + + // + template + void + inplace_within_range_val(const T& lower, + const T& upper, + const T* const __restrict values, + const size_type size, + const RangeType op) { + if (op == RangeType::IncInc) { + this->inplace_within_range_val( + lower, upper, values, size); + } else if (op == RangeType::IncExc) { + this->inplace_within_range_val( + lower, upper, values, size); + } else if (op == RangeType::ExcInc) { + this->inplace_within_range_val( + lower, upper, values, size); + } else if (op == RangeType::ExcExc) { + this->inplace_within_range_val( + lower, upper, values, size); + } else { + // unimplemented + } + } + + template + void + inplace_within_range_val(const T& lower, + const T& upper, + const T* const __restrict values, + const size_type size) { + range_checker::le(size, this->size()); + + policy_type::template op_within_range_val( + this->data(), this->offset(), lower, upper, values, size); + } + + // + template + void + inplace_arith_compare(const T* const __restrict src, + const ArithHighPrecisionType& right_operand, + const ArithHighPrecisionType& value, + const size_type size, + const ArithOpType a_op, + const CompareOpType cmp_op) { + if (a_op == ArithOpType::Add) { + if (cmp_op == CompareOpType::EQ) { + this->inplace_arith_compare( + src, right_operand, value, size); + } else if (cmp_op == CompareOpType::GE) { + this->inplace_arith_compare( + src, right_operand, value, size); + } else if (cmp_op == CompareOpType::GT) { + this->inplace_arith_compare( + src, right_operand, value, size); + } else if (cmp_op == CompareOpType::LE) { + this->inplace_arith_compare( + src, right_operand, value, size); + } else if (cmp_op == CompareOpType::LT) { + this->inplace_arith_compare( + src, right_operand, value, size); + } else if (cmp_op == CompareOpType::NE) { + this->inplace_arith_compare( + src, right_operand, value, size); + } else { + // unimplemented + } + } else if (a_op == ArithOpType::Sub) { + if (cmp_op == CompareOpType::EQ) { + this->inplace_arith_compare( + src, right_operand, value, size); + } else if (cmp_op == CompareOpType::GE) { + this->inplace_arith_compare( + src, right_operand, value, size); + } else if (cmp_op == CompareOpType::GT) { + this->inplace_arith_compare( + src, right_operand, value, size); + } else if (cmp_op == CompareOpType::LE) { + this->inplace_arith_compare( + src, right_operand, value, size); + } else if (cmp_op == CompareOpType::LT) { + this->inplace_arith_compare( + src, right_operand, value, size); + } else if (cmp_op == CompareOpType::NE) { + this->inplace_arith_compare( + src, right_operand, value, size); + } else { + // unimplemented + } + } else if (a_op == ArithOpType::Mul) { + if (cmp_op == CompareOpType::EQ) { + this->inplace_arith_compare( + src, right_operand, value, size); + } else if (cmp_op == CompareOpType::GE) { + this->inplace_arith_compare( + src, right_operand, value, size); + } else if (cmp_op == CompareOpType::GT) { + this->inplace_arith_compare( + src, right_operand, value, size); + } else if (cmp_op == CompareOpType::LE) { + this->inplace_arith_compare( + src, right_operand, value, size); + } else if (cmp_op == CompareOpType::LT) { + this->inplace_arith_compare( + src, right_operand, value, size); + } else if (cmp_op == CompareOpType::NE) { + this->inplace_arith_compare( + src, right_operand, value, size); + } else { + // unimplemented + } + } else if (a_op == ArithOpType::Div) { + if (cmp_op == CompareOpType::EQ) { + this->inplace_arith_compare( + src, right_operand, value, size); + } else if (cmp_op == CompareOpType::GE) { + this->inplace_arith_compare( + src, right_operand, value, size); + } else if (cmp_op == CompareOpType::GT) { + this->inplace_arith_compare( + src, right_operand, value, size); + } else if (cmp_op == CompareOpType::LE) { + this->inplace_arith_compare( + src, right_operand, value, size); + } else if (cmp_op == CompareOpType::LT) { + this->inplace_arith_compare( + src, right_operand, value, size); + } else if (cmp_op == CompareOpType::NE) { + this->inplace_arith_compare( + src, right_operand, value, size); + } else { + // unimplemented + } + } else if (a_op == ArithOpType::Mod) { + if (cmp_op == CompareOpType::EQ) { + this->inplace_arith_compare( + src, right_operand, value, size); + } else if (cmp_op == CompareOpType::GE) { + this->inplace_arith_compare( + src, right_operand, value, size); + } else if (cmp_op == CompareOpType::GT) { + this->inplace_arith_compare( + src, right_operand, value, size); + } else if (cmp_op == CompareOpType::LE) { + this->inplace_arith_compare( + src, right_operand, value, size); + } else if (cmp_op == CompareOpType::LT) { + this->inplace_arith_compare( + src, right_operand, value, size); + } else if (cmp_op == CompareOpType::NE) { + this->inplace_arith_compare( + src, right_operand, value, size); + } else { + // unimplemented + } + } else { + // unimplemented + } + } + + template + void + inplace_arith_compare(const T* const __restrict src, + const ArithHighPrecisionType& right_operand, + const ArithHighPrecisionType& value, + const size_type size) { + range_checker::le(size, this->size()); + + policy_type::template op_arith_compare( + this->data(), this->offset(), src, right_operand, value, size); + } + + // + // Inplace and. Also, counts the number of active bits. + template + inline size_type + inplace_and_with_count(const BitsetBase& other, + const size_type size) { + range_checker::le(size, this->size()); + range_checker::le(size, other.size()); + + return policy_type::op_and_with_count( + this->data(), other.data(), this->offset(), other.offset(), size); + } + + // Inplace or. Also, counts the number of inactive bits. + template + inline size_type + inplace_or_with_count(const BitsetBase& other, + const size_type size) { + range_checker::le(size, this->size()); + range_checker::le(size, other.size()); + + return policy_type::op_or_with_count( + this->data(), other.data(), this->offset(), other.offset(), size); + } + + private: + // Return the starting bit offset in our container. + inline size_type + offset() const { + return as_derived().offset_impl(); + } + + // CRTP + inline ImplT& + as_derived() { + return static_cast(*this); + } + + // CRTP + inline const ImplT& + as_derived() const { + return static_cast(*this); + } +}; + +// Bitset view +template +class BitsetView : public BitsetBase, + IsRangeCheckEnabled> { + friend class BitsetBase, + IsRangeCheckEnabled>; + + public: + using policy_type = PolicyT; + using data_type = typename policy_type::data_type; + using size_type = typename policy_type::size_type; + using proxy_type = typename policy_type::proxy_type; + using const_proxy_type = typename policy_type::const_proxy_type; + + using range_checker = RangeChecker; + + BitsetView() { + } + BitsetView(const BitsetView&) = default; + BitsetView(BitsetView&&) = default; + BitsetView& + operator=(const BitsetView&) = default; + BitsetView& + operator=(BitsetView&&) = default; + + template + BitsetView(BitsetBase& bitset) + : Data{bitset.data()}, Size{bitset.size()}, Offset{bitset.offset()} { + } + + BitsetView(void* data, const size_type size) + : Data{reinterpret_cast(data)}, Size{size}, Offset{0} { + } + + BitsetView(void* data, const size_type offset, const size_type size) + : Data{reinterpret_cast(data)}, Size{size}, Offset{offset} { + } + + private: + // the referenced bits are [Offset, Offset + Size) + data_type* Data = nullptr; + // measured in bits + size_type Size = 0; + // measured in bits + size_type Offset = 0; + + inline data_type* + data_impl() { + return Data; + } + inline const data_type* + data_impl() const { + return Data; + } + inline size_type + size_impl() const { + return Size; + } + inline size_type + offset_impl() const { + return Offset; + } +}; + +// Bitset +template +class Bitset + : public BitsetBase, + IsRangeCheckEnabled> { + friend class BitsetBase, + IsRangeCheckEnabled>; + + public: + using policy_type = PolicyT; + using data_type = typename policy_type::data_type; + using size_type = typename policy_type::size_type; + using proxy_type = typename policy_type::proxy_type; + using const_proxy_type = typename policy_type::const_proxy_type; + + // This is the container type. + using container_type = ContainerT; + // This is how the data is stored. For example, we may operate using + // uint64_t values, but store the data in std::vector container. + // This is useful if we need to convert a bitset into a container + // using move operator. + using container_data_type = typename container_type::value_type; + + using range_checker = RangeChecker; + + // Allocate an empty one. + Bitset() { + } + // Allocate the given number of bits. + Bitset(const size_type size) + : Data(get_required_size_in_container_elements(size)), Size{size} { + } + // Allocate the given number of bits, initialize with a given value. + Bitset(const size_type size, const bool init) + : Data(get_required_size_in_container_elements(size), + init ? data_type(-1) : 0), + Size{size} { + } + // Do not allow implicit copies (Rust style). + Bitset(const Bitset&) = delete; + // Allow default move. + Bitset(Bitset&&) = default; + // Do not allow implicit copies (Rust style). + Bitset& + operator=(const Bitset&) = delete; + // Allow default move. + Bitset& + operator=(Bitset&&) = default; + + template + Bitset(const BitsetBase& other) { + Data = container_type( + get_required_size_in_container_elements(other.size())); + Size = other.size(); + + policy_type::op_copy(other.data(), + other.offset(), + this->data(), + this->offset(), + other.size()); + } + + // Clone a current bitset (Rust style). + Bitset + clone() const { + Bitset cloned; + cloned.Data = Data; + cloned.Size = Size; + return cloned; + } + + // Rust style. + inline container_type + into() && { + return std::move(this->Data); + } + + // Resize. + void + resize(const size_type new_size) { + const size_type new_size_in_container_elements = + get_required_size_in_container_elements(new_size); + Data.resize(new_size_in_container_elements); + Size = new_size; + } + + // Resize and initialize new bits with a given value if grown. + void + resize(const size_type new_size, const bool init) { + const size_type old_size = this->size(); + this->resize(new_size); + + if (new_size > old_size) { + policy_type::op_fill( + this->data(), old_size, new_size - old_size, init); + } + } + + // Append data from another bitset / bitset view in + // [starting_bit_idx, starting_bit_idx + count) range + // to the end of this bitset. + template + void + append(const BitsetBase& other, + const size_type starting_bit_idx, + const size_type count) { + range_checker::le(starting_bit_idx, other.size()); + + const size_type old_size = this->size(); + this->resize(this->size() + count); + + policy_type::op_copy(other.data(), + other.offset() + starting_bit_idx, + this->data(), + this->offset() + old_size, + count); + } + + // Append data from another bitset / bitset view + // to the end of this bitset. + template + void + append(const BitsetBase& other) { + this->append(other, 0, other.size()); + } + + // Make bitset empty. + inline void + clear() { + Data.clear(); + Size = 0; + } + + // Reserve + inline void + reserve(const size_type capacity) { + const size_type capacity_in_container_elements = + get_required_size_in_container_elements(capacity); + Data.reserve(capacity_in_container_elements); + } + + // Return a new bitset, equal to a | b + template + friend Bitset + operator|(const BitsetBase& a, + const BitsetBase& b) { + Bitset clone(a); + return std::move(clone |= b); + } + + // Return a new bitset, equal to a - b + template + friend Bitset + operator-(const BitsetBase& a, + const BitsetBase& b) { + Bitset clone(a); + return std::move(clone -= b); + } + + protected: + // the container + container_type Data; + // the actual number of bits + size_type Size = 0; + + inline data_type* + data_impl() { + return reinterpret_cast(Data.data()); + } + inline const data_type* + data_impl() const { + return reinterpret_cast(Data.data()); + } + inline size_type + size_impl() const { + return Size; + } + inline size_type + offset_impl() const { + return 0; + } + + // + static inline size_type + get_required_size_in_container_elements(const size_t size) { + const size_type size_in_bytes = + policy_type::get_required_size_in_bytes(size); + return (size_in_bytes + sizeof(container_data_type) - 1) / + sizeof(container_data_type); + } +}; + +} // namespace bitset +} // namespace milvus diff --git a/internal/core/src/bitset/common.h b/internal/core/src/bitset/common.h new file mode 100644 index 000000000000..662813e91c2b --- /dev/null +++ b/internal/core/src/bitset/common.h @@ -0,0 +1,147 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include + +namespace milvus { +namespace bitset { + +// this option is only somewhat supported +// #define BITSET_HEADER_ONLY + +// a supporting utility +template +inline constexpr bool always_false_v = false; + +// a ? b +enum class CompareOpType { + GT = 1, + GE = 2, + LT = 3, + LE = 4, + EQ = 5, + NE = 6, +}; + +template +struct CompareOperator { + template + static inline bool + compare(const T& t, const U& u) { + if constexpr (Op == CompareOpType::EQ) { + return (t == u); + } else if constexpr (Op == CompareOpType::GE) { + return (t >= u); + } else if constexpr (Op == CompareOpType::GT) { + return (t > u); + } else if constexpr (Op == CompareOpType::LE) { + return (t <= u); + } else if constexpr (Op == CompareOpType::LT) { + return (t < u); + } else if constexpr (Op == CompareOpType::NE) { + return (t != u); + } else { + // unimplemented + static_assert(always_false_v, "unimplemented"); + } + } +}; + +// a ? v && v ? b +enum class RangeType { + // [a, b] + IncInc, + // [a, b) + IncExc, + // (a, b] + ExcInc, + // (a, b) + ExcExc +}; + +template +struct RangeOperator { + template + static inline bool + within_range(const T& lower, const T& upper, const T& value) { + if constexpr (Op == RangeType::IncInc) { + return (lower <= value && value <= upper); + } else if constexpr (Op == RangeType::ExcInc) { + return (lower < value && value <= upper); + } else if constexpr (Op == RangeType::IncExc) { + return (lower <= value && value < upper); + } else if constexpr (Op == RangeType::ExcExc) { + return (lower < value && value < upper); + } else { + // unimplemented + static_assert(always_false_v, "unimplemented"); + } + } +}; + +// +template +struct Range2Compare { + static constexpr inline CompareOpType lower = + (Op == RangeType::IncInc || Op == RangeType::IncExc) + ? CompareOpType::LE + : CompareOpType::LT; + static constexpr inline CompareOpType upper = + (Op == RangeType::IncInc || Op == RangeType::ExcInc) + ? CompareOpType::LE + : CompareOpType::LT; +}; + +// The following operation is Milvus-specific +enum class ArithOpType { Add, Sub, Mul, Div, Mod }; + +template +using ArithHighPrecisionType = + std::conditional_t && !std::is_same_v, + int64_t, + T>; + +template +struct ArithCompareOperator { + template + static inline bool + compare(const T& left, + const ArithHighPrecisionType& right, + const ArithHighPrecisionType& value) { + if constexpr (AOp == ArithOpType::Add) { + return CompareOperator::compare(left + right, value); + } else if constexpr (AOp == ArithOpType::Sub) { + return CompareOperator::compare(left - right, value); + } else if constexpr (AOp == ArithOpType::Mul) { + return CompareOperator::compare(left * right, value); + } else if constexpr (AOp == ArithOpType::Div) { + return CompareOperator::compare(left / right, value); + } else if constexpr (AOp == ArithOpType::Mod) { + return CompareOperator::compare(fmod(left, right), value); + } else { + // unimplemented + static_assert(always_false_v, "unimplemented"); + } + } +}; + +} // namespace bitset +} // namespace milvus diff --git a/internal/core/src/bitset/detail/bit_wise.h b/internal/core/src/bitset/detail/bit_wise.h new file mode 100644 index 000000000000..5e8c1a37914c --- /dev/null +++ b/internal/core/src/bitset/detail/bit_wise.h @@ -0,0 +1,416 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +#include "proxy.h" + +namespace milvus { +namespace bitset { +namespace detail { + +// This is a naive reference policy that operates on bit level. +// No optimizations are applied. +// This is little-endian based. +template +struct BitWiseBitsetPolicy { + using data_type = ElementT; + constexpr static auto data_bits = sizeof(data_type) * 8; + + using size_type = size_t; + + using self_type = BitWiseBitsetPolicy; + + using proxy_type = Proxy; + using const_proxy_type = ConstProxy; + + static inline size_type + get_element(const size_t idx) { + return idx / data_bits; + } + + static inline size_type + get_shift(const size_t idx) { + return idx % data_bits; + } + + static inline size_type + get_required_size_in_elements(const size_t size) { + return (size + data_bits - 1) / data_bits; + } + + static inline size_type + get_required_size_in_bytes(const size_t size) { + return get_required_size_in_elements(size) * sizeof(data_type); + } + + static inline proxy_type + get_proxy(data_type* const __restrict data, const size_type idx) { + data_type& element = data[get_element(idx)]; + const size_type shift = get_shift(idx); + return proxy_type{element, shift}; + } + + static inline const_proxy_type + get_proxy(const data_type* const __restrict data, const size_type idx) { + const data_type& element = data[get_element(idx)]; + const size_type shift = get_shift(idx); + return const_proxy_type{element, shift}; + } + + static inline data_type + op_read(const data_type* const data, + const size_type start, + const size_type nbits) { + data_type value = 0; + for (size_type i = 0; i < nbits; i++) { + const auto proxy = get_proxy(data, start + i); + value += proxy ? (data_type(1) << i) : 0; + } + + return value; + } + + static void + op_write(data_type* const data, + const size_type start, + const size_type nbits, + const data_type value) { + for (size_type i = 0; i < nbits; i++) { + auto proxy = get_proxy(data, start + i); + data_type mask = data_type(1) << i; + if ((value & mask) == mask) { + proxy = true; + } else { + proxy = false; + } + } + } + + static inline void + op_flip(data_type* const data, + const size_type start, + const size_type size) { + for (size_type i = 0; i < size; i++) { + auto proxy = get_proxy(data, start + i); + proxy.flip(); + } + } + + static inline void + op_and(data_type* const left, + const data_type* const right, + const size_t start_left, + const size_t start_right, + const size_t size) { + // todo: check if intersect + + for (size_type i = 0; i < size; i++) { + auto proxy_left = get_proxy(left, start_left + i); + auto proxy_right = get_proxy(right, start_right + i); + + proxy_left &= proxy_right; + } + } + + static inline void + op_or(data_type* const left, + const data_type* const right, + const size_t start_left, + const size_t start_right, + const size_t size) { + // todo: check if intersect + + for (size_type i = 0; i < size; i++) { + auto proxy_left = get_proxy(left, start_left + i); + auto proxy_right = get_proxy(right, start_right + i); + + proxy_left |= proxy_right; + } + } + + static inline void + op_set(data_type* const data, const size_type start, const size_type size) { + for (size_type i = 0; i < size; i++) { + get_proxy(data, start + i) = true; + } + } + + static inline void + op_reset(data_type* const data, + const size_type start, + const size_type size) { + for (size_type i = 0; i < size; i++) { + get_proxy(data, start + i) = false; + } + } + + static inline bool + op_all(const data_type* const data, + const size_type start, + const size_type size) { + for (size_type i = 0; i < size; i++) { + if (!get_proxy(data, start + i)) { + return false; + } + } + + return true; + } + + static inline bool + op_none(const data_type* const data, + const size_type start, + const size_type size) { + for (size_type i = 0; i < size; i++) { + if (get_proxy(data, start + i)) { + return false; + } + } + + return true; + } + + static void + op_copy(const data_type* const src, + const size_type start_src, + data_type* const dst, + const size_type start_dst, + const size_type size) { + for (size_type i = 0; i < size; i++) { + const auto src_p = get_proxy(src, start_src + i); + auto dst_p = get_proxy(dst, start_dst + i); + dst_p = src_p.operator bool(); + } + } + + static void + op_fill(data_type* const dst, + const size_type start_dst, + const size_type size, + const bool value) { + for (size_type i = 0; i < size; i++) { + auto dst_p = get_proxy(dst, start_dst + i); + dst_p = value; + } + } + + static inline size_type + op_count(const data_type* const data, + const size_type start, + const size_type size) { + size_type count = 0; + + for (size_type i = 0; i < size; i++) { + auto proxy = get_proxy(data, start + i); + count += (proxy) ? 1 : 0; + } + + return count; + } + + static inline bool + op_eq(const data_type* const left, + const data_type* const right, + const size_t start_left, + const size_t start_right, + const size_t size) { + for (size_type i = 0; i < size; i++) { + const auto proxy_left = get_proxy(left, start_left + i); + const auto proxy_right = get_proxy(right, start_right + i); + + if (proxy_left != proxy_right) { + return false; + } + } + + return true; + } + + static inline void + op_xor(data_type* const left, + const data_type* const right, + const size_t start_left, + const size_t start_right, + const size_t size) { + // todo: check if intersect + + for (size_type i = 0; i < size; i++) { + auto proxy_left = get_proxy(left, start_left + i); + const auto proxy_right = get_proxy(right, start_right + i); + + proxy_left ^= proxy_right; + } + } + + static inline void + op_sub(data_type* const left, + const data_type* const right, + const size_t start_left, + const size_t start_right, + const size_t size) { + // todo: check if intersect + + for (size_type i = 0; i < size; i++) { + auto proxy_left = get_proxy(left, start_left + i); + const auto proxy_right = get_proxy(right, start_right + i); + + proxy_left &= ~proxy_right; + } + } + + // + static inline std::optional + op_find(const data_type* const data, + const size_type start, + const size_type size, + const size_type starting_idx) { + for (size_type i = starting_idx; i < size; i++) { + const auto proxy = get_proxy(data, start + i); + if (proxy) { + return i; + } + } + + return std::nullopt; + } + + // + template + static inline void + op_compare_column(data_type* const __restrict data, + const size_type start, + const T* const __restrict t, + const U* const __restrict u, + const size_type size) { + for (size_type i = 0; i < size; i++) { + get_proxy(data, start + i) = + CompareOperator::compare(t[i], u[i]); + } + } + + // + template + static inline void + op_compare_val(data_type* const __restrict data, + const size_type start, + const T* const __restrict t, + const size_type size, + const T& value) { + for (size_type i = 0; i < size; i++) { + get_proxy(data, start + i) = + CompareOperator::compare(t[i], value); + } + } + + template + static inline void + op_within_range_column(data_type* const __restrict data, + const size_type start, + const T* const __restrict lower, + const T* const __restrict upper, + const T* const __restrict values, + const size_type size) { + for (size_type i = 0; i < size; i++) { + get_proxy(data, start + i) = + RangeOperator::within_range(lower[i], upper[i], values[i]); + } + } + + // + template + static inline void + op_within_range_val(data_type* const __restrict data, + const size_type start, + const T& lower, + const T& upper, + const T* const __restrict values, + const size_type size) { + for (size_type i = 0; i < size; i++) { + get_proxy(data, start + i) = + RangeOperator::within_range(lower, upper, values[i]); + } + } + + // + template + static inline void + op_arith_compare(data_type* const __restrict data, + const size_type start, + const T* const __restrict src, + const ArithHighPrecisionType& right_operand, + const ArithHighPrecisionType& value, + const size_type size) { + for (size_type i = 0; i < size; i++) { + get_proxy(data, start + i) = + ArithCompareOperator::compare( + src[i], right_operand, value); + } + } + + // + static inline size_t + op_and_with_count(data_type* const left, + const data_type* const right, + const size_t start_left, + const size_t start_right, + const size_t size) { + // todo: check if intersect + + size_t active = 0; + for (size_type i = 0; i < size; i++) { + auto proxy_left = get_proxy(left, start_left + i); + auto proxy_right = get_proxy(right, start_right + i); + + const bool b = proxy_left & proxy_right; + proxy_left = b; + + active += b ? 1 : 0; + } + + return active; + } + + static inline size_t + op_or_with_count(data_type* const left, + const data_type* const right, + const size_t start_left, + const size_t start_right, + const size_t size) { + // todo: check if intersect + + size_t inactive = 0; + for (size_type i = 0; i < size; i++) { + auto proxy_left = get_proxy(left, start_left + i); + auto proxy_right = get_proxy(right, start_right + i); + + const bool b = proxy_left | proxy_right; + proxy_left = b; + + inactive += b ? 0 : 1; + } + + return inactive; + } +}; + +} // namespace detail +} // namespace bitset +} // namespace milvus diff --git a/internal/core/src/bitset/detail/ctz.h b/internal/core/src/bitset/detail/ctz.h new file mode 100644 index 000000000000..fb758cb84a8a --- /dev/null +++ b/internal/core/src/bitset/detail/ctz.h @@ -0,0 +1,65 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +namespace milvus { +namespace bitset { +namespace detail { + +// returns 8 * sizeof(T) for 0 +// returns 1 for 0b10 +// returns 2 for 0b100 +template +struct CtzHelper {}; + +template <> +struct CtzHelper { + static inline auto + ctz(const uint8_t value) { + return __builtin_ctz(value); + } +}; + +template <> +struct CtzHelper { + static inline auto + ctz(const unsigned int value) { + return __builtin_ctz(value); + } +}; + +template <> +struct CtzHelper { + static inline auto + ctz(const unsigned long value) { + return __builtin_ctzl(value); + } +}; + +template <> +struct CtzHelper { + static inline auto + ctz(const unsigned long long value) { + return __builtin_ctzll(value); + } +}; + +} // namespace detail +} // namespace bitset +} // namespace milvus diff --git a/internal/core/src/bitset/detail/element_vectorized.h b/internal/core/src/bitset/detail/element_vectorized.h new file mode 100644 index 000000000000..393f9d01ae28 --- /dev/null +++ b/internal/core/src/bitset/detail/element_vectorized.h @@ -0,0 +1,447 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include + +#include "proxy.h" +#include "element_wise.h" + +namespace milvus { +namespace bitset { +namespace detail { + +// SIMD applied on top of ElementWiseBitsetPolicy +template +struct VectorizedElementWiseBitsetPolicy { + using data_type = ElementT; + constexpr static auto data_bits = sizeof(data_type) * 8; + + using size_type = size_t; + + using self_type = VectorizedElementWiseBitsetPolicy; + + using proxy_type = Proxy; + using const_proxy_type = ConstProxy; + + static inline size_type + get_element(const size_t idx) { + return idx / data_bits; + } + + static inline size_type + get_shift(const size_t idx) { + return idx % data_bits; + } + + static inline size_type + get_required_size_in_elements(const size_t size) { + return (size + data_bits - 1) / data_bits; + } + + static inline size_type + get_required_size_in_bytes(const size_t size) { + return get_required_size_in_elements(size) * sizeof(data_type); + } + + static inline proxy_type + get_proxy(data_type* const __restrict data, const size_type idx) { + data_type& element = data[get_element(idx)]; + const size_type shift = get_shift(idx); + return proxy_type{element, shift}; + } + + static inline const_proxy_type + get_proxy(const data_type* const __restrict data, const size_type idx) { + const data_type& element = data[get_element(idx)]; + const size_type shift = get_shift(idx); + return const_proxy_type{element, shift}; + } + + static inline void + op_flip(data_type* const data, + const size_type start, + const size_type size) { + ElementWiseBitsetPolicy::op_flip(data, start, size); + } + + static inline void + op_and(data_type* const left, + const data_type* const right, + const size_t start_left, + const size_t start_right, + const size_t size) { + ElementWiseBitsetPolicy::op_and( + left, right, start_left, start_right, size); + } + + static inline void + op_or(data_type* const left, + const data_type* const right, + const size_t start_left, + const size_t start_right, + const size_t size) { + ElementWiseBitsetPolicy::op_or( + left, right, start_left, start_right, size); + } + + static inline void + op_set(data_type* const data, const size_type start, const size_type size) { + ElementWiseBitsetPolicy::op_set(data, start, size); + } + + static inline void + op_reset(data_type* const data, + const size_type start, + const size_type size) { + ElementWiseBitsetPolicy::op_reset(data, start, size); + } + + static inline bool + op_all(const data_type* const data, + const size_type start, + const size_type size) { + return ElementWiseBitsetPolicy::op_all(data, start, size); + } + + static inline bool + op_none(const data_type* const data, + const size_type start, + const size_type size) { + return ElementWiseBitsetPolicy::op_none(data, start, size); + } + + static void + op_copy(const data_type* const src, + const size_type start_src, + data_type* const dst, + const size_type start_dst, + const size_type size) { + ElementWiseBitsetPolicy::op_copy( + src, start_src, dst, start_dst, size); + } + + static inline size_type + op_count(const data_type* const data, + const size_type start, + const size_type size) { + return ElementWiseBitsetPolicy::op_count(data, start, size); + } + + static inline bool + op_eq(const data_type* const left, + const data_type* const right, + const size_type start_left, + const size_type start_right, + const size_type size) { + return ElementWiseBitsetPolicy::op_eq( + left, right, start_left, start_right, size); + } + + static inline void + op_xor(data_type* const left, + const data_type* const right, + const size_t start_left, + const size_t start_right, + const size_t size) { + ElementWiseBitsetPolicy::op_xor( + left, right, start_left, start_right, size); + } + + static inline void + op_sub(data_type* const left, + const data_type* const right, + const size_t start_left, + const size_t start_right, + const size_t size) { + ElementWiseBitsetPolicy::op_sub( + left, right, start_left, start_right, size); + } + + static void + op_fill(data_type* const data, + const size_type start, + const size_type size, + const bool value) { + ElementWiseBitsetPolicy::op_fill(data, start, size, value); + } + + // + static inline std::optional + op_find(const data_type* const data, + const size_type start, + const size_type size, + const size_type starting_idx) { + return ElementWiseBitsetPolicy::op_find( + data, start, size, starting_idx); + } + + // + template + static inline void + op_compare_column(data_type* const __restrict data, + const size_type start, + const T* const __restrict t, + const U* const __restrict u, + const size_type size) { + op_func( + start, + size, + [data, t, u](const size_type starting_bit, + const size_type ptr_offset, + const size_type nbits) { + ElementWiseBitsetPolicy:: + template op_compare_column(data, + starting_bit, + t + ptr_offset, + u + ptr_offset, + nbits); + }, + [data, t, u](const size_type starting_element, + const size_type ptr_offset, + const size_type nbits) { + return VectorizedT::template op_compare_column( + reinterpret_cast(data + starting_element), + t + ptr_offset, + u + ptr_offset, + nbits); + }); + } + + // + template + static inline void + op_compare_val(data_type* const __restrict data, + const size_type start, + const T* const __restrict t, + const size_type size, + const T& value) { + op_func( + start, + size, + [data, t, value](const size_type starting_bit, + const size_type ptr_offset, + const size_type nbits) { + ElementWiseBitsetPolicy::template op_compare_val( + data, starting_bit, t + ptr_offset, nbits, value); + }, + [data, t, value](const size_type starting_element, + const size_type ptr_offset, + const size_type nbits) { + return VectorizedT::template op_compare_val( + reinterpret_cast(data + starting_element), + t + ptr_offset, + nbits, + value); + }); + } + + // + template + static inline void + op_within_range_column(data_type* const __restrict data, + const size_type start, + const T* const __restrict lower, + const T* const __restrict upper, + const T* const __restrict values, + const size_type size) { + op_func( + start, + size, + [data, lower, upper, values](const size_type starting_bit, + const size_type ptr_offset, + const size_type nbits) { + ElementWiseBitsetPolicy:: + template op_within_range_column(data, + starting_bit, + lower + ptr_offset, + upper + ptr_offset, + values + ptr_offset, + nbits); + }, + [data, lower, upper, values](const size_type starting_element, + const size_type ptr_offset, + const size_type nbits) { + return VectorizedT::template op_within_range_column( + reinterpret_cast(data + starting_element), + lower + ptr_offset, + upper + ptr_offset, + values + ptr_offset, + nbits); + }); + } + + // + template + static inline void + op_within_range_val(data_type* const __restrict data, + const size_type start, + const T& lower, + const T& upper, + const T* const __restrict values, + const size_type size) { + op_func( + start, + size, + [data, lower, upper, values](const size_type starting_bit, + const size_type ptr_offset, + const size_type nbits) { + ElementWiseBitsetPolicy:: + template op_within_range_val(data, + starting_bit, + lower, + upper, + values + ptr_offset, + nbits); + }, + [data, lower, upper, values](const size_type starting_element, + const size_type ptr_offset, + const size_type nbits) { + return VectorizedT::template op_within_range_val( + reinterpret_cast(data + starting_element), + lower, + upper, + values + ptr_offset, + nbits); + }); + } + + // + template + static inline void + op_arith_compare(data_type* const __restrict data, + const size_type start, + const T* const __restrict src, + const ArithHighPrecisionType& right_operand, + const ArithHighPrecisionType& value, + const size_type size) { + op_func( + start, + size, + [data, src, right_operand, value](const size_type starting_bit, + const size_type ptr_offset, + const size_type nbits) { + ElementWiseBitsetPolicy:: + template op_arith_compare(data, + starting_bit, + src + ptr_offset, + right_operand, + value, + nbits); + }, + [data, src, right_operand, value](const size_type starting_element, + const size_type ptr_offset, + const size_type nbits) { + return VectorizedT::template op_arith_compare( + reinterpret_cast(data + starting_element), + src + ptr_offset, + right_operand, + value, + nbits); + }); + } + + // + static inline size_t + op_and_with_count(data_type* const left, + const data_type* const right, + const size_t start_left, + const size_t start_right, + const size_t size) { + return ElementWiseBitsetPolicy::op_and_with_count( + left, right, start_left, start_right, size); + } + + static inline size_t + op_or_with_count(data_type* const left, + const data_type* const right, + const size_t start_left, + const size_t start_right, + const size_t size) { + return ElementWiseBitsetPolicy::op_or_with_count( + left, right, start_left, start_right, size); + } + + // void FuncBaseline(const size_t starting_bit, const size_type ptr_offset, const size_type nbits) + // bool FuncVectorized(const size_type starting_element, const size_type ptr_offset, const size_type nbits) + template + static inline void + op_func(const size_type start, + const size_type size, + FuncBaseline func_baseline, + FuncVectorized func_vectorized) { + if (size == 0) { + return; + } + + auto start_element = get_element(start); + const auto end_element = get_element(start + size); + + const auto start_shift = get_shift(start); + const auto end_shift = get_shift(start + size); + + // same element? + if (start_element == end_element) { + func_baseline(start, 0, size); + return; + } + + // + uintptr_t ptr_offset = 0; + + // process the first element + if (start_shift != 0) { + // it is possible to do vectorized masking here, but it is not worth it + func_baseline(start, 0, size); + + // start from the next element + start_element += 1; + ptr_offset += data_bits - start_shift; + } + + // process the middle + { + const size_t starting_bit_idx = start_element * data_bits; + const size_t nbits = (end_element - start_element) * data_bits; + + // check if vectorized implementation is available + if (!func_vectorized(start_element, ptr_offset, nbits)) { + // vectorized implementation is not available, invoke the default one + func_baseline(starting_bit_idx, ptr_offset, nbits); + } + + // + ptr_offset += nbits; + } + + // process the last element + if (end_shift != 0) { + // it is possible to do vectorized masking here, but it is not worth it + const size_t starting_bit_idx = end_element * data_bits; + + func_baseline(starting_bit_idx, ptr_offset, end_shift); + } + } +}; + +} // namespace detail +} // namespace bitset +} // namespace milvus diff --git a/internal/core/src/bitset/detail/element_wise.h b/internal/core/src/bitset/detail/element_wise.h new file mode 100644 index 000000000000..062b1442909a --- /dev/null +++ b/internal/core/src/bitset/detail/element_wise.h @@ -0,0 +1,979 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +#include "proxy.h" + +#include "ctz.h" +#include "popcount.h" + +namespace milvus { +namespace bitset { +namespace detail { + +// This one is similar to boost::dynamic_bitset +template +struct ElementWiseBitsetPolicy { + using data_type = ElementT; + constexpr static auto data_bits = sizeof(data_type) * 8; + + using size_type = size_t; + + using self_type = ElementWiseBitsetPolicy; + + using proxy_type = Proxy; + using const_proxy_type = ConstProxy; + + static inline size_type + get_element(const size_t idx) { + return idx / data_bits; + } + + static inline size_type + get_shift(const size_t idx) { + return idx % data_bits; + } + + static inline size_type + get_required_size_in_elements(const size_t size) { + return (size + data_bits - 1) / data_bits; + } + + static inline size_type + get_required_size_in_bytes(const size_t size) { + return get_required_size_in_elements(size) * sizeof(data_type); + } + + static inline proxy_type + get_proxy(data_type* const __restrict data, const size_type idx) { + data_type& element = data[get_element(idx)]; + const size_type shift = get_shift(idx); + return proxy_type{element, shift}; + } + + static inline const_proxy_type + get_proxy(const data_type* const __restrict data, const size_type idx) { + const data_type& element = data[get_element(idx)]; + const size_type shift = get_shift(idx); + return const_proxy_type{element, shift}; + } + + static inline data_type + op_read(const data_type* const data, + const size_type start, + const size_type nbits) { + if (nbits == 0) { + return 0; + } + + const auto start_element = get_element(start); + const auto end_element = get_element(start + nbits - 1); + + const auto start_shift = get_shift(start); + const auto end_shift = get_shift(start + nbits - 1); + + if (start_element == end_element) { + // read from 1 element only + const data_type m1 = get_shift_mask_end(start_shift); + const data_type m2 = get_shift_mask_begin(end_shift + 1); + const data_type mask = get_shift_mask_end(start_shift) & + get_shift_mask_begin(end_shift + 1); + + // read and shift + const data_type element = data[start_element]; + const data_type value = (element & mask) >> start_shift; + return value; + } else { + // read from 2 elements + const data_type first_v = data[start_element]; + const data_type second_v = data[start_element + 1]; + + const data_type first_mask = get_shift_mask_end(start_shift); + const data_type second_mask = get_shift_mask_begin(end_shift + 1); + + const data_type value1 = (first_v & first_mask) >> start_shift; + const data_type value2 = (second_v & second_mask); + const data_type value = + value1 | (value2 << (data_bits - start_shift)); + + return value; + } + } + + static inline void + op_write(data_type* const data, + const size_type start, + const size_type nbits, + const data_type value) { + if (nbits == 0) { + return; + } + + const auto start_element = get_element(start); + const auto end_element = get_element(start + nbits - 1); + + const auto start_shift = get_shift(start); + const auto end_shift = get_shift(start + nbits - 1); + + if (start_element == end_element) { + // write into a single element + + const data_type m1 = get_shift_mask_end(start_shift); + const data_type m2 = get_shift_mask_begin(end_shift + 1); + const data_type mask = get_shift_mask_end(start_shift) & + get_shift_mask_begin(end_shift + 1); + + // read an existing value + const data_type element = data[start_element]; + // combine a new value + const data_type new_value = + (element & (~mask)) | ((value << start_shift) & mask); + // write it back + data[start_element] = new_value; + } else { + // write into two elements + const data_type first_v = data[start_element]; + const data_type second_v = data[start_element + 1]; + + const data_type first_mask = get_shift_mask_end(start_shift); + const data_type second_mask = get_shift_mask_begin(end_shift + 1); + + const data_type value1 = (first_v & (~first_mask)) | + ((value << start_shift) & first_mask); + const data_type value2 = + (second_v & (~second_mask)) | + ((value >> (data_bits - start_shift)) & second_mask); + + data[start_element] = value1; + data[start_element + 1] = value2; + } + } + + static inline void + op_flip(data_type* const data, + const size_type start, + const size_type size) { + if (size == 0) { + return; + } + + auto start_element = get_element(start); + const auto end_element = get_element(start + size); + + const auto start_shift = get_shift(start); + const auto end_shift = get_shift(start + size); + + // same element to modify? + if (start_element == end_element) { + const data_type existing_v = data[start_element]; + const data_type new_v = ~existing_v; + + const data_type existing_mask = get_shift_mask_begin(start_shift) | + get_shift_mask_end(end_shift); + const data_type new_mask = get_shift_mask_end(start_shift) & + get_shift_mask_begin(end_shift); + + data[start_element] = + (existing_v & existing_mask) | (new_v & new_mask); + return; + } + + // process the first element + if (start_shift != 0) { + const data_type existing_v = data[start_element]; + const data_type new_v = ~existing_v; + + const data_type existing_mask = get_shift_mask_begin(start_shift); + const data_type new_mask = get_shift_mask_end(start_shift); + + data[start_element] = + (existing_v & existing_mask) | (new_v & new_mask); + start_element += 1; + } + + // process the middle + for (size_type i = start_element; i < end_element; i++) { + data[i] = ~data[i]; + } + + // process the last element + if (end_shift != 0) { + const data_type existing_v = data[end_element]; + const data_type new_v = ~existing_v; + + const data_type existing_mask = get_shift_mask_end(end_shift); + const data_type new_mask = get_shift_mask_begin(end_shift); + + data[end_element] = + (existing_v & existing_mask) | (new_v & new_mask); + } + } + + static inline void + op_and(data_type* const left, + const data_type* const right, + const size_t start_left, + const size_t start_right, + const size_t size) { + op_func(left, + right, + start_left, + start_right, + size, + [](const data_type left_v, const data_type right_v) { + return left_v & right_v; + }); + } + + static inline void + op_or(data_type* const left, + const data_type* const right, + const size_t start_left, + const size_t start_right, + const size_t size) { + op_func(left, + right, + start_left, + start_right, + size, + [](const data_type left_v, const data_type right_v) { + return left_v | right_v; + }); + } + + static inline data_type + get_shift_mask_begin(const size_type shift) { + // 0 -> 0b00000000 + // 1 -> 0b00000001 + // 2 -> 0b00000011 + if (shift == data_bits) { + return data_type(-1); + } + + return (data_type(1) << shift) - data_type(1); + } + + static inline data_type + get_shift_mask_end(const size_type shift) { + // 0 -> 0b11111111 + // 1 -> 0b11111110 + // 2 -> 0b11111100 + return ~(get_shift_mask_begin(shift)); + } + + static inline void + op_set(data_type* const data, const size_type start, const size_type size) { + op_fill(data, start, size, true); + } + + static inline void + op_reset(data_type* const data, + const size_type start, + const size_type size) { + op_fill(data, start, size, false); + } + + static inline bool + op_all(const data_type* const data, + const size_type start, + const size_type size) { + if (size == 0) { + return true; + } + + auto start_element = get_element(start); + const auto end_element = get_element(start + size); + + const auto start_shift = get_shift(start); + const auto end_shift = get_shift(start + size); + + // same element? + if (start_element == end_element) { + const data_type existing_v = data[start_element]; + + const data_type existing_mask = get_shift_mask_end(start_shift) & + get_shift_mask_begin(end_shift); + + return ((existing_v & existing_mask) == existing_mask); + } + + // process the first element + if (start_shift != 0) { + const data_type existing_v = data[start_element]; + + const data_type existing_mask = get_shift_mask_end(start_shift); + if ((existing_v & existing_mask) != existing_mask) { + return false; + } + + start_element += 1; + } + + // process the middle + for (size_type i = start_element; i < end_element; i++) { + if (data[i] != data_type(-1)) { + return false; + } + } + + // process the last element + if (end_shift != 0) { + const data_type existing_v = data[end_element]; + + const data_type existing_mask = get_shift_mask_begin(end_shift); + + if ((existing_v & existing_mask) != existing_mask) { + return false; + } + } + + return true; + } + + static inline bool + op_none(const data_type* const data, + const size_type start, + const size_type size) { + if (size == 0) { + return true; + } + + auto start_element = get_element(start); + const auto end_element = get_element(start + size); + + const auto start_shift = get_shift(start); + const auto end_shift = get_shift(start + size); + + // same element? + if (start_element == end_element) { + const data_type existing_v = data[start_element]; + + const data_type existing_mask = get_shift_mask_end(start_shift) & + get_shift_mask_begin(end_shift); + + return ((existing_v & existing_mask) == data_type(0)); + } + + // process the first element + if (start_shift != 0) { + const data_type existing_v = data[start_element]; + + const data_type existing_mask = get_shift_mask_end(start_shift); + if ((existing_v & existing_mask) != data_type(0)) { + return false; + } + + start_element += 1; + } + + // process the middle + for (size_type i = start_element; i < end_element; i++) { + if (data[i] != data_type(0)) { + return false; + } + } + + // process the last element + if (end_shift != 0) { + const data_type existing_v = data[end_element]; + + const data_type existing_mask = get_shift_mask_begin(end_shift); + + if ((existing_v & existing_mask) != data_type(0)) { + return false; + } + } + + return true; + } + + static void + op_copy(const data_type* const src, + const size_type start_src, + data_type* const dst, + const size_type start_dst, + const size_type size) { + if (size == 0) { + return; + } + + // process big blocks + const size_type size_b = (size / data_bits) * data_bits; + + if ((start_src % data_bits) == 0) { + if ((start_dst % data_bits) == 0) { + // plain memcpy + for (size_type i = 0; i < size_b; i += data_bits) { + const data_type src_v = src[(start_src + i) / data_bits]; + dst[(start_dst + i) / data_bits] = src_v; + } + } else { + // easier read + for (size_type i = 0; i < size_b; i += data_bits) { + const data_type src_v = src[(start_src + i) / data_bits]; + op_write(dst, start_dst + i, data_bits, src_v); + } + } + } else { + if ((start_dst % data_bits) == 0) { + // easier write + for (size_type i = 0; i < size_b; i += data_bits) { + const data_type src_v = + op_read(src, start_src + i, data_bits); + dst[(start_dst + i) / data_bits] = src_v; + } + } else { + // general case + for (size_type i = 0; i < size_b; i += data_bits) { + const data_type src_v = + op_read(src, start_src + i, data_bits); + op_write(dst, start_dst + i, data_bits, src_v); + } + } + } + + // process leftovers + if (size_b != size) { + const data_type src_v = + op_read(src, start_src + size_b, size - size_b); + op_write(dst, start_dst + size_b, size - size_b, src_v); + } + } + + static void + op_fill(data_type* const data, + const size_type start, + const size_type size, + const bool value) { + if (size == 0) { + return; + } + + const data_type new_v = (value) ? data_type(-1) : data_type(0); + + // + auto start_element = get_element(start); + const auto end_element = get_element(start + size); + + const auto start_shift = get_shift(start); + const auto end_shift = get_shift(start + size); + + // same element to modify? + if (start_element == end_element) { + const data_type existing_v = data[start_element]; + + const data_type existing_mask = get_shift_mask_begin(start_shift) | + get_shift_mask_end(end_shift); + const data_type new_mask = get_shift_mask_end(start_shift) & + get_shift_mask_begin(end_shift); + + data[start_element] = + (existing_v & existing_mask) | (new_v & new_mask); + return; + } + + // process the first element + if (start_shift != 0) { + const data_type existing_v = data[start_element]; + + const data_type existing_mask = get_shift_mask_begin(start_shift); + const data_type new_mask = get_shift_mask_end(start_shift); + + data[start_element] = + (existing_v & existing_mask) | (new_v & new_mask); + start_element += 1; + } + + // process the middle + for (size_type i = start_element; i < end_element; i++) { + data[i] = new_v; + } + + // process the last element + if (end_shift != 0) { + const data_type existing_v = data[end_element]; + + const data_type existing_mask = get_shift_mask_end(end_shift); + const data_type new_mask = get_shift_mask_begin(end_shift); + + data[end_element] = + (existing_v & existing_mask) | (new_v & new_mask); + } + } + + static inline size_type + op_count(const data_type* const data, + const size_type start, + const size_type size) { + if (size == 0) { + return 0; + } + + size_type count = 0; + + auto start_element = get_element(start); + const auto end_element = get_element(start + size); + + const auto start_shift = get_shift(start); + const auto end_shift = get_shift(start + size); + + // same element? + if (start_element == end_element) { + const data_type existing_v = data[start_element]; + + const data_type existing_mask = get_shift_mask_end(start_shift) & + get_shift_mask_begin(end_shift); + + return PopCountHelper::count(existing_v & existing_mask); + } + + // process the first element + if (start_shift != 0) { + const data_type existing_v = data[start_element]; + const data_type existing_mask = get_shift_mask_end(start_shift); + + count = + PopCountHelper::count(existing_v & existing_mask); + + start_element += 1; + } + + // process the middle + for (size_type i = start_element; i < end_element; i++) { + count += PopCountHelper::count(data[i]); + } + + // process the last element + if (end_shift != 0) { + const data_type existing_v = data[end_element]; + const data_type existing_mask = get_shift_mask_begin(end_shift); + + count += + PopCountHelper::count(existing_v & existing_mask); + } + + return count; + } + + static inline bool + op_eq(const data_type* const left, + const data_type* const right, + const size_type start_left, + const size_type start_right, + const size_type size) { + if (size == 0) { + return true; + } + + // process big chunks + const size_type size_b = (size / data_bits) * data_bits; + + if ((start_left % data_bits) == 0) { + if ((start_right % data_bits) == 0) { + // plain "memcpy" + size_type start_left_idx = start_left / data_bits; + size_type start_right_idx = start_right / data_bits; + + for (size_type i = 0, j = 0; i < size_b; + i += data_bits, j += 1) { + const data_type left_v = left[start_left_idx + j]; + const data_type right_v = right[start_right_idx + j]; + if (left_v != right_v) { + return false; + } + } + } else { + // easier left + size_type start_left_idx = start_left / data_bits; + + for (size_type i = 0, j = 0; i < size_b; + i += data_bits, j += 1) { + const data_type left_v = left[start_left_idx + j]; + const data_type right_v = + op_read(right, start_right + i, data_bits); + if (left_v != right_v) { + return false; + } + } + } + } else { + if ((start_right % data_bits) == 0) { + // easier right + size_type start_right_idx = start_right / data_bits; + + for (size_type i = 0, j = 0; i < size_b; + i += data_bits, j += 1) { + const data_type left_v = + op_read(left, start_left + i, data_bits); + const data_type right_v = right[start_right_idx + j]; + if (left_v != right_v) { + return false; + } + } + } else { + // general case + for (size_type i = 0; i < size_b; i += data_bits) { + const data_type left_v = + op_read(left, start_left + i, data_bits); + const data_type right_v = + op_read(right, start_right + i, data_bits); + if (left_v != right_v) { + return false; + } + } + } + } + + // process leftovers + if (size_b != size) { + const data_type left_v = + op_read(left, start_left + size_b, size - size_b); + const data_type right_v = + op_read(right, start_right + size_b, size - size_b); + if (left_v != right_v) { + return false; + } + } + + return true; + } + + static inline void + op_xor(data_type* const left, + const data_type* const right, + const size_t start_left, + const size_t start_right, + const size_t size) { + op_func(left, + right, + start_left, + start_right, + size, + [](const data_type left_v, const data_type right_v) { + return left_v ^ right_v; + }); + } + + static inline void + op_sub(data_type* const left, + const data_type* const right, + const size_t start_left, + const size_t start_right, + const size_t size) { + if (size == 0) { + return; + } + + op_func(left, + right, + start_left, + start_right, + size, + [](const data_type left_v, const data_type right_v) { + return left_v & ~right_v; + }); + } + + // + static inline std::optional + op_find(const data_type* const data, + const size_type start, + const size_type size, + const size_type starting_idx) { + if (size == 0) { + return std::nullopt; + } + + // + auto start_element = get_element(start + starting_idx); + const auto end_element = get_element(start + size); + + const auto start_shift = get_shift(start + starting_idx); + const auto end_shift = get_shift(start + size); + + size_type extra_offset = 0; + + // same element? + if (start_element == end_element) { + const data_type existing_v = data[start_element]; + + const data_type existing_mask = get_shift_mask_end(start_shift) & + get_shift_mask_begin(end_shift); + + const data_type value = existing_v & existing_mask; + if (value != 0) { + const auto ctz = CtzHelper::ctz(value); + return size_type(ctz) + start_element * data_bits - start; + } else { + return std::nullopt; + } + } + + // process the first element + if (start_shift != 0) { + const data_type existing_v = data[start_element]; + const data_type existing_mask = get_shift_mask_end(start_shift); + + const data_type value = existing_v & existing_mask; + if (value != 0) { + const auto ctz = CtzHelper::ctz(value) + + start_element * data_bits - start; + return size_type(ctz); + } + + start_element += 1; + extra_offset += data_bits - start_shift; + } + + // process the middle + for (size_type i = start_element; i < end_element; i++) { + const data_type value = data[i]; + if (value != 0) { + const auto ctz = CtzHelper::ctz(value); + return size_type(ctz) + i * data_bits - start; + } + } + + // process the last element + if (end_shift != 0) { + const data_type existing_v = data[end_element]; + const data_type existing_mask = get_shift_mask_begin(end_shift); + + const data_type value = existing_v & existing_mask; + if (value != 0) { + const auto ctz = CtzHelper::ctz(value); + return size_type(ctz) + end_element * data_bits - start; + } + } + + return std::nullopt; + } + + // + template + static inline void + op_compare_column(data_type* const __restrict data, + const size_type start, + const T* const __restrict t, + const U* const __restrict u, + const size_type size) { + for (size_type i = 0; i < size; i++) { + get_proxy(data, start + i) = + CompareOperator::compare(t[i], u[i]); + } + } + + // + template + static inline void + op_compare_val(data_type* const __restrict data, + const size_type start, + const T* const __restrict t, + const size_type size, + const T& value) { + for (size_type i = 0; i < size; i++) { + get_proxy(data, start + i) = + CompareOperator::compare(t[i], value); + } + } + + // + template + static inline void + op_within_range_column(data_type* const __restrict data, + const size_type start, + const T* const __restrict lower, + const T* const __restrict upper, + const T* const __restrict values, + const size_type size) { + for (size_type i = 0; i < size; i++) { + get_proxy(data, start + i) = + RangeOperator::within_range(lower[i], upper[i], values[i]); + } + } + + // + template + static inline void + op_within_range_val(data_type* const __restrict data, + const size_type start, + const T& lower, + const T& upper, + const T* const __restrict values, + const size_type size) { + for (size_type i = 0; i < size; i++) { + get_proxy(data, start + i) = + RangeOperator::within_range(lower, upper, values[i]); + } + } + + // + template + static inline void + op_arith_compare(data_type* const __restrict data, + const size_type start, + const T* const __restrict src, + const ArithHighPrecisionType& right_operand, + const ArithHighPrecisionType& value, + const size_type size) { + for (size_type i = 0; i < size; i++) { + get_proxy(data, start + i) = + ArithCompareOperator::compare( + src[i], right_operand, value); + } + } + + // + static inline size_t + op_and_with_count(data_type* const left, + const data_type* const right, + const size_t start_left, + const size_t start_right, + const size_t size) { + size_t active = 0; + + op_func(left, + right, + start_left, + start_right, + size, + [&active](const data_type left_v, const data_type right_v) { + const data_type result = left_v & right_v; + active += PopCountHelper::count(result); + + return result; + }); + + return active; + } + + static inline size_t + op_or_with_count(data_type* const left, + const data_type* const right, + const size_t start_left, + const size_t start_right, + const size_t size) { + size_t inactive = 0; + + op_func(left, + right, + start_left, + start_right, + size, + [&inactive](const data_type left_v, const data_type right_v) { + const data_type result = left_v | right_v; + inactive += + (data_bits - PopCountHelper::count(result)); + + return result; + }); + + return inactive; + } + + // data_type Func(const data_type left_v, const data_type right_v); + template + static inline void + op_func(data_type* const left, + const data_type* const right, + const size_t start_left, + const size_t start_right, + const size_t size, + Func func) { + if (size == 0) { + return; + } + + // process big blocks + const size_type size_b = (size / data_bits) * data_bits; + if ((start_left % data_bits) == 0) { + if ((start_right % data_bits) == 0) { + // plain "memcpy". + // A compiler auto-vectorization is expected. + size_type start_left_idx = start_left / data_bits; + size_type start_right_idx = start_right / data_bits; + + for (size_type i = 0, j = 0; i < size_b; + i += data_bits, j += 1) { + data_type& left_v = left[start_left_idx + j]; + const data_type right_v = right[start_right_idx + j]; + + const data_type result_v = func(left_v, right_v); + left_v = result_v; + } + } else { + // easier read + size_type start_right_idx = start_right / data_bits; + + for (size_type i = 0, j = 0; i < size_b; + i += data_bits, j += 1) { + const data_type left_v = + op_read(left, start_left + i, data_bits); + const data_type right_v = right[start_right_idx + j]; + + const data_type result_v = func(left_v, right_v); + op_write(left, start_right + i, data_bits, result_v); + } + } + } else { + if ((start_right % data_bits) == 0) { + // easier write + size_type start_left_idx = start_left / data_bits; + + for (size_type i = 0, j = 0; i < size_b; + i += data_bits, j += 1) { + data_type& left_v = left[start_left_idx + j]; + const data_type right_v = + op_read(right, start_right + i, data_bits); + + const data_type result_v = func(left_v, right_v); + left_v = result_v; + } + } else { + // general case + for (size_type i = 0; i < size_b; i += data_bits) { + const data_type left_v = + op_read(left, start_left + i, data_bits); + const data_type right_v = + op_read(right, start_right + i, data_bits); + + const data_type result_v = func(left_v, right_v); + op_write(left, start_right + i, data_bits, result_v); + } + } + } + + // process leftovers + if (size_b != size) { + const data_type left_v = + op_read(left, start_left + size_b, size - size_b); + const data_type right_v = + op_read(right, start_right + size_b, size - size_b); + + const data_type result_v = func(left_v, right_v); + op_write(left, start_left + size_b, size - size_b, result_v); + } + } +}; + +} // namespace detail +} // namespace bitset +} // namespace milvus diff --git a/internal/core/src/bitset/detail/platform/arm/neon-decl.h b/internal/core/src/bitset/detail/platform/arm/neon-decl.h new file mode 100644 index 000000000000..c92bb37c0fc4 --- /dev/null +++ b/internal/core/src/bitset/detail/platform/arm/neon-decl.h @@ -0,0 +1,201 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// ARM NEON declaration + +#pragma once + +#include +#include + +#include "bitset/common.h" + +namespace milvus { +namespace bitset { +namespace detail { +namespace arm { +namespace neon { + +/////////////////////////////////////////////////////////////////////////// +// a facility to run through all acceptable data types +#define ALL_DATATYPES_1(FUNC) \ + FUNC(int8_t); \ + FUNC(int16_t); \ + FUNC(int32_t); \ + FUNC(int64_t); \ + FUNC(float); \ + FUNC(double); + +/////////////////////////////////////////////////////////////////////////// + +// the default implementation does nothing +template +struct OpCompareColumnImpl { + static bool + op_compare_column(uint8_t* const __restrict bitmask, + const T* const __restrict t, + const U* const __restrict u, + const size_t size) { + return false; + } +}; + +// the following use cases are handled +#define DECLARE_PARTIAL_OP_COMPARE_COLUMN(TTYPE) \ + template \ + struct OpCompareColumnImpl { \ + static bool \ + op_compare_column(uint8_t* const __restrict bitmask, \ + const TTYPE* const __restrict t, \ + const TTYPE* const __restrict u, \ + const size_t size); \ + }; + +ALL_DATATYPES_1(DECLARE_PARTIAL_OP_COMPARE_COLUMN) + +#undef DECLARE_PARTIAL_OP_COMPARE_COLUMN + +/////////////////////////////////////////////////////////////////////////// + +// the default implementation does nothing +template +struct OpCompareValImpl { + static inline bool + op_compare_val(uint8_t* const __restrict bitmask, + const T* const __restrict t, + const size_t size, + const T& value) { + return false; + } +}; + +// the following use cases are handled +#define DECLARE_PARTIAL_OP_COMPARE_VAL(TTYPE) \ + template \ + struct OpCompareValImpl { \ + static bool \ + op_compare_val(uint8_t* const __restrict bitmask, \ + const TTYPE* const __restrict t, \ + const size_t size, \ + const TTYPE& value); \ + }; + +ALL_DATATYPES_1(DECLARE_PARTIAL_OP_COMPARE_VAL) + +#undef DECLARE_PARTIAL_OP_COMPARE_VAL + +/////////////////////////////////////////////////////////////////////////// + +// the default implementation does nothing +template +struct OpWithinRangeColumnImpl { + static inline bool + op_within_range_column(uint8_t* const __restrict bitmask, + const T* const __restrict lower, + const T* const __restrict upper, + const T* const __restrict values, + const size_t size) { + return false; + } +}; + +// the following use cases are handled +#define DECLARE_PARTIAL_OP_WITHIN_RANGE_COLUMN(TTYPE) \ + template \ + struct OpWithinRangeColumnImpl { \ + static bool \ + op_within_range_column(uint8_t* const __restrict bitmask, \ + const TTYPE* const __restrict lower, \ + const TTYPE* const __restrict upper, \ + const TTYPE* const __restrict values, \ + const size_t size); \ + }; + +ALL_DATATYPES_1(DECLARE_PARTIAL_OP_WITHIN_RANGE_COLUMN) + +#undef DECLARE_PARTIAL_OP_WITHIN_RANGE_COLUMN + +/////////////////////////////////////////////////////////////////////////// + +// the default implementation does nothing +template +struct OpWithinRangeValImpl { + static inline bool + op_within_range_val(uint8_t* const __restrict bitmask, + const T& lower, + const T& upper, + const T* const __restrict values, + const size_t size) { + return false; + } +}; + +// the following use cases are handled +#define DECLARE_PARTIAL_OP_WITHIN_RANGE_VAL(TTYPE) \ + template \ + struct OpWithinRangeValImpl { \ + static bool \ + op_within_range_val(uint8_t* const __restrict bitmask, \ + const TTYPE& lower, \ + const TTYPE& upper, \ + const TTYPE* const __restrict values, \ + const size_t size); \ + }; + +ALL_DATATYPES_1(DECLARE_PARTIAL_OP_WITHIN_RANGE_VAL) + +#undef DECLARE_PARTIAL_OP_WITHIN_RANGE_VAL + +/////////////////////////////////////////////////////////////////////////// + +// the default implementation does nothing +template +struct OpArithCompareImpl { + static inline bool + op_arith_compare(uint8_t* const __restrict bitmask, + const T* const __restrict src, + const ArithHighPrecisionType& right_operand, + const ArithHighPrecisionType& value, + const size_t size) { + return false; + } +}; + +// the following use cases are handled +#define DECLARE_PARTIAL_OP_ARITH_COMPARE(TTYPE) \ + template \ + struct OpArithCompareImpl { \ + static bool \ + op_arith_compare(uint8_t* const __restrict bitmask, \ + const TTYPE* const __restrict src, \ + const ArithHighPrecisionType& right_operand, \ + const ArithHighPrecisionType& value, \ + const size_t size); \ + }; + +ALL_DATATYPES_1(DECLARE_PARTIAL_OP_ARITH_COMPARE) + +#undef DECLARE_PARTIAL_OP_ARITH_COMPARE + +/////////////////////////////////////////////////////////////////////////// + +#undef ALL_DATATYPES_1 + +} // namespace neon +} // namespace arm +} // namespace detail +} // namespace bitset +} // namespace milvus diff --git a/internal/core/src/bitset/detail/platform/arm/neon-impl.h b/internal/core/src/bitset/detail/platform/arm/neon-impl.h new file mode 100644 index 000000000000..0547665d9f6c --- /dev/null +++ b/internal/core/src/bitset/detail/platform/arm/neon-impl.h @@ -0,0 +1,1819 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// ARM NEON implementation + +#pragma once + +#include + +#include +#include +#include +#include + +#include "neon-decl.h" + +#include "bitset/common.h" + +namespace milvus { +namespace bitset { +namespace detail { +namespace arm { +namespace neon { + +namespace { + +// this function is missing somewhy +inline uint64x2_t +vmvnq_u64(const uint64x2_t value) { + const uint64x2_t m1 = vreinterpretq_u64_u32(vdupq_n_u32(0xFFFFFFFF)); + return veorq_u64(value, m1); +} + +// draft: movemask functions from sse2neon library. +// todo: can this be made better? + +// todo: optimize +inline uint8_t +movemask(const uint8x8_t cmp) { + static const int8_t shifts[8] = {0, 1, 2, 3, 4, 5, 6, 7}; + // shift right by 7, leaving 1 bit + const uint8x8_t sh = vshr_n_u8(cmp, 7); + // load shifts + const int8x8_t shifts_v = vld1_s8(shifts); + // shift each of 8 lanes with 1 bit values differently + const uint8x8_t shifted_bits = vshl_u8(sh, shifts_v); + // horizontal sum of bits on different positions + return vaddv_u8(shifted_bits); +} + +// todo: optimize +// https://lemire.me/blog/2017/07/10/pruning-spaces-faster-on-arm-processors-with-vector-table-lookups/ (?) +inline uint16_t +movemask(const uint8x16_t cmp) { + uint16x8_t high_bits = vreinterpretq_u16_u8(vshrq_n_u8(cmp, 7)); + uint32x4_t paired16 = + vreinterpretq_u32_u16(vsraq_n_u16(high_bits, high_bits, 7)); + uint64x2_t paired32 = + vreinterpretq_u64_u32(vsraq_n_u32(paired16, paired16, 14)); + uint8x16_t paired64 = + vreinterpretq_u8_u64(vsraq_n_u64(paired32, paired32, 28)); + return vgetq_lane_u8(paired64, 0) | ((int)vgetq_lane_u8(paired64, 8) << 8); +} + +// todo: optimize +inline uint32_t +movemask(const uint8x16x2_t cmp) { + return (uint32_t)(movemask(cmp.val[0])) | + ((uint32_t)(movemask(cmp.val[1])) << 16); +} + +// todo: optimize +inline uint8_t +movemask(const uint16x8_t cmp) { + static const int16_t shifts[8] = {0, 1, 2, 3, 4, 5, 6, 7}; + // shift right by 15, leaving 1 bit + const uint16x8_t sh = vshrq_n_u16(cmp, 15); + // load shifts + const int16x8_t shifts_v = vld1q_s16(shifts); + // shift each of 8 lanes with 1 bit values differently + const uint16x8_t shifted_bits = vshlq_u16(sh, shifts_v); + // horizontal sum of bits on different positions + return vaddvq_u16(shifted_bits); +} + +// todo: optimize +inline uint16_t +movemask(const uint16x8x2_t cmp) { + return (uint16_t)(movemask(cmp.val[0])) | + ((uint16_t)(movemask(cmp.val[1])) << 8); +} + +// todo: optimize +inline uint32_t +movemask(const uint32x4_t cmp) { + static const int32_t shifts[4] = {0, 1, 2, 3}; + // shift right by 31, leaving 1 bit + const uint32x4_t sh = vshrq_n_u32(cmp, 31); + // load shifts + const int32x4_t shifts_v = vld1q_s32(shifts); + // shift each of 4 lanes with 1 bit values differently + const uint32x4_t shifted_bits = vshlq_u32(sh, shifts_v); + // horizontal sum of bits on different positions + return vaddvq_u32(shifted_bits); +} + +// todo: optimize +inline uint32_t +movemask(const uint32x4x2_t cmp) { + return movemask(cmp.val[0]) | (movemask(cmp.val[1]) << 4); +} + +// todo: optimize +inline uint8_t +movemask(const uint64x2_t cmp) { + // shift right by 63, leaving 1 bit + const uint64x2_t sh = vshrq_n_u64(cmp, 63); + return vgetq_lane_u64(sh, 0) | (vgetq_lane_u64(sh, 1) << 1); +} + +// todo: optimize +inline uint8_t +movemask(const uint64x2x4_t cmp) { + return movemask(cmp.val[0]) | (movemask(cmp.val[1]) << 2) | + (movemask(cmp.val[2]) << 4) | (movemask(cmp.val[3]) << 6); +} + +// +template +struct CmpHelper {}; + +template <> +struct CmpHelper { + static inline uint8x8_t + compare(const int8x8_t a, const int8x8_t b) { + return vceq_s8(a, b); + } + + static inline uint8x16x2_t + compare(const int8x16x2_t a, const int8x16x2_t b) { + return {vceqq_s8(a.val[0], b.val[0]), vceqq_s8(a.val[1], b.val[1])}; + } + + static inline uint16x8_t + compare(const int16x8_t a, const int16x8_t b) { + return vceqq_s16(a, b); + } + + static inline uint16x8x2_t + compare(const int16x8x2_t a, const int16x8x2_t b) { + return {vceqq_s16(a.val[0], b.val[0]), vceqq_s16(a.val[1], b.val[1])}; + } + + static inline uint32x4x2_t + compare(const int32x4x2_t a, const int32x4x2_t b) { + return {vceqq_s32(a.val[0], b.val[0]), vceqq_s32(a.val[1], b.val[1])}; + } + + static inline uint64x2x4_t + compare(const int64x2x4_t a, const int64x2x4_t b) { + return {vceqq_s64(a.val[0], b.val[0]), + vceqq_s64(a.val[1], b.val[1]), + vceqq_s64(a.val[2], b.val[2]), + vceqq_s64(a.val[3], b.val[3])}; + } + + static inline uint32x4x2_t + compare(const float32x4x2_t a, const float32x4x2_t b) { + return {vceqq_f32(a.val[0], b.val[0]), vceqq_f32(a.val[1], b.val[1])}; + } + + static inline uint64x2x4_t + compare(const float64x2x4_t a, const float64x2x4_t b) { + return {vceqq_f64(a.val[0], b.val[0]), + vceqq_f64(a.val[1], b.val[1]), + vceqq_f64(a.val[2], b.val[2]), + vceqq_f64(a.val[3], b.val[3])}; + } +}; + +template <> +struct CmpHelper { + static inline uint8x8_t + compare(const int8x8_t a, const int8x8_t b) { + return vcge_s8(a, b); + } + + static inline uint8x16x2_t + compare(const int8x16x2_t a, const int8x16x2_t b) { + return {vcgeq_s8(a.val[0], b.val[0]), vcgeq_s8(a.val[1], b.val[1])}; + } + + static inline uint16x8_t + compare(const int16x8_t a, const int16x8_t b) { + return vcgeq_s16(a, b); + } + + static inline uint16x8x2_t + compare(const int16x8x2_t a, const int16x8x2_t b) { + return {vcgeq_s16(a.val[0], b.val[0]), vcgeq_s16(a.val[1], b.val[1])}; + } + + static inline uint32x4x2_t + compare(const int32x4x2_t a, const int32x4x2_t b) { + return {vcgeq_s32(a.val[0], b.val[0]), vcgeq_s32(a.val[1], b.val[1])}; + } + + static inline uint64x2x4_t + compare(const int64x2x4_t a, const int64x2x4_t b) { + return {vcgeq_s64(a.val[0], b.val[0]), + vcgeq_s64(a.val[1], b.val[1]), + vcgeq_s64(a.val[2], b.val[2]), + vcgeq_s64(a.val[3], b.val[3])}; + } + + static inline uint32x4x2_t + compare(const float32x4x2_t a, const float32x4x2_t b) { + return {vcgeq_f32(a.val[0], b.val[0]), vcgeq_f32(a.val[1], b.val[1])}; + } + + static inline uint64x2x4_t + compare(const float64x2x4_t a, const float64x2x4_t b) { + return {vcgeq_f64(a.val[0], b.val[0]), + vcgeq_f64(a.val[1], b.val[1]), + vcgeq_f64(a.val[2], b.val[2]), + vcgeq_f64(a.val[3], b.val[3])}; + } +}; + +template <> +struct CmpHelper { + static inline uint8x8_t + compare(const int8x8_t a, const int8x8_t b) { + return vcgt_s8(a, b); + } + + static inline uint8x16x2_t + compare(const int8x16x2_t a, const int8x16x2_t b) { + return {vcgtq_s8(a.val[0], b.val[0]), vcgtq_s8(a.val[1], b.val[1])}; + } + + static inline uint16x8_t + compare(const int16x8_t a, const int16x8_t b) { + return vcgtq_s16(a, b); + } + + static inline uint16x8x2_t + compare(const int16x8x2_t a, const int16x8x2_t b) { + return {vcgtq_s16(a.val[0], b.val[0]), vcgtq_s16(a.val[1], b.val[1])}; + } + + static inline uint32x4x2_t + compare(const int32x4x2_t a, const int32x4x2_t b) { + return {vcgtq_s32(a.val[0], b.val[0]), vcgtq_s32(a.val[1], b.val[1])}; + } + + static inline uint64x2x4_t + compare(const int64x2x4_t a, const int64x2x4_t b) { + return {vcgtq_s64(a.val[0], b.val[0]), + vcgtq_s64(a.val[1], b.val[1]), + vcgtq_s64(a.val[2], b.val[2]), + vcgtq_s64(a.val[3], b.val[3])}; + } + + static inline uint32x4x2_t + compare(const float32x4x2_t a, const float32x4x2_t b) { + return {vcgtq_f32(a.val[0], b.val[0]), vcgtq_f32(a.val[1], b.val[1])}; + } + + static inline uint64x2x4_t + compare(const float64x2x4_t a, const float64x2x4_t b) { + return {vcgtq_f64(a.val[0], b.val[0]), + vcgtq_f64(a.val[1], b.val[1]), + vcgtq_f64(a.val[2], b.val[2]), + vcgtq_f64(a.val[3], b.val[3])}; + } +}; + +template <> +struct CmpHelper { + static inline uint8x8_t + compare(const int8x8_t a, const int8x8_t b) { + return vcle_s8(a, b); + } + + static inline uint8x16x2_t + compare(const int8x16x2_t a, const int8x16x2_t b) { + return {vcleq_s8(a.val[0], b.val[0]), vcleq_s8(a.val[1], b.val[1])}; + } + + static inline uint16x8_t + compare(const int16x8_t a, const int16x8_t b) { + return vcleq_s16(a, b); + } + + static inline uint16x8x2_t + compare(const int16x8x2_t a, const int16x8x2_t b) { + return {vcleq_s16(a.val[0], b.val[0]), vcleq_s16(a.val[1], b.val[1])}; + } + + static inline uint32x4x2_t + compare(const int32x4x2_t a, const int32x4x2_t b) { + return {vcleq_s32(a.val[0], b.val[0]), vcleq_s32(a.val[1], b.val[1])}; + } + + static inline uint64x2x4_t + compare(const int64x2x4_t a, const int64x2x4_t b) { + return {vcleq_s64(a.val[0], b.val[0]), + vcleq_s64(a.val[1], b.val[1]), + vcleq_s64(a.val[2], b.val[2]), + vcleq_s64(a.val[3], b.val[3])}; + } + + static inline uint32x4x2_t + compare(const float32x4x2_t a, const float32x4x2_t b) { + return {vcleq_f32(a.val[0], b.val[0]), vcleq_f32(a.val[1], b.val[1])}; + } + + static inline uint64x2x4_t + compare(const float64x2x4_t a, const float64x2x4_t b) { + return {vcleq_f64(a.val[0], b.val[0]), + vcleq_f64(a.val[1], b.val[1]), + vcleq_f64(a.val[2], b.val[2]), + vcleq_f64(a.val[3], b.val[3])}; + } +}; + +template <> +struct CmpHelper { + static inline uint8x8_t + compare(const int8x8_t a, const int8x8_t b) { + return vclt_s8(a, b); + } + + static inline uint8x16x2_t + compare(const int8x16x2_t a, const int8x16x2_t b) { + return {vcltq_s8(a.val[0], b.val[0]), vcltq_s8(a.val[1], b.val[1])}; + } + + static inline uint16x8_t + compare(const int16x8_t a, const int16x8_t b) { + return vcltq_s16(a, b); + } + + static inline uint16x8x2_t + compare(const int16x8x2_t a, const int16x8x2_t b) { + return {vcltq_s16(a.val[0], b.val[0]), vcltq_s16(a.val[1], b.val[1])}; + } + + static inline uint32x4x2_t + compare(const int32x4x2_t a, const int32x4x2_t b) { + return {vcltq_s32(a.val[0], b.val[0]), vcltq_s32(a.val[1], b.val[1])}; + } + + static inline uint64x2x4_t + compare(const int64x2x4_t a, const int64x2x4_t b) { + return {vcltq_s64(a.val[0], b.val[0]), + vcltq_s64(a.val[1], b.val[1]), + vcltq_s64(a.val[2], b.val[2]), + vcltq_s64(a.val[3], b.val[3])}; + } + + static inline uint32x4x2_t + compare(const float32x4x2_t a, const float32x4x2_t b) { + return {vcltq_f32(a.val[0], b.val[0]), vcltq_f32(a.val[1], b.val[1])}; + } + + static inline uint64x2x4_t + compare(const float64x2x4_t a, const float64x2x4_t b) { + return {vcltq_f64(a.val[0], b.val[0]), + vcltq_f64(a.val[1], b.val[1]), + vcltq_f64(a.val[2], b.val[2]), + vcltq_f64(a.val[3], b.val[3])}; + } +}; + +template <> +struct CmpHelper { + static inline uint8x8_t + compare(const int8x8_t a, const int8x8_t b) { + return vmvn_u8(vceq_s8(a, b)); + } + + static inline uint8x16x2_t + compare(const int8x16x2_t a, const int8x16x2_t b) { + return {vmvnq_u8(vceqq_s8(a.val[0], b.val[0])), + vmvnq_u8(vceqq_s8(a.val[1], b.val[1]))}; + } + + static inline uint16x8_t + compare(const int16x8_t a, const int16x8_t b) { + return vmvnq_u16(vceqq_s16(a, b)); + } + + static inline uint16x8x2_t + compare(const int16x8x2_t a, const int16x8x2_t b) { + return {vmvnq_u16(vceqq_s16(a.val[0], b.val[0])), + vmvnq_u16(vceqq_s16(a.val[1], b.val[1]))}; + } + + static inline uint32x4x2_t + compare(const int32x4x2_t a, const int32x4x2_t b) { + return {vmvnq_u32(vceqq_s32(a.val[0], b.val[0])), + vmvnq_u32(vceqq_s32(a.val[1], b.val[1]))}; + } + + static inline uint64x2x4_t + compare(const int64x2x4_t a, const int64x2x4_t b) { + return {vmvnq_u64(vceqq_s64(a.val[0], b.val[0])), + vmvnq_u64(vceqq_s64(a.val[1], b.val[1])), + vmvnq_u64(vceqq_s64(a.val[2], b.val[2])), + vmvnq_u64(vceqq_s64(a.val[3], b.val[3]))}; + } + + static inline uint32x4x2_t + compare(const float32x4x2_t a, const float32x4x2_t b) { + return {vmvnq_u32(vceqq_f32(a.val[0], b.val[0])), + vmvnq_u32(vceqq_f32(a.val[1], b.val[1]))}; + } + + static inline uint64x2x4_t + compare(const float64x2x4_t a, const float64x2x4_t b) { + return {vmvnq_u64(vceqq_f64(a.val[0], b.val[0])), + vmvnq_u64(vceqq_f64(a.val[1], b.val[1])), + vmvnq_u64(vceqq_f64(a.val[2], b.val[2])), + vmvnq_u64(vceqq_f64(a.val[3], b.val[3]))}; + } +}; + +} // namespace + +/////////////////////////////////////////////////////////////////////////// + +// +template +bool +OpCompareValImpl::op_compare_val(uint8_t* const __restrict res_u8, + const int8_t* const __restrict src, + const size_t size, + const int8_t& val) { + // the restriction of the API + assert((size % 8) == 0); + + // + uint32_t* const __restrict res_u32 = reinterpret_cast(res_u8); + const int8x16x2_t target = {vdupq_n_s8(val), vdupq_n_s8(val)}; + + // todo: aligned reads & writes + + const size_t size32 = (size / 32) * 32; + for (size_t i = 0; i < size32; i += 32) { + const int8x16x2_t v0 = {vld1q_s8(src + i), vld1q_s8(src + i + 16)}; + const uint8x16x2_t cmp = CmpHelper::compare(v0, target); + const uint32_t mmask = movemask(cmp); + + res_u32[i / 32] = mmask; + } + + for (size_t i = size32; i < size; i += 8) { + const int8x8_t v0 = vld1_s8(src + i); + const uint8x8_t cmp = CmpHelper::compare(v0, vdup_n_s8(val)); + const uint8_t mmask = movemask(cmp); + + res_u8[i / 8] = mmask; + } + + return true; +} + +template +bool +OpCompareValImpl::op_compare_val( + uint8_t* const __restrict res_u8, + const int16_t* const __restrict src, + const size_t size, + const int16_t& val) { + // the restriction of the API + assert((size % 8) == 0); + + // + uint16_t* const __restrict res_u16 = reinterpret_cast(res_u8); + const int16x8x2_t target = {vdupq_n_s16(val), vdupq_n_s16(val)}; + + // todo: aligned reads & writes + + const size_t size16 = (size / 16) * 16; + for (size_t i = 0; i < size16; i += 16) { + const int16x8x2_t v0 = {vld1q_s16(src + i), vld1q_s16(src + i + 8)}; + const uint16x8x2_t cmp = CmpHelper::compare(v0, target); + const uint16_t mmask = movemask(cmp); + + res_u16[i / 16] = mmask; + } + + if (size16 != size) { + // 8 elements to process + const int16x8_t v0 = vld1q_s16(src + size16); + const uint16x8_t cmp = CmpHelper::compare(v0, target.val[0]); + const uint8_t mmask = movemask(cmp); + + res_u8[size16 / 8] = mmask; + } + + return true; +} + +template +bool +OpCompareValImpl::op_compare_val( + uint8_t* const __restrict res_u8, + const int32_t* const __restrict src, + const size_t size, + const int32_t& val) { + // the restriction of the API + assert((size % 8) == 0); + + // + const int32x4x2_t target = {vdupq_n_s32(val), vdupq_n_s32(val)}; + + // todo: aligned reads & writes + + const size_t size8 = (size / 8) * 8; + for (size_t i = 0; i < size8; i += 8) { + const int32x4x2_t v0 = {vld1q_s32(src + i), vld1q_s32(src + i + 4)}; + const uint32x4x2_t cmp = CmpHelper::compare(v0, target); + const uint8_t mmask = movemask(cmp); + + res_u8[i / 8] = mmask; + } + + return true; +} + +template +bool +OpCompareValImpl::op_compare_val( + uint8_t* const __restrict res_u8, + const int64_t* const __restrict src, + const size_t size, + const int64_t& val) { + // the restriction of the API + assert((size % 8) == 0); + + // + const int64x2x4_t target = { + vdupq_n_s64(val), vdupq_n_s64(val), vdupq_n_s64(val), vdupq_n_s64(val)}; + + // todo: aligned reads & writes + + const size_t size8 = (size / 8) * 8; + for (size_t i = 0; i < size8; i += 8) { + const int64x2x4_t v0 = {vld1q_s64(src + i), + vld1q_s64(src + i + 2), + vld1q_s64(src + i + 4), + vld1q_s64(src + i + 6)}; + const uint64x2x4_t cmp = CmpHelper::compare(v0, target); + const uint8_t mmask = movemask(cmp); + + res_u8[i / 8] = mmask; + } + + return true; +} + +template +bool +OpCompareValImpl::op_compare_val(uint8_t* const __restrict res_u8, + const float* const __restrict src, + const size_t size, + const float& val) { + // the restriction of the API + assert((size % 8) == 0); + + // + const float32x4x2_t target = {vdupq_n_f32(val), vdupq_n_f32(val)}; + + // todo: aligned reads & writes + + const size_t size8 = (size / 8) * 8; + for (size_t i = 0; i < size8; i += 8) { + const float32x4x2_t v0 = {vld1q_f32(src + i), vld1q_f32(src + i + 4)}; + const uint32x4x2_t cmp = CmpHelper::compare(v0, target); + const uint8_t mmask = movemask(cmp); + + res_u8[i / 8] = mmask; + } + + return true; +} + +template +bool +OpCompareValImpl::op_compare_val(uint8_t* const __restrict res_u8, + const double* const __restrict src, + const size_t size, + const double& val) { + // the restriction of the API + assert((size % 8) == 0); + + // + const float64x2x4_t target = { + vdupq_n_f64(val), vdupq_n_f64(val), vdupq_n_f64(val), vdupq_n_f64(val)}; + + // todo: aligned reads & writes + + const size_t size8 = (size / 8) * 8; + for (size_t i = 0; i < size8; i += 8) { + const float64x2x4_t v0 = {vld1q_f64(src + i), + vld1q_f64(src + i + 2), + vld1q_f64(src + i + 4), + vld1q_f64(src + i + 6)}; + const uint64x2x4_t cmp = CmpHelper::compare(v0, target); + const uint8_t mmask = movemask(cmp); + + res_u8[i / 8] = mmask; + } + + return true; +} + +/////////////////////////////////////////////////////////////////////////// + +// +template +bool +OpCompareColumnImpl::op_compare_column( + uint8_t* const __restrict res_u8, + const int8_t* const __restrict left, + const int8_t* const __restrict right, + const size_t size) { + // the restriction of the API + assert((size % 8) == 0); + + // + uint32_t* const __restrict res_u32 = reinterpret_cast(res_u8); + + // todo: aligned reads & writes + + const size_t size32 = (size / 32) * 32; + for (size_t i = 0; i < size32; i += 32) { + const int8x16x2_t v0l = {vld1q_s8(left + i), vld1q_s8(left + i + 16)}; + const int8x16x2_t v0r = {vld1q_s8(right + i), vld1q_s8(right + i + 16)}; + const uint8x16x2_t cmp = CmpHelper::compare(v0l, v0r); + const uint32_t mmask = movemask(cmp); + + res_u32[i / 32] = mmask; + } + + for (size_t i = size32; i < size; i += 8) { + const int8x8_t v0l = vld1_s8(left + i); + const int8x8_t v0r = vld1_s8(right + i); + const uint8x8_t cmp = CmpHelper::compare(v0l, v0r); + const uint8_t mmask = movemask(cmp); + + res_u8[i / 8] = mmask; + } + + return true; +} + +template +bool +OpCompareColumnImpl::op_compare_column( + uint8_t* const __restrict res_u8, + const int16_t* const __restrict left, + const int16_t* const __restrict right, + const size_t size) { + // the restriction of the API + assert((size % 8) == 0); + + // + uint16_t* const __restrict res_u16 = reinterpret_cast(res_u8); + + // todo: aligned reads & writes + + const size_t size16 = (size / 16) * 16; + for (size_t i = 0; i < size16; i += 16) { + const int16x8x2_t v0l = {vld1q_s16(left + i), vld1q_s16(left + i + 8)}; + const int16x8x2_t v0r = {vld1q_s16(right + i), + vld1q_s16(right + i + 8)}; + const uint16x8x2_t cmp = CmpHelper::compare(v0l, v0r); + const uint16_t mmask = movemask(cmp); + + res_u16[i / 16] = mmask; + } + + if (size16 != size) { + // 8 elements to process + const int16x8_t v0l = vld1q_s16(left + size16); + const int16x8_t v0r = vld1q_s16(right + size16); + const uint16x8_t cmp = CmpHelper::compare(v0l, v0r); + const uint8_t mmask = movemask(cmp); + + res_u8[size16 / 8] = mmask; + } + + return true; +} + +template +bool +OpCompareColumnImpl::op_compare_column( + uint8_t* const __restrict res_u8, + const int32_t* const __restrict left, + const int32_t* const __restrict right, + const size_t size) { + // the restriction of the API + assert((size % 8) == 0); + + // todo: aligned reads & writes + + const size_t size8 = (size / 8) * 8; + for (size_t i = 0; i < size8; i += 8) { + const int32x4x2_t v0l = {vld1q_s32(left + i), vld1q_s32(left + i + 4)}; + const int32x4x2_t v0r = {vld1q_s32(right + i), + vld1q_s32(right + i + 4)}; + const uint32x4x2_t cmp = CmpHelper::compare(v0l, v0r); + const uint8_t mmask = movemask(cmp); + + res_u8[i / 8] = mmask; + } + + return true; +} + +template +bool +OpCompareColumnImpl::op_compare_column( + uint8_t* const __restrict res_u8, + const int64_t* const __restrict left, + const int64_t* const __restrict right, + const size_t size) { + // the restriction of the API + assert((size % 8) == 0); + + // todo: aligned reads & writes + + const size_t size8 = (size / 8) * 8; + for (size_t i = 0; i < size8; i += 8) { + const int64x2x4_t v0l = {vld1q_s64(left + i), + vld1q_s64(left + i + 2), + vld1q_s64(left + i + 4), + vld1q_s64(left + i + 6)}; + const int64x2x4_t v0r = {vld1q_s64(right + i), + vld1q_s64(right + i + 2), + vld1q_s64(right + i + 4), + vld1q_s64(right + i + 6)}; + const uint64x2x4_t cmp = CmpHelper::compare(v0l, v0r); + const uint8_t mmask = movemask(cmp); + + res_u8[i / 8] = mmask; + } + + return true; +} + +template +bool +OpCompareColumnImpl::op_compare_column( + uint8_t* const __restrict res_u8, + const float* const __restrict left, + const float* const __restrict right, + const size_t size) { + // the restriction of the API + assert((size % 8) == 0); + + // todo: aligned reads & writes + + const size_t size8 = (size / 8) * 8; + for (size_t i = 0; i < size8; i += 8) { + const float32x4x2_t v0l = {vld1q_f32(left + i), + vld1q_f32(left + i + 4)}; + const float32x4x2_t v0r = {vld1q_f32(right + i), + vld1q_f32(right + i + 4)}; + const uint32x4x2_t cmp = CmpHelper::compare(v0l, v0r); + const uint8_t mmask = movemask(cmp); + + res_u8[i / 8] = mmask; + } + + return true; +} + +template +bool +OpCompareColumnImpl::op_compare_column( + uint8_t* const __restrict res_u8, + const double* const __restrict left, + const double* const __restrict right, + const size_t size) { + // the restriction of the API + assert((size % 8) == 0); + + // todo: aligned reads & writes + + const size_t size8 = (size / 8) * 8; + for (size_t i = 0; i < size8; i += 8) { + const float64x2x4_t v0l = {vld1q_f64(left + i), + vld1q_f64(left + i + 2), + vld1q_f64(left + i + 4), + vld1q_f64(left + i + 6)}; + const float64x2x4_t v0r = {vld1q_f64(right + i), + vld1q_f64(right + i + 2), + vld1q_f64(right + i + 4), + vld1q_f64(right + i + 6)}; + const uint64x2x4_t cmp = CmpHelper::compare(v0l, v0r); + const uint8_t mmask = movemask(cmp); + + res_u8[i / 8] = mmask; + } + + return true; +} + +/////////////////////////////////////////////////////////////////////////// + +// +template +bool +OpWithinRangeColumnImpl::op_within_range_column( + uint8_t* const __restrict res_u8, + const int8_t* const __restrict lower, + const int8_t* const __restrict upper, + const int8_t* const __restrict values, + const size_t size) { + // the restriction of the API + assert((size % 8) == 0); + + // + uint32_t* const __restrict res_u32 = reinterpret_cast(res_u8); + + // todo: aligned reads & writes + + const size_t size32 = (size / 32) * 32; + for (size_t i = 0; i < size32; i += 32) { + const int8x16x2_t v0l = {vld1q_s8(lower + i), vld1q_s8(lower + i + 16)}; + const int8x16x2_t v0u = {vld1q_s8(upper + i), vld1q_s8(upper + i + 16)}; + const int8x16x2_t v0v = {vld1q_s8(values + i), + vld1q_s8(values + i + 16)}; + const uint8x16x2_t cmp0l = + CmpHelper::lower>::compare(v0l, v0v); + const uint8x16x2_t cmp0u = + CmpHelper::upper>::compare(v0v, v0u); + const uint8x16x2_t cmp = {vandq_u8(cmp0l.val[0], cmp0u.val[0]), + vandq_u8(cmp0l.val[1], cmp0u.val[1])}; + const uint32_t mmask = movemask(cmp); + + res_u32[i / 32] = mmask; + } + + for (size_t i = size32; i < size; i += 8) { + const int8x8_t v0l = vld1_s8(lower + i); + const int8x8_t v0u = vld1_s8(upper + i); + const int8x8_t v0v = vld1_s8(values + i); + const uint8x8_t cmp0l = + CmpHelper::lower>::compare(v0l, v0v); + const uint8x8_t cmp0u = + CmpHelper::upper>::compare(v0v, v0u); + const uint8x8_t cmp = vand_u8(cmp0l, cmp0u); + const uint8_t mmask = movemask(cmp); + + res_u8[i / 8] = mmask; + } + + return true; +} + +template +bool +OpWithinRangeColumnImpl::op_within_range_column( + uint8_t* const __restrict res_u8, + const int16_t* const __restrict lower, + const int16_t* const __restrict upper, + const int16_t* const __restrict values, + const size_t size) { + // the restriction of the API + assert((size % 8) == 0); + + // + uint16_t* const __restrict res_u16 = reinterpret_cast(res_u8); + + // todo: aligned reads & writes + + const size_t size16 = (size / 16) * 16; + for (size_t i = 0; i < size16; i += 16) { + const int16x8x2_t v0l = {vld1q_s16(lower + i), + vld1q_s16(lower + i + 8)}; + const int16x8x2_t v0u = {vld1q_s16(upper + i), + vld1q_s16(upper + i + 8)}; + const int16x8x2_t v0v = {vld1q_s16(values + i), + vld1q_s16(values + i + 8)}; + const uint16x8x2_t cmp0l = + CmpHelper::lower>::compare(v0l, v0v); + const uint16x8x2_t cmp0u = + CmpHelper::upper>::compare(v0v, v0u); + const uint16x8x2_t cmp = {vandq_u16(cmp0l.val[0], cmp0u.val[0]), + vandq_u16(cmp0l.val[1], cmp0u.val[1])}; + const uint16_t mmask = movemask(cmp); + + res_u16[i / 16] = mmask; + } + + if (size16 != size) { + // 8 elements to process + const int16x8_t v0l = vld1q_s16(lower + size16); + const int16x8_t v0u = vld1q_s16(upper + size16); + const int16x8_t v0v = vld1q_s16(values + size16); + const uint16x8_t cmp0l = + CmpHelper::lower>::compare(v0l, v0v); + const uint16x8_t cmp0u = + CmpHelper::upper>::compare(v0v, v0u); + const uint16x8_t cmp = vandq_u16(cmp0l, cmp0u); + const uint8_t mmask = movemask(cmp); + + res_u8[size16 / 8] = mmask; + } + + return true; +} + +template +bool +OpWithinRangeColumnImpl::op_within_range_column( + uint8_t* const __restrict res_u8, + const int32_t* const __restrict lower, + const int32_t* const __restrict upper, + const int32_t* const __restrict values, + const size_t size) { + // the restriction of the API + assert((size % 8) == 0); + + // todo: aligned reads & writes + + const size_t size8 = (size / 8) * 8; + for (size_t i = 0; i < size8; i += 8) { + const int32x4x2_t v0l = {vld1q_s32(lower + i), + vld1q_s32(lower + i + 4)}; + const int32x4x2_t v0u = {vld1q_s32(upper + i), + vld1q_s32(upper + i + 4)}; + const int32x4x2_t v0v = {vld1q_s32(values + i), + vld1q_s32(values + i + 4)}; + const uint32x4x2_t cmp0l = + CmpHelper::lower>::compare(v0l, v0v); + const uint32x4x2_t cmp0u = + CmpHelper::upper>::compare(v0v, v0u); + const uint32x4x2_t cmp = {vandq_u32(cmp0l.val[0], cmp0u.val[0]), + vandq_u32(cmp0l.val[1], cmp0u.val[1])}; + const uint8_t mmask = movemask(cmp); + + res_u8[i / 8] = mmask; + } + + return true; +} + +template +bool +OpWithinRangeColumnImpl::op_within_range_column( + uint8_t* const __restrict res_u8, + const int64_t* const __restrict lower, + const int64_t* const __restrict upper, + const int64_t* const __restrict values, + const size_t size) { + // the restriction of the API + assert((size % 8) == 0); + + // todo: aligned reads & writes + + const size_t size8 = (size / 8) * 8; + for (size_t i = 0; i < size8; i += 8) { + const int64x2x4_t v0l = {vld1q_s64(lower + i), + vld1q_s64(lower + i + 2), + vld1q_s64(lower + i + 4), + vld1q_s64(lower + i + 6)}; + const int64x2x4_t v0u = {vld1q_s64(upper + i), + vld1q_s64(upper + i + 2), + vld1q_s64(upper + i + 4), + vld1q_s64(upper + i + 6)}; + const int64x2x4_t v0v = {vld1q_s64(values + i), + vld1q_s64(values + i + 2), + vld1q_s64(values + i + 4), + vld1q_s64(values + i + 6)}; + const uint64x2x4_t cmp0l = + CmpHelper::lower>::compare(v0l, v0v); + const uint64x2x4_t cmp0u = + CmpHelper::upper>::compare(v0v, v0u); + const uint64x2x4_t cmp = {vandq_u64(cmp0l.val[0], cmp0u.val[0]), + vandq_u64(cmp0l.val[1], cmp0u.val[1]), + vandq_u64(cmp0l.val[2], cmp0u.val[2]), + vandq_u64(cmp0l.val[3], cmp0u.val[3])}; + const uint8_t mmask = movemask(cmp); + + res_u8[i / 8] = mmask; + } + + return true; +} + +template +bool +OpWithinRangeColumnImpl::op_within_range_column( + uint8_t* const __restrict res_u8, + const float* const __restrict lower, + const float* const __restrict upper, + const float* const __restrict values, + const size_t size) { + // the restriction of the API + assert((size % 8) == 0); + + // todo: aligned reads & writes + + const size_t size8 = (size / 8) * 8; + for (size_t i = 0; i < size8; i += 8) { + const float32x4x2_t v0l = {vld1q_f32(lower + i), + vld1q_f32(lower + i + 4)}; + const float32x4x2_t v0u = {vld1q_f32(upper + i), + vld1q_f32(upper + i + 4)}; + const float32x4x2_t v0v = {vld1q_f32(values + i), + vld1q_f32(values + i + 4)}; + const uint32x4x2_t cmp0l = + CmpHelper::lower>::compare(v0l, v0v); + const uint32x4x2_t cmp0u = + CmpHelper::upper>::compare(v0v, v0u); + const uint32x4x2_t cmp = {vandq_u32(cmp0l.val[0], cmp0u.val[0]), + vandq_u32(cmp0l.val[1], cmp0u.val[1])}; + const uint8_t mmask = movemask(cmp); + + res_u8[i / 8] = mmask; + } + + return true; +} + +template +bool +OpWithinRangeColumnImpl::op_within_range_column( + uint8_t* const __restrict res_u8, + const double* const __restrict lower, + const double* const __restrict upper, + const double* const __restrict values, + const size_t size) { + // the restriction of the API + assert((size % 8) == 0); + + // todo: aligned reads & writes + + const size_t size8 = (size / 8) * 8; + for (size_t i = 0; i < size8; i += 8) { + const float64x2x4_t v0l = {vld1q_f64(lower + i), + vld1q_f64(lower + i + 2), + vld1q_f64(lower + i + 4), + vld1q_f64(lower + i + 6)}; + const float64x2x4_t v0u = {vld1q_f64(upper + i), + vld1q_f64(upper + i + 2), + vld1q_f64(upper + i + 4), + vld1q_f64(upper + i + 6)}; + const float64x2x4_t v0v = {vld1q_f64(values + i), + vld1q_f64(values + i + 2), + vld1q_f64(values + i + 4), + vld1q_f64(values + i + 6)}; + const uint64x2x4_t cmp0l = + CmpHelper::lower>::compare(v0l, v0v); + const uint64x2x4_t cmp0u = + CmpHelper::upper>::compare(v0v, v0u); + const uint64x2x4_t cmp = {vandq_u64(cmp0l.val[0], cmp0u.val[0]), + vandq_u64(cmp0l.val[1], cmp0u.val[1]), + vandq_u64(cmp0l.val[2], cmp0u.val[2]), + vandq_u64(cmp0l.val[3], cmp0u.val[3])}; + const uint8_t mmask = movemask(cmp); + + res_u8[i / 8] = mmask; + } + + return true; +} + +/////////////////////////////////////////////////////////////////////////// + +// +template +bool +OpWithinRangeValImpl::op_within_range_val( + uint8_t* const __restrict res_u8, + const int8_t& lower, + const int8_t& upper, + const int8_t* const __restrict values, + const size_t size) { + // the restriction of the API + assert((size % 8) == 0); + + // + const int8x16x2_t lower_v = {vdupq_n_s8(lower), vdupq_n_s8(lower)}; + const int8x16x2_t upper_v = {vdupq_n_s8(upper), vdupq_n_s8(upper)}; + uint32_t* const __restrict res_u32 = reinterpret_cast(res_u8); + + // todo: aligned reads & writes + + const size_t size32 = (size / 32) * 32; + for (size_t i = 0; i < size32; i += 32) { + const int8x16x2_t v0v = {vld1q_s8(values + i), + vld1q_s8(values + i + 16)}; + const uint8x16x2_t cmp0l = + CmpHelper::lower>::compare(lower_v, v0v); + const uint8x16x2_t cmp0u = + CmpHelper::upper>::compare(v0v, upper_v); + const uint8x16x2_t cmp = {vandq_u8(cmp0l.val[0], cmp0u.val[0]), + vandq_u8(cmp0l.val[1], cmp0u.val[1])}; + const uint32_t mmask = movemask(cmp); + + res_u32[i / 32] = mmask; + } + + for (size_t i = size32; i < size; i += 8) { + const int8x8_t lower_v1 = vdup_n_s8(lower); + const int8x8_t upper_v1 = vdup_n_s8(upper); + const int8x8_t v0v = vld1_s8(values + i); + const uint8x8_t cmp0l = + CmpHelper::lower>::compare(lower_v1, v0v); + const uint8x8_t cmp0u = + CmpHelper::upper>::compare(v0v, upper_v1); + const uint8x8_t cmp = vand_u8(cmp0l, cmp0u); + const uint8_t mmask = movemask(cmp); + + res_u8[i / 8] = mmask; + } + + return true; +} + +template +bool +OpWithinRangeValImpl::op_within_range_val( + uint8_t* const __restrict res_u8, + const int16_t& lower, + const int16_t& upper, + const int16_t* const __restrict values, + const size_t size) { + // the restriction of the API + assert((size % 8) == 0); + + // + const int16x8x2_t lower_v = {vdupq_n_s16(lower), vdupq_n_s16(lower)}; + const int16x8x2_t upper_v = {vdupq_n_s16(upper), vdupq_n_s16(upper)}; + uint16_t* const __restrict res_u16 = reinterpret_cast(res_u8); + + // todo: aligned reads & writes + + const size_t size16 = (size / 16) * 16; + for (size_t i = 0; i < size16; i += 16) { + const int16x8x2_t v0v = {vld1q_s16(values + i), + vld1q_s16(values + i + 8)}; + const uint16x8x2_t cmp0l = + CmpHelper::lower>::compare(lower_v, v0v); + const uint16x8x2_t cmp0u = + CmpHelper::upper>::compare(v0v, upper_v); + const uint16x8x2_t cmp = {vandq_u16(cmp0l.val[0], cmp0u.val[0]), + vandq_u16(cmp0l.val[1], cmp0u.val[1])}; + const uint16_t mmask = movemask(cmp); + + res_u16[i / 16] = mmask; + } + + if (size16 != size) { + // 8 elements to process + const int16x8_t v0v = vld1q_s16(values + size16); + const uint16x8_t cmp0l = + CmpHelper::lower>::compare(lower_v.val[0], v0v); + const uint16x8_t cmp0u = + CmpHelper::upper>::compare(v0v, upper_v.val[0]); + const uint16x8_t cmp = vandq_u16(cmp0l, cmp0u); + const uint8_t mmask = movemask(cmp); + + res_u8[size16 / 8] = mmask; + } + + return true; +} + +template +bool +OpWithinRangeValImpl::op_within_range_val( + uint8_t* const __restrict res_u8, + const int32_t& lower, + const int32_t& upper, + const int32_t* const __restrict values, + const size_t size) { + // the restriction of the API + assert((size % 8) == 0); + + // + const int32x4x2_t lower_v = {vdupq_n_s32(lower), vdupq_n_s32(lower)}; + const int32x4x2_t upper_v = {vdupq_n_s32(upper), vdupq_n_s32(upper)}; + + // todo: aligned reads & writes + + const size_t size8 = (size / 8) * 8; + for (size_t i = 0; i < size8; i += 8) { + const int32x4x2_t v0v = {vld1q_s32(values + i), + vld1q_s32(values + i + 4)}; + const uint32x4x2_t cmp0l = + CmpHelper::lower>::compare(lower_v, v0v); + const uint32x4x2_t cmp0u = + CmpHelper::upper>::compare(v0v, upper_v); + const uint32x4x2_t cmp = {vandq_u32(cmp0l.val[0], cmp0u.val[0]), + vandq_u32(cmp0l.val[1], cmp0u.val[1])}; + const uint8_t mmask = movemask(cmp); + + res_u8[i / 8] = mmask; + } + + return true; +} + +template +bool +OpWithinRangeValImpl::op_within_range_val( + uint8_t* const __restrict res_u8, + const int64_t& lower, + const int64_t& upper, + const int64_t* const __restrict values, + const size_t size) { + // the restriction of the API + assert((size % 8) == 0); + + // + const int64x2x4_t lower_v = {vdupq_n_s64(lower), + vdupq_n_s64(lower), + vdupq_n_s64(lower), + vdupq_n_s64(lower)}; + const int64x2x4_t upper_v = {vdupq_n_s64(upper), + vdupq_n_s64(upper), + vdupq_n_s64(upper), + vdupq_n_s64(upper)}; + + // todo: aligned reads & writes + + const size_t size8 = (size / 8) * 8; + for (size_t i = 0; i < size8; i += 8) { + const int64x2x4_t v0v = {vld1q_s64(values + i), + vld1q_s64(values + i + 2), + vld1q_s64(values + i + 4), + vld1q_s64(values + i + 6)}; + const uint64x2x4_t cmp0l = + CmpHelper::lower>::compare(lower_v, v0v); + const uint64x2x4_t cmp0u = + CmpHelper::upper>::compare(v0v, upper_v); + const uint64x2x4_t cmp = {vandq_u64(cmp0l.val[0], cmp0u.val[0]), + vandq_u64(cmp0l.val[1], cmp0u.val[1]), + vandq_u64(cmp0l.val[2], cmp0u.val[2]), + vandq_u64(cmp0l.val[3], cmp0u.val[3])}; + const uint8_t mmask = movemask(cmp); + + res_u8[i / 8] = mmask; + } + + return true; +} + +template +bool +OpWithinRangeValImpl::op_within_range_val( + uint8_t* const __restrict res_u8, + const float& lower, + const float& upper, + const float* const __restrict values, + const size_t size) { + // the restriction of the API + assert((size % 8) == 0); + + // + const float32x4x2_t lower_v = {vdupq_n_f32(lower), vdupq_n_f32(lower)}; + const float32x4x2_t upper_v = {vdupq_n_f32(upper), vdupq_n_f32(upper)}; + + // todo: aligned reads & writes + + const size_t size8 = (size / 8) * 8; + for (size_t i = 0; i < size8; i += 8) { + const float32x4x2_t v0v = {vld1q_f32(values + i), + vld1q_f32(values + i + 4)}; + const uint32x4x2_t cmp0l = + CmpHelper::lower>::compare(lower_v, v0v); + const uint32x4x2_t cmp0u = + CmpHelper::upper>::compare(v0v, upper_v); + const uint32x4x2_t cmp = {vandq_u32(cmp0l.val[0], cmp0u.val[0]), + vandq_u32(cmp0l.val[1], cmp0u.val[1])}; + const uint8_t mmask = movemask(cmp); + + res_u8[i / 8] = mmask; + } + + return true; +} + +template +bool +OpWithinRangeValImpl::op_within_range_val( + uint8_t* const __restrict res_u8, + const double& lower, + const double& upper, + const double* const __restrict values, + const size_t size) { + // the restriction of the API + assert((size % 8) == 0); + + // + const float64x2x4_t lower_v = {vdupq_n_f64(lower), + vdupq_n_f64(lower), + vdupq_n_f64(lower), + vdupq_n_f64(lower)}; + const float64x2x4_t upper_v = {vdupq_n_f64(upper), + vdupq_n_f64(upper), + vdupq_n_f64(upper), + vdupq_n_f64(upper)}; + + // todo: aligned reads & writes + + const size_t size8 = (size / 8) * 8; + for (size_t i = 0; i < size8; i += 8) { + const float64x2x4_t v0v = {vld1q_f64(values + i), + vld1q_f64(values + i + 2), + vld1q_f64(values + i + 4), + vld1q_f64(values + i + 6)}; + const uint64x2x4_t cmp0l = + CmpHelper::lower>::compare(lower_v, v0v); + const uint64x2x4_t cmp0u = + CmpHelper::upper>::compare(v0v, upper_v); + const uint64x2x4_t cmp = {vandq_u64(cmp0l.val[0], cmp0u.val[0]), + vandq_u64(cmp0l.val[1], cmp0u.val[1]), + vandq_u64(cmp0l.val[2], cmp0u.val[2]), + vandq_u64(cmp0l.val[3], cmp0u.val[3])}; + const uint8_t mmask = movemask(cmp); + + res_u8[i / 8] = mmask; + } + + return true; +} + +/////////////////////////////////////////////////////////////////////////// + +namespace { + +// +template +struct ArithHelperI64 {}; + +template +struct ArithHelperI64 { + static inline uint64x2x4_t + op(const int64x2x4_t left, + const int64x2x4_t right, + const int64x2x4_t value) { + // left + right == value + const int64x2x4_t lr = {vaddq_s64(left.val[0], right.val[0]), + vaddq_s64(left.val[1], right.val[1]), + vaddq_s64(left.val[2], right.val[2]), + vaddq_s64(left.val[3], right.val[3])}; + return CmpHelper::compare(lr, value); + } +}; + +template +struct ArithHelperI64 { + static inline uint64x2x4_t + op(const int64x2x4_t left, + const int64x2x4_t right, + const int64x2x4_t value) { + // left - right == value + const int64x2x4_t lr = {vsubq_s64(left.val[0], right.val[0]), + vsubq_s64(left.val[1], right.val[1]), + vsubq_s64(left.val[2], right.val[2]), + vsubq_s64(left.val[3], right.val[3])}; + return CmpHelper::compare(lr, value); + } +}; + +// template +// struct ArithHelperI64 { +// // todo draft: https://stackoverflow.com/questions/60236627/facing-problem-in-implementing-multiplication-of-64-bit-variables-using-arm-neon +// inline int64x2_t arm_vmulq_s64(const int64x2_t a, const int64x2_t b) +// { +// const auto ac = vmovn_s64(a); +// const auto pr = vmovn_s64(b); + +// const auto hi = vmulq_s32(b, vrev64q_s32(a)); + +// return vmlal_u32(vshlq_n_s64(vpaddlq_u32(hi), 32), ac, pr); +// } + +// static inline uint64x2x4_t op(const int64x2x4_t left, const int64x2x4_t right, const int64x2x4_t value) { +// // left * right == value +// const int64x2x4_t lr = { +// arm_vmulq_s64(left.val[0], right.val[0]), +// arm_vmulq_s64(left.val[1], right.val[1]), +// arm_vmulq_s64(left.val[2], right.val[2]), +// arm_vmulq_s64(left.val[3], right.val[3]) +// }; +// return CmpHelper::compare(lr, value); +// } +// }; + +// +template +struct ArithHelperF32 {}; + +template +struct ArithHelperF32 { + static inline uint32x4x2_t + op(const float32x4x2_t left, + const float32x4x2_t right, + const float32x4x2_t value) { + // left + right == value + const float32x4x2_t lr = {vaddq_f32(left.val[0], right.val[0]), + vaddq_f32(left.val[1], right.val[1])}; + return CmpHelper::compare(lr, value); + } +}; + +template +struct ArithHelperF32 { + static inline uint32x4x2_t + op(const float32x4x2_t left, + const float32x4x2_t right, + const float32x4x2_t value) { + // left - right == value + const float32x4x2_t lr = {vsubq_f32(left.val[0], right.val[0]), + vsubq_f32(left.val[1], right.val[1])}; + return CmpHelper::compare(lr, value); + } +}; + +template +struct ArithHelperF32 { + static inline uint32x4x2_t + op(const float32x4x2_t left, + const float32x4x2_t right, + const float32x4x2_t value) { + // left * right == value + const float32x4x2_t lr = {vmulq_f32(left.val[0], right.val[0]), + vmulq_f32(left.val[1], right.val[1])}; + return CmpHelper::compare(lr, value); + } +}; + +template +struct ArithHelperF32 { + static inline uint32x4x2_t + op(const float32x4x2_t left, + const float32x4x2_t right, + const float32x4x2_t value) { + // left == right * value + const float32x4x2_t rv = {vmulq_f32(right.val[0], value.val[0]), + vmulq_f32(right.val[1], value.val[1])}; + return CmpHelper::compare(left, rv); + } +}; + +// +template +struct ArithHelperF64 {}; + +template +struct ArithHelperF64 { + static inline uint64x2x4_t + op(const float64x2x4_t left, + const float64x2x4_t right, + const float64x2x4_t value) { + // left + right == value + const float64x2x4_t lr = {vaddq_f64(left.val[0], right.val[0]), + vaddq_f64(left.val[1], right.val[1]), + vaddq_f64(left.val[2], right.val[2]), + vaddq_f64(left.val[3], right.val[3])}; + return CmpHelper::compare(lr, value); + } +}; + +template +struct ArithHelperF64 { + static inline uint64x2x4_t + op(const float64x2x4_t left, + const float64x2x4_t right, + const float64x2x4_t value) { + // left - right == value + const float64x2x4_t lr = {vsubq_f64(left.val[0], right.val[0]), + vsubq_f64(left.val[1], right.val[1]), + vsubq_f64(left.val[2], right.val[2]), + vsubq_f64(left.val[3], right.val[3])}; + return CmpHelper::compare(lr, value); + } +}; + +template +struct ArithHelperF64 { + static inline uint64x2x4_t + op(const float64x2x4_t left, + const float64x2x4_t right, + const float64x2x4_t value) { + // left * right == value + const float64x2x4_t lr = {vmulq_f64(left.val[0], right.val[0]), + vmulq_f64(left.val[1], right.val[1]), + vmulq_f64(left.val[2], right.val[2]), + vmulq_f64(left.val[3], right.val[3])}; + return CmpHelper::compare(lr, value); + } +}; + +template +struct ArithHelperF64 { + static inline uint64x2x4_t + op(const float64x2x4_t left, + const float64x2x4_t right, + const float64x2x4_t value) { + // left == right * value + const float64x2x4_t rv = {vmulq_f64(right.val[0], value.val[0]), + vmulq_f64(right.val[1], value.val[1]), + vmulq_f64(right.val[2], value.val[2]), + vmulq_f64(right.val[3], value.val[3])}; + return CmpHelper::compare(left, rv); + } +}; + +} // namespace + +// todo: Mul, Div, Mod + +template +bool +OpArithCompareImpl::op_arith_compare( + uint8_t* const __restrict res_u8, + const int8_t* const __restrict src, + const ArithHighPrecisionType& right_operand, + const ArithHighPrecisionType& value, + const size_t size) { + if constexpr (AOp == ArithOpType::Mul || AOp == ArithOpType::Div || + AOp == ArithOpType::Mod) { + return false; + } else { + // the restriction of the API + assert((size % 8) == 0); + static_assert(std::is_same_v>); + + // + const int64x2x4_t right_v = {vdupq_n_s64(right_operand), + vdupq_n_s64(right_operand), + vdupq_n_s64(right_operand), + vdupq_n_s64(right_operand)}; + const int64x2x4_t value_v = {vdupq_n_s64(value), + vdupq_n_s64(value), + vdupq_n_s64(value), + vdupq_n_s64(value)}; + + // todo: aligned reads & writes + + const size_t size8 = (size / 8) * 8; + for (size_t i = 0; i < size8; i += 8) { + const int8x8_t v0v_i8 = vld1_s8(src + i); + const int16x8_t v0v_i16 = vmovl_s8(v0v_i8); + const int32x4x2_t v0v_i32 = {vmovl_s16(vget_low_s16(v0v_i16)), + vmovl_s16(vget_high_s16(v0v_i16))}; + const int64x2x4_t v0v_i64 = { + vmovl_s32(vget_low_s32(v0v_i32.val[0])), + vmovl_s32(vget_high_s32(v0v_i32.val[0])), + vmovl_s32(vget_low_s32(v0v_i32.val[1])), + vmovl_s32(vget_high_s32(v0v_i32.val[1]))}; + + const uint64x2x4_t cmp = + ArithHelperI64::op(v0v_i64, right_v, value_v); + + const uint8_t mmask = movemask(cmp); + res_u8[i / 8] = mmask; + } + + return true; + } +} + +template +bool +OpArithCompareImpl::op_arith_compare( + uint8_t* const __restrict res_u8, + const int16_t* const __restrict src, + const ArithHighPrecisionType& right_operand, + const ArithHighPrecisionType& value, + const size_t size) { + if constexpr (AOp == ArithOpType::Mul || AOp == ArithOpType::Div || + AOp == ArithOpType::Mod) { + return false; + } else { + // the restriction of the API + assert((size % 8) == 0); + static_assert(std::is_same_v>); + + // + const int64x2x4_t right_v = {vdupq_n_s64(right_operand), + vdupq_n_s64(right_operand), + vdupq_n_s64(right_operand), + vdupq_n_s64(right_operand)}; + const int64x2x4_t value_v = {vdupq_n_s64(value), + vdupq_n_s64(value), + vdupq_n_s64(value), + vdupq_n_s64(value)}; + + // todo: aligned reads & writes + + const size_t size8 = (size / 8) * 8; + for (size_t i = 0; i < size8; i += 8) { + const int16x8_t v0v_i16 = vld1q_s16(src + i); + const int32x4x2_t v0v_i32 = {vmovl_s16(vget_low_s16(v0v_i16)), + vmovl_s16(vget_high_s16(v0v_i16))}; + const int64x2x4_t v0v_i64 = { + vmovl_s32(vget_low_s32(v0v_i32.val[0])), + vmovl_s32(vget_high_s32(v0v_i32.val[0])), + vmovl_s32(vget_low_s32(v0v_i32.val[1])), + vmovl_s32(vget_high_s32(v0v_i32.val[1]))}; + + const uint64x2x4_t cmp = + ArithHelperI64::op(v0v_i64, right_v, value_v); + + const uint8_t mmask = movemask(cmp); + res_u8[i / 8] = mmask; + } + + return true; + } +} + +template +bool +OpArithCompareImpl::op_arith_compare( + uint8_t* const __restrict res_u8, + const int32_t* const __restrict src, + const ArithHighPrecisionType& right_operand, + const ArithHighPrecisionType& value, + const size_t size) { + if constexpr (AOp == ArithOpType::Mul || AOp == ArithOpType::Div || + AOp == ArithOpType::Mod) { + return false; + } else { + // the restriction of the API + assert((size % 8) == 0); + static_assert(std::is_same_v>); + + // + const int64x2x4_t right_v = {vdupq_n_s64(right_operand), + vdupq_n_s64(right_operand), + vdupq_n_s64(right_operand), + vdupq_n_s64(right_operand)}; + const int64x2x4_t value_v = {vdupq_n_s64(value), + vdupq_n_s64(value), + vdupq_n_s64(value), + vdupq_n_s64(value)}; + + // todo: aligned reads & writes + + const size_t size8 = (size / 8) * 8; + for (size_t i = 0; i < size8; i += 8) { + const int32x4x2_t v0v_i32 = {vld1q_s32(src + i), + vld1q_s32(src + i + 4)}; + const int64x2x4_t v0v_i64 = { + vmovl_s32(vget_low_s32(v0v_i32.val[0])), + vmovl_s32(vget_high_s32(v0v_i32.val[0])), + vmovl_s32(vget_low_s32(v0v_i32.val[1])), + vmovl_s32(vget_high_s32(v0v_i32.val[1]))}; + + const uint64x2x4_t cmp = + ArithHelperI64::op(v0v_i64, right_v, value_v); + + const uint8_t mmask = movemask(cmp); + res_u8[i / 8] = mmask; + } + + return true; + } +} + +template +bool +OpArithCompareImpl::op_arith_compare( + uint8_t* const __restrict res_u8, + const int64_t* const __restrict src, + const ArithHighPrecisionType& right_operand, + const ArithHighPrecisionType& value, + const size_t size) { + if constexpr (AOp == ArithOpType::Mul || AOp == ArithOpType::Div || + AOp == ArithOpType::Mod) { + return false; + } else { + // the restriction of the API + assert((size % 8) == 0); + static_assert(std::is_same_v>); + + // + const int64x2x4_t right_v = {vdupq_n_s64(right_operand), + vdupq_n_s64(right_operand), + vdupq_n_s64(right_operand), + vdupq_n_s64(right_operand)}; + const int64x2x4_t value_v = {vdupq_n_s64(value), + vdupq_n_s64(value), + vdupq_n_s64(value), + vdupq_n_s64(value)}; + + // todo: aligned reads & writes + + const size_t size8 = (size / 8) * 8; + for (size_t i = 0; i < size8; i += 8) { + const int64x2x4_t v0v = {vld1q_s64(src + i), + vld1q_s64(src + i + 2), + vld1q_s64(src + i + 4), + vld1q_s64(src + i + 6)}; + const uint64x2x4_t cmp = + ArithHelperI64::op(v0v, right_v, value_v); + + const uint8_t mmask = movemask(cmp); + res_u8[i / 8] = mmask; + } + + return true; + } +} + +template +bool +OpArithCompareImpl::op_arith_compare( + uint8_t* const __restrict res_u8, + const float* const __restrict src, + const ArithHighPrecisionType& right_operand, + const ArithHighPrecisionType& value, + const size_t size) { + if constexpr (AOp == ArithOpType::Mod) { + return false; + } else { + // the restriction of the API + assert((size % 8) == 0); + + // + const float32x4x2_t right_v = {vdupq_n_f32(right_operand), + vdupq_n_f32(right_operand)}; + const float32x4x2_t value_v = {vdupq_n_f32(value), vdupq_n_f32(value)}; + + // todo: aligned reads & writes + + const size_t size8 = (size / 8) * 8; + for (size_t i = 0; i < size8; i += 8) { + const float32x4x2_t v0v = {vld1q_f32(src + i), + vld1q_f32(src + i + 4)}; + const uint32x4x2_t cmp = + ArithHelperF32::op(v0v, right_v, value_v); + + const uint8_t mmask = movemask(cmp); + res_u8[i / 8] = mmask; + } + + return true; + } +} + +template +bool +OpArithCompareImpl::op_arith_compare( + uint8_t* const __restrict res_u8, + const double* const __restrict src, + const ArithHighPrecisionType& right_operand, + const ArithHighPrecisionType& value, + const size_t size) { + if constexpr (AOp == ArithOpType::Mod) { + return false; + } else { + // the restriction of the API + assert((size % 8) == 0); + + // + const float64x2x4_t right_v = {vdupq_n_f64(right_operand), + vdupq_n_f64(right_operand), + vdupq_n_f64(right_operand), + vdupq_n_f64(right_operand)}; + const float64x2x4_t value_v = {vdupq_n_f64(value), + vdupq_n_f64(value), + vdupq_n_f64(value), + vdupq_n_f64(value)}; + + // todo: aligned reads & writes + + const size_t size8 = (size / 8) * 8; + for (size_t i = 0; i < size8; i += 8) { + const float64x2x4_t v0v = {vld1q_f64(src + i), + vld1q_f64(src + i + 2), + vld1q_f64(src + i + 4), + vld1q_f64(src + i + 6)}; + const uint64x2x4_t cmp = + ArithHelperF64::op(v0v, right_v, value_v); + + const uint8_t mmask = movemask(cmp); + res_u8[i / 8] = mmask; + } + + return true; + } +} + +/////////////////////////////////////////////////////////////////////////// + +} // namespace neon +} // namespace arm +} // namespace detail +} // namespace bitset +} // namespace milvus diff --git a/internal/core/src/bitset/detail/platform/arm/neon-inst.cpp b/internal/core/src/bitset/detail/platform/arm/neon-inst.cpp new file mode 100644 index 000000000000..01069b1fa00a --- /dev/null +++ b/internal/core/src/bitset/detail/platform/arm/neon-inst.cpp @@ -0,0 +1,199 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// ARM NEON instantiation + +#include "bitset/common.h" + +#ifndef BITSET_HEADER_ONLY + +#include "neon-decl.h" +#include "neon-impl.h" + +#include +#include + +namespace milvus { +namespace bitset { +namespace detail { +namespace arm { +namespace neon { + +// a facility to run through all possible compare operations +#define ALL_COMPARE_OPS(FUNC, ...) \ + FUNC(__VA_ARGS__, EQ); \ + FUNC(__VA_ARGS__, GE); \ + FUNC(__VA_ARGS__, GT); \ + FUNC(__VA_ARGS__, LE); \ + FUNC(__VA_ARGS__, LT); \ + FUNC(__VA_ARGS__, NE); + +// a facility to run through all possible range operations +#define ALL_RANGE_OPS(FUNC, ...) \ + FUNC(__VA_ARGS__, IncInc); \ + FUNC(__VA_ARGS__, IncExc); \ + FUNC(__VA_ARGS__, ExcInc); \ + FUNC(__VA_ARGS__, ExcExc); + +// a facility to run through all possible arithmetic compare operations +#define ALL_ARITH_CMP_OPS(FUNC, ...) \ + FUNC(__VA_ARGS__, Add, EQ); \ + FUNC(__VA_ARGS__, Add, GE); \ + FUNC(__VA_ARGS__, Add, GT); \ + FUNC(__VA_ARGS__, Add, LE); \ + FUNC(__VA_ARGS__, Add, LT); \ + FUNC(__VA_ARGS__, Add, NE); \ + FUNC(__VA_ARGS__, Sub, EQ); \ + FUNC(__VA_ARGS__, Sub, GE); \ + FUNC(__VA_ARGS__, Sub, GT); \ + FUNC(__VA_ARGS__, Sub, LE); \ + FUNC(__VA_ARGS__, Sub, LT); \ + FUNC(__VA_ARGS__, Sub, NE); \ + FUNC(__VA_ARGS__, Mul, EQ); \ + FUNC(__VA_ARGS__, Mul, GE); \ + FUNC(__VA_ARGS__, Mul, GT); \ + FUNC(__VA_ARGS__, Mul, LE); \ + FUNC(__VA_ARGS__, Mul, LT); \ + FUNC(__VA_ARGS__, Mul, NE); \ + FUNC(__VA_ARGS__, Div, EQ); \ + FUNC(__VA_ARGS__, Div, GE); \ + FUNC(__VA_ARGS__, Div, GT); \ + FUNC(__VA_ARGS__, Div, LE); \ + FUNC(__VA_ARGS__, Div, LT); \ + FUNC(__VA_ARGS__, Div, NE); \ + FUNC(__VA_ARGS__, Mod, EQ); \ + FUNC(__VA_ARGS__, Mod, GE); \ + FUNC(__VA_ARGS__, Mod, GT); \ + FUNC(__VA_ARGS__, Mod, LE); \ + FUNC(__VA_ARGS__, Mod, LT); \ + FUNC(__VA_ARGS__, Mod, NE); + +/////////////////////////////////////////////////////////////////////////// + +// +#define INSTANTIATE_COMPARE_VAL_NEON(TTYPE, OP) \ + template bool OpCompareValImpl::op_compare_val( \ + uint8_t* const __restrict bitmask, \ + const TTYPE* const __restrict src, \ + const size_t size, \ + const TTYPE& val); + +ALL_COMPARE_OPS(INSTANTIATE_COMPARE_VAL_NEON, int8_t) +ALL_COMPARE_OPS(INSTANTIATE_COMPARE_VAL_NEON, int16_t) +ALL_COMPARE_OPS(INSTANTIATE_COMPARE_VAL_NEON, int32_t) +ALL_COMPARE_OPS(INSTANTIATE_COMPARE_VAL_NEON, int64_t) +ALL_COMPARE_OPS(INSTANTIATE_COMPARE_VAL_NEON, float) +ALL_COMPARE_OPS(INSTANTIATE_COMPARE_VAL_NEON, double) + +#undef INSTANTIATE_COMPARE_VAL_NEON + +/////////////////////////////////////////////////////////////////////////// + +// +#define INSTANTIATE_COMPARE_COLUMN_NEON(TTYPE, OP) \ + template bool \ + OpCompareColumnImpl::op_compare_column( \ + uint8_t* const __restrict bitmask, \ + const TTYPE* const __restrict left, \ + const TTYPE* const __restrict right, \ + const size_t size); + +ALL_COMPARE_OPS(INSTANTIATE_COMPARE_COLUMN_NEON, int8_t) +ALL_COMPARE_OPS(INSTANTIATE_COMPARE_COLUMN_NEON, int16_t) +ALL_COMPARE_OPS(INSTANTIATE_COMPARE_COLUMN_NEON, int32_t) +ALL_COMPARE_OPS(INSTANTIATE_COMPARE_COLUMN_NEON, int64_t) +ALL_COMPARE_OPS(INSTANTIATE_COMPARE_COLUMN_NEON, float) +ALL_COMPARE_OPS(INSTANTIATE_COMPARE_COLUMN_NEON, double) + +#undef INSTANTIATE_COMPARE_COLUMN_NEON + +/////////////////////////////////////////////////////////////////////////// + +// +#define INSTANTIATE_WITHIN_RANGE_COLUMN_NEON(TTYPE, OP) \ + template bool \ + OpWithinRangeColumnImpl::op_within_range_column( \ + uint8_t* const __restrict res_u8, \ + const TTYPE* const __restrict lower, \ + const TTYPE* const __restrict upper, \ + const TTYPE* const __restrict values, \ + const size_t size); + +ALL_RANGE_OPS(INSTANTIATE_WITHIN_RANGE_COLUMN_NEON, int8_t) +ALL_RANGE_OPS(INSTANTIATE_WITHIN_RANGE_COLUMN_NEON, int16_t) +ALL_RANGE_OPS(INSTANTIATE_WITHIN_RANGE_COLUMN_NEON, int32_t) +ALL_RANGE_OPS(INSTANTIATE_WITHIN_RANGE_COLUMN_NEON, int64_t) +ALL_RANGE_OPS(INSTANTIATE_WITHIN_RANGE_COLUMN_NEON, float) +ALL_RANGE_OPS(INSTANTIATE_WITHIN_RANGE_COLUMN_NEON, double) + +#undef INSTANTIATE_WITHIN_RANGE_COLUMN_NEON + +/////////////////////////////////////////////////////////////////////////// + +// +#define INSTANTIATE_WITHIN_RANGE_VAL_NEON(TTYPE, OP) \ + template bool \ + OpWithinRangeValImpl::op_within_range_val( \ + uint8_t* const __restrict res_u8, \ + const TTYPE& lower, \ + const TTYPE& upper, \ + const TTYPE* const __restrict values, \ + const size_t size); + +ALL_RANGE_OPS(INSTANTIATE_WITHIN_RANGE_VAL_NEON, int8_t) +ALL_RANGE_OPS(INSTANTIATE_WITHIN_RANGE_VAL_NEON, int16_t) +ALL_RANGE_OPS(INSTANTIATE_WITHIN_RANGE_VAL_NEON, int32_t) +ALL_RANGE_OPS(INSTANTIATE_WITHIN_RANGE_VAL_NEON, int64_t) +ALL_RANGE_OPS(INSTANTIATE_WITHIN_RANGE_VAL_NEON, float) +ALL_RANGE_OPS(INSTANTIATE_WITHIN_RANGE_VAL_NEON, double) + +#undef INSTANTIATE_WITHIN_RANGE_VAL_NEON + +/////////////////////////////////////////////////////////////////////////// + +// +#define INSTANTIATE_ARITH_COMPARE_NEON(TTYPE, OP, CMP) \ + template bool \ + OpArithCompareImpl:: \ + op_arith_compare(uint8_t* const __restrict res_u8, \ + const TTYPE* const __restrict src, \ + const ArithHighPrecisionType& right_operand, \ + const ArithHighPrecisionType& value, \ + const size_t size); + +ALL_ARITH_CMP_OPS(INSTANTIATE_ARITH_COMPARE_NEON, int8_t) +ALL_ARITH_CMP_OPS(INSTANTIATE_ARITH_COMPARE_NEON, int16_t) +ALL_ARITH_CMP_OPS(INSTANTIATE_ARITH_COMPARE_NEON, int32_t) +ALL_ARITH_CMP_OPS(INSTANTIATE_ARITH_COMPARE_NEON, int64_t) +ALL_ARITH_CMP_OPS(INSTANTIATE_ARITH_COMPARE_NEON, float) +ALL_ARITH_CMP_OPS(INSTANTIATE_ARITH_COMPARE_NEON, double) + +#undef INSTANTIATE_ARITH_COMPARE_NEON + +/////////////////////////////////////////////////////////////////////////// + +// +#undef ALL_COMPARE_OPS +#undef ALL_RANGE_OPS +#undef ALL_ARITH_CMP_OPS + +} // namespace neon +} // namespace arm +} // namespace detail +} // namespace bitset +} // namespace milvus + +#endif diff --git a/internal/core/src/bitset/detail/platform/arm/neon.h b/internal/core/src/bitset/detail/platform/arm/neon.h new file mode 100644 index 000000000000..004547506e40 --- /dev/null +++ b/internal/core/src/bitset/detail/platform/arm/neon.h @@ -0,0 +1,63 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "bitset/common.h" + +#include "neon-decl.h" + +#ifdef BITSET_HEADER_ONLY +#include "neon-impl.h" +#endif + +namespace milvus { +namespace bitset { +namespace detail { +namespace arm { + +/////////////////////////////////////////////////////////////////////////// + +// +struct VectorizedNeon { + template + static constexpr inline auto op_compare_column = + neon::OpCompareColumnImpl::op_compare_column; + + template + static constexpr inline auto op_compare_val = + neon::OpCompareValImpl::op_compare_val; + + template + static constexpr inline auto op_within_range_column = + neon::OpWithinRangeColumnImpl::op_within_range_column; + + template + static constexpr inline auto op_within_range_val = + neon::OpWithinRangeValImpl::op_within_range_val; + + template + static constexpr inline auto op_arith_compare = + neon::OpArithCompareImpl::op_arith_compare; +}; + +} // namespace arm +} // namespace detail +} // namespace bitset +} // namespace milvus diff --git a/internal/core/src/bitset/detail/platform/arm/sve-decl.h b/internal/core/src/bitset/detail/platform/arm/sve-decl.h new file mode 100644 index 000000000000..f563041e1505 --- /dev/null +++ b/internal/core/src/bitset/detail/platform/arm/sve-decl.h @@ -0,0 +1,201 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// ARM SVE declaration + +#pragma once + +#include +#include + +#include "bitset/common.h" + +namespace milvus { +namespace bitset { +namespace detail { +namespace arm { +namespace sve { + +/////////////////////////////////////////////////////////////////////////// +// a facility to run through all acceptable data types +#define ALL_DATATYPES_1(FUNC) \ + FUNC(int8_t); \ + FUNC(int16_t); \ + FUNC(int32_t); \ + FUNC(int64_t); \ + FUNC(float); \ + FUNC(double); + +/////////////////////////////////////////////////////////////////////////// + +// the default implementation does nothing +template +struct OpCompareColumnImpl { + static bool + op_compare_column(uint8_t* const __restrict bitmask, + const T* const __restrict t, + const U* const __restrict u, + const size_t size) { + return false; + } +}; + +// the following use cases are handled +#define DECLARE_PARTIAL_OP_COMPARE_COLUMN(TTYPE) \ + template \ + struct OpCompareColumnImpl { \ + static bool \ + op_compare_column(uint8_t* const __restrict bitmask, \ + const TTYPE* const __restrict t, \ + const TTYPE* const __restrict u, \ + const size_t size); \ + }; + +ALL_DATATYPES_1(DECLARE_PARTIAL_OP_COMPARE_COLUMN) + +#undef DECLARE_PARTIAL_OP_COMPARE_COLUMN + +/////////////////////////////////////////////////////////////////////////// + +// the default implementation does nothing +template +struct OpCompareValImpl { + static inline bool + op_compare_val(uint8_t* const __restrict bitmask, + const T* const __restrict t, + const size_t size, + const T& value) { + return false; + } +}; + +// the following use cases are handled +#define DECLARE_PARTIAL_OP_COMPARE_VAL(TTYPE) \ + template \ + struct OpCompareValImpl { \ + static bool \ + op_compare_val(uint8_t* const __restrict bitmask, \ + const TTYPE* const __restrict t, \ + const size_t size, \ + const TTYPE& value); \ + }; + +ALL_DATATYPES_1(DECLARE_PARTIAL_OP_COMPARE_VAL) + +#undef DECLARE_PARTIAL_OP_COMPARE_VAL + +/////////////////////////////////////////////////////////////////////////// + +// the default implementation does nothing +template +struct OpWithinRangeColumnImpl { + static inline bool + op_within_range_column(uint8_t* const __restrict bitmask, + const T* const __restrict lower, + const T* const __restrict upper, + const T* const __restrict values, + const size_t size) { + return false; + } +}; + +// the following use cases are handled +#define DECLARE_PARTIAL_OP_WITHIN_RANGE_COLUMN(TTYPE) \ + template \ + struct OpWithinRangeColumnImpl { \ + static bool \ + op_within_range_column(uint8_t* const __restrict bitmask, \ + const TTYPE* const __restrict lower, \ + const TTYPE* const __restrict upper, \ + const TTYPE* const __restrict values, \ + const size_t size); \ + }; + +ALL_DATATYPES_1(DECLARE_PARTIAL_OP_WITHIN_RANGE_COLUMN) + +#undef DECLARE_PARTIAL_OP_WITHIN_RANGE_COLUMN + +/////////////////////////////////////////////////////////////////////////// + +// the default implementation does nothing +template +struct OpWithinRangeValImpl { + static inline bool + op_within_range_val(uint8_t* const __restrict bitmask, + const T& lower, + const T& upper, + const T* const __restrict values, + const size_t size) { + return false; + } +}; + +// the following use cases are handled +#define DECLARE_PARTIAL_OP_WITHIN_RANGE_VAL(TTYPE) \ + template \ + struct OpWithinRangeValImpl { \ + static bool \ + op_within_range_val(uint8_t* const __restrict bitmask, \ + const TTYPE& lower, \ + const TTYPE& upper, \ + const TTYPE* const __restrict values, \ + const size_t size); \ + }; + +ALL_DATATYPES_1(DECLARE_PARTIAL_OP_WITHIN_RANGE_VAL) + +#undef DECLARE_PARTIAL_OP_WITHIN_RANGE_VAL + +/////////////////////////////////////////////////////////////////////////// + +// the default implementation does nothing +template +struct OpArithCompareImpl { + static inline bool + op_arith_compare(uint8_t* const __restrict bitmask, + const T* const __restrict src, + const ArithHighPrecisionType& right_operand, + const ArithHighPrecisionType& value, + const size_t size) { + return false; + } +}; + +// the following use cases are handled +#define DECLARE_PARTIAL_OP_ARITH_COMPARE(TTYPE) \ + template \ + struct OpArithCompareImpl { \ + static bool \ + op_arith_compare(uint8_t* const __restrict bitmask, \ + const TTYPE* const __restrict src, \ + const ArithHighPrecisionType& right_operand, \ + const ArithHighPrecisionType& value, \ + const size_t size); \ + }; + +ALL_DATATYPES_1(DECLARE_PARTIAL_OP_ARITH_COMPARE) + +#undef DECLARE_PARTIAL_OP_ARITH_COMPARE + +/////////////////////////////////////////////////////////////////////////// + +#undef ALL_DATATYPES_1 + +} // namespace sve +} // namespace arm +} // namespace detail +} // namespace bitset +} // namespace milvus diff --git a/internal/core/src/bitset/detail/platform/arm/sve-impl.h b/internal/core/src/bitset/detail/platform/arm/sve-impl.h new file mode 100644 index 000000000000..18433402d04d --- /dev/null +++ b/internal/core/src/bitset/detail/platform/arm/sve-impl.h @@ -0,0 +1,1745 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// ARM SVE implementation + +#pragma once + +#include + +#include +#include +#include +#include + +#include "sve-decl.h" + +#include "bitset/common.h" + +// #include + +namespace milvus { +namespace bitset { +namespace detail { +namespace arm { +namespace sve { + +namespace { + +// +constexpr size_t MAX_SVE_WIDTH = 2048; + +constexpr uint8_t SVE_LANES_8[MAX_SVE_WIDTH / 8] = { + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, + 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, + 0x16, 0x17, 0x18, 0x19, 0x1A, 0x1B, 0x1C, 0x1D, 0x1E, 0x1F, 0x20, + 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, 0x2A, 0x2B, + 0x2C, 0x2D, 0x2E, 0x2F, 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, + 0x37, 0x38, 0x39, 0x3A, 0x3B, 0x3C, 0x3D, 0x3E, 0x3F, + + 0x40, 0x41, 0x42, 0x43, 0x44, 0x45, 0x46, 0x47, 0x48, 0x49, 0x4A, + 0x4B, 0x4C, 0x4D, 0x4E, 0x4F, 0x50, 0x51, 0x52, 0x53, 0x54, 0x55, + 0x56, 0x57, 0x58, 0x59, 0x5A, 0x5B, 0x5C, 0x5D, 0x5E, 0x5F, 0x60, + 0x61, 0x62, 0x63, 0x64, 0x65, 0x66, 0x67, 0x68, 0x69, 0x6A, 0x6B, + 0x6C, 0x6D, 0x6E, 0x6F, 0x70, 0x71, 0x72, 0x73, 0x74, 0x75, 0x76, + 0x77, 0x78, 0x79, 0x7A, 0x7B, 0x7C, 0x7D, 0x7E, 0x7F, + + 0x80, 0x81, 0x82, 0x83, 0x84, 0x85, 0x86, 0x87, 0x88, 0x89, 0x8A, + 0x8B, 0x8C, 0x8D, 0x8E, 0x8F, 0x90, 0x91, 0x92, 0x93, 0x94, 0x95, + 0x96, 0x97, 0x98, 0x99, 0x9A, 0x9B, 0x9C, 0x9D, 0x9E, 0x9F, 0xA0, + 0xA1, 0xA2, 0xA3, 0xA4, 0xA5, 0xA6, 0xA7, 0xA8, 0xA9, 0xAA, 0xAB, + 0xAC, 0xAD, 0xAE, 0xAF, 0xB0, 0xB1, 0xB2, 0xB3, 0xB4, 0xB5, 0xB6, + 0xB7, 0xB8, 0xB9, 0xBA, 0xBB, 0xBC, 0xBD, 0xBE, 0xBF, + + 0xC0, 0xC1, 0xC2, 0xC3, 0xC4, 0xC5, 0xC6, 0xC7, 0xC8, 0xC9, 0xCA, + 0xCB, 0xCC, 0xCD, 0xCE, 0xCF, 0xD0, 0xD1, 0xD2, 0xD3, 0xD4, 0xD5, + 0xD6, 0xD7, 0xD8, 0xD9, 0xDA, 0xDB, 0xDC, 0xDD, 0xDE, 0xDF, 0xE0, + 0xE1, 0xE2, 0xE3, 0xE4, 0xE5, 0xE6, 0xE7, 0xE8, 0xE9, 0xEA, 0xEB, + 0xEC, 0xED, 0xEE, 0xEF, 0xF0, 0xF1, 0xF2, 0xF3, 0xF4, 0xF5, 0xF6, + 0xF7, 0xF8, 0xF9, 0xFA, 0xFB, 0xFC, 0xFD, 0xFE, 0xFF}; + +constexpr uint16_t SVE_LANES_16[MAX_SVE_WIDTH / 16] = { + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, + 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, + 0x16, 0x17, 0x18, 0x19, 0x1A, 0x1B, 0x1C, 0x1D, 0x1E, 0x1F, 0x20, + 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, 0x2A, 0x2B, + 0x2C, 0x2D, 0x2E, 0x2F, 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, + 0x37, 0x38, 0x39, 0x3A, 0x3B, 0x3C, 0x3D, 0x3E, 0x3F, + + 0x40, 0x41, 0x42, 0x43, 0x44, 0x45, 0x46, 0x47, 0x48, 0x49, 0x4A, + 0x4B, 0x4C, 0x4D, 0x4E, 0x4F, 0x50, 0x51, 0x52, 0x53, 0x54, 0x55, + 0x56, 0x57, 0x58, 0x59, 0x5A, 0x5B, 0x5C, 0x5D, 0x5E, 0x5F, 0x60, + 0x61, 0x62, 0x63, 0x64, 0x65, 0x66, 0x67, 0x68, 0x69, 0x6A, 0x6B, + 0x6C, 0x6D, 0x6E, 0x6F, 0x70, 0x71, 0x72, 0x73, 0x74, 0x75, 0x76, + 0x77, 0x78, 0x79, 0x7A, 0x7B, 0x7C, 0x7D, 0x7E, 0x7F}; + +constexpr uint32_t SVE_LANES_32[MAX_SVE_WIDTH / 32] = { + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, + 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, + 0x16, 0x17, 0x18, 0x19, 0x1A, 0x1B, 0x1C, 0x1D, 0x1E, 0x1F, 0x20, + 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, 0x2A, 0x2B, + 0x2C, 0x2D, 0x2E, 0x2F, 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, + 0x37, 0x38, 0x39, 0x3A, 0x3B, 0x3C, 0x3D, 0x3E, 0x3F}; + +constexpr uint64_t SVE_LANES_64[MAX_SVE_WIDTH / 64] = { + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, + 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, + 0x16, 0x17, 0x18, 0x19, 0x1A, 0x1B, 0x1C, 0x1D, 0x1E, 0x1F}; + +/* +// debugging facilities + +// +void print_svbool_t(const svbool_t value) { + // 2048 bits, 256 bytes => 256 bits bitmask, 32 bytes + uint8_t v[MAX_SVE_WIDTH / 64]; + *((svbool_t*)v) = value; + + const size_t sve_width = svcntb(); + for (size_t i = 0; i < sve_width / 8; i++) { + printf("%d ", int(v[i])); + } + printf("\n"); +} + +// +void print_svuint8_t(const svuint8_t value) { + uint8_t v[MAX_SVE_WIDTH / 8]; + *((svuint8_t*)v) = value; + + const size_t sve_width = svcntb(); + for (size_t i = 0; i < sve_width; i++) { + printf("%d ", int(v[i])); + } + printf("\n"); +} + +*/ + +/////////////////////////////////////////////////////////////////////////// + +// todo: replace with pext whenever available + +// generate 16-bit bitmask from 8 serialized 16-bit svbool_t values +void +write_bitmask_16_8x(uint8_t* const __restrict res_u8, + const svbool_t pred_op, + const svbool_t pred_write, + const uint8_t* const __restrict pred_buf) { + // perform parallel pext + // 2048b -> 32 bytes mask -> 256 bytes total, 128 uint16_t values + // 512b -> 8 bytes mask -> 64 bytes total, 32 uint16_t values + // 256b -> 4 bytes mask -> 32 bytes total, 16 uint16_t values + // 128b -> 2 bytes mask -> 16 bytes total, 8 uint16_t values + + // this code does reduction of 16-bit 0b0A0B0C0D0E0F0G0H words into + // uint8_t values 0bABCDEFGH, then writes ones to the memory + + // we need to operate in uint8_t + const svuint8_t mask_8b = svld1_u8(pred_op, pred_buf); + + const svuint8_t mask_04_8b = svand_n_u8_z(pred_op, mask_8b, 0x01); + const svuint8_t mask_15_8b = svand_n_u8_z(pred_op, mask_8b, 0x04); + const svuint8_t mask_15s_8b = svlsr_n_u8_z(pred_op, mask_15_8b, 1); + const svuint8_t mask_26_8b = svand_n_u8_z(pred_op, mask_8b, 0x10); + const svuint8_t mask_26s_8b = svlsr_n_u8_z(pred_op, mask_26_8b, 2); + const svuint8_t mask_37_8b = svand_n_u8_z(pred_op, mask_8b, 0x40); + const svuint8_t mask_37s_8b = svlsr_n_u8_z(pred_op, mask_37_8b, 3); + + const svuint8_t mask_0347_8b = svorr_u8_z(pred_op, mask_04_8b, mask_37s_8b); + const svuint8_t mask_1256_8b = + svorr_u8_z(pred_op, mask_15s_8b, mask_26s_8b); + const svuint8_t mask_cmb_8b = + svorr_u8_z(pred_op, mask_0347_8b, mask_1256_8b); + + // + const svuint16_t shifts_16b = svdup_u16(0x0400UL); + const svuint8_t shifts_8b = svreinterpret_u8_u16(shifts_16b); + const svuint8_t shifted_8b_m0 = svlsl_u8_z(pred_op, mask_cmb_8b, shifts_8b); + + const svuint8_t zero_8b = svdup_n_u8(0); + + const svuint8_t shifted_8b_m3 = + svorr_u8_z(pred_op, + svuzp1_u8(shifted_8b_m0, zero_8b), + svuzp2_u8(shifted_8b_m0, zero_8b)); + + // write a finished bitmask + svst1_u8(pred_write, res_u8, shifted_8b_m3); +} + +// generate 32-bit bitmask from 8 serialized 32-bit svbool_t values +void +write_bitmask_32_8x(uint8_t* const __restrict res_u8, + const svbool_t pred_op, + const svbool_t pred_write, + const uint8_t* const __restrict pred_buf) { + // perform parallel pext + // 2048b -> 32 bytes mask -> 256 bytes total, 64 uint32_t values + // 512b -> 8 bytes mask -> 64 bytes total, 16 uint32_t values + // 256b -> 4 bytes mask -> 32 bytes total, 8 uint32_t values + // 128b -> 2 bytes mask -> 16 bytes total, 4 uint32_t values + + // this code does reduction of 32-bit 0b000A000B000C000D... dwords into + // uint8_t values 0bABCDEFGH, then writes ones to the memory + + // we need to operate in uint8_t + const svuint8_t mask_8b = svld1_u8(pred_op, pred_buf); + + const svuint8_t mask_024_8b = svand_n_u8_z(pred_op, mask_8b, 0x01); + const svuint8_t mask_135s_8b = svlsr_n_u8_z(pred_op, mask_8b, 3); + const svuint8_t mask_cmb_8b = + svorr_u8_z(pred_op, mask_024_8b, mask_135s_8b); + + // + const svuint32_t shifts_32b = svdup_u32(0x06040200UL); + const svuint8_t shifts_8b = svreinterpret_u8_u32(shifts_32b); + const svuint8_t shifted_8b_m0 = svlsl_u8_z(pred_op, mask_cmb_8b, shifts_8b); + + const svuint8_t zero_8b = svdup_n_u8(0); + + const svuint8_t shifted_8b_m2 = + svorr_u8_z(pred_op, + svuzp1_u8(shifted_8b_m0, zero_8b), + svuzp2_u8(shifted_8b_m0, zero_8b)); + const svuint8_t shifted_8b_m3 = + svorr_u8_z(pred_op, + svuzp1_u8(shifted_8b_m2, zero_8b), + svuzp2_u8(shifted_8b_m2, zero_8b)); + + // write a finished bitmask + svst1_u8(pred_write, res_u8, shifted_8b_m3); +} + +// generate 64-bit bitmask from 8 serialized 64-bit svbool_t values +void +write_bitmask_64_8x(uint8_t* const __restrict res_u8, + const svbool_t pred_op, + const svbool_t pred_write, + const uint8_t* const __restrict pred_buf) { + // perform parallel pext + // 2048b -> 32 bytes mask -> 256 bytes total, 32 uint64_t values + // 512b -> 8 bytes mask -> 64 bytes total, 4 uint64_t values + // 256b -> 4 bytes mask -> 32 bytes total, 2 uint64_t values + // 128b -> 2 bytes mask -> 16 bytes total, 1 uint64_t values + + // this code does reduction of 64-bit 0b0000000A0000000B... qwords into + // uint8_t values 0bABCDEFGH, then writes ones to the memory + + // we need to operate in uint8_t + const svuint8_t mask_8b = svld1_u8(pred_op, pred_buf); + const svuint64_t shifts_64b = svdup_u64(0x706050403020100ULL); + const svuint8_t shifts_8b = svreinterpret_u8_u64(shifts_64b); + const svuint8_t shifted_8b_m0 = svlsl_u8_z(pred_op, mask_8b, shifts_8b); + + const svuint8_t zero_8b = svdup_n_u8(0); + + const svuint8_t shifted_8b_m1 = + svorr_u8_z(pred_op, + svuzp1_u8(shifted_8b_m0, zero_8b), + svuzp2_u8(shifted_8b_m0, zero_8b)); + const svuint8_t shifted_8b_m2 = + svorr_u8_z(pred_op, + svuzp1_u8(shifted_8b_m1, zero_8b), + svuzp2_u8(shifted_8b_m1, zero_8b)); + const svuint8_t shifted_8b_m3 = + svorr_u8_z(pred_op, + svuzp1_u8(shifted_8b_m2, zero_8b), + svuzp2_u8(shifted_8b_m2, zero_8b)); + + // write a finished bitmask + svst1_u8(pred_write, res_u8, shifted_8b_m3); +} + +/////////////////////////////////////////////////////////////////////////// + +// +inline svbool_t +get_pred_op_8(const size_t n_elements) { + const svbool_t pred_all_8 = svptrue_b8(); + const svuint8_t lanes_8 = svld1_u8(pred_all_8, SVE_LANES_8); + const svuint8_t leftovers_op = svdup_n_u8(n_elements); + const svbool_t pred_op = svcmpgt_u8(pred_all_8, leftovers_op, lanes_8); + return pred_op; +} + +// +inline svbool_t +get_pred_op_16(const size_t n_elements) { + const svbool_t pred_all_16 = svptrue_b16(); + const svuint16_t lanes_16 = svld1_u16(pred_all_16, SVE_LANES_16); + const svuint16_t leftovers_op = svdup_n_u16(n_elements); + const svbool_t pred_op = svcmpgt_u16(pred_all_16, leftovers_op, lanes_16); + return pred_op; +} + +// +inline svbool_t +get_pred_op_32(const size_t n_elements) { + const svbool_t pred_all_32 = svptrue_b32(); + const svuint32_t lanes_32 = svld1_u32(pred_all_32, SVE_LANES_32); + const svuint32_t leftovers_op = svdup_n_u32(n_elements); + const svbool_t pred_op = svcmpgt_u32(pred_all_32, leftovers_op, lanes_32); + return pred_op; +} + +// +inline svbool_t +get_pred_op_64(const size_t n_elements) { + const svbool_t pred_all_64 = svptrue_b64(); + const svuint64_t lanes_64 = svld1_u64(pred_all_64, SVE_LANES_64); + const svuint64_t leftovers_op = svdup_n_u64(n_elements); + const svbool_t pred_op = svcmpgt_u64(pred_all_64, leftovers_op, lanes_64); + return pred_op; +} + +// +template +struct GetPredHelper {}; + +template <> +struct GetPredHelper { + inline static svbool_t + get_pred_op(const size_t n_elements) { + return get_pred_op_8(n_elements); + } +}; + +template <> +struct GetPredHelper { + inline static svbool_t + get_pred_op(const size_t n_elements) { + return get_pred_op_16(n_elements); + } +}; + +template <> +struct GetPredHelper { + inline static svbool_t + get_pred_op(const size_t n_elements) { + return get_pred_op_32(n_elements); + } +}; + +template <> +struct GetPredHelper { + inline static svbool_t + get_pred_op(const size_t n_elements) { + return get_pred_op_64(n_elements); + } +}; + +template <> +struct GetPredHelper { + inline static svbool_t + get_pred_op(const size_t n_elements) { + return get_pred_op_32(n_elements); + } +}; + +template <> +struct GetPredHelper { + inline static svbool_t + get_pred_op(const size_t n_elements) { + return get_pred_op_64(n_elements); + } +}; + +template +inline svbool_t +get_pred_op(const size_t n_elements) { + return GetPredHelper::get_pred_op(n_elements); +} + +/////////////////////////////////////////////////////////////////////////// + +// +template +struct CmpHelper {}; + +template <> +struct CmpHelper { + static inline svbool_t + compare(const svbool_t pred, const svint8_t a, const svint8_t b) { + return svcmpeq_s8(pred, a, b); + } + + static inline svbool_t + compare(const svbool_t pred, const svint16_t a, const svint16_t b) { + return svcmpeq_s16(pred, a, b); + } + + static inline svbool_t + compare(const svbool_t pred, const svint32_t a, const svint32_t b) { + return svcmpeq_s32(pred, a, b); + } + + static inline svbool_t + compare(const svbool_t pred, const svint64_t a, const svint64_t b) { + return svcmpeq_s64(pred, a, b); + } + + static inline svbool_t + compare(const svbool_t pred, const svfloat32_t a, const svfloat32_t b) { + return svcmpeq_f32(pred, a, b); + } + + static inline svbool_t + compare(const svbool_t pred, const svfloat64_t a, const svfloat64_t b) { + return svcmpeq_f64(pred, a, b); + } +}; + +template <> +struct CmpHelper { + static inline svbool_t + compare(const svbool_t pred, const svint8_t a, const svint8_t b) { + return svcmpge_s8(pred, a, b); + } + + static inline svbool_t + compare(const svbool_t pred, const svint16_t a, const svint16_t b) { + return svcmpge_s16(pred, a, b); + } + + static inline svbool_t + compare(const svbool_t pred, const svint32_t a, const svint32_t b) { + return svcmpge_s32(pred, a, b); + } + + static inline svbool_t + compare(const svbool_t pred, const svint64_t a, const svint64_t b) { + return svcmpge_s64(pred, a, b); + } + + static inline svbool_t + compare(const svbool_t pred, const svfloat32_t a, const svfloat32_t b) { + return svcmpge_f32(pred, a, b); + } + + static inline svbool_t + compare(const svbool_t pred, const svfloat64_t a, const svfloat64_t b) { + return svcmpge_f64(pred, a, b); + } +}; + +template <> +struct CmpHelper { + static inline svbool_t + compare(const svbool_t pred, const svint8_t a, const svint8_t b) { + return svcmpgt_s8(pred, a, b); + } + + static inline svbool_t + compare(const svbool_t pred, const svint16_t a, const svint16_t b) { + return svcmpgt_s16(pred, a, b); + } + + static inline svbool_t + compare(const svbool_t pred, const svint32_t a, const svint32_t b) { + return svcmpgt_s32(pred, a, b); + } + + static inline svbool_t + compare(const svbool_t pred, const svint64_t a, const svint64_t b) { + return svcmpgt_s64(pred, a, b); + } + + static inline svbool_t + compare(const svbool_t pred, const svfloat32_t a, const svfloat32_t b) { + return svcmpgt_f32(pred, a, b); + } + + static inline svbool_t + compare(const svbool_t pred, const svfloat64_t a, const svfloat64_t b) { + return svcmpgt_f64(pred, a, b); + } +}; + +template <> +struct CmpHelper { + static inline svbool_t + compare(const svbool_t pred, const svint8_t a, const svint8_t b) { + return svcmple_s8(pred, a, b); + } + + static inline svbool_t + compare(const svbool_t pred, const svint16_t a, const svint16_t b) { + return svcmple_s16(pred, a, b); + } + + static inline svbool_t + compare(const svbool_t pred, const svint32_t a, const svint32_t b) { + return svcmple_s32(pred, a, b); + } + + static inline svbool_t + compare(const svbool_t pred, const svint64_t a, const svint64_t b) { + return svcmple_s64(pred, a, b); + } + + static inline svbool_t + compare(const svbool_t pred, const svfloat32_t a, const svfloat32_t b) { + return svcmple_f32(pred, a, b); + } + + static inline svbool_t + compare(const svbool_t pred, const svfloat64_t a, const svfloat64_t b) { + return svcmple_f64(pred, a, b); + } +}; + +template <> +struct CmpHelper { + static inline svbool_t + compare(const svbool_t pred, const svint8_t a, const svint8_t b) { + return svcmplt_s8(pred, a, b); + } + + static inline svbool_t + compare(const svbool_t pred, const svint16_t a, const svint16_t b) { + return svcmplt_s16(pred, a, b); + } + + static inline svbool_t + compare(const svbool_t pred, const svint32_t a, const svint32_t b) { + return svcmplt_s32(pred, a, b); + } + + static inline svbool_t + compare(const svbool_t pred, const svint64_t a, const svint64_t b) { + return svcmplt_s64(pred, a, b); + } + + static inline svbool_t + compare(const svbool_t pred, const svfloat32_t a, const svfloat32_t b) { + return svcmplt_f32(pred, a, b); + } + + static inline svbool_t + compare(const svbool_t pred, const svfloat64_t a, const svfloat64_t b) { + return svcmplt_f64(pred, a, b); + } +}; + +template <> +struct CmpHelper { + static inline svbool_t + compare(const svbool_t pred, const svint8_t a, const svint8_t b) { + return svcmpne_s8(pred, a, b); + } + + static inline svbool_t + compare(const svbool_t pred, const svint16_t a, const svint16_t b) { + return svcmpne_s16(pred, a, b); + } + + static inline svbool_t + compare(const svbool_t pred, const svint32_t a, const svint32_t b) { + return svcmpne_s32(pred, a, b); + } + + static inline svbool_t + compare(const svbool_t pred, const svint64_t a, const svint64_t b) { + return svcmpne_s64(pred, a, b); + } + + static inline svbool_t + compare(const svbool_t pred, const svfloat32_t a, const svfloat32_t b) { + return svcmpne_f32(pred, a, b); + } + + static inline svbool_t + compare(const svbool_t pred, const svfloat64_t a, const svfloat64_t b) { + return svcmpne_f64(pred, a, b); + } +}; + +/////////////////////////////////////////////////////////////////////////// + +template +struct SVEVector {}; + +template <> +struct SVEVector { + using data_type = int8_t; + using sve_type = svint8_t; + + // measured in the number of elements that an SVE register can hold + static inline size_t + width() { + return svcntb(); + } + + static inline svbool_t + pred_all() { + return svptrue_b8(); + } + + inline static sve_type + set1(const data_type value) { + return svdup_n_s8(value); + } + + inline static sve_type + load(const svbool_t pred, const data_type* value) { + return svld1_s8(pred, value); + } +}; + +template <> +struct SVEVector { + using data_type = int16_t; + using sve_type = svint16_t; + + // measured in the number of elements that an SVE register can hold + static inline size_t + width() { + return svcnth(); + } + + static inline svbool_t + pred_all() { + return svptrue_b16(); + } + + inline static sve_type + set1(const data_type value) { + return svdup_n_s16(value); + } + + inline static sve_type + load(const svbool_t pred, const data_type* value) { + return svld1_s16(pred, value); + } +}; + +template <> +struct SVEVector { + using data_type = int32_t; + using sve_type = svint32_t; + + // measured in the number of elements that an SVE register can hold + static inline size_t + width() { + return svcntw(); + } + + static inline svbool_t + pred_all() { + return svptrue_b32(); + } + + inline static sve_type + set1(const data_type value) { + return svdup_n_s32(value); + } + + inline static sve_type + load(const svbool_t pred, const data_type* value) { + return svld1_s32(pred, value); + } +}; + +template <> +struct SVEVector { + using data_type = int64_t; + using sve_type = svint64_t; + + // measured in the number of elements that an SVE register can hold + static inline size_t + width() { + return svcntd(); + } + + static inline svbool_t + pred_all() { + return svptrue_b64(); + } + + inline static sve_type + set1(const data_type value) { + return svdup_n_s64(value); + } + + inline static sve_type + load(const svbool_t pred, const data_type* value) { + return svld1_s64(pred, value); + } +}; + +template <> +struct SVEVector { + using data_type = float; + using sve_type = svfloat32_t; + + // measured in the number of elements that an SVE register can hold + static inline size_t + width() { + return svcntw(); + } + + static inline svbool_t + pred_all() { + return svptrue_b32(); + } + + inline static sve_type + set1(const data_type value) { + return svdup_n_f32(value); + } + + inline static sve_type + load(const svbool_t pred, const data_type* value) { + return svld1_f32(pred, value); + } +}; + +template <> +struct SVEVector { + using data_type = double; + using sve_type = svfloat64_t; + + // measured in the number of elements that an SVE register can hold + static inline size_t + width() { + return svcntd(); + } + + static inline svbool_t + pred_all() { + return svptrue_b64(); + } + + inline static sve_type + set1(const data_type value) { + return svdup_n_f64(value); + } + + inline static sve_type + load(const svbool_t pred, const data_type* value) { + return svld1_f64(pred, value); + } +}; + +/////////////////////////////////////////////////////////////////////////// + +// an interesting discussion here: +// https://stackoverflow.com/questions/77834169/what-is-a-fast-fallback-algorithm-which-emulates-pdep-and-pext-in-software + +// SVE2 has bitperm, which contains the implementation of pext + +// todo: replace with pext whenever available + +// +template +struct MaskHelper {}; + +template <> +struct MaskHelper<1> { + static inline void + write(uint8_t* const __restrict bitmask, + const size_t size, + const svbool_t pred0, + const svbool_t pred1, + const svbool_t pred2, + const svbool_t pred3, + const svbool_t pred4, + const svbool_t pred5, + const svbool_t pred6, + const svbool_t pred7) { + const size_t sve_width = svcntb(); + if (sve_width == 8 * sve_width) { + // perform a full write + *((svbool_t*)(bitmask + 0 * sve_width / 8)) = pred0; + *((svbool_t*)(bitmask + 1 * sve_width / 8)) = pred1; + *((svbool_t*)(bitmask + 2 * sve_width / 8)) = pred2; + *((svbool_t*)(bitmask + 3 * sve_width / 8)) = pred3; + *((svbool_t*)(bitmask + 4 * sve_width / 8)) = pred4; + *((svbool_t*)(bitmask + 5 * sve_width / 8)) = pred5; + *((svbool_t*)(bitmask + 6 * sve_width / 8)) = pred6; + *((svbool_t*)(bitmask + 7 * sve_width / 8)) = pred7; + } else { + // perform a partial write + + // this is the buffer for the maximum possible case of 2048 bits + uint8_t pred_buf[MAX_SVE_WIDTH / 8]; + *((volatile svbool_t*)(pred_buf + 0 * sve_width / 8)) = pred0; + *((volatile svbool_t*)(pred_buf + 1 * sve_width / 8)) = pred1; + *((volatile svbool_t*)(pred_buf + 2 * sve_width / 8)) = pred2; + *((volatile svbool_t*)(pred_buf + 3 * sve_width / 8)) = pred3; + *((volatile svbool_t*)(pred_buf + 4 * sve_width / 8)) = pred4; + *((volatile svbool_t*)(pred_buf + 5 * sve_width / 8)) = pred5; + *((volatile svbool_t*)(pred_buf + 6 * sve_width / 8)) = pred6; + *((volatile svbool_t*)(pred_buf + 7 * sve_width / 8)) = pred7; + + // make the write mask + const svbool_t pred_write = get_pred_op_8(size / 8); + + // load the buffer + const svuint8_t mask_u8 = svld1_u8(pred_write, pred_buf); + // write it to the bitmask + svst1_u8(pred_write, bitmask, mask_u8); + } + } +}; + +template <> +struct MaskHelper<2> { + static inline void + write(uint8_t* const __restrict bitmask, + const size_t size, + const svbool_t pred0, + const svbool_t pred1, + const svbool_t pred2, + const svbool_t pred3, + const svbool_t pred4, + const svbool_t pred5, + const svbool_t pred6, + const svbool_t pred7) { + const size_t sve_width = svcnth(); + + // this is the buffer for the maximum possible case of 2048 bits + uint8_t pred_buf[MAX_SVE_WIDTH / 8]; + *((volatile svbool_t*)(pred_buf + 0 * sve_width / 4)) = pred0; + *((volatile svbool_t*)(pred_buf + 1 * sve_width / 4)) = pred1; + *((volatile svbool_t*)(pred_buf + 2 * sve_width / 4)) = pred2; + *((volatile svbool_t*)(pred_buf + 3 * sve_width / 4)) = pred3; + *((volatile svbool_t*)(pred_buf + 4 * sve_width / 4)) = pred4; + *((volatile svbool_t*)(pred_buf + 5 * sve_width / 4)) = pred5; + *((volatile svbool_t*)(pred_buf + 6 * sve_width / 4)) = pred6; + *((volatile svbool_t*)(pred_buf + 7 * sve_width / 4)) = pred7; + + const svbool_t pred_op_8 = get_pred_op_8(size / 4); + const svbool_t pred_write_8 = get_pred_op_8(size / 8); + write_bitmask_16_8x(bitmask, pred_op_8, pred_write_8, pred_buf); + } +}; + +template <> +struct MaskHelper<4> { + static inline void + write(uint8_t* const __restrict bitmask, + const size_t size, + const svbool_t pred0, + const svbool_t pred1, + const svbool_t pred2, + const svbool_t pred3, + const svbool_t pred4, + const svbool_t pred5, + const svbool_t pred6, + const svbool_t pred7) { + const size_t sve_width = svcntw(); + + // this is the buffer for the maximum possible case of 2048 bits + uint8_t pred_buf[MAX_SVE_WIDTH / 8]; + *((volatile svbool_t*)(pred_buf + 0 * sve_width / 2)) = pred0; + *((volatile svbool_t*)(pred_buf + 1 * sve_width / 2)) = pred1; + *((volatile svbool_t*)(pred_buf + 2 * sve_width / 2)) = pred2; + *((volatile svbool_t*)(pred_buf + 3 * sve_width / 2)) = pred3; + *((volatile svbool_t*)(pred_buf + 4 * sve_width / 2)) = pred4; + *((volatile svbool_t*)(pred_buf + 5 * sve_width / 2)) = pred5; + *((volatile svbool_t*)(pred_buf + 6 * sve_width / 2)) = pred6; + *((volatile svbool_t*)(pred_buf + 7 * sve_width / 2)) = pred7; + + const svbool_t pred_op_8 = get_pred_op_8(size / 2); + const svbool_t pred_write_8 = get_pred_op_8(size / 8); + write_bitmask_32_8x(bitmask, pred_op_8, pred_write_8, pred_buf); + } +}; + +template <> +struct MaskHelper<8> { + static inline void + write(uint8_t* const __restrict bitmask, + const size_t size, + const svbool_t pred0, + const svbool_t pred1, + const svbool_t pred2, + const svbool_t pred3, + const svbool_t pred4, + const svbool_t pred5, + const svbool_t pred6, + const svbool_t pred7) { + const size_t sve_width = svcntd(); + + // this is the buffer for the maximum possible case of 2048 bits + uint8_t pred_buf[MAX_SVE_WIDTH / 8]; + *((volatile svbool_t*)(pred_buf + 0 * sve_width)) = pred0; + *((volatile svbool_t*)(pred_buf + 1 * sve_width)) = pred1; + *((volatile svbool_t*)(pred_buf + 2 * sve_width)) = pred2; + *((volatile svbool_t*)(pred_buf + 3 * sve_width)) = pred3; + *((volatile svbool_t*)(pred_buf + 4 * sve_width)) = pred4; + *((volatile svbool_t*)(pred_buf + 5 * sve_width)) = pred5; + *((volatile svbool_t*)(pred_buf + 6 * sve_width)) = pred6; + *((volatile svbool_t*)(pred_buf + 7 * sve_width)) = pred7; + + const svbool_t pred_op_8 = get_pred_op_8(size / 1); + const svbool_t pred_write_8 = get_pred_op_8(size / 8); + write_bitmask_64_8x(bitmask, pred_op_8, pred_write_8, pred_buf); + } +}; + +/////////////////////////////////////////////////////////////////////////// + +// the facility that handles all bitset processing for SVE +template +bool +op_mask_helper(uint8_t* const __restrict res_u8, const size_t size, Func func) { + // the restriction of the API + assert((size % 8) == 0); + + // + using sve_t = SVEVector; + + // SVE width in elements + const size_t sve_width = sve_t::width(); + assert((sve_width % 8) == 0); + + // process large blocks + const size_t size_sve8 = (size / (8 * sve_width)) * (8 * sve_width); + { + for (size_t i = 0; i < size_sve8; i += 8 * sve_width) { + const svbool_t pred_all = sve_t::pred_all(); + + const svbool_t cmp0 = func(pred_all, i + 0 * sve_width); + const svbool_t cmp1 = func(pred_all, i + 1 * sve_width); + const svbool_t cmp2 = func(pred_all, i + 2 * sve_width); + const svbool_t cmp3 = func(pred_all, i + 3 * sve_width); + const svbool_t cmp4 = func(pred_all, i + 4 * sve_width); + const svbool_t cmp5 = func(pred_all, i + 5 * sve_width); + const svbool_t cmp6 = func(pred_all, i + 6 * sve_width); + const svbool_t cmp7 = func(pred_all, i + 7 * sve_width); + + MaskHelper::write(res_u8 + i / 8, + sve_width * 8, + cmp0, + cmp1, + cmp2, + cmp3, + cmp4, + cmp5, + cmp6, + cmp7); + } + } + + // process leftovers + if (size_sve8 != size) { + auto get_partial_pred = [sve_width, size, size_sve8](const size_t j) { + const size_t start = size_sve8 + j * sve_width; + const size_t end = size_sve8 + (j + 1) * sve_width; + + const size_t amount = (end < size) ? sve_width : (size - start); + const svbool_t pred_op = get_pred_op(amount); + + return pred_op; + }; + + const svbool_t pred_none = svpfalse_b(); + svbool_t cmp0 = pred_none; + svbool_t cmp1 = pred_none; + svbool_t cmp2 = pred_none; + svbool_t cmp3 = pred_none; + svbool_t cmp4 = pred_none; + svbool_t cmp5 = pred_none; + svbool_t cmp6 = pred_none; + svbool_t cmp7 = pred_none; + + const size_t jcount = (size - size_sve8 + sve_width - 1) / sve_width; + if (jcount > 0) { + cmp0 = func(get_partial_pred(0), size_sve8 + 0 * sve_width); + } + if (jcount > 1) { + cmp1 = func(get_partial_pred(1), size_sve8 + 1 * sve_width); + } + if (jcount > 2) { + cmp2 = func(get_partial_pred(2), size_sve8 + 2 * sve_width); + } + if (jcount > 3) { + cmp3 = func(get_partial_pred(3), size_sve8 + 3 * sve_width); + } + if (jcount > 4) { + cmp4 = func(get_partial_pred(4), size_sve8 + 4 * sve_width); + } + if (jcount > 5) { + cmp5 = func(get_partial_pred(5), size_sve8 + 5 * sve_width); + } + if (jcount > 6) { + cmp6 = func(get_partial_pred(6), size_sve8 + 6 * sve_width); + } + if (jcount > 7) { + cmp7 = func(get_partial_pred(7), size_sve8 + 7 * sve_width); + } + + MaskHelper::write(res_u8 + size_sve8 / 8, + size - size_sve8, + cmp0, + cmp1, + cmp2, + cmp3, + cmp4, + cmp5, + cmp6, + cmp7); + } + + return true; +} + +} // namespace + +/////////////////////////////////////////////////////////////////////////// + +namespace { + +template +bool +op_compare_val_impl(uint8_t* const __restrict res_u8, + const T* const __restrict src, + const size_t size, + const T& val) { + auto handler = [src, val](const svbool_t pred, const size_t idx) { + using sve_t = SVEVector; + + const auto target = sve_t::set1(val); + const auto v = sve_t::load(pred, src + idx); + const svbool_t cmp = CmpHelper::compare(pred, v, target); + return cmp; + }; + + return op_mask_helper(res_u8, size, handler); +} + +} // namespace + +// +template +bool +OpCompareValImpl::op_compare_val(uint8_t* const __restrict res_u8, + const int8_t* const __restrict src, + const size_t size, + const int8_t& val) { + return op_compare_val_impl(res_u8, src, size, val); +} + +template +bool +OpCompareValImpl::op_compare_val( + uint8_t* const __restrict res_u8, + const int16_t* const __restrict src, + const size_t size, + const int16_t& val) { + return op_compare_val_impl(res_u8, src, size, val); +} + +template +bool +OpCompareValImpl::op_compare_val( + uint8_t* const __restrict res_u8, + const int32_t* const __restrict src, + const size_t size, + const int32_t& val) { + return op_compare_val_impl(res_u8, src, size, val); +} + +template +bool +OpCompareValImpl::op_compare_val( + uint8_t* const __restrict res_u8, + const int64_t* const __restrict src, + const size_t size, + const int64_t& val) { + return op_compare_val_impl(res_u8, src, size, val); +} + +template +bool +OpCompareValImpl::op_compare_val(uint8_t* const __restrict res_u8, + const float* const __restrict src, + const size_t size, + const float& val) { + return op_compare_val_impl(res_u8, src, size, val); +} + +template +bool +OpCompareValImpl::op_compare_val(uint8_t* const __restrict res_u8, + const double* const __restrict src, + const size_t size, + const double& val) { + return op_compare_val_impl(res_u8, src, size, val); +} + +/////////////////////////////////////////////////////////////////////////// + +namespace { + +template +bool +op_compare_column_impl(uint8_t* const __restrict res_u8, + const T* const __restrict left, + const T* const __restrict right, + const size_t size) { + auto handler = [left, right](const svbool_t pred, const size_t idx) { + using sve_t = SVEVector; + + const auto left_v = sve_t::load(pred, left + idx); + const auto right_v = sve_t::load(pred, right + idx); + const svbool_t cmp = CmpHelper::compare(pred, left_v, right_v); + return cmp; + }; + + return op_mask_helper(res_u8, size, handler); +} + +} // namespace + +// +template +bool +OpCompareColumnImpl::op_compare_column( + uint8_t* const __restrict res_u8, + const int8_t* const __restrict left, + const int8_t* const __restrict right, + const size_t size) { + return op_compare_column_impl(res_u8, left, right, size); +} + +template +bool +OpCompareColumnImpl::op_compare_column( + uint8_t* const __restrict res_u8, + const int16_t* const __restrict left, + const int16_t* const __restrict right, + const size_t size) { + return op_compare_column_impl(res_u8, left, right, size); +} + +template +bool +OpCompareColumnImpl::op_compare_column( + uint8_t* const __restrict res_u8, + const int32_t* const __restrict left, + const int32_t* const __restrict right, + const size_t size) { + return op_compare_column_impl(res_u8, left, right, size); +} + +template +bool +OpCompareColumnImpl::op_compare_column( + uint8_t* const __restrict res_u8, + const int64_t* const __restrict left, + const int64_t* const __restrict right, + const size_t size) { + return op_compare_column_impl(res_u8, left, right, size); +} + +template +bool +OpCompareColumnImpl::op_compare_column( + uint8_t* const __restrict res_u8, + const float* const __restrict left, + const float* const __restrict right, + const size_t size) { + return op_compare_column_impl(res_u8, left, right, size); +} + +template +bool +OpCompareColumnImpl::op_compare_column( + uint8_t* const __restrict res_u8, + const double* const __restrict left, + const double* const __restrict right, + const size_t size) { + return op_compare_column_impl(res_u8, left, right, size); +} + +/////////////////////////////////////////////////////////////////////////// + +namespace { + +template +bool +op_within_range_column_impl(uint8_t* const __restrict res_u8, + const T* const __restrict lower, + const T* const __restrict upper, + const T* const __restrict values, + const size_t size) { + auto handler = [lower, upper, values](const svbool_t pred, + const size_t idx) { + using sve_t = SVEVector; + + const auto lower_v = sve_t::load(pred, lower + idx); + const auto upper_v = sve_t::load(pred, upper + idx); + const auto values_v = sve_t::load(pred, values + idx); + + const svbool_t cmpl = CmpHelper::lower>::compare( + pred, lower_v, values_v); + const svbool_t cmpu = CmpHelper::upper>::compare( + pred, values_v, upper_v); + const svbool_t cmp = svand_b_z(pred, cmpl, cmpu); + + return cmp; + }; + + return op_mask_helper(res_u8, size, handler); +} + +} // namespace + +// +template +bool +OpWithinRangeColumnImpl::op_within_range_column( + uint8_t* const __restrict res_u8, + const int8_t* const __restrict lower, + const int8_t* const __restrict upper, + const int8_t* const __restrict values, + const size_t size) { + return op_within_range_column_impl( + res_u8, lower, upper, values, size); +} + +template +bool +OpWithinRangeColumnImpl::op_within_range_column( + uint8_t* const __restrict res_u8, + const int16_t* const __restrict lower, + const int16_t* const __restrict upper, + const int16_t* const __restrict values, + const size_t size) { + return op_within_range_column_impl( + res_u8, lower, upper, values, size); +} + +template +bool +OpWithinRangeColumnImpl::op_within_range_column( + uint8_t* const __restrict res_u8, + const int32_t* const __restrict lower, + const int32_t* const __restrict upper, + const int32_t* const __restrict values, + const size_t size) { + return op_within_range_column_impl( + res_u8, lower, upper, values, size); +} + +template +bool +OpWithinRangeColumnImpl::op_within_range_column( + uint8_t* const __restrict res_u8, + const int64_t* const __restrict lower, + const int64_t* const __restrict upper, + const int64_t* const __restrict values, + const size_t size) { + return op_within_range_column_impl( + res_u8, lower, upper, values, size); +} + +template +bool +OpWithinRangeColumnImpl::op_within_range_column( + uint8_t* const __restrict res_u8, + const float* const __restrict lower, + const float* const __restrict upper, + const float* const __restrict values, + const size_t size) { + return op_within_range_column_impl( + res_u8, lower, upper, values, size); +} + +template +bool +OpWithinRangeColumnImpl::op_within_range_column( + uint8_t* const __restrict res_u8, + const double* const __restrict lower, + const double* const __restrict upper, + const double* const __restrict values, + const size_t size) { + return op_within_range_column_impl( + res_u8, lower, upper, values, size); +} + +/////////////////////////////////////////////////////////////////////////// + +namespace { + +template +bool +op_within_range_val_impl(uint8_t* const __restrict res_u8, + const T& lower, + const T& upper, + const T* const __restrict values, + const size_t size) { + auto handler = [lower, upper, values](const svbool_t pred, + const size_t idx) { + using sve_t = SVEVector; + + const auto lower_v = sve_t::set1(lower); + const auto upper_v = sve_t::set1(upper); + const auto values_v = sve_t::load(pred, values + idx); + + const svbool_t cmpl = CmpHelper::lower>::compare( + pred, lower_v, values_v); + const svbool_t cmpu = CmpHelper::upper>::compare( + pred, values_v, upper_v); + const svbool_t cmp = svand_b_z(pred, cmpl, cmpu); + + return cmp; + }; + + return op_mask_helper(res_u8, size, handler); +} + +} // namespace + +// +template +bool +OpWithinRangeValImpl::op_within_range_val( + uint8_t* const __restrict res_u8, + const int8_t& lower, + const int8_t& upper, + const int8_t* const __restrict values, + const size_t size) { + return op_within_range_val_impl( + res_u8, lower, upper, values, size); +} + +template +bool +OpWithinRangeValImpl::op_within_range_val( + uint8_t* const __restrict res_u8, + const int16_t& lower, + const int16_t& upper, + const int16_t* const __restrict values, + const size_t size) { + return op_within_range_val_impl( + res_u8, lower, upper, values, size); +} + +template +bool +OpWithinRangeValImpl::op_within_range_val( + uint8_t* const __restrict res_u8, + const int32_t& lower, + const int32_t& upper, + const int32_t* const __restrict values, + const size_t size) { + return op_within_range_val_impl( + res_u8, lower, upper, values, size); +} + +template +bool +OpWithinRangeValImpl::op_within_range_val( + uint8_t* const __restrict res_u8, + const int64_t& lower, + const int64_t& upper, + const int64_t* const __restrict values, + const size_t size) { + return op_within_range_val_impl( + res_u8, lower, upper, values, size); +} + +template +bool +OpWithinRangeValImpl::op_within_range_val( + uint8_t* const __restrict res_u8, + const float& lower, + const float& upper, + const float* const __restrict values, + const size_t size) { + return op_within_range_val_impl( + res_u8, lower, upper, values, size); +} + +template +bool +OpWithinRangeValImpl::op_within_range_val( + uint8_t* const __restrict res_u8, + const double& lower, + const double& upper, + const double* const __restrict values, + const size_t size) { + return op_within_range_val_impl( + res_u8, lower, upper, values, size); +} + +/////////////////////////////////////////////////////////////////////////// + +namespace { + +template +struct ArithHelperI64 {}; + +template +struct ArithHelperI64 { + static inline svbool_t + op(const svbool_t pred, + const svint64_t left, + const svint64_t right, + const svint64_t value) { + // left + right == value + return CmpHelper::compare( + pred, svadd_s64_z(pred, left, right), value); + } +}; + +template +struct ArithHelperI64 { + static inline svbool_t + op(const svbool_t pred, + const svint64_t left, + const svint64_t right, + const svint64_t value) { + // left - right == value + return CmpHelper::compare( + pred, svsub_s64_z(pred, left, right), value); + } +}; + +template +struct ArithHelperI64 { + static inline svbool_t + op(const svbool_t pred, + const svint64_t left, + const svint64_t right, + const svint64_t value) { + // left * right == value + return CmpHelper::compare( + pred, svmul_s64_z(pred, left, right), value); + } +}; + +template +struct ArithHelperI64 { + static inline svbool_t + op(const svbool_t pred, + const svint64_t left, + const svint64_t right, + const svint64_t value) { + // left / right == value + return CmpHelper::compare( + pred, svdiv_s64_z(pred, left, right), value); + } +}; + +// +template +struct ArithHelperF32 {}; + +template +struct ArithHelperF32 { + static inline svbool_t + op(const svbool_t pred, + const svfloat32_t left, + const svfloat32_t right, + const svfloat32_t value) { + // left + right == value + return CmpHelper::compare( + pred, svadd_f32_z(pred, left, right), value); + } +}; + +template +struct ArithHelperF32 { + static inline svbool_t + op(const svbool_t pred, + const svfloat32_t left, + const svfloat32_t right, + const svfloat32_t value) { + // left - right == value + return CmpHelper::compare( + pred, svsub_f32_z(pred, left, right), value); + } +}; + +template +struct ArithHelperF32 { + static inline svbool_t + op(const svbool_t pred, + const svfloat32_t left, + const svfloat32_t right, + const svfloat32_t value) { + // left * right == value + return CmpHelper::compare( + pred, svmul_f32_z(pred, left, right), value); + } +}; + +template +struct ArithHelperF32 { + static inline svbool_t + op(const svbool_t pred, + const svfloat32_t left, + const svfloat32_t right, + const svfloat32_t value) { + // left == right * value + return CmpHelper::compare( + pred, left, svmul_f32_z(pred, right, value)); + } +}; + +// +template +struct ArithHelperF64 {}; + +template +struct ArithHelperF64 { + static inline svbool_t + op(const svbool_t pred, + const svfloat64_t left, + const svfloat64_t right, + const svfloat64_t value) { + // left + right == value + return CmpHelper::compare( + pred, svadd_f64_z(pred, left, right), value); + } +}; + +template +struct ArithHelperF64 { + static inline svbool_t + op(const svbool_t pred, + const svfloat64_t left, + const svfloat64_t right, + const svfloat64_t value) { + // left - right == value + return CmpHelper::compare( + pred, svsub_f64_z(pred, left, right), value); + } +}; + +template +struct ArithHelperF64 { + static inline svbool_t + op(const svbool_t pred, + const svfloat64_t left, + const svfloat64_t right, + const svfloat64_t value) { + // left * right == value + return CmpHelper::compare( + pred, svmul_f64_z(pred, left, right), value); + } +}; + +template +struct ArithHelperF64 { + static inline svbool_t + op(const svbool_t pred, + const svfloat64_t left, + const svfloat64_t right, + const svfloat64_t value) { + // left == right * value + return CmpHelper::compare( + pred, left, svmul_f64_z(pred, right, value)); + } +}; + +} // namespace + +// todo: Mod + +template +bool +OpArithCompareImpl::op_arith_compare( + uint8_t* const __restrict res_u8, + const int8_t* const __restrict src, + const ArithHighPrecisionType& right_operand, + const ArithHighPrecisionType& value, + const size_t size) { + if constexpr (AOp == ArithOpType::Mod) { + return false; + } else { + using T = int64_t; + + auto handler = [src, right_operand, value](const svbool_t pred, + const size_t idx) { + using sve_t = SVEVector; + + const auto right_v = svdup_n_s64(right_operand); + const auto value_v = svdup_n_s64(value); + const svint64_t src_v = svld1sb_s64(pred, src + idx); + + const svbool_t cmp = + ArithHelperI64::op(pred, src_v, right_v, value_v); + return cmp; + }; + + return op_mask_helper(res_u8, size, handler); + } +} + +template +bool +OpArithCompareImpl::op_arith_compare( + uint8_t* const __restrict res_u8, + const int16_t* const __restrict src, + const ArithHighPrecisionType& right_operand, + const ArithHighPrecisionType& value, + const size_t size) { + if constexpr (AOp == ArithOpType::Mod) { + return false; + } else { + using T = int64_t; + + auto handler = [src, right_operand, value](const svbool_t pred, + const size_t idx) { + using sve_t = SVEVector; + + const auto right_v = svdup_n_s64(right_operand); + const auto value_v = svdup_n_s64(value); + const svint64_t src_v = svld1sh_s64(pred, src + idx); + + const svbool_t cmp = + ArithHelperI64::op(pred, src_v, right_v, value_v); + return cmp; + }; + + return op_mask_helper(res_u8, size, handler); + } +} + +template +bool +OpArithCompareImpl::op_arith_compare( + uint8_t* const __restrict res_u8, + const int32_t* const __restrict src, + const ArithHighPrecisionType& right_operand, + const ArithHighPrecisionType& value, + const size_t size) { + if constexpr (AOp == ArithOpType::Mod) { + return false; + } else { + using T = int64_t; + + auto handler = [src, right_operand, value](const svbool_t pred, + const size_t idx) { + using sve_t = SVEVector; + + const auto right_v = svdup_n_s64(right_operand); + const auto value_v = svdup_n_s64(value); + const svint64_t src_v = svld1sw_s64(pred, src + idx); + + const svbool_t cmp = + ArithHelperI64::op(pred, src_v, right_v, value_v); + return cmp; + }; + + return op_mask_helper(res_u8, size, handler); + } +} + +template +bool +OpArithCompareImpl::op_arith_compare( + uint8_t* const __restrict res_u8, + const int64_t* const __restrict src, + const ArithHighPrecisionType& right_operand, + const ArithHighPrecisionType& value, + const size_t size) { + if constexpr (AOp == ArithOpType::Mod) { + return false; + } else { + using T = int64_t; + + auto handler = [src, right_operand, value](const svbool_t pred, + const size_t idx) { + using sve_t = SVEVector; + + const auto right_v = svdup_n_s64(right_operand); + const auto value_v = svdup_n_s64(value); + const svint64_t src_v = svld1_s64(pred, src + idx); + + const svbool_t cmp = + ArithHelperI64::op(pred, src_v, right_v, value_v); + return cmp; + }; + + return op_mask_helper(res_u8, size, handler); + } +} + +template +bool +OpArithCompareImpl::op_arith_compare( + uint8_t* const __restrict res_u8, + const float* const __restrict src, + const ArithHighPrecisionType& right_operand, + const ArithHighPrecisionType& value, + const size_t size) { + if constexpr (AOp == ArithOpType::Mod) { + return false; + } else { + using T = float; + + auto handler = [src, right_operand, value](const svbool_t pred, + const size_t idx) { + using sve_t = SVEVector; + + const auto right_v = svdup_n_f32(right_operand); + const auto value_v = svdup_n_f32(value); + const svfloat32_t src_v = svld1_f32(pred, src + idx); + + const svbool_t cmp = + ArithHelperF32::op(pred, src_v, right_v, value_v); + return cmp; + }; + + return op_mask_helper(res_u8, size, handler); + } +} + +template +bool +OpArithCompareImpl::op_arith_compare( + uint8_t* const __restrict res_u8, + const double* const __restrict src, + const ArithHighPrecisionType& right_operand, + const ArithHighPrecisionType& value, + const size_t size) { + if constexpr (AOp == ArithOpType::Mod) { + return false; + } else { + using T = double; + + auto handler = [src, right_operand, value](const svbool_t pred, + const size_t idx) { + using sve_t = SVEVector; + + const auto right_v = svdup_n_f64(right_operand); + const auto value_v = svdup_n_f64(value); + const svfloat64_t src_v = svld1_f64(pred, src + idx); + + const svbool_t cmp = + ArithHelperF64::op(pred, src_v, right_v, value_v); + return cmp; + }; + + return op_mask_helper(res_u8, size, handler); + } +} + +/////////////////////////////////////////////////////////////////////////// + +} // namespace sve +} // namespace arm +} // namespace detail +} // namespace bitset +} // namespace milvus diff --git a/internal/core/src/bitset/detail/platform/arm/sve-inst.cpp b/internal/core/src/bitset/detail/platform/arm/sve-inst.cpp new file mode 100644 index 000000000000..ae2dd946a93a --- /dev/null +++ b/internal/core/src/bitset/detail/platform/arm/sve-inst.cpp @@ -0,0 +1,199 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// ARM SVE instantiation + +#include "bitset/common.h" + +#ifndef BITSET_HEADER_ONLY + +#include "sve-decl.h" +#include "sve-impl.h" + +#include +#include + +namespace milvus { +namespace bitset { +namespace detail { +namespace arm { +namespace sve { + +// a facility to run through all possible compare operations +#define ALL_COMPARE_OPS(FUNC, ...) \ + FUNC(__VA_ARGS__, EQ); \ + FUNC(__VA_ARGS__, GE); \ + FUNC(__VA_ARGS__, GT); \ + FUNC(__VA_ARGS__, LE); \ + FUNC(__VA_ARGS__, LT); \ + FUNC(__VA_ARGS__, NE); + +// a facility to run through all possible range operations +#define ALL_RANGE_OPS(FUNC, ...) \ + FUNC(__VA_ARGS__, IncInc); \ + FUNC(__VA_ARGS__, IncExc); \ + FUNC(__VA_ARGS__, ExcInc); \ + FUNC(__VA_ARGS__, ExcExc); + +// a facility to run through all possible arithmetic compare operations +#define ALL_ARITH_CMP_OPS(FUNC, ...) \ + FUNC(__VA_ARGS__, Add, EQ); \ + FUNC(__VA_ARGS__, Add, GE); \ + FUNC(__VA_ARGS__, Add, GT); \ + FUNC(__VA_ARGS__, Add, LE); \ + FUNC(__VA_ARGS__, Add, LT); \ + FUNC(__VA_ARGS__, Add, NE); \ + FUNC(__VA_ARGS__, Sub, EQ); \ + FUNC(__VA_ARGS__, Sub, GE); \ + FUNC(__VA_ARGS__, Sub, GT); \ + FUNC(__VA_ARGS__, Sub, LE); \ + FUNC(__VA_ARGS__, Sub, LT); \ + FUNC(__VA_ARGS__, Sub, NE); \ + FUNC(__VA_ARGS__, Mul, EQ); \ + FUNC(__VA_ARGS__, Mul, GE); \ + FUNC(__VA_ARGS__, Mul, GT); \ + FUNC(__VA_ARGS__, Mul, LE); \ + FUNC(__VA_ARGS__, Mul, LT); \ + FUNC(__VA_ARGS__, Mul, NE); \ + FUNC(__VA_ARGS__, Div, EQ); \ + FUNC(__VA_ARGS__, Div, GE); \ + FUNC(__VA_ARGS__, Div, GT); \ + FUNC(__VA_ARGS__, Div, LE); \ + FUNC(__VA_ARGS__, Div, LT); \ + FUNC(__VA_ARGS__, Div, NE); \ + FUNC(__VA_ARGS__, Mod, EQ); \ + FUNC(__VA_ARGS__, Mod, GE); \ + FUNC(__VA_ARGS__, Mod, GT); \ + FUNC(__VA_ARGS__, Mod, LE); \ + FUNC(__VA_ARGS__, Mod, LT); \ + FUNC(__VA_ARGS__, Mod, NE); + +/////////////////////////////////////////////////////////////////////////// + +// +#define INSTANTIATE_COMPARE_VAL_SVE(TTYPE, OP) \ + template bool OpCompareValImpl::op_compare_val( \ + uint8_t* const __restrict bitmask, \ + const TTYPE* const __restrict src, \ + const size_t size, \ + const TTYPE& val); + +ALL_COMPARE_OPS(INSTANTIATE_COMPARE_VAL_SVE, int8_t) +ALL_COMPARE_OPS(INSTANTIATE_COMPARE_VAL_SVE, int16_t) +ALL_COMPARE_OPS(INSTANTIATE_COMPARE_VAL_SVE, int32_t) +ALL_COMPARE_OPS(INSTANTIATE_COMPARE_VAL_SVE, int64_t) +ALL_COMPARE_OPS(INSTANTIATE_COMPARE_VAL_SVE, float) +ALL_COMPARE_OPS(INSTANTIATE_COMPARE_VAL_SVE, double) + +#undef INSTANTIATE_COMPARE_VAL_SVE + +/////////////////////////////////////////////////////////////////////////// + +// +#define INSTANTIATE_COMPARE_COLUMN_SVE(TTYPE, OP) \ + template bool \ + OpCompareColumnImpl::op_compare_column( \ + uint8_t* const __restrict bitmask, \ + const TTYPE* const __restrict left, \ + const TTYPE* const __restrict right, \ + const size_t size); + +ALL_COMPARE_OPS(INSTANTIATE_COMPARE_COLUMN_SVE, int8_t) +ALL_COMPARE_OPS(INSTANTIATE_COMPARE_COLUMN_SVE, int16_t) +ALL_COMPARE_OPS(INSTANTIATE_COMPARE_COLUMN_SVE, int32_t) +ALL_COMPARE_OPS(INSTANTIATE_COMPARE_COLUMN_SVE, int64_t) +ALL_COMPARE_OPS(INSTANTIATE_COMPARE_COLUMN_SVE, float) +ALL_COMPARE_OPS(INSTANTIATE_COMPARE_COLUMN_SVE, double) + +#undef INSTANTIATE_COMPARE_COLUMN_SVE + +/////////////////////////////////////////////////////////////////////////// + +// +#define INSTANTIATE_WITHIN_RANGE_COLUMN_SVE(TTYPE, OP) \ + template bool \ + OpWithinRangeColumnImpl::op_within_range_column( \ + uint8_t* const __restrict res_u8, \ + const TTYPE* const __restrict lower, \ + const TTYPE* const __restrict upper, \ + const TTYPE* const __restrict values, \ + const size_t size); + +ALL_RANGE_OPS(INSTANTIATE_WITHIN_RANGE_COLUMN_SVE, int8_t) +ALL_RANGE_OPS(INSTANTIATE_WITHIN_RANGE_COLUMN_SVE, int16_t) +ALL_RANGE_OPS(INSTANTIATE_WITHIN_RANGE_COLUMN_SVE, int32_t) +ALL_RANGE_OPS(INSTANTIATE_WITHIN_RANGE_COLUMN_SVE, int64_t) +ALL_RANGE_OPS(INSTANTIATE_WITHIN_RANGE_COLUMN_SVE, float) +ALL_RANGE_OPS(INSTANTIATE_WITHIN_RANGE_COLUMN_SVE, double) + +#undef INSTANTIATE_WITHIN_RANGE_COLUMN_SVE + +/////////////////////////////////////////////////////////////////////////// + +// +#define INSTANTIATE_WITHIN_RANGE_VAL_SVE(TTYPE, OP) \ + template bool \ + OpWithinRangeValImpl::op_within_range_val( \ + uint8_t* const __restrict res_u8, \ + const TTYPE& lower, \ + const TTYPE& upper, \ + const TTYPE* const __restrict values, \ + const size_t size); + +ALL_RANGE_OPS(INSTANTIATE_WITHIN_RANGE_VAL_SVE, int8_t) +ALL_RANGE_OPS(INSTANTIATE_WITHIN_RANGE_VAL_SVE, int16_t) +ALL_RANGE_OPS(INSTANTIATE_WITHIN_RANGE_VAL_SVE, int32_t) +ALL_RANGE_OPS(INSTANTIATE_WITHIN_RANGE_VAL_SVE, int64_t) +ALL_RANGE_OPS(INSTANTIATE_WITHIN_RANGE_VAL_SVE, float) +ALL_RANGE_OPS(INSTANTIATE_WITHIN_RANGE_VAL_SVE, double) + +#undef INSTANTIATE_WITHIN_RANGE_VAL_SVE + +/////////////////////////////////////////////////////////////////////////// + +// +#define INSTANTIATE_ARITH_COMPARE_SVE(TTYPE, OP, CMP) \ + template bool \ + OpArithCompareImpl:: \ + op_arith_compare(uint8_t* const __restrict res_u8, \ + const TTYPE* const __restrict src, \ + const ArithHighPrecisionType& right_operand, \ + const ArithHighPrecisionType& value, \ + const size_t size); + +ALL_ARITH_CMP_OPS(INSTANTIATE_ARITH_COMPARE_SVE, int8_t) +ALL_ARITH_CMP_OPS(INSTANTIATE_ARITH_COMPARE_SVE, int16_t) +ALL_ARITH_CMP_OPS(INSTANTIATE_ARITH_COMPARE_SVE, int32_t) +ALL_ARITH_CMP_OPS(INSTANTIATE_ARITH_COMPARE_SVE, int64_t) +ALL_ARITH_CMP_OPS(INSTANTIATE_ARITH_COMPARE_SVE, float) +ALL_ARITH_CMP_OPS(INSTANTIATE_ARITH_COMPARE_SVE, double) + +#undef INSTANTIATE_ARITH_COMPARE_SVE + +/////////////////////////////////////////////////////////////////////////// + +// +#undef ALL_COMPARE_OPS +#undef ALL_RANGE_OPS +#undef ALL_ARITH_CMP_OPS + +} // namespace sve +} // namespace arm +} // namespace detail +} // namespace bitset +} // namespace milvus + +#endif diff --git a/internal/core/src/bitset/detail/platform/arm/sve.h b/internal/core/src/bitset/detail/platform/arm/sve.h new file mode 100644 index 000000000000..615431373dcf --- /dev/null +++ b/internal/core/src/bitset/detail/platform/arm/sve.h @@ -0,0 +1,63 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "bitset/common.h" + +#include "sve-decl.h" + +#ifdef BITSET_HEADER_ONLY +#include "sve-impl.h" +#endif + +namespace milvus { +namespace bitset { +namespace detail { +namespace arm { + +/////////////////////////////////////////////////////////////////////////// + +// +struct VectorizedSve { + template + static constexpr inline auto op_compare_column = + sve::OpCompareColumnImpl::op_compare_column; + + template + static constexpr inline auto op_compare_val = + sve::OpCompareValImpl::op_compare_val; + + template + static constexpr inline auto op_within_range_column = + sve::OpWithinRangeColumnImpl::op_within_range_column; + + template + static constexpr inline auto op_within_range_val = + sve::OpWithinRangeValImpl::op_within_range_val; + + template + static constexpr inline auto op_arith_compare = + sve::OpArithCompareImpl::op_arith_compare; +}; + +} // namespace arm +} // namespace detail +} // namespace bitset +} // namespace milvus diff --git a/internal/core/src/bitset/detail/platform/dynamic.cpp b/internal/core/src/bitset/detail/platform/dynamic.cpp new file mode 100644 index 000000000000..0b9f35bcac29 --- /dev/null +++ b/internal/core/src/bitset/detail/platform/dynamic.cpp @@ -0,0 +1,625 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "dynamic.h" + +#include +#include +#include + +#if defined(__x86_64__) +#include "x86/instruction_set.h" +#include "x86/avx2.h" +#include "x86/avx512.h" + +using namespace milvus::bitset::detail::x86; +#endif + +#if defined(__aarch64__) +#include "arm/neon.h" + +#ifdef __ARM_FEATURE_SVE +#include "arm/sve.h" +#endif + +using namespace milvus::bitset::detail::arm; + +#endif + +#include "vectorized_ref.h" + +// a facility to run through all possible compare operations +#define ALL_COMPARE_OPS(FUNC, ...) \ + FUNC(__VA_ARGS__, EQ); \ + FUNC(__VA_ARGS__, GE); \ + FUNC(__VA_ARGS__, GT); \ + FUNC(__VA_ARGS__, LE); \ + FUNC(__VA_ARGS__, LT); \ + FUNC(__VA_ARGS__, NE); + +// a facility to run through all possible range operations +#define ALL_RANGE_OPS(FUNC, ...) \ + FUNC(__VA_ARGS__, IncInc); \ + FUNC(__VA_ARGS__, IncExc); \ + FUNC(__VA_ARGS__, ExcInc); \ + FUNC(__VA_ARGS__, ExcExc); + +// a facility to run through all possible arithmetic compare operations +#define ALL_ARITH_CMP_OPS(FUNC, ...) \ + FUNC(__VA_ARGS__, Add, EQ); \ + FUNC(__VA_ARGS__, Add, GE); \ + FUNC(__VA_ARGS__, Add, GT); \ + FUNC(__VA_ARGS__, Add, LE); \ + FUNC(__VA_ARGS__, Add, LT); \ + FUNC(__VA_ARGS__, Add, NE); \ + FUNC(__VA_ARGS__, Sub, EQ); \ + FUNC(__VA_ARGS__, Sub, GE); \ + FUNC(__VA_ARGS__, Sub, GT); \ + FUNC(__VA_ARGS__, Sub, LE); \ + FUNC(__VA_ARGS__, Sub, LT); \ + FUNC(__VA_ARGS__, Sub, NE); \ + FUNC(__VA_ARGS__, Mul, EQ); \ + FUNC(__VA_ARGS__, Mul, GE); \ + FUNC(__VA_ARGS__, Mul, GT); \ + FUNC(__VA_ARGS__, Mul, LE); \ + FUNC(__VA_ARGS__, Mul, LT); \ + FUNC(__VA_ARGS__, Mul, NE); \ + FUNC(__VA_ARGS__, Div, EQ); \ + FUNC(__VA_ARGS__, Div, GE); \ + FUNC(__VA_ARGS__, Div, GT); \ + FUNC(__VA_ARGS__, Div, LE); \ + FUNC(__VA_ARGS__, Div, LT); \ + FUNC(__VA_ARGS__, Div, NE); \ + FUNC(__VA_ARGS__, Mod, EQ); \ + FUNC(__VA_ARGS__, Mod, GE); \ + FUNC(__VA_ARGS__, Mod, GT); \ + FUNC(__VA_ARGS__, Mod, LE); \ + FUNC(__VA_ARGS__, Mod, LT); \ + FUNC(__VA_ARGS__, Mod, NE); + +// +namespace milvus { +namespace bitset { +namespace detail { + +///////////////////////////////////////////////////////////////////////////// +// op_compare_column + +// Define pointers for op_compare +template +using OpCompareColumnPtr = bool (*)(uint8_t* const __restrict output, + const T* const __restrict t, + const U* const __restrict u, + const size_t size); + +#define DECLARE_OP_COMPARE_COLUMN(TTYPE, UTYPE, OP) \ + OpCompareColumnPtr \ + op_compare_column_##TTYPE##_##UTYPE##_##OP = VectorizedRef:: \ + template op_compare_column; + +ALL_COMPARE_OPS(DECLARE_OP_COMPARE_COLUMN, int8_t, int8_t) +ALL_COMPARE_OPS(DECLARE_OP_COMPARE_COLUMN, int16_t, int16_t) +ALL_COMPARE_OPS(DECLARE_OP_COMPARE_COLUMN, int32_t, int32_t) +ALL_COMPARE_OPS(DECLARE_OP_COMPARE_COLUMN, int64_t, int64_t) +ALL_COMPARE_OPS(DECLARE_OP_COMPARE_COLUMN, float, float) +ALL_COMPARE_OPS(DECLARE_OP_COMPARE_COLUMN, double, double) + +#undef DECLARE_OP_COMPARE_COLUMN + +// +namespace dynamic { + +#define DISPATCH_OP_COMPARE_COLUMN_IMPL(TTYPE, OP) \ + template <> \ + bool \ + OpCompareColumnImpl::op_compare_column( \ + uint8_t* const __restrict bitmask, \ + const TTYPE* const __restrict t, \ + const TTYPE* const __restrict u, \ + const size_t size) { \ + return op_compare_column_##TTYPE##_##TTYPE##_##OP( \ + bitmask, t, u, size); \ + } + +ALL_COMPARE_OPS(DISPATCH_OP_COMPARE_COLUMN_IMPL, int8_t) +ALL_COMPARE_OPS(DISPATCH_OP_COMPARE_COLUMN_IMPL, int16_t) +ALL_COMPARE_OPS(DISPATCH_OP_COMPARE_COLUMN_IMPL, int32_t) +ALL_COMPARE_OPS(DISPATCH_OP_COMPARE_COLUMN_IMPL, int64_t) +ALL_COMPARE_OPS(DISPATCH_OP_COMPARE_COLUMN_IMPL, float) +ALL_COMPARE_OPS(DISPATCH_OP_COMPARE_COLUMN_IMPL, double) + +#undef DISPATCH_OP_COMPARE_COLUMN_IMPL + +} // namespace dynamic + +///////////////////////////////////////////////////////////////////////////// +// op_compare_val +template +using OpCompareValPtr = bool (*)(uint8_t* const __restrict output, + const T* const __restrict t, + const size_t size, + const T& value); + +#define DECLARE_OP_COMPARE_VAL(TTYPE, OP) \ + OpCompareValPtr op_compare_val_##TTYPE##_##OP = \ + VectorizedRef::template op_compare_val; + +ALL_COMPARE_OPS(DECLARE_OP_COMPARE_VAL, int8_t) +ALL_COMPARE_OPS(DECLARE_OP_COMPARE_VAL, int16_t) +ALL_COMPARE_OPS(DECLARE_OP_COMPARE_VAL, int32_t) +ALL_COMPARE_OPS(DECLARE_OP_COMPARE_VAL, int64_t) +ALL_COMPARE_OPS(DECLARE_OP_COMPARE_VAL, float) +ALL_COMPARE_OPS(DECLARE_OP_COMPARE_VAL, double) + +#undef DECLARE_OP_COMPARE_VAL + +namespace dynamic { + +#define DISPATCH_OP_COMPARE_VAL_IMPL(TTYPE, OP) \ + template <> \ + bool OpCompareValImpl::op_compare_val( \ + uint8_t* const __restrict bitmask, \ + const TTYPE* const __restrict t, \ + const size_t size, \ + const TTYPE& value) { \ + return op_compare_val_##TTYPE##_##OP(bitmask, t, size, value); \ + } + +ALL_COMPARE_OPS(DISPATCH_OP_COMPARE_VAL_IMPL, int8_t) +ALL_COMPARE_OPS(DISPATCH_OP_COMPARE_VAL_IMPL, int16_t) +ALL_COMPARE_OPS(DISPATCH_OP_COMPARE_VAL_IMPL, int32_t) +ALL_COMPARE_OPS(DISPATCH_OP_COMPARE_VAL_IMPL, int64_t) +ALL_COMPARE_OPS(DISPATCH_OP_COMPARE_VAL_IMPL, float) +ALL_COMPARE_OPS(DISPATCH_OP_COMPARE_VAL_IMPL, double) + +#undef DISPATCH_OP_COMPARE_VAL_IMPL + +} // namespace dynamic + +///////////////////////////////////////////////////////////////////////////// +// op_within_range column +template +using OpWithinRangeColumnPtr = bool (*)(uint8_t* const __restrict output, + const T* const __restrict lower, + const T* const __restrict upper, + const T* const __restrict values, + const size_t size); + +#define DECLARE_OP_WITHIN_RANGE_COLUMN(TTYPE, OP) \ + OpWithinRangeColumnPtr \ + op_within_range_column_##TTYPE##_##OP = \ + VectorizedRef::template op_within_range_column; + +ALL_RANGE_OPS(DECLARE_OP_WITHIN_RANGE_COLUMN, int8_t) +ALL_RANGE_OPS(DECLARE_OP_WITHIN_RANGE_COLUMN, int16_t) +ALL_RANGE_OPS(DECLARE_OP_WITHIN_RANGE_COLUMN, int32_t) +ALL_RANGE_OPS(DECLARE_OP_WITHIN_RANGE_COLUMN, int64_t) +ALL_RANGE_OPS(DECLARE_OP_WITHIN_RANGE_COLUMN, float) +ALL_RANGE_OPS(DECLARE_OP_WITHIN_RANGE_COLUMN, double) + +#undef DECLARE_OP_WITHIN_RANGE_COLUMN + +// +namespace dynamic { + +#define DISPATCH_OP_WITHIN_RANGE_COLUMN_IMPL(TTYPE, OP) \ + template <> \ + bool \ + OpWithinRangeColumnImpl::op_within_range_column( \ + uint8_t* const __restrict output, \ + const TTYPE* const __restrict lower, \ + const TTYPE* const __restrict upper, \ + const TTYPE* const __restrict values, \ + const size_t size) { \ + return op_within_range_column_##TTYPE##_##OP( \ + output, lower, upper, values, size); \ + } + +ALL_RANGE_OPS(DISPATCH_OP_WITHIN_RANGE_COLUMN_IMPL, int8_t) +ALL_RANGE_OPS(DISPATCH_OP_WITHIN_RANGE_COLUMN_IMPL, int16_t) +ALL_RANGE_OPS(DISPATCH_OP_WITHIN_RANGE_COLUMN_IMPL, int32_t) +ALL_RANGE_OPS(DISPATCH_OP_WITHIN_RANGE_COLUMN_IMPL, int64_t) +ALL_RANGE_OPS(DISPATCH_OP_WITHIN_RANGE_COLUMN_IMPL, float) +ALL_RANGE_OPS(DISPATCH_OP_WITHIN_RANGE_COLUMN_IMPL, double) + +#undef DISPATCH_OP_WITHIN_RANGE_COLUMN_IMPL +} // namespace dynamic + +///////////////////////////////////////////////////////////////////////////// +// op_within_range val +template +using OpWithinRangeValPtr = bool (*)(uint8_t* const __restrict output, + const T& lower, + const T& upper, + const T* const __restrict values, + const size_t size); + +#define DECLARE_OP_WITHIN_RANGE_VAL(TTYPE, OP) \ + OpWithinRangeValPtr \ + op_within_range_val_##TTYPE##_##OP = \ + VectorizedRef::template op_within_range_val; + +ALL_RANGE_OPS(DECLARE_OP_WITHIN_RANGE_VAL, int8_t) +ALL_RANGE_OPS(DECLARE_OP_WITHIN_RANGE_VAL, int16_t) +ALL_RANGE_OPS(DECLARE_OP_WITHIN_RANGE_VAL, int32_t) +ALL_RANGE_OPS(DECLARE_OP_WITHIN_RANGE_VAL, int64_t) +ALL_RANGE_OPS(DECLARE_OP_WITHIN_RANGE_VAL, float) +ALL_RANGE_OPS(DECLARE_OP_WITHIN_RANGE_VAL, double) + +#undef DECLARE_OP_WITHIN_RANGE_VAL + +// +namespace dynamic { + +#define DISPATCH_OP_WITHIN_RANGE_VAL_IMPL(TTYPE, OP) \ + template <> \ + bool OpWithinRangeValImpl::op_within_range_val( \ + uint8_t* const __restrict output, \ + const TTYPE& lower, \ + const TTYPE& upper, \ + const TTYPE* const __restrict values, \ + const size_t size) { \ + return op_within_range_val_##TTYPE##_##OP( \ + output, lower, upper, values, size); \ + } + +ALL_RANGE_OPS(DISPATCH_OP_WITHIN_RANGE_VAL_IMPL, int8_t) +ALL_RANGE_OPS(DISPATCH_OP_WITHIN_RANGE_VAL_IMPL, int16_t) +ALL_RANGE_OPS(DISPATCH_OP_WITHIN_RANGE_VAL_IMPL, int32_t) +ALL_RANGE_OPS(DISPATCH_OP_WITHIN_RANGE_VAL_IMPL, int64_t) +ALL_RANGE_OPS(DISPATCH_OP_WITHIN_RANGE_VAL_IMPL, float) +ALL_RANGE_OPS(DISPATCH_OP_WITHIN_RANGE_VAL_IMPL, double) + +} // namespace dynamic + +///////////////////////////////////////////////////////////////////////////// +// op_arith_compare +template +using OpArithComparePtr = + bool (*)(uint8_t* const __restrict output, + const T* const __restrict src, + const ArithHighPrecisionType& right_operand, + const ArithHighPrecisionType& value, + const size_t size); + +#define DECLARE_OP_ARITH_COMPARE(TTYPE, AOP, CMPOP) \ + OpArithComparePtr \ + op_arith_compare_##TTYPE##_##AOP##_##CMPOP = \ + VectorizedRef::template op_arith_compare; + +ALL_ARITH_CMP_OPS(DECLARE_OP_ARITH_COMPARE, int8_t) +ALL_ARITH_CMP_OPS(DECLARE_OP_ARITH_COMPARE, int16_t) +ALL_ARITH_CMP_OPS(DECLARE_OP_ARITH_COMPARE, int32_t) +ALL_ARITH_CMP_OPS(DECLARE_OP_ARITH_COMPARE, int64_t) +ALL_ARITH_CMP_OPS(DECLARE_OP_ARITH_COMPARE, float) +ALL_ARITH_CMP_OPS(DECLARE_OP_ARITH_COMPARE, double) + +#undef DECLARE_OP_ARITH_COMPARE + +// +namespace dynamic { + +#define DISPATCH_OP_ARITH_COMPARE(TTYPE, AOP, CMPOP) \ + template <> \ + bool OpArithCompareImpl:: \ + op_arith_compare(uint8_t* const __restrict output, \ + const TTYPE* const __restrict src, \ + const ArithHighPrecisionType& right_operand, \ + const ArithHighPrecisionType& value, \ + const size_t size) { \ + return op_arith_compare_##TTYPE##_##AOP##_##CMPOP( \ + output, src, right_operand, value, size); \ + } + +ALL_ARITH_CMP_OPS(DISPATCH_OP_ARITH_COMPARE, int8_t) +ALL_ARITH_CMP_OPS(DISPATCH_OP_ARITH_COMPARE, int16_t) +ALL_ARITH_CMP_OPS(DISPATCH_OP_ARITH_COMPARE, int32_t) +ALL_ARITH_CMP_OPS(DISPATCH_OP_ARITH_COMPARE, int64_t) +ALL_ARITH_CMP_OPS(DISPATCH_OP_ARITH_COMPARE, float) +ALL_ARITH_CMP_OPS(DISPATCH_OP_ARITH_COMPARE, double) + +} // namespace dynamic + +} // namespace detail +} // namespace bitset +} // namespace milvus + +// +static void +init_dynamic_hook() { + using namespace milvus::bitset; + using namespace milvus::bitset::detail; + +#if defined(__x86_64__) + // AVX512 ? + if (cpu_support_avx512()) { +#define SET_OP_COMPARE_COLUMN_AVX512(TTYPE, UTYPE, OP) \ + op_compare_column_##TTYPE##_##UTYPE##_##OP = VectorizedAvx512:: \ + template op_compare_column; +#define SET_OP_COMPARE_VAL_AVX512(TTYPE, OP) \ + op_compare_val_##TTYPE##_##OP = \ + VectorizedAvx512::template op_compare_val; +#define SET_OP_WITHIN_RANGE_COLUMN_AVX512(TTYPE, OP) \ + op_within_range_column_##TTYPE##_##OP = \ + VectorizedAvx512::template op_within_range_column; +#define SET_OP_WITHIN_RANGE_VAL_AVX512(TTYPE, OP) \ + op_within_range_val_##TTYPE##_##OP = \ + VectorizedAvx512::template op_within_range_val; +#define SET_ARITH_COMPARE_AVX512(TTYPE, AOP, CMPOP) \ + op_arith_compare_##TTYPE##_##AOP##_##CMPOP = \ + VectorizedAvx512::template op_arith_compare; + + // assign AVX512-related pointers + ALL_COMPARE_OPS(SET_OP_COMPARE_COLUMN_AVX512, int8_t, int8_t) + ALL_COMPARE_OPS(SET_OP_COMPARE_COLUMN_AVX512, int16_t, int16_t) + ALL_COMPARE_OPS(SET_OP_COMPARE_COLUMN_AVX512, int32_t, int32_t) + ALL_COMPARE_OPS(SET_OP_COMPARE_COLUMN_AVX512, int64_t, int64_t) + ALL_COMPARE_OPS(SET_OP_COMPARE_COLUMN_AVX512, float, float) + ALL_COMPARE_OPS(SET_OP_COMPARE_COLUMN_AVX512, double, double) + + ALL_COMPARE_OPS(SET_OP_COMPARE_VAL_AVX512, int8_t) + ALL_COMPARE_OPS(SET_OP_COMPARE_VAL_AVX512, int16_t) + ALL_COMPARE_OPS(SET_OP_COMPARE_VAL_AVX512, int32_t) + ALL_COMPARE_OPS(SET_OP_COMPARE_VAL_AVX512, int64_t) + ALL_COMPARE_OPS(SET_OP_COMPARE_VAL_AVX512, float) + ALL_COMPARE_OPS(SET_OP_COMPARE_VAL_AVX512, double) + + ALL_RANGE_OPS(SET_OP_WITHIN_RANGE_COLUMN_AVX512, int8_t) + ALL_RANGE_OPS(SET_OP_WITHIN_RANGE_COLUMN_AVX512, int16_t) + ALL_RANGE_OPS(SET_OP_WITHIN_RANGE_COLUMN_AVX512, int32_t) + ALL_RANGE_OPS(SET_OP_WITHIN_RANGE_COLUMN_AVX512, int64_t) + ALL_RANGE_OPS(SET_OP_WITHIN_RANGE_COLUMN_AVX512, float) + ALL_RANGE_OPS(SET_OP_WITHIN_RANGE_COLUMN_AVX512, double) + + ALL_RANGE_OPS(SET_OP_WITHIN_RANGE_VAL_AVX512, int8_t) + ALL_RANGE_OPS(SET_OP_WITHIN_RANGE_VAL_AVX512, int16_t) + ALL_RANGE_OPS(SET_OP_WITHIN_RANGE_VAL_AVX512, int32_t) + ALL_RANGE_OPS(SET_OP_WITHIN_RANGE_VAL_AVX512, int64_t) + ALL_RANGE_OPS(SET_OP_WITHIN_RANGE_VAL_AVX512, float) + ALL_RANGE_OPS(SET_OP_WITHIN_RANGE_VAL_AVX512, double) + + ALL_ARITH_CMP_OPS(SET_ARITH_COMPARE_AVX512, int8_t) + ALL_ARITH_CMP_OPS(SET_ARITH_COMPARE_AVX512, int16_t) + ALL_ARITH_CMP_OPS(SET_ARITH_COMPARE_AVX512, int32_t) + ALL_ARITH_CMP_OPS(SET_ARITH_COMPARE_AVX512, int64_t) + ALL_ARITH_CMP_OPS(SET_ARITH_COMPARE_AVX512, float) + ALL_ARITH_CMP_OPS(SET_ARITH_COMPARE_AVX512, double) + +#undef SET_OP_COMPARE_COLUMN_AVX512 +#undef SET_OP_COMPARE_VAL_AVX512 +#undef SET_OP_WITHIN_RANGE_COLUMN_AVX512 +#undef SET_OP_WITHIN_RANGE_VAL_AVX512 +#undef SET_ARITH_COMPARE_AVX512 + + return; + } + + // AVX2 ? + if (cpu_support_avx2()) { +#define SET_OP_COMPARE_COLUMN_AVX2(TTYPE, UTYPE, OP) \ + op_compare_column_##TTYPE##_##UTYPE##_##OP = VectorizedAvx2:: \ + template op_compare_column; +#define SET_OP_COMPARE_VAL_AVX2(TTYPE, OP) \ + op_compare_val_##TTYPE##_##OP = \ + VectorizedAvx2::template op_compare_val; +#define SET_OP_WITHIN_RANGE_COLUMN_AVX2(TTYPE, OP) \ + op_within_range_column_##TTYPE##_##OP = \ + VectorizedAvx2::template op_within_range_column; +#define SET_OP_WITHIN_RANGE_VAL_AVX2(TTYPE, OP) \ + op_within_range_val_##TTYPE##_##OP = \ + VectorizedAvx2::template op_within_range_val; +#define SET_ARITH_COMPARE_AVX2(TTYPE, AOP, CMPOP) \ + op_arith_compare_##TTYPE##_##AOP##_##CMPOP = \ + VectorizedAvx2::template op_arith_compare; + + // assign AVX2-related pointers + ALL_COMPARE_OPS(SET_OP_COMPARE_COLUMN_AVX2, int8_t, int8_t) + ALL_COMPARE_OPS(SET_OP_COMPARE_COLUMN_AVX2, int16_t, int16_t) + ALL_COMPARE_OPS(SET_OP_COMPARE_COLUMN_AVX2, int32_t, int32_t) + ALL_COMPARE_OPS(SET_OP_COMPARE_COLUMN_AVX2, int64_t, int64_t) + ALL_COMPARE_OPS(SET_OP_COMPARE_COLUMN_AVX2, float, float) + ALL_COMPARE_OPS(SET_OP_COMPARE_COLUMN_AVX2, double, double) + + ALL_COMPARE_OPS(SET_OP_COMPARE_VAL_AVX2, int8_t) + ALL_COMPARE_OPS(SET_OP_COMPARE_VAL_AVX2, int16_t) + ALL_COMPARE_OPS(SET_OP_COMPARE_VAL_AVX2, int32_t) + ALL_COMPARE_OPS(SET_OP_COMPARE_VAL_AVX2, int64_t) + ALL_COMPARE_OPS(SET_OP_COMPARE_VAL_AVX2, float) + ALL_COMPARE_OPS(SET_OP_COMPARE_VAL_AVX2, double) + + ALL_RANGE_OPS(SET_OP_WITHIN_RANGE_COLUMN_AVX2, int8_t) + ALL_RANGE_OPS(SET_OP_WITHIN_RANGE_COLUMN_AVX2, int16_t) + ALL_RANGE_OPS(SET_OP_WITHIN_RANGE_COLUMN_AVX2, int32_t) + ALL_RANGE_OPS(SET_OP_WITHIN_RANGE_COLUMN_AVX2, int64_t) + ALL_RANGE_OPS(SET_OP_WITHIN_RANGE_COLUMN_AVX2, float) + ALL_RANGE_OPS(SET_OP_WITHIN_RANGE_COLUMN_AVX2, double) + + ALL_RANGE_OPS(SET_OP_WITHIN_RANGE_VAL_AVX2, int8_t) + ALL_RANGE_OPS(SET_OP_WITHIN_RANGE_VAL_AVX2, int16_t) + ALL_RANGE_OPS(SET_OP_WITHIN_RANGE_VAL_AVX2, int32_t) + ALL_RANGE_OPS(SET_OP_WITHIN_RANGE_VAL_AVX2, int64_t) + ALL_RANGE_OPS(SET_OP_WITHIN_RANGE_VAL_AVX2, float) + ALL_RANGE_OPS(SET_OP_WITHIN_RANGE_VAL_AVX2, double) + + ALL_ARITH_CMP_OPS(SET_ARITH_COMPARE_AVX2, int8_t) + ALL_ARITH_CMP_OPS(SET_ARITH_COMPARE_AVX2, int16_t) + ALL_ARITH_CMP_OPS(SET_ARITH_COMPARE_AVX2, int32_t) + ALL_ARITH_CMP_OPS(SET_ARITH_COMPARE_AVX2, int64_t) + ALL_ARITH_CMP_OPS(SET_ARITH_COMPARE_AVX2, float) + ALL_ARITH_CMP_OPS(SET_ARITH_COMPARE_AVX2, double) + +#undef SET_OP_COMPARE_COLUMN_AVX2 +#undef SET_OP_COMPARE_VAL_AVX2 +#undef SET_OP_WITHIN_RANGE_COLUMN_AVX2 +#undef SET_OP_WITHIN_RANGE_VAL_AVX2 +#undef SET_ARITH_COMPARE_AVX2 + + return; + } +#endif + +#if defined(__aarch64__) + // neon ? + { +#define SET_OP_COMPARE_COLUMN_NEON(TTYPE, UTYPE, OP) \ + op_compare_column_##TTYPE##_##UTYPE##_##OP = VectorizedNeon:: \ + template op_compare_column; +#define SET_OP_COMPARE_VAL_NEON(TTYPE, OP) \ + op_compare_val_##TTYPE##_##OP = \ + VectorizedNeon::template op_compare_val; +#define SET_OP_WITHIN_RANGE_COLUMN_NEON(TTYPE, OP) \ + op_within_range_column_##TTYPE##_##OP = \ + VectorizedNeon::template op_within_range_column; +#define SET_OP_WITHIN_RANGE_VAL_NEON(TTYPE, OP) \ + op_within_range_val_##TTYPE##_##OP = \ + VectorizedNeon::template op_within_range_val; +#define SET_ARITH_COMPARE_NEON(TTYPE, AOP, CMPOP) \ + op_arith_compare_##TTYPE##_##AOP##_##CMPOP = \ + VectorizedNeon::template op_arith_compare; + + // assign NEON-related pointers + ALL_COMPARE_OPS(SET_OP_COMPARE_COLUMN_NEON, int8_t, int8_t) + ALL_COMPARE_OPS(SET_OP_COMPARE_COLUMN_NEON, int16_t, int16_t) + ALL_COMPARE_OPS(SET_OP_COMPARE_COLUMN_NEON, int32_t, int32_t) + ALL_COMPARE_OPS(SET_OP_COMPARE_COLUMN_NEON, int64_t, int64_t) + ALL_COMPARE_OPS(SET_OP_COMPARE_COLUMN_NEON, float, float) + ALL_COMPARE_OPS(SET_OP_COMPARE_COLUMN_NEON, double, double) + + ALL_COMPARE_OPS(SET_OP_COMPARE_VAL_NEON, int8_t) + ALL_COMPARE_OPS(SET_OP_COMPARE_VAL_NEON, int16_t) + ALL_COMPARE_OPS(SET_OP_COMPARE_VAL_NEON, int32_t) + ALL_COMPARE_OPS(SET_OP_COMPARE_VAL_NEON, int64_t) + ALL_COMPARE_OPS(SET_OP_COMPARE_VAL_NEON, float) + ALL_COMPARE_OPS(SET_OP_COMPARE_VAL_NEON, double) + + ALL_RANGE_OPS(SET_OP_WITHIN_RANGE_COLUMN_NEON, int8_t) + ALL_RANGE_OPS(SET_OP_WITHIN_RANGE_COLUMN_NEON, int16_t) + ALL_RANGE_OPS(SET_OP_WITHIN_RANGE_COLUMN_NEON, int32_t) + ALL_RANGE_OPS(SET_OP_WITHIN_RANGE_COLUMN_NEON, int64_t) + ALL_RANGE_OPS(SET_OP_WITHIN_RANGE_COLUMN_NEON, float) + ALL_RANGE_OPS(SET_OP_WITHIN_RANGE_COLUMN_NEON, double) + + ALL_RANGE_OPS(SET_OP_WITHIN_RANGE_VAL_NEON, int8_t) + ALL_RANGE_OPS(SET_OP_WITHIN_RANGE_VAL_NEON, int16_t) + ALL_RANGE_OPS(SET_OP_WITHIN_RANGE_VAL_NEON, int32_t) + ALL_RANGE_OPS(SET_OP_WITHIN_RANGE_VAL_NEON, int64_t) + ALL_RANGE_OPS(SET_OP_WITHIN_RANGE_VAL_NEON, float) + ALL_RANGE_OPS(SET_OP_WITHIN_RANGE_VAL_NEON, double) + + ALL_ARITH_CMP_OPS(SET_ARITH_COMPARE_NEON, int8_t) + ALL_ARITH_CMP_OPS(SET_ARITH_COMPARE_NEON, int16_t) + ALL_ARITH_CMP_OPS(SET_ARITH_COMPARE_NEON, int32_t) + ALL_ARITH_CMP_OPS(SET_ARITH_COMPARE_NEON, int64_t) + ALL_ARITH_CMP_OPS(SET_ARITH_COMPARE_NEON, float) + ALL_ARITH_CMP_OPS(SET_ARITH_COMPARE_NEON, double) + +#undef SET_OP_COMPARE_COLUMN_NEON +#undef SET_OP_COMPARE_VAL_NEON +#undef SET_OP_WITHIN_RANGE_COLUMN_NEON +#undef SET_OP_WITHIN_RANGE_VAL_NEON +#undef SET_ARITH_COMPARE_NEON + } + +#ifdef __ARM_FEATURE_SVE + + // sve? + { +#define SET_OP_COMPARE_COLUMN_SVE(TTYPE, UTYPE, OP) \ + op_compare_column_##TTYPE##_##UTYPE##_##OP = VectorizedSve:: \ + template op_compare_column; +#define SET_OP_COMPARE_VAL_SVE(TTYPE, OP) \ + op_compare_val_##TTYPE##_##OP = \ + VectorizedSve::template op_compare_val; +#define SET_OP_WITHIN_RANGE_COLUMN_SVE(TTYPE, OP) \ + op_within_range_column_##TTYPE##_##OP = \ + VectorizedSve::template op_within_range_column; +#define SET_OP_WITHIN_RANGE_VAL_SVE(TTYPE, OP) \ + op_within_range_val_##TTYPE##_##OP = \ + VectorizedSve::template op_within_range_val; +#define SET_ARITH_COMPARE_SVE(TTYPE, AOP, CMPOP) \ + op_arith_compare_##TTYPE##_##AOP##_##CMPOP = \ + VectorizedSve::template op_arith_compare; + + // assign SVE-related pointers + ALL_COMPARE_OPS(SET_OP_COMPARE_COLUMN_SVE, int8_t, int8_t) + ALL_COMPARE_OPS(SET_OP_COMPARE_COLUMN_SVE, int16_t, int16_t) + ALL_COMPARE_OPS(SET_OP_COMPARE_COLUMN_SVE, int32_t, int32_t) + ALL_COMPARE_OPS(SET_OP_COMPARE_COLUMN_SVE, int64_t, int64_t) + ALL_COMPARE_OPS(SET_OP_COMPARE_COLUMN_SVE, float, float) + ALL_COMPARE_OPS(SET_OP_COMPARE_COLUMN_SVE, double, double) + + ALL_COMPARE_OPS(SET_OP_COMPARE_VAL_SVE, int8_t) + ALL_COMPARE_OPS(SET_OP_COMPARE_VAL_SVE, int16_t) + ALL_COMPARE_OPS(SET_OP_COMPARE_VAL_SVE, int32_t) + ALL_COMPARE_OPS(SET_OP_COMPARE_VAL_SVE, int64_t) + ALL_COMPARE_OPS(SET_OP_COMPARE_VAL_SVE, float) + ALL_COMPARE_OPS(SET_OP_COMPARE_VAL_SVE, double) + + ALL_RANGE_OPS(SET_OP_WITHIN_RANGE_COLUMN_SVE, int8_t) + ALL_RANGE_OPS(SET_OP_WITHIN_RANGE_COLUMN_SVE, int16_t) + ALL_RANGE_OPS(SET_OP_WITHIN_RANGE_COLUMN_SVE, int32_t) + ALL_RANGE_OPS(SET_OP_WITHIN_RANGE_COLUMN_SVE, int64_t) + ALL_RANGE_OPS(SET_OP_WITHIN_RANGE_COLUMN_SVE, float) + ALL_RANGE_OPS(SET_OP_WITHIN_RANGE_COLUMN_SVE, double) + + ALL_RANGE_OPS(SET_OP_WITHIN_RANGE_VAL_SVE, int8_t) + ALL_RANGE_OPS(SET_OP_WITHIN_RANGE_VAL_SVE, int16_t) + ALL_RANGE_OPS(SET_OP_WITHIN_RANGE_VAL_SVE, int32_t) + ALL_RANGE_OPS(SET_OP_WITHIN_RANGE_VAL_SVE, int64_t) + ALL_RANGE_OPS(SET_OP_WITHIN_RANGE_VAL_SVE, float) + ALL_RANGE_OPS(SET_OP_WITHIN_RANGE_VAL_SVE, double) + + ALL_ARITH_CMP_OPS(SET_ARITH_COMPARE_SVE, int8_t) + ALL_ARITH_CMP_OPS(SET_ARITH_COMPARE_SVE, int16_t) + ALL_ARITH_CMP_OPS(SET_ARITH_COMPARE_SVE, int32_t) + ALL_ARITH_CMP_OPS(SET_ARITH_COMPARE_SVE, int64_t) + ALL_ARITH_CMP_OPS(SET_ARITH_COMPARE_SVE, float) + ALL_ARITH_CMP_OPS(SET_ARITH_COMPARE_SVE, double) + +#undef SET_OP_COMPARE_COLUMN_SVE +#undef SET_OP_COMPARE_VAL_SVE +#undef SET_OP_WITHIN_RANGE_COLUMN_SVE +#undef SET_OP_WITHIN_RANGE_VAL_SVE +#undef SET_ARITH_COMPARE_SVE + } +#endif + +#endif +} + +// no longer needed +#undef ALL_COMPARE_OPS +#undef ALL_RANGE_OPS +#undef ALL_ARITH_CMP_OPS + +// +static int init_dynamic_ = []() { + init_dynamic_hook(); + + return 0; +}(); diff --git a/internal/core/src/bitset/detail/platform/dynamic.h b/internal/core/src/bitset/detail/platform/dynamic.h new file mode 100644 index 000000000000..3a050a5e83aa --- /dev/null +++ b/internal/core/src/bitset/detail/platform/dynamic.h @@ -0,0 +1,255 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "bitset/common.h" + +namespace milvus { +namespace bitset { +namespace detail { + +namespace dynamic { + +/////////////////////////////////////////////////////////////////////////// +// a facility to run through all acceptable data types +#define ALL_DATATYPES_1(FUNC) \ + FUNC(int8_t); \ + FUNC(int16_t); \ + FUNC(int32_t); \ + FUNC(int64_t); \ + FUNC(float); \ + FUNC(double); + +/////////////////////////////////////////////////////////////////////////// +// the default implementation +template +struct OpCompareColumnImpl { + static bool + op_compare_column(uint8_t* const __restrict bitmask, + const T* const __restrict t, + const U* const __restrict u, + const size_t size) { + return false; + } +}; + +#define DECLARE_PARTIAL_OP_COMPARE_COLUMN(TTYPE) \ + template \ + struct OpCompareColumnImpl { \ + static bool \ + op_compare_column(uint8_t* const __restrict bitmask, \ + const TTYPE* const __restrict t, \ + const TTYPE* const __restrict u, \ + const size_t size); \ + }; + +ALL_DATATYPES_1(DECLARE_PARTIAL_OP_COMPARE_COLUMN) + +#undef DECLARE_PARTIAL_OP_COMPARE_COLUMN + +/////////////////////////////////////////////////////////////////////////// +// the default implementation +template +struct OpCompareValImpl { + static inline bool + op_compare_val(uint8_t* const __restrict bitmask, + const T* const __restrict t, + const size_t size, + const T& value) { + return false; + } +}; + +#define DECLARE_PARTIAL_OP_COMPARE_VAL(TTYPE) \ + template \ + struct OpCompareValImpl { \ + static bool \ + op_compare_val(uint8_t* const __restrict bitmask, \ + const TTYPE* const __restrict t, \ + const size_t size, \ + const TTYPE& value); \ + }; + +ALL_DATATYPES_1(DECLARE_PARTIAL_OP_COMPARE_VAL) + +#undef DECLARE_PARTIAL_OP_COMPARE_VAL + +/////////////////////////////////////////////////////////////////////////// +// the default implementation +template +struct OpWithinRangeColumnImpl { + static inline bool + op_within_range_column(uint8_t* const __restrict bitmask, + const T* const __restrict lower, + const T* const __restrict upper, + const T* const __restrict values, + const size_t size) { + return false; + } +}; + +#define DECLARE_PARTIAL_OP_WITHIN_RANGE_COLUMN(TTYPE) \ + template \ + struct OpWithinRangeColumnImpl { \ + static bool \ + op_within_range_column(uint8_t* const __restrict bitmask, \ + const TTYPE* const __restrict lower, \ + const TTYPE* const __restrict upper, \ + const TTYPE* const __restrict values, \ + const size_t size); \ + }; + +ALL_DATATYPES_1(DECLARE_PARTIAL_OP_WITHIN_RANGE_COLUMN) + +#undef DECLARE_PARTIAL_OP_WITHIN_RANGE_COLUMN + +/////////////////////////////////////////////////////////////////////////// +// the default implementation +template +struct OpWithinRangeValImpl { + static inline bool + op_within_range_val(uint8_t* const __restrict bitmask, + const T& lower, + const T& upper, + const T* const __restrict values, + const size_t size) { + return false; + } +}; + +#define DECLARE_PARTIAL_OP_WITHIN_RANGE_VAL(TTYPE) \ + template \ + struct OpWithinRangeValImpl { \ + static bool \ + op_within_range_val(uint8_t* const __restrict bitmask, \ + const TTYPE& lower, \ + const TTYPE& upper, \ + const TTYPE* const __restrict values, \ + const size_t size); \ + }; + +ALL_DATATYPES_1(DECLARE_PARTIAL_OP_WITHIN_RANGE_VAL) + +#undef DECLARE_PARTIAL_OP_WITHIN_RANGE_VAL + +/////////////////////////////////////////////////////////////////////////// +// the default implementation +template +struct OpArithCompareImpl { + static inline bool + op_arith_compare(uint8_t* const __restrict bitmask, + const T* const __restrict src, + const ArithHighPrecisionType& right_operand, + const ArithHighPrecisionType& value, + const size_t size) { + return false; + } +}; + +#define DECLARE_PARTIAL_OP_ARITH_COMPARE(TTYPE) \ + template \ + struct OpArithCompareImpl { \ + static bool \ + op_arith_compare(uint8_t* const __restrict bitmask, \ + const TTYPE* const __restrict src, \ + const ArithHighPrecisionType& right_operand, \ + const ArithHighPrecisionType& value, \ + const size_t size); \ + }; + +ALL_DATATYPES_1(DECLARE_PARTIAL_OP_ARITH_COMPARE) + +// + +/////////////////////////////////////////////////////////////////////////// + +#undef ALL_DATATYPES_1 + +} // namespace dynamic + +/////////////////////////////////////////////////////////////////////////// + +// +struct VectorizedDynamic { + // Fills a bitmask by comparing two arrays element-wise. + // API requirement: size % 8 == 0 + template + static bool + op_compare_column(uint8_t* const __restrict bitmask, + const T* const __restrict t, + const U* const __restrict u, + const size_t size) { + return dynamic::OpCompareColumnImpl::op_compare_column( + bitmask, t, u, size); + } + + // Fills a bitmask by comparing elements of a given array to a + // given value. + // API requirement: size % 8 == 0 + template + static bool + op_compare_val(uint8_t* const __restrict bitmask, + const T* const __restrict t, + const size_t size, + const T& value) { + return dynamic::OpCompareValImpl::op_compare_val( + bitmask, t, size, value); + } + + // API requirement: size % 8 == 0 + template + static bool + op_within_range_column(uint8_t* const __restrict bitmask, + const T* const __restrict lower, + const T* const __restrict upper, + const T* const __restrict values, + const size_t size) { + return dynamic::OpWithinRangeColumnImpl::op_within_range_column( + bitmask, lower, upper, values, size); + } + + // API requirement: size % 8 == 0 + template + static bool + op_within_range_val(uint8_t* const __restrict bitmask, + const T& lower, + const T& upper, + const T* const __restrict values, + const size_t size) { + return dynamic::OpWithinRangeValImpl::op_within_range_val( + bitmask, lower, upper, values, size); + } + + // API requirement: size % 8 == 0 + template + static inline bool + op_arith_compare(uint8_t* const __restrict bitmask, + const T* const __restrict src, + const ArithHighPrecisionType& right_operand, + const ArithHighPrecisionType& value, + const size_t size) { + return dynamic::OpArithCompareImpl::op_arith_compare( + bitmask, src, right_operand, value, size); + } +}; + +} // namespace detail +} // namespace bitset +} // namespace milvus \ No newline at end of file diff --git a/internal/core/src/bitset/detail/platform/vectorized_ref.h b/internal/core/src/bitset/detail/platform/vectorized_ref.h new file mode 100644 index 000000000000..20da65406f1f --- /dev/null +++ b/internal/core/src/bitset/detail/platform/vectorized_ref.h @@ -0,0 +1,95 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +#include "bitset/common.h" + +namespace milvus { +namespace bitset { +namespace detail { + +// The default reference vectorizer. +// Its every function returns a boolean value whether a vectorized implementation +// exists and was invoked. If not, then the caller code will use a default +// non-vectorized implementation. +// The default vectorizer provides no vectorized implementation, forcing the +// caller to use a defaut non-vectorized implementation every time. +struct VectorizedRef { + // Fills a bitmask by comparing two arrays element-wise. + // API requirement: size % 8 == 0 + template + static inline bool + op_compare_column(uint8_t* const __restrict output, + const T* const __restrict t, + const U* const __restrict u, + const size_t size) { + return false; + } + + // Fills a bitmask by comparing elements of a given array to a + // given value. + // API requirement: size % 8 == 0 + template + static inline bool + op_compare_val(uint8_t* const __restrict output, + const T* const __restrict t, + const size_t size, + const T& value) { + return false; + } + + // API requirement: size % 8 == 0 + template + static inline bool + op_within_range_column(uint8_t* const __restrict data, + const T* const __restrict lower, + const T* const __restrict upper, + const T* const __restrict values, + const size_t size) { + return false; + } + + // API requirement: size % 8 == 0 + template + static inline bool + op_within_range_val(uint8_t* const __restrict data, + const T& lower, + const T& upper, + const T* const __restrict values, + const size_t size) { + return false; + } + + // API requirement: size % 8 == 0 + template + static inline bool + op_arith_compare(uint8_t* const __restrict bitmask, + const T* const __restrict src, + const ArithHighPrecisionType& right_operand, + const ArithHighPrecisionType& value, + const size_t size) { + return false; + } +}; + +} // namespace detail +} // namespace bitset +} // namespace milvus diff --git a/internal/core/src/bitset/detail/platform/x86/avx2-decl.h b/internal/core/src/bitset/detail/platform/x86/avx2-decl.h new file mode 100644 index 000000000000..cdac2b9713f3 --- /dev/null +++ b/internal/core/src/bitset/detail/platform/x86/avx2-decl.h @@ -0,0 +1,201 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// AVX2 declaration + +#pragma once + +#include +#include + +#include "bitset/common.h" + +namespace milvus { +namespace bitset { +namespace detail { +namespace x86 { +namespace avx2 { + +/////////////////////////////////////////////////////////////////////////// +// a facility to run through all acceptable data types +#define ALL_DATATYPES_1(FUNC) \ + FUNC(int8_t); \ + FUNC(int16_t); \ + FUNC(int32_t); \ + FUNC(int64_t); \ + FUNC(float); \ + FUNC(double); + +/////////////////////////////////////////////////////////////////////////// + +// the default implementation does nothing +template +struct OpCompareColumnImpl { + static bool + op_compare_column(uint8_t* const __restrict bitmask, + const T* const __restrict t, + const U* const __restrict u, + const size_t size) { + return false; + } +}; + +// the following use cases are handled +#define DECLARE_PARTIAL_OP_COMPARE_COLUMN(TTYPE) \ + template \ + struct OpCompareColumnImpl { \ + static bool \ + op_compare_column(uint8_t* const __restrict bitmask, \ + const TTYPE* const __restrict t, \ + const TTYPE* const __restrict u, \ + const size_t size); \ + }; + +ALL_DATATYPES_1(DECLARE_PARTIAL_OP_COMPARE_COLUMN) + +#undef DECLARE_PARTIAL_OP_COMPARE_COLUMN + +/////////////////////////////////////////////////////////////////////////// + +// the default implementation does nothing +template +struct OpCompareValImpl { + static inline bool + op_compare_val(uint8_t* const __restrict bitmask, + const T* const __restrict t, + const size_t size, + const T& value) { + return false; + } +}; + +// the following use cases are handled +#define DECLARE_PARTIAL_OP_COMPARE_VAL(TTYPE) \ + template \ + struct OpCompareValImpl { \ + static bool \ + op_compare_val(uint8_t* const __restrict bitmask, \ + const TTYPE* const __restrict t, \ + const size_t size, \ + const TTYPE& value); \ + }; + +ALL_DATATYPES_1(DECLARE_PARTIAL_OP_COMPARE_VAL) + +#undef DECLARE_PARTIAL_OP_COMPARE_VAL + +/////////////////////////////////////////////////////////////////////////// + +// the default implementation does nothing +template +struct OpWithinRangeColumnImpl { + static inline bool + op_within_range_column(uint8_t* const __restrict bitmask, + const T* const __restrict lower, + const T* const __restrict upper, + const T* const __restrict values, + const size_t size) { + return false; + } +}; + +// the following use cases are handled +#define DECLARE_PARTIAL_OP_WITHIN_RANGE_COLUMN(TTYPE) \ + template \ + struct OpWithinRangeColumnImpl { \ + static bool \ + op_within_range_column(uint8_t* const __restrict bitmask, \ + const TTYPE* const __restrict lower, \ + const TTYPE* const __restrict upper, \ + const TTYPE* const __restrict values, \ + const size_t size); \ + }; + +ALL_DATATYPES_1(DECLARE_PARTIAL_OP_WITHIN_RANGE_COLUMN) + +#undef DECLARE_PARTIAL_OP_WITHIN_RANGE_COLUMN + +/////////////////////////////////////////////////////////////////////////// + +// the default implementation does nothing +template +struct OpWithinRangeValImpl { + static inline bool + op_within_range_val(uint8_t* const __restrict bitmask, + const T& lower, + const T& upper, + const T* const __restrict values, + const size_t size) { + return false; + } +}; + +// the following use cases are handled +#define DECLARE_PARTIAL_OP_WITHIN_RANGE_VAL(TTYPE) \ + template \ + struct OpWithinRangeValImpl { \ + static bool \ + op_within_range_val(uint8_t* const __restrict bitmask, \ + const TTYPE& lower, \ + const TTYPE& upper, \ + const TTYPE* const __restrict values, \ + const size_t size); \ + }; + +ALL_DATATYPES_1(DECLARE_PARTIAL_OP_WITHIN_RANGE_VAL) + +#undef DECLARE_PARTIAL_OP_WITHIN_RANGE_VAL + +/////////////////////////////////////////////////////////////////////////// + +// the default implementation does nothing +template +struct OpArithCompareImpl { + static inline bool + op_arith_compare(uint8_t* const __restrict bitmask, + const T* const __restrict src, + const ArithHighPrecisionType& right_operand, + const ArithHighPrecisionType& value, + const size_t size) { + return false; + } +}; + +// the following use cases are handled +#define DECLARE_PARTIAL_OP_ARITH_COMPARE(TTYPE) \ + template \ + struct OpArithCompareImpl { \ + static bool \ + op_arith_compare(uint8_t* const __restrict bitmask, \ + const TTYPE* const __restrict src, \ + const ArithHighPrecisionType& right_operand, \ + const ArithHighPrecisionType& value, \ + const size_t size); \ + }; + +ALL_DATATYPES_1(DECLARE_PARTIAL_OP_ARITH_COMPARE) + +#undef DECLARE_PARTIAL_OP_ARITH_COMPARE + +/////////////////////////////////////////////////////////////////////////// + +#undef ALL_DATATYPES_1 + +} // namespace avx2 +} // namespace x86 +} // namespace detail +} // namespace bitset +} // namespace milvus diff --git a/internal/core/src/bitset/detail/platform/x86/avx2-impl.h b/internal/core/src/bitset/detail/platform/x86/avx2-impl.h new file mode 100644 index 000000000000..3b74749d2a63 --- /dev/null +++ b/internal/core/src/bitset/detail/platform/x86/avx2-impl.h @@ -0,0 +1,1658 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// AVX2 implementation + +#pragma once + +#include + +#include +#include +#include +#include + +#include "avx2-decl.h" + +#include "bitset/common.h" +#include "common.h" + +namespace milvus { +namespace bitset { +namespace detail { +namespace x86 { +namespace avx2 { + +namespace { + +// count is expected to be in range [0, 32) +inline uint32_t +get_mask(const size_t count) { + return (uint32_t(1) << count) - uint32_t(1); +} + +// +template +struct CmpHelperI8 {}; + +template <> +struct CmpHelperI8 { + static inline __m256i + compare(const __m256i a, const __m256i b) { + return _mm256_cmpeq_epi8(a, b); + } +}; + +template <> +struct CmpHelperI8 { + static inline __m256i + compare(const __m256i a, const __m256i b) { + return _mm256_xor_si256(_mm256_cmpgt_epi8(b, a), _mm256_set1_epi32(-1)); + } +}; + +template <> +struct CmpHelperI8 { + static inline __m256i + compare(const __m256i a, const __m256i b) { + return _mm256_cmpgt_epi8(a, b); + } +}; + +template <> +struct CmpHelperI8 { + static inline __m256i + compare(const __m256i a, const __m256i b) { + return _mm256_xor_si256(_mm256_cmpgt_epi8(a, b), _mm256_set1_epi32(-1)); + } +}; + +template <> +struct CmpHelperI8 { + static inline __m256i + compare(const __m256i a, const __m256i b) { + return _mm256_cmpgt_epi8(b, a); + } +}; + +template <> +struct CmpHelperI8 { + static inline __m256i + compare(const __m256i a, const __m256i b) { + return _mm256_xor_si256(_mm256_cmpeq_epi8(a, b), _mm256_set1_epi32(-1)); + } +}; + +// +template +struct CmpHelperI16 {}; + +template <> +struct CmpHelperI16 { + static inline __m256i + compare(const __m256i a, const __m256i b) { + return _mm256_cmpeq_epi16(a, b); + } + + static inline __m128i + compare(const __m128i a, const __m128i b) { + return _mm_cmpeq_epi16(a, b); + } +}; + +template <> +struct CmpHelperI16 { + static inline __m256i + compare(const __m256i a, const __m256i b) { + return _mm256_xor_si256(_mm256_cmpgt_epi16(b, a), + _mm256_set1_epi32(-1)); + } + + static inline __m128i + compare(const __m128i a, const __m128i b) { + return _mm_xor_si128(_mm_cmpgt_epi16(b, a), _mm_set1_epi32(-1)); + } +}; + +template <> +struct CmpHelperI16 { + static inline __m256i + compare(const __m256i a, const __m256i b) { + return _mm256_cmpgt_epi16(a, b); + } + + static inline __m128i + compare(const __m128i a, const __m128i b) { + return _mm_cmpgt_epi16(a, b); + } +}; + +template <> +struct CmpHelperI16 { + static inline __m256i + compare(const __m256i a, const __m256i b) { + return _mm256_xor_si256(_mm256_cmpgt_epi16(a, b), + _mm256_set1_epi32(-1)); + } + + static inline __m128i + compare(const __m128i a, const __m128i b) { + return _mm_xor_si128(_mm_cmpgt_epi16(a, b), _mm_set1_epi32(-1)); + } +}; + +template <> +struct CmpHelperI16 { + static inline __m256i + compare(const __m256i a, const __m256i b) { + return _mm256_cmpgt_epi16(b, a); + } + + static inline __m128i + compare(const __m128i a, const __m128i b) { + return _mm_cmpgt_epi16(b, a); + } +}; + +template <> +struct CmpHelperI16 { + static inline __m256i + compare(const __m256i a, const __m256i b) { + return _mm256_xor_si256(_mm256_cmpeq_epi16(a, b), + _mm256_set1_epi32(-1)); + } + + static inline __m128i + compare(const __m128i a, const __m128i b) { + return _mm_xor_si128(_mm_cmpeq_epi16(a, b), _mm_set1_epi32(-1)); + } +}; + +// +template +struct CmpHelperI32 {}; + +template <> +struct CmpHelperI32 { + static inline __m256i + compare(const __m256i a, const __m256i b) { + return _mm256_cmpeq_epi32(a, b); + } +}; + +template <> +struct CmpHelperI32 { + static inline __m256i + compare(const __m256i a, const __m256i b) { + return _mm256_xor_si256(_mm256_cmpgt_epi32(b, a), + _mm256_set1_epi32(-1)); + } +}; + +template <> +struct CmpHelperI32 { + static inline __m256i + compare(const __m256i a, const __m256i b) { + return _mm256_cmpgt_epi32(a, b); + } +}; + +template <> +struct CmpHelperI32 { + static inline __m256i + compare(const __m256i a, const __m256i b) { + return _mm256_xor_si256(_mm256_cmpgt_epi32(a, b), + _mm256_set1_epi32(-1)); + } +}; + +template <> +struct CmpHelperI32 { + static inline __m256i + compare(const __m256i a, const __m256i b) { + return _mm256_cmpgt_epi32(b, a); + } +}; + +template <> +struct CmpHelperI32 { + static inline __m256i + compare(const __m256i a, const __m256i b) { + return _mm256_xor_si256(_mm256_cmpeq_epi32(a, b), + _mm256_set1_epi32(-1)); + } +}; + +// +template +struct CmpHelperI64 {}; + +template <> +struct CmpHelperI64 { + static inline __m256i + compare(const __m256i a, const __m256i b) { + return _mm256_cmpeq_epi64(a, b); + } +}; + +template <> +struct CmpHelperI64 { + static inline __m256i + compare(const __m256i a, const __m256i b) { + return _mm256_xor_si256(_mm256_cmpgt_epi64(b, a), + _mm256_set1_epi32(-1)); + } +}; + +template <> +struct CmpHelperI64 { + static inline __m256i + compare(const __m256i a, const __m256i b) { + return _mm256_cmpgt_epi64(a, b); + } +}; + +template <> +struct CmpHelperI64 { + static inline __m256i + compare(const __m256i a, const __m256i b) { + return _mm256_xor_si256(_mm256_cmpgt_epi64(a, b), + _mm256_set1_epi32(-1)); + } +}; + +template <> +struct CmpHelperI64 { + static inline __m256i + compare(const __m256i a, const __m256i b) { + return _mm256_cmpgt_epi64(b, a); + } +}; + +template <> +struct CmpHelperI64 { + static inline __m256i + compare(const __m256i a, const __m256i b) { + return _mm256_xor_si256(_mm256_cmpeq_epi64(a, b), + _mm256_set1_epi32(-1)); + } +}; + +} // namespace + +/////////////////////////////////////////////////////////////////////////// + +// +template +bool +OpCompareValImpl::op_compare_val(uint8_t* const __restrict res_u8, + const int8_t* const __restrict src, + const size_t size, + const int8_t& val) { + // the restriction of the API + assert((size % 8) == 0); + + // + uint32_t* const __restrict res_u32 = reinterpret_cast(res_u8); + const __m256i target = _mm256_set1_epi8(val); + + // todo: aligned reads & writes + + const size_t size32 = (size / 32) * 32; + for (size_t i = 0; i < size32; i += 32) { + const __m256i v0 = _mm256_loadu_si256((const __m256i*)(src + i)); + const __m256i cmp = CmpHelperI8::compare(v0, target); + const uint32_t mmask = _mm256_movemask_epi8(cmp); + + res_u32[i / 32] = mmask; + } + + if (size32 != size) { + // 8, 16 or 24 elements to process + const __m256i mask = + _mm256_setr_epi64x((size - size32 >= 8) ? (-1) : 0, + (size - size32 >= 16) ? (-1) : 0, + (size - size32 >= 24) ? (-1) : 0, + 0); + + const __m256i v0 = + _mm256_maskload_epi64((const long long*)(src + size32), mask); + const __m256i cmp = CmpHelperI8::compare(v0, target); + const uint32_t mmask = _mm256_movemask_epi8(cmp); + + if (size - size32 >= 8) { + res_u8[size32 / 8 + 0] = (mmask & 0xFF); + } + if (size - size32 >= 16) { + res_u8[size32 / 8 + 1] = ((mmask >> 8) & 0xFF); + } + if (size - size32 >= 24) { + res_u8[size32 / 8 + 2] = ((mmask >> 16) & 0xFF); + } + } + + return true; +} + +template +bool +OpCompareValImpl::op_compare_val( + uint8_t* const __restrict res_u8, + const int16_t* const __restrict src, + const size_t size, + const int16_t& val) { + // the restriction of the API + assert((size % 8) == 0); + + // + uint16_t* const __restrict res_u16 = reinterpret_cast(res_u8); + const __m256i target = _mm256_set1_epi16(val); + + // todo: aligned reads & writes + + const size_t size16 = (size / 16) * 16; + for (size_t i = 0; i < size16; i += 16) { + const __m256i v0 = _mm256_loadu_si256((const __m256i*)(src + i)); + const __m256i cmp = CmpHelperI16::compare(v0, target); + const __m256i pcmp = _mm256_packs_epi16(cmp, cmp); + const __m256i qcmp = + _mm256_permute4x64_epi64(pcmp, _MM_SHUFFLE(3, 1, 2, 0)); + const uint16_t mmask = _mm256_movemask_epi8(qcmp); + + res_u16[i / 16] = mmask; + } + + if (size16 != size) { + // 8 elements to process + const __m128i v0 = _mm_loadu_si128((const __m128i*)(src + size16)); + const __m128i target0 = _mm_set1_epi16(val); + const __m128i cmp = CmpHelperI16::compare(v0, target0); + const __m128i pcmp = _mm_packs_epi16(cmp, cmp); + const uint32_t mmask = _mm_movemask_epi8(pcmp) & 0xFF; + + res_u8[size16 / 8] = mmask; + } + + return true; +} + +template +bool +OpCompareValImpl::op_compare_val( + uint8_t* const __restrict res_u8, + const int32_t* const __restrict src, + const size_t size, + const int32_t& val) { + // the restriction of the API + assert((size % 8) == 0); + + // + const __m256i target = _mm256_set1_epi32(val); + + // todo: aligned reads & writes + + const size_t size8 = (size / 8) * 8; + for (size_t i = 0; i < size8; i += 8) { + const __m256i v0 = _mm256_loadu_si256((const __m256i*)(src + i)); + const __m256i cmp = CmpHelperI32::compare(v0, target); + const uint8_t mmask = _mm256_movemask_ps(_mm256_castsi256_ps(cmp)); + + res_u8[i / 8] = mmask; + } + + return true; +} + +template +bool +OpCompareValImpl::op_compare_val( + uint8_t* const __restrict res_u8, + const int64_t* const __restrict src, + const size_t size, + const int64_t& val) { + // the restriction of the API + assert((size % 8) == 0); + + // + const __m256i target = _mm256_set1_epi64x(val); + + // todo: aligned reads & writes + + const size_t size8 = (size / 8) * 8; + for (size_t i = 0; i < size8; i += 8) { + const __m256i v0 = _mm256_loadu_si256((const __m256i*)(src + i)); + const __m256i v1 = _mm256_loadu_si256((const __m256i*)(src + i + 4)); + const __m256i cmp0 = CmpHelperI64::compare(v0, target); + const __m256i cmp1 = CmpHelperI64::compare(v1, target); + const uint8_t mmask0 = _mm256_movemask_pd(_mm256_castsi256_pd(cmp0)); + const uint8_t mmask1 = _mm256_movemask_pd(_mm256_castsi256_pd(cmp1)); + + res_u8[i / 8] = mmask0 + mmask1 * 16; + } + + return true; +} + +template +bool +OpCompareValImpl::op_compare_val(uint8_t* const __restrict res_u8, + const float* const __restrict src, + const size_t size, + const float& val) { + // the restriction of the API + assert((size % 8) == 0); + + // + constexpr auto pred = ComparePredicate::value; + + const __m256 target = _mm256_set1_ps(val); + + // todo: aligned reads & writes + + const size_t size8 = (size / 8) * 8; + for (size_t i = 0; i < size8; i += 8) { + const __m256 v0 = _mm256_loadu_ps(src + i); + const __m256 cmp = _mm256_cmp_ps(v0, target, pred); + const uint8_t mmask = _mm256_movemask_ps(cmp); + + res_u8[i / 8] = mmask; + } + + return true; +} + +template +bool +OpCompareValImpl::op_compare_val(uint8_t* const __restrict res_u8, + const double* const __restrict src, + const size_t size, + const double& val) { + // the restriction of the API + assert((size % 8) == 0); + + // + constexpr auto pred = ComparePredicate::value; + + const __m256d target = _mm256_set1_pd(val); + + // todo: aligned reads & writes + + const size_t size8 = (size / 8) * 8; + for (size_t i = 0; i < size8; i += 8) { + const __m256d v0 = _mm256_loadu_pd(src + i); + const __m256d v1 = _mm256_loadu_pd(src + i + 4); + const __m256d cmp0 = _mm256_cmp_pd(v0, target, pred); + const __m256d cmp1 = _mm256_cmp_pd(v1, target, pred); + const uint8_t mmask0 = _mm256_movemask_pd(cmp0); + const uint8_t mmask1 = _mm256_movemask_pd(cmp1); + + res_u8[i / 8] = mmask0 + mmask1 * 16; + } + + return true; +} + +/////////////////////////////////////////////////////////////////////////// + +// +template +bool +OpCompareColumnImpl::op_compare_column( + uint8_t* const __restrict res_u8, + const int8_t* const __restrict left, + const int8_t* const __restrict right, + const size_t size) { + // the restriction of the API + assert((size % 8) == 0); + + // + uint32_t* const __restrict res_u32 = reinterpret_cast(res_u8); + + // todo: aligned reads & writes + + const size_t size32 = (size / 32) * 32; + for (size_t i = 0; i < size32; i += 32) { + const __m256i v0l = _mm256_loadu_si256((const __m256i*)(left + i)); + const __m256i v0r = _mm256_loadu_si256((const __m256i*)(right + i)); + const __m256i cmp = CmpHelperI8::compare(v0l, v0r); + const uint32_t mmask = _mm256_movemask_epi8(cmp); + + res_u32[i / 32] = mmask; + } + + if (size32 != size) { + // 8, 16 or 24 elements to process + const __m256i mask = + _mm256_setr_epi64x((size - size32 >= 8) ? (-1) : 0, + (size - size32 >= 16) ? (-1) : 0, + (size - size32 >= 24) ? (-1) : 0, + 0); + + const __m256i v0l = + _mm256_maskload_epi64((const long long*)(left + size32), mask); + const __m256i v0r = + _mm256_maskload_epi64((const long long*)(right + size32), mask); + const __m256i cmp = CmpHelperI8::compare(v0l, v0r); + const uint32_t mmask = _mm256_movemask_epi8(cmp); + + if (size - size32 >= 8) { + res_u8[size32 / 8 + 0] = (mmask & 0xFF); + } + if (size - size32 >= 16) { + res_u8[size32 / 8 + 1] = ((mmask >> 8) & 0xFF); + } + if (size - size32 >= 24) { + res_u8[size32 / 8 + 2] = ((mmask >> 16) & 0xFF); + } + } + + return true; +} + +template +bool +OpCompareColumnImpl::op_compare_column( + uint8_t* const __restrict res_u8, + const int16_t* const __restrict left, + const int16_t* const __restrict right, + const size_t size) { + // the restriction of the API + assert((size % 8) == 0); + + // + uint16_t* const __restrict res_u16 = reinterpret_cast(res_u8); + + // todo: aligned reads & writes + + const size_t size16 = (size / 16) * 16; + for (size_t i = 0; i < size16; i += 16) { + const __m256i v0l = _mm256_loadu_si256((const __m256i*)(left + i)); + const __m256i v0r = _mm256_loadu_si256((const __m256i*)(right + i)); + const __m256i cmp = CmpHelperI16::compare(v0l, v0r); + const __m256i pcmp = _mm256_packs_epi16(cmp, cmp); + const __m256i qcmp = + _mm256_permute4x64_epi64(pcmp, _MM_SHUFFLE(3, 1, 2, 0)); + const uint16_t mmask = _mm256_movemask_epi8(qcmp); + + res_u16[i / 16] = mmask; + } + + if (size16 != size) { + // 8 elements to process + const __m128i v0l = _mm_loadu_si128((const __m128i*)(left + size16)); + const __m128i v0r = _mm_loadu_si128((const __m128i*)(right + size16)); + const __m128i cmp = CmpHelperI16::compare(v0l, v0r); + const __m128i pcmp = _mm_packs_epi16(cmp, cmp); + const uint32_t mmask = _mm_movemask_epi8(pcmp) & 0xFF; + + res_u8[size16 / 8] = mmask; + } + + return true; +} + +template +bool +OpCompareColumnImpl::op_compare_column( + uint8_t* const __restrict res_u8, + const int32_t* const __restrict left, + const int32_t* const __restrict right, + const size_t size) { + // the restriction of the API + assert((size % 8) == 0); + + // todo: aligned reads & writes + + const size_t size8 = (size / 8) * 8; + for (size_t i = 0; i < size8; i += 8) { + const __m256i v0l = _mm256_loadu_si256((const __m256i*)(left + i)); + const __m256i v0r = _mm256_loadu_si256((const __m256i*)(right + i)); + const __m256i cmp = CmpHelperI32::compare(v0l, v0r); + const uint8_t mmask = _mm256_movemask_ps(_mm256_castsi256_ps(cmp)); + + res_u8[i / 8] = mmask; + } + + return true; +} + +template +bool +OpCompareColumnImpl::op_compare_column( + uint8_t* const __restrict res_u8, + const int64_t* const __restrict left, + const int64_t* const __restrict right, + const size_t size) { + // the restriction of the API + assert((size % 8) == 0); + + // todo: aligned reads & writes + + const size_t size8 = (size / 8) * 8; + for (size_t i = 0; i < size8; i += 8) { + const __m256i v0l = _mm256_loadu_si256((const __m256i*)(left + i)); + const __m256i v1l = _mm256_loadu_si256((const __m256i*)(left + i + 4)); + const __m256i v0r = _mm256_loadu_si256((const __m256i*)(right + i)); + const __m256i v1r = _mm256_loadu_si256((const __m256i*)(right + i + 4)); + const __m256i cmp0 = CmpHelperI64::compare(v0l, v0r); + const __m256i cmp1 = CmpHelperI64::compare(v1l, v1r); + const uint8_t mmask0 = _mm256_movemask_pd(_mm256_castsi256_pd(cmp0)); + const uint8_t mmask1 = _mm256_movemask_pd(_mm256_castsi256_pd(cmp1)); + + res_u8[i / 8] = mmask0 + mmask1 * 16; + } + + return true; +} + +template +bool +OpCompareColumnImpl::op_compare_column( + uint8_t* const __restrict res_u8, + const float* const __restrict left, + const float* const __restrict right, + const size_t size) { + // the restriction of the API + assert((size % 8) == 0); + + // + constexpr auto pred = ComparePredicate::value; + + // todo: aligned reads & writes + + const size_t size8 = (size / 8) * 8; + for (size_t i = 0; i < size8; i += 8) { + const __m256 v0l = _mm256_loadu_ps(left + i); + const __m256 v0r = _mm256_loadu_ps(right + i); + const __m256 cmp = _mm256_cmp_ps(v0l, v0r, pred); + const uint8_t mmask = _mm256_movemask_ps(cmp); + + res_u8[i / 8] = mmask; + } + + return true; +} + +template +bool +OpCompareColumnImpl::op_compare_column( + uint8_t* const __restrict res_u8, + const double* const __restrict left, + const double* const __restrict right, + const size_t size) { + // the restriction of the API + assert((size % 8) == 0); + + // + constexpr auto pred = ComparePredicate::value; + + // todo: aligned reads & writes + const size_t size8 = (size / 8) * 8; + for (size_t i = 0; i < size8; i += 8) { + const __m256d v0l = _mm256_loadu_pd(left + i); + const __m256d v1l = _mm256_loadu_pd(left + i + 4); + const __m256d v0r = _mm256_loadu_pd(right + i); + const __m256d v1r = _mm256_loadu_pd(right + i + 4); + const __m256d cmp0 = _mm256_cmp_pd(v0l, v0r, pred); + const __m256d cmp1 = _mm256_cmp_pd(v1l, v1r, pred); + const uint8_t mmask0 = _mm256_movemask_pd(cmp0); + const uint8_t mmask1 = _mm256_movemask_pd(cmp1); + + res_u8[i / 8] = mmask0 + mmask1 * 16; + } + + return true; +} + +/////////////////////////////////////////////////////////////////////////// + +// +template +bool +OpWithinRangeColumnImpl::op_within_range_column( + uint8_t* const __restrict res_u8, + const int8_t* const __restrict lower, + const int8_t* const __restrict upper, + const int8_t* const __restrict values, + const size_t size) { + // the restriction of the API + assert((size % 8) == 0); + + // + uint32_t* const __restrict res_u32 = reinterpret_cast(res_u8); + + // todo: aligned reads & writes + + const size_t size32 = (size / 32) * 32; + for (size_t i = 0; i < size32; i += 32) { + const __m256i v0l = _mm256_loadu_si256((const __m256i*)(lower + i)); + const __m256i v0u = _mm256_loadu_si256((const __m256i*)(upper + i)); + const __m256i v0v = _mm256_loadu_si256((const __m256i*)(values + i)); + const __m256i cmpl = + CmpHelperI8::lower>::compare(v0l, v0v); + const __m256i cmpu = + CmpHelperI8::upper>::compare(v0v, v0u); + const __m256i cmp = _mm256_and_si256(cmpl, cmpu); + const uint32_t mmask = _mm256_movemask_epi8(cmp); + + res_u32[i / 32] = mmask; + } + + if (size32 != size) { + // 8, 16 or 24 elements to process + const __m256i mask = + _mm256_setr_epi64x((size - size32 >= 8) ? (-1) : 0, + (size - size32 >= 16) ? (-1) : 0, + (size - size32 >= 24) ? (-1) : 0, + 0); + + const __m256i v0l = + _mm256_maskload_epi64((const long long*)(lower + size32), mask); + const __m256i v0u = + _mm256_maskload_epi64((const long long*)(upper + size32), mask); + const __m256i v0v = + _mm256_maskload_epi64((const long long*)(values + size32), mask); + const __m256i cmpl = + CmpHelperI8::lower>::compare(v0l, v0v); + const __m256i cmpu = + CmpHelperI8::upper>::compare(v0v, v0u); + const __m256i cmp = _mm256_and_si256(cmpl, cmpu); + const uint32_t mmask = _mm256_movemask_epi8(cmp); + + if (size - size32 >= 8) { + res_u8[size32 / 8 + 0] = (mmask & 0xFF); + } + if (size - size32 >= 16) { + res_u8[size32 / 8 + 1] = ((mmask >> 8) & 0xFF); + } + if (size - size32 >= 24) { + res_u8[size32 / 8 + 2] = ((mmask >> 16) & 0xFF); + } + } + + return true; +} + +template +bool +OpWithinRangeColumnImpl::op_within_range_column( + uint8_t* const __restrict res_u8, + const int16_t* const __restrict lower, + const int16_t* const __restrict upper, + const int16_t* const __restrict values, + const size_t size) { + // the restriction of the API + assert((size % 8) == 0); + + // + uint16_t* const __restrict res_u16 = reinterpret_cast(res_u8); + + // todo: aligned reads & writes + + const size_t size16 = (size / 16) * 16; + for (size_t i = 0; i < size16; i += 16) { + const __m256i v0l = _mm256_loadu_si256((const __m256i*)(lower + i)); + const __m256i v0u = _mm256_loadu_si256((const __m256i*)(upper + i)); + const __m256i v0v = _mm256_loadu_si256((const __m256i*)(values + i)); + const __m256i cmpl = + CmpHelperI16::lower>::compare(v0l, v0v); + const __m256i cmpu = + CmpHelperI16::upper>::compare(v0v, v0u); + const __m256i cmp = _mm256_and_si256(cmpl, cmpu); + const __m256i pcmp = _mm256_packs_epi16(cmp, cmp); + const __m256i qcmp = + _mm256_permute4x64_epi64(pcmp, _MM_SHUFFLE(3, 1, 2, 0)); + const uint16_t mmask = _mm256_movemask_epi8(qcmp); + + res_u16[i / 16] = mmask; + } + + if (size16 != size) { + // 8 elements to process + const __m128i v0l = _mm_loadu_si128((const __m128i*)(lower + size16)); + const __m128i v0u = _mm_loadu_si128((const __m128i*)(upper + size16)); + const __m128i v0v = _mm_loadu_si128((const __m128i*)(values + size16)); + const __m128i cmpl = + CmpHelperI16::lower>::compare(v0l, v0v); + const __m128i cmpu = + CmpHelperI16::upper>::compare(v0v, v0u); + const __m128i cmp = _mm_and_si128(cmpl, cmpu); + const __m128i pcmp = _mm_packs_epi16(cmp, cmp); + const uint32_t mmask = _mm_movemask_epi8(pcmp) & 0xFF; + + res_u8[size16 / 8] = mmask; + } + + return true; +} + +template +bool +OpWithinRangeColumnImpl::op_within_range_column( + uint8_t* const __restrict res_u8, + const int32_t* const __restrict lower, + const int32_t* const __restrict upper, + const int32_t* const __restrict values, + const size_t size) { + // the restriction of the API + assert((size % 8) == 0); + + // todo: aligned reads & writes + + const size_t size8 = (size / 8) * 8; + for (size_t i = 0; i < size8; i += 8) { + const __m256i v0l = _mm256_loadu_si256((const __m256i*)(lower + i)); + const __m256i v0u = _mm256_loadu_si256((const __m256i*)(upper + i)); + const __m256i v0v = _mm256_loadu_si256((const __m256i*)(values + i)); + const __m256i cmpl = + CmpHelperI32::lower>::compare(v0l, v0v); + const __m256i cmpu = + CmpHelperI32::upper>::compare(v0v, v0u); + const __m256i cmp = _mm256_and_si256(cmpl, cmpu); + const uint8_t mmask = _mm256_movemask_ps(_mm256_castsi256_ps(cmp)); + + res_u8[i / 8] = mmask; + } + + return true; +} + +template +bool +OpWithinRangeColumnImpl::op_within_range_column( + uint8_t* const __restrict res_u8, + const int64_t* const __restrict lower, + const int64_t* const __restrict upper, + const int64_t* const __restrict values, + const size_t size) { + // the restriction of the API + assert((size % 8) == 0); + + // todo: aligned reads & writes + + const size_t size8 = (size / 8) * 8; + for (size_t i = 0; i < size8; i += 8) { + const __m256i v0l = _mm256_loadu_si256((const __m256i*)(lower + i)); + const __m256i v1l = _mm256_loadu_si256((const __m256i*)(lower + i + 4)); + const __m256i v0u = _mm256_loadu_si256((const __m256i*)(upper + i)); + const __m256i v1u = _mm256_loadu_si256((const __m256i*)(upper + i + 4)); + const __m256i v0v = _mm256_loadu_si256((const __m256i*)(values + i)); + const __m256i v1v = + _mm256_loadu_si256((const __m256i*)(values + i + 4)); + const __m256i cmp0l = + CmpHelperI64::lower>::compare(v0l, v0v); + const __m256i cmp0u = + CmpHelperI64::upper>::compare(v0v, v0u); + const __m256i cmp1l = + CmpHelperI64::lower>::compare(v1l, v1v); + const __m256i cmp1u = + CmpHelperI64::upper>::compare(v1v, v1u); + const __m256i cmp0 = _mm256_and_si256(cmp0l, cmp0u); + const __m256i cmp1 = _mm256_and_si256(cmp1l, cmp1u); + const uint8_t mmask0 = _mm256_movemask_pd(_mm256_castsi256_pd(cmp0)); + const uint8_t mmask1 = _mm256_movemask_pd(_mm256_castsi256_pd(cmp1)); + + res_u8[i / 8] = mmask0 + mmask1 * 16; + } + + return true; +} + +template +bool +OpWithinRangeColumnImpl::op_within_range_column( + uint8_t* const __restrict res_u8, + const float* const __restrict lower, + const float* const __restrict upper, + const float* const __restrict values, + const size_t size) { + // the restriction of the API + assert((size % 8) == 0); + + // + constexpr auto pred_lower = + ComparePredicate::lower>::value; + constexpr auto pred_upper = + ComparePredicate::upper>::value; + + // todo: aligned reads & writes + + const size_t size8 = (size / 8) * 8; + for (size_t i = 0; i < size8; i += 8) { + const __m256 v0l = _mm256_loadu_ps(lower + i); + const __m256 v0u = _mm256_loadu_ps(upper + i); + const __m256 v0v = _mm256_loadu_ps(values + i); + const __m256 cmpl = _mm256_cmp_ps(v0l, v0v, pred_lower); + const __m256 cmpu = _mm256_cmp_ps(v0v, v0u, pred_upper); + const __m256 cmp = _mm256_and_ps(cmpl, cmpu); + const uint8_t mmask = _mm256_movemask_ps(cmp); + + res_u8[i / 8] = mmask; + } + + return true; +} + +template +bool +OpWithinRangeColumnImpl::op_within_range_column( + uint8_t* const __restrict res_u8, + const double* const __restrict lower, + const double* const __restrict upper, + const double* const __restrict values, + const size_t size) { + // the restriction of the API + assert((size % 8) == 0); + + // + constexpr auto pred_lower = + ComparePredicate::lower>::value; + constexpr auto pred_upper = + ComparePredicate::upper>::value; + + // todo: aligned reads & writes + const size_t size8 = (size / 8) * 8; + for (size_t i = 0; i < size8; i += 8) { + const __m256d v0l = _mm256_loadu_pd(lower + i); + const __m256d v1l = _mm256_loadu_pd(lower + i + 4); + const __m256d v0u = _mm256_loadu_pd(upper + i); + const __m256d v1u = _mm256_loadu_pd(upper + i + 4); + const __m256d v0v = _mm256_loadu_pd(values + i); + const __m256d v1v = _mm256_loadu_pd(values + i + 4); + const __m256d cmp0l = _mm256_cmp_pd(v0l, v0v, pred_lower); + const __m256d cmp0u = _mm256_cmp_pd(v0v, v0u, pred_upper); + const __m256d cmp1l = _mm256_cmp_pd(v1l, v1v, pred_lower); + const __m256d cmp1u = _mm256_cmp_pd(v1v, v1u, pred_upper); + const __m256d cmp0 = _mm256_and_pd(cmp0l, cmp0u); + const __m256d cmp1 = _mm256_and_pd(cmp1l, cmp1u); + const uint8_t mmask0 = _mm256_movemask_pd(cmp0); + const uint8_t mmask1 = _mm256_movemask_pd(cmp1); + + res_u8[i / 8] = mmask0 + mmask1 * 16; + } + + return true; +} + +/////////////////////////////////////////////////////////////////////////// + +template +bool +OpWithinRangeValImpl::op_within_range_val( + uint8_t* const __restrict res_u8, + const int8_t& lower, + const int8_t& upper, + const int8_t* const __restrict values, + const size_t size) { + // the restriction of the API + assert((size % 8) == 0); + + // + uint32_t* const __restrict res_u32 = reinterpret_cast(res_u8); + const __m256i lower_v = _mm256_set1_epi8(lower); + const __m256i upper_v = _mm256_set1_epi8(upper); + + // todo: aligned reads & writes + + const size_t size32 = (size / 32) * 32; + for (size_t i = 0; i < size32; i += 32) { + const __m256i v0v = _mm256_loadu_si256((const __m256i*)(values + i)); + const __m256i cmpl = + CmpHelperI8::lower>::compare(lower_v, v0v); + const __m256i cmpu = + CmpHelperI8::upper>::compare(v0v, upper_v); + const __m256i cmp = _mm256_and_si256(cmpl, cmpu); + const uint32_t mmask = _mm256_movemask_epi8(cmp); + + res_u32[i / 32] = mmask; + } + + if (size32 != size) { + // 8, 16 or 24 elements to process + const __m256i mask = + _mm256_setr_epi64x((size - size32 >= 8) ? (-1) : 0, + (size - size32 >= 16) ? (-1) : 0, + (size - size32 >= 24) ? (-1) : 0, + 0); + + const __m256i v0v = + _mm256_maskload_epi64((const long long*)(values + size32), mask); + const __m256i cmpl = + CmpHelperI8::lower>::compare(lower_v, v0v); + const __m256i cmpu = + CmpHelperI8::upper>::compare(v0v, upper_v); + const __m256i cmp = _mm256_and_si256(cmpl, cmpu); + const uint32_t mmask = _mm256_movemask_epi8(cmp); + + if (size - size32 >= 8) { + res_u8[size32 / 8 + 0] = (mmask & 0xFF); + } + if (size - size32 >= 16) { + res_u8[size32 / 8 + 1] = ((mmask >> 8) & 0xFF); + } + if (size - size32 >= 24) { + res_u8[size32 / 8 + 2] = ((mmask >> 16) & 0xFF); + } + } + + return true; +} + +template +bool +OpWithinRangeValImpl::op_within_range_val( + uint8_t* const __restrict res_u8, + const int16_t& lower, + const int16_t& upper, + const int16_t* const __restrict values, + const size_t size) { + // the restriction of the API + assert((size % 8) == 0); + + // + uint16_t* const __restrict res_u16 = reinterpret_cast(res_u8); + const __m256i lower_v = _mm256_set1_epi16(lower); + const __m256i upper_v = _mm256_set1_epi16(upper); + + // todo: aligned reads & writes + + const size_t size16 = (size / 16) * 16; + for (size_t i = 0; i < size16; i += 16) { + const __m256i v0v = _mm256_loadu_si256((const __m256i*)(values + i)); + const __m256i cmpl = + CmpHelperI16::lower>::compare(lower_v, v0v); + const __m256i cmpu = + CmpHelperI16::upper>::compare(v0v, upper_v); + const __m256i cmp = _mm256_and_si256(cmpl, cmpu); + const __m256i pcmp = _mm256_packs_epi16(cmp, cmp); + const __m256i qcmp = + _mm256_permute4x64_epi64(pcmp, _MM_SHUFFLE(3, 1, 2, 0)); + const uint16_t mmask = _mm256_movemask_epi8(qcmp); + + res_u16[i / 16] = mmask; + } + + if (size16 != size) { + // 8 elements to process + const __m128i lower_v1 = _mm_set1_epi16(lower); + const __m128i upper_v1 = _mm_set1_epi16(upper); + const __m128i v0v = _mm_loadu_si128((const __m128i*)(values + size16)); + const __m128i cmpl = + CmpHelperI16::lower>::compare(lower_v1, v0v); + const __m128i cmpu = + CmpHelperI16::upper>::compare(v0v, upper_v1); + const __m128i cmp = _mm_and_si128(cmpl, cmpu); + const __m128i pcmp = _mm_packs_epi16(cmp, cmp); + const uint32_t mmask = _mm_movemask_epi8(pcmp) & 0xFF; + + res_u8[size16 / 8] = mmask; + } + + return true; +} + +template +bool +OpWithinRangeValImpl::op_within_range_val( + uint8_t* const __restrict res_u8, + const int32_t& lower, + const int32_t& upper, + const int32_t* const __restrict values, + const size_t size) { + // the restriction of the API + assert((size % 8) == 0); + + // + const __m256i lower_v = _mm256_set1_epi32(lower); + const __m256i upper_v = _mm256_set1_epi32(upper); + + // todo: aligned reads & writes + + const size_t size8 = (size / 8) * 8; + for (size_t i = 0; i < size8; i += 8) { + const __m256i v0v = _mm256_loadu_si256((const __m256i*)(values + i)); + const __m256i cmpl = + CmpHelperI32::lower>::compare(lower_v, v0v); + const __m256i cmpu = + CmpHelperI32::upper>::compare(v0v, upper_v); + const __m256i cmp = _mm256_and_si256(cmpl, cmpu); + const uint8_t mmask = _mm256_movemask_ps(_mm256_castsi256_ps(cmp)); + + res_u8[i / 8] = mmask; + } + + return true; +} + +template +bool +OpWithinRangeValImpl::op_within_range_val( + uint8_t* const __restrict res_u8, + const int64_t& lower, + const int64_t& upper, + const int64_t* const __restrict values, + const size_t size) { + // the restriction of the API + assert((size % 8) == 0); + + // + const __m256i lower_v = _mm256_set1_epi64x(lower); + const __m256i upper_v = _mm256_set1_epi64x(upper); + + // todo: aligned reads & writes + + const size_t size8 = (size / 8) * 8; + for (size_t i = 0; i < size8; i += 8) { + const __m256i v0v = _mm256_loadu_si256((const __m256i*)(values + i)); + const __m256i v1v = + _mm256_loadu_si256((const __m256i*)(values + i + 4)); + const __m256i cmp0l = + CmpHelperI64::lower>::compare(lower_v, v0v); + const __m256i cmp0u = + CmpHelperI64::upper>::compare(v0v, upper_v); + const __m256i cmp1l = + CmpHelperI64::lower>::compare(lower_v, v1v); + const __m256i cmp1u = + CmpHelperI64::upper>::compare(v1v, upper_v); + const __m256i cmp0 = _mm256_and_si256(cmp0l, cmp0u); + const __m256i cmp1 = _mm256_and_si256(cmp1l, cmp1u); + const uint8_t mmask0 = _mm256_movemask_pd(_mm256_castsi256_pd(cmp0)); + const uint8_t mmask1 = _mm256_movemask_pd(_mm256_castsi256_pd(cmp1)); + + res_u8[i / 8] = mmask0 + mmask1 * 16; + } + + return true; +} + +template +bool +OpWithinRangeValImpl::op_within_range_val( + uint8_t* const __restrict res_u8, + const float& lower, + const float& upper, + const float* const __restrict values, + const size_t size) { + // the restriction of the API + assert((size % 8) == 0); + + // + const __m256 lower_v = _mm256_set1_ps(lower); + const __m256 upper_v = _mm256_set1_ps(upper); + constexpr auto pred_lower = + ComparePredicate::lower>::value; + constexpr auto pred_upper = + ComparePredicate::upper>::value; + + // todo: aligned reads & writes + + const size_t size8 = (size / 8) * 8; + for (size_t i = 0; i < size8; i += 8) { + const __m256 v0v = _mm256_loadu_ps(values + i); + const __m256 cmpl = _mm256_cmp_ps(lower_v, v0v, pred_lower); + const __m256 cmpu = _mm256_cmp_ps(v0v, upper_v, pred_upper); + const __m256 cmp = _mm256_and_ps(cmpl, cmpu); + const uint8_t mmask = _mm256_movemask_ps(cmp); + + res_u8[i / 8] = mmask; + } + + return true; +} + +template +bool +OpWithinRangeValImpl::op_within_range_val( + uint8_t* const __restrict res_u8, + const double& lower, + const double& upper, + const double* const __restrict values, + const size_t size) { + // the restriction of the API + assert((size % 8) == 0); + + // + const __m256d lower_v = _mm256_set1_pd(lower); + const __m256d upper_v = _mm256_set1_pd(upper); + constexpr auto pred_lower = + ComparePredicate::lower>::value; + constexpr auto pred_upper = + ComparePredicate::upper>::value; + + // todo: aligned reads & writes + const size_t size8 = (size / 8) * 8; + for (size_t i = 0; i < size8; i += 8) { + const __m256d v0v = _mm256_loadu_pd(values + i); + const __m256d v1v = _mm256_loadu_pd(values + i + 4); + const __m256d cmp0l = _mm256_cmp_pd(lower_v, v0v, pred_lower); + const __m256d cmp0u = _mm256_cmp_pd(v0v, upper_v, pred_upper); + const __m256d cmp1l = _mm256_cmp_pd(lower_v, v1v, pred_lower); + const __m256d cmp1u = _mm256_cmp_pd(v1v, upper_v, pred_upper); + const __m256d cmp0 = _mm256_and_pd(cmp0l, cmp0u); + const __m256d cmp1 = _mm256_and_pd(cmp1l, cmp1u); + const uint8_t mmask0 = _mm256_movemask_pd(cmp0); + const uint8_t mmask1 = _mm256_movemask_pd(cmp1); + + res_u8[i / 8] = mmask0 + mmask1 * 16; + } + + return true; +} + +/////////////////////////////////////////////////////////////////////////// + +namespace { + +// +template +struct ArithHelperI64 {}; + +template +struct ArithHelperI64 { + static inline __m256i + op(const __m256i left, const __m256i right, const __m256i value) { + // left + right ?? value + return CmpHelperI64::compare(_mm256_add_epi64(left, right), + value); + } +}; + +template +struct ArithHelperI64 { + static inline __m256i + op(const __m256i left, const __m256i right, const __m256i value) { + // left - right ?? value + return CmpHelperI64::compare(_mm256_sub_epi64(left, right), + value); + } +}; + +template +struct ArithHelperI64 { + static inline __m256i + op(const __m256i left, const __m256i right, const __m256i value) { + // left * right ?? value + + // draft: the code from Agner Fog's vectorclass library + const __m256i a = left; + const __m256i b = right; + const __m256i bswap = _mm256_shuffle_epi32(b, 0xB1); // swap H<->L + const __m256i prodlh = + _mm256_mullo_epi32(a, bswap); // 32 bit L*H products + const __m256i zero = _mm256_setzero_si256(); // 0 + const __m256i prodlh2 = + _mm256_hadd_epi32(prodlh, zero); // a0Lb0H+a0Hb0L,a1Lb1H+a1Hb1L,0,0 + const __m256i prodlh3 = _mm256_shuffle_epi32( + prodlh2, 0x73); // 0, a0Lb0H+a0Hb0L, 0, a1Lb1H+a1Hb1L + const __m256i prodll = + _mm256_mul_epu32(a, b); // a0Lb0L,a1Lb1L, 64 bit unsigned products + const __m256i prod = _mm256_add_epi64( + prodll, + prodlh3); // a0Lb0L+(a0Lb0H+a0Hb0L)<<32, a1Lb1L+(a1Lb1H+a1Hb1L)<<32 + + return CmpHelperI64::compare(prod, value); + } +}; + +// todo: Mul, Div, Mod + +// +template +struct ArithHelperF32 {}; + +template +struct ArithHelperF32 { + static inline __m256 + op(const __m256 left, const __m256 right, const __m256 value) { + // left + right == value + constexpr auto pred = ComparePredicate::value; + return _mm256_cmp_ps(_mm256_add_ps(left, right), value, pred); + } +}; + +template +struct ArithHelperF32 { + static inline __m256 + op(const __m256 left, const __m256 right, const __m256 value) { + // left - right == value + constexpr auto pred = ComparePredicate::value; + return _mm256_cmp_ps(_mm256_sub_ps(left, right), value, pred); + } +}; + +template +struct ArithHelperF32 { + static inline __m256 + op(const __m256 left, const __m256 right, const __m256 value) { + // left * right == value + constexpr auto pred = ComparePredicate::value; + return _mm256_cmp_ps(_mm256_mul_ps(left, right), value, pred); + } +}; + +template +struct ArithHelperF32 { + static inline __m256 + op(const __m256 left, const __m256 right, const __m256 value) { + // left == right * value + constexpr auto pred = ComparePredicate::value; + return _mm256_cmp_ps(left, _mm256_mul_ps(right, value), pred); + } +}; + +// todo: Mod + +// +template +struct ArithHelperF64 {}; + +template +struct ArithHelperF64 { + static inline __m256d + op(const __m256d left, const __m256d right, const __m256d value) { + // left + right == value + constexpr auto pred = ComparePredicate::value; + return _mm256_cmp_pd(_mm256_add_pd(left, right), value, pred); + } +}; + +template +struct ArithHelperF64 { + static inline __m256d + op(const __m256d left, const __m256d right, const __m256d value) { + // left - right == value + constexpr auto pred = ComparePredicate::value; + return _mm256_cmp_pd(_mm256_sub_pd(left, right), value, pred); + } +}; + +template +struct ArithHelperF64 { + static inline __m256d + op(const __m256d left, const __m256d right, const __m256d value) { + // left * right == value + constexpr auto pred = ComparePredicate::value; + return _mm256_cmp_pd(_mm256_mul_pd(left, right), value, pred); + } +}; + +template +struct ArithHelperF64 { + static inline __m256d + op(const __m256d left, const __m256d right, const __m256d value) { + // left == right * value + constexpr auto pred = ComparePredicate::value; + return _mm256_cmp_pd(left, _mm256_mul_pd(right, value), pred); + } +}; + +} // namespace + +// todo: Mul, Div, Mod + +template +bool +OpArithCompareImpl::op_arith_compare( + uint8_t* const __restrict res_u8, + const int8_t* const __restrict src, + const ArithHighPrecisionType& right_operand, + const ArithHighPrecisionType& value, + const size_t size) { + if constexpr (AOp == ArithOpType::Div || AOp == ArithOpType::Mod) { + return false; + } else { + // the restriction of the API + assert((size % 8) == 0); + static_assert(std::is_same_v>); + + // + const __m256i right_v = _mm256_set1_epi64x(right_operand); + const __m256i value_v = _mm256_set1_epi64x(value); + const uint64_t* const __restrict src_u64 = + reinterpret_cast(src); + + // todo: aligned reads & writes + + const size_t size8 = (size / 8) * 8; + for (size_t i = 0; i < size8; i += 8) { + const uint64_t v = src_u64[i / 8]; + const __m256i v0s = _mm256_cvtepi8_epi64(_mm_set_epi64x(0, v)); + const __m256i v1s = + _mm256_cvtepi8_epi64(_mm_set_epi64x(0, v >> 32)); + const __m256i cmp0 = + ArithHelperI64::op(v0s, right_v, value_v); + const __m256i cmp1 = + ArithHelperI64::op(v1s, right_v, value_v); + const uint8_t mmask0 = + _mm256_movemask_pd(_mm256_castsi256_pd(cmp0)); + const uint8_t mmask1 = + _mm256_movemask_pd(_mm256_castsi256_pd(cmp1)); + + res_u8[i / 8] = mmask0 + mmask1 * 16; + } + + return true; + } +} + +template +bool +OpArithCompareImpl::op_arith_compare( + uint8_t* const __restrict res_u8, + const int16_t* const __restrict src, + const ArithHighPrecisionType& right_operand, + const ArithHighPrecisionType& value, + const size_t size) { + if constexpr (AOp == ArithOpType::Div || AOp == ArithOpType::Mod) { + return false; + } else { + // the restriction of the API + assert((size % 8) == 0); + static_assert(std::is_same_v>); + + // + const __m256i right_v = _mm256_set1_epi64x(right_operand); + const __m256i value_v = _mm256_set1_epi64x(value); + + // todo: aligned reads & writes + + const size_t size8 = (size / 8) * 8; + for (size_t i = 0; i < size8; i += 8) { + const __m128i vs = _mm_loadu_si128((const __m128i*)(src + i)); + const __m256i v0s = _mm256_cvtepi16_epi64(vs); + const __m128i v1sr = _mm_set_epi64x(0, _mm_extract_epi64(vs, 1)); + const __m256i v1s = _mm256_cvtepi16_epi64(v1sr); + const __m256i cmp0 = + ArithHelperI64::op(v0s, right_v, value_v); + const __m256i cmp1 = + ArithHelperI64::op(v1s, right_v, value_v); + const uint8_t mmask0 = + _mm256_movemask_pd(_mm256_castsi256_pd(cmp0)); + const uint8_t mmask1 = + _mm256_movemask_pd(_mm256_castsi256_pd(cmp1)); + + res_u8[i / 8] = mmask0 + mmask1 * 16; + } + + return true; + } +} + +template +bool +OpArithCompareImpl::op_arith_compare( + uint8_t* const __restrict res_u8, + const int32_t* const __restrict src, + const ArithHighPrecisionType& right_operand, + const ArithHighPrecisionType& value, + const size_t size) { + if constexpr (AOp == ArithOpType::Div || AOp == ArithOpType::Mod) { + return false; + } else { + // the restriction of the API + assert((size % 8) == 0); + static_assert(std::is_same_v>); + + // + const __m256i right_v = _mm256_set1_epi64x(right_operand); + const __m256i value_v = _mm256_set1_epi64x(value); + + // todo: aligned reads & writes + + const size_t size8 = (size / 8) * 8; + for (size_t i = 0; i < size8; i += 8) { + const __m256i vs = _mm256_loadu_si256((const __m256i*)(src + i)); + const __m256i v0s = + _mm256_cvtepi32_epi64(_mm256_extracti128_si256(vs, 0)); + const __m256i v1s = + _mm256_cvtepi32_epi64(_mm256_extracti128_si256(vs, 1)); + const __m256i cmp0 = + ArithHelperI64::op(v0s, right_v, value_v); + const __m256i cmp1 = + ArithHelperI64::op(v1s, right_v, value_v); + const uint8_t mmask0 = + _mm256_movemask_pd(_mm256_castsi256_pd(cmp0)); + const uint8_t mmask1 = + _mm256_movemask_pd(_mm256_castsi256_pd(cmp1)); + + res_u8[i / 8] = mmask0 + mmask1 * 16; + } + + return true; + } +} + +template +bool +OpArithCompareImpl::op_arith_compare( + uint8_t* const __restrict res_u8, + const int64_t* const __restrict src, + const ArithHighPrecisionType& right_operand, + const ArithHighPrecisionType& value, + const size_t size) { + if constexpr (AOp == ArithOpType::Div || AOp == ArithOpType::Mod) { + return false; + } else { + // the restriction of the API + assert((size % 8) == 0); + static_assert(std::is_same_v>); + + // + const __m256i right_v = _mm256_set1_epi64x(right_operand); + const __m256i value_v = _mm256_set1_epi64x(value); + + // todo: aligned reads & writes + + const size_t size8 = (size / 8) * 8; + for (size_t i = 0; i < size8; i += 8) { + const __m256i v0s = _mm256_loadu_si256((const __m256i*)(src + i)); + const __m256i v1s = + _mm256_loadu_si256((const __m256i*)(src + i + 4)); + const __m256i cmp0 = + ArithHelperI64::op(v0s, right_v, value_v); + const __m256i cmp1 = + ArithHelperI64::op(v1s, right_v, value_v); + const uint8_t mmask0 = + _mm256_movemask_pd(_mm256_castsi256_pd(cmp0)); + const uint8_t mmask1 = + _mm256_movemask_pd(_mm256_castsi256_pd(cmp1)); + + res_u8[i / 8] = mmask0 + mmask1 * 16; + } + + return true; + } +} + +template +bool +OpArithCompareImpl::op_arith_compare( + uint8_t* const __restrict res_u8, + const float* const __restrict src, + const ArithHighPrecisionType& right_operand, + const ArithHighPrecisionType& value, + const size_t size) { + if constexpr (AOp == ArithOpType::Mod) { + return false; + } else { + // the restriction of the API + assert((size % 8) == 0); + + // + const __m256 right_v = _mm256_set1_ps(right_operand); + const __m256 value_v = _mm256_set1_ps(value); + + // todo: aligned reads & writes + + const size_t size8 = (size / 8) * 8; + for (size_t i = 0; i < size8; i += 8) { + const __m256 v0s = _mm256_loadu_ps(src + i); + const __m256 cmp = + ArithHelperF32::op(v0s, right_v, value_v); + const uint8_t mmask = _mm256_movemask_ps(cmp); + + res_u8[i / 8] = mmask; + } + + return true; + } +} + +template +bool +OpArithCompareImpl::op_arith_compare( + uint8_t* const __restrict res_u8, + const double* const __restrict src, + const ArithHighPrecisionType& right_operand, + const ArithHighPrecisionType& value, + const size_t size) { + if constexpr (AOp == ArithOpType::Mod) { + return false; + } else { + // the restriction of the API + assert((size % 8) == 0); + + // + const __m256d right_v = _mm256_set1_pd(right_operand); + const __m256d value_v = _mm256_set1_pd(value); + + // todo: aligned reads & writes + + const size_t size8 = (size / 8) * 8; + for (size_t i = 0; i < size8; i += 8) { + const __m256d v0s = _mm256_loadu_pd(src + i); + const __m256d v1s = _mm256_loadu_pd(src + i + 4); + const __m256d cmp0 = + ArithHelperF64::op(v0s, right_v, value_v); + const __m256d cmp1 = + ArithHelperF64::op(v1s, right_v, value_v); + const uint8_t mmask0 = _mm256_movemask_pd(cmp0); + const uint8_t mmask1 = _mm256_movemask_pd(cmp1); + + res_u8[i / 8] = mmask0 + mmask1 * 16; + } + + return true; + } +} + +/////////////////////////////////////////////////////////////////////////// + +} // namespace avx2 +} // namespace x86 +} // namespace detail +} // namespace bitset +} // namespace milvus diff --git a/internal/core/src/bitset/detail/platform/x86/avx2-inst.cpp b/internal/core/src/bitset/detail/platform/x86/avx2-inst.cpp new file mode 100644 index 000000000000..5f73a1ef126e --- /dev/null +++ b/internal/core/src/bitset/detail/platform/x86/avx2-inst.cpp @@ -0,0 +1,199 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// AVX2 instantiation + +#include "bitset/common.h" + +#ifndef BITSET_HEADER_ONLY + +#include "avx2-decl.h" +#include "avx2-impl.h" + +#include +#include + +namespace milvus { +namespace bitset { +namespace detail { +namespace x86 { +namespace avx2 { + +// a facility to run through all possible compare operations +#define ALL_COMPARE_OPS(FUNC, ...) \ + FUNC(__VA_ARGS__, EQ); \ + FUNC(__VA_ARGS__, GE); \ + FUNC(__VA_ARGS__, GT); \ + FUNC(__VA_ARGS__, LE); \ + FUNC(__VA_ARGS__, LT); \ + FUNC(__VA_ARGS__, NE); + +// a facility to run through all possible range operations +#define ALL_RANGE_OPS(FUNC, ...) \ + FUNC(__VA_ARGS__, IncInc); \ + FUNC(__VA_ARGS__, IncExc); \ + FUNC(__VA_ARGS__, ExcInc); \ + FUNC(__VA_ARGS__, ExcExc); + +// a facility to run through all possible arithmetic compare operations +#define ALL_ARITH_CMP_OPS(FUNC, ...) \ + FUNC(__VA_ARGS__, Add, EQ); \ + FUNC(__VA_ARGS__, Add, GE); \ + FUNC(__VA_ARGS__, Add, GT); \ + FUNC(__VA_ARGS__, Add, LE); \ + FUNC(__VA_ARGS__, Add, LT); \ + FUNC(__VA_ARGS__, Add, NE); \ + FUNC(__VA_ARGS__, Sub, EQ); \ + FUNC(__VA_ARGS__, Sub, GE); \ + FUNC(__VA_ARGS__, Sub, GT); \ + FUNC(__VA_ARGS__, Sub, LE); \ + FUNC(__VA_ARGS__, Sub, LT); \ + FUNC(__VA_ARGS__, Sub, NE); \ + FUNC(__VA_ARGS__, Mul, EQ); \ + FUNC(__VA_ARGS__, Mul, GE); \ + FUNC(__VA_ARGS__, Mul, GT); \ + FUNC(__VA_ARGS__, Mul, LE); \ + FUNC(__VA_ARGS__, Mul, LT); \ + FUNC(__VA_ARGS__, Mul, NE); \ + FUNC(__VA_ARGS__, Div, EQ); \ + FUNC(__VA_ARGS__, Div, GE); \ + FUNC(__VA_ARGS__, Div, GT); \ + FUNC(__VA_ARGS__, Div, LE); \ + FUNC(__VA_ARGS__, Div, LT); \ + FUNC(__VA_ARGS__, Div, NE); \ + FUNC(__VA_ARGS__, Mod, EQ); \ + FUNC(__VA_ARGS__, Mod, GE); \ + FUNC(__VA_ARGS__, Mod, GT); \ + FUNC(__VA_ARGS__, Mod, LE); \ + FUNC(__VA_ARGS__, Mod, LT); \ + FUNC(__VA_ARGS__, Mod, NE); + +/////////////////////////////////////////////////////////////////////////// + +// +#define INSTANTIATE_COMPARE_VAL_AVX2(TTYPE, OP) \ + template bool OpCompareValImpl::op_compare_val( \ + uint8_t* const __restrict bitmask, \ + const TTYPE* const __restrict src, \ + const size_t size, \ + const TTYPE& val); + +ALL_COMPARE_OPS(INSTANTIATE_COMPARE_VAL_AVX2, int8_t) +ALL_COMPARE_OPS(INSTANTIATE_COMPARE_VAL_AVX2, int16_t) +ALL_COMPARE_OPS(INSTANTIATE_COMPARE_VAL_AVX2, int32_t) +ALL_COMPARE_OPS(INSTANTIATE_COMPARE_VAL_AVX2, int64_t) +ALL_COMPARE_OPS(INSTANTIATE_COMPARE_VAL_AVX2, float) +ALL_COMPARE_OPS(INSTANTIATE_COMPARE_VAL_AVX2, double) + +#undef INSTANTIATE_COMPARE_VAL_AVX2 + +/////////////////////////////////////////////////////////////////////////// + +// +#define INSTANTIATE_COMPARE_COLUMN_AVX2(TTYPE, OP) \ + template bool \ + OpCompareColumnImpl::op_compare_column( \ + uint8_t* const __restrict bitmask, \ + const TTYPE* const __restrict left, \ + const TTYPE* const __restrict right, \ + const size_t size); + +ALL_COMPARE_OPS(INSTANTIATE_COMPARE_COLUMN_AVX2, int8_t) +ALL_COMPARE_OPS(INSTANTIATE_COMPARE_COLUMN_AVX2, int16_t) +ALL_COMPARE_OPS(INSTANTIATE_COMPARE_COLUMN_AVX2, int32_t) +ALL_COMPARE_OPS(INSTANTIATE_COMPARE_COLUMN_AVX2, int64_t) +ALL_COMPARE_OPS(INSTANTIATE_COMPARE_COLUMN_AVX2, float) +ALL_COMPARE_OPS(INSTANTIATE_COMPARE_COLUMN_AVX2, double) + +#undef INSTANTIATE_COMPARE_COLUMN_AVX2 + +/////////////////////////////////////////////////////////////////////////// + +// +#define INSTANTIATE_WITHIN_RANGE_COLUMN_AVX2(TTYPE, OP) \ + template bool \ + OpWithinRangeColumnImpl::op_within_range_column( \ + uint8_t* const __restrict res_u8, \ + const TTYPE* const __restrict lower, \ + const TTYPE* const __restrict upper, \ + const TTYPE* const __restrict values, \ + const size_t size); + +ALL_RANGE_OPS(INSTANTIATE_WITHIN_RANGE_COLUMN_AVX2, int8_t) +ALL_RANGE_OPS(INSTANTIATE_WITHIN_RANGE_COLUMN_AVX2, int16_t) +ALL_RANGE_OPS(INSTANTIATE_WITHIN_RANGE_COLUMN_AVX2, int32_t) +ALL_RANGE_OPS(INSTANTIATE_WITHIN_RANGE_COLUMN_AVX2, int64_t) +ALL_RANGE_OPS(INSTANTIATE_WITHIN_RANGE_COLUMN_AVX2, float) +ALL_RANGE_OPS(INSTANTIATE_WITHIN_RANGE_COLUMN_AVX2, double) + +#undef INSTANTIATE_WITHIN_RANGE_COLUMN_AVX2 + +/////////////////////////////////////////////////////////////////////////// + +// +#define INSTANTIATE_WITHIN_RANGE_VAL_AVX2(TTYPE, OP) \ + template bool \ + OpWithinRangeValImpl::op_within_range_val( \ + uint8_t* const __restrict res_u8, \ + const TTYPE& lower, \ + const TTYPE& upper, \ + const TTYPE* const __restrict values, \ + const size_t size); + +ALL_RANGE_OPS(INSTANTIATE_WITHIN_RANGE_VAL_AVX2, int8_t) +ALL_RANGE_OPS(INSTANTIATE_WITHIN_RANGE_VAL_AVX2, int16_t) +ALL_RANGE_OPS(INSTANTIATE_WITHIN_RANGE_VAL_AVX2, int32_t) +ALL_RANGE_OPS(INSTANTIATE_WITHIN_RANGE_VAL_AVX2, int64_t) +ALL_RANGE_OPS(INSTANTIATE_WITHIN_RANGE_VAL_AVX2, float) +ALL_RANGE_OPS(INSTANTIATE_WITHIN_RANGE_VAL_AVX2, double) + +#undef INSTANTIATE_WITHIN_RANGE_VAL_AVX2 + +/////////////////////////////////////////////////////////////////////////// + +// +#define INSTANTIATE_ARITH_COMPARE_AVX2(TTYPE, OP, CMP) \ + template bool \ + OpArithCompareImpl:: \ + op_arith_compare(uint8_t* const __restrict res_u8, \ + const TTYPE* const __restrict src, \ + const ArithHighPrecisionType& right_operand, \ + const ArithHighPrecisionType& value, \ + const size_t size); + +ALL_ARITH_CMP_OPS(INSTANTIATE_ARITH_COMPARE_AVX2, int8_t) +ALL_ARITH_CMP_OPS(INSTANTIATE_ARITH_COMPARE_AVX2, int16_t) +ALL_ARITH_CMP_OPS(INSTANTIATE_ARITH_COMPARE_AVX2, int32_t) +ALL_ARITH_CMP_OPS(INSTANTIATE_ARITH_COMPARE_AVX2, int64_t) +ALL_ARITH_CMP_OPS(INSTANTIATE_ARITH_COMPARE_AVX2, float) +ALL_ARITH_CMP_OPS(INSTANTIATE_ARITH_COMPARE_AVX2, double) + +#undef INSTANTIATE_ARITH_COMPARE_AVX2 + +/////////////////////////////////////////////////////////////////////////// + +// +#undef ALL_COMPARE_OPS +#undef ALL_RANGE_OPS +#undef ALL_ARITH_CMP_OPS + +} // namespace avx2 +} // namespace x86 +} // namespace detail +} // namespace bitset +} // namespace milvus + +#endif diff --git a/internal/core/src/bitset/detail/platform/x86/avx2.h b/internal/core/src/bitset/detail/platform/x86/avx2.h new file mode 100644 index 000000000000..711b9f2b8f51 --- /dev/null +++ b/internal/core/src/bitset/detail/platform/x86/avx2.h @@ -0,0 +1,63 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "bitset/common.h" + +#include "avx2-decl.h" + +#ifdef BITSET_HEADER_ONLY +#include "avx2-impl.h" +#endif + +namespace milvus { +namespace bitset { +namespace detail { +namespace x86 { + +/////////////////////////////////////////////////////////////////////////// + +// +struct VectorizedAvx2 { + template + static constexpr inline auto op_compare_column = + avx2::OpCompareColumnImpl::op_compare_column; + + template + static constexpr inline auto op_compare_val = + avx2::OpCompareValImpl::op_compare_val; + + template + static constexpr inline auto op_within_range_column = + avx2::OpWithinRangeColumnImpl::op_within_range_column; + + template + static constexpr inline auto op_within_range_val = + avx2::OpWithinRangeValImpl::op_within_range_val; + + template + static constexpr inline auto op_arith_compare = + avx2::OpArithCompareImpl::op_arith_compare; +}; + +} // namespace x86 +} // namespace detail +} // namespace bitset +} // namespace milvus diff --git a/internal/core/src/bitset/detail/platform/x86/avx512-decl.h b/internal/core/src/bitset/detail/platform/x86/avx512-decl.h new file mode 100644 index 000000000000..3ad5173cda37 --- /dev/null +++ b/internal/core/src/bitset/detail/platform/x86/avx512-decl.h @@ -0,0 +1,201 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// AVX512 declaration + +#pragma once + +#include +#include + +#include "bitset/common.h" + +namespace milvus { +namespace bitset { +namespace detail { +namespace x86 { +namespace avx512 { + +/////////////////////////////////////////////////////////////////////////// +// a facility to run through all acceptable data types +#define ALL_DATATYPES_1(FUNC) \ + FUNC(int8_t); \ + FUNC(int16_t); \ + FUNC(int32_t); \ + FUNC(int64_t); \ + FUNC(float); \ + FUNC(double); + +/////////////////////////////////////////////////////////////////////////// + +// the default implementation does nothing +template +struct OpCompareColumnImpl { + static bool + op_compare_column(uint8_t* const __restrict bitmask, + const T* const __restrict t, + const U* const __restrict u, + const size_t size) { + return false; + } +}; + +// the following use cases are handled +#define DECLARE_PARTIAL_OP_COMPARE_COLUMN(TTYPE) \ + template \ + struct OpCompareColumnImpl { \ + static bool \ + op_compare_column(uint8_t* const __restrict bitmask, \ + const TTYPE* const __restrict t, \ + const TTYPE* const __restrict u, \ + const size_t size); \ + }; + +ALL_DATATYPES_1(DECLARE_PARTIAL_OP_COMPARE_COLUMN) + +#undef DECLARE_PARTIAL_OP_COMPARE_COLUMN + +/////////////////////////////////////////////////////////////////////////// + +// the default implementation does nothing +template +struct OpCompareValImpl { + static inline bool + op_compare_val(uint8_t* const __restrict bitmask, + const T* const __restrict t, + const size_t size, + const T& value) { + return false; + } +}; + +// the following use cases are handled +#define DECLARE_PARTIAL_OP_COMPARE_VAL(TTYPE) \ + template \ + struct OpCompareValImpl { \ + static bool \ + op_compare_val(uint8_t* const __restrict bitmask, \ + const TTYPE* const __restrict t, \ + const size_t size, \ + const TTYPE& value); \ + }; + +ALL_DATATYPES_1(DECLARE_PARTIAL_OP_COMPARE_VAL) + +#undef DECLARE_PARTIAL_OP_COMPARE_VAL + +/////////////////////////////////////////////////////////////////////////// + +// the default implementation does nothing +template +struct OpWithinRangeColumnImpl { + static inline bool + op_within_range_column(uint8_t* const __restrict bitmask, + const T* const __restrict lower, + const T* const __restrict upper, + const T* const __restrict values, + const size_t size) { + return false; + } +}; + +// the following use cases are handled +#define DECLARE_PARTIAL_OP_WITHIN_RANGE_COLUMN(TTYPE) \ + template \ + struct OpWithinRangeColumnImpl { \ + static bool \ + op_within_range_column(uint8_t* const __restrict bitmask, \ + const TTYPE* const __restrict lower, \ + const TTYPE* const __restrict upper, \ + const TTYPE* const __restrict values, \ + const size_t size); \ + }; + +ALL_DATATYPES_1(DECLARE_PARTIAL_OP_WITHIN_RANGE_COLUMN) + +#undef DECLARE_PARTIAL_OP_WITHIN_RANGE_COLUMN + +/////////////////////////////////////////////////////////////////////////// + +// the default implementation does nothing +template +struct OpWithinRangeValImpl { + static inline bool + op_within_range_val(uint8_t* const __restrict bitmask, + const T& lower, + const T& upper, + const T* const __restrict values, + const size_t size) { + return false; + } +}; + +// the following use cases are handled +#define DECLARE_PARTIAL_OP_WITHIN_RANGE_VAL(TTYPE) \ + template \ + struct OpWithinRangeValImpl { \ + static bool \ + op_within_range_val(uint8_t* const __restrict bitmask, \ + const TTYPE& lower, \ + const TTYPE& upper, \ + const TTYPE* const __restrict values, \ + const size_t size); \ + }; + +ALL_DATATYPES_1(DECLARE_PARTIAL_OP_WITHIN_RANGE_VAL) + +#undef DECLARE_PARTIAL_OP_WITHIN_RANGE_VAL + +/////////////////////////////////////////////////////////////////////////// + +// the default implementation does nothing +template +struct OpArithCompareImpl { + static inline bool + op_arith_compare(uint8_t* const __restrict bitmask, + const T* const __restrict src, + const ArithHighPrecisionType& right_operand, + const ArithHighPrecisionType& value, + const size_t size) { + return false; + } +}; + +// the following use cases are handled +#define DECLARE_PARTIAL_OP_ARITH_COMPARE(TTYPE) \ + template \ + struct OpArithCompareImpl { \ + static bool \ + op_arith_compare(uint8_t* const __restrict bitmask, \ + const TTYPE* const __restrict src, \ + const ArithHighPrecisionType& right_operand, \ + const ArithHighPrecisionType& value, \ + const size_t size); \ + }; + +ALL_DATATYPES_1(DECLARE_PARTIAL_OP_ARITH_COMPARE) + +#undef DECLARE_PARTIAL_OP_ARITH_COMPARE + +/////////////////////////////////////////////////////////////////////////// + +#undef ALL_DATATYPES_1 + +} // namespace avx512 +} // namespace x86 +} // namespace detail +} // namespace bitset +} // namespace milvus diff --git a/internal/core/src/bitset/detail/platform/x86/avx512-impl.h b/internal/core/src/bitset/detail/platform/x86/avx512-impl.h new file mode 100644 index 000000000000..b460d257ecda --- /dev/null +++ b/internal/core/src/bitset/detail/platform/x86/avx512-impl.h @@ -0,0 +1,1460 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// AVX512 implementation + +#pragma once + +#include + +#include +#include +#include +#include + +#include "avx512-decl.h" + +#include "bitset/common.h" +#include "common.h" + +namespace milvus { +namespace bitset { +namespace detail { +namespace x86 { +namespace avx512 { + +namespace { + +// count is expected to be in range [0, 64) +inline uint64_t +get_mask(const size_t count) { + return (uint64_t(1) << count) - uint64_t(1); +} + +} // namespace + +/////////////////////////////////////////////////////////////////////////// + +// +template +bool +OpCompareValImpl::op_compare_val(uint8_t* const __restrict res_u8, + const int8_t* const __restrict src, + const size_t size, + const int8_t& val) { + // the restriction of the API + assert((size % 8) == 0); + + // + const __m512i target = _mm512_set1_epi8(val); + uint64_t* const __restrict res_u64 = reinterpret_cast(res_u8); + constexpr auto pred = ComparePredicate::value; + + // todo: aligned reads & writes + + // process big blocks + const size_t size64 = (size / 64) * 64; + for (size_t i = 0; i < size64; i += 64) { + const __m512i v = _mm512_loadu_si512(src + i); + const __mmask64 cmp_mask = _mm512_cmp_epi8_mask(v, target, pred); + + res_u64[i / 64] = cmp_mask; + } + + // process leftovers + if (size64 != size) { + // 8, 16, 24, 32, 40, 48 or 56 elements to process + const uint64_t mask = get_mask(size - size64); + const __m512i v = _mm512_maskz_loadu_epi8(mask, src + size64); + const __mmask64 cmp_mask = _mm512_cmp_epi8_mask(v, target, pred); + + const uint16_t store_mask = get_mask((size - size64) / 8); + _mm_mask_storeu_epi8(res_u64 + size64 / 64, + store_mask, + _mm_setr_epi64(__m64(cmp_mask), __m64(0ULL))); + } + + return true; +} + +template +bool +OpCompareValImpl::op_compare_val( + uint8_t* const __restrict res_u8, + const int16_t* const __restrict src, + const size_t size, + const int16_t& val) { + // the restriction of the API + assert((size % 8) == 0); + + // + const __m512i target = _mm512_set1_epi16(val); + uint32_t* const __restrict res_u32 = reinterpret_cast(res_u8); + constexpr auto pred = ComparePredicate::value; + + // todo: aligned reads & writes + + // process big blocks + const size_t size32 = (size / 32) * 32; + for (size_t i = 0; i < size32; i += 32) { + const __m512i v = _mm512_loadu_si512(src + i); + const __mmask32 cmp_mask = _mm512_cmp_epi16_mask(v, target, pred); + + res_u32[i / 32] = cmp_mask; + } + + // process leftovers + if (size32 != size) { + // 8, 16 or 24 elements to process + const uint32_t mask = get_mask(size - size32); + const __m512i v = _mm512_maskz_loadu_epi16(mask, src + size32); + const __mmask32 cmp_mask = _mm512_cmp_epi16_mask(v, target, pred); + + const uint16_t store_mask = get_mask((size - size32) / 8); + _mm_mask_storeu_epi8(res_u32 + size32 / 32, + store_mask, + _mm_setr_epi32(cmp_mask, 0, 0, 0)); + } + + return true; +} + +template +bool +OpCompareValImpl::op_compare_val( + uint8_t* const __restrict res_u8, + const int32_t* const __restrict src, + const size_t size, + const int32_t& val) { + // the restriction of the API + assert((size % 8) == 0); + + // + const __m512i target = _mm512_set1_epi32(val); + uint16_t* const __restrict res_u16 = reinterpret_cast(res_u8); + constexpr auto pred = ComparePredicate::value; + + // todo: aligned reads & writes + + // process big blocks + const size_t size16 = (size / 16) * 16; + for (size_t i = 0; i < size16; i += 16) { + const __m512i v = _mm512_loadu_si512(src + i); + const __mmask16 cmp_mask = _mm512_cmp_epi32_mask(v, target, pred); + + res_u16[i / 16] = cmp_mask; + } + + // process leftovers + if (size16 != size) { + // 8 elements to process + const __m256i v = _mm256_loadu_si256((const __m256i*)(src + size16)); + const __mmask8 cmp_mask = + _mm256_cmp_epi32_mask(v, _mm512_castsi512_si256(target), pred); + + res_u8[size16 / 8] = cmp_mask; + } + + return true; +} + +template +bool +OpCompareValImpl::op_compare_val( + uint8_t* const __restrict res_u8, + const int64_t* const __restrict src, + const size_t size, + const int64_t& val) { + // the restriction of the API + assert((size % 8) == 0); + + // + const __m512i target = _mm512_set1_epi64(val); + constexpr auto pred = ComparePredicate::value; + + // todo: aligned reads & writes + + // process big blocks + const size_t size8 = (size / 8) * 8; + for (size_t i = 0; i < size8; i += 8) { + const __m512i v = _mm512_loadu_si512(src + i); + const __mmask8 cmp_mask = _mm512_cmp_epi64_mask(v, target, pred); + + res_u8[i / 8] = cmp_mask; + } + + return true; +} + +template +bool +OpCompareValImpl::op_compare_val(uint8_t* const __restrict res_u8, + const float* const __restrict src, + const size_t size, + const float& val) { + // the restriction of the API + assert((size % 8) == 0); + + // + uint16_t* const __restrict res_u16 = reinterpret_cast(res_u8); + constexpr auto pred = ComparePredicate::value; + + const __m512 target = _mm512_set1_ps(val); + + // todo: aligned reads & writes + + // process big blocks + const size_t size16 = (size / 16) * 16; + for (size_t i = 0; i < size16; i += 16) { + const __m512 v = _mm512_loadu_ps(src + i); + const __mmask16 cmp_mask = _mm512_cmp_ps_mask(v, target, pred); + + res_u16[i / 16] = cmp_mask; + } + + // process leftovers + if (size16 != size) { + // 8 elements to process + const __m256 v = _mm256_loadu_ps(src + size16); + const __mmask8 cmp_mask = + _mm256_cmp_ps_mask(v, _mm512_castps512_ps256(target), pred); + + res_u8[size16 / 8] = cmp_mask; + } + + return true; +} + +template +bool +OpCompareValImpl::op_compare_val(uint8_t* const __restrict res_u8, + const double* const __restrict src, + const size_t size, + const double& val) { + // the restriction of the API + assert((size % 8) == 0); + + // + constexpr auto pred = ComparePredicate::value; + + const __m512d target = _mm512_set1_pd(val); + + // todo: aligned reads & writes + + // process big blocks + const size_t size8 = (size / 8) * 8; + for (size_t i = 0; i < size8; i += 8) { + const __m512d v = _mm512_loadu_pd(src + i); + const __mmask8 cmp_mask = _mm512_cmp_pd_mask(v, target, pred); + + res_u8[i / 8] = cmp_mask; + } + + return true; +} + +/////////////////////////////////////////////////////////////////////////// + +// +template +bool +OpCompareColumnImpl::op_compare_column( + uint8_t* const __restrict res_u8, + const int8_t* const __restrict left, + const int8_t* const __restrict right, + const size_t size) { + // the restriction of the API + assert((size % 8) == 0); + + // + uint64_t* const __restrict res_u64 = reinterpret_cast(res_u8); + constexpr auto pred = ComparePredicate::value; + + // todo: aligned reads & writes + + // process big blocks + const size_t size64 = (size / 64) * 64; + for (size_t i = 0; i < size64; i += 64) { + const __m512i vl = _mm512_loadu_si512(left + i); + const __m512i vr = _mm512_loadu_si512(right + i); + const __mmask64 cmp_mask = _mm512_cmp_epi8_mask(vl, vr, pred); + + res_u64[i / 64] = cmp_mask; + } + + // process leftovers + if (size64 != size) { + // 8, 16, 24, 32, 40, 48 or 56 elements to process + const uint64_t mask = get_mask(size - size64); + const __m512i vl = _mm512_maskz_loadu_epi8(mask, left + size64); + const __m512i vr = _mm512_maskz_loadu_epi8(mask, right + size64); + const __mmask64 cmp_mask = _mm512_cmp_epi8_mask(vl, vr, pred); + + const uint16_t store_mask = get_mask((size - size64) / 8); + _mm_mask_storeu_epi8(res_u64 + size64 / 64, + store_mask, + _mm_setr_epi64(__m64(cmp_mask), __m64(0ULL))); + } + + return true; +} + +template +bool +OpCompareColumnImpl::op_compare_column( + uint8_t* const __restrict res_u8, + const int16_t* const __restrict left, + const int16_t* const __restrict right, + const size_t size) { + // the restriction of the API + assert((size % 8) == 0); + + // + uint32_t* const __restrict res_u32 = reinterpret_cast(res_u8); + constexpr auto pred = ComparePredicate::value; + + // todo: aligned reads & writes + + // process big blocks + const size_t size32 = (size / 32) * 32; + for (size_t i = 0; i < size32; i += 32) { + const __m512i vl = _mm512_loadu_si512(left + i); + const __m512i vr = _mm512_loadu_si512(right + i); + const __mmask32 cmp_mask = _mm512_cmp_epi16_mask(vl, vr, pred); + + res_u32[i / 32] = cmp_mask; + } + + // process leftovers + if (size32 != size) { + // 8, 16 or 24 elements to process + const uint32_t mask = get_mask(size - size32); + const __m512i vl = _mm512_maskz_loadu_epi16(mask, left + size32); + const __m512i vr = _mm512_maskz_loadu_epi16(mask, right + size32); + const __mmask32 cmp_mask = _mm512_cmp_epi16_mask(vl, vr, pred); + + const uint16_t store_mask = get_mask((size - size32) / 8); + _mm_mask_storeu_epi8(res_u32 + size32 / 32, + store_mask, + _mm_setr_epi32(cmp_mask, 0, 0, 0)); + } + + return true; +} + +template +bool +OpCompareColumnImpl::op_compare_column( + uint8_t* const __restrict res_u8, + const int32_t* const __restrict left, + const int32_t* const __restrict right, + const size_t size) { + // the restriction of the API + assert((size % 8) == 0); + + // + uint16_t* const __restrict res_u16 = reinterpret_cast(res_u8); + constexpr auto pred = ComparePredicate::value; + + // todo: aligned reads & writes + + // process big blocks + const size_t size16 = (size / 16) * 16; + for (size_t i = 0; i < size16; i += 16) { + const __m512i vl = _mm512_loadu_si512(left + i); + const __m512i vr = _mm512_loadu_si512(right + i); + const __mmask16 cmp_mask = _mm512_cmp_epi32_mask(vl, vr, pred); + + res_u16[i / 16] = cmp_mask; + } + + // process leftovers + if (size16 != size) { + // 8 elements to process + const __m256i vl = _mm256_loadu_si256((const __m256i*)(left + size16)); + const __m256i vr = _mm256_loadu_si256((const __m256i*)(right + size16)); + const __mmask8 cmp_mask = _mm256_cmp_epi32_mask(vl, vr, pred); + + res_u8[size16 / 8] = cmp_mask; + } + + return true; +} + +template +bool +OpCompareColumnImpl::op_compare_column( + uint8_t* const __restrict res_u8, + const int64_t* const __restrict left, + const int64_t* const __restrict right, + const size_t size) { + // the restriction of the API + assert((size % 8) == 0); + + // + constexpr auto pred = ComparePredicate::value; + + // todo: aligned reads & writes + + // process big blocks + const size_t size8 = (size / 8) * 8; + for (size_t i = 0; i < size8; i += 8) { + const __m512i vl = _mm512_loadu_si512(left + i); + const __m512i vr = _mm512_loadu_si512(right + i); + const __mmask8 cmp_mask = _mm512_cmp_epi64_mask(vl, vr, pred); + + res_u8[i / 8] = cmp_mask; + } + + return true; +} + +template +bool +OpCompareColumnImpl::op_compare_column( + uint8_t* const __restrict res_u8, + const float* const __restrict left, + const float* const __restrict right, + const size_t size) { + // the restriction of the API + assert((size % 8) == 0); + + // + uint16_t* const __restrict res_u16 = reinterpret_cast(res_u8); + constexpr auto pred = ComparePredicate::value; + + // todo: aligned reads & writes + + // process big blocks + const size_t size16 = (size / 16) * 16; + for (size_t i = 0; i < size16; i += 16) { + const __m512 vl = _mm512_loadu_ps(left + i); + const __m512 vr = _mm512_loadu_ps(right + i); + const __mmask16 cmp_mask = _mm512_cmp_ps_mask(vl, vr, pred); + + res_u16[i / 16] = cmp_mask; + } + + // process leftovers + if (size16 != size) { + // process 8 elements + const __m256 vl = _mm256_loadu_ps(left + size16); + const __m256 vr = _mm256_loadu_ps(right + size16); + const __mmask8 cmp_mask = _mm256_cmp_ps_mask(vl, vr, pred); + + res_u8[size16 / 8] = cmp_mask; + } + + return true; +} + +template +bool +OpCompareColumnImpl::op_compare_column( + uint8_t* const __restrict res_u8, + const double* const __restrict left, + const double* const __restrict right, + const size_t size) { + // the restriction of the API + assert((size % 8) == 0); + + // + constexpr auto pred = ComparePredicate::value; + + // todo: aligned reads & writes + + // process big blocks + const size_t size8 = (size / 8) * 8; + for (size_t i = 0; i < size8; i += 8) { + const __m512d vl = _mm512_loadu_pd(left + i); + const __m512d vr = _mm512_loadu_pd(right + i); + const __mmask8 cmp_mask = _mm512_cmp_pd_mask(vl, vr, pred); + + res_u8[i / 8] = cmp_mask; + } + + return true; +} + +/////////////////////////////////////////////////////////////////////////// + +// +template +bool +OpWithinRangeColumnImpl::op_within_range_column( + uint8_t* const __restrict res_u8, + const int8_t* const __restrict lower, + const int8_t* const __restrict upper, + const int8_t* const __restrict values, + const size_t size) { + // the restriction of the API + assert((size % 8) == 0); + + // + uint64_t* const __restrict res_u64 = reinterpret_cast(res_u8); + constexpr auto pred_lower = + ComparePredicate::lower>::value; + constexpr auto pred_upper = + ComparePredicate::upper>::value; + + // todo: aligned reads & writes + + // process big blocks + const size_t size64 = (size / 64) * 64; + for (size_t i = 0; i < size64; i += 64) { + const __m512i vl = _mm512_loadu_si512(lower + i); + const __m512i vu = _mm512_loadu_si512(upper + i); + const __m512i vv = _mm512_loadu_si512(values + i); + const __mmask64 cmpl_mask = _mm512_cmp_epi8_mask(vl, vv, pred_lower); + const __mmask64 cmp_mask = + _mm512_mask_cmp_epi8_mask(cmpl_mask, vv, vu, pred_upper); + + res_u64[i / 64] = cmp_mask; + } + + // process leftovers + if (size64 != size) { + // 8, 16, 24, 32, 40, 48 or 56 elements to process + const uint64_t mask = get_mask(size - size64); + const __m512i vl = _mm512_maskz_loadu_epi8(mask, lower + size64); + const __m512i vu = _mm512_maskz_loadu_epi8(mask, upper + size64); + const __m512i vv = _mm512_maskz_loadu_epi8(mask, values + size64); + const __mmask64 cmpl_mask = _mm512_cmp_epi8_mask(vl, vv, pred_lower); + const __mmask64 cmp_mask = + _mm512_mask_cmp_epi8_mask(cmpl_mask, vv, vu, pred_upper); + + const uint16_t store_mask = get_mask((size - size64) / 8); + _mm_mask_storeu_epi8(res_u64 + size64 / 64, + store_mask, + _mm_setr_epi64(__m64(cmp_mask), __m64(0ULL))); + } + + return true; +} + +template +bool +OpWithinRangeColumnImpl::op_within_range_column( + uint8_t* const __restrict res_u8, + const int16_t* const __restrict lower, + const int16_t* const __restrict upper, + const int16_t* const __restrict values, + const size_t size) { + // the restriction of the API + assert((size % 8) == 0); + + // + uint32_t* const __restrict res_u32 = reinterpret_cast(res_u8); + constexpr auto pred_lower = + ComparePredicate::lower>::value; + constexpr auto pred_upper = + ComparePredicate::upper>::value; + + // todo: aligned reads & writes + + // process big blocks + const size_t size32 = (size / 32) * 32; + for (size_t i = 0; i < size32; i += 32) { + const __m512i vl = _mm512_loadu_si512(lower + i); + const __m512i vu = _mm512_loadu_si512(upper + i); + const __m512i vv = _mm512_loadu_si512(values + i); + const __mmask32 cmpl_mask = _mm512_cmp_epi16_mask(vl, vv, pred_lower); + const __mmask32 cmp_mask = + _mm512_mask_cmp_epi16_mask(cmpl_mask, vv, vu, pred_upper); + + res_u32[i / 32] = cmp_mask; + } + + // process leftovers + if (size32 != size) { + // 8, 16 or 24 elements to process + const uint32_t mask = get_mask(size - size32); + const __m512i vl = _mm512_maskz_loadu_epi16(mask, lower + size32); + const __m512i vu = _mm512_maskz_loadu_epi16(mask, upper + size32); + const __m512i vv = _mm512_maskz_loadu_epi16(mask, values + size32); + const __mmask32 cmpl_mask = _mm512_cmp_epi16_mask(vl, vv, pred_lower); + const __mmask32 cmp_mask = + _mm512_mask_cmp_epi16_mask(cmpl_mask, vv, vu, pred_upper); + + const uint16_t store_mask = get_mask((size - size32) / 8); + _mm_mask_storeu_epi8(res_u32 + size32 / 32, + store_mask, + _mm_setr_epi32(cmp_mask, 0, 0, 0)); + } + + return true; +} + +template +bool +OpWithinRangeColumnImpl::op_within_range_column( + uint8_t* const __restrict res_u8, + const int32_t* const __restrict lower, + const int32_t* const __restrict upper, + const int32_t* const __restrict values, + const size_t size) { + // the restriction of the API + assert((size % 8) == 0); + + // + uint16_t* const __restrict res_u16 = reinterpret_cast(res_u8); + constexpr auto pred_lower = + ComparePredicate::lower>::value; + constexpr auto pred_upper = + ComparePredicate::upper>::value; + + // todo: aligned reads & writes + + // process big blocks + const size_t size16 = (size / 16) * 16; + for (size_t i = 0; i < size16; i += 16) { + const __m512i vl = _mm512_loadu_si512(lower + i); + const __m512i vu = _mm512_loadu_si512(upper + i); + const __m512i vv = _mm512_loadu_si512(values + i); + const __mmask16 cmpl_mask = _mm512_cmp_epi32_mask(vl, vv, pred_lower); + const __mmask16 cmp_mask = + _mm512_mask_cmp_epi32_mask(cmpl_mask, vv, vu, pred_upper); + + res_u16[i / 16] = cmp_mask; + } + + // process leftovers + if (size16 != size) { + // 8 elements to process + const __m256i vl = _mm256_loadu_si256((const __m256i*)(lower + size16)); + const __m256i vu = _mm256_loadu_si256((const __m256i*)(upper + size16)); + const __m256i vv = + _mm256_loadu_si256((const __m256i*)(values + size16)); + const __mmask8 cmpl_mask = _mm256_cmp_epi32_mask(vl, vv, pred_lower); + const __mmask8 cmp_mask = + _mm256_mask_cmp_epi32_mask(cmpl_mask, vv, vu, pred_upper); + + res_u8[size16 / 8] = cmp_mask; + } + + return true; +} + +template +bool +OpWithinRangeColumnImpl::op_within_range_column( + uint8_t* const __restrict res_u8, + const int64_t* const __restrict lower, + const int64_t* const __restrict upper, + const int64_t* const __restrict values, + const size_t size) { + // the restriction of the API + assert((size % 8) == 0); + + // + constexpr auto pred_lower = + ComparePredicate::lower>::value; + constexpr auto pred_upper = + ComparePredicate::upper>::value; + + // todo: aligned reads & writes + + // process big blocks + const size_t size8 = (size / 8) * 8; + for (size_t i = 0; i < size8; i += 8) { + const __m512i vl = _mm512_loadu_si512(lower + i); + const __m512i vu = _mm512_loadu_si512(upper + i); + const __m512i vv = _mm512_loadu_si512(values + i); + const __mmask8 cmpl_mask = _mm512_cmp_epi64_mask(vl, vv, pred_lower); + const __mmask8 cmp_mask = + _mm512_mask_cmp_epi64_mask(cmpl_mask, vv, vu, pred_upper); + + res_u8[i / 8] = cmp_mask; + } + + return true; +} + +template +bool +OpWithinRangeColumnImpl::op_within_range_column( + uint8_t* const __restrict res_u8, + const float* const __restrict lower, + const float* const __restrict upper, + const float* const __restrict values, + const size_t size) { + // the restriction of the API + assert((size % 8) == 0); + + // + uint16_t* const __restrict res_u16 = reinterpret_cast(res_u8); + constexpr auto pred_lower = + ComparePredicate::lower>::value; + constexpr auto pred_upper = + ComparePredicate::upper>::value; + + // todo: aligned reads & writes + + // process big blocks + const size_t size16 = (size / 16) * 16; + for (size_t i = 0; i < size16; i += 16) { + const __m512 vl = _mm512_loadu_ps(lower + i); + const __m512 vu = _mm512_loadu_ps(upper + i); + const __m512 vv = _mm512_loadu_ps(values + i); + const __mmask16 cmpl_mask = _mm512_cmp_ps_mask(vl, vv, pred_lower); + const __mmask16 cmp_mask = + _mm512_mask_cmp_ps_mask(cmpl_mask, vv, vu, pred_upper); + + res_u16[i / 16] = cmp_mask; + } + + // process leftovers + if (size16 != size) { + // process 8 elements + const __m256 vl = _mm256_loadu_ps(lower + size16); + const __m256 vu = _mm256_loadu_ps(upper + size16); + const __m256 vv = _mm256_loadu_ps(values + size16); + const __mmask8 cmpl_mask = _mm256_cmp_ps_mask(vl, vv, pred_lower); + const __mmask8 cmp_mask = + _mm256_mask_cmp_ps_mask(cmpl_mask, vv, vu, pred_upper); + + res_u8[size16 / 8] = cmp_mask; + } + + return true; +} + +template +bool +OpWithinRangeColumnImpl::op_within_range_column( + uint8_t* const __restrict res_u8, + const double* const __restrict lower, + const double* const __restrict upper, + const double* const __restrict values, + const size_t size) { + // the restriction of the API + assert((size % 8) == 0); + + // + constexpr auto pred_lower = + ComparePredicate::lower>::value; + constexpr auto pred_upper = + ComparePredicate::upper>::value; + + // todo: aligned reads & writes + + // process big blocks + const size_t size8 = (size / 8) * 8; + for (size_t i = 0; i < size8; i += 8) { + const __m512d vl = _mm512_loadu_pd(lower + i); + const __m512d vu = _mm512_loadu_pd(upper + i); + const __m512d vv = _mm512_loadu_pd(values + i); + const __mmask8 cmpl_mask = _mm512_cmp_pd_mask(vl, vv, pred_lower); + const __mmask8 cmp_mask = + _mm512_mask_cmp_pd_mask(cmpl_mask, vv, vu, pred_upper); + + res_u8[i / 8] = cmp_mask; + } + + return true; +} + +/////////////////////////////////////////////////////////////////////////// + +// +template +bool +OpWithinRangeValImpl::op_within_range_val( + uint8_t* const __restrict res_u8, + const int8_t& lower, + const int8_t& upper, + const int8_t* const __restrict values, + const size_t size) { + // the restriction of the API + assert((size % 8) == 0); + + // + const __m512i lower_v = _mm512_set1_epi8(lower); + const __m512i upper_v = _mm512_set1_epi8(upper); + uint64_t* const __restrict res_u64 = reinterpret_cast(res_u8); + constexpr auto pred_lower = + ComparePredicate::lower>::value; + constexpr auto pred_upper = + ComparePredicate::upper>::value; + + // todo: aligned reads & writes + + // process big blocks + const size_t size64 = (size / 64) * 64; + for (size_t i = 0; i < size64; i += 64) { + const __m512i vv = _mm512_loadu_si512(values + i); + const __mmask64 cmpl_mask = + _mm512_cmp_epi8_mask(lower_v, vv, pred_lower); + const __mmask64 cmp_mask = + _mm512_mask_cmp_epi8_mask(cmpl_mask, vv, upper_v, pred_upper); + + res_u64[i / 64] = cmp_mask; + } + + // process leftovers + if (size64 != size) { + // 8, 16, 24, 32, 40, 48 or 56 elements to process + const uint64_t mask = get_mask(size - size64); + const __m512i vv = _mm512_maskz_loadu_epi8(mask, values + size64); + const __mmask64 cmpl_mask = + _mm512_cmp_epi8_mask(lower_v, vv, pred_lower); + const __mmask64 cmp_mask = + _mm512_mask_cmp_epi8_mask(cmpl_mask, vv, upper_v, pred_upper); + + const uint16_t store_mask = get_mask((size - size64) / 8); + _mm_mask_storeu_epi8(res_u64 + size64 / 64, + store_mask, + _mm_setr_epi64(__m64(cmp_mask), __m64(0ULL))); + } + + return true; +} + +template +bool +OpWithinRangeValImpl::op_within_range_val( + uint8_t* const __restrict res_u8, + const int16_t& lower, + const int16_t& upper, + const int16_t* const __restrict values, + const size_t size) { + // the restriction of the API + assert((size % 8) == 0); + + // + const __m512i lower_v = _mm512_set1_epi16(lower); + const __m512i upper_v = _mm512_set1_epi16(upper); + uint32_t* const __restrict res_u32 = reinterpret_cast(res_u8); + constexpr auto pred_lower = + ComparePredicate::lower>::value; + constexpr auto pred_upper = + ComparePredicate::upper>::value; + + // todo: aligned reads & writes + + // process big blocks + const size_t size32 = (size / 32) * 32; + for (size_t i = 0; i < size32; i += 32) { + const __m512i vv = _mm512_loadu_si512(values + i); + const __mmask32 cmpl_mask = + _mm512_cmp_epi16_mask(lower_v, vv, pred_lower); + const __mmask32 cmp_mask = + _mm512_mask_cmp_epi16_mask(cmpl_mask, vv, upper_v, pred_upper); + + res_u32[i / 32] = cmp_mask; + } + + // process leftovers + if (size32 != size) { + // 8, 16 or 24 elements to process + const uint32_t mask = get_mask(size - size32); + const __m512i vv = _mm512_maskz_loadu_epi16(mask, values + size32); + const __mmask32 cmpl_mask = + _mm512_cmp_epi16_mask(lower_v, vv, pred_lower); + const __mmask32 cmp_mask = + _mm512_mask_cmp_epi16_mask(cmpl_mask, vv, upper_v, pred_upper); + + const uint16_t store_mask = get_mask((size - size32) / 8); + _mm_mask_storeu_epi8(res_u32 + size32 / 32, + store_mask, + _mm_setr_epi32(cmp_mask, 0, 0, 0)); + } + + return true; +} + +template +bool +OpWithinRangeValImpl::op_within_range_val( + uint8_t* const __restrict res_u8, + const int32_t& lower, + const int32_t& upper, + const int32_t* const __restrict values, + const size_t size) { + // the restriction of the API + assert((size % 8) == 0); + + // + const __m512i lower_v = _mm512_set1_epi32(lower); + const __m512i upper_v = _mm512_set1_epi32(upper); + uint16_t* const __restrict res_u16 = reinterpret_cast(res_u8); + constexpr auto pred_lower = + ComparePredicate::lower>::value; + constexpr auto pred_upper = + ComparePredicate::upper>::value; + + // todo: aligned reads & writes + + // process big blocks + const size_t size16 = (size / 16) * 16; + for (size_t i = 0; i < size16; i += 16) { + const __m512i vv = _mm512_loadu_si512(values + i); + const __mmask16 cmpl_mask = + _mm512_cmp_epi32_mask(lower_v, vv, pred_lower); + const __mmask16 cmp_mask = + _mm512_mask_cmp_epi32_mask(cmpl_mask, vv, upper_v, pred_upper); + + res_u16[i / 16] = cmp_mask; + } + + // process leftovers + if (size16 != size) { + // 8 elements to process + const __m256i vv = + _mm256_loadu_si256((const __m256i*)(values + size16)); + const __mmask8 cmpl_mask = _mm256_cmp_epi32_mask( + _mm512_castsi512_si256(lower_v), vv, pred_lower); + const __mmask8 cmp_mask = _mm256_mask_cmp_epi32_mask( + cmpl_mask, vv, _mm512_castsi512_si256(upper_v), pred_upper); + + res_u8[size16 / 8] = cmp_mask; + } + + return true; +} + +template +bool +OpWithinRangeValImpl::op_within_range_val( + uint8_t* const __restrict res_u8, + const int64_t& lower, + const int64_t& upper, + const int64_t* const __restrict values, + const size_t size) { + // the restriction of the API + assert((size % 8) == 0); + + // + const __m512i lower_v = _mm512_set1_epi64(lower); + const __m512i upper_v = _mm512_set1_epi64(upper); + constexpr auto pred_lower = + ComparePredicate::lower>::value; + constexpr auto pred_upper = + ComparePredicate::upper>::value; + + // todo: aligned reads & writes + + // process big blocks + const size_t size8 = (size / 8) * 8; + for (size_t i = 0; i < size8; i += 8) { + const __m512i vv = _mm512_loadu_si512(values + i); + const __mmask8 cmpl_mask = + _mm512_cmp_epi64_mask(lower_v, vv, pred_lower); + const __mmask8 cmp_mask = + _mm512_mask_cmp_epi64_mask(cmpl_mask, vv, upper_v, pred_upper); + + res_u8[i / 8] = cmp_mask; + } + + return true; +} + +template +bool +OpWithinRangeValImpl::op_within_range_val( + uint8_t* const __restrict res_u8, + const float& lower, + const float& upper, + const float* const __restrict values, + const size_t size) { + // the restriction of the API + assert((size % 8) == 0); + + // + const __m512 lower_v = _mm512_set1_ps(lower); + const __m512 upper_v = _mm512_set1_ps(upper); + uint16_t* const __restrict res_u16 = reinterpret_cast(res_u8); + constexpr auto pred_lower = + ComparePredicate::lower>::value; + constexpr auto pred_upper = + ComparePredicate::upper>::value; + + // todo: aligned reads & writes + + // process big blocks + const size_t size16 = (size / 16) * 16; + for (size_t i = 0; i < size16; i += 16) { + const __m512 vv = _mm512_loadu_ps(values + i); + const __mmask16 cmpl_mask = _mm512_cmp_ps_mask(lower_v, vv, pred_lower); + const __mmask16 cmp_mask = + _mm512_mask_cmp_ps_mask(cmpl_mask, vv, upper_v, pred_upper); + + res_u16[i / 16] = cmp_mask; + } + + // process leftovers + if (size16 != size) { + // process 8 elements + const __m256 vv = _mm256_loadu_ps(values + size16); + const __mmask8 cmpl_mask = + _mm256_cmp_ps_mask(_mm512_castps512_ps256(lower_v), vv, pred_lower); + const __mmask8 cmp_mask = _mm256_mask_cmp_ps_mask( + cmpl_mask, vv, _mm512_castps512_ps256(upper_v), pred_upper); + + res_u8[size16 / 8] = cmp_mask; + } + + return true; +} + +template +bool +OpWithinRangeValImpl::op_within_range_val( + uint8_t* const __restrict res_u8, + const double& lower, + const double& upper, + const double* const __restrict values, + const size_t size) { + // the restriction of the API + assert((size % 8) == 0); + + // + const __m512d lower_v = _mm512_set1_pd(lower); + const __m512d upper_v = _mm512_set1_pd(upper); + constexpr auto pred_lower = + ComparePredicate::lower>::value; + constexpr auto pred_upper = + ComparePredicate::upper>::value; + + // todo: aligned reads & writes + + // process big blocks + const size_t size8 = (size / 8) * 8; + for (size_t i = 0; i < size8; i += 8) { + const __m512d vv = _mm512_loadu_pd(values + i); + const __mmask8 cmpl_mask = _mm512_cmp_pd_mask(lower_v, vv, pred_lower); + const __mmask8 cmp_mask = + _mm512_mask_cmp_pd_mask(cmpl_mask, vv, upper_v, pred_upper); + + res_u8[i / 8] = cmp_mask; + } + + return true; +} + +/////////////////////////////////////////////////////////////////////////// + +namespace { + +// +template +struct ArithHelperI64 {}; + +template +struct ArithHelperI64 { + static inline __mmask8 + op(const __m512i left, const __m512i right, const __m512i value) { + // left + right == value + constexpr auto pred = ComparePredicate::value; + return _mm512_cmp_epi64_mask( + _mm512_add_epi64(left, right), value, pred); + } +}; + +template +struct ArithHelperI64 { + static inline __mmask8 + op(const __m512i left, const __m512i right, const __m512i value) { + // left - right == value + constexpr auto pred = ComparePredicate::value; + return _mm512_cmp_epi64_mask( + _mm512_sub_epi64(left, right), value, pred); + } +}; + +template +struct ArithHelperI64 { + static inline __mmask8 + op(const __m512i left, const __m512i right, const __m512i value) { + // left * right == value + constexpr auto pred = ComparePredicate::value; + return _mm512_cmp_epi64_mask( + _mm512_mullo_epi64(left, right), value, pred); + } +}; + +// +template +struct ArithHelperF32 {}; + +template +struct ArithHelperF32 { + static inline __mmask16 + op(const __m512 left, const __m512 right, const __m512 value) { + // left + right == value + constexpr auto pred = ComparePredicate::value; + return _mm512_cmp_ps_mask(_mm512_add_ps(left, right), value, pred); + } +}; + +template +struct ArithHelperF32 { + static inline __mmask16 + op(const __m512 left, const __m512 right, const __m512 value) { + // left - right == value + constexpr auto pred = ComparePredicate::value; + return _mm512_cmp_ps_mask(_mm512_sub_ps(left, right), value, pred); + } +}; + +template +struct ArithHelperF32 { + static inline __mmask16 + op(const __m512 left, const __m512 right, const __m512 value) { + // left * right == value + constexpr auto pred = ComparePredicate::value; + return _mm512_cmp_ps_mask(_mm512_mul_ps(left, right), value, pred); + } +}; + +template +struct ArithHelperF32 { + static inline __mmask16 + op(const __m512 left, const __m512 right, const __m512 value) { + // left == right * value + constexpr auto pred = ComparePredicate::value; + return _mm512_cmp_ps_mask(left, _mm512_mul_ps(right, value), pred); + } +}; + +// +template +struct ArithHelperF64 {}; + +template +struct ArithHelperF64 { + static inline __mmask8 + op(const __m512d left, const __m512d right, const __m512d value) { + // left + right == value + constexpr auto pred = ComparePredicate::value; + return _mm512_cmp_pd_mask(_mm512_add_pd(left, right), value, pred); + } +}; + +template +struct ArithHelperF64 { + static inline __mmask8 + op(const __m512d left, const __m512d right, const __m512d value) { + // left - right == value + constexpr auto pred = ComparePredicate::value; + return _mm512_cmp_pd_mask(_mm512_sub_pd(left, right), value, pred); + } +}; + +template +struct ArithHelperF64 { + static inline __mmask8 + op(const __m512d left, const __m512d right, const __m512d value) { + // left * right == value + constexpr auto pred = ComparePredicate::value; + return _mm512_cmp_pd_mask(_mm512_mul_pd(left, right), value, pred); + } +}; + +template +struct ArithHelperF64 { + static inline __mmask8 + op(const __m512d left, const __m512d right, const __m512d value) { + // left == right * value + constexpr auto pred = ComparePredicate::value; + return _mm512_cmp_pd_mask(left, _mm512_mul_pd(right, value), pred); + } +}; + +} // namespace + +// +template +bool +OpArithCompareImpl::op_arith_compare( + uint8_t* const __restrict res_u8, + const int8_t* const __restrict src, + const ArithHighPrecisionType& right_operand, + const ArithHighPrecisionType& value, + const size_t size) { + if constexpr (AOp == ArithOpType::Div || AOp == ArithOpType::Mod) { + return false; + } else { + // the restriction of the API + assert((size % 8) == 0); + static_assert(std::is_same_v>); + + // + const __m512i right_v = _mm512_set1_epi64(right_operand); + const __m512i value_v = _mm512_set1_epi64(value); + + // process big blocks + const size_t size16 = (size / 16) * 16; + for (size_t i = 0; i < size16; i += 16) { + const __m128i vs = _mm_loadu_si128((const __m128i*)(src + i)); + const __m512i v0s = _mm512_cvtepi8_epi64( + _mm_unpacklo_epi64(vs, _mm_setzero_si128())); + const __m512i v1s = _mm512_cvtepi8_epi64( + _mm_unpackhi_epi64(vs, _mm_setzero_si128())); + const __mmask8 cmp_mask0 = + ArithHelperI64::op(v0s, right_v, value_v); + const __mmask8 cmp_mask1 = + ArithHelperI64::op(v1s, right_v, value_v); + + res_u8[i / 8 + 0] = cmp_mask0; + res_u8[i / 8 + 1] = cmp_mask1; + } + + // process leftovers + if (size16 != size) { + // process 8 elements + const int64_t* const __restrict src64 = + (const int64_t*)(src + size16); + const __m128i vs = _mm_set_epi64x(0, *src64); + const __m512i v0s = _mm512_cvtepi8_epi64(vs); + const __mmask8 cmp_mask = + ArithHelperI64::op(v0s, right_v, value_v); + + res_u8[size16 / 8] = cmp_mask; + } + + return true; + } +} + +template +bool +OpArithCompareImpl::op_arith_compare( + uint8_t* const __restrict res_u8, + const int16_t* const __restrict src, + const ArithHighPrecisionType& right_operand, + const ArithHighPrecisionType& value, + const size_t size) { + if constexpr (AOp == ArithOpType::Div || AOp == ArithOpType::Mod) { + return false; + } else { + // the restriction of the API + assert((size % 8) == 0); + static_assert(std::is_same_v>); + + // + const __m512i right_v = _mm512_set1_epi64(right_operand); + const __m512i value_v = _mm512_set1_epi64(value); + + // todo: aligned reads & writes + + // process big blocks + const size_t size16 = (size / 16) * 16; + for (size_t i = 0; i < size16; i += 16) { + const __m256i vs = _mm256_loadu_si256((const __m256i*)(src + i)); + const __m512i v0s = + _mm512_cvtepi16_epi64(_mm256_extracti128_si256(vs, 0)); + const __m512i v1s = + _mm512_cvtepi16_epi64(_mm256_extracti128_si256(vs, 1)); + const __mmask8 cmp_mask0 = + ArithHelperI64::op(v0s, right_v, value_v); + const __mmask8 cmp_mask1 = + ArithHelperI64::op(v1s, right_v, value_v); + + res_u8[i / 8 + 0] = cmp_mask0; + res_u8[i / 8 + 1] = cmp_mask1; + } + + // process leftovers + if (size16 != size) { + // process 8 elements + const __m128i vs = _mm_loadu_si128((const __m128i*)(src + size16)); + const __m512i v0s = _mm512_cvtepi16_epi64(vs); + const __mmask8 cmp_mask = + ArithHelperI64::op(v0s, right_v, value_v); + + res_u8[size16 / 8] = cmp_mask; + } + + return true; + } +} + +template +bool +OpArithCompareImpl::op_arith_compare( + uint8_t* const __restrict res_u8, + const int32_t* const __restrict src, + const ArithHighPrecisionType& right_operand, + const ArithHighPrecisionType& value, + const size_t size) { + if constexpr (AOp == ArithOpType::Div || AOp == ArithOpType::Mod) { + return false; + } else { + // the restriction of the API + assert((size % 8) == 0); + static_assert(std::is_same_v>); + + // + const __m512i right_v = _mm512_set1_epi64(right_operand); + const __m512i value_v = _mm512_set1_epi64(value); + + // todo: aligned reads & writes + + // process big blocks + const size_t size16 = (size / 16) * 16; + for (size_t i = 0; i < size16; i += 16) { + const __m512i vs = _mm512_loadu_si512((const __m512i*)(src + i)); + const __m512i v0s = + _mm512_cvtepi32_epi64(_mm512_extracti64x4_epi64(vs, 0)); + const __m512i v1s = + _mm512_cvtepi32_epi64(_mm512_extracti64x4_epi64(vs, 1)); + const __mmask8 cmp_mask0 = + ArithHelperI64::op(v0s, right_v, value_v); + const __mmask8 cmp_mask1 = + ArithHelperI64::op(v1s, right_v, value_v); + + res_u8[i / 8 + 0] = cmp_mask0; + res_u8[i / 8 + 1] = cmp_mask1; + } + + // process leftovers + if (size16 != size) { + // process 8 elements + const __m256i vs = + _mm256_loadu_si256((const __m256i*)(src + size16)); + const __m512i v0s = _mm512_cvtepi32_epi64(vs); + const __mmask8 cmp_mask = + ArithHelperI64::op(v0s, right_v, value_v); + + res_u8[size16 / 8] = cmp_mask; + } + + return true; + } +} + +template +bool +OpArithCompareImpl::op_arith_compare( + uint8_t* const __restrict res_u8, + const int64_t* const __restrict src, + const ArithHighPrecisionType& right_operand, + const ArithHighPrecisionType& value, + const size_t size) { + if constexpr (AOp == ArithOpType::Div || AOp == ArithOpType::Mod) { + return false; + } else { + // the restriction of the API + assert((size % 8) == 0); + static_assert(std::is_same_v>); + + // + const __m512i right_v = _mm512_set1_epi64(right_operand); + const __m512i value_v = _mm512_set1_epi64(value); + + // todo: aligned reads & writes + + // process big blocks + const size_t size8 = (size / 8) * 8; + for (size_t i = 0; i < size8; i += 8) { + const __m512i v0s = _mm512_loadu_si512((const __m512i*)(src + i)); + const __mmask8 cmp_mask = + ArithHelperI64::op(v0s, right_v, value_v); + + res_u8[i / 8] = cmp_mask; + } + + return true; + } +} + +template +bool +OpArithCompareImpl::op_arith_compare( + uint8_t* const __restrict res_u8, + const float* const __restrict src, + const ArithHighPrecisionType& right_operand, + const ArithHighPrecisionType& value, + const size_t size) { + if constexpr (AOp == ArithOpType::Mod) { + return false; + } else { + // the restriction of the API + assert((size % 8) == 0); + + // + const __m512 right_v = _mm512_set1_ps(right_operand); + const __m512 value_v = _mm512_set1_ps(value); + uint16_t* const __restrict res_u16 = + reinterpret_cast(res_u8); + + // todo: aligned reads & writes + + // process big blocks + const size_t size16 = (size / 16) * 16; + for (size_t i = 0; i < size16; i += 16) { + const __m512 v0s = _mm512_loadu_ps(src + i); + const __mmask16 cmp_mask = + ArithHelperF32::op(v0s, right_v, value_v); + res_u16[i / 16] = cmp_mask; + } + + // process leftovers + if (size16 != size) { + // process 8 elements + const __m256 vs = _mm256_loadu_ps(src + size16); + const __m512 v0s = _mm512_castps256_ps512(vs); + const __mmask16 cmp_mask = + ArithHelperF32::op(v0s, right_v, value_v); + res_u8[size16 / 8] = uint8_t(cmp_mask); + } + + return true; + } +} + +template +bool +OpArithCompareImpl::op_arith_compare( + uint8_t* const __restrict res_u8, + const double* const __restrict src, + const ArithHighPrecisionType& right_operand, + const ArithHighPrecisionType& value, + const size_t size) { + if constexpr (AOp == ArithOpType::Mod) { + return false; + } else { + // the restriction of the API + assert((size % 8) == 0); + + // + const __m512d right_v = _mm512_set1_pd(right_operand); + const __m512d value_v = _mm512_set1_pd(value); + + // todo: aligned reads & writes + + // process big blocks + const size_t size8 = (size / 8) * 8; + for (size_t i = 0; i < size8; i += 8) { + const __m512d v0s = _mm512_loadu_pd(src + i); + const __mmask8 cmp_mask = + ArithHelperF64::op(v0s, right_v, value_v); + + res_u8[i / 8] = cmp_mask; + } + + return true; + } +} + +/////////////////////////////////////////////////////////////////////////// + +} // namespace avx512 +} // namespace x86 +} // namespace detail +} // namespace bitset +} // namespace milvus diff --git a/internal/core/src/bitset/detail/platform/x86/avx512-inst.cpp b/internal/core/src/bitset/detail/platform/x86/avx512-inst.cpp new file mode 100644 index 000000000000..d8c4fd046eb4 --- /dev/null +++ b/internal/core/src/bitset/detail/platform/x86/avx512-inst.cpp @@ -0,0 +1,199 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// AVX512 instantiation + +#include "bitset/common.h" + +#ifndef BITSET_HEADER_ONLY + +#include "avx512-decl.h" +#include "avx512-impl.h" + +#include +#include + +namespace milvus { +namespace bitset { +namespace detail { +namespace x86 { +namespace avx512 { + +// a facility to run through all possible compare operations +#define ALL_COMPARE_OPS(FUNC, ...) \ + FUNC(__VA_ARGS__, EQ); \ + FUNC(__VA_ARGS__, GE); \ + FUNC(__VA_ARGS__, GT); \ + FUNC(__VA_ARGS__, LE); \ + FUNC(__VA_ARGS__, LT); \ + FUNC(__VA_ARGS__, NE); + +// a facility to run through all possible range operations +#define ALL_RANGE_OPS(FUNC, ...) \ + FUNC(__VA_ARGS__, IncInc); \ + FUNC(__VA_ARGS__, IncExc); \ + FUNC(__VA_ARGS__, ExcInc); \ + FUNC(__VA_ARGS__, ExcExc); + +// a facility to run through all possible arithmetic compare operations +#define ALL_ARITH_CMP_OPS(FUNC, ...) \ + FUNC(__VA_ARGS__, Add, EQ); \ + FUNC(__VA_ARGS__, Add, GE); \ + FUNC(__VA_ARGS__, Add, GT); \ + FUNC(__VA_ARGS__, Add, LE); \ + FUNC(__VA_ARGS__, Add, LT); \ + FUNC(__VA_ARGS__, Add, NE); \ + FUNC(__VA_ARGS__, Sub, EQ); \ + FUNC(__VA_ARGS__, Sub, GE); \ + FUNC(__VA_ARGS__, Sub, GT); \ + FUNC(__VA_ARGS__, Sub, LE); \ + FUNC(__VA_ARGS__, Sub, LT); \ + FUNC(__VA_ARGS__, Sub, NE); \ + FUNC(__VA_ARGS__, Mul, EQ); \ + FUNC(__VA_ARGS__, Mul, GE); \ + FUNC(__VA_ARGS__, Mul, GT); \ + FUNC(__VA_ARGS__, Mul, LE); \ + FUNC(__VA_ARGS__, Mul, LT); \ + FUNC(__VA_ARGS__, Mul, NE); \ + FUNC(__VA_ARGS__, Div, EQ); \ + FUNC(__VA_ARGS__, Div, GE); \ + FUNC(__VA_ARGS__, Div, GT); \ + FUNC(__VA_ARGS__, Div, LE); \ + FUNC(__VA_ARGS__, Div, LT); \ + FUNC(__VA_ARGS__, Div, NE); \ + FUNC(__VA_ARGS__, Mod, EQ); \ + FUNC(__VA_ARGS__, Mod, GE); \ + FUNC(__VA_ARGS__, Mod, GT); \ + FUNC(__VA_ARGS__, Mod, LE); \ + FUNC(__VA_ARGS__, Mod, LT); \ + FUNC(__VA_ARGS__, Mod, NE); + +/////////////////////////////////////////////////////////////////////////// + +// +#define INSTANTIATE_COMPARE_VAL_AVX512(TTYPE, OP) \ + template bool OpCompareValImpl::op_compare_val( \ + uint8_t* const __restrict bitmask, \ + const TTYPE* const __restrict src, \ + const size_t size, \ + const TTYPE& val); + +ALL_COMPARE_OPS(INSTANTIATE_COMPARE_VAL_AVX512, int8_t) +ALL_COMPARE_OPS(INSTANTIATE_COMPARE_VAL_AVX512, int16_t) +ALL_COMPARE_OPS(INSTANTIATE_COMPARE_VAL_AVX512, int32_t) +ALL_COMPARE_OPS(INSTANTIATE_COMPARE_VAL_AVX512, int64_t) +ALL_COMPARE_OPS(INSTANTIATE_COMPARE_VAL_AVX512, float) +ALL_COMPARE_OPS(INSTANTIATE_COMPARE_VAL_AVX512, double) + +#undef INSTANTIATE_COMPARE_VAL_AVX512 + +/////////////////////////////////////////////////////////////////////////// + +// +#define INSTANTIATE_COMPARE_COLUMN_AVX512(TTYPE, OP) \ + template bool \ + OpCompareColumnImpl::op_compare_column( \ + uint8_t* const __restrict bitmask, \ + const TTYPE* const __restrict left, \ + const TTYPE* const __restrict right, \ + const size_t size); + +ALL_COMPARE_OPS(INSTANTIATE_COMPARE_COLUMN_AVX512, int8_t) +ALL_COMPARE_OPS(INSTANTIATE_COMPARE_COLUMN_AVX512, int16_t) +ALL_COMPARE_OPS(INSTANTIATE_COMPARE_COLUMN_AVX512, int32_t) +ALL_COMPARE_OPS(INSTANTIATE_COMPARE_COLUMN_AVX512, int64_t) +ALL_COMPARE_OPS(INSTANTIATE_COMPARE_COLUMN_AVX512, float) +ALL_COMPARE_OPS(INSTANTIATE_COMPARE_COLUMN_AVX512, double) + +#undef INSTANTIATE_COMPARE_COLUMN_AVX512 + +/////////////////////////////////////////////////////////////////////////// + +// +#define INSTANTIATE_WITHIN_RANGE_COLUMN_AVX512(TTYPE, OP) \ + template bool \ + OpWithinRangeColumnImpl::op_within_range_column( \ + uint8_t* const __restrict res_u8, \ + const TTYPE* const __restrict lower, \ + const TTYPE* const __restrict upper, \ + const TTYPE* const __restrict values, \ + const size_t size); + +ALL_RANGE_OPS(INSTANTIATE_WITHIN_RANGE_COLUMN_AVX512, int8_t) +ALL_RANGE_OPS(INSTANTIATE_WITHIN_RANGE_COLUMN_AVX512, int16_t) +ALL_RANGE_OPS(INSTANTIATE_WITHIN_RANGE_COLUMN_AVX512, int32_t) +ALL_RANGE_OPS(INSTANTIATE_WITHIN_RANGE_COLUMN_AVX512, int64_t) +ALL_RANGE_OPS(INSTANTIATE_WITHIN_RANGE_COLUMN_AVX512, float) +ALL_RANGE_OPS(INSTANTIATE_WITHIN_RANGE_COLUMN_AVX512, double) + +#undef INSTANTIATE_WITHIN_RANGE_COLUMN_AVX512 + +/////////////////////////////////////////////////////////////////////////// + +// +#define INSTANTIATE_WITHIN_RANGE_VAL_AVX512(TTYPE, OP) \ + template bool \ + OpWithinRangeValImpl::op_within_range_val( \ + uint8_t* const __restrict res_u8, \ + const TTYPE& lower, \ + const TTYPE& upper, \ + const TTYPE* const __restrict values, \ + const size_t size); + +ALL_RANGE_OPS(INSTANTIATE_WITHIN_RANGE_VAL_AVX512, int8_t) +ALL_RANGE_OPS(INSTANTIATE_WITHIN_RANGE_VAL_AVX512, int16_t) +ALL_RANGE_OPS(INSTANTIATE_WITHIN_RANGE_VAL_AVX512, int32_t) +ALL_RANGE_OPS(INSTANTIATE_WITHIN_RANGE_VAL_AVX512, int64_t) +ALL_RANGE_OPS(INSTANTIATE_WITHIN_RANGE_VAL_AVX512, float) +ALL_RANGE_OPS(INSTANTIATE_WITHIN_RANGE_VAL_AVX512, double) + +#undef INSTANTIATE_WITHIN_RANGE_VAL_AVX512 + +/////////////////////////////////////////////////////////////////////////// + +// +#define INSTANTIATE_ARITH_COMPARE_AVX512(TTYPE, OP, CMP) \ + template bool \ + OpArithCompareImpl:: \ + op_arith_compare(uint8_t* const __restrict res_u8, \ + const TTYPE* const __restrict src, \ + const ArithHighPrecisionType& right_operand, \ + const ArithHighPrecisionType& value, \ + const size_t size); + +ALL_ARITH_CMP_OPS(INSTANTIATE_ARITH_COMPARE_AVX512, int8_t) +ALL_ARITH_CMP_OPS(INSTANTIATE_ARITH_COMPARE_AVX512, int16_t) +ALL_ARITH_CMP_OPS(INSTANTIATE_ARITH_COMPARE_AVX512, int32_t) +ALL_ARITH_CMP_OPS(INSTANTIATE_ARITH_COMPARE_AVX512, int64_t) +ALL_ARITH_CMP_OPS(INSTANTIATE_ARITH_COMPARE_AVX512, float) +ALL_ARITH_CMP_OPS(INSTANTIATE_ARITH_COMPARE_AVX512, double) + +#undef INSTANTIATE_ARITH_COMPARE_AVX512 + +/////////////////////////////////////////////////////////////////////////// + +// +#undef ALL_COMPARE_OPS +#undef ALL_RANGE_OPS +#undef ALL_ARITH_CMP_OPS + +} // namespace avx512 +} // namespace x86 +} // namespace detail +} // namespace bitset +} // namespace milvus + +#endif diff --git a/internal/core/src/bitset/detail/platform/x86/avx512.h b/internal/core/src/bitset/detail/platform/x86/avx512.h new file mode 100644 index 000000000000..2582efd7c380 --- /dev/null +++ b/internal/core/src/bitset/detail/platform/x86/avx512.h @@ -0,0 +1,63 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "bitset/common.h" + +#include "avx512-decl.h" + +#ifdef BITSET_HEADER_ONLY +#include "avx512-impl.h" +#endif + +namespace milvus { +namespace bitset { +namespace detail { +namespace x86 { + +/////////////////////////////////////////////////////////////////////////// + +// +struct VectorizedAvx512 { + template + static constexpr inline auto op_compare_column = + avx512::OpCompareColumnImpl::op_compare_column; + + template + static constexpr inline auto op_compare_val = + avx512::OpCompareValImpl::op_compare_val; + + template + static constexpr inline auto op_within_range_column = + avx512::OpWithinRangeColumnImpl::op_within_range_column; + + template + static constexpr inline auto op_within_range_val = + avx512::OpWithinRangeValImpl::op_within_range_val; + + template + static constexpr inline auto op_arith_compare = + avx512::OpArithCompareImpl::op_arith_compare; +}; + +} // namespace x86 +} // namespace detail +} // namespace bitset +} // namespace milvus diff --git a/internal/core/src/bitset/detail/platform/x86/common.h b/internal/core/src/bitset/detail/platform/x86/common.h new file mode 100644 index 000000000000..9bedb78c320f --- /dev/null +++ b/internal/core/src/bitset/detail/platform/x86/common.h @@ -0,0 +1,73 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include + +#include "bitset/common.h" + +namespace milvus { +namespace bitset { +namespace detail { +namespace x86 { + +// +template +struct ComparePredicate {}; + +template +struct ComparePredicate { + static inline constexpr int value = + std::is_floating_point_v ? _CMP_EQ_OQ : _MM_CMPINT_EQ; +}; + +template +struct ComparePredicate { + static inline constexpr int value = + std::is_floating_point_v ? _CMP_LT_OQ : _MM_CMPINT_LT; +}; + +template +struct ComparePredicate { + static inline constexpr int value = + std::is_floating_point_v ? _CMP_LE_OQ : _MM_CMPINT_LE; +}; + +template +struct ComparePredicate { + static inline constexpr int value = + std::is_floating_point_v ? _CMP_GT_OQ : _MM_CMPINT_NLE; +}; + +template +struct ComparePredicate { + static inline constexpr int value = + std::is_floating_point_v ? _CMP_GE_OQ : _MM_CMPINT_NLT; +}; + +template +struct ComparePredicate { + static inline constexpr int value = + std::is_floating_point_v ? _CMP_NEQ_OQ : _MM_CMPINT_NE; +}; + +} // namespace x86 +} // namespace detail +} // namespace bitset +} // namespace milvus diff --git a/internal/core/src/bitset/detail/platform/x86/instruction_set.cpp b/internal/core/src/bitset/detail/platform/x86/instruction_set.cpp new file mode 100644 index 000000000000..329dc4243cfa --- /dev/null +++ b/internal/core/src/bitset/detail/platform/x86/instruction_set.cpp @@ -0,0 +1,139 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "instruction_set.h" + +#include + +#include +#include + +namespace milvus { +namespace bitset { +namespace detail { +namespace x86 { + +InstructionSet::InstructionSet() + : nIds_{0}, + nExIds_{0}, + isIntel_{false}, + isAMD_{false}, + f_1_ECX_{0}, + f_1_EDX_{0}, + f_7_EBX_{0}, + f_7_ECX_{0}, + f_81_ECX_{0}, + f_81_EDX_{0}, + data_{}, + extdata_{} { + std::array cpui; + + // Calling __cpuid with 0x0 as the function_id argument + // gets the number of the highest valid function ID. + __cpuid(0, cpui[0], cpui[1], cpui[2], cpui[3]); + nIds_ = cpui[0]; + + for (int i = 0; i <= nIds_; ++i) { + __cpuid_count(i, 0, cpui[0], cpui[1], cpui[2], cpui[3]); + data_.push_back(cpui); + } + + // Capture vendor string + char vendor[0x20]; + memset(vendor, 0, sizeof(vendor)); + *reinterpret_cast(vendor) = data_[0][1]; + *reinterpret_cast(vendor + 4) = data_[0][3]; + *reinterpret_cast(vendor + 8) = data_[0][2]; + vendor_ = vendor; + if (vendor_ == "GenuineIntel") { + isIntel_ = true; + } else if (vendor_ == "AuthenticAMD") { + isAMD_ = true; + } + + // load bitset with flags for function 0x00000001 + if (nIds_ >= 1) { + f_1_ECX_ = data_[1][2]; + f_1_EDX_ = data_[1][3]; + } + + // load bitset with flags for function 0x00000007 + if (nIds_ >= 7) { + f_7_EBX_ = data_[7][1]; + f_7_ECX_ = data_[7][2]; + } + + // Calling __cpuid with 0x80000000 as the function_id argument + // gets the number of the highest valid extended ID. + __cpuid(0x80000000, cpui[0], cpui[1], cpui[2], cpui[3]); + nExIds_ = cpui[0]; + + char brand[0x40]; + memset(brand, 0, sizeof(brand)); + + for (int i = 0x80000000; i <= nExIds_; ++i) { + __cpuid_count(i, 0, cpui[0], cpui[1], cpui[2], cpui[3]); + extdata_.push_back(cpui); + } + + // load bitset with flags for function 0x80000001 + if (nExIds_ >= (int)0x80000001) { + f_81_ECX_ = extdata_[1][2]; + f_81_EDX_ = extdata_[1][3]; + } + + // Interpret CPU brand string if reported + if (nExIds_ >= (int)0x80000004) { + memcpy(brand, extdata_[2].data(), sizeof(cpui)); + memcpy(brand + 16, extdata_[3].data(), sizeof(cpui)); + memcpy(brand + 32, extdata_[4].data(), sizeof(cpui)); + brand_ = brand; + } +}; + +// +bool +cpu_support_avx512() { + InstructionSet& instruction_set_inst = InstructionSet::GetInstance(); + return (instruction_set_inst.AVX512F() && instruction_set_inst.AVX512DQ() && + instruction_set_inst.AVX512BW() && instruction_set_inst.AVX512VL()); +} + +// +bool +cpu_support_avx2() { + InstructionSet& instruction_set_inst = InstructionSet::GetInstance(); + return (instruction_set_inst.AVX2()); +} + +// +bool +cpu_support_sse4_2() { + InstructionSet& instruction_set_inst = InstructionSet::GetInstance(); + return (instruction_set_inst.SSE42()); +} + +// +bool +cpu_support_sse2() { + InstructionSet& instruction_set_inst = InstructionSet::GetInstance(); + return (instruction_set_inst.SSE2()); +} + +} // namespace x86 +} // namespace detail +} // namespace bitset +} // namespace milvus diff --git a/internal/core/src/simd/instruction_set.h b/internal/core/src/bitset/detail/platform/x86/instruction_set.h similarity index 54% rename from internal/core/src/simd/instruction_set.h rename to internal/core/src/bitset/detail/platform/x86/instruction_set.h index a80686d1603b..92ab309c9514 100644 --- a/internal/core/src/simd/instruction_set.h +++ b/internal/core/src/bitset/detail/platform/x86/instruction_set.h @@ -1,27 +1,30 @@ -// Copyright (C) 2019-2023 Zilliz. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance // with the License. You may obtain a copy of the License at // -// http://www.apache.org/licenses/LICENSE-2.0 +// http://www.apache.org/licenses/LICENSE-2.0 // -// Unless required by applicable law or agreed to in writing, software distributed under the License -// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -// or implied. See the License for the specific language governing permissions and limitations under the License. +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once -#include - #include #include -#include -#include #include #include namespace milvus { -namespace simd { +namespace bitset { +namespace detail { +namespace x86 { class InstructionSet { public: @@ -32,83 +35,7 @@ class InstructionSet { } private: - InstructionSet() - : nIds_{0}, - nExIds_{0}, - isIntel_{false}, - isAMD_{false}, - f_1_ECX_{0}, - f_1_EDX_{0}, - f_7_EBX_{0}, - f_7_ECX_{0}, - f_81_ECX_{0}, - f_81_EDX_{0}, - data_{}, - extdata_{} { - std::array cpui; - - // Calling __cpuid with 0x0 as the function_id argument - // gets the number of the highest valid function ID. - __cpuid(0, cpui[0], cpui[1], cpui[2], cpui[3]); - nIds_ = cpui[0]; - - for (int i = 0; i <= nIds_; ++i) { - __cpuid_count(i, 0, cpui[0], cpui[1], cpui[2], cpui[3]); - data_.push_back(cpui); - } - - // Capture vendor string - char vendor[0x20]; - memset(vendor, 0, sizeof(vendor)); - *reinterpret_cast(vendor) = data_[0][1]; - *reinterpret_cast(vendor + 4) = data_[0][3]; - *reinterpret_cast(vendor + 8) = data_[0][2]; - vendor_ = vendor; - if (vendor_ == "GenuineIntel") { - isIntel_ = true; - } else if (vendor_ == "AuthenticAMD") { - isAMD_ = true; - } - - // load bitset with flags for function 0x00000001 - if (nIds_ >= 1) { - f_1_ECX_ = data_[1][2]; - f_1_EDX_ = data_[1][3]; - } - - // load bitset with flags for function 0x00000007 - if (nIds_ >= 7) { - f_7_EBX_ = data_[7][1]; - f_7_ECX_ = data_[7][2]; - } - - // Calling __cpuid with 0x80000000 as the function_id argument - // gets the number of the highest valid extended ID. - __cpuid(0x80000000, cpui[0], cpui[1], cpui[2], cpui[3]); - nExIds_ = cpui[0]; - - char brand[0x40]; - memset(brand, 0, sizeof(brand)); - - for (int i = 0x80000000; i <= nExIds_; ++i) { - __cpuid_count(i, 0, cpui[0], cpui[1], cpui[2], cpui[3]); - extdata_.push_back(cpui); - } - - // load bitset with flags for function 0x80000001 - if (nExIds_ >= (int)0x80000001) { - f_81_ECX_ = extdata_[1][2]; - f_81_EDX_ = extdata_[1][3]; - } - - // Interpret CPU brand string if reported - if (nExIds_ >= (int)0x80000004) { - memcpy(brand, extdata_[2].data(), sizeof(cpui)); - memcpy(brand + 16, extdata_[3].data(), sizeof(cpui)); - memcpy(brand + 32, extdata_[4].data(), sizeof(cpui)); - brand_ = brand; - } - }; + InstructionSet(); public: // getters @@ -348,21 +275,32 @@ class InstructionSet { } private: - int nIds_; - int nExIds_; + int nIds_ = 0; + int nExIds_ = 0; std::string vendor_; std::string brand_; - bool isIntel_; - bool isAMD_; - std::bitset<32> f_1_ECX_; - std::bitset<32> f_1_EDX_; - std::bitset<32> f_7_EBX_; - std::bitset<32> f_7_ECX_; - std::bitset<32> f_81_ECX_; - std::bitset<32> f_81_EDX_; + bool isIntel_ = false; + bool isAMD_ = false; + std::bitset<32> f_1_ECX_ = {0}; + std::bitset<32> f_1_EDX_ = {0}; + std::bitset<32> f_7_EBX_ = {0}; + std::bitset<32> f_7_ECX_ = {0}; + std::bitset<32> f_81_ECX_ = {0}; + std::bitset<32> f_81_EDX_ = {0}; std::vector> data_; std::vector> extdata_; }; -} // namespace simd -} // namespace milvus \ No newline at end of file +bool +cpu_support_avx512(); +bool +cpu_support_avx2(); +bool +cpu_support_sse4_2(); +bool +cpu_support_sse2(); + +} // namespace x86 +} // namespace detail +} // namespace bitset +} // namespace milvus diff --git a/internal/core/src/bitset/detail/popcount.h b/internal/core/src/bitset/detail/popcount.h new file mode 100644 index 000000000000..05789d437049 --- /dev/null +++ b/internal/core/src/bitset/detail/popcount.h @@ -0,0 +1,64 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +namespace milvus { +namespace bitset { +namespace detail { + +// +template +struct PopCountHelper {}; + +// +template <> +struct PopCountHelper { + static inline unsigned long long + count(const unsigned long long v) { + return __builtin_popcountll(v); + } +}; + +template <> +struct PopCountHelper { + static inline unsigned long + count(const unsigned long v) { + return __builtin_popcountl(v); + } +}; + +template <> +struct PopCountHelper { + static inline unsigned int + count(const unsigned int v) { + return __builtin_popcount(v); + } +}; + +template <> +struct PopCountHelper { + static inline uint8_t + count(const uint8_t v) { + return __builtin_popcount(v); + } +}; + +} // namespace detail +} // namespace bitset +} // namespace milvus diff --git a/internal/core/src/bitset/detail/proxy.h b/internal/core/src/bitset/detail/proxy.h new file mode 100644 index 000000000000..efcdc0994e57 --- /dev/null +++ b/internal/core/src/bitset/detail/proxy.h @@ -0,0 +1,133 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +namespace milvus { +namespace bitset { +namespace detail { + +template +struct ConstProxy { + using policy_type = PolicyT; + using size_type = typename policy_type::size_type; + using data_type = typename policy_type::data_type; + using self_type = ConstProxy; + + const data_type& element; + data_type mask; + + inline ConstProxy(const data_type& _element, const size_type _shift) + : element{_element} { + mask = (data_type(1) << _shift); + } + + inline operator bool() const { + return ((element & mask) != 0); + } + inline bool + operator~() const { + return ((element & mask) == 0); + } +}; + +template +struct Proxy { + using policy_type = PolicyT; + using size_type = typename policy_type::size_type; + using data_type = typename policy_type::data_type; + using self_type = Proxy; + + data_type& element; + data_type mask; + + inline Proxy(data_type& _element, const size_type _shift) + : element{_element} { + mask = (data_type(1) << _shift); + } + + inline operator bool() const { + return ((element & mask) != 0); + } + inline bool + operator~() const { + return ((element & mask) == 0); + } + + inline self_type& + operator=(const bool value) { + if (value) { + set(); + } else { + reset(); + } + return *this; + } + + inline self_type& + operator=(const self_type& other) { + bool value = other.operator bool(); + if (value) { + set(); + } else { + reset(); + } + return *this; + } + + inline self_type& + operator|=(const bool value) { + if (value) { + set(); + } + return *this; + } + + inline self_type& + operator&=(const bool value) { + if (!value) { + reset(); + } + return *this; + } + + inline self_type& + operator^=(const bool value) { + if (value) { + flip(); + } + return *this; + } + + inline void + set() { + element |= mask; + } + + inline void + reset() { + element &= ~mask; + } + + inline void + flip() { + element ^= mask; + } +}; + +} // namespace detail +} // namespace bitset +} // namespace milvus diff --git a/internal/core/src/bitset/readme.txt b/internal/core/src/bitset/readme.txt new file mode 100644 index 000000000000..95e0e82d96e5 --- /dev/null +++ b/internal/core/src/bitset/readme.txt @@ -0,0 +1 @@ +The standlaone version of the bitset library is available at https://github.com/alexanderguzhva/bitset diff --git a/internal/core/src/common/BitsetView.h b/internal/core/src/common/BitsetView.h index dc0e9d8a5988..3d1a75be6c92 100644 --- a/internal/core/src/common/BitsetView.h +++ b/internal/core/src/common/BitsetView.h @@ -41,8 +41,7 @@ class BitsetView : public knowhere::BitsetView { } BitsetView(const BitsetType& bitset) // NOLINT - : BitsetView((uint8_t*)boost_ext::get_data(bitset), - size_t(bitset.size())) { + : BitsetView((uint8_t*)(bitset.data()), size_t(bitset.size())) { } BitsetView(const BitsetTypePtr& bitset_ptr) { // NOLINT diff --git a/internal/core/src/common/CMakeLists.txt b/internal/core/src/common/CMakeLists.txt index 3412651778d7..4330b43f8099 100644 --- a/internal/core/src/common/CMakeLists.txt +++ b/internal/core/src/common/CMakeLists.txt @@ -29,6 +29,7 @@ set(COMMON_SRC add_library(milvus_common SHARED ${COMMON_SRC}) target_link_libraries(milvus_common + milvus_bitset milvus_config milvus_log milvus_proto diff --git a/internal/core/src/common/CustomBitset.h b/internal/core/src/common/CustomBitset.h new file mode 100644 index 000000000000..476df245ed97 --- /dev/null +++ b/internal/core/src/common/CustomBitset.h @@ -0,0 +1,48 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include + +#include "bitset/bitset.h" +#include "bitset/common.h" +#include "bitset/detail/element_vectorized.h" +#include "bitset/detail/platform/dynamic.h" + +namespace milvus { + +namespace { + +using vectorized_type = milvus::bitset::detail::VectorizedDynamic; +using policy_type = + milvus::bitset::detail::VectorizedElementWiseBitsetPolicy; +using container_type = folly::fbvector; +// temporary enable range check +using bitset_type = milvus::bitset::Bitset; +// temporary enable range check +using bitset_view = milvus::bitset::BitsetView; + +} // namespace + +using CustomBitset = bitset_type; +using CustomBitsetView = bitset_view; + +} // namespace milvus diff --git a/internal/core/src/common/Types.h b/internal/core/src/common/Types.h index ef7adc587215..21c757c9c557 100644 --- a/internal/core/src/common/Types.h +++ b/internal/core/src/common/Types.h @@ -46,6 +46,8 @@ #include "pb/segcore.pb.h" #include "Json.h" +#include "CustomBitset.h" + namespace milvus { using idx_t = int64_t; @@ -158,8 +160,9 @@ using OptFieldT = std::unordered_map< using SegOffset = fluent::NamedType; -using BitsetType = boost::dynamic_bitset<>; -using BitsetTypePtr = std::shared_ptr>; +//using BitsetType = boost::dynamic_bitset<>; +using BitsetType = CustomBitset; +using BitsetTypePtr = std::shared_ptr; using BitsetTypeOpt = std::optional; template @@ -167,7 +170,10 @@ using FixedVector = folly::fbvector< Type>; // boost::container::vector has memory leak when version > 1.79, so use folly::fbvector instead using Config = nlohmann::json; -using TargetBitmap = FixedVector; +//using TargetBitmap = std::vector; +//using TargetBitmapPtr = std::unique_ptr; +using TargetBitmap = CustomBitset; +using TargetBitmapView = CustomBitsetView; using TargetBitmapPtr = std::unique_ptr; using BinaryPtr = knowhere::BinaryPtr; @@ -188,9 +194,9 @@ IndexIsSparse(const IndexType& index_type) { // Plus 1 because we can't use greater(>) symbol constexpr size_t REF_SIZE_THRESHOLD = 16 + 1; -using BitsetBlockType = BitsetType::block_type; -constexpr size_t BITSET_BLOCK_SIZE = sizeof(BitsetType::block_type); -constexpr size_t BITSET_BLOCK_BIT_SIZE = sizeof(BitsetType::block_type) * 8; +//using BitsetBlockType = BitsetType::block_type; +//constexpr size_t BITSET_BLOCK_SIZE = sizeof(BitsetType::block_type); +//constexpr size_t BITSET_BLOCK_BIT_SIZE = sizeof(BitsetType::block_type) * 8; template using MayConstRef = std::conditional_t || std::is_same_v, diff --git a/internal/core/src/common/Vector.h b/internal/core/src/common/Vector.h index 0abaf9af60e5..dab66ffb18a3 100644 --- a/internal/core/src/common/Vector.h +++ b/internal/core/src/common/Vector.h @@ -68,10 +68,17 @@ class ColumnVector final : public BaseVector { values_ = InitScalarFieldData(data_type, length); } - ColumnVector(FixedVector&& data) - : BaseVector(DataType::BOOL, data.size()) { - values_ = - std::make_shared>(DataType::BOOL, std::move(data)); + // ColumnVector(FixedVector&& data) + // : BaseVector(DataType::BOOL, data.size()) { + // values_ = + // std::make_shared>(DataType::BOOL, std::move(data)); + // } + + // the size is the number of bits + ColumnVector(TargetBitmap&& bitmap) + : BaseVector(DataType::INT8, bitmap.size()) { + values_ = std::make_shared>( + bitmap.size(), DataType::INT8, std::move(bitmap).into()); } virtual ~ColumnVector() override { diff --git a/internal/core/src/exec/CMakeLists.txt b/internal/core/src/exec/CMakeLists.txt index 1573cf3e5702..9b1ca330c7bc 100644 --- a/internal/core/src/exec/CMakeLists.txt +++ b/internal/core/src/exec/CMakeLists.txt @@ -29,8 +29,5 @@ set(MILVUS_EXEC_SRCS ) add_library(milvus_exec STATIC ${MILVUS_EXEC_SRCS}) -if(USE_DYNAMIC_SIMD) - target_link_libraries(milvus_exec milvus_common milvus_simd milvus-storage ${CONAN_LIBS}) -else() - target_link_libraries(milvus_exec milvus_common milvus-storage ${CONAN_LIBS}) -endif() + +target_link_libraries(milvus_exec milvus_common milvus-storage ${CONAN_LIBS}) diff --git a/internal/core/src/exec/expression/AlwaysTrueExpr.cpp b/internal/core/src/exec/expression/AlwaysTrueExpr.cpp index 85594e7e7023..24789c429ac8 100644 --- a/internal/core/src/exec/expression/AlwaysTrueExpr.cpp +++ b/internal/core/src/exec/expression/AlwaysTrueExpr.cpp @@ -31,11 +31,10 @@ PhyAlwaysTrueExpr::Eval(EvalCtx& context, VectorPtr& result) { } auto res_vec = - std::make_shared(DataType::BOOL, real_batch_size); - bool* res_bool = (bool*)res_vec->GetRawData(); - for (size_t i = 0; i < real_batch_size; ++i) { - res_bool[i] = true; - } + std::make_shared(TargetBitmap(real_batch_size)); + TargetBitmapView res(res_vec->GetRawData(), real_batch_size); + + res.set(); result = res_vec; current_pos_ += real_batch_size; diff --git a/internal/core/src/exec/expression/BinaryArithOpEvalRangeExpr.cpp b/internal/core/src/exec/expression/BinaryArithOpEvalRangeExpr.cpp index 2ce0d8b2a9c3..bf944eb6e444 100644 --- a/internal/core/src/exec/expression/BinaryArithOpEvalRangeExpr.cpp +++ b/internal/core/src/exec/expression/BinaryArithOpEvalRangeExpr.cpp @@ -112,8 +112,8 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForJson() { return nullptr; } auto res_vec = - std::make_shared(DataType::BOOL, real_batch_size); - bool* res = (bool*)res_vec->GetRawData(); + std::make_shared(TargetBitmap(real_batch_size)); + TargetBitmapView res(res_vec->GetRawData(), real_batch_size); auto pointer = milvus::Json::pointer(expr_->column_.nested_path_); auto op_type = expr_->op_type_; @@ -160,7 +160,7 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForJson() { auto execute_sub_batch = [op_type, arith_type](const milvus::Json* data, const int size, - bool* res, + TargetBitmapView res, ValueType val, ValueType right_operand, const std::string& pointer) { @@ -491,8 +491,8 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForArray() { return nullptr; } auto res_vec = - std::make_shared(DataType::BOOL, real_batch_size); - bool* res = (bool*)res_vec->GetRawData(); + std::make_shared(TargetBitmap(real_batch_size)); + TargetBitmapView res(res_vec->GetRawData(), real_batch_size); int index = -1; if (expr_->column_.nested_path_.size() > 0) { @@ -520,7 +520,7 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForArray() { auto execute_sub_batch = [op_type, arith_type](const ArrayView* data, const int size, - bool* res, + TargetBitmapView res, ValueType val, ValueType right_operand, int index) { @@ -836,7 +836,7 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForIndex() { Index* index_ptr, HighPrecisionType value, HighPrecisionType right_operand) { - FixedVector res; + TargetBitmap res; switch (op_type) { case proto::plan::OpType::Equal: { switch (arith_type) { @@ -1208,15 +1208,15 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForData() { auto right_operand = GetValueFromProto(expr_->right_operand_); auto res_vec = - std::make_shared(DataType::BOOL, real_batch_size); - bool* res = (bool*)res_vec->GetRawData(); + std::make_shared(TargetBitmap(real_batch_size)); + TargetBitmapView res(res_vec->GetRawData(), real_batch_size); auto op_type = expr_->op_type_; auto arith_type = expr_->arith_op_type_; auto execute_sub_batch = [op_type, arith_type]( const T* data, const int size, - bool* res, + TargetBitmapView res, HighPrecisionType value, HighPrecisionType right_operand) { switch (op_type) { diff --git a/internal/core/src/exec/expression/BinaryArithOpEvalRangeExpr.h b/internal/core/src/exec/expression/BinaryArithOpEvalRangeExpr.h index 805d77c62d68..3c84819dc2b8 100644 --- a/internal/core/src/exec/expression/BinaryArithOpEvalRangeExpr.h +++ b/internal/core/src/exec/expression/BinaryArithOpEvalRangeExpr.h @@ -28,6 +28,64 @@ namespace milvus { namespace exec { +namespace { + +template +struct CmpOpHelper { + using op = void; +}; +template <> +struct CmpOpHelper { + static constexpr auto op = milvus::bitset::CompareOpType::EQ; +}; +template <> +struct CmpOpHelper { + static constexpr auto op = milvus::bitset::CompareOpType::GE; +}; +template <> +struct CmpOpHelper { + static constexpr auto op = milvus::bitset::CompareOpType::GT; +}; +template <> +struct CmpOpHelper { + static constexpr auto op = milvus::bitset::CompareOpType::LE; +}; +template <> +struct CmpOpHelper { + static constexpr auto op = milvus::bitset::CompareOpType::LT; +}; +template <> +struct CmpOpHelper { + static constexpr auto op = milvus::bitset::CompareOpType::NE; +}; + +template +struct ArithOpHelper { + using op = void; +}; +template <> +struct ArithOpHelper { + static constexpr auto op = milvus::bitset::ArithOpType::Add; +}; +template <> +struct ArithOpHelper { + static constexpr auto op = milvus::bitset::ArithOpType::Sub; +}; +template <> +struct ArithOpHelper { + static constexpr auto op = milvus::bitset::ArithOpType::Mul; +}; +template <> +struct ArithOpHelper { + static constexpr auto op = milvus::bitset::ArithOpType::Div; +}; +template <> +struct ArithOpHelper { + static constexpr auto op = milvus::bitset::ArithOpType::Mod; +}; + +} // namespace + template @@ -42,7 +100,9 @@ struct ArithOpElementFunc { size_t size, HighPrecisonType val, HighPrecisonType right_operand, - bool* res) { + TargetBitmapView res) { + /* + // This is the original code, kept here for the documentation purposes for (int i = 0; i < size; ++i) { if constexpr (cmp_op == proto::plan::OpType::Equal) { if constexpr (arith_op == proto::plan::ArithOpType::Add) { @@ -178,6 +238,30 @@ struct ArithOpElementFunc { } } } + */ + + if constexpr (!std::is_same_v::op), + void>) { + constexpr auto cmp_op_cvt = CmpOpHelper::op; + if constexpr (!std::is_same_v::op), + void>) { + constexpr auto arith_op_cvt = ArithOpHelper::op; + + res.inplace_arith_compare( + src, right_operand, val, size); + } else { + PanicInfo( + OpTypeInvalid, + fmt::format( + "unsupported arith type:{} for ArithOpElementFunc", + arith_op)); + } + } else { + PanicInfo( + OpTypeInvalid, + fmt::format("unsupported cmp type:{} for ArithOpElementFunc", + cmp_op)); + } } }; @@ -191,13 +275,12 @@ struct ArithOpIndexFunc { T> HighPrecisonType; using Index = index::ScalarIndex; - FixedVector + TargetBitmap operator()(Index* index, size_t size, HighPrecisonType val, HighPrecisonType right_operand) { - FixedVector res_vec(size); - bool* res = res_vec.data(); + TargetBitmap res(size); for (size_t i = 0; i < size; ++i) { if constexpr (cmp_op == proto::plan::OpType::Equal) { if constexpr (arith_op == proto::plan::ArithOpType::Add) { @@ -339,7 +422,7 @@ struct ArithOpIndexFunc { } } } - return res_vec; + return res; } }; diff --git a/internal/core/src/exec/expression/BinaryRangeExpr.cpp b/internal/core/src/exec/expression/BinaryRangeExpr.cpp index 953591a8961f..d7d916aade21 100644 --- a/internal/core/src/exec/expression/BinaryRangeExpr.cpp +++ b/internal/core/src/exec/expression/BinaryRangeExpr.cpp @@ -144,7 +144,7 @@ PhyBinaryRangeFilterExpr::PreCheckOverflow(HighPrecisionType& val1, cached_overflow_res_->size() == batch_size) { return cached_overflow_res_; } - auto res = std::make_shared(DataType::BOOL, batch_size); + auto res = std::make_shared(TargetBitmap(batch_size)); return res; }; @@ -235,12 +235,13 @@ PhyBinaryRangeFilterExpr::ExecRangeVisitorImplForData() { return res; } auto res_vec = - std::make_shared(DataType::BOOL, real_batch_size); - bool* res = (bool*)res_vec->GetRawData(); + std::make_shared(TargetBitmap(real_batch_size)); + TargetBitmapView res(res_vec->GetRawData(), real_batch_size); + auto execute_sub_batch = [lower_inclusive, upper_inclusive]( const T* data, const int size, - bool* res, + TargetBitmapView res, HighPrecisionType val1, HighPrecisionType val2) { if (lower_inclusive && upper_inclusive) { @@ -295,8 +296,8 @@ PhyBinaryRangeFilterExpr::ExecRangeVisitorImplForJson() { return nullptr; } auto res_vec = - std::make_shared(DataType::BOOL, real_batch_size); - bool* res = (bool*)res_vec->GetRawData(); + std::make_shared(TargetBitmap(real_batch_size)); + TargetBitmapView res(res_vec->GetRawData(), real_batch_size); bool lower_inclusive = expr_->lower_inclusive_; bool upper_inclusive = expr_->upper_inclusive_; @@ -307,7 +308,7 @@ PhyBinaryRangeFilterExpr::ExecRangeVisitorImplForJson() { auto execute_sub_batch = [lower_inclusive, upper_inclusive, pointer]( const milvus::Json* data, const int size, - bool* res, + TargetBitmapView res, ValueType val1, ValueType val2) { if (lower_inclusive && upper_inclusive) { @@ -345,8 +346,8 @@ PhyBinaryRangeFilterExpr::ExecRangeVisitorImplForArray() { return nullptr; } auto res_vec = - std::make_shared(DataType::BOOL, real_batch_size); - bool* res = (bool*)res_vec->GetRawData(); + std::make_shared(TargetBitmap(real_batch_size)); + TargetBitmapView res(res_vec->GetRawData(), real_batch_size); bool lower_inclusive = expr_->lower_inclusive_; bool upper_inclusive = expr_->upper_inclusive_; @@ -360,7 +361,7 @@ PhyBinaryRangeFilterExpr::ExecRangeVisitorImplForArray() { auto execute_sub_batch = [lower_inclusive, upper_inclusive]( const milvus::ArrayView* data, const int size, - bool* res, + TargetBitmapView res, ValueType val1, ValueType val2, int index) { diff --git a/internal/core/src/exec/expression/BinaryRangeExpr.h b/internal/core/src/exec/expression/BinaryRangeExpr.h index cb2b0b9e783b..6484a40e5ef1 100644 --- a/internal/core/src/exec/expression/BinaryRangeExpr.h +++ b/internal/core/src/exec/expression/BinaryRangeExpr.h @@ -35,17 +35,19 @@ struct BinaryRangeElementFunc { T> HighPrecisionType; void - operator()(T val1, T val2, const T* src, size_t n, bool* res) { - for (size_t i = 0; i < n; ++i) { - if constexpr (lower_inclusive && upper_inclusive) { - res[i] = val1 <= src[i] && src[i] <= val2; - } else if constexpr (lower_inclusive && !upper_inclusive) { - res[i] = val1 <= src[i] && src[i] < val2; - } else if constexpr (!lower_inclusive && upper_inclusive) { - res[i] = val1 < src[i] && src[i] <= val2; - } else { - res[i] = val1 < src[i] && src[i] < val2; - } + operator()(T val1, T val2, const T* src, size_t n, TargetBitmapView res) { + if constexpr (lower_inclusive && upper_inclusive) { + res.inplace_within_range_val( + val1, val2, src, n); + } else if constexpr (lower_inclusive && !upper_inclusive) { + res.inplace_within_range_val( + val1, val2, src, n); + } else if constexpr (!lower_inclusive && upper_inclusive) { + res.inplace_within_range_val( + val1, val2, src, n); + } else { + res.inplace_within_range_val( + val1, val2, src, n); } } }; @@ -80,7 +82,7 @@ struct BinaryRangeElementFuncForJson { const std::string& pointer, const milvus::Json* src, size_t n, - bool* res) { + TargetBitmapView res) { for (size_t i = 0; i < n; ++i) { if constexpr (lower_inclusive && upper_inclusive) { BinaryRangeJSONCompare(val1 <= value && value <= val2); @@ -106,7 +108,7 @@ struct BinaryRangeElementFuncForArray { int index, const milvus::ArrayView* src, size_t n, - bool* res) { + TargetBitmapView res) { for (size_t i = 0; i < n; ++i) { if constexpr (lower_inclusive && upper_inclusive) { if (index >= src[i].length()) { @@ -152,7 +154,7 @@ struct BinaryRangeIndexFunc { int64_t, IndexInnerType> HighPrecisionType; - FixedVector + TargetBitmap operator()(Index* index, IndexInnerType val1, IndexInnerType val2, diff --git a/internal/core/src/exec/expression/CompareExpr.cpp b/internal/core/src/exec/expression/CompareExpr.cpp index 73adcfe9e0dd..f9199d65938a 100644 --- a/internal/core/src/exec/expression/CompareExpr.cpp +++ b/internal/core/src/exec/expression/CompareExpr.cpp @@ -118,8 +118,8 @@ PhyCompareFilterExpr::ExecCompareExprDispatcher(OpType op) { } auto res_vec = - std::make_shared(DataType::BOOL, real_batch_size); - bool* res = (bool*)res_vec->GetRawData(); + std::make_shared(TargetBitmap(real_batch_size)); + TargetBitmapView res(res_vec->GetRawData(), real_batch_size); auto left_data_barrier = segment_->num_chunk_data(expr_->left_field_id_); auto right_data_barrier = segment_->num_chunk_data(expr_->right_field_id_); @@ -257,14 +257,16 @@ PhyCompareFilterExpr::ExecCompareRightType() { if (real_batch_size == 0) { return nullptr; } + auto res_vec = - std::make_shared(DataType::BOOL, real_batch_size); - bool* res = (bool*)res_vec->GetRawData(); + std::make_shared(TargetBitmap(real_batch_size)); + TargetBitmapView res(res_vec->GetRawData(), real_batch_size); + auto expr_type = expr_->op_type_; auto execute_sub_batch = [expr_type](const T* left, const U* right, const int size, - bool* res) { + TargetBitmapView res) { switch (expr_type) { case proto::plan::GreaterThan: { CompareElementFunc func; diff --git a/internal/core/src/exec/expression/CompareExpr.h b/internal/core/src/exec/expression/CompareExpr.h index 392e6d21c735..ff6069665182 100644 --- a/internal/core/src/exec/expression/CompareExpr.h +++ b/internal/core/src/exec/expression/CompareExpr.h @@ -24,7 +24,6 @@ #include "common/Vector.h" #include "exec/expression/Expr.h" #include "segcore/SegmentInterface.h" -#include "simd/interface.h" namespace milvus { namespace exec { @@ -42,7 +41,12 @@ using ChunkDataAccessor = std::function; template struct CompareElementFunc { void - operator_base(const T* left, const U* right, size_t size, bool* res) { + operator()(const T* left, + const U* right, + size_t size, + TargetBitmapView res) { + /* + // This is the original code, kept here for the documentation purposes for (int i = 0; i < size; ++i) { if constexpr (op == proto::plan::OpType::Equal) { res[i] = left[i] == right[i]; @@ -63,24 +67,31 @@ struct CompareElementFunc { op)); } } - } - - void - operator()(const T* left, const U* right, size_t size, bool* res) { -#if defined(USE_DYNAMIC_SIMD) - if constexpr (std::is_same_v) { - milvus::simd::compare_col_func( - static_cast(op), - left, - right, - size, - res); + */ + + if constexpr (op == proto::plan::OpType::Equal) { + res.inplace_compare_column( + left, right, size); + } else if constexpr (op == proto::plan::OpType::NotEqual) { + res.inplace_compare_column( + left, right, size); + } else if constexpr (op == proto::plan::OpType::GreaterThan) { + res.inplace_compare_column( + left, right, size); + } else if constexpr (op == proto::plan::OpType::LessThan) { + res.inplace_compare_column( + left, right, size); + } else if constexpr (op == proto::plan::OpType::GreaterEqual) { + res.inplace_compare_column( + left, right, size); + } else if constexpr (op == proto::plan::OpType::LessEqual) { + res.inplace_compare_column( + left, right, size); } else { - operator_base(left, right, size, res); + PanicInfo(OpTypeInvalid, + fmt::format( + "unsupported op_type:{} for CompareElementFunc", op)); } -#else - operator_base(left, right, size, res); -#endif } }; @@ -148,7 +159,7 @@ class PhyCompareFilterExpr : public Expr { template int64_t - ProcessBothDataChunks(FUNC func, bool* res, ValTypes... values) { + ProcessBothDataChunks(FUNC func, TargetBitmapView res, ValTypes... values) { int64_t processed_size = 0; for (size_t i = current_chunk_id_; i < num_chunk_; i++) { diff --git a/internal/core/src/exec/expression/ConjunctExpr.cpp b/internal/core/src/exec/expression/ConjunctExpr.cpp index a26b98dda783..da535d936d03 100644 --- a/internal/core/src/exec/expression/ConjunctExpr.cpp +++ b/internal/core/src/exec/expression/ConjunctExpr.cpp @@ -15,7 +15,6 @@ // limitations under the License. #include "ConjunctExpr.h" -#include "simd/hook.h" namespace milvus { namespace exec { @@ -39,48 +38,26 @@ PhyConjunctFilterExpr::ResolveType(const std::vector& inputs) { static bool AllTrue(ColumnVectorPtr& vec) { - bool* data = static_cast(vec->GetRawData()); -#if defined(USE_DYNAMIC_SIMD) - return milvus::simd::all_true(data, vec->size()); -#else - for (int i = 0; i < vec->size(); ++i) { - if (!data[i]) { - return false; - } - } - return true; -#endif + TargetBitmapView data(vec->GetRawData(), vec->size()); + return data.all(); } static void AllSet(ColumnVectorPtr& vec) { - bool* data = static_cast(vec->GetRawData()); - for (int i = 0; i < vec->size(); ++i) { - data[i] = true; - } + TargetBitmapView data(vec->GetRawData(), vec->size()); + data.set(); } static void AllReset(ColumnVectorPtr& vec) { - bool* data = static_cast(vec->GetRawData()); - for (int i = 0; i < vec->size(); ++i) { - data[i] = false; - } + TargetBitmapView data(vec->GetRawData(), vec->size()); + data.reset(); } static bool AllFalse(ColumnVectorPtr& vec) { - bool* data = static_cast(vec->GetRawData()); -#if defined(USE_DYNAMIC_SIMD) - return milvus::simd::all_false(data, vec->size()); -#else - for (int i = 0; i < vec->size(); ++i) { - if (data[i]) { - return false; - } - } - return true; -#endif + TargetBitmapView data(vec->GetRawData(), vec->size()); + return data.none(); } int64_t diff --git a/internal/core/src/exec/expression/ConjunctExpr.h b/internal/core/src/exec/expression/ConjunctExpr.h index bd8105997736..de239bcb7551 100644 --- a/internal/core/src/exec/expression/ConjunctExpr.h +++ b/internal/core/src/exec/expression/ConjunctExpr.h @@ -31,8 +31,12 @@ template struct ConjunctElementFunc { int64_t operator()(ColumnVectorPtr& input_result, ColumnVectorPtr& result) { - bool* input_data = static_cast(input_result->GetRawData()); - bool* res_data = static_cast(result->GetRawData()); + TargetBitmapView input_data(input_result->GetRawData(), + input_result->size()); + TargetBitmapView res_data(result->GetRawData(), result->size()); + + /* + // This is the original code, kept here for the documentation purposes int64_t activate_rows = 0; for (int i = 0; i < result->size(); ++i) { if constexpr (is_and) { @@ -47,7 +51,15 @@ struct ConjunctElementFunc { } } } - return activate_rows; + */ + + if constexpr (is_and) { + return (int64_t)res_data.inplace_and_with_count(input_data, + res_data.size()); + } else { + return (int64_t)res_data.inplace_or_with_count(input_data, + res_data.size()); + } } }; diff --git a/internal/core/src/exec/expression/ExistsExpr.cpp b/internal/core/src/exec/expression/ExistsExpr.cpp index 7a9ec4d4c006..9331ae052317 100644 --- a/internal/core/src/exec/expression/ExistsExpr.cpp +++ b/internal/core/src/exec/expression/ExistsExpr.cpp @@ -45,13 +45,13 @@ PhyExistsFilterExpr::EvalJsonExistsForDataSegment() { return nullptr; } auto res_vec = - std::make_shared(DataType::BOOL, real_batch_size); - bool* res = (bool*)res_vec->GetRawData(); + std::make_shared(TargetBitmap(real_batch_size)); + TargetBitmapView res(res_vec->GetRawData(), real_batch_size); auto pointer = milvus::Json::pointer(expr_->column_.nested_path_); auto execute_sub_batch = [](const milvus::Json* data, const int size, - bool* res, + TargetBitmapView res, const std::string& pointer) { for (int i = 0; i < size; ++i) { res[i] = data[i].exist(pointer); diff --git a/internal/core/src/exec/expression/ExistsExpr.h b/internal/core/src/exec/expression/ExistsExpr.h index b672e7d65eb7..2b2410853157 100644 --- a/internal/core/src/exec/expression/ExistsExpr.h +++ b/internal/core/src/exec/expression/ExistsExpr.h @@ -30,7 +30,7 @@ namespace exec { template struct ExistsElementFunc { void - operator()(const T* src, size_t size, T val, bool* res) { + operator()(const T* src, size_t size, T val, TargetBitmapView res) { } }; diff --git a/internal/core/src/exec/expression/Expr.h b/internal/core/src/exec/expression/Expr.h index 587052a5b95d..ea9eeac92cef 100644 --- a/internal/core/src/exec/expression/Expr.h +++ b/internal/core/src/exec/expression/Expr.h @@ -188,7 +188,7 @@ class SegmentExpr : public Expr { ProcessDataChunks( FUNC func, std::function skip_func, - bool* res, + TargetBitmapView res, ValTypes... values) { int64_t processed_size = 0; @@ -225,9 +225,9 @@ class SegmentExpr : public Expr { } int - ProcessIndexOneChunk(FixedVector& result, + ProcessIndexOneChunk(TargetBitmap& result, size_t chunk_id, - const FixedVector& chunk_res, + const TargetBitmap& chunk_res, int processed_rows) { auto data_pos = chunk_id == current_index_chunk_ ? current_index_chunk_pos_ : 0; @@ -235,20 +235,21 @@ class SegmentExpr : public Expr { std::min(size_per_chunk_ - data_pos, batch_size_ - processed_rows), int64_t(chunk_res.size())); - result.insert(result.end(), - chunk_res.begin() + data_pos, - chunk_res.begin() + data_pos + size); + // result.insert(result.end(), + // chunk_res.begin() + data_pos, + // chunk_res.begin() + data_pos + size); + result.append(chunk_res, data_pos, size); return size; } template - FixedVector + TargetBitmap ProcessIndexChunks(FUNC func, ValTypes... values) { typedef std:: conditional_t, std::string, T> IndexInnerType; using Index = index::ScalarIndex; - FixedVector result; + TargetBitmap result; int processed_rows = 0; for (size_t i = current_index_chunk_; i < num_index_chunk_; i++) { @@ -330,7 +331,7 @@ class SegmentExpr : public Expr { // Cache for index scan to avoid search index every batch int64_t cached_index_chunk_id_{-1}; - FixedVector cached_index_chunk_res_{}; + TargetBitmap cached_index_chunk_res_{}; }; void diff --git a/internal/core/src/exec/expression/JsonContainsExpr.cpp b/internal/core/src/exec/expression/JsonContainsExpr.cpp index eafa08731bf9..13cfd147d5e2 100644 --- a/internal/core/src/exec/expression/JsonContainsExpr.cpp +++ b/internal/core/src/exec/expression/JsonContainsExpr.cpp @@ -168,15 +168,16 @@ PhyJsonContainsFilterExpr::ExecArrayContains() { "[ExecArrayContains]nested path must be null"); auto res_vec = - std::make_shared(DataType::BOOL, real_batch_size); - bool* res = (bool*)res_vec->GetRawData(); + std::make_shared(TargetBitmap(real_batch_size)); + TargetBitmapView res(res_vec->GetRawData(), real_batch_size); + std::unordered_set elements; for (auto const& element : expr_->vals_) { elements.insert(GetValueFromProto(element)); } auto execute_sub_batch = [](const milvus::ArrayView* data, const int size, - bool* res, + TargetBitmapView res, const std::unordered_set& elements) { auto executor = [&](size_t i) { const auto& array = data[i]; @@ -215,8 +216,9 @@ PhyJsonContainsFilterExpr::ExecJsonContains() { } auto res_vec = - std::make_shared(DataType::BOOL, real_batch_size); - bool* res = (bool*)res_vec->GetRawData(); + std::make_shared(TargetBitmap(real_batch_size)); + TargetBitmapView res(res_vec->GetRawData(), real_batch_size); + std::unordered_set elements; auto pointer = milvus::Json::pointer(expr_->column_.nested_path_); for (auto const& element : expr_->vals_) { @@ -224,7 +226,7 @@ PhyJsonContainsFilterExpr::ExecJsonContains() { } auto execute_sub_batch = [](const milvus::Json* data, const int size, - bool* res, + TargetBitmapView res, const std::string& pointer, const std::unordered_set& elements) { auto executor = [&](size_t i) { @@ -265,9 +267,11 @@ PhyJsonContainsFilterExpr::ExecJsonContainsArray() { if (real_batch_size == 0) { return nullptr; } + auto res_vec = - std::make_shared(DataType::BOOL, real_batch_size); - bool* res = (bool*)res_vec->GetRawData(); + std::make_shared(TargetBitmap(real_batch_size)); + TargetBitmapView res(res_vec->GetRawData(), real_batch_size); + auto pointer = milvus::Json::pointer(expr_->column_.nested_path_); std::vector elements; for (auto const& element : expr_->vals_) { @@ -276,7 +280,7 @@ PhyJsonContainsFilterExpr::ExecJsonContainsArray() { auto execute_sub_batch = [](const milvus::Json* data, const int size, - bool* res, + TargetBitmapView res, const std::string& pointer, const std::vector& elements) { auto executor = [&](size_t i) -> bool { @@ -333,9 +337,10 @@ PhyJsonContainsFilterExpr::ExecArrayContainsAll() { if (real_batch_size == 0) { return nullptr; } + auto res_vec = - std::make_shared(DataType::BOOL, real_batch_size); - bool* res = (bool*)res_vec->GetRawData(); + std::make_shared(TargetBitmap(real_batch_size)); + TargetBitmapView res(res_vec->GetRawData(), real_batch_size); std::unordered_set elements; for (auto const& element : expr_->vals_) { @@ -344,7 +349,7 @@ PhyJsonContainsFilterExpr::ExecArrayContainsAll() { auto execute_sub_batch = [](const milvus::ArrayView* data, const int size, - bool* res, + TargetBitmapView res, const std::unordered_set& elements) { auto executor = [&](size_t i) { std::unordered_set tmp_elements(elements); @@ -383,9 +388,11 @@ PhyJsonContainsFilterExpr::ExecJsonContainsAll() { if (real_batch_size == 0) { return nullptr; } + auto res_vec = - std::make_shared(DataType::BOOL, real_batch_size); - bool* res = (bool*)res_vec->GetRawData(); + std::make_shared(TargetBitmap(real_batch_size)); + TargetBitmapView res(res_vec->GetRawData(), real_batch_size); + auto pointer = milvus::Json::pointer(expr_->column_.nested_path_); std::unordered_set elements; for (auto const& element : expr_->vals_) { @@ -394,7 +401,7 @@ PhyJsonContainsFilterExpr::ExecJsonContainsAll() { auto execute_sub_batch = [](const milvus::Json* data, const int size, - bool* res, + TargetBitmapView res, const std::string& pointer, const std::unordered_set& elements) { auto executor = [&](const size_t i) -> bool { @@ -439,8 +446,9 @@ PhyJsonContainsFilterExpr::ExecJsonContainsAllWithDiffType() { return nullptr; } auto res_vec = - std::make_shared(DataType::BOOL, real_batch_size); - bool* res = (bool*)res_vec->GetRawData(); + std::make_shared(TargetBitmap(real_batch_size)); + TargetBitmapView res(res_vec->GetRawData(), real_batch_size); + auto pointer = milvus::Json::pointer(expr_->column_.nested_path_); auto elements = expr_->vals_; @@ -454,7 +462,7 @@ PhyJsonContainsFilterExpr::ExecJsonContainsAllWithDiffType() { auto execute_sub_batch = [](const milvus::Json* data, const int size, - bool* res, + TargetBitmapView res, const std::string& pointer, const std::vector& elements, const std::unordered_set elements_index) { @@ -563,9 +571,11 @@ PhyJsonContainsFilterExpr::ExecJsonContainsAllArray() { if (real_batch_size == 0) { return nullptr; } + auto res_vec = - std::make_shared(DataType::BOOL, real_batch_size); - bool* res = (bool*)res_vec->GetRawData(); + std::make_shared(TargetBitmap(real_batch_size)); + TargetBitmapView res(res_vec->GetRawData(), real_batch_size); + auto pointer = milvus::Json::pointer(expr_->column_.nested_path_); std::vector elements; @@ -575,7 +585,7 @@ PhyJsonContainsFilterExpr::ExecJsonContainsAllArray() { auto execute_sub_batch = [](const milvus::Json* data, const int size, - bool* res, + TargetBitmapView res, const std::string& pointer, const std::vector& elements) { auto executor = [&](const size_t i) { @@ -629,9 +639,11 @@ PhyJsonContainsFilterExpr::ExecJsonContainsWithDiffType() { if (real_batch_size == 0) { return nullptr; } + auto res_vec = - std::make_shared(DataType::BOOL, real_batch_size); - bool* res = (bool*)res_vec->GetRawData(); + std::make_shared(TargetBitmap(real_batch_size)); + TargetBitmapView res(res_vec->GetRawData(), real_batch_size); + auto pointer = milvus::Json::pointer(expr_->column_.nested_path_); auto elements = expr_->vals_; @@ -645,7 +657,7 @@ PhyJsonContainsFilterExpr::ExecJsonContainsWithDiffType() { auto execute_sub_batch = [](const milvus::Json* data, const int size, - bool* res, + TargetBitmapView res, const std::string& pointer, const std::vector& elements) { auto executor = [&](const size_t i) { diff --git a/internal/core/src/exec/expression/LogicalBinaryExpr.cpp b/internal/core/src/exec/expression/LogicalBinaryExpr.cpp index 64144ea9b2ba..d388ab2454cc 100644 --- a/internal/core/src/exec/expression/LogicalBinaryExpr.cpp +++ b/internal/core/src/exec/expression/LogicalBinaryExpr.cpp @@ -21,9 +21,10 @@ namespace exec { void PhyLogicalBinaryExpr::Eval(EvalCtx& context, VectorPtr& result) { - AssertInfo(inputs_.size() == 2, - "logical binary expr must has two input, but now {}", - inputs_.size()); + AssertInfo( + inputs_.size() == 2, + "logical binary expr must have 2 inputs, but {} inputs are provided", + inputs_.size()); VectorPtr left; inputs_[0]->Eval(context, left); VectorPtr right; @@ -31,14 +32,14 @@ PhyLogicalBinaryExpr::Eval(EvalCtx& context, VectorPtr& result) { auto lflat = GetColumnVector(left); auto rflat = GetColumnVector(right); auto size = left->size(); - bool* ldata = static_cast(lflat->GetRawData()); - bool* rdata = static_cast(rflat->GetRawData()); + TargetBitmapView lview(lflat->GetRawData(), size); + TargetBitmapView rview(rflat->GetRawData(), size); if (expr_->op_type_ == expr::LogicalBinaryExpr::OpType::And) { LogicalElementFunc func; - func(ldata, rdata, size); + func(lview, rview, size); } else if (expr_->op_type_ == expr::LogicalBinaryExpr::OpType::Or) { LogicalElementFunc func; - func(ldata, rdata, size); + func(lview, rview, size); } else { PanicInfo(OpTypeInvalid, "unsupported logical operator: {}", diff --git a/internal/core/src/exec/expression/LogicalBinaryExpr.h b/internal/core/src/exec/expression/LogicalBinaryExpr.h index c94df0b8b855..43680772fbbf 100644 --- a/internal/core/src/exec/expression/LogicalBinaryExpr.h +++ b/internal/core/src/exec/expression/LogicalBinaryExpr.h @@ -23,7 +23,6 @@ #include "common/Vector.h" #include "exec/expression/Expr.h" #include "segcore/SegmentInterface.h" -#include "simd/hook.h" namespace milvus { namespace exec { @@ -33,16 +32,9 @@ enum class LogicalOpType { Invalid = 0, And = 1, Or = 2, Xor = 3, Minus = 4 }; template struct LogicalElementFunc { void - operator()(bool* left, bool* right, int n) { -#if defined(USE_DYNAMIC_SIMD) - if constexpr (op == LogicalOpType::And) { - milvus::simd::and_bool(left, right, n); - } else if constexpr (op == LogicalOpType::Or) { - milvus::simd::or_bool(left, right, n); - } else { - PanicInfo(OpTypeInvalid, "unsupported logical operator: {}", op); - } -#else + operator()(TargetBitmapView left, TargetBitmapView right, int n) { + /* + // This is the original code, kept here for the documentation purposes for (size_t i = 0; i < n; ++i) { if constexpr (op == LogicalOpType::And) { left[i] &= right[i]; @@ -53,7 +45,19 @@ struct LogicalElementFunc { OpTypeInvalid, "unsupported logical operator: {}", op); } } -#endif + */ + + if constexpr (op == LogicalOpType::And) { + left.inplace_and(right, n); + } else if constexpr (op == LogicalOpType::Or) { + left.inplace_or(right, n); + } else if constexpr (op == LogicalOpType::Xor) { + left.inplace_xor(right, n); + } else if constexpr (op == LogicalOpType::Minus) { + left.inplace_sub(right, n); + } else { + PanicInfo(OpTypeInvalid, "unsupported logical operator: {}", op); + } } }; diff --git a/internal/core/src/exec/expression/LogicalUnaryExpr.cpp b/internal/core/src/exec/expression/LogicalUnaryExpr.cpp index 9b77ef46c31f..4d4bb550691c 100644 --- a/internal/core/src/exec/expression/LogicalUnaryExpr.cpp +++ b/internal/core/src/exec/expression/LogicalUnaryExpr.cpp @@ -15,7 +15,6 @@ // limitations under the License. #include "LogicalUnaryExpr.h" -#include "simd/hook.h" namespace milvus { namespace exec { @@ -29,14 +28,8 @@ PhyLogicalUnaryExpr::Eval(EvalCtx& context, VectorPtr& result) { inputs_[0]->Eval(context, result); if (expr_->op_type_ == milvus::expr::LogicalUnaryExpr::OpType::LogicalNot) { auto flat_vec = GetColumnVector(result); - bool* data = static_cast(flat_vec->GetRawData()); -#if defined(USE_DYNAMIC_SIMD) - milvus::simd::invert_bool(data, flat_vec->size()); -#else - for (int i = 0; i < flat_vec->size(); ++i) { - data[i] = !data[i]; - } -#endif + TargetBitmapView data(flat_vec->GetRawData(), flat_vec->size()); + data.flip(); } } diff --git a/internal/core/src/exec/expression/TermExpr.cpp b/internal/core/src/exec/expression/TermExpr.cpp index d9c53a5bca54..1cf18619e67a 100644 --- a/internal/core/src/exec/expression/TermExpr.cpp +++ b/internal/core/src/exec/expression/TermExpr.cpp @@ -198,8 +198,8 @@ PhyTermFilterExpr::ExecPkTermImpl() { } auto res_vec = - std::make_shared(DataType::BOOL, real_batch_size); - bool* res = (bool*)res_vec->GetRawData(); + std::make_shared(TargetBitmap(real_batch_size)); + TargetBitmapView res(res_vec->GetRawData(), real_batch_size); for (size_t i = 0; i < real_batch_size; ++i) { res[i] = cached_bits_[current_data_chunk_pos_++]; @@ -243,9 +243,10 @@ PhyTermFilterExpr::ExecTermArrayVariableInField() { if (real_batch_size == 0) { return nullptr; } + auto res_vec = - std::make_shared(DataType::BOOL, real_batch_size); - bool* res = (bool*)res_vec->GetRawData(); + std::make_shared(TargetBitmap(real_batch_size)); + TargetBitmapView res(res_vec->GetRawData(), real_batch_size); AssertInfo(expr_->vals_.size() == 1, "element length in json array must be one"); @@ -253,7 +254,7 @@ PhyTermFilterExpr::ExecTermArrayVariableInField() { auto execute_sub_batch = [](const ArrayView* data, const int size, - bool* res, + TargetBitmapView res, const ValueType& target_val) { auto executor = [&](size_t i) { for (int i = 0; i < data[i].length(); i++) { @@ -290,9 +291,10 @@ PhyTermFilterExpr::ExecTermArrayFieldInVariable() { if (real_batch_size == 0) { return nullptr; } + auto res_vec = - std::make_shared(DataType::BOOL, real_batch_size); - bool* res = (bool*)res_vec->GetRawData(); + std::make_shared(TargetBitmap(real_batch_size)); + TargetBitmapView res(res_vec->GetRawData(), real_batch_size); int index = -1; if (expr_->column_.nested_path_.size() > 0) { @@ -304,15 +306,13 @@ PhyTermFilterExpr::ExecTermArrayFieldInVariable() { } if (term_set.empty()) { - for (size_t i = 0; i < real_batch_size; ++i) { - res[i] = false; - } + res.reset(); return res_vec; } auto execute_sub_batch = [](const ArrayView* data, const int size, - bool* res, + TargetBitmapView res, int index, const std::unordered_set& term_set) { if (term_set.empty()) { @@ -350,9 +350,11 @@ PhyTermFilterExpr::ExecTermJsonVariableInField() { if (real_batch_size == 0) { return nullptr; } + auto res_vec = - std::make_shared(DataType::BOOL, real_batch_size); - bool* res = (bool*)res_vec->GetRawData(); + std::make_shared(TargetBitmap(real_batch_size)); + TargetBitmapView res(res_vec->GetRawData(), real_batch_size); + AssertInfo(expr_->vals_.size() == 1, "element length in json array must be one"); ValueType val = GetValueFromProto(expr_->vals_[0]); @@ -360,7 +362,7 @@ PhyTermFilterExpr::ExecTermJsonVariableInField() { auto execute_sub_batch = [](const Json* data, const int size, - bool* res, + TargetBitmapView res, const std::string pointer, const ValueType& target_val) { auto executor = [&](size_t i) { @@ -403,9 +405,11 @@ PhyTermFilterExpr::ExecTermJsonFieldInVariable() { if (real_batch_size == 0) { return nullptr; } + auto res_vec = - std::make_shared(DataType::BOOL, real_batch_size); - bool* res = (bool*)res_vec->GetRawData(); + std::make_shared(TargetBitmap(real_batch_size)); + TargetBitmapView res(res_vec->GetRawData(), real_batch_size); + auto pointer = milvus::Json::pointer(expr_->column_.nested_path_); std::unordered_set term_set; for (const auto& element : expr_->vals_) { @@ -421,7 +425,7 @@ PhyTermFilterExpr::ExecTermJsonFieldInVariable() { auto execute_sub_batch = [](const Json* data, const int size, - bool* res, + TargetBitmapView res, const std::string pointer, const std::unordered_set& terms) { auto executor = [&](size_t i) { @@ -532,9 +536,11 @@ PhyTermFilterExpr::ExecVisitorImplForData() { if (real_batch_size == 0) { return nullptr; } + auto res_vec = - std::make_shared(DataType::BOOL, real_batch_size); - bool* res = (bool*)res_vec->GetRawData(); + std::make_shared(TargetBitmap(real_batch_size)); + TargetBitmapView res(res_vec->GetRawData(), real_batch_size); + std::vector vals; for (auto& val : expr_->vals_) { // Integral overflow process @@ -547,7 +553,7 @@ PhyTermFilterExpr::ExecVisitorImplForData() { std::unordered_set vals_set(vals.begin(), vals.end()); auto execute_sub_batch = [](const T* data, const int size, - bool* res, + TargetBitmapView res, const std::unordered_set& vals) { TermElementFuncSet func; for (size_t i = 0; i < size; ++i) { diff --git a/internal/core/src/exec/expression/TermExpr.h b/internal/core/src/exec/expression/TermExpr.h index c6cf9ad98ded..48dc718cc429 100644 --- a/internal/core/src/exec/expression/TermExpr.h +++ b/internal/core/src/exec/expression/TermExpr.h @@ -41,7 +41,7 @@ struct TermIndexFunc { conditional_t, std::string, T> IndexInnerType; using Index = index::ScalarIndex; - FixedVector + TargetBitmap operator()(Index* index, size_t n, const IndexInnerType* val) { return index->In(n, val); } @@ -129,7 +129,7 @@ class PhyTermFilterExpr : public SegmentExpr { bool use_cache_offsets_{false}; bool cached_offsets_inited_{false}; ColumnVectorPtr cached_offsets_; - FixedVector cached_bits_; + TargetBitmap cached_bits_; }; } //namespace exec } // namespace milvus diff --git a/internal/core/src/exec/expression/UnaryExpr.cpp b/internal/core/src/exec/expression/UnaryExpr.cpp index e577bacdfb44..305fd1caef2e 100644 --- a/internal/core/src/exec/expression/UnaryExpr.cpp +++ b/internal/core/src/exec/expression/UnaryExpr.cpp @@ -125,8 +125,8 @@ PhyUnaryRangeFilterExpr::ExecRangeVisitorImplArray() { return nullptr; } auto res_vec = - std::make_shared(DataType::BOOL, real_batch_size); - bool* res = (bool*)res_vec->GetRawData(); + std::make_shared(TargetBitmap(real_batch_size)); + TargetBitmapView res(res_vec->GetRawData(), real_batch_size); ValueType val = GetValueFromProto(expr_->val_); auto op_type = expr_->op_type_; @@ -136,7 +136,7 @@ PhyUnaryRangeFilterExpr::ExecRangeVisitorImplArray() { } auto execute_sub_batch = [op_type](const milvus::ArrayView* data, const int size, - bool* res, + TargetBitmapView res, ValueType val, int index) { switch (op_type) { @@ -210,8 +210,8 @@ PhyUnaryRangeFilterExpr::ExecRangeVisitorImplJson() { ExprValueType val = GetValueFromProto(expr_->val_); auto res_vec = - std::make_shared(DataType::BOOL, real_batch_size); - bool* res = (bool*)res_vec->GetRawData(); + std::make_shared(TargetBitmap(real_batch_size)); + TargetBitmapView res(res_vec->GetRawData(), real_batch_size); auto op_type = expr_->op_type_; auto pointer = milvus::Json::pointer(expr_->column_.nested_path_); @@ -247,7 +247,7 @@ PhyUnaryRangeFilterExpr::ExecRangeVisitorImplJson() { auto execute_sub_batch = [op_type, pointer](const milvus::Json* data, const int size, - bool* res, + TargetBitmapView res, ExprValueType val) { switch (op_type) { case proto::plan::GreaterThan: { @@ -392,7 +392,7 @@ PhyUnaryRangeFilterExpr::ExecRangeVisitorImplForIndex() { } auto op_type = expr_->op_type_; auto execute_sub_batch = [op_type](Index* index_ptr, IndexInnerType val) { - FixedVector res; + TargetBitmap res; switch (op_type) { case proto::plan::GreaterThan: { UnaryIndexFunc func; @@ -472,13 +472,12 @@ PhyUnaryRangeFilterExpr::PreCheckOverflow() { case proto::plan::GreaterThan: case proto::plan::GreaterEqual: { auto res_vec = std::make_shared( - DataType::BOOL, batch_size); + TargetBitmap(batch_size)); cached_overflow_res_ = res_vec; - bool* res = (bool*)res_vec->GetRawData(); + TargetBitmapView res(res_vec->GetRawData(), batch_size); + if (milvus::query::lt_lb(val)) { - for (size_t i = 0; i < batch_size; ++i) { - res[i] = true; - } + res.set(); return res_vec; } return res_vec; @@ -486,35 +485,32 @@ PhyUnaryRangeFilterExpr::PreCheckOverflow() { case proto::plan::LessThan: case proto::plan::LessEqual: { auto res_vec = std::make_shared( - DataType::BOOL, batch_size); + TargetBitmap(batch_size)); cached_overflow_res_ = res_vec; - bool* res = (bool*)res_vec->GetRawData(); + TargetBitmapView res(res_vec->GetRawData(), batch_size); + if (milvus::query::gt_ub(val)) { - for (size_t i = 0; i < batch_size; ++i) { - res[i] = true; - } + res.set(); return res_vec; } return res_vec; } case proto::plan::Equal: { auto res_vec = std::make_shared( - DataType::BOOL, batch_size); + TargetBitmap(batch_size)); cached_overflow_res_ = res_vec; - bool* res = (bool*)res_vec->GetRawData(); - for (size_t i = 0; i < batch_size; ++i) { - res[i] = false; - } + TargetBitmapView res(res_vec->GetRawData(), batch_size); + + res.reset(); return res_vec; } case proto::plan::NotEqual: { auto res_vec = std::make_shared( - DataType::BOOL, batch_size); + TargetBitmap(batch_size)); cached_overflow_res_ = res_vec; - bool* res = (bool*)res_vec->GetRawData(); - for (size_t i = 0; i < batch_size; ++i) { - res[i] = true; - } + TargetBitmapView res(res_vec->GetRawData(), batch_size); + + res.set(); return res_vec; } default: { @@ -544,12 +540,12 @@ PhyUnaryRangeFilterExpr::ExecRangeVisitorImplForData() { } IndexInnerType val = GetValueFromProto(expr_->val_); auto res_vec = - std::make_shared(DataType::BOOL, real_batch_size); - bool* res = (bool*)res_vec->GetRawData(); + std::make_shared(TargetBitmap(real_batch_size)); + TargetBitmapView res(res_vec->GetRawData(), real_batch_size); auto expr_type = expr_->op_type_; auto execute_sub_batch = [expr_type](const T* data, const int size, - bool* res, + TargetBitmapView res, IndexInnerType val) { switch (expr_type) { case proto::plan::GreaterThan: { diff --git a/internal/core/src/exec/expression/UnaryExpr.h b/internal/core/src/exec/expression/UnaryExpr.h index d29d6b371fc4..cd260183608e 100644 --- a/internal/core/src/exec/expression/UnaryExpr.h +++ b/internal/core/src/exec/expression/UnaryExpr.h @@ -39,7 +39,10 @@ struct UnaryElementFuncForMatch { IndexInnerType; void - operator()(const T* src, size_t size, IndexInnerType val, bool* res) { + operator()(const T* src, + size_t size, + IndexInnerType val, + TargetBitmapView res) { if constexpr (std::is_same_v) { // translate the pattern match in advance, which avoid computing it every loop. std::regex reg(TranslatePatternMatchToRegex(val)); @@ -65,13 +68,18 @@ struct UnaryElementFunc { conditional_t, std::string, T> IndexInnerType; void - operator()(const T* src, size_t size, IndexInnerType val, bool* res) { + operator()(const T* src, + size_t size, + IndexInnerType val, + TargetBitmapView res) { if constexpr (op == proto::plan::OpType::Match) { UnaryElementFuncForMatch func; func(src, size, val, res); return; } + /* + // This is the original code, which is kept for the documentation purposes for (int i = 0; i < size; ++i) { if constexpr (op == proto::plan::OpType::Equal) { res[i] = src[i] == val; @@ -95,6 +103,36 @@ struct UnaryElementFunc { op)); } } + */ + + if constexpr (op == proto::plan::OpType::PrefixMatch) { + for (int i = 0; i < size; ++i) { + res[i] = milvus::query::Match( + src[i], val, proto::plan::OpType::PrefixMatch); + } + } else if constexpr (op == proto::plan::OpType::Equal) { + res.inplace_compare_val( + src, size, val); + } else if constexpr (op == proto::plan::OpType::NotEqual) { + res.inplace_compare_val( + src, size, val); + } else if constexpr (op == proto::plan::OpType::GreaterThan) { + res.inplace_compare_val( + src, size, val); + } else if constexpr (op == proto::plan::OpType::LessThan) { + res.inplace_compare_val( + src, size, val); + } else if constexpr (op == proto::plan::OpType::GreaterEqual) { + res.inplace_compare_val( + src, size, val); + } else if constexpr (op == proto::plan::OpType::LessEqual) { + res.inplace_compare_val( + src, size, val); + } else { + PanicInfo( + OpTypeInvalid, + fmt::format("unsupported op_type:{} for UnaryElementFunc", op)); + } } }; @@ -122,7 +160,7 @@ struct UnaryElementFuncForArray { size_t size, ValueType val, int index, - bool* res) { + TargetBitmapView res) { for (int i = 0; i < size; ++i) { if constexpr (op == proto::plan::OpType::Equal) { if constexpr (std::is_same_v) { @@ -172,7 +210,7 @@ struct UnaryIndexFuncForMatch { conditional_t, std::string, T> IndexInnerType; using Index = index::ScalarIndex; - FixedVector + TargetBitmap operator()(Index* index, IndexInnerType val) { if constexpr (!std::is_same_v && !std::is_same_v) { @@ -207,7 +245,7 @@ struct UnaryIndexFunc { conditional_t, std::string, T> IndexInnerType; using Index = index::ScalarIndex; - FixedVector + TargetBitmap operator()(Index* index, IndexInnerType val) { if constexpr (op == proto::plan::OpType::Equal) { return index->In(1, &val); diff --git a/internal/core/src/query/CMakeLists.txt b/internal/core/src/query/CMakeLists.txt index 588bb8ebe221..6e9c2e4b93ca 100644 --- a/internal/core/src/query/CMakeLists.txt +++ b/internal/core/src/query/CMakeLists.txt @@ -30,8 +30,5 @@ set(MILVUS_QUERY_SRCS PlanProto.cpp ) add_library(milvus_query ${MILVUS_QUERY_SRCS}) -if(USE_DYNAMIC_SIMD) - target_link_libraries(milvus_query milvus_index milvus_simd) -else() - target_link_libraries(milvus_query milvus_index) -endif() + +target_link_libraries(milvus_query milvus_index milvus_bitset) diff --git a/internal/core/src/query/Utils.h b/internal/core/src/query/Utils.h index 49af5a1d186a..2a42e894c7b7 100644 --- a/internal/core/src/query/Utils.h +++ b/internal/core/src/query/Utils.h @@ -16,7 +16,6 @@ #include "query/Expr.h" #include "common/Utils.h" -#include "simd/hook.h" namespace milvus::query { @@ -72,60 +71,4 @@ out_of_range(int64_t t) { return gt_ub(t) || lt_lb(t); } -inline void -AppendOneChunk(BitsetType& result, const bool* chunk_ptr, size_t chunk_len) { - // Append a value once instead of BITSET_BLOCK_BIT_SIZE times. - auto AppendBlock = [&result](const bool* ptr, int n) { - for (int i = 0; i < n; ++i) { -#if defined(USE_DYNAMIC_SIMD) - auto val = milvus::simd::get_bitset_block(ptr); -#else - BitsetBlockType val = 0; - // This can use CPU SIMD optimzation - uint8_t vals[BITSET_BLOCK_SIZE] = {0}; - for (size_t j = 0; j < 8; ++j) { - for (size_t k = 0; k < BITSET_BLOCK_SIZE; ++k) { - vals[k] |= uint8_t(*(ptr + k * 8 + j)) << j; - } - } - for (size_t j = 0; j < BITSET_BLOCK_SIZE; ++j) { - val |= BitsetBlockType(vals[j]) << (8 * j); - } -#endif - result.append(val); - ptr += BITSET_BLOCK_SIZE * 8; - } - }; - // Append bit for these bits that can not be union as a block - // Usually n less than BITSET_BLOCK_BIT_SIZE. - auto AppendBit = [&result](const bool* ptr, int n) { - for (int i = 0; i < n; ++i) { - bool bit = *ptr++; - result.push_back(bit); - } - }; - - size_t res_len = result.size(); - - int n_prefix = - res_len % BITSET_BLOCK_BIT_SIZE == 0 - ? 0 - : std::min(BITSET_BLOCK_BIT_SIZE - res_len % BITSET_BLOCK_BIT_SIZE, - chunk_len); - - AppendBit(chunk_ptr, n_prefix); - - if (n_prefix == chunk_len) - return; - - size_t n_block = (chunk_len - n_prefix) / BITSET_BLOCK_BIT_SIZE; - size_t n_suffix = (chunk_len - n_prefix) % BITSET_BLOCK_BIT_SIZE; - - AppendBlock(chunk_ptr + n_prefix, n_block); - - AppendBit(chunk_ptr + n_prefix + n_block * BITSET_BLOCK_BIT_SIZE, n_suffix); - - return; -} - } // namespace milvus::query diff --git a/internal/core/src/query/generated/ExecExprVisitor.h b/internal/core/src/query/generated/ExecExprVisitor.h index 6bcc7d05a836..2da1cd0cc5f6 100644 --- a/internal/core/src/query/generated/ExecExprVisitor.h +++ b/internal/core/src/query/generated/ExecExprVisitor.h @@ -25,7 +25,7 @@ namespace milvus::query { void -AppendOneChunk(BitsetType& result, const FixedVector& chunk_res); +AppendOneChunk(BitsetType& result, const TargetBitmapView chunk_res); class ExecExprVisitor : public ExprVisitor { public: diff --git a/internal/core/src/query/visitors/ExecExprVisitor.cpp b/internal/core/src/query/visitors/ExecExprVisitor.cpp index e6a8ef901c66..dc94e914a903 100644 --- a/internal/core/src/query/visitors/ExecExprVisitor.cpp +++ b/internal/core/src/query/visitors/ExecExprVisitor.cpp @@ -37,7 +37,6 @@ #include "simdjson/error.h" #include "query/PlanProto.h" #include "index/SkipIndex.h" -#include "simd/hook.h" #include "index/Meta.h" namespace milvus::query { @@ -183,89 +182,28 @@ static auto Assemble(const std::deque& srcs) -> BitsetType { BitsetType res; - if (srcs.size() == 1) { - return srcs[0]; - } - int64_t total_size = 0; for (auto& chunk : srcs) { total_size += chunk.size(); } - res.resize(total_size); + res.reserve(total_size); - int64_t counter = 0; for (auto& chunk : srcs) { - for (int64_t i = 0; i < chunk.size(); ++i) { - res[counter + i] = chunk[i]; - } - counter += chunk.size(); + res.append(chunk); } return res; } void -AppendOneChunk(BitsetType& result, const FixedVector& chunk_res) { - // Append a value once instead of BITSET_BLOCK_BIT_SIZE times. - auto AppendBlock = [&result](const bool* ptr, int n) { - for (int i = 0; i < n; ++i) { -#if defined(USE_DYNAMIC_SIMD) - auto val = milvus::simd::get_bitset_block(ptr); -#else - BitsetBlockType val = 0; - // This can use CPU SIMD optimzation - uint8_t vals[BITSET_BLOCK_SIZE] = {0}; - for (size_t j = 0; j < 8; ++j) { - for (size_t k = 0; k < BITSET_BLOCK_SIZE; ++k) { - vals[k] |= uint8_t(*(ptr + k * 8 + j)) << j; - } - } - for (size_t j = 0; j < BITSET_BLOCK_SIZE; ++j) { - val |= BitsetBlockType(vals[j]) << (8 * j); - } -#endif - result.append(val); - ptr += BITSET_BLOCK_SIZE * 8; - } - }; - // Append bit for these bits that can not be union as a block - // Usually n less than BITSET_BLOCK_BIT_SIZE. - auto AppendBit = [&result](const bool* ptr, int n) { - for (int i = 0; i < n; ++i) { - bool bit = *ptr++; - result.push_back(bit); - } - }; - - size_t res_len = result.size(); - size_t chunk_len = chunk_res.size(); - const bool* chunk_ptr = chunk_res.data(); - - int n_prefix = - res_len % BITSET_BLOCK_BIT_SIZE == 0 - ? 0 - : std::min(BITSET_BLOCK_BIT_SIZE - res_len % BITSET_BLOCK_BIT_SIZE, - chunk_len); - - AppendBit(chunk_ptr, n_prefix); - - if (n_prefix == chunk_len) - return; - - size_t n_block = (chunk_len - n_prefix) / BITSET_BLOCK_BIT_SIZE; - size_t n_suffix = (chunk_len - n_prefix) % BITSET_BLOCK_BIT_SIZE; - - AppendBlock(chunk_ptr + n_prefix, n_block); - - AppendBit(chunk_ptr + n_prefix + n_block * BITSET_BLOCK_BIT_SIZE, n_suffix); - - return; +AppendOneChunk(BitsetType& result, const TargetBitmapView chunk_res) { + result.append(chunk_res); } BitsetType -AssembleChunk(const std::vector>& results) { +AssembleChunk(const std::vector& results) { BitsetType assemble_result; for (auto& result : results) { - AppendOneChunk(assemble_result, result); + AppendOneChunk(assemble_result, result.view()); } return assemble_result; } @@ -285,7 +223,7 @@ ExecExprVisitor::ExecRangeVisitorImpl(FieldId field_id, auto indexing_barrier = segment_.num_chunk_index(field_id); auto size_per_chunk = segment_.size_per_chunk(); auto num_chunk = upper_div(row_count_, size_per_chunk); - std::vector> results; + std::vector results; results.reserve(num_chunk); typedef std:: @@ -307,7 +245,7 @@ ExecExprVisitor::ExecRangeVisitorImpl(FieldId field_id, auto this_size = chunk_id == num_chunk - 1 ? row_count_ - chunk_id * size_per_chunk : size_per_chunk; - FixedVector chunk_res(this_size); + TargetBitmap chunk_res(this_size); //check possible chunk metrics auto& skipIndex = segment_.GetSkipIndex(); if (skip_index_func(skipIndex, field_id, chunk_id)) { @@ -343,7 +281,7 @@ ExecExprVisitor::ExecDataRangeVisitorImpl(FieldId field_id, auto data_barrier = segment_.num_chunk_data(field_id); AssertInfo(std::max(data_barrier, indexing_barrier) == num_chunk, "max(data_barrier, index_barrier) not equal to num_chunk"); - std::vector> results; + std::vector results; results.reserve(num_chunk); // for growing segment, indexing_barrier will always less than data_barrier @@ -354,7 +292,7 @@ ExecExprVisitor::ExecDataRangeVisitorImpl(FieldId field_id, auto this_size = chunk_id == num_chunk - 1 ? row_count_ - chunk_id * size_per_chunk : size_per_chunk; - FixedVector result(this_size); + TargetBitmap result(this_size); auto chunk = segment_.chunk_data(field_id, chunk_id); const T* data = chunk.data(); for (int index = 0; index < this_size; ++index) { @@ -377,7 +315,7 @@ ExecExprVisitor::ExecDataRangeVisitorImpl(FieldId field_id, auto& indexing = segment_.chunk_scalar_index(field_id, chunk_id); auto this_size = const_cast(&indexing)->Count(); - FixedVector result(this_size); + TargetBitmap result(this_size); for (int offset = 0; offset < this_size; ++offset) { result[offset] = index_func(const_cast(&indexing), offset); } @@ -2112,11 +2050,11 @@ ExecExprVisitor::ExecCompareLeftType(const FieldId& left_field_id, CmpFunc cmp_func) { auto size_per_chunk = segment_.size_per_chunk(); auto num_chunks = upper_div(row_count_, size_per_chunk); - std::vector> results; + std::vector results; results.reserve(num_chunks); for (int64_t chunk_id = 0; chunk_id < num_chunks; ++chunk_id) { - FixedVector result; + TargetBitmap result; const T* left_raw_data = segment_.chunk_data(left_field_id, chunk_id).data(); @@ -2155,7 +2093,7 @@ ExecExprVisitor::ExecCompareLeftType(const FieldId& left_field_id, fmt::format("unsupported right datatype {} of compare expr", right_field_type)); } - results.push_back(result); + results.push_back(std::move(result)); } auto final_result = AssembleChunk(results); AssertInfo(final_result.size() == row_count_, diff --git a/internal/core/src/query/visitors/ExecPlanNodeVisitor.cpp b/internal/core/src/query/visitors/ExecPlanNodeVisitor.cpp index 23e03d1c0ef6..354b0af553fe 100644 --- a/internal/core/src/query/visitors/ExecPlanNodeVisitor.cpp +++ b/internal/core/src/query/visitors/ExecPlanNodeVisitor.cpp @@ -102,16 +102,15 @@ ExecPlanNodeVisitor::ExecuteExprNodeInternal( "expr result vector's children size not equal one"); LOG_DEBUG("output result length:{}", childrens[0]->size()); if (auto vec = std::dynamic_pointer_cast(childrens[0])) { - AppendOneChunk(bitset_holder, - static_cast(vec->GetRawData()), - vec->size()); + TargetBitmapView view(vec->GetRawData(), vec->size()); + AppendOneChunk(bitset_holder, view); } else if (auto row = std::dynamic_pointer_cast(childrens[0])) { auto bit_vec = std::dynamic_pointer_cast(row->child(0)); - AppendOneChunk(bitset_holder, - static_cast(bit_vec->GetRawData()), - bit_vec->size()); + TargetBitmapView view(bit_vec->GetRawData(), bit_vec->size()); + AppendOneChunk(bitset_holder, view); + if (!cache_offset_getted) { // offset cache only get once because not support iterator batch auto cache_offset_vec = @@ -168,7 +167,7 @@ ExecPlanNodeVisitor::VectorVisitorImpl(VectorPlanNode& node) { BitsetType expr_res; ExecuteExprNode( node.filter_plannode_.value(), segment, active_count, expr_res); - bitset_holder = std::make_unique(expr_res); + bitset_holder = std::make_unique(expr_res.clone()); bitset_holder->flip(); } else { bitset_holder = std::make_unique(active_count, false); diff --git a/internal/core/src/segcore/CMakeLists.txt b/internal/core/src/segcore/CMakeLists.txt index 9986882e7a0d..eb92f5657dc1 100644 --- a/internal/core/src/segcore/CMakeLists.txt +++ b/internal/core/src/segcore/CMakeLists.txt @@ -39,6 +39,6 @@ set(SEGCORE_FILES ConcurrentVector.cpp) add_library(milvus_segcore SHARED ${SEGCORE_FILES}) -target_link_libraries(milvus_segcore milvus_query milvus_exec ${OpenMP_CXX_FLAGS} milvus-storage) +target_link_libraries(milvus_segcore milvus_query milvus_bitset milvus_exec ${OpenMP_CXX_FLAGS} milvus-storage) install(TARGETS milvus_segcore DESTINATION "${CMAKE_INSTALL_LIBDIR}") diff --git a/internal/core/src/segcore/DeletedRecord.h b/internal/core/src/segcore/DeletedRecord.h index 7529062cfbfe..93e56c81c3eb 100644 --- a/internal/core/src/segcore/DeletedRecord.h +++ b/internal/core/src/segcore/DeletedRecord.h @@ -137,8 +137,9 @@ DeletedRecord::TmpBitmap::clone(int64_t capacity) -> std::shared_ptr { auto res = std::make_shared(); res->del_barrier = this->del_barrier; - res->bitmap_ptr = std::make_shared(); - *(res->bitmap_ptr) = *(this->bitmap_ptr); + // res->bitmap_ptr = std::make_shared(); + // *(res->bitmap_ptr) = *(this->bitmap_ptr); + res->bitmap_ptr = std::make_shared(this->bitmap_ptr->clone()); res->bitmap_ptr->resize(capacity, false); return res; } diff --git a/internal/core/src/segcore/SegmentInterface.cpp b/internal/core/src/segcore/SegmentInterface.cpp index 0c6fb47104f5..962e86eedfa7 100644 --- a/internal/core/src/segcore/SegmentInterface.cpp +++ b/internal/core/src/segcore/SegmentInterface.cpp @@ -267,11 +267,15 @@ SegmentInternalInterface::timestamp_filter(BitsetType& bitset, auto pilot = upper_bound(timestamps, 0, cnt, timestamp); // offset bigger than pilot should be filtered out. - for (int offset = pilot; offset < cnt; offset = bitset.find_next(offset)) { - if (offset == BitsetType::npos) { + auto offset = pilot; + while (offset < cnt) { + bitset[offset] = false; + + const auto next_offset = bitset.find_next(offset); + if (!next_offset.has_value()) { return; } - bitset[offset] = false; + offset = next_offset.value(); } } diff --git a/internal/core/src/segcore/SegmentSealedImpl.cpp b/internal/core/src/segcore/SegmentSealedImpl.cpp index f98a3c00078a..4dd3a894e0d3 100644 --- a/internal/core/src/segcore/SegmentSealedImpl.cpp +++ b/internal/core/src/segcore/SegmentSealedImpl.cpp @@ -979,8 +979,9 @@ SegmentSealedImpl::check_search(const query::Plan* plan) const { auto absent_fields = request_fields - field_ready_bitset; if (absent_fields.any()) { + // absent_fields.find_first() returns std::optional<> auto field_id = - FieldId(absent_fields.find_first() + START_USER_FIELDID); + FieldId(absent_fields.find_first().value() + START_USER_FIELDID); auto& field_meta = schema_->operator[](field_id); PanicInfo( FieldNotLoaded, diff --git a/internal/core/src/simd/CMakeLists.txt b/internal/core/src/simd/CMakeLists.txt deleted file mode 100644 index 632373da08ca..000000000000 --- a/internal/core/src/simd/CMakeLists.txt +++ /dev/null @@ -1,41 +0,0 @@ -# Copyright (C) 2019-2020 Zilliz. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software distributed under the License -# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -# or implied. See the License for the specific language governing permissions and limitations under the License - -set(MILVUS_SIMD_SRCS - ref.cpp - hook.cpp -) - -if (${CMAKE_SYSTEM_PROCESSOR} STREQUAL "x86_64") - # x86 cpu simd - message ("simd using x86_64 mode") - list(APPEND MILVUS_SIMD_SRCS - sse2.cpp - sse4.cpp - avx2.cpp - avx512.cpp - ) - set_source_files_properties(sse4.cpp PROPERTIES COMPILE_FLAGS "-msse4.2") - set_source_files_properties(avx2.cpp PROPERTIES COMPILE_FLAGS "-mavx2") - set_source_files_properties(avx512.cpp PROPERTIES COMPILE_FLAGS "-mavx512f -mavx512vl -mavx512dq -mavx512bw") - -elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "arm*") - # TODO: add arm cpu simd - message ("simd using arm mode") - list(APPEND MILVUS_SIMD_SRCS - neon.cpp - ) -endif() - -add_library(milvus_simd ${MILVUS_SIMD_SRCS}) - -# Link the milvus_simd library with other libraries as needed -target_link_libraries(milvus_simd milvus_log) diff --git a/internal/core/src/simd/avx2.cpp b/internal/core/src/simd/avx2.cpp deleted file mode 100644 index 1ea51faccabb..000000000000 --- a/internal/core/src/simd/avx2.cpp +++ /dev/null @@ -1,310 +0,0 @@ -// Copyright (C) 2019-2023 Zilliz. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software distributed under the License -// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -// or implied. See the License for the specific language governing permissions and limitations under the License. - -#if defined(__x86_64__) - -#include "avx2.h" -#include "sse2.h" -#include "sse4.h" - -#include - -#include -#include - -namespace milvus { -namespace simd { - -BitsetBlockType -GetBitsetBlockAVX2(const bool* src) { - if constexpr (BITSET_BLOCK_SIZE == 8) { - // BitsetBlockType has 64 bits - __m256i highbit = _mm256_set1_epi8(0x7F); - uint32_t tmp[8]; - __m256i boolvec = _mm256_loadu_si256((__m256i*)(src)); - __m256i highbits = _mm256_add_epi8(boolvec, highbit); - tmp[0] = _mm256_movemask_epi8(highbits); - boolvec = _mm256_loadu_si256((__m256i*)(src + 32)); - highbits = _mm256_add_epi8(boolvec, highbit); - tmp[1] = _mm256_movemask_epi8(highbits); - - __m256i tmpvec = _mm256_loadu_si256((__m256i*)tmp); - BitsetBlockType res[4]; - _mm256_storeu_si256((__m256i*)res, tmpvec); - return res[0]; - // __m256i tmpvec = _mm_loadu_si64(tmp); - // BitsetBlockType res; - // _mm_storeu_si64(&res, tmpvec); - // return res; - } else { - // Others has 32 bits - __m256i highbit = _mm256_set1_epi8(0x7F); - uint32_t tmp[8]; - __m256i boolvec = _mm256_loadu_si256((__m256i*)&src[0]); - __m256i highbits = _mm256_add_epi8(boolvec, highbit); - tmp[0] = _mm256_movemask_epi8(highbits); - - __m256i tmpvec = _mm256_loadu_si256((__m256i*)tmp); - BitsetBlockType res[8]; - _mm256_storeu_si256((__m256i*)res, tmpvec); - return res[0]; - } -} - -template <> -bool -FindTermAVX2(const bool* src, size_t vec_size, bool val) { - __m256i ymm_target = _mm256_set1_epi8(val); - __m256i ymm_data; - size_t num_chunks = vec_size / 32; - - for (size_t i = 0; i < 32 * num_chunks; i += 32) { - ymm_data = - _mm256_loadu_si256(reinterpret_cast(src + i)); - __m256i ymm_match = _mm256_cmpeq_epi8(ymm_data, ymm_target); - int mask = _mm256_movemask_epi8(ymm_match); - if (mask != 0) { - return true; - } - } - - for (size_t i = 32 * num_chunks; i < vec_size; ++i) { - if (src[i] == val) { - return true; - } - } - return false; -} - -template <> -bool -FindTermAVX2(const int8_t* src, size_t vec_size, int8_t val) { - __m256i ymm_target = _mm256_set1_epi8(val); - __m256i ymm_data; - size_t num_chunks = vec_size / 32; - - for (size_t i = 0; i < 32 * num_chunks; i += 32) { - ymm_data = - _mm256_loadu_si256(reinterpret_cast(src + i)); - __m256i ymm_match = _mm256_cmpeq_epi8(ymm_data, ymm_target); - int mask = _mm256_movemask_epi8(ymm_match); - if (mask != 0) { - return true; - } - } - - for (size_t i = 32 * num_chunks; i < vec_size; ++i) { - if (src[i] == val) { - return true; - } - } - return false; -} - -template <> -bool -FindTermAVX2(const int16_t* src, size_t vec_size, int16_t val) { - __m256i ymm_target = _mm256_set1_epi16(val); - __m256i ymm_data; - size_t num_chunks = vec_size / 16; - for (size_t i = 0; i < 16 * num_chunks; i += 16) { - ymm_data = - _mm256_loadu_si256(reinterpret_cast(src + i)); - __m256i ymm_match = _mm256_cmpeq_epi16(ymm_data, ymm_target); - int mask = _mm256_movemask_epi8(ymm_match); - if (mask != 0) { - return true; - } - } - - for (size_t i = 16 * num_chunks; i < vec_size; ++i) { - if (src[i] == val) { - return true; - } - } - return false; -} - -template <> -bool -FindTermAVX2(const int32_t* src, size_t vec_size, int32_t val) { - __m256i ymm_target = _mm256_set1_epi32(val); - __m256i ymm_data; - size_t num_chunks = vec_size / 8; - size_t remaining_size = vec_size % 8; - - for (size_t i = 0; i < 8 * num_chunks; i += 8) { - ymm_data = - _mm256_loadu_si256(reinterpret_cast(src + i)); - __m256i ymm_match = _mm256_cmpeq_epi32(ymm_data, ymm_target); - int mask = _mm256_movemask_epi8(ymm_match); - if (mask != 0) { - return true; - } - } - - if (remaining_size == 0) { - return false; - } - return FindTermSSE2(src + 8 * num_chunks, remaining_size, val); -} - -template <> -bool -FindTermAVX2(const int64_t* src, size_t vec_size, int64_t val) { - __m256i ymm_target = _mm256_set1_epi64x(val); - __m256i ymm_data; - size_t num_chunks = vec_size / 4; - - for (size_t i = 0; i < 4 * num_chunks; i += 4) { - ymm_data = - _mm256_loadu_si256(reinterpret_cast(src + i)); - __m256i ymm_match = _mm256_cmpeq_epi64(ymm_data, ymm_target); - int mask = _mm256_movemask_epi8(ymm_match); - if (mask != 0) { - return true; - } - } - - for (size_t i = 4 * num_chunks; i < vec_size; ++i) { - if (src[i] == val) { - return true; - } - } - return false; -} - -template <> -bool -FindTermAVX2(const float* src, size_t vec_size, float val) { - __m256 ymm_target = _mm256_set1_ps(val); - __m256 ymm_data; - size_t num_chunks = vec_size / 8; - - for (size_t i = 0; i < 8 * num_chunks; i += 8) { - ymm_data = _mm256_loadu_ps(src + i); - __m256 ymm_match = _mm256_cmp_ps(ymm_data, ymm_target, _CMP_EQ_OQ); - int mask = _mm256_movemask_ps(ymm_match); - if (mask != 0) { - return true; - } - } - - for (size_t i = 8 * num_chunks; i < vec_size; ++i) { - if (src[i] == val) { - return true; - } - } - return false; -} - -template <> -bool -FindTermAVX2(const double* src, size_t vec_size, double val) { - __m256d ymm_target = _mm256_set1_pd(val); - __m256d ymm_data; - size_t num_chunks = vec_size / 4; - - for (size_t i = 0; i < 4 * num_chunks; i += 4) { - ymm_data = _mm256_loadu_pd(src + i); - __m256d ymm_match = _mm256_cmp_pd(ymm_data, ymm_target, _CMP_EQ_OQ); - int mask = _mm256_movemask_pd(ymm_match); - if (mask != 0) { - return true; - } - } - - for (size_t i = 4 * num_chunks; i < vec_size; ++i) { - if (src[i] == val) { - return true; - } - } - return false; -} - -bool -AllFalseAVX2(const bool* src, int64_t size) { - int num_chunk = size / 32; - __m256i highbit = _mm256_set1_epi8(0x7F); - for (size_t i = 0; i < num_chunk * 16; i += 16) { - __m256i data = - _mm256_loadu_si256(reinterpret_cast(src + i)); - __m256i highbits = _mm256_add_epi8(data, highbit); - if (_mm256_movemask_epi8(highbits) != 0) { - return false; - } - } - - for (size_t i = num_chunk * 16; i < size; ++i) { - if (src[i]) { - return false; - } - } - return true; -} - -bool -AllTrueAVX2(const bool* src, int64_t size) { - int num_chunk = size / 16; - __m256i highbit = _mm256_set1_epi8(0x7F); - for (size_t i = 0; i < num_chunk * 16; i += 16) { - __m256i data = - _mm256_loadu_si256(reinterpret_cast(src + i)); - __m256i highbits = _mm256_add_epi8(data, highbit); - if (_mm256_movemask_epi8(highbits) != 0xFFFF) { - return false; - } - } - - for (size_t i = num_chunk * 16; i < size; ++i) { - if (!src[i]) { - return false; - } - } - return true; -} - -void -AndBoolAVX2(bool* left, bool* right, int64_t size) { - int num_chunk = size / 32; - for (size_t i = 0; i < num_chunk * 32; i += 32) { - __m256i l_reg = - _mm256_loadu_si256(reinterpret_cast<__m256i*>(left + i)); - __m256i r_reg = - _mm256_loadu_si256(reinterpret_cast<__m256i*>(right + i)); - __m256i res = _mm256_and_si256(l_reg, r_reg); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(left + i), res); - } - for (size_t i = num_chunk * 32; i < size; ++i) { - left[i] &= right[i]; - } -} - -void -OrBoolAVX2(bool* left, bool* right, int64_t size) { - int num_chunk = size / 32; - for (size_t i = 0; i < num_chunk * 32; i += 32) { - __m256i l_reg = - _mm256_loadu_si256(reinterpret_cast<__m256i*>(left + i)); - __m256i r_reg = - _mm256_loadu_si256(reinterpret_cast<__m256i*>(right + i)); - __m256i res = _mm256_or_si256(l_reg, r_reg); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(left + i), res); - } - for (size_t i = num_chunk * 32; i < size; ++i) { - left[i] |= right[i]; - } -} - -} // namespace simd -} // namespace milvus - -#endif diff --git a/internal/core/src/simd/avx2.h b/internal/core/src/simd/avx2.h deleted file mode 100644 index 90d6c7bbb7e6..000000000000 --- a/internal/core/src/simd/avx2.h +++ /dev/null @@ -1,74 +0,0 @@ -// Copyright (C) 2019-2023 Zilliz. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software distributed under the License -// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -// or implied. See the License for the specific language governing permissions and limitations under the License. - -#pragma once - -#include -#include -#include - -#include "common.h" - -namespace milvus { -namespace simd { - -BitsetBlockType -GetBitsetBlockAVX2(const bool* src); - -template -bool -FindTermAVX2(const T* src, size_t vec_size, T va) { - CHECK_SUPPORTED_TYPE(T, "unsupported type for FindTermAVX2"); - return false; -} - -template <> -bool -FindTermAVX2(const bool* src, size_t vec_size, bool val); - -template <> -bool -FindTermAVX2(const int8_t* src, size_t vec_size, int8_t val); - -template <> -bool -FindTermAVX2(const int16_t* src, size_t vec_size, int16_t val); - -template <> -bool -FindTermAVX2(const int32_t* src, size_t vec_size, int32_t val); - -template <> -bool -FindTermAVX2(const int64_t* src, size_t vec_size, int64_t val); - -template <> -bool -FindTermAVX2(const float* src, size_t vec_size, float val); - -template <> -bool -FindTermAVX2(const double* src, size_t vec_size, double val); - -bool -AllFalseAVX2(const bool* src, int64_t size); - -bool -AllTrueAVX2(const bool* src, int64_t size); - -void -AndBoolAVX2(bool* left, bool* right, int64_t size); - -void -OrBoolAVX2(bool* left, bool* right, int64_t size); - -} // namespace simd -} // namespace milvus diff --git a/internal/core/src/simd/avx512.cpp b/internal/core/src/simd/avx512.cpp deleted file mode 100644 index 2fb8f7f5392b..000000000000 --- a/internal/core/src/simd/avx512.cpp +++ /dev/null @@ -1,918 +0,0 @@ -// Copyright (C) 2019-2023 Zilliz. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software distributed under the License -// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -// or implied. See the License for the specific language governing permissions and limitations under the License. - -#include "avx512.h" -#include - -#if defined(__x86_64__) -#include - -namespace milvus { -namespace simd { - -template <> -bool -FindTermAVX512(const bool* src, size_t vec_size, bool val) { - __m512i zmm_target = _mm512_set1_epi8(val); - __m512i zmm_data; - size_t num_chunks = vec_size / 64; - - for (size_t i = 0; i < 64 * num_chunks; i += 64) { - zmm_data = - _mm512_loadu_si512(reinterpret_cast(src + i)); - __mmask64 mask = _mm512_cmpeq_epi8_mask(zmm_data, zmm_target); - if (mask != 0) { - return true; - } - } - - for (size_t i = 64 * num_chunks; i < vec_size; ++i) { - if (src[i] == val) { - return true; - } - } - return false; -} - -template <> -bool -FindTermAVX512(const int8_t* src, size_t vec_size, int8_t val) { - __m512i zmm_target = _mm512_set1_epi8(val); - __m512i zmm_data; - size_t num_chunks = vec_size / 64; - - for (size_t i = 0; i < 64 * num_chunks; i += 64) { - zmm_data = - _mm512_loadu_si512(reinterpret_cast(src + i)); - __mmask64 mask = _mm512_cmpeq_epi8_mask(zmm_data, zmm_target); - if (mask != 0) { - return true; - } - } - - for (size_t i = 64 * num_chunks; i < vec_size; ++i) { - if (src[i] == val) { - return true; - } - } - return false; -} - -template <> -bool -FindTermAVX512(const int16_t* src, size_t vec_size, int16_t val) { - __m512i zmm_target = _mm512_set1_epi16(val); - __m512i zmm_data; - size_t num_chunks = vec_size / 32; - - for (size_t i = 0; i < 32 * num_chunks; i += 32) { - zmm_data = - _mm512_loadu_si512(reinterpret_cast(src + i)); - __mmask32 mask = _mm512_cmpeq_epi16_mask(zmm_data, zmm_target); - if (mask != 0) { - return true; - } - } - - for (size_t i = 32 * num_chunks; i < vec_size; ++i) { - if (src[i] == val) { - return true; - } - } - return false; -} - -template <> -bool -FindTermAVX512(const int32_t* src, size_t vec_size, int32_t val) { - __m512i zmm_target = _mm512_set1_epi32(val); - __m512i zmm_data; - size_t num_chunks = vec_size / 16; - - for (size_t i = 0; i < 16 * num_chunks; i += 16) { - zmm_data = - _mm512_loadu_si512(reinterpret_cast(src + i)); - __mmask16 mask = _mm512_cmpeq_epi32_mask(zmm_data, zmm_target); - if (mask != 0) { - return true; - } - } - - for (size_t i = 16 * num_chunks; i < vec_size; ++i) { - if (src[i] == val) { - return true; - } - } - return false; -} - -template <> -bool -FindTermAVX512(const int64_t* src, size_t vec_size, int64_t val) { - __m512i zmm_target = _mm512_set1_epi64(val); - __m512i zmm_data; - size_t num_chunks = vec_size / 8; - - for (size_t i = 0; i < 8 * num_chunks; i += 8) { - zmm_data = - _mm512_loadu_si512(reinterpret_cast(src + i)); - __mmask8 mask = _mm512_cmpeq_epi64_mask(zmm_data, zmm_target); - if (mask != 0) { - return true; - } - } - - for (size_t i = 8 * num_chunks; i < vec_size; ++i) { - if (src[i] == val) { - return true; - } - } - return false; -} - -template <> -bool -FindTermAVX512(const float* src, size_t vec_size, float val) { - __m512 zmm_target = _mm512_set1_ps(val); - __m512 zmm_data; - size_t num_chunks = vec_size / 16; - - for (size_t i = 0; i < 16 * num_chunks; i += 16) { - zmm_data = _mm512_loadu_ps(src + i); - __mmask16 mask = _mm512_cmp_ps_mask(zmm_data, zmm_target, _CMP_EQ_OQ); - if (mask != 0) { - return true; - } - } - - for (size_t i = 16 * num_chunks; i < vec_size; ++i) { - if (src[i] == val) { - return true; - } - } - return false; -} - -template <> -bool -FindTermAVX512(const double* src, size_t vec_size, double val) { - __m512d zmm_target = _mm512_set1_pd(val); - __m512d zmm_data; - size_t num_chunks = vec_size / 8; - - for (size_t i = 0; i < 8 * num_chunks; i += 8) { - zmm_data = _mm512_loadu_pd(src + i); - __mmask8 mask = _mm512_cmp_pd_mask(zmm_data, zmm_target, _CMP_EQ_OQ); - if (mask != 0) { - return true; - } - } - - for (size_t i = 8 * num_chunks; i < vec_size; ++i) { - if (src[i] == val) { - return true; - } - } - return false; -} - -void -AndBoolAVX512(bool* left, bool* right, int64_t size) { - int num_chunk = size / 64; - for (size_t i = 0; i < num_chunk * 64; i += 64) { - __m512i l_reg = - _mm512_loadu_si512(reinterpret_cast<__m512i*>(left + i)); - __m512i r_reg = - _mm512_loadu_si512(reinterpret_cast<__m512i*>(right + i)); - __m512i res = _mm512_and_si512(l_reg, r_reg); - _mm512_storeu_si512(reinterpret_cast<__m512i*>(left + i), res); - } - for (size_t i = num_chunk * 64; i < size; ++i) { - left[i] &= right[i]; - } -} - -void -OrBoolAVX512(bool* left, bool* right, int64_t size) { - int num_chunk = size / 64; - for (size_t i = 0; i < num_chunk * 64; i += 64) { - __m512i l_reg = - _mm512_loadu_si512(reinterpret_cast<__m512i*>(left + i)); - __m512i r_reg = - _mm512_loadu_si512(reinterpret_cast<__m512i*>(right + i)); - __m512i res = _mm512_or_si512(l_reg, r_reg); - _mm512_storeu_si512(reinterpret_cast<__m512i*>(left + i), res); - } - for (size_t i = num_chunk * 64; i < size; ++i) { - left[i] |= right[i]; - } -} - -template -struct CompareOperator; - -template -struct CompareOperator { - static constexpr int ComparePredicate = - std::is_floating_point_v ? _CMP_EQ_OQ : _MM_CMPINT_EQ; - static constexpr bool - Op(T a, T b) { - return a == b; - } -}; - -template -struct CompareOperator { - static constexpr int ComparePredicate = - std::is_floating_point_v ? _CMP_NEQ_OQ : _MM_CMPINT_NE; - static constexpr bool - Op(T a, T b) { - return a != b; - } -}; - -template -struct CompareOperator { - static constexpr int ComparePredicate = - std::is_floating_point_v ? _CMP_LT_OQ : _MM_CMPINT_LT; - static constexpr bool - Op(T a, T b) { - return a < b; - } -}; - -template -struct CompareOperator { - static constexpr int ComparePredicate = - std::is_floating_point_v ? _CMP_LE_OQ : _MM_CMPINT_LE; - static constexpr bool - Op(T a, T b) { - return a <= b; - } -}; - -template -struct CompareOperator { - static constexpr int ComparePredicate = - std::is_floating_point_v ? _CMP_GT_OQ : _MM_CMPINT_NLE; - static constexpr bool - Op(T a, T b) { - return a > b; - } -}; - -template -struct CompareOperator { - static constexpr int ComparePredicate = - std::is_floating_point_v ? _CMP_GE_OQ : _MM_CMPINT_NLT; - static constexpr bool - Op(T a, T b) { - return a >= b; - } -}; - -template -struct CompareValAVX512Impl { - static void - Compare(const T* src, size_t size, T val, bool* res) { - static_assert(std::is_integral_v || std::is_floating_point_v, - "T must be integral or float/double type"); - } -}; - -template -struct CompareValAVX512Impl { - static void - Compare(const int8_t* src, size_t size, int8_t val, bool* res) { - __m512i target = _mm512_set1_epi8(val); - - int middle = size / 64 * 64; - - for (size_t i = 0; i < middle; i += 64) { - __m512i data = - _mm512_loadu_si512(reinterpret_cast(src + i)); - - __mmask64 cmp_res_mask = _mm512_cmp_epi8_mask( - data, - target, - (CompareOperator::ComparePredicate)); - __m512i cmp_res = _mm512_maskz_set1_epi8(cmp_res_mask, 0x01); - _mm512_storeu_si512(res + i, cmp_res); - } - - for (size_t i = middle; i < size; ++i) { - res[i] = CompareOperator::Op(src[i], val); - } - } -}; - -template -struct CompareValAVX512Impl { - static void - Compare(const int16_t* src, size_t size, int16_t val, bool* res) { - __m512i target = _mm512_set1_epi16(val); - - int middle = size / 32 * 32; - - for (size_t i = 0; i < middle; i += 32) { - __m512i data = - _mm512_loadu_si512(reinterpret_cast(src + i)); - - __mmask32 cmp_res_mask = _mm512_cmp_epi16_mask( - data, - target, - (CompareOperator::ComparePredicate)); - __m256i cmp_res = _mm256_maskz_set1_epi8(cmp_res_mask, 0x01); - _mm256_storeu_si256((__m256i*)(res + i), cmp_res); - } - - for (size_t i = middle; i < size; ++i) { - res[i] = CompareOperator::Op(src[i], val); - } - } -}; - -template -struct CompareValAVX512Impl { - static void - Compare(const int32_t* src, size_t size, int32_t val, bool* res) { - __m512i target = _mm512_set1_epi32(val); - - int middle = size / 16 * 16; - - for (size_t i = 0; i < middle; i += 16) { - __m512i data = - _mm512_loadu_si512(reinterpret_cast(src + i)); - - __mmask16 cmp_res_mask = _mm512_cmp_epi32_mask( - data, - target, - (CompareOperator::ComparePredicate)); - __m128i cmp_res = _mm_maskz_set1_epi8(cmp_res_mask, 0x01); - _mm_storeu_si128((__m128i*)(res + i), cmp_res); - } - - for (size_t i = middle; i < size; ++i) { - res[i] = CompareOperator::Op(src[i], val); - } - } -}; - -template -struct CompareValAVX512Impl { - static void - Compare(const int64_t* src, size_t size, int64_t val, bool* res) { - __m512i target = _mm512_set1_epi64(val); - int middle = size / 8 * 8; - int index = 0; - for (size_t i = 0; i < middle; i += 8) { - __m512i data = - _mm512_loadu_si512(reinterpret_cast(src + i)); - __mmask8 mask = _mm512_cmp_epi64_mask( - data, - target, - (CompareOperator::ComparePredicate)); - __m128i cmp_res = _mm_maskz_set1_epi8(mask, 0x01); - _mm_storel_epi64((__m128i_u*)(res + i), cmp_res); - } - - for (size_t i = middle; i < size; ++i) { - res[i] = CompareOperator::Op(src[i], val); - } - } -}; - -template -struct CompareValAVX512Impl { - static void - Compare(const float* src, size_t size, float val, bool* res) { - __m512 target = _mm512_set1_ps(val); - - int middle = size / 16 * 16; - - for (size_t i = 0; i < middle; i += 16) { - __m512 data = _mm512_loadu_ps(src + i); - - __mmask16 cmp_res_mask = _mm512_cmp_ps_mask( - data, target, (CompareOperator::ComparePredicate)); - __m128i cmp_res = _mm_maskz_set1_epi8(cmp_res_mask, 0x01); - _mm_storeu_si128((__m128i*)(res + i), cmp_res); - } - - for (size_t i = middle; i < size; ++i) { - res[i] = CompareOperator::Op(src[i], val); - } - } -}; - -template -struct CompareValAVX512Impl { - static void - Compare(const double* src, size_t size, double val, bool* res) { - __m512d target = _mm512_set1_pd(val); - - int middle = size / 8 * 8; - - for (size_t i = 0; i < middle; i += 8) { - __m512d data = _mm512_loadu_pd(src + i); - - __mmask8 cmp_res_mask = _mm512_cmp_pd_mask( - data, - target, - (CompareOperator::ComparePredicate)); - __m128i cmp_res = _mm_maskz_set1_epi8(cmp_res_mask, 0x01); - _mm_storel_epi64((__m128i_u*)(res + i), cmp_res); - } - - for (size_t i = middle; i < size; ++i) { - res[i] = CompareOperator::Op(src[i], val); - } - } -}; - -template -void -EqualValAVX512(const T* src, size_t size, T val, bool* res) { - static_assert(std::is_integral_v || std::is_floating_point_v, - "T must be integral or float/double type"); - CompareValAVX512Impl::Compare(src, size, val, res); -}; -template void -EqualValAVX512(const int8_t* src, size_t size, int8_t val, bool* res); -template void -EqualValAVX512(const int16_t* src, size_t size, int16_t val, bool* res); -template void -EqualValAVX512(const int32_t* src, size_t size, int32_t val, bool* res); -template void -EqualValAVX512(const int64_t* src, size_t size, int64_t val, bool* res); -template void -EqualValAVX512(const float* src, size_t size, float val, bool* res); -template void -EqualValAVX512(const double* src, size_t size, double val, bool* res); - -template -void -LessValAVX512(const T* src, size_t size, T val, bool* res) { - static_assert(std::is_integral_v || std::is_floating_point_v, - "T must be integral or float/double type"); - CompareValAVX512Impl::Compare(src, size, val, res); -}; -template void -LessValAVX512(const int8_t* src, size_t size, int8_t val, bool* res); -template void -LessValAVX512(const int16_t* src, size_t size, int16_t val, bool* res); -template void -LessValAVX512(const int32_t* src, size_t size, int32_t val, bool* res); -template void -LessValAVX512(const int64_t* src, size_t size, int64_t val, bool* res); -template void -LessValAVX512(const float* src, size_t size, float val, bool* res); -template void -LessValAVX512(const double* src, size_t size, double val, bool* res); - -template -void -GreaterValAVX512(const T* src, size_t size, T val, bool* res) { - static_assert(std::is_integral_v || std::is_floating_point_v, - "T must be integral or float/double type"); - CompareValAVX512Impl::Compare(src, size, val, res); -}; -template void -GreaterValAVX512(const int8_t* src, size_t size, int8_t val, bool* res); -template void -GreaterValAVX512(const int16_t* src, size_t size, int16_t val, bool* res); -template void -GreaterValAVX512(const int32_t* src, size_t size, int32_t val, bool* res); -template void -GreaterValAVX512(const int64_t* src, size_t size, int64_t val, bool* res); -template void -GreaterValAVX512(const float* src, size_t size, float val, bool* res); -template void -GreaterValAVX512(const double* src, size_t size, double val, bool* res); - -template -void -NotEqualValAVX512(const T* src, size_t size, T val, bool* res) { - static_assert(std::is_integral_v || std::is_floating_point_v, - "T must be integral or float/double type"); - CompareValAVX512Impl::Compare(src, size, val, res); -}; -template void -NotEqualValAVX512(const int8_t* src, size_t size, int8_t val, bool* res); -template void -NotEqualValAVX512(const int16_t* src, size_t size, int16_t val, bool* res); -template void -NotEqualValAVX512(const int32_t* src, size_t size, int32_t val, bool* res); -template void -NotEqualValAVX512(const int64_t* src, size_t size, int64_t val, bool* res); -template void -NotEqualValAVX512(const float* src, size_t size, float val, bool* res); -template void -NotEqualValAVX512(const double* src, size_t size, double val, bool* res); - -template -void -LessEqualValAVX512(const T* src, size_t size, T val, bool* res) { - static_assert(std::is_integral_v || std::is_floating_point_v, - "T must be integral or float/double type"); - CompareValAVX512Impl::Compare(src, size, val, res); -}; -template void -LessEqualValAVX512(const int8_t* src, size_t size, int8_t val, bool* res); -template void -LessEqualValAVX512(const int16_t* src, size_t size, int16_t val, bool* res); -template void -LessEqualValAVX512(const int32_t* src, size_t size, int32_t val, bool* res); -template void -LessEqualValAVX512(const int64_t* src, size_t size, int64_t val, bool* res); -template void -LessEqualValAVX512(const float* src, size_t size, float val, bool* res); -template void -LessEqualValAVX512(const double* src, size_t size, double val, bool* res); - -template -void -GreaterEqualValAVX512(const T* src, size_t size, T val, bool* res) { - static_assert(std::is_integral_v || std::is_floating_point_v, - "T must be integral or float/double type"); - CompareValAVX512Impl::Compare(src, size, val, res); -}; -template void -GreaterEqualValAVX512(const int8_t* src, size_t size, int8_t val, bool* res); -template void -GreaterEqualValAVX512(const int16_t* src, size_t size, int16_t val, bool* res); -template void -GreaterEqualValAVX512(const int32_t* src, size_t size, int32_t val, bool* res); -template void -GreaterEqualValAVX512(const int64_t* src, size_t size, int64_t val, bool* res); -template void -GreaterEqualValAVX512(const float* src, size_t size, float val, bool* res); -template void -GreaterEqualValAVX512(const double* src, size_t size, double val, bool* res); - -template -void -CompareColumnAVX512(const T* left, const T* right, size_t size, bool* res) { - static_assert(std::is_integral_v || std::is_floating_point_v, - "T must be integral or float/double type"); -} - -template -struct CompareColumnAVX512Impl { - static void - Compare(const T* left, const T* right, size_t size, bool* res) { - static_assert(std::is_integral_v, "T must be integral type"); - - int batch_size = 512 / (sizeof(T) * 8); - int middle = size / batch_size * batch_size; - - for (size_t i = 0; i < middle; i += batch_size) { - __m512i left_reg = - _mm512_loadu_si512(reinterpret_cast(left + i)); - __m512i right_reg = - _mm512_loadu_si512(reinterpret_cast(right + i)); - - if constexpr (std::is_same_v) { - __mmask64 cmp_res_mask = _mm512_cmp_epi8_mask( - left_reg, - right_reg, - (CompareOperator::ComparePredicate)); - - __m512i cmp_res = _mm512_maskz_set1_epi8(cmp_res_mask, 0x01); - _mm512_storeu_si512(res + i, cmp_res); - } else if constexpr (std::is_same_v) { - __mmask32 cmp_res_mask = _mm512_cmp_epi16_mask( - left_reg, - right_reg, - (CompareOperator::ComparePredicate)); - - __m256i cmp_res = _mm256_maskz_set1_epi8(cmp_res_mask, 0x01); - _mm256_storeu_si256((__m256i*)(res + i), cmp_res); - } else if constexpr (std::is_same_v) { - __mmask16 cmp_res_mask = _mm512_cmp_epi32_mask( - left_reg, - right_reg, - (CompareOperator::ComparePredicate)); - - __m128i cmp_res = _mm_maskz_set1_epi8(cmp_res_mask, 0x01); - _mm_storeu_si128((__m128i*)(res + i), cmp_res); - } else if constexpr (std::is_same_v) { - __mmask8 mask = _mm512_cmp_epi64_mask( - left_reg, - right_reg, - (CompareOperator::ComparePredicate)); - - __m128i cmp_res = _mm_maskz_set1_epi8(mask, 0x01); - _mm_storel_epi64((__m128i_u*)(res + i), cmp_res); - } - } - - for (size_t i = middle; i < size; ++i) { - res[i] = CompareOperator::Op(left[i], right[i]); - } - } -}; - -template -struct CompareColumnAVX512Impl { - static void - Compare(const float* left, const float* right, size_t size, bool* res) { - int batch_size = 512 / (sizeof(float) * 8); - int middle = size / batch_size * batch_size; - - for (size_t i = 0; i < middle; i += batch_size) { - __m512 left_reg = - _mm512_loadu_ps(reinterpret_cast(left + i)); - __m512 right_reg = - _mm512_loadu_ps(reinterpret_cast(right + i)); - - __mmask16 cmp_res_mask = _mm512_cmp_ps_mask( - left_reg, - right_reg, - (CompareOperator::ComparePredicate)); - - __m128i cmp_res = _mm_maskz_set1_epi8(cmp_res_mask, 0x01); - _mm_storeu_si128((__m128i*)(res + i), cmp_res); - } - - for (size_t i = middle; i < size; ++i) { - res[i] = CompareOperator::Op(left[i], right[i]); - } - } -}; - -template -struct CompareColumnAVX512Impl { - static void - Compare(const double* left, const double* right, size_t size, bool* res) { - int batch_size = 512 / (sizeof(double) * 8); - int middle = size / batch_size * batch_size; - - for (size_t i = 0; i < middle; i += batch_size) { - __m512d left_reg = - _mm512_loadu_pd(reinterpret_cast(left + i)); - __m512d right_reg = - _mm512_loadu_pd(reinterpret_cast(right + i)); - - __mmask8 cmp_res_mask = _mm512_cmp_pd_mask( - left_reg, - right_reg, - (CompareOperator::ComparePredicate)); - - __m128i cmp_res = _mm_maskz_set1_epi8(cmp_res_mask, 0x01); - _mm_storel_epi64((__m128i_u*)(res + i), cmp_res); - } - - for (size_t i = middle; i < size; ++i) { - res[i] = CompareOperator::Op(left[i], right[i]); - } - } -}; - -template -void -EqualColumnAVX512(const T* left, const T* right, size_t size, bool* res) { - static_assert(std::is_integral_v || std::is_floating_point_v, - "T must be integral or float/double type"); - CompareColumnAVX512Impl::Compare( - left, right, size, res); -}; - -template void -EqualColumnAVX512(const int8_t* left, - const int8_t* right, - size_t size, - bool* res); -template void -EqualColumnAVX512(const int16_t* left, - const int16_t* right, - size_t size, - bool* res); -template void -EqualColumnAVX512(const int32_t* left, - const int32_t* right, - size_t size, - bool* res); -template void -EqualColumnAVX512(const int64_t* left, - const int64_t* right, - size_t size, - bool* res); -template void -EqualColumnAVX512(const float* left, - const float* right, - size_t size, - bool* res); -template void -EqualColumnAVX512(const double* left, - const double* right, - size_t size, - bool* res); - -template -void -LessColumnAVX512(const T* left, const T* right, size_t size, bool* res) { - static_assert(std::is_integral_v || std::is_floating_point_v, - "T must be integral or float/double type"); - CompareColumnAVX512Impl::Compare( - left, right, size, res); -}; -template void -LessColumnAVX512(const int8_t* left, - const int8_t* right, - size_t size, - bool* res); -template void -LessColumnAVX512(const int16_t* left, - const int16_t* right, - size_t size, - bool* res); -template void -LessColumnAVX512(const int32_t* left, - const int32_t* right, - size_t size, - bool* res); -template void -LessColumnAVX512(const int64_t* left, - const int64_t* right, - size_t size, - bool* res); -template void -LessColumnAVX512(const float* left, const float* right, size_t size, bool* res); -template void -LessColumnAVX512(const double* left, - const double* right, - size_t size, - bool* res); - -template -void -GreaterColumnAVX512(const T* left, const T* right, size_t size, bool* res) { - static_assert(std::is_integral_v || std::is_floating_point_v, - "T must be integral or float/double type"); - CompareColumnAVX512Impl::Compare( - left, right, size, res); -}; -template void -GreaterColumnAVX512(const int8_t* left, - const int8_t* right, - size_t size, - bool* res); -template void -GreaterColumnAVX512(const int16_t* left, - const int16_t* right, - size_t size, - bool* res); -template void -GreaterColumnAVX512(const int32_t* left, - const int32_t* right, - size_t size, - bool* res); -template void -GreaterColumnAVX512(const int64_t* left, - const int64_t* right, - size_t size, - bool* res); -template void -GreaterColumnAVX512(const float* left, - const float* right, - size_t size, - bool* res); -template void -GreaterColumnAVX512(const double* left, - const double* right, - size_t size, - bool* res); - -template -void -LessEqualColumnAVX512(const T* left, const T* right, size_t size, bool* res) { - static_assert(std::is_integral_v || std::is_floating_point_v, - "T must be integral or float/double type"); - CompareColumnAVX512Impl::Compare( - left, right, size, res); -}; -template void -LessEqualColumnAVX512(const int8_t* left, - const int8_t* right, - size_t size, - bool* res); -template void -LessEqualColumnAVX512(const int16_t* left, - const int16_t* right, - size_t size, - bool* res); -template void -LessEqualColumnAVX512(const int32_t* left, - const int32_t* right, - size_t size, - bool* res); -template void -LessEqualColumnAVX512(const int64_t* left, - const int64_t* right, - size_t size, - bool* res); -template void -LessEqualColumnAVX512(const float* left, - const float* right, - size_t size, - bool* res); -template void -LessEqualColumnAVX512(const double* left, - const double* right, - size_t size, - bool* res); - -template -void -GreaterEqualColumnAVX512(const T* left, - const T* right, - size_t size, - bool* res) { - static_assert(std::is_integral_v || std::is_floating_point_v, - "T must be integral or float/double type"); - CompareColumnAVX512Impl::Compare( - left, right, size, res); -}; -template void -GreaterEqualColumnAVX512(const int8_t* left, - const int8_t* right, - size_t size, - bool* res); -template void -GreaterEqualColumnAVX512(const int16_t* left, - const int16_t* right, - size_t size, - bool* res); -template void -GreaterEqualColumnAVX512(const int32_t* left, - const int32_t* right, - size_t size, - bool* res); -template void -GreaterEqualColumnAVX512(const int64_t* left, - const int64_t* right, - size_t size, - bool* res); -template void -GreaterEqualColumnAVX512(const float* left, - const float* right, - size_t size, - bool* res); -template void -GreaterEqualColumnAVX512(const double* left, - const double* right, - size_t size, - bool* res); - -template -void -NotEqualColumnAVX512(const T* left, const T* right, size_t size, bool* res) { - static_assert(std::is_integral_v || std::is_floating_point_v, - "T must be integral or float/double type"); - CompareColumnAVX512Impl::Compare( - left, right, size, res); -}; - -template void -NotEqualColumnAVX512(const int8_t* left, - const int8_t* right, - size_t size, - bool* res); -template void -NotEqualColumnAVX512(const int16_t* left, - const int16_t* right, - size_t size, - bool* res); -template void -NotEqualColumnAVX512(const int32_t* left, - const int32_t* right, - size_t size, - bool* res); -template void -NotEqualColumnAVX512(const int64_t* left, - const int64_t* right, - size_t size, - bool* res); -template void -NotEqualColumnAVX512(const float* left, - const float* right, - size_t size, - bool* res); -template void -NotEqualColumnAVX512(const double* left, - const double* right, - size_t size, - bool* res); - -} // namespace simd -} // namespace milvus -#endif diff --git a/internal/core/src/simd/avx512.h b/internal/core/src/simd/avx512.h deleted file mode 100644 index 9b5c549d3d42..000000000000 --- a/internal/core/src/simd/avx512.h +++ /dev/null @@ -1,113 +0,0 @@ -// Copyright (C) 2019-2023 Zilliz. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software distributed under the License -// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -// or implied. See the License for the specific language governing permissions and limitations under the License. - -#pragma once - -#include -#include -#include - -#include "common.h" - -namespace milvus { -namespace simd { - -template -bool -FindTermAVX512(const T* src, size_t vec_size, T va) { - CHECK_SUPPORTED_TYPE(T, "unsupported type for FindTermAVX512"); - return false; -} - -template <> -bool -FindTermAVX512(const bool* src, size_t vec_size, bool val); - -template <> -bool -FindTermAVX512(const int8_t* src, size_t vec_size, int8_t val); - -template <> -bool -FindTermAVX512(const int16_t* src, size_t vec_size, int16_t val); - -template <> -bool -FindTermAVX512(const int32_t* src, size_t vec_size, int32_t val); - -template <> -bool -FindTermAVX512(const int64_t* src, size_t vec_size, int64_t val); - -template <> -bool -FindTermAVX512(const float* src, size_t vec_size, float val); - -template <> -bool -FindTermAVX512(const double* src, size_t vec_size, double val); - -void -AndBoolAVX512(bool* left, bool* right, int64_t size); - -void -OrBoolAVX512(bool* left, bool* right, int64_t size); - -template -void -EqualValAVX512(const T* src, size_t size, T val, bool* res); - -template -void -LessValAVX512(const T* src, size_t size, T val, bool* res); - -template -void -GreaterValAVX512(const T* src, size_t size, T val, bool* res); - -template -void -NotEqualValAVX512(const T* src, size_t size, T val, bool* res); - -template -void -LessEqualValAVX512(const T* src, size_t size, T val, bool* res); - -template -void -GreaterEqualValAVX512(const T* src, size_t size, T val, bool* res); - -template -void -EqualColumnAVX512(const T* left, const T* right, size_t size, bool* res); - -template -void -LessColumnAVX512(const T* left, const T* right, size_t size, bool* res); - -template -void -LessEqualColumnAVX512(const T* left, const T* right, size_t size, bool* res); - -template -void -GreaterColumnAVX512(const T* left, const T* right, size_t size, bool* res); - -template -void -GreaterEqualColumnAVX512(const T* left, const T* right, size_t size, bool* res); - -template -void -NotEqualColumnAVX512(const T* left, const T* right, size_t size, bool* res); - -} // namespace simd -} // namespace milvus diff --git a/internal/core/src/simd/common.h b/internal/core/src/simd/common.h deleted file mode 100644 index f6e0c9e3c630..000000000000 --- a/internal/core/src/simd/common.h +++ /dev/null @@ -1,53 +0,0 @@ -// Copyright (C) 2019-2023 Zilliz. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software distributed under the License -// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -// or implied. See the License for the specific language governing permissions and limitations under the License. - -#pragma once - -#include -#include -#include - -namespace milvus { -namespace simd { - -using BitsetBlockType = unsigned long; -constexpr size_t BITSET_BLOCK_SIZE = sizeof(unsigned long); - -/* -* For term size less than TERM_EXPR_IN_SIZE_THREAD, -* using simd search better for all numberic type. -* For term size bigger than TERM_EXPR_IN_SIZE_THREAD, -* using set search better for all numberic type. -* 50 is experimental value, using dynamic plan to support modify it -* in different situation. -*/ -const int TERM_EXPR_IN_SIZE_THREAD = 50; - -#define CHECK_SUPPORTED_TYPE(T, Message) \ - static_assert( \ - std::is_same::value || std::is_same::value || \ - std::is_same::value || \ - std::is_same::value || \ - std::is_same::value || \ - std::is_same::value || std::is_same::value, \ - Message); - -enum class CompareType { - GT = 1, - GE = 2, - LT = 3, - LE = 4, - EQ = 5, - NEQ = 6, -}; - -} // namespace simd -} // namespace milvus diff --git a/internal/core/src/simd/hook.cpp b/internal/core/src/simd/hook.cpp deleted file mode 100644 index 89b5b300671b..000000000000 --- a/internal/core/src/simd/hook.cpp +++ /dev/null @@ -1,595 +0,0 @@ -// Copyright (C) 2019-2023 Zilliz. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software distributed under the License -// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -// or implied. See the License for the specific language governing permissions and limitations under the License. - -// -*- c++ -*- - -#include "hook.h" - -#include -#include -#include - -#include "ref.h" -#include "log/Log.h" -#if defined(__x86_64__) -#include "avx2.h" -#include "avx512.h" -#include "sse2.h" -#include "sse4.h" -#include "instruction_set.h" -#elif defined(__ARM_NEON) -#include "neon.h" -#endif - -namespace milvus { -namespace simd { - -decltype(get_bitset_block) get_bitset_block = GetBitsetBlockRef; -decltype(all_false) all_false = AllFalseRef; -decltype(all_true) all_true = AllTrueRef; -decltype(invert_bool) invert_bool = InvertBoolRef; -decltype(and_bool) and_bool = AndBoolRef; -decltype(or_bool) or_bool = OrBoolRef; - -#define DECLARE_FIND_TERM_PTR(type) \ - FindTermPtr find_term_##type = FindTermRef; -DECLARE_FIND_TERM_PTR(bool) -DECLARE_FIND_TERM_PTR(int8_t) -DECLARE_FIND_TERM_PTR(int16_t) -DECLARE_FIND_TERM_PTR(int32_t) -DECLARE_FIND_TERM_PTR(int64_t) -DECLARE_FIND_TERM_PTR(float) -DECLARE_FIND_TERM_PTR(double) - -#define DECLARE_COMPARE_VAL_PTR(prefix, RefFunc, type) \ - CompareValPtr prefix##_##type = RefFunc; - -DECLARE_COMPARE_VAL_PTR(equal_val, EqualValRef, bool) -DECLARE_COMPARE_VAL_PTR(equal_val, EqualValRef, int8_t) -DECLARE_COMPARE_VAL_PTR(equal_val, EqualValRef, int16_t) -DECLARE_COMPARE_VAL_PTR(equal_val, EqualValRef, int32_t) -DECLARE_COMPARE_VAL_PTR(equal_val, EqualValRef, int64_t) -DECLARE_COMPARE_VAL_PTR(equal_val, EqualValRef, float) -DECLARE_COMPARE_VAL_PTR(equal_val, EqualValRef, double) - -DECLARE_COMPARE_VAL_PTR(less_val, LessValRef, bool) -DECLARE_COMPARE_VAL_PTR(less_val, LessValRef, int8_t) -DECLARE_COMPARE_VAL_PTR(less_val, LessValRef, int16_t) -DECLARE_COMPARE_VAL_PTR(less_val, LessValRef, int32_t) -DECLARE_COMPARE_VAL_PTR(less_val, LessValRef, int64_t) -DECLARE_COMPARE_VAL_PTR(less_val, LessValRef, float) -DECLARE_COMPARE_VAL_PTR(less_val, LessValRef, double) - -DECLARE_COMPARE_VAL_PTR(greater_val, GreaterValRef, bool) -DECLARE_COMPARE_VAL_PTR(greater_val, GreaterValRef, int8_t) -DECLARE_COMPARE_VAL_PTR(greater_val, GreaterValRef, int16_t) -DECLARE_COMPARE_VAL_PTR(greater_val, GreaterValRef, int32_t) -DECLARE_COMPARE_VAL_PTR(greater_val, GreaterValRef, int64_t) -DECLARE_COMPARE_VAL_PTR(greater_val, GreaterValRef, float) -DECLARE_COMPARE_VAL_PTR(greater_val, GreaterValRef, double) - -DECLARE_COMPARE_VAL_PTR(less_equal_val, LessEqualValRef, bool) -DECLARE_COMPARE_VAL_PTR(less_equal_val, LessEqualValRef, int8_t) -DECLARE_COMPARE_VAL_PTR(less_equal_val, LessEqualValRef, int16_t) -DECLARE_COMPARE_VAL_PTR(less_equal_val, LessEqualValRef, int32_t) -DECLARE_COMPARE_VAL_PTR(less_equal_val, LessEqualValRef, int64_t) -DECLARE_COMPARE_VAL_PTR(less_equal_val, LessEqualValRef, float) -DECLARE_COMPARE_VAL_PTR(less_equal_val, LessEqualValRef, double) - -DECLARE_COMPARE_VAL_PTR(greater_equal_val, GreaterEqualValRef, bool) -DECLARE_COMPARE_VAL_PTR(greater_equal_val, GreaterEqualValRef, int8_t) -DECLARE_COMPARE_VAL_PTR(greater_equal_val, GreaterEqualValRef, int16_t) -DECLARE_COMPARE_VAL_PTR(greater_equal_val, GreaterEqualValRef, int32_t) -DECLARE_COMPARE_VAL_PTR(greater_equal_val, GreaterEqualValRef, int64_t) -DECLARE_COMPARE_VAL_PTR(greater_equal_val, GreaterEqualValRef, float) -DECLARE_COMPARE_VAL_PTR(greater_equal_val, GreaterEqualValRef, double) - -DECLARE_COMPARE_VAL_PTR(not_equal_val, NotEqualValRef, bool) -DECLARE_COMPARE_VAL_PTR(not_equal_val, NotEqualValRef, int8_t) -DECLARE_COMPARE_VAL_PTR(not_equal_val, NotEqualValRef, int16_t) -DECLARE_COMPARE_VAL_PTR(not_equal_val, NotEqualValRef, int32_t) -DECLARE_COMPARE_VAL_PTR(not_equal_val, NotEqualValRef, int64_t) -DECLARE_COMPARE_VAL_PTR(not_equal_val, NotEqualValRef, float) -DECLARE_COMPARE_VAL_PTR(not_equal_val, NotEqualValRef, double) - -#define DECLARE_COMPARE_COL_PTR(prefix, RefFunc, type) \ - CompareColPtr prefix##_##type = RefFunc; - -DECLARE_COMPARE_COL_PTR(equal_col, EqualColumnRef, bool) -DECLARE_COMPARE_COL_PTR(equal_col, EqualColumnRef, int8_t) -DECLARE_COMPARE_COL_PTR(equal_col, EqualColumnRef, int16_t) -DECLARE_COMPARE_COL_PTR(equal_col, EqualColumnRef, int32_t) -DECLARE_COMPARE_COL_PTR(equal_col, EqualColumnRef, int64_t) -DECLARE_COMPARE_COL_PTR(equal_col, EqualColumnRef, float) -DECLARE_COMPARE_COL_PTR(equal_col, EqualColumnRef, double) - -DECLARE_COMPARE_COL_PTR(less_col, LessColumnRef, bool) -DECLARE_COMPARE_COL_PTR(less_col, LessColumnRef, int8_t) -DECLARE_COMPARE_COL_PTR(less_col, LessColumnRef, int16_t) -DECLARE_COMPARE_COL_PTR(less_col, LessColumnRef, int32_t) -DECLARE_COMPARE_COL_PTR(less_col, LessColumnRef, int64_t) -DECLARE_COMPARE_COL_PTR(less_col, LessColumnRef, float) -DECLARE_COMPARE_COL_PTR(less_col, LessColumnRef, double) - -DECLARE_COMPARE_COL_PTR(greater_col, GreaterColumnRef, bool) -DECLARE_COMPARE_COL_PTR(greater_col, GreaterColumnRef, int8_t) -DECLARE_COMPARE_COL_PTR(greater_col, GreaterColumnRef, int16_t) -DECLARE_COMPARE_COL_PTR(greater_col, GreaterColumnRef, int32_t) -DECLARE_COMPARE_COL_PTR(greater_col, GreaterColumnRef, int64_t) -DECLARE_COMPARE_COL_PTR(greater_col, GreaterColumnRef, float) -DECLARE_COMPARE_COL_PTR(greater_col, GreaterColumnRef, double) - -DECLARE_COMPARE_COL_PTR(less_equal_col, LessEqualColumnRef, bool) -DECLARE_COMPARE_COL_PTR(less_equal_col, LessEqualColumnRef, int8_t) -DECLARE_COMPARE_COL_PTR(less_equal_col, LessEqualColumnRef, int16_t) -DECLARE_COMPARE_COL_PTR(less_equal_col, LessEqualColumnRef, int32_t) -DECLARE_COMPARE_COL_PTR(less_equal_col, LessEqualColumnRef, int64_t) -DECLARE_COMPARE_COL_PTR(less_equal_col, LessEqualColumnRef, float) -DECLARE_COMPARE_COL_PTR(less_equal_col, LessEqualColumnRef, double) - -DECLARE_COMPARE_COL_PTR(greater_equal_col, GreaterEqualColumnRef, bool) -DECLARE_COMPARE_COL_PTR(greater_equal_col, GreaterEqualColumnRef, int8_t) -DECLARE_COMPARE_COL_PTR(greater_equal_col, GreaterEqualColumnRef, int16_t) -DECLARE_COMPARE_COL_PTR(greater_equal_col, GreaterEqualColumnRef, int32_t) -DECLARE_COMPARE_COL_PTR(greater_equal_col, GreaterEqualColumnRef, int64_t) -DECLARE_COMPARE_COL_PTR(greater_equal_col, GreaterEqualColumnRef, float) -DECLARE_COMPARE_COL_PTR(greater_equal_col, GreaterEqualColumnRef, double) - -DECLARE_COMPARE_COL_PTR(not_equal_col, NotEqualColumnRef, bool) -DECLARE_COMPARE_COL_PTR(not_equal_col, NotEqualColumnRef, int8_t) -DECLARE_COMPARE_COL_PTR(not_equal_col, NotEqualColumnRef, int16_t) -DECLARE_COMPARE_COL_PTR(not_equal_col, NotEqualColumnRef, int32_t) -DECLARE_COMPARE_COL_PTR(not_equal_col, NotEqualColumnRef, int64_t) -DECLARE_COMPARE_COL_PTR(not_equal_col, NotEqualColumnRef, float) -DECLARE_COMPARE_COL_PTR(not_equal_col, NotEqualColumnRef, double) - -#if defined(__x86_64__) -bool -cpu_support_avx512() { - InstructionSet& instruction_set_inst = InstructionSet::GetInstance(); - return (instruction_set_inst.AVX512F() && instruction_set_inst.AVX512DQ() && - instruction_set_inst.AVX512BW() && instruction_set_inst.AVX512VL()); -} - -bool -cpu_support_avx2() { - InstructionSet& instruction_set_inst = InstructionSet::GetInstance(); - return (instruction_set_inst.AVX2()); -} - -bool -cpu_support_sse4_2() { - InstructionSet& instruction_set_inst = InstructionSet::GetInstance(); - return (instruction_set_inst.SSE42()); -} - -bool -cpu_support_sse2() { - InstructionSet& instruction_set_inst = InstructionSet::GetInstance(); - return (instruction_set_inst.SSE2()); -} -#endif - -static void -bitset_hook() { - static std::mutex hook_mutex; - std::lock_guard lock(hook_mutex); - std::string simd_type = "REF"; -#if defined(__x86_64__) - // SSE2 have best performance in test. - if (cpu_support_sse2()) { - simd_type = "SSE2"; - get_bitset_block = GetBitsetBlockSSE2; - } -#endif - // TODO: support arm cpu - LOG_INFO("bitset hook simd type: {}", simd_type); -} - -static void -find_term_hook() { - static std::mutex hook_mutex; - std::lock_guard lock(hook_mutex); - std::string simd_type = "REF"; -#if defined(__x86_64__) - if (cpu_support_avx512()) { - simd_type = "AVX512"; - find_term_bool = FindTermAVX512; - find_term_int8_t = FindTermAVX512; - find_term_int16_t = FindTermAVX512; - find_term_int32_t = FindTermAVX512; - find_term_int64_t = FindTermAVX512; - find_term_float = FindTermAVX512; - find_term_double = FindTermAVX512; - } else if (cpu_support_avx2()) { - simd_type = "AVX2"; - find_term_bool = FindTermAVX2; - find_term_int8_t = FindTermAVX2; - find_term_int16_t = FindTermAVX2; - find_term_int32_t = FindTermAVX2; - find_term_int64_t = FindTermAVX2; - find_term_float = FindTermAVX2; - find_term_double = FindTermAVX2; - } else if (cpu_support_sse4_2()) { - simd_type = "SSE4"; - find_term_bool = FindTermSSE4; - find_term_int8_t = FindTermSSE4; - find_term_int16_t = FindTermSSE4; - find_term_int32_t = FindTermSSE4; - find_term_int64_t = FindTermSSE4; - find_term_float = FindTermSSE4; - find_term_double = FindTermSSE4; - } else if (cpu_support_sse2()) { - simd_type = "SSE2"; - find_term_bool = FindTermSSE2; - find_term_int8_t = FindTermSSE2; - find_term_int16_t = FindTermSSE2; - find_term_int32_t = FindTermSSE2; - find_term_int64_t = FindTermSSE2; - find_term_float = FindTermSSE2; - find_term_double = FindTermSSE2; - } -#endif - // TODO: support arm cpu - LOG_INFO("find term hook simd type: {}", simd_type); -} - -static void -all_boolean_hook() { - static std::mutex hook_mutex; - std::lock_guard lock(hook_mutex); - std::string simd_type = "REF"; -#if defined(__x86_64__) - if (cpu_support_sse2()) { - simd_type = "SSE2"; - all_false = AllFalseSSE2; - all_true = AllTrueSSE2; - } -#elif defined(__ARM_NEON) - simd_type = "NEON"; - all_false = AllFalseNEON; - all_true = AllTrueNEON; -#endif - // TODO: support arm cpu - LOG_INFO("AllFalse/AllTrue hook simd type: {}", simd_type); -} - -static void -invert_boolean_hook() { - static std::mutex hook_mutex; - std::lock_guard lock(hook_mutex); - std::string simd_type = "REF"; -#if defined(__x86_64__) - if (cpu_support_sse2()) { - simd_type = "SSE2"; - invert_bool = InvertBoolSSE2; - } -#elif defined(__ARM_NEON) - simd_type = "NEON"; - invert_bool = InvertBoolNEON; -#endif - // TODO: support arm cpu - LOG_INFO("InvertBoolean hook simd type: {}", simd_type); -} - -static void -logical_boolean_hook() { - static std::mutex hook_mutex; - std::lock_guard lock(hook_mutex); - std::string simd_type = "REF"; -#if defined(__x86_64__) - if (cpu_support_avx512()) { - simd_type = "AVX512"; - and_bool = AndBoolAVX512; - or_bool = OrBoolAVX512; - } else if (cpu_support_avx2()) { - simd_type = "AVX2"; - and_bool = AndBoolAVX2; - or_bool = OrBoolAVX2; - } else if (cpu_support_sse2()) { - simd_type = "SSE2"; - and_bool = AndBoolSSE2; - or_bool = OrBoolSSE2; - } -#elif defined(__ARM_NEON) - simd_type = "NEON"; - and_bool = AndBoolNEON; - or_bool = OrBoolNEON; -#endif - // TODO: support arm cpu - LOG_INFO("InvertBoolean hook simd type: {}", simd_type); -} - -static void -boolean_hook() { - all_boolean_hook(); - invert_boolean_hook(); - logical_boolean_hook(); -} - -static void -equal_val_hook() { - static std::mutex hook_mutex; - std::lock_guard lock(hook_mutex); - std::string simd_type = "REF"; -#if defined(__x86_64__) - // Only support avx512 for now - if (cpu_support_avx512()) { - simd_type = "AVX512"; - equal_val_int8_t = EqualValAVX512; - equal_val_int16_t = EqualValAVX512; - equal_val_int32_t = EqualValAVX512; - equal_val_int64_t = EqualValAVX512; - equal_val_float = EqualValAVX512; - equal_val_double = EqualValAVX512; - } -#endif - // TODO: support arm cpu - LOG_INFO("equal val hook simd type: {} ", simd_type); -} - -static void -less_val_hook() { - static std::mutex hook_mutex; - std::lock_guard lock(hook_mutex); - std::string simd_type = "REF"; -#if defined(__x86_64__) - // Only support avx512 for now - if (cpu_support_avx512()) { - simd_type = "AVX512"; - less_val_int8_t = LessValAVX512; - less_val_int16_t = LessValAVX512; - less_val_int32_t = LessValAVX512; - less_val_int64_t = LessValAVX512; - less_val_float = LessValAVX512; - less_val_double = LessValAVX512; - } -#endif - // TODO: support arm cpu - LOG_INFO("less than val hook simd type:{} ", simd_type); -} - -static void -greater_val_hook() { - static std::mutex hook_mutex; - std::lock_guard lock(hook_mutex); - std::string simd_type = "REF"; -#if defined(__x86_64__) - // Only support avx512 for now - if (cpu_support_avx512()) { - simd_type = "AVX512"; - greater_val_int8_t = GreaterValAVX512; - greater_val_int16_t = GreaterValAVX512; - greater_val_int32_t = GreaterValAVX512; - greater_val_int64_t = GreaterValAVX512; - greater_val_float = GreaterValAVX512; - greater_val_double = GreaterValAVX512; - } -#endif - // TODO: support arm cpu - LOG_INFO("greater than val hook simd type: {} ", simd_type); -} - -static void -less_equal_val_hook() { - static std::mutex hook_mutex; - std::lock_guard lock(hook_mutex); - std::string simd_type = "REF"; -#if defined(__x86_64__) - // Only support avx512 for now - if (cpu_support_avx512()) { - simd_type = "AVX512"; - less_equal_val_int8_t = LessEqualValAVX512; - less_equal_val_int16_t = LessEqualValAVX512; - less_equal_val_int32_t = LessEqualValAVX512; - less_equal_val_int64_t = LessEqualValAVX512; - less_equal_val_float = LessEqualValAVX512; - less_equal_val_double = LessEqualValAVX512; - } -#endif - // TODO: support arm cpu - LOG_INFO("less equal than val hook simd type: {} ", simd_type); -} - -static void -greater_equal_val_hook() { - static std::mutex hook_mutex; - std::lock_guard lock(hook_mutex); - std::string simd_type = "REF"; -#if defined(__x86_64__) - // Only support avx512 for now - if (cpu_support_avx512()) { - simd_type = "AVX512"; - greater_equal_val_int8_t = GreaterEqualValAVX512; - greater_equal_val_int16_t = GreaterEqualValAVX512; - greater_equal_val_int32_t = GreaterEqualValAVX512; - greater_equal_val_int64_t = GreaterEqualValAVX512; - greater_equal_val_float = GreaterEqualValAVX512; - greater_equal_val_double = GreaterEqualValAVX512; - } -#endif - // TODO: support arm cpu - LOG_INFO("greater equal than val hook simd type: {} ", simd_type); -} - -static void -not_equal_val_hook() { - static std::mutex hook_mutex; - std::lock_guard lock(hook_mutex); - std::string simd_type = "REF"; -#if defined(__x86_64__) - // Only support avx512 for now - if (cpu_support_avx512()) { - simd_type = "AVX512"; - not_equal_val_int8_t = NotEqualValAVX512; - not_equal_val_int16_t = NotEqualValAVX512; - not_equal_val_int32_t = NotEqualValAVX512; - not_equal_val_int64_t = NotEqualValAVX512; - not_equal_val_float = NotEqualValAVX512; - not_equal_val_double = NotEqualValAVX512; - } -#endif - // TODO: support arm cpu - LOG_INFO("not equal val hook simd type: {}", simd_type); -} - -static void -equal_col_hook() { - static std::mutex hook_mutex; - std::lock_guard lock(hook_mutex); - std::string simd_type = "REF"; -#if defined(__x86_64__) - // Only support avx512 for now - if (cpu_support_avx512()) { - simd_type = "AVX512"; - equal_col_int8_t = EqualColumnAVX512; - equal_col_int16_t = EqualColumnAVX512; - equal_col_int32_t = EqualColumnAVX512; - equal_col_int64_t = EqualColumnAVX512; - equal_col_float = EqualColumnAVX512; - equal_col_double = EqualColumnAVX512; - } -#endif - // TODO: support arm cpu - LOG_INFO("equal column hook simd type:{} ", simd_type); -} - -static void -less_col_hook() { - static std::mutex hook_mutex; - std::lock_guard lock(hook_mutex); - std::string simd_type = "REF"; -#if defined(__x86_64__) - // Only support avx512 for now - if (cpu_support_avx512()) { - simd_type = "AVX512"; - less_col_int8_t = LessColumnAVX512; - less_col_int16_t = LessColumnAVX512; - less_col_int32_t = LessColumnAVX512; - less_col_int64_t = LessColumnAVX512; - less_col_float = LessColumnAVX512; - less_col_double = LessColumnAVX512; - } -#endif - // TODO: support arm cpu - LOG_INFO("less than column hook simd type:{} ", simd_type); -} - -static void -greater_col_hook() { - static std::mutex hook_mutex; - std::lock_guard lock(hook_mutex); - std::string simd_type = "REF"; -#if defined(__x86_64__) - // Only support avx512 for now - if (cpu_support_avx512()) { - simd_type = "AVX512"; - greater_col_int8_t = GreaterColumnAVX512; - greater_col_int16_t = GreaterColumnAVX512; - greater_col_int32_t = GreaterColumnAVX512; - greater_col_int64_t = GreaterColumnAVX512; - greater_col_float = GreaterColumnAVX512; - greater_col_double = GreaterColumnAVX512; - } -#endif - // TODO: support arm cpu - LOG_INFO("greater than column hook simd type:{} ", simd_type); -} - -static void -less_equal_col_hook() { - static std::mutex hook_mutex; - std::lock_guard lock(hook_mutex); - std::string simd_type = "REF"; -#if defined(__x86_64__) - // Only support avx512 for now - if (cpu_support_avx512()) { - simd_type = "AVX512"; - less_equal_col_int8_t = LessEqualColumnAVX512; - less_equal_col_int16_t = LessEqualColumnAVX512; - less_equal_col_int32_t = LessEqualColumnAVX512; - less_equal_col_int64_t = LessEqualColumnAVX512; - less_equal_col_float = LessEqualColumnAVX512; - less_equal_col_double = LessEqualColumnAVX512; - } -#endif - // TODO: support arm cpu - LOG_INFO("less equal than column hook simd type: {}", simd_type); -} - -static void -greater_equal_col_hook() { - static std::mutex hook_mutex; - std::lock_guard lock(hook_mutex); - std::string simd_type = "REF"; -#if defined(__x86_64__) - // Only support avx512 for now - if (cpu_support_avx512()) { - simd_type = "AVX512"; - greater_equal_col_int8_t = GreaterEqualColumnAVX512; - greater_equal_col_int16_t = GreaterEqualColumnAVX512; - greater_equal_col_int32_t = GreaterEqualColumnAVX512; - greater_equal_col_int64_t = GreaterEqualColumnAVX512; - greater_equal_col_float = GreaterEqualColumnAVX512; - greater_equal_col_double = GreaterEqualColumnAVX512; - } -#endif - // TODO: support arm cpu - LOG_INFO("greater equal than column hook simd type:{} ", simd_type); -} - -static void -not_equal_col_hook() { - static std::mutex hook_mutex; - std::lock_guard lock(hook_mutex); - std::string simd_type = "REF"; -#if defined(__x86_64__) - // Only support avx512 for now - if (cpu_support_avx512()) { - simd_type = "AVX512"; - not_equal_col_int8_t = NotEqualColumnAVX512; - not_equal_col_int16_t = NotEqualColumnAVX512; - not_equal_col_int32_t = NotEqualColumnAVX512; - not_equal_col_int64_t = NotEqualColumnAVX512; - not_equal_col_float = NotEqualColumnAVX512; - not_equal_col_double = NotEqualColumnAVX512; - } -#endif - // TODO: support arm cpu - LOG_INFO("not equal column hook simd type: {}", simd_type); -} - -static void -compare_hook() { - equal_val_hook(); - less_val_hook(); - greater_val_hook(); - less_equal_val_hook(); - greater_equal_val_hook(); - not_equal_val_hook(); - equal_col_hook(); - less_col_hook(); - greater_col_hook(); - less_equal_col_hook(); - greater_equal_col_hook(); - not_equal_col_hook(); -} - -static int init_hook_ = []() { - bitset_hook(); - boolean_hook(); - find_term_hook(); - compare_hook(); - return 0; -}(); - -} // namespace simd -} // namespace milvus diff --git a/internal/core/src/simd/hook.h b/internal/core/src/simd/hook.h deleted file mode 100644 index 2ffbbd81442d..000000000000 --- a/internal/core/src/simd/hook.h +++ /dev/null @@ -1,161 +0,0 @@ -// Copyright (C) 2019-2023 Zilliz. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software distributed under the License -// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -// or implied. See the License for the specific language governing permissions and limitations under the License. - -#pragma once - -#include -#include - -#include "common.h" -namespace milvus { -namespace simd { - -#if defined(__x86_64__) -bool -cpu_support_avx512(); -bool -cpu_support_avx2(); -bool -cpu_support_sse4_2(); -#endif - -extern BitsetBlockType (*get_bitset_block)(const bool* src); -extern bool (*all_false)(const bool* src, int64_t size); -extern bool (*all_true)(const bool* src, int64_t size); -extern void (*invert_bool)(bool* src, int64_t size); -extern void (*and_bool)(bool* left, bool* right, int64_t size); -extern void (*or_bool)(bool* left, bool* right, int64_t size); - -template -using FindTermPtr = bool (*)(const T* src, size_t size, T val); -#define EXTERN_FIND_TERM_PTR(type) extern FindTermPtr find_term_##type; - -EXTERN_FIND_TERM_PTR(bool) -EXTERN_FIND_TERM_PTR(int8_t) -EXTERN_FIND_TERM_PTR(int16_t) -EXTERN_FIND_TERM_PTR(int32_t) -EXTERN_FIND_TERM_PTR(int64_t) -EXTERN_FIND_TERM_PTR(float) -EXTERN_FIND_TERM_PTR(double) - -// Compare val function register -// Such as A == 10, A < 10... -template -using CompareValPtr = void (*)(const T* src, size_t size, T val, bool* res); -#define EXTERN_COMPARE_VAL_PTR(prefix, type) \ - extern CompareValPtr prefix##_##type; - -// Compare column function register -// Such as A == B, A < B... -template -using CompareColPtr = - void (*)(const T* left, const T* right, size_t size, bool* res); -#define EXTERN_COMPARE_COL_PTR(prefix, type) \ - extern CompareColPtr prefix##_##type; - -EXTERN_COMPARE_VAL_PTR(equal_val, bool) -EXTERN_COMPARE_VAL_PTR(equal_val, int8_t) -EXTERN_COMPARE_VAL_PTR(equal_val, int16_t) -EXTERN_COMPARE_VAL_PTR(equal_val, int32_t) -EXTERN_COMPARE_VAL_PTR(equal_val, int64_t) -EXTERN_COMPARE_VAL_PTR(equal_val, float) -EXTERN_COMPARE_VAL_PTR(equal_val, double) - -EXTERN_COMPARE_VAL_PTR(less_val, bool) -EXTERN_COMPARE_VAL_PTR(less_val, int8_t) -EXTERN_COMPARE_VAL_PTR(less_val, int16_t) -EXTERN_COMPARE_VAL_PTR(less_val, int32_t) -EXTERN_COMPARE_VAL_PTR(less_val, int64_t) -EXTERN_COMPARE_VAL_PTR(less_val, float) -EXTERN_COMPARE_VAL_PTR(less_val, double) - -EXTERN_COMPARE_VAL_PTR(greater_val, bool) -EXTERN_COMPARE_VAL_PTR(greater_val, int8_t) -EXTERN_COMPARE_VAL_PTR(greater_val, int16_t) -EXTERN_COMPARE_VAL_PTR(greater_val, int32_t) -EXTERN_COMPARE_VAL_PTR(greater_val, int64_t) -EXTERN_COMPARE_VAL_PTR(greater_val, float) -EXTERN_COMPARE_VAL_PTR(greater_val, double) - -EXTERN_COMPARE_VAL_PTR(less_equal_val, bool) -EXTERN_COMPARE_VAL_PTR(less_equal_val, int8_t) -EXTERN_COMPARE_VAL_PTR(less_equal_val, int16_t) -EXTERN_COMPARE_VAL_PTR(less_equal_val, int32_t) -EXTERN_COMPARE_VAL_PTR(less_equal_val, int64_t) -EXTERN_COMPARE_VAL_PTR(less_equal_val, float) -EXTERN_COMPARE_VAL_PTR(less_equal_val, double) - -EXTERN_COMPARE_VAL_PTR(greater_equal_val, bool) -EXTERN_COMPARE_VAL_PTR(greater_equal_val, int8_t) -EXTERN_COMPARE_VAL_PTR(greater_equal_val, int16_t) -EXTERN_COMPARE_VAL_PTR(greater_equal_val, int32_t) -EXTERN_COMPARE_VAL_PTR(greater_equal_val, int64_t) -EXTERN_COMPARE_VAL_PTR(greater_equal_val, float) -EXTERN_COMPARE_VAL_PTR(greater_equal_val, double) - -EXTERN_COMPARE_VAL_PTR(not_equal_val, bool) -EXTERN_COMPARE_VAL_PTR(not_equal_val, int8_t) -EXTERN_COMPARE_VAL_PTR(not_equal_val, int16_t) -EXTERN_COMPARE_VAL_PTR(not_equal_val, int32_t) -EXTERN_COMPARE_VAL_PTR(not_equal_val, int64_t) -EXTERN_COMPARE_VAL_PTR(not_equal_val, float) -EXTERN_COMPARE_VAL_PTR(not_equal_val, double) - -EXTERN_COMPARE_COL_PTR(equal_col, bool) -EXTERN_COMPARE_COL_PTR(equal_col, int8_t) -EXTERN_COMPARE_COL_PTR(equal_col, int16_t) -EXTERN_COMPARE_COL_PTR(equal_col, int32_t) -EXTERN_COMPARE_COL_PTR(equal_col, int64_t) -EXTERN_COMPARE_COL_PTR(equal_col, float) -EXTERN_COMPARE_COL_PTR(equal_col, double) - -EXTERN_COMPARE_COL_PTR(less_col, bool) -EXTERN_COMPARE_COL_PTR(less_col, int8_t) -EXTERN_COMPARE_COL_PTR(less_col, int16_t) -EXTERN_COMPARE_COL_PTR(less_col, int32_t) -EXTERN_COMPARE_COL_PTR(less_col, int64_t) -EXTERN_COMPARE_COL_PTR(less_col, float) -EXTERN_COMPARE_COL_PTR(less_col, double) - -EXTERN_COMPARE_COL_PTR(greater_col, bool) -EXTERN_COMPARE_COL_PTR(greater_col, int8_t) -EXTERN_COMPARE_COL_PTR(greater_col, int16_t) -EXTERN_COMPARE_COL_PTR(greater_col, int32_t) -EXTERN_COMPARE_COL_PTR(greater_col, int64_t) -EXTERN_COMPARE_COL_PTR(greater_col, float) -EXTERN_COMPARE_COL_PTR(greater_col, double) - -EXTERN_COMPARE_COL_PTR(less_equal_col, bool) -EXTERN_COMPARE_COL_PTR(less_equal_col, int8_t) -EXTERN_COMPARE_COL_PTR(less_equal_col, int16_t) -EXTERN_COMPARE_COL_PTR(less_equal_col, int32_t) -EXTERN_COMPARE_COL_PTR(less_equal_col, int64_t) -EXTERN_COMPARE_COL_PTR(less_equal_col, float) -EXTERN_COMPARE_COL_PTR(less_equal_col, double) - -EXTERN_COMPARE_COL_PTR(greater_equal_col, bool) -EXTERN_COMPARE_COL_PTR(greater_equal_col, int8_t) -EXTERN_COMPARE_COL_PTR(greater_equal_col, int16_t) -EXTERN_COMPARE_COL_PTR(greater_equal_col, int32_t) -EXTERN_COMPARE_COL_PTR(greater_equal_col, int64_t) -EXTERN_COMPARE_COL_PTR(greater_equal_col, float) -EXTERN_COMPARE_COL_PTR(greater_equal_col, double) - -EXTERN_COMPARE_COL_PTR(not_equal_col, bool) -EXTERN_COMPARE_COL_PTR(not_equal_col, int8_t) -EXTERN_COMPARE_COL_PTR(not_equal_col, int16_t) -EXTERN_COMPARE_COL_PTR(not_equal_col, int32_t) -EXTERN_COMPARE_COL_PTR(not_equal_col, int64_t) -EXTERN_COMPARE_COL_PTR(not_equal_col, float) -EXTERN_COMPARE_COL_PTR(not_equal_col, double) - -} // namespace simd -} // namespace milvus diff --git a/internal/core/src/simd/interface.h b/internal/core/src/simd/interface.h deleted file mode 100644 index e93a5c31dc94..000000000000 --- a/internal/core/src/simd/interface.h +++ /dev/null @@ -1,264 +0,0 @@ -// Copyright (C) 2019-2023 Zilliz. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software distributed under the License -// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -// or implied. See the License for the specific language governing permissions and limitations under the License. - -#pragma once - -#include "hook.h" -namespace milvus { -namespace simd { - -#define DISPATCH_FIND_TERM_SIMD_FUNC(type) \ - if constexpr (std::is_same_v) { \ - return milvus::simd::find_term_##type(data, size, val); \ - } - -#define DISPATCH_COMPARE_VAL_SIMD_FUNC(prefix, type) \ - if constexpr (std::is_same_v) { \ - return milvus::simd::prefix##_##type(data, size, val, res); \ - } - -#define DISPATCH_COMPARE_COL_SIMD_FUNC(prefix, type) \ - if constexpr (std::is_same_v) { \ - return milvus::simd::prefix##_##type(left, right, size, res); \ - } - -template -bool -find_term_func(const T* data, size_t size, T val) { - static_assert( - std::is_integral::value || std::is_floating_point::value, - "T must be integral or float/double type"); - - DISPATCH_FIND_TERM_SIMD_FUNC(bool) - DISPATCH_FIND_TERM_SIMD_FUNC(int8_t) - DISPATCH_FIND_TERM_SIMD_FUNC(int16_t) - DISPATCH_FIND_TERM_SIMD_FUNC(int32_t) - DISPATCH_FIND_TERM_SIMD_FUNC(int64_t) - DISPATCH_FIND_TERM_SIMD_FUNC(float) - DISPATCH_FIND_TERM_SIMD_FUNC(double) -} - -template -void -equal_val_func(const T* data, int64_t size, T val, bool* res) { - static_assert( - std::is_integral::value || std::is_floating_point::value, - "T must be integral or float/double type"); - - DISPATCH_COMPARE_VAL_SIMD_FUNC(equal_val, bool) - DISPATCH_COMPARE_VAL_SIMD_FUNC(equal_val, int8_t) - DISPATCH_COMPARE_VAL_SIMD_FUNC(equal_val, int16_t) - DISPATCH_COMPARE_VAL_SIMD_FUNC(equal_val, int32_t) - DISPATCH_COMPARE_VAL_SIMD_FUNC(equal_val, int64_t) - DISPATCH_COMPARE_VAL_SIMD_FUNC(equal_val, float) - DISPATCH_COMPARE_VAL_SIMD_FUNC(equal_val, double) -} - -template -void -less_val_func(const T* data, int64_t size, T val, bool* res) { - static_assert( - std::is_integral::value || std::is_floating_point::value, - "T must be integral or float/double type"); - - DISPATCH_COMPARE_VAL_SIMD_FUNC(less_val, bool) - DISPATCH_COMPARE_VAL_SIMD_FUNC(less_val, int8_t) - DISPATCH_COMPARE_VAL_SIMD_FUNC(less_val, int16_t) - DISPATCH_COMPARE_VAL_SIMD_FUNC(less_val, int32_t) - DISPATCH_COMPARE_VAL_SIMD_FUNC(less_val, int64_t) - DISPATCH_COMPARE_VAL_SIMD_FUNC(less_val, float) - DISPATCH_COMPARE_VAL_SIMD_FUNC(less_val, double) -} - -template -void -greater_val_func(const T* data, int64_t size, T val, bool* res) { - static_assert( - std::is_integral::value || std::is_floating_point::value, - "T must be integral or float/double type"); - - DISPATCH_COMPARE_VAL_SIMD_FUNC(greater_val, bool) - DISPATCH_COMPARE_VAL_SIMD_FUNC(greater_val, int8_t) - DISPATCH_COMPARE_VAL_SIMD_FUNC(greater_val, int16_t) - DISPATCH_COMPARE_VAL_SIMD_FUNC(greater_val, int32_t) - DISPATCH_COMPARE_VAL_SIMD_FUNC(greater_val, int64_t) - DISPATCH_COMPARE_VAL_SIMD_FUNC(greater_val, float) - DISPATCH_COMPARE_VAL_SIMD_FUNC(greater_val, double) -} - -template -void -less_equal_val_func(const T* data, int64_t size, T val, bool* res) { - static_assert( - std::is_integral::value || std::is_floating_point::value, - "T must be integral or float/double type"); - - DISPATCH_COMPARE_VAL_SIMD_FUNC(less_equal_val, bool) - DISPATCH_COMPARE_VAL_SIMD_FUNC(less_equal_val, int8_t) - DISPATCH_COMPARE_VAL_SIMD_FUNC(less_equal_val, int16_t) - DISPATCH_COMPARE_VAL_SIMD_FUNC(less_equal_val, int32_t) - DISPATCH_COMPARE_VAL_SIMD_FUNC(less_equal_val, int64_t) - DISPATCH_COMPARE_VAL_SIMD_FUNC(less_equal_val, float) - DISPATCH_COMPARE_VAL_SIMD_FUNC(less_equal_val, double) -} - -template -void -greater_equal_val_func(const T* data, int64_t size, T val, bool* res) { - static_assert( - std::is_integral::value || std::is_floating_point::value, - "T must be integral or float/double type"); - - DISPATCH_COMPARE_VAL_SIMD_FUNC(greater_equal_val, bool) - DISPATCH_COMPARE_VAL_SIMD_FUNC(greater_equal_val, int8_t) - DISPATCH_COMPARE_VAL_SIMD_FUNC(greater_equal_val, int16_t) - DISPATCH_COMPARE_VAL_SIMD_FUNC(greater_equal_val, int32_t) - DISPATCH_COMPARE_VAL_SIMD_FUNC(greater_equal_val, int64_t) - DISPATCH_COMPARE_VAL_SIMD_FUNC(greater_equal_val, float) - DISPATCH_COMPARE_VAL_SIMD_FUNC(greater_equal_val, double) -} - -template -void -not_equal_val_func(const T* data, int64_t size, T val, bool* res) { - static_assert( - std::is_integral::value || std::is_floating_point::value, - "T must be integral or float/double type"); - - DISPATCH_COMPARE_VAL_SIMD_FUNC(not_equal_val, bool) - DISPATCH_COMPARE_VAL_SIMD_FUNC(not_equal_val, int8_t) - DISPATCH_COMPARE_VAL_SIMD_FUNC(not_equal_val, int16_t) - DISPATCH_COMPARE_VAL_SIMD_FUNC(not_equal_val, int32_t) - DISPATCH_COMPARE_VAL_SIMD_FUNC(not_equal_val, int64_t) - DISPATCH_COMPARE_VAL_SIMD_FUNC(not_equal_val, float) - DISPATCH_COMPARE_VAL_SIMD_FUNC(not_equal_val, double) -} - -template -void -equal_col_func(const T* left, const T* right, int64_t size, bool* res) { - static_assert( - std::is_integral::value || std::is_floating_point::value, - "T must be integral or float/double type"); - - DISPATCH_COMPARE_COL_SIMD_FUNC(equal_col, bool) - DISPATCH_COMPARE_COL_SIMD_FUNC(equal_col, int8_t) - DISPATCH_COMPARE_COL_SIMD_FUNC(equal_col, int16_t) - DISPATCH_COMPARE_COL_SIMD_FUNC(equal_col, int32_t) - DISPATCH_COMPARE_COL_SIMD_FUNC(equal_col, int64_t) - DISPATCH_COMPARE_COL_SIMD_FUNC(equal_col, float) - DISPATCH_COMPARE_COL_SIMD_FUNC(equal_col, double) -} - -template -void -less_col_func(const T* left, const T* right, int64_t size, bool* res) { - static_assert( - std::is_integral::value || std::is_floating_point::value, - "T must be integral or float/double type"); - - DISPATCH_COMPARE_COL_SIMD_FUNC(less_col, bool) - DISPATCH_COMPARE_COL_SIMD_FUNC(less_col, int8_t) - DISPATCH_COMPARE_COL_SIMD_FUNC(less_col, int16_t) - DISPATCH_COMPARE_COL_SIMD_FUNC(less_col, int32_t) - DISPATCH_COMPARE_COL_SIMD_FUNC(less_col, int64_t) - DISPATCH_COMPARE_COL_SIMD_FUNC(less_col, float) - DISPATCH_COMPARE_COL_SIMD_FUNC(less_col, double) -} - -template -void -greater_col_func(const T* left, const T* right, int64_t size, bool* res) { - static_assert( - std::is_integral::value || std::is_floating_point::value, - "T must be integral or float/double type"); - - DISPATCH_COMPARE_COL_SIMD_FUNC(greater_col, bool) - DISPATCH_COMPARE_COL_SIMD_FUNC(greater_col, int8_t) - DISPATCH_COMPARE_COL_SIMD_FUNC(greater_col, int16_t) - DISPATCH_COMPARE_COL_SIMD_FUNC(greater_col, int32_t) - DISPATCH_COMPARE_COL_SIMD_FUNC(greater_col, int64_t) - DISPATCH_COMPARE_COL_SIMD_FUNC(greater_col, float) - DISPATCH_COMPARE_COL_SIMD_FUNC(greater_col, double) -} - -template -void -less_equal_col_func(const T* left, const T* right, int64_t size, bool* res) { - static_assert( - std::is_integral::value || std::is_floating_point::value, - "T must be integral or float/double type"); - - DISPATCH_COMPARE_COL_SIMD_FUNC(less_equal_col, bool) - DISPATCH_COMPARE_COL_SIMD_FUNC(less_equal_col, int8_t) - DISPATCH_COMPARE_COL_SIMD_FUNC(less_equal_col, int16_t) - DISPATCH_COMPARE_COL_SIMD_FUNC(less_equal_col, int32_t) - DISPATCH_COMPARE_COL_SIMD_FUNC(less_equal_col, int64_t) - DISPATCH_COMPARE_COL_SIMD_FUNC(less_equal_col, float) - DISPATCH_COMPARE_COL_SIMD_FUNC(less_equal_col, double) -} - -template -void -greater_equal_col_func(const T* left, const T* right, int64_t size, bool* res) { - static_assert( - std::is_integral::value || std::is_floating_point::value, - "T must be integral or float/double type"); - - DISPATCH_COMPARE_COL_SIMD_FUNC(greater_equal_col, bool) - DISPATCH_COMPARE_COL_SIMD_FUNC(greater_equal_col, int8_t) - DISPATCH_COMPARE_COL_SIMD_FUNC(greater_equal_col, int16_t) - DISPATCH_COMPARE_COL_SIMD_FUNC(greater_equal_col, int32_t) - DISPATCH_COMPARE_COL_SIMD_FUNC(greater_equal_col, int64_t) - DISPATCH_COMPARE_COL_SIMD_FUNC(greater_equal_col, float) - DISPATCH_COMPARE_COL_SIMD_FUNC(greater_equal_col, double) -} - -template -void -not_equal_col_func(const T* left, const T* right, int64_t size, bool* res) { - static_assert( - std::is_integral::value || std::is_floating_point::value, - "T must be integral or float/double type"); - - DISPATCH_COMPARE_COL_SIMD_FUNC(not_equal_col, bool) - DISPATCH_COMPARE_COL_SIMD_FUNC(not_equal_col, int8_t) - DISPATCH_COMPARE_COL_SIMD_FUNC(not_equal_col, int16_t) - DISPATCH_COMPARE_COL_SIMD_FUNC(not_equal_col, int32_t) - DISPATCH_COMPARE_COL_SIMD_FUNC(not_equal_col, int64_t) - DISPATCH_COMPARE_COL_SIMD_FUNC(not_equal_col, float) - DISPATCH_COMPARE_COL_SIMD_FUNC(not_equal_col, double) -} - -template -void -compare_col_func(CompareType cmp_type, - const T* left, - const T* right, - int64_t size, - bool* res) { - if (cmp_type == CompareType::EQ) { - equal_col_func(left, right, size, res); - } else if (cmp_type == CompareType::NEQ) { - not_equal_col_func(left, right, size, res); - } else if (cmp_type == CompareType::GE) { - greater_equal_col_func(left, right, size, res); - } else if (cmp_type == CompareType::GT) { - greater_col_func(left, right, size, res); - } else if (cmp_type == CompareType::LE) { - less_equal_col_func(left, right, size, res); - } else if (cmp_type == CompareType::LT) { - less_col_func(left, right, size, res); - } -} - -} // namespace simd -} // namespace milvus diff --git a/internal/core/src/simd/neon.cpp b/internal/core/src/simd/neon.cpp deleted file mode 100644 index 6bdda9138e77..000000000000 --- a/internal/core/src/simd/neon.cpp +++ /dev/null @@ -1,123 +0,0 @@ -// Copyright (C) 2019-2023 Zilliz. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software distributed under the License -// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -// or implied. See the License for the specific language governing permissions and limitations under the License. - -#if defined(__ARM_NEON) - -#include "neon.h" - -#include -#include - -namespace milvus { -namespace simd { - -bool -AllFalseNEON(const bool* src, int64_t size) { - int num_chunk = size / 16; - - const uint8_t* ptr = reinterpret_cast(src); - for (size_t i = 0; i < num_chunk * 16; i += 16) { - uint8x16_t data = vld1q_u8(ptr + i); - if (vmaxvq_u8(data) != 0) { - return false; - } - } - - for (size_t i = num_chunk * 16; i < size; ++i) { - if (src[i]) { - return false; - } - } - - return true; -} - -bool -AllTrueNEON(const bool* src, int64_t size) { - int num_chunk = size / 16; - - const uint8_t* ptr = reinterpret_cast(src); - for (size_t i = 0; i < num_chunk * 16; i += 16) { - uint8x16_t data = vld1q_u8(ptr + i); - if (vminvq_u8(data) == 0) { - return false; - } - } - - for (size_t i = num_chunk * 16; i < size; ++i) { - if (!src[i]) { - return false; - } - } - - return true; -} - -void -InvertBoolNEON(bool* src, int64_t size) { - int num_chunk = size / 16; - uint8x16_t mask = vdupq_n_u8(0x01); - uint8_t* ptr = reinterpret_cast(src); - for (size_t i = 0; i < num_chunk * 16; i += 16) { - uint8x16_t data = vld1q_u8(ptr + i); - - uint8x16_t flipped = veorq_u8(data, mask); - - vst1q_u8(ptr + i, flipped); - } - - for (size_t i = num_chunk * 16; i < size; ++i) { - src[i] = !src[i]; - } -} - -void -AndBoolNEON(bool* left, bool* right, int64_t size) { - int num_chunk = size / 16; - uint8_t* lptr = reinterpret_cast(left); - uint8_t* rptr = reinterpret_cast(right); - for (size_t i = 0; i < num_chunk * 16; i += 16) { - uint8x16_t l_reg = vld1q_u8(lptr + i); - uint8x16_t r_reg = vld1q_u8(rptr + i); - - uint8x16_t res = vandq_u8(l_reg, r_reg); - - vst1q_u8(lptr + i, res); - } - - for (size_t i = num_chunk * 16; i < size; ++i) { - left[i] &= right[i]; - } -} - -void -OrBoolNEON(bool* left, bool* right, int64_t size) { - int num_chunk = size / 16; - uint8_t* lptr = reinterpret_cast(left); - uint8_t* rptr = reinterpret_cast(right); - for (size_t i = 0; i < num_chunk * 16; i += 16) { - uint8x16_t l_reg = vld1q_u8(lptr + i); - uint8x16_t r_reg = vld1q_u8(rptr + i); - - uint8x16_t res = vorrq_u8(l_reg, r_reg); - - vst1q_u8(lptr + i, res); - } - - for (size_t i = num_chunk * 16; i < size; ++i) { - left[i] |= right[i]; - } -} - -} // namespace simd -} // namespace milvus - -#endif \ No newline at end of file diff --git a/internal/core/src/simd/neon.h b/internal/core/src/simd/neon.h deleted file mode 100644 index 2c38a2eb709b..000000000000 --- a/internal/core/src/simd/neon.h +++ /dev/null @@ -1,37 +0,0 @@ -// Copyright (C) 2019-2023 Zilliz. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software distributed under the License -// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -// or implied. See the License for the specific language governing permissions and limitations under the License. - -#pragma once -#include -#include "common.h" -namespace milvus { -namespace simd { - -BitsetBlockType -GetBitsetBlockSSE2(const bool* src); - -bool -AllFalseNEON(const bool* src, int64_t size); - -bool -AllTrueNEON(const bool* src, int64_t size); - -void -InvertBoolNEON(bool* src, int64_t size); - -void -AndBoolNEON(bool* left, bool* right, int64_t size); - -void -OrBoolNEON(bool* left, bool* right, int64_t size); - -} // namespace simd -} // namespace milvus \ No newline at end of file diff --git a/internal/core/src/simd/ref.cpp b/internal/core/src/simd/ref.cpp deleted file mode 100644 index f858fe97d2ef..000000000000 --- a/internal/core/src/simd/ref.cpp +++ /dev/null @@ -1,74 +0,0 @@ -// Copyright (C) 2019-2023 Zilliz. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software distributed under the License -// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -// or implied. See the License for the specific language governing permissions and limitations under the License. - -#include "ref.h" - -namespace milvus { -namespace simd { - -BitsetBlockType -GetBitsetBlockRef(const bool* src) { - BitsetBlockType val = 0; - uint8_t vals[BITSET_BLOCK_SIZE] = {0}; - for (size_t j = 0; j < 8; ++j) { - for (size_t k = 0; k < BITSET_BLOCK_SIZE; ++k) { - vals[k] |= uint8_t(*(src + k * 8 + j)) << j; - } - } - for (size_t j = 0; j < BITSET_BLOCK_SIZE; ++j) { - val |= (BitsetBlockType)(vals[j]) << (8 * j); - } - return val; -} - -bool -AllTrueRef(const bool* src, int64_t size) { - for (size_t i = 0; i < size; ++i) { - if (!src[i]) { - return false; - } - } - return true; -} - -bool -AllFalseRef(const bool* src, int64_t size) { - for (size_t i = 0; i < size; ++i) { - if (src[i]) { - return false; - } - } - return true; -} - -void -InvertBoolRef(bool* src, int64_t size) { - for (size_t i = 0; i < size; ++i) { - src[i] = !src[i]; - } -} - -void -AndBoolRef(bool* left, bool* right, int64_t size) { - for (size_t i = 0; i < size; ++i) { - left[i] &= right[i]; - } -} - -void -OrBoolRef(bool* left, bool* right, int64_t size) { - for (size_t i = 0; i < size; ++i) { - left[i] |= right[i]; - } -} - -} // namespace simd -} // namespace milvus diff --git a/internal/core/src/simd/ref.h b/internal/core/src/simd/ref.h deleted file mode 100644 index f3b7af1a0c62..000000000000 --- a/internal/core/src/simd/ref.h +++ /dev/null @@ -1,143 +0,0 @@ -// Copyright (C) 2019-2023 Zilliz. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software distributed under the License -// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -// or implied. See the License for the specific language governing permissions and limitations under the License. - -#pragma once - -#include "common.h" - -namespace milvus { -namespace simd { - -BitsetBlockType -GetBitsetBlockRef(const bool* src); - -bool -AllTrueRef(const bool* src, int64_t size); - -bool -AllFalseRef(const bool* src, int64_t size); - -void -InvertBoolRef(bool* src, int64_t size); - -void -AndBoolRef(bool* left, bool* right, int64_t size); - -void -OrBoolRef(bool* left, bool* right, int64_t size); - -template -bool -FindTermRef(const T* src, size_t size, T val) { - for (size_t i = 0; i < size; ++i) { - if (src[i] == val) { - return true; - } - } - return false; -} - -template -void -EqualValRef(const T* src, size_t size, T val, bool* res) { - for (size_t i = 0; i < size; ++i) { - res[i] = src[i] == val; - } -} - -template -void -LessValRef(const T* src, size_t size, T val, bool* res) { - for (size_t i = 0; i < size; ++i) { - res[i] = src[i] < val; - } -} - -template -void -GreaterValRef(const T* src, size_t size, T val, bool* res) { - for (size_t i = 0; i < size; ++i) { - res[i] = src[i] > val; - } -} - -template -void -LessEqualValRef(const T* src, size_t size, T val, bool* res) { - for (size_t i = 0; i < size; ++i) { - res[i] = src[i] <= val; - } -} -template -void -GreaterEqualValRef(const T* src, size_t size, T val, bool* res) { - for (size_t i = 0; i < size; ++i) { - res[i] = src[i] >= val; - } -} -template -void -NotEqualValRef(const T* src, size_t size, T val, bool* res) { - for (size_t i = 0; i < size; ++i) { - res[i] = src[i] != val; - } -} - -template -void -EqualColumnRef(const T* left, const T* right, size_t size, bool* res) { - for (size_t i = 0; i < size; ++i) { - res[i] = left[i] == right[i]; - } -} - -template -void -LessColumnRef(const T* left, const T* right, size_t size, bool* res) { - for (size_t i = 0; i < size; ++i) { - res[i] = left[i] < right[i]; - } -} - -template -void -LessEqualColumnRef(const T* left, const T* right, size_t size, bool* res) { - for (size_t i = 0; i < size; ++i) { - res[i] = left[i] <= right[i]; - } -} - -template -void -GreaterColumnRef(const T* left, const T* right, size_t size, bool* res) { - for (size_t i = 0; i < size; ++i) { - res[i] = left[i] > right[i]; - } -} - -template -void -GreaterEqualColumnRef(const T* left, const T* right, size_t size, bool* res) { - for (size_t i = 0; i < size; ++i) { - res[i] = left[i] >= right[i]; - } -} - -template -void -NotEqualColumnRef(const T* left, const T* right, size_t size, bool* res) { - for (size_t i = 0; i < size; ++i) { - res[i] = left[i] != right[i]; - } -} - -} // namespace simd -} // namespace milvus diff --git a/internal/core/src/simd/sse2.cpp b/internal/core/src/simd/sse2.cpp deleted file mode 100644 index c0060ef8563f..000000000000 --- a/internal/core/src/simd/sse2.cpp +++ /dev/null @@ -1,348 +0,0 @@ -// Copyright (C) 2019-2023 Zilliz. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software distributed under the License -// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -// or implied. See the License for the specific language governing permissions and limitations under the License. - -#if defined(__x86_64__) - -#include "sse2.h" - -#include -#include - -namespace milvus { -namespace simd { - -#define ALIGNED(x) __attribute__((aligned(x))) - -BitsetBlockType -GetBitsetBlockSSE2(const bool* src) { - if constexpr (BITSET_BLOCK_SIZE == 8) { - // BitsetBlockType has 64 bits - __m128i highbit = _mm_set1_epi8(0x7F); - uint16_t tmp[4]; - for (size_t i = 0; i < 4; i += 1) { - // Outer function assert (src has 64 * n length) - __m128i boolvec = _mm_loadu_si128((__m128i*)&src[i * 16]); - __m128i highbits = _mm_add_epi8(boolvec, highbit); - tmp[i] = _mm_movemask_epi8(highbits); - } - - __m128i tmpvec = _mm_loadl_epi64((__m128i_u*)tmp); - BitsetBlockType res; - _mm_storel_epi64((__m128i_u*)&res, tmpvec); - return res; - } else { - // Others has 32 bits - __m128i highbit = _mm_set1_epi8(0x7F); - uint16_t tmp[8]; - for (size_t i = 0; i < 2; i += 1) { - __m128i boolvec = _mm_loadu_si128((__m128i*)&src[i * 16]); - __m128i highbits = _mm_add_epi8(boolvec, highbit); - tmp[i] = _mm_movemask_epi8(highbits); - } - - __m128i tmpvec = _mm_loadu_si128((__m128i*)tmp); - BitsetBlockType res[4]; - _mm_storeu_si128((__m128i*)res, tmpvec); - return res[0]; - } -} - -template <> -bool -FindTermSSE2(const bool* src, size_t vec_size, bool val) { - __m128i xmm_target = _mm_set1_epi8(val); - __m128i xmm_data; - size_t num_chunks = vec_size / 16; - for (size_t i = 0; i < num_chunks * 16; i += 16) { - xmm_data = _mm_loadu_si128(reinterpret_cast(src + i)); - __m128i xmm_match = _mm_cmpeq_epi8(xmm_data, xmm_target); - int mask = _mm_movemask_epi8(xmm_match); - if (mask != 0) { - return true; - } - } - - for (size_t i = num_chunks * 16; i < vec_size; ++i) { - if (src[i] == val) { - return true; - } - } - - return false; -} - -template <> -bool -FindTermSSE2(const int8_t* src, size_t vec_size, int8_t val) { - __m128i xmm_target = _mm_set1_epi8(val); - __m128i xmm_data; - size_t num_chunks = vec_size / 16; - for (size_t i = 0; i < num_chunks * 16; i += 16) { - xmm_data = _mm_loadu_si128(reinterpret_cast(src + i)); - __m128i xmm_match = _mm_cmpeq_epi8(xmm_data, xmm_target); - int mask = _mm_movemask_epi8(xmm_match); - if (mask != 0) { - return true; - } - } - - for (size_t i = num_chunks * 16; i < vec_size; ++i) { - if (src[i] == val) { - return true; - } - } - - return false; -} - -template <> -bool -FindTermSSE2(const int16_t* src, size_t vec_size, int16_t val) { - __m128i xmm_target = _mm_set1_epi16(val); - __m128i xmm_data; - size_t num_chunks = vec_size / 8; - for (size_t i = 0; i < num_chunks * 8; i += 8) { - xmm_data = _mm_loadu_si128(reinterpret_cast(src + i)); - __m128i xmm_match = _mm_cmpeq_epi16(xmm_data, xmm_target); - int mask = _mm_movemask_epi8(xmm_match); - if (mask != 0) { - return true; - } - } - - for (size_t i = num_chunks * 8; i < vec_size; ++i) { - if (src[i] == val) { - return true; - } - } - return false; -} - -template <> -bool -FindTermSSE2(const int32_t* src, size_t vec_size, int32_t val) { - size_t num_chunk = vec_size / 4; - size_t remaining_size = vec_size % 4; - - __m128i xmm_target = _mm_set1_epi32(val); - for (size_t i = 0; i < num_chunk * 4; i += 4) { - __m128i xmm_data = - _mm_loadu_si128(reinterpret_cast(src + i)); - __m128i xmm_match = _mm_cmpeq_epi32(xmm_data, xmm_target); - int mask = _mm_movemask_epi8(xmm_match); - if (mask != 0) { - return true; - } - } - - const int32_t* remaining_ptr = src + num_chunk * 4; - if (remaining_size == 0) { - return false; - } else if (remaining_size == 1) { - return *remaining_ptr == val; - } else if (remaining_size == 2) { - __m128i xmm_data = - _mm_set_epi32(0, 0, *(remaining_ptr + 1), *(remaining_ptr)); - __m128i xmm_match = _mm_cmpeq_epi32(xmm_data, xmm_target); - int mask = _mm_movemask_epi8(xmm_match); - if ((mask & 0xFF) != 0) { - return true; - } - } else { - __m128i xmm_data = _mm_set_epi32( - 0, *(remaining_ptr + 2), *(remaining_ptr + 1), *(remaining_ptr)); - __m128i xmm_match = _mm_cmpeq_epi32(xmm_data, xmm_target); - int mask = _mm_movemask_epi8(xmm_match); - if ((mask & 0xFFF) != 0) { - return true; - } - } - return false; -} - -template <> -bool -FindTermSSE2(const int64_t* src, size_t vec_size, int64_t val) { - // _mm_cmpeq_epi64 is not implement in SSE2, compare two int32 instead. - int32_t low = static_cast(val); - int32_t high = static_cast(val >> 32); - size_t num_chunk = vec_size / 2; - size_t remaining_size = vec_size % 2; - - for (int64_t i = 0; i < num_chunk * 2; i += 2) { - __m128i xmm_vec = - _mm_load_si128(reinterpret_cast(src + i)); - - __m128i xmm_low = _mm_set1_epi32(low); - __m128i xmm_high = _mm_set1_epi32(high); - __m128i cmp_low = _mm_cmpeq_epi32(xmm_vec, xmm_low); - __m128i cmp_high = - _mm_cmpeq_epi32(_mm_srli_epi64(xmm_vec, 32), xmm_high); - __m128i cmp_result = _mm_and_si128(cmp_low, cmp_high); - - int mask = _mm_movemask_epi8(cmp_result); - if (mask != 0) { - return true; - } - } - - if (remaining_size == 1) { - if (src[2 * num_chunk] == val) { - return true; - } - } - return false; -} - -template <> -bool -FindTermSSE2(const float* src, size_t vec_size, float val) { - size_t num_chunks = vec_size / 4; - __m128 xmm_target = _mm_set1_ps(val); - for (int i = 0; i < 4 * num_chunks; i += 4) { - __m128 xmm_data = _mm_loadu_ps(src + i); - __m128 xmm_match = _mm_cmpeq_ps(xmm_data, xmm_target); - int mask = _mm_movemask_ps(xmm_match); - if (mask != 0) { - return true; - } - } - - for (size_t i = 4 * num_chunks; i < vec_size; ++i) { - if (src[i] == val) { - return true; - } - } - return false; -} - -template <> -bool -FindTermSSE2(const double* src, size_t vec_size, double val) { - size_t num_chunks = vec_size / 2; - __m128d xmm_target = _mm_set1_pd(val); - for (int i = 0; i < 2 * num_chunks; i += 2) { - __m128d xmm_data = _mm_loadu_pd(src + i); - __m128d xmm_match = _mm_cmpeq_pd(xmm_data, xmm_target); - int mask = _mm_movemask_pd(xmm_match); - if (mask != 0) { - return true; - } - } - - for (size_t i = 2 * num_chunks; i < vec_size; ++i) { - if (src[i] == val) { - return true; - } - } - return false; -} - -void -print_m128i(__m128i v) { - alignas(16) int result[4]; - _mm_store_si128(reinterpret_cast<__m128i*>(result), v); - - for (int i = 0; i < 4; ++i) { - std::cout << std::hex << result[i] << " "; - } - - std::cout << std::endl; -} - -bool -AllFalseSSE2(const bool* src, int64_t size) { - int num_chunk = size / 16; - __m128i highbit = _mm_set1_epi8(0x7F); - for (size_t i = 0; i < num_chunk * 16; i += 16) { - __m128i data = - _mm_loadu_si128(reinterpret_cast(src + i)); - __m128i highbits = _mm_add_epi8(data, highbit); - if (_mm_movemask_epi8(highbits) != 0) { - return false; - } - } - - for (size_t i = num_chunk * 16; i < size; ++i) { - if (src[i]) { - return false; - } - } - return true; -} - -bool -AllTrueSSE2(const bool* src, int64_t size) { - int num_chunk = size / 16; - __m128i highbit = _mm_set1_epi8(0x7F); - for (size_t i = 0; i < num_chunk * 16; i += 16) { - __m128i data = - _mm_loadu_si128(reinterpret_cast(src + i)); - __m128i highbits = _mm_add_epi8(data, highbit); - if (_mm_movemask_epi8(highbits) != 0xFFFF) { - return false; - } - } - - for (size_t i = num_chunk * 16; i < size; ++i) { - if (!src[i]) { - return false; - } - } - return true; -} - -void -InvertBoolSSE2(bool* src, int64_t size) { - int num_chunk = size / 16; - __m128i mask = _mm_set1_epi8(0x01); - for (size_t i = 0; i < num_chunk * 16; i += 16) { - __m128i data = _mm_loadu_si128(reinterpret_cast<__m128i*>(src + i)); - __m128i flipped = _mm_xor_si128(data, mask); - _mm_storeu_si128(reinterpret_cast<__m128i*>(src + i), flipped); - } - for (size_t i = num_chunk * 16; i < size; ++i) { - src[i] = !src[i]; - } -} - -void -AndBoolSSE2(bool* left, bool* right, int64_t size) { - int num_chunk = size / 16; - for (size_t i = 0; i < num_chunk * 16; i += 16) { - __m128i l_reg = _mm_loadu_si128(reinterpret_cast<__m128i*>(left + i)); - __m128i r_reg = _mm_loadu_si128(reinterpret_cast<__m128i*>(right + i)); - __m128i res = _mm_and_si128(l_reg, r_reg); - _mm_storeu_si128(reinterpret_cast<__m128i*>(left + i), res); - } - for (size_t i = num_chunk * 16; i < size; ++i) { - left[i] &= right[i]; - } -} - -void -OrBoolSSE2(bool* left, bool* right, int64_t size) { - int num_chunk = size / 16; - for (size_t i = 0; i < num_chunk * 16; i += 16) { - __m128i l_reg = _mm_loadu_si128(reinterpret_cast<__m128i*>(left + i)); - __m128i r_reg = _mm_loadu_si128(reinterpret_cast<__m128i*>(right + i)); - __m128i res = _mm_or_si128(l_reg, r_reg); - _mm_storeu_si128(reinterpret_cast<__m128i*>(left + i), res); - } - for (size_t i = num_chunk * 16; i < size; ++i) { - left[i] |= right[i]; - } -} - -} // namespace simd -} // namespace milvus - -#endif diff --git a/internal/core/src/simd/sse2.h b/internal/core/src/simd/sse2.h deleted file mode 100644 index 9b53bb869ff3..000000000000 --- a/internal/core/src/simd/sse2.h +++ /dev/null @@ -1,78 +0,0 @@ -// Copyright (C) 2019-2023 Zilliz. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software distributed under the License -// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -// or implied. See the License for the specific language governing permissions and limitations under the License. - -#pragma once - -#include -#include - -#include -#include - -#include "common.h" -namespace milvus { -namespace simd { - -BitsetBlockType -GetBitsetBlockSSE2(const bool* src); - -bool -AllFalseSSE2(const bool* src, int64_t size); - -bool -AllTrueSSE2(const bool* src, int64_t size); - -void -InvertBoolSSE2(bool* src, int64_t size); - -void -AndBoolSSE2(bool* left, bool* right, int64_t size); - -void -OrBoolSSE2(bool* left, bool* right, int64_t size); - -template -bool -FindTermSSE2(const T* src, size_t vec_size, T va) { - CHECK_SUPPORTED_TYPE(T, "unsupported type for FindTermSSE2"); - return false; -} - -template <> -bool -FindTermSSE2(const bool* src, size_t vec_size, bool val); - -template <> -bool -FindTermSSE2(const int8_t* src, size_t vec_size, int8_t val); - -template <> -bool -FindTermSSE2(const int16_t* src, size_t vec_size, int16_t val); - -template <> -bool -FindTermSSE2(const int32_t* src, size_t vec_size, int32_t val); - -template <> -bool -FindTermSSE2(const int64_t* src, size_t vec_size, int64_t val); - -template <> -bool -FindTermSSE2(const float* src, size_t vec_size, float val); - -template <> -bool -FindTermSSE2(const double* src, size_t vec_size, double val); - -} // namespace simd -} // namespace milvus diff --git a/internal/core/src/simd/sse4.cpp b/internal/core/src/simd/sse4.cpp deleted file mode 100644 index bf3d08c76bc7..000000000000 --- a/internal/core/src/simd/sse4.cpp +++ /dev/null @@ -1,110 +0,0 @@ -// Copyright (C) 2019-2023 Zilliz. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software distributed under the License -// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -// or implied. See the License for the specific language governing permissions and limitations under the License. - -#if defined(__x86_64__) - -#include "sse4.h" -#include "sse2.h" - -#include -#include -#include - -extern "C" { -extern int -sse2_strcmp(const char* s1, const char* s2); -} -namespace milvus { -namespace simd { - -template <> -bool -FindTermSSE4(const int64_t* src, size_t vec_size, int64_t val) { - size_t num_chunk = vec_size / 2; - size_t remaining_size = vec_size % 2; - - __m128i xmm_target = _mm_set1_epi64x(val); - for (size_t i = 0; i < num_chunk * 2; i += 2) { - __m128i xmm_data = - _mm_loadu_si128(reinterpret_cast(src + i)); - __m128i xmm_match = _mm_cmpeq_epi64(xmm_data, xmm_target); - int mask = _mm_movemask_epi8(xmm_match); - if (mask != 0) { - return true; - } - } - if (remaining_size == 1) { - if (src[2 * num_chunk] == val) { - return true; - } - } - return false; -} - -template <> -bool -FindTermSSE4(const std::string* src, size_t vec_size, std::string val) { - for (size_t i = 0; i < vec_size; ++i) { - if (StrCmpSSE4(src[i].c_str(), val.c_str())) { - return true; - } - } - return false; -} - -template <> -bool -FindTermSSE4(const std::string_view* src, - size_t vec_size, - std::string_view val) { - for (size_t i = 0; i < vec_size; ++i) { - if (!StrCmpSSE4(src[i].data(), val.data())) { - return true; - } - } - return false; -} - -int -StrCmpSSE4(const char* s1, const char* s2) { - __m128i* ptr1 = reinterpret_cast<__m128i*>(const_cast(s1)); - __m128i* ptr2 = reinterpret_cast<__m128i*>(const_cast(s2)); - - for (;; ptr1++, ptr2++) { - const __m128i a = _mm_loadu_si128(ptr1); - const __m128i b = _mm_loadu_si128(ptr2); - - const uint8_t mode = _SIDD_UBYTE_OPS | _SIDD_CMP_EQUAL_EACH | - _SIDD_NEGATIVE_POLARITY | _SIDD_LEAST_SIGNIFICANT; - - if (_mm_cmpistrc(a, b, mode)) { - const auto idx = _mm_cmpistri(a, b, mode); - const uint8_t b1 = (reinterpret_cast(ptr1))[idx]; - const uint8_t b2 = (reinterpret_cast(ptr2))[idx]; - - if (b1 < b2) { - return -1; - } else if (b1 > b2) { - return +1; - } else { - return 0; - } - } else if (_mm_cmpistrz(a, b, mode)) { - break; - } - } - return 0; -} - -} // namespace simd -} // namespace milvus - -#endif diff --git a/internal/core/src/simd/sse4.h b/internal/core/src/simd/sse4.h deleted file mode 100644 index 107ab519f73b..000000000000 --- a/internal/core/src/simd/sse4.h +++ /dev/null @@ -1,41 +0,0 @@ -// Copyright (C) 2019-2023 Zilliz. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software distributed under the License -// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -// or implied. See the License for the specific language governing permissions and limitations under the License. - -#pragma once - -#include -#include - -#include -#include - -#include "common.h" -#include "sse2.h" -namespace milvus { -namespace simd { - -template -bool -FindTermSSE4(const T* src, size_t vec_size, T val) { - CHECK_SUPPORTED_TYPE(T, "unsupported type for FindTermSSE2"); - // SSE4 still hava 128bit, using same code with SSE2 - return FindTermSSE2(src, vec_size, val); -} - -template <> -bool -FindTermSSE4(const int64_t* src, size_t vec_size, int64_t val); - -int -StrCmpSSE4(const char* s1, const char* s2); - -} // namespace simd -} // namespace milvus diff --git a/internal/core/unittest/CMakeLists.txt b/internal/core/unittest/CMakeLists.txt index 3318141200c9..ad2a1cba9cbc 100644 --- a/internal/core/unittest/CMakeLists.txt +++ b/internal/core/unittest/CMakeLists.txt @@ -144,16 +144,25 @@ if (LINUX) add_subdirectory(bench) endif () -if (USE_DYNAMIC_SIMD) -add_executable(dynamic_simd_test - test_simd.cpp) - -target_link_libraries(dynamic_simd_test - milvus_simd - milvus_log - gtest - ${CONAN_LIBS}) - -install(TARGETS dynamic_simd_test DESTINATION unittest) -endif() - +# if (USE_DYNAMIC_SIMD) +# add_executable(dynamic_simd_test +# test_simd.cpp) +# +# target_link_libraries(dynamic_simd_test +# milvus_simd +# milvus_log +# gtest +# ${CONAN_LIBS}) +# +# install(TARGETS dynamic_simd_test DESTINATION unittest) +# endif() + +add_executable(bitset_test + test_bitset.cpp +) +target_link_libraries(bitset_test + milvus_bitset + gtest + ${CONAN_LIBS} +) +install(TARGETS bitset_test DESTINATION unittest) diff --git a/internal/core/unittest/test_bitset.cpp b/internal/core/unittest/test_bitset.cpp new file mode 100644 index 000000000000..0ebe660b6ccf --- /dev/null +++ b/internal/core/unittest/test_bitset.cpp @@ -0,0 +1,1601 @@ +// Copyright (C) 2019-2024 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "bitset/bitset.h" +#include "bitset/detail/bit_wise.h" +#include "bitset/detail/element_wise.h" +#include "bitset/detail/element_vectorized.h" +#include "bitset/detail/platform/dynamic.h" +#include "bitset/detail/platform/vectorized_ref.h" + +#if defined(__x86_64__) +#include "bitset/detail/platform/x86/avx2.h" +#include "bitset/detail/platform/x86/avx512.h" +#include "bitset/detail/platform/x86/instruction_set.h" +#endif + +#if defined(__aarch64__) +#include "bitset/detail/platform/arm/neon.h" + +#ifdef __ARM_FEATURE_SVE +#include "bitset/detail/platform/arm/sve.h" +#endif + +#endif + +using namespace milvus::bitset; + +////////////////////////////////////////////////////////////////////////////////////////// + +// * The data is processed using ElementT, +// * A container stores the data using ContainerValueT elements, +// * VectorizerT defines the vectorization. + +template +struct RefImplTraits { + using policy_type = milvus::bitset::detail::BitWiseBitsetPolicy; + using container_type = std::vector; + using bitset_type = + milvus::bitset::Bitset; + using bitset_view = milvus::bitset::BitsetView; +}; + +template +struct ElementImplTraits { + using policy_type = + milvus::bitset::detail::ElementWiseBitsetPolicy; + using container_type = std::vector; + using bitset_type = + milvus::bitset::Bitset; + using bitset_view = milvus::bitset::BitsetView; +}; + +template +struct VectorizedImplTraits { + using policy_type = + milvus::bitset::detail::VectorizedElementWiseBitsetPolicy; + using container_type = std::vector; + using bitset_type = + milvus::bitset::Bitset; + using bitset_view = milvus::bitset::BitsetView; +}; + +////////////////////////////////////////////////////////////////////////////////////////// + +// set running mode to 1 to run a subset of tests +// set running mode to 2 to run benchmarks +// otherwise, all of the tests are run + +#define RUNNING_MODE 1 + +#if RUNNING_MODE == 1 +// short tests +static constexpr bool print_log = false; +static constexpr bool print_timing = false; + +static constexpr size_t typical_sizes[] = {0, 1, 10, 100, 1000}; +static constexpr size_t typical_offsets[] = { + 0, 1, 2, 3, 4, 5, 6, 7, 11, 21, 35, 55, 63, 127, 703}; +static constexpr CompareOpType typical_compare_ops[] = {CompareOpType::EQ, + CompareOpType::GE, + CompareOpType::GT, + CompareOpType::LE, + CompareOpType::LT, + CompareOpType::NE}; +static constexpr RangeType typical_range_types[] = { + RangeType::IncInc, RangeType::IncExc, RangeType::ExcInc, RangeType::ExcExc}; +static constexpr ArithOpType typical_arith_ops[] = {ArithOpType::Add, + ArithOpType::Sub, + ArithOpType::Mul, + ArithOpType::Div, + ArithOpType::Mod}; + +#elif RUNNING_MODE == 2 + +// benchmarks +static constexpr bool print_log = false; +static constexpr bool print_timing = true; + +static constexpr size_t typical_sizes[] = {10000000}; +static constexpr size_t typical_offsets[] = {}; +static constexpr CompareOpType typical_compare_ops[] = {CompareOpType::EQ, + CompareOpType::GE, + CompareOpType::GT, + CompareOpType::LE, + CompareOpType::LT, + CompareOpType::NE}; +static constexpr RangeType typical_range_types[] = { + RangeType::IncInc, RangeType::IncExc, RangeType::ExcInc, RangeType::ExcExc}; +static constexpr ArithOpType typical_arith_ops[] = {ArithOpType::Add, + ArithOpType::Sub, + ArithOpType::Mul, + ArithOpType::Div, + ArithOpType::Mod}; + +#else + +// full tests, mostly used for code coverage +static constexpr bool print_log = false; +static constexpr bool print_timing = false; + +static constexpr size_t typical_sizes[] = {0, + 1, + 10, + 100, + 1000, + 10000, + 2048, + 2056, + 2064, + 2072, + 2080, + 2088, + 2096, + 2104, + 2112}; +static constexpr size_t typical_offsets[] = { + 0, 1, 2, 3, 4, 5, 6, 7, 11, 21, 35, 45, 55, + 63, 127, 512, 520, 528, 536, 544, 556, 564, 572, 580, 703}; +static constexpr CompareOpType typical_compare_ops[] = {CompareOpType::EQ, + CompareOpType::GE, + CompareOpType::GT, + CompareOpType::LE, + CompareOpType::LT, + CompareOpType::NE}; +static constexpr RangeType typical_range_types[] = { + RangeType::IncInc, RangeType::IncExc, RangeType::ExcInc, RangeType::ExcExc}; +static constexpr ArithOpType typical_arith_ops[] = {ArithOpType::Add, + ArithOpType::Sub, + ArithOpType::Mul, + ArithOpType::Div, + ArithOpType::Mod}; + +#define FULL_TESTS 1 +#endif + +////////////////////////////////////////////////////////////////////////////////////////// + +// combinations to run +using Ttypes2 = ::testing::Types< +#if FULL_TESTS == 1 + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, +#endif + + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple + +#if FULL_TESTS == 1 + , + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple +#endif + >; + +// combinations to run +using Ttypes1 = ::testing::Types< +#if FULL_TESTS == 1 + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, +#endif + + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple + +#if FULL_TESTS == 1 + , + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple +#endif + >; + +////////////////////////////////////////////////////////////////////////////////////////// + +struct StopWatch { + using time_type = + std::chrono::time_point; + time_type start; + + StopWatch() { + start = now(); + } + + inline double + elapsed() { + auto current = now(); + return std::chrono::duration(current - start).count(); + } + + static inline time_type + now() { + return std::chrono::high_resolution_clock::now(); + } +}; + +// +template +void +FillRandom(std::vector& t, + std::default_random_engine& rng, + const size_t max_v) { + std::uniform_int_distribution tt(0, max_v); + for (size_t i = 0; i < t.size(); i++) { + t[i] = tt(rng); + } +} + +template +void +FillRandom(BitsetT& bitset, std::default_random_engine& rng) { + std::uniform_int_distribution tt(0, 1); + for (size_t i = 0; i < bitset.size(); i++) { + bitset[i] = (tt(rng) == 0); + } +} + +////////////////////////////////////////////////////////////////////////////////////////// + +// +template +void +TestFindImpl(BitsetT& bitset, const size_t max_v) { + const size_t n = bitset.size(); + + std::default_random_engine rng(123); + std::uniform_int_distribution u(0, max_v); + + std::vector one_pos; + for (size_t i = 0; i < n; i++) { + bool enabled = (u(rng) == 0); + if (enabled) { + one_pos.push_back(i); + bitset[i] = true; + } + } + + StopWatch sw; + + auto bit_idx = bitset.find_first(); + if (!bit_idx.has_value()) { + ASSERT_EQ(one_pos.size(), 0); + return; + } + + for (size_t i = 0; i < one_pos.size(); i++) { + ASSERT_TRUE(bit_idx.has_value()) << n << ", " << max_v; + ASSERT_EQ(bit_idx.value(), one_pos[i]) << n << ", " << max_v; + bit_idx = bitset.find_next(bit_idx.value()); + } + + ASSERT_FALSE(bit_idx.has_value()) + << n << ", " << max_v << ", " << bit_idx.value(); + + if (print_timing) { + printf("elapsed %f\n", sw.elapsed()); + } +} + +template +void +TestFindImpl() { + for (const size_t n : typical_sizes) { + for (const size_t pr : {1, 100}) { + BitsetT bitset(n); + bitset.reset(); + + if (print_log) { + printf("Testing bitset, n=%zd, pr=%zd\n", n, pr); + } + + TestFindImpl(bitset, pr); + + for (const size_t offset : typical_offsets) { + if (offset >= n) { + continue; + } + + bitset.reset(); + auto view = bitset.view(offset); + + if (print_log) { + printf("Testing bitset view, n=%zd, offset=%zd, pr=%zd\n", + n, + offset, + pr); + } + + TestFindImpl(view, pr); + } + } + } +} + +// +TEST(FindRef, f) { + using impl_traits = RefImplTraits; + TestFindImpl(); +} + +// +TEST(FindElement, f) { + using impl_traits = ElementImplTraits; + TestFindImpl(); +} + +// // +// TEST(FindVectorizedAvx2, f) { +// TestFindImpl(); +// } + +////////////////////////////////////////////////////////////////////////////////////////// + +// +template +void +TestInplaceCompareColumnImpl(BitsetT& bitset, CompareOpType op) { + const size_t n = bitset.size(); + constexpr size_t max_v = 2; + + std::vector t(n, 0); + std::vector u(n, 0); + + std::default_random_engine rng(123); + FillRandom(t, rng, max_v); + FillRandom(u, rng, max_v); + + StopWatch sw; + bitset.inplace_compare_column(t.data(), u.data(), n, op); + + if (print_timing) { + printf("elapsed %f\n", sw.elapsed()); + } + + for (size_t i = 0; i < n; i++) { + if (op == CompareOpType::EQ) { + ASSERT_EQ(t[i] == u[i], bitset[i]) << i; + } else if (op == CompareOpType::GE) { + ASSERT_EQ(t[i] >= u[i], bitset[i]) << i; + } else if (op == CompareOpType::GT) { + ASSERT_EQ(t[i] > u[i], bitset[i]) << i; + } else if (op == CompareOpType::LE) { + ASSERT_EQ(t[i] <= u[i], bitset[i]) << i; + } else if (op == CompareOpType::LT) { + ASSERT_EQ(t[i] < u[i], bitset[i]) << i; + } else if (op == CompareOpType::NE) { + ASSERT_EQ(t[i] != u[i], bitset[i]) << i; + } else { + ASSERT_TRUE(false) << "Not implemented"; + } + } +} + +template +void +TestInplaceCompareColumnImpl() { + for (const size_t n : typical_sizes) { + for (const auto op : typical_compare_ops) { + BitsetT bitset(n); + bitset.reset(); + + if (print_log) { + printf("Testing bitset, n=%zd, op=%zd\n", n, (size_t)op); + } + + TestInplaceCompareColumnImpl(bitset, op); + + for (const size_t offset : typical_offsets) { + if (offset >= n) { + continue; + } + + bitset.reset(); + auto view = bitset.view(offset); + + if (print_log) { + printf("Testing bitset view, n=%zd, offset=%zd, op=%zd\n", + n, + offset, + (size_t)op); + } + + TestInplaceCompareColumnImpl(view, op); + } + } + } +} + +// +template +class InplaceCompareColumnSuite : public ::testing::Test {}; + +TYPED_TEST_SUITE_P(InplaceCompareColumnSuite); + +// +TYPED_TEST_P(InplaceCompareColumnSuite, BitWise) { + using impl_traits = RefImplTraits, + std::tuple_element_t<3, TypeParam>>; + TestInplaceCompareColumnImpl, + std::tuple_element_t<1, TypeParam>>(); +} + +// +TYPED_TEST_P(InplaceCompareColumnSuite, ElementWise) { + using impl_traits = ElementImplTraits, + std::tuple_element_t<3, TypeParam>>; + TestInplaceCompareColumnImpl, + std::tuple_element_t<1, TypeParam>>(); +} + +// +TYPED_TEST_P(InplaceCompareColumnSuite, Avx2) { +#if defined(__x86_64__) + using namespace milvus::bitset::detail::x86; + + if (cpu_support_avx2()) { + using impl_traits = + VectorizedImplTraits, + std::tuple_element_t<3, TypeParam>, + milvus::bitset::detail::x86::VectorizedAvx2>; + TestInplaceCompareColumnImpl, + std::tuple_element_t<1, TypeParam>>(); + } +#endif +} + +// +TYPED_TEST_P(InplaceCompareColumnSuite, Avx512) { +#if defined(__x86_64__) + using namespace milvus::bitset::detail::x86; + + if (cpu_support_avx512()) { + using impl_traits = + VectorizedImplTraits, + std::tuple_element_t<3, TypeParam>, + milvus::bitset::detail::x86::VectorizedAvx512>; + TestInplaceCompareColumnImpl, + std::tuple_element_t<1, TypeParam>>(); + } +#endif +} + +// +TYPED_TEST_P(InplaceCompareColumnSuite, Neon) { +#if defined(__aarch64__) + using namespace milvus::bitset::detail::arm; + + using impl_traits = + VectorizedImplTraits, + std::tuple_element_t<3, TypeParam>, + milvus::bitset::detail::arm::VectorizedNeon>; + TestInplaceCompareColumnImpl, + std::tuple_element_t<1, TypeParam>>(); +#endif +} + +// +TYPED_TEST_P(InplaceCompareColumnSuite, Sve) { +#if defined(__aarch64__) && defined(__ARM_FEATURE_SVE) + using namespace milvus::bitset::detail::arm; + + using impl_traits = + VectorizedImplTraits, + std::tuple_element_t<3, TypeParam>, + milvus::bitset::detail::arm::VectorizedSve>; + TestInplaceCompareColumnImpl, + std::tuple_element_t<1, TypeParam>>(); +#endif +} + +// +TYPED_TEST_P(InplaceCompareColumnSuite, Dynamic) { + using impl_traits = + VectorizedImplTraits, + std::tuple_element_t<3, TypeParam>, + milvus::bitset::detail::VectorizedDynamic>; + TestInplaceCompareColumnImpl, + std::tuple_element_t<1, TypeParam>>(); +} + +// +TYPED_TEST_P(InplaceCompareColumnSuite, VecRef) { + using impl_traits = + VectorizedImplTraits, + std::tuple_element_t<3, TypeParam>, + milvus::bitset::detail::VectorizedRef>; + TestInplaceCompareColumnImpl, + std::tuple_element_t<1, TypeParam>>(); +} + +// +REGISTER_TYPED_TEST_SUITE_P(InplaceCompareColumnSuite, + BitWise, + ElementWise, + Avx2, + Avx512, + Neon, + Sve, + Dynamic, + VecRef); + +INSTANTIATE_TYPED_TEST_SUITE_P(InplaceCompareColumnTest, + InplaceCompareColumnSuite, + Ttypes2); + +////////////////////////////////////////////////////////////////////////////////////////// + +// +template +void +TestInplaceCompareValImpl(BitsetT& bitset, CompareOpType op) { + const size_t n = bitset.size(); + constexpr size_t max_v = 3; + constexpr T value = 1; + + std::vector t(n, 0); + + std::default_random_engine rng(123); + FillRandom(t, rng, max_v); + + StopWatch sw; + bitset.inplace_compare_val(t.data(), n, value, op); + + if (print_timing) { + printf("elapsed %f\n", sw.elapsed()); + } + + for (size_t i = 0; i < n; i++) { + if (op == CompareOpType::EQ) { + ASSERT_EQ(t[i] == value, bitset[i]) << i; + } else if (op == CompareOpType::GE) { + ASSERT_EQ(t[i] >= value, bitset[i]) << i; + } else if (op == CompareOpType::GT) { + ASSERT_EQ(t[i] > value, bitset[i]) << i; + } else if (op == CompareOpType::LE) { + ASSERT_EQ(t[i] <= value, bitset[i]) << i; + } else if (op == CompareOpType::LT) { + ASSERT_EQ(t[i] < value, bitset[i]) << i; + } else if (op == CompareOpType::NE) { + ASSERT_EQ(t[i] != value, bitset[i]) << i; + } else { + ASSERT_TRUE(false) << "Not implemented"; + } + } +} + +template +void +TestInplaceCompareValImpl() { + for (const size_t n : typical_sizes) { + for (const auto op : typical_compare_ops) { + BitsetT bitset(n); + bitset.reset(); + + if (print_log) { + printf("Testing bitset, n=%zd, op=%zd\n", n, (size_t)op); + } + + TestInplaceCompareValImpl(bitset, op); + + for (const size_t offset : typical_offsets) { + if (offset >= n) { + continue; + } + + bitset.reset(); + auto view = bitset.view(offset); + + if (print_log) { + printf("Testing bitset view, n=%zd, offset=%zd, op=%zd\n", + n, + offset, + (size_t)op); + } + + TestInplaceCompareValImpl(view, op); + } + } + } +} + +// +template +class InplaceCompareValSuite : public ::testing::Test {}; + +TYPED_TEST_SUITE_P(InplaceCompareValSuite); + +TYPED_TEST_P(InplaceCompareValSuite, BitWise) { + using impl_traits = RefImplTraits, + std::tuple_element_t<2, TypeParam>>; + TestInplaceCompareValImpl>(); +} + +TYPED_TEST_P(InplaceCompareValSuite, ElementWise) { + using impl_traits = ElementImplTraits, + std::tuple_element_t<2, TypeParam>>; + TestInplaceCompareValImpl>(); +} + +TYPED_TEST_P(InplaceCompareValSuite, Avx2) { +#if defined(__x86_64__) + using namespace milvus::bitset::detail::x86; + + if (cpu_support_avx2()) { + using impl_traits = + VectorizedImplTraits, + std::tuple_element_t<2, TypeParam>, + milvus::bitset::detail::x86::VectorizedAvx2>; + TestInplaceCompareValImpl>(); + } +#endif +} + +TYPED_TEST_P(InplaceCompareValSuite, Avx512) { +#if defined(__x86_64__) + using namespace milvus::bitset::detail::x86; + + if (cpu_support_avx512()) { + using impl_traits = + VectorizedImplTraits, + std::tuple_element_t<2, TypeParam>, + milvus::bitset::detail::x86::VectorizedAvx512>; + TestInplaceCompareValImpl>(); + } +#endif +} + +TYPED_TEST_P(InplaceCompareValSuite, Neon) { +#if defined(__aarch64__) + using namespace milvus::bitset::detail::arm; + + using impl_traits = + VectorizedImplTraits, + std::tuple_element_t<2, TypeParam>, + milvus::bitset::detail::arm::VectorizedNeon>; + TestInplaceCompareValImpl>(); +#endif +} + +TYPED_TEST_P(InplaceCompareValSuite, Sve) { +#if defined(__aarch64__) && defined(__ARM_FEATURE_SVE) + using namespace milvus::bitset::detail::arm; + + using impl_traits = + VectorizedImplTraits, + std::tuple_element_t<2, TypeParam>, + milvus::bitset::detail::arm::VectorizedSve>; + TestInplaceCompareValImpl>(); +#endif +} + +TYPED_TEST_P(InplaceCompareValSuite, Dynamic) { + using impl_traits = + VectorizedImplTraits, + std::tuple_element_t<2, TypeParam>, + milvus::bitset::detail::VectorizedDynamic>; + TestInplaceCompareValImpl>(); +} + +TYPED_TEST_P(InplaceCompareValSuite, VecRef) { + using impl_traits = + VectorizedImplTraits, + std::tuple_element_t<2, TypeParam>, + milvus::bitset::detail::VectorizedRef>; + TestInplaceCompareValImpl>(); +} + +// +REGISTER_TYPED_TEST_SUITE_P(InplaceCompareValSuite, + BitWise, + ElementWise, + Avx2, + Avx512, + Neon, + Sve, + Dynamic, + VecRef); + +INSTANTIATE_TYPED_TEST_SUITE_P(InplaceCompareValTest, + InplaceCompareValSuite, + Ttypes1); + +////////////////////////////////////////////////////////////////////////////////////////// + +// +template +void +TestInplaceWithinRangeColumnImpl(BitsetT& bitset, RangeType op) { + const size_t n = bitset.size(); + constexpr size_t max_v = 3; + + std::vector range(n, 0); + std::vector values(n, 0); + + std::vector lower(n, 0); + std::vector upper(n, 0); + + std::default_random_engine rng(123); + FillRandom(lower, rng, max_v); + FillRandom(range, rng, max_v); + FillRandom(values, rng, 2 * max_v); + + for (size_t i = 0; i < n; i++) { + upper[i] = lower[i] + range[i]; + } + + StopWatch sw; + bitset.inplace_within_range_column( + lower.data(), upper.data(), values.data(), n, op); + + if (print_timing) { + printf("elapsed %f\n", sw.elapsed()); + } + + for (size_t i = 0; i < n; i++) { + if (op == RangeType::IncInc) { + ASSERT_EQ(lower[i] <= values[i] && values[i] <= upper[i], bitset[i]) + << i << " " << lower[i] << " " << values[i] << " " << upper[i]; + } else if (op == RangeType::IncExc) { + ASSERT_EQ(lower[i] <= values[i] && values[i] < upper[i], bitset[i]) + << i << " " << lower[i] << " " << values[i] << " " << upper[i]; + } else if (op == RangeType::ExcInc) { + ASSERT_EQ(lower[i] < values[i] && values[i] <= upper[i], bitset[i]) + << i << " " << lower[i] << " " << values[i] << " " << upper[i]; + } else if (op == RangeType::ExcExc) { + ASSERT_EQ(lower[i] < values[i] && values[i] < upper[i], bitset[i]) + << i << " " << lower[i] << " " << values[i] << " " << upper[i]; + } else { + ASSERT_TRUE(false) << "Not implemented"; + } + } +} + +template +void +TestInplaceWithinRangeColumnImpl() { + for (const size_t n : typical_sizes) { + for (const auto op : typical_range_types) { + BitsetT bitset(n); + bitset.reset(); + + if (print_log) { + printf("Testing bitset, n=%zd, op=%zd\n", n, (size_t)op); + } + + TestInplaceWithinRangeColumnImpl(bitset, op); + + for (const size_t offset : typical_offsets) { + if (offset >= n) { + continue; + } + + bitset.reset(); + auto view = bitset.view(offset); + + if (print_log) { + printf("Testing bitset view, n=%zd, offset=%zd, op=%zd\n", + n, + offset, + (size_t)op); + } + + TestInplaceWithinRangeColumnImpl(view, op); + } + } + } +} + +// +template +class InplaceWithinRangeColumnSuite : public ::testing::Test {}; + +TYPED_TEST_SUITE_P(InplaceWithinRangeColumnSuite); + +TYPED_TEST_P(InplaceWithinRangeColumnSuite, BitWise) { + using impl_traits = RefImplTraits, + std::tuple_element_t<2, TypeParam>>; + TestInplaceWithinRangeColumnImpl>(); +} + +TYPED_TEST_P(InplaceWithinRangeColumnSuite, ElementWise) { + using impl_traits = ElementImplTraits, + std::tuple_element_t<2, TypeParam>>; + TestInplaceWithinRangeColumnImpl>(); +} + +TYPED_TEST_P(InplaceWithinRangeColumnSuite, Avx2) { +#if defined(__x86_64__) + using namespace milvus::bitset::detail::x86; + + if (cpu_support_avx2()) { + using impl_traits = + VectorizedImplTraits, + std::tuple_element_t<2, TypeParam>, + milvus::bitset::detail::x86::VectorizedAvx2>; + TestInplaceWithinRangeColumnImpl>(); + } +#endif +} + +TYPED_TEST_P(InplaceWithinRangeColumnSuite, Avx512) { +#if defined(__x86_64__) + using namespace milvus::bitset::detail::x86; + + if (cpu_support_avx512()) { + using impl_traits = + VectorizedImplTraits, + std::tuple_element_t<2, TypeParam>, + milvus::bitset::detail::x86::VectorizedAvx512>; + TestInplaceWithinRangeColumnImpl>(); + } +#endif +} + +TYPED_TEST_P(InplaceWithinRangeColumnSuite, Neon) { +#if defined(__aarch64__) + using namespace milvus::bitset::detail::arm; + + using impl_traits = + VectorizedImplTraits, + std::tuple_element_t<2, TypeParam>, + milvus::bitset::detail::arm::VectorizedNeon>; + TestInplaceWithinRangeColumnImpl>(); +#endif +} + +TYPED_TEST_P(InplaceWithinRangeColumnSuite, Sve) { +#if defined(__aarch64__) && defined(__ARM_FEATURE_SVE) + using namespace milvus::bitset::detail::arm; + + using impl_traits = + VectorizedImplTraits, + std::tuple_element_t<2, TypeParam>, + milvus::bitset::detail::arm::VectorizedSve>; + TestInplaceWithinRangeColumnImpl>(); +#endif +} + +TYPED_TEST_P(InplaceWithinRangeColumnSuite, Dynamic) { + using impl_traits = + VectorizedImplTraits, + std::tuple_element_t<2, TypeParam>, + milvus::bitset::detail::VectorizedDynamic>; + TestInplaceWithinRangeColumnImpl>(); +} + +TYPED_TEST_P(InplaceWithinRangeColumnSuite, VecRef) { + using impl_traits = + VectorizedImplTraits, + std::tuple_element_t<2, TypeParam>, + milvus::bitset::detail::VectorizedRef>; + TestInplaceWithinRangeColumnImpl>(); +} + +// +REGISTER_TYPED_TEST_SUITE_P(InplaceWithinRangeColumnSuite, + BitWise, + ElementWise, + Avx2, + Avx512, + Neon, + Sve, + Dynamic, + VecRef); + +INSTANTIATE_TYPED_TEST_SUITE_P(InplaceWithinRangeColumnTest, + InplaceWithinRangeColumnSuite, + Ttypes1); + +////////////////////////////////////////////////////////////////////////////////////////// + +// +template +void +TestInplaceWithinRangeValImpl(BitsetT& bitset, RangeType op) { + const size_t n = bitset.size(); + constexpr size_t max_v = 10; + constexpr T lower_v = 3; + constexpr T upper_v = 7; + + std::vector values(n, 0); + + std::default_random_engine rng(123); + FillRandom(values, rng, max_v); + + StopWatch sw; + bitset.inplace_within_range_val(lower_v, upper_v, values.data(), n, op); + + if (print_timing) { + printf("elapsed %f\n", sw.elapsed()); + } + + for (size_t i = 0; i < n; i++) { + if (op == RangeType::IncInc) { + ASSERT_EQ(lower_v <= values[i] && values[i] <= upper_v, bitset[i]) + << i << " " << lower_v << " " << values[i] << " " << upper_v; + } else if (op == RangeType::IncExc) { + ASSERT_EQ(lower_v <= values[i] && values[i] < upper_v, bitset[i]) + << i << " " << lower_v << " " << values[i] << " " << upper_v; + } else if (op == RangeType::ExcInc) { + ASSERT_EQ(lower_v < values[i] && values[i] <= upper_v, bitset[i]) + << i << " " << lower_v << " " << values[i] << " " << upper_v; + } else if (op == RangeType::ExcExc) { + ASSERT_EQ(lower_v < values[i] && values[i] < upper_v, bitset[i]) + << i << " " << lower_v << " " << values[i] << " " << upper_v; + } else { + ASSERT_TRUE(false) << "Not implemented"; + } + } +} + +template +void +TestInplaceWithinRangeValImpl() { + for (const size_t n : typical_sizes) { + for (const auto op : typical_range_types) { + BitsetT bitset(n); + bitset.reset(); + + if (print_log) { + printf("Testing bitset, n=%zd, op=%zd\n", n, (size_t)op); + } + + TestInplaceWithinRangeValImpl(bitset, op); + + for (const size_t offset : typical_offsets) { + if (offset >= n) { + continue; + } + + bitset.reset(); + auto view = bitset.view(offset); + + if (print_log) { + printf("Testing bitset view, n=%zd, offset=%zd, op=%zd\n", + n, + offset, + (size_t)op); + } + + TestInplaceWithinRangeValImpl(view, op); + } + } + } +} + +// +template +class InplaceWithinRangeValSuite : public ::testing::Test {}; + +TYPED_TEST_SUITE_P(InplaceWithinRangeValSuite); + +TYPED_TEST_P(InplaceWithinRangeValSuite, BitWise) { + using impl_traits = RefImplTraits, + std::tuple_element_t<2, TypeParam>>; + TestInplaceWithinRangeValImpl>(); +} + +TYPED_TEST_P(InplaceWithinRangeValSuite, ElementWise) { + using impl_traits = ElementImplTraits, + std::tuple_element_t<2, TypeParam>>; + TestInplaceWithinRangeValImpl>(); +} + +TYPED_TEST_P(InplaceWithinRangeValSuite, Avx2) { +#if defined(__x86_64__) + using namespace milvus::bitset::detail::x86; + + if (cpu_support_avx2()) { + using impl_traits = + VectorizedImplTraits, + std::tuple_element_t<2, TypeParam>, + milvus::bitset::detail::x86::VectorizedAvx2>; + TestInplaceWithinRangeValImpl>(); + } +#endif +} + +TYPED_TEST_P(InplaceWithinRangeValSuite, Avx512) { +#if defined(__x86_64__) + using namespace milvus::bitset::detail::x86; + + if (cpu_support_avx512()) { + using impl_traits = + VectorizedImplTraits, + std::tuple_element_t<2, TypeParam>, + milvus::bitset::detail::x86::VectorizedAvx512>; + TestInplaceWithinRangeValImpl>(); + } +#endif +} + +TYPED_TEST_P(InplaceWithinRangeValSuite, Neon) { +#if defined(__aarch64__) + using namespace milvus::bitset::detail::arm; + + using impl_traits = + VectorizedImplTraits, + std::tuple_element_t<2, TypeParam>, + milvus::bitset::detail::arm::VectorizedNeon>; + TestInplaceWithinRangeValImpl>(); +#endif +} + +TYPED_TEST_P(InplaceWithinRangeValSuite, Sve) { +#if defined(__aarch64__) && defined(__ARM_FEATURE_SVE) + using namespace milvus::bitset::detail::arm; + + using impl_traits = + VectorizedImplTraits, + std::tuple_element_t<2, TypeParam>, + milvus::bitset::detail::arm::VectorizedSve>; + TestInplaceWithinRangeValImpl>(); +#endif +} + +TYPED_TEST_P(InplaceWithinRangeValSuite, Dynamic) { + using impl_traits = + VectorizedImplTraits, + std::tuple_element_t<2, TypeParam>, + milvus::bitset::detail::VectorizedDynamic>; + TestInplaceWithinRangeValImpl>(); +} + +TYPED_TEST_P(InplaceWithinRangeValSuite, VecRef) { + using impl_traits = + VectorizedImplTraits, + std::tuple_element_t<2, TypeParam>, + milvus::bitset::detail::VectorizedRef>; + TestInplaceWithinRangeValImpl>(); +} + +// +REGISTER_TYPED_TEST_SUITE_P(InplaceWithinRangeValSuite, + BitWise, + ElementWise, + Avx2, + Avx512, + Neon, + Sve, + Dynamic, + VecRef); + +INSTANTIATE_TYPED_TEST_SUITE_P(InplaceWithinRangeValTest, + InplaceWithinRangeValSuite, + Ttypes1); + +////////////////////////////////////////////////////////////////////////////////////////// + +// +template +void +TestInplaceArithCompareImpl(BitsetT& bitset, + ArithOpType a_op, + CompareOpType cmp_op) { + using HT = ArithHighPrecisionType; + + const size_t n = bitset.size(); + constexpr size_t max_v = 10; + + std::vector left(n, 0); + HT right_operand = 2; + HT value = 5; + + std::default_random_engine rng(123); + FillRandom(left, rng, max_v); + + StopWatch sw; + bitset.inplace_arith_compare( + left.data(), right_operand, value, n, a_op, cmp_op); + + if (print_timing) { + printf("elapsed %f\n", sw.elapsed()); + } + + for (size_t i = 0; i < n; i++) { + if (a_op == ArithOpType::Add) { + if (cmp_op == CompareOpType::EQ) { + ASSERT_EQ((left[i] + right_operand) == value, bitset[i]) << i; + } else if (cmp_op == CompareOpType::GE) { + ASSERT_EQ((left[i] + right_operand) >= value, bitset[i]) << i; + } else if (cmp_op == CompareOpType::GT) { + ASSERT_EQ((left[i] + right_operand) > value, bitset[i]) << i; + } else if (cmp_op == CompareOpType::LE) { + ASSERT_EQ((left[i] + right_operand) <= value, bitset[i]) << i; + } else if (cmp_op == CompareOpType::LT) { + ASSERT_EQ((left[i] + right_operand) < value, bitset[i]) << i; + } else if (cmp_op == CompareOpType::NE) { + ASSERT_EQ((left[i] + right_operand) != value, bitset[i]) << i; + } else { + ASSERT_TRUE(false) << "Not implemented"; + } + } else if (a_op == ArithOpType::Sub) { + if (cmp_op == CompareOpType::EQ) { + ASSERT_EQ((left[i] - right_operand) == value, bitset[i]) << i; + } else if (cmp_op == CompareOpType::GE) { + ASSERT_EQ((left[i] - right_operand) >= value, bitset[i]) << i; + } else if (cmp_op == CompareOpType::GT) { + ASSERT_EQ((left[i] - right_operand) > value, bitset[i]) << i; + } else if (cmp_op == CompareOpType::LE) { + ASSERT_EQ((left[i] - right_operand) <= value, bitset[i]) << i; + } else if (cmp_op == CompareOpType::LT) { + ASSERT_EQ((left[i] - right_operand) < value, bitset[i]) << i; + } else if (cmp_op == CompareOpType::NE) { + ASSERT_EQ((left[i] - right_operand) != value, bitset[i]) << i; + } else { + ASSERT_TRUE(false) << "Not implemented"; + } + } else if (a_op == ArithOpType::Mul) { + if (cmp_op == CompareOpType::EQ) { + ASSERT_EQ((left[i] * right_operand) == value, bitset[i]) << i; + } else if (cmp_op == CompareOpType::GE) { + ASSERT_EQ((left[i] * right_operand) >= value, bitset[i]) << i; + } else if (cmp_op == CompareOpType::GT) { + ASSERT_EQ((left[i] * right_operand) > value, bitset[i]) << i; + } else if (cmp_op == CompareOpType::LE) { + ASSERT_EQ((left[i] * right_operand) <= value, bitset[i]) << i; + } else if (cmp_op == CompareOpType::LT) { + ASSERT_EQ((left[i] * right_operand) < value, bitset[i]) << i; + } else if (cmp_op == CompareOpType::NE) { + ASSERT_EQ((left[i] * right_operand) != value, bitset[i]) << i; + } else { + ASSERT_TRUE(false) << "Not implemented"; + } + } else if (a_op == ArithOpType::Div) { + if (cmp_op == CompareOpType::EQ) { + ASSERT_EQ((left[i] / right_operand) == value, bitset[i]) << i; + } else if (cmp_op == CompareOpType::GE) { + ASSERT_EQ((left[i] / right_operand) >= value, bitset[i]) << i; + } else if (cmp_op == CompareOpType::GT) { + ASSERT_EQ((left[i] / right_operand) > value, bitset[i]) << i; + } else if (cmp_op == CompareOpType::LE) { + ASSERT_EQ((left[i] / right_operand) <= value, bitset[i]) << i; + } else if (cmp_op == CompareOpType::LT) { + ASSERT_EQ((left[i] / right_operand) < value, bitset[i]) << i; + } else if (cmp_op == CompareOpType::NE) { + ASSERT_EQ((left[i] / right_operand) != value, bitset[i]) << i; + } else { + ASSERT_TRUE(false) << "Not implemented"; + } + } else if (a_op == ArithOpType::Mod) { + if (cmp_op == CompareOpType::EQ) { + ASSERT_EQ(fmod(left[i], right_operand) == value, bitset[i]) + << i; + } else if (cmp_op == CompareOpType::GE) { + ASSERT_EQ(fmod(left[i], right_operand) >= value, bitset[i]) + << i; + } else if (cmp_op == CompareOpType::GT) { + ASSERT_EQ(fmod(left[i], right_operand) > value, bitset[i]) << i; + } else if (cmp_op == CompareOpType::LE) { + ASSERT_EQ(fmod(left[i], right_operand) <= value, bitset[i]) + << i; + } else if (cmp_op == CompareOpType::LT) { + ASSERT_EQ(fmod(left[i], right_operand) < value, bitset[i]) << i; + } else if (cmp_op == CompareOpType::NE) { + ASSERT_EQ(fmod(left[i], right_operand) != value, bitset[i]) + << i; + } else { + ASSERT_TRUE(false) << "Not implemented"; + } + } else { + ASSERT_TRUE(false) << "Not implemented"; + } + } +} + +template +void +TestInplaceArithCompareImpl() { + for (const size_t n : typical_sizes) { + for (const auto a_op : typical_arith_ops) { + for (const auto cmp_op : typical_compare_ops) { + BitsetT bitset(n); + bitset.reset(); + + if (print_log) { + printf( + "Testing bitset, n=%zd, a_op=%zd\n", n, (size_t)a_op); + } + + TestInplaceArithCompareImpl(bitset, a_op, cmp_op); + + for (const size_t offset : typical_offsets) { + if (offset >= n) { + continue; + } + + bitset.reset(); + auto view = bitset.view(offset); + + if (print_log) { + printf( + "Testing bitset view, n=%zd, offset=%zd, a_op=%zd, " + "cmp_op=%zd\n", + n, + offset, + (size_t)a_op, + (size_t)cmp_op); + } + + TestInplaceArithCompareImpl( + view, a_op, cmp_op); + } + } + } + } +} + +// +template +class InplaceArithCompareSuite : public ::testing::Test {}; + +TYPED_TEST_SUITE_P(InplaceArithCompareSuite); + +TYPED_TEST_P(InplaceArithCompareSuite, BitWise) { + using impl_traits = RefImplTraits, + std::tuple_element_t<2, TypeParam>>; + TestInplaceArithCompareImpl>(); +} + +TYPED_TEST_P(InplaceArithCompareSuite, ElementWise) { + using impl_traits = ElementImplTraits, + std::tuple_element_t<2, TypeParam>>; + TestInplaceArithCompareImpl>(); +} + +TYPED_TEST_P(InplaceArithCompareSuite, Avx2) { +#if defined(__x86_64__) + using namespace milvus::bitset::detail::x86; + + if (cpu_support_avx2()) { + using impl_traits = + VectorizedImplTraits, + std::tuple_element_t<2, TypeParam>, + milvus::bitset::detail::x86::VectorizedAvx2>; + TestInplaceArithCompareImpl>(); + } +#endif +} + +TYPED_TEST_P(InplaceArithCompareSuite, Avx512) { +#if defined(__x86_64__) + using namespace milvus::bitset::detail::x86; + + if (cpu_support_avx512()) { + using impl_traits = + VectorizedImplTraits, + std::tuple_element_t<2, TypeParam>, + milvus::bitset::detail::x86::VectorizedAvx512>; + TestInplaceArithCompareImpl>(); + } +#endif +} + +TYPED_TEST_P(InplaceArithCompareSuite, Neon) { +#if defined(__aarch64__) + using namespace milvus::bitset::detail::arm; + + using impl_traits = + VectorizedImplTraits, + std::tuple_element_t<2, TypeParam>, + milvus::bitset::detail::arm::VectorizedNeon>; + TestInplaceArithCompareImpl>(); +#endif +} + +TYPED_TEST_P(InplaceArithCompareSuite, Sve) { +#if defined(__aarch64__) && defined(__ARM_FEATURE_SVE) + using namespace milvus::bitset::detail::arm; + + using impl_traits = + VectorizedImplTraits, + std::tuple_element_t<2, TypeParam>, + milvus::bitset::detail::arm::VectorizedSve>; + TestInplaceArithCompareImpl>(); +#endif +} + +TYPED_TEST_P(InplaceArithCompareSuite, Dynamic) { + using impl_traits = + VectorizedImplTraits, + std::tuple_element_t<2, TypeParam>, + milvus::bitset::detail::VectorizedDynamic>; + TestInplaceArithCompareImpl>(); +} + +TYPED_TEST_P(InplaceArithCompareSuite, VecRef) { + using impl_traits = + VectorizedImplTraits, + std::tuple_element_t<2, TypeParam>, + milvus::bitset::detail::VectorizedRef>; + TestInplaceArithCompareImpl>(); +} + +// +REGISTER_TYPED_TEST_SUITE_P(InplaceArithCompareSuite, + BitWise, + ElementWise, + Avx2, + Avx512, + Neon, + Sve, + Dynamic, + VecRef); + +INSTANTIATE_TYPED_TEST_SUITE_P(InplaceArithCompareTest, + InplaceArithCompareSuite, + Ttypes1); + +////////////////////////////////////////////////////////////////////////////////////////// + +template +void +TestAppendImpl(BitsetT& bitset_dst, const BitsetU& bitset_src) { + std::vector b_dst; + b_dst.reserve(bitset_src.size() + bitset_dst.size()); + + for (size_t i = 0; i < bitset_dst.size(); i++) { + b_dst.push_back(bitset_dst[i]); + } + for (size_t i = 0; i < bitset_src.size(); i++) { + b_dst.push_back(bitset_src[i]); + } + + StopWatch sw; + bitset_dst.append(bitset_src); + + if (print_timing) { + printf("elapsed %f\n", sw.elapsed()); + } + + // + ASSERT_EQ(b_dst.size(), bitset_dst.size()); + for (size_t i = 0; i < bitset_dst.size(); i++) { + ASSERT_EQ(b_dst[i], bitset_dst[i]) << i; + } +} + +template +void +TestAppendImpl() { + std::default_random_engine rng(345); + + std::vector bt0; + for (const size_t n : typical_sizes) { + BitsetT bitset(n); + FillRandom(bitset, rng); + bt0.push_back(std::move(bitset)); + } + + std::vector bt1; + for (const size_t n : typical_sizes) { + BitsetT bitset(n); + FillRandom(bitset, rng); + bt1.push_back(std::move(bitset)); + } + + for (const auto& bt_a : bt0) { + for (const auto& bt_b : bt1) { + auto bt = bt_a.clone(); + + if (print_log) { + printf( + "Testing bitset, n=%zd, m=%zd\n", bt_a.size(), bt_b.size()); + } + + TestAppendImpl(bt, bt_b); + + for (const size_t offset : typical_offsets) { + if (offset >= bt_b.size()) { + continue; + } + + bt = bt_a.clone(); + auto view = bt_b.view(offset); + + if (print_log) { + printf("Testing bitset view, n=%zd, m=%zd, offset=%zd\n", + bt_a.size(), + bt_b.size(), + offset); + } + + TestAppendImpl(bt, view); + } + } + } +} + +TEST(Append, BitWise) { + using impl_traits = RefImplTraits; + TestAppendImpl(); +} + +TEST(Append, ElementWise) { + using impl_traits = ElementImplTraits; + TestAppendImpl(); +} + +////////////////////////////////////////////////////////////////////////////////////////// + +// +template +void +TestCountImpl(BitsetT& bitset, const size_t max_v) { + const size_t n = bitset.size(); + + std::default_random_engine rng(123); + std::uniform_int_distribution u(0, max_v); + + std::vector one_pos; + for (size_t i = 0; i < n; i++) { + bool enabled = (u(rng) == 0); + if (enabled) { + one_pos.push_back(i); + bitset[i] = true; + } + } + + StopWatch sw; + + auto count = bitset.count(); + ASSERT_EQ(count, one_pos.size()); + + if (print_timing) { + printf("elapsed %f\n", sw.elapsed()); + } +} + +template +void +TestCountImpl() { + for (const size_t n : typical_sizes) { + for (const size_t pr : {1, 100}) { + BitsetT bitset(n); + bitset.reset(); + + if (print_log) { + printf("Testing bitset, n=%zd, pr=%zd\n", n, pr); + } + + TestCountImpl(bitset, pr); + + for (const size_t offset : typical_offsets) { + if (offset >= n) { + continue; + } + + bitset.reset(); + auto view = bitset.view(offset); + + if (print_log) { + printf("Testing bitset view, n=%zd, offset=%zd, pr=%zd\n", + n, + offset, + pr); + } + + TestCountImpl(view, pr); + } + } + } +} + +// +TEST(CountRef, f) { + using impl_traits = RefImplTraits; + TestCountImpl(); +} + +// +TEST(CountElement, f) { + using impl_traits = ElementImplTraits; + TestCountImpl(); +} + +////////////////////////////////////////////////////////////////////////////////////////// + +int +main(int argc, char* argv[]) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/internal/core/unittest/test_c_api.cpp b/internal/core/unittest/test_c_api.cpp index b8f09b6b515d..12faad22e2c1 100644 --- a/internal/core/unittest/test_c_api.cpp +++ b/internal/core/unittest/test_c_api.cpp @@ -4696,41 +4696,41 @@ TEST(CApiTest, RANGE_SEARCH_WITH_RADIUS_AND_RANGE_FILTER_WHEN_L2) { } TEST(CApiTest, AssembeChunkTest) { - FixedVector chunk; + TargetBitmap chunk(1000); for (size_t i = 0; i < 1000; ++i) { - chunk.push_back(i % 2 == 0); + chunk[i] = (i % 2 == 0); } BitsetType result; milvus::query::AppendOneChunk(result, chunk); - std::string s; - boost::to_string(result, s); - std::cout << s << std::endl; + // std::string s; + // boost::to_string(result, s); + // std::cout << s << std::endl; int index = 0; for (size_t i = 0; i < 1000; i++) { ASSERT_EQ(result[index++], chunk[i]) << i; } - chunk.clear(); + chunk = TargetBitmap(934); for (int i = 0; i < 934; ++i) { - chunk.push_back(i % 2 == 0); + chunk[i] = (i % 2 == 0); } milvus::query::AppendOneChunk(result, chunk); for (size_t i = 0; i < 934; i++) { ASSERT_EQ(result[index++], chunk[i]) << i; } - chunk.clear(); + chunk = TargetBitmap(62); for (int i = 0; i < 62; ++i) { - chunk.push_back(i % 2 == 0); + chunk[i] = (i % 2 == 0); } milvus::query::AppendOneChunk(result, chunk); for (size_t i = 0; i < 62; i++) { ASSERT_EQ(result[index++], chunk[i]) << i; } - chunk.clear(); + chunk = TargetBitmap(105); for (int i = 0; i < 105; ++i) { - chunk.push_back(i % 2 == 0); + chunk[i] = (i % 2 == 0); } milvus::query::AppendOneChunk(result, chunk); for (size_t i = 0; i < 105; i++) { @@ -4745,16 +4745,17 @@ search_id(const BitsetType& bitset, bool use_find) { std::vector dst_offset; if (use_find) { - for (int i = bitset.find_first(); i < bitset.size(); - i = bitset.find_next(i)) { - if (i == BitsetType::npos) { - return dst_offset; - } - auto offset = SegOffset(i); + auto i = bitset.find_first(); + while (i.has_value()) { + auto offset = SegOffset(i.value()); if (timestamps[offset.get()] <= timestamp) { dst_offset.push_back(offset); } + + i = bitset.find_next(i.value()); } + + return dst_offset; } else { for (int i = 0; i < bitset.size(); i++) { if (bitset[i]) { @@ -4769,7 +4770,7 @@ search_id(const BitsetType& bitset, } TEST(CApiTest, SearchIdTest) { - using BitsetType = boost::dynamic_bitset<>; + // using BitsetType = boost::dynamic_bitset<>; auto test = [&](int NT) { BitsetType bitset(1000000); @@ -4819,9 +4820,9 @@ TEST(CApiTest, SearchIdTest) { } TEST(CApiTest, AssembeChunkPerfTest) { - FixedVector chunk; + TargetBitmap chunk(100000000); for (size_t i = 0; i < 100000000; ++i) { - chunk.push_back(i % 2 == 0); + chunk[i] = (i % 2 == 0); } BitsetType result; // while (true) { diff --git a/internal/core/unittest/test_simd.cpp b/internal/core/unittest/test_simd.cpp deleted file mode 100644 index 8cfa3d8d223b..000000000000 --- a/internal/core/unittest/test_simd.cpp +++ /dev/null @@ -1,1536 +0,0 @@ -// Copyright (C) 2019-2020 Zilliz. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software distributed under the License -// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -// or implied. See the License for the specific language governing permissions and limitations under the License - -#include - -#include -#include -#include -#include -#include -#include -#include -#include - -using namespace std; - -template -using FixedVector = boost::container::vector; - -#define PRINT_SKPI_TEST \ - std::cout \ - << "skip " \ - << ::testing::UnitTest::GetInstance()->current_test_info()->name() \ - << std::endl; - -#if defined(__x86_64__) -#include "simd/hook.h" -#include "simd/ref.h" -#include "simd/sse2.h" -#include "simd/sse4.h" -#include "simd/avx2.h" -#include "simd/avx512.h" -#include "simd/ref.h" - -using namespace milvus::simd; -TEST(GetBitSetBlock, base_test_sse) { - FixedVector src; - for (int i = 0; i < 64; ++i) { - src.push_back(false); - } - - auto res = GetBitsetBlockSSE2(src.data()); - std::cout << res << std::endl; - ASSERT_EQ(res, 0); - - src.clear(); - for (int i = 0; i < 64; ++i) { - src.push_back(true); - } - res = GetBitsetBlockSSE2(src.data()); - std::cout << std::hex << res << std::endl; - ASSERT_EQ(res, 0xffffffffffffffff); - - src.clear(); - for (int i = 0; i < 64; ++i) { - src.push_back(i % 2 == 0 ? true : false); - } - res = GetBitsetBlockSSE2(src.data()); - std::cout << std::hex << res << std::endl; - ASSERT_EQ(res, 0x5555555555555555); - - src.clear(); - for (int i = 0; i < 64; ++i) { - src.push_back(i % 4 == 0 ? true : false); - } - res = GetBitsetBlockSSE2(src.data()); - std::cout << std::hex << res << std::endl; - ASSERT_EQ(res, 0x1111111111111111); - - src.clear(); - for (int i = 0; i < 64; ++i) { - src.push_back(i % 8 == 0 ? true : false); - } - res = GetBitsetBlockSSE2(src.data()); - std::cout << std::hex << res << std::endl; - ASSERT_EQ(res, 0x0101010101010101); - - src.clear(); - for (int i = 0; i < 64; ++i) { - src.push_back(i % 16 == 0 ? true : false); - } - res = GetBitsetBlockSSE2(src.data()); - std::cout << std::hex << res << std::endl; - ASSERT_EQ(res, 0x0001000100010001); - - src.clear(); - for (int i = 0; i < 64; ++i) { - src.push_back(i % 32 == 0 ? true : false); - } - res = GetBitsetBlockSSE2(src.data()); - std::cout << std::hex << res << std::endl; - ASSERT_EQ(res, 0x0000000100000001); - - src.clear(); - for (int i = 0; i < 64; ++i) { - src.push_back(i % 5 == 0 ? true : false); - } - res = GetBitsetBlockSSE2(src.data()); - std::cout << std::hex << res << std::endl; - ASSERT_EQ(res, 0x1084210842108421); -} - -TEST(GetBitsetBlockPerf, bitset) { - FixedVector srcs; - for (size_t i = 0; i < 100000000; ++i) { - srcs.push_back(i % 2 == 0); - } - std::cout << "start test" << std::endl; - auto start = std::chrono::steady_clock::now(); - for (int i = 0; i < 10000000; ++i) - auto result = GetBitsetBlockSSE2(srcs.data() + i); - std::cout << "cost: " - << std::chrono::duration_cast( - std::chrono::steady_clock::now() - start) - .count() - << "us" << std::endl; - start = std::chrono::steady_clock::now(); - for (int i = 0; i < 10000000; ++i) - auto result = GetBitsetBlockAVX2(srcs.data() + i); - std::cout << "cost: " - << std::chrono::duration_cast( - std::chrono::steady_clock::now() - start) - .count() - << "us" << std::endl; -} - -TEST(GetBitSetBlock, base_test_avx2) { - FixedVector src; - for (int i = 0; i < 64; ++i) { - src.push_back(false); - } - - auto res = GetBitsetBlockAVX2(src.data()); - std::cout << res << std::endl; - ASSERT_EQ(res, 0); - - src.clear(); - for (int i = 0; i < 64; ++i) { - src.push_back(true); - } - res = GetBitsetBlockAVX2(src.data()); - std::cout << std::hex << res << std::endl; - ASSERT_EQ(res, 0xffffffffffffffff); - - src.clear(); - for (int i = 0; i < 64; ++i) { - src.push_back(i % 2 == 0 ? true : false); - } - res = GetBitsetBlockAVX2(src.data()); - std::cout << std::hex << res << std::endl; - ASSERT_EQ(res, 0x5555555555555555); - - src.clear(); - for (int i = 0; i < 64; ++i) { - src.push_back(i % 4 == 0 ? true : false); - } - res = GetBitsetBlockAVX2(src.data()); - std::cout << std::hex << res << std::endl; - ASSERT_EQ(res, 0x1111111111111111); - - src.clear(); - for (int i = 0; i < 64; ++i) { - src.push_back(i % 8 == 0 ? true : false); - } - res = GetBitsetBlockAVX2(src.data()); - std::cout << std::hex << res << std::endl; - ASSERT_EQ(res, 0x0101010101010101); - - src.clear(); - for (int i = 0; i < 64; ++i) { - src.push_back(i % 16 == 0 ? true : false); - } - res = GetBitsetBlockAVX2(src.data()); - std::cout << std::hex << res << std::endl; - ASSERT_EQ(res, 0x0001000100010001); - - src.clear(); - for (int i = 0; i < 64; ++i) { - src.push_back(i % 32 == 0 ? true : false); - } - res = GetBitsetBlockAVX2(src.data()); - std::cout << std::hex << res << std::endl; - ASSERT_EQ(res, 0x0000000100000001); - - src.clear(); - for (int i = 0; i < 64; ++i) { - src.push_back(i % 5 == 0 ? true : false); - } - res = GetBitsetBlockAVX2(src.data()); - std::cout << std::hex << res << std::endl; - ASSERT_EQ(res, 0x1084210842108421); -} - -TEST(FindTermSSE2, bool_type) { - FixedVector vecs; - vecs.push_back(false); - - auto res = FindTermSSE2(vecs.data(), vecs.size(), true); - ASSERT_EQ(res, false); - res = FindTermSSE2(vecs.data(), vecs.size(), false); - ASSERT_EQ(res, true); - - for (int i = 0; i < 16; i++) { - vecs.push_back(false); - } - - res = FindTermSSE2(vecs.data(), vecs.size(), true); - ASSERT_EQ(res, false); - res = FindTermSSE2(vecs.data(), vecs.size(), false); - ASSERT_EQ(res, true); - - vecs.push_back(true); - for (int i = 0; i < 16; i++) { - vecs.push_back(false); - } - res = FindTermSSE2(vecs.data(), vecs.size(), true); - ASSERT_EQ(res, true); -} - -TEST(FindTermSSE2, int8_type) { - std::vector vecs; - for (int i = 0; i < 100; i++) { - vecs.push_back(i); - } - - auto res = FindTermSSE2(vecs.data(), vecs.size(), (int8_t)0); - ASSERT_EQ(res, true); - res = FindTermSSE2(vecs.data(), vecs.size(), (int8_t)10); - ASSERT_EQ(res, true); - res = FindTermSSE2(vecs.data(), vecs.size(), (int8_t)99); - ASSERT_EQ(res, true); - res = FindTermSSE2(vecs.data(), vecs.size(), (int8_t)100); - ASSERT_EQ(res, false); - res = FindTermSSE2(vecs.data(), vecs.size(), (int8_t)127); - ASSERT_EQ(res, false); - vecs.push_back(127); - res = FindTermSSE2(vecs.data(), vecs.size(), (int8_t)127); - ASSERT_EQ(res, true); -} - -TEST(FindTermSSE2, int16_type) { - std::vector vecs; - for (int i = 0; i < 1000; i++) { - vecs.push_back(i); - } - - auto res = FindTermSSE2(vecs.data(), vecs.size(), (int16_t)0); - ASSERT_EQ(res, true); - res = FindTermSSE2(vecs.data(), vecs.size(), (int16_t)10); - ASSERT_EQ(res, true); - res = FindTermSSE2(vecs.data(), vecs.size(), (int16_t)999); - ASSERT_EQ(res, true); - res = FindTermSSE2(vecs.data(), vecs.size(), (int16_t)1000); - ASSERT_EQ(res, false); - res = FindTermSSE2(vecs.data(), vecs.size(), (int16_t)1270); - ASSERT_EQ(res, false); - vecs.push_back(1000); - res = FindTermSSE2(vecs.data(), vecs.size(), (int16_t)1000); - ASSERT_EQ(res, true); -} - -TEST(FindTermSSE2, int32_type) { - std::vector vecs; - for (int i = 0; i < 1000; i++) { - vecs.push_back(i); - } - - auto res = FindTermSSE2(vecs.data(), vecs.size(), 0); - ASSERT_EQ(res, true); - res = FindTermSSE2(vecs.data(), vecs.size(), 10); - ASSERT_EQ(res, true); - res = FindTermSSE2(vecs.data(), vecs.size(), 999); - ASSERT_EQ(res, true); - res = FindTermSSE2(vecs.data(), vecs.size(), 1000); - ASSERT_EQ(res, false); - - vecs.push_back(1000); - res = FindTermSSE2(vecs.data(), vecs.size(), 1000); - ASSERT_EQ(res, true); - res = FindTermSSE2(vecs.data(), vecs.size(), 1001); - ASSERT_EQ(res, false); - - vecs.push_back(1001); - res = FindTermSSE2(vecs.data(), vecs.size(), 1001); - ASSERT_EQ(res, true); - res = FindTermSSE2(vecs.data(), vecs.size(), 1002); - ASSERT_EQ(res, false); - - vecs.push_back(1002); - res = FindTermSSE2(vecs.data(), vecs.size(), 1002); - ASSERT_EQ(res, true); - res = FindTermSSE2(vecs.data(), vecs.size(), 1003); - ASSERT_EQ(res, false); - - res = FindTermSSE2(vecs.data(), vecs.size(), 1270); - ASSERT_EQ(res, false); -} - -TEST(FindTermSSE2, int64_type) { - std::vector vecs; - for (int i = 0; i < 1000; i++) { - vecs.push_back(i); - } - - auto res = FindTermSSE2(vecs.data(), vecs.size(), (int64_t)0); - ASSERT_EQ(res, true); - res = FindTermSSE2(vecs.data(), vecs.size(), (int64_t)10); - ASSERT_EQ(res, true); - res = FindTermSSE2(vecs.data(), vecs.size(), (int64_t)999); - ASSERT_EQ(res, true); - res = FindTermSSE2(vecs.data(), vecs.size(), (int64_t)1000); - ASSERT_EQ(res, false); - res = FindTermSSE2(vecs.data(), vecs.size(), (int64_t)1270); - ASSERT_EQ(res, false); - vecs.push_back(1005); - res = FindTermSSE2(vecs.data(), vecs.size(), (int64_t)1005); - ASSERT_EQ(res, true); -} - -TEST(FindTermSSE2, float_type) { - std::vector vecs; - for (int i = 0; i < 10000; i++) { - vecs.push_back(i + 0.01); - } - - auto res = FindTermSSE2(vecs.data(), vecs.size(), (float)0.01); - ASSERT_EQ(res, true); - res = FindTermSSE2(vecs.data(), vecs.size(), (float)10.01); - ASSERT_EQ(res, true); - res = FindTermSSE2(vecs.data(), vecs.size(), (float)10000.01); - ASSERT_EQ(res, false); - res = FindTermSSE2(vecs.data(), vecs.size(), (float)12700.02); - ASSERT_EQ(res, false); - vecs.push_back(1.001); - res = FindTermSSE2(vecs.data(), vecs.size(), (float)1.001); - ASSERT_EQ(res, true); -} - -TEST(FindTermSSE2, double_type) { - std::vector vecs; - for (int i = 0; i < 10000; i++) { - vecs.push_back(i + 0.01); - } - - auto res = FindTermSSE2(vecs.data(), vecs.size(), 0.01); - ASSERT_EQ(res, true); - res = FindTermSSE2(vecs.data(), vecs.size(), 10.01); - ASSERT_EQ(res, true); - res = FindTermSSE2(vecs.data(), vecs.size(), 10000.01); - ASSERT_EQ(res, false); - res = FindTermSSE2(vecs.data(), vecs.size(), 12700.01); - ASSERT_EQ(res, false); - vecs.push_back(1.001); - res = FindTermSSE2(vecs.data(), vecs.size(), 1.001); - ASSERT_EQ(res, true); -} - -TEST(FindTermSSE4, int64_type) { - if (!cpu_support_sse4_2()) { - PRINT_SKPI_TEST - return; - } - std::vector srcs; - for (size_t i = 0; i < 1000; i++) { - srcs.push_back(i); - } - - auto res = FindTermSSE4(srcs.data(), srcs.size(), (int64_t)0); - ASSERT_EQ(res, true); - res = FindTermSSE4(srcs.data(), srcs.size(), (int64_t)1); - ASSERT_EQ(res, true); - res = FindTermSSE4(srcs.data(), srcs.size(), (int64_t)999); - ASSERT_EQ(res, true); - res = FindTermSSE4(srcs.data(), srcs.size(), (int64_t)1000); - ASSERT_EQ(res, false); - res = FindTermSSE4(srcs.data(), srcs.size(), (int64_t)2000); - ASSERT_EQ(res, false); - srcs.push_back(1000); - res = FindTermSSE4(srcs.data(), srcs.size(), (int64_t)1000); - ASSERT_EQ(res, true); -} - -TEST(FindTermAVX2, bool_type) { - if (!cpu_support_avx2()) { - PRINT_SKPI_TEST - return; - } - std::vector srcs; - for (size_t i = 0; i < 1000; i++) { - srcs.push_back(i); - } - FixedVector vecs; - vecs.push_back(false); - - auto res = FindTermAVX2(vecs.data(), vecs.size(), true); - ASSERT_EQ(res, false); - res = FindTermAVX2(vecs.data(), vecs.size(), false); - ASSERT_EQ(res, true); - - for (int i = 0; i < 16; i++) { - vecs.push_back(false); - } - - res = FindTermAVX2(vecs.data(), vecs.size(), true); - ASSERT_EQ(res, false); - res = FindTermAVX2(vecs.data(), vecs.size(), false); - ASSERT_EQ(res, true); - - vecs.push_back(true); - for (int i = 0; i < 16; i++) { - vecs.push_back(false); - } - res = FindTermAVX2(vecs.data(), vecs.size(), true); - ASSERT_EQ(res, true); -} - -TEST(FindTermAVX2, int8_type) { - if (!cpu_support_avx2()) { - PRINT_SKPI_TEST - return; - } - std::vector vecs; - for (int i = 0; i < 100; i++) { - vecs.push_back(i); - } - - auto res = FindTermAVX2(vecs.data(), vecs.size(), (int8_t)0); - ASSERT_EQ(res, true); - res = FindTermAVX2(vecs.data(), vecs.size(), (int8_t)10); - ASSERT_EQ(res, true); - res = FindTermAVX2(vecs.data(), vecs.size(), (int8_t)99); - ASSERT_EQ(res, true); - res = FindTermAVX2(vecs.data(), vecs.size(), (int8_t)100); - ASSERT_EQ(res, false); - res = FindTermAVX2(vecs.data(), vecs.size(), (int8_t)127); - ASSERT_EQ(res, false); - vecs.push_back(127); - res = FindTermAVX2(vecs.data(), vecs.size(), (int8_t)127); - ASSERT_EQ(res, true); -} - -TEST(FindTermAVX2, int16_type) { - if (!cpu_support_avx2()) { - PRINT_SKPI_TEST - return; - } - std::vector vecs; - for (int i = 0; i < 1000; i++) { - vecs.push_back(i); - } - - auto res = FindTermAVX2(vecs.data(), vecs.size(), (int16_t)0); - ASSERT_EQ(res, true); - res = FindTermAVX2(vecs.data(), vecs.size(), (int16_t)10); - ASSERT_EQ(res, true); - res = FindTermAVX2(vecs.data(), vecs.size(), (int16_t)999); - ASSERT_EQ(res, true); - res = FindTermAVX2(vecs.data(), vecs.size(), (int16_t)1000); - ASSERT_EQ(res, false); - res = FindTermAVX2(vecs.data(), vecs.size(), (int16_t)1270); - ASSERT_EQ(res, false); - vecs.push_back(1270); - res = FindTermAVX2(vecs.data(), vecs.size(), (int16_t)1270); - ASSERT_EQ(res, true); -} - -TEST(FindTermAVX2, int32_type) { - if (!cpu_support_avx2()) { - PRINT_SKPI_TEST - return; - } - std::vector vecs; - for (int i = 0; i < 1000; i++) { - vecs.push_back(i); - } - - auto res = FindTermAVX2(vecs.data(), vecs.size(), 0); - ASSERT_EQ(res, true); - res = FindTermAVX2(vecs.data(), vecs.size(), 10); - ASSERT_EQ(res, true); - res = FindTermAVX2(vecs.data(), vecs.size(), 999); - ASSERT_EQ(res, true); - res = FindTermAVX2(vecs.data(), vecs.size(), 1000); - ASSERT_EQ(res, false); - res = FindTermAVX2(vecs.data(), vecs.size(), 1270); - ASSERT_EQ(res, false); - vecs.push_back(1270); - res = FindTermAVX2(vecs.data(), vecs.size(), 1270); - ASSERT_EQ(res, true); -} - -TEST(FindTermAVX2, int64_type) { - if (!cpu_support_avx2()) { - PRINT_SKPI_TEST - return; - } - std::vector vecs; - for (int i = 0; i < 1000; i++) { - vecs.push_back(i); - } - - auto res = FindTermAVX2(vecs.data(), vecs.size(), (int64_t)0); - ASSERT_EQ(res, true); - res = FindTermAVX2(vecs.data(), vecs.size(), (int64_t)10); - ASSERT_EQ(res, true); - res = FindTermAVX2(vecs.data(), vecs.size(), (int64_t)999); - ASSERT_EQ(res, true); - res = FindTermAVX2(vecs.data(), vecs.size(), (int64_t)1000); - ASSERT_EQ(res, false); - res = FindTermAVX2(vecs.data(), vecs.size(), (int64_t)1270); - ASSERT_EQ(res, false); - vecs.push_back(1270); - res = FindTermAVX2(vecs.data(), vecs.size(), (int64_t)1270); - ASSERT_EQ(res, true); -} - -TEST(FindTermAVX2, float_type) { - if (!cpu_support_avx2()) { - PRINT_SKPI_TEST - return; - } - std::vector vecs; - for (int i = 0; i < 10000; i++) { - vecs.push_back(i + 0.01); - } - - auto res = FindTermAVX2(vecs.data(), vecs.size(), (float)0.01); - ASSERT_EQ(res, true); - res = FindTermAVX2(vecs.data(), vecs.size(), (float)10.01); - ASSERT_EQ(res, true); - res = FindTermAVX2(vecs.data(), vecs.size(), (float)10000.01); - ASSERT_EQ(res, false); - res = FindTermAVX2(vecs.data(), vecs.size(), (float)12700.02); - ASSERT_EQ(res, false); - vecs.push_back(12700.02); - res = FindTermAVX2(vecs.data(), vecs.size(), (float)12700.02); - ASSERT_EQ(res, true); -} - -TEST(FindTermAVX2, double_type) { - if (!cpu_support_avx2()) { - PRINT_SKPI_TEST - return; - } - std::vector vecs; - for (int i = 0; i < 10000; i++) { - vecs.push_back(i + 0.01); - } - - auto res = FindTermAVX2(vecs.data(), vecs.size(), 0.01); - ASSERT_EQ(res, true); - res = FindTermAVX2(vecs.data(), vecs.size(), 10.01); - ASSERT_EQ(res, true); - res = FindTermAVX2(vecs.data(), vecs.size(), 10000.01); - ASSERT_EQ(res, false); - res = FindTermAVX2(vecs.data(), vecs.size(), 12700.01); - ASSERT_EQ(res, false); - vecs.push_back(12700.01); - res = FindTermAVX2(vecs.data(), vecs.size(), 12700.01); - ASSERT_EQ(res, true); -} - -TEST(FindTermAVX512, bool_type) { - if (!cpu_support_avx512()) { - PRINT_SKPI_TEST - return; - } - std::vector srcs; - for (size_t i = 0; i < 1000; i++) { - srcs.push_back(i); - } - FixedVector vecs; - vecs.push_back(false); - - auto res = FindTermAVX512(vecs.data(), vecs.size(), true); - ASSERT_EQ(res, false); - res = FindTermAVX512(vecs.data(), vecs.size(), false); - ASSERT_EQ(res, true); - - for (int i = 0; i < 16; i++) { - vecs.push_back(false); - } - - res = FindTermAVX512(vecs.data(), vecs.size(), true); - ASSERT_EQ(res, false); - res = FindTermAVX512(vecs.data(), vecs.size(), false); - ASSERT_EQ(res, true); - - vecs.push_back(true); - for (int i = 0; i < 16; i++) { - vecs.push_back(false); - } - res = FindTermAVX512(vecs.data(), vecs.size(), true); - ASSERT_EQ(res, true); -} - -TEST(FindTermAVX512, int8_type) { - if (!cpu_support_avx512()) { - PRINT_SKPI_TEST - return; - } - std::vector vecs; - for (int i = 0; i < 100; i++) { - vecs.push_back(i); - } - - auto res = FindTermAVX512(vecs.data(), vecs.size(), (int8_t)0); - ASSERT_EQ(res, true); - res = FindTermAVX512(vecs.data(), vecs.size(), (int8_t)10); - ASSERT_EQ(res, true); - res = FindTermAVX512(vecs.data(), vecs.size(), (int8_t)99); - ASSERT_EQ(res, true); - res = FindTermAVX512(vecs.data(), vecs.size(), (int8_t)100); - ASSERT_EQ(res, false); - res = FindTermAVX512(vecs.data(), vecs.size(), (int8_t)127); - ASSERT_EQ(res, false); - vecs.push_back(127); - res = FindTermAVX512(vecs.data(), vecs.size(), (int8_t)127); - ASSERT_EQ(res, true); -} - -TEST(FindTermAVX512, int16_type) { - if (!cpu_support_avx512()) { - PRINT_SKPI_TEST - return; - } - std::vector vecs; - for (int i = 0; i < 1000; i++) { - vecs.push_back(i); - } - - auto res = FindTermAVX512(vecs.data(), vecs.size(), (int16_t)0); - ASSERT_EQ(res, true); - res = FindTermAVX512(vecs.data(), vecs.size(), (int16_t)10); - ASSERT_EQ(res, true); - res = FindTermAVX512(vecs.data(), vecs.size(), (int16_t)999); - ASSERT_EQ(res, true); - res = FindTermAVX512(vecs.data(), vecs.size(), (int16_t)1000); - ASSERT_EQ(res, false); - res = FindTermAVX512(vecs.data(), vecs.size(), (int16_t)1270); - ASSERT_EQ(res, false); - vecs.push_back(1270); - res = FindTermAVX512(vecs.data(), vecs.size(), (int16_t)1270); - ASSERT_EQ(res, true); -} - -TEST(FindTermAVX512, int32_type) { - if (!cpu_support_avx512()) { - PRINT_SKPI_TEST - return; - } - std::vector vecs; - for (int i = 0; i < 1000; i++) { - vecs.push_back(i); - } - - auto res = FindTermAVX512(vecs.data(), vecs.size(), 0); - ASSERT_EQ(res, true); - res = FindTermAVX512(vecs.data(), vecs.size(), 10); - ASSERT_EQ(res, true); - res = FindTermAVX512(vecs.data(), vecs.size(), 999); - ASSERT_EQ(res, true); - res = FindTermAVX512(vecs.data(), vecs.size(), 1000); - ASSERT_EQ(res, false); - res = FindTermAVX512(vecs.data(), vecs.size(), 1270); - ASSERT_EQ(res, false); - vecs.push_back(1270); - res = FindTermAVX512(vecs.data(), vecs.size(), 1270); - ASSERT_EQ(res, true); -} - -TEST(FindTermAVX512, int64_type) { - if (!cpu_support_avx512()) { - PRINT_SKPI_TEST - return; - } - std::vector vecs; - for (int i = 0; i < 1000; i++) { - vecs.push_back(i); - } - - auto res = FindTermAVX512(vecs.data(), vecs.size(), (int64_t)0); - ASSERT_EQ(res, true); - res = FindTermAVX512(vecs.data(), vecs.size(), (int64_t)10); - ASSERT_EQ(res, true); - res = FindTermAVX512(vecs.data(), vecs.size(), (int64_t)999); - ASSERT_EQ(res, true); - res = FindTermAVX512(vecs.data(), vecs.size(), (int64_t)1000); - ASSERT_EQ(res, false); - res = FindTermAVX512(vecs.data(), vecs.size(), (int64_t)1270); - ASSERT_EQ(res, false); - vecs.push_back(1270); - res = FindTermAVX512(vecs.data(), vecs.size(), (int64_t)1270); - ASSERT_EQ(res, true); -} - -TEST(FindTermAVX512, float_type) { - if (!cpu_support_avx512()) { - PRINT_SKPI_TEST - return; - } - std::vector vecs; - for (int i = 0; i < 10000; i++) { - vecs.push_back(i + 0.01); - } - - auto res = FindTermAVX512(vecs.data(), vecs.size(), (float)0.01); - ASSERT_EQ(res, true); - res = FindTermAVX512(vecs.data(), vecs.size(), (float)10.01); - ASSERT_EQ(res, true); - res = FindTermAVX512(vecs.data(), vecs.size(), (float)10000.01); - ASSERT_EQ(res, false); - res = FindTermAVX512(vecs.data(), vecs.size(), (float)12700.02); - ASSERT_EQ(res, false); - vecs.push_back(12700.02); - res = FindTermAVX512(vecs.data(), vecs.size(), (float)12700.02); - ASSERT_EQ(res, true); -} - -TEST(StrCmpSS4, string_type) { - if (!cpu_support_sse4_2()) { - PRINT_SKPI_TEST - return; - } - - std::vector s1; - for (int i = 0; i < 1000; ++i) { - s1.push_back("test" + std::to_string(i)); - } - - for (int i = 0; i < 1000; ++i) { - auto res = StrCmpSSE4(s1[i].c_str(), "test0"); - } - - string s2; - string s3; - for (int i = 0; i < 1000; ++i) { - s2.push_back('x'); - } - for (int i = 0; i < 1000; ++i) { - s3.push_back('x'); - } - - auto res = StrCmpSSE4(s2.c_str(), s3.c_str()); - std::cout << res << std::endl; -} - -TEST(FindTermAVX512, double_type) { - if (!cpu_support_avx512()) { - PRINT_SKPI_TEST - return; - } - std::vector vecs; - for (int i = 0; i < 10000; i++) { - vecs.push_back(i + 0.01); - } - - auto res = FindTermAVX512(vecs.data(), vecs.size(), 0.01); - ASSERT_EQ(res, true); - res = FindTermAVX512(vecs.data(), vecs.size(), 10.01); - ASSERT_EQ(res, true); - res = FindTermAVX512(vecs.data(), vecs.size(), 10000.01); - ASSERT_EQ(res, false); - res = FindTermAVX512(vecs.data(), vecs.size(), 12700.01); - ASSERT_EQ(res, false); - vecs.push_back(12700.01); - res = FindTermAVX512(vecs.data(), vecs.size(), 12700.01); - ASSERT_EQ(res, true); -} - -TEST(EqualVal, perf_int8) { - if (!cpu_support_avx512()) { - PRINT_SKPI_TEST - return; - } - std::vector srcs(1000000); - for (int i = 0; i < 1000000; ++i) { - srcs[i] = i % 128; - } - FixedVector res(1000000); - auto start = std::chrono::steady_clock::now(); - EqualValRef(srcs.data(), 1000000, (int8_t)10, res.data()); - std::cout << std::chrono::duration_cast( - std::chrono::steady_clock::now() - start) - .count() - << std::endl; - start = std::chrono::steady_clock::now(); - EqualValAVX512(srcs.data(), 1000000, (int8_t)10, res.data()); - std::cout << std::chrono::duration_cast( - std::chrono::steady_clock::now() - start) - .count() - << std::endl; -} - -template -void -TestCompareValAVX512Perf() { - if (!cpu_support_avx512()) { - PRINT_SKPI_TEST - return; - } - std::vector srcs(1000000); - for (int i = 0; i < 1000000; ++i) { - srcs[i] = i; - } - FixedVector res(1000000); - T target = 10; - auto start = std::chrono::steady_clock::now(); - EqualValRef(srcs.data(), 1000000, target, res.data()); - std::cout << std::chrono::duration_cast( - std::chrono::steady_clock::now() - start) - .count() - << std::endl; - start = std::chrono::steady_clock::now(); - EqualValAVX512(srcs.data(), 1000000, target, res.data()); - std::cout << std::chrono::duration_cast( - std::chrono::steady_clock::now() - start) - .count() - << std::endl; -} - -TEST(EqualVal, perf_int16) { - TestCompareValAVX512Perf(); -} - -TEST(EqualVal, pref_int32) { - TestCompareValAVX512Perf(); -} - -TEST(EqualVal, perf_int64) { - TestCompareValAVX512Perf(); -} - -TEST(EqualVal, perf_float) { - TestCompareValAVX512Perf(); -} - -TEST(EqualVal, perf_double) { - TestCompareValAVX512Perf(); -} - -template -void -TestCompareValAVX512(int size, T target) { - if (!cpu_support_avx512()) { - PRINT_SKPI_TEST - return; - } - std::vector vecs; - for (int i = 0; i < size; ++i) { - if constexpr (std::is_same_v) { - vecs.push_back(i % 127); - } else if constexpr (std::is_floating_point_v) { - vecs.push_back(i + 0.01); - } else { - vecs.push_back(i); - } - } - FixedVector res(size); - - EqualValAVX512(vecs.data(), size, target, res.data()); - for (int i = 0; i < size; i++) { - ASSERT_EQ(res[i], vecs[i] == target) << i; - } - LessValAVX512(vecs.data(), size, target, res.data()); - for (int i = 0; i < size; i++) { - ASSERT_EQ(res[i], vecs[i] < target) << i; - } - LessEqualValAVX512(vecs.data(), size, target, res.data()); - for (int i = 0; i < size; i++) { - ASSERT_EQ(res[i], vecs[i] <= target) << i; - } - GreaterEqualValAVX512(vecs.data(), size, target, res.data()); - for (int i = 0; i < size; i++) { - ASSERT_EQ(res[i], vecs[i] >= target) << i; - } - GreaterValAVX512(vecs.data(), size, target, res.data()); - for (int i = 0; i < size; i++) { - ASSERT_EQ(res[i], vecs[i] > target) << i; - } - NotEqualValAVX512(vecs.data(), size, target, res.data()); - for (int i = 0; i < size; i++) { - ASSERT_EQ(res[i], vecs[i] != target) << i; - } -} - -TEST(CompareVal, avx512_int8) { - TestCompareValAVX512(1000, 9); - TestCompareValAVX512(1000, 99); - TestCompareValAVX512(1001, 127); -} - -TEST(CompareVal, avx512_int16) { - TestCompareValAVX512(1000, 99); - TestCompareValAVX512(1000, 999); - TestCompareValAVX512(1001, 1000); -} - -TEST(CompareVal, avx512_int32) { - TestCompareValAVX512(1000, 99); - TestCompareValAVX512(1000, 999); - TestCompareValAVX512(1001, 1000); -} - -TEST(CompareVal, avx512_int64) { - TestCompareValAVX512(1000, 99); - TestCompareValAVX512(1000, 999); - TestCompareValAVX512(1001, 1000); -} - -TEST(CompareVal, avx512_float) { - TestCompareValAVX512(1000, 99.01); - TestCompareValAVX512(1000, 999.01); - TestCompareValAVX512(1001, 1000.01); -} - -TEST(CompareVal, avx512_double) { - TestCompareValAVX512(1000, 99.01); - TestCompareValAVX512(1000, 999.01); - TestCompareValAVX512(1001, 1000.01); -} - -template -void -TestCompareColumnAVX512Perf() { - if (!cpu_support_avx512()) { - PRINT_SKPI_TEST - return; - } - std::vector lefts(1000000); - for (int i = 0; i < 1000000; ++i) { - lefts[i] = i; - } - std::vector rights(1000000); - for (int i = 0; i < 1000000; ++i) { - rights[i] = i; - } - FixedVector res(1000000); - auto start = std::chrono::steady_clock::now(); - LessColumnRef(lefts.data(), rights.data(), 1000000, res.data()); - std::cout << std::chrono::duration_cast( - std::chrono::steady_clock::now() - start) - .count() - << std::endl; - start = std::chrono::steady_clock::now(); - LessColumnAVX512(lefts.data(), rights.data(), 1000000, res.data()); - std::cout << std::chrono::duration_cast( - std::chrono::steady_clock::now() - start) - .count() - << std::endl; -} - -TEST(LessColumn, pref_int32) { - TestCompareColumnAVX512Perf(); -} - -TEST(LessColumn, perf_int64) { - TestCompareColumnAVX512Perf(); -} - -TEST(LessColumn, perf_float) { - TestCompareColumnAVX512Perf(); -} - -TEST(LessColumn, perf_double) { - TestCompareColumnAVX512Perf(); -} - -template -void -TestCompareColumnAVX512(int size, T min_val, T max_val) { - if (!cpu_support_avx512()) { - PRINT_SKPI_TEST - return; - } - std::random_device rd; - std::mt19937 gen(rd()); - - std::vector left; - std::vector right; - if constexpr (std::is_same_v) { - std::uniform_real_distribution dis(min_val, max_val); - for (int i = 0; i < size; ++i) { - left.push_back(dis(gen)); - right.push_back(dis(gen)); - } - } else if constexpr (std::is_same_v) { - std::uniform_real_distribution dis(min_val, max_val); - for (int i = 0; i < size; ++i) { - left.push_back(dis(gen)); - right.push_back(dis(gen)); - } - } else { - std::uniform_int_distribution<> dis(min_val, max_val); - for (int i = 0; i < size; ++i) { - left.push_back(dis(gen)); - right.push_back(dis(gen)); - } - } - - FixedVector res(size); - - EqualColumnAVX512(left.data(), right.data(), size, res.data()); - for (int i = 0; i < size; i++) { - ASSERT_EQ(res[i], left[i] == right[i]) << i; - } - LessColumnAVX512(left.data(), right.data(), size, res.data()); - for (int i = 0; i < size; i++) { - ASSERT_EQ(res[i], left[i] < right[i]) << i; - } - GreaterColumnAVX512(left.data(), right.data(), size, res.data()); - for (int i = 0; i < size; i++) { - ASSERT_EQ(res[i], left[i] > right[i]) << i; - } - LessEqualColumnAVX512(left.data(), right.data(), size, res.data()); - for (int i = 0; i < size; i++) { - ASSERT_EQ(res[i], left[i] <= right[i]) << i; - } - GreaterEqualColumnAVX512(left.data(), right.data(), size, res.data()); - for (int i = 0; i < size; i++) { - ASSERT_EQ(res[i], left[i] >= right[i]) << i; - } - NotEqualColumnAVX512(left.data(), right.data(), size, res.data()); - for (int i = 0; i < size; i++) { - ASSERT_EQ(res[i], left[i] != right[i]) << i; - } -} - -TEST(CompareColumn, avx512_int8) { - TestCompareColumnAVX512(1000, -128, 127); - TestCompareColumnAVX512(1001, -128, 127); -} - -TEST(CompareColumn, avx512_int16) { - TestCompareColumnAVX512(1000, -1000, 1000); - TestCompareColumnAVX512(1001, -1000, 1000); -} - -TEST(CompareColumn, avx512_int32) { - TestCompareColumnAVX512(1000, -1000, 1000); - TestCompareColumnAVX512(1001, -1000, 1000); -} - -TEST(CompareColumn, avx512_int64) { - TestCompareColumnAVX512(1000, -1000, 1000); - TestCompareColumnAVX512(1001, -1000, 1000); -} - -TEST(CompareColumn, avx512_float) { - TestCompareColumnAVX512(1000, -1.0, 1.0); - TestCompareColumnAVX512(1001, -1.0, 1.0); -} - -TEST(CompareColumn, avx512_double) { - TestCompareColumnAVX512(1000, -1.0, 1.0); - TestCompareColumnAVX512(1001, -1.0, 1.0); -} - -TEST(AllBooleanSSE2, function) { - FixedVector src; - for (int i = 0; i < 8192; ++i) { - src.push_back(false); - } - auto res = AllFalseSSE2(src.data(), src.size()); - EXPECT_EQ(res, true); - res = AllTrueSSE2(src.data(), src.size()); - EXPECT_EQ(res, false); - - for (int i = 0; i < 8192; ++i) { - src.push_back(i % 2 == 0 ? true : false); - } - res = AllFalseSSE2(src.data(), src.size()); - EXPECT_EQ(res, false); - res = AllTrueSSE2(src.data(), src.size()); - EXPECT_EQ(res, false); - - src.clear(); - for (int i = 0; i < 8192; ++i) { - src.push_back(true); - } - res = AllTrueSSE2(src.data(), src.size()); - EXPECT_EQ(res, true); -} - -TEST(AllBooleanSSE2, performance) { - FixedVector src; - - for (int i = 0; i < 8192; ++i) { - src.push_back(i % 2 == 0 ? true : false); - } - std::cout << "sse2" << std::endl; - for (int j = 0; j < 10; j++) { - auto start = std::chrono::system_clock::now(); - auto res = AllFalseSSE2(src.data(), src.size()); - std::cout << std::chrono::duration_cast( - std::chrono::system_clock::now() - start) - .count() - << std::endl; - start = std::chrono::system_clock::now(); - res = AllTrueSSE2(src.data(), src.size()); - std::cout << std::chrono::duration_cast( - std::chrono::system_clock::now() - start) - .count() - << std::endl; - } - - std::cout << "avx2" << std::endl; - for (int j = 0; j < 10; j++) { - auto start = std::chrono::system_clock::now(); - auto res = AllFalseAVX2(src.data(), src.size()); - std::cout << std::chrono::duration_cast( - std::chrono::system_clock::now() - start) - .count() - << std::endl; - start = std::chrono::system_clock::now(); - res = AllTrueAVX2(src.data(), src.size()); - std::cout << std::chrono::duration_cast( - std::chrono::system_clock::now() - start) - .count() - << std::endl; - } - - for (int j = 0; j < 10; j++) { - auto start = std::chrono::system_clock::now(); - auto res = AllFalseRef(src.data(), src.size()); - std::cout << std::chrono::duration_cast( - std::chrono::system_clock::now() - start) - .count() - << std::endl; - start = std::chrono::system_clock::now(); - res = AllTrueRef(src.data(), src.size()); - std::cout << std::chrono::duration_cast( - std::chrono::system_clock::now() - start) - .count() - << std::endl; - } -} - -TEST(InvertBool, function) { - FixedVector src; - for (int i = 0; i < 8192; ++i) { - src.push_back(i % 2 == 0 ? true : false); - } - InvertBoolSSE2(src.data(), src.size()); - for (int i = 0; i < 8192; ++i) { - EXPECT_EQ(src[i], (i % 2) != 0); - } - - src.clear(); - for (int i = 0; i < 8192; ++i) { - src.push_back(i % 3 == 0 ? true : false); - } - InvertBoolSSE2(src.data(), src.size()); - for (int i = 0; i < 8192; ++i) { - EXPECT_EQ(src[i], (i % 3) != 0); - } -} - -TEST(InvertBool, performance) { - FixedVector src; - for (int i = 0; i < 8192; ++i) { - src.push_back(i % 2 == 0 ? true : false); - } - for (int i = 0; i < 10; ++i) { - auto start = std::chrono::system_clock::now(); - InvertBoolSSE2(src.data(), src.size()); - std::cout << std::chrono::duration_cast( - std::chrono::system_clock::now() - start) - .count() - << std::endl; - } - - for (int i = 0; i < 10; ++i) { - auto start = std::chrono::system_clock::now(); - InvertBoolRef(src.data(), src.size()); - std::cout << std::chrono::duration_cast( - std::chrono::system_clock::now() - start) - .count() - << std::endl; - } -} - -TEST(LogicalBool, function) { - FixedVector left; - for (int i = 0; i < 8192; ++i) { - left.push_back(i % 2 == 0 ? true : false); - } - FixedVector right; - for (int i = 0; i < 8192; ++i) { - right.push_back(i % 2 == 0 ? true : false); - } - AndBoolSSE2(left.data(), right.data(), right.size()); - for (int i = 0; i < 8192; ++i) { - EXPECT_EQ(left[i], i % 2 == 0); - } - OrBoolSSE2(left.data(), right.data(), right.size()); - for (int i = 0; i < 8192; ++i) { - EXPECT_EQ(left[i], i % 2 == 0); - } - - left.clear(); - right.clear(); - for (int i = 0; i < 8192; ++i) { - left.push_back(i % 2 == 0 ? true : false); - } - for (int i = 0; i < 8192; ++i) { - right.push_back(i % 5 == 0 ? true : false); - } - AndBoolSSE2(left.data(), right.data(), right.size()); - for (int i = 0; i < 8192; ++i) { - EXPECT_EQ(left[i], (i % 2 == 0) && (i % 5 == 0)); - } - - left.clear(); - right.clear(); - for (int i = 0; i < 8192; ++i) { - left.push_back(i % 2 == 0 ? true : false); - } - for (int i = 0; i < 8192; ++i) { - right.push_back(i % 5 == 0 ? true : false); - } - OrBoolSSE2(left.data(), right.data(), right.size()); - for (int i = 0; i < 8192; ++i) { - EXPECT_EQ(left[i], (i % 2 == 0) || (i % 5 == 0)); - } - - left.clear(); - right.clear(); - for (int i = 0; i < 8192; ++i) { - left.push_back(i % 2 == 0 ? true : false); - } - for (int i = 0; i < 8192; ++i) { - right.push_back(i % 5 == 0 ? true : false); - } - AndBoolAVX2(left.data(), right.data(), right.size()); - for (int i = 0; i < 8192; ++i) { - EXPECT_EQ(left[i], (i % 2 == 0) && (i % 5 == 0)); - } - - left.clear(); - right.clear(); - for (int i = 0; i < 8192; ++i) { - left.push_back(i % 2 == 0 ? true : false); - } - for (int i = 0; i < 8192; ++i) { - right.push_back(i % 5 == 0 ? true : false); - } - OrBoolAVX2(left.data(), right.data(), right.size()); - for (int i = 0; i < 8192; ++i) { - EXPECT_EQ(left[i], (i % 2 == 0) || (i % 5 == 0)); - } -} - -TEST(LogicalBool, performance) { - FixedVector left; - for (int i = 0; i < 8192; ++i) { - left.push_back(i % 2 == 0 ? true : false); - } - FixedVector right; - for (int i = 0; i < 8192; ++i) { - right.push_back(i % 2 == 0 ? true : false); - } - std::cout << "sse2" << std::endl; - for (int i = 0; i < 10; ++i) { - auto start = std::chrono::system_clock::now(); - AndBoolSSE2(left.data(), right.data(), left.size()); - std::cout << std::chrono::duration_cast( - std::chrono::system_clock::now() - start) - .count() - << std::endl; - } - std::cout << "avx2" << std::endl; - for (int i = 0; i < 10; ++i) { - auto start = std::chrono::system_clock::now(); - AndBoolAVX2(left.data(), right.data(), left.size()); - std::cout << std::chrono::duration_cast( - std::chrono::system_clock::now() - start) - .count() - << std::endl; - } - - for (int i = 0; i < 10; ++i) { - auto start = std::chrono::system_clock::now(); - AndBoolRef(left.data(), right.data(), left.size()); - std::cout << std::chrono::duration_cast( - std::chrono::system_clock::now() - start) - .count() - << std::endl; - } -} - -#endif - -#if defined(__ARM_NEON) -#include "simd/ref.h" -#include "simd/neon.h" -using namespace milvus::simd; - -#include -#include - -void -print_uint8x16(uint8x16_t vec) { - uint8_t tmp[16]; - vst1q_u8(tmp, vec); - - std::cout << "Vector contents: "; - for (int i = 0; i < 16; ++i) { - std::cout << static_cast(tmp[i]) << " "; - } - std::cout << std::endl; -} - -void -print_uint8x8(uint8x8_t vec) { - uint8_t tmp[8]; - vst1_u8(tmp, vec); - - std::cout << "Vector contents: "; - for (int i = 0; i < 8; ++i) { - std::cout << static_cast(tmp[i]) << " "; - } - std::cout << std::endl; -} - -void -print_uint16x8(uint16x8_t vec) { - uint16_t tmp[8]; - vst1q_u16(tmp, vec); - - std::cout << "Vector contents: "; - for (int i = 0; i < 8; ++i) { - std::cout << static_cast(tmp[i]) << " "; - } - std::cout << std::endl; -} - -TEST(InvertBool, function) { - FixedVector src; - for (int i = 0; i < 8192; ++i) { - src.push_back(i % 2 == 0 ? true : false); - } - InvertBoolNEON(src.data(), src.size()); - for (int i = 0; i < 8192; ++i) { - EXPECT_EQ(src[i], (i % 2) != 0); - } - - src.clear(); - for (int i = 0; i < 8192; ++i) { - src.push_back(i % 3 == 0 ? true : false); - } - InvertBoolNEON(src.data(), src.size()); - for (int i = 0; i < 8192; ++i) { - EXPECT_EQ(src[i], (i % 3) != 0); - } -} - -TEST(InvertBool, performance) { - FixedVector src; - for (int i = 0; i < 8192; ++i) { - src.push_back(i % 2 == 0 ? true : false); - } - for (int i = 0; i < 10; ++i) { - auto start = std::chrono::system_clock::now(); - InvertBoolNEON(src.data(), src.size()); - std::cout << std::chrono::duration_cast( - std::chrono::system_clock::now() - start) - .count() - << std::endl; - } - - for (int i = 0; i < 10; ++i) { - auto start = std::chrono::system_clock::now(); - InvertBoolRef(src.data(), src.size()); - std::cout << std::chrono::duration_cast( - std::chrono::system_clock::now() - start) - .count() - << std::endl; - } -} - -TEST(LogicalBool, function) { - FixedVector left; - for (int i = 0; i < 8192; ++i) { - left.push_back(i % 2 == 0 ? true : false); - } - FixedVector right; - for (int i = 0; i < 8192; ++i) { - right.push_back(i % 2 == 0 ? true : false); - } - AndBoolNEON(left.data(), right.data(), right.size()); - for (int i = 0; i < 8192; ++i) { - EXPECT_EQ(left[i], i % 2 == 0); - } - OrBoolNEON(left.data(), right.data(), right.size()); - for (int i = 0; i < 8192; ++i) { - EXPECT_EQ(left[i], i % 2 == 0); - } - - left.clear(); - right.clear(); - for (int i = 0; i < 8192; ++i) { - left.push_back(i % 2 == 0 ? true : false); - } - for (int i = 0; i < 8192; ++i) { - right.push_back(i % 5 == 0 ? true : false); - } - AndBoolNEON(left.data(), right.data(), right.size()); - for (int i = 0; i < 8192; ++i) { - EXPECT_EQ(left[i], (i % 2 == 0) && (i % 5 == 0)); - } - - left.clear(); - right.clear(); - for (int i = 0; i < 8192; ++i) { - left.push_back(i % 2 == 0 ? true : false); - } - for (int i = 0; i < 8192; ++i) { - right.push_back(i % 5 == 0 ? true : false); - } - OrBoolNEON(left.data(), right.data(), right.size()); - for (int i = 0; i < 8192; ++i) { - EXPECT_EQ(left[i], (i % 2 == 0) || (i % 5 == 0)); - } -} - -TEST(LogicalBool, performance) { - FixedVector left; - for (int i = 0; i < 8192; ++i) { - left.push_back(i % 2 == 0 ? true : false); - } - FixedVector right; - for (int i = 0; i < 8192; ++i) { - right.push_back(i % 2 == 0 ? true : false); - } - std::cout << "NEON" << std::endl; - for (int i = 0; i < 10; ++i) { - auto start = std::chrono::system_clock::now(); - AndBoolNEON(left.data(), right.data(), left.size()); - std::cout << std::chrono::duration_cast( - std::chrono::system_clock::now() - start) - .count() - << std::endl; - } - std::cout << "ref" << std::endl; - - for (int i = 0; i < 10; ++i) { - auto start = std::chrono::system_clock::now(); - AndBoolRef(left.data(), right.data(), left.size()); - std::cout << std::chrono::duration_cast( - std::chrono::system_clock::now() - start) - .count() - << std::endl; - } -} - -TEST(AllBooleanNeon, function) { - FixedVector src; - for (int i = 0; i < 8192; ++i) { - src.push_back(false); - } - auto res = AllFalseNEON(src.data(), src.size()); - EXPECT_EQ(res, true); - res = AllTrueNEON(src.data(), src.size()); - EXPECT_EQ(res, false); - - for (int i = 0; i < 8192; ++i) { - src.push_back(i % 2 == 0 ? true : false); - } - res = AllFalseNEON(src.data(), src.size()); - EXPECT_EQ(res, false); - res = AllTrueNEON(src.data(), src.size()); - EXPECT_EQ(res, false); - - src.clear(); - for (int i = 0; i < 8192; ++i) { - src.push_back(true); - } - res = AllTrueNEON(src.data(), src.size()); - EXPECT_EQ(res, true); -} - -TEST(AllBooleanNeon, performance) { - FixedVector src; - - for (int i = 0; i < 8192; ++i) { - src.push_back(i % 2 == 0 ? true : false); - } - std::cout << "NEON" << std::endl; - for (int j = 0; j < 10; j++) { - auto start = std::chrono::system_clock::now(); - auto res = AllFalseNEON(src.data(), src.size()); - std::cout << std::chrono::duration_cast( - std::chrono::system_clock::now() - start) - .count() - << std::endl; - start = std::chrono::system_clock::now(); - res = AllTrueNEON(src.data(), src.size()); - std::cout << std::chrono::duration_cast( - std::chrono::system_clock::now() - start) - .count() - << std::endl; - } - - std::cout << "ref" << std::endl; - for (int j = 0; j < 10; j++) { - auto start = std::chrono::system_clock::now(); - auto res = AllFalseRef(src.data(), src.size()); - std::cout << std::chrono::duration_cast( - std::chrono::system_clock::now() - start) - .count() - << std::endl; - start = std::chrono::system_clock::now(); - res = AllTrueRef(src.data(), src.size()); - std::cout << std::chrono::duration_cast( - std::chrono::system_clock::now() - start) - .count() - << std::endl; - } -} - -#endif - -int -main(int argc, char* argv[]) { - ::testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); -} diff --git a/internal/core/unittest/test_utils/AssertUtils.h b/internal/core/unittest/test_utils/AssertUtils.h index d089b88650ef..5e92369b9043 100644 --- a/internal/core/unittest/test_utils/AssertUtils.h +++ b/internal/core/unittest/test_utils/AssertUtils.h @@ -36,33 +36,18 @@ compare_double(double x, double y, double epsilon = 0.000001f) { } bool -Any(const milvus::FixedVector& vec) { - for (auto& val : vec) { - if (val == false) { - return false; - } - } - return true; +Any(const milvus::TargetBitmap& bitmap) { + return bitmap.any(); } bool -BitSetNone(const milvus::FixedVector& vec) { - for (auto& val : vec) { - if (val == true) { - return false; - } - } - return true; +BitSetNone(const milvus::TargetBitmap& bitmap) { + return bitmap.none(); } uint64_t -Count(const milvus::FixedVector& vec) { - uint64_t count = 0; - for (size_t i = 0; i < vec.size(); ++i) { - if (vec[i] == true) - count++; - } - return count; +Count(const milvus::TargetBitmap& bitmap) { + return bitmap.count(); } inline void