Skip to content

Commit

Permalink
Merge pull request #50076 from FFFFFFFHHHHHHH/jaccard_similarity
Browse files Browse the repository at this point in the history
Add function arrayJaccardIndex
  • Loading branch information
rschu1ze committed Jul 17, 2023
2 parents 036fb1f + 0895e47 commit 9d7737b
Show file tree
Hide file tree
Showing 7 changed files with 258 additions and 15 deletions.
18 changes: 18 additions & 0 deletions docs/en/sql-reference/functions/array-functions.md
Expand Up @@ -996,6 +996,24 @@ SELECT
└──────────────┴───────────┘
```

## arrayJaccardIndex

Returns the [Jaccard index](https://en.wikipedia.org/wiki/Jaccard_index) of two arrays.

**Example**

Query:
``` sql
SELECT arrayJaccardIndex([1, 2], [2, 3]) AS res
```

Result:
``` text
┌─res────────────────┐
│ 0.3333333333333333 │
└────────────────────┘
```

## arrayReduce

Applies an aggregate function to array elements and returns its result. The name of the aggregation function is passed as a string in single quotes `'max'`, `'sum'`. When using parametric aggregate functions, the parameter is indicated after the function name in parentheses `'uniqUpTo(6)'`.
Expand Down
38 changes: 23 additions & 15 deletions src/DataTypes/IDataType.h
Expand Up @@ -410,21 +410,29 @@ inline bool isDateTime(const T & data_type) { return WhichDataType(data_type).is
template <typename T>
inline bool isDateTime64(const T & data_type) { return WhichDataType(data_type).isDateTime64(); }

inline bool isEnum(const DataTypePtr & data_type) { return WhichDataType(data_type).isEnum(); }
inline bool isDecimal(const DataTypePtr & data_type) { return WhichDataType(data_type).isDecimal(); }
inline bool isTuple(const DataTypePtr & data_type) { return WhichDataType(data_type).isTuple(); }
inline bool isArray(const DataTypePtr & data_type) { return WhichDataType(data_type).isArray(); }
inline bool isMap(const DataTypePtr & data_type) {return WhichDataType(data_type).isMap(); }
inline bool isInterval(const DataTypePtr & data_type) {return WhichDataType(data_type).isInterval(); }
inline bool isNothing(const DataTypePtr & data_type) { return WhichDataType(data_type).isNothing(); }
inline bool isUUID(const DataTypePtr & data_type) { return WhichDataType(data_type).isUUID(); }
inline bool isIPv4(const DataTypePtr & data_type) { return WhichDataType(data_type).isIPv4(); }
inline bool isIPv6(const DataTypePtr & data_type) { return WhichDataType(data_type).isIPv6(); }

template <typename T>
inline bool isObject(const T & data_type)
{
return WhichDataType(data_type).isObject();
template <typename T>
inline bool isEnum(const T & data_type) { return WhichDataType(data_type).isEnum(); }
template <typename T>
inline bool isDecimal(const T & data_type) { return WhichDataType(data_type).isDecimal(); }
template <typename T>
inline bool isTuple(const T & data_type) { return WhichDataType(data_type).isTuple(); }
template <typename T>
inline bool isArray(const T & data_type) { return WhichDataType(data_type).isArray(); }
template <typename T>
inline bool isMap(const T & data_type) {return WhichDataType(data_type).isMap(); }
template <typename T>
inline bool isInterval(const T & data_type) {return WhichDataType(data_type).isInterval(); }
template <typename T>
inline bool isNothing(const T & data_type) { return WhichDataType(data_type).isNothing(); }
template <typename T>
inline bool isUUID(const T & data_type) { return WhichDataType(data_type).isUUID(); }
template <typename T>
inline bool isIPv4(const T & data_type) { return WhichDataType(data_type).isIPv4(); }
template <typename T>
inline bool isIPv6(const T & data_type) { return WhichDataType(data_type).isIPv6(); }

template <typename T>
inline bool isObject(const T & data_type) { return WhichDataType(data_type).isObject();
}

template <typename T>
Expand Down
161 changes: 161 additions & 0 deletions src/Functions/array/arrayJaccardIndex.cpp
@@ -0,0 +1,161 @@
#include <Columns/ColumnArray.h>
#include <Columns/ColumnsNumber.h>
#include <Columns/IColumn.h>
#include <DataTypes/DataTypeArray.h>
#include <DataTypes/DataTypesNumber.h>
#include <DataTypes/IDataType.h>
#include <Functions/FunctionFactory.h>
#include <Functions/FunctionHelpers.h>
#include <DataTypes/DataTypeNothing.h>
#include <DataTypes/getMostSubtype.h>
#include <Core/ColumnsWithTypeAndName.h>
#include <Core/ColumnWithTypeAndName.h>
#include <Interpreters/Context_fwd.h>
#include <base/types.h>

namespace DB
{
namespace ErrorCodes
{
extern const int ILLEGAL_COLUMN;
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
extern const int LOGICAL_ERROR;
}

class FunctionArrayJaccardIndex : public IFunction
{
private:
using ResultType = Float64;

struct LeftAndRightSizes
{
size_t left_size;
size_t right_size;
};

template <bool left_is_const, bool right_is_const>
static LeftAndRightSizes getArraySizes(const ColumnArray::Offsets & left_offsets, const ColumnArray::Offsets & right_offsets, size_t i)
{
size_t left_size;
size_t right_size;

if constexpr (left_is_const)
left_size = left_offsets[0];
else
left_size = left_offsets[i] - left_offsets[i - 1];

if constexpr (right_is_const)
right_size = right_offsets[0];
else
right_size = right_offsets[i] - right_offsets[i - 1];

return {left_size, right_size};
}

template <bool left_is_const, bool right_is_const>
static void vector(const ColumnArray::Offsets & intersect_offsets, const ColumnArray::Offsets & left_offsets, const ColumnArray::Offsets & right_offsets, PaddedPODArray<ResultType> & res)
{
for (size_t i = 0; i < res.size(); ++i)
{
LeftAndRightSizes sizes = getArraySizes<left_is_const, right_is_const>(left_offsets, right_offsets, i);
size_t intersect_size = intersect_offsets[i] - intersect_offsets[i - 1];
res[i] = static_cast<ResultType>(intersect_size) / (sizes.left_size + sizes.right_size - intersect_size);
}
}

template <bool left_is_const, bool right_is_const>
static void vectorWithEmptyIntersect(const ColumnArray::Offsets & left_offsets, const ColumnArray::Offsets & right_offsets, PaddedPODArray<ResultType> & res)
{
for (size_t i = 0; i < res.size(); ++i)
{
LeftAndRightSizes sizes = getArraySizes<left_is_const, right_is_const>(left_offsets, right_offsets, i);
if (sizes.left_size == 0 && sizes.right_size == 0)
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "array aggregate functions cannot be performed on two empty arrays");
res[i] = 0;
}
}

public:
static constexpr auto name = "arrayJaccardIndex";
String getName() const override { return name; }
static FunctionPtr create(ContextPtr context_) { return std::make_shared<FunctionArrayJaccardIndex>(context_); }
explicit FunctionArrayJaccardIndex(ContextPtr context_) : context(context_) {}
size_t getNumberOfArguments() const override { return 2; }
bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo &) const override { return true; }
bool useDefaultImplementationForConstants() const override { return true; }

DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override
{
FunctionArgumentDescriptors args{
{"array_1", &isArray<IDataType>, nullptr, "Array"},
{"array_2", &isArray<IDataType>, nullptr, "Array"},
};
validateFunctionArgumentTypes(*this, arguments, args);
return std::make_shared<DataTypeNumber<ResultType>>();
}

ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t input_rows_count) const override
{
auto cast_to_array = [&](const ColumnWithTypeAndName & col) -> std::pair<const ColumnArray *, bool>
{
if (const ColumnConst * col_const = typeid_cast<const ColumnConst *>(col.column.get()))
{
const ColumnArray * col_const_array = checkAndGetColumn<ColumnArray>(col_const->getDataColumnPtr().get());
return {col_const_array, true};
}
else if (const ColumnArray * col_non_const_array = checkAndGetColumn<ColumnArray>(col.column.get()))
return {col_non_const_array, false};
else
throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Argument for function {} must be array but it has type {}.", col.column->getName(), getName());
};

const auto & [left_array, left_is_const] = cast_to_array(arguments[0]);
const auto & [right_array, right_is_const] = cast_to_array(arguments[1]);

auto intersect_array = FunctionFactory::instance().get("arrayIntersect", context)->build(arguments);

ColumnWithTypeAndName intersect_column;
intersect_column.type = intersect_array->getResultType();
intersect_column.column = intersect_array->execute(arguments, intersect_column.type, input_rows_count);

const auto * intersect_column_type = checkAndGetDataType<DataTypeArray>(intersect_column.type.get());
if (!intersect_column_type)
throw Exception(ErrorCodes::LOGICAL_ERROR, "Unexpected return type for function arrayIntersect");

auto col_res = ColumnVector<ResultType>::create();
typename ColumnVector<ResultType>::Container & vec_res = col_res->getData();
vec_res.resize(input_rows_count);

#define EXECUTE_VECTOR(left_is_const, right_is_const) \
if (typeid_cast<const DataTypeNothing *>(intersect_column_type->getNestedType().get())) \
vectorWithEmptyIntersect<left_is_const, right_is_const>(left_array->getOffsets(), right_array->getOffsets(), vec_res); \
else \
{ \
const ColumnArray * intersect_column_array = checkAndGetColumn<ColumnArray>(intersect_column.column.get()); \
vector<left_is_const, right_is_const>(intersect_column_array->getOffsets(), left_array->getOffsets(), right_array->getOffsets(), vec_res); \
}

if (!left_is_const && !right_is_const)
EXECUTE_VECTOR(false, false)
else if (!left_is_const && right_is_const)
EXECUTE_VECTOR(false, true)
else if (left_is_const && !right_is_const)
EXECUTE_VECTOR(true, false)
else
EXECUTE_VECTOR(true, true)

#undef EXECUTE_VECTOR

return col_res;
}

private:
ContextPtr context;
};

REGISTER_FUNCTION(ArrayJaccardIndex)
{
factory.registerFunction<FunctionArrayJaccardIndex>();
}

}
Expand Up @@ -112,6 +112,7 @@ arrayFirstIndex
arrayFirstOrNull
arrayFlatten
arrayIntersect
arrayJaccardIndex
arrayJoin
arrayLast
arrayLastIndex
Expand Down
23 changes: 23 additions & 0 deletions tests/queries/0_stateless/02737_arrayJaccardIndex.reference
@@ -0,0 +1,23 @@
negative tests
const arguments
[1,2] [1,2,3,4] 0.5
[1,1.1,2.2] [2.2,3.3,444] 0.2
[1] [1] 1
['a'] ['a','aa','aaa'] 0.33
[[1,2],[3,4]] [[1,2],[3,5]] 0.33
non-const arguments
[1] [1,2] 0.5
[1,2] [1,2] 1
[1,2,3] [1,2] 0.67
[1] [] 0
[1,2] [] 0
[1,2,3] [] 0
[1,2] [1] 0.5
[1,2] [1,2] 1
[1,2] [1,2,3] 0.67
[] [1] 0
[] [1,2] 0
[] [1,2,3] 0
[1] [1] 1
[1,2] [1,2] 1
[1,2,3] [1,2,3] 1
30 changes: 30 additions & 0 deletions tests/queries/0_stateless/02737_arrayJaccardIndex.sql
@@ -0,0 +1,30 @@
SELECT 'negative tests';

SELECT 'a' AS arr1, 2 AS arr2, round(arrayJaccardIndex(arr1, arr2), 2); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT }
SELECT [] AS arr1, [] AS arr2, round(arrayJaccardIndex(arr1, arr2), 2); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT }
SELECT ['1', '2'] AS arr1, [1,2] AS arr2, round(arrayJaccardIndex(arr1, arr2), 2); -- { serverError NO_COMMON_TYPE }

SELECT 'const arguments';

SELECT [1,2] AS arr1, [1,2,3,4] AS arr2, round(arrayJaccardIndex(arr1, arr2), 2);
SELECT [1, 1.1, 2.2] AS arr1, [2.2, 3.3, 444] AS arr2, round(arrayJaccardIndex(arr1, arr2), 2);
SELECT [toUInt16(1)] AS arr1, [toUInt32(1)] AS arr2, round(arrayJaccardIndex(arr1, arr2), 2);
SELECT ['a'] AS arr1, ['a', 'aa', 'aaa'] AS arr2, round(arrayJaccardIndex(arr1, arr2), 2);
SELECT [[1,2], [3,4]] AS arr1, [[1,2], [3,5]] AS arr2, round(arrayJaccardIndex(arr1, arr2), 2);

SELECT 'non-const arguments';

DROP TABLE IF EXISTS array_jaccard_index;

CREATE TABLE array_jaccard_index (arr Array(UInt8)) engine = MergeTree ORDER BY arr;
INSERT INTO array_jaccard_index values ([1,2,3]);
INSERT INTO array_jaccard_index values ([1,2]);
INSERT INTO array_jaccard_index values ([1]);

SELECT arr, [1,2] AS other, round(arrayJaccardIndex(arr, other), 2) FROM array_jaccard_index ORDER BY arr;
SELECT arr, [] AS other, round(arrayJaccardIndex(arr, other), 2) FROM array_jaccard_index ORDER BY arr;
SELECT [1,2] AS other, arr, round(arrayJaccardIndex(other, arr), 2) FROM array_jaccard_index ORDER BY arr;
SELECT [] AS other, arr, round(arrayJaccardIndex(other, arr), 2) FROM array_jaccard_index ORDER BY arr;
SELECT arr, arr, round(arrayJaccardIndex(arr, arr), 2) FROM array_jaccard_index ORDER BY arr;

DROP TABLE array_jaccard_index;
2 changes: 2 additions & 0 deletions utils/check-style/aspell-ignore/en/aspell-dict.txt
Expand Up @@ -1035,6 +1035,7 @@ arrayFirst
arrayFirstIndex
arrayFlatten
arrayIntersect
arrayJaccardIndex
arrayJoin
arrayLast
arrayLastIndex
Expand Down Expand Up @@ -1608,6 +1609,7 @@ isNull
isValidJSON
isValidUTF
iteratively
jaccard
javaHash
javaHashUTF
jbod
Expand Down

0 comments on commit 9d7737b

Please sign in to comment.