Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions src/include/optimizer/count_rel_table_optimizer.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
#pragma once

#include "logical_operator_visitor.h"
#include "planner/operator/logical_plan.h"

namespace lbug {
namespace main {
class ClientContext;
}

namespace optimizer {

/**
* This optimizer detects patterns where we're counting all rows from a single rel table
* without any filters, and replaces the scan + aggregate with a direct count from table metadata.
*
* Pattern detected:
* AGGREGATE (COUNT_STAR only, no keys) →
* PROJECTION (empty or pass-through) →
* EXTEND (single rel table) →
* SCAN_NODE_TABLE
*
* This pattern is replaced with:
* COUNT_REL_TABLE (new operator that directly reads the count from table metadata)
*/
class CountRelTableOptimizer : public LogicalOperatorVisitor {
public:
explicit CountRelTableOptimizer(main::ClientContext* context) : context{context} {}

void rewrite(planner::LogicalPlan* plan);

private:
std::shared_ptr<planner::LogicalOperator> visitOperator(
const std::shared_ptr<planner::LogicalOperator>& op);

std::shared_ptr<planner::LogicalOperator> visitAggregateReplace(
std::shared_ptr<planner::LogicalOperator> op) override;

// Check if the aggregate is a simple COUNT(*) with no keys
bool isSimpleCountStar(planner::LogicalOperator* op) const;

// Check if the plan below aggregate matches the pattern for optimization
bool canOptimize(planner::LogicalOperator* aggregate) const;

main::ClientContext* context;
};

} // namespace optimizer
} // namespace lbug
6 changes: 6 additions & 0 deletions src/include/optimizer/logical_operator_visitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,12 @@ class LogicalOperatorVisitor {
return op;
}

virtual void visitCountRelTable(planner::LogicalOperator* /*op*/) {}
virtual std::shared_ptr<planner::LogicalOperator> visitCountRelTableReplace(
std::shared_ptr<planner::LogicalOperator> op) {
return op;
}

virtual void visitDelete(planner::LogicalOperator* /*op*/) {}
virtual std::shared_ptr<planner::LogicalOperator> visitDeleteReplace(
std::shared_ptr<planner::LogicalOperator> op) {
Expand Down
1 change: 1 addition & 0 deletions src/include/planner/operator/logical_operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ enum class LogicalOperatorType : uint8_t {
ATTACH_DATABASE,
COPY_FROM,
COPY_TO,
COUNT_REL_TABLE,
CREATE_MACRO,
CREATE_SEQUENCE,
CREATE_TABLE,
Expand Down
84 changes: 84 additions & 0 deletions src/include/planner/operator/scan/logical_count_rel_table.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
#pragma once

#include "binder/expression/expression.h"
#include "binder/expression/node_expression.h"
#include "catalog/catalog_entry/rel_group_catalog_entry.h"
#include "common/enums/extend_direction.h"
#include "planner/operator/logical_operator.h"

namespace lbug {
namespace planner {

struct LogicalCountRelTablePrintInfo final : OPPrintInfo {
std::string relTableName;
std::shared_ptr<binder::Expression> countExpr;

LogicalCountRelTablePrintInfo(std::string relTableName,
std::shared_ptr<binder::Expression> countExpr)
: relTableName{std::move(relTableName)}, countExpr{std::move(countExpr)} {}

std::string toString() const override {
return "Table: " + relTableName + ", Count: " + countExpr->toString();
}

std::unique_ptr<OPPrintInfo> copy() const override {
return std::make_unique<LogicalCountRelTablePrintInfo>(relTableName, countExpr);
}
};

/**
* LogicalCountRelTable is an optimized operator that counts the number of rows
* in a rel table by scanning through bound nodes and counting edges.
*
* This operator is created by CountRelTableOptimizer when it detects:
* COUNT(*) over a single rel table with no filters
*/
class LogicalCountRelTable final : public LogicalOperator {
static constexpr LogicalOperatorType type_ = LogicalOperatorType::COUNT_REL_TABLE;

public:
LogicalCountRelTable(catalog::RelGroupCatalogEntry* relGroupEntry,
std::vector<common::table_id_t> relTableIDs,
std::vector<common::table_id_t> boundNodeTableIDs,
std::shared_ptr<binder::NodeExpression> boundNode, common::ExtendDirection direction,
std::shared_ptr<binder::Expression> countExpr)
: LogicalOperator{type_}, relGroupEntry{relGroupEntry}, relTableIDs{std::move(relTableIDs)},
boundNodeTableIDs{std::move(boundNodeTableIDs)}, boundNode{std::move(boundNode)},
direction{direction}, countExpr{std::move(countExpr)} {
cardinality = 1; // Always returns exactly one row
}

void computeFactorizedSchema() override;
void computeFlatSchema() override;

std::string getExpressionsForPrinting() const override { return countExpr->toString(); }

catalog::RelGroupCatalogEntry* getRelGroupEntry() const { return relGroupEntry; }
const std::vector<common::table_id_t>& getRelTableIDs() const { return relTableIDs; }
const std::vector<common::table_id_t>& getBoundNodeTableIDs() const {
return boundNodeTableIDs;
}
std::shared_ptr<binder::NodeExpression> getBoundNode() const { return boundNode; }
common::ExtendDirection getDirection() const { return direction; }
std::shared_ptr<binder::Expression> getCountExpr() const { return countExpr; }

std::unique_ptr<OPPrintInfo> getPrintInfo() const override {
return std::make_unique<LogicalCountRelTablePrintInfo>(relGroupEntry->getName(), countExpr);
}

std::unique_ptr<LogicalOperator> copy() override {
return std::make_unique<LogicalCountRelTable>(relGroupEntry, relTableIDs, boundNodeTableIDs,
boundNode, direction, countExpr);
}

private:
catalog::RelGroupCatalogEntry* relGroupEntry;
std::vector<common::table_id_t> relTableIDs;
std::vector<common::table_id_t> boundNodeTableIDs;
std::shared_ptr<binder::NodeExpression> boundNode;
common::ExtendDirection direction;
std::shared_ptr<binder::Expression> countExpr;
};

} // namespace planner
} // namespace lbug
1 change: 1 addition & 0 deletions src/include/processor/operator/physical_operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ enum class PhysicalOperatorType : uint8_t {
ATTACH_DATABASE,
BATCH_INSERT,
COPY_TO,
COUNT_REL_TABLE,
CREATE_MACRO,
CREATE_SEQUENCE,
CREATE_TABLE,
Expand Down
62 changes: 62 additions & 0 deletions src/include/processor/operator/scan/count_rel_table.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
#pragma once

#include "common/enums/rel_direction.h"
#include "processor/operator/physical_operator.h"
#include "storage/table/node_table.h"
#include "storage/table/rel_table.h"

namespace lbug {
namespace processor {

struct CountRelTablePrintInfo final : OPPrintInfo {
std::string relTableName;

explicit CountRelTablePrintInfo(std::string relTableName)
: relTableName{std::move(relTableName)} {}

std::string toString() const override { return "Table: " + relTableName; }

std::unique_ptr<OPPrintInfo> copy() const override {
return std::make_unique<CountRelTablePrintInfo>(relTableName);
}
};

/**
* CountRelTable is a source operator that counts edges in a rel table
* by scanning through all bound nodes and counting their edges.
* It creates its own internal vectors for node scanning (not exposed in ResultSet).
*/
class CountRelTable final : public PhysicalOperator {
static constexpr PhysicalOperatorType type_ = PhysicalOperatorType::COUNT_REL_TABLE;

public:
CountRelTable(std::vector<storage::NodeTable*> nodeTables,
std::vector<storage::RelTable*> relTables, common::RelDataDirection direction,
DataPos countOutputPos, physical_op_id id, std::unique_ptr<OPPrintInfo> printInfo)
: PhysicalOperator{type_, id, std::move(printInfo)}, nodeTables{std::move(nodeTables)},
relTables{std::move(relTables)}, direction{direction}, countOutputPos{countOutputPos} {}

bool isSource() const override { return true; }
bool isParallel() const override { return false; }

void initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) override;

bool getNextTuplesInternal(ExecutionContext* context) override;

std::unique_ptr<PhysicalOperator> copy() override {
return std::make_unique<CountRelTable>(nodeTables, relTables, direction, countOutputPos, id,
printInfo->copy());
}

private:
std::vector<storage::NodeTable*> nodeTables;
std::vector<storage::RelTable*> relTables;
common::RelDataDirection direction;
DataPos countOutputPos;
common::ValueVector* countVector;
bool hasExecuted;
common::row_idx_t totalCount;
};

} // namespace processor
} // namespace lbug
2 changes: 2 additions & 0 deletions src/include/processor/plan_mapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ class PlanMapper {
std::unique_ptr<PhysicalOperator> mapCopyRelFrom(
const planner::LogicalOperator* logicalOperator);
std::unique_ptr<PhysicalOperator> mapCopyTo(const planner::LogicalOperator* logicalOperator);
std::unique_ptr<PhysicalOperator> mapCountRelTable(
const planner::LogicalOperator* logicalOperator);
std::unique_ptr<PhysicalOperator> mapCreateMacro(
const planner::LogicalOperator* logicalOperator);
std::unique_ptr<PhysicalOperator> mapCreateSequence(
Expand Down
1 change: 1 addition & 0 deletions src/optimizer/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ add_library(lbug_optimizer
agg_key_dependency_optimizer.cpp
cardinality_updater.cpp
correlated_subquery_unnest_solver.cpp
count_rel_table_optimizer.cpp
factorization_rewriter.cpp
filter_push_down_optimizer.cpp
logical_operator_collector.cpp
Expand Down
Loading