Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 89 additions & 0 deletions examples/ice_disk_to_arrow_memory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
"""
Register icebug-disk Parquet files as Arrow memory-backed tables.

The example keeps the data in PyArrow tables and exposes it to Ladybug as
ice-mem/Arrow tables. Relationship tables can be either FLAT or CSR.
"""

from __future__ import annotations

import argparse
from pathlib import Path

import ladybug as lb
import pyarrow.parquet as pq


def register_flat(
conn: lb.Connection,
data_dir: Path,
node_table: str,
rel_table: str,
src_table: str,
dst_table: str,
) -> None:
"""Register FLAT icebug-disk Parquet files as Arrow memory-backed tables."""
nodes = pq.read_table(data_dir / f"nodes_{node_table}.parquet")
rels = pq.read_table(data_dir / f"rels_{rel_table}.parquet")

conn.create_arrow_table(node_table, nodes)
conn.create_arrow_rel_table(
rel_table,
rels,
src_table,
dst_table,
layout=lb.ArrowRelTableLayout.FLAT,
)


def register_csr(
conn: lb.Connection,
data_dir: Path,
node_table: str,
rel_table: str,
src_table: str,
dst_table: str,
) -> None:
"""Register CSR icebug-disk Parquet files as Arrow memory-backed tables."""
nodes = pq.read_table(data_dir / f"nodes_{node_table}.parquet")
indices = pq.read_table(data_dir / f"indices_{rel_table}.parquet")
indptr = pq.read_table(data_dir / f"indptr_{rel_table}.parquet")

conn.create_arrow_table(node_table, nodes)
conn.create_arrow_rel_table(
rel_table,
indices,
src_table,
dst_table,
layout=lb.ArrowRelTableLayout.CSR,
indptr_dataframe=indptr,
)


def main() -> None:
"""Run the example."""
parser = argparse.ArgumentParser()
parser.add_argument("data_dir", type=Path)
parser.add_argument("--layout", choices=["flat", "csr"], default="csr")
parser.add_argument("--node-table", required=True)
parser.add_argument("--rel-table", required=True)
parser.add_argument("--src-table")
parser.add_argument("--dst-table")
args = parser.parse_args()

src_table = args.src_table or args.node_table
dst_table = args.dst_table or args.node_table

db = lb.Database(":memory:")
conn = db.connect()
if args.layout == "flat":
register_flat(conn, args.data_dir, args.node_table, args.rel_table, src_table, dst_table)
else:
register_csr(conn, args.data_dir, args.node_table, args.rel_table, src_table, dst_table)

result = conn.execute(f"MATCH (a:{src_table})-[r:{args.rel_table}]->(b:{dst_table}) RETURN COUNT(*)")
print(result.get_next()[0])


if __name__ == "__main__":
main()
3 changes: 2 additions & 1 deletion src_cpp/include/py_connection.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ class PyConnection {
std::unique_ptr<PyQueryResult> createArrowTable(const std::string& tableName,
py::object arrowTable);
std::unique_ptr<PyQueryResult> createArrowRelTable(const std::string& tableName,
py::object arrowTable, const std::string& srcTableName, const std::string& dstTableName);
py::object arrowTable, const std::string& srcTableName, const std::string& dstTableName,
const std::string& layout, py::object indptrTable);
std::unique_ptr<PyQueryResult> dropArrowTable(const std::string& tableName);

static Value transformPythonValue(const py::handle& val);
Expand Down
107 changes: 61 additions & 46 deletions src_cpp/py_connection.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#include "include/py_connection.h"

#include <algorithm>
#include <cctype>
#include <utility>

#include "cached_import/py_cached_import.h"
Expand Down Expand Up @@ -52,7 +54,8 @@ void PyConnection::initialize(py::handle& m) {
.def("create_arrow_table", &PyConnection::createArrowTable, py::arg("table_name"),
py::arg("arrow_table"))
.def("create_arrow_rel_table", &PyConnection::createArrowRelTable, py::arg("table_name"),
py::arg("arrow_table"), py::arg("src_table_name"), py::arg("dst_table_name"))
py::arg("arrow_table"), py::arg("src_table_name"), py::arg("dst_table_name"),
py::arg("layout") = "FLAT", py::arg("indptr_table") = py::none())
.def("drop_arrow_table", &PyConnection::dropArrowTable, py::arg("table_name"));
PyDateTime_IMPORT;
}
Expand Down Expand Up @@ -1013,79 +1016,91 @@ void PyConnection::removeScalarFunction(const std::string& name) {
refState().ref().removeUDFFunction(name);
}

std::unique_ptr<PyQueryResult> PyConnection::createArrowTable(const std::string& tableName,
py::object arrowTable) {
auto& stateRef = refState();
py::gil_scoped_acquire acquire;
struct ExportedArrowTable {
ArrowSchemaWrapper schema;
std::vector<ArrowArrayWrapper> arrays;
py::list keepAlive;
};

// Convert pandas/polars to pyarrow if needed
static py::object normalizeArrowTable(py::object arrowTable) {
if (PyConnection::isPandasDataframe(arrowTable)) {
arrowTable = importCache->pyarrow.lib.Table.from_pandas()(arrowTable);
} else if (PyConnection::isPolarsDataframe(arrowTable)) {
arrowTable = arrowTable.attr("to_arrow")();
return importCache->pyarrow.lib.Table.from_pandas()(arrowTable);
}
if (PyConnection::isPolarsDataframe(arrowTable)) {
return arrowTable.attr("to_arrow")();
}

// Ensure we have a pyarrow table
if (!PyConnection::isPyArrowTable(arrowTable)) {
throw RuntimeException("Expected a pyarrow Table, polars DataFrame, or pandas DataFrame");
}
return arrowTable;
}

// Export Arrow table to C Data Interface
// First, get the schema
ArrowSchemaWrapper schema;
arrowTable.attr("schema").attr("_export_to_c")(reinterpret_cast<uint64_t>(&schema));
static ExportedArrowTable exportArrowTable(py::object arrowTable) {
arrowTable = normalizeArrowTable(std::move(arrowTable));

ExportedArrowTable exported;
arrowTable.attr("schema").attr("_export_to_c")(reinterpret_cast<uint64_t>(&exported.schema));

// Get the batches (arrays)
std::vector<ArrowArrayWrapper> arrays;
py::list batches = arrowTable.attr("to_batches")();
for (auto& batch : batches) {
arrays.emplace_back();
batch.attr("_export_to_c")(reinterpret_cast<uint64_t>(&arrays.back()));
exported.arrays.emplace_back();
batch.attr("_export_to_c")(reinterpret_cast<uint64_t>(&exported.arrays.back()));
}

// Keep pyarrow producers alive while C++ accesses exported Arrow memory.
py::list keepAlive;
keepAlive.append(arrowTable);
keepAlive.append(batches);
exported.keepAlive.append(arrowTable);
exported.keepAlive.append(batches);
return exported;
}

std::unique_ptr<PyQueryResult> PyConnection::createArrowTable(const std::string& tableName,
py::object arrowTable) {
auto& stateRef = refState();
py::gil_scoped_acquire acquire;

auto exported = exportArrowTable(std::move(arrowTable));
auto result = ArrowTableSupport::createViewFromArrowTable(stateRef.ref(), tableName,
std::move(schema), std::move(arrays));
std::move(exported.schema), std::move(exported.arrays));
if (result.queryResult && result.queryResult->isSuccess()) {
stateRef.arrowTableRefs[tableName] = std::move(keepAlive);
stateRef.arrowTableRefs[tableName] = std::move(exported.keepAlive);
}

return checkAndWrapQueryResult(result.queryResult, state);
}

std::unique_ptr<PyQueryResult> PyConnection::createArrowRelTable(const std::string& tableName,
py::object arrowTable, const std::string& srcTableName, const std::string& dstTableName) {
py::object arrowTable, const std::string& srcTableName, const std::string& dstTableName,
const std::string& layout, py::object indptrTable) {
auto& stateRef = refState();
py::gil_scoped_acquire acquire;

if (PyConnection::isPandasDataframe(arrowTable)) {
arrowTable = importCache->pyarrow.lib.Table.from_pandas()(arrowTable);
} else if (PyConnection::isPolarsDataframe(arrowTable)) {
arrowTable = arrowTable.attr("to_arrow")();
}
if (!PyConnection::isPyArrowTable(arrowTable)) {
throw RuntimeException("Expected a pyarrow Table, polars DataFrame, or pandas DataFrame");
}

ArrowSchemaWrapper schema;
arrowTable.attr("schema").attr("_export_to_c")(reinterpret_cast<uint64_t>(&schema));
std::vector<ArrowArrayWrapper> arrays;
py::list batches = arrowTable.attr("to_batches")();
for (auto& batch : batches) {
arrays.emplace_back();
batch.attr("_export_to_c")(reinterpret_cast<uint64_t>(&arrays.back()));
}
auto layoutUpper = layout;
std::transform(layoutUpper.begin(), layoutUpper.end(), layoutUpper.begin(),
[](unsigned char c) { return static_cast<char>(std::toupper(c)); });

auto exported = exportArrowTable(std::move(arrowTable));
ArrowTableCreationResult result;
py::list keepAlive;
keepAlive.append(arrowTable);
keepAlive.append(batches);
keepAlive.append(exported.keepAlive);

auto result = ArrowTableSupport::createRelTableFromArrowTable(stateRef.ref(), tableName,
srcTableName, dstTableName, std::move(schema), std::move(arrays));
if (layoutUpper == "FLAT") {
if (!py::none().is(indptrTable)) {
throw RuntimeException("indptr_table is only valid for CSR Arrow relationship tables");
}
result = ArrowTableSupport::createRelTableFromArrowTable(stateRef.ref(), tableName,
srcTableName, dstTableName, std::move(exported.schema), std::move(exported.arrays));
} else if (layoutUpper == "CSR") {
if (py::none().is(indptrTable)) {
throw RuntimeException("indptr_table is required for CSR Arrow relationship tables");
}
auto exportedIndptr = exportArrowTable(std::move(indptrTable));
keepAlive.append(exportedIndptr.keepAlive);
result = ArrowTableSupport::createRelTableFromArrowCSR(stateRef.ref(), tableName,
srcTableName, dstTableName, std::move(exported.schema), std::move(exported.arrays),
std::move(exportedIndptr.schema), std::move(exportedIndptr.arrays));
} else {
throw RuntimeException("Arrow relationship table layout must be FLAT or CSR");
}
if (result.queryResult && result.queryResult->isSuccess()) {
stateRef.arrowTableRefs[tableName] = std::move(keepAlive);
}
Expand Down
3 changes: 2 additions & 1 deletion src_py/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
from .database import Database # noqa: E402
from .prepared_statement import PreparedStatement # noqa: E402
from .query_result import ArrowQueryResult, CSRResult, QueryResult # noqa: E402
from .types import Type # noqa: E402
from .types import ArrowRelTableLayout, Type # noqa: E402

_VERSION_INFO: tuple[str, int] | None = None

Expand All @@ -81,6 +81,7 @@ def __getattr__(name: str) -> str | int:
__all__ = [
"AsyncConnection",
"ArrowQueryResult",
"ArrowRelTableLayout",
"Connection",
"CSRResult",
"Database",
Expand Down
23 changes: 23 additions & 0 deletions src_py/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from ._backend import get_capi_module, get_pybind_module
from .prepared_statement import PreparedStatement
from .query_result import ArrowQueryResult, QueryResult
from .types import ArrowRelTableLayout

if TYPE_CHECKING:
import sys
Expand Down Expand Up @@ -811,6 +812,8 @@ def create_arrow_rel_table(
dataframe: Any,
src_table_name: str,
dst_table_name: str,
layout: ArrowRelTableLayout | str = ArrowRelTableLayout.FLAT,
indptr_dataframe: Any | None = None,
) -> QueryResult:
"""
Create an Arrow memory-backed relationship table from a DataFrame.
Expand All @@ -829,19 +832,37 @@ def create_arrow_rel_table(
dst_table_name : str
Destination node table name in the FROM/TO pair.

layout : ArrowRelTableLayout | str
Relationship layout. FLAT expects ``dataframe`` to contain ``from``
and ``to`` endpoint columns. CSR expects ``dataframe`` to contain a
``to`` destination offset column plus properties, and
``indptr_dataframe`` to contain source offsets.

indptr_dataframe : Any | None
A pandas DataFrame, polars DataFrame, or PyArrow table containing
CSR source offsets. Required when ``layout`` is CSR.

Returns
-------
QueryResult
Result of the table creation query.

"""
self.init_connection()
layout_value = (
layout.value if isinstance(layout, ArrowRelTableLayout) else str(layout)
).upper()
if layout_value == ArrowRelTableLayout.CSR.value and indptr_dataframe is None:
msg = "indptr_dataframe is required when layout is CSR"
raise ValueError(msg)
try:
query_result_internal = self._connection.create_arrow_rel_table(
table_name,
dataframe,
src_table_name,
dst_table_name,
layout_value,
indptr_dataframe,
)
except NotImplementedError:
py_connection = self._get_pybind_connection()
Expand All @@ -853,6 +874,8 @@ def create_arrow_rel_table(
dataframe,
src_table_name,
dst_table_name,
layout_value,
indptr_dataframe,
)
if not query_result_internal.isSuccess():
raise RuntimeError(query_result_internal.getErrorMessage())
Expand Down
7 changes: 7 additions & 0 deletions src_py/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,10 @@ class Type(Enum):
STRUCT = "STRUCT"
MAP = "MAP"
UNION = "UNION"


class ArrowRelTableLayout(Enum):
"""Arrow-backed relationship table layout."""

FLAT = "FLAT"
CSR = "CSR"
Loading
Loading