Skip to content

Commit

Permalink
Merge pull request #51560 from ClickHouse/backport/23.3/50551
Browse files Browse the repository at this point in the history
Backport #50551 to 23.3: Fix backward compatibility for IP types hashing in aggregate functions
  • Loading branch information
al13n321 committed Jun 30, 2023
2 parents 1dd052c + 1585803 commit 04e2ce0
Show file tree
Hide file tree
Showing 18 changed files with 584 additions and 27 deletions.
28 changes: 20 additions & 8 deletions base/base/IPv4andIPv6.h
Expand Up @@ -2,21 +2,23 @@

#include <base/strong_typedef.h>
#include <base/extended_types.h>
#include <Common/formatIPv6.h>
#include <Common/memcmpSmall.h>

namespace DB
{

using IPv4 = StrongTypedef<UInt32, struct IPv4Tag>;
struct IPv4 : StrongTypedef<UInt32, struct IPv4Tag>
{
using StrongTypedef::StrongTypedef;
using StrongTypedef::operator=;
constexpr explicit IPv4(UInt64 value): StrongTypedef(static_cast<UnderlyingType>(value)) {}
};

struct IPv6 : StrongTypedef<UInt128, struct IPv6Tag>
{
constexpr IPv6() = default;
constexpr explicit IPv6(const UInt128 & x) : StrongTypedef(x) {}
constexpr explicit IPv6(UInt128 && x) : StrongTypedef(std::move(x)) {}

IPv6 & operator=(const UInt128 & rhs) { StrongTypedef::operator=(rhs); return *this; }
IPv6 & operator=(UInt128 && rhs) { StrongTypedef::operator=(std::move(rhs)); return *this; }
using StrongTypedef::StrongTypedef;
using StrongTypedef::operator=;

bool operator<(const IPv6 & rhs) const
{
Expand Down Expand Up @@ -54,12 +56,22 @@ namespace DB

namespace std
{
/// For historical reasons we hash IPv6 as a FixedString(16)
template <>
struct hash<DB::IPv6>
{
size_t operator()(const DB::IPv6 & x) const
{
return std::hash<DB::IPv6::UnderlyingType>()(x.toUnderType());
return std::hash<std::string_view>{}(std::string_view(reinterpret_cast<const char*>(&x.toUnderType()), IPV6_BINARY_LENGTH));
}
};

template <>
struct hash<DB::IPv4>
{
size_t operator()(const DB::IPv4 & x) const
{
return std::hash<DB::IPv4::UnderlyingType>()(x.toUnderType());
}
};
}
4 changes: 4 additions & 0 deletions docs/en/sql-reference/aggregate-functions/combinators.md
Expand Up @@ -70,6 +70,10 @@ Result:

If you apply this combinator, the aggregate function does not return the resulting value (such as the number of unique values for the [uniq](../../sql-reference/aggregate-functions/reference/uniq.md#agg_function-uniq) function), but an intermediate state of the aggregation (for `uniq`, this is the hash table for calculating the number of unique values). This is an `AggregateFunction(...)` that can be used for further processing or stored in a table to finish aggregating later.

:::note
Please notice, that -MapState is not an invariant for the same data due to the fact that order of data in intermediate state can change, though it doesn't impact ingestion of this data.
:::

To work with these states, use:

- [AggregatingMergeTree](../../engines/table-engines/mergetree-family/aggregatingmergetree.md) table engine.
Expand Down
4 changes: 4 additions & 0 deletions docs/ru/sql-reference/aggregate-functions/combinators.md
Expand Up @@ -66,6 +66,10 @@ WITH anySimpleState(number) AS c SELECT toTypeName(c), c FROM numbers(1);

В случае применения этого комбинатора, агрегатная функция возвращает не готовое значение (например, в случае функции [uniq](reference/uniq.md#agg_function-uniq) — количество уникальных значений), а промежуточное состояние агрегации (например, в случае функции `uniq` — хэш-таблицу для расчёта количества уникальных значений), которое имеет тип `AggregateFunction(...)` и может использоваться для дальнейшей обработки или может быть сохранено в таблицу для последующей доагрегации.

:::note
Промежуточное состояние для -MapState не является инвариантом для одних и тех же исходных данные т.к. порядок данных может меняться. Это не влияет, тем не менее, на загрузку таких данных.
:::

Для работы с промежуточными состояниями предназначены:

- Движок таблиц [AggregatingMergeTree](../../engines/table-engines/mergetree-family/aggregatingmergetree.md).
Expand Down
1 change: 1 addition & 0 deletions src/AggregateFunctions/AggregateFunctionGroupArray.cpp
Expand Up @@ -25,6 +25,7 @@ IAggregateFunction * createWithNumericOrTimeType(const IDataType & argument_type
WhichDataType which(argument_type);
if (which.idx == TypeIndex::Date) return new AggregateFunctionTemplate<UInt16, Data>(std::forward<TArgs>(args)...);
if (which.idx == TypeIndex::DateTime) return new AggregateFunctionTemplate<UInt32, Data>(std::forward<TArgs>(args)...);
if (which.idx == TypeIndex::IPv4) return new AggregateFunctionTemplate<IPv4, Data>(std::forward<TArgs>(args)...);
return createWithNumericType<AggregateFunctionTemplate, Data, TArgs...>(argument_type, std::forward<TArgs>(args)...);
}

Expand Down
11 changes: 11 additions & 0 deletions src/AggregateFunctions/AggregateFunctionGroupUniqArray.cpp
Expand Up @@ -4,6 +4,7 @@
#include <AggregateFunctions/FactoryHelpers.h>
#include <DataTypes/DataTypeDate.h>
#include <DataTypes/DataTypeDateTime.h>
#include <DataTypes/DataTypeIPv4andIPv6.h>


namespace DB
Expand Down Expand Up @@ -39,12 +40,22 @@ class AggregateFunctionGroupUniqArrayDateTime : public AggregateFunctionGroupUni
static DataTypePtr createResultType() { return std::make_shared<DataTypeArray>(std::make_shared<DataTypeDateTime>()); }
};

template <typename HasLimit>
class AggregateFunctionGroupUniqArrayIPv4 : public AggregateFunctionGroupUniqArray<DataTypeIPv4::FieldType, HasLimit>
{
public:
explicit AggregateFunctionGroupUniqArrayIPv4(const DataTypePtr & argument_type, const Array & parameters_, UInt64 max_elems_ = std::numeric_limits<UInt64>::max())
: AggregateFunctionGroupUniqArray<DataTypeIPv4::FieldType, HasLimit>(argument_type, parameters_, createResultType(), max_elems_) {}
static DataTypePtr createResultType() { return std::make_shared<DataTypeArray>(std::make_shared<DataTypeIPv4>()); }
};

template <typename HasLimit, typename ... TArgs>
IAggregateFunction * createWithExtraTypes(const DataTypePtr & argument_type, TArgs && ... args)
{
WhichDataType which(argument_type);
if (which.idx == TypeIndex::Date) return new AggregateFunctionGroupUniqArrayDate<HasLimit>(argument_type, std::forward<TArgs>(args)...);
else if (which.idx == TypeIndex::DateTime) return new AggregateFunctionGroupUniqArrayDateTime<HasLimit>(argument_type, std::forward<TArgs>(args)...);
else if (which.idx == TypeIndex::IPv4) return new AggregateFunctionGroupUniqArrayIPv4<HasLimit>(argument_type, std::forward<TArgs>(args)...);
else
{
/// Check that we can use plain version of AggregateFunctionGroupUniqArrayGeneric
Expand Down
4 changes: 4 additions & 0 deletions src/AggregateFunctions/AggregateFunctionMap.cpp
Expand Up @@ -100,6 +100,10 @@ class AggregateFunctionCombinatorMap final : public IAggregateFunctionCombinator
return std::make_shared<AggregateFunctionMap<UInt256>>(nested_function, arguments);
case TypeIndex::UUID:
return std::make_shared<AggregateFunctionMap<UUID>>(nested_function, arguments);
case TypeIndex::IPv4:
return std::make_shared<AggregateFunctionMap<IPv4>>(nested_function, arguments);
case TypeIndex::IPv6:
return std::make_shared<AggregateFunctionMap<IPv6>>(nested_function, arguments);
case TypeIndex::FixedString:
case TypeIndex::String:
return std::make_shared<AggregateFunctionMap<String>>(nested_function, arguments);
Expand Down
29 changes: 29 additions & 0 deletions src/AggregateFunctions/AggregateFunctionMap.h
Expand Up @@ -19,7 +19,9 @@
#include <IO/ReadHelpers.h>
#include <IO/WriteHelpers.h>
#include "DataTypes/Serializations/ISerialization.h"
#include <base/IPv4andIPv6.h>
#include "base/types.h"
#include <Common/formatIPv6.h>
#include <Common/Arena.h>
#include "AggregateFunctions/AggregateFunctionFactory.h"

Expand Down Expand Up @@ -69,6 +71,31 @@ struct AggregateFunctionMapCombinatorData<String>
}
};

/// Specialization for IPv6 - for historical reasons it should be stored as FixedString(16)
template <>
struct AggregateFunctionMapCombinatorData<IPv6>
{
struct IPv6Hash
{
using hash_type = std::hash<IPv6>;
using is_transparent = void;

size_t operator()(const IPv6 & ip) const { return hash_type{}(ip); }
};

using SearchType = IPv6;
std::unordered_map<IPv6, AggregateDataPtr, IPv6Hash, std::equal_to<>> merged_maps;

static void writeKey(const IPv6 & key, WriteBuffer & buf)
{
writeIPv6Binary(key, buf);
}
static void readKey(IPv6 & key, ReadBuffer & buf)
{
readIPv6Binary(key, buf);
}
};

template <typename KeyType>
class AggregateFunctionMap final
: public IAggregateFunctionDataHelper<AggregateFunctionMapCombinatorData<KeyType>, AggregateFunctionMap<KeyType>>
Expand Down Expand Up @@ -147,6 +174,8 @@ class AggregateFunctionMap final
StringRef key_ref;
if (key_type->getTypeId() == TypeIndex::FixedString)
key_ref = assert_cast<const ColumnFixedString &>(key_column).getDataAt(offset + i);
else if (key_type->getTypeId() == TypeIndex::IPv6)
key_ref = assert_cast<const ColumnIPv6 &>(key_column).getDataAt(offset + i);
else
key_ref = assert_cast<const ColumnString &>(key_column).getDataAt(offset + i);

Expand Down
19 changes: 19 additions & 0 deletions src/AggregateFunctions/AggregateFunctionTopK.cpp
Expand Up @@ -5,6 +5,7 @@
#include <Common/FieldVisitorConvertToNumber.h>
#include <DataTypes/DataTypeDate.h>
#include <DataTypes/DataTypeDateTime.h>
#include <DataTypes/DataTypeIPv4andIPv6.h>


static inline constexpr UInt64 TOP_K_MAX_SIZE = 0xFFFFFF;
Expand Down Expand Up @@ -60,6 +61,22 @@ class AggregateFunctionTopKDateTime : public AggregateFunctionTopK<DataTypeDateT
{}
};

template <bool is_weighted>
class AggregateFunctionTopKIPv4 : public AggregateFunctionTopK<DataTypeIPv4::FieldType, is_weighted>
{
public:
using AggregateFunctionTopK<DataTypeIPv4::FieldType, is_weighted>::AggregateFunctionTopK;

AggregateFunctionTopKIPv4(UInt64 threshold_, UInt64 load_factor, const DataTypes & argument_types_, const Array & params)
: AggregateFunctionTopK<DataTypeIPv4::FieldType, is_weighted>(
threshold_,
load_factor,
argument_types_,
params,
std::make_shared<DataTypeArray>(std::make_shared<DataTypeIPv4>()))
{}
};


template <bool is_weighted>
IAggregateFunction * createWithExtraTypes(const DataTypes & argument_types, UInt64 threshold, UInt64 load_factor, const Array & params)
Expand All @@ -72,6 +89,8 @@ IAggregateFunction * createWithExtraTypes(const DataTypes & argument_types, UInt
return new AggregateFunctionTopKDate<is_weighted>(threshold, load_factor, argument_types, params);
if (which.idx == TypeIndex::DateTime)
return new AggregateFunctionTopKDateTime<is_weighted>(threshold, load_factor, argument_types, params);
if (which.idx == TypeIndex::IPv4)
return new AggregateFunctionTopKIPv4<is_weighted>(threshold, load_factor, argument_types, params);

/// Check that we can use plain version of AggregateFunctionTopKGeneric
if (argument_types[0]->isValueUnambiguouslyRepresentedInContiguousMemoryRegion())
Expand Down
9 changes: 9 additions & 0 deletions src/AggregateFunctions/AggregateFunctionUniq.cpp
Expand Up @@ -8,6 +8,7 @@
#include <DataTypes/DataTypeDateTime.h>
#include <DataTypes/DataTypeTuple.h>
#include <DataTypes/DataTypeUUID.h>
#include <DataTypes/DataTypeIPv4andIPv6.h>

#include <Core/Settings.h>

Expand Down Expand Up @@ -60,6 +61,10 @@ createAggregateFunctionUniq(const std::string & name, const DataTypes & argument
return std::make_shared<AggregateFunctionUniq<String, Data>>(argument_types);
else if (which.isUUID())
return std::make_shared<AggregateFunctionUniq<DataTypeUUID::FieldType, Data>>(argument_types);
else if (which.isIPv4())
return std::make_shared<AggregateFunctionUniq<DataTypeIPv4::FieldType, Data>>(argument_types);
else if (which.isIPv6())
return std::make_shared<AggregateFunctionUniq<DataTypeIPv6::FieldType, Data>>(argument_types);
else if (which.isTuple())
{
if (use_exact_hash_function)
Expand Down Expand Up @@ -109,6 +114,10 @@ createAggregateFunctionUniq(const std::string & name, const DataTypes & argument
return std::make_shared<AggregateFunctionUniq<String, Data<String, is_able_to_parallelize_merge>>>(argument_types);
else if (which.isUUID())
return std::make_shared<AggregateFunctionUniq<DataTypeUUID::FieldType, Data<DataTypeUUID::FieldType, is_able_to_parallelize_merge>>>(argument_types);
else if (which.isIPv4())
return std::make_shared<AggregateFunctionUniq<DataTypeIPv4::FieldType, Data<DataTypeIPv4::FieldType, is_able_to_parallelize_merge>>>(argument_types);
else if (which.isIPv6())
return std::make_shared<AggregateFunctionUniq<DataTypeIPv6::FieldType, Data<DataTypeIPv6::FieldType, is_able_to_parallelize_merge>>>(argument_types);
else if (which.isTuple())
{
if (use_exact_hash_function)
Expand Down
55 changes: 43 additions & 12 deletions src/AggregateFunctions/AggregateFunctionUniq.h
Expand Up @@ -101,6 +101,18 @@ struct AggregateFunctionUniqHLL12Data<UUID, false>
static String getName() { return "uniqHLL12"; }
};

template <>
struct AggregateFunctionUniqHLL12Data<IPv6, false>
{
using Set = HyperLogLogWithSmallSetOptimization<UInt64, 16, 12>;
Set set;

constexpr static bool is_able_to_parallelize_merge = false;
constexpr static bool is_variadic = false;

static String getName() { return "uniqHLL12"; }
};

template <bool is_exact_, bool argument_is_tuple_, bool is_able_to_parallelize_merge_>
struct AggregateFunctionUniqHLL12DataForVariadic
{
Expand Down Expand Up @@ -155,6 +167,25 @@ struct AggregateFunctionUniqExactData<String, is_able_to_parallelize_merge_>
static String getName() { return "uniqExact"; }
};

/// For historical reasons IPv6 is treated as FixedString(16)
template <bool is_able_to_parallelize_merge_>
struct AggregateFunctionUniqExactData<IPv6, is_able_to_parallelize_merge_>
{
using Key = UInt128;

/// When creating, the hash table must be small.
using SingleLevelSet = HashSet<Key, UInt128TrivialHash, HashTableGrower<3>, HashTableAllocatorWithStackMemory<sizeof(Key) * (1 << 3)>>;
using TwoLevelSet = TwoLevelHashSet<Key, UInt128TrivialHash>;
using Set = UniqExactSet<SingleLevelSet, TwoLevelSet>;

Set set;

constexpr static bool is_able_to_parallelize_merge = is_able_to_parallelize_merge_;
constexpr static bool is_variadic = false;

static String getName() { return "uniqExact"; }
};

template <bool is_exact_, bool argument_is_tuple_, bool is_able_to_parallelize_merge_>
struct AggregateFunctionUniqExactDataForVariadic : AggregateFunctionUniqExactData<String, is_able_to_parallelize_merge_>
{
Expand Down Expand Up @@ -248,27 +279,22 @@ struct Adder
AggregateFunctionUniqUniquesHashSetData> || std::is_same_v<Data, AggregateFunctionUniqHLL12Data<T, Data::is_able_to_parallelize_merge>>)
{
const auto & column = *columns[0];
if constexpr (!std::is_same_v<T, String>)
if constexpr (std::is_same_v<T, String> || std::is_same_v<T, IPv6>)
{
using ValueType = typename decltype(data.set)::value_type;
const auto & value = assert_cast<const ColumnVector<T> &>(column).getElement(row_num);
data.set.insert(static_cast<ValueType>(AggregateFunctionUniqTraits<T>::hash(value)));
StringRef value = column.getDataAt(row_num);
data.set.insert(CityHash_v1_0_2::CityHash64(value.data, value.size));
}
else
{
StringRef value = column.getDataAt(row_num);
data.set.insert(CityHash_v1_0_2::CityHash64(value.data, value.size));
using ValueType = typename decltype(data.set)::value_type;
const auto & value = assert_cast<const ColumnVector<T> &>(column).getElement(row_num);
data.set.insert(static_cast<ValueType>(AggregateFunctionUniqTraits<T>::hash(value)));
}
}
else if constexpr (std::is_same_v<Data, AggregateFunctionUniqExactData<T, Data::is_able_to_parallelize_merge>>)
{
const auto & column = *columns[0];
if constexpr (!std::is_same_v<T, String>)
{
data.set.template insert<const T &, use_single_level_hash_table>(
assert_cast<const ColumnVector<T> &>(column).getData()[row_num]);
}
else
if constexpr (std::is_same_v<T, String> || std::is_same_v<T, IPv6>)
{
StringRef value = column.getDataAt(row_num);

Expand All @@ -279,6 +305,11 @@ struct Adder

data.set.template insert<const UInt128 &, use_single_level_hash_table>(key);
}
else
{
data.set.template insert<const T &, use_single_level_hash_table>(
assert_cast<const ColumnVector<T> &>(column).getData()[row_num]);
}
}
#if USE_DATASKETCHES
else if constexpr (std::is_same_v<Data, AggregateFunctionUniqThetaData>)
Expand Down
5 changes: 5 additions & 0 deletions src/AggregateFunctions/AggregateFunctionUniqCombined.cpp
Expand Up @@ -8,6 +8,7 @@
#include <DataTypes/DataTypeDate.h>
#include <DataTypes/DataTypeDate32.h>
#include <DataTypes/DataTypeDateTime.h>
#include <DataTypes/DataTypeIPv4andIPv6.h>

#include <functional>

Expand Down Expand Up @@ -60,6 +61,10 @@ namespace
return std::make_shared<typename WithK<K, HashValueType>::template AggregateFunction<String>>(argument_types, params);
else if (which.isUUID())
return std::make_shared<typename WithK<K, HashValueType>::template AggregateFunction<DataTypeUUID::FieldType>>(argument_types, params);
else if (which.isIPv4())
return std::make_shared<typename WithK<K, HashValueType>::template AggregateFunction<DataTypeIPv4::FieldType>>(argument_types, params);
else if (which.isIPv6())
return std::make_shared<typename WithK<K, HashValueType>::template AggregateFunction<DataTypeIPv6::FieldType>>(argument_types, params);
else if (which.isTuple())
{
if (use_exact_hash_function)
Expand Down
14 changes: 9 additions & 5 deletions src/AggregateFunctions/AggregateFunctionUniqCombined.h
Expand Up @@ -119,6 +119,10 @@ struct AggregateFunctionUniqCombinedData<String, K, HashValueType> : public Aggr
{
};

template <UInt8 K, typename HashValueType>
struct AggregateFunctionUniqCombinedData<IPv6, K, HashValueType> : public AggregateFunctionUniqCombinedDataWithKey<UInt64 /*always*/, K>
{
};

template <typename T, UInt8 K, typename HashValueType>
class AggregateFunctionUniqCombined final
Expand All @@ -141,15 +145,15 @@ class AggregateFunctionUniqCombined final

void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override
{
if constexpr (!std::is_same_v<T, String>)
if constexpr (std::is_same_v<T, String> || std::is_same_v<T, IPv6>)
{
const auto & value = assert_cast<const ColumnVector<T> &>(*columns[0]).getElement(row_num);
this->data(place).set.insert(detail::AggregateFunctionUniqCombinedTraits<T, HashValueType>::hash(value));
StringRef value = columns[0]->getDataAt(row_num);
this->data(place).set.insert(CityHash_v1_0_2::CityHash64(value.data, value.size));
}
else
{
StringRef value = columns[0]->getDataAt(row_num);
this->data(place).set.insert(CityHash_v1_0_2::CityHash64(value.data, value.size));
const auto & value = assert_cast<const ColumnVector<T> &>(*columns[0]).getElement(row_num);
this->data(place).set.insert(detail::AggregateFunctionUniqCombinedTraits<T, HashValueType>::hash(value));
}
}

Expand Down

0 comments on commit 04e2ce0

Please sign in to comment.