diff --git a/examples/ice_disk_to_arrow_memory.py b/examples/ice_disk_to_arrow_memory.py new file mode 100644 index 0000000..52df48a --- /dev/null +++ b/examples/ice_disk_to_arrow_memory.py @@ -0,0 +1,89 @@ +""" +Register icebug-disk Parquet files as Arrow memory-backed tables. + +The example keeps the data in PyArrow tables and exposes it to Ladybug as +ice-mem/Arrow tables. Relationship tables can be either FLAT or CSR. +""" + +from __future__ import annotations + +import argparse +from pathlib import Path + +import ladybug as lb +import pyarrow.parquet as pq + + +def register_flat( + conn: lb.Connection, + data_dir: Path, + node_table: str, + rel_table: str, + src_table: str, + dst_table: str, +) -> None: + """Register FLAT icebug-disk Parquet files as Arrow memory-backed tables.""" + nodes = pq.read_table(data_dir / f"nodes_{node_table}.parquet") + rels = pq.read_table(data_dir / f"rels_{rel_table}.parquet") + + conn.create_arrow_table(node_table, nodes) + conn.create_arrow_rel_table( + rel_table, + rels, + src_table, + dst_table, + layout=lb.ArrowRelTableLayout.FLAT, + ) + + +def register_csr( + conn: lb.Connection, + data_dir: Path, + node_table: str, + rel_table: str, + src_table: str, + dst_table: str, +) -> None: + """Register CSR icebug-disk Parquet files as Arrow memory-backed tables.""" + nodes = pq.read_table(data_dir / f"nodes_{node_table}.parquet") + indices = pq.read_table(data_dir / f"indices_{rel_table}.parquet") + indptr = pq.read_table(data_dir / f"indptr_{rel_table}.parquet") + + conn.create_arrow_table(node_table, nodes) + conn.create_arrow_rel_table( + rel_table, + indices, + src_table, + dst_table, + layout=lb.ArrowRelTableLayout.CSR, + indptr_dataframe=indptr, + ) + + +def main() -> None: + """Run the example.""" + parser = argparse.ArgumentParser() + parser.add_argument("data_dir", type=Path) + parser.add_argument("--layout", choices=["flat", "csr"], default="csr") + parser.add_argument("--node-table", required=True) + parser.add_argument("--rel-table", required=True) + parser.add_argument("--src-table") + parser.add_argument("--dst-table") + args = parser.parse_args() + + src_table = args.src_table or args.node_table + dst_table = args.dst_table or args.node_table + + db = lb.Database(":memory:") + conn = db.connect() + if args.layout == "flat": + register_flat(conn, args.data_dir, args.node_table, args.rel_table, src_table, dst_table) + else: + register_csr(conn, args.data_dir, args.node_table, args.rel_table, src_table, dst_table) + + result = conn.execute(f"MATCH (a:{src_table})-[r:{args.rel_table}]->(b:{dst_table}) RETURN COUNT(*)") + print(result.get_next()[0]) + + +if __name__ == "__main__": + main() diff --git a/src_cpp/include/py_connection.h b/src_cpp/include/py_connection.h index 97c6ca9..61a971e 100644 --- a/src_cpp/include/py_connection.h +++ b/src_cpp/include/py_connection.h @@ -56,7 +56,8 @@ class PyConnection { std::unique_ptr createArrowTable(const std::string& tableName, py::object arrowTable); std::unique_ptr createArrowRelTable(const std::string& tableName, - py::object arrowTable, const std::string& srcTableName, const std::string& dstTableName); + py::object arrowTable, const std::string& srcTableName, const std::string& dstTableName, + const std::string& layout, py::object indptrTable); std::unique_ptr dropArrowTable(const std::string& tableName); static Value transformPythonValue(const py::handle& val); diff --git a/src_cpp/py_connection.cpp b/src_cpp/py_connection.cpp index 56a7c20..ffa1a56 100644 --- a/src_cpp/py_connection.cpp +++ b/src_cpp/py_connection.cpp @@ -1,5 +1,7 @@ #include "include/py_connection.h" +#include +#include #include #include "cached_import/py_cached_import.h" @@ -52,7 +54,8 @@ void PyConnection::initialize(py::handle& m) { .def("create_arrow_table", &PyConnection::createArrowTable, py::arg("table_name"), py::arg("arrow_table")) .def("create_arrow_rel_table", &PyConnection::createArrowRelTable, py::arg("table_name"), - py::arg("arrow_table"), py::arg("src_table_name"), py::arg("dst_table_name")) + py::arg("arrow_table"), py::arg("src_table_name"), py::arg("dst_table_name"), + py::arg("layout") = "FLAT", py::arg("indptr_table") = py::none()) .def("drop_arrow_table", &PyConnection::dropArrowTable, py::arg("table_name")); PyDateTime_IMPORT; } @@ -1013,79 +1016,91 @@ void PyConnection::removeScalarFunction(const std::string& name) { refState().ref().removeUDFFunction(name); } -std::unique_ptr PyConnection::createArrowTable(const std::string& tableName, - py::object arrowTable) { - auto& stateRef = refState(); - py::gil_scoped_acquire acquire; +struct ExportedArrowTable { + ArrowSchemaWrapper schema; + std::vector arrays; + py::list keepAlive; +}; - // Convert pandas/polars to pyarrow if needed +static py::object normalizeArrowTable(py::object arrowTable) { if (PyConnection::isPandasDataframe(arrowTable)) { - arrowTable = importCache->pyarrow.lib.Table.from_pandas()(arrowTable); - } else if (PyConnection::isPolarsDataframe(arrowTable)) { - arrowTable = arrowTable.attr("to_arrow")(); + return importCache->pyarrow.lib.Table.from_pandas()(arrowTable); + } + if (PyConnection::isPolarsDataframe(arrowTable)) { + return arrowTable.attr("to_arrow")(); } - - // Ensure we have a pyarrow table if (!PyConnection::isPyArrowTable(arrowTable)) { throw RuntimeException("Expected a pyarrow Table, polars DataFrame, or pandas DataFrame"); } + return arrowTable; +} - // Export Arrow table to C Data Interface - // First, get the schema - ArrowSchemaWrapper schema; - arrowTable.attr("schema").attr("_export_to_c")(reinterpret_cast(&schema)); +static ExportedArrowTable exportArrowTable(py::object arrowTable) { + arrowTable = normalizeArrowTable(std::move(arrowTable)); + + ExportedArrowTable exported; + arrowTable.attr("schema").attr("_export_to_c")(reinterpret_cast(&exported.schema)); - // Get the batches (arrays) - std::vector arrays; py::list batches = arrowTable.attr("to_batches")(); for (auto& batch : batches) { - arrays.emplace_back(); - batch.attr("_export_to_c")(reinterpret_cast(&arrays.back())); + exported.arrays.emplace_back(); + batch.attr("_export_to_c")(reinterpret_cast(&exported.arrays.back())); } // Keep pyarrow producers alive while C++ accesses exported Arrow memory. - py::list keepAlive; - keepAlive.append(arrowTable); - keepAlive.append(batches); + exported.keepAlive.append(arrowTable); + exported.keepAlive.append(batches); + return exported; +} + +std::unique_ptr PyConnection::createArrowTable(const std::string& tableName, + py::object arrowTable) { + auto& stateRef = refState(); + py::gil_scoped_acquire acquire; + auto exported = exportArrowTable(std::move(arrowTable)); auto result = ArrowTableSupport::createViewFromArrowTable(stateRef.ref(), tableName, - std::move(schema), std::move(arrays)); + std::move(exported.schema), std::move(exported.arrays)); if (result.queryResult && result.queryResult->isSuccess()) { - stateRef.arrowTableRefs[tableName] = std::move(keepAlive); + stateRef.arrowTableRefs[tableName] = std::move(exported.keepAlive); } return checkAndWrapQueryResult(result.queryResult, state); } std::unique_ptr PyConnection::createArrowRelTable(const std::string& tableName, - py::object arrowTable, const std::string& srcTableName, const std::string& dstTableName) { + py::object arrowTable, const std::string& srcTableName, const std::string& dstTableName, + const std::string& layout, py::object indptrTable) { auto& stateRef = refState(); py::gil_scoped_acquire acquire; - if (PyConnection::isPandasDataframe(arrowTable)) { - arrowTable = importCache->pyarrow.lib.Table.from_pandas()(arrowTable); - } else if (PyConnection::isPolarsDataframe(arrowTable)) { - arrowTable = arrowTable.attr("to_arrow")(); - } - if (!PyConnection::isPyArrowTable(arrowTable)) { - throw RuntimeException("Expected a pyarrow Table, polars DataFrame, or pandas DataFrame"); - } - - ArrowSchemaWrapper schema; - arrowTable.attr("schema").attr("_export_to_c")(reinterpret_cast(&schema)); - std::vector arrays; - py::list batches = arrowTable.attr("to_batches")(); - for (auto& batch : batches) { - arrays.emplace_back(); - batch.attr("_export_to_c")(reinterpret_cast(&arrays.back())); - } + auto layoutUpper = layout; + std::transform(layoutUpper.begin(), layoutUpper.end(), layoutUpper.begin(), + [](unsigned char c) { return static_cast(std::toupper(c)); }); + auto exported = exportArrowTable(std::move(arrowTable)); + ArrowTableCreationResult result; py::list keepAlive; - keepAlive.append(arrowTable); - keepAlive.append(batches); + keepAlive.append(exported.keepAlive); - auto result = ArrowTableSupport::createRelTableFromArrowTable(stateRef.ref(), tableName, - srcTableName, dstTableName, std::move(schema), std::move(arrays)); + if (layoutUpper == "FLAT") { + if (!py::none().is(indptrTable)) { + throw RuntimeException("indptr_table is only valid for CSR Arrow relationship tables"); + } + result = ArrowTableSupport::createRelTableFromArrowTable(stateRef.ref(), tableName, + srcTableName, dstTableName, std::move(exported.schema), std::move(exported.arrays)); + } else if (layoutUpper == "CSR") { + if (py::none().is(indptrTable)) { + throw RuntimeException("indptr_table is required for CSR Arrow relationship tables"); + } + auto exportedIndptr = exportArrowTable(std::move(indptrTable)); + keepAlive.append(exportedIndptr.keepAlive); + result = ArrowTableSupport::createRelTableFromArrowCSR(stateRef.ref(), tableName, + srcTableName, dstTableName, std::move(exported.schema), std::move(exported.arrays), + std::move(exportedIndptr.schema), std::move(exportedIndptr.arrays)); + } else { + throw RuntimeException("Arrow relationship table layout must be FLAT or CSR"); + } if (result.queryResult && result.queryResult->isSuccess()) { stateRef.arrowTableRefs[tableName] = std::move(keepAlive); } diff --git a/src_py/__init__.py b/src_py/__init__.py index 782cbda..80475be 100644 --- a/src_py/__init__.py +++ b/src_py/__init__.py @@ -57,7 +57,7 @@ from .database import Database # noqa: E402 from .prepared_statement import PreparedStatement # noqa: E402 from .query_result import ArrowQueryResult, CSRResult, QueryResult # noqa: E402 -from .types import Type # noqa: E402 +from .types import ArrowRelTableLayout, Type # noqa: E402 _VERSION_INFO: tuple[str, int] | None = None @@ -81,6 +81,7 @@ def __getattr__(name: str) -> str | int: __all__ = [ "AsyncConnection", "ArrowQueryResult", + "ArrowRelTableLayout", "Connection", "CSRResult", "Database", diff --git a/src_py/connection.py b/src_py/connection.py index 4fe6665..7d81595 100644 --- a/src_py/connection.py +++ b/src_py/connection.py @@ -11,6 +11,7 @@ from ._backend import get_capi_module, get_pybind_module from .prepared_statement import PreparedStatement from .query_result import ArrowQueryResult, QueryResult +from .types import ArrowRelTableLayout if TYPE_CHECKING: import sys @@ -811,6 +812,8 @@ def create_arrow_rel_table( dataframe: Any, src_table_name: str, dst_table_name: str, + layout: ArrowRelTableLayout | str = ArrowRelTableLayout.FLAT, + indptr_dataframe: Any | None = None, ) -> QueryResult: """ Create an Arrow memory-backed relationship table from a DataFrame. @@ -829,6 +832,16 @@ def create_arrow_rel_table( dst_table_name : str Destination node table name in the FROM/TO pair. + layout : ArrowRelTableLayout | str + Relationship layout. FLAT expects ``dataframe`` to contain ``from`` + and ``to`` endpoint columns. CSR expects ``dataframe`` to contain a + ``to`` destination offset column plus properties, and + ``indptr_dataframe`` to contain source offsets. + + indptr_dataframe : Any | None + A pandas DataFrame, polars DataFrame, or PyArrow table containing + CSR source offsets. Required when ``layout`` is CSR. + Returns ------- QueryResult @@ -836,12 +849,20 @@ def create_arrow_rel_table( """ self.init_connection() + layout_value = ( + layout.value if isinstance(layout, ArrowRelTableLayout) else str(layout) + ).upper() + if layout_value == ArrowRelTableLayout.CSR.value and indptr_dataframe is None: + msg = "indptr_dataframe is required when layout is CSR" + raise ValueError(msg) try: query_result_internal = self._connection.create_arrow_rel_table( table_name, dataframe, src_table_name, dst_table_name, + layout_value, + indptr_dataframe, ) except NotImplementedError: py_connection = self._get_pybind_connection() @@ -853,6 +874,8 @@ def create_arrow_rel_table( dataframe, src_table_name, dst_table_name, + layout_value, + indptr_dataframe, ) if not query_result_internal.isSuccess(): raise RuntimeError(query_result_internal.getErrorMessage()) diff --git a/src_py/types.py b/src_py/types.py index 4566e24..844d9c7 100644 --- a/src_py/types.py +++ b/src_py/types.py @@ -37,3 +37,10 @@ class Type(Enum): STRUCT = "STRUCT" MAP = "MAP" UNION = "UNION" + + +class ArrowRelTableLayout(Enum): + """Arrow-backed relationship table layout.""" + + FLAT = "FLAT" + CSR = "CSR" diff --git a/test/test_arrow_memory_backed_table.py b/test/test_arrow_memory_backed_table.py index abd4680..d2edd0f 100644 --- a/test/test_arrow_memory_backed_table.py +++ b/test/test_arrow_memory_backed_table.py @@ -339,6 +339,68 @@ def test_arrow_memory_backed_arrow_node_and_rel_table(conn_db_empty: ConnDB) -> conn.drop_arrow_table("arrow_people") +def test_arrow_memory_backed_csr_arrow_rel_table(conn_db_empty: ConnDB) -> None: + """Test an Arrow memory-backed CSR relationship over Arrow-backed nodes.""" + conn, _ = conn_db_empty + + import ladybug as lb + + pa = pytest.importorskip("pyarrow") + + people = pa.Table.from_arrays( + [ + pa.array([1, 2, 3], type=pa.int64()), + pa.array(["Alice", "Bob", "Carol"], type=pa.string()), + ], + names=["id", "name"], + ) + conn.create_arrow_table("arrow_csr_people", people) + + indices = pa.Table.from_arrays( + [ + pa.array([1, 2, 2], type=pa.uint64()), + pa.array([10, 20, 30], type=pa.int64()), + ], + names=["to", "weight"], + ) + indptr = pa.Table.from_arrays( + [pa.array([0, 2, 3, 3], type=pa.uint64())], + names=["indptr"], + ) + conn.create_arrow_rel_table( + "arrow_csr_knows", + indices, + "arrow_csr_people", + "arrow_csr_people", + layout=lb.ArrowRelTableLayout.CSR, + indptr_dataframe=indptr, + ) + + result = conn.execute( + "MATCH (a:arrow_csr_people)-[r:arrow_csr_knows]->(b:arrow_csr_people) " + "RETURN a.name, b.name, r.weight ORDER BY a.id, b.id" + ) + rows = [] + while result.has_next(): + rows.append(result.get_next()) + + assert rows == [ + ["Alice", "Bob", 10], + ["Alice", "Carol", 20], + ["Bob", "Carol", 30], + ] + + result = conn.execute( + "MATCH (:arrow_csr_people)<-[r:arrow_csr_knows]-(:arrow_csr_people) " + "RETURN COUNT(*), SUM(r.weight)" + ) + assert result.get_next() == [3, 60] + assert not result.has_next() + + conn.drop_arrow_table("arrow_csr_knows") + conn.drop_arrow_table("arrow_csr_people") + + def test_arrow_memory_backed_native_node_and_arrow_rel_table( conn_db_empty: ConnDB, ) -> None: