From 00a96de848f0e039ae804787cc2227b5374c5ceb Mon Sep 17 00:00:00 2001 From: robot-clickhouse Date: Tue, 20 Feb 2024 11:07:38 +0000 Subject: [PATCH] Backport #60150 to 23.11: Fix cosineDistance crash with Nullable --- .../en/sql-reference/functions/distance-functions.md | 2 +- src/Functions/vectorFunctions.cpp | 12 ++++++------ .../02994_cosineDistanceNullable.reference | 11 +++++++++++ .../0_stateless/02994_cosineDistanceNullable.sql | 3 +++ 4 files changed, 21 insertions(+), 7 deletions(-) create mode 100644 tests/queries/0_stateless/02994_cosineDistanceNullable.reference create mode 100644 tests/queries/0_stateless/02994_cosineDistanceNullable.sql diff --git a/docs/en/sql-reference/functions/distance-functions.md b/docs/en/sql-reference/functions/distance-functions.md index 1774c22014d6..e20c35c6b6f1 100644 --- a/docs/en/sql-reference/functions/distance-functions.md +++ b/docs/en/sql-reference/functions/distance-functions.md @@ -509,7 +509,7 @@ Result: ## cosineDistance -Calculates the cosine distance between two vectors (the values of the tuples are the coordinates). The less the returned value is, the more similar are the vectors. +Calculates the cosine distance between two vectors (the values of the tuples are the coordinates). The smaller the returned value is, the more similar are the vectors. **Syntax** diff --git a/src/Functions/vectorFunctions.cpp b/src/Functions/vectorFunctions.cpp index 33b0e9f60393..de4a6fb0a5cd 100644 --- a/src/Functions/vectorFunctions.cpp +++ b/src/Functions/vectorFunctions.cpp @@ -1,9 +1,9 @@ #include #include #include +#include #include #include -#include #include #include #include @@ -1364,11 +1364,11 @@ class FunctionCosineDistance : public ITupleFunction ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t input_rows_count) const override { - if (getReturnTypeImpl(arguments)->isNullable()) - { - return DataTypeNullable(std::make_shared()) - .createColumnConstWithDefaultValue(input_rows_count); - } + /// TODO: cosineDistance does not support nullable arguments + /// https://github.com/ClickHouse/ClickHouse/pull/27933#issuecomment-916670286 + auto return_type = getReturnTypeImpl(arguments); + if (return_type->isNullable()) + return return_type->createColumnConstWithDefaultValue(input_rows_count); FunctionDotProduct dot(context); ColumnWithTypeAndName dot_result{dot.executeImpl(arguments, DataTypePtr(), input_rows_count), diff --git a/tests/queries/0_stateless/02994_cosineDistanceNullable.reference b/tests/queries/0_stateless/02994_cosineDistanceNullable.reference new file mode 100644 index 000000000000..e4fe1f97e7e9 --- /dev/null +++ b/tests/queries/0_stateless/02994_cosineDistanceNullable.reference @@ -0,0 +1,11 @@ +\N +\N +\N +\N +\N +\N +\N +\N +\N +\N +\N diff --git a/tests/queries/0_stateless/02994_cosineDistanceNullable.sql b/tests/queries/0_stateless/02994_cosineDistanceNullable.sql new file mode 100644 index 000000000000..a62216982f39 --- /dev/null +++ b/tests/queries/0_stateless/02994_cosineDistanceNullable.sql @@ -0,0 +1,3 @@ +-- https://github.com/ClickHouse/ClickHouse/issues/59596 +SELECT cosineDistance((1, 1), (toNullable(0.5), 0.1)); +SELECT cosineDistance((1, 1), (toNullable(0.5), 0.1)) from numbers(10);