From 18d77d05fe1a185e65cf815c244d47f15b0dfac9 Mon Sep 17 00:00:00 2001 From: Donghao Ren Date: Wed, 8 Oct 2025 10:32:24 -0700 Subject: [PATCH] feat: default to server-side DuckDB --- packages/backend/embedding_atlas/cli.py | 2 +- packages/backend/embedding_atlas/server.py | 26 +++++++++++++--------- 2 files changed, 16 insertions(+), 12 deletions(-) diff --git a/packages/backend/embedding_atlas/cli.py b/packages/backend/embedding_atlas/cli.py index 8f08516..abdf3f5 100644 --- a/packages/backend/embedding_atlas/cli.py +++ b/packages/backend/embedding_atlas/cli.py @@ -177,7 +177,7 @@ def find_available_port(start_port: int, max_attempts: int = 10, host="localhost @click.option( "--duckdb", type=str, - default="wasm", + default="server", help="DuckDB connection mode: 'wasm' (run in browser), 'server' (run on this server), or URI (e.g., 'ws://localhost:3000').", ) @click.option( diff --git a/packages/backend/embedding_atlas/server.py b/packages/backend/embedding_atlas/server.py index 16b7c33..ce80189 100644 --- a/packages/backend/embedding_atlas/server.py +++ b/packages/backend/embedding_atlas/server.py @@ -79,20 +79,16 @@ async def make_archive(): data = data_source.make_archive(static_path) return Response(content=data, media_type="application/zip") - # Database connection - - @lru_cache(maxsize=1) - def get_connection(): - con = duckdb.connect(":memory:") - df = data_source.dataset - _ = df - con.sql("CREATE TABLE dataset AS (SELECT * FROM df)") - return con + if duckdb_uri == "server": + duckdb_connection = make_duckdb_connection(data_source.dataset) + else: + duckdb_connection = None def handle_query(query: dict): + assert duckdb_connection is not None sql = query["sql"] command = query["type"] - with get_connection().cursor() as cursor: + with duckdb_connection.cursor() as cursor: try: result = cursor.execute(sql) if command == "exec": @@ -111,6 +107,7 @@ def handle_query(query: dict): return JSONResponse({"error": str(e)}, status_code=500) def handle_selection(query: dict): + assert duckdb_connection is not None predicate = query.get("predicate", None) format = query["format"] formats = { @@ -119,7 +116,7 @@ def handle_selection(query: dict): "csv": "(FORMAT CSV)", "parquet": "(FORMAT parquet)", } - with get_connection().cursor() as cursor: + with duckdb_connection.cursor() as cursor: filename = ".selection-" + str(uuid.uuid4()) + ".tmp" try: if predicate is not None: @@ -172,6 +169,13 @@ async def post_selection(req: Request): return app +def make_duckdb_connection(df): + con = duckdb.connect(":memory:") + _ = df # used in the query + con.sql("CREATE TABLE dataset AS (SELECT * FROM df)") + return con + + def parse_range_header(request: Request, content_length: int): value = request.headers.get("Range") if value is not None: