Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion packages/backend/embedding_atlas/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def find_available_port(start_port: int, max_attempts: int = 10, host="localhost
@click.option(
"--duckdb",
type=str,
default="wasm",
default="server",
help="DuckDB connection mode: 'wasm' (run in browser), 'server' (run on this server), or URI (e.g., 'ws://localhost:3000').",
)
@click.option(
Expand Down
26 changes: 15 additions & 11 deletions packages/backend/embedding_atlas/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,20 +79,16 @@ async def make_archive():
data = data_source.make_archive(static_path)
return Response(content=data, media_type="application/zip")

# Database connection

@lru_cache(maxsize=1)
def get_connection():
con = duckdb.connect(":memory:")
df = data_source.dataset
_ = df
con.sql("CREATE TABLE dataset AS (SELECT * FROM df)")
return con
if duckdb_uri == "server":
duckdb_connection = make_duckdb_connection(data_source.dataset)
else:
duckdb_connection = None

def handle_query(query: dict):
assert duckdb_connection is not None
sql = query["sql"]
command = query["type"]
with get_connection().cursor() as cursor:
with duckdb_connection.cursor() as cursor:
try:
result = cursor.execute(sql)
if command == "exec":
Expand All @@ -111,6 +107,7 @@ def handle_query(query: dict):
return JSONResponse({"error": str(e)}, status_code=500)

def handle_selection(query: dict):
assert duckdb_connection is not None
predicate = query.get("predicate", None)
format = query["format"]
formats = {
Expand All @@ -119,7 +116,7 @@ def handle_selection(query: dict):
"csv": "(FORMAT CSV)",
"parquet": "(FORMAT parquet)",
}
with get_connection().cursor() as cursor:
with duckdb_connection.cursor() as cursor:
filename = ".selection-" + str(uuid.uuid4()) + ".tmp"
try:
if predicate is not None:
Expand Down Expand Up @@ -172,6 +169,13 @@ async def post_selection(req: Request):
return app


def make_duckdb_connection(df):
con = duckdb.connect(":memory:")
_ = df # used in the query
con.sql("CREATE TABLE dataset AS (SELECT * FROM df)")
return con


def parse_range_header(request: Request, content_length: int):
value = request.headers.get("Range")
if value is not None:
Expand Down