Skip to content

Commit

Permalink
Cosmetics
Browse files Browse the repository at this point in the history
  • Loading branch information
rschu1ze committed Nov 7, 2023
1 parent 842cc36 commit 7223d49
Showing 1 changed file with 23 additions and 28 deletions.
51 changes: 23 additions & 28 deletions src/Functions/array/arrayRandomSample.cpp
@@ -1,11 +1,11 @@
#include <random>
#include <Columns/ColumnArray.h>
#include <Columns/ColumnsNumber.h>
#include <DataTypes/DataTypeArray.h>
#include <Functions/FunctionFactory.h>
#include <Functions/FunctionHelpers.h>
#include <Functions/IFunction.h>
#include <Poco/Logger.h>
#include "Columns/ColumnsNumber.h"

namespace DB
{
Expand Down Expand Up @@ -42,22 +42,21 @@ class FunctionArrayRandomSample : public IFunction
// Return an array with the same nested type as the input array
const DataTypePtr & array_type = arguments[0].type;
const DataTypeArray * array_data_type = checkAndGetDataType<DataTypeArray>(array_type.get());

// Get the nested data type of the array
const DataTypePtr & nested_type = array_data_type->getNestedType();

return std::make_shared<DataTypeArray>(nested_type);
}

ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t input_rows_count) const override
{
const ColumnArray * column_array = checkAndGetColumn<ColumnArray>(arguments[0].column.get());
if (!column_array)
throw Exception(ErrorCodes::ILLEGAL_COLUMN, "First argument must be an array");
const ColumnArray * col_array = checkAndGetColumn<ColumnArray>(arguments[0].column.get());
if (!col_array)
throw Exception(ErrorCodes::ILLEGAL_COLUMN, "First argument of function {} must be an array", getName());

const IColumn * col_samples = arguments[1].column.get();
if (!col_samples)
throw Exception(ErrorCodes::ILLEGAL_COLUMN, "The second argument is empty or null, type = {}", arguments[1].type->getName());
throw Exception(ErrorCodes::ILLEGAL_COLUMN,
"The second argument of function {} is empty or null, type = {}",
getName(), arguments[1].type->getName());

UInt64 samples;
try
Expand All @@ -74,39 +73,35 @@ class FunctionArrayRandomSample : public IFunction
std::random_device rd;
std::mt19937 gen(rd());

auto nested_column = column_array->getDataPtr()->cloneEmpty();
auto offsets_column = ColumnUInt64::create();

auto res_data = ColumnArray::create(std::move(nested_column), std::move(offsets_column));
auto col_res_data = col_array->getDataPtr()->cloneEmpty();
auto col_res_offsets = ColumnUInt64::create(input_rows_count);
auto col_res = ColumnArray::create(std::move(col_res_data), std::move(col_res_offsets));

const auto & input_offsets = column_array->getOffsets();
auto & res_offsets = res_data->getOffsets();
res_offsets.resize(input_rows_count);
const auto & array_offsets = col_array->getOffsets();
auto & res_offsets = col_res->getOffsets();

UInt64 cur_samples;
size_t current_offset = 0;
std::vector<size_t> indices;
size_t prev_array_offset = 0;

for (size_t row = 0; row < input_rows_count; row++)
{
size_t row_size = input_offsets[row] - current_offset;
const size_t num_elements = array_offsets[row] - prev_array_offset;

std::vector<size_t> indices(row_size);
indices.resize(num_elements);
std::iota(indices.begin(), indices.end(), 0);
std::shuffle(indices.begin(), indices.end(), gen);

cur_samples = std::min(samples, static_cast<UInt64>(row_size));
const size_t cur_samples = std::min(num_elements, samples);

for (UInt64 j = 0; j < cur_samples; j++)
{
size_t source_index = indices[j];
res_data->getData().insertFrom(column_array->getData(), source_index);
}
for (UInt64 i = 0; i < cur_samples; i++)
col_res->getData().insertFrom(col_array->getData(), indices[i]);

res_offsets[row] = current_offset + cur_samples;
current_offset += cur_samples;
res_offsets[row] = prev_array_offset + cur_samples;
prev_array_offset += cur_samples;
indices.clear();
}

return res_data;
return col_res;
}
};

Expand Down

0 comments on commit 7223d49

Please sign in to comment.