Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 144 additions & 0 deletions src/lang2sql/adapters/db/dsn_builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
"""Form fields → DSN assembly.

The setup wizard collects credentials field-by-field so non-developers never
see a DSN string. Each ``build_*`` here turns those fields into the canonical
SQLAlchemy/D1 URL that :func:`build_explorer` already understands. Splitting
this off keeps the wizard's UI layer (Discord modals) thin and lets us unit-
test the assembly without a Discord runtime.
"""

from __future__ import annotations

from dataclasses import dataclass
from urllib.parse import quote_plus


@dataclass
class ConnectionSpec:
"""The wizard's output: a DSN + any out-of-band secrets the adapter needs."""

dsn: str
extras: dict[str, str]


# Supported DB types in the wizard. Order matters — surfaces in the dropdown.
SUPPORTED_DB_TYPES: tuple[str, ...] = (
"postgresql",
"mysql",
"snowflake",
"bigquery",
"duckdb",
"d1",
)


def _quote(s: str) -> str:
return quote_plus(s, safe="")


def build_postgresql(*, host: str, port: str, database: str, user: str, password: str) -> ConnectionSpec:
p = int(port) if port else 5432
dsn = f"postgresql+psycopg://{_quote(user)}:{_quote(password)}@{host}:{p}/{database}"
return ConnectionSpec(dsn=dsn, extras={})


def build_mysql(*, host: str, port: str, database: str, user: str, password: str) -> ConnectionSpec:
p = int(port) if port else 3306
dsn = f"mysql+pymysql://{_quote(user)}:{_quote(password)}@{host}:{p}/{database}"
return ConnectionSpec(dsn=dsn, extras={})


def build_snowflake(
*, account: str, user: str, password: str, database: str, warehouse: str
) -> ConnectionSpec:
dsn = (
f"snowflake://{_quote(user)}:{_quote(password)}@{account}"
f"/{database}?warehouse={_quote(warehouse)}"
)
return ConnectionSpec(dsn=dsn, extras={})


def build_bigquery(*, project: str, dataset: str) -> ConnectionSpec:
# Auth via Application Default Credentials (gcloud) — credentials are not
# in the DSN. We document this in the wizard's success message.
dsn = f"bigquery://{project}/{dataset}"
return ConnectionSpec(dsn=dsn, extras={})


def build_duckdb(*, path: str) -> ConnectionSpec:
return ConnectionSpec(dsn=f"duckdb:///{path}", extras={})


def build_d1(*, account_id: str, database_id: str, api_token: str) -> ConnectionSpec:
# The token doesn't go in the URL — it's an out-of-band header.
return ConnectionSpec(
dsn=f"d1://{account_id}/{database_id}",
extras={"d1_token": api_token},
)


# Field schemas surfaced by the Discord Modal layer. Each entry is
# (label, placeholder, required, masked).
FIELD_SCHEMA: dict[str, list[tuple[str, str, bool, bool]]] = {
"postgresql": [
("host", "db.example.com", True, False),
("port", "5432", False, False),
("database", "analytics", True, False),
("user", "readonly_user", True, False),
("password", "•••••", True, True),
],
"mysql": [
("host", "db.example.com", True, False),
("port", "3306", False, False),
("database", "analytics", True, False),
("user", "readonly_user", True, False),
("password", "•••••", True, True),
],
"snowflake": [
("account", "abc12345.us-east-1", True, False),
("user", "readonly_user", True, False),
("password", "•••••", True, True),
("database", "ANALYTICS", True, False),
("warehouse", "COMPUTE_WH", True, False),
],
"bigquery": [
("project", "my-gcp-project", True, False),
("dataset", "analytics", True, False),
],
"duckdb": [
("path", "/data/warehouse.duckdb", True, False),
],
"d1": [
("account_id", "Cloudflare account ID", True, False),
("database_id", "D1 database ID", True, False),
("api_token", "Cloudflare API token", True, True),
],
}


_BUILDERS = {
"postgresql": build_postgresql,
"mysql": build_mysql,
"snowflake": build_snowflake,
"bigquery": build_bigquery,
"duckdb": build_duckdb,
"d1": build_d1,
}


def assemble(db_type: str, fields: dict[str, str]) -> ConnectionSpec:
"""Dispatch by ``db_type`` to the matching builder.

The wizard hands raw modal inputs in ``fields``; this is the one entry
point so the UI layer stays dialect-agnostic.
"""
builder = _BUILDERS.get(db_type)
if builder is None:
raise ValueError(f"unsupported db type: {db_type!r}")
# Filter to the expected kwargs (modal can hand stray keys safely).
expected = {name for name, *_ in FIELD_SCHEMA[db_type]}
cleaned = {k: (v or "").strip() for k, v in fields.items() if k in expected}
missing = [n for n, _, req, _ in FIELD_SCHEMA[db_type] if req and not cleaned.get(n)]
if missing:
raise ValueError(f"missing required fields: {', '.join(missing)}")
return builder(**cleaned)
19 changes: 16 additions & 3 deletions src/lang2sql/adapters/db/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,18 @@
from .sqlalchemy_explorer import SqlAlchemyExplorer


def build_explorer(connection: str, *, schema: str | None = None) -> ExplorerPort:
def build_explorer(
connection: str,
*,
schema: str | None = None,
extras: dict | None = None,
) -> ExplorerPort:
"""Route a connection string to the matching explorer adapter.

``schema`` is forwarded to the SQLAlchemy explorer (ignored by D1, which is
schema-less SQLite). Raises ``ValueError`` on an empty/unparseable string.
schema-less SQLite). ``extras`` carries per-adapter secrets that don't
belong in the URL — currently ``d1_token`` for the D1 HTTP API. Raises
``ValueError`` on an empty/unparseable string.
"""
if not connection or not connection.strip():
raise ValueError("empty connection string")
Expand All @@ -34,13 +41,19 @@ def build_explorer(connection: str, *, schema: str | None = None) -> ExplorerPor
if not scheme:
raise ValueError(f"connection string has no scheme: {connection!r}")

extras = extras or {}

if scheme == "d1":
parts = urlsplit(connection)
account_id = parts.netloc
database_id = parts.path.lstrip("/")
if not account_id or not database_id:
raise ValueError("d1 URL must be d1://<account_id>/<database_id>")
return D1Explorer(account_id=account_id, database_id=database_id)
return D1Explorer(
account_id=account_id,
database_id=database_id,
token=extras.get("d1_token"),
)

# Anything else is assumed to be a SQLAlchemy URL (driver loaded lazily).
return SqlAlchemyExplorer(connection, schema=schema)
Expand Down
5 changes: 5 additions & 0 deletions src/lang2sql/frontends/discord/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,11 @@ def _register_commands(self) -> None:
tree = self.tree
handlers = self._handlers

@tree.command(name="setup", description="Connect a database with a guided form (no DSN needed)")
async def setup(interaction: discord.Interaction) -> None:
from .setup_wizard import start_setup_flow # local import — discord-only path
await start_setup_flow(interaction, handlers, _interaction_context)

@tree.command(name="connect", description="Store a database connection string")
async def connect(interaction: discord.Interaction, dsn: str) -> None:
await self._run(interaction, handlers.connect(to_identity(_interaction_context(interaction)), dsn))
Expand Down
56 changes: 56 additions & 0 deletions src/lang2sql/frontends/discord/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

from datetime import datetime, timezone

from ...adapters.db import build_explorer
from ...adapters.db.dsn_builder import assemble
from ...core.identity import Identity
from ...core.ports.frontend import OutboundMessage
from ...harness.loop import agent_loop
Expand Down Expand Up @@ -93,6 +95,60 @@ async def audit_me(self, identity: Identity) -> OutboundMessage:
lines.append(f"- {_fmt_ts(event.ts)} {event.action} @ {event.scope}")
return OutboundMessage(text="\n".join(lines))

async def register_db_for_guild(
self,
identity: Identity,
db_type: str,
fields: dict[str, str],
) -> OutboundMessage:
"""The /setup wizard's commit step (non-developer entry point).

Takes the wizard's per-field inputs (no DSN literals), assembles the
DSN, tests the connection by listing tables once, and on success
stores the DSN (+ any out-of-band token) under the guild's scope via
:class:`EncryptedSecrets`. The next ``build_context`` for this guild
will use this DB transparently.
"""
try:
spec = assemble(db_type, fields)
except ValueError as exc:
return OutboundMessage(text=f"⚠️ Setup error: {exc}")

try:
explorer = build_explorer(spec.dsn, extras=spec.extras)
tables = await explorer.list_tables()
except ModuleNotFoundError as exc:
return OutboundMessage(
text=(
f"⚠️ Connection driver not installed for {db_type}. "
f"Ask an admin to run `uv sync --extra {db_type}`.\n"
f"(details: {exc})"
)
)
except Exception as exc: # surface what the DB said, but stay user-friendly
return OutboundMessage(
text=(
f"❌ Couldn't connect to {db_type}: {type(exc).__name__}: {exc}.\n"
"Common causes: wrong host/port, network/firewall, "
"wrong credentials, or read permission missing."
)
)

scope = identity.guild_id or f"dm:{identity.user_id}"
await self._concierge.secrets.set(scope, "db_dsn", spec.dsn)
for k, v in spec.extras.items():
await self._concierge.secrets.set(scope, f"db_extras.{k}", v)
# Bust any cached explorer for this scope so the next turn picks it up.
self._concierge.forget_explorer(scope)

return OutboundMessage(
text=(
f"✅ Connected to **{db_type}** — found **{len(tables)} table(s)**. "
"Your credentials are stored encrypted; you can `/semantic_show` "
"or just ask a question now."
)
)

async def connect(self, identity: Identity, dsn: str) -> OutboundMessage:
"""V1 stub: stash a DB DSN keyed by guild/DM in the concierge kv store.

Expand Down
124 changes: 124 additions & 0 deletions src/lang2sql/frontends/discord/setup_wizard.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
"""``/setup`` — a zero-DSN connection wizard for non-developers.

The user never sees a SQLAlchemy URL or an env file. They run ``/setup``, pick
their database from a dropdown, and fill a short form. We assemble the DSN,
test the connection by listing tables, and store the credentials encrypted via
:class:`EncryptedSecrets` keyed by the guild scope. The next message in that
guild transparently uses the new database.

Discord coupling lives only here and in ``bot.py``: the actual register-and-
test logic is :meth:`CommandHandlers.register_db_for_guild` (pure, testable).
"""

from __future__ import annotations

from typing import TYPE_CHECKING

import discord
from discord import ui

from ...adapters.db.dsn_builder import FIELD_SCHEMA, SUPPORTED_DB_TYPES
from .session_router import to_identity

if TYPE_CHECKING:
from .commands import CommandHandlers
from .bot import InteractionContext


# Per-DB human labels surfaced in the dropdown.
_LABELS: dict[str, str] = {
"postgresql": "PostgreSQL",
"mysql": "MySQL",
"snowflake": "Snowflake",
"bigquery": "BigQuery",
"duckdb": "DuckDB (file)",
"d1": "Cloudflare D1",
}


class _ConnectionFormModal(ui.Modal):
"""The per-DB-type form. Fields come from :data:`FIELD_SCHEMA`.

Discord modals cap at 5 :class:`ui.TextInput` rows, which matches our
widest schema (Postgres/MySQL/Snowflake). Passwords/tokens are plain text
inputs — Discord has no masked input style — but the form is ephemeral so
only the user sees what they typed.
"""

def __init__(
self,
db_type: str,
handlers: "CommandHandlers",
ctx_factory,
) -> None:
super().__init__(title=f"Connect to {_LABELS.get(db_type, db_type)}")
self._db_type = db_type
self._handlers = handlers
self._ctx_factory = ctx_factory # () -> InteractionContext
self._inputs: dict[str, ui.TextInput] = {}
for name, placeholder, required, _masked in FIELD_SCHEMA[db_type]:
inp = ui.TextInput(
label=name,
placeholder=placeholder,
required=required,
style=discord.TextStyle.short,
max_length=200,
)
self._inputs[name] = inp
self.add_item(inp)

async def on_submit(self, interaction: discord.Interaction) -> None:
# Connection test can take a few seconds; defer so Discord doesn't
# timeout the interaction. Ephemeral so only the user sees the result.
await interaction.response.defer(ephemeral=True, thinking=True)
fields = {name: inp.value for name, inp in self._inputs.items()}
identity = to_identity(self._ctx_factory(interaction))
result = await self._handlers.register_db_for_guild(
identity, self._db_type, fields
)
await interaction.followup.send(result.text, ephemeral=True)


class _DbTypeSelect(ui.Select):
"""Step 1 dropdown — pick which DB type to connect."""

def __init__(self, handlers: "CommandHandlers", ctx_factory) -> None:
options = [
discord.SelectOption(label=_LABELS[t], value=t) for t in SUPPORTED_DB_TYPES
]
super().__init__(
placeholder="Choose your database…",
options=options,
min_values=1,
max_values=1,
)
self._handlers = handlers
self._ctx_factory = ctx_factory

async def callback(self, interaction: discord.Interaction) -> None:
# Opening a modal *is* the response to this select interaction.
await interaction.response.send_modal(
_ConnectionFormModal(self.values[0], self._handlers, self._ctx_factory)
)


class _SetupView(ui.View):
"""Holds the DB-type dropdown. Auto-times out after 2 minutes."""

def __init__(self, handlers: "CommandHandlers", ctx_factory) -> None:
super().__init__(timeout=120.0)
self.add_item(_DbTypeSelect(handlers, ctx_factory))


async def start_setup_flow(
interaction: discord.Interaction,
handlers: "CommandHandlers",
ctx_factory,
) -> None:
"""Entry point bot.py wires to ``/setup`` — surfaces the picker ephemerally."""
await interaction.response.send_message(
"Let's connect your database. Pick its type, then fill the form. "
"Your credentials are stored encrypted; nobody else sees what you type.",
view=_SetupView(handlers, ctx_factory),
ephemeral=True,
)
Loading
Loading