Skip to content

Commit

Permalink
Adds db iter and graph_db -> edge list tool
Browse files Browse the repository at this point in the history
  • Loading branch information
JSybrandt committed May 18, 2020
1 parent a6024df commit 7ef2769
Show file tree
Hide file tree
Showing 4 changed files with 167 additions and 1 deletion.
13 changes: 13 additions & 0 deletions agatha/util/misc_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,19 @@


def iter_to_batches(iterable, batch_size):
"""
Chunks the input iterable into fixed-sized batches.
Example:
```python3
list(iter_to_batches(range(10), 3))
[
[0, 1, 2],
[3, 4, 5],
[6, 7, 8],
[9]
]
```
"""
args = [iter(iterable)] * batch_size
for batch in zip_longest(*args):
yield list(filter(lambda b: b is not None, batch))
Expand Down
48 changes: 47 additions & 1 deletion agatha/util/sqlite3_lookup.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import sqlite3
from pathlib import Path
import json
from typing import List, Any, Set, Optional
from typing import List, Any, Set, Optional, Tuple
import dask.bag as dbag
from agatha.util.misc_util import Record
import os
Expand Down Expand Up @@ -400,6 +400,52 @@ def __len__(self)->int:
).fetchone()[0]
return self._len

class _CursorIterator:
"""
Iterates a cursor.
Expects cursor to return key, value pairs
"""
def __init__(self, cursor):
self.cursor = cursor

def __next__(self)->Tuple[str, Any]:
key_value = next(self.cursor)
if key_value is None:
return None
assert len(key_value) == 2, \
"_CursorIterator's cursor must produce key-value pairs."
key, val = key_value
val = json.loads(val)
return key, val

def __iter__(self):
return self

def iterate(self, where:Optional[str]=None):
"""
Returns an iterator to the underlying database.
If `where` is specified, returned rows will be conditioned.
Note, when writing a `where` clause that columns are `key` and `value`
"""
assert self.connected(), "Attempting to operate on closed db."
query_stmt = f"""
SELECT
{self.key_column_name} as key,
{self.value_column_name} as value
FROM {self.table_name}
"""
if where is not None:
query_stmt += f"""
WHERE {where}
"""
return Sqlite3LookupTable._CursorIterator(self._cursor.execute(query_stmt))

def __iter__(self):
"""
Returns an iterator that enables us to loop through the entire database.
"""
return self.iterate()


################################################################################
## SPECIAL CASES ###############################################################
Expand Down
26 changes: 26 additions & 0 deletions agatha/util/test_sqlite3_lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,3 +275,29 @@ def test_len():
db_path = make_sqlite3_db("test_len", expected)
table = Sqlite3LookupTable(db_path)
assert len(table) == len(expected)

def test_iter():
expected = {
"A": ["B", "C"],
"B": ["A"],
"C": ["A"],
}
db_path = make_sqlite3_db("test_iter", expected)
table = Sqlite3LookupTable(db_path)
actual = {k: v for k, v in table}
assert actual == expected

def test_iter_where():
db_data = {
"AA": ["B", "C"],
"BBBB": ["A"],
"CC": ["A"],
}
db_path = make_sqlite3_db("test_iter_where", db_data)
table = Sqlite3LookupTable(db_path)
actual = {k: v for k, v in table.iterate(where="length(key) = 2")}
expected= {
"AA": ["B", "C"],
"CC": ["A"],
}
assert actual == expected
81 changes: 81 additions & 0 deletions tools/py_scripts/sqlite_graph_to_edge_json.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
#!/usr/bin/env python3
"""
Extracts Graph JSON files from an Sqlite3 File.
This is required to perform large-scale graph operations on a graph that has
already been collected into an sqlite3 database. Typically, we discard the
intermediate graph files that are used to create these large graph databases.
"""

from pathlib import Path
from agatha.util.sqlite3_lookup import Sqlite3Graph
from agatha.util.misc_util import iter_to_batches
from fire import Fire
import json
from tqdm import tqdm


def main(
input_db:Path,
output_dir:Path,
nodes_per_file:int=1e6,
output_file_fmt_str:str="{:08d}.json",
disable_pbar:bool=False
):
"""Sqlite3Graph -> Edge Json
Args:
input_db: A graph sqlite3 table.
output_dir: The location of a directory that we are going to make and fill
with json files.
nodes_per_file: Each file is going to contain at most this number of nodes.
output_file_fmt_str: This string will be called with `.format(int)` for
each output file. Must produce unique names for each string.
"""

input_db = Path(input_db)
assert input_db.is_file(), f"Failed to find {input_db}"

output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
assert output_dir.is_dir(), f"Failed to find dir: {output_dir}"
assert len(list(output_dir.iterdir())) == 0, f"{output_dir} is not empty."

nodes_per_file = int(nodes_per_file)
assert nodes_per_file > 0, \
"Must supply positive number of edges per output file."

try:
format_test = output_file_fmt_str.format(1)
except Exception:
assert False, "output_file_fmt_str must contain format component `{}`"

print("Opening", input_db)
graph = Sqlite3Graph(input_db)
if not disable_pbar:
num_nodes = len(graph)
graph = tqdm(graph, total=num_nodes)
graph.set_description("Reading edges")

for batch_idx, edge_batch in enumerate(
iter_to_batches(graph, nodes_per_file)
):
file_path = output_dir.joinpath(output_file_fmt_str.format(batch_idx+1))
if disable_pbar:
print(file_path)
else:
graph.set_description(f"Writing {file_path}")
assert not file_path.exists(), \
f"Error: {file_path} already exists. Is your format string bad?"
with open(file_path, 'w') as json_file:
for node, neighbors in edge_batch:
for neigh in neighbors:
json_file.write(json.dumps({"key": node, "value": neigh}))
json_file.write("\n")
if not disable_pbar:
graph.set_description("Reading edges")
print("Done!")

if __name__ == "__main__":
Fire(main)

0 comments on commit 7ef2769

Please sign in to comment.