Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add server setting aggregate_function_group_array_max_element_size. #53550

Merged
merged 4 commits into from Aug 18, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
13 changes: 11 additions & 2 deletions src/AggregateFunctions/AggregateFunctionGroupArray.cpp
Expand Up @@ -4,6 +4,8 @@
#include <AggregateFunctions/FactoryHelpers.h>
#include <DataTypes/DataTypeDate.h>
#include <DataTypes/DataTypeDateTime.h>
#include <Interpreters/Context.h>
#include <Core/ServerSettings.h>


namespace DB
Expand Down Expand Up @@ -43,6 +45,13 @@ inline AggregateFunctionPtr createAggregateFunctionGroupArrayImpl(const DataType
return std::make_shared<GroupArrayGeneralImpl<GroupArrayNodeGeneral, Trait>>(argument_type, parameters, std::forward<TArgs>(args)...);
}

static size_t getMaxArraySize()
{
if (auto context = Context::getGlobalContextInstance())
return context->getServerSettings().aggregate_function_group_array_max_element_size;

return 0xFFFFFF;
}

template <bool Tlast>
AggregateFunctionPtr createAggregateFunctionGroupArray(
Expand All @@ -51,7 +60,7 @@ AggregateFunctionPtr createAggregateFunctionGroupArray(
assertUnary(name, argument_types);

bool limit_size = false;
UInt64 max_elems = std::numeric_limits<UInt64>::max();
UInt64 max_elems = getMaxArraySize();

if (parameters.empty())
{
Expand All @@ -78,7 +87,7 @@ AggregateFunctionPtr createAggregateFunctionGroupArray(
{
if (Tlast)
throw Exception(ErrorCodes::BAD_ARGUMENTS, "groupArrayLast make sense only with max_elems (groupArrayLast(max_elems)())");
return createAggregateFunctionGroupArrayImpl<GroupArrayTrait</* Thas_limit= */ false, Tlast, /* Tsampler= */ Sampler::NONE>>(argument_types[0], parameters);
return createAggregateFunctionGroupArrayImpl<GroupArrayTrait</* Thas_limit= */ false, Tlast, /* Tsampler= */ Sampler::NONE>>(argument_types[0], parameters, max_elems);
}
else
return createAggregateFunctionGroupArrayImpl<GroupArrayTrait</* Thas_limit= */ true, Tlast, /* Tsampler= */ Sampler::NONE>>(argument_types[0], parameters, max_elems);
Expand Down
56 changes: 34 additions & 22 deletions src/AggregateFunctions/AggregateFunctionGroupArray.h
Expand Up @@ -21,7 +21,7 @@

#include <type_traits>

#define AGGREGATE_FUNCTION_GROUP_ARRAY_MAX_ARRAY_SIZE 0xFFFFFF
#define AGGREGATE_FUNCTION_GROUP_ARRAY_MAX_ELEMENT_SIZE 0xFFFFFF


namespace DB
Expand Down Expand Up @@ -128,7 +128,7 @@ class GroupArrayNumericImpl final

public:
explicit GroupArrayNumericImpl(
const DataTypePtr & data_type_, const Array & parameters_, UInt64 max_elems_ = std::numeric_limits<UInt64>::max(), UInt64 seed_ = 123456)
const DataTypePtr & data_type_, const Array & parameters_, UInt64 max_elems_, UInt64 seed_ = 123456)
: IAggregateFunctionDataHelper<GroupArrayNumericData<T, Trait::sampler != Sampler::NONE>, GroupArrayNumericImpl<T, Trait>>(
{data_type_}, parameters_, std::make_shared<DataTypeArray>(data_type_))
, max_elems(max_elems_)
Expand Down Expand Up @@ -263,10 +263,18 @@ class GroupArrayNumericImpl final
}
}

static void checkArraySize(size_t elems, size_t max_elems)
{
if (unlikely(elems > max_elems))
throw Exception(ErrorCodes::TOO_LARGE_ARRAY_SIZE,
"Too large array size {} (maximum: {})", elems, max_elems);
}

void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
{
const auto & value = this->data(place).value;
const size_t size = value.size();
const UInt64 size = value.size();
checkArraySize(size, max_elems);
writeVarUInt(size, buf);
for (const auto & element : value)
writeBinaryLittleEndian(element, buf);
Expand All @@ -287,13 +295,7 @@ class GroupArrayNumericImpl final
{
size_t size = 0;
readVarUInt(size, buf);

if (unlikely(size > AGGREGATE_FUNCTION_GROUP_ARRAY_MAX_ARRAY_SIZE))
throw Exception(ErrorCodes::TOO_LARGE_ARRAY_SIZE,
"Too large array size (maximum: {})", AGGREGATE_FUNCTION_GROUP_ARRAY_MAX_ARRAY_SIZE);

if (limit_num_elems && unlikely(size > max_elems))
throw Exception(ErrorCodes::TOO_LARGE_ARRAY_SIZE, "Too large array size, it should not exceed {}", max_elems);
checkArraySize(size, max_elems);

auto & value = this->data(place).value;

Expand Down Expand Up @@ -357,9 +359,17 @@ struct GroupArrayNodeBase
const_cast<char *>(arena->alignedInsert(reinterpret_cast<const char *>(this), sizeof(Node) + size, alignof(Node))));
}

static void checkElementSize(size_t size, size_t max_size)
{
if (unlikely(size > max_size))
throw Exception(ErrorCodes::TOO_LARGE_ARRAY_SIZE,
"Too large array element size {} (maximum: {})", size, max_size);
}

/// Write node to buffer
void write(WriteBuffer & buf) const
{
checkElementSize(size, AGGREGATE_FUNCTION_GROUP_ARRAY_MAX_ELEMENT_SIZE);
writeVarUInt(size, buf);
buf.write(data(), size);
}
Expand All @@ -369,9 +379,7 @@ struct GroupArrayNodeBase
{
UInt64 size;
readVarUInt(size, buf);
if (unlikely(size > AGGREGATE_FUNCTION_GROUP_ARRAY_MAX_ARRAY_SIZE))
throw Exception(ErrorCodes::TOO_LARGE_ARRAY_SIZE,
"Too large array size (maximum: {})", AGGREGATE_FUNCTION_GROUP_ARRAY_MAX_ARRAY_SIZE);
checkElementSize(size, AGGREGATE_FUNCTION_GROUP_ARRAY_MAX_ELEMENT_SIZE);

Node * node = reinterpret_cast<Node *>(arena->alignedAlloc(sizeof(Node) + size, alignof(Node)));
node->size = size;
Expand Down Expand Up @@ -455,7 +463,7 @@ class GroupArrayGeneralImpl final
UInt64 seed;

public:
GroupArrayGeneralImpl(const DataTypePtr & data_type_, const Array & parameters_, UInt64 max_elems_ = std::numeric_limits<UInt64>::max(), UInt64 seed_ = 123456)
GroupArrayGeneralImpl(const DataTypePtr & data_type_, const Array & parameters_, UInt64 max_elems_, UInt64 seed_ = 123456)
: IAggregateFunctionDataHelper<GroupArrayGeneralData<Node, Trait::sampler != Sampler::NONE>, GroupArrayGeneralImpl<Node, Trait>>(
{data_type_}, parameters_, std::make_shared<DataTypeArray>(data_type_))
, data_type(this->argument_types[0])
Expand Down Expand Up @@ -596,9 +604,18 @@ class GroupArrayGeneralImpl final
}
}

static void checkArraySize(size_t elems, size_t max_elems)
{
if (unlikely(elems > max_elems))
throw Exception(ErrorCodes::TOO_LARGE_ARRAY_SIZE,
"Too large array size {} (maximum: {})", elems, max_elems);
}

void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
{
writeVarUInt(data(place).value.size(), buf);
UInt64 elems = data(place).value.size();
checkArraySize(elems, max_elems);
writeVarUInt(elems, buf);

auto & value = data(place).value;
for (auto & node : value)
Expand All @@ -624,12 +641,7 @@ class GroupArrayGeneralImpl final
if (unlikely(elems == 0))
return;

if (unlikely(elems > AGGREGATE_FUNCTION_GROUP_ARRAY_MAX_ARRAY_SIZE))
throw Exception(ErrorCodes::TOO_LARGE_ARRAY_SIZE,
"Too large array size (maximum: {})", AGGREGATE_FUNCTION_GROUP_ARRAY_MAX_ARRAY_SIZE);

if (limit_num_elems && unlikely(elems > max_elems))
throw Exception(ErrorCodes::TOO_LARGE_ARRAY_SIZE, "Too large array size, it should not exceed {}", max_elems);
checkArraySize(elems, max_elems);

auto & value = data(place).value;

Expand Down Expand Up @@ -673,6 +685,6 @@ class GroupArrayGeneralImpl final
bool allocatesMemoryInArena() const override { return true; }
};

#undef AGGREGATE_FUNCTION_GROUP_ARRAY_MAX_ARRAY_SIZE
#undef AGGREGATE_FUNCTION_GROUP_ARRAY_MAX_ELEMENT_SIZE

}
1 change: 1 addition & 0 deletions src/Core/ServerSettings.h
Expand Up @@ -44,6 +44,7 @@ namespace DB
M(String, tmp_policy, "", "Policy for storage with temporary data.", 0) \
M(UInt64, max_temporary_data_on_disk_size, 0, "The maximum amount of storage that could be used for external aggregation, joins or sorting., ", 0) \
M(String, temporary_data_in_cache, "", "Cache disk name for temporary data.", 0) \
M(UInt64, aggregate_function_group_array_max_element_size, 0xFFFFFF, "Max array element size in bytes for groupArray function. This limit is checked at serialization and help to avoid large state size.", 0) \
M(UInt64, max_server_memory_usage, 0, "Maximum total memory usage of the server in bytes. Zero means unlimited.", 0) \
M(Double, max_server_memory_usage_to_ram_ratio, 0.9, "Same as max_server_memory_usage but in to RAM ratio. Allows to lower max memory on low-memory systems.", 0) \
M(UInt64, merges_mutations_memory_usage_soft_limit, 0, "Maximum total memory usage for merges and mutations in bytes. Zero means unlimited.", 0) \
Expand Down
5 changes: 5 additions & 0 deletions src/Interpreters/Context.cpp
Expand Up @@ -4595,4 +4595,9 @@ void Context::setClientProtocolVersion(UInt64 version)
client_protocol_version = version;
}

const ServerSettings & Context::getServerSettings() const
{
return shared->server_settings;
}

}
4 changes: 4 additions & 0 deletions src/Interpreters/Context.h
Expand Up @@ -206,6 +206,8 @@ using PreparedSetsCachePtr = std::shared_ptr<PreparedSetsCache>;

class SessionTracker;

struct ServerSettings;

/// An empty interface for an arbitrary object that may be attached by a shared pointer
/// to query context, when using ClickHouse as a library.
struct IHostContext
Expand Down Expand Up @@ -1192,6 +1194,8 @@ class Context: public std::enable_shared_from_this<Context>
void setPreparedSetsCache(const PreparedSetsCachePtr & cache);
PreparedSetsCachePtr getPreparedSetsCache() const;

const ServerSettings & getServerSettings() const;

private:
std::unique_lock<std::recursive_mutex> getLock() const;

Expand Down
Empty file.
@@ -0,0 +1,3 @@
<clickhouse>
<aggregate_function_group_array_max_element_size>10</aggregate_function_group_array_max_element_size>
</clickhouse>
65 changes: 65 additions & 0 deletions tests/integration/test_group_array_element_size/test.py
@@ -0,0 +1,65 @@
#!/usr/bin/env python3
import pytest
from helpers.cluster import ClickHouseCluster

cluster = ClickHouseCluster(__file__)
node1 = cluster.add_instance(
"node1",
main_configs=["configs/group_array_max_element_size.xml"],
stay_alive=True,
)


@pytest.fixture(scope="module")
def started_cluster():
try:
cluster.start()

yield cluster

finally:
cluster.shutdown()


def test_max_exement_size(started_cluster):
node1.query(
"CREATE TABLE tab3 (x AggregateFunction(groupArray, Array(UInt8))) ENGINE = MergeTree ORDER BY tuple()"
)
node1.query("insert into tab3 select groupArrayState([zero]) from zeros(10)")
assert node1.query("select length(groupArrayMerge(x)) from tab3") == "10\n"

# First query should always fail
with pytest.raises(Exception, match=r"Too large array size"):
node1.query("insert into tab3 select groupArrayState([zero]) from zeros(11)")

node1.replace_in_config(
"/etc/clickhouse-server/config.d/group_array_max_element_size.xml",
"10",
"11",
)

node1.restart_clickhouse()

node1.query("insert into tab3 select groupArrayState([zero]) from zeros(11)")
assert node1.query("select length(groupArrayMerge(x)) from tab3") == "21\n"

node1.replace_in_config(
"/etc/clickhouse-server/config.d/group_array_max_element_size.xml",
"11",
"10",
)

node1.restart_clickhouse()

with pytest.raises(Exception, match=r"Too large array size"):
node1.query("select length(groupArrayMerge(x)) from tab3")

node1.replace_in_config(
"/etc/clickhouse-server/config.d/group_array_max_element_size.xml",
"10",
"11",
)

node1.restart_clickhouse()

assert node1.query("select length(groupArrayMerge(x)) from tab3") == "21\n"