diff --git a/src/include/optimizer/count_rel_table_optimizer.h b/src/include/optimizer/count_rel_table_optimizer.h new file mode 100644 index 0000000000..2d6d70e3e5 --- /dev/null +++ b/src/include/optimizer/count_rel_table_optimizer.h @@ -0,0 +1,49 @@ +#pragma once + +#include "logical_operator_visitor.h" +#include "planner/operator/logical_plan.h" + +namespace lbug { +namespace main { +class ClientContext; +} + +namespace optimizer { + +/** + * This optimizer detects patterns where we're counting all rows from a single rel table + * without any filters, and replaces the scan + aggregate with a direct count from table metadata. + * + * Pattern detected: + * AGGREGATE (COUNT_STAR only, no keys) → + * PROJECTION (empty or pass-through) → + * EXTEND (single rel table) → + * SCAN_NODE_TABLE + * + * This pattern is replaced with: + * COUNT_REL_TABLE (new operator that directly reads the count from table metadata) + */ +class CountRelTableOptimizer : public LogicalOperatorVisitor { +public: + explicit CountRelTableOptimizer(main::ClientContext* context) : context{context} {} + + void rewrite(planner::LogicalPlan* plan); + +private: + std::shared_ptr visitOperator( + const std::shared_ptr& op); + + std::shared_ptr visitAggregateReplace( + std::shared_ptr op) override; + + // Check if the aggregate is a simple COUNT(*) with no keys + bool isSimpleCountStar(planner::LogicalOperator* op) const; + + // Check if the plan below aggregate matches the pattern for optimization + bool canOptimize(planner::LogicalOperator* aggregate) const; + + main::ClientContext* context; +}; + +} // namespace optimizer +} // namespace lbug diff --git a/src/include/optimizer/logical_operator_visitor.h b/src/include/optimizer/logical_operator_visitor.h index 355f147c6b..2e4c91a8b5 100644 --- a/src/include/optimizer/logical_operator_visitor.h +++ b/src/include/optimizer/logical_operator_visitor.h @@ -39,6 +39,12 @@ class LogicalOperatorVisitor { return op; } + virtual void visitCountRelTable(planner::LogicalOperator* /*op*/) {} + virtual std::shared_ptr visitCountRelTableReplace( + std::shared_ptr op) { + return op; + } + virtual void visitDelete(planner::LogicalOperator* /*op*/) {} virtual std::shared_ptr visitDeleteReplace( std::shared_ptr op) { diff --git a/src/include/planner/operator/logical_operator.h b/src/include/planner/operator/logical_operator.h index 3be8831209..53d217fac0 100644 --- a/src/include/planner/operator/logical_operator.h +++ b/src/include/planner/operator/logical_operator.h @@ -17,6 +17,7 @@ enum class LogicalOperatorType : uint8_t { ATTACH_DATABASE, COPY_FROM, COPY_TO, + COUNT_REL_TABLE, CREATE_MACRO, CREATE_SEQUENCE, CREATE_TABLE, diff --git a/src/include/planner/operator/scan/logical_count_rel_table.h b/src/include/planner/operator/scan/logical_count_rel_table.h new file mode 100644 index 0000000000..19c23450a6 --- /dev/null +++ b/src/include/planner/operator/scan/logical_count_rel_table.h @@ -0,0 +1,84 @@ +#pragma once + +#include "binder/expression/expression.h" +#include "binder/expression/node_expression.h" +#include "catalog/catalog_entry/rel_group_catalog_entry.h" +#include "common/enums/extend_direction.h" +#include "planner/operator/logical_operator.h" + +namespace lbug { +namespace planner { + +struct LogicalCountRelTablePrintInfo final : OPPrintInfo { + std::string relTableName; + std::shared_ptr countExpr; + + LogicalCountRelTablePrintInfo(std::string relTableName, + std::shared_ptr countExpr) + : relTableName{std::move(relTableName)}, countExpr{std::move(countExpr)} {} + + std::string toString() const override { + return "Table: " + relTableName + ", Count: " + countExpr->toString(); + } + + std::unique_ptr copy() const override { + return std::make_unique(relTableName, countExpr); + } +}; + +/** + * LogicalCountRelTable is an optimized operator that counts the number of rows + * in a rel table by scanning through bound nodes and counting edges. + * + * This operator is created by CountRelTableOptimizer when it detects: + * COUNT(*) over a single rel table with no filters + */ +class LogicalCountRelTable final : public LogicalOperator { + static constexpr LogicalOperatorType type_ = LogicalOperatorType::COUNT_REL_TABLE; + +public: + LogicalCountRelTable(catalog::RelGroupCatalogEntry* relGroupEntry, + std::vector relTableIDs, + std::vector boundNodeTableIDs, + std::shared_ptr boundNode, common::ExtendDirection direction, + std::shared_ptr countExpr) + : LogicalOperator{type_}, relGroupEntry{relGroupEntry}, relTableIDs{std::move(relTableIDs)}, + boundNodeTableIDs{std::move(boundNodeTableIDs)}, boundNode{std::move(boundNode)}, + direction{direction}, countExpr{std::move(countExpr)} { + cardinality = 1; // Always returns exactly one row + } + + void computeFactorizedSchema() override; + void computeFlatSchema() override; + + std::string getExpressionsForPrinting() const override { return countExpr->toString(); } + + catalog::RelGroupCatalogEntry* getRelGroupEntry() const { return relGroupEntry; } + const std::vector& getRelTableIDs() const { return relTableIDs; } + const std::vector& getBoundNodeTableIDs() const { + return boundNodeTableIDs; + } + std::shared_ptr getBoundNode() const { return boundNode; } + common::ExtendDirection getDirection() const { return direction; } + std::shared_ptr getCountExpr() const { return countExpr; } + + std::unique_ptr getPrintInfo() const override { + return std::make_unique(relGroupEntry->getName(), countExpr); + } + + std::unique_ptr copy() override { + return std::make_unique(relGroupEntry, relTableIDs, boundNodeTableIDs, + boundNode, direction, countExpr); + } + +private: + catalog::RelGroupCatalogEntry* relGroupEntry; + std::vector relTableIDs; + std::vector boundNodeTableIDs; + std::shared_ptr boundNode; + common::ExtendDirection direction; + std::shared_ptr countExpr; +}; + +} // namespace planner +} // namespace lbug diff --git a/src/include/processor/operator/physical_operator.h b/src/include/processor/operator/physical_operator.h index 795693987b..38cf1fe79f 100644 --- a/src/include/processor/operator/physical_operator.h +++ b/src/include/processor/operator/physical_operator.h @@ -22,6 +22,7 @@ enum class PhysicalOperatorType : uint8_t { ATTACH_DATABASE, BATCH_INSERT, COPY_TO, + COUNT_REL_TABLE, CREATE_MACRO, CREATE_SEQUENCE, CREATE_TABLE, diff --git a/src/include/processor/operator/scan/count_rel_table.h b/src/include/processor/operator/scan/count_rel_table.h new file mode 100644 index 0000000000..9b937548cf --- /dev/null +++ b/src/include/processor/operator/scan/count_rel_table.h @@ -0,0 +1,62 @@ +#pragma once + +#include "common/enums/rel_direction.h" +#include "processor/operator/physical_operator.h" +#include "storage/table/node_table.h" +#include "storage/table/rel_table.h" + +namespace lbug { +namespace processor { + +struct CountRelTablePrintInfo final : OPPrintInfo { + std::string relTableName; + + explicit CountRelTablePrintInfo(std::string relTableName) + : relTableName{std::move(relTableName)} {} + + std::string toString() const override { return "Table: " + relTableName; } + + std::unique_ptr copy() const override { + return std::make_unique(relTableName); + } +}; + +/** + * CountRelTable is a source operator that counts edges in a rel table + * by scanning through all bound nodes and counting their edges. + * It creates its own internal vectors for node scanning (not exposed in ResultSet). + */ +class CountRelTable final : public PhysicalOperator { + static constexpr PhysicalOperatorType type_ = PhysicalOperatorType::COUNT_REL_TABLE; + +public: + CountRelTable(std::vector nodeTables, + std::vector relTables, common::RelDataDirection direction, + DataPos countOutputPos, physical_op_id id, std::unique_ptr printInfo) + : PhysicalOperator{type_, id, std::move(printInfo)}, nodeTables{std::move(nodeTables)}, + relTables{std::move(relTables)}, direction{direction}, countOutputPos{countOutputPos} {} + + bool isSource() const override { return true; } + bool isParallel() const override { return false; } + + void initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) override; + + bool getNextTuplesInternal(ExecutionContext* context) override; + + std::unique_ptr copy() override { + return std::make_unique(nodeTables, relTables, direction, countOutputPos, id, + printInfo->copy()); + } + +private: + std::vector nodeTables; + std::vector relTables; + common::RelDataDirection direction; + DataPos countOutputPos; + common::ValueVector* countVector; + bool hasExecuted; + common::row_idx_t totalCount; +}; + +} // namespace processor +} // namespace lbug diff --git a/src/include/processor/plan_mapper.h b/src/include/processor/plan_mapper.h index 3f907facf7..1f838c0bd2 100644 --- a/src/include/processor/plan_mapper.h +++ b/src/include/processor/plan_mapper.h @@ -90,6 +90,8 @@ class PlanMapper { std::unique_ptr mapCopyRelFrom( const planner::LogicalOperator* logicalOperator); std::unique_ptr mapCopyTo(const planner::LogicalOperator* logicalOperator); + std::unique_ptr mapCountRelTable( + const planner::LogicalOperator* logicalOperator); std::unique_ptr mapCreateMacro( const planner::LogicalOperator* logicalOperator); std::unique_ptr mapCreateSequence( diff --git a/src/optimizer/CMakeLists.txt b/src/optimizer/CMakeLists.txt index 7ba4d4ed2b..1ced97ac06 100644 --- a/src/optimizer/CMakeLists.txt +++ b/src/optimizer/CMakeLists.txt @@ -4,6 +4,7 @@ add_library(lbug_optimizer agg_key_dependency_optimizer.cpp cardinality_updater.cpp correlated_subquery_unnest_solver.cpp + count_rel_table_optimizer.cpp factorization_rewriter.cpp filter_push_down_optimizer.cpp logical_operator_collector.cpp diff --git a/src/optimizer/count_rel_table_optimizer.cpp b/src/optimizer/count_rel_table_optimizer.cpp new file mode 100644 index 0000000000..a80b13ee5e --- /dev/null +++ b/src/optimizer/count_rel_table_optimizer.cpp @@ -0,0 +1,217 @@ +#include "optimizer/count_rel_table_optimizer.h" + +#include "binder/expression/aggregate_function_expression.h" +#include "binder/expression/node_expression.h" +#include "catalog/catalog_entry/node_table_id_pair.h" +#include "function/aggregate/count_star.h" +#include "main/client_context.h" +#include "planner/operator/extend/logical_extend.h" +#include "planner/operator/logical_aggregate.h" +#include "planner/operator/logical_projection.h" +#include "planner/operator/scan/logical_count_rel_table.h" +#include "planner/operator/scan/logical_scan_node_table.h" + +using namespace lbug::common; +using namespace lbug::planner; +using namespace lbug::binder; +using namespace lbug::catalog; + +namespace lbug { +namespace optimizer { + +void CountRelTableOptimizer::rewrite(LogicalPlan* plan) { + visitOperator(plan->getLastOperator()); +} + +std::shared_ptr CountRelTableOptimizer::visitOperator( + const std::shared_ptr& op) { + // bottom-up traversal + for (auto i = 0u; i < op->getNumChildren(); ++i) { + op->setChild(i, visitOperator(op->getChild(i))); + } + auto result = visitOperatorReplaceSwitch(op); + result->computeFlatSchema(); + return result; +} + +bool CountRelTableOptimizer::isSimpleCountStar(LogicalOperator* op) const { + if (op->getOperatorType() != LogicalOperatorType::AGGREGATE) { + return false; + } + auto& aggregate = op->constCast(); + + // Must have no keys (i.e., a simple aggregate without GROUP BY) + if (aggregate.hasKeys()) { + return false; + } + + // Must have exactly one aggregate expression + auto aggregates = aggregate.getAggregates(); + if (aggregates.size() != 1) { + return false; + } + + // Must be COUNT_STAR + auto& aggExpr = aggregates[0]; + if (aggExpr->expressionType != ExpressionType::AGGREGATE_FUNCTION) { + return false; + } + auto& aggFuncExpr = aggExpr->constCast(); + if (aggFuncExpr.getFunction().name != function::CountStarFunction::name) { + return false; + } + + // COUNT_STAR should not be DISTINCT (conceptually it doesn't make sense) + if (aggFuncExpr.isDistinct()) { + return false; + } + + return true; +} + +bool CountRelTableOptimizer::canOptimize(LogicalOperator* aggregate) const { + // Pattern we're looking for: + // AGGREGATE (COUNT_STAR, no keys) + // -> PROJECTION (empty expressions or pass-through) + // -> EXTEND (single rel table, no properties scanned) + // -> SCAN_NODE_TABLE (no properties scanned) + // + // Note: The projection between aggregate and extend might be empty or + // just projecting the count expression. + + auto* current = aggregate->getChild(0).get(); + + // Skip any projections (they should be empty or just for count) + while (current->getOperatorType() == LogicalOperatorType::PROJECTION) { + auto& proj = current->constCast(); + // Empty projection is okay, it's just a passthrough + if (!proj.getExpressionsToProject().empty()) { + // If projection has expressions, they should all be aggregate expressions + // (which means they're just passing through the count) + for (auto& expr : proj.getExpressionsToProject()) { + if (expr->expressionType != ExpressionType::AGGREGATE_FUNCTION) { + return false; + } + } + } + current = current->getChild(0).get(); + } + + // Now we should have EXTEND + if (current->getOperatorType() != LogicalOperatorType::EXTEND) { + return false; + } + auto& extend = current->constCast(); + + // Don't optimize for undirected edges (BOTH direction) - the query pattern + // (a)-[e]-(b) generates a plan that scans both directions, and optimizing + // this would require special handling to avoid double counting. + if (extend.getDirection() == ExtendDirection::BOTH) { + return false; + } + + // The rel should be a single table (not multi-labeled) + auto rel = extend.getRel(); + if (rel->isMultiLabeled()) { + return false; + } + + // Check if we're scanning any properties (we can only optimize when no properties needed) + if (!extend.getProperties().empty()) { + return false; + } + + // The child of extend should be SCAN_NODE_TABLE + auto* extendChild = current->getChild(0).get(); + if (extendChild->getOperatorType() != LogicalOperatorType::SCAN_NODE_TABLE) { + return false; + } + auto& scanNode = extendChild->constCast(); + + // Check if node scan has any properties (we can only optimize when no properties needed) + if (!scanNode.getProperties().empty()) { + return false; + } + + return true; +} + +std::shared_ptr CountRelTableOptimizer::visitAggregateReplace( + std::shared_ptr op) { + if (!isSimpleCountStar(op.get())) { + return op; + } + + if (!canOptimize(op.get())) { + return op; + } + + // Find the EXTEND operator + auto* current = op->getChild(0).get(); + while (current->getOperatorType() == LogicalOperatorType::PROJECTION) { + current = current->getChild(0).get(); + } + + KU_ASSERT(current->getOperatorType() == LogicalOperatorType::EXTEND); + auto& extend = current->constCast(); + auto rel = extend.getRel(); + auto boundNode = extend.getBoundNode(); + auto nbrNode = extend.getNbrNode(); + + // Get the rel group entry + KU_ASSERT(rel->getNumEntries() == 1); + auto* relGroupEntry = rel->getEntry(0)->ptrCast(); + + // Determine the source and destination node table IDs based on extend direction. + // If extendFromSource is true, then boundNode is the source and nbrNode is the destination. + // If extendFromSource is false, then boundNode is the destination and nbrNode is the source. + auto boundNodeTableIDs = boundNode->getTableIDsSet(); + auto nbrNodeTableIDs = nbrNode->getTableIDsSet(); + + // Get only the rel table IDs that match the specific node table ID pairs in the query. + // A rel table connects a specific (srcTableID, dstTableID) pair. + std::vector relTableIDs; + for (auto& info : relGroupEntry->getRelEntryInfos()) { + table_id_t srcTableID = info.nodePair.srcTableID; + table_id_t dstTableID = info.nodePair.dstTableID; + + bool matches = false; + if (extend.extendFromSourceNode()) { + // boundNode is src, nbrNode is dst + matches = + boundNodeTableIDs.contains(srcTableID) && nbrNodeTableIDs.contains(dstTableID); + } else { + // boundNode is dst, nbrNode is src + matches = + boundNodeTableIDs.contains(dstTableID) && nbrNodeTableIDs.contains(srcTableID); + } + + if (matches) { + relTableIDs.push_back(info.oid); + } + } + + // If no matching rel tables, don't optimize (shouldn't happen for valid queries) + if (relTableIDs.empty()) { + return op; + } + + // Get the count expression from the original aggregate + auto& aggregate = op->constCast(); + auto countExpr = aggregate.getAggregates()[0]; + + // Get the bound node table IDs as a vector + std::vector boundNodeTableIDsVec(boundNodeTableIDs.begin(), + boundNodeTableIDs.end()); + + // Create the new COUNT_REL_TABLE operator with all necessary information for scanning + auto countRelTable = + std::make_shared(relGroupEntry, std::move(relTableIDs), + std::move(boundNodeTableIDsVec), boundNode, extend.getDirection(), countExpr); + countRelTable->computeFlatSchema(); + + return countRelTable; +} + +} // namespace optimizer +} // namespace lbug diff --git a/src/optimizer/logical_operator_visitor.cpp b/src/optimizer/logical_operator_visitor.cpp index 89d454ce5f..71c093ad37 100644 --- a/src/optimizer/logical_operator_visitor.cpp +++ b/src/optimizer/logical_operator_visitor.cpp @@ -19,6 +19,9 @@ void LogicalOperatorVisitor::visitOperatorSwitch(LogicalOperator* op) { case LogicalOperatorType::COPY_TO: { visitCopyTo(op); } break; + case LogicalOperatorType::COUNT_REL_TABLE: { + visitCountRelTable(op); + } break; case LogicalOperatorType::DELETE: { visitDelete(op); } break; @@ -108,6 +111,9 @@ std::shared_ptr LogicalOperatorVisitor::visitOperatorReplaceSwi case LogicalOperatorType::COPY_TO: { return visitCopyToReplace(op); } + case LogicalOperatorType::COUNT_REL_TABLE: { + return visitCountRelTableReplace(op); + } case LogicalOperatorType::DELETE: { return visitDeleteReplace(op); } diff --git a/src/optimizer/optimizer.cpp b/src/optimizer/optimizer.cpp index e7f6f04283..02ca4e491f 100644 --- a/src/optimizer/optimizer.cpp +++ b/src/optimizer/optimizer.cpp @@ -5,6 +5,7 @@ #include "optimizer/agg_key_dependency_optimizer.h" #include "optimizer/cardinality_updater.h" #include "optimizer/correlated_subquery_unnest_solver.h" +#include "optimizer/count_rel_table_optimizer.h" #include "optimizer/factorization_rewriter.h" #include "optimizer/filter_push_down_optimizer.h" #include "optimizer/limit_push_down_optimizer.h" @@ -32,6 +33,11 @@ void Optimizer::optimize(planner::LogicalPlan* plan, main::ClientContext* contex auto removeUnnecessaryJoinOptimizer = RemoveUnnecessaryJoinOptimizer(); removeUnnecessaryJoinOptimizer.rewrite(plan); + // CountRelTableOptimizer should be applied early before other optimizations + // that might change the plan structure. + auto countRelTableOptimizer = CountRelTableOptimizer(context); + countRelTableOptimizer.rewrite(plan); + auto filterPushDownOptimizer = FilterPushDownOptimizer(context); filterPushDownOptimizer.rewrite(plan); diff --git a/src/planner/operator/logical_operator.cpp b/src/planner/operator/logical_operator.cpp index 9f80089c91..70cd455d19 100644 --- a/src/planner/operator/logical_operator.cpp +++ b/src/planner/operator/logical_operator.cpp @@ -22,6 +22,8 @@ std::string LogicalOperatorUtils::logicalOperatorTypeToString(LogicalOperatorTyp return "COPY_FROM"; case LogicalOperatorType::COPY_TO: return "COPY_TO"; + case LogicalOperatorType::COUNT_REL_TABLE: + return "COUNT_REL_TABLE"; case LogicalOperatorType::CREATE_MACRO: return "CREATE_MACRO"; case LogicalOperatorType::CREATE_SEQUENCE: diff --git a/src/planner/operator/scan/CMakeLists.txt b/src/planner/operator/scan/CMakeLists.txt index f2b03bca0f..7de6cd2b73 100644 --- a/src/planner/operator/scan/CMakeLists.txt +++ b/src/planner/operator/scan/CMakeLists.txt @@ -1,5 +1,6 @@ add_library(lbug_planner_scan OBJECT + logical_count_rel_table.cpp logical_expressions_scan.cpp logical_index_look_up.cpp logical_scan_node_table.cpp) diff --git a/src/planner/operator/scan/logical_count_rel_table.cpp b/src/planner/operator/scan/logical_count_rel_table.cpp new file mode 100644 index 0000000000..dcf152c0c2 --- /dev/null +++ b/src/planner/operator/scan/logical_count_rel_table.cpp @@ -0,0 +1,24 @@ +#include "planner/operator/scan/logical_count_rel_table.h" + +namespace lbug { +namespace planner { + +void LogicalCountRelTable::computeFactorizedSchema() { + createEmptySchema(); + // Only output the count expression in a single-state group. + // This operator is a source - it has no child in the logical plan. + // The bound node is used internally for scanning but not exposed. + auto groupPos = schema->createGroup(); + schema->insertToGroupAndScope(countExpr, groupPos); + schema->setGroupAsSingleState(groupPos); +} + +void LogicalCountRelTable::computeFlatSchema() { + createEmptySchema(); + // For flat schema, create a single group with the count expression. + auto groupPos = schema->createGroup(); + schema->insertToGroupAndScope(countExpr, groupPos); +} + +} // namespace planner +} // namespace lbug diff --git a/src/processor/map/CMakeLists.txt b/src/processor/map/CMakeLists.txt index 3bd69011e4..3cf2c942fe 100644 --- a/src/processor/map/CMakeLists.txt +++ b/src/processor/map/CMakeLists.txt @@ -7,6 +7,7 @@ add_library(lbug_processor_mapper map_acc_hash_join.cpp map_accumulate.cpp map_aggregate.cpp + map_count_rel_table.cpp map_standalone_call.cpp map_table_function_call.cpp map_copy_to.cpp diff --git a/src/processor/map/map_count_rel_table.cpp b/src/processor/map/map_count_rel_table.cpp new file mode 100644 index 0000000000..f51c5ed290 --- /dev/null +++ b/src/processor/map/map_count_rel_table.cpp @@ -0,0 +1,55 @@ +#include "planner/operator/scan/logical_count_rel_table.h" +#include "processor/operator/scan/count_rel_table.h" +#include "processor/plan_mapper.h" +#include "storage/storage_manager.h" + +using namespace lbug::common; +using namespace lbug::planner; +using namespace lbug::storage; + +namespace lbug { +namespace processor { + +std::unique_ptr PlanMapper::mapCountRelTable( + const LogicalOperator* logicalOperator) { + auto& logicalCountRelTable = logicalOperator->constCast(); + auto outSchema = logicalCountRelTable.getSchema(); + + auto storageManager = StorageManager::Get(*clientContext); + + // Get the node tables for scanning bound nodes + std::vector nodeTables; + for (auto tableID : logicalCountRelTable.getBoundNodeTableIDs()) { + nodeTables.push_back(storageManager->getTable(tableID)->ptrCast()); + } + + // Get the rel tables + std::vector relTables; + for (auto tableID : logicalCountRelTable.getRelTableIDs()) { + relTables.push_back(storageManager->getTable(tableID)->ptrCast()); + } + + // Determine rel data direction from extend direction + auto extendDirection = logicalCountRelTable.getDirection(); + RelDataDirection relDirection; + if (extendDirection == ExtendDirection::FWD) { + relDirection = RelDataDirection::FWD; + } else if (extendDirection == ExtendDirection::BWD) { + relDirection = RelDataDirection::BWD; + } else { + // For BOTH, we'll scan FWD (shouldn't reach here as optimizer filters BOTH) + relDirection = RelDataDirection::FWD; + } + + // Get the output position for the count expression + auto countOutputPos = getDataPos(*logicalCountRelTable.getCountExpr(), *outSchema); + + auto printInfo = std::make_unique( + logicalCountRelTable.getRelGroupEntry()->getName()); + + return std::make_unique(std::move(nodeTables), std::move(relTables), + relDirection, countOutputPos, getOperatorID(), std::move(printInfo)); +} + +} // namespace processor +} // namespace lbug diff --git a/src/processor/map/plan_mapper.cpp b/src/processor/map/plan_mapper.cpp index 3211606d76..7176211c51 100644 --- a/src/processor/map/plan_mapper.cpp +++ b/src/processor/map/plan_mapper.cpp @@ -62,6 +62,9 @@ std::unique_ptr PlanMapper::mapOperator(const LogicalOperator* case LogicalOperatorType::COPY_TO: { physicalOperator = mapCopyTo(logicalOperator); } break; + case LogicalOperatorType::COUNT_REL_TABLE: { + physicalOperator = mapCountRelTable(logicalOperator); + } break; case LogicalOperatorType::CREATE_MACRO: { physicalOperator = mapCreateMacro(logicalOperator); } break; diff --git a/src/processor/operator/physical_operator.cpp b/src/processor/operator/physical_operator.cpp index b4c165ff79..b6c168ccbb 100644 --- a/src/processor/operator/physical_operator.cpp +++ b/src/processor/operator/physical_operator.cpp @@ -27,6 +27,8 @@ std::string PhysicalOperatorUtils::operatorTypeToString(PhysicalOperatorType ope return "BATCH_INSERT"; case PhysicalOperatorType::COPY_TO: return "COPY_TO"; + case PhysicalOperatorType::COUNT_REL_TABLE: + return "COUNT_REL_TABLE"; case PhysicalOperatorType::CREATE_MACRO: return "CREATE_MACRO"; case PhysicalOperatorType::CREATE_SEQUENCE: diff --git a/src/processor/operator/scan/CMakeLists.txt b/src/processor/operator/scan/CMakeLists.txt index 0de0db6334..5b50b49835 100644 --- a/src/processor/operator/scan/CMakeLists.txt +++ b/src/processor/operator/scan/CMakeLists.txt @@ -1,5 +1,6 @@ add_library(lbug_processor_operator_scan OBJECT + count_rel_table.cpp primary_key_scan_node_table.cpp scan_multi_rel_tables.cpp scan_node_table.cpp diff --git a/src/processor/operator/scan/count_rel_table.cpp b/src/processor/operator/scan/count_rel_table.cpp new file mode 100644 index 0000000000..9f14692f3e --- /dev/null +++ b/src/processor/operator/scan/count_rel_table.cpp @@ -0,0 +1,137 @@ +#include "processor/operator/scan/count_rel_table.h" + +#include "common/system_config.h" +#include "main/client_context.h" +#include "main/database.h" +#include "processor/execution_context.h" +#include "storage/buffer_manager/memory_manager.h" +#include "storage/local_storage/local_rel_table.h" +#include "storage/local_storage/local_storage.h" +#include "storage/table/column.h" +#include "storage/table/column_chunk_data.h" +#include "storage/table/csr_chunked_node_group.h" +#include "storage/table/csr_node_group.h" +#include "storage/table/rel_table_data.h" +#include "transaction/transaction.h" + +using namespace lbug::common; +using namespace lbug::storage; +using namespace lbug::transaction; + +namespace lbug { +namespace processor { + +void CountRelTable::initLocalStateInternal(ResultSet* resultSet, ExecutionContext* /*context*/) { + countVector = resultSet->getValueVector(countOutputPos).get(); + hasExecuted = false; + totalCount = 0; +} + +// Count rels by using CSR metadata, accounting for deletions and uncommitted data. +// This is more efficient than scanning through all edges. +bool CountRelTable::getNextTuplesInternal(ExecutionContext* context) { + if (hasExecuted) { + return false; + } + + auto transaction = Transaction::Get(*context->clientContext); + auto* memoryManager = context->clientContext->getDatabase()->getMemoryManager(); + + for (auto* relTable : relTables) { + // Get the RelTableData for the specified direction + auto* relTableData = relTable->getDirectedTableData(direction); + auto numNodeGroups = relTableData->getNumNodeGroups(); + auto* csrLengthColumn = relTableData->getCSRLengthColumn(); + + // For each node group in the rel table + for (node_group_idx_t nodeGroupIdx = 0; nodeGroupIdx < numNodeGroups; nodeGroupIdx++) { + auto* nodeGroup = relTableData->getNodeGroup(nodeGroupIdx); + if (!nodeGroup) { + continue; + } + + auto& csrNodeGroup = nodeGroup->cast(); + + // Count from persistent (checkpointed) data + if (auto* persistentGroup = csrNodeGroup.getPersistentChunkedGroup()) { + // Sum the actual relationship lengths from the CSR header instead of using + // getNumRows() which includes dummy rows added for CSR offset array gaps + auto& csrPersistentGroup = persistentGroup->cast(); + auto& csrHeader = csrPersistentGroup.getCSRHeader(); + + // Get the number of nodes in this CSR header + auto numNodes = csrHeader.length->getNumValues(); + if (numNodes == 0) { + continue; + } + + // Create an in-memory chunk to scan the CSR length column into + auto lengthChunk = + ColumnChunkFactory::createColumnChunkData(*memoryManager, LogicalType::UINT64(), + false /*enableCompression*/, StorageConfig::NODE_GROUP_SIZE, + ResidencyState::IN_MEMORY, false /*initializeToZero*/); + + // Initialize scan state and scan the length column from disk + ChunkState chunkState; + csrHeader.length->initializeScanState(chunkState, csrLengthColumn); + csrLengthColumn->scan(chunkState, lengthChunk.get(), 0 /*offsetInChunk*/, numNodes); + + // Sum all the lengths + auto* lengthData = reinterpret_cast(lengthChunk->getData()); + row_idx_t groupRelCount = 0; + for (offset_t i = 0; i < numNodes; ++i) { + groupRelCount += lengthData[i]; + } + totalCount += groupRelCount; + + // Subtract deletions from persistent data + if (persistentGroup->hasVersionInfo()) { + auto numDeletions = + persistentGroup->getNumDeletions(transaction, 0, groupRelCount); + totalCount -= numDeletions; + } + } + + // Count in-memory committed data (not yet checkpointed) + // This data is stored in chunkedGroups within the NodeGroup + auto numChunkedGroups = csrNodeGroup.getNumChunkedGroups(); + for (node_group_idx_t i = 0; i < numChunkedGroups; i++) { + auto* chunkedGroup = csrNodeGroup.getChunkedNodeGroup(i); + if (chunkedGroup) { + auto numRows = chunkedGroup->getNumRows(); + totalCount += numRows; + // Subtract deletions from in-memory committed data + if (chunkedGroup->hasVersionInfo()) { + auto numDeletions = chunkedGroup->getNumDeletions(transaction, 0, numRows); + totalCount -= numDeletions; + } + } + } + } + + // Add uncommitted insertions from local storage + if (transaction->isWriteTransaction()) { + if (auto* localTable = + transaction->getLocalStorage()->getLocalTable(relTable->getTableID())) { + auto& localRelTable = localTable->cast(); + // Count entries in the CSR index for this direction. + // We can't use getNumTotalRows() because it includes deleted rows. + auto& csrIndex = localRelTable.getCSRIndex(direction); + for (const auto& [nodeOffset, rowIndices] : csrIndex) { + totalCount += rowIndices.size(); + } + } + } + } + + hasExecuted = true; + + // Write the count to the output vector (single value) + countVector->state->getSelVectorUnsafe().setToUnfiltered(1); + countVector->setValue(0, static_cast(totalCount)); + + return true; +} + +} // namespace processor +} // namespace lbug diff --git a/test/optimizer/optimizer_test.cpp b/test/optimizer/optimizer_test.cpp index d5bc2c3cb8..608c615d62 100644 --- a/test/optimizer/optimizer_test.cpp +++ b/test/optimizer/optimizer_test.cpp @@ -1,5 +1,6 @@ #include "graph_test/private_graph_test.h" #include "planner/operator/logical_plan_util.h" +#include "planner/operator/scan/logical_count_rel_table.h" #include "test_runner/test_runner.h" namespace lbug { @@ -17,6 +18,19 @@ class OptimizerTest : public DBTest { std::unique_ptr getRoot(const std::string& query) { return TestRunner::getLogicalPlan(query, *conn); } + + // Helper to check if a specific operator type exists in the plan + static bool hasOperatorType(planner::LogicalOperator* op, planner::LogicalOperatorType type) { + if (op->getOperatorType() == type) { + return true; + } + for (auto i = 0u; i < op->getNumChildren(); ++i) { + if (hasOperatorType(op->getChild(i).get(), type)) { + return true; + } + } + return false; + } }; TEST_F(OptimizerTest, JoinHint) { @@ -211,5 +225,37 @@ TEST_F(OptimizerTest, SubqueryHint) { ASSERT_STREQ(getEncodedPlan(q6).c_str(), "Filter()HJ(a._ID){S(a)}{E(a)Filter()S(b)}"); } +TEST_F(OptimizerTest, CountRelTableOptimizer) { + // Test that COUNT(*) over a single rel table is optimized to COUNT_REL_TABLE + auto q1 = "MATCH (a:person)-[e:knows]->(b:person) RETURN COUNT(*);"; + auto plan1 = getRoot(q1); + ASSERT_TRUE(hasOperatorType(plan1->getLastOperator().get(), + planner::LogicalOperatorType::COUNT_REL_TABLE)); + // Verify the query returns the correct result + auto result1 = conn->query(q1); + ASSERT_TRUE(result1->isSuccess()); + ASSERT_EQ(result1->getNumTuples(), 1); + auto tuple1 = result1->getNext(); + ASSERT_EQ(tuple1->getValue(0)->getValue(), 14); + + // Test that COUNT(*) with GROUP BY is NOT optimized (has keys) + auto q2 = "MATCH (a:person)-[e:knows]->(b:person) RETURN a.fName, COUNT(*);"; + auto plan2 = getRoot(q2); + ASSERT_FALSE(hasOperatorType(plan2->getLastOperator().get(), + planner::LogicalOperatorType::COUNT_REL_TABLE)); + + // Test that COUNT(*) with WHERE clause is NOT optimized (has filter) + auto q3 = "MATCH (a:person)-[e:knows]->(b:person) WHERE a.ID > 0 RETURN COUNT(*);"; + auto plan3 = getRoot(q3); + ASSERT_FALSE(hasOperatorType(plan3->getLastOperator().get(), + planner::LogicalOperatorType::COUNT_REL_TABLE)); + + // Test that COUNT(DISTINCT ...) is NOT optimized + auto q4 = "MATCH (a:person)-[e:knows]->(b:person) RETURN COUNT(DISTINCT a);"; + auto plan4 = getRoot(q4); + ASSERT_FALSE(hasOperatorType(plan4->getLastOperator().get(), + planner::LogicalOperatorType::COUNT_REL_TABLE)); +} + } // namespace testing } // namespace lbug diff --git a/tools/benchmark/count_rel_table.benchmark b/tools/benchmark/count_rel_table.benchmark new file mode 100644 index 0000000000..1d45389939 --- /dev/null +++ b/tools/benchmark/count_rel_table.benchmark @@ -0,0 +1,5 @@ +-NAME count_rel_table +-PRERUN CREATE NODE TABLE account(ID INT64 PRIMARY KEY); CREATE REL TABLE follows(FROM account TO account); COPY account FROM "dataset/snap/amazon0601/parquet/amazon-nodes.parquet"; COPY follows FROM "dataset/snap/amazon0601/parquet/amazon-edges.parquet"; +-QUERY MATCH ()-[r:follows]->() RETURN COUNT(*) +---- 1 +3387388