Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src_cpp/py_database.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ void PyDatabase::initialize(py::handle& m) {
bool, bool, bool>(),
py::arg("database_path"), py::arg("buffer_pool_size") = 0,
py::arg("max_num_threads") = 0, py::arg("compression") = true,
py::arg("read_only") = false, py::arg("max_db_size") = (uint64_t)1 << 43,
py::arg("read_only") = false, py::arg("max_db_size") = -1u,
py::arg("auto_checkpoint") = true, py::arg("checkpoint_threshold") = -1,
py::arg("throw_on_wal_replay_failure") = true, py::arg("enable_checksums") = true,
py::arg("enable_multi_writes") = false)
Expand Down
7 changes: 5 additions & 2 deletions src_py/_lbug_capi.py
Original file line number Diff line number Diff line change
Expand Up @@ -1195,7 +1195,7 @@ def __init__(
max_num_threads: int = 0,
compression: bool = True,
read_only: bool = False,
max_db_size: int = (1 << 30),
max_db_size: int | None = None,
auto_checkpoint: bool = True,
checkpoint_threshold: int = -1,
throw_on_wal_replay_failure: bool = True,
Expand All @@ -1208,7 +1208,10 @@ def __init__(
config.max_num_threads = max_num_threads
config.enable_compression = compression
config.read_only = read_only
config.max_db_size = max_db_size

if max_db_size is not None:
config.max_db_size = max_db_size

config.auto_checkpoint = auto_checkpoint
if checkpoint_threshold >= 0:
config.checkpoint_threshold = checkpoint_threshold
Expand Down
40 changes: 22 additions & 18 deletions src_py/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def __init__(
compression: bool = True,
lazy_init: bool = False,
read_only: bool = False,
max_db_size: int = (1 << 30),
max_db_size: int | None = None,
auto_checkpoint: bool = True,
checkpoint_threshold: int = -1,
throw_on_wal_replay_failure: bool = True,
Expand Down Expand Up @@ -77,12 +77,11 @@ def __init__(
database path.
Default to False.

max_db_size : int
max_db_size : int, optional
The maximum size of the database in bytes. Note that this is introduced
temporarily for now to get around with the default 8TB mmap address
space limit some environment. This will be removed once we implemente
a better solution later. The value is default to 1 << 43 (8TB) under 64-bit
environment and 1GB under 32-bit one.
space limit some environment. This will be removed once we implement
a better solution later. If not specified, the backend's default is used.

auto_checkpoint: bool
If true, the database will automatically checkpoint when the size of
Expand Down Expand Up @@ -242,19 +241,24 @@ def init_pybind_database(self) -> Any | None:
if pybind_module is None:
return None
if self._pybind_database is None:
self._pybind_database = pybind_module.Database(
self.database_path,
self.buffer_pool_size,
self.max_num_threads,
self.compression,
self.read_only,
self.max_db_size,
self.auto_checkpoint,
self.checkpoint_threshold,
self.throw_on_wal_replay_failure,
self.enable_checksums,
self.enable_multi_writes,
)
kwargs = {
"database_path": self.database_path,
"buffer_pool_size": self.buffer_pool_size,
"max_num_threads": self.max_num_threads,
"compression": self.compression,
"read_only": self.read_only,
"auto_checkpoint": self.auto_checkpoint,
"checkpoint_threshold": self.checkpoint_threshold,
"throw_on_wal_replay_failure": self.throw_on_wal_replay_failure,
"enable_checksums": self.enable_checksums,
"enable_multi_writes": self.enable_multi_writes,
}

if self.max_db_size is not None:
kwargs["max_db_size"] = self.max_db_size

self._pybind_database = pybind_module.Database(**kwargs)

return self._pybind_database

def get_torch_geometric_remote_backend(
Expand Down
21 changes: 19 additions & 2 deletions test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,9 @@ def init_movie_serial(conn: lb.Connection) -> None:


_POOL_SIZE_: int = 256 * 1024 * 1024
# Use 1GB max DB size for tests to avoid exhausting virtual address space
# when many databases are open simultaneously (CI runners may have tight VA limits)
_MAX_DB_SIZE_: int = 1 << 30
Comment thread
aheev marked this conversation as resolved.


def get_db_file_path(tmp_path: Path) -> Path:
Expand Down Expand Up @@ -228,7 +231,12 @@ def _close_cached_readonly_state() -> None:

def create_conn_db(path: Path, *, read_only: bool) -> ConnDB:
"""Return a new connection and database."""
db = lb.Database(path, buffer_pool_size=_POOL_SIZE_, read_only=read_only)
db = lb.Database(
path,
buffer_pool_size=_POOL_SIZE_,
read_only=read_only,
max_db_size=_MAX_DB_SIZE_,
)
conn = lb.Connection(db, num_threads=4)
return conn, db

Expand Down Expand Up @@ -290,11 +298,20 @@ def conn_db_empty(tmp_path: Path) -> ConnDB:
db.close()


@pytest.fixture(scope="session")
def max_db_size() -> int:
"""Return the maximum database size used across tests."""
return _MAX_DB_SIZE_


@pytest.fixture
def conn_db_in_mem() -> ConnDB:
"""Return a new in-memory connection and database."""
db = lb.Database(
database_path=":memory:", buffer_pool_size=_POOL_SIZE_, read_only=False
database_path=":memory:",
buffer_pool_size=_POOL_SIZE_,
read_only=False,
max_db_size=_MAX_DB_SIZE_,
)
conn = lb.Connection(db, num_threads=4)
try:
Expand Down
9 changes: 0 additions & 9 deletions test/test_iteration.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import ladybug as lb
from type_aliases import ConnDB


Expand Down Expand Up @@ -35,11 +34,7 @@ def test_iteration_loop(conn_db_in_mem: ConnDB) -> None:

def test_get_all(conn_db_in_mem: ConnDB) -> None:
conn, _ = conn_db_in_mem
db = lb.Database(database_path=":memory:")
assert not db.is_closed
assert db._database is not None

conn = lb.Connection(db)
conn.execute("CREATE NODE TABLE person(name STRING, age INT64, PRIMARY KEY(name));")
conn.execute("CREATE (:person {name: 'Alice', age: 30});")
conn.execute("CREATE (:person {name: 'Bob', age: 40});")
Expand All @@ -54,11 +49,7 @@ def test_get_all(conn_db_in_mem: ConnDB) -> None:

def test_get_n(conn_db_in_mem: ConnDB) -> None:
conn, _ = conn_db_in_mem
db = lb.Database(database_path=":memory:")
assert not db.is_closed
assert db._database is not None

conn = lb.Connection(db)
conn.execute("CREATE NODE TABLE person(name STRING, age INT64, PRIMARY KEY(name));")
conn.execute("CREATE (:person {name: 'Alice', age: 30});")
conn.execute("CREATE (:person {name: 'Bob', age: 40});")
Expand Down
18 changes: 13 additions & 5 deletions test/test_mvcc_bank.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,16 +240,21 @@ def run_bank_test(
n_readers: int,
duration: int,
enable_multi_writes: bool,
max_db_size: int,
seed: int = 42,
) -> Stats:
rng = random.Random(seed)
edges = build_edges(N_ACCOUNTS, EDGE_PROB, rng)

try:
db = lb.Database(str(db_path), enable_multi_writes=enable_multi_writes)
db = lb.Database(
str(db_path),
enable_multi_writes=enable_multi_writes,
max_db_size=max_db_size,
)
except TypeError:
# Fallback if binding patch is not applied
db = lb.Database(str(db_path))
db = lb.Database(str(db_path), max_db_size=max_db_size)

setup_db(db, N_ACCOUNTS, edges)

Expand Down Expand Up @@ -287,20 +292,21 @@ def run_bank_test(
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
def test_single_writer_no_anomalies(tmp_path: Path) -> None:
def test_single_writer_no_anomalies(tmp_path: Path, max_db_size: int) -> None:
"""Baseline: single writer, no concurrent write transactions."""
stats = run_bank_test(
tmp_path / "bank_single.lbdb",
n_writers=1,
n_readers=2,
duration=DURATION_SINGLE_WRITER,
enable_multi_writes=False,
max_db_size=max_db_size,
)
assert stats.anomalies == [], f"MVCC anomalies detected: {stats.anomalies}"
assert stats.reads_failed == 0, f"Reader errors: {stats.reads_failed}"


def test_multi_writer_no_anomalies(tmp_path: Path) -> None:
def test_multi_writer_no_anomalies(tmp_path: Path, max_db_size: int) -> None:
"""
Four concurrent writers with enable_multi_writes=True.

Expand All @@ -313,6 +319,7 @@ def test_multi_writer_no_anomalies(tmp_path: Path) -> None:
n_readers=2,
duration=DURATION_MULTI_WRITER,
enable_multi_writes=True,
max_db_size=max_db_size,
)
assert stats.anomalies == [], f"MVCC anomalies detected: {stats.anomalies}"
assert stats.reads_failed == 0, f"Reader errors: {stats.reads_failed}"
Expand All @@ -321,7 +328,7 @@ def test_multi_writer_no_anomalies(tmp_path: Path) -> None:


@pytest.mark.slow
def test_multi_writer_stress_no_anomalies(tmp_path: Path) -> None:
def test_multi_writer_stress_no_anomalies(tmp_path: Path, max_db_size: int) -> None:
"""
Stress: 8 writers / 4 readers for 60 s (matches adsharma README example).

Expand All @@ -333,5 +340,6 @@ def test_multi_writer_stress_no_anomalies(tmp_path: Path) -> None:
n_readers=4,
duration=60,
enable_multi_writes=True,
max_db_size=max_db_size,
)
assert stats.anomalies == [], f"MVCC anomalies detected: {stats.anomalies}"
26 changes: 16 additions & 10 deletions test/test_wal.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,24 @@
from conftest import get_db_file_path


def run_query_in_new_process(tmp_path: Path, build_dir: Path, queries: str):
def run_query_in_new_process(
tmp_path: Path, build_dir: Path, queries: str, max_db_size: int
):
db_path = get_db_file_path(tmp_path)
code = dedent(f"""
import sys
sys.path.append(r"{build_dir!s}")

import ladybug as lb
db = lb.Database(r"{db_path!s}")
db = lb.Database(r"{db_path!s}", max_db_size={max_db_size})
""") + queries
return subprocess.Popen([sys.executable, "-c", code])


def run_query_then_kill(tmp_path: Path, build_dir: Path, queries: str):
proc = run_query_in_new_process(tmp_path, build_dir, queries)
def run_query_then_kill(
tmp_path: Path, build_dir: Path, queries: str, max_db_size: int
):
proc = run_query_in_new_process(tmp_path, build_dir, queries, max_db_size)
time.sleep(5)
proc.kill()
proc.wait(5)
Expand All @@ -32,15 +36,15 @@ def run_query_then_kill(tmp_path: Path, build_dir: Path, queries: str):

# Kill the database while it's in the middle of executing a long persistent query
# When we reload the database we will replay from the WAL (which will be incomplete)
def test_replay_after_kill(tmp_path: Path, build_dir: Path) -> None:
def test_replay_after_kill(tmp_path: Path, build_dir: Path, max_db_size: int) -> None:
queries = dedent("""
conn = lb.Connection(db)
conn.execute("CREATE NODE TABLE tab (id INT64, PRIMARY KEY (id));")
conn.execute("UNWIND RANGE(1,100000) AS x UNWIND RANGE(1, 100000) AS y CREATE (:tab {id: x * 100000 + y});")
""")
run_query_then_kill(tmp_path, build_dir, queries)
run_query_then_kill(tmp_path, build_dir, queries, max_db_size)
db_path = get_db_file_path(tmp_path)
with lb.Database(db_path) as db, lb.Connection(db) as conn:
with lb.Database(db_path, max_db_size=max_db_size) as db, lb.Connection(db) as conn:
# previously committed queries should be valid after replaying WAL
result = conn.execute("CALL show_tables() RETURN *")
assert result.has_next()
Expand All @@ -49,7 +53,9 @@ def test_replay_after_kill(tmp_path: Path, build_dir: Path) -> None:
result.close()


def test_replay_with_exception(tmp_path: Path, build_dir: Path) -> None:
def test_replay_with_exception(
tmp_path: Path, build_dir: Path, max_db_size: int
) -> None:
queries = dedent("""
conn = lb.Connection(db)
conn.execute("CREATE NODE TABLE tab (id INT64, PRIMARY KEY (id));")
Expand All @@ -62,9 +68,9 @@ def test_replay_with_exception(tmp_path: Path, build_dir: Path) -> None:
assert i % 2 == 1
conn.execute("UNWIND RANGE(1,100000) AS x UNWIND RANGE(1, 100000) AS y CREATE (:tab {id: x * 100000 + y});")
""")
run_query_then_kill(tmp_path, build_dir, queries)
run_query_then_kill(tmp_path, build_dir, queries, max_db_size)
db_path = get_db_file_path(tmp_path)
with lb.Database(db_path) as db, lb.Connection(db) as conn:
with lb.Database(db_path, max_db_size=max_db_size) as db, lb.Connection(db) as conn:
# previously committed queries should be valid after replaying WAL
result = conn.execute("match (t:tab) where t.id <= 5 return t.id")
assert result.get_num_tuples() == 5
Expand Down
Loading