diff --git a/README.md b/README.md index ae32ecb..5139a9f 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,7 @@ Icebug is a standardized graph format designed for efficient graph data intercha | **icebug-disk** | Parquet files | Object storage, persistence | | **icebug-memory** | Apache Arrow tables | In-process, zero-copy access | -Both represent graphs in [CSR (Compressed Sparse Row)](https://en.wikipedia.org/wiki/Sparse_matrix#Compressed_sparse_row_(CSR,_CRS_or_Yale_format)) format, which enables fast adjacency-list traversal. +Both represent *directed* graphs in [CSR (Compressed Sparse Row)](https://en.wikipedia.org/wiki/Sparse_matrix#Compressed_sparse_row_(CSR,_CRS_or_Yale_format)) format, which enables fast adjacency-list traversal. --- @@ -74,13 +74,20 @@ Table: follows (FROM user TO user) Convert Arrow tables directly into an in-memory CSR graph ```python -from icebug_format import IcebugMemGraph, convert_arrow_tables_to_csr +from icebug_format import IcebugMemGraph -graph: IcebugMemGraph = convert_arrow_tables_to_csr( +# Directed heterogeneous graph (different node types on each end) +graph: IcebugMemGraph = IcebugMemGraph.from_arrow_tables( from_node_arrow_table=users, # pa.Table, first column is the primary key - to_node_arrow_table=cities, # pa.Table, first column is the primary key rel_arrow_table=livesin, # pa.Table with 'source' and 'target' columns - directed=True, + to_node_arrow_table=cities, # pa.Table, first column is the primary key +) + +# Directed or undirected homogeneous graph (same node type on both ends) +graph: IcebugMemGraph = IcebugMemGraph.from_arrow_tables( + from_node_arrow_table=users, # pa.Table, first column is the primary key + rel_arrow_table=follows, # pa.Table with 'source' and 'target' columns + undirected=True, # undirected=True for undirected (to_node_arrow_table must be omitted) ) # Node tables are passed through unchanged @@ -101,7 +108,12 @@ The `rel_arrow_table` source and target columns are resolved by name in priority Any remaining columns are preserved as edge properties in `graph.indices`. -Set `directed=False` to automatically add reverse edges (undirected graph). +Set `undirected=True` to automatically add reverse edges (undirected graph). For undirected graphs, `to_node_arrow_table` must be omitted; the same node table is used for both sides of every edge. + +## Caveats + +- icebug-format will always output a directed graph +- If you want an undirected graph to be converted, pass undirected=True to the CLI or Python API, and the reverse edges will be added automatically. But do note that undirected graphs are supported for rel tables with same node type on both ends only --- diff --git a/icebug_format/__init__.py b/icebug_format/__init__.py index 60cc79f..7258769 100644 --- a/icebug_format/__init__.py +++ b/icebug_format/__init__.py @@ -1,4 +1,4 @@ from icebug_format.cli import main -from icebug_format.memory import IcebugMemGraph, convert_arrow_tables_to_csr +from icebug_format.memory import IcebugMemGraph -__all__ = ["main", "IcebugMemGraph", "convert_arrow_tables_to_csr"] +__all__ = ["main", "IcebugMemGraph"] diff --git a/icebug_format/cli.py b/icebug_format/cli.py index 256687c..731e070 100644 --- a/icebug_format/cli.py +++ b/icebug_format/cli.py @@ -385,7 +385,7 @@ def create_csr_graph_to_duckdb( source_db_path: str, output_db_path: str, limit_rels: int | None = None, - directed: bool = False, + undirected: bool = False, csr_table_name: str = "csr_graph", node_table: str | None = None, edge_table: str | None = None, @@ -400,7 +400,7 @@ def create_csr_graph_to_duckdb( source_db_path: Path to source DuckDB with edges table output_db_path: Path to output DuckDB for CSR data limit_rels: Limit number of relationships for testing - directed: Whether graph is directed + undirected: Whether graph is undirected csr_table_name: Name of table to store CSR data node_table: Specific node table to use (default: auto-discover) edge_table: Specific edge table to use (default: auto-discover) @@ -508,6 +508,17 @@ def create_csr_graph_to_duckdb( src_csr_table = f"{csr_table_name}_{src_table}" dst_csr_table = f"{csr_table_name}_{dst_table}" + # For undirected graphs, from and to node tables must be the same. + if undirected: + if src_table != dst_table: + raise ValueError( + f"Undirected graphs require the same node table on both sides of an " + f"edge, but edge table '{et}' connects '{src_table}' -> '{dst_table}'. " + f"Use --undirected for homogeneous edge tables." + ) + dst_pk = src_pk + dst_csr_table = src_csr_table + # Inline id→csr_index mapping as CTEs — no separate mapping tables needed map_cte = f""" src_map AS ( @@ -536,21 +547,24 @@ def create_csr_graph_to_duckdb( if edge_cols: reverse_cols += ", " + ", ".join(edge_cols) + # Self-loops are not filtered from directed graphs. + # For undirected graphs, the reverse UNION excludes self-loops so + # each self-loop appears exactly once (forward only). join_clause = f""" FROM orig.{et} e JOIN src_map m1 ON e.source = m1.original_node_id - JOIN dst_map m2 ON e.target = m2.original_node_id - WHERE e.source != e.target""" + JOIN dst_map m2 ON e.target = m2.original_node_id""" if limit_rels: limit_per_table = limit_rels // len(edge_tables) - if directed: + if not undirected: rel_query = f""" WITH {map_cte} SELECT {select_cols} {join_clause} LIMIT {limit_per_table} """ else: + # Reverse self-loops using CSR indices (already mapped) rel_query = f""" WITH {map_cte}, limited AS ( @@ -560,9 +574,10 @@ def create_csr_graph_to_duckdb( SELECT * FROM limited UNION ALL SELECT {reverse_cols} FROM limited + WHERE csr_source != csr_target """ else: - if directed: + if not undirected: rel_query = f""" WITH {map_cte} SELECT {select_cols} {join_clause} @@ -573,6 +588,7 @@ def create_csr_graph_to_duckdb( SELECT {select_cols} {join_clause} UNION ALL SELECT {reverse_select_cols} {join_clause} + WHERE e.source != e.target """ con.execute(f"CREATE TABLE relations_{edge_name} AS {rel_query};") @@ -609,9 +625,9 @@ def create_csr_graph_to_duckdb( # Recreate with leading zero con.execute(f""" CREATE OR REPLACE TABLE {indptr_table} AS - SELECT 0::BIGINT AS ptr + SELECT 0::UBIGINT AS ptr UNION ALL - SELECT ptr::int64 FROM {indptr_table} + SELECT ptr::UBIGINT FROM {indptr_table} ORDER BY ptr; """) @@ -623,7 +639,7 @@ def create_csr_graph_to_duckdb( indices_table = f"{csr_table_name}_indices_{edge_name}" con.execute(f""" CREATE TABLE {indices_table} AS - SELECT csr_target AS target{', ' + ', '.join(edge_cols) if edge_cols else ''} + SELECT csr_target::UBIGINT AS target{', ' + ', '.join(edge_cols) if edge_cols else ''} FROM relations_{edge_name} ORDER BY csr_source, csr_target; """) @@ -714,9 +730,9 @@ def main(): help="Number of edges to use in test mode (default: 50000)", ) parser.add_argument( - "--directed", + "--undirected", action="store_true", - help="Treat graph as directed (default: undirected)", + help="Treat graph as undirected (default: directed)", ) parser.add_argument( "--storage", @@ -757,7 +773,7 @@ def main(): print(f"GraphAr directory: {args.graphar}") print(f"CSR output database: {args.output_db}") print(f"CSR table prefix: {args.csr_table}") - print(f"Directed: {args.directed}") + print(f"Undirected: {args.undirected}") print(f"DuckDB memory limit: {args.memory_limit}") try: @@ -772,7 +788,7 @@ def main(): graphar_dir=args.graphar, output_db_path=args.output_db, csr_table_name=args.csr_table, - directed=args.directed, + undirected=args.undirected, memory_limit=args.memory_limit, ) @@ -796,7 +812,7 @@ def main(): print(f"Source database: {source_db_path}") print(f"CSR output database: {args.output_db}") print(f"CSR table prefix: {args.csr_table}") - print(f"Directed: {args.directed}") + print(f"Undirected: {args.undirected}") print(f"DuckDB memory limit: {args.memory_limit}") # Compute default storage path from output_db if not specified @@ -816,7 +832,7 @@ def main(): source_db_path=source_db_path, output_db_path=args.output_db, limit_rels=test_limit, - directed=args.directed, + undirected=args.undirected, csr_table_name=args.csr_table, node_table=args.node_table, edge_table=args.edge_table, diff --git a/icebug_format/graphar.py b/icebug_format/graphar.py index 33316df..c1ca904 100644 --- a/icebug_format/graphar.py +++ b/icebug_format/graphar.py @@ -157,7 +157,7 @@ def convert_graphar_to_graph_std( graphar_dir: str, output_db_path: str, csr_table_name: str = "graph", - directed: bool = False, + undirected: bool = False, memory_limit: str = "80%", ) -> None: """ @@ -167,7 +167,7 @@ def convert_graphar_to_graph_std( graphar_dir: Path to directory with GraphAr data output_db_path: Path to output DuckDB database csr_table_name: Name prefix for CSR tables - directed: Whether graph is directed + undirected: Whether graph is undirected memory_limit: DuckDB memory limit setting """ print("\n=== Converting GraphAr to Graph-Std Format ===") @@ -398,22 +398,22 @@ def convert_graphar_to_graph_std( # Add leading zero temp_table = f"{indptr_table}_temp" con.execute(f"DROP TABLE IF EXISTS {temp_table}") - con.execute(f"CREATE TABLE {temp_table} (ptr BIGINT)") - con.execute(f"INSERT INTO {temp_table} VALUES (CAST(0 AS BIGINT))") - con.execute(f"INSERT INTO {temp_table} SELECT ptr FROM {indptr_table}") + con.execute(f"CREATE TABLE {temp_table} (ptr UBIGINT)") + con.execute(f"INSERT INTO {temp_table} VALUES (CAST(0 AS UBIGINT))") + con.execute(f"INSERT INTO {temp_table} SELECT CAST(ptr AS UBIGINT) FROM {indptr_table}") con.execute(f"DROP TABLE {indptr_table}") con.execute(f"ALTER TABLE {temp_table} RENAME TO {indptr_table}") # Build CSR indices indices_table = f"{csr_table_name}_indices_{edge_type}" - col_defs = "target BIGINT" + col_defs = "target UBIGINT" for prop in prop_cols: col_defs += f", {prop} BIGINT" con.execute(f""" CREATE TABLE {indices_table} AS - SELECT csr_target AS target{', ' + ', '.join(prop_cols) if prop_cols else ''} + SELECT CAST(csr_target AS UBIGINT) AS target{', ' + ', '.join(prop_cols) if prop_cols else ''} FROM {rel_table_name} ORDER BY csr_source, csr_target """) @@ -569,9 +569,9 @@ def main(): help="Table name prefix for CSR data (default: graph)", ) parser.add_argument( - "--directed", + "--undirected", action="store_true", - help="Treat graph as directed (default: undirected)", + help="Treat graph as undirected (default: directed)", ) args = parser.parse_args() @@ -580,13 +580,13 @@ def main(): print(f"GraphAr directory: {args.graphar_dir}") print(f"CSR output database: {args.output_db}") print(f"CSR table prefix: {args.csr_table}") - print(f"Directed: {args.directed}") + print(f"Undirected: {args.undirected}") convert_graphar_to_graph_std( graphar_dir=args.graphar_dir, output_db_path=args.output_db, csr_table_name=args.csr_table, - directed=args.directed, + undirected=args.undirected, ) print("\n=== Conversion Completed Successfully! ===") diff --git a/icebug_format/memory.py b/icebug_format/memory.py index 2eed22c..e561171 100644 --- a/icebug_format/memory.py +++ b/icebug_format/memory.py @@ -53,159 +53,182 @@ class IcebugMemGraph: indices: pa.Table indptr: pa.Table + @classmethod + def from_arrow_tables( + cls, + from_node_arrow_table: pa.Table, + rel_arrow_table: pa.Table, + *, + to_node_arrow_table: pa.Table | None = None, + undirected: bool = False, + ) -> "IcebugMemGraph": + """ + Convert node and relationship Arrow tables to an IcebugMemGraph. + + The first column of each node table is treated as the primary key used + to map node IDs to dense 0-based CSR indices. + + The relationship table's source and target columns are resolved by name + in the following priority order, falling back to positional columns: + + - Source: ``source`` → ``src`` → ``from`` → 0th column + - Target: ``target`` → ``destination`` → ``dest`` → ``to`` → 1st column + + Any remaining columns in *rel_arrow_table* are preserved as edge + properties in the *indices* output table. + + For undirected graphs (``undirected=True``), ``to_node_arrow_table`` must + not be provided: the from-node table is used for both sides of every + edge. Providing ``to_node_arrow_table`` while also passing + ``undirected=True`` raises ``ValueError``. + + Args: + from_node_arrow_table: Source node table. + rel_arrow_table: Relationship table. + to_node_arrow_table: Destination node table (directed graphs only). + Defaults to *from_node_arrow_table* when + ``None`` (i.e., homogeneous edges). + undirected: If ``False`` (default), only forward edges are + stored. If ``True``, reverse edges are added + so the graph is treated as undirected. + + Returns: + IcebugMemGraph where *src* and *dest* are the original node tables + and *indices*/*indptr* encode the CSR adjacency structure. + + Raises: + ValueError: If *rel_arrow_table* has fewer than 2 columns. + ValueError: If ``undirected=True`` and *to_node_arrow_table* is + provided (undirected graphs always use a single node + table for both sides). + """ + if undirected and to_node_arrow_table is not None: + raise ValueError( + "to_node_arrow_table must not be provided for undirected graphs; " + "from and to node tables are always the same for undirected edges." + ) -def convert_arrow_tables_to_csr( - from_node_arrow_table: pa.Table, - to_node_arrow_table: pa.Table, - rel_arrow_table: pa.Table, - directed: bool = True, -) -> IcebugMemGraph: - """ - Convert node and relationship Arrow tables to an IcebugMemGraph. - - The first column of each node table is treated as the primary key used - to map node IDs to dense 0-based CSR indices. - - The relationship table's source and target columns are resolved by name - in the following priority order, falling back to positional columns: - - - Source: ``source`` → ``src`` → ``from`` → 0th column - - Target: ``target`` → ``destination`` → ``dest`` → ``to`` → 1st column - - Any remaining columns in *rel_arrow_table* are preserved as edge - properties in the *indices* output table. - - Args: - from_node_arrow_table: Source node table. - to_node_arrow_table: Destination node table. - rel_arrow_table: Relationship table. - directed: If True (default), only forward edges are - stored. If False, reverse edges are added - so the graph is treated as undirected. - - Returns: - IcebugMemGraph where *src* and *dest* are the original node tables - passed in unchanged, and *indices*/*indptr* encode the CSR - adjacency structure. - """ - if rel_arrow_table.num_columns < 2: - raise ValueError( - f"rel_arrow_table must have at least 2 columns (source and target), " - f"got {rel_arrow_table.num_columns}" - ) - - src_pk = from_node_arrow_table.schema.names[0] - dst_pk = to_node_arrow_table.schema.names[0] - num_src_nodes = len(from_node_arrow_table) - - src_col, dst_col = _resolve_rel_columns(rel_arrow_table.schema) - edge_cols = [ - c for c in rel_arrow_table.schema.names if c not in (src_col, dst_col) - ] - - select_fwd = "m1.csr_index AS csr_source, m2.csr_index AS csr_target" - select_rev = "m2.csr_index AS csr_source, m1.csr_index AS csr_target" - def q(name: str) -> str: - return '"' + name.replace('"', '""') + '"' - - if edge_cols: - props = ", ".join(f"e.{q(c)}" for c in edge_cols) - select_fwd += f", {props}" - select_rev += f", {props}" - - map_cte = f""" - src_map AS ( - SELECT row_number() OVER () - 1 AS csr_index, - {q(src_pk)} AS original_node_id - FROM from_nodes - ), - dst_map AS ( - SELECT row_number() OVER () - 1 AS csr_index, - {q(dst_pk)} AS original_node_id - FROM to_nodes - ) - """ + if to_node_arrow_table is None: + to_node_arrow_table = from_node_arrow_table - join_clause = f""" - FROM edges e - JOIN src_map m1 ON e.{q(src_col)} = m1.original_node_id - JOIN dst_map m2 ON e.{q(dst_col)} = m2.original_node_id - """ + if rel_arrow_table.num_columns < 2: + raise ValueError( + f"rel_arrow_table must have at least 2 columns (source and target), " + f"got {rel_arrow_table.num_columns}" + ) - if directed: - rel_query = f"WITH {map_cte} SELECT {select_fwd} {join_clause}" - else: - # Self-loops appear once (forward only); non-self edges get both directions. - rel_query = f""" - WITH {map_cte} - SELECT {select_fwd} {join_clause} - UNION ALL - SELECT {select_rev} {join_clause} - WHERE e.{q(src_col)} != e.{q(dst_col)} + src_pk = from_node_arrow_table.schema.names[0] + dst_pk = to_node_arrow_table.schema.names[0] + num_src_nodes = len(from_node_arrow_table) + + src_col, dst_col = _resolve_rel_columns(rel_arrow_table.schema) + edge_cols = [ + c for c in rel_arrow_table.schema.names if c not in (src_col, dst_col) + ] + + select_fwd = "m1.csr_index AS csr_source, m2.csr_index AS csr_target" + select_rev = "m2.csr_index AS csr_source, m1.csr_index AS csr_target" + def q(name: str) -> str: + return '"' + name.replace('"', '""') + '"' + + if edge_cols: + props = ", ".join(f"e.{q(c)}" for c in edge_cols) + select_fwd += f", {props}" + select_rev += f", {props}" + + map_cte = f""" + src_map AS ( + SELECT row_number() OVER () - 1 AS csr_index, + {q(src_pk)} AS original_node_id + FROM from_nodes + ), + dst_map AS ( + SELECT row_number() OVER () - 1 AS csr_index, + {q(dst_pk)} AS original_node_id + FROM to_nodes + ) """ - edge_props_select = (", " + ", ".join(q(c) for c in edge_cols)) if edge_cols else "" - - con = duckdb.connect() - try: - con.register("from_nodes", from_node_arrow_table) - con.register("to_nodes", to_node_arrow_table) - con.register("edges", rel_arrow_table) - - con.execute(f"CREATE TABLE relations AS {rel_query}") + join_clause = f""" + FROM edges e + JOIN src_map m1 ON e.{q(src_col)} = m1.original_node_id + JOIN dst_map m2 ON e.{q(dst_col)} = m2.original_node_id + """ - # Build indptr: cumulative degree per source node - con.execute(f""" - CREATE TABLE indptr_table AS - WITH node_range AS ( - SELECT unnest(range(0, {num_src_nodes})) AS node_id - ), - degrees AS ( - SELECT csr_source AS src, COUNT(*) AS deg + if not undirected: + rel_query = f"WITH {map_cte} SELECT {select_fwd} {join_clause}" + else: + # Self-loops appear once (forward only); non-self edges get both directions. + rel_query = f""" + WITH {map_cte} + SELECT {select_fwd} {join_clause} + UNION ALL + SELECT {select_rev} {join_clause} + WHERE e.{q(src_col)} != e.{q(dst_col)} + """ + + edge_props_select = (", " + ", ".join(q(c) for c in edge_cols)) if edge_cols else "" + + con = duckdb.connect() + try: + con.register("from_nodes", from_node_arrow_table) + con.register("to_nodes", to_node_arrow_table) + con.register("edges", rel_arrow_table) + + con.execute(f"CREATE TABLE relations AS {rel_query}") + + # Build indptr: cumulative degree per source node + con.execute(f""" + CREATE TABLE indptr_table AS + WITH node_range AS ( + SELECT unnest(range(0, {num_src_nodes})) AS node_id + ), + degrees AS ( + SELECT csr_source AS src, COUNT(*) AS deg + FROM relations + GROUP BY csr_source + ), + cumulative AS ( + SELECT + node_range.node_id, + COALESCE( + SUM(degrees.deg) OVER ( + ORDER BY node_range.node_id + ROWS UNBOUNDED PRECEDING + ), 0 + ) AS ptr + FROM node_range + LEFT JOIN degrees ON node_range.node_id = degrees.src + ) + SELECT ptr FROM cumulative + ORDER BY node_id + """) + + # Prepend leading zero so indptr[i] = start of node i's adjacency list + con.execute(""" + CREATE OR REPLACE TABLE indptr_table AS + SELECT 0::UINT64 AS ptr + UNION ALL + SELECT ptr::UINT64 FROM indptr_table + ORDER BY ptr + """) + + # Build indices: neighbour list sorted by (source, target) + con.execute(f""" + CREATE TABLE indices_table AS + SELECT csr_target::UINT64 AS target{edge_props_select} FROM relations - GROUP BY csr_source - ), - cumulative AS ( - SELECT - node_range.node_id, - COALESCE( - SUM(degrees.deg) OVER ( - ORDER BY node_range.node_id - ROWS UNBOUNDED PRECEDING - ), 0 - ) AS ptr - FROM node_range - LEFT JOIN degrees ON node_range.node_id = degrees.src - ) - SELECT ptr FROM cumulative - ORDER BY node_id - """) - - # Prepend leading zero so indptr[i] = start of node i's adjacency list - con.execute(""" - CREATE OR REPLACE TABLE indptr_table AS - SELECT 0::UINT64 AS ptr - UNION ALL - SELECT ptr::UINT64 FROM indptr_table - ORDER BY ptr - """) - - # Build indices: neighbour list sorted by (source, target) - con.execute(f""" - CREATE TABLE indices_table AS - SELECT csr_target::UINT64 AS target{edge_props_select} - FROM relations - ORDER BY csr_source, csr_target - """) - - indices = con.execute("SELECT * FROM indices_table").arrow().read_all() - indptr = con.execute("SELECT * FROM indptr_table").arrow().read_all() - finally: - con.close() - - return IcebugMemGraph( - src=from_node_arrow_table, - dest=to_node_arrow_table, - indices=indices, - indptr=indptr, - ) + ORDER BY csr_source, csr_target + """) + + indices = con.execute("SELECT * FROM indices_table").arrow().read_all() + indptr = con.execute("SELECT * FROM indptr_table").arrow().read_all() + finally: + con.close() + + return cls( + src=from_node_arrow_table, + dest=to_node_arrow_table, + indices=indices, + indptr=indptr, + ) diff --git a/tests/test_cli.py b/tests/test_cli.py new file mode 100644 index 0000000..abb424c --- /dev/null +++ b/tests/test_cli.py @@ -0,0 +1,361 @@ +"""Tests for the icebug-disk converter (create_csr_graph_to_duckdb).""" + +import tempfile +from pathlib import Path + +import duckdb +import pytest + +from icebug_format.cli import create_csr_graph_to_duckdb + +_MEM = "1GB" + + +def _make_source_db(path: str, edges: list[tuple], self_loop: bool = False) -> None: + """Create a minimal source DuckDB with nodes and edges tables.""" + con = duckdb.connect(path) + con.execute("CREATE TABLE nodes (id BIGINT)") + # Collect unique node IDs + node_ids = sorted({n for e in edges for n in e}) + for nid in node_ids: + con.execute(f"INSERT INTO nodes VALUES ({nid})") + + con.execute("CREATE TABLE edges (source BIGINT, target BIGINT)") + for src, dst in edges: + con.execute(f"INSERT INTO edges VALUES ({src}, {dst})") + con.close() + + +def _make_hetero_source_db(path: str) -> None: + """Create a source DuckDB with two node types and a heterogeneous edge table.""" + con = duckdb.connect(path) + con.execute("CREATE TABLE nodes_user (id BIGINT)") + con.execute("INSERT INTO nodes_user VALUES (0), (1)") + con.execute("CREATE TABLE nodes_city (id BIGINT)") + con.execute("INSERT INTO nodes_city VALUES (10), (11)") + con.execute("CREATE TABLE edges_livesin (source BIGINT, target BIGINT)") + con.execute("INSERT INTO edges_livesin VALUES (0, 10), (1, 11)") + con.close() + + +def _make_multi_edge_source_db(path: str) -> None: + """Create a source DB with two edge tables: edges_follows and edges_likes.""" + con = duckdb.connect(path) + con.execute("CREATE TABLE nodes (id BIGINT)") + for i in range(4): + con.execute(f"INSERT INTO nodes VALUES ({i})") + con.execute("CREATE TABLE edges_follows (source BIGINT, target BIGINT)") + con.execute("INSERT INTO edges_follows VALUES (0,1),(1,2)") + con.execute("CREATE TABLE edges_likes (source BIGINT, target BIGINT)") + con.execute("INSERT INTO edges_likes VALUES (0,2),(1,3)") + con.close() + + +def _make_multi_node_source_db(path: str) -> None: + """Create a source DB with two node tables and one edge table.""" + con = duckdb.connect(path) + con.execute("CREATE TABLE nodes_user (id BIGINT)") + for i in range(3): + con.execute(f"INSERT INTO nodes_user VALUES ({i})") + con.execute("CREATE TABLE nodes_admin (id BIGINT)") + for i in range(10, 12): + con.execute(f"INSERT INTO nodes_admin VALUES ({i})") + con.execute("CREATE TABLE edges (source BIGINT, target BIGINT)") + con.execute("INSERT INTO edges VALUES (0,1),(1,2)") + con.close() + + +def _parquet_dir(out_path: str) -> Path: + """Return the parquet output directory for a given output_db_path.""" + p = Path(out_path) + return p.parent / p.stem + + +# --------------------------------------------------------------------------- +# Directed graph +# --------------------------------------------------------------------------- + + +def test_directed_basic(): + with tempfile.TemporaryDirectory() as tmpdir: + src = str(Path(tmpdir) / "src.duckdb") + out = str(Path(tmpdir) / "out.duckdb") + # 0 -> 1 -> 2 + _make_source_db(src, [(0, 1), (1, 2)]) + create_csr_graph_to_duckdb(src, out, undirected=False, memory_limit=_MEM) + + con = duckdb.connect(out) + indices = con.execute("SELECT target FROM csr_graph_indices_edges ORDER BY rowid").fetchall() + indptr = con.execute("SELECT ptr FROM csr_graph_indptr_edges ORDER BY rowid").fetchall() + con.close() + + assert [r[0] for r in indices] == [1, 2] + assert [r[0] for r in indptr] == [0, 1, 2, 2] + + +def test_csr_columns_are_uint64(): + with tempfile.TemporaryDirectory() as tmpdir: + src = str(Path(tmpdir) / "src.duckdb") + out = str(Path(tmpdir) / "out.duckdb") + _make_source_db(src, [(0, 1), (1, 2)]) + create_csr_graph_to_duckdb(src, out, undirected=False, memory_limit=_MEM) + + con = duckdb.connect(out) + indices_desc = con.execute("DESCRIBE csr_graph_indices_edges").fetchall() + indptr_desc = con.execute("DESCRIBE csr_graph_indptr_edges").fetchall() + con.close() + + indices_target_type = next(col[1] for col in indices_desc if col[0] == "target") + indptr_ptr_type = next(col[1] for col in indptr_desc if col[0] == "ptr") + assert indices_target_type == "UBIGINT" + assert indptr_ptr_type == "UBIGINT" + + +def test_directed_preserves_self_loops(): + """Self-loops must not be filtered from directed graphs.""" + with tempfile.TemporaryDirectory() as tmpdir: + src = str(Path(tmpdir) / "src.duckdb") + out = str(Path(tmpdir) / "out.duckdb") + # 0->0 (self-loop) + 0->1 + _make_source_db(src, [(0, 0), (0, 1)]) + create_csr_graph_to_duckdb(src, out, undirected=False, memory_limit=_MEM) + + con = duckdb.connect(out) + indices = con.execute("SELECT target FROM csr_graph_indices_edges ORDER BY target").fetchall() + con.close() + + assert sorted(r[0] for r in indices) == [0, 1] + + +# --------------------------------------------------------------------------- +# Undirected graph +# --------------------------------------------------------------------------- + + +def test_undirected_adds_reverse_edges(): + with tempfile.TemporaryDirectory() as tmpdir: + src = str(Path(tmpdir) / "src.duckdb") + out = str(Path(tmpdir) / "out.duckdb") + # 0 -- 1 + _make_source_db(src, [(0, 1)]) + create_csr_graph_to_duckdb(src, out, undirected=True, memory_limit=_MEM) + + con = duckdb.connect(out) + count = con.execute("SELECT COUNT(*) FROM csr_graph_indices_edges").fetchone()[0] + con.close() + + assert count == 2 # forward + reverse + + +def test_undirected_self_loop_appears_once(): + """Self-loops in an undirected graph must appear exactly once.""" + with tempfile.TemporaryDirectory() as tmpdir: + src = str(Path(tmpdir) / "src.duckdb") + out = str(Path(tmpdir) / "out.duckdb") + # 0->0 self-loop + 0->1 + _make_source_db(src, [(0, 0), (0, 1)]) + create_csr_graph_to_duckdb(src, out, undirected=True, memory_limit=_MEM) + + con = duckdb.connect(out) + count = con.execute("SELECT COUNT(*) FROM csr_graph_indices_edges").fetchone()[0] + con.close() + + # Edges: 0--0 (once) + 0--1 (forward) + 1--0 (reverse) = 3 + assert count == 3 + + +# --------------------------------------------------------------------------- +# Undirected validation +# --------------------------------------------------------------------------- + + +def test_undirected_heterogeneous_edges_raise(): + """Undirected graphs must not have heterogeneous (bipartite) edge tables.""" + with tempfile.TemporaryDirectory() as tmpdir: + src = str(Path(tmpdir) / "src.duckdb") + out = str(Path(tmpdir) / "out.duckdb") + _make_hetero_source_db(src) + + schema_path = Path(tmpdir) / "schema.cypher" + schema_path.write_text( + "CREATE REL TABLE livesin(FROM user TO city) WITH (storage='x', format='icebug-disk');\n" + ) + + with pytest.raises(ValueError, match="same node table"): + create_csr_graph_to_duckdb( + src, out, undirected=True, schema_path=str(schema_path), memory_limit=_MEM + ) + + +# --------------------------------------------------------------------------- +# limit_rels +# --------------------------------------------------------------------------- + + +def test_limit_rels_caps_edge_count(): + """limit_rels restricts how many edges are stored in the CSR indices table.""" + with tempfile.TemporaryDirectory() as tmpdir: + src = str(Path(tmpdir) / "src.duckdb") + out = str(Path(tmpdir) / "out.duckdb") + # 10 edges: 0->1, 1->2, ..., 9->10 + _make_source_db(src, [(i, i + 1) for i in range(10)]) + create_csr_graph_to_duckdb(src, out, undirected=False, limit_rels=3, memory_limit=_MEM) + + con = duckdb.connect(out) + count = con.execute("SELECT COUNT(*) FROM csr_graph_indices_edges").fetchone()[0] + con.close() + + assert count <= 3 + + +def test_limit_rels_undirected_adds_reverse_within_limit(): + """For undirected graphs, limit_rels applies to forward edges; reverse are added after.""" + with tempfile.TemporaryDirectory() as tmpdir: + src = str(Path(tmpdir) / "src.duckdb") + out = str(Path(tmpdir) / "out.duckdb") + # 6 distinct edges: 0-1, 1-2, 2-3, 3-4, 4-5, 5-0 + _make_source_db(src, [(0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, 0)]) + create_csr_graph_to_duckdb(src, out, undirected=True, limit_rels=2, memory_limit=_MEM) + + con = duckdb.connect(out) + count = con.execute("SELECT COUNT(*) FROM csr_graph_indices_edges").fetchone()[0] + con.close() + + # 2 forward edges → 4 total (each gets a reverse) + assert count == 4 + + +# --------------------------------------------------------------------------- +# csr_table_name +# --------------------------------------------------------------------------- + + +def test_csr_table_name_prefix(): + """All output tables should be prefixed with the custom csr_table_name.""" + with tempfile.TemporaryDirectory() as tmpdir: + src = str(Path(tmpdir) / "src.duckdb") + out = str(Path(tmpdir) / "out.duckdb") + _make_source_db(src, [(0, 1), (1, 2)]) + create_csr_graph_to_duckdb(src, out, undirected=False, csr_table_name="mygraph", memory_limit=_MEM) + + con = duckdb.connect(out) + tables = {r[0] for r in con.execute("SHOW TABLES").fetchall()} + con.close() + + assert "mygraph_nodes" in tables + assert "mygraph_indices_edges" in tables + assert "mygraph_indptr_edges" in tables + # Default prefix must NOT appear + assert "csr_graph_indices_edges" not in tables + + +# --------------------------------------------------------------------------- +# node_table / edge_table +# --------------------------------------------------------------------------- + + +def test_node_table_selects_single_table(): + """node_table restricts processing to exactly one node table.""" + with tempfile.TemporaryDirectory() as tmpdir: + src = str(Path(tmpdir) / "src.duckdb") + out = str(Path(tmpdir) / "out.duckdb") + _make_multi_node_source_db(src) + create_csr_graph_to_duckdb( + src, out, undirected=False, node_table="nodes_user", memory_limit=_MEM + ) + + con = duckdb.connect(out) + tables = {r[0] for r in con.execute("SHOW TABLES").fetchall()} + con.close() + + assert "csr_graph_nodes_user" in tables + assert "csr_graph_nodes_admin" not in tables + + +def test_edge_table_selects_single_table(): + """edge_table restricts processing to exactly one edge table.""" + with tempfile.TemporaryDirectory() as tmpdir: + src = str(Path(tmpdir) / "src.duckdb") + out = str(Path(tmpdir) / "out.duckdb") + _make_multi_edge_source_db(src) + create_csr_graph_to_duckdb( + src, out, undirected=False, edge_table="edges_follows", memory_limit=_MEM + ) + + con = duckdb.connect(out) + tables = {r[0] for r in con.execute("SHOW TABLES").fetchall()} + con.close() + + assert "csr_graph_indices_follows" in tables + assert "csr_graph_indices_likes" not in tables + + +def test_edge_table_not_found_raises(): + """Specifying a non-existent edge_table should raise ValueError.""" + with tempfile.TemporaryDirectory() as tmpdir: + src = str(Path(tmpdir) / "src.duckdb") + out = str(Path(tmpdir) / "out.duckdb") + _make_source_db(src, [(0, 1)]) + with pytest.raises(ValueError, match="No edge tables found"): + create_csr_graph_to_duckdb( + src, out, undirected=False, edge_table="edges_nonexistent", memory_limit=_MEM + ) + + +# --------------------------------------------------------------------------- +# schema_path +# --------------------------------------------------------------------------- + + +def test_schema_path_maps_from_to_node_types(): + """schema_path controls which node types appear in FROM/TO of the output schema.cypher.""" + with tempfile.TemporaryDirectory() as tmpdir: + src = str(Path(tmpdir) / "src.duckdb") + out = str(Path(tmpdir) / "out.duckdb") + _make_hetero_source_db(src) + + schema_path = Path(tmpdir) / "in_schema.cypher" + schema_path.write_text( + "CREATE REL TABLE livesin(FROM user TO city) WITH (storage='x', format='icebug-disk');\n" + ) + + create_csr_graph_to_duckdb( + src, out, undirected=False, schema_path=str(schema_path), memory_limit=_MEM + ) + + out_schema = (_parquet_dir(out) / "schema.cypher").read_text() + assert "FROM user TO city" in out_schema + + +# --------------------------------------------------------------------------- +# storage_path +# --------------------------------------------------------------------------- + + +def test_storage_path_appears_in_schema_cypher(): + """Custom storage_path should appear in the WITH clause of the output schema.cypher.""" + with tempfile.TemporaryDirectory() as tmpdir: + src = str(Path(tmpdir) / "src.duckdb") + out = str(Path(tmpdir) / "out.duckdb") + _make_source_db(src, [(0, 1), (1, 2)]) + + create_csr_graph_to_duckdb( + src, out, undirected=False, storage_path="./my_custom_store", memory_limit=_MEM + ) + + out_schema = (_parquet_dir(out) / "schema.cypher").read_text() + assert "./my_custom_store" in out_schema + + +def test_storage_path_default_uses_output_stem(): + """When storage_path is omitted the output DB stem is used as the default.""" + with tempfile.TemporaryDirectory() as tmpdir: + src = str(Path(tmpdir) / "src.duckdb") + out = str(Path(tmpdir) / "out.duckdb") + _make_source_db(src, [(0, 1)]) + + create_csr_graph_to_duckdb(src, out, undirected=False, memory_limit=_MEM) + + out_schema = (_parquet_dir(out) / "schema.cypher").read_text() + # Default storage_path is "./out" (stem of out.duckdb) + assert "./out" in out_schema diff --git a/tests/test_memory.py b/tests/test_memory.py index 9d190a1..1b7122d 100644 --- a/tests/test_memory.py +++ b/tests/test_memory.py @@ -1,9 +1,9 @@ -"""Tests for the icebug-memory converter (convert_arrow_tables_to_csr).""" +"""Tests for the icebug-memory converter (IcebugMemGraph.from_arrow_tables).""" import pyarrow as pa import pytest -from icebug_format.memory import IcebugMemGraph, convert_arrow_tables_to_csr +from icebug_format.memory import IcebugMemGraph def _nodes(*ids, pk="id"): @@ -27,7 +27,7 @@ def _rels(sources, targets, src_col="source", dst_col="destination", **props): def test_returns_icebug_mem_graph(): nodes = _nodes(0, 1) rels = _rels([0], [1]) - result = convert_arrow_tables_to_csr(nodes, nodes, rels) + result = IcebugMemGraph.from_arrow_tables(nodes, rels) assert isinstance(result, IcebugMemGraph) @@ -39,14 +39,14 @@ def test_returns_icebug_mem_graph(): def test_indices_target_column_is_uint64(): nodes = _nodes(0, 1, 2) rels = _rels([0, 1], [1, 2]) - g = convert_arrow_tables_to_csr(nodes, nodes, rels) + g = IcebugMemGraph.from_arrow_tables(nodes, rels) assert g.indices.schema.field("target").type == pa.uint64() def test_indptr_ptr_column_is_uint64(): nodes = _nodes(0, 1, 2) rels = _rels([0, 1], [1, 2]) - g = convert_arrow_tables_to_csr(nodes, nodes, rels) + g = IcebugMemGraph.from_arrow_tables(nodes, rels) assert g.indptr.schema.field("ptr").type == pa.uint64() @@ -59,7 +59,7 @@ def test_directed_linear_chain(): # Graph: 0 -> 1 -> 2 nodes = _nodes(0, 1, 2) rels = _rels([0, 1], [1, 2]) - g = convert_arrow_tables_to_csr(nodes, nodes, rels) + g = IcebugMemGraph.from_arrow_tables(nodes, rels) # indices: neighbour list in source order assert g.indices["target"].to_pylist() == [1, 2] @@ -71,7 +71,7 @@ def test_directed_fan_out(): # Graph: 0 -> 1, 0 -> 2, 0 -> 3 nodes = _nodes(0, 1, 2, 3) rels = _rels([0, 0, 0], [1, 2, 3]) - g = convert_arrow_tables_to_csr(nodes, nodes, rels) + g = IcebugMemGraph.from_arrow_tables(nodes, rels) assert sorted(g.indices["target"].to_pylist()) == [1, 2, 3] assert g.indptr["ptr"].to_pylist() == [0, 3, 3, 3, 3] @@ -80,31 +80,40 @@ def test_directed_fan_out(): def test_indptr_length_equals_num_src_nodes_plus_one(): nodes = _nodes(0, 1, 2, 3, 4) rels = _rels([0, 2], [1, 3]) - g = convert_arrow_tables_to_csr(nodes, nodes, rels) + g = IcebugMemGraph.from_arrow_tables(nodes, rels) assert len(g.indptr) == len(nodes) + 1 def test_indptr_starts_with_zero(): nodes = _nodes(0, 1, 2) rels = _rels([0], [1]) - g = convert_arrow_tables_to_csr(nodes, nodes, rels) + g = IcebugMemGraph.from_arrow_tables(nodes, rels) assert g.indptr["ptr"][0].as_py() == 0 def test_indptr_ends_with_edge_count(): nodes = _nodes(0, 1, 2) rels = _rels([0, 1], [1, 2]) - g = convert_arrow_tables_to_csr(nodes, nodes, rels) + g = IcebugMemGraph.from_arrow_tables(nodes, rels) assert g.indptr["ptr"][-1].as_py() == 2 def test_indices_length_equals_edge_count(): nodes = _nodes(0, 1, 2) rels = _rels([0, 1], [1, 2]) - g = convert_arrow_tables_to_csr(nodes, nodes, rels) + g = IcebugMemGraph.from_arrow_tables(nodes, rels) assert len(g.indices) == 2 +def test_directed_preserves_self_loops(): + # Self-loops must not be filtered in directed graphs. + nodes = _nodes(0, 1) + rels = _rels([0, 0], [0, 1]) # 0->0 self-loop + 0->1 + g = IcebugMemGraph.from_arrow_tables(nodes, rels) + assert len(g.indices) == 2 + assert sorted(g.indices["target"].to_pylist()) == [0, 1] + + # --------------------------------------------------------------------------- # Undirected graph # --------------------------------------------------------------------------- @@ -114,7 +123,7 @@ def test_undirected_adds_reverse_edges(): # Graph: 0 -- 1 nodes = _nodes(0, 1) rels = _rels([0], [1]) - g = convert_arrow_tables_to_csr(nodes, nodes, rels, directed=False) + g = IcebugMemGraph.from_arrow_tables(nodes, rels, undirected=True) assert len(g.indices) == 2 targets = sorted(g.indices["target"].to_pylist()) @@ -125,7 +134,7 @@ def test_undirected_indptr_reflects_bidirectional_degree(): # 0 -- 1 -- 2 nodes = _nodes(0, 1, 2) rels = _rels([0, 1], [1, 2]) - g = convert_arrow_tables_to_csr(nodes, nodes, rels, directed=False) + g = IcebugMemGraph.from_arrow_tables(nodes, rels, undirected=True) ptr = g.indptr["ptr"].to_pylist() # node 0: 1 neighbour, node 1: 2 neighbours, node 2: 1 neighbour @@ -140,7 +149,7 @@ def test_undirected_indptr_reflects_bidirectional_degree(): def test_self_loops_appear_once_in_undirected_graph(): nodes = _nodes(0, 1) rels = _rels([0, 0], [0, 1]) # 0->0 self-loop + 0->1 - g = convert_arrow_tables_to_csr(nodes, nodes, rels, directed=False) + g = IcebugMemGraph.from_arrow_tables(nodes, rels, undirected=True) # 0->0 (once), 0->1, 1->0 → 3 entries total assert len(g.indices) == 3 @@ -167,7 +176,7 @@ def test_edge_properties_preserved_in_indices(): "weight": weight, } ) - g = convert_arrow_tables_to_csr(nodes, nodes, rels) + g = IcebugMemGraph.from_arrow_tables(nodes, rels) assert "weight" in g.indices.schema.names assert g.indices["weight"].to_pylist() == pytest.approx([0.5, 1.5]) @@ -182,7 +191,7 @@ def test_edge_properties_not_in_indptr(): "weight": pa.array([1.0], type=pa.float32()), } ) - g = convert_arrow_tables_to_csr(nodes, nodes, rels) + g = IcebugMemGraph.from_arrow_tables(nodes, rels) assert g.indptr.schema.names == ["ptr"] @@ -195,7 +204,7 @@ def test_edge_properties_not_in_indptr(): def test_source_column_aliases(src_col): nodes = _nodes(0, 1) rels = _rels([0], [1], src_col=src_col) - g = convert_arrow_tables_to_csr(nodes, nodes, rels) + g = IcebugMemGraph.from_arrow_tables(nodes, rels) assert len(g.indices) == 1 @@ -203,7 +212,7 @@ def test_source_column_aliases(src_col): def test_target_column_aliases(dst_col): nodes = _nodes(0, 1) rels = _rels([0], [1], dst_col=dst_col) - g = convert_arrow_tables_to_csr(nodes, nodes, rels) + g = IcebugMemGraph.from_arrow_tables(nodes, rels) assert len(g.indices) == 1 @@ -216,11 +225,19 @@ def test_src_and_dest_tables_are_passed_through(): src_nodes = pa.table({"id": pa.array([10, 20], type=pa.int64()), "label": pa.array(["a", "b"])}) dst_nodes = pa.table({"id": pa.array([10, 20], type=pa.int64()), "label": pa.array(["c", "d"])}) rels = _rels([10], [20]) - g = convert_arrow_tables_to_csr(src_nodes, dst_nodes, rels) + g = IcebugMemGraph.from_arrow_tables(src_nodes, rels, to_node_arrow_table=dst_nodes) assert g.src.equals(src_nodes) assert g.dest.equals(dst_nodes) +def test_omitting_to_node_table_uses_from_node_for_both(): + nodes = _nodes(0, 1) + rels = _rels([0], [1]) + g = IcebugMemGraph.from_arrow_tables(nodes, rels) + assert g.src.equals(nodes) + assert g.dest.equals(nodes) + + # --------------------------------------------------------------------------- # Input validation # --------------------------------------------------------------------------- @@ -230,7 +247,14 @@ def test_rel_table_with_fewer_than_two_columns_raises(): nodes = _nodes(0, 1) bad_rels = pa.table({"source": pa.array([0], type=pa.int64())}) with pytest.raises(ValueError, match="at least 2 columns"): - convert_arrow_tables_to_csr(nodes, nodes, bad_rels) + IcebugMemGraph.from_arrow_tables(nodes, bad_rels) + + +def test_undirected_with_to_node_table_raises(): + nodes = _nodes(0, 1) + rels = _rels([0], [1]) + with pytest.raises(ValueError, match="to_node_arrow_table must not be provided"): + IcebugMemGraph.from_arrow_tables(nodes, rels, to_node_arrow_table=nodes, undirected=True) def test_column_names_with_spaces_are_handled(): @@ -239,14 +263,14 @@ def test_column_names_with_spaces_are_handled(): "source": pa.array([0], type=pa.int64()), "destination": pa.array([1], type=pa.int64()), }) - g = convert_arrow_tables_to_csr(nodes, nodes, rels) + g = IcebugMemGraph.from_arrow_tables(nodes, rels) assert len(g.indices) == 1 def test_empty_edges_produces_zero_indptr(): nodes = _nodes(0, 1, 2) rels = _rels([], []) - g = convert_arrow_tables_to_csr(nodes, nodes, rels) + g = IcebugMemGraph.from_arrow_tables(nodes, rels) assert len(g.indices) == 0 assert g.indptr["ptr"].to_pylist() == [0, 0, 0, 0]