From e5cbce7764828f46467518a411be3a082e313acf Mon Sep 17 00:00:00 2001 From: Dmitry Novik Date: Tue, 7 Jun 2022 17:20:43 +0200 Subject: [PATCH 1/4] Merge pull request #36944 from excitoon-favorites/better_exp_smooth --- src/Processors/Transforms/WindowTransform.cpp | 327 ++++++++++-------- .../02020_exponential_smoothing.reference | 112 +++--- .../02020_exponential_smoothing.sql | 92 ++++- 3 files changed, 338 insertions(+), 193 deletions(-) diff --git a/src/Processors/Transforms/WindowTransform.cpp b/src/Processors/Transforms/WindowTransform.cpp index b81ed099915f..2a2fed1cc078 100644 --- a/src/Processors/Transforms/WindowTransform.cpp +++ b/src/Processors/Transforms/WindowTransform.cpp @@ -1,5 +1,7 @@ #include +#include + #include #include #include @@ -14,6 +16,7 @@ #include #include + namespace DB { @@ -1538,65 +1541,21 @@ struct WindowFunctionDenseRank final : public WindowFunction namespace recurrent_detail { - template T getLastValueFromInputColumn(const WindowTransform * /*transform*/, size_t /*function_index*/, size_t /*column_index*/) + template T getValue(const WindowTransform * /*transform*/, size_t /*function_index*/, size_t /*column_index*/, RowNumber /*row*/) { - throw Exception(ErrorCodes::NOT_IMPLEMENTED, "getLastValueFromInputColumn() is not implemented for {} type", typeid(T).name()); + throw Exception(ErrorCodes::NOT_IMPLEMENTED, "recurrent_detail::getValue() is not implemented for {} type", typeid(T).name()); } - template<> Float64 getLastValueFromInputColumn(const WindowTransform * transform, size_t function_index, size_t column_index) + template<> Float64 getValue(const WindowTransform * transform, size_t function_index, size_t column_index, RowNumber row) { const auto & workspace = transform->workspaces[function_index]; - auto current_row = transform->current_row; - - if (current_row.row == 0) - { - if (current_row.block > 0) - { - const auto & column = transform->blockAt(current_row.block - 1).input_columns[workspace.argument_column_indices[column_index]]; - return column->getFloat64(column->size() - 1); - } - } - else - { - const auto & column = transform->blockAt(current_row.block).input_columns[workspace.argument_column_indices[column_index]]; - return column->getFloat64(current_row.row - 1); - } - - return 0; - } - - template T getLastValueFromState(const WindowTransform * /*transform*/, size_t /*function_index*/, size_t /*data_index*/) - { - throw Exception(ErrorCodes::NOT_IMPLEMENTED, "getLastValueFromInputColumn() is not implemented for {} type", typeid(T).name()); - } - - template<> Float64 getLastValueFromState(const WindowTransform * transform, size_t function_index, size_t data_index) - { - const auto & workspace = transform->workspaces[function_index]; - if (workspace.aggregate_function_state.data() == nullptr) - { - return 0.0; - } - else - { - return static_cast(static_cast(workspace.aggregate_function_state.data()))[data_index]; - } - } - - template void setValueToState(const WindowTransform * /*transform*/, size_t /*function_index*/, T /*value*/, size_t /*data_index*/) - { - throw Exception(ErrorCodes::NOT_IMPLEMENTED, "setValueToState() is not implemented for {} type", typeid(T).name()); - } - - template<> void setValueToState(const WindowTransform * transform, size_t function_index, Float64 value, size_t data_index) - { - const auto & workspace = transform->workspaces[function_index]; - static_cast(static_cast(workspace.aggregate_function_state.data()))[data_index] = value; + const auto & column = transform->blockAt(row.block).input_columns[workspace.argument_column_indices[column_index]]; + return column->getFloat64(row.row); } template void setValueToOutputColumn(const WindowTransform * /*transform*/, size_t /*function_index*/, T /*value*/) { - throw Exception(ErrorCodes::NOT_IMPLEMENTED, "setValueToOutputColumn() is not implemented for {} type", typeid(T).name()); + throw Exception(ErrorCodes::NOT_IMPLEMENTED, "recurrent_detail::setValueToOutputColumn() is not implemented for {} type", typeid(T).name()); } template<> void setValueToOutputColumn(const WindowTransform * transform, size_t function_index, Float64 value) @@ -1607,82 +1566,77 @@ namespace recurrent_detail assert_cast(to).getData().push_back(value); } +} - template T getCurrentValueFromInputColumn(const WindowTransform * /*transform*/, size_t /*function_index*/, size_t /*column_index*/) +struct WindowFunctionHelpers +{ + template + static T getValue(const WindowTransform * transform, size_t function_index, size_t column_index, RowNumber row) { - throw Exception(ErrorCodes::NOT_IMPLEMENTED, "getCurrentValueFromInputColumn() is not implemented for {} type", typeid(T).name()); + return recurrent_detail::getValue(transform, function_index, column_index, row); } - template<> Float64 getCurrentValueFromInputColumn(const WindowTransform * transform, size_t function_index, size_t column_index) + template + static void setValueToOutputColumn(const WindowTransform * transform, size_t function_index, T value) { - const auto & workspace = transform->workspaces[function_index]; - auto current_row = transform->current_row; - const auto & current_block = transform->blockAt(current_row); - - return (*current_block.input_columns[workspace.argument_column_indices[column_index]]).getFloat64(transform->current_row.row); + recurrent_detail::setValueToOutputColumn(transform, function_index, value); } -} +}; -template -struct RecurrentWindowFunction : public WindowFunction +template +struct StatefulWindowFunction : public WindowFunction { - RecurrentWindowFunction(const std::string & name_, + StatefulWindowFunction(const std::string & name_, const DataTypes & argument_types_, const Array & parameters_) : WindowFunction(name_, argument_types_, parameters_) { } - size_t sizeOfData() const override { return sizeof(Float64)*state_size; } + size_t sizeOfData() const override { return sizeof(State); } size_t alignOfData() const override { return 1; } void create(AggregateDataPtr __restrict place) const override { - auto * const state = static_cast(static_cast(place)); - for (size_t i = 0; i < state_size; ++i) - state[i] = 0.0; + new (place) State(); } - template - static T getLastValueFromInputColumn(const WindowTransform * transform, size_t function_index, size_t column_index) + void destroy(AggregateDataPtr __restrict place) const noexcept override { - return recurrent_detail::getLastValueFromInputColumn(transform, function_index, column_index); + auto * const state = static_cast(static_cast(place)); + state->~State(); } - template - static T getLastValueFromState(const WindowTransform * transform, size_t function_index, size_t data_index) + State & getState(const WindowFunctionWorkspace & workspace) { - return recurrent_detail::getLastValueFromState(transform, function_index, data_index); - } - - template - static void setValueToState(const WindowTransform * transform, size_t function_index, T value, size_t data_index) - { - recurrent_detail::setValueToState(transform, function_index, value, data_index); + return *static_cast(static_cast(workspace.aggregate_function_state.data())); } +}; - template - static void setValueToOutputColumn(const WindowTransform * transform, size_t function_index, T value) - { - recurrent_detail::setValueToOutputColumn(transform, function_index, value); - } +struct ExponentialTimeDecayedSumState +{ + RowNumber previous_frame_start; + RowNumber previous_frame_end; + Float64 previous_time; + Float64 previous_sum; +}; - template - static T getCurrentValueFromInputColumn(const WindowTransform * transform, size_t function_index, size_t column_index) - { - return recurrent_detail::getCurrentValueFromInputColumn(transform, function_index, column_index); - } +struct ExponentialTimeDecayedAvgState +{ + RowNumber previous_frame_start; + RowNumber previous_frame_end; + Float64 previous_time; + Float64 previous_sum; + Float64 previous_count; }; -struct WindowFunctionExponentialTimeDecayedSum final : public RecurrentWindowFunction<1> +struct WindowFunctionExponentialTimeDecayedSum final : public StatefulWindowFunction { static constexpr size_t ARGUMENT_VALUE = 0; static constexpr size_t ARGUMENT_TIME = 1; - static constexpr size_t STATE_SUM = 0; - WindowFunctionExponentialTimeDecayedSum(const std::string & name_, const DataTypes & argument_types_, const Array & parameters_) - : RecurrentWindowFunction(name_, argument_types_, parameters_) + : StatefulWindowFunction(name_, argument_types_, parameters_) { if (parameters_.size() != 1) { @@ -1724,33 +1678,60 @@ struct WindowFunctionExponentialTimeDecayedSum final : public RecurrentWindowFun void windowInsertResultInto(const WindowTransform * transform, size_t function_index) override { - Float64 last_sum = getLastValueFromState(transform, function_index, STATE_SUM); - Float64 last_t = getLastValueFromInputColumn(transform, function_index, ARGUMENT_TIME); + const auto & workspace = transform->workspaces[function_index]; + auto & state = getState(workspace); + + Float64 result = 0; + Float64 curr_t = WindowFunctionHelpers::getValue(transform, function_index, ARGUMENT_TIME, transform->current_row); - Float64 x = getCurrentValueFromInputColumn(transform, function_index, ARGUMENT_VALUE); - Float64 t = getCurrentValueFromInputColumn(transform, function_index, ARGUMENT_TIME); + if (state.previous_frame_start <= transform->frame_start + && transform->frame_start < state.previous_frame_end + && state.previous_frame_end <= transform->frame_end) + { + for (RowNumber i = state.previous_frame_start; i < transform->frame_start; transform->advanceRowNumber(i)) + { + Float64 prev_val = WindowFunctionHelpers::getValue(transform, function_index, ARGUMENT_VALUE, i); + Float64 prev_t = WindowFunctionHelpers::getValue(transform, function_index, ARGUMENT_TIME, i); + result -= std::exp((prev_t - curr_t) / decay_length) * prev_val; + } + result += std::exp((state.previous_time - curr_t) / decay_length) * state.previous_sum; + for (RowNumber i = state.previous_frame_end; i < transform->frame_end; transform->advanceRowNumber(i)) + { + Float64 prev_val = WindowFunctionHelpers::getValue(transform, function_index, ARGUMENT_VALUE, i); + Float64 prev_t = WindowFunctionHelpers::getValue(transform, function_index, ARGUMENT_TIME, i); + result += std::exp((prev_t - curr_t) / decay_length) * prev_val; + } + } + else + { + for (RowNumber i = transform->frame_start; i < transform->frame_end; transform->advanceRowNumber(i)) + { + Float64 prev_val = WindowFunctionHelpers::getValue(transform, function_index, ARGUMENT_VALUE, i); + Float64 prev_t = WindowFunctionHelpers::getValue(transform, function_index, ARGUMENT_TIME, i); + result += std::exp((prev_t - curr_t) / decay_length) * prev_val; + } + } - Float64 c = exp((last_t - t) / decay_length); - Float64 result = x + c * last_sum; + state.previous_sum = result; + state.previous_time = curr_t; + state.previous_frame_start = transform->frame_start; + state.previous_frame_end = transform->frame_end; - setValueToOutputColumn(transform, function_index, result); - setValueToState(transform, function_index, result, STATE_SUM); + WindowFunctionHelpers::setValueToOutputColumn(transform, function_index, result); } private: Float64 decay_length; }; -struct WindowFunctionExponentialTimeDecayedMax final : public RecurrentWindowFunction<1> +struct WindowFunctionExponentialTimeDecayedMax final : public WindowFunction { static constexpr size_t ARGUMENT_VALUE = 0; static constexpr size_t ARGUMENT_TIME = 1; - static constexpr size_t STATE_MAX = 0; - WindowFunctionExponentialTimeDecayedMax(const std::string & name_, const DataTypes & argument_types_, const Array & parameters_) - : RecurrentWindowFunction(name_, argument_types_, parameters_) + : WindowFunction(name_, argument_types_, parameters_) { if (parameters_.size() != 1) { @@ -1792,32 +1773,35 @@ struct WindowFunctionExponentialTimeDecayedMax final : public RecurrentWindowFun void windowInsertResultInto(const WindowTransform * transform, size_t function_index) override { - Float64 last_max = getLastValueFromState(transform, function_index, STATE_MAX); - Float64 last_t = getLastValueFromInputColumn(transform, function_index, ARGUMENT_TIME); + Float64 result = std::numeric_limits::lowest(); + Float64 curr_t = WindowFunctionHelpers::getValue(transform, function_index, ARGUMENT_TIME, transform->current_row); - Float64 x = getCurrentValueFromInputColumn(transform, function_index, ARGUMENT_VALUE); - Float64 t = getCurrentValueFromInputColumn(transform, function_index, ARGUMENT_TIME); + for (RowNumber i = transform->frame_start; i < transform->frame_end; transform->advanceRowNumber(i)) + { + Float64 value = WindowFunctionHelpers::getValue(transform, function_index, ARGUMENT_VALUE, i); + Float64 t = WindowFunctionHelpers::getValue(transform, function_index, ARGUMENT_TIME, i); - Float64 c = exp((last_t - t) / decay_length); - Float64 result = std::max(x, c * last_max); + /// Avoiding extra calls to `exp` and multiplications. + if (value > result || t > curr_t || result < 0) + { + result = std::max(std::exp((t - curr_t) / decay_length) * value, result); + } + } - setValueToOutputColumn(transform, function_index, result); - setValueToState(transform, function_index, result, STATE_MAX); + WindowFunctionHelpers::setValueToOutputColumn(transform, function_index, result); } private: Float64 decay_length; }; -struct WindowFunctionExponentialTimeDecayedCount final : public RecurrentWindowFunction<1> +struct WindowFunctionExponentialTimeDecayedCount final : public StatefulWindowFunction { static constexpr size_t ARGUMENT_TIME = 0; - static constexpr size_t STATE_COUNT = 0; - WindowFunctionExponentialTimeDecayedCount(const std::string & name_, const DataTypes & argument_types_, const Array & parameters_) - : RecurrentWindowFunction(name_, argument_types_, parameters_) + : StatefulWindowFunction(name_, argument_types_, parameters_) { if (parameters_.size() != 1) { @@ -1851,33 +1835,57 @@ struct WindowFunctionExponentialTimeDecayedCount final : public RecurrentWindowF void windowInsertResultInto(const WindowTransform * transform, size_t function_index) override { - Float64 last_count = getLastValueFromState(transform, function_index, STATE_COUNT); - Float64 last_t = getLastValueFromInputColumn(transform, function_index, ARGUMENT_TIME); + const auto & workspace = transform->workspaces[function_index]; + auto & state = getState(workspace); + + Float64 result = 0; + Float64 curr_t = WindowFunctionHelpers::getValue(transform, function_index, ARGUMENT_TIME, transform->current_row); - Float64 t = getCurrentValueFromInputColumn(transform, function_index, ARGUMENT_TIME); + if (state.previous_frame_start <= transform->frame_start + && transform->frame_start < state.previous_frame_end + && state.previous_frame_end <= transform->frame_end) + { + for (RowNumber i = state.previous_frame_start; i < transform->frame_start; transform->advanceRowNumber(i)) + { + Float64 prev_t = WindowFunctionHelpers::getValue(transform, function_index, ARGUMENT_TIME, i); + result -= std::exp((prev_t - curr_t) / decay_length); + } + result += std::exp((state.previous_time - curr_t) / decay_length) * state.previous_sum; + for (RowNumber i = state.previous_frame_end; i < transform->frame_end; transform->advanceRowNumber(i)) + { + Float64 prev_t = WindowFunctionHelpers::getValue(transform, function_index, ARGUMENT_TIME, i); + result += std::exp((prev_t - curr_t) / decay_length); + } + } + else + { + for (RowNumber i = transform->frame_start; i < transform->frame_end; transform->advanceRowNumber(i)) + { + Float64 prev_t = WindowFunctionHelpers::getValue(transform, function_index, ARGUMENT_TIME, i); + result += std::exp((prev_t - curr_t) / decay_length); + } + } - Float64 c = exp((last_t - t) / decay_length); - Float64 result = c * last_count + 1.0; + state.previous_sum = result; + state.previous_time = curr_t; + state.previous_frame_start = transform->frame_start; + state.previous_frame_end = transform->frame_end; - setValueToOutputColumn(transform, function_index, result); - setValueToState(transform, function_index, result, STATE_COUNT); + WindowFunctionHelpers::setValueToOutputColumn(transform, function_index, result); } private: Float64 decay_length; }; -struct WindowFunctionExponentialTimeDecayedAvg final : public RecurrentWindowFunction<2> +struct WindowFunctionExponentialTimeDecayedAvg final : public StatefulWindowFunction { static constexpr size_t ARGUMENT_VALUE = 0; static constexpr size_t ARGUMENT_TIME = 1; - static constexpr size_t STATE_SUM = 0; - static constexpr size_t STATE_COUNT = 1; - WindowFunctionExponentialTimeDecayedAvg(const std::string & name_, const DataTypes & argument_types_, const Array & parameters_) - : RecurrentWindowFunction(name_, argument_types_, parameters_) + : StatefulWindowFunction(name_, argument_types_, parameters_) { if (parameters_.size() != 1) { @@ -1919,21 +1927,60 @@ struct WindowFunctionExponentialTimeDecayedAvg final : public RecurrentWindowFun void windowInsertResultInto(const WindowTransform * transform, size_t function_index) override { - Float64 last_sum = getLastValueFromState(transform, function_index, STATE_SUM); - Float64 last_count = getLastValueFromState(transform, function_index, STATE_COUNT); - Float64 last_t = getLastValueFromInputColumn(transform, function_index, ARGUMENT_TIME); + const auto & workspace = transform->workspaces[function_index]; + auto & state = getState(workspace); + + Float64 count = 0; + Float64 sum = 0; + Float64 curr_t = WindowFunctionHelpers::getValue(transform, function_index, ARGUMENT_TIME, transform->current_row); + + if (state.previous_frame_start <= transform->frame_start + && transform->frame_start < state.previous_frame_end + && state.previous_frame_end <= transform->frame_end) + { + for (RowNumber i = state.previous_frame_start; i < transform->frame_start; transform->advanceRowNumber(i)) + { + Float64 prev_val = WindowFunctionHelpers::getValue(transform, function_index, ARGUMENT_VALUE, i); + Float64 prev_t = WindowFunctionHelpers::getValue(transform, function_index, ARGUMENT_TIME, i); + Float64 decay = std::exp((prev_t - curr_t) / decay_length); + sum -= decay * prev_val; + count -= decay; + } - Float64 x = getCurrentValueFromInputColumn(transform, function_index, ARGUMENT_VALUE); - Float64 t = getCurrentValueFromInputColumn(transform, function_index, ARGUMENT_TIME); + { + Float64 decay = std::exp((state.previous_time - curr_t) / decay_length); + sum += decay * state.previous_sum; + count += decay * state.previous_count; + } + + for (RowNumber i = state.previous_frame_end; i < transform->frame_end; transform->advanceRowNumber(i)) + { + Float64 prev_val = WindowFunctionHelpers::getValue(transform, function_index, ARGUMENT_VALUE, i); + Float64 prev_t = WindowFunctionHelpers::getValue(transform, function_index, ARGUMENT_TIME, i); + Float64 decay = std::exp((prev_t - curr_t) / decay_length); + sum += decay * prev_val; + count += decay; + } + } + else + { + for (RowNumber i = transform->frame_start; i < transform->frame_end; transform->advanceRowNumber(i)) + { + Float64 prev_val = WindowFunctionHelpers::getValue(transform, function_index, ARGUMENT_VALUE, i); + Float64 prev_t = WindowFunctionHelpers::getValue(transform, function_index, ARGUMENT_TIME, i); + Float64 decay = std::exp((prev_t - curr_t) / decay_length); + sum += decay * prev_val; + count += decay; + } + } - Float64 c = exp((last_t - t) / decay_length); - Float64 new_sum = c * last_sum + x; - Float64 new_count = c * last_count + 1.0; - Float64 result = new_sum / new_count; + state.previous_sum = sum; + state.previous_count = count; + state.previous_time = curr_t; + state.previous_frame_start = transform->frame_start; + state.previous_frame_end = transform->frame_end; - setValueToOutputColumn(transform, function_index, result); - setValueToState(transform, function_index, new_sum, STATE_SUM); - setValueToState(transform, function_index, new_count, STATE_COUNT); + WindowFunctionHelpers::setValueToOutputColumn(transform, function_index, sum/count); } private: diff --git a/tests/queries/0_stateless/02020_exponential_smoothing.reference b/tests/queries/0_stateless/02020_exponential_smoothing.reference index b3c234206783..334d32e1c163 100644 --- a/tests/queries/0_stateless/02020_exponential_smoothing.reference +++ b/tests/queries/0_stateless/02020_exponential_smoothing.reference @@ -1,13 +1,14 @@ +exponentialMovingAverage 1 0 0.5 0 1 0.25 0 2 0.125 -0 3 0.0625 -0 4 0.03125 -0 5 0.015625 -0 6 0.0078125 -0 7 0.00390625 -0 8 0.001953125 -0 9 0.0009765625 +0 3 0.062 +0 4 0.031 +0 5 0.016 +0 6 0.008 +0 7 0.004 +0 8 0.002 +0 9 0.001 1 0 0.067 0 1 0.062 0 2 0.058 @@ -128,16 +129,17 @@ 0 47 0.129 ██████▍ 0 48 0.065 ███▏ 0 49 0.032 █▌ +exponentialTimeDecayedSum 1 0 1 -0 1 0.36787944117144233 -0 2 0.1353352832366127 -0 3 0.04978706836786395 -0 4 0.018315638888734186 -0 5 0.00673794699908547 -0 6 0.0024787521766663594 -0 7 0.0009118819655545166 -0 8 0.00033546262790251196 -0 9 0.0001234098040866796 +0 1 0.368 +0 2 0.135 +0 3 0.05 +0 4 0.018 +0 5 0.007 +0 6 0.002 +0 7 0.001 +0 8 0 +0 9 0 1 0 1 0 1 0.905 0 2 0.819 @@ -258,16 +260,17 @@ 0 47 0.136 ██████▋ 0 48 0.05 ██▌ 0 49 0.018 ▊ +exponentialTimeDecayedMax 1 0 1 -0 1 0.36787944117144233 -0 2 0.1353352832366127 -0 3 0.04978706836786395 -0 4 0.018315638888734186 -0 5 0.00673794699908547 -0 6 0.0024787521766663594 -0 7 0.0009118819655545166 -0 8 0.00033546262790251196 -0 9 0.0001234098040866796 +0 1 0.368 +0 2 0.135 +0 3 0.05 +0 4 0.018 +0 5 0.007 +0 6 0.002 +0 7 0.001 +0 8 0 +0 9 0 1 0 1 0 1 0.905 0 2 0.819 @@ -388,16 +391,17 @@ 0 47 0.135 ██████▋ 0 48 0.05 ██▍ 0 49 0.018 ▊ +exponentialTimeDecayedCount 1 0 1 -0 1 1.3678794411714423 -0 2 1.5032147244080551 -0 3 1.553001792775919 -0 4 1.5713174316646532 -0 5 1.5780553786637386 -0 6 1.5805341308404048 -0 7 1.5814460128059595 -0 8 1.581781475433862 -0 9 1.5819048852379487 +0 1 1.368 +0 2 1.503 +0 3 1.553 +0 4 1.571 +0 5 1.578 +0 6 1.581 +0 7 1.581 +0 8 1.582 +0 9 1.582 1 0 1 0 1 1.905 0 2 2.724 @@ -518,16 +522,17 @@ 0 47 10.422 ██████████████████████████ 0 48 10.43 ██████████████████████████ 0 49 10.438 ██████████████████████████ +exponentialTimeDecayedAvg 1 0 1 -0 1 0.2689414213699951 -0 2 0.09003057317038046 -0 3 0.032058603280084995 -0 4 0.01165623095603961 -0 5 0.004269778545282112 -0 6 0.0015683003158864733 -0 7 0.000576612769687006 -0 8 0.00021207899644323433 -0 9 0.00007801341612780745 +0 1 0.269 +0 2 0.09 +0 3 0.032 +0 4 0.012 +0 5 0.004 +0 6 0.002 +0 7 0.001 +0 8 0 +0 9 0 1 0 1 0 1 0.475 0 2 0.301 @@ -648,3 +653,24 @@ 0 47 0.206 ████████████████████▋ 0 48 0.201 ████████████████████ 0 49 0.196 ███████████████████▌ +Check `exponentialTimeDecayed.*` supports sliding windows +2 1 3.010050167084 2 3.030251507111 0.993333444442 +1 2 7.060905027605 4.080805360107 4.02030134086 1.756312382816 +0 3 12.091654548833 5.101006700134 5.000500014167 2.418089094006 +4 4 11.050650848754 5.050250835421 5.000500014167 2.209909172572 +5 5 9.970249502081 5 5.000500014167 1.993850509716 +1 6 20.07305726224 10.202013400268 5.000500014167 4.014210020072 +0 7 15.991544871125 10.100501670842 3.98029867414 4.017674596889 +10 8 10.980198673307 10 2.970248507056 3.696727276261 +Check `exponentialTimeDecayedMax` works with negative values +2 1 -1.010050167084 +1 2 -1 +10 3 -0.990049833749 +4 4 -0.980198673307 +5 5 -1.010050167084 +1 6 -1 +10 7 -0.990049833749 +10 8 -0.980198673307 +10 9 -9.801986733068 +9.81 10 -9.801986733068 +9.9 11 -9.712388869079 diff --git a/tests/queries/0_stateless/02020_exponential_smoothing.sql b/tests/queries/0_stateless/02020_exponential_smoothing.sql index a39b09a883da..462081b12d6c 100644 --- a/tests/queries/0_stateless/02020_exponential_smoothing.sql +++ b/tests/queries/0_stateless/02020_exponential_smoothing.sql @@ -1,5 +1,6 @@ --- exponentialMovingAverage -SELECT number = 0 AS value, number AS time, exponentialMovingAverage(1)(value, time) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS exp_smooth FROM numbers(10); +SELECT 'exponentialMovingAverage'; + +SELECT value, time, round(exp_smooth, 3) FROM (SELECT number = 0 AS value, number AS time, exponentialMovingAverage(1)(value, time) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS exp_smooth FROM numbers(10)); SELECT value, time, round(exp_smooth, 3) FROM (SELECT number = 0 AS value, number AS time, exponentialMovingAverage(10)(value, time) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS exp_smooth FROM numbers(10)); SELECT number AS value, number AS time, exponentialMovingAverage(1)(value, time) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS exp_smooth FROM numbers(10); @@ -32,8 +33,9 @@ FROM FROM numbers(50) ); --- exponentialTimeDecayedSum -SELECT number = 0 AS value, number AS time, exponentialTimeDecayedSum(1)(value, time) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS exp_smooth FROM numbers(10); +SELECT 'exponentialTimeDecayedSum'; + +SELECT value, time, round(exp_smooth, 3) FROM (SELECT number = 0 AS value, number AS time, exponentialTimeDecayedSum(1)(value, time) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS exp_smooth FROM numbers(10)); SELECT value, time, round(exp_smooth, 3) FROM (SELECT number = 0 AS value, number AS time, exponentialTimeDecayedSum(10)(value, time) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS exp_smooth FROM numbers(10)); SELECT number AS value, number AS time, exponentialTimeDecayedSum(1)(value, time) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS exp_smooth FROM numbers(10); @@ -66,8 +68,9 @@ FROM FROM numbers(50) ); --- exponentialTimeDecayedMax -SELECT number = 0 AS value, number AS time, exponentialTimeDecayedMax(1)(value, time) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS exp_smooth FROM numbers(10); +SELECT 'exponentialTimeDecayedMax'; + +SELECT value, time, round(exp_smooth, 3) FROM (SELECT number = 0 AS value, number AS time, exponentialTimeDecayedMax(1)(value, time) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS exp_smooth FROM numbers(10)); SELECT value, time, round(exp_smooth, 3) FROM (SELECT number = 0 AS value, number AS time, exponentialTimeDecayedMax(10)(value, time) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS exp_smooth FROM numbers(10)); SELECT number AS value, number AS time, exponentialTimeDecayedMax(1)(value, time) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS exp_smooth FROM numbers(10); @@ -100,8 +103,9 @@ FROM FROM numbers(50) ); --- exponentialTimeDecayedCount -SELECT number = 0 AS value, number AS time, exponentialTimeDecayedCount(1)(time) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS exp_smooth FROM numbers(10); +SELECT 'exponentialTimeDecayedCount'; + +SELECT value, time, round(exp_smooth, 3) FROM (SELECT number = 0 AS value, number AS time, exponentialTimeDecayedCount(1)(time) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS exp_smooth FROM numbers(10)); SELECT value, time, round(exp_smooth, 3) FROM (SELECT number = 0 AS value, number AS time, exponentialTimeDecayedCount(10)(time) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS exp_smooth FROM numbers(10)); SELECT number AS value, number AS time, exponentialTimeDecayedCount(1)(time) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS exp_smooth FROM numbers(10); @@ -134,8 +138,9 @@ FROM FROM numbers(50) ); --- exponentialTimeDecayedAvg -SELECT number = 0 AS value, number AS time, exponentialTimeDecayedAvg(1)(value, time) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS exp_smooth FROM numbers(10); +SELECT 'exponentialTimeDecayedAvg'; + +SELECT value, time, round(exp_smooth, 3) FROM (SELECT number = 0 AS value, number AS time, exponentialTimeDecayedAvg(1)(value, time) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS exp_smooth FROM numbers(10)); SELECT value, time, round(exp_smooth, 3) FROM (SELECT number = 0 AS value, number AS time, exponentialTimeDecayedAvg(10)(value, time) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS exp_smooth FROM numbers(10)); SELECT number AS value, number AS time, exponentialTimeDecayedAvg(1)(value, time) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS exp_smooth FROM numbers(10); @@ -167,3 +172,70 @@ FROM exponentialTimeDecayedAvg(100)(value, time) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS exp_smooth FROM numbers(50) ); + +SELECT 'Check `exponentialTimeDecayed.*` supports sliding windows'; + +SELECT + x, + t, + round(sum, 12), + round(max, 12), + round(count, 12), + round(avg, 12) +FROM +( + SELECT + d[1] AS x, + d[2] AS t, + exponentialTimeDecayedSum(100)(x, t) OVER w AS sum, + exponentialTimeDecayedMax(100)(x, t) OVER w AS max, + exponentialTimeDecayedCount(100)(t) OVER w AS count, + exponentialTimeDecayedAvg(100)(x, t) OVER w AS avg + FROM + ( + SELECT [[2, 1], [1, 2], [0, 3], [4, 4], [5, 5], [1, 6], [0, 7], [10, 8]] AS d + ) + ARRAY JOIN d + WINDOW w AS (ORDER BY 1 ASC Rows BETWEEN 2 PRECEDING AND 2 FOLLOWING) +); + +SELECT + x, + t, + round(sum, 12), + round(max, 12), + round(count, 12), + round(avg, 12) +FROM +( + SELECT + sin(number) AS x, + number AS t, + exponentialTimeDecayedSum(100)(x, t) OVER w AS sum, + exponentialTimeDecayedMax(100)(x, t) OVER w AS max, + exponentialTimeDecayedCount(100)(t) OVER w AS count, + exponentialTimeDecayedAvg(100)(x, t) OVER w AS avg + FROM numbers(1000000) + WINDOW w AS (ORDER BY 1 ASC Rows BETWEEN 2 PRECEDING AND 2 FOLLOWING) +) +FORMAT `Null`; + +SELECT 'Check `exponentialTimeDecayedMax` works with negative values'; + +SELECT + x, + t, + round(max, 12) +FROM +( + SELECT + d[1] AS x, + d[2] AS t, + exponentialTimeDecayedMax(100)(-x, t) OVER w AS max + FROM + ( + SELECT [[2, 1], [1, 2], [10, 3], [4, 4], [5, 5], [1, 6], [10, 7], [10, 8], [10, 9], [9.81, 10], [9.9, 11]] AS d + ) + ARRAY JOIN d + WINDOW w AS (ORDER BY 1 ASC Rows BETWEEN 2 PRECEDING AND 2 FOLLOWING) +); From c860f72adbefced2fad5bc401017e74c9912b91e Mon Sep 17 00:00:00 2001 From: Dmitry Novik Date: Mon, 20 Jun 2022 20:03:26 +0200 Subject: [PATCH 2/4] Merge pull request #34632 from excitoon-favorites/optimizedprocessing --- src/Core/Settings.h | 1 + src/Interpreters/ExpressionAnalyzer.cpp | 14 +- src/Interpreters/ExpressionAnalyzer.h | 1 + src/Interpreters/InterpreterSelectQuery.cpp | 8 +- src/Interpreters/InterpreterSelectQuery.h | 3 + src/Interpreters/WindowDescription.h | 4 + .../QueryPlan/Optimizations/Optimizations.h | 9 +- ...reuseStorageOrderingForWindowFunctions.cpp | 122 ++++++++++++++++++ .../QueryPlan/ReadFromMergeTree.cpp | 26 +++- src/Processors/QueryPlan/ReadFromMergeTree.h | 7 + src/Processors/QueryPlan/SortingStep.cpp | 6 + src/Processors/QueryPlan/SortingStep.h | 2 + src/Processors/QueryPlan/WindowStep.cpp | 5 + src/Processors/QueryPlan/WindowStep.h | 2 + ...ns_optimize_read_in_window_order.reference | 12 ++ ...mizations_optimize_read_in_window_order.sh | 36 ++++++ ...timize_read_in_window_order_long.reference | 4 + ...ions_optimize_read_in_window_order_long.sh | 35 +++++ 18 files changed, 287 insertions(+), 10 deletions(-) create mode 100644 src/Processors/QueryPlan/Optimizations/reuseStorageOrderingForWindowFunctions.cpp create mode 100644 tests/queries/0_stateless/01655_plan_optimizations_optimize_read_in_window_order.reference create mode 100755 tests/queries/0_stateless/01655_plan_optimizations_optimize_read_in_window_order.sh create mode 100644 tests/queries/0_stateless/01655_plan_optimizations_optimize_read_in_window_order_long.reference create mode 100755 tests/queries/0_stateless/01655_plan_optimizations_optimize_read_in_window_order_long.sh diff --git a/src/Core/Settings.h b/src/Core/Settings.h index 0c69b864d528..b40d88f769f0 100644 --- a/src/Core/Settings.h +++ b/src/Core/Settings.h @@ -396,6 +396,7 @@ class IColumn; M(Bool, parallel_view_processing, false, "Enables pushing to attached views concurrently instead of sequentially.", 0) \ M(Bool, enable_unaligned_array_join, false, "Allow ARRAY JOIN with multiple arrays that have different sizes. When this settings is enabled, arrays will be resized to the longest one.", 0) \ M(Bool, optimize_read_in_order, true, "Enable ORDER BY optimization for reading data in corresponding order in MergeTree tables.", 0) \ + M(Bool, optimize_read_in_window_order, true, "Enable ORDER BY optimization in window clause for reading data in corresponding order in MergeTree tables.", 0) \ M(Bool, optimize_aggregation_in_order, false, "Enable GROUP BY optimization for aggregating data in corresponding order in MergeTree tables.", 0) \ M(UInt64, aggregation_in_order_max_block_bytes, 50000000, "Maximal size of block in bytes accumulated during aggregation in order of primary key. Lower block size allows to parallelize more final merge stage of aggregation.", 0) \ M(UInt64, read_in_order_two_level_merge_threshold, 100, "Minimal number of parts to read to run preliminary merge step during multithread reading in order of primary key.", 0) \ diff --git a/src/Interpreters/ExpressionAnalyzer.cpp b/src/Interpreters/ExpressionAnalyzer.cpp index 841d7bc567f9..2cf2a17fe1f3 100644 --- a/src/Interpreters/ExpressionAnalyzer.cpp +++ b/src/Interpreters/ExpressionAnalyzer.cpp @@ -607,7 +607,7 @@ void ExpressionAnalyzer::makeAggregateDescriptions(ActionsDAGPtr & actions, Aggr } } -void makeWindowDescriptionFromAST(const Context & context, +void ExpressionAnalyzer::makeWindowDescriptionFromAST(const Context & context_, const WindowDescriptions & existing_descriptions, WindowDescription & desc, const IAST * ast) { @@ -676,6 +676,10 @@ void makeWindowDescriptionFromAST(const Context & context, desc.partition_by.push_back(SortColumnDescription( with_alias->getColumnName(), 1 /* direction */, 1 /* nulls_direction */)); + + auto actions_dag = std::make_shared(columns_after_join); + getRootActions(column_ast, false, actions_dag); + desc.partition_by_actions.push_back(std::move(actions_dag)); } } @@ -693,6 +697,10 @@ void makeWindowDescriptionFromAST(const Context & context, order_by_element.children.front()->getColumnName(), order_by_element.direction, order_by_element.nulls_direction)); + + auto actions_dag = std::make_shared(columns_after_join); + getRootActions(column_ast, false, actions_dag); + desc.order_by_actions.push_back(std::move(actions_dag)); } } @@ -719,14 +727,14 @@ void makeWindowDescriptionFromAST(const Context & context, if (definition.frame_end_type == WindowFrame::BoundaryType::Offset) { auto [value, _] = evaluateConstantExpression(definition.frame_end_offset, - context.shared_from_this()); + context_.shared_from_this()); desc.frame.end_offset = value; } if (definition.frame_begin_type == WindowFrame::BoundaryType::Offset) { auto [value, _] = evaluateConstantExpression(definition.frame_begin_offset, - context.shared_from_this()); + context_.shared_from_this()); desc.frame.begin_offset = value; } } diff --git a/src/Interpreters/ExpressionAnalyzer.h b/src/Interpreters/ExpressionAnalyzer.h index 8736f6548e34..52a9f9d9ff4d 100644 --- a/src/Interpreters/ExpressionAnalyzer.h +++ b/src/Interpreters/ExpressionAnalyzer.h @@ -132,6 +132,7 @@ class ExpressionAnalyzer : protected ExpressionAnalyzerData, private boost::nonc /// A list of windows for window functions. const WindowDescriptions & windowDescriptions() const { return window_descriptions; } + void makeWindowDescriptionFromAST(const Context & context, const WindowDescriptions & existing_descriptions, WindowDescription & desc, const IAST * ast); void makeWindowDescriptions(ActionsDAGPtr actions); /** diff --git a/src/Interpreters/InterpreterSelectQuery.cpp b/src/Interpreters/InterpreterSelectQuery.cpp index 35d20be46553..ffa094a82ae5 100644 --- a/src/Interpreters/InterpreterSelectQuery.cpp +++ b/src/Interpreters/InterpreterSelectQuery.cpp @@ -811,7 +811,7 @@ static FillColumnDescription getWithFillDescription(const ASTOrderByElement & or return descr; } -static SortDescription getSortDescription(const ASTSelectQuery & query, ContextPtr context) +SortDescription InterpreterSelectQuery::getSortDescription(const ASTSelectQuery & query, ContextPtr context_) { SortDescription order_descr; order_descr.reserve(query.orderBy()->children.size()); @@ -826,7 +826,7 @@ static SortDescription getSortDescription(const ASTSelectQuery & query, ContextP if (order_by_elem.with_fill) { - FillColumnDescription fill_desc = getWithFillDescription(order_by_elem, context); + FillColumnDescription fill_desc = getWithFillDescription(order_by_elem, context_); order_descr.emplace_back(name, order_by_elem.direction, order_by_elem.nulls_direction, collator, true, fill_desc); } else @@ -885,12 +885,12 @@ static std::pair getLimitLengthAndOffset(const ASTSelectQuery & } -static UInt64 getLimitForSorting(const ASTSelectQuery & query, ContextPtr context) +UInt64 InterpreterSelectQuery::getLimitForSorting(const ASTSelectQuery & query, ContextPtr context_) { /// Partial sort can be done if there is LIMIT but no DISTINCT or LIMIT BY, neither ARRAY JOIN. if (!query.distinct && !query.limitBy() && !query.limit_with_ties && !query.arrayJoinExpressionList().first && query.limitLength()) { - auto [limit_length, limit_offset] = getLimitLengthAndOffset(query, context); + auto [limit_length, limit_offset] = getLimitLengthAndOffset(query, context_); if (limit_length > std::numeric_limits::max() - limit_offset) return 0; diff --git a/src/Interpreters/InterpreterSelectQuery.h b/src/Interpreters/InterpreterSelectQuery.h index aa41d8376013..0eb6ae06e82f 100644 --- a/src/Interpreters/InterpreterSelectQuery.h +++ b/src/Interpreters/InterpreterSelectQuery.h @@ -106,6 +106,9 @@ class InterpreterSelectQuery : public IInterpreterUnionOrSelectQuery Names getRequiredColumns() { return required_columns; } + static SortDescription getSortDescription(const ASTSelectQuery & query, ContextPtr context); + static UInt64 getLimitForSorting(const ASTSelectQuery & query, ContextPtr context); + private: InterpreterSelectQuery( const ASTPtr & query_ptr_, diff --git a/src/Interpreters/WindowDescription.h b/src/Interpreters/WindowDescription.h index bb0130b4d4ee..65c8cb9423c9 100644 --- a/src/Interpreters/WindowDescription.h +++ b/src/Interpreters/WindowDescription.h @@ -7,6 +7,7 @@ #include #include #include +#include namespace DB { @@ -90,6 +91,9 @@ struct WindowDescription // then by ORDER BY. This field holds this combined sort order. SortDescription full_sort_description; + std::vector partition_by_actions; + std::vector order_by_actions; + WindowFrame frame; // The window functions that are calculated for this window. diff --git a/src/Processors/QueryPlan/Optimizations/Optimizations.h b/src/Processors/QueryPlan/Optimizations/Optimizations.h index 10bc62935371..11bdb1c95e50 100644 --- a/src/Processors/QueryPlan/Optimizations/Optimizations.h +++ b/src/Processors/QueryPlan/Optimizations/Optimizations.h @@ -44,16 +44,21 @@ size_t tryMergeExpressions(QueryPlan::Node * parent_node, QueryPlan::Nodes &); /// May split FilterStep and push down only part of it. size_t tryPushDownFilter(QueryPlan::Node * parent_node, QueryPlan::Nodes & nodes); +/// Utilize storage sorting when sorting for window functions. +/// Update information about prefix sort description in SortingStep. +size_t tryReuseStorageOrderingForWindowFunctions(QueryPlan::Node * parent_node, QueryPlan::Nodes & nodes); + inline const auto & getOptimizations() { - static const std::array optimizations = + static const std::array optimizations = {{ {tryLiftUpArrayJoin, "liftUpArrayJoin", &QueryPlanOptimizationSettings::optimize_plan}, {tryPushDownLimit, "pushDownLimit", &QueryPlanOptimizationSettings::optimize_plan}, {trySplitFilter, "splitFilter", &QueryPlanOptimizationSettings::optimize_plan}, {tryMergeExpressions, "mergeExpressions", &QueryPlanOptimizationSettings::optimize_plan}, {tryPushDownFilter, "pushDownFilter", &QueryPlanOptimizationSettings::filter_push_down}, - }}; + {tryReuseStorageOrderingForWindowFunctions, "reuseStorageOrderingForWindowFunctions", &QueryPlanOptimizationSettings::optimize_plan} + }}; return optimizations; } diff --git a/src/Processors/QueryPlan/Optimizations/reuseStorageOrderingForWindowFunctions.cpp b/src/Processors/QueryPlan/Optimizations/reuseStorageOrderingForWindowFunctions.cpp new file mode 100644 index 000000000000..c68ec47edff0 --- /dev/null +++ b/src/Processors/QueryPlan/Optimizations/reuseStorageOrderingForWindowFunctions.cpp @@ -0,0 +1,122 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace DB::QueryPlanOptimizations +{ + +size_t tryReuseStorageOrderingForWindowFunctions(QueryPlan::Node * parent_node, QueryPlan::Nodes & /*nodes*/) +{ + /// Find the following sequence of steps, add InputOrderInfo and apply prefix sort description to + /// SortingStep: + /// WindowStep <- SortingStep <- [Expression] <- [SettingQuotaAndLimits] <- ReadFromMergeTree + + auto * window_node = parent_node; + auto * window = typeid_cast(window_node->step.get()); + if (!window) + return 0; + if (window_node->children.size() != 1) + return 0; + + auto * sorting_node = window_node->children.front(); + auto * sorting = typeid_cast(sorting_node->step.get()); + if (!sorting) + return 0; + if (sorting_node->children.size() != 1) + return 0; + + auto * possible_read_from_merge_tree_node = sorting_node->children.front(); + + if (typeid_cast(possible_read_from_merge_tree_node->step.get())) + { + if (possible_read_from_merge_tree_node->children.size() != 1) + return 0; + + possible_read_from_merge_tree_node = possible_read_from_merge_tree_node->children.front(); + } + + if (typeid_cast(possible_read_from_merge_tree_node->step.get())) + { + if (possible_read_from_merge_tree_node->children.size() != 1) + return 0; + + possible_read_from_merge_tree_node = possible_read_from_merge_tree_node->children.front(); + } + + auto * read_from_merge_tree = typeid_cast(possible_read_from_merge_tree_node->step.get()); + if (!read_from_merge_tree) + { + return 0; + } + + auto context = read_from_merge_tree->getContext(); + if (!context->getSettings().optimize_read_in_window_order) + { + return 0; + } + + const auto & query_info = read_from_merge_tree->getQueryInfo(); + const auto * select_query = query_info.query->as(); + + ManyExpressionActions order_by_elements_actions; + const auto & window_desc = window->getWindowDescription(); + + for (const auto & actions_dag : window_desc.partition_by_actions) + { + order_by_elements_actions.emplace_back( + std::make_shared(actions_dag, ExpressionActionsSettings::fromContext(context, CompileExpressions::yes))); + } + + for (const auto & actions_dag : window_desc.order_by_actions) + { + order_by_elements_actions.emplace_back( + std::make_shared(actions_dag, ExpressionActionsSettings::fromContext(context, CompileExpressions::yes))); + } + + auto order_optimizer = std::make_shared( + *select_query, + order_by_elements_actions, + window->getWindowDescription().full_sort_description, + query_info.syntax_analyzer_result); + + read_from_merge_tree->setQueryInfoOrderOptimizer(order_optimizer); + + /// If we don't have filtration, we can pushdown limit to reading stage for optimizations. + UInt64 limit = (select_query->hasFiltration() || select_query->groupBy()) ? 0 : InterpreterSelectQuery::getLimitForSorting(*select_query, context); + + auto order_info = order_optimizer->getInputOrder( + query_info.projection ? query_info.projection->desc->metadata : read_from_merge_tree->getStorageMetadata(), + context, + limit); + + if (order_info) + { + read_from_merge_tree->setQueryInfoInputOrderInfo(order_info); + sorting->convertToFinishSorting(order_info->order_key_prefix_descr); + } + + return 0; +} + +} diff --git a/src/Processors/QueryPlan/ReadFromMergeTree.cpp b/src/Processors/QueryPlan/ReadFromMergeTree.cpp index 1bfc1ec7306f..abc753c5fa7b 100644 --- a/src/Processors/QueryPlan/ReadFromMergeTree.cpp +++ b/src/Processors/QueryPlan/ReadFromMergeTree.cpp @@ -977,6 +977,30 @@ MergeTreeDataSelectAnalysisResultPtr ReadFromMergeTree::selectRangesToRead( return std::make_shared(MergeTreeDataSelectAnalysisResult{.result = std::move(result)}); } +void ReadFromMergeTree::setQueryInfoOrderOptimizer(std::shared_ptr order_optimizer) +{ + if (query_info.projection) + { + query_info.projection->order_optimizer = order_optimizer; + } + else + { + query_info.order_optimizer = order_optimizer; + } +} + +void ReadFromMergeTree::setQueryInfoInputOrderInfo(InputOrderInfoPtr order_info) +{ + if (query_info.projection) + { + query_info.projection->input_order_info = order_info; + } + else + { + query_info.input_order_info = order_info; + } +} + ReadFromMergeTree::AnalysisResult ReadFromMergeTree::getAnalysisResult() const { auto result_ptr = analyzed_result_ptr ? analyzed_result_ptr : selectRangesToRead(prepared_parts); @@ -1060,7 +1084,7 @@ void ReadFromMergeTree::initializePipeline(QueryPipelineBuilder & pipeline, cons column_names_to_read, result_projection); } - else if ((settings.optimize_read_in_order || settings.optimize_aggregation_in_order) && input_order_info) + else if ((settings.optimize_read_in_order || settings.optimize_aggregation_in_order || settings.optimize_read_in_window_order) && input_order_info) { pipe = spreadMarkRangesAmongStreamsWithOrder( std::move(result.parts_with_ranges), diff --git a/src/Processors/QueryPlan/ReadFromMergeTree.h b/src/Processors/QueryPlan/ReadFromMergeTree.h index 685b99a7bdcb..16333edcaf3b 100644 --- a/src/Processors/QueryPlan/ReadFromMergeTree.h +++ b/src/Processors/QueryPlan/ReadFromMergeTree.h @@ -128,6 +128,13 @@ class ReadFromMergeTree final : public ISourceStep bool sample_factor_column_queried, Poco::Logger * log); + ContextPtr getContext() const { return context; } + const SelectQueryInfo & getQueryInfo() const { return query_info; } + StorageMetadataPtr getStorageMetadata() const { return metadata_for_reading; } + + void setQueryInfoOrderOptimizer(std::shared_ptr read_in_order_optimizer); + void setQueryInfoInputOrderInfo(InputOrderInfoPtr order_info); + private: const MergeTreeReaderSettings reader_settings; diff --git a/src/Processors/QueryPlan/SortingStep.cpp b/src/Processors/QueryPlan/SortingStep.cpp index 32b314b1c505..602680e17182 100644 --- a/src/Processors/QueryPlan/SortingStep.cpp +++ b/src/Processors/QueryPlan/SortingStep.cpp @@ -97,6 +97,12 @@ void SortingStep::updateLimit(size_t limit_) } } +void SortingStep::convertToFinishSorting(SortDescription prefix_description_) +{ + type = Type::FinishSorting; + prefix_description = std::move(prefix_description_); +} + void SortingStep::transformPipeline(QueryPipelineBuilder & pipeline, const BuildQueryPipelineSettings &) { if (type == Type::FinishSorting) diff --git a/src/Processors/QueryPlan/SortingStep.h b/src/Processors/QueryPlan/SortingStep.h index 8e253e71f441..4da98a15f65d 100644 --- a/src/Processors/QueryPlan/SortingStep.h +++ b/src/Processors/QueryPlan/SortingStep.h @@ -49,6 +49,8 @@ class SortingStep : public ITransformingStep /// Add limit or change it to lower value. void updateLimit(size_t limit_); + void convertToFinishSorting(SortDescription prefix_description); + private: enum class Type diff --git a/src/Processors/QueryPlan/WindowStep.cpp b/src/Processors/QueryPlan/WindowStep.cpp index cd4bb5f67307..0916b43b29a5 100644 --- a/src/Processors/QueryPlan/WindowStep.cpp +++ b/src/Processors/QueryPlan/WindowStep.cpp @@ -138,4 +138,9 @@ void WindowStep::describeActions(JSONBuilder::JSONMap & map) const map.add("Functions", std::move(functions_array)); } +const WindowDescription & WindowStep::getWindowDescription() const +{ + return window_description; +} + } diff --git a/src/Processors/QueryPlan/WindowStep.h b/src/Processors/QueryPlan/WindowStep.h index a65b157f4817..9b58cceb972b 100644 --- a/src/Processors/QueryPlan/WindowStep.h +++ b/src/Processors/QueryPlan/WindowStep.h @@ -25,6 +25,8 @@ class WindowStep : public ITransformingStep void describeActions(JSONBuilder::JSONMap & map) const override; void describeActions(FormatSettings & settings) const override; + const WindowDescription & getWindowDescription() const; + private: WindowDescription window_description; std::vector window_functions; diff --git a/tests/queries/0_stateless/01655_plan_optimizations_optimize_read_in_window_order.reference b/tests/queries/0_stateless/01655_plan_optimizations_optimize_read_in_window_order.reference new file mode 100644 index 000000000000..7fcd29b5faf9 --- /dev/null +++ b/tests/queries/0_stateless/01655_plan_optimizations_optimize_read_in_window_order.reference @@ -0,0 +1,12 @@ +Partial sorting plan + optimize_read_in_window_order=0 + Sort description: n ASC, x ASC + optimize_read_in_window_order=1 + Prefix sort description: n ASC + Result sort description: n ASC, x ASC +No sorting plan + optimize_read_in_window_order=0 + Sort description: n ASC, x ASC + optimize_read_in_window_order=1 + Prefix sort description: n ASC, x ASC + Result sort description: n ASC, x ASC diff --git a/tests/queries/0_stateless/01655_plan_optimizations_optimize_read_in_window_order.sh b/tests/queries/0_stateless/01655_plan_optimizations_optimize_read_in_window_order.sh new file mode 100755 index 000000000000..418baea81136 --- /dev/null +++ b/tests/queries/0_stateless/01655_plan_optimizations_optimize_read_in_window_order.sh @@ -0,0 +1,36 @@ +#!/usr/bin/env bash + +CURDIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd) +# shellcheck source=../shell_config.sh +. "$CURDIR"/../shell_config.sh + +name=test_01655_plan_optimizations_optimize_read_in_window_order + +$CLICKHOUSE_CLIENT -q "drop table if exists ${name}" +$CLICKHOUSE_CLIENT -q "drop table if exists ${name}_n" +$CLICKHOUSE_CLIENT -q "drop table if exists ${name}_n_x" + +$CLICKHOUSE_CLIENT -q "create table ${name} engine=MergeTree order by tuple() as select toInt64((sin(number)+2)*65535)%10 as n, number as x from numbers_mt(100000)" +$CLICKHOUSE_CLIENT -q "create table ${name}_n engine=MergeTree order by n as select * from ${name} order by n" +$CLICKHOUSE_CLIENT -q "create table ${name}_n_x engine=MergeTree order by (n, x) as select * from ${name} order by n, x" + +$CLICKHOUSE_CLIENT -q "optimize table ${name}_n final" +$CLICKHOUSE_CLIENT -q "optimize table ${name}_n_x final" + +echo 'Partial sorting plan' +echo ' optimize_read_in_window_order=0' +$CLICKHOUSE_CLIENT -q "explain plan actions=1, description=1 select n, sum(x) OVER (ORDER BY n, x ROWS BETWEEN 100 PRECEDING AND CURRENT ROW) from ${name}_n SETTINGS optimize_read_in_window_order=0" | grep -i "sort description" + +echo ' optimize_read_in_window_order=1' +$CLICKHOUSE_CLIENT -q "explain plan actions=1, description=1 select n, sum(x) OVER (ORDER BY n, x ROWS BETWEEN 100 PRECEDING AND CURRENT ROW) from ${name}_n SETTINGS optimize_read_in_window_order=1" | grep -i "sort description" + +echo 'No sorting plan' +echo ' optimize_read_in_window_order=0' +$CLICKHOUSE_CLIENT -q "explain plan actions=1, description=1 select n, sum(x) OVER (ORDER BY n, x ROWS BETWEEN 100 PRECEDING AND CURRENT ROW) from ${name}_n_x SETTINGS optimize_read_in_window_order=0" | grep -i "sort description" + +echo ' optimize_read_in_window_order=1' +$CLICKHOUSE_CLIENT -q "explain plan actions=1, description=1 select n, sum(x) OVER (ORDER BY n, x ROWS BETWEEN 100 PRECEDING AND CURRENT ROW) from ${name}_n_x SETTINGS optimize_read_in_window_order=1" | grep -i "sort description" + +$CLICKHOUSE_CLIENT -q "drop table ${name}" +$CLICKHOUSE_CLIENT -q "drop table ${name}_n" +$CLICKHOUSE_CLIENT -q "drop table ${name}_n_x" diff --git a/tests/queries/0_stateless/01655_plan_optimizations_optimize_read_in_window_order_long.reference b/tests/queries/0_stateless/01655_plan_optimizations_optimize_read_in_window_order_long.reference new file mode 100644 index 000000000000..b462a5a7baa4 --- /dev/null +++ b/tests/queries/0_stateless/01655_plan_optimizations_optimize_read_in_window_order_long.reference @@ -0,0 +1,4 @@ +OK +OK +OK +OK diff --git a/tests/queries/0_stateless/01655_plan_optimizations_optimize_read_in_window_order_long.sh b/tests/queries/0_stateless/01655_plan_optimizations_optimize_read_in_window_order_long.sh new file mode 100755 index 000000000000..297688a29c32 --- /dev/null +++ b/tests/queries/0_stateless/01655_plan_optimizations_optimize_read_in_window_order_long.sh @@ -0,0 +1,35 @@ +#!/usr/bin/env bash +# Tags: long + +CURDIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd) +# shellcheck source=../shell_config.sh +. "$CURDIR"/../shell_config.sh + +name=test_01655_plan_optimizations_optimize_read_in_window_order_long +max_memory_usage=20000000 + +$CLICKHOUSE_CLIENT -q "drop table if exists ${name}" +$CLICKHOUSE_CLIENT -q "drop table if exists ${name}_n" +$CLICKHOUSE_CLIENT -q "drop table if exists ${name}_n_x" + +$CLICKHOUSE_CLIENT -q "create table ${name} engine=MergeTree order by tuple() as select toInt64((sin(number)+2)*65535)%500 as n, number as x from numbers_mt(5000000)" +$CLICKHOUSE_CLIENT -q "create table ${name}_n engine=MergeTree order by n as select * from ${name} order by n" +$CLICKHOUSE_CLIENT -q "create table ${name}_n_x engine=MergeTree order by (n, x) as select * from ${name} order by n, x" + +$CLICKHOUSE_CLIENT -q "optimize table ${name}_n final" +$CLICKHOUSE_CLIENT -q "optimize table ${name}_n_x final" + +$CLICKHOUSE_CLIENT -q "select n, sum(x) OVER (ORDER BY n, x ROWS BETWEEN 100 PRECEDING AND CURRENT ROW) from ${name}_n SETTINGS optimize_read_in_window_order=0, max_memory_usage=$max_memory_usage, max_threads=1 format Null" 2>&1 | grep -F -q "MEMORY_LIMIT_EXCEEDED" && echo 'OK' || echo 'FAIL' +$CLICKHOUSE_CLIENT -q "select n, sum(x) OVER (ORDER BY n, x ROWS BETWEEN 100 PRECEDING AND CURRENT ROW) from ${name}_n SETTINGS optimize_read_in_window_order=1, max_memory_usage=$max_memory_usage, max_threads=1 format Null" + +$CLICKHOUSE_CLIENT -q "select n, sum(x) OVER (ORDER BY n, x ROWS BETWEEN 100 PRECEDING AND CURRENT ROW) from ${name}_n_x SETTINGS optimize_read_in_window_order=0, max_memory_usage=$max_memory_usage, max_threads=1 format Null" 2>&1 | grep -F -q "MEMORY_LIMIT_EXCEEDED" && echo 'OK' || echo 'FAIL' +$CLICKHOUSE_CLIENT -q "select n, sum(x) OVER (ORDER BY n, x ROWS BETWEEN 100 PRECEDING AND CURRENT ROW) from ${name}_n_x SETTINGS optimize_read_in_window_order=1, max_memory_usage=$max_memory_usage, max_threads=1 format Null" + +$CLICKHOUSE_CLIENT -q "select n, sum(x) OVER (PARTITION BY n ORDER BY x ROWS BETWEEN 100 PRECEDING AND CURRENT ROW) from ${name}_n_x SETTINGS optimize_read_in_window_order=0, max_memory_usage=$max_memory_usage, max_threads=1 format Null" 2>&1 | grep -F -q "MEMORY_LIMIT_EXCEEDED" && echo 'OK' || echo 'FAIL' +$CLICKHOUSE_CLIENT -q "select n, sum(x) OVER (PARTITION BY n ORDER BY x ROWS BETWEEN 100 PRECEDING AND CURRENT ROW) from ${name}_n_x SETTINGS optimize_read_in_window_order=1, max_memory_usage=$max_memory_usage, max_threads=1 format Null" + +$CLICKHOUSE_CLIENT -q "select n, sum(x) OVER (PARTITION BY n+x%2 ORDER BY n, x ROWS BETWEEN 100 PRECEDING AND CURRENT ROW) from ${name}_n_x SETTINGS optimize_read_in_window_order=1, max_memory_usage=$max_memory_usage, max_threads=1 format Null" 2>&1 | grep -F -q "MEMORY_LIMIT_EXCEEDED" && echo 'OK' || echo 'FAIL' + +$CLICKHOUSE_CLIENT -q "drop table ${name}" +$CLICKHOUSE_CLIENT -q "drop table ${name}_n" +$CLICKHOUSE_CLIENT -q "drop table ${name}_n_x" From 9fb6a80de64d16bf9f3a0d02d74d4c9f4202c2fa Mon Sep 17 00:00:00 2001 From: alesapin Date: Mon, 13 Jun 2022 13:39:01 +0200 Subject: [PATCH 3/4] Merge pull request #37659 from frew/master Support `batch_delete` capability for GCS --- src/Disks/S3/DiskS3.cpp | 37 +++++++++++++++++++++++++-------- src/Disks/S3/DiskS3.h | 3 +++ src/Disks/S3/S3Capabilities.cpp | 15 +++++++++++++ src/Disks/S3/S3Capabilities.h | 27 ++++++++++++++++++++++++ src/Disks/S3/registerDiskS3.cpp | 2 ++ 5 files changed, 75 insertions(+), 9 deletions(-) create mode 100644 src/Disks/S3/S3Capabilities.cpp create mode 100644 src/Disks/S3/S3Capabilities.h diff --git a/src/Disks/S3/DiskS3.cpp b/src/Disks/S3/DiskS3.cpp index e46620d9d1f0..bf5d08370902 100644 --- a/src/Disks/S3/DiskS3.cpp +++ b/src/Disks/S3/DiskS3.cpp @@ -35,6 +35,7 @@ #include #include +#include #include #include #include @@ -155,12 +156,14 @@ DiskS3::DiskS3( DiskPtr metadata_disk_, FileCachePtr cache_, ContextPtr context_, + const S3Capabilities & s3_capabilities_, SettingsPtr settings_, GetDiskSettings settings_getter_) : IDiskRemote(name_, s3_root_path_, metadata_disk_, std::move(cache_), "DiskS3", settings_->thread_pool_size) , bucket(std::move(bucket_)) , current_settings(std::move(settings_)) , settings_getter(settings_getter_) + , s3_capabilities(s3_capabilities_) , context(context_) { } @@ -180,15 +183,31 @@ void DiskS3::removeFromRemoteFS(RemoteFSPathKeeperPtr fs_paths_keeper) s3_paths_keeper->removePaths([&](S3PathKeeper::Chunk && chunk) { String keys = S3PathKeeper::getChunkKeys(chunk); - LOG_TRACE(log, "Remove AWS keys {}", keys); - Aws::S3::Model::Delete delkeys; - delkeys.SetObjects(chunk); - Aws::S3::Model::DeleteObjectsRequest request; - request.SetBucket(bucket); - request.SetDelete(delkeys); - auto outcome = settings->client->DeleteObjects(request); - // Do not throw here, continue deleting other chunks - logIfError(outcome, [&](){return "Can't remove AWS keys: " + keys;}); + if (!s3_capabilities.support_batch_delete) + { + LOG_TRACE(log, "Remove AWS keys {} one by one", keys); + for (const auto & obj : chunk) + { + Aws::S3::Model::DeleteObjectRequest request; + request.SetBucket(bucket); + request.SetKey(obj.GetKey()); + auto outcome = settings->client->DeleteObject(request); + // Do not throw here, continue deleting other keys and chunks + logIfError(outcome, [&](){return "Can't remove AWS key: " + obj.GetKey();}); + } + } + else + { + LOG_TRACE(log, "Remove AWS keys {}", keys); + Aws::S3::Model::Delete delkeys; + delkeys.SetObjects(chunk); + Aws::S3::Model::DeleteObjectsRequest request; + request.SetBucket(bucket); + request.SetDelete(delkeys); + auto outcome = settings->client->DeleteObjects(request); + // Do not throw here, continue deleting other chunks + logIfError(outcome, [&](){return "Can't remove AWS keys: " + keys;}); + } }); } diff --git a/src/Disks/S3/DiskS3.h b/src/Disks/S3/DiskS3.h index 2de1600d906f..4d24fd071acc 100644 --- a/src/Disks/S3/DiskS3.h +++ b/src/Disks/S3/DiskS3.h @@ -9,6 +9,7 @@ #include #include "Disks/DiskFactory.h" #include "Disks/Executor.h" +#include #include #include @@ -76,6 +77,7 @@ class DiskS3 final : public IDiskRemote DiskPtr metadata_disk_, FileCachePtr cache_, ContextPtr context_, + const S3Capabilities & s3_capabilities_, SettingsPtr settings_, GetDiskSettings settings_getter_); @@ -166,6 +168,7 @@ class DiskS3 final : public IDiskRemote MultiVersion current_settings; /// Gets disk settings from context. GetDiskSettings settings_getter; + const S3Capabilities s3_capabilities; std::atomic revision_counter = 0; static constexpr UInt64 LATEST_REVISION = std::numeric_limits::max(); diff --git a/src/Disks/S3/S3Capabilities.cpp b/src/Disks/S3/S3Capabilities.cpp new file mode 100644 index 000000000000..f96f0b5539a6 --- /dev/null +++ b/src/Disks/S3/S3Capabilities.cpp @@ -0,0 +1,15 @@ +#include + +namespace DB +{ + +S3Capabilities getCapabilitiesFromConfig(const Poco::Util::AbstractConfiguration & config, const std::string & config_prefix) +{ + return S3Capabilities + { + .support_batch_delete = config.getBool(config_prefix + ".support_batch_delete", true), + .support_proxy = config.getBool(config_prefix + ".support_proxy", config.has(config_prefix + ".proxy")), + }; +} + +} diff --git a/src/Disks/S3/S3Capabilities.h b/src/Disks/S3/S3Capabilities.h new file mode 100644 index 000000000000..46e647da89e5 --- /dev/null +++ b/src/Disks/S3/S3Capabilities.h @@ -0,0 +1,27 @@ +#pragma once + +#include +#include + +namespace DB +{ + +/// Supported/unsupported features by different S3 implementations +/// Can be useful only for almost compatible with AWS S3 versions. +struct S3Capabilities +{ + /// Google S3 implementation doesn't support batch delete + /// TODO: possibly we have to use Google SDK https://github.com/googleapis/google-cloud-cpp/tree/main/google/cloud/storage + /// because looks like it miss a lot of features like: + /// 1) batch delete + /// 2) list_v2 + /// 3) multipart upload works differently + bool support_batch_delete{true}; + + /// Y.Cloud S3 implementation support proxy for connection + bool support_proxy{false}; +}; + +S3Capabilities getCapabilitiesFromConfig(const Poco::Util::AbstractConfiguration & config, const std::string & config_prefix); + +} diff --git a/src/Disks/S3/registerDiskS3.cpp b/src/Disks/S3/registerDiskS3.cpp index 2b5fe3c5a81b..96630ad72624 100644 --- a/src/Disks/S3/registerDiskS3.cpp +++ b/src/Disks/S3/registerDiskS3.cpp @@ -187,6 +187,7 @@ void registerDiskS3(DiskFactory & factory) auto [metadata_path, metadata_disk] = prepareForLocalMetadata(name, config, config_prefix, context); FileCachePtr cache = getCachePtrForDisk(name, config, config_prefix, context); + S3Capabilities s3_capabilities = getCapabilitiesFromConfig(config, config_prefix); std::shared_ptr s3disk = std::make_shared( name, @@ -195,6 +196,7 @@ void registerDiskS3(DiskFactory & factory) metadata_disk, std::move(cache), context, + s3_capabilities, getSettings(config, config_prefix, context), getSettings); From d45018fa08a2f0150ca4e51c5e56505c576484bc Mon Sep 17 00:00:00 2001 From: Robert Schulze Date: Tue, 12 Jul 2022 19:09:55 +0200 Subject: [PATCH 4/4] Merge pull request #37882 from excitoon-favorites/nodeleteobjects Fixes for objects removal in `S3ObjectStorage` --- src/Disks/IDiskRemote.h | 3 +- src/Disks/S3/DiskS3.cpp | 47 --------- src/Disks/S3/DiskS3.h | 57 ++++++++++- src/Disks/S3/registerDiskS3.cpp | 97 ++++++++++++++++++- .../configs/config.d/storage_conf.xml | 14 +++ .../s3_mocks/no_delete_objects.py | 92 ++++++++++++++++++ tests/integration/test_merge_tree_s3/test.py | 14 ++- 7 files changed, 272 insertions(+), 52 deletions(-) create mode 100644 tests/integration/test_merge_tree_s3/s3_mocks/no_delete_objects.py diff --git a/src/Disks/IDiskRemote.h b/src/Disks/IDiskRemote.h index 82e76b8f68d4..4c91400c94c5 100644 --- a/src/Disks/IDiskRemote.h +++ b/src/Disks/IDiskRemote.h @@ -165,9 +165,10 @@ friend class DiskRemoteReservation; DiskPtr metadata_disk; FileCachePtr cache; -private: +public: void removeMetadata(const String & path, RemoteFSPathKeeperPtr fs_paths_keeper); +private: void removeMetadataRecursive(const String & path, RemoteFSPathKeeperPtr fs_paths_keeper); bool tryReserve(UInt64 bytes); diff --git a/src/Disks/S3/DiskS3.cpp b/src/Disks/S3/DiskS3.cpp index bf5d08370902..723b1e7373c1 100644 --- a/src/Disks/S3/DiskS3.cpp +++ b/src/Disks/S3/DiskS3.cpp @@ -11,7 +11,6 @@ #include #include -#include #include #include @@ -58,52 +57,6 @@ namespace ErrorCodes extern const int LOGICAL_ERROR; } -/// Helper class to collect keys into chunks of maximum size (to prepare batch requests to AWS API) -/// see https://docs.aws.amazon.com/AmazonS3/latest/API/API_DeleteObjects.html -class S3PathKeeper : public RemoteFSPathKeeper -{ -public: - using Chunk = Aws::Vector; - using Chunks = std::list; - - explicit S3PathKeeper(size_t chunk_limit_) : RemoteFSPathKeeper(chunk_limit_) {} - - void addPath(const String & path) override - { - if (chunks.empty() || chunks.back().size() >= chunk_limit) - { - /// add one more chunk - chunks.push_back(Chunks::value_type()); - chunks.back().reserve(chunk_limit); - } - Aws::S3::Model::ObjectIdentifier obj; - obj.SetKey(path); - chunks.back().push_back(obj); - } - - void removePaths(Fn auto && remove_chunk_func) - { - for (auto & chunk : chunks) - remove_chunk_func(std::move(chunk)); - } - - static String getChunkKeys(const Chunk & chunk) - { - String res; - for (const auto & obj : chunk) - { - const auto & key = obj.GetKey(); - if (!res.empty()) - res.append(", "); - res.append(key.c_str(), key.size()); - } - return res; - } - -private: - Chunks chunks; -}; - template void throwIfError(Aws::Utils::Outcome & response) { diff --git a/src/Disks/S3/DiskS3.h b/src/Disks/S3/DiskS3.h index 4d24fd071acc..473dc6a7a751 100644 --- a/src/Disks/S3/DiskS3.h +++ b/src/Disks/S3/DiskS3.h @@ -7,6 +7,7 @@ #include #include #include +#include #include "Disks/DiskFactory.h" #include "Disks/Executor.h" #include @@ -14,6 +15,7 @@ #include #include #include +#include #include #include @@ -121,6 +123,8 @@ class DiskS3 final : public IDiskRemote void applyNewSettings(const Poco::Util::AbstractConfiguration & config, ContextPtr context, const String &, const DisksMap &) override; + void setCapabilitiesSupportBatchDelete(bool value) { s3_capabilities.support_batch_delete = value; } + private: void createFileOperationObject(const String & operation_name, UInt64 revision, const ObjectMetadata & metadata); /// Converts revision to binary string with leading zeroes (64 bit). @@ -168,7 +172,7 @@ class DiskS3 final : public IDiskRemote MultiVersion current_settings; /// Gets disk settings from context. GetDiskSettings settings_getter; - const S3Capabilities s3_capabilities; + S3Capabilities s3_capabilities; std::atomic revision_counter = 0; static constexpr UInt64 LATEST_REVISION = std::numeric_limits::max(); @@ -190,6 +194,57 @@ class DiskS3 final : public IDiskRemote ContextPtr context; }; +/// Helper class to collect keys into chunks of maximum size (to prepare batch requests to AWS API) +/// see https://docs.aws.amazon.com/AmazonS3/latest/API/API_DeleteObjects.html +class S3PathKeeper : public RemoteFSPathKeeper +{ +public: + using Chunk = Aws::Vector; + using Chunks = std::list; + + explicit S3PathKeeper(size_t chunk_limit_) : RemoteFSPathKeeper(chunk_limit_) {} + + void addPath(const String & path) override + { + if (chunks.empty() || chunks.back().size() >= chunk_limit) + { + /// add one more chunk + chunks.push_back(Chunks::value_type()); + chunks.back().reserve(chunk_limit); + } + Aws::S3::Model::ObjectIdentifier obj; + obj.SetKey(path); + chunks.back().push_back(obj); + } + + void removePaths(Fn auto && remove_chunk_func) + { + for (auto & chunk : chunks) + remove_chunk_func(std::move(chunk)); + } + + Chunks getChunks() const + { + return chunks; + } + + static String getChunkKeys(const Chunk & chunk) + { + String res; + for (const auto & obj : chunk) + { + const auto & key = obj.GetKey(); + if (!res.empty()) + res.append(", "); + res.append(key.c_str(), key.size()); + } + return res; + } + +private: + Chunks chunks; +}; + } #endif diff --git a/src/Disks/S3/registerDiskS3.cpp b/src/Disks/S3/registerDiskS3.cpp index 96630ad72624..474378f4f3e3 100644 --- a/src/Disks/S3/registerDiskS3.cpp +++ b/src/Disks/S3/registerDiskS3.cpp @@ -9,6 +9,8 @@ #if USE_AWS_S3 #include +#include +#include #include #include "DiskS3.h" #include "Disks/DiskCacheWrapper.h" @@ -21,6 +23,7 @@ #include "Disks/RemoteDisksCommon.h" #include + namespace DB { namespace ErrorCodes @@ -46,7 +49,79 @@ void checkReadAccess(const String & disk_name, IDisk & disk) throw Exception("No read access to S3 bucket in disk " + disk_name, ErrorCodes::PATH_ACCESS_DENIED); } -void checkRemoveAccess(IDisk & disk) { disk.removeFile("test_acl"); } +void checkRemoveAccess(IDisk & disk) +{ + disk.removeFile("test_acl"); +} + +bool checkBatchRemoveIsMissing(DiskS3 & disk, std::unique_ptr settings, const String & bucket) +{ + const String path = "_test_remove_objects_capability"; + try + { + auto file = disk.writeFile(path, DBMS_DEFAULT_BUFFER_SIZE, WriteMode::Rewrite); + file->write("test", 4); + file->finalize(); + } + catch (...) + { + try + { + disk.removeFile(path); + } + catch (...) + { + } + return false; /// We don't have write access, therefore no information about batch remove. + } + + /// See `IDiskRemote::removeSharedFile`. + auto fs_paths_keeper = std::dynamic_pointer_cast(disk.createFSPathKeeper()); + disk.removeMetadata(path, fs_paths_keeper); + + auto fs_paths_keeper_copy = std::dynamic_pointer_cast(disk.createFSPathKeeper()); + for (const auto & chunk : fs_paths_keeper->getChunks()) + for (const auto & obj : chunk) + fs_paths_keeper_copy->addPath(obj.GetKey()); + + try + { + /// See `DiskS3::removeFromRemoteFS`. + fs_paths_keeper->removePaths([&](S3PathKeeper::Chunk && chunk) + { + String keys = S3PathKeeper::getChunkKeys(chunk); + LOG_TRACE(&Poco::Logger::get("registerDiskS3"), "Remove AWS keys {}", keys); + Aws::S3::Model::Delete delkeys; + delkeys.SetObjects(chunk); + Aws::S3::Model::DeleteObjectsRequest request; + request.SetBucket(bucket); + request.SetDelete(delkeys); + auto outcome = settings->client->DeleteObjects(request); + if (!outcome.IsSuccess()) + { + const auto & err = outcome.GetError(); + throw Exception(err.GetMessage(), static_cast(err.GetErrorType())); + } + }); + return false; + } + catch (const Exception &) + { + fs_paths_keeper_copy->removePaths([&](S3PathKeeper::Chunk && chunk) + { + String keys = S3PathKeeper::getChunkKeys(chunk); + LOG_TRACE(&Poco::Logger::get("registerDiskS3"), "Remove AWS keys {} one by one", keys); + for (const auto & obj : chunk) + { + Aws::S3::Model::DeleteObjectRequest request; + request.SetBucket(bucket); + request.SetKey(obj.GetKey()); + settings->client->DeleteObject(request); + } + }); + return true; + } +} std::shared_ptr getProxyResolverConfiguration( const String & prefix, const Poco::Util::AbstractConfiguration & proxy_resolver_config) @@ -200,8 +275,26 @@ void registerDiskS3(DiskFactory & factory) getSettings(config, config_prefix, context), getSettings); + bool skip_access_check = config.getBool(config_prefix + ".skip_access_check", false); + + if (!skip_access_check) + { + /// If `support_batch_delete` is turned on (default), check and possibly switch it off. + if (s3_capabilities.support_batch_delete && checkBatchRemoveIsMissing(*std::dynamic_pointer_cast(s3disk), getSettings(config, config_prefix, context), uri.bucket)) + { + LOG_WARNING( + &Poco::Logger::get("registerDiskS3"), + "Storage for disk {} does not support batch delete operations, " + "so `s3_capabilities.support_batch_delete` was automatically turned off during the access check. " + "To remove this message set `s3_capabilities.support_batch_delete` for the disk to `false`.", + name + ); + std::dynamic_pointer_cast(s3disk)->setCapabilitiesSupportBatchDelete(false); + } + } + /// This code is used only to check access to the corresponding disk. - if (!config.getBool(config_prefix + ".skip_access_check", false)) + if (!skip_access_check) { checkWriteAccess(*s3disk); checkReadAccess(name, *s3disk); diff --git a/tests/integration/test_merge_tree_s3/configs/config.d/storage_conf.xml b/tests/integration/test_merge_tree_s3/configs/config.d/storage_conf.xml index 2f1b8275a0bb..a6e2d29c5d57 100644 --- a/tests/integration/test_merge_tree_s3/configs/config.d/storage_conf.xml +++ b/tests/integration/test_merge_tree_s3/configs/config.d/storage_conf.xml @@ -15,6 +15,13 @@ minio123 10 + + s3 + http://resolver:8082/root/data/ + minio + minio123 + 10 + local / @@ -46,6 +53,13 @@ + + +
+ no_delete_objects_s3 +
+
+
diff --git a/tests/integration/test_merge_tree_s3/s3_mocks/no_delete_objects.py b/tests/integration/test_merge_tree_s3/s3_mocks/no_delete_objects.py new file mode 100644 index 000000000000..111f3a490c2b --- /dev/null +++ b/tests/integration/test_merge_tree_s3/s3_mocks/no_delete_objects.py @@ -0,0 +1,92 @@ +import http.client +import http.server +import random +import socketserver +import sys +import urllib.parse + + +UPSTREAM_HOST = "minio1:9001" +random.seed("No delete objects/1.0") + + +def request(command, url, headers={}, data=None): + """Mini-requests.""" + + class Dummy: + pass + + parts = urllib.parse.urlparse(url) + c = http.client.HTTPConnection(parts.hostname, parts.port) + c.request( + command, + urllib.parse.urlunparse(parts._replace(scheme="", netloc="")), + headers=headers, + body=data, + ) + r = c.getresponse() + result = Dummy() + result.status_code = r.status + result.headers = r.headers + result.content = r.read() + return result + + +class RequestHandler(http.server.BaseHTTPRequestHandler): + def do_GET(self): + if self.path == "/": + self.send_response(200) + self.send_header("Content-Type", "text/plain") + self.end_headers() + self.wfile.write(b"OK") + else: + self.do_HEAD() + + def do_PUT(self): + self.do_HEAD() + + def do_DELETE(self): + self.do_HEAD() + + def do_POST(self): + query = urllib.parse.urlparse(self.path).query + params = urllib.parse.parse_qs(query, keep_blank_values=True) + if "delete" in params: + self.send_response(501) + self.send_header("Content-Type", "application/xml") + self.end_headers() + self.wfile.write( + b""" + + NotImplemented + Ima GCP and I can't do `DeleteObjects` request for ya. See https://issuetracker.google.com/issues/162653700 . + RESOURCE + REQUEST_ID +""" + ) + else: + self.do_HEAD() + + def do_HEAD(self): + content_length = self.headers.get("Content-Length") + data = self.rfile.read(int(content_length)) if content_length else None + r = request( + self.command, + f"http://{UPSTREAM_HOST}{self.path}", + headers=self.headers, + data=data, + ) + self.send_response(r.status_code) + for k, v in r.headers.items(): + self.send_header(k, v) + self.end_headers() + self.wfile.write(r.content) + self.wfile.close() + + +class ThreadedHTTPServer(socketserver.ThreadingMixIn, http.server.HTTPServer): + """Handle requests in a separate thread.""" + + +httpd = ThreadedHTTPServer(("0.0.0.0", int(sys.argv[1])), RequestHandler) +httpd.serve_forever() diff --git a/tests/integration/test_merge_tree_s3/test.py b/tests/integration/test_merge_tree_s3/test.py index b7ef3ce3ef2e..b04b02246132 100644 --- a/tests/integration/test_merge_tree_s3/test.py +++ b/tests/integration/test_merge_tree_s3/test.py @@ -67,7 +67,10 @@ def create_table(node, table_name, **additional_settings): def run_s3_mocks(cluster): logging.info("Starting s3 mocks") - mocks = (("unstable_proxy.py", "resolver", "8081"),) + mocks = ( + ("unstable_proxy.py", "resolver", "8081"), + ("no_delete_objects.py", "resolver", "8082"), + ) for mock_filename, container, port in mocks: container_id = cluster.get_container_id(container) current_dir = os.path.dirname(__file__) @@ -602,6 +605,15 @@ def restart_disk(): thread.join() +@pytest.mark.parametrize("node_name", ["node"]) +def test_s3_no_delete_objects(cluster, node_name): + node = cluster.instances[node_name] + create_table( + node, "s3_test_no_delete_objects", storage_policy="no_delete_objects_s3" + ) + node.query("DROP TABLE s3_test_no_delete_objects SYNC") + + @pytest.mark.parametrize("node_name", ["node"]) def test_s3_disk_reads_on_unstable_connection(cluster, node_name): node = cluster.instances[node_name]