diff --git a/.clang-tidy b/.clang-tidy index 5dd99dd83..d5702893d 100644 --- a/.clang-tidy +++ b/.clang-tidy @@ -47,10 +47,10 @@ Checks: > # Warnings from headers outside the regex (PyTorch, pybind11, etc.) are suppressed # entirely and never reach WarningsAsErrors — so the large warning counts printed # by clang-tidy ("N warnings generated") are third-party noise that is silently -# dropped. Only diagnostics in our own headers (.*/gigl/csrc/.*) are reported, +# dropped. Only diagnostics in our own headers (.*/gigl-core/core/.*) are reported, # and those are treated as hard errors. WarningsAsErrors: '*' -HeaderFilterRegex: '.*/gigl-core/csrc/.*' +HeaderFilterRegex: '.*/gigl-core/core/.*' FormatStyle: none # CheckOptions: per-check tuning parameters. Each entry configures a specific # option for an individual check, using the form: diff --git a/.github/cloud_builder/run_command_on_active_checkout.yaml b/.github/cloud_builder/run_command_on_active_checkout.yaml index d99c024a3..5c11c5207 100644 --- a/.github/cloud_builder/run_command_on_active_checkout.yaml +++ b/.github/cloud_builder/run_command_on_active_checkout.yaml @@ -3,7 +3,7 @@ substitutions: options: logging: CLOUD_LOGGING_ONLY steps: - - name: us-central1-docker.pkg.dev/external-snap-ci-github-gigl/gigl-base-images/gigl-builder:b598f3d72eee47f5513dcb39460944459a0a012f.108.1 + - name: us-central1-docker.pkg.dev/external-snap-ci-github-gigl/gigl-base-images/gigl-builder:7d3182eeb6446ce3e35910babba990c8e003879d.109.1 entrypoint: /bin/bash args: - -c diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 3325d1c98..0855887be 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -59,7 +59,7 @@ jobs: run: | # Remove stale cmake cache from previous runs on self-hosted runners. rm -rf gigl-core/.cache/cmake_build - uv build --wheel gigl-core/ --locked + uv build --wheel gigl-core/ - name: Publish gigl-core wheel working-directory: gigl-core @@ -73,5 +73,5 @@ jobs: # when gigl becomes available. - name: Build and publish gigl wheel run: | - uv build --wheel --locked + uv build --wheel uv publish --publish-url ${{ matrix.publish-url }} --username oauth2accesstoken --keyring-provider subprocess diff --git a/Makefile b/Makefile index 9db02bb80..b86b1242f 100644 --- a/Makefile +++ b/Makefile @@ -22,7 +22,7 @@ DOCKER_IMAGE_MAIN_CPU_NAME_WITH_TAG?=${DOCKER_IMAGE_MAIN_CPU_NAME}:${DATE} DOCKER_IMAGE_DEV_WORKBENCH_NAME_WITH_TAG?=${DOCKER_IMAGE_DEV_WORKBENCH_NAME}:${DATE} PYTHON_DIRS:=.github/scripts examples gigl tests snapchat scripts -CPP_SOURCES:=$(shell find gigl-core/csrc \( -name "*.cpp" -o -name "*.cu" \) 2>/dev/null) +CPP_SOURCES:=$(shell find gigl-core/core \( -name "*.cpp" -o -name "*.cu" \) 2>/dev/null) # clang-tidy 15 does not fully support CUDA syntax (e.g. <<<...>>>, __global__). # Exclude .cu files from tidy targets; clang-format and clangd handle them fine. CPP_SOURCES_NO_CUDA:=$(filter-out %.cu,$(CPP_SOURCES)) @@ -35,8 +35,8 @@ GIGL_E2E_TEST_COMPILED_PIPELINE_PATH:=/tmp/gigl/pipeline_${DATE}_${GIT_HASH}.yam GIT_BRANCH:=$(shell git rev-parse --abbrev-ref HEAD 2>/dev/null || echo "") -# Find all markdown files in the repo except for those in .venv or tools directories. -MD_FILES := $(shell find . -type f -name "*.md" ! -path "*/.venv/*" ! -path "*/tools/*") +# Find all markdown files in the repo except for those in .venv, tools, or cmake cache directories. +MD_FILES := $(shell find . -type f -name "*.md" ! -path "*/.venv/*" ! -path "*/tools/*" ! -path "*/.cache/*") GIGL_ALERT_EMAILS?="" get_ver_hash: @@ -165,20 +165,11 @@ type_check: build_cpp_extensions: $(MAKE) -C gigl-core build_cpp_extensions -check_lint_cpp: build_cpp_extensions - $(if $(CPP_SOURCES_NO_CUDA),uv run python -m scripts.run_cpp_lint $(CPP_SOURCES_NO_CUDA)) - -# Not part of `make format`: clang-tidy --fix rewrites logic (renames identifiers, -# changes expressions, adds/removes keywords), not just style. Run manually and -# review the diff before committing. Note: --fix cannot auto-repair every check; -# some violations require manual edits. -# --extra-arg=-Wno-ignored-optimization-argument suppresses GCC-specific LTO flags -# (-fno-fat-lto-objects, -flto=auto) that cmake writes into compile_commands.json. -# clang-tidy forwards compiler warnings via clang-diagnostic-*, and .clang-tidy sets -# WarningsAsErrors: '*', so the warning must be silenced at the compiler level before -# clang-tidy ever sees it. -fix_lint_cpp: build_cpp_extensions - $(if $(CPP_SOURCES_NO_CUDA),clang-tidy-15 --fix --extra-arg=-Wno-ignored-optimization-argument -p gigl-core/.cache/cmake_build/compile_commands.json $(CPP_SOURCES_NO_CUDA)) +check_lint_cpp: + $(MAKE) -C gigl-core check_lint_cpp + +fix_lint_cpp: + $(MAKE) -C gigl-core fix_lint_cpp lint_test: check_format assert_yaml_configs_parse check_lint_cpp @echo "Lint checks pass!" diff --git a/docs/cpp_style_guide.md b/docs/cpp_style_guide.md index 0a9e684ab..cb38358d0 100644 --- a/docs/cpp_style_guide.md +++ b/docs/cpp_style_guide.md @@ -163,7 +163,7 @@ Enforced via `readability-identifier-naming`: | Option | Value | Effect | | ---------------------------------------------------------- | ---------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | | `WarningsAsErrors` | `*` | Every check failure is a hard error in CI | -| `HeaderFilterRegex` | `.*/gigl-core/csrc/.*` | Scopes checks to our own headers. Using `.*` causes clang-tidy to report warnings from every PyTorch/pybind11 header it parses, flooding output with thousands of third-party issues. | +| `HeaderFilterRegex` | `.*/gigl-core/core/.*` | Scopes checks to our own headers. Using `.*` causes clang-tidy to report warnings from every PyTorch/pybind11 header it parses, flooding output with thousands of third-party issues. | | `FormatStyle` | `none` | clang-tidy does not auto-reformat; use clang-format separately | | `bugprone-string-constructor.LargeLengthThreshold` | `8388608` (8 MB) | Strings larger than 8 MB from a length argument are flagged | | `modernize-loop-convert.NamingStyle` | `camelBack` | Auto-generated loop variable names use camelBack, matching `readability-identifier-naming.VariableCase` | @@ -174,7 +174,7 @@ ______________________________________________________________________ ## pybind11 Extension Modules -Extension modules live under `gigl-core/csrc/`. +Extension modules live under `gigl-core/core/`. ### Naming convention @@ -184,11 +184,11 @@ Extension modules live under `gigl-core/csrc/`. | `.cpp` / `.cu` | Implementation — function and class definitions | | `.h` | Declarations (function signatures, class definitions, constants) | -Example: to add a `my_op` extension under `gigl-core/csrc/sampling/`: +Example: to add a `my_op` extension under `gigl-core/core/sampling/`: ``` -gigl-core/csrc/sampling/python_my_op.cpp ← pybind11 bindings -gigl-core/csrc/sampling/my_op.cpp ← implementation +gigl-core/core/sampling/python_my_op.cpp ← pybind11 bindings +gigl-core/core/sampling/my_op.cpp ← implementation ``` The compiled `.so` is installed into the `gigl_core` package and importable as `gigl_core.`. diff --git a/gigl-core/.clangd b/gigl-core/.clangd new file mode 100644 index 000000000..8a771fd30 --- /dev/null +++ b/gigl-core/.clangd @@ -0,0 +1,5 @@ +# Point clangd at the test compilation database rather than the default cmake_build one. +# The test database includes both the extension modules and the test binaries, so clangd +# can resolve gtest headers for test files alongside production headers. +CompileFlags: + CompilationDatabase: .cache/cpp_tests diff --git a/gigl-core/CMakeLists.txt b/gigl-core/CMakeLists.txt index 5b0e6d9f2..3e39eb669 100644 --- a/gigl-core/CMakeLists.txt +++ b/gigl-core/CMakeLists.txt @@ -26,19 +26,19 @@ endif() # --------------------------------------------------------------------------- # Extension modules — auto-discovered. -# Files named python_*.cpp under csrc/ are compiled as pybind11 extension +# Files named python_*.cpp under core/ are compiled as pybind11 extension # modules. The companion .cpp (without the "python_" prefix) is included # automatically when present. Add a new extension by dropping source files # here; no changes to this CMakeLists.txt are needed. # --------------------------------------------------------------------------- if(CMAKE_CUDA_COMPILER) file(GLOB_RECURSE _PYTHON_SRCS CONFIGURE_DEPENDS - "${CMAKE_CURRENT_SOURCE_DIR}/csrc/python_*.cpp" - "${CMAKE_CURRENT_SOURCE_DIR}/csrc/python_*.cu" + "${CMAKE_CURRENT_SOURCE_DIR}/core/python_*.cpp" + "${CMAKE_CURRENT_SOURCE_DIR}/core/python_*.cu" ) else() file(GLOB_RECURSE _PYTHON_SRCS CONFIGURE_DEPENDS - "${CMAKE_CURRENT_SOURCE_DIR}/csrc/python_*.cpp" + "${CMAKE_CURRENT_SOURCE_DIR}/core/python_*.cpp" ) endif() diff --git a/gigl-core/Makefile b/gigl-core/Makefile index f3f63c7b8..2cbc48bd6 100644 --- a/gigl-core/Makefile +++ b/gigl-core/Makefile @@ -2,7 +2,7 @@ # Invoked from the GiGL repo root via: $(MAKE) -C gigl-core # All paths below are relative to gigl-core/. -CPP_SOURCES := $(shell find csrc \( -name "*.cpp" -o -name "*.cu" \) 2>/dev/null) +CPP_SOURCES := $(shell find core \( -name "*.cpp" -o -name "*.cu" \) 2>/dev/null) # clang-tidy 15 does not fully support CUDA syntax (e.g. <<<...>>>, __global__). # Exclude .cu files from tidy targets; clang-format and clangd handle them fine. CPP_SOURCES_NO_CUDA := $(filter-out %.cu,$(CPP_SOURCES)) @@ -12,10 +12,12 @@ CPP_SOURCES_NO_CUDA := $(filter-out %.cu,$(CPP_SOURCES)) # pyproject.toml, make skips the reinstall unless something actually changed. # We cd to the repo root so that no-build-isolation-package in the root pyproject.toml # is respected by uv pip install. -.cache/cmake_build/CMakeInit.txt: $(shell find csrc \( -name '*.cpp' -o -name '*.cu' -o -name '*.h' -o -name '*.cuh' \) 2>/dev/null) CMakeLists.txt pyproject.toml +.cache/cmake_build/CMakeInit.txt: $(shell find core \( -name '*.cpp' -o -name '*.cu' -o -name '*.h' -o -name '*.cuh' \) 2>/dev/null) CMakeLists.txt pyproject.toml cd $(abspath $(CURDIR)/..) && uv pip install -e gigl-core/ -build_cpp_extensions: .cache/cmake_build/CMakeInit.txt +# Also depend on the test cmake configure so that .cache/cpp_tests/compile_commands.json +# is generated automatically, giving clangd visibility into test files and gtest headers. +build_cpp_extensions: .cache/cmake_build/CMakeInit.txt .cache/cpp_tests/.configured .cache/cpp_tests/.configured: CMakeLists.txt tests/CMakeLists.txt .cache/cmake_build/CMakeInit.txt cmake -C .cache/cmake_build/CMakeInit.txt -S . -B .cache/cpp_tests -DGIGL_CORE_BUILD_TESTS=ON @@ -25,13 +27,26 @@ unit_test_cpp: .cache/cpp_tests/.configured cmake --build .cache/cpp_tests --parallel ctest --test-dir .cache/cpp_tests --output-on-failure -# TODO: Remove the $(if ...) guards once C++ source files are permanently present in the -# repo. The guards exist to silently no-op on branches that have no python_*.cpp files yet. check_format_cpp: - $(if $(CPP_SOURCES),clang-format-15 --dry-run --Werror --style=file $(CPP_SOURCES)) + clang-format-15 --dry-run --Werror --style=file $(CPP_SOURCES) format_cpp: - $(if $(CPP_SOURCES),clang-format-15 -i --style=file $(CPP_SOURCES)) + clang-format-15 -i --style=file $(CPP_SOURCES) + +# Not part of `make format`: clang-tidy --fix rewrites logic (renames identifiers, +# changes expressions, adds/removes keywords), not just style. Run manually and +# review the diff before committing. Note: --fix cannot auto-repair every check; +# some violations require manual edits. +# --extra-arg=-Wno-ignored-optimization-argument suppresses GCC-specific LTO flags +# (-fno-fat-lto-objects, -flto=auto) that cmake writes into compile_commands.json. +# clang-tidy forwards compiler warnings via clang-diagnostic-*, and .clang-tidy sets +# WarningsAsErrors: '*', so the warning must be silenced at the compiler level before +# clang-tidy ever sees it. +check_lint_cpp: build_cpp_extensions + cd $(abspath $(CURDIR)/..) && uv run python gigl-core/scripts/run_cpp_lint.py $(addprefix gigl-core/,$(CPP_SOURCES_NO_CUDA)) + +fix_lint_cpp: build_cpp_extensions + clang-tidy-15 --fix --extra-arg=-Wno-ignored-optimization-argument -p .cache/cmake_build/compile_commands.json $(CPP_SOURCES_NO_CUDA) # Wipe cmake build caches. Use this if cmake's cached state becomes inconsistent # after switching between branches with substantially different CMakeLists.txt structure. @@ -43,4 +58,4 @@ clean_build_files_cpp: # Declare targets as phony so make always runs their recipes, even if a file or # directory with the same name happens to exist on disk. -.PHONY: build_cpp_extensions unit_test_cpp check_format_cpp format_cpp clean_cpp clean_build_files_cpp +.PHONY: build_cpp_extensions unit_test_cpp check_format_cpp format_cpp check_lint_cpp fix_lint_cpp clean_cpp clean_build_files_cpp diff --git a/gigl-core/core/sampling/ppr_forward_push.cpp b/gigl-core/core/sampling/ppr_forward_push.cpp new file mode 100644 index 000000000..9a2a17f03 --- /dev/null +++ b/gigl-core/core/sampling/ppr_forward_push.cpp @@ -0,0 +1,360 @@ +#include "ppr_forward_push.h" + +#include + +#include +#include +#include +#include +#include +#include +#include + +namespace gigl { + +// Pack (node_id, etype_id) into a single uint64 for use as a hash key. +// Inputs are cast through uint32_t to avoid sign-extension of negative int32 values. +static uint64_t packKey(int32_t nodeId, int32_t edgeTypeId) { + return (static_cast(static_cast(nodeId)) << 32) | static_cast(edgeTypeId); +} + +PPRForwardPush::PPRForwardPush(const torch::Tensor& seedNodes, + int32_t seedNodeTypeId, + double alpha, + double requeueThresholdFactor, + std::vector> nodeTypeToEdgeTypeIds, + std::vector edgeTypeToDstNtypeId, + std::vector degreeTensors) + : _alpha(alpha), + _requeueThresholdFactor(requeueThresholdFactor), + // std::move transfers ownership of each vector into the member variable + // without copying its contents — equivalent to Python's list hand-off + // when you no longer need the original. + _nodeTypeToEdgeTypeIds(std::move(nodeTypeToEdgeTypeIds)), + _edgeTypeToDstNtypeId(std::move(edgeTypeToDstNtypeId)), + _degreeTensors(std::move(degreeTensors)) { + TORCH_CHECK(seedNodes.dim() == 1, "seedNodes must be 1D"); + // int32_t is sufficient: batch sizes approaching 2B seeds are not a realistic concern. + _batchSize = static_cast(seedNodes.size(0)); + _numNodeTypes = static_cast(_nodeTypeToEdgeTypeIds.size()); + + TORCH_CHECK(seedNodeTypeId >= 0, "seedNodeTypeId ", seedNodeTypeId, " is negative."); + TORCH_CHECK( + seedNodeTypeId < _numNodeTypes, "seedNodeTypeId ", seedNodeTypeId, " out of range [0, ", _numNodeTypes, ")."); + auto numEdgeTypes = static_cast(_edgeTypeToDstNtypeId.size()); + for (int32_t edgeTypeId = 0; edgeTypeId < numEdgeTypes; ++edgeTypeId) { + int32_t dstNodeTypeId = _edgeTypeToDstNtypeId[edgeTypeId]; + TORCH_CHECK(dstNodeTypeId >= 0, "edgeTypeToDstNtypeId[", edgeTypeId, "] = ", dstNodeTypeId, " is negative."); + TORCH_CHECK(dstNodeTypeId < _numNodeTypes, + "edgeTypeToDstNtypeId[", + edgeTypeId, + "] = ", + dstNodeTypeId, + " out of range [0, ", + _numNodeTypes, + ")."); + } + for (int32_t nodeTypeId = 0; nodeTypeId < _numNodeTypes; ++nodeTypeId) { + for (int32_t edgeTypeId : _nodeTypeToEdgeTypeIds[nodeTypeId]) { + TORCH_CHECK(edgeTypeId >= 0, + "nodeTypeToEdgeTypeIds[", + nodeTypeId, + "] contains negative edge type id ", + edgeTypeId, + "."); + TORCH_CHECK(edgeTypeId < numEdgeTypes, + "nodeTypeToEdgeTypeIds[", + nodeTypeId, + "] contains edge type id ", + edgeTypeId, + " out of range [0, ", + numEdgeTypes, + ")."); + } + } + + // Allocate per-seed, per-node-type state. + // .assign(n, val) fills a vector with n independent copies of val — like [val for _ in range(n)] in Python. + _state.assign(_batchSize, std::vector(_numNodeTypes)); + + // accessor() returns a typed view into the tensor's data that + // supports [i] indexing with bounds checking in debug builds. + auto seedNodeAcc = seedNodes.accessor(); + _numNodesInQueue = _batchSize; + for (int32_t seedIdx = 0; seedIdx < _batchSize; ++seedIdx) { + auto seedNodeId = static_cast(seedNodeAcc[seedIdx]); + // PPR initialisation: each seed starts with residual = alpha (the + // restart probability). The first push will move alpha into ppr_score + // and distribute (1-alpha)*alpha to the seed's neighbors. + _state[seedIdx][seedNodeTypeId].residuals[seedNodeId] = _alpha; + _state[seedIdx][seedNodeTypeId].queue.insert(seedNodeId); + } +} + +std::optional> PPRForwardPush::drainQueue() { + if (_numNodesInQueue == 0) { + return std::nullopt; + } + + // Reset the snapshot from the previous iteration. + // TODO: if this loop becomes a bottleneck, consider parallelising with + // std::for_each(std::execution::par_unseq, ...) or adding vectorisation hints. + for (auto& perSeedState : _state) { + for (auto& nodeTypeState : perSeedState) { + nodeTypeState.queuedNodes.clear(); + } + } + + // nodesToLookup[edgeTypeId] = set of node IDs that need a neighbor fetch for + // edge type edgeTypeId this round. Using a set deduplicates nodes that appear + // in multiple seeds' queues: we only fetch each (node, etype) pair once. + std::unordered_map> nodesToLookup; + + // TODO: For homogeneous graphs _numNodeTypes == 1, so the inner loop always + // executes exactly once (nodeTypeId=0). std::vector indexing is cheap, but a + // dedicated homogeneous code path could eliminate the loop entirely. Profile + // before splitting. + for (int32_t seedIdx = 0; seedIdx < _batchSize; ++seedIdx) { + for (int32_t nodeTypeId = 0; nodeTypeId < _numNodeTypes; ++nodeTypeId) { + auto& seedNodeTypeState = _state[seedIdx][nodeTypeId]; + if (seedNodeTypeState.queue.empty()) { + continue; + } + + // Move the live queue into the snapshot in O(1) — avoids copying all node IDs. + // The explicit clear() after move is defensive: the standard only guarantees + // a moved-from container is "valid but unspecified", not necessarily empty. + seedNodeTypeState.queuedNodes = std::move(seedNodeTypeState.queue); + seedNodeTypeState.queue.clear(); + _numNodesInQueue -= static_cast(seedNodeTypeState.queuedNodes.size()); + + for (int32_t nodeId : seedNodeTypeState.queuedNodes) { + for (int32_t edgeTypeId : _nodeTypeToEdgeTypeIds[nodeTypeId]) { + if (_neighborCache.find(packKey(nodeId, edgeTypeId)) == _neighborCache.end()) { + nodesToLookup[edgeTypeId].insert(nodeId); + } + } + } + } + } + + std::unordered_map result; + for (const auto& [edgeTypeId, nodeSet] : nodesToLookup) { + std::vector nodeIdsToLookup(nodeSet.begin(), nodeSet.end()); + result[edgeTypeId] = torch::tensor(nodeIdsToLookup, torch::kLong); + } + return result; +} + +void PPRForwardPush::pushResiduals( + const std::unordered_map>& fetchedByEtypeId) { + // Step 1: Unpack the input map into a C++ map keyed by packKey(nodeId, edgeTypeId) + // for fast lookup during the residual-push loop below. + std::unordered_map> fetched; + for (const auto& [edgeTypeId, neighborTensors] : fetchedByEtypeId) { + const auto& nodeIdsTensor = std::get<0>(neighborTensors); + const auto& flatNeighborIdsTensor = std::get<1>(neighborTensors); + const auto& countsTensor = std::get<2>(neighborTensors); + + // accessor() gives a bounds-checked, typed 1-D view into + // each tensor's data — equivalent to iterating over a NumPy array. + auto nodeIdsAccessor = nodeIdsTensor.accessor(); + auto flatNeighborIdsAccessor = flatNeighborIdsTensor.accessor(); + auto countsAccessor = countsTensor.accessor(); + + // Walk the flat neighbor list, slicing out each node's neighbors using + // the running offset into the concatenated flat buffer. + int64_t offset = 0; + for (int64_t nodeIdx = 0; nodeIdx < nodeIdsTensor.size(0); ++nodeIdx) { + auto nodeId = static_cast(nodeIdsAccessor[nodeIdx]); + int64_t count = countsAccessor[nodeIdx]; + std::vector neighborIds(count); + for (int64_t neighborIdx = 0; neighborIdx < count; ++neighborIdx) { + neighborIds[neighborIdx] = static_cast(flatNeighborIdsAccessor[offset + neighborIdx]); + } + fetched[packKey(nodeId, edgeTypeId)] = std::move(neighborIds); + offset += count; + } + } + + // Step 2: For every node that was in the queue (captured in _queuedNodes + // by drainQueue()), apply one PPR push step: + // a. Absorb residual into the PPR score. + // b. Distribute (1-alpha) * residual equally to each neighbor. + // c. Enqueue any neighbor whose residual now exceeds the requeue threshold. + for (int32_t seedIdx = 0; seedIdx < _batchSize; ++seedIdx) { + for (int32_t nodeTypeId = 0; nodeTypeId < _numNodeTypes; ++nodeTypeId) { + auto& srcNodeTypeState = _state[seedIdx][nodeTypeId]; + if (srcNodeTypeState.queuedNodes.empty()) { + continue; + } + + for (int32_t sourceNodeId : srcNodeTypeState.queuedNodes) { + auto residualIter = srcNodeTypeState.residuals.find(sourceNodeId); + double sourceResidual = (residualIter != srcNodeTypeState.residuals.end()) ? residualIter->second : 0.0; + + // a. Absorb: move residual into the PPR score. + srcNodeTypeState.pprScores[sourceNodeId] += sourceResidual; + srcNodeTypeState.residuals[sourceNodeId] = 0.0; + + // b. Count total fetched/cached neighbors across all edge types for + // this source node. We normalise by the number of neighbors we + // actually retrieved, not the true degree, so residual is fully + // distributed among known neighbors rather than leaking to unfetched + // ones (which matters when num_neighbors_per_hop < true_degree). + int32_t totalFetched = 0; + for (int32_t edgeTypeId : _nodeTypeToEdgeTypeIds[nodeTypeId]) { + auto fetchedEntry = fetched.find(packKey(sourceNodeId, edgeTypeId)); + if (fetchedEntry != fetched.end()) { + totalFetched += static_cast(fetchedEntry->second.size()); + } else { + auto cachedEntry = _neighborCache.find(packKey(sourceNodeId, edgeTypeId)); + if (cachedEntry != _neighborCache.end()) { + totalFetched += static_cast(cachedEntry->second.size()); + } + } + } + // Two cases reach here: + // 1. True sink node (no outgoing edges): absorbing the full residual is correct. + // 2. Budget exhausted, no cache entry: the (1-α)·r that should flow to + // neighbors has nowhere to go, so it gets absorbed into src's score instead. + // This overstates src and understates its neighbors. This is expected + // behavior when max_fetch_iterations is set, which intentionally trades + // theoretical PPR correctness for better throughput. + if (totalFetched == 0) { + continue; + } + + double residualPerNeighbor = (1.0 - _alpha) * sourceResidual / static_cast(totalFetched); + + for (int32_t edgeTypeId : _nodeTypeToEdgeTypeIds[nodeTypeId]) { + // Invariant: fetched and _neighborCache are mutually exclusive for + // any given (node, etype) key within one iteration. drainQueue() + // only requests a fetch for nodes absent from _neighborCache, so a + // key is in at most one of the two. + // + // Neighbor list for this (src, edgeTypeId) pair, borrowed from whichever + // map holds it. reference_wrapper is used because std::optional cannot + // hold a reference directly, and we want to avoid copying the vector — + // the data already exists in fetched or _neighborCache and both outlive + // this loop body. Access via neighborList->get(). + std::optional>> neighborList; + auto fetchedEntry = fetched.find(packKey(sourceNodeId, edgeTypeId)); + if (fetchedEntry != fetched.end()) { + neighborList = std::cref(fetchedEntry->second); + } else { + auto cachedEntry = _neighborCache.find(packKey(sourceNodeId, edgeTypeId)); + if (cachedEntry != _neighborCache.end()) { + neighborList = std::cref(cachedEntry->second); + } + } + if (!neighborList || neighborList->get().empty()) { + continue; + } + + int32_t dstNodeTypeId = _edgeTypeToDstNtypeId[edgeTypeId]; + + // c. Accumulate residual for each neighbor and re-enqueue if threshold + // exceeded. + auto& dstNodeTypeState = _state[seedIdx][dstNodeTypeId]; + for (int32_t neighborNodeId : neighborList->get()) { + dstNodeTypeState.residuals[neighborNodeId] += residualPerNeighbor; + + double threshold = _requeueThresholdFactor * + static_cast(getTotalDegree(neighborNodeId, dstNodeTypeId)); + + if (dstNodeTypeState.queue.find(neighborNodeId) == dstNodeTypeState.queue.end() && + dstNodeTypeState.residuals[neighborNodeId] >= threshold) { + dstNodeTypeState.queue.insert(neighborNodeId); + ++_numNodesInQueue; + + // Promote neighbor lists to the persistent cache: this node will + // be processed next iteration, so caching avoids a re-fetch. + for (int32_t neighborEdgeTypeId : _nodeTypeToEdgeTypeIds[dstNodeTypeId]) { + uint64_t packedKey = packKey(neighborNodeId, neighborEdgeTypeId); + if (_neighborCache.find(packedKey) == _neighborCache.end()) { + auto fetchedNeighborEntry = fetched.find(packedKey); + if (fetchedNeighborEntry != fetched.end()) { + _neighborCache[packedKey] = fetchedNeighborEntry->second; + } + } + } + } + } + } + } + } + } +} + +std::unordered_map> PPRForwardPush::extractTopK( + int32_t maxPprNodes) { + std::unordered_map> result; + // Emit an entry for every node type, even if unreachable in this batch (empty tensors, + // all-zero valid_counts). This keeps the output shape consistent across batches so + // downstream model architectures see a fixed set of PPR edge types every iteration. + for (int32_t nodeTypeId = 0; nodeTypeId < _numNodeTypes; ++nodeTypeId) { + std::vector flatIds; + std::vector flatWeights; + std::vector validCounts; + + for (int32_t seedIdx = 0; seedIdx < _batchSize; ++seedIdx) { + const auto& scores = _state[seedIdx][nodeTypeId].pprScores; + int32_t topK = std::min(maxPprNodes, static_cast(scores.size())); + if (topK > 0) { + std::vector> scorePairs(scores.begin(), scores.end()); + std::partial_sort(scorePairs.begin(), + scorePairs.begin() + topK, + scorePairs.end(), + [](const auto& a, const auto& b) { return a.second > b.second; }); + + for (int32_t rankIdx = 0; rankIdx < topK; ++rankIdx) { + flatIds.push_back(static_cast(scorePairs[rankIdx].first)); + flatWeights.push_back(scorePairs[rankIdx].second); + } + } + validCounts.push_back(static_cast(topK)); + } + + result[nodeTypeId] = {torch::tensor(flatIds, torch::kLong), + torch::tensor(flatWeights, torch::kDouble), + torch::tensor(validCounts, torch::kLong)}; + } + return result; +} + +int32_t PPRForwardPush::getTotalDegree(int32_t nodeId, int32_t nodeTypeId) const { + TORCH_CHECK(nodeTypeId >= 0, "nodeTypeId ", nodeTypeId, " is negative, which indicates a sampler bug."); + TORCH_CHECK(nodeTypeId < static_cast(_degreeTensors.size()), + "nodeTypeId ", + nodeTypeId, + " out of range [0, ", + _degreeTensors.size(), + "). This indicates a construction bug in the sampler."); + const auto& degreeTensor = _degreeTensors[nodeTypeId]; + if (degreeTensor.numel() == 0) { + return 0; + } + TORCH_CHECK(nodeId >= 0, "Node ID ", nodeId, " is negative, which indicates a sampler bug."); + TORCH_CHECK(nodeId < static_cast(degreeTensor.size(0)), + "Node ID ", + nodeId, + " out of range for degree tensor of ntype_id ", + nodeTypeId, + " (size=", + degreeTensor.size(0), + "). This indicates corrupted graph data or a sampler bug."); + if (degreeTensor.scalar_type() == torch::kInt) { + return degreeTensor.data_ptr()[nodeId]; + } + if (degreeTensor.scalar_type() == torch::kLong) { + return static_cast(std::min(degreeTensor.data_ptr()[nodeId], INT32_MAX)); + } + TORCH_CHECK(false, + "Unsupported degree tensor dtype: ", + degreeTensor.scalar_type(), + ". Expected torch.int32 or torch.int64."); + return 0; // unreachable; suppresses compiler warning +} + +} // namespace gigl diff --git a/gigl-core/core/sampling/ppr_forward_push.h b/gigl-core/core/sampling/ppr_forward_push.h new file mode 100644 index 000000000..1c1eef670 --- /dev/null +++ b/gigl-core/core/sampling/ppr_forward_push.h @@ -0,0 +1,108 @@ +#pragma once + +#include + +#include +#include +#include +#include +#include +#include +#include + +namespace gigl { + +// Per-seed, per-node-type PPR algorithm state. +// Grouping all four tables into one struct is a logical convenience: a single +// _state[seedIdx][nodeTypeId] access reaches all four tables for a given (seed, ntype) +// pair, rather than indexing four separate 2D arrays. Note that unordered_map and +// unordered_set heap-allocate their bucket storage, so the actual key-value data is +// not co-located in memory — only the control-plane metadata (size, bucket pointer) +// lives inside the struct. +struct SeedNodeTypeState { + std::unordered_map pprScores; // absorbed PPR mass + std::unordered_map residuals; // unabsorbed mass waiting to push + std::unordered_set queue; // nodes queued for the next drain + std::unordered_set queuedNodes; // snapshot captured by drainQueue() +}; + +// C++ kernel for PPR Forward Push (Andersen et al., 2006). +// Hot-loop state lives here; distributed neighbor fetches are driven from Python. +// +// Call sequence per batch: +// 1. PPRForwardPush(seedNodes, ...) +// while True: +// 2. drainQueue() → nodes needing neighbor lookup +// 3. +// 4. pushResiduals(fetchedByEtypeId) +// 5. extractTopK(maxPprNodes) +class PPRForwardPush { + public: + PPRForwardPush(const torch::Tensor& seedNodes, + int32_t seedNodeTypeId, + double alpha, + double requeueThresholdFactor, + std::vector> nodeTypeToEdgeTypeIds, + std::vector edgeTypeToDstNtypeId, + std::vector degreeTensors); + + // Drain queued nodes and return {etype_id: int64 node tensor} for neighbor lookup. + // Returns nullopt when the queue is empty (convergence). Empty map means all nodes + // were cache-hits; call pushResiduals({}) to continue. + std::optional> drainQueue(); + + // Push residuals given fetched neighbor data. + // fetchedByEtypeId: {etype_id: (node_ids[N], flat_nbrs[sum(counts)], counts[N])} + void pushResiduals(const std::unordered_map< + int32_t, std::tuple>& + fetchedByEtypeId); + + // Return top-k PPR nodes per seed per node type. + // Result: {ntype_id: (flat_ids, flat_weights, valid_counts)} — one entry per node type, + // including types unreachable in this batch (empty tensors, all-zero valid_counts). + std::unordered_map> + extractTopK(int32_t maxPprNodes); + + private: + // Total out-degree of a node across all edge types. Returns 0 for sink nodes. + [[nodiscard]] int32_t getTotalDegree(int32_t nodeId, int32_t nodeTypeId) const; + + double _alpha; + double _requeueThresholdFactor; // alpha * eps; per-node requeue threshold = factor * degree + + // NOTE: int32_t is used for batch size, node IDs, and type IDs throughout this class. + // All of this code will break silently (overflow) if batch size or node IDs exceed ~2B + // (INT32_MAX = 2,147,483,647). This is not a realistic concern today, but if graph + // scale ever approaches that threshold, these should be widened to int64_t. + int32_t _batchSize; // number of seed nodes in the current batch + int32_t _numNodeTypes; // total distinct node types (1 for homogeneous graphs) + int32_t _numNodesInQueue{0}; // running count of queued nodes across all seeds and types + + // Graph structure — set at construction, read-only during the algorithm. + // _nodeTypeToEdgeTypeIds[ntype_id] → list of edge type IDs that originate from that node type. + // _edgeTypeToDstNtypeId[etype_id] → destination node type ID for that edge type. + // _degreeTensors[ntype_id] → int32 tensor of total out-degrees, indexed by node ID. + std::vector> _nodeTypeToEdgeTypeIds; + std::vector _edgeTypeToDstNtypeId; + std::vector _degreeTensors; + + // Per-seed, per-node-type PPR state. Indexed as _state[seedIdx][nodeTypeId]. + // 2D vector: both dimensions are dense sequential integers bounded at construction, + // so array indexing is O(1) with no hashing (contrast with _neighborCache below). + // + // int32_t is used for node and type IDs throughout to match PyG/GLT's signed-integer + // convention (torch.int32 / torch.int64). Signed types also make nodeId >= 0 checks + // meaningful — an unsigned type would make that guard tautological. + // + // Sized [_batchSize][_numNodeTypes] at construction and never resized, + // so [seedIdx][nodeTypeId] indexing is always safe within the loop bounds. + std::vector> _state; + + // Neighbor lists keyed by packKey(nodeId, edgeTypeId). + // Hash map: nodeId is a sparse graph ID from a large graph, so a dense array is + // impractical (contrast with _state above). Populated incrementally; avoids re-fetching. + std::unordered_map> _neighborCache; + +}; + +} // namespace gigl diff --git a/gigl-core/core/sampling/python_ppr_forward_push.cpp b/gigl-core/core/sampling/python_ppr_forward_push.cpp new file mode 100644 index 000000000..22981a48a --- /dev/null +++ b/gigl-core/core/sampling/python_ppr_forward_push.cpp @@ -0,0 +1,61 @@ +// Python bindings for PPRForwardPush. +// +// Pure C++ algorithm lives in ppr_forward_push.{h,cpp}; this file only handles +// type conversion between Python (pybind11) and C++ types, then delegates to +// the C++ implementation. + +#include +#include + +#include +#include +#include + +#include "ppr_forward_push.h" + +namespace py = pybind11; + +namespace gigl { + +// pushResiduals: a wrapper is needed solely to release the GIL during the C++ push. +// pybind11/stl.h handles all type conversions automatically; the other methods use +// direct member function pointers for the same reason. +static void pushResidualsWrapper(PPRForwardPush& state, const py::dict& fetchedByEtypeId) { + std::unordered_map> neighborTensorsByEtypeId; + // Dict iteration touches Python objects — GIL must be held here. + for (auto item : fetchedByEtypeId) { + auto edgeTypeId = item.first.cast(); + auto neighborTensors = item.second.cast(); + neighborTensorsByEtypeId[edgeTypeId] = {neighborTensors[0].cast(), + neighborTensors[1].cast(), + neighborTensors[2].cast()}; + } + // C++ push only uses tensor accessor/data_ptr APIs — GIL-safe to release. + // Releasing here lets the asyncio event loop process RPC completion callbacks + // from other concurrent PPR coroutines while this push runs. + // REQUIREMENT: no other thread may read or modify neighborTensorsByEtypeId or + // the underlying tensor data while the GIL is released. The caller (Python) + // must not alias or mutate fetchedByEtypeId until push_residuals returns. + { + py::gil_scoped_release release; + state.pushResiduals(neighborTensorsByEtypeId); + } +} + +} // namespace gigl + +// TORCH_EXTENSION_NAME is set by PyTorch's build system to match the Python +// module name derived from this file's path (e.g. "ppr_forward_push"). +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + py::class_(m, "PPRForwardPush") + .def(py::init>, + std::vector, + std::vector>()) + .def("drain_queue", &gigl::PPRForwardPush::drainQueue) + .def("push_residuals", gigl::pushResidualsWrapper) + .def("extract_top_k", &gigl::PPRForwardPush::extractTopK); +} diff --git a/gigl-core/pyproject.toml b/gigl-core/pyproject.toml index 8fc595572..f7b159954 100644 --- a/gigl-core/pyproject.toml +++ b/gigl-core/pyproject.toml @@ -29,8 +29,8 @@ editable.rebuild = false cache-keys = [ { file = "pyproject.toml" }, { file = "CMakeLists.txt" }, - { file = "csrc/**/*.h" }, - { file = "csrc/**/*.cpp" }, - { file = "csrc/**/*.cu" }, - { file = "csrc/**/*.cuh" }, + { file = "core/**/*.h" }, + { file = "core/**/*.cpp" }, + { file = "core/**/*.cu" }, + { file = "core/**/*.cuh" }, ] diff --git a/scripts/run_cpp_lint.py b/gigl-core/scripts/run_cpp_lint.py similarity index 92% rename from scripts/run_cpp_lint.py rename to gigl-core/scripts/run_cpp_lint.py index 7e2db01c9..32cebd9e2 100644 --- a/scripts/run_cpp_lint.py +++ b/gigl-core/scripts/run_cpp_lint.py @@ -2,7 +2,7 @@ Runs clangd --check on each file in parallel and prints a clean summary. Expects compile_commands.json to already exist at -gigl-core/.cache/cmake_build/compile_commands.json; call +.cache/cmake_build/compile_commands.json; call ``make build_cpp_extensions`` first if it is absent or stale (``make check_lint_cpp`` does this automatically via a Makefile prerequisite). @@ -17,10 +17,8 @@ from concurrent.futures import ThreadPoolExecutor, as_completed from pathlib import Path -_REPO_ROOT = Path(__file__).resolve().parent.parent -COMPILE_COMMANDS = ( - _REPO_ROOT / "gigl-core" / ".cache" / "cmake_build" / "compile_commands.json" -) +_GIGL_CORE_ROOT = Path(__file__).resolve().parent.parent +COMPILE_COMMANDS = _GIGL_CORE_ROOT / ".cache" / "cmake_build" / "compile_commands.json" # Matches real clang-tidy diagnostics emitted by clangd: # E[HH:MM:SS.mmm] [check-name] Line N: message diff --git a/gigl-core/src/gigl_core/__init__.py b/gigl-core/src/gigl_core/__init__.py index e69de29bb..524135619 100644 --- a/gigl-core/src/gigl_core/__init__.py +++ b/gigl-core/src/gigl_core/__init__.py @@ -0,0 +1,3 @@ +from gigl_core.ppr_forward_push import PPRForwardPush + +__all__ = ["PPRForwardPush"] diff --git a/gigl-core/src/gigl_core/ppr_forward_push.pyi b/gigl-core/src/gigl_core/ppr_forward_push.pyi new file mode 100644 index 000000000..0c1ea79af --- /dev/null +++ b/gigl-core/src/gigl_core/ppr_forward_push.pyi @@ -0,0 +1,21 @@ +import torch + +class PPRForwardPush: + def __init__( + self, + seed_nodes: torch.Tensor, + seed_node_type_id: int, + alpha: float, + requeue_threshold_factor: float, + node_type_to_edge_type_ids: list[list[int]], + edge_type_to_dst_ntype_id: list[int], + degree_tensors: list[torch.Tensor], + ) -> None: ... + def drain_queue(self) -> dict[int, torch.Tensor] | None: ... + def push_residuals( + self, + fetched_by_etype_id: dict[int, tuple[torch.Tensor, torch.Tensor, torch.Tensor]], + ) -> None: ... + def extract_top_k( + self, max_ppr_nodes: int + ) -> dict[int, tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: ... diff --git a/gigl-core/src/gigl_core/py.typed b/gigl-core/src/gigl_core/py.typed new file mode 100644 index 000000000..23cb101e4 --- /dev/null +++ b/gigl-core/src/gigl_core/py.typed @@ -0,0 +1 @@ +PEP 561 marker. Presence of this file tells mypy that gigl_core supports type checking. diff --git a/gigl-core/tests/CMakeLists.txt b/gigl-core/tests/CMakeLists.txt index 74eac1f40..dbd669060 100644 --- a/gigl-core/tests/CMakeLists.txt +++ b/gigl-core/tests/CMakeLists.txt @@ -24,6 +24,29 @@ FetchContent_MakeAvailable(googletest) # Required for add_test() to register tests with CTest. enable_testing() +# --------------------------------------------------------------------------- +# Torch + kernel library (required by tests that use C++ kernels) +# --------------------------------------------------------------------------- +foreach(_prefix IN LISTS CMAKE_PREFIX_PATH) + if(NOT TORCH_CMAKE_PREFIX AND EXISTS "${_prefix}/torch/share/cmake") + set(TORCH_CMAKE_PREFIX "${_prefix}/torch/share/cmake") + endif() +endforeach() +find_package(Torch REQUIRED PATHS "${TORCH_CMAKE_PREFIX}") + +# Auto-discover all kernel sources under core/. python_*.cpp are pybind11 +# extension entry points that belong to the wheel, not to the test library. +if(CMAKE_CUDA_COMPILER) + file(GLOB_RECURSE _KERNEL_SRCS "${CMAKE_SOURCE_DIR}/core/*.cpp" "${CMAKE_SOURCE_DIR}/core/*.cu") +else() + file(GLOB_RECURSE _KERNEL_SRCS "${CMAKE_SOURCE_DIR}/core/*.cpp") +endif() +list(FILTER _KERNEL_SRCS EXCLUDE REGEX ".*/python_[^/]*$") + +add_library(gigl_core_kernels STATIC ${_KERNEL_SRCS}) +target_include_directories(gigl_core_kernels PUBLIC "${CMAKE_SOURCE_DIR}/core") +target_link_libraries(gigl_core_kernels PUBLIC "${TORCH_LIBRARIES}") + # --------------------------------------------------------------------------- # Auto-discover test targets # --------------------------------------------------------------------------- @@ -45,7 +68,7 @@ foreach(test_source ${TEST_SOURCES}) string(REPLACE "/" "_" test_name "${_rel}") string(REGEX REPLACE "\\.[^.]+$" "" test_name "${test_name}") add_executable(${test_name} ${test_source}) - target_link_libraries(${test_name} GTest::gtest_main) + target_link_libraries(${test_name} GTest::gtest_main gigl_core_kernels) # add_test registers the binary with CTest. Each *_test binary is one # CTest entry; GoogleTest itself reports individual TEST() results inside it. add_test(NAME ${test_name} COMMAND ${test_name}) diff --git a/gigl-core/tests/ppr_forward_push_test.cpp b/gigl-core/tests/ppr_forward_push_test.cpp new file mode 100644 index 000000000..604763deb --- /dev/null +++ b/gigl-core/tests/ppr_forward_push_test.cpp @@ -0,0 +1,138 @@ +#include +#include "sampling/ppr_forward_push.h" + +using gigl::PPRForwardPush; + +// Builds a single-edge-type, single-node-type PPRForwardPush. +static PPRForwardPush makeState( + const std::vector& seeds, + double alpha, + double requeueThresholdFactor, + const std::vector& degrees) { + return PPRForwardPush( + torch::tensor(seeds, torch::kLong), + /*seedNodeTypeId=*/0, + alpha, + requeueThresholdFactor, + /*nodeTypeToEdgeTypeIds=*/{{0}}, + /*edgeTypeToDstNtypeId=*/{0}, + {torch::tensor(degrees, torch::kInt)}); +} + +// Convenience wrapper: build the fetchedByEtypeId argument for pushResiduals +// from flat vectors, keeping test call sites readable. +static std::unordered_map> +makeFetched(int32_t edgeTypeId, + const std::vector& nodeIds, + const std::vector& flatNeighborIds, + const std::vector& counts) { + return {{edgeTypeId, + {torch::tensor(nodeIds, torch::kLong), + torch::tensor(flatNeighborIds, torch::kLong), + torch::tensor(counts, torch::kLong)}}}; +} + +// After construction, drainQueue() returns the seed node under etype 0. +TEST(PPRForwardPush, DrainQueueReturnsSeedNodeInitially) { + auto state = makeState(/*seeds=*/{0}, /*alpha=*/0.15, /*requeueThresholdFactor=*/1e-6, /*degrees=*/{1}); + auto result = state.drainQueue(); + ASSERT_TRUE(result.has_value()); + const auto& nodeMap = result.value(); + ASSERT_NE(nodeMap.find(0), nodeMap.end()); + EXPECT_EQ(nodeMap.at(0).size(0), 1); + EXPECT_EQ(nodeMap.at(0)[0].item(), 0); +} + +// After convergence (sink node absorbs all residual), drainQueue() returns nullopt. +TEST(PPRForwardPush, DrainQueueReturnsNulloptAfterConvergence) { + auto state = makeState(/*seeds=*/{0}, /*alpha=*/0.15, /*requeueThresholdFactor=*/1e-6, /*degrees=*/{0}); + state.drainQueue(); + state.pushResiduals({}); + EXPECT_FALSE(state.drainQueue().has_value()); +} + +// A sink seed node absorbs its full residual as PPR score (= alpha). +TEST(PPRForwardPush, PprScoreAbsorbsAlpha) { + const double alpha = 0.15; + auto state = makeState(/*seeds=*/{0}, alpha, /*requeueThresholdFactor=*/1e-6, /*degrees=*/{0}); + state.drainQueue(); + state.pushResiduals({}); + auto topk = state.extractTopK(10); + ASSERT_NE(topk.find(0), topk.end()); + const auto& [ids, weights, counts] = topk.at(0); + EXPECT_EQ(ids[0].item(), 0); + EXPECT_NEAR(weights[0].item(), static_cast(alpha), 1e-5F); +} + +// Node 0 (degree 1) pushes (1-alpha)*alpha residual to node 1 (sink). +TEST(PPRForwardPush, ResidualDistributedToNeighbor) { + const double alpha = 0.15; + auto state = makeState(/*seeds=*/{0}, alpha, /*requeueThresholdFactor=*/1e-6, /*degrees=*/{1, 0}); + + // Iteration 1: seed node 0 → neighbor node 1. + state.drainQueue(); + state.pushResiduals(makeFetched(/*edgeTypeId=*/0, /*nodeIds=*/{0}, /*flatNeighborIds=*/{1}, /*counts=*/{1})); + + // Iteration 2: node 1 is a sink; absorbs its residual, no further push. + state.drainQueue(); + state.pushResiduals({}); + + EXPECT_FALSE(state.drainQueue().has_value()); + + auto topk = state.extractTopK(10); + ASSERT_NE(topk.find(0), topk.end()); + const auto& [ids, weights, counts] = topk.at(0); + ASSERT_EQ(counts[0].item(), 2); + EXPECT_EQ(ids[0].item(), 0); + EXPECT_EQ(ids[1].item(), 1); + EXPECT_NEAR(weights[0].item(), static_cast(alpha), 1e-5F); + EXPECT_NEAR(weights[1].item(), static_cast((1.0 - alpha) * alpha), 1e-5F); +} + +// Two seeds (0 and 1) both push residual to sink node 2. The neighbor-lookup +// request must deduplicate to one entry for node 2, yet both seeds must still +// accumulate a PPR score for it. +TEST(PPRForwardPush, DeduplicatesNodesAcrossSeeds) { + auto state = makeState(/*seeds=*/{0, 1}, /*alpha=*/0.15, /*requeueThresholdFactor=*/1e-6, /*degrees=*/{1, 1, 0}); + + state.drainQueue(); + state.pushResiduals(makeFetched(/*edgeTypeId=*/0, /*nodeIds=*/{0, 1}, /*flatNeighborIds=*/{2, 2}, /*counts=*/{1, 1})); + + auto iter2 = state.drainQueue(); + ASSERT_TRUE(iter2.has_value()); + const auto& iter2Map = iter2.value(); + ASSERT_NE(iter2Map.find(0), iter2Map.end()); + EXPECT_EQ(iter2Map.at(0).size(0), 1); // node 2 deduplicated in the lookup request + + state.pushResiduals({}); + EXPECT_FALSE(state.drainQueue().has_value()); + + auto topk = state.extractTopK(10); + ASSERT_NE(topk.find(0), topk.end()); + const auto& [ids, weights, counts] = topk.at(0); + // Each seed (batch indices 0 and 1) should have 2 nodes in its top-k. + EXPECT_EQ(counts[0].item(), 2); // seed 0: nodes {0, 2} + EXPECT_EQ(counts[1].item(), 2); // seed 1: nodes {1, 2} + // The flat id layout is [seed0_top1, seed0_top2, seed1_top1, seed1_top2]. + // Within each seed the highest scorer comes first, so seed-node beats node 2. + EXPECT_EQ(ids[1].item(), 2); // seed 0's second node is node 2 + EXPECT_EQ(ids[3].item(), 2); // seed 1's second node is node 2 +} + +// extractTopK respects the maxPprNodes limit. +TEST(PPRForwardPush, ExtractTopKLimitsResults) { + auto state = makeState(/*seeds=*/{0}, /*alpha=*/0.15, /*requeueThresholdFactor=*/1e-6, /*degrees=*/{1, 0}); + + state.drainQueue(); + state.pushResiduals(makeFetched(/*edgeTypeId=*/0, /*nodeIds=*/{0}, /*flatNeighborIds=*/{1}, /*counts=*/{1})); + state.drainQueue(); + state.pushResiduals({}); + + auto topk1 = state.extractTopK(1); + ASSERT_NE(topk1.find(0), topk1.end()); + EXPECT_EQ(std::get<2>(topk1.at(0))[0].item(), 1); + + auto topk10 = state.extractTopK(10); + ASSERT_NE(topk10.find(0), topk10.end()); + EXPECT_EQ(std::get<2>(topk10.at(0))[0].item(), 2); +} diff --git a/gigl/dep_vars.env b/gigl/dep_vars.env index 4b28e38b7..6f6ee5584 100644 --- a/gigl/dep_vars.env +++ b/gigl/dep_vars.env @@ -1,7 +1,7 @@ # Note this file only supports static key value pairs so it can be loaded by make, bash, python, and sbt without any additional parsing. -DOCKER_LATEST_BASE_CUDA_IMAGE_NAME_WITH_TAG=us-central1-docker.pkg.dev/external-snap-ci-github-gigl/public-gigl/gigl-cuda-base:b598f3d72eee47f5513dcb39460944459a0a012f.108.1 -DOCKER_LATEST_BASE_CPU_IMAGE_NAME_WITH_TAG=us-central1-docker.pkg.dev/external-snap-ci-github-gigl/public-gigl/gigl-cpu-base:b598f3d72eee47f5513dcb39460944459a0a012f.108.1 -DOCKER_LATEST_BASE_DATAFLOW_IMAGE_NAME_WITH_TAG=us-central1-docker.pkg.dev/external-snap-ci-github-gigl/public-gigl/gigl-dataflow-base:b598f3d72eee47f5513dcb39460944459a0a012f.108.1 +DOCKER_LATEST_BASE_CUDA_IMAGE_NAME_WITH_TAG=us-central1-docker.pkg.dev/external-snap-ci-github-gigl/public-gigl/gigl-cuda-base:7d3182eeb6446ce3e35910babba990c8e003879d.109.1 +DOCKER_LATEST_BASE_CPU_IMAGE_NAME_WITH_TAG=us-central1-docker.pkg.dev/external-snap-ci-github-gigl/public-gigl/gigl-cpu-base:7d3182eeb6446ce3e35910babba990c8e003879d.109.1 +DOCKER_LATEST_BASE_DATAFLOW_IMAGE_NAME_WITH_TAG=us-central1-docker.pkg.dev/external-snap-ci-github-gigl/public-gigl/gigl-dataflow-base:7d3182eeb6446ce3e35910babba990c8e003879d.109.1 DEFAULT_GIGL_RELEASE_SRC_IMAGE_CUDA=us-central1-docker.pkg.dev/external-snap-ci-github-gigl/public-gigl/src-cuda:0.2.0 DEFAULT_GIGL_RELEASE_SRC_IMAGE_CPU=us-central1-docker.pkg.dev/external-snap-ci-github-gigl/public-gigl/src-cpu:0.2.0 diff --git a/gigl/distributed/dist_ppr_sampler.py b/gigl/distributed/dist_ppr_sampler.py index e09a8f3ff..83369d8c2 100644 --- a/gigl/distributed/dist_ppr_sampler.py +++ b/gigl/distributed/dist_ppr_sampler.py @@ -1,18 +1,12 @@ -# TODO (mkolodner-sc): The forward push loop in _compute_ppr_scores is the -# main throughput bottleneck — both the queue drain (preparing batched node -# lookups by edge type) and the residual push/requeue pass are pure Python -# dict/set operations in tight nested loops. Moving these to a C++ extension -# (e.g. pybind11) would eliminate per-operation Python overhead and enable -# cache-friendly memory access patterns. - -# TODO (mkolodner-sc): Investigate whether concurrency for _sample_one_hop and _compute_ppr_scores will -# yield performance benefits. - -import heapq +import asyncio from collections import defaultdict from typing import Optional, Union import torch + +# TODO: Once gigl_core has a stable Python interface, re-export PPRForwardPush +# under a gigl.core namespace rather than importing directly from the C++ extension. +from gigl_core import PPRForwardPush from graphlearn_torch.sampler import ( HeteroSamplerOutput, NeighborOutput, @@ -35,11 +29,6 @@ # Sentinel type names for homogeneous graphs. The PPR algorithm uses # dict[NodeType, ...] internally for both homo and hetero graphs; these # sentinels let the homogeneous path reuse the same dict-based code. -# TODO (mkolodner-sc): The sentinel approach adds an extra dict lookup on -# every operation in the hot loop for homogeneous graphs (always resolving -# the same single key). Profile whether this overhead is meaningful -# compared to the neighbor fetch and residual update costs, and consider -# splitting into separate homo/hetero loop implementations if so. _PPR_HOMOGENEOUS_NODE_TYPE = "ppr_homogeneous_node_type" _PPR_HOMOGENEOUS_EDGE_TYPE = ( _PPR_HOMOGENEOUS_NODE_TYPE, @@ -85,9 +74,10 @@ class DistPPRNeighborSampler(BaseDistNeighborSampler): but require more computation. Typical values: 1e-4 to 1e-6. max_ppr_nodes: Maximum number of nodes to return per seed based on PPR scores. num_neighbors_per_hop: Maximum number of neighbors to fetch per hop. - total_degree_dtype: Dtype for precomputed total-degree tensors. Defaults to - ``torch.int32``, which supports total degrees up to ~2 billion. Use a - larger dtype if nodes have exceptionally high aggregate degrees. + total_degree_dtype: Dtype for precomputed total-degree tensors. Defaults + to ``torch.int32``. Use a larger dtype if nodes have exceptionally high + aggregate degrees. + degree_tensors: Pre-computed degree tensors from the dataset. """ def __init__( @@ -99,14 +89,15 @@ def __init__( num_neighbors_per_hop: int = 100_000, total_degree_dtype: torch.dtype = torch.int32, degree_tensors: Union[torch.Tensor, dict[EdgeType, torch.Tensor]], + max_fetch_iterations: Optional[int] = None, **kwargs, ): super().__init__(*args, **kwargs) self._alpha = alpha - self._eps = eps self._max_ppr_nodes = max_ppr_nodes self._requeue_threshold_factor = alpha * eps self._num_neighbors_per_hop = num_neighbors_per_hop + self._max_fetch_iterations = max_fetch_iterations # Build mapping from node type to edge types that can be traversed from that node type. self._node_type_to_edge_types: dict[NodeType, list[EdgeType]] = defaultdict( @@ -152,6 +143,54 @@ def __init__( self._build_total_degree_tensors(degree_tensors, total_degree_dtype) ) + # Build integer ID mappings for the C++ forward-push kernel. String + # NodeType / EdgeType keys are only used at the Python boundary + # (translating to/from _sample_one_hop); all hot-loop state inside + # PPRForwardPush is indexed by int32 IDs. + # + # We include both source types (have outgoing edges) and destination-only + # types (no outgoing edges, but may accumulate PPR score during the walk) + # so the kernel can index residual/ppr_score tables for any node it sees. + source_node_types: set[NodeType] = set(self._node_type_to_edge_types.keys()) + destination_node_types: set[NodeType] = { + self._get_destination_type(et) + for etypes in self._node_type_to_edge_types.values() + for et in etypes + } + all_node_types: list[NodeType] = sorted( + source_node_types | destination_node_types + ) + all_edge_types: list[EdgeType] = sorted( + {et for etypes in self._node_type_to_edge_types.values() for et in etypes} + ) + + self._node_type_to_id: dict[NodeType, int] = { + nt: i for i, nt in enumerate(all_node_types) + } + self._ntype_id_to_ntype: list[NodeType] = all_node_types + self._etype_to_etype_id: dict[EdgeType, int] = { + et: i for i, et in enumerate(all_edge_types) + } + self._etype_id_to_etype: list[EdgeType] = all_edge_types + + self._node_type_id_to_edge_type_ids: list[list[int]] = [ + [ + self._etype_to_etype_id[et] + for et in self._node_type_to_edge_types.get(nt, []) + ] + for nt in all_node_types + ] + self._edge_type_id_to_dst_ntype_id: list[int] = [ + self._node_type_to_id[self._get_destination_type(et)] + for et in all_edge_types + ] + # Degree tensors indexed by ntype_id. Destination-only types get an empty + # tensor; the C++ kernel returns 0 for those, matching _get_total_degree. + self._degree_tensors_for_cpp: list[torch.Tensor] = [ + self._node_type_to_total_degree.get(nt, torch.zeros(0, dtype=torch.int32)) + for nt in all_node_types + ] + def _build_total_degree_tensors( self, degree_tensors: Union[torch.Tensor, dict[EdgeType, torch.Tensor]], @@ -204,105 +243,66 @@ def _build_total_degree_tensors( return result - def _get_total_degree(self, node_id: int, node_type: NodeType) -> int: - """Look up the precomputed total degree of a node. - - Args: - node_id: The ID of the node to look up. - node_type: The node type. - - Returns: - The total degree (sum across all edge types) for the node. - - Raises: - ValueError: If the node ID is out of range, indicating corrupted - graph data or a sampler bug. - """ - # Destination-only node types (no outgoing edges) are absent from - # _node_type_to_total_degree because total degree is only computed for - # traversable source types. Returning 0 here is correct: such nodes - # act as terminals — they accumulate PPR score but never push residual - # further. - if node_type not in self._node_type_to_total_degree: - return 0 - degree_tensor = self._node_type_to_total_degree[node_type] - if node_id >= len(degree_tensor): - raise ValueError( - f"Node ID {node_id} exceeds total degree tensor length " - f"({len(degree_tensor)}) for node type {node_type}." - ) - return int(degree_tensor[node_id].item()) - def _get_destination_type(self, edge_type: EdgeType) -> NodeType: """Get the node type at the destination end of an edge type.""" return edge_type[0] if self.edge_dir == "in" else edge_type[-1] async def _batch_fetch_neighbors( self, - nodes_to_lookup: dict[EdgeType, set[int]], + nodes_by_etype_id: dict[int, torch.Tensor], device: torch.device, - ) -> dict[tuple[int, EdgeType], list[int]]: - """Batch fetch neighbors for nodes grouped by edge type. + ) -> dict[int, tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: + """Batch fetch neighbors for nodes grouped by integer edge type ID. Issues one ``_sample_one_hop`` call per edge type (not per node), so all nodes of the same edge type are fetched in a single RPC round-trip. Each node's neighbor list is capped at ``self._num_neighbors_per_hop``. Args: - nodes_to_lookup: Dict mapping each edge type to the set of node IDs - whose neighbors should be fetched via that edge type. Only nodes - absent from the caller's ``neighbor_cache`` should be included. + nodes_by_etype_id: Dict mapping integer edge type ID to a 1-D int64 + tensor of node IDs to fetch neighbors for. Comes directly from + ``drain_queue()``; node IDs are already deduplicated. device: Torch device for intermediate tensor creation. Returns: - Dict mapping ``(node_id, edge_type)`` to the list of neighbor node IDs - returned by ``_sample_one_hop``. Only nodes that appeared in - ``nodes_to_lookup`` are present; edge types with an empty node set are - skipped entirely. + Dict mapping etype_id to ``(node_ids, flat_neighbors, counts)`` as + int64 tensors, ready to pass directly to ``push_residuals``. + ``flat_neighbors`` is the flat concatenation of all neighbor lists + for that edge type; ``counts[i]`` is the neighbor count for + ``node_ids[i]``. Example:: - nodes_to_lookup = { - ("user", "buys", "item"): {0, 3}, - ("item", "bought_by", "user"): {7}, + nodes_by_etype_id = { + 2: tensor([0, 3]), # etype_id 2 → nodes 0 and 3 + 5: tensor([7]), # etype_id 5 → node 7 } # Might return (neighbor lists depend on graph structure): { - (0, ("user", "buys", "item")): [5, 9, 2], - (3, ("user", "buys", "item")): [1], - (7, ("item", "bought_by", "user")): [0, 3], + 2: (tensor([0, 3]), tensor([5, 9, 2, 1]), tensor([3, 1])), + 5: (tensor([7]), tensor([0, 3]), tensor([2])), } """ - result: dict[tuple[int, EdgeType], list[int]] = {} - for etype, node_ids in nodes_to_lookup.items(): - if not node_ids: - continue - nodes_list = list(node_ids) - lookup_tensor = torch.tensor(nodes_list, dtype=torch.long, device=device) - - # _sample_one_hop expects None for homogeneous graphs, not the PPR sentinel. - output: NeighborOutput = await self._sample_one_hop( - srcs=lookup_tensor, - num_nbr=self._num_neighbors_per_hop, - etype=etype if etype != _PPR_HOMOGENEOUS_EDGE_TYPE else None, + # Fire all per-edge-type RPC calls concurrently. Each _sample_one_hop + # issues a single RPC round-trip; doing them in parallel rather than + # sequentially cuts fetch latency from O(num_edge_types) to O(1). + eids = list(nodes_by_etype_id.keys()) + sample_tasks = [] + for eid in eids: + etype = self._etype_id_to_etype[eid] + sample_tasks.append( + self._sample_one_hop( + srcs=nodes_by_etype_id[eid].to(device), + num_nbr=self._num_neighbors_per_hop, + # _sample_one_hop expects None for homogeneous graphs, not the PPR sentinel. + etype=None if etype == _PPR_HOMOGENEOUS_EDGE_TYPE else etype, + ) ) - neighbors = output.nbr - neighbor_counts = output.nbr_num - - # TODO (mkolodner-sc): Investigate performance of a vectorized version of the below code - neighbors_list = neighbors.tolist() - counts_list = neighbor_counts.tolist() - del neighbors, neighbor_counts - - # neighbors_list is a flat concatenation of all neighbors for all looked-up nodes. - # We use offset to slice out each node's neighbors: node i's neighbors are at - # neighbors_list[offset : offset + count], then we advance offset by count. - offset = 0 - for node_id, count in zip(nodes_list, counts_list): - result[(node_id, etype)] = neighbors_list[offset : offset + count] - offset += count - - return result + outputs: list[NeighborOutput] = await asyncio.gather(*sample_tasks) + return { + eid: (nodes_by_etype_id[eid], output.nbr, output.nbr_num) + for eid, output in zip(eids, outputs) + } async def _compute_ppr_scores( self, @@ -364,226 +364,60 @@ async def _compute_ppr_scores( if seed_node_type is None: seed_node_type = _PPR_HOMOGENEOUS_NODE_TYPE device = seed_nodes.device - batch_size = seed_nodes.size(0) - - # Per-seed PPR state, nested by node type for efficient type-grouped access. - - # ppr_scores[i][node_type][node_id] = accumulated PPR score for node_id - # of type node_type, relative to seed i. Updated each iteration by - # absorbing the node's residual. - ppr_scores: list[dict[NodeType, dict[int, float]]] = [ - defaultdict(lambda: defaultdict(float)) for _ in range(batch_size) - ] - - # residuals[i][node_type][node_id] = unconverged probability mass at node_id - # of type node_type for seed i. Each iteration, a node's residual is - # absorbed into its PPR score and then distributed to its neighbors. - residuals: list[dict[NodeType, dict[int, float]]] = [ - defaultdict(lambda: defaultdict(float)) for _ in range(batch_size) - ] - - # queue[i][node_type] = set of node IDs whose residual exceeds the - # convergence threshold (alpha * eps * total_degree). The algorithm - # terminates when all queues are empty. A set is used because multiple - # neighbors can push residual to the same node in one iteration — - # deduplication avoids redundant processing, and the O(1) membership - # check matters since it runs in the innermost loop. - queue: list[dict[NodeType, set[int]]] = [ - defaultdict(set) for _ in range(batch_size) - ] - - seed_list = seed_nodes.tolist() - for i, seed in enumerate(seed_list): - residuals[i][seed_node_type][seed] = self._alpha - queue[i][seed_node_type].add(seed) - - # Cache keyed by (node_id, edge_type) since same node can have different neighbors per edge type - neighbor_cache: dict[tuple[int, EdgeType], list[int]] = {} - - num_nodes_in_queue = batch_size - one_minus_alpha = 1 - self._alpha + ppr_state = PPRForwardPush( + seed_nodes, + self._node_type_to_id[seed_node_type], + self._alpha, + self._requeue_threshold_factor, + self._node_type_id_to_edge_type_ids, + self._edge_type_id_to_dst_ntype_id, + self._degree_tensors_for_cpp, + ) - while num_nodes_in_queue > 0: - # Drain all nodes from all queues and group by edge type for batched lookups - queued_nodes: list[dict[NodeType, set[int]]] = [ - defaultdict(set) for _ in range(batch_size) - ] - nodes_to_lookup: dict[EdgeType, set[int]] = defaultdict(set) - - for seed_idx in range(batch_size): - if queue[seed_idx]: - queued_nodes[seed_idx] = queue[seed_idx] - queue[seed_idx] = defaultdict(set) - for node_type, node_ids in queued_nodes[seed_idx].items(): - num_nodes_in_queue -= len(node_ids) - # We fetch neighbors for ALL edge types originating - # from this node type, not just the edge type that - # caused the node to be queued. This is required for - # correctness: forward push distributes residual to - # all neighbors proportionally by total degree, so - # every edge type must be considered. - # Destination-only types have no entry in _node_type_to_edge_types; - # .get() returns [] so we skip neighbor lookup for them. - edge_types_for_node = self._node_type_to_edge_types.get( - node_type, [] - ) - for node_id in node_ids: - for etype in edge_types_for_node: - cache_key = (node_id, etype) - if cache_key not in neighbor_cache: - # TODO (mkolodner-sc): Investigate switching from set to list - # here. _sample_one_hop handles duplicates correctly (second - # write to result[(node_id, etype)] is a no-op overwrite), so - # dedup is not required for correctness. A list would avoid - # per-add hash cost and the set->list->tensor conversion in - # _batch_fetch_neighbors, though at the cost of redundant - # network calls for any duplicate nodes across seeds. - nodes_to_lookup[etype].add(node_id) - - fetched_neighbors = await self._batch_fetch_neighbors( - nodes_to_lookup=nodes_to_lookup, - device=device, + fetch_iteration_count = 0 + + while True: + # drain_queue returns None when the queue is truly empty (convergence), + # or a dict (possibly empty) when nodes were drained. An empty dict + # means all drained nodes either had cached neighbors or no outgoing + # edges — we still call push_residuals to flush their residuals into + # ppr_scores_. + nodes_by_etype_id = ppr_state.drain_queue() + if nodes_by_etype_id is None: + break + + fetch_budget_remaining = ( + self._max_fetch_iterations is None + or fetch_iteration_count < self._max_fetch_iterations + ) + if nodes_by_etype_id and fetch_budget_remaining: + fetched_by_etype_id = await self._batch_fetch_neighbors( + nodes_by_etype_id, device + ) + fetch_iteration_count += 1 + else: + # Fetch budget exhausted; push_residuals will use the existing neighbor cache. + fetched_by_etype_id = {} + + # Run in executor so the C++ push doesn't block the asyncio event loop. + await asyncio.get_running_loop().run_in_executor( + None, ppr_state.push_residuals, fetched_by_etype_id ) - # fetched_neighbors is intentionally NOT merged into neighbor_cache - # upfront. We only promote entries when a node is requeued — see - # the should_requeue block below. - - # Push residual to neighbors and re-queue in a single pass. This - # is safe because each seed's state is independent, and residuals - # are always positive so the merged loop can never miss a re-queue. - for seed_idx in range(batch_size): - for source_type, source_nodes in queued_nodes[seed_idx].items(): - for source_node in source_nodes: - source_residual = residuals[seed_idx][source_type].get( - source_node, 0.0 - ) - - ppr_scores[seed_idx][source_type][source_node] += ( - source_residual - ) - residuals[seed_idx][source_type][source_node] = 0.0 - - # Same destination-only guard as in the queue drain loop above. - edge_types_for_node = self._node_type_to_edge_types.get( - source_type, [] - ) - - total_degree = self._get_total_degree(source_node, source_type) - - if total_degree == 0: - continue - - residual_per_neighbor = ( - one_minus_alpha * source_residual / total_degree - ) - - for etype in edge_types_for_node: - cache_key = (source_node, etype) - # fetched_neighbors and neighbor_cache are mutually - # exclusive per iteration: the queue drain only adds - # a node to nodes_to_lookup if it is absent from - # neighbor_cache, so a key appears in at most one. - neighbor_list = fetched_neighbors.get( - cache_key, neighbor_cache.get(cache_key, []) - ) - if not neighbor_list: - continue - - neighbor_type = self._get_destination_type(etype) - - for neighbor_node in neighbor_list: - residuals[seed_idx][neighbor_type][neighbor_node] += ( - residual_per_neighbor - ) - - requeue_threshold = ( - self._requeue_threshold_factor - * self._get_total_degree( - neighbor_node, neighbor_type - ) - ) - should_requeue = ( - neighbor_node not in queue[seed_idx][neighbor_type] - and residuals[seed_idx][neighbor_type][ - neighbor_node - ] - >= requeue_threshold - ) - if should_requeue: - queue[seed_idx][neighbor_type].add(neighbor_node) - num_nodes_in_queue += 1 - # Promote this node's neighbor lists to the - # persistent cache: it will be processed next - # iteration, so caching now avoids a re-fetch. - # Nodes that are never requeued (typically - # high-degree) are never promoted, keeping - # their large neighbor lists out of the cache. - for ( - promote_etype - ) in self._node_type_to_edge_types.get( - neighbor_type, [] - ): - promote_key = (neighbor_node, promote_etype) - if ( - promote_key in fetched_neighbors - and promote_key not in neighbor_cache - ): - neighbor_cache[promote_key] = ( - fetched_neighbors[promote_key] - ) - - # Extract top-k nodes by PPR score, grouped by node type. - # Results are three flat tensors per node type (no padding): - # - flat_ids: [id_seed0_0, id_seed0_1, ..., id_seed1_0, ...] - # - flat_weights: [wt_seed0_0, wt_seed0_1, ..., wt_seed1_0, ...] - # - valid_counts: [count_seed0, count_seed1, ...] - # - # valid_counts[i] records how many top-k neighbors seed i contributed. - # The inducer uses valid_counts to slice flat_ids into per-seed groups - # and assign local indices. Example: - # - # 4 seeds, valid_counts = [1, 6, 2, 1] (10 total pairs) - # flat_ids = [d0a, d1a, d1b, d1c, d1d, d1e, d1f, d2a, d2b, d3a] - # - # seed 0 owns flat_ids[0:1], seed 1 owns flat_ids[1:7], - # seed 2 owns flat_ids[7:9], seed 3 owns flat_ids[9:10] - # _node_type_to_edge_types only contains source types; destination-only - # types are absent but may have accumulated PPR scores during the walk. - # We union with all types seen in ppr_scores so they appear in the output. - all_node_types: set[NodeType] = set(self._node_type_to_edge_types.keys()) - for seed_ppr in ppr_scores: - all_node_types.update(seed_ppr.keys()) + # Translate ntype_id integer keys back to NodeType strings for the rest + # of the pipeline, and move tensors to the correct device. ntype_to_flat_ids: dict[NodeType, torch.Tensor] = {} ntype_to_flat_weights: dict[NodeType, torch.Tensor] = {} ntype_to_valid_counts: dict[NodeType, torch.Tensor] = {} - for ntype in all_node_types: - flat_ids: list[int] = [] - flat_weights: list[float] = [] - valid_counts: list[int] = [] - - for i in range(batch_size): - type_scores = ppr_scores[i].get(ntype, {}) - top_k = heapq.nlargest( - self._max_ppr_nodes, type_scores.items(), key=lambda x: x[1] - ) - if top_k: - ids, weights = zip(*top_k) - flat_ids.extend(ids) - flat_weights.extend(weights) - valid_counts.append(len(top_k)) - - ntype_to_flat_ids[ntype] = torch.tensor( - flat_ids, dtype=torch.long, device=device - ) - ntype_to_flat_weights[ntype] = torch.tensor( - flat_weights, dtype=torch.float, device=device - ) - ntype_to_valid_counts[ntype] = torch.tensor( - valid_counts, dtype=torch.long, device=device - ) + for ntype_id, (flat_ids, flat_weights, valid_counts) in ppr_state.extract_top_k( + self._max_ppr_nodes + ).items(): + ntype = self._ntype_id_to_ntype[ntype_id] + ntype_to_flat_ids[ntype] = flat_ids.to(device) + ntype_to_flat_weights[ntype] = flat_weights.to(device) + ntype_to_valid_counts[ntype] = valid_counts.to(device) if self._is_homogeneous: assert ( @@ -596,7 +430,11 @@ async def _compute_ppr_scores( ntype_to_valid_counts[_PPR_HOMOGENEOUS_NODE_TYPE], ) else: - return ntype_to_flat_ids, ntype_to_flat_weights, ntype_to_valid_counts + return ( + ntype_to_flat_ids, + ntype_to_flat_weights, + ntype_to_valid_counts, + ) async def _sample_from_nodes( self, @@ -683,20 +521,35 @@ async def _sample_from_nodes( # NodeType -> global IDs (same values as nodes_to_sample). src_dict = inducer.init_node(nodes_to_sample) - # Compute PPR for each seed type, collecting flat global neighbor IDs, - # weights, and per-seed counts. Build nbr_dict for a single - # inducer.induce_next call using PPR edge types (seed_type, 'ppr', ntype) - # — the inducer only cares about etype[0] and etype[-1] as source/dest - # node types, so the relation name is arbitrary. + # Compute PPR for all seed types concurrently, collecting flat global + # neighbor IDs, weights, and per-seed counts. Build nbr_dict for a + # single inducer.induce_next call using PPR edge types + # (seed_type, 'ppr', ntype) — the inducer only cares about etype[0] + # and etype[-1] as source/dest node types, so the relation name is + # arbitrary. + # + # Each seed type's PPR computation is entirely independent: it creates + # its own PPRForwardPush and only reads shared sampler attributes + # (degree tensors, edge-type maps) which are immutable after __init__. + # Running them with asyncio.gather allows their fetch phases to overlap, + # which is most beneficial when there are 2+ distinct seed node types + # (e.g. cross-type supervision edges like user→story). + seed_types = list(nodes_to_sample.keys()) + ppr_results = await asyncio.gather( + *[ + self._compute_ppr_scores(nodes_to_sample[seed_type], seed_type) + for seed_type in seed_types + ] + ) + nbr_dict: dict[EdgeType, list[torch.Tensor]] = {} ppr_edge_type_to_flat_weights: dict[EdgeType, torch.Tensor] = {} - for seed_type, seed_nodes in nodes_to_sample.items(): - ( - ntype_to_flat_ids, - ntype_to_flat_weights, - ntype_to_valid_counts, - ) = await self._compute_ppr_scores(seed_nodes, seed_type) + for seed_type, ( + ntype_to_flat_ids, + ntype_to_flat_weights, + ntype_to_valid_counts, + ) in zip(seed_types, ppr_results): assert isinstance(ntype_to_flat_ids, dict) assert isinstance(ntype_to_flat_weights, dict) assert isinstance(ntype_to_valid_counts, dict) @@ -746,14 +599,19 @@ async def _sample_from_nodes( # rows_dict and cols_dict are keyed by PPR edge type and give # flat local source/destination indices respectively, aligned with # the flat_ids order passed to induce_next. - for ppr_edge_type, flat_weights in ppr_edge_type_to_flat_weights.items(): + for ( + ppr_edge_type, + flat_weights, + ) in ppr_edge_type_to_flat_weights.items(): rows = rows_dict.get(ppr_edge_type) cols = cols_dict.get(ppr_edge_type) if rows is not None and cols is not None: edge_index = torch.stack([rows, cols]) else: edge_index = torch.zeros(2, 0, dtype=torch.long, device=self.device) - flat_weights = torch.zeros(0, dtype=torch.float, device=self.device) + flat_weights = torch.zeros( + 0, dtype=torch.double, device=self.device + ) etype_str = repr(ppr_edge_type) metadata[f"{PPR_EDGE_INDEX_METADATA_KEY}{etype_str}"] = edge_index metadata[f"{PPR_WEIGHT_METADATA_KEY}{etype_str}"] = flat_weights diff --git a/gigl/distributed/sampler_options.py b/gigl/distributed/sampler_options.py index a4d2c9335..fccd7a3ba 100644 --- a/gigl/distributed/sampler_options.py +++ b/gigl/distributed/sampler_options.py @@ -54,18 +54,27 @@ class PPRSamplerOptions: max_ppr_nodes: Maximum number of nodes to return per seed based on PPR scores. num_neighbors_per_hop: Maximum number of neighbors fetched per node per edge - type during PPR traversal. Set large to approximate fetching all - neighbors. + type during PPR traversal. 1000 is sufficient in practice — high-degree + hub nodes receive diminishing residual per neighbor, so capping the fetch + has little effect on PPR accuracy while keeping per-hop RPC cost bounded. + Set large to approximate fetching all neighbors. total_degree_dtype: Dtype for precomputed total-degree tensors. Defaults to ``torch.int32``, which supports total degrees up to ~2 billion. Use a larger dtype if nodes have exceptionally high aggregate degrees. + max_fetch_iterations: Maximum number of iterations that issue RPC neighbor + fetches. After this many fetch iterations, subsequent iterations push + residuals using only already-cached neighbor lists (no new RPCs). + The algorithm still runs to convergence — re-enqueued nodes propagate + through cached neighbors at negligible cost. ``None`` (default) means + no fetch limit. """ alpha: float = 0.5 eps: float = 1e-4 max_ppr_nodes: int = 50 - num_neighbors_per_hop: int = 100_000 + num_neighbors_per_hop: int = 1_000 total_degree_dtype: torch.dtype = torch.int32 + max_fetch_iterations: Optional[int] = None SamplerOptions = Union[KHopNeighborSamplerOptions, PPRSamplerOptions] diff --git a/gigl/distributed/utils/dist_sampler.py b/gigl/distributed/utils/dist_sampler.py index d42c5b104..0333f4138 100644 --- a/gigl/distributed/utils/dist_sampler.py +++ b/gigl/distributed/utils/dist_sampler.py @@ -82,6 +82,7 @@ def create_dist_sampler( alpha=sampler_options.alpha, eps=sampler_options.eps, max_ppr_nodes=sampler_options.max_ppr_nodes, + max_fetch_iterations=sampler_options.max_fetch_iterations, num_neighbors_per_hop=sampler_options.num_neighbors_per_hop, total_degree_dtype=sampler_options.total_degree_dtype, degree_tensors=degree_tensors, diff --git a/gigl/distributed/utils/neighborloader.py b/gigl/distributed/utils/neighborloader.py index 3ecd2b031..b91b411e3 100644 --- a/gigl/distributed/utils/neighborloader.py +++ b/gigl/distributed/utils/neighborloader.py @@ -190,6 +190,11 @@ def strip_non_ppr_edge_types( for edge_type in list(data.edge_types): if edge_type not in ppr_edge_types: del data[edge_type] + # num_sampled_edges is set by GLT's standard k-hop sampler but not + # by PPR sampling, which constructs HeteroData manually. Guard with + # hasattr rather than assuming it's always present. + if hasattr(data, "num_sampled_edges"): + data.num_sampled_edges.pop(edge_type, None) return data diff --git a/pyproject.toml b/pyproject.toml index ebbc65ea8..de3542607 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -108,10 +108,10 @@ required-environments = [ no-build-isolation-package = ["gigl-core"] [dependency-groups] -# scikit-build-core is gigl-core's PEP 517 build backend. With no-build-isolation-package -# set, uv does not install [build-system].requires automatically, so it must be present -# in the ambient environment before any gigl-core build (uv sync, uv build, Dockerfiles). -gigl-core-build-backend = ["scikit-build-core>=0.10"] +# These are gigl-core's [build-system].requires. With no-build-isolation-package set, +# uv does not install them automatically, so they must be present in the ambient +# environment before any gigl-core build (uv sync, uv build, Dockerfiles). +gigl-core-build-backend = ["scikit-build-core>=0.10", "pybind11>=2.12"] dev = [ {include-group = "gigl-core-build-backend"}, {include-group = "docs"}, diff --git a/tests/unit/distributed/dist_ppr_sampler_test.py b/tests/unit/distributed/dist_ppr_sampler_test.py index b0879c306..15369de72 100644 --- a/tests/unit/distributed/dist_ppr_sampler_test.py +++ b/tests/unit/distributed/dist_ppr_sampler_test.py @@ -270,8 +270,10 @@ def _assert_ppr_scores_match_reference( """Assert sampler PPR scores match reference scores per node type. Checks that top-k node sets are identical and that per-node scores - are within atol=1e-6. The forward push error per node is bounded by - O(alpha * eps * degree); observed deltas are ~1e-7 for eps=1e-6. + are within atol=2e-6. The forward push error per node is bounded by + the per-node requeue threshold alpha * eps * degree; for max degree 3, + alpha=0.5, eps=1e-6 the per-node threshold is ~1.5e-6. Tolerance is + set to 2e-6 to provide a small margin above this bound. Args: ntype_to_sampler_ppr: Sampler output from :func:`_extract_hetero_ppr_scores`. @@ -290,7 +292,7 @@ def _assert_ppr_scores_match_reference( for node_id in reference_ppr[ntype_str]: ref_score = reference_ppr[ntype_str][node_id] sam_score = ntype_to_sampler_ppr[ntype_str][node_id] - assert abs(sam_score - ref_score) < 1e-6, ( + assert abs(sam_score - ref_score) < 2e-6, ( f"{seed_id}, type {ntype_str}, node {node_id}: " f"sampler={sam_score:.8f} vs reference={ref_score:.8f}" ) @@ -328,6 +330,14 @@ def _run_ppr_loader_correctness_check( for datum in loader: assert isinstance(datum, Data) + # PPR sampling does not count per-hop neighbors, so num_sampled_edges + # should be absent or empty on all PPR output batches. + assert ( + not hasattr(datum, "num_sampled_edges") or len(datum.num_sampled_edges) == 0 + ), ( + f"Expected empty num_sampled_edges for PPR output, got {datum.num_sampled_edges}" + ) + assert hasattr(datum, "edge_index"), "Missing edge_index on Data" assert hasattr(datum, "edge_attr"), "Missing edge_attr on Data" @@ -371,12 +381,15 @@ def _run_ppr_loader_correctness_check( f" Reference: {sorted(reference_ppr.keys())}" ) - # Forward push is an approximation; with eps=1e-6 the per-node error - # is bounded by O(alpha * eps * degree). Observed deltas are ~1e-7. + # Forward push is an approximation; with eps=1e-6 the per-node + # requeue threshold is alpha * eps * degree. For this test graph + # (max degree 3, alpha=0.5, eps=1e-6) the per-node threshold is + # ~1.5e-6. Tolerance is set to 2e-6 to provide a small margin + # above this bound. for node_id in reference_ppr: ref_score = reference_ppr[node_id] sam_score = sampler_ppr[node_id] - assert abs(sam_score - ref_score) < 1e-6, ( + assert abs(sam_score - ref_score) < 2e-6, ( f"Seed {seed_global_id}, node {node_id}: " f"sampler={sam_score:.8f} vs reference={ref_score:.8f}" ) @@ -424,6 +437,14 @@ def _run_ppr_hetero_loader_correctness_check( for datum in loader: assert isinstance(datum, HeteroData) + # PPR sampling does not count per-hop neighbors, so num_sampled_edges + # should be absent or empty on all PPR output batches. + assert ( + not hasattr(datum, "num_sampled_edges") or len(datum.num_sampled_edges) == 0 + ), ( + f"Expected empty num_sampled_edges for PPR output, got {datum.num_sampled_edges}" + ) + seed_global_id = datum[USER].batch[0].item() ntype_to_sampler_ppr = _extract_hetero_ppr_scores( @@ -505,6 +526,14 @@ def _run_ppr_ablp_loader_correctness_check( for datum in loader: assert isinstance(datum, HeteroData) + # PPR sampling does not count per-hop neighbors, so num_sampled_edges + # should be absent or empty on all PPR output batches. + assert ( + not hasattr(datum, "num_sampled_edges") or len(datum.num_sampled_edges) == 0 + ), ( + f"Expected empty num_sampled_edges for PPR output, got {datum.num_sampled_edges}" + ) + # ABLP should produce positive labels alongside PPR metadata assert hasattr(datum, "y_positive"), "Missing y_positive on HeteroData" diff --git a/uv.lock b/uv.lock index b8303f693..afa760197 100644 --- a/uv.lock +++ b/uv.lock @@ -784,6 +784,7 @@ dev = [ { name = "pandas-stubs" }, { name = "parameterized" }, { name = "pre-commit" }, + { name = "pybind11" }, { name = "pydata-sphinx-theme" }, { name = "ruff" }, { name = "scikit-build-core" }, @@ -818,6 +819,7 @@ docs = [ { name = "sphinx-tabs" }, ] gigl-core-build-backend = [ + { name = "pybind11" }, { name = "scikit-build-core" }, ] lint = [ @@ -910,6 +912,7 @@ dev = [ { name = "pandas-stubs", specifier = "==2.2.2.240807" }, { name = "parameterized", specifier = "==0.9.0" }, { name = "pre-commit", specifier = "==3.3.2" }, + { name = "pybind11", specifier = ">=2.12" }, { name = "pydata-sphinx-theme", specifier = "==0.16.1" }, { name = "ruff", specifier = "==0.15.10" }, { name = "scikit-build-core", specifier = ">=0.10" }, @@ -943,7 +946,10 @@ docs = [ { name = "sphinx-rtd-theme", specifier = "==2.0.0" }, { name = "sphinx-tabs", specifier = "==3.4.5" }, ] -gigl-core-build-backend = [{ name = "scikit-build-core", specifier = ">=0.10" }] +gigl-core-build-backend = [ + { name = "pybind11", specifier = ">=2.12" }, + { name = "scikit-build-core", specifier = ">=0.10" }, +] lint = [ { name = "mdformat", specifier = "==0.7.22" }, { name = "mdformat-tables", specifier = "==1.0.0" }, @@ -3070,6 +3076,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/47/8d/d529b5d697919ba8c11ad626e835d4039be708a35b0d22de83a269a6682c/pyasn1_modules-0.4.2-py3-none-any.whl", hash = "sha256:29253a9207ce32b64c3ac6600edc75368f98473906e8fd1043bd6b5b1de2c14a", size = 181259, upload-time = "2025-03-28T02:41:19.028Z" }, ] +[[package]] +name = "pybind11" +version = "3.0.4" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/cc/f0/35145a3c3baffeef55d4b8324caa33abaa8fa56ab345ecd4b2211d09163e/pybind11-3.0.4.tar.gz", hash = "sha256:3286b59c8a774b9ee650169302dd5a4eedc30a8617905a0560dd8ee44775130c", size = 589533, upload-time = "2026-04-19T03:08:15.925Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b3/06/c3a23c9a0263b136c519f033a58d4641e73065fefc7754e9667ec206d992/pybind11-3.0.4-py3-none-any.whl", hash = "sha256:961720ee652da51d531b7b2451a6bd2bc042b0106e6d9baa48ecb7d58034ce63", size = 314166, upload-time = "2026-04-19T03:08:14.091Z" }, +] + [[package]] name = "pycparser" version = "2.23"