Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cherry pick #60464 to 23.8: Try to avoid calculation of scalar subqueries for CREATE TABLE. #61081

Closed
130 changes: 76 additions & 54 deletions src/Analyzer/Passes/QueryAnalysisPass.cpp
Expand Up @@ -1082,6 +1082,8 @@ class TableExpressionsAliasVisitor : public InDepthQueryTreeVisitor<TableExpress
class QueryAnalyzer
{
public:
explicit QueryAnalyzer(bool only_analyze_) : only_analyze(only_analyze_) {}

void resolve(QueryTreeNodePtr & node, const QueryTreeNodePtr & table_expression, ContextPtr context)
{
IdentifierResolveScope scope(node, nullptr /*parent_scope*/);
Expand Down Expand Up @@ -1444,6 +1446,7 @@ class QueryAnalyzer
/// Global scalar subquery to scalar value map
std::unordered_map<QueryTreeNodePtrWithHash, Block> scalar_subquery_to_scalar_value;

const bool only_analyze;
};

/// Utility functions implementation
Expand Down Expand Up @@ -1991,80 +1994,96 @@ void QueryAnalyzer::evaluateScalarSubqueryIfNeeded(QueryTreeNodePtr & node, Iden
auto interpreter = std::make_unique<InterpreterSelectQueryAnalyzer>(node->toAST(), subquery_context, subquery_context->getViewSource(), options);

auto io = interpreter->execute();

PullingAsyncPipelineExecutor executor(io.pipeline);
io.pipeline.setProgressCallback(context->getProgressCallback());
io.pipeline.setProcessListElement(context->getProcessListElement());

Block block;

while (block.rows() == 0 && executor.pull(block))
{
}

if (block.rows() == 0)
if (only_analyze)
{
auto types = interpreter->getSampleBlock().getDataTypes();
if (types.size() != 1)
types = {std::make_shared<DataTypeTuple>(types)};

auto & type = types[0];
if (!type->isNullable())
/// If query is only analyzed, then constants are not correct.
scalar_block = interpreter->getSampleBlock();
for (auto & column : scalar_block)
{
if (!type->canBeInsideNullable())
throw Exception(ErrorCodes::INCORRECT_RESULT_OF_SCALAR_SUBQUERY,
"Scalar subquery returned empty result of type {} which cannot be Nullable",
type->getName());

type = makeNullable(type);
if (column.column->empty())
{
auto mut_col = column.column->cloneEmpty();
mut_col->insertDefault();
column.column = std::move(mut_col);
}
}

auto scalar_column = type->createColumn();
scalar_column->insert(Null());
scalar_block.insert({std::move(scalar_column), type, "null"});
}
else
{
if (block.rows() != 1)
throw Exception(ErrorCodes::INCORRECT_RESULT_OF_SCALAR_SUBQUERY, "Scalar subquery returned more than one row");
Block block;

Block tmp_block;
while (tmp_block.rows() == 0 && executor.pull(tmp_block))
while (block.rows() == 0 && executor.pull(block))
{
}

if (tmp_block.rows() != 0)
throw Exception(ErrorCodes::INCORRECT_RESULT_OF_SCALAR_SUBQUERY, "Scalar subquery returned more than one row");

block = materializeBlock(block);
size_t columns = block.columns();

if (columns == 1)
if (block.rows() == 0)
{
auto & column = block.getByPosition(0);
/// Here we wrap type to nullable if we can.
/// It is needed cause if subquery return no rows, it's result will be Null.
/// In case of many columns, do not check it cause tuple can't be nullable.
if (!column.type->isNullable() && column.type->canBeInsideNullable())
auto types = interpreter->getSampleBlock().getDataTypes();
if (types.size() != 1)
types = {std::make_shared<DataTypeTuple>(types)};

auto & type = types[0];
if (!type->isNullable())
{
column.type = makeNullable(column.type);
column.column = makeNullable(column.column);
if (!type->canBeInsideNullable())
throw Exception(ErrorCodes::INCORRECT_RESULT_OF_SCALAR_SUBQUERY,
"Scalar subquery returned empty result of type {} which cannot be Nullable",
type->getName());

type = makeNullable(type);
}

scalar_block = block;
auto scalar_column = type->createColumn();
scalar_column->insert(Null());
scalar_block.insert({std::move(scalar_column), type, "null"});
}
else
{
/** Make unique column names for tuple.
*
* Example: SELECT (SELECT 2 AS x, x)
*/
makeUniqueColumnNamesInBlock(block);
if (block.rows() != 1)
throw Exception(ErrorCodes::INCORRECT_RESULT_OF_SCALAR_SUBQUERY, "Scalar subquery returned more than one row");

Block tmp_block;
while (tmp_block.rows() == 0 && executor.pull(tmp_block))
{
}

scalar_block.insert({
ColumnTuple::create(block.getColumns()),
std::make_shared<DataTypeTuple>(block.getDataTypes(), block.getNames()),
"tuple"});
if (tmp_block.rows() != 0)
throw Exception(ErrorCodes::INCORRECT_RESULT_OF_SCALAR_SUBQUERY, "Scalar subquery returned more than one row");

block = materializeBlock(block);
size_t columns = block.columns();

if (columns == 1)
{
auto & column = block.getByPosition(0);
/// Here we wrap type to nullable if we can.
/// It is needed cause if subquery return no rows, it's result will be Null.
/// In case of many columns, do not check it cause tuple can't be nullable.
if (!column.type->isNullable() && column.type->canBeInsideNullable())
{
column.type = makeNullable(column.type);
column.column = makeNullable(column.column);
}

scalar_block = block;
}
else
{
/** Make unique column names for tuple.
*
* Example: SELECT (SELECT 2 AS x, x)
*/
makeUniqueColumnNamesInBlock(block);

scalar_block.insert({
ColumnTuple::create(block.getColumns()),
std::make_shared<DataTypeTuple>(block.getDataTypes(), block.getNames()),
"tuple"});
}
}
}

Expand Down Expand Up @@ -7786,13 +7805,16 @@ void QueryAnalyzer::resolveUnion(const QueryTreeNodePtr & union_node, Identifier

}

QueryAnalysisPass::QueryAnalysisPass(QueryTreeNodePtr table_expression_)
QueryAnalysisPass::QueryAnalysisPass(QueryTreeNodePtr table_expression_, bool only_analyze_)
: table_expression(std::move(table_expression_))
, only_analyze(only_analyze_)
{}

QueryAnalysisPass::QueryAnalysisPass(bool only_analyze_) : only_analyze(only_analyze_) {}

void QueryAnalysisPass::run(QueryTreeNodePtr & query_tree_node, ContextPtr context)
{
QueryAnalyzer analyzer;
QueryAnalyzer analyzer(only_analyze);
analyzer.resolve(query_tree_node, table_expression, context);
createUniqueTableAliases(query_tree_node, table_expression, context);
}
Expand Down
5 changes: 3 additions & 2 deletions src/Analyzer/Passes/QueryAnalysisPass.h
Expand Up @@ -71,13 +71,13 @@ class QueryAnalysisPass final : public IQueryTreePass
/** Construct query analysis pass for query or union analysis.
* Available columns are extracted from query node join tree.
*/
QueryAnalysisPass() = default;
explicit QueryAnalysisPass(bool only_analyze_ = false);

/** Construct query analysis pass for expression or list of expressions analysis.
* Available expression columns are extracted from table expression.
* Table expression node must have query, union, table, table function type.
*/
explicit QueryAnalysisPass(QueryTreeNodePtr table_expression_);
QueryAnalysisPass(QueryTreeNodePtr table_expression_, bool only_analyze_ = false);

String getName() override
{
Expand All @@ -93,6 +93,7 @@ class QueryAnalysisPass final : public IQueryTreePass

private:
QueryTreeNodePtr table_expression;
const bool only_analyze;
};

}
4 changes: 2 additions & 2 deletions src/Analyzer/QueryTreePassManager.cpp
Expand Up @@ -246,9 +246,9 @@ void QueryTreePassManager::dump(WriteBuffer & buffer, size_t up_to_pass_index)
}
}

void addQueryTreePasses(QueryTreePassManager & manager)
void addQueryTreePasses(QueryTreePassManager & manager, bool only_analyze)
{
manager.addPass(std::make_unique<QueryAnalysisPass>());
manager.addPass(std::make_unique<QueryAnalysisPass>(only_analyze));
manager.addPass(std::make_unique<GroupingFunctionsResolvePass>());

manager.addPass(std::make_unique<RemoveUnusedProjectionColumnsPass>());
Expand Down
2 changes: 1 addition & 1 deletion src/Analyzer/QueryTreePassManager.h
Expand Up @@ -47,6 +47,6 @@ class QueryTreePassManager : public WithContext
std::vector<QueryTreePassPtr> passes;
};

void addQueryTreePasses(QueryTreePassManager & manager);
void addQueryTreePasses(QueryTreePassManager & manager, bool only_analyze = false);

}
2 changes: 2 additions & 0 deletions src/Functions/FunctionTokens.h
Expand Up @@ -74,6 +74,8 @@ class FunctionTokens : public IFunction

size_t getNumberOfArguments() const override { return Generator::getNumberOfArguments(); }

ColumnNumbers getArgumentsThatAreAlwaysConstant() const override { return Generator::getArgumentsThatAreAlwaysConstant(); }

DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override
{
Generator::checkArguments(*this, arguments);
Expand Down
2 changes: 2 additions & 0 deletions src/Functions/URL/URLHierarchy.cpp
Expand Up @@ -24,6 +24,8 @@ class URLPathHierarchyImpl
static bool isVariadic() { return false; }
static size_t getNumberOfArguments() { return 1; }

static ColumnNumbers getArgumentsThatAreAlwaysConstant() { return {}; }

static void checkArguments(const IFunction & func, const ColumnsWithTypeAndName & arguments)
{
FunctionArgumentDescriptors mandatory_args{
Expand Down
2 changes: 2 additions & 0 deletions src/Functions/URL/URLPathHierarchy.cpp
Expand Up @@ -22,6 +22,8 @@ class URLHierarchyImpl
static bool isVariadic() { return false; }
static size_t getNumberOfArguments() { return 1; }

static ColumnNumbers getArgumentsThatAreAlwaysConstant() { return {}; }

static void checkArguments(const IFunction & func, const ColumnsWithTypeAndName & arguments)
{
FunctionArgumentDescriptors mandatory_args{
Expand Down
2 changes: 2 additions & 0 deletions src/Functions/URL/extractURLParameterNames.cpp
Expand Up @@ -22,6 +22,8 @@ class ExtractURLParameterNamesImpl
static bool isVariadic() { return false; }
static size_t getNumberOfArguments() { return 1; }

static ColumnNumbers getArgumentsThatAreAlwaysConstant() { return {}; }

static void checkArguments(const IFunction & func, const ColumnsWithTypeAndName & arguments)
{
FunctionArgumentDescriptors mandatory_args{
Expand Down
2 changes: 2 additions & 0 deletions src/Functions/URL/extractURLParameters.cpp
Expand Up @@ -23,6 +23,8 @@ class ExtractURLParametersImpl
static bool isVariadic() { return false; }
static size_t getNumberOfArguments() { return 1; }

static ColumnNumbers getArgumentsThatAreAlwaysConstant() { return {}; }

static void checkArguments(const IFunction & func, const ColumnsWithTypeAndName & arguments)
{
FunctionArgumentDescriptors mandatory_args{
Expand Down
2 changes: 2 additions & 0 deletions src/Functions/alphaTokens.cpp
Expand Up @@ -32,6 +32,8 @@ class SplitByAlphaImpl

static size_t getNumberOfArguments() { return 0; }

static ColumnNumbers getArgumentsThatAreAlwaysConstant() { return {1}; }

static void checkArguments(const IFunction & func, const ColumnsWithTypeAndName & arguments)
{
checkArgumentsWithOptionalMaxSubstrings(func, arguments);
Expand Down
2 changes: 2 additions & 0 deletions src/Functions/extractAll.cpp
Expand Up @@ -50,6 +50,8 @@ class ExtractAllImpl
static bool isVariadic() { return false; }
static size_t getNumberOfArguments() { return 2; }

static ColumnNumbers getArgumentsThatAreAlwaysConstant() { return {1}; }

static void checkArguments(const IFunction & func, const ColumnsWithTypeAndName & arguments)
{
FunctionArgumentDescriptors mandatory_args{
Expand Down
5 changes: 5 additions & 0 deletions src/Functions/identity.cpp
Expand Up @@ -9,4 +9,9 @@ REGISTER_FUNCTION(Identity)
factory.registerFunction<FunctionIdentity>();
}

REGISTER_FUNCTION(ScalarSubqueryResult)
{
factory.registerFunction<FunctionScalarSubqueryResult>();
}

}
20 changes: 17 additions & 3 deletions src/Functions/identity.h
Expand Up @@ -6,11 +6,12 @@
namespace DB
{

class FunctionIdentity : public IFunction
template<typename Name>
class FunctionIdentityBase : public IFunction
{
public:
static constexpr auto name = "identity";
static FunctionPtr create(ContextPtr) { return std::make_shared<FunctionIdentity>(); }
static constexpr auto name = Name::name;
static FunctionPtr create(ContextPtr) { return std::make_shared<FunctionIdentityBase<Name>>(); }

String getName() const override { return name; }
size_t getNumberOfArguments() const override { return 1; }
Expand All @@ -28,4 +29,17 @@ class FunctionIdentity : public IFunction
}
};

struct IdentityName
{
static constexpr auto name = "identity";
};

struct ScalarSubqueryResultName
{
static constexpr auto name = "__scalarSubqueryResult";
};

using FunctionIdentity = FunctionIdentityBase<IdentityName>;
using FunctionScalarSubqueryResult = FunctionIdentityBase<ScalarSubqueryResultName>;

}
2 changes: 2 additions & 0 deletions src/Functions/splitByChar.cpp
Expand Up @@ -40,6 +40,8 @@ class SplitByCharImpl
static bool isVariadic() { return true; }
static size_t getNumberOfArguments() { return 0; }

static ColumnNumbers getArgumentsThatAreAlwaysConstant() { return {0, 2}; }

static void checkArguments(const IFunction & func, const ColumnsWithTypeAndName & arguments)
{
checkArgumentsWithSeparatorAndOptionalMaxSubstrings(func, arguments);
Expand Down
2 changes: 2 additions & 0 deletions src/Functions/splitByNonAlpha.cpp
Expand Up @@ -42,6 +42,8 @@ class SplitByNonAlphaImpl
static bool isVariadic() { return true; }
static size_t getNumberOfArguments() { return 0; }

static ColumnNumbers getArgumentsThatAreAlwaysConstant() { return {1}; }

static void checkArguments(const IFunction & func, const ColumnsWithTypeAndName & arguments)
{
checkArgumentsWithOptionalMaxSubstrings(func, arguments);
Expand Down
2 changes: 2 additions & 0 deletions src/Functions/splitByRegexp.cpp
Expand Up @@ -44,6 +44,8 @@ class SplitByRegexpImpl
static bool isVariadic() { return true; }
static size_t getNumberOfArguments() { return 0; }

static ColumnNumbers getArgumentsThatAreAlwaysConstant() { return {0, 2}; }

static void checkArguments(const IFunction & func, const ColumnsWithTypeAndName & arguments)
{
checkArgumentsWithSeparatorAndOptionalMaxSubstrings(func, arguments);
Expand Down
2 changes: 2 additions & 0 deletions src/Functions/splitByString.cpp
Expand Up @@ -39,6 +39,8 @@ class SplitByStringImpl
static bool isVariadic() { return true; }
static size_t getNumberOfArguments() { return 0; }

static ColumnNumbers getArgumentsThatAreAlwaysConstant() { return {0, 2}; }

static void checkArguments(const IFunction & func, const ColumnsWithTypeAndName & arguments)
{
checkArgumentsWithSeparatorAndOptionalMaxSubstrings(func, arguments);
Expand Down
2 changes: 2 additions & 0 deletions src/Functions/splitByWhitespace.cpp
Expand Up @@ -30,6 +30,8 @@ class SplitByWhitespaceImpl
static bool isVariadic() { return true; }
static size_t getNumberOfArguments() { return 0; }

static ColumnNumbers getArgumentsThatAreAlwaysConstant() { return {1}; }

static void checkArguments(const IFunction & func, const ColumnsWithTypeAndName & arguments)
{
checkArgumentsWithOptionalMaxSubstrings(func, arguments);
Expand Down