Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Try to avoid calculation of scalar subqueries for CREATE TABLE. #60464

Merged
merged 11 commits into from Mar 8, 2024
2 changes: 2 additions & 0 deletions src/Functions/FunctionTokens.h
Expand Up @@ -74,6 +74,8 @@ class FunctionTokens : public IFunction

size_t getNumberOfArguments() const override { return Generator::getNumberOfArguments(); }

ColumnNumbers getArgumentsThatAreAlwaysConstant() const override { return Generator::getArgumentsThatAreAlwaysConstant(); }

DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override
{
Generator::checkArguments(*this, arguments);
Expand Down
2 changes: 2 additions & 0 deletions src/Functions/URL/URLHierarchy.cpp
Expand Up @@ -24,6 +24,8 @@ class URLPathHierarchyImpl
static bool isVariadic() { return false; }
static size_t getNumberOfArguments() { return 1; }

static ColumnNumbers getArgumentsThatAreAlwaysConstant() { return {}; }

static void checkArguments(const IFunction & func, const ColumnsWithTypeAndName & arguments)
{
FunctionArgumentDescriptors mandatory_args{
Expand Down
2 changes: 2 additions & 0 deletions src/Functions/URL/URLPathHierarchy.cpp
Expand Up @@ -22,6 +22,8 @@ class URLHierarchyImpl
static bool isVariadic() { return false; }
static size_t getNumberOfArguments() { return 1; }

static ColumnNumbers getArgumentsThatAreAlwaysConstant() { return {}; }

static void checkArguments(const IFunction & func, const ColumnsWithTypeAndName & arguments)
{
FunctionArgumentDescriptors mandatory_args{
Expand Down
2 changes: 2 additions & 0 deletions src/Functions/URL/extractURLParameterNames.cpp
Expand Up @@ -22,6 +22,8 @@ class ExtractURLParameterNamesImpl
static bool isVariadic() { return false; }
static size_t getNumberOfArguments() { return 1; }

static ColumnNumbers getArgumentsThatAreAlwaysConstant() { return {}; }

static void checkArguments(const IFunction & func, const ColumnsWithTypeAndName & arguments)
{
FunctionArgumentDescriptors mandatory_args{
Expand Down
2 changes: 2 additions & 0 deletions src/Functions/URL/extractURLParameters.cpp
Expand Up @@ -23,6 +23,8 @@ class ExtractURLParametersImpl
static bool isVariadic() { return false; }
static size_t getNumberOfArguments() { return 1; }

static ColumnNumbers getArgumentsThatAreAlwaysConstant() { return {}; }

static void checkArguments(const IFunction & func, const ColumnsWithTypeAndName & arguments)
{
FunctionArgumentDescriptors mandatory_args{
Expand Down
2 changes: 2 additions & 0 deletions src/Functions/alphaTokens.cpp
Expand Up @@ -32,6 +32,8 @@ class SplitByAlphaImpl

static size_t getNumberOfArguments() { return 0; }

static ColumnNumbers getArgumentsThatAreAlwaysConstant() { return {1}; }

static void checkArguments(const IFunction & func, const ColumnsWithTypeAndName & arguments)
{
checkArgumentsWithOptionalMaxSubstrings(func, arguments);
Expand Down
2 changes: 2 additions & 0 deletions src/Functions/extractAll.cpp
Expand Up @@ -50,6 +50,8 @@ class ExtractAllImpl
static bool isVariadic() { return false; }
static size_t getNumberOfArguments() { return 2; }

static ColumnNumbers getArgumentsThatAreAlwaysConstant() { return {1}; }

static void checkArguments(const IFunction & func, const ColumnsWithTypeAndName & arguments)
{
FunctionArgumentDescriptors mandatory_args{
Expand Down
5 changes: 5 additions & 0 deletions src/Functions/identity.cpp
Expand Up @@ -9,4 +9,9 @@ REGISTER_FUNCTION(Identity)
factory.registerFunction<FunctionIdentity>();
}

REGISTER_FUNCTION(ScalarSubqueryResult)
{
factory.registerFunction<FunctionScalarSubqueryResult>();
}

}
20 changes: 17 additions & 3 deletions src/Functions/identity.h
Expand Up @@ -6,11 +6,12 @@
namespace DB
{

class FunctionIdentity : public IFunction
template<typename Name>
class FunctionIdentityBase : public IFunction
{
public:
static constexpr auto name = "identity";
static FunctionPtr create(ContextPtr) { return std::make_shared<FunctionIdentity>(); }
static constexpr auto name = Name::name;
static FunctionPtr create(ContextPtr) { return std::make_shared<FunctionIdentityBase<Name>>(); }

String getName() const override { return name; }
size_t getNumberOfArguments() const override { return 1; }
Expand All @@ -28,4 +29,17 @@ class FunctionIdentity : public IFunction
}
};

struct IdentityName
{
static constexpr auto name = "identity";
};

struct ScalarSubqueryResultName
{
static constexpr auto name = "__scalarSubqueryResult";
};

using FunctionIdentity = FunctionIdentityBase<IdentityName>;
using FunctionScalarSubqueryResult = FunctionIdentityBase<ScalarSubqueryResultName>;

}
2 changes: 2 additions & 0 deletions src/Functions/splitByChar.cpp
Expand Up @@ -40,6 +40,8 @@ class SplitByCharImpl
static bool isVariadic() { return true; }
static size_t getNumberOfArguments() { return 0; }

static ColumnNumbers getArgumentsThatAreAlwaysConstant() { return {0, 2}; }

static void checkArguments(const IFunction & func, const ColumnsWithTypeAndName & arguments)
{
checkArgumentsWithSeparatorAndOptionalMaxSubstrings(func, arguments);
Expand Down
2 changes: 2 additions & 0 deletions src/Functions/splitByNonAlpha.cpp
Expand Up @@ -42,6 +42,8 @@ class SplitByNonAlphaImpl
static bool isVariadic() { return true; }
static size_t getNumberOfArguments() { return 0; }

static ColumnNumbers getArgumentsThatAreAlwaysConstant() { return {1}; }

static void checkArguments(const IFunction & func, const ColumnsWithTypeAndName & arguments)
{
checkArgumentsWithOptionalMaxSubstrings(func, arguments);
Expand Down
2 changes: 2 additions & 0 deletions src/Functions/splitByRegexp.cpp
Expand Up @@ -44,6 +44,8 @@ class SplitByRegexpImpl
static bool isVariadic() { return true; }
static size_t getNumberOfArguments() { return 0; }

static ColumnNumbers getArgumentsThatAreAlwaysConstant() { return {0, 2}; }

static void checkArguments(const IFunction & func, const ColumnsWithTypeAndName & arguments)
{
checkArgumentsWithSeparatorAndOptionalMaxSubstrings(func, arguments);
Expand Down
2 changes: 2 additions & 0 deletions src/Functions/splitByString.cpp
Expand Up @@ -39,6 +39,8 @@ class SplitByStringImpl
static bool isVariadic() { return true; }
static size_t getNumberOfArguments() { return 0; }

static ColumnNumbers getArgumentsThatAreAlwaysConstant() { return {0, 2}; }

static void checkArguments(const IFunction & func, const ColumnsWithTypeAndName & arguments)
{
checkArgumentsWithSeparatorAndOptionalMaxSubstrings(func, arguments);
Expand Down
2 changes: 2 additions & 0 deletions src/Functions/splitByWhitespace.cpp
Expand Up @@ -30,6 +30,8 @@ class SplitByWhitespaceImpl
static bool isVariadic() { return true; }
static size_t getNumberOfArguments() { return 0; }

static ColumnNumbers getArgumentsThatAreAlwaysConstant() { return {1}; }

static void checkArguments(const IFunction & func, const ColumnsWithTypeAndName & arguments)
{
checkArgumentsWithOptionalMaxSubstrings(func, arguments);
Expand Down
44 changes: 44 additions & 0 deletions src/Interpreters/ActionsDAG.cpp
Expand Up @@ -64,6 +64,37 @@ std::pair<ColumnsWithTypeAndName, bool> getFunctionArguments(const ActionsDAG::N
return { std::move(arguments), all_const };
}

bool isConstantFromScalarSubquery(const ActionsDAG::Node * node)
{
std::stack<const ActionsDAG::Node *> stack;
stack.push(node);
while (!stack.empty())
{
const auto * arg = stack.top();
stack.pop();

if (arg->column && isColumnConst(*arg->column))
continue;

while (arg->type == ActionsDAG::ActionType::ALIAS)
arg = arg->children.at(0);

if (arg->type != ActionsDAG::ActionType::FUNCTION)
return false;

if (arg->function_base->getName() == "__scalarSubqueryResult")
continue;

if (arg->children.empty() || !arg->function_base->isSuitableForConstantFolding())
return false;

for (const auto * child : arg->children)
stack.push(child);
}

return true;
}

}

void ActionsDAG::Node::toTree(JSONBuilder::JSONMap & map) const
Expand Down Expand Up @@ -196,6 +227,19 @@ const ActionsDAG::Node & ActionsDAG::addFunction(
{
auto [arguments, all_const] = getFunctionArguments(children);

auto constant_args = function->getArgumentsThatAreAlwaysConstant();
for (size_t pos : constant_args)
{
if (pos >= children.size())
continue;

if (arguments[pos].column && isColumnConst(*arguments[pos].column))
continue;

if (isConstantFromScalarSubquery(children[pos]))
arguments[pos].column = arguments[pos].type->createColumnConstWithDefaultValue(0);
}

auto function_base = function->build(arguments);
return addFunctionImpl(
function_base,
Expand Down
2 changes: 1 addition & 1 deletion src/Interpreters/ExecuteScalarSubqueriesVisitor.cpp
Expand Up @@ -281,7 +281,7 @@ void ExecuteScalarSubqueriesMatcher::visit(const ASTSubquery & subquery, ASTPtr
if (data.only_analyze)
{
ast->as<ASTFunction>()->alias.clear();
auto func = makeASTFunction("identity", std::move(ast));
auto func = makeASTFunction("__scalarSubqueryResult", std::move(ast));
func->alias = subquery_alias;
func->prefer_alias_to_column_name = prefer_alias_to_column_name;
ast = std::move(func);
Expand Down
21 changes: 2 additions & 19 deletions src/Interpreters/InterpreterCreateQuery.cpp
Expand Up @@ -809,24 +809,7 @@ InterpreterCreateQuery::TableProperties InterpreterCreateQuery::getTableProperti
}
else
{
/** To get valid sample block we need to prepare query without only_analyze, because we need to execute scalar
* subqueries. Otherwise functions that expect only constant arguments will throw error during query analysis,
* because the result of scalar subquery is not a constant.
*
* Example:
* CREATE MATERIALIZED VIEW test_mv ENGINE=MergeTree ORDER BY arr
* AS
* WITH (SELECT '\d[a-z]') AS constant_value
* SELECT extractAll(concat(toString(number), 'a'), assumeNotNull(constant_value)) AS arr
* FROM test_table;
*
* For new analyzer this issue does not exists because we always execute scalar subqueries.
* We can improve this in new analyzer, and execute scalar subqueries only in contexts when we expect constant
* for example: LIMIT, OFFSET, functions parameters, functions constant only arguments.
*/

InterpreterSelectWithUnionQuery interpreter(create.select->clone(), getContext(), SelectQueryOptions());
as_select_sample = interpreter.getSampleBlock();
as_select_sample = InterpreterSelectWithUnionQuery::getSampleBlock(create.select->clone(), getContext());
}

properties.columns = ColumnsDescription(as_select_sample.getNamesAndTypesList());
Expand Down Expand Up @@ -1237,7 +1220,7 @@ BlockIO InterpreterCreateQuery::createTable(ASTCreateQuery & create)
{
input_block = InterpreterSelectWithUnionQuery(create.select->clone(),
getContext(),
{}).getSampleBlock();
SelectQueryOptions().analyze()).getSampleBlock();
}

Block output_block = to_table->getInMemoryMetadataPtr()->getSampleBlock();
Expand Down
Expand Up @@ -53,7 +53,7 @@ static bool tryExtractConstValueFromCondition(const ASTPtr & condition, bool & v
}
}
}
else if (function->name == "toUInt8" || function->name == "toInt8" || function->name == "identity")
else if (function->name == "toUInt8" || function->name == "toInt8" || function->name == "identity" || function->name == "__scalarSubqueryResult")
{
if (const auto * expr_list = function->arguments->as<ASTExpressionList>())
{
Expand Down
Expand Up @@ -2,7 +2,7 @@ SELECT 1
WHERE 0
SELECT 1
SELECT 1
WHERE (1 IN (0, 2)) AND (2 = (identity(_CAST(2, \'Nullable(UInt8)\')) AS subquery))
WHERE (1 IN (0, 2)) AND (2 = (__scalarSubqueryResult(_CAST(2, \'Nullable(UInt8)\')) AS subquery))
SELECT 1
WHERE 1 IN ((
SELECT arrayJoin([1, 2, 3])
Expand Down
Expand Up @@ -7,7 +7,7 @@ SELECT parseDateTime64BestEffort('2020-05-14T03:37:03.253184Z', 'bar'); -- {ser
SELECT parseDateTime64BestEffort('2020-05-14T03:37:03.253184Z', 3, 4); -- {serverError 43} -- invalid timezone parameter
SELECT parseDateTime64BestEffort('2020-05-14T03:37:03.253184Z', 3, 'baz'); -- {serverError BAD_ARGUMENTS} -- unknown timezone

SELECT parseDateTime64BestEffort('2020-05-14T03:37:03.253184Z', materialize(3), 4); -- {serverError 44} -- non-const precision
SELECT parseDateTime64BestEffort('2020-05-14T03:37:03.253184Z', materialize(3), 4); -- {serverError 43} -- non-const precision
SELECT parseDateTime64BestEffort('2020-05-14T03:37:03.253184Z', 3, materialize('UTC')); -- {serverError 44} -- non-const timezone

SELECT parseDateTime64BestEffort('2020-05-14T03:37:03.253184012345678910111213141516171819Z', 3, 'UTC'); -- {serverError 6}
Expand Down
Expand Up @@ -5,7 +5,13 @@ SELECT (SELECT * FROM system.numbers LIMIT 1 OFFSET 1) AS n, toUInt64(10 / n) FO
1,10
EXPLAIN SYNTAX SELECT (SELECT * FROM system.numbers LIMIT 1 OFFSET 1) AS n, toUInt64(10 / n);
SELECT
identity(_CAST(0, \'Nullable(UInt64)\')) AS n,
__scalarSubqueryResult(_CAST(0, \'Nullable(UInt64)\')) AS n,
toUInt64(10 / n)
SELECT * FROM (WITH (SELECT * FROM system.numbers LIMIT 1 OFFSET 1) AS n, toUInt64(10 / n) as q SELECT * FROM system.one WHERE q > 0);
0
SELECT * FROM (SELECT (SELECT '\d[a-z]') AS n, extractAll('5abc', assumeNotNull(n))) FORMAT CSV;
"\d[a-z]","['5a']"
EXPLAIN SYNTAX SELECT (SELECT * FROM system.numbers LIMIT 1 OFFSET 1) AS n, toUInt64(10 / n);
SELECT
__scalarSubqueryResult(_CAST(0, \'Nullable(UInt64)\')) AS n,
toUInt64(10 / n)
Expand Up @@ -3,3 +3,6 @@ SELECT * FROM (SELECT (SELECT * FROM system.numbers LIMIT 1 OFFSET 1) AS n, toUI
SELECT (SELECT * FROM system.numbers LIMIT 1 OFFSET 1) AS n, toUInt64(10 / n) FORMAT CSV;
EXPLAIN SYNTAX SELECT (SELECT * FROM system.numbers LIMIT 1 OFFSET 1) AS n, toUInt64(10 / n);
SELECT * FROM (WITH (SELECT * FROM system.numbers LIMIT 1 OFFSET 1) AS n, toUInt64(10 / n) as q SELECT * FROM system.one WHERE q > 0);

SELECT * FROM (SELECT (SELECT '\d[a-z]') AS n, extractAll('5abc', assumeNotNull(n))) FORMAT CSV;
EXPLAIN SYNTAX SELECT (SELECT * FROM system.numbers LIMIT 1 OFFSET 1) AS n, toUInt64(10 / n);
Expand Up @@ -62,6 +62,7 @@ __bitBoolMaskOr
__bitSwapLastTwo
__bitWrapperFunc
__getScalar
__scalarSubqueryResult
abs
accurateCast
accurateCastOrDefault
Expand Down
66 changes: 66 additions & 0 deletions tests/queries/0_stateless/02999_scalar_subqueries_bug_1.reference
@@ -0,0 +1,66 @@
0 0
0 0
0 0
0 0
1 \N
1 \N
2 \N
2 \N
3 \N
3 \N
4 \N
4 \N
5 \N
5 \N
6 \N
6 \N
7 \N
7 \N
8 \N
8 \N
9 \N
9 \N
10 10
10 10
10 10
10 10
11 \N
11 \N
12 \N
12 \N
13 \N
13 \N
14 \N
14 \N
15 \N
15 \N
16 \N
16 \N
17 \N
17 \N
18 \N
18 \N
19 \N
19 \N
20 20
20 20
20 20
20 20
21 \N
21 \N
22 \N
22 \N
23 \N
23 \N
24 \N
24 \N
25 \N
25 \N
26 \N
26 \N
27 \N
27 \N
28 \N
28 \N
29 \N
29 \N
8 changes: 8 additions & 0 deletions tests/queries/0_stateless/02999_scalar_subqueries_bug_1.sql
@@ -0,0 +1,8 @@
drop table if exists t_table_select;
CREATE TABLE t_table_select (id UInt32) ENGINE = MergeTree ORDER BY id;
INSERT INTO t_table_select (id) SELECT number FROM numbers(30);

CREATE TEMPORARY TABLE t_test AS SELECT a.id, b.id FROM remote('127.0.0.{1,2}', currentDatabase(), t_table_select) AS a GLOBAL LEFT JOIN (SELECT id FROM remote('127.0.0.{1,2}', currentDatabase(), t_table_select) AS b WHERE (b.id % 10) = 0) AS b ON b.id = a.id SETTINGS join_use_nulls = 1;

select * from t_test order by id;

Empty file.