diff --git a/be/src/core/column/column_complex.h b/be/src/core/column/column_complex.h index 9f0d7b45e72cf4..00c1987e6b2824 100644 --- a/be/src/core/column/column_complex.h +++ b/be/src/core/column/column_complex.h @@ -70,7 +70,9 @@ class ColumnComplexType final : public COWHelper> } if constexpr (T == TYPE_BITMAP) { - pvalue->deserialize(pos); + if (!pvalue->deserialize(pos, length)) { + throw Exception(Status::DataQualityError("Failed to deserialize bitmap data")); + } } else if constexpr (T == TYPE_HLL) { pvalue->deserialize(Slice(pos, length)); } else if constexpr (T == TYPE_QUANTILE_STATE) { diff --git a/be/src/core/data_type/data_type_bitmap.cpp b/be/src/core/data_type/data_type_bitmap.cpp index 4bb6f2a0f7533f..4ce72626ec5eb3 100644 --- a/be/src/core/data_type/data_type_bitmap.cpp +++ b/be/src/core/data_type/data_type_bitmap.cpp @@ -20,6 +20,8 @@ #include #include "agent/be_exec_version_manager.h" +#include "common/exception.h" +#include "common/status.h" #include "core/assert_cast.h" #include "core/column/column.h" #include "core/column/column_complex.h" @@ -90,8 +92,11 @@ const char* DataTypeBitMap::deserialize(const char* buf, MutableColumnPtr* colum const auto* meta_ptr = reinterpret_cast(buf); const char* data_ptr = buf + sizeof(size_t) * real_have_saved_num; for (size_t i = 0; i < real_have_saved_num; ++i) { - data[i].deserialize(data_ptr); - data_ptr += unaligned_load(&meta_ptr[i]); + auto one_size = unaligned_load(&meta_ptr[i]); + if (!data[i].deserialize(data_ptr, one_size)) { + throw Exception(Status::DataQualityError("Failed to deserialize bitmap data")); + } + data_ptr += one_size; } return data_ptr; } @@ -115,6 +120,8 @@ void DataTypeBitMap::serialize_as_stream(const BitmapValue& cvalue, BufferWritab void DataTypeBitMap::deserialize_as_stream(BitmapValue& value, BufferReadable& buf) { StringRef ref; buf.read_binary(ref); - value.deserialize(ref.data); + if (!value.deserialize(ref.data, ref.size)) { + throw Exception(Status::DataQualityError("Failed to deserialize bitmap data")); + } } } // namespace doris diff --git a/be/src/core/data_type_serde/data_type_bitmap_serde.cpp b/be/src/core/data_type_serde/data_type_bitmap_serde.cpp index cb11d03e303ec7..9bd054f876e200 100644 --- a/be/src/core/data_type_serde/data_type_bitmap_serde.cpp +++ b/be/src/core/data_type_serde/data_type_bitmap_serde.cpp @@ -70,7 +70,7 @@ Status DataTypeBitMapSerDe::deserialize_one_cell_from_json(IColumn& column, Slic auto& data = data_column.get_data(); BitmapValue value; - if (!value.deserialize(slice.data)) { + if (!value.deserialize(slice.data, slice.size)) { return Status::InternalError("deserialize BITMAP from string fail!"); } data.push_back(std::move(value)); @@ -98,7 +98,7 @@ Status DataTypeBitMapSerDe::write_column_to_pb(const IColumn& column, PValues& r Status DataTypeBitMapSerDe::read_column_from_pb(IColumn& column, const PValues& arg) const { auto& col = reinterpret_cast(column); for (int i = 0; i < arg.bytes_value_size(); ++i) { - BitmapValue value(arg.bytes_value(i).data()); + BitmapValue value(arg.bytes_value(i).data(), arg.bytes_value(i).size()); col.insert_value(value); } return Status::OK(); @@ -144,7 +144,7 @@ Status DataTypeBitMapSerDe::write_column_to_arrow(const IColumn& column, const N void DataTypeBitMapSerDe::read_one_cell_from_jsonb(IColumn& column, const JsonbValue* arg) const { auto& col = reinterpret_cast(column); auto* blob = arg->unpack(); - BitmapValue bitmap_value(blob->getBlob()); + BitmapValue bitmap_value(blob->getBlob(), blob->getBlobLen()); col.insert_value(bitmap_value); } @@ -224,7 +224,7 @@ Status DataTypeBitMapSerDe::from_string(StringRef& str, IColumn& column, Status DataTypeBitMapSerDe::from_olap_string(const std::string& str, Field& field, const FormatOptions& options) const { BitmapValue value; - if (!value.deserialize(str.data())) { + if (!value.deserialize(str.data(), str.size())) { return Status::InternalError("deserialize BITMAP from string fail!"); } field = Field::create_field(std::move(value)); diff --git a/be/src/core/value/bitmap_value.h b/be/src/core/value/bitmap_value.h index 3b7b05d04a10c3..60b22c5aab40a2 100644 --- a/be/src/core/value/bitmap_value.h +++ b/be/src/core/value/bitmap_value.h @@ -39,6 +39,7 @@ #include "common/config.h" #include "common/exception.h" #include "common/logging.h" +#include "common/status.h" #include "core/pod_array.h" #include "core/pod_array_fwd.h" #include "util/coding.h" @@ -563,41 +564,95 @@ class Roaring64Map { } /** - * read a bitmap from a serialized version. + * Read a Doris-encoded BITMAP* payload from `buf`. Reads no more than + * `maxbytes` bytes. Throws doris::Exception if the buffer is too small + * or the encoding is otherwise malformed. * - * This function is unsafe in the sense that if you provide bad data, - * many bytes could be read, possibly causing a buffer overflow. See also readSafe. + * Wraps the upstream CRoaring *_deserialize_safe primitives so a + * tampered map_size varint or per-container count cannot trigger a + * heap out-of-bounds read. */ - static Roaring64Map read(const char* buf) { + static Roaring64Map readSafe(const char* buf, size_t maxbytes) { Roaring64Map result; - + if (maxbytes < 1) { + throw Exception(ErrorCode::INVALID_ARGUMENT, + "ran out of bytes while reading bitmap type code"); + } bool is_v1 = BitmapTypeCode::BITMAP32 == *buf || BitmapTypeCode::BITMAP64 == *buf; bool is_bitmap32 = BitmapTypeCode::BITMAP32 == *buf || BitmapTypeCode::BITMAP32_V2 == *buf; bool is_bitmap64 = BitmapTypeCode::BITMAP64 == *buf || BitmapTypeCode::BITMAP64_V2 == *buf; + if (!is_bitmap32 && !is_bitmap64) { + throw Exception(ErrorCode::INVALID_ARGUMENT, "invalid bitmap type code for read: {}", + (int)(*buf)); + } + buf++; + maxbytes--; + + // Doris convention: Roaring::read/write second arg is `portable`, and + // is_v1=true means portable serialization (per existing read() above). + auto read_one_roaring = [&](roaring::Roaring& out) { + roaring::api::roaring_bitmap_t* rb = nullptr; + if (is_v1) { + size_t need = roaring::api::roaring_bitmap_portable_deserialize_size(buf, maxbytes); + if (need == 0 || need > maxbytes) { + throw Exception(ErrorCode::INVALID_ARGUMENT, + "ran out of bytes or invalid portable roaring bitmap"); + } + rb = roaring::api::roaring_bitmap_portable_deserialize_safe(buf, maxbytes); + } else { + rb = roaring::api::roaring_bitmap_deserialize_safe(buf, maxbytes); + } + if (rb == nullptr) { + throw Exception(ErrorCode::INVALID_ARGUMENT, + "ran out of bytes or invalid roaring bitmap container"); + } + out = roaring::Roaring(rb); + size_t consumed = out.getSizeInBytes(is_v1); + if (consumed > maxbytes) { + throw Exception(ErrorCode::INVALID_ARGUMENT, + "inconsistent roaring bitmap size after read"); + } + buf += consumed; + maxbytes -= consumed; + }; + if (is_bitmap32) { - roaring::Roaring read = roaring::Roaring::read(buf + 1, is_v1); - result.emplaceOrInsert(0, std::move(read)); + roaring::Roaring r; + read_one_roaring(r); + result.emplaceOrInsert(0, std::move(r)); return result; } - DCHECK(is_bitmap64); - buf++; - - // get map size (varint64 took 1~10 bytes) - uint64_t map_size; - buf = reinterpret_cast( + // is_bitmap64: read map_size varint within remaining buffer. + uint64_t map_size = 0; + size_t varint_max = maxbytes < 10 ? maxbytes : 10; + const uint8_t* p = decode_varint64_ptr(reinterpret_cast(buf), - reinterpret_cast(buf + 10), &map_size)); - DCHECK(buf != nullptr); + reinterpret_cast(buf + varint_max), &map_size); + if (p == nullptr) { + throw Exception(ErrorCode::INVALID_ARGUMENT, + "varint decode failure for bitmap map_size"); + } + size_t consumed = reinterpret_cast(p) - buf; + buf = reinterpret_cast(p); + maxbytes -= consumed; + + // Cheap upper-bound: each entry takes at least 4 bytes (the map key). + if (map_size > maxbytes / sizeof(uint32_t)) { + throw Exception(ErrorCode::INVALID_ARGUMENT, "declared bitmap map_size exceeds buffer"); + } + for (uint64_t lcv = 0; lcv < map_size; lcv++) { - // get map key + if (maxbytes < sizeof(uint32_t)) { + throw Exception(ErrorCode::INVALID_ARGUMENT, "ran out of bytes for bitmap map key"); + } uint32_t key = decode_fixed32_le(reinterpret_cast(buf)); buf += sizeof(uint32_t); - // read map value Roaring - roaring::Roaring read_var = roaring::Roaring::read(buf, is_v1); - // forward buffer past the last Roaring Bitmap - buf += read_var.getSizeInBytes(is_v1); - result.emplaceOrInsert(key, std::move(read_var)); + maxbytes -= sizeof(uint32_t); + + roaring::Roaring r; + read_one_roaring(r); + result.emplaceOrInsert(key, std::move(r)); } return result; } @@ -882,10 +937,13 @@ class BitmapValue { explicit BitmapValue(uint64_t value) : _sv(value), _bitmap(nullptr), _type(SINGLE), _is_shared(false) {} - // Construct a bitmap from serialized data. - explicit BitmapValue(const char* src) : _is_shared(false) { - bool res = deserialize(src); - DCHECK(res); + // Construct a bitmap from serialized data with a bounded buffer length. + // Throws if the input is truncated or malformed. + BitmapValue(const char* src, size_t maxbytes) : _is_shared(false) { + if (!deserialize(src, maxbytes)) { + throw Exception(ErrorCode::INTERNAL_ERROR, + "BitmapValue: failed to deserialize from buffer"); + } } // !FIXME: We should rethink the logic here @@ -1937,12 +1995,21 @@ class BitmapValue { // Deserialize a bitmap value from `src`. // Return false if `src` begins with unknown type code, true otherwise. - bool deserialize(const char* src) { + // + // Bounded deserialize: reads no more than `maxbytes` bytes from `src`. + // Returns false on unknown type code or any bounds/format violation. + bool deserialize(const char* src, size_t maxbytes) { + if (maxbytes < 1) { + return false; + } switch (*src) { case BitmapTypeCode::EMPTY: _type = EMPTY; break; case BitmapTypeCode::SINGLE32: + if (maxbytes < 1 + sizeof(uint32_t)) { + return false; + } _type = SINGLE; _sv = decode_fixed32_le(reinterpret_cast(src + 1)); if (config::enable_set_in_bitmap_value) { @@ -1951,6 +2018,9 @@ class BitmapValue { } break; case BitmapTypeCode::SINGLE64: + if (maxbytes < 1 + sizeof(uint64_t)) { + return false; + } _type = SINGLE; _sv = decode_fixed64_le(reinterpret_cast(src + 1)); if (config::enable_set_in_bitmap_value) { @@ -1965,21 +2035,32 @@ class BitmapValue { _type = BITMAP; _is_shared = false; try { - _bitmap = std::make_shared(detail::Roaring64Map::read(src)); + _bitmap = std::make_shared( + detail::Roaring64Map::readSafe(src, maxbytes)); + } catch (const doris::Exception& e) { + LOG(ERROR) << "Decode roaring bitmap failed: " << e.what(); + return false; } catch (const std::runtime_error& e) { - LOG(ERROR) << "Decode roaring bitmap failed, " << e.what(); + LOG(ERROR) << "Decode roaring bitmap failed: " << e.what(); return false; } break; case BitmapTypeCode::SET: { + if (maxbytes < 2) { + return false; + } _type = SET; ++src; uint8_t count = *src; ++src; + size_t remaining = maxbytes - 2; if (count > SET_TYPE_THRESHOLD) { throw Exception(ErrorCode::INTERNAL_ERROR, "bitmap value with incorrect set count, count: {}", count); } + if (remaining < static_cast(count) * sizeof(uint64_t)) { + return false; + } _set.reserve(count); for (uint8_t i = 0; i != count; ++i, src += sizeof(uint64_t)) { _set.insert(decode_fixed64_le(reinterpret_cast(src))); @@ -2001,15 +2082,22 @@ class BitmapValue { break; } case BitmapTypeCode::SET_V2: { + if (maxbytes < 1 + sizeof(uint32_t)) { + return false; + } uint32_t size = 0; memcpy(&size, src + 1, sizeof(uint32_t)); src += sizeof(uint32_t) + 1; + size_t remaining = maxbytes - 1 - sizeof(uint32_t); + if (static_cast(size) > remaining / sizeof(uint64_t)) { + return false; + } if (!config::enable_set_in_bitmap_value || size > SET_TYPE_THRESHOLD) { _type = BITMAP; _prepare_bitmap_for_write(); - for (int i = 0; i < size; ++i) { + for (uint32_t i = 0; i < size; ++i) { uint64_t key {}; memcpy(&key, src, sizeof(uint64_t)); _bitmap->add(key); @@ -2019,7 +2107,7 @@ class BitmapValue { _type = SET; _set.reserve(size); - for (int i = 0; i < size; ++i) { + for (uint32_t i = 0; i < size; ++i) { uint64_t key {}; memcpy(&key, src, sizeof(uint64_t)); _set.insert(key); diff --git a/be/src/exprs/aggregate/aggregate_function_orthogonal_bitmap.h b/be/src/exprs/aggregate/aggregate_function_orthogonal_bitmap.h index 146de6605b8d4a..4b879331881e7c 100644 --- a/be/src/exprs/aggregate/aggregate_function_orthogonal_bitmap.h +++ b/be/src/exprs/aggregate/aggregate_function_orthogonal_bitmap.h @@ -27,6 +27,8 @@ #include #include +#include "common/exception.h" +#include "common/status.h" #include "core/column/column_complex.h" #include "core/column/column_vector.h" #include "core/data_type/data_type_bitmap.h" @@ -167,7 +169,10 @@ struct AggIntersectCount : public AggOrthBitmapBaseData { buf.read_binary(AggOrthBitmapBaseData::first_init); std::string data; buf.read_binary(data); - AggOrthBitmapBaseData::bitmap.deserialize(data.data()); + if (!AggOrthBitmapBaseData::bitmap.deserialize(data.data(), data.size())) { + throw Exception(ErrorCode::INTERNAL_ERROR, + "failed to deserialize BitmapIntersect state"); + } } void get(IColumn& to) const { diff --git a/be/src/exprs/function/function_bitmap.cpp b/be/src/exprs/function/function_bitmap.cpp index 35341f297640b0..cf85d440c53962 100644 --- a/be/src/exprs/function/function_bitmap.cpp +++ b/be/src/exprs/function/function_bitmap.cpp @@ -260,9 +260,6 @@ struct BitmapFromString { } }; -struct NameBitmapFromBase64 { - static constexpr auto name = "bitmap_from_base64"; -}; struct BitmapFromBase64 { using ArgumentType = DataTypeString; @@ -302,7 +299,7 @@ struct BitmapFromBase64 { null_map[i] = 1; } else { BitmapValue bitmap_val; - if (!bitmap_val.deserialize(decode_buff.data())) { + if (!bitmap_val.deserialize(decode_buff.data(), outlen)) { return Status::RuntimeError("bitmap_from_base64 decode failed: base64: {}", std::string(src_str, src_size)); } diff --git a/be/src/util/bitmap_expr_calculation.h b/be/src/util/bitmap_expr_calculation.h index ae5adf00b3154b..f4710fd59823b8 100644 --- a/be/src/util/bitmap_expr_calculation.h +++ b/be/src/util/bitmap_expr_calculation.h @@ -31,8 +31,6 @@ class BitmapExprCalculation : public BitmapIntersect { public: BitmapExprCalculation() = default; - explicit BitmapExprCalculation(const char* src) { deserialize(src); } - void bitmap_calculation_init(std::string& input_str) { _polish = reverse_polish(input_str); std::string bitmap_key; diff --git a/be/src/util/bitmap_intersect.h b/be/src/util/bitmap_intersect.h index 7e9d0308843338..a75571fb803c0c 100644 --- a/be/src/util/bitmap_intersect.h +++ b/be/src/util/bitmap_intersect.h @@ -47,10 +47,15 @@ class Helper { // read_from start template - static void read_from(const char** src, T* result) { + static bool read_from(const char** src, size_t* remaining, T* result) { size_t type_size = sizeof(T); + if (*remaining < type_size) { + return false; + } memcpy(result, *src, type_size); *src += type_size; + *remaining -= type_size; + return true; } }; @@ -112,37 +117,68 @@ inline int32_t Helper::serialize_size(const std::string& v) { // serialize_size end template <> -inline void Helper::read_from(const char** src, VecDateTimeValue* result) { +inline bool Helper::read_from(const char** src, size_t* remaining, + VecDateTimeValue* result) { + if (*remaining < (size_t)(DATETIME_PACKED_TIME_BYTE_SIZE + DATETIME_TYPE_BYTE_SIZE)) { + return false; + } result->from_packed_time(*(int64_t*)(*src)); *src += DATETIME_PACKED_TIME_BYTE_SIZE; if (*(int*)(*src) == TIME_DATE) { result->cast_to_date(); } *src += DATETIME_TYPE_BYTE_SIZE; + *remaining -= (DATETIME_PACKED_TIME_BYTE_SIZE + DATETIME_TYPE_BYTE_SIZE); + return true; } template <> -inline void Helper::read_from(const char** src, DecimalV2Value* result) { +inline bool Helper::read_from(const char** src, size_t* remaining, + DecimalV2Value* result) { + if (*remaining < (size_t)DECIMAL_BYTE_SIZE) { + return false; + } __int128 v = 0; memcpy(&v, *src, DECIMAL_BYTE_SIZE); *src += DECIMAL_BYTE_SIZE; + *remaining -= DECIMAL_BYTE_SIZE; *result = DecimalV2Value(v); + return true; } template <> -inline void Helper::read_from(const char** src, StringRef* result) { +inline bool Helper::read_from(const char** src, size_t* remaining, StringRef* result) { + if (*remaining < 4) { + return false; + } int32_t length = *(int32_t*)(*src); *src += 4; + *remaining -= 4; + if (length < 0 || (size_t)length > *remaining) { + return false; + } *result = StringRef((char*)*src, length); *src += length; + *remaining -= length; + return true; } template <> -inline void Helper::read_from(const char** src, std::string* result) { +inline bool Helper::read_from(const char** src, size_t* remaining, + std::string* result) { + if (*remaining < 4) { + return false; + } int32_t length = *(int32_t*)(*src); *src += 4; + *remaining -= 4; + if (length < 0 || (size_t)length > *remaining) { + return false; + } *result = std::string((char*)*src, length); *src += length; + *remaining -= length; + return true; } // read_from end } // namespace detail @@ -156,7 +192,7 @@ struct BitmapIntersect { public: BitmapIntersect() = default; - explicit BitmapIntersect(const char* src) { deserialize(src); } + explicit BitmapIntersect(const char* src, size_t maxbytes) { deserialize(src, maxbytes); } void add_key(const T key) { BitmapValue empty_bitmap; @@ -224,17 +260,38 @@ struct BitmapIntersect { } } - void deserialize(const char* src) { + // Bounded deserialization. Returns true on success and false on any + // truncated / malformed input. Never reads past `src + maxbytes`. + bool deserialize(const char* src, size_t maxbytes) { const char* reader = src; + size_t remaining = maxbytes; + if (remaining < 4) { + return false; + } int32_t bitmaps_size = *(int32_t*)reader; reader += 4; + remaining -= 4; + if (bitmaps_size < 0) { + return false; + } for (int32_t i = 0; i < bitmaps_size; i++) { T key; - detail::Helper::read_from(&reader, &key); - BitmapValue bitmap(reader); - reader += bitmap.getSizeInBytes(); - _bitmaps[key] = bitmap; + if (!detail::Helper::read_from(&reader, &remaining, &key)) { + return false; + } + BitmapValue bitmap; + if (!bitmap.deserialize(reader, remaining)) { + return false; + } + size_t consumed = bitmap.getSizeInBytes(); + if (consumed > remaining) { + return false; + } + reader += consumed; + remaining -= consumed; + _bitmaps[key] = std::move(bitmap); } + return true; } protected: @@ -246,7 +303,7 @@ struct BitmapIntersect { public: BitmapIntersect() = default; - explicit BitmapIntersect(const char* src) { deserialize(src); } + explicit BitmapIntersect(const char* src, size_t maxbytes) { deserialize(src, maxbytes); } void add_key(const std::string_view key) { BitmapValue empty_bitmap; @@ -311,17 +368,36 @@ struct BitmapIntersect { } } - void deserialize(const char* src) { + bool deserialize(const char* src, size_t maxbytes) { const char* reader = src; + size_t remaining = maxbytes; + if (remaining < 4) { + return false; + } int32_t bitmaps_size = *(int32_t*)reader; reader += 4; + remaining -= 4; + if (bitmaps_size < 0) { + return false; + } for (int32_t i = 0; i < bitmaps_size; i++) { std::string key; - detail::Helper::read_from(&reader, &key); - BitmapValue bitmap(reader); - reader += bitmap.getSizeInBytes(); - _bitmaps[key] = bitmap; + if (!detail::Helper::read_from(&reader, &remaining, &key)) { + return false; + } + BitmapValue bitmap; + if (!bitmap.deserialize(reader, remaining)) { + return false; + } + size_t consumed = bitmap.getSizeInBytes(); + if (consumed > remaining) { + return false; + } + reader += consumed; + remaining -= consumed; + _bitmaps[key] = std::move(bitmap); } + return true; } protected: diff --git a/be/test/core/value/bitmap_value_test.cpp b/be/test/core/value/bitmap_value_test.cpp index 5b25551ae0c85c..84dd56d25bcd99 100644 --- a/be/test/core/value/bitmap_value_test.cpp +++ b/be/test/core/value/bitmap_value_test.cpp @@ -22,7 +22,10 @@ #include #include +#include +#include #include +#include #include "gtest/gtest.h" #include "gtest/gtest_pred_impl.h" @@ -187,7 +190,7 @@ TEST(BitmapValueTest, Roaring64Map_write_read) { roaring64_map.write(buffer.get(), 1); - detail::Roaring64Map bitmap_read = detail::Roaring64Map::read(buffer.get()); + detail::Roaring64Map bitmap_read = detail::Roaring64Map::readSafe(buffer.get(), bytes); EXPECT_EQ(bitmap_read, roaring64_map); @@ -195,7 +198,7 @@ TEST(BitmapValueTest, Roaring64Map_write_read) { buffer.reset(new char[bytes]); roaring64_map.write(buffer.get(), 2); - bitmap_read = detail::Roaring64Map::read(buffer.get()); + bitmap_read = detail::Roaring64Map::readSafe(buffer.get(), bytes); EXPECT_EQ(bitmap_read, roaring64_map); @@ -207,7 +210,7 @@ TEST(BitmapValueTest, Roaring64Map_write_read) { roaring64_map.write(buffer.get(), 1); - bitmap_read = detail::Roaring64Map::read(buffer.get()); + bitmap_read = detail::Roaring64Map::readSafe(buffer.get(), bytes); EXPECT_EQ(bitmap_read, roaring64_map); @@ -215,7 +218,7 @@ TEST(BitmapValueTest, Roaring64Map_write_read) { buffer.reset(new char[bytes]); roaring64_map.write(buffer.get(), 2); - bitmap_read = detail::Roaring64Map::read(buffer.get()); + bitmap_read = detail::Roaring64Map::readSafe(buffer.get(), bytes); EXPECT_EQ(bitmap_read, roaring64_map); } @@ -529,7 +532,7 @@ TEST(BitmapValueTest, write_read) { std::unique_ptr buffer(new char[size]); bitmap_empty.write_to(buffer.get()); - BitmapValue deserialized(buffer.get()); + BitmapValue deserialized(buffer.get(), size); check_bitmap_equal(deserialized, bitmap_empty); @@ -538,7 +541,7 @@ TEST(BitmapValueTest, write_read) { bitmap_single.write_to(buffer.get()); deserialized.reset(); - deserialized.deserialize(buffer.get()); + deserialized.deserialize(buffer.get(), size); check_bitmap_equal(deserialized, bitmap_single); @@ -547,7 +550,7 @@ TEST(BitmapValueTest, write_read) { bitmap_set.write_to(buffer.get()); deserialized.reset(); - deserialized.deserialize(buffer.get()); + deserialized.deserialize(buffer.get(), size); check_bitmap_equal(deserialized, bitmap_set); @@ -556,7 +559,7 @@ TEST(BitmapValueTest, write_read) { bitmap.write_to(buffer.get()); deserialized.reset(); - deserialized.deserialize(buffer.get()); + deserialized.deserialize(buffer.get(), size); check_bitmap_equal(deserialized, bitmap); @@ -973,7 +976,7 @@ TEST(BitmapValueTest, bitmap_serde) { std::string expect_buffer(1, BitmapTypeCode::EMPTY); EXPECT_EQ(expect_buffer, buffer); - BitmapValue out(buffer.data()); + BitmapValue out(buffer.data(), buffer.size()); EXPECT_EQ(0, out.cardinality()); } { // SINGLE32 @@ -984,15 +987,15 @@ TEST(BitmapValueTest, bitmap_serde) { put_fixed32_le(&expect_buffer, i); EXPECT_EQ(expect_buffer, buffer); - BitmapValue out(buffer.data()); + BitmapValue out(buffer.data(), buffer.size()); EXPECT_EQ(1, out.cardinality()); EXPECT_TRUE(out.contains(i)); } { // BITMAP32 - BitmapValue bitmap32({0, UINT32_MAX}); + BitmapValue bitmap32(std::vector {0, UINT32_MAX}); std::string buffer = convert_bitmap_to_string(bitmap32); - BitmapValue out(buffer.data()); + BitmapValue out(buffer.data(), buffer.size()); EXPECT_EQ(2, out.cardinality()); EXPECT_TRUE(out.contains(0)); EXPECT_TRUE(out.contains(UINT32_MAX)); @@ -1005,15 +1008,15 @@ TEST(BitmapValueTest, bitmap_serde) { put_fixed64_le(&expect_buffer, i); EXPECT_EQ(expect_buffer, buffer); - BitmapValue out(buffer.data()); + BitmapValue out(buffer.data(), buffer.size()); EXPECT_EQ(1, out.cardinality()); EXPECT_TRUE(out.contains(i)); } { // BITMAP64 - BitmapValue bitmap64({0, static_cast(UINT32_MAX) + 1}); + BitmapValue bitmap64(std::vector {0, static_cast(UINT32_MAX) + 1}); std::string buffer = convert_bitmap_to_string(bitmap64); - BitmapValue out(buffer.data()); + BitmapValue out(buffer.data(), buffer.size()); EXPECT_EQ(2, out.cardinality()); EXPECT_TRUE(out.contains(0)); EXPECT_TRUE(out.contains(static_cast(UINT32_MAX) + 1)); @@ -1078,7 +1081,7 @@ TEST(BitmapValueTest, Roaring64Map) { uint32_t expectedsize = r1.getSizeInBytes(1); char* serializedbytes = new char[expectedsize]; r1.write(serializedbytes, 1); - Roaring64Map t = Roaring64Map::read(serializedbytes); + Roaring64Map t = Roaring64Map::readSafe(serializedbytes, expectedsize); EXPECT_TRUE(r1 == t); delete[] serializedbytes; @@ -1192,6 +1195,97 @@ TEST(BitmapValueTest, bitmap_value_iterator_test) { TEST(BitmapValueTest, invalid_data) { BitmapValue bitmap; char data[] = {0x02, static_cast(0xff), 0x03}; - EXPECT_FALSE(bitmap.deserialize(data)); + EXPECT_FALSE(bitmap.deserialize(data, sizeof(data))); +} + +// Reproduces a heap out-of-bounds read in the legacy BitmapValue::deserialize: +// +// - The type byte is BITMAP64 (v1), so Roaring64Map::read is invoked. +// - The next bytes are a varint encoding a huge map_size (UINT32_MAX). +// - There are NO further bytes in the buffer, so the unbounded loop in +// Roaring64Map::read would dereference far past the buffer end. +// - Roaring::read is unsafe (no maxbytes) and the try/catch in deserialize +// only catches std::runtime_error — an over-read does not necessarily +// throw. +// +// The new bounded deserialize(src, maxbytes) must safely reject this without +// reading past the end of the provided buffer. +TEST(BitmapValueTest, deserialize_malicious_bitmap64_map_size) { + // Build payload: [BITMAP64][varint(UINT32_MAX)] then nothing. + std::string payload; + payload.push_back(static_cast(BitmapTypeCode::BITMAP64)); + uint8_t varint_buf[10]; + uint8_t* end = encode_varint64(varint_buf, std::numeric_limits::max()); + payload.append(reinterpret_cast(varint_buf), reinterpret_cast(end)); + + // Place the payload at the very end of a heap allocation so any + // out-of-bounds read is observable (under ASAN, this would crash). + std::vector heap_buf(payload.size()); + std::memcpy(heap_buf.data(), payload.data(), payload.size()); + + BitmapValue bitmap; + EXPECT_FALSE(bitmap.deserialize(heap_buf.data(), heap_buf.size())); +} + +// Same shape but with the new portable encoding (BITMAP64_V2). +TEST(BitmapValueTest, deserialize_malicious_bitmap64v2_map_size) { + std::string payload; + payload.push_back(static_cast(BitmapTypeCode::BITMAP64_V2)); + uint8_t varint_buf[10]; + uint8_t* end = encode_varint64(varint_buf, 1'000'000ULL); + payload.append(reinterpret_cast(varint_buf), reinterpret_cast(end)); + + std::vector heap_buf(payload.size()); + std::memcpy(heap_buf.data(), payload.data(), payload.size()); + + BitmapValue bitmap; + EXPECT_FALSE(bitmap.deserialize(heap_buf.data(), heap_buf.size())); +} + +// Truncated single32/single64 must not over-read the type byte's tail. +TEST(BitmapValueTest, deserialize_truncated_single) { + { + char data[] = {static_cast(BitmapTypeCode::SINGLE32), 0x01}; + BitmapValue bitmap; + EXPECT_FALSE(bitmap.deserialize(data, sizeof(data))); + } + { + char data[] = {static_cast(BitmapTypeCode::SINGLE64), 0x01, 0x02}; + BitmapValue bitmap; + EXPECT_FALSE(bitmap.deserialize(data, sizeof(data))); + } +} + +// Truncated SET_V2 with a huge claimed element count must be rejected. +TEST(BitmapValueTest, deserialize_malicious_set_v2) { + std::string payload; + payload.push_back(static_cast(BitmapTypeCode::SET_V2)); + uint32_t fake_size = 1'000'000; + payload.append(reinterpret_cast(&fake_size), sizeof(fake_size)); + + std::vector heap_buf(payload.size()); + std::memcpy(heap_buf.data(), payload.data(), payload.size()); + + BitmapValue bitmap; + EXPECT_FALSE(bitmap.deserialize(heap_buf.data(), heap_buf.size())); +} + +// Round-trip: serialize a real bitmap and verify the safe deserialize accepts +// it and reads exactly the right number of bytes. +TEST(BitmapValueTest, deserialize_bounded_roundtrip) { + BitmapValue original; + for (uint64_t v : {1ULL, 100ULL, 1ULL << 40, (1ULL << 40) + 7}) { + original.add(v); + } + std::string buf(original.getSizeInBytes(), '\0'); + original.write_to(buf.data()); + + BitmapValue restored; + EXPECT_TRUE(restored.deserialize(buf.data(), buf.size())); + EXPECT_EQ(restored.cardinality(), original.cardinality()); + + // Even a 1-byte short buffer must be rejected without UB. + BitmapValue short_target; + EXPECT_FALSE(short_target.deserialize(buf.data(), buf.size() - 1)); } } // namespace doris diff --git a/regression-test/suites/query_p0/sql_functions/bitmap_functions/test_bitmap_function.groovy b/regression-test/suites/query_p0/sql_functions/bitmap_functions/test_bitmap_function.groovy index e09f6e02506cbc..3447e28cb49b12 100644 --- a/regression-test/suites/query_p0/sql_functions/bitmap_functions/test_bitmap_function.groovy +++ b/regression-test/suites/query_p0/sql_functions/bitmap_functions/test_bitmap_function.groovy @@ -905,4 +905,63 @@ suite("test_bitmap_function") { sql """ SELECT bitmap_from_base64('CQoL') AS result; """ exception "bitmap_from_base64 decode failed" } + + /* + ┌────────────┬─────────────────────┬─────────────────────────────────────────────────────────────────────────────────────────────────────┐ + │ Base64 │ Payload │ Attack │ + ├────────────┼─────────────────────┼─────────────────────────────────────────────────────────────────────────────────────────────────────┤ + │ BP////8P │ 04 FF FF FF FF 0F │ BITMAP64 with varint map_size = UINT32_MAX, no body — without fix loops huge map + heap over-read │ + ├────────────┼─────────────────────┼─────────────────────────────────────────────────────────────────────────────────────────────────────┤ + │ Df////8P │ 0D FF FF FF FF 0F │ BITMAP64_V2 (portable) variant of above │ + ├────────────┼─────────────────────┼─────────────────────────────────────────────────────────────────────────────────────────────────────┤ + │ Cv////8= │ 0A FF FF FF FF │ SET_V2 with size = UINT32_MAX → attempts ~32 GiB alloc + over-read │ + └────────────┴─────────────────────┴─────────────────────────────────────────────────────────────────────────────────────────────────────┘ + */ + // Malicious BITMAP64 (type=0x04) with varint map_size = UINT32_MAX and no payload. + // Without bounds checking, deserialize would loop UINT32_MAX times reading past + // the buffer and trigger a huge allocation / heap over-read. + test { + sql """ SELECT bitmap_from_base64('BP////8P') AS result; """ + exception "bitmap_from_base64 decode failed" + } + + // Malicious BITMAP64_V2 (portable, type=0x0D) with varint map_size = UINT32_MAX. + test { + sql """ SELECT bitmap_from_base64('Df////8P') AS result; """ + exception "bitmap_from_base64 decode failed" + } + + // Malicious SET_V2 (type=0x0A) with uint32 size = UINT32_MAX and no elements. + // Without bounds checking this would attempt to allocate 32 GiB. + test { + sql """ SELECT bitmap_from_base64('Cv////8=') AS result; """ + exception "bitmap_from_base64 decode failed" + } + + // Invalid bitmap type codes: should hit the !is_bitmap32 && !is_bitmap64 + // branch in Roaring64Map::readSafe / deserialize and return false instead + // of dereferencing arbitrary memory. + /* +┌────────────┬────────────────┬───────────────────────────────────────────────┐ +│ Base64 │ 第一字节(type) │ 说明 │ +├────────────┼────────────────┼───────────────────────────────────────────────┤ +│ /wAAAA== │ 0xFF │ 完全非法 type code │ +├────────────┼────────────────┼───────────────────────────────────────────────┤ +│ BgAAAA== │ 0x06 │ 落在 SET(5) 与 SET_V2(10) 之间的空隙 │ +├────────────┼────────────────┼───────────────────────────────────────────────┤ +│ CwAAAA== │ 0x0B │ 落在 SET_V2(10) 与 BITMAP32_V2(12) 之间的空隙 │ +└────────────┴────────────────┴───────────────────────────────────────────────┘ + */ + test { + sql """ SELECT bitmap_from_base64('/wAAAA==') AS result; """ + exception "bitmap_from_base64 decode failed" + } + test { + sql """ SELECT bitmap_from_base64('BgAAAA==') AS result; """ + exception "bitmap_from_base64 decode failed" + } + test { + sql """ SELECT bitmap_from_base64('CwAAAA==') AS result; """ + exception "bitmap_from_base64 decode failed" + } }