diff --git a/.gitignore b/.gitignore index 2613e49..3df7d2c 100644 --- a/.gitignore +++ b/.gitignore @@ -7,6 +7,7 @@ _build/ _generate/ *.so *.so.* +*.a *.py[cod] *.egg-info .eggs/ @@ -17,4 +18,26 @@ __pycache__/ .ipynb_checkpoints/ .vscode/ .DS_Store -*.prof \ No newline at end of file +*.prof + +# Test coverage +htmlcov/ +.coverage +.coverage.* +coverage.xml +*.cover + +# Pytest +.pytest_cache/ +.pytest_cache + +# Build artifacts +*.o +*.obj +*.lib +*.exp + +# Temporary files +*.tmp +*.bak +*~ \ No newline at end of file diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000..bb4c282 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,312 @@ +# Contributing Guide + +Thank you for your interest in contributing to python_prtree! + +## Development Setup + +### Prerequisites + +- Python 3.8 or higher +- CMake 3.12 or higher +- C++17-compatible compiler (GCC 7+, Clang 5+, MSVC 2017+) +- Git + +### Quick Start + +```bash +# Clone the repository +git clone --recursive https://github.com/atksh/python_prtree +cd python_prtree + +# Setup development environment (first time only) +make dev + +# Build and test +make build +make test +``` + +## Makefile Commands + +We provide a Makefile to streamline the development workflow. + +### Show Help + +```bash +make help +``` + +### Initial Setup + +```bash +make init # Initialize submodules + check dependencies +make install-deps # Install development dependencies +make dev # Run init + install-deps + build at once +``` + +### Building + +```bash +make build # Build in debug mode +make build-release # Build in release mode +make rebuild # Clean build (clean + build) +make debug-build # Build with debug symbols +``` + +### Testing + +```bash +make test # Run all tests +make test-verbose # Run tests in verbose mode +make test-fast # Run tests in parallel (faster) +make test-coverage # Run tests with coverage +make test-one TEST= # Run specific test(s) +``` + +### Cleanup + +```bash +make clean # Remove build artifacts +make clean-all # Remove everything including submodules +``` + +### Packaging + +```bash +make wheel # Build wheel package +make sdist # Build source distribution +make release # Build release packages (wheel + sdist) +``` + +### Other Commands + +```bash +make format # Format C++ code (requires clang-format) +make lint # Lint code +make info # Show project information +make check # Run build and tests (for CI) +make quick # Quick test (clean + build + test) +``` + +## Development Workflow + +### Adding New Features + +1. **Create a branch** + ```bash + git checkout -b feature/your-feature-name + ``` + +2. **Setup development environment (first time only)** + ```bash + make dev + ``` + +3. **Make changes** + - C++ code: `cpp/prtree.h`, `cpp/main.cc` + - Python wrapper: `src/python_prtree/__init__.py` + - Tests: `tests/test_PRTree.py` + +4. **Build and test** + ```bash + make rebuild + make test + ``` + +5. **Check code quality** + ```bash + make format # Format code + make lint # Lint code + ``` + +6. **Commit** + ```bash + git add -A + git commit -m "Add: description of new feature" + ``` + +7. **Create pull request** + +### Test-Driven Development (TDD) + +1. **Write tests first** + ```python + # tests/test_PRTree.py + def test_new_feature(): + # Write test code + pass + ``` + +2. **Verify test fails** + ```bash + make test-one TEST=test_new_feature + ``` + +3. **Implement feature** + ```cpp + // cpp/prtree.h + // Add implementation + ``` + +4. **Build and test** + ```bash + make build + make test-one TEST=test_new_feature + ``` + +5. **Run all tests** + ```bash + make test + ``` + +### Debugging + +Build with debug symbols for debugging: + +```bash +make debug-build +gdb python3 +(gdb) run -c "from python_prtree import PRTree2D; ..." +``` + +### Checking Coverage + +```bash +make test-coverage +# Open htmlcov/index.html in browser +``` + +## Coding Standards + +### C++ + +- **Style**: Follow Google C++ Style Guide +- **Formatting**: Use clang-format (`make format`) +- **Naming conventions**: + - Classes: `PascalCase` (e.g., `PRTree`) + - Functions/methods: `snake_case` (e.g., `batch_query`) + - Variables: `snake_case` + - Constants: `UPPER_CASE` + +### Python + +- **Style**: Follow PEP 8 +- **Line length**: Maximum 100 characters +- **Documentation**: Use docstrings + +### Tests + +- Add tests for all new features +- Test case naming: `test__` +- Cover edge cases +- Use parameterized tests (`@pytest.mark.parametrize`) + +## Project Structure + +``` +python_prtree/ +├── cpp/ # C++ implementation +│ ├── prtree.h # PRTree core implementation +│ ├── main.cc # Python bindings +│ ├── parallel.h # Parallel processing utilities +│ └── small_vector.h # Optimized vector +├── src/python_prtree/ # Python wrapper +│ └── __init__.py +├── tests/ # Test suite +│ └── test_PRTree.py +├── third/ # Third-party libraries (submodules) +│ ├── pybind11/ +│ └── snappy/ +├── CMakeLists.txt # CMake configuration +├── setup.py # Packaging configuration +├── Makefile # Development workflow +└── README.md # User documentation +``` + +## Troubleshooting + +### Submodules Not Found + +```bash +make clean-all +make init +``` + +### Build Errors + +```bash +make clean +make build +``` + +### Test Failures + +1. Verify build succeeded + ```bash + make build + ``` + +2. Check environment variables + ```bash + echo $PYTHONPATH # Should include src directory + ``` + +3. Run in verbose mode + ```bash + make test-verbose + ``` + +### CMake Errors + +Clear CMake cache: +```bash +rm -rf build +make build +``` + +## Continuous Integration (CI) + +When you create a pull request, the following checks run automatically: + +- Build verification +- All tests +- Code coverage + +Run the same checks locally: +```bash +make check +``` + +## Release Process + +1. **Update version** + - Update version number in `setup.py` + +2. **Update changelog** + - Update "New Features and Changes" section in `README.md` + +3. **Run tests** + ```bash + make clean + make check + ``` + +4. **Build release packages** + ```bash + make release + ``` + +5. **Create tag** + ```bash + git tag -a v0.x.x -m "Release v0.x.x" + git push origin v0.x.x + ``` + +## Questions and Support + +- **Issues**: https://github.com/atksh/python_prtree/issues +- **Discussions**: https://github.com/atksh/python_prtree/discussions + +## License + +Contributions to this project will be released under the same license as the project (MIT). diff --git a/MAKEFILE_USAGE.md b/MAKEFILE_USAGE.md new file mode 100644 index 0000000..f84dce1 --- /dev/null +++ b/MAKEFILE_USAGE.md @@ -0,0 +1,246 @@ +# Makefile Usage Guide + +This document provides a quick reference for all available Make commands in the python_prtree project. + +## Quick Start + +```bash +# First time setup +make dev + +# Build and test +make build +make test +``` + +## Command Reference + +### Essential Commands + +| Command | Description | +|---------|-------------| +| `make help` | Show all available commands | +| `make dev` | Complete development setup (init + install-deps + build) | +| `make build` | Build C++ extension | +| `make test` | Run all tests | +| `make clean` | Remove build artifacts | + +### Initialization + +| Command | Description | +|---------|-------------| +| `make init` | Initialize submodules and check dependencies | +| `make check-deps` | Verify required tools are installed | +| `make init-submodules` | Initialize git submodules | +| `make install-deps` | Install Python development dependencies | + +### Building + +| Command | Description | +|---------|-------------| +| `make build` | Build in debug mode (default) | +| `make build-release` | Build optimized release version | +| `make rebuild` | Clean and rebuild | +| `make debug-build` | Build with debug symbols | + +### Testing + +| Command | Description | Example | +|---------|-------------|---------| +| `make test` | Run all tests | | +| `make test-verbose` | Run tests with detailed output | | +| `make test-fast` | Run tests in parallel | | +| `make test-coverage` | Generate coverage report | | +| `make test-one` | Run specific test(s) | `make test-one TEST=test_result` | + +### Code Quality + +| Command | Description | Requirements | +|---------|-------------|--------------| +| `make format` | Format C++ code | clang-format | +| `make lint-cpp` | Lint C++ code | clang-tidy | +| `make lint-python` | Lint Python code | flake8 | +| `make lint` | Lint all code | clang-tidy, flake8 | + +### Packaging + +| Command | Description | +|---------|-------------| +| `make wheel` | Build wheel package | +| `make sdist` | Build source distribution | +| `make release` | Build both wheel and sdist | + +### Maintenance + +| Command | Description | +|---------|-------------| +| `make clean` | Remove build artifacts | +| `make clean-all` | Remove everything including submodules | +| `make info` | Show project and environment info | +| `make check` | Run build + test (for CI) | + +### Other + +| Command | Description | Requirements | +|---------|-------------|--------------| +| `make docs` | Generate documentation | Doxygen | +| `make benchmark` | Run benchmarks | benchmark.py | +| `make watch-test` | Auto-run tests on file changes | pytest-watch | + +## Common Workflows + +### First Time Setup + +```bash +# Clone with submodules +git clone --recursive https://github.com/atksh/python_prtree +cd python_prtree + +# Setup development environment +make dev +``` + +### Daily Development + +```bash +# Make changes to code... + +# Build and test +make rebuild +make test + +# Or use quick command +make quick # clean + build + test +``` + +### Before Committing + +```bash +# Format and lint +make format +make lint + +# Run full test suite +make test + +# Check everything +make check +``` + +### Testing Specific Features + +```bash +# Run tests matching a pattern +make test-one TEST=test_query + +# This will run all tests with "test_query" in the name +``` + +### Release Preparation + +```bash +# Clean everything +make clean + +# Run all checks +make check + +# Build release packages +make release +``` + +## Troubleshooting + +### "Submodules not initialized" + +```bash +make init +``` + +### Build failures + +```bash +make clean +make build +``` + +### Test failures + +```bash +# Run in verbose mode to see details +make test-verbose + +# Check environment +make info +``` + +### CMake cache issues + +```bash +rm -rf build +make build +``` + +## Environment Variables + +The Makefile automatically sets: + +- `PYTHONPATH`: Includes `src/` directory for imports + +You can customize: + +- `PYTHON`: Python executable (default: `python3`) +- `CMAKE_BUILD_TYPE`: Build type for CMake + +Example: +```bash +PYTHON=python3.11 make build +``` + +## Tips + +1. **Parallel Testing**: Use `make test-fast` to run tests in parallel +2. **Coverage Reports**: Use `make test-coverage` and open `htmlcov/index.html` +3. **Watch Mode**: Install pytest-watch (`pip install pytest-watch`) and use `make watch-test` +4. **Incremental Builds**: `make build` only rebuilds changed files +5. **Clean Slate**: Use `make rebuild` or `make quick` for a fresh build + +## Integration with IDEs + +### VS Code + +Add to `.vscode/tasks.json`: + +```json +{ + "version": "2.0.0", + "tasks": [ + { + "label": "Build", + "type": "shell", + "command": "make build", + "group": "build" + }, + { + "label": "Test", + "type": "shell", + "command": "make test", + "group": "test" + } + ] +} +``` + +### PyCharm + +Configure External Tools: +- Settings → Tools → External Tools → Add +- Program: `make` +- Arguments: `build` (or any other command) +- Working directory: `$ProjectFileDir$` + +## See Also + +- `CONTRIBUTING.md`: Full development guide +- `README.md`: User documentation +- `make help`: List all commands diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..7a14508 --- /dev/null +++ b/Makefile @@ -0,0 +1,247 @@ +.PHONY: help init build build-release test test-verbose test-coverage clean clean-all install dev-install format lint docs check-submodules + +# Default target +.DEFAULT_GOAL := help + +# Colors for output +BOLD := \033[1m +GREEN := \033[32m +YELLOW := \033[33m +BLUE := \033[34m +RESET := \033[0m + +# Python environment +PYTHON := python3 +PYTEST := $(PYTHON) -m pytest +PIP := $(PYTHON) -m pip + +# Project directories +SRC_DIR := src/python_prtree +CPP_DIR := cpp +TEST_DIR := tests +BUILD_DIR := build +DIST_DIR := dist + +# Set PYTHONPATH +export PYTHONPATH := $(CURDIR)/src:$(PYTHONPATH) + +help: ## Show help message + @echo "$(BOLD)$(BLUE)python_prtree Development Makefile$(RESET)" + @echo "" + @echo "$(BOLD)Available commands:$(RESET)" + @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | \ + awk 'BEGIN {FS = ":.*?## "}; {printf " $(GREEN)%-20s$(RESET) %s\n", $$1, $$2}' + @echo "" + @echo "$(BOLD)Development workflow:$(RESET)" + @echo " 1. $(YELLOW)make init$(RESET) - Initial setup" + @echo " 2. $(YELLOW)make build$(RESET) - Build C++ extension" + @echo " 3. $(YELLOW)make test$(RESET) - Run tests" + @echo " 4. $(YELLOW)make clean$(RESET) - Clean up" + @echo "" + +init: check-deps init-submodules ## Initial setup (initialize submodules + install dependencies) + @echo "$(BOLD)$(GREEN)✓ Initialization complete$(RESET)" + +check-deps: ## Check for required dependencies + @echo "$(BOLD)Checking dependencies...$(RESET)" + @command -v git >/dev/null 2>&1 || { echo "$(BOLD)Error: git is not installed$(RESET)" >&2; exit 1; } + @command -v cmake >/dev/null 2>&1 || { echo "$(BOLD)Error: cmake is not installed$(RESET)" >&2; exit 1; } + @command -v $(PYTHON) >/dev/null 2>&1 || { echo "$(BOLD)Error: python3 is not installed$(RESET)" >&2; exit 1; } + @echo "$(GREEN)✓ All required dependencies found$(RESET)" + +init-submodules: ## Initialize git submodules + @echo "$(BOLD)Initializing submodules...$(RESET)" + @if [ ! -f third/pybind11/CMakeLists.txt ]; then \ + git submodule update --init --recursive; \ + echo "$(GREEN)✓ Submodules initialized$(RESET)"; \ + else \ + echo "$(YELLOW)Submodules already initialized$(RESET)"; \ + fi + +check-submodules: ## Check submodule status + @if [ ! -f third/pybind11/CMakeLists.txt ]; then \ + echo "$(BOLD)$(YELLOW)Warning: Submodules not initialized$(RESET)"; \ + echo "$(YELLOW)Please run: make init$(RESET)"; \ + exit 1; \ + fi + +build: check-submodules ## Build C++ extension (in-place, debug mode) + @echo "$(BOLD)Building C++ extension (debug mode)...$(RESET)" + $(PYTHON) setup.py build_ext --inplace + @echo "$(GREEN)✓ Build complete$(RESET)" + +build-release: check-submodules ## Build C++ extension (release mode) + @echo "$(BOLD)Building C++ extension (release mode)...$(RESET)" + CMAKE_BUILD_TYPE=Release $(PYTHON) setup.py build_ext --inplace + @echo "$(GREEN)✓ Release build complete$(RESET)" + +rebuild: clean build ## Clean build (clean + build) + +test: build ## Run tests + @echo "$(BOLD)Running tests...$(RESET)" + $(PYTEST) $(TEST_DIR) -v + @echo "$(GREEN)✓ All tests passed$(RESET)" + +test-verbose: build ## Run tests in verbose mode + @echo "$(BOLD)Running tests (verbose mode)...$(RESET)" + $(PYTEST) $(TEST_DIR) -vv --tb=long + +test-fast: build ## Run tests in parallel (fast) + @echo "$(BOLD)Running tests in parallel...$(RESET)" + $(PYTEST) $(TEST_DIR) -v -n auto + @echo "$(GREEN)✓ All tests passed$(RESET)" + +test-coverage: build ## Run tests with coverage + @echo "$(BOLD)Running tests with coverage...$(RESET)" + $(PYTEST) $(TEST_DIR) --cov=$(SRC_DIR) --cov-report=html --cov-report=term + @echo "$(GREEN)✓ Coverage report generated: htmlcov/index.html$(RESET)" + +test-one: build ## Run specific tests (e.g., make test-one TEST=test_result) + @if [ -z "$(TEST)" ]; then \ + echo "$(BOLD)Error: TEST variable not specified$(RESET)"; \ + echo "Example: make test-one TEST=test_result"; \ + exit 1; \ + fi + @echo "$(BOLD)Running test $(TEST)...$(RESET)" + $(PYTEST) $(TEST_DIR) -k "$(TEST)" -v + +clean: ## Remove build artifacts + @echo "$(BOLD)Cleaning build artifacts...$(RESET)" + rm -rf $(BUILD_DIR) + rm -rf $(DIST_DIR) + rm -rf *.egg-info + rm -rf .pytest_cache + rm -rf .coverage + rm -rf htmlcov + find . -type d -name __pycache__ -exec rm -rf {} + 2>/dev/null || true + find . -type f -name '*.pyc' -delete + find . -type f -name '*.pyo' -delete + find . -type f -name '*.so' -delete + find . -type f -name '*.a' -delete + find $(SRC_DIR) -name '*.so' -delete 2>/dev/null || true + find $(SRC_DIR) -name '*.a' -delete 2>/dev/null || true + @echo "$(GREEN)✓ Cleanup complete$(RESET)" + +clean-all: clean ## Clean everything (including submodules) + @echo "$(BOLD)Cleaning everything...$(RESET)" + git submodule deinit -f --all 2>/dev/null || true + rm -rf third/pybind11/* + rm -rf third/snappy/* + @echo "$(GREEN)✓ Complete cleanup done$(RESET)" + @echo "$(YELLOW)Note: Please run 'make init' again$(RESET)" + +install: ## Install package + @echo "$(BOLD)Installing package...$(RESET)" + $(PIP) install . + @echo "$(GREEN)✓ Installation complete$(RESET)" + +dev-install: ## Install in development mode (pip install -e .) + @echo "$(BOLD)Installing in development mode...$(RESET)" + $(PIP) install -e . + @echo "$(GREEN)✓ Development installation complete$(RESET)" + +install-deps: ## Install development dependencies + @echo "$(BOLD)Installing development dependencies...$(RESET)" + $(PIP) install pytest pytest-cov pytest-xdist numpy + @echo "$(GREEN)✓ Dependencies installed$(RESET)" + +format: ## Format C++ code (requires clang-format) + @if command -v clang-format >/dev/null 2>&1; then \ + echo "$(BOLD)Formatting C++ code...$(RESET)"; \ + find $(CPP_DIR) -name '*.h' -o -name '*.cc' | xargs clang-format -i; \ + echo "$(GREEN)✓ Formatting complete$(RESET)"; \ + else \ + echo "$(YELLOW)Warning: clang-format not installed$(RESET)"; \ + fi + +lint-cpp: ## Lint C++ code (requires clang-tidy) + @if command -v clang-tidy >/dev/null 2>&1; then \ + echo "$(BOLD)Linting C++ code...$(RESET)"; \ + find $(CPP_DIR) -name '*.cc' | xargs clang-tidy; \ + else \ + echo "$(YELLOW)Warning: clang-tidy not installed$(RESET)"; \ + fi + +lint-python: ## Lint Python code (requires flake8) + @if command -v flake8 >/dev/null 2>&1; then \ + echo "$(BOLD)Linting Python code...$(RESET)"; \ + flake8 $(SRC_DIR) $(TEST_DIR) --max-line-length=100; \ + else \ + echo "$(YELLOW)Warning: flake8 not installed$(RESET)"; \ + fi + +lint: lint-cpp lint-python ## Lint all code + +docs: ## Generate documentation (requires Doxygen) + @if command -v doxygen >/dev/null 2>&1; then \ + echo "$(BOLD)Generating documentation...$(RESET)"; \ + doxygen Doxyfile 2>/dev/null || echo "$(YELLOW)Doxyfile not found$(RESET)"; \ + else \ + echo "$(YELLOW)Warning: doxygen not installed$(RESET)"; \ + fi + +benchmark: build ## Run benchmarks (if benchmark script exists) + @if [ -f benchmark.py ]; then \ + echo "$(BOLD)Running benchmarks...$(RESET)"; \ + $(PYTHON) benchmark.py; \ + else \ + echo "$(YELLOW)benchmark.py not found$(RESET)"; \ + fi + +wheel: check-submodules ## Build wheel package + @echo "$(BOLD)Building wheel package...$(RESET)" + $(PYTHON) setup.py bdist_wheel + @echo "$(GREEN)✓ Wheel package created: $(DIST_DIR)/$(RESET)" + @ls -lh $(DIST_DIR)/*.whl 2>/dev/null || true + +sdist: ## Build source distribution + @echo "$(BOLD)Building source distribution...$(RESET)" + $(PYTHON) setup.py sdist + @echo "$(GREEN)✓ Source distribution created: $(DIST_DIR)/$(RESET)" + +release: clean check-submodules wheel sdist ## Build release packages (wheel + sdist) + @echo "$(BOLD)$(GREEN)✓ Release packages ready$(RESET)" + @echo "Distribution files:" + @ls -lh $(DIST_DIR)/ + +check: build test ## Run build and tests (for CI) + @echo "$(BOLD)$(GREEN)✓ All checks passed$(RESET)" + +watch-test: ## Run tests in watch mode (requires pytest-watch) + @if command -v ptw >/dev/null 2>&1; then \ + echo "$(BOLD)Starting test watch mode...$(RESET)"; \ + ptw -- $(TEST_DIR) -v; \ + else \ + echo "$(YELLOW)pytest-watch not installed$(RESET)"; \ + echo "Install: pip install pytest-watch"; \ + fi + +debug-build: ## Build with debug information + @echo "$(BOLD)Building with debug info...$(RESET)" + CMAKE_BUILD_TYPE=Debug $(PYTHON) setup.py build_ext --inplace + @echo "$(GREEN)✓ Debug build complete$(RESET)" + +info: ## Show project information + @echo "$(BOLD)$(BLUE)Project Information$(RESET)" + @echo "Python: $$($(PYTHON) --version)" + @echo "Pip: $$($(PIP) --version)" + @echo "CMake: $$(cmake --version | head -n1)" + @echo "Git: $$(git --version)" + @echo "" + @echo "$(BOLD)Project structure:$(RESET)" + @echo "Source directory: $(SRC_DIR)" + @echo "C++ directory: $(CPP_DIR)" + @echo "Test directory: $(TEST_DIR)" + @echo "" + @echo "$(BOLD)Submodule status:$(RESET)" + @git submodule status + +# Quick development targets +quick: clean build test ## Quick test (clean + build + test) + +dev: init install-deps build ## Setup development environment + @echo "$(BOLD)$(GREEN)✓ Development environment setup complete$(RESET)" + @echo "" + @echo "Next steps:" + @echo " - $(YELLOW)make test$(RESET) to run tests" + @echo " - $(YELLOW)make watch-test$(RESET) to start auto-testing (requires pytest-watch)" diff --git a/README.md b/README.md index d9a5cf6..727d09c 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,10 @@ _python_prtree_ is a python/c++ implementation of the Priority R-Tree (see refer - `query` and `batch_query` - `batch_query` is parallelized by `std::thread` and is much faster than the `query` method. - The `query` method has an optional keyword argument `return_obj`; if `return_obj=True`, a Python object is returned. +- `query_intersections` + - Returns all pairs of intersecting AABBs as a numpy array of shape (n_pairs, 2). + - Optimized for performance with parallel processing and double-precision refinement. + - Similar to `scipy.spatial.cKDTree.query_pairs` but for bounding boxes instead of points. - `rebuild` - It improves performance when many insert/delete operations are called since the last rebuild. - Note that if the size changes more than 1.5 times, the insert/erase method also performs `rebuild`. @@ -77,6 +81,11 @@ print(prtree.query([0.5, 0.5])) # [1] print(prtree.query(0.5, 0.5)) # 1d-array # [1] + +# Find all pairs of intersecting rectangles +pairs = prtree.query_intersections() +print(pairs) +# [[1 3]] # rectangles with index 1 and 3 intersect ``` ```python diff --git a/cpp/main.cc b/cpp/main.cc index bf6bee4..a5a7a79 100644 --- a/cpp/main.cc +++ b/cpp/main.cc @@ -9,8 +9,7 @@ using T = int64_t; // is a temporary type of template. You can change it and // recompile this. const int B = 8; // the number of children of tree. -PYBIND11_MODULE(PRTree, m) -{ +PYBIND11_MODULE(PRTree, m) { m.doc() = R"pbdoc( INCOMPLETE Priority R-Tree Only supports for construct and find @@ -62,6 +61,12 @@ PYBIND11_MODULE(PRTree, m) )pbdoc") .def("size", &PRTree::size, R"pbdoc( get n + )pbdoc") + .def("query_intersections", &PRTree::query_intersections, + R"pbdoc( + Find all pairs of intersecting AABBs. + Returns a numpy array of shape (n_pairs, 2) where each row contains + a pair of indices (i, j) with i < j representing intersecting AABBs. )pbdoc"); py::class_>(m, "_PRTree3D") @@ -109,6 +114,12 @@ PYBIND11_MODULE(PRTree, m) )pbdoc") .def("size", &PRTree::size, R"pbdoc( get n + )pbdoc") + .def("query_intersections", &PRTree::query_intersections, + R"pbdoc( + Find all pairs of intersecting AABBs. + Returns a numpy array of shape (n_pairs, 2) where each row contains + a pair of indices (i, j) with i < j representing intersecting AABBs. )pbdoc"); py::class_>(m, "_PRTree4D") @@ -156,6 +167,12 @@ PYBIND11_MODULE(PRTree, m) )pbdoc") .def("size", &PRTree::size, R"pbdoc( get n + )pbdoc") + .def("query_intersections", &PRTree::query_intersections, + R"pbdoc( + Find all pairs of intersecting AABBs. + Returns a numpy array of shape (n_pairs, 2) where each row contains + a pair of indices (i, j) with i < j representing intersecting AABBs. )pbdoc"); #ifdef VERSION_INFO diff --git a/cpp/parallel.h b/cpp/parallel.h index 4ef1908..a682a35 100644 --- a/cpp/parallel.h +++ b/cpp/parallel.h @@ -1,13 +1,14 @@ #pragma once -#include -#include #include +#include +#include template -void parallel_for_each(const Iter first, const Iter last, T &result, const F &func) -{ +void parallel_for_each(const Iter first, const Iter last, T &result, + const F &func) { auto f = std::ref(func); - const size_t nthreads = (size_t)std::max(1, (int)std::thread::hardware_concurrency()); + const size_t nthreads = + (size_t)std::max(1, (int)std::thread::hardware_concurrency()); const size_t total = std::distance(first, last); std::vector rr(nthreads); { @@ -17,40 +18,35 @@ void parallel_for_each(const Iter first, const Iter last, T &result, const F &fu size_t remaining = total % nthreads; Iter n = first; iters.emplace_back(first); - for (size_t i = 0; i < nthreads - 1; ++i) - { + for (size_t i = 0; i < nthreads - 1; ++i) { std::advance(n, i < remaining ? step + 1 : step); iters.emplace_back(n); } iters.emplace_back(last); result.reserve(total); - for (auto &r : rr) - { + for (auto &r : rr) { r.reserve(total / nthreads + 1); } - for (size_t t = 0; t < nthreads; t++) - { - threads.emplace_back(std::thread([&, t] - { std::for_each(iters[t], iters[t + 1], [&](auto &x) - { f(x, rr[t]); }); })); + for (size_t t = 0; t < nthreads; t++) { + threads.emplace_back(std::thread([&, t] { + std::for_each(iters[t], iters[t + 1], [&](auto &x) { f(x, rr[t]); }); + })); } - std::for_each(threads.begin(), threads.end(), [&](std::thread &x) - { x.join(); }); + std::for_each(threads.begin(), threads.end(), + [&](std::thread &x) { x.join(); }); } - for (size_t t = 0; t < nthreads; t++) - { - result.insert(result.end(), - std::make_move_iterator(rr[t].begin()), + for (size_t t = 0; t < nthreads; t++) { + result.insert(result.end(), std::make_move_iterator(rr[t].begin()), std::make_move_iterator(rr[t].end())); } } template -void parallel_for_each(const Iter first, const Iter last, const F &func) -{ +void parallel_for_each(const Iter first, const Iter last, const F &func) { auto f = std::ref(func); - const size_t nthreads = (size_t)std::max(1, (int)std::thread::hardware_concurrency()); + const size_t nthreads = + (size_t)std::max(1, (int)std::thread::hardware_concurrency()); const size_t total = std::distance(first, last); { std::vector threads; @@ -59,19 +55,17 @@ void parallel_for_each(const Iter first, const Iter last, const F &func) size_t remaining = total % nthreads; Iter n = first; iters.emplace_back(first); - for (size_t i = 0; i < nthreads - 1; ++i) - { + for (size_t i = 0; i < nthreads - 1; ++i) { std::advance(n, i < remaining ? step + 1 : step); iters.emplace_back(n); } iters.emplace_back(last); - for (size_t t = 0; t < nthreads; t++) - { - threads.emplace_back(std::thread([&, t] - { std::for_each(iters[t], iters[t + 1], [&](auto &x) - { f(x); }); })); + for (size_t t = 0; t < nthreads; t++) { + threads.emplace_back(std::thread([&, t] { + std::for_each(iters[t], iters[t + 1], [&](auto &x) { f(x); }); + })); } - std::for_each(threads.begin(), threads.end(), [&](std::thread &x) - { x.join(); }); + std::for_each(threads.begin(), threads.end(), + [&](std::thread &x) { x.join(); }); } } diff --git a/cpp/prtree.h b/cpp/prtree.h index 074385c..dab3bc2 100644 --- a/cpp/prtree.h +++ b/cpp/prtree.h @@ -5,6 +5,7 @@ #include #include #include +#include #include #include #include @@ -12,15 +13,14 @@ #include #include #include +#include #include #include +#include #include #include -#include -#include #include -#include -#include +#include #include #include @@ -30,15 +30,15 @@ #include #include #include +#include #include //for smart pointers #include #include #include -#include -#include #include "parallel.h" #include "small_vector.h" +#include #ifdef MY_DEBUG #include @@ -48,30 +48,27 @@ using Real = float; namespace py = pybind11; -template -using vec = std::vector; +template using vec = std::vector; template -inline py::array_t as_pyarray(Sequence &seq) -{ +inline py::array_t as_pyarray(Sequence &seq) { auto size = seq.size(); auto data = seq.data(); - std::unique_ptr seq_ptr = std::make_unique(std::move(seq)); - auto capsule = py::capsule(seq_ptr.get(), [](void *p) - { std::unique_ptr(reinterpret_cast(p)); }); + std::unique_ptr seq_ptr = + std::make_unique(std::move(seq)); + auto capsule = py::capsule(seq_ptr.get(), [](void *p) { + std::unique_ptr(reinterpret_cast(p)); + }); seq_ptr.release(); return py::array(size, data, capsule); } -template -auto list_list_to_arrays(vec> out_ll) -{ +template auto list_list_to_arrays(vec> out_ll) { vec out_s; out_s.reserve(out_ll.size()); std::size_t sum = 0; - for (auto &&i : out_ll) - { + for (auto &&i : out_ll) { out_s.push_back(i.size()); sum += i.size(); } @@ -80,19 +77,15 @@ auto list_list_to_arrays(vec> out_ll) for (const auto &v : out_ll) out.insert(out.end(), v.begin(), v.end()); - return make_tuple( - std::move(as_pyarray(out_s)), - std::move(as_pyarray(out))); + return make_tuple(std::move(as_pyarray(out_s)), std::move(as_pyarray(out))); } template using svec = itlib::small_vector; -template -using deque = std::deque; +template using deque = std::deque; -template -using queue = std::queue>; +template using queue = std::queue>; static const float REBUILD_THRE = 1.25; @@ -104,160 +97,129 @@ static const float REBUILD_THRE = 1.25; #define unlikely(x) (x) #endif -std::string compress(std::string &data) -{ +std::string compress(std::string &data) { std::string output; snappy::Compress(data.data(), data.size(), &output); return output; } -std::string decompress(std::string &data) -{ +std::string decompress(std::string &data) { std::string output; snappy::Uncompress(data.data(), data.size(), &output); return output; } -template -class BB -{ +template class BB { private: Real values[2 * D]; public: BB() { clear(); } - BB(const Real (&minima)[D], const Real (&maxima)[D]) - { + BB(const Real (&minima)[D], const Real (&maxima)[D]) { Real v[2 * D]; - for (int i = 0; i < D; ++i) - { + for (int i = 0; i < D; ++i) { v[i] = -minima[i]; v[i + D] = maxima[i]; } validate(v); - for (int i = 0; i < D; ++i) - { + for (int i = 0; i < D; ++i) { values[i] = v[i]; values[i + D] = v[i + D]; } } - BB(const Real (&v)[2 * D]) - { + BB(const Real (&v)[2 * D]) { validate(v); - for (int i = 0; i < D; ++i) - { + for (int i = 0; i < D; ++i) { values[i] = v[i]; values[i + D] = v[i + D]; } } - Real min(const int dim) const - { - if (unlikely(dim < 0 || D <= dim)) - { + Real min(const int dim) const { + if (unlikely(dim < 0 || D <= dim)) { throw std::runtime_error("Invalid dim"); } return -values[dim]; } - Real max(const int dim) const - { - if (unlikely(dim < 0 || D <= dim)) - { + Real max(const int dim) const { + if (unlikely(dim < 0 || D <= dim)) { throw std::runtime_error("Invalid dim"); } return values[dim + D]; } - bool validate(const Real (&v)[2 * D]) const - { + bool validate(const Real (&v)[2 * D]) const { bool flag = false; - for (int i = 0; i < D; ++i) - { - if (unlikely(-v[i] > v[i + D])) - { + for (int i = 0; i < D; ++i) { + if (unlikely(-v[i] > v[i + D])) { flag = true; break; } } - if (unlikely(flag)) - { + if (unlikely(flag)) { throw std::runtime_error("Invalid Bounding Box"); } return flag; } - void clear() - { - for (int i = 0; i < 2 * D; ++i) - { + void clear() { + for (int i = 0; i < 2 * D; ++i) { values[i] = -1e100; } } - Real val_for_comp(const int &axis) const - { + Real val_for_comp(const int &axis) const { const int axis2 = (axis + 1) % (2 * D); return values[axis] + values[axis2]; } - BB operator+(const BB &rhs) const - { + BB operator+(const BB &rhs) const { Real result[2 * D]; - for (int i = 0; i < 2 * D; ++i) - { + for (int i = 0; i < 2 * D; ++i) { result[i] = std::max(values[i], rhs.values[i]); } return BB(result); } - BB operator+=(const BB &rhs) - { - for (int i = 0; i < 2 * D; ++i) - { + BB operator+=(const BB &rhs) { + for (int i = 0; i < 2 * D; ++i) { values[i] = std::max(values[i], rhs.values[i]); } return *this; } - void expand(const Real (&delta)[D]) - { - for (int i = 0; i < D; ++i) - { + void expand(const Real (&delta)[D]) { + for (int i = 0; i < D; ++i) { values[i] += delta[i]; values[i + D] += delta[i]; } } - bool operator()(const BB &target) const - { // whether this and target has any intersect + bool operator()( + const BB &target) const { // whether this and target has any intersect Real minima[D]; Real maxima[D]; bool flags[D]; bool flag = true; - for (int i = 0; i < D; ++i) - { + for (int i = 0; i < D; ++i) { minima[i] = std::min(values[i], target.values[i]); maxima[i] = std::min(values[i + D], target.values[i + D]); } - for (int i = 0; i < D; ++i) - { + for (int i = 0; i < D; ++i) { flags[i] = -minima[i] <= maxima[i]; } - for (int i = 0; i < D; ++i) - { + for (int i = 0; i < D; ++i) { flag &= flags[i]; } return flag; } - Real area() const - { + Real area() const { Real result = 1; - for (int i = 0; i < D; ++i) - { + for (int i = 0; i < D; ++i) { result *= max(i) - min(i); } return result; @@ -265,102 +227,82 @@ class BB inline Real operator[](const int i) const { return values[i]; } - template - void serialize(Archive &ar) { ar(values); } + template void serialize(Archive &ar) { ar(values); } }; -template -class DataType -{ +template class DataType { public: BB second; T first; DataType(){}; - DataType(const T &f, const BB &s) - { + DataType(const T &f, const BB &s) { first = f; second = s; } - DataType(T &&f, BB &&s) noexcept - { + DataType(T &&f, BB &&s) noexcept { first = std::move(f); second = std::move(s); } - template - void serialize(Archive &ar) { ar(first, second); } + template void serialize(Archive &ar) { ar(first, second); } }; template -void clean_data(DataType *b, DataType *e) -{ - for (DataType *it = e - 1; it >= b; --it) - { +void clean_data(DataType *b, DataType *e) { + for (DataType *it = e - 1; it >= b; --it) { it->~DataType(); } } -template -class Leaf -{ +template class Leaf { public: BB mbb; svec, B> data; // You can swap when filtering int axis = 0; // T is type of keys(ids) which will be returned when you post a query. - Leaf() - { - mbb = BB(); - } - Leaf(const int _axis) - { + Leaf() { mbb = BB(); } + Leaf(const int _axis) { axis = _axis; mbb = BB(); } void set_axis(const int &_axis) { axis = _axis; } - void push(const T &key, const BB &target) - { + void push(const T &key, const BB &target) { data.emplace_back(key, target); update_mbb(); } - void update_mbb() - { + void update_mbb() { mbb.clear(); - for (const auto &datum : data) - { + for (const auto &datum : data) { mbb += datum.second; } } - bool filter(DataType &value) - { // false means given value is ignored - auto comp = [=](const auto &a, const auto &b) noexcept - { return a.second.val_for_comp(axis) < b.second.val_for_comp(axis); }; + bool filter(DataType &value) { // false means given value is ignored + auto comp = [=](const auto &a, const auto &b) noexcept { + return a.second.val_for_comp(axis) < b.second.val_for_comp(axis); + }; - if (data.size() < B) - { // if there is room, just push the candidate + if (data.size() < B) { // if there is room, just push the candidate auto iter = std::lower_bound(data.begin(), data.end(), value, comp); DataType tmp_value = DataType(value); data.insert(iter, std::move(tmp_value)); mbb += value.second; return true; - } - else - { // if there is no room, check the priority and swap if needed - if (data[0].second.val_for_comp(axis) < value.second.val_for_comp(axis)) - { - size_t n_swap = std::lower_bound(data.begin(), data.end(), value, comp) - data.begin(); + } else { // if there is no room, check the priority and swap if needed + if (data[0].second.val_for_comp(axis) < value.second.val_for_comp(axis)) { + size_t n_swap = + std::lower_bound(data.begin(), data.end(), value, comp) - + data.begin(); std::swap(*data.begin(), value); auto iter = data.begin(); - for (size_t i = 0; i < n_swap - 1; ++i) - { + for (size_t i = 0; i < n_swap - 1; ++i) { std::swap(*(iter + i), *(iter + i + 1)); } update_mbb(); @@ -370,65 +312,50 @@ class Leaf } }; -template -class PseudoPRTreeNode -{ +template class PseudoPRTreeNode { public: Leaf leaves[2 * D]; std::unique_ptr left, right; - PseudoPRTreeNode() - { - for (int i = 0; i < 2 * D; i++) - { + PseudoPRTreeNode() { + for (int i = 0; i < 2 * D; i++) { leaves[i].set_axis(i); } } - PseudoPRTreeNode(const int axis) - { - for (int i = 0; i < 2 * D; i++) - { + PseudoPRTreeNode(const int axis) { + for (int i = 0; i < 2 * D; i++) { const int j = (axis + i) % (2 * D); leaves[i].set_axis(j); } } - template - void serialize(Archive &archive) - { + template void serialize(Archive &archive) { // archive(cereal::(left), cereal::defer(right), leaves); archive(left, right, leaves); } - void address_of_leaves(vec *> &out) - { - for (auto &leaf : leaves) - { - if (leaf.data.size() > 0) - { + void address_of_leaves(vec *> &out) { + for (auto &leaf : leaves) { + if (leaf.data.size() > 0) { out.emplace_back(&leaf); } } } - template - auto filter(const iterator &b, const iterator &e) - { - auto out = std::remove_if(b, e, [&](auto &x) - { + template auto filter(const iterator &b, const iterator &e) { + auto out = std::remove_if(b, e, [&](auto &x) { for (auto &l : leaves) { if (l.filter(x)) { return true; } } - return false; }); + return false; + }); return out; } }; -template -class PseudoPRTree -{ +template class PseudoPRTree { public: std::unique_ptr> root; vec *> cache_children; @@ -436,30 +363,23 @@ class PseudoPRTree PseudoPRTree() { root = std::make_unique>(); } - template - PseudoPRTree(const iterator &b, const iterator &e) - { - if (!root) - { + template PseudoPRTree(const iterator &b, const iterator &e) { + if (!root) { root = std::make_unique>(); } construct(root.get(), b, e, 0); clean_data(b, e); } - template - void serialize(Archive &archive) - { + template void serialize(Archive &archive) { archive(root); // archive.serializeDeferments(); } template - void construct(PseudoPRTreeNode *node, const iterator &b, const iterator &e, - const int depth) - { - if (e - b > 0 && node != nullptr) - { + void construct(PseudoPRTreeNode *node, const iterator &b, + const iterator &e, const int depth) { + if (e - b > 0 && node != nullptr) { bool use_recursive_threads = std::pow(2, depth + 1) <= nthreads; #ifdef MY_DEBUG use_recursive_threads = false; @@ -475,59 +395,44 @@ class PseudoPRTree std::advance(m, (ee - b) / 2); std::nth_element(b, m, ee, [axis](const DataType &lhs, - const DataType &rhs) noexcept - { + const DataType &rhs) noexcept { return lhs.second[axis] < rhs.second[axis]; }); - if (m - b > 0) - { + if (m - b > 0) { node->left = std::make_unique>(axis); node_left = node->left.get(); - if (use_recursive_threads) - { + if (use_recursive_threads) { threads.push_back( - std::thread([&]() - { construct(node_left, b, m, depth + 1); })); - } - else - { + std::thread([&]() { construct(node_left, b, m, depth + 1); })); + } else { construct(node_left, b, m, depth + 1); } } - if (ee - m > 0) - { + if (ee - m > 0) { node->right = std::make_unique>(axis); node_right = node->right.get(); - if (use_recursive_threads) - { + if (use_recursive_threads) { threads.push_back( - std::thread([&]() - { construct(node_right, m, ee, depth + 1); })); - } - else - { + std::thread([&]() { construct(node_right, m, ee, depth + 1); })); + } else { construct(node_right, m, ee, depth + 1); } } std::for_each(threads.begin(), threads.end(), - [&](std::thread &x) - { x.join(); }); + [&](std::thread &x) { x.join(); }); } } - auto get_all_leaves(const int hint) - { - if (cache_children.empty()) - { + auto get_all_leaves(const int hint) { + if (cache_children.empty()) { using U = PseudoPRTreeNode; cache_children.reserve(hint); auto node = root.get(); queue que; que.emplace(node); - while (!que.empty()) - { + while (!que.empty()) { node = que.front(); que.pop(); node->address_of_leaves(cache_children); @@ -540,110 +445,84 @@ class PseudoPRTree return cache_children; } - std::pair *, DataType *> as_X(void *placement, const int hint) - { + std::pair *, DataType *> as_X(void *placement, + const int hint) { DataType *b, *e; auto children = get_all_leaves(hint); T total = children.size(); b = reinterpret_cast *>(placement); e = b + total; - for (T i = 0; i < total; i++) - { + for (T i = 0; i < total; i++) { new (b + i) DataType{i, children[i]->mbb}; } return {b, e}; } }; -template -class PRTreeLeaf -{ +template class PRTreeLeaf { public: BB mbb; svec, B> data; - PRTreeLeaf() - { - mbb = BB(); - } + PRTreeLeaf() { mbb = BB(); } - PRTreeLeaf(const Leaf &leaf) - { + PRTreeLeaf(const Leaf &leaf) { mbb = leaf.mbb; data = leaf.data; } - Real area() const - { - return mbb.area(); - } + Real area() const { return mbb.area(); } - void update_mbb() - { + void update_mbb() { mbb.clear(); - for (const auto &datum : data) - { + for (const auto &datum : data) { mbb += datum.second; } } - void operator()(const BB &target, vec &out) const - { - if (mbb(target)) - { - for (const auto &x : data) - { - if (x.second(target)) - { + void operator()(const BB &target, vec &out) const { + if (mbb(target)) { + for (const auto &x : data) { + if (x.second(target)) { out.emplace_back(x.first); } } } } - void del(const T &key, const BB &target) - { - if (mbb(target)) - { + void del(const T &key, const BB &target) { + if (mbb(target)) { auto remove_it = - std::remove_if(data.begin(), data.end(), [&](auto &datum) - { return datum.second(target) && datum.first == key; }); + std::remove_if(data.begin(), data.end(), [&](auto &datum) { + return datum.second(target) && datum.first == key; + }); data.erase(remove_it, data.end()); } } - void push(const T &key, const BB &target) - { + void push(const T &key, const BB &target) { data.emplace_back(key, target); update_mbb(); } - template - void save(Archive &ar) const - { + template void save(Archive &ar) const { vec> _data; - for (const auto &datum : data) - { + for (const auto &datum : data) { _data.push_back(datum); } ar(mbb, _data); } - template - void load(Archive &ar) - { + template void load(Archive &ar) { vec> _data; ar(mbb, _data); - for (const auto &datum : _data) - { + for (const auto &datum : _data) { data.push_back(datum); } } }; -template -class PRTreeNode -{ +template class PRTreeNode { public: BB mbb; std::unique_ptr> leaf; @@ -654,8 +533,7 @@ class PRTreeNode PRTreeNode(BB &&_mbb) noexcept { mbb = std::move(_mbb); } - PRTreeNode(Leaf *l) - { + PRTreeNode(Leaf *l) { leaf = std::make_unique>(); mbb = l->mbb; leaf->mbb = std::move(l->mbb); @@ -665,25 +543,20 @@ class PRTreeNode bool operator()(const BB &target) { return mbb(target); } }; -template -class PRTreeElement -{ +template class PRTreeElement { public: BB mbb; std::unique_ptr> leaf; bool is_used = false; - PRTreeElement() - { + PRTreeElement() { mbb = BB(); is_used = false; } - PRTreeElement(const PRTreeNode &node) - { + PRTreeElement(const PRTreeNode &node) { mbb = BB(node.mbb); - if (node.leaf) - { + if (node.leaf) { Leaf tmp_leaf = Leaf(*node.leaf.get()); leaf = std::make_unique>(tmp_leaf); } @@ -692,23 +565,20 @@ class PRTreeElement bool operator()(const BB &target) { return is_used && mbb(target); } - template - void serialize(Archive &archive) - { + template void serialize(Archive &archive) { archive(mbb, leaf, is_used); } }; template -void bfs(const std::function> &)> &func, vec> &flat_tree, const BB target) -{ +void bfs( + const std::function> &)> &func, + vec> &flat_tree, const BB target) { queue que; - auto qpush_if_intersect = [&](const size_t &i) - { + auto qpush_if_intersect = [&](const size_t &i) { PRTreeElement &r = flat_tree[i]; // std::cout << "i " << (long int) i << " : " << (bool) r.leaf << std::endl; - if (r(target)) - { + if (r(target)) { // std::cout << " is pushed" << std::endl; que.emplace(i); } @@ -716,22 +586,17 @@ void bfs(const std::function> &)> &func // std::cout << "size: " << flat_tree.size() << std::endl; qpush_if_intersect(0); - while (!que.empty()) - { + while (!que.empty()) { size_t idx = que.front(); // std::cout << "idx: " << (long int) idx << std::endl; que.pop(); PRTreeElement &elem = flat_tree[idx]; - if (elem.leaf) - { + if (elem.leaf) { // std::cout << "func called for " << (long int) idx << std::endl; func(elem.leaf); - } - else - { - for (size_t offset = 0; offset < B; offset++) - { + } else { + for (size_t offset = 0; offset < B; offset++) { size_t jdx = idx * B + offset + 1; qpush_if_intersect(jdx); } @@ -739,28 +604,24 @@ void bfs(const std::function> &)> &func } } -template -class PRTree -{ +template class PRTree { private: vec> flat_tree; std::unordered_map> idx2bb; std::unordered_map idx2data; int64_t n_at_build = 0; std::atomic global_idx = 0; - - // Double-precision storage for exact refinement (optional, only when built from float64) + + // Double-precision storage for exact refinement (optional, only when built + // from float64) std::unordered_map> idx2exact; public: - template - void serialize(Archive &archive) - { + template void serialize(Archive &archive) { archive(flat_tree, idx2bb, idx2data, global_idx, n_at_build, idx2exact); } - void save(std::string fname) - { + void save(std::string fname) { { { std::ofstream ofs(fname, std::ios::binary); @@ -775,8 +636,7 @@ class PRTree } } - void load(std::string fname) - { + void load(std::string fname) { { { std::ifstream ifs(fname, std::ios::binary); @@ -795,43 +655,39 @@ class PRTree PRTree(std::string fname) { load(fname); } - // Helper: Validate bounding box coordinates (reject NaN/Inf, enforce min <= max) - template - void validate_box(const CoordType* coords, int dim_count) const - { - for (int i = 0; i < dim_count; ++i) - { + // Helper: Validate bounding box coordinates (reject NaN/Inf, enforce min <= + // max) + template + void validate_box(const CoordType *coords, int dim_count) const { + for (int i = 0; i < dim_count; ++i) { CoordType min_val = coords[i]; CoordType max_val = coords[i + dim_count]; - + // Check for NaN or Inf - if (!std::isfinite(min_val) || !std::isfinite(max_val)) - { - throw std::runtime_error("Bounding box coordinates must be finite (no NaN or Inf)"); + if (!std::isfinite(min_val) || !std::isfinite(max_val)) { + throw std::runtime_error( + "Bounding box coordinates must be finite (no NaN or Inf)"); } - + // Enforce min <= max - if (min_val > max_val) - { - throw std::runtime_error("Bounding box minimum must be <= maximum in each dimension"); + if (min_val > max_val) { + throw std::runtime_error( + "Bounding box minimum must be <= maximum in each dimension"); } } } // Constructor for float32 input (no refinement, pure float32 performance) - PRTree(const py::array_t &idx, const py::array_t &x) - { + PRTree(const py::array_t &idx, const py::array_t &x) { const auto &buff_info_idx = idx.request(); const auto &shape_idx = buff_info_idx.shape; const auto &buff_info_x = x.request(); const auto &shape_x = buff_info_x.shape; - if (unlikely(shape_idx[0] != shape_x[0])) - { + if (unlikely(shape_idx[0] != shape_x[0])) { throw std::runtime_error( "Both index and boudning box must have the same length"); } - if (unlikely(shape_x[1] != 2 * D)) - { + if (unlikely(shape_x[1] != 2 * D)) { throw std::runtime_error( "Bounding box must have the shape (length, 2 * dim)"); } @@ -847,37 +703,32 @@ class PRTree b = reinterpret_cast *>(placement); e = b + length; - for (T i = 0; i < length; i++) - { + for (T i = 0; i < length; i++) { Real minima[D]; Real maxima[D]; - - for (int j = 0; j < D; ++j) - { - minima[j] = rx(i, j); // Direct float32 assignment + + for (int j = 0; j < D; ++j) { + minima[j] = rx(i, j); // Direct float32 assignment maxima[j] = rx(i, j + D); } - + // Validate bounding box (reject NaN/Inf, enforce min <= max) float coords[2 * D]; - for (int j = 0; j < D; ++j) - { + for (int j = 0; j < D; ++j) { coords[j] = minima[j]; coords[j + D] = maxima[j]; } validate_box(coords, D); - + auto bb = BB(minima, maxima); auto ri_i = ri(i); new (b + i) DataType{std::move(ri_i), std::move(bb)}; } - for (T i = 0; i < length; i++) - { + for (T i = 0; i < length; i++) { Real minima[D]; Real maxima[D]; - for (int j = 0; j < D; ++j) - { + for (int j = 0; j < D; ++j) { minima[j] = rx(i, j); maxima[j] = rx(i, j + D); } @@ -890,19 +741,16 @@ class PRTree } // Constructor for float64 input (float32 tree + double refinement) - PRTree(const py::array_t &idx, const py::array_t &x) - { + PRTree(const py::array_t &idx, const py::array_t &x) { const auto &buff_info_idx = idx.request(); const auto &shape_idx = buff_info_idx.shape; const auto &buff_info_x = x.request(); const auto &shape_x = buff_info_x.shape; - if (unlikely(shape_idx[0] != shape_x[0])) - { + if (unlikely(shape_idx[0] != shape_x[0])) { throw std::runtime_error( "Both index and boudning box must have the same length"); } - if (unlikely(shape_x[1] != 2 * D)) - { + if (unlikely(shape_x[1] != 2 * D)) { throw std::runtime_error( "Bounding box must have the shape (length, 2 * dim)"); } @@ -911,49 +759,45 @@ class PRTree auto rx = x.template unchecked<2>(); T length = shape_idx[0]; idx2bb.reserve(length); - idx2exact.reserve(length); // Reserve space for exact coordinates + idx2exact.reserve(length); // Reserve space for exact coordinates DataType *b, *e; void *placement = std::malloc(sizeof(DataType) * length); b = reinterpret_cast *>(placement); e = b + length; - for (T i = 0; i < length; i++) - { + for (T i = 0; i < length; i++) { Real minima[D]; Real maxima[D]; std::array exact_coords; - - for (int j = 0; j < D; ++j) - { + + for (int j = 0; j < D; ++j) { double val_min = rx(i, j); double val_max = rx(i, j + D); - exact_coords[j] = val_min; // Store exact double for refinement + exact_coords[j] = val_min; // Store exact double for refinement exact_coords[j + D] = val_max; } - - // Validate bounding box with double precision (reject NaN/Inf, enforce min <= max) + + // Validate bounding box with double precision (reject NaN/Inf, enforce + // min <= max) validate_box(exact_coords.data(), D); - + // Convert to float32 for tree after validation - for (int j = 0; j < D; ++j) - { + for (int j = 0; j < D; ++j) { minima[j] = static_cast(exact_coords[j]); maxima[j] = static_cast(exact_coords[j + D]); } - + auto bb = BB(minima, maxima); auto ri_i = ri(i); - idx2exact[ri_i] = exact_coords; // Store exact coordinates + idx2exact[ri_i] = exact_coords; // Store exact coordinates new (b + i) DataType{std::move(ri_i), std::move(bb)}; } - for (T i = 0; i < length; i++) - { + for (T i = 0; i < length; i++) { Real minima[D]; Real maxima[D]; - for (int j = 0; j < D; ++j) - { + for (int j = 0; j < D; ++j) { minima[j] = static_cast(rx(i, j)); maxima[j] = static_cast(rx(i, j + D)); } @@ -965,29 +809,26 @@ class PRTree std::free(placement); } - void set_obj(const T &idx, std::optional objdumps = std::nullopt) - { - if (objdumps) - { + void set_obj(const T &idx, + std::optional objdumps = std::nullopt) { + if (objdumps) { auto val = objdumps.value(); idx2data.emplace(idx, compress(val)); } } - py::object get_obj(const T &idx) - { + py::object get_obj(const T &idx) { py::object obj = py::none(); auto search = idx2data.find(idx); - if (likely(search != idx2data.end())) - { + if (likely(search != idx2data.end())) { auto val = idx2data.at(idx); obj = py::cast(py::bytes(decompress(val))); } return obj; } - void insert(const T &idx, const py::array_t &x, const std::optional objdumps = std::nullopt) - { + void insert(const T &idx, const py::array_t &x, + const std::optional objdumps = std::nullopt) { #ifdef MY_DEBUG ProfilerStart("insert.prof"); std::cout << "profiler start of insert" << std::endl; @@ -998,20 +839,17 @@ class PRTree const auto &buff_info_x = x.request(); const auto &shape_x = buff_info_x.shape; const auto &ndim = buff_info_x.ndim; - if (unlikely((shape_x[0] != 2 * D || ndim != 1))) - { + if (unlikely((shape_x[0] != 2 * D || ndim != 1))) { throw std::runtime_error("invalid shape."); } auto it = idx2bb.find(idx); - if (unlikely(it != idx2bb.end())) - { + if (unlikely(it != idx2bb.end())) { throw std::runtime_error("Given index is already included."); } { Real minima[D]; Real maxima[D]; - for (int i = 0; i < D; ++i) - { + for (int i = 0; i < D; ++i) { minima[i] = *x.data(i); maxima[i] = *x.data(i + D); } @@ -1021,48 +859,38 @@ class PRTree set_obj(idx, objdumps); Real delta[D]; - for (int i = 0; i < D; ++i) - { + for (int i = 0; i < D; ++i) { delta[i] = bb.max(i) - bb.min(i) + 0.00000001; } // find the leaf node to insert Real c = 0.0; size_t count = flat_tree.size(); - while (cands.empty()) - { + while (cands.empty()) { Real d[D]; - for (int i = 0; i < D; ++i) - { + for (int i = 0; i < D; ++i) { d[i] = delta[i] * c; } bb.expand(d); c = (c + 1) * 2; queue que; - auto qpush_if_intersect = [&](const size_t &i) - { - if (flat_tree[i](bb)) - { + auto qpush_if_intersect = [&](const size_t &i) { + if (flat_tree[i](bb)) { que.emplace(i); } }; qpush_if_intersect(0); - while (!que.empty()) - { + while (!que.empty()) { size_t i = que.front(); que.pop(); PRTreeElement &elem = flat_tree[i]; - if (elem.leaf && elem.leaf->mbb(bb)) - { + if (elem.leaf && elem.leaf->mbb(bb)) { cands.push_back(i); - } - else - { - for (size_t offset = 0; offset < B; offset++) - { + } else { + for (size_t offset = 0; offset < B; offset++) { size_t j = i * B + offset + 1; if (j < count) qpush_if_intersect(j); @@ -1077,22 +905,17 @@ class PRTree // Now cands is the list of candidate leaf nodes to insert bb = idx2bb.at(idx); size_t min_leaf = 0; - if (cands.size() == 1) - { + if (cands.size() == 1) { min_leaf = cands[0]; - } - else - { + } else { Real min_diff_area = 1e100; - for (const auto &i : cands) - { + for (const auto &i : cands) { PRTreeLeaf *leaf = flat_tree[i].leaf.get(); PRTreeLeaf tmp_leaf = PRTreeLeaf(*leaf); Real diff_area = -tmp_leaf.area(); tmp_leaf.push(idx, bb); diff_area += tmp_leaf.area(); - if (diff_area < min_diff_area) - { + if (diff_area < min_diff_area) { min_diff_area = diff_area; min_leaf = i; } @@ -1101,15 +924,13 @@ class PRTree flat_tree[min_leaf].leaf->push(idx, bb); // update mbbs of all cands and their parents size_t i = min_leaf; - while (true) - { + while (true) { PRTreeElement &elem = flat_tree[i]; if (elem.leaf) elem.mbb += elem.leaf->mbb; - if (i > 0) - { + if (i > 0) { size_t j = (i - 1) / B; flat_tree[j].mbb += flat_tree[i].mbb; } @@ -1118,8 +939,7 @@ class PRTree i = (i - 1) / B; } - if (size() > REBUILD_THRE * n_at_build) - { + if (size() > REBUILD_THRE * n_at_build) { rebuild(); } #ifdef MY_DEBUG @@ -1128,8 +948,7 @@ class PRTree #endif } - void rebuild() - { + void rebuild() { std::stack sta; T length = idx2bb.size(); DataType *b, *e; @@ -1140,28 +959,21 @@ class PRTree T i = 0; sta.push(0); - while (!sta.empty()) - { + while (!sta.empty()) { size_t idx = sta.top(); sta.pop(); PRTreeElement &elem = flat_tree[idx]; - if (elem.leaf) - { - for (const auto &datum : elem.leaf->data) - { + if (elem.leaf) { + for (const auto &datum : elem.leaf->data) { new (b + i) DataType{datum.first, datum.second}; i++; } - } - else - { - for (size_t offset = 0; offset < B; offset++) - { + } else { + for (size_t offset = 0; offset < B; offset++) { size_t jdx = idx * B + offset + 1; - if (likely(flat_tree[jdx].is_used)) - { + if (likely(flat_tree[jdx].is_used)) { sta.push(jdx); } } @@ -1173,8 +985,7 @@ class PRTree } template - void build(const iterator &b, const iterator &e, void *placement) - { + void build(const iterator &b, const iterator &e, void *placement) { #ifdef MY_DEBUG ProfilerStart("build.prof"); std::cout << "profiler start of build" << std::endl; @@ -1187,14 +998,12 @@ class PRTree auto first_tree = PseudoPRTree(b, e); auto first_leaves = first_tree.get_all_leaves(e - b); - for (auto &leaf : first_leaves) - { + for (auto &leaf : first_leaves) { auto pp = std::make_unique>(leaf); prev_nodes.push_back(std::move(pp)); } auto [bb, ee] = first_tree.as_X(placement, e - b); - while (prev_nodes.size() > 1) - { + while (prev_nodes.size() > 1) { auto tree = PseudoPRTree(bb, ee); auto leaves = tree.get_all_leaves(ee - bb); auto leaves_size = leaves.size(); @@ -1202,43 +1011,35 @@ class PRTree vec>> tmp_nodes; tmp_nodes.reserve(leaves_size); - for (auto &leaf : leaves) - { + for (auto &leaf : leaves) { int idx, jdx; int len = leaf->data.size(); auto pp = std::make_unique>(leaf->mbb); - if (likely(!leaf->data.empty())) - { - for (int i = 1; i < len; i++) - { + if (likely(!leaf->data.empty())) { + for (int i = 1; i < len; i++) { idx = leaf->data[len - i - 1].first; // reversed way jdx = leaf->data[len - i].first; prev_nodes[idx]->next = std::move(prev_nodes[jdx]); } idx = leaf->data[0].first; pp->head = std::move(prev_nodes[idx]); - if (unlikely(!pp->head)) - { + if (unlikely(!pp->head)) { throw std::runtime_error("ppp"); } tmp_nodes.push_back(std::move(pp)); - } - else - { + } else { throw std::runtime_error("what????"); } } prev_nodes.swap(tmp_nodes); - if (prev_nodes.size() > 1) - { + if (prev_nodes.size() > 1) { auto tmp = tree.as_X(placement, ee - bb); bb = std::move(tmp.first); ee = std::move(tmp.second); } } - if (unlikely(prev_nodes.size() != 1)) - { + if (unlikely(prev_nodes.size() != 1)) { throw std::runtime_error("#roots is not 1."); } root = std::move(prev_nodes[0]); @@ -1251,8 +1052,7 @@ class PRTree int depth = 0; p = root.get(); - while (p->head) - { + while (p->head) { p = p->head.get(); depth++; } @@ -1262,8 +1062,7 @@ class PRTree flat_tree.clear(); flat_tree.shrink_to_fit(); size_t count = 0; - for (int i = 0; i <= depth; i++) - { + for (int i = 0; i <= depth; i++) { count += std::pow(B, depth); } flat_tree.resize(count); @@ -1271,8 +1070,7 @@ class PRTree // assign que.emplace(root.get(), 0); - while (!que.empty()) - { + while (!que.empty()) { auto tmp = que.front(); que.pop(); p = tmp.first; @@ -1280,15 +1078,13 @@ class PRTree flat_tree[idx] = PRTreeElement(*p); size_t child_idx = 0; - if (p->head) - { + if (p->head) { size_t jdx = idx * B + child_idx + 1; ++child_idx; q = p->head.get(); que.emplace(q, jdx); - while (q->next) - { + while (q->next) { jdx = idx * B + child_idx + 1; ++child_idx; @@ -1305,8 +1101,7 @@ class PRTree #endif } - auto find_all(const py::array_t &x) - { + auto find_all(const py::array_t &x) { #ifdef MY_DEBUG ProfilerStart("find_all.prof"); std::cout << "profiler start of find_all" << std::endl; @@ -1315,75 +1110,55 @@ class PRTree const auto &ndim = buff_info_x.ndim; const auto &shape_x = buff_info_x.shape; bool is_point = false; - if (unlikely(ndim == 1 && (!(shape_x[0] == 2 * D || shape_x[0] == D)))) - { + if (unlikely(ndim == 1 && (!(shape_x[0] == 2 * D || shape_x[0] == D)))) { throw std::runtime_error("Invalid Bounding box size"); } - if (unlikely((ndim == 2 && (!(shape_x[1] == 2 * D || shape_x[1] == D))))) - { + if (unlikely((ndim == 2 && (!(shape_x[1] == 2 * D || shape_x[1] == D))))) { throw std::runtime_error( "Bounding box must have the shape (length, 2 * dim)"); } - if (unlikely(ndim > 3)) - { + if (unlikely(ndim > 3)) { throw std::runtime_error("invalid shape"); } - if (ndim == 1) - { - if (shape_x[0] == D) - { + if (ndim == 1) { + if (shape_x[0] == D) { is_point = true; } - } - else - { - if (shape_x[1] == D) - { + } else { + if (shape_x[1] == D) { is_point = true; } } vec> X; X.reserve(ndim == 1 ? 1 : shape_x[0]); BB bb; - if (ndim == 1) - { + if (ndim == 1) { { Real minima[D]; Real maxima[D]; - for (int i = 0; i < D; ++i) - { + for (int i = 0; i < D; ++i) { minima[i] = *x.data(i); - if (is_point) - { + if (is_point) { maxima[i] = minima[i]; - } - else - { + } else { maxima[i] = *x.data(i + D); } } bb = BB(minima, maxima); } X.push_back(std::move(bb)); - } - else - { + } else { X.reserve(shape_x[0]); - for (long int i = 0; i < shape_x[0]; i++) - { + for (long int i = 0; i < shape_x[0]; i++) { { Real minima[D]; Real maxima[D]; - for (int j = 0; j < D; ++j) - { + for (int j = 0; j < D; ++j) { minima[j] = *x.data(i, j); - if (is_point) - { + if (is_point) { maxima[j] = minima[j]; - } - else - { + } else { maxima[j] = *x.data(i, j + D); } } @@ -1395,88 +1170,71 @@ class PRTree // Build exact query coordinates for refinement vec> queries_exact; queries_exact.reserve(X.size()); - - if (ndim == 1) - { + + if (ndim == 1) { std::array qe; - for (int i = 0; i < D; ++i) - { + for (int i = 0; i < D; ++i) { qe[i] = static_cast(*x.data(i)); - if (is_point) - { + if (is_point) { qe[i + D] = qe[i]; - } - else - { + } else { qe[i + D] = static_cast(*x.data(i + D)); } } queries_exact.push_back(qe); - } - else - { - for (long int i = 0; i < shape_x[0]; i++) - { + } else { + for (long int i = 0; i < shape_x[0]; i++) { std::array qe; - for (int j = 0; j < D; ++j) - { + for (int j = 0; j < D; ++j) { qe[j] = static_cast(*x.data(i, j)); - if (is_point) - { + if (is_point) { qe[j + D] = qe[j]; - } - else - { + } else { qe[j + D] = static_cast(*x.data(i, j + D)); } } queries_exact.push_back(qe); } } - + vec> out; - out.resize(X.size()); // Pre-size for index-based parallel access + out.resize(X.size()); // Pre-size for index-based parallel access #ifdef MY_DEBUG - for (size_t i = 0; i < X.size(); ++i) - { + for (size_t i = 0; i < X.size(); ++i) { auto candidates = find(X[i]); out[i] = refine_candidates(candidates, queries_exact[i]); } #else // Index-based parallel loop (safe, no pointer arithmetic) const size_t n_queries = X.size(); - + // Early return if no queries - if (n_queries == 0) - { + if (n_queries == 0) { return out; } - + // Guard against hardware_concurrency() returning 0 (can happen on macOS) size_t hw = std::thread::hardware_concurrency(); size_t n_threads = hw ? hw : 1; n_threads = std::min(n_threads, n_queries); - + const size_t chunk_size = (n_queries + n_threads - 1) / n_threads; - + vec threads; threads.reserve(n_threads); - - for (size_t t = 0; t < n_threads; ++t) - { + + for (size_t t = 0; t < n_threads; ++t) { threads.emplace_back([&, t]() { size_t start = t * chunk_size; size_t end = std::min(start + chunk_size, n_queries); - for (size_t i = start; i < end; ++i) - { + for (size_t i = start; i < end; ++i) { auto candidates = find(X[i]); out[i] = refine_candidates(candidates, queries_exact[i]); } }); } - - for (auto &thread : threads) - { + + for (auto &thread : threads) { thread.join(); } #endif @@ -1487,108 +1245,91 @@ class PRTree return out; } - auto find_all_array(const py::array_t &x) - { + auto find_all_array(const py::array_t &x) { return list_list_to_arrays(std::move(find_all(x))); } - auto find_one(const vec &x) - { + auto find_one(const vec &x) { bool is_point = false; - if (unlikely(!(x.size() == 2 * D || x.size() == D))) - { + if (unlikely(!(x.size() == 2 * D || x.size() == D))) { throw std::runtime_error("invalid shape"); } Real minima[D]; Real maxima[D]; std::array query_exact; - - if (x.size() == D) - { + + if (x.size() == D) { is_point = true; } - for (int i = 0; i < D; ++i) - { + for (int i = 0; i < D; ++i) { minima[i] = x.at(i); query_exact[i] = static_cast(x.at(i)); - - if (is_point) - { + + if (is_point) { maxima[i] = minima[i]; query_exact[i + D] = query_exact[i]; - } - else - { + } else { maxima[i] = x.at(i + D); query_exact[i + D] = static_cast(x.at(i + D)); } } const auto bb = BB(minima, maxima); auto candidates = find(bb); - + // Refine with double precision if exact coordinates are available auto out = refine_candidates(candidates, query_exact); return out; } - // Helper method: Check intersection with double precision (closed interval semantics) - bool intersects_exact(const std::array &box_a, const std::array &box_b) const - { - for (int i = 0; i < D; ++i) - { + // Helper method: Check intersection with double precision (closed interval + // semantics) + bool intersects_exact(const std::array &box_a, + const std::array &box_b) const { + for (int i = 0; i < D; ++i) { double a_min = box_a[i]; double a_max = box_a[i + D]; double b_min = box_b[i]; double b_max = box_b[i + D]; - + // Closed interval: boxes touch if a_max == b_min or b_max == a_min - if (a_min > b_max || b_min > a_max) - { + if (a_min > b_max || b_min > a_max) { return false; } } return true; } - + // Refine candidates using double-precision coordinates - vec refine_candidates(const vec &candidates, const std::array &query_exact) const - { - if (idx2exact.empty()) - { + vec refine_candidates(const vec &candidates, + const std::array &query_exact) const { + if (idx2exact.empty()) { // No exact coordinates stored, return candidates as-is return candidates; } - + vec refined; refined.reserve(candidates.size()); - - for (const T &idx : candidates) - { + + for (const T &idx : candidates) { auto it = idx2exact.find(idx); - if (it != idx2exact.end()) - { + if (it != idx2exact.end()) { // Check with double precision - if (intersects_exact(it->second, query_exact)) - { + if (intersects_exact(it->second, query_exact)) { refined.push_back(idx); } // else: false positive from float32, filter it out - } - else - { + } else { // No exact coords for this item (e.g., inserted as float32), keep it refined.push_back(idx); } } - + return refined; } - vec find(const BB &target) - { + vec find(const BB &target) { vec out; - auto find_func = [&](std::unique_ptr> &leaf) - { + auto find_func = [&](std::unique_ptr> &leaf) { (*leaf)(target, out); }; @@ -1597,17 +1338,14 @@ class PRTree return out; } - void erase(const T idx) - { + void erase(const T idx) { auto it = idx2bb.find(idx); - if (unlikely(it == idx2bb.end())) - { + if (unlikely(it == idx2bb.end())) { throw std::runtime_error("Given index is not found."); } BB target = it->second; - auto erase_func = [&](std::unique_ptr> &leaf) - { + auto erase_func = [&](std::unique_ptr> &leaf) { leaf->del(idx, target); }; @@ -1615,15 +1353,172 @@ class PRTree idx2bb.erase(idx); idx2data.erase(idx); - idx2exact.erase(idx); // Also remove from exact coordinates if present - if (unlikely(REBUILD_THRE * size() < n_at_build)) - { + idx2exact.erase(idx); // Also remove from exact coordinates if present + if (unlikely(REBUILD_THRE * size() < n_at_build)) { rebuild(); } } - int64_t size() - { - return static_cast(idx2bb.size()); + int64_t size() { return static_cast(idx2bb.size()); } + + /** + * Find all pairs of intersecting AABBs in the tree. + * Returns a numpy array of shape (n_pairs, 2) where each row contains + * a pair of indices (i, j) with i < j representing intersecting AABBs. + * + * This method is optimized for performance by: + * - Using parallel processing for queries + * - Avoiding duplicate pairs by enforcing i < j + * - Performing intersection checks in C++ to minimize Python overhead + * - Using double-precision refinement when exact coordinates are available + * + * @return py::array_t Array of shape (n_pairs, 2) containing index pairs + */ + py::array_t query_intersections() { + // Collect all indices and bounding boxes + vec indices; + vec> bboxes; + vec> exact_coords; + + if (unlikely(idx2bb.empty())) { + // Return empty array of shape (0, 2) + vec empty_data; + std::unique_ptr> data_ptr = + std::make_unique>(std::move(empty_data)); + auto capsule = py::capsule(data_ptr.get(), [](void *p) { + std::unique_ptr>(reinterpret_cast *>(p)); + }); + data_ptr.release(); + return py::array_t({0, 2}, {2 * sizeof(T), sizeof(T)}, nullptr, + capsule); + } + + indices.reserve(idx2bb.size()); + bboxes.reserve(idx2bb.size()); + exact_coords.reserve(idx2bb.size()); + + for (const auto &pair : idx2bb) { + indices.push_back(pair.first); + bboxes.push_back(pair.second); + + // Get exact coordinates if available + auto it = idx2exact.find(pair.first); + if (it != idx2exact.end()) { + exact_coords.push_back(it->second); + } else { + // Create dummy exact coords from float32 BB (won't be used for + // refinement) + std::array dummy; + for (int i = 0; i < D; ++i) { + dummy[i] = static_cast(pair.second.min(i)); + dummy[i + D] = static_cast(pair.second.max(i)); + } + exact_coords.push_back(dummy); + } + } + + const size_t n_items = indices.size(); + + // Use thread-local storage to collect pairs + // Guard against hardware_concurrency() returning 0 (can happen on some + // systems) + size_t hw = std::thread::hardware_concurrency(); + size_t n_threads = hw ? hw : 1; + n_threads = std::min(n_threads, n_items); + vec>> thread_pairs(n_threads); + +#ifdef MY_PARALLEL + vec threads; + threads.reserve(n_threads); + + for (size_t t = 0; t < n_threads; ++t) { + threads.emplace_back([&, t]() { + vec> local_pairs; + + for (size_t i = t; i < n_items; i += n_threads) { + const T idx_i = indices[i]; + const BB &bb_i = bboxes[i]; + + // Find all intersections with this bounding box + auto candidates = find(bb_i); + + // Refine candidates using exact coordinates if available + if (!idx2exact.empty()) { + candidates = refine_candidates(candidates, exact_coords[i]); + } + + // Keep only pairs where idx_i < idx_j to avoid duplicates + for (const T &idx_j : candidates) { + if (idx_i < idx_j) { + local_pairs.emplace_back(idx_i, idx_j); + } + } + } + + thread_pairs[t] = std::move(local_pairs); + }); + } + + for (auto &thread : threads) { + thread.join(); + } +#else + // Single-threaded version + vec> local_pairs; + + for (size_t i = 0; i < n_items; ++i) { + const T idx_i = indices[i]; + const BB &bb_i = bboxes[i]; + + // Find all intersections with this bounding box + auto candidates = find(bb_i); + + // Refine candidates using exact coordinates if available + if (!idx2exact.empty()) { + candidates = refine_candidates(candidates, exact_coords[i]); + } + + // Keep only pairs where idx_i < idx_j to avoid duplicates + for (const T &idx_j : candidates) { + if (idx_i < idx_j) { + local_pairs.emplace_back(idx_i, idx_j); + } + } + } + + thread_pairs[0] = std::move(local_pairs); +#endif + + // Merge results from all threads into a flat vector + vec flat_pairs; + size_t total_pairs = 0; + for (const auto &pairs : thread_pairs) { + total_pairs += pairs.size(); + } + flat_pairs.reserve(total_pairs * 2); + + for (const auto &pairs : thread_pairs) { + for (const auto &pair : pairs) { + flat_pairs.push_back(pair.first); + flat_pairs.push_back(pair.second); + } + } + + // Create output numpy array using the same pattern as as_pyarray + auto data = flat_pairs.data(); + std::unique_ptr> data_ptr = + std::make_unique>(std::move(flat_pairs)); + auto capsule = py::capsule(data_ptr.get(), [](void *p) { + std::unique_ptr>(reinterpret_cast *>(p)); + }); + data_ptr.release(); + + // Return 2D array with shape (total_pairs, 2) + return py::array_t( + {static_cast(total_pairs), py::ssize_t(2)}, // shape + {2 * sizeof(T), sizeof(T)}, // strides (row-major) + data, // data pointer + capsule // capsule for cleanup + ); } }; diff --git a/cpp/small_vector.h b/cpp/small_vector.h index 1753d17..6cedaa5 100644 --- a/cpp/small_vector.h +++ b/cpp/small_vector.h @@ -44,12 +44,13 @@ // Simply include this file wherever you need. // It defines the class itlib::small_vector, which is a drop-in replacement of // std::vector, but with an initial capacity as a template argument. -// It gives you the benefits of using std::vector, at the cost of having a statically -// allocated buffer for the initial capacity, which gives you cache-local data -// when the vector is small (smaller than the initial capacity). +// It gives you the benefits of using std::vector, at the cost of having a +// statically allocated buffer for the initial capacity, which gives you +// cache-local data when the vector is small (smaller than the initial +// capacity). // -// When the size exceeds the capacity, the vector allocates memory via the provided -// allocator, falling back to classic std::vector behavior. +// When the size exceeds the capacity, the vector allocates memory via the +// provided allocator, falling back to classic std::vector behavior. // // The second size_t template argument, RevertToStaticSize, is used when a // small_vector which has already switched to dynamically allocated size reduces @@ -61,14 +62,15 @@ // // Example: // -// itlib::small_vector myvec; // a small_vector of size 0, initial capacity 4, and revert size 4 (smaller than 5) -// myvec.resize(2); // vector is {0,0} in static buffer -// myvec[1] = 11; // vector is {0,11} in static buffer +// itlib::small_vector myvec; // a small_vector of size 0, initial +// capacity 4, and revert size 4 (smaller than 5) myvec.resize(2); // vector is +// {0,0} in static buffer myvec[1] = 11; // vector is {0,11} in static buffer // myvec.push_back(7); // vector is {0,11,7} in static buffer // myvec.insert(myvec.begin() + 1, 3); // vector is {0,3,11,7} in static buffer -// myvec.push_back(5); // vector is {0,3,11,7,5} in dynamically allocated memory buffer -// myvec.erase(myvec.begin()); // vector is {3,11,7,5} back in static buffer -// myvec.resize(5); // vector is {3,11,7,5,0} back in dynamically allocated memory +// myvec.push_back(5); // vector is {0,3,11,7,5} in dynamically allocated memory +// buffer myvec.erase(myvec.begin()); // vector is {3,11,7,5} back in static +// buffer myvec.resize(5); // vector is {3,11,7,5,0} back in dynamically +// allocated memory // // // Reference: @@ -84,13 +86,14 @@ // // Other notes: // -// * the default value for RevertToStaticSize is zero. This means that once a dynamic -// buffer is allocated the data will never be put into the static one, even if the -// size allows it. Even if clear() is called. The only way to do so is to call -// shrink_to_fit() or revert_to_static() +// * the default value for RevertToStaticSize is zero. This means that once a +// dynamic +// buffer is allocated the data will never be put into the static one, even if +// the size allows it. Even if clear() is called. The only way to do so is to +// call shrink_to_fit() or revert_to_static() // * shrink_to_fit will free and reallocate if size != capacity and the data -// doesn't fit into the static buffer. It also will revert to the static buffer -// whenever possible regardless of the RevertToStaticSize value +// doesn't fit into the static buffer. It also will revert to the static +// buffer whenever possible regardless of the RevertToStaticSize value // // // Configuration @@ -118,7 +121,8 @@ // // To set this setting by editing the file change the line: // ``` -// # define ITLIB_SMALL_VECTOR_ERROR_HANDLING ITLIB_SMALL_VECTOR_ERROR_HANDLING_THROW +// # define ITLIB_SMALL_VECTOR_ERROR_HANDLING +// ITLIB_SMALL_VECTOR_ERROR_HANDLING_THROW // ``` // to the default setting of your choice // @@ -138,1024 +142,841 @@ // #pragma once -#include #include #include +#include -#define ITLIB_SMALL_VECTOR_ERROR_HANDLING_NONE 0 +#define ITLIB_SMALL_VECTOR_ERROR_HANDLING_NONE 0 #define ITLIB_SMALL_VECTOR_ERROR_HANDLING_THROW 1 #define ITLIB_SMALL_VECTOR_ERROR_HANDLING_ASSERT 2 #define ITLIB_SMALL_VECTOR_ERROR_HANDLING_ASSERT_AND_THROW 3 #if !defined(ITLIB_SMALL_VECTOR_ERROR_HANDLING) -# define ITLIB_SMALL_VECTOR_ERROR_HANDLING ITLIB_SMALL_VECTOR_ERROR_HANDLING_THROW +#define ITLIB_SMALL_VECTOR_ERROR_HANDLING \ + ITLIB_SMALL_VECTOR_ERROR_HANDLING_THROW #endif - #if ITLIB_SMALL_VECTOR_ERROR_HANDLING == ITLIB_SMALL_VECTOR_ERROR_HANDLING_NONE -# define I_ITLIB_SMALL_VECTOR_OUT_OF_RANGE_IF(cond) -#elif ITLIB_SMALL_VECTOR_ERROR_HANDLING == ITLIB_SMALL_VECTOR_ERROR_HANDLING_THROW -# include -# define I_ITLIB_SMALL_VECTOR_OUT_OF_RANGE_IF(cond) if (cond) throw std::out_of_range("itlib::small_vector out of range") -#elif ITLIB_SMALL_VECTOR_ERROR_HANDLING == ITLIB_SMALL_VECTOR_ERROR_HANDLING_ASSERT -# include -# define I_ITLIB_SMALL_VECTOR_OUT_OF_RANGE_IF(cond, rescue_return) assert(!(cond) && "itlib::small_vector out of range") -#elif ITLIB_SMALL_VECTOR_ERROR_HANDLING == ITLIB_SMALL_VECTOR_ERROR_HANDLING_ASSERT_AND_THROW -# include -# include -# define I_ITLIB_SMALL_VECTOR_OUT_OF_RANGE_IF(cond, rescue_return) \ - do { if (cond) { assert(false && "itlib::small_vector out of range"); throw std::out_of_range("itlib::small_vector out of range"); } } while(false) +#define I_ITLIB_SMALL_VECTOR_OUT_OF_RANGE_IF(cond) +#elif ITLIB_SMALL_VECTOR_ERROR_HANDLING == \ + ITLIB_SMALL_VECTOR_ERROR_HANDLING_THROW +#include +#define I_ITLIB_SMALL_VECTOR_OUT_OF_RANGE_IF(cond) \ + if (cond) \ + throw std::out_of_range("itlib::small_vector out of range") +#elif ITLIB_SMALL_VECTOR_ERROR_HANDLING == \ + ITLIB_SMALL_VECTOR_ERROR_HANDLING_ASSERT +#include +#define I_ITLIB_SMALL_VECTOR_OUT_OF_RANGE_IF(cond, rescue_return) \ + assert(!(cond) && "itlib::small_vector out of range") +#elif ITLIB_SMALL_VECTOR_ERROR_HANDLING == \ + ITLIB_SMALL_VECTOR_ERROR_HANDLING_ASSERT_AND_THROW +#include +#include +#define I_ITLIB_SMALL_VECTOR_OUT_OF_RANGE_IF(cond, rescue_return) \ + do { \ + if (cond) { \ + assert(false && "itlib::small_vector out of range"); \ + throw std::out_of_range("itlib::small_vector out of range"); \ + } \ + } while (false) #else #error "Unknown ITLIB_SMALL_VECTOR_ERRROR_HANDLING" #endif - #if defined(ITLIB_SMALL_VECTOR_NO_DEBUG_BOUNDS_CHECK) -# define I_ITLIB_SMALL_VECTOR_BOUNDS_CHECK(i) +#define I_ITLIB_SMALL_VECTOR_BOUNDS_CHECK(i) #else -# include -# define I_ITLIB_SMALL_VECTOR_BOUNDS_CHECK(i) assert((i) < this->size()) +#include +#define I_ITLIB_SMALL_VECTOR_BOUNDS_CHECK(i) assert((i) < this->size()) #endif -namespace itlib -{ +namespace itlib { -template> -struct small_vector -{ - static_assert(RevertToStaticSize <= StaticCapacity + 1, "itlib::small_vector: the revert-to-static size shouldn't exceed the static capacity by more than one"); +template > +struct small_vector { + static_assert(RevertToStaticSize <= StaticCapacity + 1, + "itlib::small_vector: the revert-to-static size shouldn't " + "exceed the static capacity by more than one"); + + using atraits = std::allocator_traits; - using atraits = std::allocator_traits; public: - using allocator_type = Alloc; - using value_type = typename atraits::value_type; - using size_type = typename atraits::size_type; - using difference_type = typename atraits::difference_type; - using reference = T&; - using const_reference = const T&; - using pointer = typename atraits::pointer; - using const_pointer = typename atraits::const_pointer; - using iterator = pointer; - using const_iterator = const_pointer; - using reverse_iterator = std::reverse_iterator; - using const_reverse_iterator = std::reverse_iterator; - - static constexpr size_t static_capacity = StaticCapacity; - static constexpr intptr_t revert_to_static_size = RevertToStaticSize; - - small_vector() - : small_vector(Alloc()) - {} - - small_vector(const Alloc& alloc) - : m_alloc(alloc) - , m_capacity(StaticCapacity) - , m_dynamic_capacity(0) - , m_dynamic_data(nullptr) - { - m_begin = m_end = static_begin_ptr(); - } + using allocator_type = Alloc; + using value_type = typename atraits::value_type; + using size_type = typename atraits::size_type; + using difference_type = typename atraits::difference_type; + using reference = T &; + using const_reference = const T &; + using pointer = typename atraits::pointer; + using const_pointer = typename atraits::const_pointer; + using iterator = pointer; + using const_iterator = const_pointer; + using reverse_iterator = std::reverse_iterator; + using const_reverse_iterator = std::reverse_iterator; + + static constexpr size_t static_capacity = StaticCapacity; + static constexpr intptr_t revert_to_static_size = RevertToStaticSize; + + small_vector() : small_vector(Alloc()) {} + + small_vector(const Alloc &alloc) + : m_alloc(alloc), m_capacity(StaticCapacity), m_dynamic_capacity(0), + m_dynamic_data(nullptr) { + m_begin = m_end = static_begin_ptr(); + } + + explicit small_vector(size_t count, const Alloc &alloc = Alloc()) + : small_vector(alloc) { + resize(count); + } + + explicit small_vector(size_t count, const T &value, + const Alloc &alloc = Alloc()) + : small_vector(alloc) { + assign_impl(count, value); + } + + template ())> + small_vector(InputIterator first, InputIterator last, + const Alloc &alloc = Alloc()) + : small_vector(alloc) { + assign_impl(first, last); + } + + small_vector(std::initializer_list l, const Alloc &alloc = Alloc()) + : small_vector(alloc) { + assign_impl(l); + } + + small_vector(const small_vector &v) + : small_vector(v, atraits::select_on_container_copy_construction( + v.get_allocator())) {} + + small_vector(const small_vector &v, const Alloc &alloc) + : m_alloc(alloc), m_dynamic_capacity(0), m_dynamic_data(nullptr) { + if (v.size() > StaticCapacity) { + m_dynamic_capacity = v.size(); + m_begin = m_end = m_dynamic_data = + atraits::allocate(get_alloc(), m_dynamic_capacity); + m_capacity = v.size(); + } else { + m_begin = m_end = static_begin_ptr(); + m_capacity = StaticCapacity; + } + + for (auto p = v.m_begin; p != v.m_end; ++p) { + atraits::construct(get_alloc(), m_end, *p); + ++m_end; + } + } + + small_vector(small_vector &&v) noexcept + : m_alloc(std::move(v.get_alloc())), m_capacity(v.m_capacity), + m_dynamic_capacity(v.m_dynamic_capacity), + m_dynamic_data(v.m_dynamic_data) { + if (v.m_begin == v.static_begin_ptr()) { + m_begin = m_end = static_begin_ptr(); + for (auto p = v.m_begin; p != v.m_end; ++p) { + atraits::construct(get_alloc(), m_end, std::move(*p)); + ++m_end; + } + + v.clear(); + } else { + m_begin = v.m_begin; + m_end = v.m_end; + } + + v.m_dynamic_capacity = 0; + v.m_dynamic_data = nullptr; + v.m_begin = v.m_end = v.static_begin_ptr(); + v.m_capacity = StaticCapacity; + } - explicit small_vector(size_t count, const Alloc& alloc = Alloc()) - : small_vector(alloc) - { - resize(count); - } + ~small_vector() { + clear(); - explicit small_vector(size_t count, const T& value, const Alloc& alloc = Alloc()) - : small_vector(alloc) - { - assign_impl(count, value); + if (m_dynamic_data) { + atraits::deallocate(get_alloc(), m_dynamic_data, m_dynamic_capacity); } + } - template ())> - small_vector(InputIterator first, InputIterator last, const Alloc& alloc = Alloc()) - : small_vector(alloc) - { - assign_impl(first, last); + small_vector &operator=(const small_vector &v) { + if (this == &v) { + // prevent self usurp + return *this; } - small_vector(std::initializer_list l, const Alloc& alloc = Alloc()) - : small_vector(alloc) - { - assign_impl(l); - } + clear(); - small_vector(const small_vector& v) - : small_vector(v, atraits::select_on_container_copy_construction(v.get_allocator())) - {} - - small_vector(const small_vector& v, const Alloc& alloc) - : m_alloc(alloc) - , m_dynamic_capacity(0) - , m_dynamic_data(nullptr) - { - if (v.size() > StaticCapacity) - { - m_dynamic_capacity = v.size(); - m_begin = m_end = m_dynamic_data = atraits::allocate(get_alloc(), m_dynamic_capacity); - m_capacity = v.size(); - } - else - { - m_begin = m_end = static_begin_ptr(); - m_capacity = StaticCapacity; - } + m_begin = m_end = choose_data(v.size()); - for (auto p = v.m_begin; p != v.m_end; ++p) - { - atraits::construct(get_alloc(), m_end, *p); - ++m_end; - } + for (auto p = v.m_begin; p != v.m_end; ++p) { + atraits::construct(get_alloc(), m_end, *p); + ++m_end; } - small_vector(small_vector&& v) noexcept - : m_alloc(std::move(v.get_alloc())) - , m_capacity(v.m_capacity) - , m_dynamic_capacity(v.m_dynamic_capacity) - , m_dynamic_data(v.m_dynamic_data) - { - if (v.m_begin == v.static_begin_ptr()) - { - m_begin = m_end = static_begin_ptr(); - for (auto p = v.m_begin; p != v.m_end; ++p) - { - atraits::construct(get_alloc(), m_end, std::move(*p)); - ++m_end; - } - - v.clear(); - } - else - { - m_begin = v.m_begin; - m_end = v.m_end; - } + update_capacity(); - v.m_dynamic_capacity = 0; - v.m_dynamic_data = nullptr; - v.m_begin = v.m_end = v.static_begin_ptr(); - v.m_capacity = StaticCapacity; - } + return *this; + } - ~small_vector() - { - clear(); + small_vector &operator=(small_vector &&v) noexcept { + clear(); - if (m_dynamic_data) - { - atraits::deallocate(get_alloc(), m_dynamic_data, m_dynamic_capacity); - } + get_alloc() = std::move(v.get_alloc()); + m_capacity = v.m_capacity; + m_dynamic_capacity = v.m_dynamic_capacity; + m_dynamic_data = v.m_dynamic_data; + + if (v.m_begin == v.static_begin_ptr()) { + m_begin = m_end = static_begin_ptr(); + for (auto p = v.m_begin; p != v.m_end; ++p) { + atraits::construct(get_alloc(), m_end, std::move(*p)); + ++m_end; + } + + v.clear(); + } else { + m_begin = v.m_begin; + m_end = v.m_end; } - small_vector& operator=(const small_vector& v) - { - if (this == &v) - { - // prevent self usurp - return *this; - } + v.m_dynamic_capacity = 0; + v.m_dynamic_data = nullptr; + v.m_begin = v.m_end = v.static_begin_ptr(); + v.m_capacity = StaticCapacity; - clear(); + return *this; + } - m_begin = m_end = choose_data(v.size()); + void assign(size_type count, const T &value) { + clear(); + assign_impl(count, value); + } - for (auto p = v.m_begin; p != v.m_end; ++p) - { - atraits::construct(get_alloc(), m_end, *p); - ++m_end; - } + template ())> + void assign(InputIterator first, InputIterator last) { + clear(); + assign_impl(first, last); + } - update_capacity(); + void assign(std::initializer_list ilist) { + clear(); + assign_impl(ilist); + } - return *this; - } + allocator_type get_allocator() const { return get_alloc(); } - small_vector& operator=(small_vector&& v) noexcept - { - clear(); - - get_alloc() = std::move(v.get_alloc()); - m_capacity = v.m_capacity; - m_dynamic_capacity = v.m_dynamic_capacity; - m_dynamic_data = v.m_dynamic_data; - - if (v.m_begin == v.static_begin_ptr()) - { - m_begin = m_end = static_begin_ptr(); - for (auto p = v.m_begin; p != v.m_end; ++p) - { - atraits::construct(get_alloc(), m_end, std::move(*p)); - ++m_end; - } - - v.clear(); - } - else - { - m_begin = v.m_begin; - m_end = v.m_end; - } + const_reference at(size_type i) const { + I_ITLIB_SMALL_VECTOR_BOUNDS_CHECK(i); + return *(m_begin + i); + } - v.m_dynamic_capacity = 0; - v.m_dynamic_data = nullptr; - v.m_begin = v.m_end = v.static_begin_ptr(); - v.m_capacity = StaticCapacity; + reference at(size_type i) { + I_ITLIB_SMALL_VECTOR_BOUNDS_CHECK(i); + return *(m_begin + i); + } - return *this; - } + const_reference operator[](size_type i) const { return at(i); } - void assign(size_type count, const T& value) - { - clear(); - assign_impl(count, value); - } + reference operator[](size_type i) { return at(i); } - template ())> - void assign(InputIterator first, InputIterator last) - { - clear(); - assign_impl(first, last); - } + const_reference front() const { return at(0); } - void assign(std::initializer_list ilist) - { - clear(); - assign_impl(ilist); - } + reference front() { return at(0); } - allocator_type get_allocator() const - { - return get_alloc(); - } + const_reference back() const { return *(m_end - 1); } - const_reference at(size_type i) const - { - I_ITLIB_SMALL_VECTOR_BOUNDS_CHECK(i); - return *(m_begin + i); - } + reference back() { return *(m_end - 1); } - reference at(size_type i) - { - I_ITLIB_SMALL_VECTOR_BOUNDS_CHECK(i); - return *(m_begin + i); - } + const_pointer data() const noexcept { return m_begin; } - const_reference operator[](size_type i) const - { - return at(i); - } + pointer data() noexcept { return m_begin; } - reference operator[](size_type i) - { - return at(i); - } + // iterators + iterator begin() noexcept { return m_begin; } - const_reference front() const - { - return at(0); - } + const_iterator begin() const noexcept { return m_begin; } - reference front() - { - return at(0); - } + const_iterator cbegin() const noexcept { return m_begin; } - const_reference back() const - { - return *(m_end - 1); - } + iterator end() noexcept { return m_end; } - reference back() - { - return *(m_end - 1); - } + const_iterator end() const noexcept { return m_end; } - const_pointer data() const noexcept - { - return m_begin; - } + const_iterator cend() const noexcept { return m_end; } - pointer data() noexcept - { - return m_begin; - } + reverse_iterator rbegin() noexcept { return reverse_iterator(end()); } - // iterators - iterator begin() noexcept - { - return m_begin; - } + const_reverse_iterator rbegin() const noexcept { + return const_reverse_iterator(end()); + } - const_iterator begin() const noexcept - { - return m_begin; - } + const_reverse_iterator crbegin() const noexcept { + return const_reverse_iterator(end()); + } - const_iterator cbegin() const noexcept - { - return m_begin; - } + reverse_iterator rend() noexcept { return reverse_iterator(begin()); } - iterator end() noexcept - { - return m_end; - } + const_reverse_iterator rend() const noexcept { + return const_reverse_iterator(begin()); + } - const_iterator end() const noexcept - { - return m_end; - } + const_reverse_iterator crend() const noexcept { + return const_reverse_iterator(begin()); + } - const_iterator cend() const noexcept - { - return m_end; - } + // capacity + bool empty() const noexcept { return m_begin == m_end; } - reverse_iterator rbegin() noexcept - { - return reverse_iterator(end()); - } + size_t size() const noexcept { return m_end - m_begin; } - const_reverse_iterator rbegin() const noexcept - { - return const_reverse_iterator(end()); - } + size_t max_size() const noexcept { return atraits::max_size(); } - const_reverse_iterator crbegin() const noexcept - { - return const_reverse_iterator(end()); - } + void reserve(size_type new_cap) { + if (new_cap <= m_capacity) + return; - reverse_iterator rend() noexcept - { - return reverse_iterator(begin()); - } + auto new_buf = choose_data(new_cap); - const_reverse_iterator rend() const noexcept - { - return const_reverse_iterator(begin()); - } + assert(new_buf != + m_begin); // should've been handled by new_cap <= m_capacity + assert(new_buf != + static_begin_ptr()); // we should never reserve into static memory - const_reverse_iterator crend() const noexcept - { - return const_reverse_iterator(begin()); + const auto s = size(); + if (s < RevertToStaticSize) { + // we've allocated enough memory for the dynamic buffer but don't move + // there until we have to + return; } - // capacity - bool empty() const noexcept - { - return m_begin == m_end; + // now we need to transfer the existing elements into the new buffer + for (size_type i = 0; i < s; ++i) { + atraits::construct(get_alloc(), new_buf + i, std::move(*(m_begin + i))); } - size_t size() const noexcept - { - return m_end - m_begin; + // free old elements + for (size_type i = 0; i < s; ++i) { + atraits::destroy(get_alloc(), m_begin + i); } - size_t max_size() const noexcept - { - return atraits::max_size(); + if (m_begin != static_begin_ptr()) { + // we've moved from dyn to dyn memory, so deallocate the old one + atraits::deallocate(get_alloc(), m_begin, m_capacity); } - void reserve(size_type new_cap) - { - if (new_cap <= m_capacity) return; + m_begin = new_buf; + m_end = new_buf + s; + m_capacity = m_dynamic_capacity; + } - auto new_buf = choose_data(new_cap); + size_t capacity() const noexcept { return m_capacity; } - assert(new_buf != m_begin); // should've been handled by new_cap <= m_capacity - assert(new_buf != static_begin_ptr()); // we should never reserve into static memory + void shrink_to_fit() { + const auto s = size(); - const auto s = size(); - if(s < RevertToStaticSize) - { - // we've allocated enough memory for the dynamic buffer but don't move there until we have to - return; - } + if (s == m_capacity) + return; + if (m_begin == static_begin_ptr()) + return; - // now we need to transfer the existing elements into the new buffer - for (size_type i = 0; i < s; ++i) - { - atraits::construct(get_alloc(), new_buf + i, std::move(*(m_begin + i))); - } + auto old_end = m_end; - // free old elements - for (size_type i = 0; i < s; ++i) - { - atraits::destroy(get_alloc(), m_begin + i); - } - - if (m_begin != static_begin_ptr()) - { - // we've moved from dyn to dyn memory, so deallocate the old one - atraits::deallocate(get_alloc(), m_begin, m_capacity); - } - - m_begin = new_buf; - m_end = new_buf + s; - m_capacity = m_dynamic_capacity; + if (s < StaticCapacity) { + // revert to static capacity + m_begin = m_end = static_begin_ptr(); + m_capacity = StaticCapacity; + } else { + // alloc new smaller buffer + m_begin = m_end = atraits::allocate(get_alloc(), s); + m_capacity = s; } - size_t capacity() const noexcept - { - return m_capacity; + for (auto p = m_dynamic_data; p != old_end; ++p) { + atraits::construct(get_alloc(), m_end, std::move(*p)); + ++m_end; + atraits::destroy(get_alloc(), p); } - void shrink_to_fit() - { - const auto s = size(); - - if (s == m_capacity) return; - if (m_begin == static_begin_ptr()) return; - - auto old_end = m_end; - - if (s < StaticCapacity) - { - // revert to static capacity - m_begin = m_end = static_begin_ptr(); - m_capacity = StaticCapacity; - } - else - { - // alloc new smaller buffer - m_begin = m_end = atraits::allocate(get_alloc(), s); - m_capacity = s; - } + atraits::deallocate(get_alloc(), m_dynamic_data, m_dynamic_capacity); + m_dynamic_data = nullptr; + m_dynamic_capacity = 0; + } - for (auto p = m_dynamic_data; p != old_end; ++p) - { - atraits::construct(get_alloc(), m_end, std::move(*p)); - ++m_end; - atraits::destroy(get_alloc(), p); - } + void revert_to_static() { + const auto s = size(); + if (m_begin == static_begin_ptr()) + return; // we're already there + if (s > StaticCapacity) + return; // nothing we can do - atraits::deallocate(get_alloc(), m_dynamic_data, m_dynamic_capacity); - m_dynamic_data = nullptr; - m_dynamic_capacity = 0; + // revert to static capacity + auto old_end = m_end; + m_begin = m_end = static_begin_ptr(); + m_capacity = StaticCapacity; + for (auto p = m_dynamic_data; p != old_end; ++p) { + atraits::construct(get_alloc(), m_end, std::move(*p)); + ++m_end; + atraits::destroy(get_alloc(), p); } + } - void revert_to_static() - { - const auto s = size(); - if (m_begin == static_begin_ptr()) return; //we're already there - if (s > StaticCapacity) return; // nothing we can do - - // revert to static capacity - auto old_end = m_end; - m_begin = m_end = static_begin_ptr(); - m_capacity = StaticCapacity; - for (auto p = m_dynamic_data; p != old_end; ++p) - { - atraits::construct(get_alloc(), m_end, std::move(*p)); - ++m_end; - atraits::destroy(get_alloc(), p); - } + // modifiers + void clear() noexcept { + for (auto p = m_begin; p != m_end; ++p) { + atraits::destroy(get_alloc(), p); } - // modifiers - void clear() noexcept - { - for (auto p = m_begin; p != m_end; ++p) - { - atraits::destroy(get_alloc(), p); + if (RevertToStaticSize > 0) { + m_begin = m_end = static_begin_ptr(); + m_capacity = StaticCapacity; + } else { + m_end = m_begin; + } + } + + iterator insert(const_iterator position, const value_type &val) { + auto pos = grow_at(position, 1); + atraits::construct(get_alloc(), pos, val); + return pos; + } + + iterator insert(const_iterator position, value_type &&val) { + auto pos = grow_at(position, 1); + atraits::construct(get_alloc(), pos, std::move(val)); + return pos; + } + + iterator insert(const_iterator position, size_type count, + const value_type &val) { + auto pos = grow_at(position, count); + for (size_type i = 0; i < count; ++i) { + atraits::construct(get_alloc(), pos + i, val); + } + return pos; + } + + template ())> + iterator insert(const_iterator position, InputIterator first, + InputIterator last) { + auto pos = grow_at(position, last - first); + size_type i = 0; + auto np = pos; + for (auto p = first; p != last; ++p, ++np) { + atraits::construct(get_alloc(), np, *p); + } + return pos; + } + + iterator insert(const_iterator position, std::initializer_list ilist) { + auto pos = grow_at(position, ilist.size()); + size_type i = 0; + for (auto &elem : ilist) { + atraits::construct(get_alloc(), pos + i, elem); + ++i; + } + return pos; + } + + template + iterator emplace(const_iterator position, Args &&...args) { + auto pos = grow_at(position, 1); + atraits::construct(get_alloc(), pos, std::forward(args)...); + return pos; + } + + iterator erase(const_iterator position) { return shrink_at(position, 1); } + + iterator erase(const_iterator first, const_iterator last) { + I_ITLIB_SMALL_VECTOR_OUT_OF_RANGE_IF(first > last); + return shrink_at(first, last - first); + } + + void push_back(const_reference val) { + auto pos = grow_at(m_end, 1); + atraits::construct(get_alloc(), pos, val); + } + + void push_back(T &&val) { + auto pos = grow_at(m_end, 1); + atraits::construct(get_alloc(), pos, std::move(val)); + } + + template reference emplace_back(Args &&...args) { + auto pos = grow_at(m_end, 1); + atraits::construct(get_alloc(), pos, std::forward(args)...); + return *pos; + } + + void pop_back() { shrink_at(m_end - 1, 1); } + + void resize(size_type n, const value_type &v) { + auto new_buf = choose_data(n); + + if (new_buf == m_begin) { + // no special transfers needed + + auto new_end = m_begin + n; + + while (m_end > new_end) { + atraits::destroy(get_alloc(), --m_end); + } + + while (new_end > m_end) { + atraits::construct(get_alloc(), m_end++, v); + } + } else { + // we need to transfer the elements into the new buffer + + const auto s = size(); + const auto num_transfer = n < s ? n : s; + + for (size_type i = 0; i < num_transfer; ++i) { + atraits::construct(get_alloc(), new_buf + i, std::move(*(m_begin + i))); + } + + // free obsoletes + for (size_type i = 0; i < s; ++i) { + atraits::destroy(get_alloc(), m_begin + i); + } + + // construct new elements + for (size_type i = num_transfer; i < n; ++i) { + atraits::construct(get_alloc(), new_buf + i, v); + } + + if (new_buf == static_begin_ptr()) { + m_capacity = StaticCapacity; + } else { + if (m_begin != static_begin_ptr()) { + // we've moved from dyn to dyn memory, so deallocate the old one + atraits::deallocate(get_alloc(), m_begin, m_capacity); } + m_capacity = m_dynamic_capacity; + } - if (RevertToStaticSize > 0) - { - m_begin = m_end = static_begin_ptr(); - m_capacity = StaticCapacity; - } - else - { - m_end = m_begin; - } + m_begin = new_buf; + m_end = new_buf + n; } + } - iterator insert(const_iterator position, const value_type& val) - { - auto pos = grow_at(position, 1); - atraits::construct(get_alloc(), pos, val); - return pos; - } + void resize(size_type n) { + auto new_buf = choose_data(n); - iterator insert(const_iterator position, value_type&& val) - { - auto pos = grow_at(position, 1); - atraits::construct(get_alloc(), pos, std::move(val)); - return pos; - } + if (new_buf == m_begin) { + // no special transfers needed - iterator insert(const_iterator position, size_type count, const value_type& val) - { - auto pos = grow_at(position, count); - for (size_type i = 0; i < count; ++i) - { - atraits::construct(get_alloc(), pos + i, val); - } - return pos; - } + auto new_end = m_begin + n; - template ())> - iterator insert(const_iterator position, InputIterator first, InputIterator last) - { - auto pos = grow_at(position, last - first); - size_type i = 0; - auto np = pos; - for (auto p = first; p != last; ++p, ++np) - { - atraits::construct(get_alloc(), np, *p); - } - return pos; - } + while (m_end > new_end) { + atraits::destroy(get_alloc(), --m_end); + } - iterator insert(const_iterator position, std::initializer_list ilist) - { - auto pos = grow_at(position, ilist.size()); - size_type i = 0; - for (auto& elem : ilist) - { - atraits::construct(get_alloc(), pos + i, elem); - ++i; - } - return pos; - } + while (new_end > m_end) { + atraits::construct(get_alloc(), m_end++); + } + } else { + // we need to transfer the elements into the new buffer - template - iterator emplace(const_iterator position, Args&&... args) - { - auto pos = grow_at(position, 1); - atraits::construct(get_alloc(), pos, std::forward(args)...); - return pos; - } + const auto s = size(); + const auto num_transfer = n < s ? n : s; - iterator erase(const_iterator position) - { - return shrink_at(position, 1); - } + for (size_type i = 0; i < num_transfer; ++i) { + atraits::construct(get_alloc(), new_buf + i, std::move(*(m_begin + i))); + } - iterator erase(const_iterator first, const_iterator last) - { - I_ITLIB_SMALL_VECTOR_OUT_OF_RANGE_IF(first > last); - return shrink_at(first, last - first); - } + // free obsoletes + for (size_type i = 0; i < s; ++i) { + atraits::destroy(get_alloc(), m_begin + i); + } - void push_back(const_reference val) - { - auto pos = grow_at(m_end, 1); - atraits::construct(get_alloc(), pos, val); - } + // construct new elements + for (size_type i = num_transfer; i < n; ++i) { + atraits::construct(get_alloc(), new_buf + i); + } - void push_back(T&& val) - { - auto pos = grow_at(m_end, 1); - atraits::construct(get_alloc(), pos, std::move(val)); - } + if (new_buf == static_begin_ptr()) { + m_capacity = StaticCapacity; + } else { + if (m_begin != static_begin_ptr()) { + // we've moved from dyn to dyn memory, so deallocate the old one + atraits::deallocate(get_alloc(), m_begin, m_capacity); + } + m_capacity = m_dynamic_capacity; + } - template - reference emplace_back(Args&&... args) - { - auto pos = grow_at(m_end, 1); - atraits::construct(get_alloc(), pos, std::forward(args)...); - return *pos; + m_begin = new_buf; + m_end = new_buf + n; } + } - void pop_back() - { - shrink_at(m_end - 1, 1); - } +private: + T *static_begin_ptr() { return reinterpret_cast(m_static_data + 0); } - void resize(size_type n, const value_type& v) - { - auto new_buf = choose_data(n); + // increase the size by splicing the elements in such a way that + // a hole of uninitialized elements is left at position, with size num + // returns the (potentially new) address of the hole + T *grow_at(const T *cp, size_t num) { + auto position = const_cast(cp); - if (new_buf == m_begin) - { - // no special transfers needed + I_ITLIB_SMALL_VECTOR_OUT_OF_RANGE_IF(position < m_begin || + position > m_end); - auto new_end = m_begin + n; + const auto s = size(); + auto new_buf = choose_data(s + num); - while (m_end > new_end) - { - atraits::destroy(get_alloc(), --m_end); - } + if (new_buf == m_begin) { + // no special transfers needed - while (new_end > m_end) - { - atraits::construct(get_alloc(), m_end++, v); - } - } - else - { - // we need to transfer the elements into the new buffer - - const auto s = size(); - const auto num_transfer = n < s ? n : s; - - for (size_type i = 0; i < num_transfer; ++i) - { - atraits::construct(get_alloc(), new_buf + i, std::move(*(m_begin + i))); - } - - // free obsoletes - for (size_type i = 0; i < s; ++i) - { - atraits::destroy(get_alloc(), m_begin + i); - } - - // construct new elements - for (size_type i = num_transfer; i < n; ++i) - { - atraits::construct(get_alloc(), new_buf + i, v); - } - - if (new_buf == static_begin_ptr()) - { - m_capacity = StaticCapacity; - } - else - { - if (m_begin != static_begin_ptr()) - { - // we've moved from dyn to dyn memory, so deallocate the old one - atraits::deallocate(get_alloc(), m_begin, m_capacity); - } - m_capacity = m_dynamic_capacity; - } - - m_begin = new_buf; - m_end = new_buf + n; - } - } + m_end = m_begin + s + num; - void resize(size_type n) - { - auto new_buf = choose_data(n); + for (auto p = m_end - num - 1; p >= position; --p) { + atraits::construct(get_alloc(), p + num, std::move(*p)); + atraits::destroy(get_alloc(), p); + } - if (new_buf == m_begin) - { - // no special transfers needed + return position; + } else { + // we need to transfer the elements into the new buffer - auto new_end = m_begin + n; + position = new_buf + (position - m_begin); - while (m_end > new_end) - { - atraits::destroy(get_alloc(), --m_end); - } + auto p = m_begin; + auto np = new_buf; - while (new_end > m_end) - { - atraits::construct(get_alloc(), m_end++); - } - } - else - { - // we need to transfer the elements into the new buffer - - const auto s = size(); - const auto num_transfer = n < s ? n : s; - - for (size_type i = 0; i < num_transfer; ++i) - { - atraits::construct(get_alloc(), new_buf + i, std::move(*(m_begin + i))); - } - - // free obsoletes - for (size_type i = 0; i < s; ++i) - { - atraits::destroy(get_alloc(), m_begin + i); - } - - // construct new elements - for (size_type i = num_transfer; i < n; ++i) - { - atraits::construct(get_alloc(), new_buf + i); - } - - if (new_buf == static_begin_ptr()) - { - m_capacity = StaticCapacity; - } - else - { - if (m_begin != static_begin_ptr()) - { - // we've moved from dyn to dyn memory, so deallocate the old one - atraits::deallocate(get_alloc(), m_begin, m_capacity); - } - m_capacity = m_dynamic_capacity; - } - - m_begin = new_buf; - m_end = new_buf + n; - } - } + for (; np != position; ++p, ++np) { + atraits::construct(get_alloc(), np, std::move(*p)); + } -private: - T* static_begin_ptr() - { - return reinterpret_cast(m_static_data + 0); - } + np += num; + for (; p != m_end; ++p, ++np) { + atraits::construct(get_alloc(), np, std::move(*p)); + } - // increase the size by splicing the elements in such a way that - // a hole of uninitialized elements is left at position, with size num - // returns the (potentially new) address of the hole - T* grow_at(const T* cp, size_t num) - { - auto position = const_cast(cp); + // destroy old + for (p = m_begin; p != m_end; ++p) { + atraits::destroy(get_alloc(), p); + } - I_ITLIB_SMALL_VECTOR_OUT_OF_RANGE_IF(position < m_begin || position > m_end); + if (m_begin != static_begin_ptr()) { + // we've moved from dyn to dyn memory, so deallocate the old one + atraits::deallocate(get_alloc(), m_begin, m_capacity); + } - const auto s = size(); - auto new_buf = choose_data(s + num); + m_capacity = m_dynamic_capacity; - if (new_buf == m_begin) - { - // no special transfers needed + m_begin = new_buf; + m_end = new_buf + s + num; - m_end = m_begin + s + num; + return position; + } + } - for (auto p = m_end - num - 1; p >= position; --p) - { - atraits::construct(get_alloc(), p + num, std::move(*p)); - atraits::destroy(get_alloc(), p); - } + T *shrink_at(const T *cp, size_t num) { + auto position = const_cast(cp); - return position; - } - else - { - // we need to transfer the elements into the new buffer + I_ITLIB_SMALL_VECTOR_OUT_OF_RANGE_IF( + position < m_begin || position > m_end || position + num > m_end); - position = new_buf + (position - m_begin); + const auto s = size(); + if (s - num == 0) { + clear(); + return m_end; + } - auto p = m_begin; - auto np = new_buf; + auto new_buf = choose_data(s - num); - for (; np != position; ++p, ++np) - { - atraits::construct(get_alloc(), np, std::move(*p)); - } + if (new_buf == m_begin) { + // no special transfers needed - np += num; - for (; p != m_end; ++p, ++np) - { - atraits::construct(get_alloc(), np, std::move(*p)); - } + for (auto p = position, np = position + num; np != m_end; ++p, ++np) { + atraits::destroy(get_alloc(), p); + atraits::construct(get_alloc(), p, std::move(*np)); + } - // destroy old - for (p = m_begin; p != m_end; ++p) - { - atraits::destroy(get_alloc(), p); - } + for (auto p = m_end - num; p != m_end; ++p) { + atraits::destroy(get_alloc(), p); + } - if (m_begin != static_begin_ptr()) - { - // we've moved from dyn to dyn memory, so deallocate the old one - atraits::deallocate(get_alloc(), m_begin, m_capacity); - } + m_end -= num; + } else { + // we need to transfer the elements into the new buffer - m_capacity = m_dynamic_capacity; + assert(new_buf == static_begin_ptr()); // since we're shrinking that's the + // only way to have a new buffer - m_begin = new_buf; - m_end = new_buf + s + num; + m_capacity = StaticCapacity; - return position; - } - } + auto p = m_begin, np = new_buf; + for (; p != position; ++p, ++np) { + atraits::construct(get_alloc(), np, std::move(*p)); + atraits::destroy(get_alloc(), p); + } - T* shrink_at(const T* cp, size_t num) - { - auto position = const_cast(cp); + for (; p != position + num; ++p) { + atraits::destroy(get_alloc(), p); + } - I_ITLIB_SMALL_VECTOR_OUT_OF_RANGE_IF(position < m_begin || position > m_end || position + num > m_end); + for (; np != new_buf + s - num; ++p, ++np) { + atraits::construct(get_alloc(), np, std::move(*p)); + atraits::destroy(get_alloc(), p); + } - const auto s = size(); - if (s - num == 0) - { - clear(); - return m_end; - } + position = new_buf + (position - m_begin); + m_begin = new_buf; + m_end = np; + } - auto new_buf = choose_data(s - num); + return position; + } - if (new_buf == m_begin) - { - // no special transfers needed + void assign_impl(size_type count, const T &value) { + assert(m_begin); + assert(m_begin == m_end); - for (auto p = position, np = position + num; np != m_end; ++p, ++np) - { - atraits::destroy(get_alloc(), p); - atraits::construct(get_alloc(), p, std::move(*np)); - } + m_begin = m_end = choose_data(count); + for (size_type i = 0; i < count; ++i) { + atraits::construct(get_alloc(), m_end, value); + ++m_end; + } - for (auto p = m_end - num; p != m_end; ++p) - { - atraits::destroy(get_alloc(), p); - } + update_capacity(); + } - m_end -= num; - } - else - { - // we need to transfer the elements into the new buffer - - assert(new_buf == static_begin_ptr()); // since we're shrinking that's the only way to have a new buffer - - m_capacity = StaticCapacity; - - auto p = m_begin, np = new_buf; - for (; p != position; ++p, ++np) - { - atraits::construct(get_alloc(), np, std::move(*p)); - atraits::destroy(get_alloc(), p); - } - - for (; p != position + num; ++p) - { - atraits::destroy(get_alloc(), p); - } - - for (; np != new_buf + s - num; ++p, ++np) - { - atraits::construct(get_alloc(), np, std::move(*p)); - atraits::destroy(get_alloc(), p); - } - - position = new_buf + (position - m_begin); - m_begin = new_buf; - m_end = np; - } + template + void assign_impl(InputIterator first, InputIterator last) { + assert(m_begin); + assert(m_begin == m_end); - return position; + m_begin = m_end = choose_data(last - first); + for (auto p = first; p != last; ++p) { + atraits::construct(get_alloc(), m_end, *p); + ++m_end; } - void assign_impl(size_type count, const T& value) - { - assert(m_begin); - assert(m_begin == m_end); + update_capacity(); + } - m_begin = m_end = choose_data(count); - for (size_type i = 0; i < count; ++i) - { - atraits::construct(get_alloc(), m_end, value); - ++m_end; - } + void assign_impl(std::initializer_list ilist) { + assert(m_begin); + assert(m_begin == m_end); - update_capacity(); + m_begin = m_end = choose_data(ilist.size()); + for (auto &elem : ilist) { + atraits::construct(get_alloc(), m_end, elem); + ++m_end; } - template - void assign_impl(InputIterator first, InputIterator last) - { - assert(m_begin); - assert(m_begin == m_end); - - m_begin = m_end = choose_data(last - first); - for (auto p = first; p != last; ++p) - { - atraits::construct(get_alloc(), m_end, *p); - ++m_end; - } + update_capacity(); + } - update_capacity(); + void update_capacity() { + if (m_begin == static_begin_ptr()) { + m_capacity = StaticCapacity; + } else { + m_capacity = m_dynamic_capacity; } + } - void assign_impl(std::initializer_list ilist) - { - assert(m_begin); - assert(m_begin == m_end); + T *choose_data(size_t desired_capacity) { + if (m_begin == m_dynamic_data) { + // we're at the dyn buffer, so see if it needs resize or revert to static - m_begin = m_end = choose_data(ilist.size()); - for (auto& elem : ilist) - { - atraits::construct(get_alloc(), m_end, elem); - ++m_end; + if (desired_capacity > m_dynamic_capacity) { + while (m_dynamic_capacity < desired_capacity) { + // grow by roughly 1.5 + m_dynamic_capacity *= 3; + ++m_dynamic_capacity; + m_dynamic_capacity /= 2; } - update_capacity(); - } + m_dynamic_data = atraits::allocate(get_alloc(), m_dynamic_capacity); + return m_dynamic_data; + } else if (desired_capacity < RevertToStaticSize) { + // we're reverting to the static buffer + return static_begin_ptr(); + } else { + // if the capacity and we don't revert to static, just do nothing + return m_dynamic_data; + } + } else { + assert(m_begin == static_begin_ptr()); // corrupt begin ptr? - void update_capacity() - { - if (m_begin == static_begin_ptr()) - { - m_capacity = StaticCapacity; - } - else - { - m_capacity = m_dynamic_capacity; - } - } + if (desired_capacity > StaticCapacity) { + // we must move to dyn memory - T* choose_data(size_t desired_capacity) - { - if (m_begin == m_dynamic_data) - { - // we're at the dyn buffer, so see if it needs resize or revert to static - - if (desired_capacity > m_dynamic_capacity) - { - while (m_dynamic_capacity < desired_capacity) - { - // grow by roughly 1.5 - m_dynamic_capacity *= 3; - ++m_dynamic_capacity; - m_dynamic_capacity /= 2; - } - - m_dynamic_data = atraits::allocate(get_alloc(), m_dynamic_capacity); - return m_dynamic_data; - } - else if (desired_capacity < RevertToStaticSize) - { - // we're reverting to the static buffer - return static_begin_ptr(); - } - else - { - // if the capacity and we don't revert to static, just do nothing - return m_dynamic_data; - } - } - else - { - assert(m_begin == static_begin_ptr()); // corrupt begin ptr? - - if (desired_capacity > StaticCapacity) - { - // we must move to dyn memory - - // see if we have enough - if (desired_capacity > m_dynamic_capacity) - { - // we need to allocate more - // we don't have anything to destroy, so we can also deallocate the buffer - if (m_dynamic_data) - { - atraits::deallocate(get_alloc(), m_dynamic_data, m_dynamic_capacity); - } - - m_dynamic_capacity = desired_capacity; - m_dynamic_data = atraits::allocate(get_alloc(), m_dynamic_capacity); - } - - return m_dynamic_data; - } - else - { - // we have enough capacity as it is - return static_begin_ptr(); - } + // see if we have enough + if (desired_capacity > m_dynamic_capacity) { + // we need to allocate more + // we don't have anything to destroy, so we can also deallocate the + // buffer + if (m_dynamic_data) { + atraits::deallocate(get_alloc(), m_dynamic_data, + m_dynamic_capacity); + } + + m_dynamic_capacity = desired_capacity; + m_dynamic_data = atraits::allocate(get_alloc(), m_dynamic_capacity); } + + return m_dynamic_data; + } else { + // we have enough capacity as it is + return static_begin_ptr(); + } } + } - allocator_type& get_alloc() { return m_alloc; } - const allocator_type& get_alloc() const { return m_alloc; } + allocator_type &get_alloc() { return m_alloc; } + const allocator_type &get_alloc() const { return m_alloc; } - allocator_type m_alloc; + allocator_type m_alloc; - pointer m_begin; - pointer m_end; + pointer m_begin; + pointer m_end; - size_t m_capacity; - typename std::aligned_storage::value>::type m_static_data[StaticCapacity]; + size_t m_capacity; + typename std::aligned_storage::value>::type + m_static_data[StaticCapacity]; - size_t m_dynamic_capacity; - pointer m_dynamic_data; + size_t m_dynamic_capacity; + pointer m_dynamic_data; }; -template -bool operator==(const small_vector& a, - const small_vector& b) -{ - if (a.size() != b.size()) - { - return false; - } - - for (size_t i = 0; i < a.size(); ++i) - { - if (!(a[i] == b[i])) - return false; - } - - return true; +template +bool operator==( + const small_vector &a, + const small_vector &b) { + if (a.size() != b.size()) { + return false; + } + + for (size_t i = 0; i < a.size(); ++i) { + if (!(a[i] == b[i])) + return false; + } + + return true; } -template -bool operator!=(const small_vector& a, - const small_vector& b) +template +bool operator!=( + const small_vector &a, + const small_vector &b) { - return !operator==(a, b); + return !operator==(a, b); } -} \ No newline at end of file +} // namespace itlib \ No newline at end of file diff --git a/tests/test_PRTree.py b/tests/test_PRTree.py index 3cea56e..6a00529 100644 --- a/tests/test_PRTree.py +++ b/tests/test_PRTree.py @@ -381,3 +381,180 @@ def test_save_load_float32_no_regression(tmp_path): assert results_before == results_after, \ "Float32 path: results changed after save/load cycle" + + +@pytest.mark.parametrize("seed", range(N_SEED)) +@pytest.mark.parametrize("PRTree, dim", [(PRTree2D, 2), (PRTree3D, 3), (PRTree4D, 4)]) +def test_query_intersections(seed, PRTree, dim): + """Test query_intersections() method returns correct pairs of intersecting AABBs.""" + np.random.seed(seed) + idx = np.arange(50) # Use smaller dataset for faster testing + x = np.random.rand(len(idx), 2 * dim) + for i in range(dim): + x[:, i + dim] += x[:, i] + + prtree = PRTree(idx, x) + pairs = prtree.query_intersections() + + # Verify output shape + assert pairs.ndim == 2 + assert pairs.shape[1] == 2 + + # Verify all pairs are valid (i < j) + assert np.all(pairs[:, 0] < pairs[:, 1]) + + # Verify pairs are unique + pairs_set = set(map(tuple, pairs)) + assert len(pairs_set) == len(pairs), "Duplicate pairs found" + + # Verify correctness: compare against naive approach + expected_pairs = [] + for i in range(len(idx)): + for j in range(i + 1, len(idx)): + if has_intersect(x[i], x[j], dim): + expected_pairs.append((idx[i], idx[j])) + + expected_set = set(expected_pairs) + assert pairs_set == expected_set, \ + f"Mismatch: expected {len(expected_set)} pairs, got {len(pairs_set)}" + + +@pytest.mark.parametrize("PRTree, dim", [(PRTree2D, 2), (PRTree3D, 3), (PRTree4D, 4)]) +def test_query_intersections_no_intersections(PRTree, dim): + """Test query_intersections() with non-overlapping AABBs.""" + # Create well-separated boxes + idx = np.arange(10) + x = np.zeros((len(idx), 2 * dim)) + + for i in range(len(idx)): + # Each box at distance 10*i, size 1 + for d in range(dim): + x[i, d] = 10 * i + d * 0.1 + x[i, d + dim] = 10 * i + d * 0.1 + 1 + + prtree = PRTree(idx, x) + pairs = prtree.query_intersections() + + # Should have no intersections + assert pairs.shape[0] == 0 + assert pairs.shape[1] == 2 + + +@pytest.mark.parametrize("PRTree, dim", [(PRTree2D, 2), (PRTree3D, 3), (PRTree4D, 4)]) +def test_query_intersections_all_intersections(PRTree, dim): + """Test query_intersections() where all boxes intersect.""" + # Create boxes that all overlap at origin + idx = np.arange(10) + x = np.zeros((len(idx), 2 * dim)) + + for i in range(len(idx)): + for d in range(dim): + x[i, d] = -1.0 - i * 0.1 + x[i, d + dim] = 1.0 + i * 0.1 + + prtree = PRTree(idx, x) + pairs = prtree.query_intersections() + + # All boxes should intersect: n*(n-1)/2 pairs + n = len(idx) + expected_count = n * (n - 1) // 2 + assert pairs.shape[0] == expected_count, \ + f"Expected {expected_count} pairs, got {pairs.shape[0]}" + + +@pytest.mark.parametrize("PRTree, dim", [(PRTree2D, 2), (PRTree3D, 3), (PRTree4D, 4)]) +def test_query_intersections_empty_tree(PRTree, dim): + """Test query_intersections() on empty tree.""" + prtree = PRTree() + pairs = prtree.query_intersections() + + assert pairs.shape == (0, 2) + + +@pytest.mark.parametrize("seed", range(N_SEED)) +@pytest.mark.parametrize("PRTree, dim", [(PRTree2D, 2), (PRTree3D, 3), (PRTree4D, 4)]) +def test_query_intersections_float64(seed, PRTree, dim): + """Test query_intersections() with float64 input (uses exact coordinate refinement).""" + np.random.seed(seed) + idx = np.arange(50) + x = np.random.rand(len(idx), 2 * dim).astype(np.float64) + for i in range(dim): + x[:, i + dim] += x[:, i] + + prtree = PRTree(idx, x) + pairs = prtree.query_intersections() + + # Verify output shape and constraints + assert pairs.ndim == 2 + assert pairs.shape[1] == 2 + assert np.all(pairs[:, 0] < pairs[:, 1]) + + # Verify correctness + expected_pairs = [] + for i in range(len(idx)): + for j in range(i + 1, len(idx)): + if has_intersect(x[i], x[j], dim): + expected_pairs.append((idx[i], idx[j])) + + pairs_set = set(map(tuple, pairs)) + expected_set = set(expected_pairs) + assert pairs_set == expected_set + + +@pytest.mark.parametrize("PRTree, dim", [(PRTree2D, 2), (PRTree3D, 3), (PRTree4D, 4)]) +def test_query_intersections_touching_boxes(PRTree, dim): + """Test that touching boxes are considered intersecting (closed interval semantics).""" + idx = np.array([0, 1]) + x = np.zeros((2, 2 * dim)) + + # Box 0: [0, 1] in all dimensions + for d in range(dim): + x[0, d] = 0.0 + x[0, d + dim] = 1.0 + + # Box 1: [1, 2] in all dimensions (touches box 0) + for d in range(dim): + x[1, d] = 1.0 + x[1, d + dim] = 2.0 + + prtree = PRTree(idx, x) + pairs = prtree.query_intersections() + + # Boxes should be considered intersecting (closed intervals) + assert pairs.shape[0] == 1 + assert tuple(pairs[0]) == (0, 1) + + +@pytest.mark.parametrize("PRTree, dim", [(PRTree2D, 2), (PRTree3D, 3), (PRTree4D, 4)]) +def test_query_intersections_after_insert_erase(PRTree, dim): + """Test query_intersections() after dynamic updates.""" + np.random.seed(42) + idx = np.arange(20) + x = np.random.rand(len(idx), 2 * dim) + for i in range(dim): + x[:, i + dim] += x[:, i] + + prtree = PRTree(idx, x) + + # Get initial pairs + pairs_initial = prtree.query_intersections() + + # Insert a new box that overlaps all existing boxes + new_box = np.zeros(2 * dim) + for d in range(dim): + new_box[d] = -10.0 + new_box[d + dim] = 10.0 + + inserted_idx = max(idx) + 1 + prtree.insert(idx=inserted_idx, bb=new_box) + + # Should have more pairs now + pairs_after_insert = prtree.query_intersections() + assert pairs_after_insert.shape[0] > pairs_initial.shape[0] + + # Erase the new box + prtree.erase(inserted_idx) + + # Should go back to original count (approximately - might differ due to rebuilding) + pairs_after_erase = prtree.query_intersections() + assert abs(pairs_after_erase.shape[0] - pairs_initial.shape[0]) <= 1